Add more type hints

This commit is contained in:
Tulir Asokan
2018-07-25 10:40:31 -04:00
parent ae334b9a04
commit dbfb980bde
20 changed files with 751 additions and 595 deletions
+58 -54
View File
@@ -14,42 +14,51 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict, Awaitable, Optional
from typing import Dict, Awaitable, Optional, Match, Tuple, TYPE_CHECKING
import logging
import asyncio
import re
from telethon.tl.types import *
from telethon.tl.types import User as TLUser
from telethon.tl.types.contacts import ContactsNotModified
from telethon.tl.functions.contacts import GetContactsRequest, SearchRequest
from telethon.tl.functions.account import UpdateStatusRequest
from mautrix_appservice import MatrixRequestError
from .db import User as DBUser, Contact as DBContact
from .db import User as DBUser, Contact as DBContact, Portal as DBPortal
from .abstract_user import AbstractUser
from . import portal as po, puppet as pu
config = None
if TYPE_CHECKING:
from .config import Config
from .context import Context
config = None # type: Config
SearchResults = List[Tuple["pu.Puppet", int]]
class User(AbstractUser):
log = logging.getLogger("mau.user")
by_mxid = {}
by_tgid = {}
log = logging.getLogger("mau.user") # type: logging.Logger
by_mxid = {} # type: Dict[str, User]
by_tgid = {} # type: Dict[int, User]
def __init__(self, mxid, tgid=None, username=None, db_contacts=None, saved_contacts=0,
is_bot=False, db_portals=None, db_instance=None):
def __init__(self, mxid: str, tgid: Optional[int] = None, username: Optional[str] = None,
db_contacts: Optional[List[DBContact]] = None, saved_contacts: int = 0,
is_bot: bool = False, db_portals: Optional[List[DBPortal]] = None,
db_instance: Optional[DBUser] = None):
super().__init__()
self.mxid = mxid # type: str
self.tgid = tgid # type: int
self.is_bot = is_bot # type: bool
self.username = username # type: str
self.contacts = []
self.saved_contacts = saved_contacts
self.db_contacts = db_contacts
self.portals = {} # type: Dict[str, po.Portal]
self.db_portals = db_portals
self._db_instance = db_instance
self.contacts = [] # type: List[pu.Puppet]
self.saved_contacts = saved_contacts # type: int
self.db_contacts = db_contacts # type: List[DBContact]
self.portals = {} # type: Dict[Tuple[int, int], po.Portal]
self.db_portals = db_portals # type: List[DBPortal]
self._db_instance = db_instance # type: DBUser
self.command_status = None # type: dict
@@ -64,53 +73,47 @@ class User(AbstractUser):
self.by_tgid[tgid] = self
@property
def name(self):
def name(self) -> str:
return self.mxid
@property
def mxid_localpart(self):
match = re.compile("@(.+):(.+)").match(self.mxid)
def mxid_localpart(self) -> str:
match = re.compile("@(.+):(.+)").match(self.mxid) # type: Match
return match.group(1)
# TODO replace with proper displayname getting everywhere
@property
def displayname(self):
def displayname(self) -> str:
return self.mxid_localpart
@property
def db_contacts(self):
def db_contacts(self) -> List[DBContact]:
return [self.db.merge(DBContact(user=self.tgid, contact=puppet.id))
for puppet in self.contacts]
@db_contacts.setter
def db_contacts(self, contacts):
if contacts:
self.contacts = [pu.Puppet.get(entry.contact) for entry in contacts]
else:
self.contacts = []
def db_contacts(self, contacts: List[DBContact]):
self.contacts = [pu.Puppet.get(entry.contact) for entry in contacts] if contacts else []
@property
def db_portals(self):
def db_portals(self) -> List[DBPortal]:
return [portal.db_instance for portal in self.portals.values()]
@db_portals.setter
def db_portals(self, portals):
if portals:
self.portals = {(portal.tgid, portal.tg_receiver):
po.Portal.get_by_tgid(portal.tgid, portal.tg_receiver)
for portal in portals}
else:
self.portals = {}
def db_portals(self, portals: List[DBPortal]):
self.portals = {(portal.tgid, portal.tg_receiver):
po.Portal.get_by_tgid(portal.tgid, portal.tg_receiver)
for portal in portals} if portals else {}
# region Database conversion
@property
def db_instance(self):
def db_instance(self) -> DBUser:
if not self._db_instance:
self._db_instance = self.new_db_instance()
return self._db_instance
def new_db_instance(self):
def new_db_instance(self) -> DBUser:
return DBUser(mxid=self.mxid, tgid=self.tgid, tg_username=self.username,
contacts=self.db_contacts, saved_contacts=self.saved_contacts or 0,
portals=self.db_portals)
@@ -134,14 +137,14 @@ class User(AbstractUser):
self.db.commit()
@classmethod
def from_db(cls, db_user):
def from_db(cls, db_user: DBUser) -> "User":
return User(db_user.mxid, db_user.tgid, db_user.tg_username, db_user.contacts,
False, db_user.saved_contacts, db_user.portals, db_instance=db_user)
# endregion
# region Telegram connection management
async def start(self, delete_unless_authenticated=False):
async def start(self, delete_unless_authenticated: bool = False) -> "User":
await super().start()
if await self.is_logged_in():
self.log.debug(f"Ensuring post_login() for {self.name}")
@@ -152,7 +155,7 @@ class User(AbstractUser):
self.client.session.delete()
return self
async def post_login(self, info=None):
async def post_login(self, info: TLUser = None):
try:
await self.update_info(info)
if not self.is_bot:
@@ -163,7 +166,7 @@ class User(AbstractUser):
except Exception:
self.log.exception("Failed to run post-login functions for %s", self.mxid)
async def update(self, update):
async def update(self, update: TypeUpdate):
if not self.is_bot:
return
@@ -186,7 +189,7 @@ class User(AbstractUser):
# endregion
# region Telegram actions that need custom methods
def ensure_started(self, even_if_no_session=False) -> "Awaitable[User]":
def ensure_started(self, even_if_no_session: bool = False) -> "Awaitable[User]":
return super().ensure_started(even_if_no_session)
def set_presence(self, online: bool = True):
@@ -194,7 +197,7 @@ class User(AbstractUser):
return
return self.client(UpdateStatusRequest(offline=not online))
async def update_info(self, info: User = None):
async def update_info(self, info: TLUser = None):
info = info or await self.client.get_me()
changed = False
if self.is_bot != info.bot:
@@ -233,8 +236,9 @@ class User(AbstractUser):
self.delete()
return True
def _search_local(self, query, max_results=5, min_similarity=45):
results = []
def _search_local(self, query: str, max_results: int = 5, min_similarity: int = 45
) -> SearchResults:
results = [] # type: SearchResults
for contact in self.contacts:
similarity = contact.similarity(query)
if similarity >= min_similarity:
@@ -242,11 +246,11 @@ class User(AbstractUser):
results.sort(key=lambda tup: tup[1], reverse=True)
return results[0:max_results]
async def _search_remote(self, query, max_results=5):
async def _search_remote(self, query: str, max_results: int = 5) -> SearchResults:
if len(query) < 5:
return []
server_results = await self.client(SearchRequest(q=query, limit=max_results))
results = []
results = [] # type: SearchResults
for user in server_results.users:
puppet = pu.Puppet.get(user.id)
await puppet.update_info(self, user)
@@ -254,7 +258,7 @@ class User(AbstractUser):
results.sort(key=lambda tup: tup[1], reverse=True)
return results[0:max_results]
async def search(self, query, force_remote=False):
async def search(self, query: str, force_remote: bool = False) -> Tuple[SearchResults, bool]:
if force_remote:
return await self._search_remote(query), True
@@ -264,7 +268,7 @@ class User(AbstractUser):
return await self._search_remote(query), True
async def sync_dialogs(self, synchronous_create=False):
async def sync_dialogs(self, synchronous_create: bool = False):
creators = []
for entity in await self.get_dialogs(limit=30):
portal = po.Portal.get_by_entity(entity)
@@ -275,7 +279,7 @@ class User(AbstractUser):
self.save()
await asyncio.gather(*creators, loop=self.loop)
def register_portal(self, portal):
def register_portal(self, portal: po.Portal):
try:
if self.portals[portal.tgid_full] == portal:
return
@@ -284,18 +288,18 @@ class User(AbstractUser):
self.portals[portal.tgid_full] = portal
self.save()
def unregister_portal(self, portal):
def unregister_portal(self, portal: po.Portal):
try:
del self.portals[portal.tgid_full]
self.save()
except KeyError:
pass
async def needs_relaybot(self, portal):
async def needs_relaybot(self, portal: po.Portal) -> bool:
return not await self.is_logged_in() or (
self.is_bot and portal.tgid_full not in self.portals)
def _hash_contacts(self):
def _hash_contacts(self) -> int:
acc = 0
for id in sorted([self.saved_contacts] + [contact.id for contact in self.contacts]):
acc = (acc * 20261 + id) & 0xffffffff
@@ -318,7 +322,7 @@ class User(AbstractUser):
# region Class instance lookup
@classmethod
def get_by_mxid(cls, mxid, create=True) -> "Optional[User]":
def get_by_mxid(cls, mxid: str, create: bool=True) -> "Optional[User]":
if not mxid:
raise ValueError("Matrix ID can't be empty")
@@ -341,7 +345,7 @@ class User(AbstractUser):
return None
@classmethod
def get_by_tgid(cls, tgid) -> "Optional[User]":
def get_by_tgid(cls, tgid: int) -> "Optional[User]":
try:
return cls.by_tgid[tgid]
except KeyError:
@@ -355,7 +359,7 @@ class User(AbstractUser):
return None
@classmethod
def find_by_username(cls, username) -> "Optional[User]":
def find_by_username(cls, username: str) -> "Optional[User]":
if not username:
return None
@@ -371,7 +375,7 @@ class User(AbstractUser):
# endregion
def init(context):
def init(context: "Context") -> List[Awaitable[User]]:
global config
config = context.config