Add more type hints

This commit is contained in:
Tulir Asokan
2018-07-25 10:40:31 -04:00
parent ae334b9a04
commit dbfb980bde
20 changed files with 751 additions and 595 deletions
+34 -26
View File
@@ -14,32 +14,39 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional, Awaitable, Pattern, Dict, List, TYPE_CHECKING
from difflib import SequenceMatcher
from typing import Optional, Awaitable
import re
import logging
import asyncio
from sqlalchemy import orm
from telethon.tl.types import UserProfilePhoto
from mautrix_appservice import AppService, IntentAPI, IntentError, MatrixRequestError
from .db import Puppet as DBPuppet
from . import util, matrix
from . import util
config = None
if TYPE_CHECKING:
from .matrix import MatrixHandler
from .config import Config
from .context import Context
config = None # type: Config
class Puppet:
log = logging.getLogger("mau.puppet")
db = None
log = logging.getLogger("mau.puppet") # type: logging.Logger
db = None # type: orm.Session
az = None # type: AppService
mx = None # type: matrix.MatrixHandler
mx = None # type: MatrixHandler
loop = None # type: asyncio.AbstractEventLoop
mxid_regex = None
username_template = None
hs_domain = None
cache = {}
by_custom_mxid = {}
mxid_regex = None # type: Pattern
username_template = None # type: str
hs_domain = None # type: str
cache = {} # type: Dict[str, Puppet]
by_custom_mxid = {} # type: Dict[str, Puppet]
def __init__(self, id=None, access_token=None, custom_mxid=None, username=None,
displayname=None, displayname_source=None, photo_id=None, is_bot=None,
@@ -71,7 +78,8 @@ class Puppet:
def tgid(self):
return self.id
async def is_logged_in(self):
@staticmethod
async def is_logged_in():
return True
# region Custom puppet management
@@ -154,12 +162,12 @@ class Puppet:
def filter_events(self, events):
new_events = []
for event in events:
type = event.get("type", None)
evt_type = event.get("type", None)
event.setdefault("content", {})
if type == "m.typing":
if evt_type == "m.typing":
is_typing = self.custom_mxid in event["content"].get("user_ids", [])
event["content"]["user_ids"] = [self.custom_mxid] if is_typing else []
elif type == "m.receipt":
elif evt_type == "m.receipt":
val = None
evt = None
for event_id in event["content"]:
@@ -273,7 +281,7 @@ class Puppet:
return round(similarity * 1000) / 10
@staticmethod
def get_displayname(info, format=True):
def get_displayname(info, enable_format=True):
data = {
"phone number": info.phone if hasattr(info, "phone") else None,
"username": info.username,
@@ -295,7 +303,7 @@ class Puppet:
elif not name:
name = info.id
if not format:
if not enable_format:
return name
return config.get("bridge.displayname_template", "{displayname} (Telegram)").format(
displayname=name)
@@ -347,18 +355,18 @@ class Puppet:
# region Getters
@classmethod
def get(cls, id, create=True) -> "Optional[Puppet]":
def get(cls, tgid, create=True) -> "Optional[Puppet]":
try:
return cls.cache[id]
return cls.cache[tgid]
except KeyError:
pass
puppet = DBPuppet.query.get(id)
puppet = DBPuppet.query.get(tgid)
if puppet:
return cls.from_db(puppet)
if create:
puppet = cls(id)
puppet = cls(tgid)
cls.db.add(puppet.db_instance)
cls.db.commit()
return puppet
@@ -402,8 +410,8 @@ class Puppet:
return None
@classmethod
def get_mxid_from_id(cls, id):
return f"@{cls.username_template.format(userid=id)}:{cls.hs_domain}"
def get_mxid_from_id(cls, tgid):
return f"@{cls.username_template.format(userid=tgid)}:{cls.hs_domain}"
@classmethod
def find_by_username(cls, username) -> "Optional[Puppet]":
@@ -437,12 +445,12 @@ class Puppet:
# endregion
def init(context):
def init(context: "Context") -> List[Awaitable[int]]:
global config
Puppet.az, Puppet.db, config, Puppet.loop, _ = context
Puppet.mx = context.mx
Puppet.username_template = config.get("bridge.username_template", "telegram_{userid}")
Puppet.hs_domain = config["homeserver"]["domain"]
localpart = Puppet.username_template.format(userid="(.+)")
Puppet.mxid_regex = re.compile(f"@{localpart}:{Puppet.hs_domain}")
Puppet.mxid_regex = re.compile(
f"@{Puppet.username_template.format(userid='(.+)')}:{Puppet.hs_domain}")
return [puppet.init_custom_mxid() for puppet in Puppet.get_all_with_custom_mxid()]