Add more type hints
This commit is contained in:
+34
-26
@@ -14,32 +14,39 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Optional, Awaitable, Pattern, Dict, List, TYPE_CHECKING
|
||||
from difflib import SequenceMatcher
|
||||
from typing import Optional, Awaitable
|
||||
import re
|
||||
import logging
|
||||
import asyncio
|
||||
|
||||
from sqlalchemy import orm
|
||||
|
||||
from telethon.tl.types import UserProfilePhoto
|
||||
from mautrix_appservice import AppService, IntentAPI, IntentError, MatrixRequestError
|
||||
|
||||
from .db import Puppet as DBPuppet
|
||||
from . import util, matrix
|
||||
from . import util
|
||||
|
||||
config = None
|
||||
if TYPE_CHECKING:
|
||||
from .matrix import MatrixHandler
|
||||
from .config import Config
|
||||
from .context import Context
|
||||
|
||||
config = None # type: Config
|
||||
|
||||
|
||||
class Puppet:
|
||||
log = logging.getLogger("mau.puppet")
|
||||
db = None
|
||||
log = logging.getLogger("mau.puppet") # type: logging.Logger
|
||||
db = None # type: orm.Session
|
||||
az = None # type: AppService
|
||||
mx = None # type: matrix.MatrixHandler
|
||||
mx = None # type: MatrixHandler
|
||||
loop = None # type: asyncio.AbstractEventLoop
|
||||
mxid_regex = None
|
||||
username_template = None
|
||||
hs_domain = None
|
||||
cache = {}
|
||||
by_custom_mxid = {}
|
||||
mxid_regex = None # type: Pattern
|
||||
username_template = None # type: str
|
||||
hs_domain = None # type: str
|
||||
cache = {} # type: Dict[str, Puppet]
|
||||
by_custom_mxid = {} # type: Dict[str, Puppet]
|
||||
|
||||
def __init__(self, id=None, access_token=None, custom_mxid=None, username=None,
|
||||
displayname=None, displayname_source=None, photo_id=None, is_bot=None,
|
||||
@@ -71,7 +78,8 @@ class Puppet:
|
||||
def tgid(self):
|
||||
return self.id
|
||||
|
||||
async def is_logged_in(self):
|
||||
@staticmethod
|
||||
async def is_logged_in():
|
||||
return True
|
||||
|
||||
# region Custom puppet management
|
||||
@@ -154,12 +162,12 @@ class Puppet:
|
||||
def filter_events(self, events):
|
||||
new_events = []
|
||||
for event in events:
|
||||
type = event.get("type", None)
|
||||
evt_type = event.get("type", None)
|
||||
event.setdefault("content", {})
|
||||
if type == "m.typing":
|
||||
if evt_type == "m.typing":
|
||||
is_typing = self.custom_mxid in event["content"].get("user_ids", [])
|
||||
event["content"]["user_ids"] = [self.custom_mxid] if is_typing else []
|
||||
elif type == "m.receipt":
|
||||
elif evt_type == "m.receipt":
|
||||
val = None
|
||||
evt = None
|
||||
for event_id in event["content"]:
|
||||
@@ -273,7 +281,7 @@ class Puppet:
|
||||
return round(similarity * 1000) / 10
|
||||
|
||||
@staticmethod
|
||||
def get_displayname(info, format=True):
|
||||
def get_displayname(info, enable_format=True):
|
||||
data = {
|
||||
"phone number": info.phone if hasattr(info, "phone") else None,
|
||||
"username": info.username,
|
||||
@@ -295,7 +303,7 @@ class Puppet:
|
||||
elif not name:
|
||||
name = info.id
|
||||
|
||||
if not format:
|
||||
if not enable_format:
|
||||
return name
|
||||
return config.get("bridge.displayname_template", "{displayname} (Telegram)").format(
|
||||
displayname=name)
|
||||
@@ -347,18 +355,18 @@ class Puppet:
|
||||
# region Getters
|
||||
|
||||
@classmethod
|
||||
def get(cls, id, create=True) -> "Optional[Puppet]":
|
||||
def get(cls, tgid, create=True) -> "Optional[Puppet]":
|
||||
try:
|
||||
return cls.cache[id]
|
||||
return cls.cache[tgid]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
puppet = DBPuppet.query.get(id)
|
||||
puppet = DBPuppet.query.get(tgid)
|
||||
if puppet:
|
||||
return cls.from_db(puppet)
|
||||
|
||||
if create:
|
||||
puppet = cls(id)
|
||||
puppet = cls(tgid)
|
||||
cls.db.add(puppet.db_instance)
|
||||
cls.db.commit()
|
||||
return puppet
|
||||
@@ -402,8 +410,8 @@ class Puppet:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_mxid_from_id(cls, id):
|
||||
return f"@{cls.username_template.format(userid=id)}:{cls.hs_domain}"
|
||||
def get_mxid_from_id(cls, tgid):
|
||||
return f"@{cls.username_template.format(userid=tgid)}:{cls.hs_domain}"
|
||||
|
||||
@classmethod
|
||||
def find_by_username(cls, username) -> "Optional[Puppet]":
|
||||
@@ -437,12 +445,12 @@ class Puppet:
|
||||
# endregion
|
||||
|
||||
|
||||
def init(context):
|
||||
def init(context: "Context") -> List[Awaitable[int]]:
|
||||
global config
|
||||
Puppet.az, Puppet.db, config, Puppet.loop, _ = context
|
||||
Puppet.mx = context.mx
|
||||
Puppet.username_template = config.get("bridge.username_template", "telegram_{userid}")
|
||||
Puppet.hs_domain = config["homeserver"]["domain"]
|
||||
localpart = Puppet.username_template.format(userid="(.+)")
|
||||
Puppet.mxid_regex = re.compile(f"@{localpart}:{Puppet.hs_domain}")
|
||||
Puppet.mxid_regex = re.compile(
|
||||
f"@{Puppet.username_template.format(userid='(.+)')}:{Puppet.hs_domain}")
|
||||
return [puppet.init_custom_mxid() for puppet in Puppet.get_all_with_custom_mxid()]
|
||||
|
||||
Reference in New Issue
Block a user