Add more type hints
This commit is contained in:
+58
-54
@@ -14,42 +14,51 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Dict, Awaitable, Optional
|
||||
from typing import Dict, Awaitable, Optional, Match, Tuple, TYPE_CHECKING
|
||||
import logging
|
||||
import asyncio
|
||||
import re
|
||||
|
||||
from telethon.tl.types import *
|
||||
from telethon.tl.types import User as TLUser
|
||||
from telethon.tl.types.contacts import ContactsNotModified
|
||||
from telethon.tl.functions.contacts import GetContactsRequest, SearchRequest
|
||||
from telethon.tl.functions.account import UpdateStatusRequest
|
||||
from mautrix_appservice import MatrixRequestError
|
||||
|
||||
from .db import User as DBUser, Contact as DBContact
|
||||
from .db import User as DBUser, Contact as DBContact, Portal as DBPortal
|
||||
from .abstract_user import AbstractUser
|
||||
from . import portal as po, puppet as pu
|
||||
|
||||
config = None
|
||||
if TYPE_CHECKING:
|
||||
from .config import Config
|
||||
from .context import Context
|
||||
|
||||
config = None # type: Config
|
||||
|
||||
SearchResults = List[Tuple["pu.Puppet", int]]
|
||||
|
||||
|
||||
class User(AbstractUser):
|
||||
log = logging.getLogger("mau.user")
|
||||
by_mxid = {}
|
||||
by_tgid = {}
|
||||
log = logging.getLogger("mau.user") # type: logging.Logger
|
||||
by_mxid = {} # type: Dict[str, User]
|
||||
by_tgid = {} # type: Dict[int, User]
|
||||
|
||||
def __init__(self, mxid, tgid=None, username=None, db_contacts=None, saved_contacts=0,
|
||||
is_bot=False, db_portals=None, db_instance=None):
|
||||
def __init__(self, mxid: str, tgid: Optional[int] = None, username: Optional[str] = None,
|
||||
db_contacts: Optional[List[DBContact]] = None, saved_contacts: int = 0,
|
||||
is_bot: bool = False, db_portals: Optional[List[DBPortal]] = None,
|
||||
db_instance: Optional[DBUser] = None):
|
||||
super().__init__()
|
||||
self.mxid = mxid # type: str
|
||||
self.tgid = tgid # type: int
|
||||
self.is_bot = is_bot # type: bool
|
||||
self.username = username # type: str
|
||||
self.contacts = []
|
||||
self.saved_contacts = saved_contacts
|
||||
self.db_contacts = db_contacts
|
||||
self.portals = {} # type: Dict[str, po.Portal]
|
||||
self.db_portals = db_portals
|
||||
self._db_instance = db_instance
|
||||
self.contacts = [] # type: List[pu.Puppet]
|
||||
self.saved_contacts = saved_contacts # type: int
|
||||
self.db_contacts = db_contacts # type: List[DBContact]
|
||||
self.portals = {} # type: Dict[Tuple[int, int], po.Portal]
|
||||
self.db_portals = db_portals # type: List[DBPortal]
|
||||
self._db_instance = db_instance # type: DBUser
|
||||
|
||||
self.command_status = None # type: dict
|
||||
|
||||
@@ -64,53 +73,47 @@ class User(AbstractUser):
|
||||
self.by_tgid[tgid] = self
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
def name(self) -> str:
|
||||
return self.mxid
|
||||
|
||||
@property
|
||||
def mxid_localpart(self):
|
||||
match = re.compile("@(.+):(.+)").match(self.mxid)
|
||||
def mxid_localpart(self) -> str:
|
||||
match = re.compile("@(.+):(.+)").match(self.mxid) # type: Match
|
||||
return match.group(1)
|
||||
|
||||
# TODO replace with proper displayname getting everywhere
|
||||
@property
|
||||
def displayname(self):
|
||||
def displayname(self) -> str:
|
||||
return self.mxid_localpart
|
||||
|
||||
@property
|
||||
def db_contacts(self):
|
||||
def db_contacts(self) -> List[DBContact]:
|
||||
return [self.db.merge(DBContact(user=self.tgid, contact=puppet.id))
|
||||
for puppet in self.contacts]
|
||||
|
||||
@db_contacts.setter
|
||||
def db_contacts(self, contacts):
|
||||
if contacts:
|
||||
self.contacts = [pu.Puppet.get(entry.contact) for entry in contacts]
|
||||
else:
|
||||
self.contacts = []
|
||||
def db_contacts(self, contacts: List[DBContact]):
|
||||
self.contacts = [pu.Puppet.get(entry.contact) for entry in contacts] if contacts else []
|
||||
|
||||
@property
|
||||
def db_portals(self):
|
||||
def db_portals(self) -> List[DBPortal]:
|
||||
return [portal.db_instance for portal in self.portals.values()]
|
||||
|
||||
@db_portals.setter
|
||||
def db_portals(self, portals):
|
||||
if portals:
|
||||
self.portals = {(portal.tgid, portal.tg_receiver):
|
||||
po.Portal.get_by_tgid(portal.tgid, portal.tg_receiver)
|
||||
for portal in portals}
|
||||
else:
|
||||
self.portals = {}
|
||||
def db_portals(self, portals: List[DBPortal]):
|
||||
self.portals = {(portal.tgid, portal.tg_receiver):
|
||||
po.Portal.get_by_tgid(portal.tgid, portal.tg_receiver)
|
||||
for portal in portals} if portals else {}
|
||||
|
||||
# region Database conversion
|
||||
|
||||
@property
|
||||
def db_instance(self):
|
||||
def db_instance(self) -> DBUser:
|
||||
if not self._db_instance:
|
||||
self._db_instance = self.new_db_instance()
|
||||
return self._db_instance
|
||||
|
||||
def new_db_instance(self):
|
||||
def new_db_instance(self) -> DBUser:
|
||||
return DBUser(mxid=self.mxid, tgid=self.tgid, tg_username=self.username,
|
||||
contacts=self.db_contacts, saved_contacts=self.saved_contacts or 0,
|
||||
portals=self.db_portals)
|
||||
@@ -134,14 +137,14 @@ class User(AbstractUser):
|
||||
self.db.commit()
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, db_user):
|
||||
def from_db(cls, db_user: DBUser) -> "User":
|
||||
return User(db_user.mxid, db_user.tgid, db_user.tg_username, db_user.contacts,
|
||||
False, db_user.saved_contacts, db_user.portals, db_instance=db_user)
|
||||
|
||||
# endregion
|
||||
# region Telegram connection management
|
||||
|
||||
async def start(self, delete_unless_authenticated=False):
|
||||
async def start(self, delete_unless_authenticated: bool = False) -> "User":
|
||||
await super().start()
|
||||
if await self.is_logged_in():
|
||||
self.log.debug(f"Ensuring post_login() for {self.name}")
|
||||
@@ -152,7 +155,7 @@ class User(AbstractUser):
|
||||
self.client.session.delete()
|
||||
return self
|
||||
|
||||
async def post_login(self, info=None):
|
||||
async def post_login(self, info: TLUser = None):
|
||||
try:
|
||||
await self.update_info(info)
|
||||
if not self.is_bot:
|
||||
@@ -163,7 +166,7 @@ class User(AbstractUser):
|
||||
except Exception:
|
||||
self.log.exception("Failed to run post-login functions for %s", self.mxid)
|
||||
|
||||
async def update(self, update):
|
||||
async def update(self, update: TypeUpdate):
|
||||
if not self.is_bot:
|
||||
return
|
||||
|
||||
@@ -186,7 +189,7 @@ class User(AbstractUser):
|
||||
# endregion
|
||||
# region Telegram actions that need custom methods
|
||||
|
||||
def ensure_started(self, even_if_no_session=False) -> "Awaitable[User]":
|
||||
def ensure_started(self, even_if_no_session: bool = False) -> "Awaitable[User]":
|
||||
return super().ensure_started(even_if_no_session)
|
||||
|
||||
def set_presence(self, online: bool = True):
|
||||
@@ -194,7 +197,7 @@ class User(AbstractUser):
|
||||
return
|
||||
return self.client(UpdateStatusRequest(offline=not online))
|
||||
|
||||
async def update_info(self, info: User = None):
|
||||
async def update_info(self, info: TLUser = None):
|
||||
info = info or await self.client.get_me()
|
||||
changed = False
|
||||
if self.is_bot != info.bot:
|
||||
@@ -233,8 +236,9 @@ class User(AbstractUser):
|
||||
self.delete()
|
||||
return True
|
||||
|
||||
def _search_local(self, query, max_results=5, min_similarity=45):
|
||||
results = []
|
||||
def _search_local(self, query: str, max_results: int = 5, min_similarity: int = 45
|
||||
) -> SearchResults:
|
||||
results = [] # type: SearchResults
|
||||
for contact in self.contacts:
|
||||
similarity = contact.similarity(query)
|
||||
if similarity >= min_similarity:
|
||||
@@ -242,11 +246,11 @@ class User(AbstractUser):
|
||||
results.sort(key=lambda tup: tup[1], reverse=True)
|
||||
return results[0:max_results]
|
||||
|
||||
async def _search_remote(self, query, max_results=5):
|
||||
async def _search_remote(self, query: str, max_results: int = 5) -> SearchResults:
|
||||
if len(query) < 5:
|
||||
return []
|
||||
server_results = await self.client(SearchRequest(q=query, limit=max_results))
|
||||
results = []
|
||||
results = [] # type: SearchResults
|
||||
for user in server_results.users:
|
||||
puppet = pu.Puppet.get(user.id)
|
||||
await puppet.update_info(self, user)
|
||||
@@ -254,7 +258,7 @@ class User(AbstractUser):
|
||||
results.sort(key=lambda tup: tup[1], reverse=True)
|
||||
return results[0:max_results]
|
||||
|
||||
async def search(self, query, force_remote=False):
|
||||
async def search(self, query: str, force_remote: bool = False) -> Tuple[SearchResults, bool]:
|
||||
if force_remote:
|
||||
return await self._search_remote(query), True
|
||||
|
||||
@@ -264,7 +268,7 @@ class User(AbstractUser):
|
||||
|
||||
return await self._search_remote(query), True
|
||||
|
||||
async def sync_dialogs(self, synchronous_create=False):
|
||||
async def sync_dialogs(self, synchronous_create: bool = False):
|
||||
creators = []
|
||||
for entity in await self.get_dialogs(limit=30):
|
||||
portal = po.Portal.get_by_entity(entity)
|
||||
@@ -275,7 +279,7 @@ class User(AbstractUser):
|
||||
self.save()
|
||||
await asyncio.gather(*creators, loop=self.loop)
|
||||
|
||||
def register_portal(self, portal):
|
||||
def register_portal(self, portal: po.Portal):
|
||||
try:
|
||||
if self.portals[portal.tgid_full] == portal:
|
||||
return
|
||||
@@ -284,18 +288,18 @@ class User(AbstractUser):
|
||||
self.portals[portal.tgid_full] = portal
|
||||
self.save()
|
||||
|
||||
def unregister_portal(self, portal):
|
||||
def unregister_portal(self, portal: po.Portal):
|
||||
try:
|
||||
del self.portals[portal.tgid_full]
|
||||
self.save()
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
async def needs_relaybot(self, portal):
|
||||
async def needs_relaybot(self, portal: po.Portal) -> bool:
|
||||
return not await self.is_logged_in() or (
|
||||
self.is_bot and portal.tgid_full not in self.portals)
|
||||
|
||||
def _hash_contacts(self):
|
||||
def _hash_contacts(self) -> int:
|
||||
acc = 0
|
||||
for id in sorted([self.saved_contacts] + [contact.id for contact in self.contacts]):
|
||||
acc = (acc * 20261 + id) & 0xffffffff
|
||||
@@ -318,7 +322,7 @@ class User(AbstractUser):
|
||||
# region Class instance lookup
|
||||
|
||||
@classmethod
|
||||
def get_by_mxid(cls, mxid, create=True) -> "Optional[User]":
|
||||
def get_by_mxid(cls, mxid: str, create: bool=True) -> "Optional[User]":
|
||||
if not mxid:
|
||||
raise ValueError("Matrix ID can't be empty")
|
||||
|
||||
@@ -341,7 +345,7 @@ class User(AbstractUser):
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_by_tgid(cls, tgid) -> "Optional[User]":
|
||||
def get_by_tgid(cls, tgid: int) -> "Optional[User]":
|
||||
try:
|
||||
return cls.by_tgid[tgid]
|
||||
except KeyError:
|
||||
@@ -355,7 +359,7 @@ class User(AbstractUser):
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def find_by_username(cls, username) -> "Optional[User]":
|
||||
def find_by_username(cls, username: str) -> "Optional[User]":
|
||||
if not username:
|
||||
return None
|
||||
|
||||
@@ -371,7 +375,7 @@ class User(AbstractUser):
|
||||
# endregion
|
||||
|
||||
|
||||
def init(context):
|
||||
def init(context: "Context") -> List[Awaitable[User]]:
|
||||
global config
|
||||
config = context.config
|
||||
|
||||
|
||||
Reference in New Issue
Block a user