diff --git a/mautrix_telegram/__main__.py b/mautrix_telegram/__main__.py
index ad4cb4ef..ed566c6d 100644
--- a/mautrix_telegram/__main__.py
+++ b/mautrix_telegram/__main__.py
@@ -14,34 +14,33 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
+from typing import Optional
import argparse
-import sys
-import logging
-import logging.config
import asyncio
+import logging.config
+import sys
-import sqlalchemy as sql
from sqlalchemy import orm
+import sqlalchemy as sql
-from alchemysession import AlchemySessionContainer
from mautrix_appservice import AppService
+from alchemysession import AlchemySessionContainer
-from .base import Base
-from .config import Config
-from .matrix import MatrixHandler
-
-from . import __version__
-from .db import init as init_db
+from .web.provisioning import ProvisioningAPI
+from .web.public import PublicBridgeWebsite
from .abstract_user import init as init_abstract_user
-from .user import init as init_user, User
+from .base import Base
from .bot import init as init_bot
+from .config import Config
+from .context import Context
+from .db import init as init_db
+from .formatter import init as init_formatter
+from .matrix import MatrixHandler
from .portal import init as init_portal
from .puppet import init as init_puppet
-from .formatter import init as init_formatter
-from .web.public import PublicBridgeWebsite
-from .web.provisioning import ProvisioningAPI
-from .context import Context
from .sqlstatestore import SQLStateStore
+from .user import User, init as init_user
+from . import __version__
parser = argparse.ArgumentParser(
description="A Matrix-Telegram puppeting bridge.",
@@ -68,7 +67,7 @@ if args.generate_registration:
sys.exit(0)
logging.config.dictConfig(config["logging"])
-log = logging.getLogger("mau.init")
+log = logging.getLogger("mau.init") # type: logging.Logger
log.debug(f"Initializing mautrix-telegram {__version__}")
db_engine = sql.create_engine(config["appservice.database"] or "sqlite:///mautrix-telegram.db")
@@ -80,7 +79,7 @@ session_container = AlchemySessionContainer(engine=db_engine, session=db_session
table_base=Base, table_prefix="telethon_",
manage_tables=False)
-loop = asyncio.get_event_loop()
+loop = asyncio.get_event_loop() # type: asyncio.AbstractEventLoop
state_store = SQLStateStore(db_session)
appserv = AppService(config["homeserver.address"], config["homeserver.domain"],
@@ -89,8 +88,8 @@ appserv = AppService(config["homeserver.address"], config["homeserver.domain"],
verify_ssl=config["homeserver.verify_ssl"], state_store=state_store,
real_user_content_key="net.maunium.telegram.puppet")
-public_website = None
-provisioning_api = None
+public_website = None # type: Optional[PublicBridgeWebsite]
+provisioning_api = None # type: Optional[ProvisioningAPI]
if config["appservice.public.enabled"]:
public_website = PublicBridgeWebsite(loop)
diff --git a/mautrix_telegram/abstract_user.py b/mautrix_telegram/abstract_user.py
index 0de32c91..49632378 100644
--- a/mautrix_telegram/abstract_user.py
+++ b/mautrix_telegram/abstract_user.py
@@ -14,26 +14,48 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
+from typing import Tuple, Optional, List, Union, TYPE_CHECKING
+from abc import ABC, abstractmethod
+import asyncio
+import logging
import platform
-from telethon.tl.types import *
-from mautrix_appservice import MatrixRequestError
+from sqlalchemy import orm
+from telethon.tl.types import Channel, ChannelForbidden, Chat, ChatForbidden, Message, \
+ MessageActionChannelMigrateFrom, MessageService, PeerUser, TypeUpdate, \
+ UpdateChannelPinnedMessage, UpdateChatAdmins, UpdateChatParticipantAdmin, \
+ UpdateChatParticipants, UpdateChatUserTyping, UpdateDeleteChannelMessages, \
+ UpdateDeleteMessages, UpdateEditChannelMessage, UpdateEditMessage, UpdateNewChannelMessage, \
+ UpdateNewMessage, UpdateReadHistoryOutbox, UpdateShortChatMessage, UpdateShortMessage, \
+ UpdateUserName, UpdateUserPhoto, UpdateUserStatus, UpdateUserTyping, User, UserStatusOffline, \
+ UserStatusOnline
+
+from mautrix_appservice import MatrixRequestError, AppService
+from alchemysession import AlchemySessionContainer
-from .tgclient import MautrixTelegramClient
-from .db import Message as DBMessage
from . import portal as po, puppet as pu, __version__
+from .db import Message as DBMessage
+from .tgclient import MautrixTelegramClient
-config = None
+if TYPE_CHECKING:
+ from .context import Context
+ from .config import Config
+
+config = None # type: Config
# Value updated from config in init()
-MAX_DELETIONS = 10
+MAX_DELETIONS = 10 # type: int
+
+UpdateMessage = Union[UpdateShortChatMessage, UpdateShortMessage, UpdateNewChannelMessage,
+ UpdateNewMessage, UpdateEditMessage, UpdateEditChannelMessage]
+UpdateMessageContent = Union[UpdateShortMessage, UpdateShortChatMessage, Message, MessageService]
-class AbstractUser:
- session_container = None
- loop = None
- log = None
- db = None
- az = None
+class AbstractUser(ABC):
+ session_container = None # type: AlchemySessionContainer
+ loop = None # type: asyncio.AbstractEventLoop
+ log = None # type: logging.Logger
+ db = None # type: orm.Session
+ az = None # type: AppService
def __init__(self):
self.puppet_whitelisted = False # type: bool
@@ -47,22 +69,22 @@ class AbstractUser:
self.is_bot = False # type: bool
@property
- def connected(self):
+ def connected(self) -> bool:
return self.client and self.client.is_connected()
@property
- def _proxy_settings(self):
- type = config["telegram.proxy.type"].lower()
- if type == "disabled":
+ def _proxy_settings(self) -> Optional[Tuple[int, str, str, str, str, str]]:
+ proxy_type = config["telegram.proxy.type"].lower()
+ if proxy_type == "disabled":
return None
- elif type == "socks4":
- type = 1
- elif type == "socks5":
- type = 2
- elif type == "http":
- type = 3
+ elif proxy_type == "socks4":
+ proxy_type = 1
+ elif proxy_type == "socks5":
+ proxy_type = 2
+ elif proxy_type == "http":
+ proxy_type = 3
- return (type,
+ return (proxy_type,
config["telegram.proxy.address"], config["telegram.proxy.port"],
config["telegram.proxy.rdns"],
config["telegram.proxy.username"], config["telegram.proxy.password"])
@@ -83,20 +105,30 @@ class AbstractUser:
proxy=self._proxy_settings)
self.client.add_event_handler(self._update_catch)
- async def update(self, update):
+ @abstractmethod
+ async def update(self, update: TypeUpdate) -> bool:
return False
+ @abstractmethod
async def post_login(self):
raise NotImplementedError()
- async def _update_catch(self, update):
+ @abstractmethod
+ def register_portal(self, portal: po.Portal):
+ raise NotImplementedError()
+
+ @abstractmethod
+ def unregister_portal(self, portal: po.Portal):
+ raise NotImplementedError()
+
+ async def _update_catch(self, update: TypeUpdate):
try:
if not await self.update(update):
await self._update(update)
except Exception:
self.log.exception("Failed to handle Telegram update")
- async def get_dialogs(self, limit=None) -> List[Union[Chat, Channel]]:
+ async def get_dialogs(self, limit: int = None) -> List[Union[Chat, Channel]]:
if self.is_bot:
return []
dialogs = await self.client.get_dialogs(limit=limit)
@@ -106,18 +138,19 @@ class AbstractUser:
and (dialog.entity.deactivated or dialog.entity.left)))]
@property
- def name(self):
+ @abstractmethod
+ def name(self) -> str:
raise NotImplementedError()
- async def is_logged_in(self):
+ async def is_logged_in(self) -> bool:
return self.client and await self.client.is_user_authorized()
- async def has_full_access(self, allow_bot=False):
+ async def has_full_access(self, allow_bot: bool = False) -> bool:
return (self.puppet_whitelisted
and (not self.is_bot or allow_bot)
and await self.is_logged_in())
- async def start(self, delete_unless_authenticated=False):
+ async def start(self, delete_unless_authenticated: bool = False) -> "AbstractUser":
if not self.client:
self._init_client()
await self.client.connect()
@@ -144,7 +177,7 @@ class AbstractUser:
# region Telegram update handling
- async def _update(self, update):
+ async def _update(self, update: TypeUpdate):
if isinstance(update, (UpdateShortChatMessage, UpdateShortMessage, UpdateNewChannelMessage,
UpdateNewMessage, UpdateEditMessage, UpdateEditChannelMessage)):
await self.update_message(update)
@@ -169,17 +202,19 @@ class AbstractUser:
else:
self.log.debug("Unhandled update: %s", update)
- async def update_pinned_messages(self, update):
+ @staticmethod
+ async def update_pinned_messages(update: UpdateChannelPinnedMessage):
portal = po.Portal.get_by_tgid(update.channel_id)
if portal and portal.mxid:
await portal.receive_telegram_pin_id(update.id)
- async def update_participants(self, update):
+ @staticmethod
+ async def update_participants(update: UpdateChatParticipants):
portal = po.Portal.get_by_tgid(update.participants.chat_id)
if portal and portal.mxid:
await portal.update_telegram_participants(update.participants.participants)
- async def update_read_receipt(self, update):
+ async def update_read_receipt(self, update: UpdateReadHistoryOutbox):
if not isinstance(update.peer, PeerUser):
self.log.debug("Unexpected read receipt peer: %s", update.peer)
return
@@ -196,7 +231,7 @@ class AbstractUser:
puppet = pu.Puppet.get(update.peer.user_id)
await puppet.intent.mark_read(portal.mxid, message.mxid)
- async def update_admin(self, update):
+ async def update_admin(self, update: Union[UpdateChatAdmins, UpdateChatParticipantAdmin]):
# TODO duplication not checked
portal = po.Portal.get_by_tgid(update.chat_id, peer_type="chat")
if isinstance(update, UpdateChatAdmins):
@@ -206,7 +241,7 @@ class AbstractUser:
else:
self.log.warning("Unexpected admin status update: %s", update)
- async def update_typing(self, update):
+ async def update_typing(self, update: Union[UpdateUserTyping, UpdateChatUserTyping]):
if isinstance(update, UpdateUserTyping):
portal = po.Portal.get_by_tgid(update.user_id, self.tgid, "user")
else:
@@ -214,7 +249,7 @@ class AbstractUser:
sender = pu.Puppet.get(update.user_id)
await portal.handle_telegram_typing(sender, update)
- async def update_others_info(self, update):
+ async def update_others_info(self, update: Union[UpdateUserName, UpdateUserPhoto]):
# TODO duplication not checked
puppet = pu.Puppet.get(update.user_id)
if isinstance(update, UpdateUserName):
@@ -226,7 +261,7 @@ class AbstractUser:
else:
self.log.warning("Unexpected other user info update: %s", update)
- async def update_status(self, update):
+ async def update_status(self, update: UpdateUserStatus):
puppet = pu.Puppet.get(update.user_id)
if isinstance(update.status, UserStatusOnline):
await puppet.default_mxid_intent.set_presence("online")
@@ -236,7 +271,9 @@ class AbstractUser:
self.log.warning("Unexpected user status update: %s", update)
return
- def get_message_details(self, update):
+ def get_message_details(self, update: UpdateMessage) -> Tuple[UpdateMessageContent,
+ Optional[pu.Puppet],
+ Optional[po.Portal]]:
if isinstance(update, UpdateShortChatMessage):
portal = po.Portal.get_by_tgid(update.chat_id, peer_type="chat")
sender = pu.Puppet.get(update.from_id)
@@ -259,7 +296,7 @@ class AbstractUser:
return update, sender, portal
@staticmethod
- async def _try_redact(portal, message):
+ async def _try_redact(portal: po.Portal, message: DBMessage):
if not portal:
return
try:
@@ -267,7 +304,7 @@ class AbstractUser:
except MatrixRequestError:
pass
- async def delete_message(self, update):
+ async def delete_message(self, update: UpdateDeleteMessages):
if len(update.messages) > MAX_DELETIONS:
return
@@ -283,7 +320,7 @@ class AbstractUser:
await self._try_redact(portal, message)
self.db.commit()
- async def delete_channel_message(self, update):
+ async def delete_channel_message(self, update: UpdateDeleteChannelMessages):
if len(update.messages) > MAX_DELETIONS:
return
@@ -299,7 +336,7 @@ class AbstractUser:
await self._try_redact(portal, message)
self.db.commit()
- async def update_message(self, original_update):
+ async def update_message(self, original_update: UpdateMessage):
update, sender, portal = self.get_message_details(original_update)
if isinstance(update, MessageService):
@@ -325,7 +362,7 @@ class AbstractUser:
# endregion
-def init(context):
+def init(context: "Context"):
global config, MAX_DELETIONS
AbstractUser.az, AbstractUser.db, config, AbstractUser.loop, _ = context
AbstractUser.session_container = context.session_container
diff --git a/mautrix_telegram/base.py b/mautrix_telegram/base.py
index c64447da..0b62d886 100644
--- a/mautrix_telegram/base.py
+++ b/mautrix_telegram/base.py
@@ -1,2 +1,2 @@
from sqlalchemy.ext.declarative import declarative_base
-Base = declarative_base()
+Base = declarative_base() # type: declarative_base
diff --git a/mautrix_telegram/bot.py b/mautrix_telegram/bot.py
index c05a62aa..51a6a110 100644
--- a/mautrix_telegram/bot.py
+++ b/mautrix_telegram/bot.py
@@ -14,7 +14,7 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-from typing import Awaitable, Callable
+from typing import Awaitable, Callable, Pattern, Dict, TYPE_CHECKING
import logging
import re
@@ -27,27 +27,31 @@ from .abstract_user import AbstractUser
from .db import BotChat
from . import puppet as pu, portal as po, user as u
-config = None
+if TYPE_CHECKING:
+ from .config import Config
+
+config = None # type: Config
ReplyFunc = Callable[[str], Awaitable[Message]]
class Bot(AbstractUser):
- log = logging.getLogger("mau.bot")
- mxid_regex = re.compile("@.+:.+")
+ log = logging.getLogger("mau.bot") # type: logging.Logger
+ mxid_regex = re.compile("@.+:.+") # type: Pattern
def __init__(self, token: str):
super().__init__()
- self.token = token
- self.puppet_whitelisted = True
- self.whitelisted = True
- self.relaybot_whitelisted = True
- self.username = None
- self.is_relaybot = True
- self.is_bot = True
- self.chats = {chat.id: chat.type for chat in BotChat.query.all()}
- self.tg_whitelist = []
- self.whitelist_group_admins = config["bridge.relaybot.whitelist_group_admins"] or False
+ self.token = token # type: str
+ self.puppet_whitelisted = True # type: bool
+ self.whitelisted = True # type: bool
+ self.relaybot_whitelisted = True # type: bool
+ self.username = None # type: str
+ self.is_relaybot = True # type: bool
+ self.is_bot = True # type: bool
+ self.chats = {chat.id: chat.type for chat in BotChat.query.all()} # type: Dict[int, str]
+ self.tg_whitelist = [] # type: List[int]
+ self.whitelist_group_admins = (config["bridge.relaybot.whitelist_group_admins"]
+ or False) # type: bool
async def init_permissions(self):
whitelist = config["bridge.relaybot.whitelist"] or []
@@ -61,7 +65,7 @@ class Bot(AbstractUser):
if isinstance(id, int):
self.tg_whitelist.append(id)
- async def start(self, delete_unless_authenticated=False):
+ async def start(self, delete_unless_authenticated: bool = False) -> "Bot":
await super().start(delete_unless_authenticated)
if not await self.is_logged_in():
await self.client.sign_in(bot_token=self.token)
@@ -118,7 +122,7 @@ class Bot(AbstractUser):
self.db.delete(existing_chat)
self.db.commit()
- async def _can_use_commands(self, chat, tgid):
+ async def _can_use_commands(self, chat: TypePeer, tgid: int) -> bool:
if tgid in self.tg_whitelist:
return True
@@ -138,7 +142,7 @@ class Bot(AbstractUser):
if p.user_id == tgid:
return isinstance(p, (ChatParticipantCreator, ChatParticipantAdmin))
- async def check_can_use_commands(self, event: Message, reply: ReplyFunc):
+ async def check_can_use_commands(self, event: Message, reply: ReplyFunc) -> bool:
if not await self._can_use_commands(event.to_id, event.from_id):
await reply("You do not have the permission to use that command.")
return False
@@ -262,7 +266,7 @@ class Bot(AbstractUser):
return "bot"
-def init(context):
+def init(context) -> Optional[Bot]:
global config
config = context.config
token = config["telegram.bot_token"]
diff --git a/mautrix_telegram/commands/clean_rooms.py b/mautrix_telegram/commands/clean_rooms.py
index e9031d3b..aac5a54d 100644
--- a/mautrix_telegram/commands/clean_rooms.py
+++ b/mautrix_telegram/commands/clean_rooms.py
@@ -23,15 +23,14 @@ from .. import puppet as pu, portal as po
ManagementRoomList = List[Tuple[str, str]]
RoomIDList = List[str]
-PortalList = List[po.Portal]
-async def _find_rooms(intent: IntentAPI) -> Tuple[
- ManagementRoomList, RoomIDList, PortalList, PortalList]:
+async def _find_rooms(intent: IntentAPI) -> Tuple[ManagementRoomList, RoomIDList,
+ List["po.Portal"], List["po.Portal"]]:
management_rooms = [] # type: ManagementRoomList
unidentified_rooms = [] # type: RoomIDList
- portals = [] # type: PortalList
- empty_portals = [] # type: PortalList
+ portals = [] # type: List[po.Portal]
+ empty_portals = [] # type: List[po.Portal]
rooms = await intent.get_joined_rooms()
for room in rooms:
@@ -108,8 +107,8 @@ async def clean_rooms(evt: CommandEvent):
async def set_rooms_to_clean(evt, management_rooms: ManagementRoomList,
- unidentified_rooms: RoomIDList, portals: PortalList,
- empty_portals: PortalList):
+ unidentified_rooms: RoomIDList, portals: List["po.Portal"],
+ empty_portals: List["po.Portal"]):
command = evt.args[0]
rooms_to_clean = []
if command == "clean-recommended":
diff --git a/mautrix_telegram/commands/portal.py b/mautrix_telegram/commands/portal.py
index 0c88ca74..c2ff2347 100644
--- a/mautrix_telegram/commands/portal.py
+++ b/mautrix_telegram/commands/portal.py
@@ -222,7 +222,7 @@ async def bridge(evt: CommandEvent):
"chat to this room, use `$cmdprefix+sp continue`")
-async def cleanup_old_portal_while_bridging(evt: CommandEvent, portal: po.Portal):
+async def cleanup_old_portal_while_bridging(evt: CommandEvent, portal: "po.Portal"):
if not portal.mxid:
await evt.reply("The portal seems to have lost its Matrix room between you"
"calling `$cmdprefix+sp bridge` and this command.\n\n"
diff --git a/mautrix_telegram/config.py b/mautrix_telegram/config.py
index b7766f47..72e61f27 100644
--- a/mautrix_telegram/config.py
+++ b/mautrix_telegram/config.py
@@ -14,6 +14,7 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
+from typing import Tuple, Any, Optional
from ruamel.yaml import YAML
from ruamel.yaml.comments import CommentedMap
import random
@@ -24,28 +25,28 @@ yaml.indent(4)
class DictWithRecursion:
- def __init__(self, data=None):
- self._data = data or CommentedMap()
+ def __init__(self, data: CommentedMap = None):
+ self._data = data or CommentedMap() # type: CommentedMap
- def _recursive_get(self, data, key, default_value):
+ def _recursive_get(self, data: CommentedMap, key: str, default_value: Any) -> Any:
if '.' in key:
key, next_key = key.split('.', 1)
next_data = data.get(key, CommentedMap())
return self._recursive_get(next_data, next_key, default_value)
return data.get(key, default_value)
- def get(self, key, default_value, allow_recursion=True):
+ def get(self, key: str, default_value: Any, allow_recursion: bool = True) -> Any:
if allow_recursion and '.' in key:
return self._recursive_get(self._data, key, default_value)
return self._data.get(key, default_value)
- def __getitem__(self, key):
+ def __getitem__(self, key: str) -> Any:
return self.get(key, None)
- def __contains__(self, key):
+ def __contains__(self, key: str) -> bool:
return self[key] is not None
- def _recursive_set(self, data, key, value):
+ def _recursive_set(self, data: CommentedMap, key: str, value: Any):
if '.' in key:
key, next_key = key.split('.', 1)
if key not in data:
@@ -55,16 +56,16 @@ class DictWithRecursion:
return
data[key] = value
- def set(self, key, value, allow_recursion=True):
+ def set(self, key: str, value: Any, allow_recursion: bool = True):
if allow_recursion and '.' in key:
self._recursive_set(self._data, key, value)
return
self._data[key] = value
- def __setitem__(self, key, value):
+ def __setitem__(self, key: str, value: Any):
self.set(key, value)
- def _recursive_del(self, data, key):
+ def _recursive_del(self, data: CommentedMap, key: str):
if '.' in key:
key, next_key = key.split('.', 1)
if key not in data:
@@ -78,7 +79,7 @@ class DictWithRecursion:
except KeyError:
pass
- def delete(self, key, allow_recursion=True):
+ def delete(self, key: str, allow_recursion: bool = True):
if allow_recursion and '.' in key:
self._recursive_del(self._data, key)
return
@@ -88,23 +89,23 @@ class DictWithRecursion:
except KeyError:
pass
- def __delitem__(self, key):
+ def __delitem__(self, key: str):
self.delete(key)
class Config(DictWithRecursion):
- def __init__(self, path, registration_path, base_path):
+ def __init__(self, path: str, registration_path: str, base_path: str):
super().__init__()
- self.path = path
- self.registration_path = registration_path
- self.base_path = base_path
- self._registration = None
+ self.path = path # type: str
+ self.registration_path = registration_path # type: str
+ self.base_path = base_path # type: str
+ self._registration = None # type: dict
def load(self):
with open(self.path, 'r') as stream:
self._data = yaml.load(stream)
- def load_base(self):
+ def load_base(self) -> Optional[DictWithRecursion]:
try:
with open(self.base_path, 'r') as stream:
return DictWithRecursion(yaml.load(stream))
@@ -120,7 +121,7 @@ class Config(DictWithRecursion):
yaml.dump(self._registration, stream)
@staticmethod
- def _new_token():
+ def _new_token() -> str:
return "".join(random.choice(string.ascii_lowercase + string.digits) for _ in range(64))
def update(self):
@@ -246,7 +247,7 @@ class Config(DictWithRecursion):
self._data = base._data
self.save()
- def _get_permissions(self, key):
+ def _get_permissions(self, key: str) -> Tuple[bool, bool, bool, bool, bool]:
level = self["bridge.permissions"].get(key, "")
admin = level == "admin"
puppeting = level == "full" or admin
@@ -254,7 +255,7 @@ class Config(DictWithRecursion):
relaybot = level == "relaybot" or user
return relaybot, user, puppeting, admin, level
- def get_permissions(self, mxid):
+ def get_permissions(self, mxid: str) -> Tuple[bool, bool, bool, bool, bool]:
permissions = self["bridge.permissions"] or {}
if mxid in permissions:
return self._get_permissions(mxid)
diff --git a/mautrix_telegram/context.py b/mautrix_telegram/context.py
index 1324e5f1..76f75ded 100644
--- a/mautrix_telegram/context.py
+++ b/mautrix_telegram/context.py
@@ -14,21 +14,27 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-from typing import Tuple
-import asyncio
+from typing import TYPE_CHECKING
+
+if TYPE_CHECKING:
+ import asyncio
+
+ from sqlalchemy.orm import scoped_session
+
+ from alchemysession import AlchemySessionContainer
+ from mautrix_appservice import AppService
+
+ from .web import PublicBridgeWebsite, ProvisioningAPI
+ from .config import Config
+ from .bot import Bot
+ from .matrix import MatrixHandler
-from sqlalchemy.orm import scoped_session
-from alchemysession import AlchemySessionContainer
-from mautrix_appservice import AppService
class Context:
- def __init__(self, az, db, config, loop, bot, mx, session_container, public_website,
- provisioning_api):
- from .web import PublicBridgeWebsite, ProvisioningAPI
- from .config import Config
- from .bot import Bot
- from .matrix import MatrixHandler
-
+ def __init__(self, az: "AppService", db: "scoped_session", config: "Config",
+ loop: "asyncio.AbstractEventLoop", bot: "Bot", mx: "MatrixHandler",
+ session_container: "AlchemySessionContainer",
+ public_website: "PublicBridgeWebsite", provisioning_api: "ProvisioningAPI"):
self.az = az # type: AppService
self.db = db # type: scoped_session
self.config = config # type: Config
diff --git a/mautrix_telegram/db.py b/mautrix_telegram/db.py
index 81bc0598..5a0baf70 100644
--- a/mautrix_telegram/db.py
+++ b/mautrix_telegram/db.py
@@ -42,6 +42,7 @@ class Portal(Base):
about = Column(String, nullable=True)
photo_id = Column(String, nullable=True)
+
class Message(Base):
query = None # type: Query
__tablename__ = "message"
diff --git a/mautrix_telegram/formatter/from_matrix.py b/mautrix_telegram/formatter/from_matrix.py
index f98d3ad5..6619ef02 100644
--- a/mautrix_telegram/formatter/from_matrix.py
+++ b/mautrix_telegram/formatter/from_matrix.py
@@ -14,10 +14,10 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
+from typing import (Optional, List, Tuple, Type, Callable, Dict, Any, Pattern, Deque, Match, TYPE_CHECKING)
from html import unescape
from html.parser import HTMLParser
from collections import deque
-from typing import Optional, List, Tuple, Type, Callable, Dict, Any
import math
import re
import logging
@@ -27,37 +27,40 @@ from telethon.tl.types import (MessageEntityMention, MessageEntityMentionName, M
MessageEntityItalic, MessageEntityCode, MessageEntityPre,
MessageEntityBotCommand, TypeMessageEntity)
-from .. import user as u, puppet as pu, portal as po, context as c
+from .. import user as u, puppet as pu, portal as po
from ..db import Message as DBMessage
from .util import (add_surrogates, remove_surrogates, trim_reply_fallback_html,
trim_reply_fallback_text, html_to_unicode)
-log = logging.getLogger("mau.fmt.mx")
-should_bridge_plaintext_highlights = False
+if TYPE_CHECKING:
+ from ..context import Context
+
+log = logging.getLogger("mau.fmt.mx") # type: logging.Logger
+should_bridge_plaintext_highlights = False # type: bool
class MatrixParser(HTMLParser):
- mention_regex = re.compile("https://matrix.to/#/(@.+:.+)")
- room_regex = re.compile("https://matrix.to/#/(#.+:.+)")
+ mention_regex = re.compile("https://matrix.to/#/(@.+:.+)") # type: Pattern
+ room_regex = re.compile("https://matrix.to/#/(#.+:.+)") # type: Pattern
block_tags = ("br", "p", "pre", "blockquote",
"ol", "ul", "li",
"h1", "h2", "h3", "h4", "h5", "h6",
- "div", "hr", "table")
+ "div", "hr", "table") # type: Tuple[str, ...]
def __init__(self):
super().__init__()
- self.text = ""
- self.entities = []
- self._building_entities = {}
- self._list_counter = 0
- self._open_tags = deque()
- self._open_tags_meta = deque()
- self._line_is_new = True
- self._list_entry_is_new = False
+ self.text = "" # type: str
+ self.entities = [] # type: List[TypeMessageEntity]
+ self._building_entities = {} # type: Dict[str, TypeMessageEntity]
+ self._list_counter = 0 # type: int
+ self._open_tags = deque() # type: Deque[str]
+ self._open_tags_meta = deque() # type: Deque[Any]
+ self._line_is_new = True # type: bool
+ self._list_entry_is_new = False # type: bool
def _parse_url(self, url: str, args: Dict[str, Any]
) -> Tuple[Optional[Type[TypeMessageEntity]], Optional[str]]:
- mention = self.mention_regex.match(url)
+ mention = self.mention_regex.match(url) # type: Match
if mention:
mxid = mention.group(1)
user = (pu.Puppet.get_by_mxid(mxid)
@@ -72,7 +75,7 @@ class MatrixParser(HTMLParser):
else:
return None, None
- room = self.room_regex.match(url)
+ room = self.room_regex.match(url) # type: Match
if room:
username = po.Portal.get_username_from_mx_alias(room.group(1))
portal = po.Portal.find_by_username(username)
@@ -92,8 +95,8 @@ class MatrixParser(HTMLParser):
self._open_tags_meta.appendleft(0)
attrs = dict(attrs)
- entity_type = None
- args = {}
+ entity_type = None # type: type(TypeMessageEntity)
+ args = {} # type: Dict[str, Any]
if tag in ("strong", "b"):
entity_type = MessageEntityBold
elif tag in ("em", "i"):
@@ -243,12 +246,12 @@ class MatrixParser(HTMLParser):
self._newline(allow_multi=tag == "br")
-command_regex = re.compile(r"^!([A-Za-z0-9@]+)")
-not_command_regex = re.compile(r"^\\(![A-Za-z0-9@]+)")
-plain_mention_regex = None
+command_regex = re.compile(r"^!([A-Za-z0-9@]+)") # type: Pattern
+not_command_regex = re.compile(r"^\\(![A-Za-z0-9@]+)") # type: Pattern
+plain_mention_regex = None # type: Pattern
-def plain_mention_to_html(match):
+def plain_mention_to_html(match: Match) -> str:
puppet = pu.Puppet.find_by_displayname(match.group(2))
if puppet:
return (f"{match.group(1)}"
@@ -351,7 +354,7 @@ def plain_mention_to_text() -> Tuple[List[TypeMessageEntity], Callable[[str], st
return entities, replacer
-def init_mx(context: c.Context):
+def init_mx(context: "Context"):
global plain_mention_regex, should_bridge_plaintext_highlights
config = context.config
dn_template = config.get("bridge.displayname_template", "{displayname} (Telegram)")
diff --git a/mautrix_telegram/formatter/from_telegram.py b/mautrix_telegram/formatter/from_telegram.py
index 70a13a55..33f8a335 100644
--- a/mautrix_telegram/formatter/from_telegram.py
+++ b/mautrix_telegram/formatter/from_telegram.py
@@ -14,13 +14,8 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
+from typing import Optional, List, Tuple, TYPE_CHECKING
from html import escape
-from typing import Optional, List, Tuple
-
-try:
- from lxml.html.diff import htmldiff
-except ImportError:
- htmldiff = None # type: function
import logging
import re
@@ -33,16 +28,26 @@ from telethon.tl.types import (MessageEntityMention, MessageEntityMentionName,
from mautrix_appservice import MatrixRequestError
from mautrix_appservice.intent_api import IntentAPI
-from .. import user as u, puppet as pu, portal as po, context as c
+from .. import user as u, puppet as pu, portal as po
from ..db import Message as DBMessage
from .util import (add_surrogates, remove_surrogates, trim_reply_fallback_html,
trim_reply_fallback_text, unicode_to_html)
-log = logging.getLogger("mau.fmt.tg")
-should_highlight_edits = False
+if TYPE_CHECKING:
+ from ..abstract_user import AbstractUser
+ from ..context import Context
+
+try:
+ from lxml.html.diff import htmldiff
+except ImportError:
+ htmldiff = None # type: function
-def telegram_reply_to_matrix(evt: Message, source: u.User) -> dict:
+log = logging.getLogger("mau.fmt.tg") # type: logging.Logger
+should_highlight_edits = False # type: bool
+
+
+def telegram_reply_to_matrix(evt: Message, source: "AbstractUser") -> dict:
if evt.reply_to_msg_id:
space = (evt.to_id.channel_id
if isinstance(evt, Message) and isinstance(evt.to_id, PeerChannel)
@@ -78,7 +83,7 @@ async def _add_forward_header(source, text: str, html: Optional[str],
if not fwd_from_text:
user = await source.client.get_entity(PeerUser(fwd_from.from_id))
if user:
- fwd_from_text = pu.Puppet.get_displayname(user, format=False)
+ fwd_from_text = pu.Puppet.get_displayname(user, False)
fwd_from_html = f"{fwd_from_text}"
if not fwd_from_text:
@@ -110,8 +115,9 @@ def highlight_edits(new_html: str, old_html: str) -> str:
return new_html
-async def _add_reply_header(source: u.User, text: str, html: str, evt: Message, relates_to: dict,
- main_intent: IntentAPI, is_edit: bool) -> Tuple[str, str]:
+async def _add_reply_header(source: "AbstractUser", text: str, html: str, evt: Message,
+ relates_to: dict, main_intent: IntentAPI, is_edit: bool
+ ) -> Tuple[str, str]:
space = (evt.to_id.channel_id
if isinstance(evt, Message) and isinstance(evt.to_id, PeerChannel)
else source.tgid)
@@ -142,7 +148,7 @@ async def _add_reply_header(source: u.User, text: str, html: str, evt: Message,
if is_edit and should_highlight_edits:
html = highlight_edits(html or escape(text), r_html_body)
- except (ValueError, KeyError, MatrixRequestError) as e:
+ except (ValueError, KeyError, MatrixRequestError):
r_sender_link = "unknown user"
r_displayname = "unknown user"
r_text_body = "Failed to fetch message"
@@ -154,8 +160,9 @@ async def _add_reply_header(source: u.User, text: str, html: str, evt: Message,
r_keyword = "In reply to" if not is_edit else "Edit to"
r_msg_link = f"{r_keyword}"
- html = (f"{r_msg_link} {r_sender_link}\n{r_html_body}
"
- + (html or escape(text)))
+ html = (
+ f"{r_msg_link} {r_sender_link}\n{r_html_body}
"
+ + (html or escape(text)))
lines = r_text_body.strip().split("\n")
text_with_quote = f"> <{r_displayname}> {lines.pop(0)}"
@@ -167,7 +174,8 @@ async def _add_reply_header(source: u.User, text: str, html: str, evt: Message,
return text_with_quote, html
-async def telegram_to_matrix(evt: Message, source: u.User, main_intent: Optional[IntentAPI] = None,
+async def telegram_to_matrix(evt: Message, source: "AbstractUser",
+ main_intent: Optional[IntentAPI] = None,
is_edit: bool = False, prefix_text: Optional[str] = None,
prefix_html: Optional[str] = None) -> Tuple[str, str, dict]:
text = add_surrogates(evt.message)
@@ -320,6 +328,6 @@ def _parse_url(html: List[str], entity_text: str, url: str) -> bool:
return False
-def init_tg(context: c.Context):
+def init_tg(context: "Context"):
global should_highlight_edits
should_highlight_edits = htmldiff and context.config["bridge.highlight_edits"]
diff --git a/mautrix_telegram/formatter/util.py b/mautrix_telegram/formatter/util.py
index f464ffe5..2a296614 100644
--- a/mautrix_telegram/formatter/util.py
+++ b/mautrix_telegram/formatter/util.py
@@ -14,8 +14,8 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
+from typing import Optional, Pattern
from html import escape
-from typing import Optional
import struct
import re
@@ -47,7 +47,7 @@ def trim_reply_fallback_text(text: str) -> str:
html_reply_fallback_regex = re.compile("^"
r"[\s\S]+?"
- "")
+ "") # type: Pattern
def trim_reply_fallback_html(html: str) -> str:
diff --git a/mautrix_telegram/matrix.py b/mautrix_telegram/matrix.py
index 28cf6796..8feed9f4 100644
--- a/mautrix_telegram/matrix.py
+++ b/mautrix_telegram/matrix.py
@@ -14,26 +14,23 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-from typing import List, Dict
+from typing import List, Dict, Tuple, Set, Match
import logging
import asyncio
import re
from mautrix_appservice import MatrixRequestError, IntentError
-from .user import User
-from .portal import Portal
-from .puppet import Puppet
-from .commands import CommandProcessor
+from . import user as u, portal as po, puppet as pu, commands as com
class MatrixHandler:
- log = logging.getLogger("mau.mx")
+ log = logging.getLogger("mau.mx") # type: logging.Logger
def __init__(self, context):
self.az, self.db, self.config, _, self.tgbot = context
- self.commands = CommandProcessor(context)
- self.previously_typing = []
+ self.commands = com.CommandProcessor(context) # type: com.CommandProcessor
+ self.previously_typing = [] # type: List[str]
self.az.matrix_event_handler(self.handle_event)
@@ -53,68 +50,68 @@ class MatrixHandler:
except asyncio.TimeoutError:
self.log.exception("TimeoutError when trying to set avatar")
- async def handle_puppet_invite(self, room, puppet, inviter):
+ async def handle_puppet_invite(self, room_id, puppet: pu.Puppet, inviter: u.User):
intent = puppet.default_mxid_intent
- self.log.debug(f"{inviter} invited puppet for {puppet.tgid} to {room}")
+ self.log.debug(f"{inviter} invited puppet for {puppet.tgid} to {room_id}")
if not await inviter.is_logged_in():
await intent.error_and_leave(
- room, text="Please log in before inviting Telegram puppets.")
+ room_id, text="Please log in before inviting Telegram puppets.")
return
- portal = Portal.get_by_mxid(room)
+ portal = po.Portal.get_by_mxid(room_id)
if portal:
if portal.peer_type == "user":
await intent.error_and_leave(
- room, text="You can not invite additional users to private chats.")
+ room_id, text="You can not invite additional users to private chats.")
return
await portal.invite_telegram(inviter, puppet)
- await intent.join_room(room)
+ await intent.join_room(room_id)
return
try:
- members = await self.az.intent.get_room_members(room)
+ members = await self.az.intent.get_room_members(room_id)
except MatrixRequestError:
members = []
if self.az.bot_mxid not in members:
if len(members) > 1:
- await intent.error_and_leave(room, text=None, html=(
+ await intent.error_and_leave(room_id, text=None, html=(
f"Please invite "
f"the bridge bot "
f"first if you want to create a Telegram chat."))
return
- await intent.join_room(room)
- portal = Portal.get_by_tgid(puppet.tgid, inviter.tgid, "user")
+ await intent.join_room(room_id)
+ portal = po.Portal.get_by_tgid(puppet.tgid, inviter.tgid, "user")
if portal.mxid:
try:
await intent.invite(portal.mxid, inviter.mxid)
- await intent.send_notice(room, text=None, html=(
+ await intent.send_notice(room_id, text=None, html=(
"You already have a private chat with me: "
f""
"Link to room"
""))
- await intent.leave_room(room)
+ await intent.leave_room(room_id)
return
except MatrixRequestError:
pass
- portal.mxid = room
+ portal.mxid = room_id
portal.save()
inviter.register_portal(portal)
- await intent.send_notice(room, "Portal to private chat created.")
+ await intent.send_notice(room_id, "po.Portal to private chat created.")
else:
- await intent.join_room(room)
- await intent.send_notice(room, "This puppet will remain inactive until a "
- "Telegram chat is created for this room.")
+ await intent.join_room(room_id)
+ await intent.send_notice(room_id, "This puppet will remain inactive until a "
+ "Telegram chat is created for this room.")
- async def accept_bot_invite(self, room, inviter):
+ async def accept_bot_invite(self, room_id: str, inviter: u.User):
tries = 0
while tries < 5:
try:
- await self.az.intent.join_room(room)
+ await self.az.intent.join_room(room_id)
break
- except (IntentError, MatrixRequestError) as e:
+ except (IntentError, MatrixRequestError):
tries += 1
wait_for_seconds = (tries + 1) * 10
if tries < 5:
- self.log.exception(f"Failed to join room {room} with bridge bot, "
+ self.log.exception(f"Failed to join room {room_id} with bridge bot, "
f"retrying in {wait_for_seconds} seconds...")
await asyncio.sleep(wait_for_seconds)
else:
@@ -123,81 +120,81 @@ class MatrixHandler:
if not inviter.whitelisted:
await self.az.intent.send_notice(
- room, text=None,
+ room_id, text=None,
html="You are not whitelisted to use this bridge.
"
"If you are the owner of this bridge, see the "
"bridge.permissions section in your config file.")
- await self.az.intent.leave_room(room)
+ await self.az.intent.leave_room(room_id)
- async def handle_invite(self, room, user, inviter):
- self.log.debug(f"{inviter} invited {user} to {room}")
- inviter = await User.get_by_mxid(inviter).ensure_started()
- if user == self.az.bot_mxid:
- return await self.accept_bot_invite(room, inviter)
+ async def handle_invite(self, room_id: str, user_id: str, inviter_mxid: str):
+ self.log.debug(f"{inviter_mxid} invited {user_id} to {room_id}")
+ inviter = await u.User.get_by_mxid(inviter_mxid).ensure_started()
+ if user_id == self.az.bot_mxid:
+ return await self.accept_bot_invite(room_id, inviter)
elif not inviter.whitelisted:
return
- puppet = Puppet.get_by_mxid(user)
+ puppet = pu.Puppet.get_by_mxid(user_id)
if puppet:
- await self.handle_puppet_invite(room, puppet, inviter)
+ await self.handle_puppet_invite(room_id, puppet, inviter)
return
- user = User.get_by_mxid(user, create=False)
+ user = u.User.get_by_mxid(user_id, create=False)
if not user:
return
await user.ensure_started()
- portal = Portal.get_by_mxid(room)
+ portal = po.Portal.get_by_mxid(room_id)
if user and await user.has_full_access(allow_bot=True) and portal:
await portal.invite_telegram(inviter, user)
return
# The rest can probably be ignored
- async def handle_join(self, room, user, event_id):
- user = await User.get_by_mxid(user).ensure_started()
+ async def handle_join(self, room_id: str, user_id: str, event_id: str):
+ user = await u.User.get_by_mxid(user_id).ensure_started()
- portal = Portal.get_by_mxid(room)
+ portal = po.Portal.get_by_mxid(room_id)
if not portal:
return
if not user.relaybot_whitelisted:
- await portal.main_intent.kick(room, user.mxid,
+ await portal.main_intent.kick(room_id, user.mxid,
"You are not whitelisted on this Telegram bridge.")
return
elif not await user.is_logged_in() and not portal.has_bot:
- await portal.main_intent.kick(room, user.mxid,
+ await portal.main_intent.kick(room_id, user.mxid,
"This chat does not have a bot relaying "
"messages for unauthenticated users.")
return
- self.log.debug(f"{user} joined {room}")
+ self.log.debug(f"{user} joined {room_id}")
if await user.is_logged_in() or portal.has_bot:
await portal.join_matrix(user, event_id)
- async def handle_part(self, room, user, sender, event_id):
- self.log.debug(f"{user} left {room}")
+ async def handle_part(self, room_id: str, user_id, sender_mxid: str, event_id: str):
+ self.log.debug(f"{user_id} left {room_id}")
- sender = User.get_by_mxid(sender, create=False)
+ sender = u.User.get_by_mxid(sender_mxid, create=False)
if not sender:
return
await sender.ensure_started()
- portal = Portal.get_by_mxid(room)
+ portal = po.Portal.get_by_mxid(room_id)
if not portal:
return
- puppet = Puppet.get_by_mxid(user)
+ puppet = pu.Puppet.get_by_mxid(user_id)
if sender and puppet:
await portal.leave_matrix(puppet, sender, event_id)
- user = User.get_by_mxid(user, create=False)
+ user = u.User.get_by_mxid(user_id, create=False)
if not user:
return
await user.ensure_started()
if await user.is_logged_in() or portal.has_bot:
await portal.leave_matrix(user, sender, event_id)
- def is_command(self, message):
+ def is_command(self, message: dict) -> Tuple[bool, str]:
text = message.get("body", "")
prefix = self.config["bridge.command_prefix"]
is_command = text.startswith(prefix)
@@ -207,14 +204,14 @@ class MatrixHandler:
async def handle_message(self, room, sender, message, event_id):
is_command, text = self.is_command(message)
- sender = await User.get_by_mxid(sender).ensure_started()
+ sender = await u.User.get_by_mxid(sender).ensure_started()
if not sender.relaybot_whitelisted:
self.log.debug(f"Ignoring message \"{message}\" from {sender} to {room}:"
- " User is not whitelisted.")
+ " u.User is not whitelisted.")
return
self.log.debug(f"Received Matrix event \"{message}\" from {sender} in {room}")
- portal = Portal.get_by_mxid(room)
+ portal = po.Portal.get_by_mxid(room)
if not is_command and portal and (await sender.is_logged_in() or portal.has_bot):
await portal.handle_matrix_message(sender, message, event_id)
return
@@ -239,39 +236,44 @@ class MatrixHandler:
await self.commands.handle(room, sender, command, args, is_management,
is_portal=portal is not None)
- async def handle_redaction(self, room, sender, event_id):
- sender = await User.get_by_mxid(sender).ensure_started()
+ @staticmethod
+ async def handle_redaction(room_id: str, sender_mxid: str, event_id: str):
+ sender = await u.User.get_by_mxid(sender_mxid).ensure_started()
if not sender.relaybot_whitelisted:
return
- portal = Portal.get_by_mxid(room)
+ portal = po.Portal.get_by_mxid(room_id)
if not portal:
return
await portal.handle_matrix_deletion(sender, event_id)
- async def handle_power_levels(self, room, sender, new, old):
- portal = Portal.get_by_mxid(room)
- sender = await User.get_by_mxid(sender).ensure_started()
+ @staticmethod
+ async def handle_power_levels(room_id: str, sender_mxid: str, new: dict, old: dict):
+ portal = po.Portal.get_by_mxid(room_id)
+ sender = await u.User.get_by_mxid(sender_mxid).ensure_started()
if await sender.has_full_access(allow_bot=True) and portal:
await portal.handle_matrix_power_levels(sender, new["users"], old["users"])
- async def handle_room_meta(self, type, room, sender, content):
- portal = Portal.get_by_mxid(room)
- sender = await User.get_by_mxid(sender).ensure_started()
+ @staticmethod
+ async def handle_room_meta(evt_type: str, room_id: str, sender_mxid: str, content: dict):
+ portal = po.Portal.get_by_mxid(room_id)
+ sender = await u.User.get_by_mxid(sender_mxid).ensure_started()
if await sender.has_full_access(allow_bot=True) and portal:
handler, content_key = {
"m.room.name": (portal.handle_matrix_title, "name"),
"m.room.topic": (portal.handle_matrix_about, "topic"),
"m.room.avatar": (portal.handle_matrix_avatar, "url"),
- }[type]
+ }[evt_type]
if content_key not in content:
return
await handler(sender, content[content_key])
- async def handle_room_pin(self, room, sender, new_events, old_events):
- portal = Portal.get_by_mxid(room)
- sender = await User.get_by_mxid(sender).ensure_started()
+ @staticmethod
+ async def handle_room_pin(room_id: str, sender_mxid: str, new_events: Set[str],
+ old_events: Set[str]):
+ portal = po.Portal.get_by_mxid(room_id)
+ sender = await u.User.get_by_mxid(sender_mxid).ensure_started()
if await sender.has_full_access(allow_bot=True) and portal:
events = new_events - old_events
if len(events) > 0:
@@ -281,12 +283,14 @@ class MatrixHandler:
# All pinned events removed, remove pinned event in Telegram.
await portal.handle_matrix_pin(sender, None)
- async def handle_name_change(self, room, user, displayname, prev_displayname, event_id):
- portal = Portal.get_by_mxid(room)
+ @staticmethod
+ async def handle_name_change(room_id: str, user_id: str, displayname: str,
+ prev_displayname: str, event_id: str):
+ portal = po.Portal.get_by_mxid(room_id)
if not portal or not portal.has_bot:
return
- user = await User.get_by_mxid(user).ensure_started()
+ user = await u.User.get_by_mxid(user_id).ensure_started()
if await user.needs_relaybot(portal):
await portal.name_change_matrix(user, displayname, prev_displayname, event_id)
@@ -296,25 +300,27 @@ class MatrixHandler:
for event_id, receipts in content.items()
for user_id in receipts.get("m.read", {})}
- async def handle_read_receipts(self, room_id: str, receipts: Dict[str, str]):
- portal = Portal.get_by_mxid(room_id)
+ @staticmethod
+ async def handle_read_receipts(room_id: str, receipts: Dict[str, str]):
+ portal = po.Portal.get_by_mxid(room_id)
if not portal:
return
for user_id, event_id in receipts.items():
- user = await User.get_by_mxid(user_id).ensure_started()
+ user = await u.User.get_by_mxid(user_id).ensure_started()
if not await user.is_logged_in():
continue
await portal.mark_read(user, event_id)
- async def handle_presence(self, user: str, presence: str):
- user = await User.get_by_mxid(user).ensure_started()
+ @staticmethod
+ async def handle_presence(user_id: str, presence: str):
+ user = await u.User.get_by_mxid(user_id).ensure_started()
if not await user.is_logged_in():
return
await user.set_presence(presence == "online")
async def handle_typing(self, room_id: str, now_typing: List[str]):
- portal = Portal.get_by_mxid(room_id)
+ portal = po.Portal.get_by_mxid(room_id)
if not portal:
return
@@ -324,7 +330,7 @@ class MatrixHandler:
if is_typing and was_typing:
continue
- user = await User.get_by_mxid(user_id).ensure_started()
+ user = await u.User.get_by_mxid(user_id).ensure_started()
if not await user.is_logged_in():
continue
@@ -332,38 +338,38 @@ class MatrixHandler:
self.previously_typing = now_typing
- def filter_matrix_event(self, event):
+ def filter_matrix_event(self, event: dict):
sender = event.get("sender", None)
if not sender:
return False
return (sender == self.az.bot_mxid
- or Puppet.get_id_from_mxid(sender) is not None)
+ or pu.Puppet.get_id_from_mxid(sender) is not None)
- async def try_handle_event(self, evt):
+ async def try_handle_event(self, evt: dict):
try:
await self.handle_event(evt)
except Exception:
self.log.exception("Error handling manually received Matrix event")
- async def handle_event(self, evt):
+ async def handle_event(self, evt: dict):
if self.filter_matrix_event(evt):
return
self.log.debug("Received event: %s", evt)
- type = evt.get("type", "m.unknown")
- room_id = evt.get("room_id", None)
- event_id = evt.get("event_id", None)
- sender = evt.get("sender", None)
- content = evt.get("content", {})
- if type == "m.room.member":
- state_key = evt["state_key"]
- prev_content = evt.get("unsigned", {}).get("prev_content", {})
- membership = content.get("membership", "")
- prev_membership = prev_content.get("membership", "leave")
+ evt_type = evt.get("type", "m.unknown") # type: str
+ room_id = evt.get("room_id", None) # type: str
+ event_id = evt.get("event_id", None) # type: str
+ sender = evt.get("sender", None) # type: str
+ content = evt.get("content", {}) # type: dict
+ if evt_type == "m.room.member":
+ state_key = evt["state_key"] # type: str
+ prev_content = evt.get("unsigned", {}).get("prev_content", {}) # type: dict
+ membership = content.get("membership", "") # type: str
+ prev_membership = prev_content.get("membership", "leave") # type: str
if membership == prev_membership:
- match = re.compile("@(.+):(.+)").match(state_key)
- localpart = match.group(1)
- displayname = content.get("displayname", localpart)
- prev_displayname = prev_content.get("displayname", localpart)
+ match = re.compile("@(.+):(.+)").match(state_key) # type: Match
+ localpart = match.group(1) # type: str
+ displayname = content.get("displayname", localpart) # type: str
+ prev_displayname = prev_content.get("displayname", localpart) # type: str
if displayname != prev_displayname:
await self.handle_name_change(room_id, state_key, displayname,
prev_displayname, event_id)
@@ -373,26 +379,26 @@ class MatrixHandler:
await self.handle_part(room_id, state_key, sender, event_id)
elif membership == "join":
await self.handle_join(room_id, state_key, event_id)
- elif type in ("m.room.message", "m.sticker"):
- if type != "m.room.message":
- content["msgtype"] = type
+ elif evt_type in ("m.room.message", "m.sticker"):
+ if evt_type != "m.room.message":
+ content["msgtype"] = evt_type
await self.handle_message(room_id, sender, content, event_id)
- elif type == "m.room.redaction":
+ elif evt_type == "m.room.redaction":
await self.handle_redaction(room_id, sender, evt["redacts"])
- elif type == "m.room.power_levels":
+ elif evt_type == "m.room.power_levels":
await self.handle_power_levels(room_id, sender, evt["content"], evt["prev_content"])
- elif type in ("m.room.name", "m.room.avatar", "m.room.topic"):
- await self.handle_room_meta(type, room_id, sender, evt["content"])
- elif type == "m.room.pinned_events":
+ elif evt_type in ("m.room.name", "m.room.avatar", "m.room.topic"):
+ await self.handle_room_meta(evt_type, room_id, sender, evt["content"])
+ elif evt_type == "m.room.pinned_events":
new_events = set(evt["content"]["pinned"])
try:
old_events = set(evt["unsigned"]["prev_content"]["pinned"])
except KeyError:
old_events = set()
await self.handle_room_pin(room_id, sender, new_events, old_events)
- elif type == "m.receipt":
+ elif evt_type == "m.receipt":
await self.handle_read_receipts(room_id, self.parse_read_receipts(content))
- elif type == "m.presence":
+ elif evt_type == "m.presence":
await self.handle_presence(sender, content.get("presence", "offline"))
- elif type == "m.typing":
+ elif evt_type == "m.typing":
await self.handle_typing(room_id, content.get("user_ids", []))
diff --git a/mautrix_telegram/portal.py b/mautrix_telegram/portal.py
index 9331c9bc..a4f80776 100644
--- a/mautrix_telegram/portal.py
+++ b/mautrix_telegram/portal.py
@@ -14,6 +14,7 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
+from typing import Pattern, Dict, Tuple, Awaitable, TYPE_CHECKING
from collections import deque
from datetime import datetime
from string import Template
@@ -24,65 +25,82 @@ import mimetypes
import unicodedata
import hashlib
import logging
+import re
import magic
+from sqlalchemy import orm
from sqlalchemy.exc import IntegrityError, InvalidRequestError
from sqlalchemy.orm.exc import FlushError
from telethon.tl.functions.messages import *
from telethon.tl.functions.channels import *
-from telethon.tl.functions.messages import ReadHistoryRequest
+from telethon.tl.functions.messages import ReadHistoryRequest as ReadMessageHistoryRequest
from telethon.tl.functions.channels import ReadHistoryRequest as ReadChannelHistoryRequest
-from telethon.errors import *
+from telethon.errors import ChatAdminRequiredError, ChatNotModifiedError
from telethon.tl.types import *
-from mautrix_appservice import MatrixRequestError, IntentError
+from mautrix_appservice import MatrixRequestError, IntentError, AppService, IntentAPI
-from .db import Portal as DBPortal, Message as DBMessage
+from .context import Context
+from .db import Portal as DBPortal, Message as DBMessage, TelegramFile as DBTelegramFile
from . import puppet as p, user as u, formatter, util
+if TYPE_CHECKING:
+ from .bot import Bot
+ from .abstract_user import AbstractUser
+ from .config import Config
+ from .tgclient import MautrixTelegramClient
+
mimetypes.init()
-config = None
+config = None # type: Config
+
+TypeMessage = Union[Message, MessageService]
+TypeParticipant = Union[TypeChatParticipant, TypeChannelParticipant]
+DedupMXID = Tuple[str, int]
+InviteList = Union[str, List[str]]
class Portal:
- log = logging.getLogger("mau.portal")
- db = None
- az = None
- bot = None
- loop = None
- filter_mode = None
- filter_list = None
- bridge_notices = False
- alias_template = None
- mx_alias_regex = None
- hs_domain = None
- by_mxid = {}
- by_tgid = {}
+ log = logging.getLogger("mau.portal") # type: logging.Logger
+ db = None # type: orm.Session
+ az = None # type: AppService
+ bot = None # type: Bot
+ loop = None # type: asyncio.AbstractEventLoop
+ filter_mode = None # type: str
+ filter_list = None # type: List[str]
+ bridge_notices = False # type: bool
+ alias_template = None # type: str
+ mx_alias_regex = None # type: Pattern
+ hs_domain = None # type: str
+ by_mxid = {} # type: Dict[str, Portal]
+ by_tgid = {} # type: Dict[Tuple[int, int], Portal]
- def __init__(self, tgid, peer_type, tg_receiver=None, mxid=None, username=None,
- megagroup=False, title=None, about=None, photo_id=None, db_instance=None):
- self.mxid = mxid
- self.tgid = tgid
- self.tg_receiver = tg_receiver or tgid
- self.peer_type = peer_type
- self.username = username
- self.megagroup = megagroup
- self.title = title
- self.about = about
- self.photo_id = photo_id
- self._db_instance = db_instance
+ def __init__(self, tgid: int, peer_type: str, tg_receiver: Optional[int] = None,
+ mxid: Optional[str] = None, username: Optional[str] = None,
+ megagroup: Optional[bool] = False, title: Optional[str] = None,
+ about: Optional[str] = None, photo_id: Optional[str] = None,
+ db_instance: DBPortal = None):
+ self.mxid = mxid # type: str
+ self.tgid = tgid # type: int
+ self.tg_receiver = tg_receiver or tgid # type: int
+ self.peer_type = peer_type # type: str
+ self.username = username # type: str
+ self.megagroup = megagroup # type: bool
+ self.title = title # type: str
+ self.about = about # type: str
+ self.photo_id = photo_id # type: str
+ self._db_instance = db_instance # type: DBPortal
- self._main_intent = None
- self._room_create_lock = asyncio.Lock()
- self._temp_pinned_message_id = None
- self._temp_pinned_message_sender = None
+ self._main_intent = None # type: IntentAPI
+ self._room_create_lock = asyncio.Lock() # type: asyncio.Lock
+ self._temp_pinned_message_id = None # type: Optional[int]
+ self._temp_pinned_message_sender = None # type: Optional[p.Puppet]
- self._dedup = deque()
- self._dedup_mxid = {}
- self._dedup_action = deque()
+ self._dedup = deque() # type: deque
+ self._dedup_mxid = {} # type: Dict[str, DedupMXID]
+ self._dedup_action = deque() # type: deque
- self._send_locks = {}
+ self._send_locks = {} # type: Dict[int, asyncio.Lock]
if tgid:
self.by_tgid[self.tgid_full] = self
@@ -92,17 +110,17 @@ class Portal:
# region Propegrties
@property
- def tgid_full(self):
+ def tgid_full(self) -> Tuple[int, int]:
return self.tgid, self.tg_receiver
@property
- def tgid_log(self):
+ def tgid_log(self) -> str:
if self.tgid == self.tg_receiver:
- return self.tgid
+ return str(self.tgid)
return f"{self.tg_receiver}<->{self.tgid}"
@property
- def peer(self):
+ def peer(self) -> TypePeer:
if self.peer_type == "user":
return PeerUser(user_id=self.tgid)
elif self.peer_type == "chat":
@@ -111,11 +129,11 @@ class Portal:
return PeerChannel(channel_id=self.tgid)
@property
- def has_bot(self):
+ def has_bot(self) -> bool:
return self.bot and self.bot.is_in_chat(self.tgid)
@property
- def main_intent(self):
+ def main_intent(self) -> IntentAPI:
if not self._main_intent:
direct = self.peer_type == "user"
puppet = p.Puppet.get(self.tgid) if direct else None
@@ -125,7 +143,7 @@ class Portal:
# endregion
# region Filtering
- def allow_bridging(self, tgid=None):
+ def allow_bridging(self, tgid: Optional[int] = None) -> bool:
tgid = tgid or self.tgid
if self.peer_type == "user":
return True
@@ -139,7 +157,7 @@ class Portal:
# region Deduplication
@staticmethod
- def _hash_event(event):
+ def _hash_event(event: TypeMessage) -> str:
# Non-channel messages are unique per-user (wtf telegram), so we have no other choice than
# to deduplicate based on a hash of the message content.
@@ -165,48 +183,54 @@ class Portal:
.encode("utf-8")
).hexdigest()
- def is_duplicate_action(self, event):
- hash = self._hash_event(event) if self.peer_type != "channel" else event.id
- if hash in self._dedup_action:
+ def is_duplicate_action(self, event: TypeMessage) -> bool:
+ evt_hash = self._hash_event(event) if self.peer_type != "channel" else event.id
+ if evt_hash in self._dedup_action:
return True
- self._dedup_action.append(hash)
+ self._dedup_action.append(evt_hash)
if len(self._dedup_action) > 20:
self._dedup_action.popleft()
return False
- def update_duplicate(self, event, mxid=None, expected_mxid=None, force_hash=False):
- hash = self._hash_event(event) if self.peer_type != "channel" or force_hash else event.id
+ def update_duplicate(self, event: TypeMessage, mxid: DedupMXID = None,
+ expected_mxid: Optional[DedupMXID] = None, force_hash: bool = False
+ ) -> Optional[DedupMXID]:
+ evt_hash = self._hash_event(
+ event) if self.peer_type != "channel" or force_hash else event.id
try:
- found_mxid = self._dedup_mxid[hash]
+ found_mxid = self._dedup_mxid[evt_hash]
except KeyError:
- return 0, "None"
+ return "None", 0
if found_mxid != expected_mxid:
return found_mxid
- self._dedup_mxid[hash] = mxid
+ self._dedup_mxid[evt_hash] = mxid
return None
- def is_duplicate(self, event, mxid=None, force_hash=False):
- hash = self._hash_event(event) if self.peer_type != "channel" or force_hash else event.id
- if hash in self._dedup:
- return self._dedup_mxid[hash]
+ def is_duplicate(self, event: TypeMessage, mxid: DedupMXID = None, force_hash: bool = False
+ ) -> Optional[DedupMXID]:
+ evt_hash = (self._hash_event(event)
+ if self.peer_type != "channel" or force_hash
+ else event.id)
+ if evt_hash in self._dedup:
+ return self._dedup_mxid[evt_hash]
- self._dedup_mxid[hash] = mxid
- self._dedup.append(hash)
+ self._dedup_mxid[evt_hash] = mxid
+ self._dedup.append(evt_hash)
if len(self._dedup) > 20:
del self._dedup_mxid[self._dedup.popleft()]
return None
- def get_input_entity(self, user):
+ def get_input_entity(self, user: u.User) -> Awaitable[TypeInputPeer]:
return user.client.get_input_entity(self.peer)
# endregion
# region Matrix room info updating
- async def invite_to_matrix(self, users):
+ async def invite_to_matrix(self, users: InviteList):
if isinstance(users, str):
await self.main_intent.invite(self.mxid, users, check_cache=True)
elif isinstance(users, list):
@@ -215,8 +239,10 @@ class Portal:
else:
raise ValueError("Invalid invite identifier given to invite_matrix()")
- async def update_matrix_room(self, user, entity, direct, puppet=None,
- levels=None, users=None, participants=None):
+ async def update_matrix_room(self, user: "AbstractUser", entity: TypeChat, direct: bool,
+ puppet: p.Puppet = None, levels: dict = None,
+ users: List[User] = None,
+ participants: List[TypeParticipant] = None):
if not direct:
await self.update_info(user, entity)
if not users or not participants:
@@ -229,8 +255,9 @@ class Portal:
await puppet.update_info(user, entity)
await puppet.intent.join_room(self.mxid)
- async def create_matrix_room(self, user, entity=None, invites=None, update_if_exists=True,
- synchronous=False):
+ async def create_matrix_room(self, user: "AbstractUser", entity: TypeChat = None,
+ invites: InviteList = None, update_if_exists: bool = True,
+ synchronous: bool = False) -> Optional[str]:
if self.mxid:
if update_if_exists:
if not entity:
@@ -245,7 +272,8 @@ class Portal:
async with self._room_create_lock:
return await self._create_matrix_room(user, entity, invites)
- async def _create_matrix_room(self, user, entity, invites):
+ async def _create_matrix_room(self, user: "AbstractUser", entity: TypeChat, invites: InviteList
+ ) -> Optional[str]:
direct = self.peer_type == "user"
if self.mxid:
@@ -310,7 +338,7 @@ class Portal:
participants=participants),
loop=self.loop)
- def _get_base_power_levels(self, levels=None, entity=None):
+ def _get_base_power_levels(self, levels: dict = None, entity: TypeChat = None) -> dict:
levels = levels or {}
power_level_requirement = (0 if self.peer_type == "chat" and not entity.admins_enabled
else 50)
@@ -336,27 +364,27 @@ class Portal:
return levels
@property
- def alias(self):
+ def alias(self) -> Optional[str]:
if not self.username:
return None
return f"#{self._get_alias_localpart()}:{self.hs_domain}"
- def _get_alias_localpart(self, username=None):
+ def _get_alias_localpart(self, username: Optional[str] = None) -> Optional[str]:
username = username or self.username
if not username:
return None
return self.alias_template.format(groupname=username)
- def add_bot_chat(self, entity):
- if self.bot and entity.id == self.bot.tgid:
+ def add_bot_chat(self, bot: User):
+ if self.bot and bot.id == self.bot.tgid:
self.bot.add_chat(self.tgid, self.peer_type)
return
- user = u.User.get_by_tgid(entity.id)
+ user = u.User.get_by_tgid(bot.id)
if user and user.is_bot:
user.register_portal(self)
- async def sync_telegram_users(self, source, users):
+ async def sync_telegram_users(self, source: "AbstractUser", users: List[User]):
allowed_tgids = set()
for entity in users:
puppet = p.Puppet.get(entity.id)
@@ -398,7 +426,7 @@ class Portal:
"You had left this Telegram chat.")
continue
- async def add_telegram_user(self, user_id, source=None):
+ async def add_telegram_user(self, user_id: int, source: Optional["AbstractUser"] = None):
puppet = p.Puppet.get(user_id)
if source:
entity = await source.client.get_entity(PeerUser(user_id))
@@ -410,7 +438,7 @@ class Portal:
user.register_portal(self)
await self.invite_to_matrix(user.mxid)
- async def delete_telegram_user(self, user_id, sender):
+ async def delete_telegram_user(self, user_id: int, sender: p.Puppet):
puppet = p.Puppet.get(user_id)
user = u.User.get_by_tgid(user_id)
kick_message = (f"Kicked by {sender.displayname}"
@@ -424,7 +452,7 @@ class Portal:
user.unregister_portal(self)
await self.main_intent.kick(self.mxid, user.mxid, kick_message)
- async def update_info(self, user, entity=None):
+ async def update_info(self, user: "AbstractUser", entity: TypeChat = None):
if self.peer_type == "user":
self.log.warning(f"Called update_info() for direct chat portal {self.tgid_log}")
return
@@ -448,7 +476,7 @@ class Portal:
if changed:
self.save()
- async def update_username(self, username, save=False):
+ async def update_username(self, username: str, save: bool = False) -> bool:
if self.username != username:
if self.username:
await self.main_intent.remove_room_alias(self._get_alias_localpart())
@@ -465,7 +493,7 @@ class Portal:
return True
return False
- async def update_about(self, about, save=False):
+ async def update_about(self, about: str, save: bool = False) -> bool:
if self.about != about:
self.about = about
await self.main_intent.set_room_topic(self.mxid, self.about)
@@ -474,7 +502,7 @@ class Portal:
return True
return False
- async def update_title(self, title, save=False):
+ async def update_title(self, title: str, save: bool = False) -> bool:
if self.title != title:
self.title = title
await self.main_intent.set_room_name(self.mxid, self.title)
@@ -484,17 +512,18 @@ class Portal:
return False
@staticmethod
- def _get_largest_photo_size(photo):
+ def _get_largest_photo_size(photo: Photo) -> TypePhotoSize:
return max(photo.sizes, key=(lambda photo2: (
len(photo2.bytes) if isinstance(photo2, PhotoCachedSize) else photo2.size)))
- async def remove_avatar(self, user, save=False):
+ async def remove_avatar(self, _: "AbstractUser", save: bool = False):
await self.main_intent.set_room_avatar(self.mxid, None)
self.photo_id = None
if save:
self.save()
- async def update_avatar(self, user, photo, save=False):
+ async def update_avatar(self, user: "AbstractUser", photo: FileLocation,
+ save: bool = False) -> bool:
photo_id = f"{photo.volume_id}-{photo.local_id}"
if self.photo_id != photo_id:
file = await util.transfer_file_to_matrix(self.db, user.client, self.main_intent,
@@ -507,7 +536,9 @@ class Portal:
return True
return False
- async def _get_users(self, user, entity):
+ async def _get_users(self, user: "AbstractUser", entity: Union[TypeInputPeer, InputUser,
+ TypeChat, TypeUser]
+ ) -> Tuple[List[TypeUser], List[TypeParticipant]]:
if self.peer_type == "chat":
chat = await user.client(GetFullChatRequest(chat_id=self.tgid))
return chat.users, chat.full_chat.participants.participants
@@ -544,7 +575,7 @@ class Portal:
elif self.peer_type == "user":
return [entity], []
- async def get_invite_link(self, user):
+ async def get_invite_link(self, user: u.User) -> str:
if self.peer_type == "user":
raise ValueError("You can't invite users to private chats.")
elif self.peer_type == "chat":
@@ -562,7 +593,7 @@ class Portal:
return link.link
- async def get_authenticated_matrix_users(self):
+ async def get_authenticated_matrix_users(self) -> List[u.User]:
try:
members = await self.main_intent.get_room_members(self.mxid)
except MatrixRequestError:
@@ -573,13 +604,14 @@ class Portal:
if p.Puppet.get_id_from_mxid(member) or member == self.main_intent.mxid:
continue
user = await u.User.get_by_mxid(member).ensure_started()
- if (has_bot and user.relaybot_whitelisted) or await user.has_full_access(
- allow_bot=True):
+ authenticated_through_bot = has_bot and user.relaybot_whitelisted
+ if authenticated_through_bot or await user.has_full_access(allow_bot=True):
authenticated.append(user)
return authenticated
@staticmethod
- async def cleanup_room(intent, room_id, message="Portal deleted", puppets_only=False):
+ async def cleanup_room(intent: IntentAPI, room_id: str, message: str = "Portal deleted",
+ puppets_only: bool = False):
try:
members = await intent.get_room_members(room_id)
except MatrixRequestError:
@@ -608,7 +640,7 @@ class Portal:
# region Matrix event handling
@staticmethod
- def _get_file_meta(body, mime):
+ def _get_file_meta(body: str, mime: str) -> str:
try:
current_extension = body[body.rindex("."):]
if mimetypes.types_map[current_extension] == mime:
@@ -620,7 +652,8 @@ class Portal:
else:
return ""
- async def _get_state_change_message(self, event, user, arguments=None):
+ async def _get_state_change_message(self, event: str, user: u.User,
+ arguments: Optional[dict] = None) -> Optional[dict]:
tpl = config[f"bridge.state_event_formats.{event}"]
if len(tpl) == 0:
# Empty format means they don't want the message
@@ -637,7 +670,8 @@ class Portal:
"formatted_body": message,
}
- async def name_change_matrix(self, user, displayname, prev_displayname, event_id):
+ async def name_change_matrix(self, user: u.User, displayname: str, prev_displayname: str,
+ event_id: str):
async with self.require_send_lock(self.bot.tgid):
message = await self._get_state_change_message(
"name_change", user,
@@ -650,15 +684,15 @@ class Portal:
space = self.tgid if self.peer_type == "channel" else self.bot.tgid
self.is_duplicate(response, (event_id, space))
- async def get_displayname(self, user):
+ async def get_displayname(self, user: u.User) -> str:
return (await self.main_intent.get_displayname(self.mxid, user.mxid)
or user.mxid_localpart)
- def set_typing(self, user, typing=True, action=SendMessageTypingAction):
- return user.client(
- SetTypingRequest(self.peer, action() if typing else SendMessageCancelAction()))
+ def set_typing(self, user: u.User, typing: bool = True, action=SendMessageTypingAction):
+ return user.client(SetTypingRequest(
+ self.peer, action() if typing else SendMessageCancelAction()))
- async def mark_read(self, user, event_id):
+ async def mark_read(self, user: u.User, event_id: str):
if user.is_bot:
return
space = self.tgid if self.peer_type == "channel" else user.tgid
@@ -671,9 +705,9 @@ class Portal:
await user.client(ReadChannelHistoryRequest(
channel=await self.get_input_entity(user), max_id=message.tgid))
else:
- await user.client(ReadHistoryRequest(peer=self.peer, max_id=message.tgid))
+ await user.client(ReadMessageHistoryRequest(peer=self.peer, max_id=message.tgid))
- async def leave_matrix(self, user, source, event_id):
+ async def leave_matrix(self, user: u.User, source: u.User, event_id: str):
if await user.needs_relaybot(self):
async with self.require_send_lock(self.bot.tgid):
message = await self._get_state_change_message("leave", user)
@@ -709,7 +743,7 @@ class Portal:
channel = await self.get_input_entity(user)
await user.client(LeaveChannelRequest(channel=channel))
- async def join_matrix(self, user, event_id):
+ async def join_matrix(self, user: u.User, event_id: str):
if await user.needs_relaybot(self):
async with self.require_send_lock(self.bot.tgid):
message = await self._get_state_change_message("join", user)
@@ -728,7 +762,7 @@ class Portal:
# We'll just assume the user is already in the chat.
pass
- async def _apply_msg_format(self, sender, msgtype, message):
+ async def _apply_msg_format(self, sender: u.User, msgtype: str, message: dict):
if "formatted_body" not in message:
message["format"] = "org.matrix.custom.html"
message["formatted_body"] = escape_html(message.get("body", ""))
@@ -743,7 +777,7 @@ class Portal:
message=body)
message["formatted_body"] = Template(tpl).safe_substitute(tpl_args)
- async def _preprocess_matrix_message(self, sender, use_relaybot, message):
+ async def _pre_process_matrix_message(self, sender: u.User, use_relaybot: bool, message: dict):
msgtype = message.get("msgtype", "m.text")
if msgtype == "m.emote":
await self._apply_msg_format(sender, msgtype, message)
@@ -751,7 +785,8 @@ class Portal:
elif use_relaybot:
await self._apply_msg_format(sender, msgtype, message)
- def _matrix_event_to_entities(self, event):
+ @staticmethod
+ def _matrix_event_to_entities(event: dict) -> Tuple[str, Optional[List[TypeMessageEntity]]]:
try:
if event.get("format", None) == "org.matrix.custom.html":
message, entities = formatter.matrix_to_telegram(event["formatted_body"])
@@ -761,32 +796,33 @@ class Portal:
message, entities = None, None
return message, entities
- def require_send_lock(self, id):
- if id is None:
- return None
+ def require_send_lock(self, user_id: int) -> asyncio.Lock:
+ if user_id is None:
+ raise ValueError("Required send lock for none id")
try:
- return self._send_locks[id]
+ return self._send_locks[user_id]
except KeyError:
- self._send_locks[id] = asyncio.Lock()
- return self._send_locks[id]
+ self._send_locks[user_id] = asyncio.Lock()
+ return self._send_locks[user_id]
- def optional_send_lock(self, id):
- if id is None:
+ def optional_send_lock(self, user_id: int) -> Optional[asyncio.Lock]:
+ if user_id is None:
return None
try:
- return self._send_locks[id]
+ return self._send_locks[user_id]
except KeyError:
return None
- async def _handle_matrix_text(self, sender_id, event_id, space, client, message, reply_to):
+ async def _handle_matrix_text(self, sender_id: int, event_id: str, space: int,
+ client: "MautrixTelegramClient", message: dict, reply_to: int):
lock = self.require_send_lock(sender_id)
async with lock:
response = await client.send_message(self.peer, message, reply_to=reply_to,
parse_mode=self._matrix_event_to_entities)
self._add_telegram_message_to_db(event_id, space, response)
- async def _handle_matrix_file(self, type, sender_id, event_id, space, client, message,
- reply_to):
+ async def _handle_matrix_file(self, msgtype: str, sender_id: int, event_id: str, space: int,
+ client: "MautrixTelegramClient", message: dict, reply_to: int):
file = await self.main_intent.download_file(message["url"])
info = message.get("info", {})
@@ -794,7 +830,7 @@ class Portal:
w, h = None, None
- if type == "m.sticker":
+ if msgtype == "m.sticker":
if mime != "image/gif":
mime, file, w, h = util.convert_image(file, source_mime=mime, target_type="webp")
else:
@@ -812,14 +848,16 @@ class Portal:
caption = message["body"] if message["body"] != file_name else None
- media = await client.upload_file(file, mime, attributes, file_name)
+ media = await client.upload_file_direct(file, mime, attributes, file_name)
lock = self.require_send_lock(sender_id)
async with lock:
response = await client.send_media(self.peer, media, reply_to=reply_to,
caption=caption)
self._add_telegram_message_to_db(event_id, space, response)
- async def _handle_matrix_location(self, sender_id, event_id, space, client, message, reply_to):
+ async def _handle_matrix_location(self, sender_id: int, event_id: str, space: int,
+ client: "MautrixTelegramClient", message: dict,
+ reply_to: int):
try:
lat, long = message["geo_uri"][len("geo:"):].split(",")
lat, long = float(lat), float(long)
@@ -827,7 +865,7 @@ class Portal:
self.log.exception("Failed to parse location")
return None
message, entities = self._matrix_event_to_entities(message)
- media = MessageMediaGeo(geo=GeoPoint(lat, long))
+ media = MessageMediaGeo(geo=GeoPoint(lat, long, access_hash=0))
lock = self.require_send_lock(sender_id)
async with lock:
@@ -835,7 +873,7 @@ class Portal:
caption=message, entities=entities)
self._add_telegram_message_to_db(event_id, space, response)
- def _add_telegram_message_to_db(self, event_id, space, response):
+ def _add_telegram_message_to_db(self, event_id: str, space: int, response: TypeMessage):
self.log.debug("Handled Matrix message: %s", response)
self.is_duplicate(response, (event_id, space))
self.db.add(DBMessage(
@@ -859,21 +897,21 @@ class Portal:
reply_to = formatter.matrix_reply_to_telegram(message, space, room_id=self.mxid)
message["mxtg_filename"] = message["body"]
- await self._preprocess_matrix_message(sender, not logged_in, message)
- type = message["msgtype"]
+ await self._pre_process_matrix_message(sender, not logged_in, message)
+ msgtype = message["msgtype"]
- if type == "m.text" or (self.bridge_notices and type == "m.notice"):
+ if msgtype == "m.text" or (self.bridge_notices and msgtype == "m.notice"):
await self._handle_matrix_text(sender_id, event_id, space, client, message, reply_to)
- elif type == "m.location":
+ elif msgtype == "m.location":
await self._handle_matrix_location(sender_id, event_id, space, client, message,
reply_to)
- elif type in ("m.sticker", "m.image", "m.file", "m.audio", "m.video"):
- await self._handle_matrix_file(type, sender_id, event_id, space, client, message,
+ elif msgtype in ("m.sticker", "m.image", "m.file", "m.audio", "m.video"):
+ await self._handle_matrix_file(msgtype, sender_id, event_id, space, client, message,
reply_to)
else:
self.log.debug(f"Unhandled Matrix event: {message}")
- async def handle_matrix_pin(self, sender, pinned_message):
+ async def handle_matrix_pin(self, sender: u.User, pinned_message: Optional[str]):
if self.peer_type != "channel":
return
try:
@@ -887,7 +925,7 @@ class Portal:
except ChatNotModifiedError:
pass
- async def handle_matrix_deletion(self, deleter, event_id):
+ async def handle_matrix_deletion(self, deleter: u.User, event_id: str):
deleter = deleter if not await deleter.needs_relaybot(self) else self.bot
space = self.tgid if self.peer_type == "channel" else deleter.tgid
message = DBMessage.query.filter(DBMessage.mxid == event_id,
@@ -897,7 +935,7 @@ class Portal:
return
await deleter.client.delete_messages(self.peer, [message.tgid])
- async def _update_telegram_power_level(self, sender, user_id, level):
+ async def _update_telegram_power_level(self, sender: u.User, user_id: int, level: int):
if self.peer_type == "chat":
await sender.client(EditChatAdminRequest(
chat_id=self.tgid, user_id=user_id, is_admin=level >= 50))
@@ -913,7 +951,8 @@ class Portal:
EditAdminRequest(channel=await self.get_input_entity(sender),
user_id=user_id, admin_rights=rights))
- async def handle_matrix_power_levels(self, sender, new_users, old_users):
+ async def handle_matrix_power_levels(self, sender: u.User, new_users: Dict[str, int],
+ old_users: Dict[str, int]):
# TODO handle all power level changes and bridge exact admin rights to supergroups/channels
for user, level in new_users.items():
if not user or user == self.main_intent.mxid or user == sender.mxid:
@@ -929,7 +968,7 @@ class Portal:
if user not in old_users or level != old_users[user]:
await self._update_telegram_power_level(sender, user_id, level)
- async def handle_matrix_about(self, sender, about):
+ async def handle_matrix_about(self, sender: u.User, about: str):
if self.peer_type not in {"channel"}:
return
channel = await self.get_input_entity(sender)
@@ -937,7 +976,7 @@ class Portal:
self.about = about
self.save()
- async def handle_matrix_title(self, sender, title):
+ async def handle_matrix_title(self, sender: u.User, title: str):
if self.peer_type not in {"chat", "channel"}:
return
@@ -950,7 +989,7 @@ class Portal:
self.title = title
self.save()
- async def handle_matrix_avatar(self, sender, url):
+ async def handle_matrix_avatar(self, sender: u.User, url: str):
if self.peer_type not in {"chat", "channel"}:
# Invalid peer type
return
@@ -958,7 +997,7 @@ class Portal:
file = await self.main_intent.download_file(url)
mime = magic.from_buffer(file, mime=True)
ext = mimetypes.guess_extension(mime)
- uploaded = await sender.client.upload_file(file, file_name=f"avatar{ext}")
+ uploaded = await sender.client.upload_file_direct(file, file_name=f"avatar{ext}")
photo = InputChatUploadedPhoto(file=uploaded)
if self.peer_type == "chat":
@@ -977,7 +1016,7 @@ class Portal:
self.save()
break
- def _register_outgoing_actions_for_dedup(self, response):
+ def _register_outgoing_actions_for_dedup(self, response: TypeUpdates):
for update in response.updates:
check_dedup = (isinstance(update, (UpdateNewMessage, UpdateNewChannelMessage))
and isinstance(update.message, MessageService))
@@ -987,7 +1026,7 @@ class Portal:
# endregion
# region Telegram chat info updating
- async def _get_telegram_users_in_matrix_room(self):
+ async def _get_telegram_users_in_matrix_room(self) -> List[int]:
user_tgids = set()
user_mxids = await self.main_intent.get_room_members(self.mxid, ("join", "invite"))
for user in user_mxids:
@@ -1001,13 +1040,13 @@ class Portal:
user_tgids.add(puppet_id)
return list(user_tgids)
- async def upgrade_telegram_chat(self, source):
+ async def upgrade_telegram_chat(self, source: u.User):
if self.peer_type != "chat":
raise ValueError("Only normal group chats are upgradable to supergroups.")
- updates = await source.client(MigrateChatRequest(chat_id=self.tgid))
+ response = await source.client(MigrateChatRequest(chat_id=self.tgid))
entity = None
- for chat in updates.chats:
+ for chat in response.chats:
if isinstance(chat, Channel):
entity = chat
break
@@ -1017,7 +1056,7 @@ class Portal:
self.migrate_and_save(entity.id)
await self.update_info(source, entity)
- async def set_telegram_username(self, source, username):
+ async def set_telegram_username(self, source: u.User, username: str):
if self.peer_type != "channel":
raise ValueError("Only channels and supergroups have usernames.")
await source.client(
@@ -1025,7 +1064,7 @@ class Portal:
if await self.update_username(username):
self.save()
- async def create_telegram_chat(self, source, supergroup=False):
+ async def create_telegram_chat(self, source: u.User, supergroup: bool = False):
if not self.mxid:
raise ValueError("Can't create Telegram chat for portal without Matrix room.")
elif self.tgid:
@@ -1036,13 +1075,13 @@ class Portal:
raise ValueError("Not enough Telegram users to create a chat")
if self.peer_type == "chat":
- updates = await source.client(CreateChatRequest(title=self.title, users=invites))
- entity = updates.chats[0]
+ response = await source.client(CreateChatRequest(title=self.title, users=invites))
+ entity = response.chats[0]
elif self.peer_type == "channel":
- updates = await source.client(CreateChannelRequest(title=self.title,
- about=self.about or "",
- megagroup=supergroup))
- entity = updates.chats[0]
+ response = await source.client(CreateChannelRequest(title=self.title,
+ about=self.about or "",
+ megagroup=supergroup))
+ entity = response.chats[0]
await source.client(InviteToChannelRequest(
channel=await source.client.get_input_entity(entity),
users=invites))
@@ -1066,7 +1105,7 @@ class Portal:
await self.main_intent.set_power_levels(self.mxid, levels)
await self.handle_matrix_power_levels(source, levels["users"], {})
- async def invite_telegram(self, source, puppet):
+ async def invite_telegram(self, source: u.User, puppet: Union[p.Puppet, "AbstractUser"]):
if self.peer_type == "chat":
await source.client(
AddChatUserRequest(chat_id=self.tgid, user_id=puppet.tgid, fwd_limit=0))
@@ -1078,16 +1117,18 @@ class Portal:
# endregion
# region Telegram event handling
- async def handle_telegram_typing(self, user, event):
+ async def handle_telegram_typing(self, user: p.Puppet,
+ _: Union[UpdateUserTyping, UpdateChatUserTyping]):
if self.mxid:
await user.intent.set_typing(self.mxid, is_typing=True)
- def get_external_url(self, evt: Message):
+ def get_external_url(self, evt: Message) -> Optional[str]:
if self.peer_type == "channel" and self.username is not None:
return f"https://t.me/{self.username}/{evt.id}"
return None
- async def handle_telegram_photo(self, source: u.User, intent, evt: Message, relates_to=None):
+ async def handle_telegram_photo(self, source: "AbstractUser", intent: IntentAPI, evt: Message,
+ relates_to=None):
largest_size = self._get_largest_photo_size(evt.media.photo)
file = await util.transfer_file_to_matrix(self.db, source.client, intent,
largest_size.location)
@@ -1117,7 +1158,7 @@ class Portal:
external_url=self.get_external_url(evt))
@staticmethod
- def _parse_telegram_document_attributes(attributes):
+ def _parse_telegram_document_attributes(attributes: List[TypeDocumentAttribute]) -> dict:
attrs = {
"name": None,
"mime_type": None,
@@ -1138,7 +1179,8 @@ class Portal:
return attrs
@staticmethod
- def _parse_telegram_document_meta(evt, file, attrs):
+ def _parse_telegram_document_meta(evt: Message, file: DBTelegramFile, attrs: dict
+ ) -> Tuple[dict, str]:
document = evt.media.document
name = evt.message or attrs["name"]
if attrs["is_sticker"]:
@@ -1170,7 +1212,9 @@ class Portal:
return info, name
- async def handle_telegram_document(self, source, intent, evt: Message, relates_to=None):
+ async def handle_telegram_document(self, source: "AbstractUser", intent: IntentAPI,
+ evt: Message,
+ relates_to: dict = None) -> Optional[dict]:
document = evt.media.document
attrs = self._parse_telegram_document_attributes(document.attributes)
@@ -1207,7 +1251,8 @@ class Portal:
kwargs["file_type"] = "m.file"
return await intent.send_file(**kwargs)
- def handle_telegram_location(self, source, intent, evt, relates_to=None):
+ def handle_telegram_location(self, _: "AbstractUser", intent: IntentAPI, evt: Message,
+ relates_to: dict = None) -> Awaitable[dict]:
location = evt.media.geo
long = location.long
lat = location.lat
@@ -1234,7 +1279,8 @@ class Portal:
"m.relates_to": relates_to or None,
}, timestamp=evt.date, external_url=self.get_external_url(evt))
- async def handle_telegram_text(self, source, intent, is_bot, evt):
+ async def handle_telegram_text(self, source: "AbstractUser", intent: IntentAPI, is_bot: bool,
+ evt: Message) -> dict:
self.log.debug(f"Sending {evt.message} to {self.mxid} by {intent.mxid}")
text, html, relates_to = await formatter.telegram_to_matrix(evt, source, self.main_intent)
await intent.set_typing(self.mxid, is_typing=False)
@@ -1243,7 +1289,7 @@ class Portal:
msgtype=msgtype, timestamp=evt.date,
external_url=self.get_external_url(evt))
- async def handle_telegram_edit(self, source, sender, evt):
+ async def handle_telegram_edit(self, source: "AbstractUser", sender: p.Puppet, evt: Message):
if not self.mxid:
return
elif not config["bridge.edits_as_replies"]:
@@ -1290,7 +1336,7 @@ class Portal:
.update({"mxid": mxid})
self.db.commit()
- async def handle_telegram_message(self, source, sender, evt):
+ async def handle_telegram_message(self, source: "AbstractUser", sender: p.Puppet, evt: Message):
if not self.mxid:
await self.create_matrix_room(source, invites=[source.mxid], update_if_exists=False)
@@ -1373,19 +1419,21 @@ class Portal:
self.db.rollback()
await intent.redact(self.mxid, mxid)
- async def _create_room_on_action(self, source, action):
+ async def _create_room_on_action(self, source: "AbstractUser",
+ action: TypeMessageAction) -> bool:
if source.is_relaybot:
return False
create_and_exit = (MessageActionChatCreate, MessageActionChannelCreate)
create_and_continue = (MessageActionChatAddUser, MessageActionChatJoinedByLink)
- if isinstance(action, create_and_exit + create_and_continue):
+ if isinstance(action, create_and_exit) or isinstance(action, create_and_continue):
await self.create_matrix_room(source, invites=[source.mxid],
update_if_exists=isinstance(action, create_and_exit))
if not isinstance(action, create_and_continue):
return False
return True
- async def handle_telegram_action(self, source, sender, update):
+ async def handle_telegram_action(self, source: "AbstractUser", sender: p.Puppet,
+ update: MessageService):
action = update.action
should_ignore = ((not self.mxid and not await self._create_room_on_action(source, action))
or self.is_duplicate_action(update))
@@ -1415,7 +1463,7 @@ class Portal:
else:
self.log.debug("Unhandled Telegram action in %s: %s", self.title, action)
- async def set_telegram_admin(self, user_id):
+ async def set_telegram_admin(self, user_id: int):
puppet = p.Puppet.get(user_id)
user = await u.User.get_by_tgid(user_id)
@@ -1426,7 +1474,7 @@ class Portal:
levels["users"][puppet.mxid] = 50
await self.main_intent.set_power_levels(self.mxid, levels)
- async def receive_telegram_pin_sender(self, sender):
+ async def receive_telegram_pin_sender(self, sender: p.Puppet):
self._temp_pinned_message_sender = sender
if self._temp_pinned_message_id:
await self.update_telegram_pin()
@@ -1434,25 +1482,25 @@ class Portal:
async def update_telegram_pin(self):
intent = (self._temp_pinned_message_sender.intent
if self._temp_pinned_message_sender else self.main_intent)
- id = self._temp_pinned_message_id
+ msg_id = self._temp_pinned_message_id
self._temp_pinned_message_id = None
self._temp_pinned_message_sender = None
- message = DBMessage.query.get((id, self.tgid))
+ message = DBMessage.query.get((msg_id, self.tgid))
if message:
await intent.set_pinned_messages(self.mxid, [message.mxid])
else:
await intent.set_pinned_messages(self.mxid, [])
- async def receive_telegram_pin_id(self, id):
- if id == 0:
+ async def receive_telegram_pin_id(self, msg_id: int):
+ if msg_id == 0:
return await self.update_telegram_pin()
- self._temp_pinned_message_id = id
+ self._temp_pinned_message_id = msg_id
if self._temp_pinned_message_sender:
await self.update_telegram_pin()
@staticmethod
- def _get_level_from_participant(participant, _):
+ def _get_level_from_participant(participant: TypeParticipant, _) -> int:
# TODO use the power level requirements to get better precision in channels
if isinstance(participant, (ChatParticipantAdmin, ChannelParticipantAdmin)):
return 50
@@ -1461,7 +1509,8 @@ class Portal:
return 0
@staticmethod
- def _participant_to_power_levels(levels, user, new_level, bot_level):
+ def _participant_to_power_levels(levels: dict, user: Union[u.User, p.Puppet], new_level: int,
+ bot_level: int) -> bool:
new_level = min(new_level, bot_level)
default_level = levels["users_default"] if "users_default" in levels else 0
try:
@@ -1473,7 +1522,7 @@ class Portal:
return True
return False
- def _get_bot_level(self, levels):
+ def _get_bot_level(self, levels: dict) -> int:
try:
return levels["users"][self.main_intent.mxid]
except KeyError:
@@ -1483,7 +1532,7 @@ class Portal:
return 0
@staticmethod
- def _get_powerlevel_level(levels):
+ def _get_powerlevel_level(levels: dict) -> int:
try:
return levels["events"]["m.room.power_levels"]
except KeyError:
@@ -1492,7 +1541,8 @@ class Portal:
except KeyError:
return 50
- def _participants_to_power_levels(self, participants, levels):
+ def _participants_to_power_levels(self, participants: List[TypeParticipant], levels: dict
+ ) -> bool:
bot_level = self._get_bot_level(levels)
if bot_level < self._get_powerlevel_level(levels):
return False
@@ -1517,13 +1567,14 @@ class Portal:
bot_level) or changed
return changed
- async def update_telegram_participants(self, participants, levels=None):
+ async def update_telegram_participants(self, participants: List[TypeParticipant],
+ levels: dict = None):
if not levels:
levels = await self.main_intent.get_power_levels(self.mxid)
if self._participants_to_power_levels(participants, levels):
await self.main_intent.set_power_levels(self.mxid, levels)
- async def set_telegram_admins_enabled(self, enabled):
+ async def set_telegram_admins_enabled(self, enabled: bool):
level = 50 if enabled else 10
levels = await self.main_intent.get_power_levels(self.mxid)
levels["invite"] = level
@@ -1535,17 +1586,17 @@ class Portal:
# region Database conversion
@property
- def db_instance(self):
+ def db_instance(self) -> DBPortal:
if not self._db_instance:
self._db_instance = self.new_db_instance()
return self._db_instance
- def new_db_instance(self):
+ def new_db_instance(self) -> DBPortal:
return DBPortal(tgid=self.tgid, tg_receiver=self.tg_receiver, peer_type=self.peer_type,
mxid=self.mxid, username=self.username, megagroup=self.megagroup,
title=self.title, about=self.about, photo_id=self.photo_id)
- def migrate_and_save(self, new_id):
+ def migrate_and_save(self, new_id: int):
existing = DBPortal.query.get(self.tgid_full)
if existing:
self.db.delete(existing)
@@ -1580,7 +1631,7 @@ class Portal:
self.db.commit()
@classmethod
- def from_db(cls, db_portal):
+ def from_db(cls, db_portal: DBPortal) -> "Portal":
return Portal(tgid=db_portal.tgid, tg_receiver=db_portal.tg_receiver,
peer_type=db_portal.peer_type, mxid=db_portal.mxid,
username=db_portal.username, megagroup=db_portal.megagroup,
@@ -1591,7 +1642,7 @@ class Portal:
# region Class instance lookup
@classmethod
- def get_by_mxid(cls, mxid):
+ def get_by_mxid(cls, mxid: str) -> Optional["Portal"]:
try:
return cls.by_mxid[mxid]
except KeyError:
@@ -1604,14 +1655,14 @@ class Portal:
return None
@classmethod
- def get_username_from_mx_alias(cls, alias):
+ def get_username_from_mx_alias(cls, alias: str) -> Optional[str]:
match = cls.mx_alias_regex.match(alias)
if match:
return match.group(1)
return None
@classmethod
- def find_by_username(cls, username):
+ def find_by_username(cls, username: str) -> Optional["Portal"]:
if not username:
return None
@@ -1626,7 +1677,8 @@ class Portal:
return None
@classmethod
- def get_by_tgid(cls, tgid, tg_receiver=None, peer_type=None):
+ def get_by_tgid(cls, tgid: int, tg_receiver: int = None, peer_type: str = None
+ ) -> Optional["Portal"]:
tg_receiver = tg_receiver or tgid
tgid_full = (tgid, tg_receiver)
try:
@@ -1647,36 +1699,37 @@ class Portal:
return None
@classmethod
- def get_by_entity(cls, entity, receiver_id=None, create=True):
+ def get_by_entity(cls, entity: Union[TypeChat, TypePeer, TypeUser, TypeUserFull, TypeInputPeer],
+ receiver_id: int = None, create: bool = True) -> Optional["Portal"]:
entity_type = type(entity)
if entity_type in {Chat, ChatFull}:
type_name = "chat"
- id = entity.id
+ entity_id = entity.id
elif entity_type in {PeerChat, InputPeerChat}:
type_name = "chat"
- id = entity.chat_id
+ entity_id = entity.chat_id
elif entity_type in {Channel, ChannelFull}:
type_name = "channel"
- id = entity.id
+ entity_id = entity.id
elif entity_type in {PeerChannel, InputPeerChannel, InputChannel}:
type_name = "channel"
- id = entity.channel_id
+ entity_id = entity.channel_id
elif entity_type in {User, UserFull}:
type_name = "user"
- id = entity.id
+ entity_id = entity.id
elif entity_type in {PeerUser, InputPeerUser, InputUser}:
type_name = "user"
- id = entity.user_id
+ entity_id = entity.user_id
else:
raise ValueError(f"Unknown entity type {entity_type.__name__}")
- return cls.get_by_tgid(id,
- receiver_id if type_name == "user" else id,
+ return cls.get_by_tgid(entity_id,
+ receiver_id if type_name == "user" else entity_id,
type_name if create else None)
# endregion
-def init(context):
+def init(context: Context):
global config
Portal.az, Portal.db, config, Portal.loop, Portal.bot = context
Portal.bridge_notices = config["bridge.bridge_notices"]
@@ -1684,5 +1737,5 @@ def init(context):
Portal.filter_list = config["bridge.filter.list"]
Portal.alias_template = config.get("bridge.alias_template", "telegram_{groupname}")
Portal.hs_domain = config["homeserver.domain"]
- localpart = Portal.alias_template.format(groupname="(.+)")
- Portal.mx_alias_regex = re.compile(f"#{localpart}:{Portal.hs_domain}")
+ Portal.mx_alias_regex = re.compile(
+ f"#{Portal.alias_template.format(groupname='(.+)')}:{Portal.hs_domain}")
diff --git a/mautrix_telegram/puppet.py b/mautrix_telegram/puppet.py
index 20b6af8a..f5642bc2 100644
--- a/mautrix_telegram/puppet.py
+++ b/mautrix_telegram/puppet.py
@@ -14,32 +14,39 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
+from typing import Optional, Awaitable, Pattern, Dict, List, TYPE_CHECKING
from difflib import SequenceMatcher
-from typing import Optional, Awaitable
import re
import logging
import asyncio
+from sqlalchemy import orm
+
from telethon.tl.types import UserProfilePhoto
from mautrix_appservice import AppService, IntentAPI, IntentError, MatrixRequestError
from .db import Puppet as DBPuppet
-from . import util, matrix
+from . import util
-config = None
+if TYPE_CHECKING:
+ from .matrix import MatrixHandler
+ from .config import Config
+ from .context import Context
+
+config = None # type: Config
class Puppet:
- log = logging.getLogger("mau.puppet")
- db = None
+ log = logging.getLogger("mau.puppet") # type: logging.Logger
+ db = None # type: orm.Session
az = None # type: AppService
- mx = None # type: matrix.MatrixHandler
+ mx = None # type: MatrixHandler
loop = None # type: asyncio.AbstractEventLoop
- mxid_regex = None
- username_template = None
- hs_domain = None
- cache = {}
- by_custom_mxid = {}
+ mxid_regex = None # type: Pattern
+ username_template = None # type: str
+ hs_domain = None # type: str
+ cache = {} # type: Dict[str, Puppet]
+ by_custom_mxid = {} # type: Dict[str, Puppet]
def __init__(self, id=None, access_token=None, custom_mxid=None, username=None,
displayname=None, displayname_source=None, photo_id=None, is_bot=None,
@@ -71,7 +78,8 @@ class Puppet:
def tgid(self):
return self.id
- async def is_logged_in(self):
+ @staticmethod
+ async def is_logged_in():
return True
# region Custom puppet management
@@ -154,12 +162,12 @@ class Puppet:
def filter_events(self, events):
new_events = []
for event in events:
- type = event.get("type", None)
+ evt_type = event.get("type", None)
event.setdefault("content", {})
- if type == "m.typing":
+ if evt_type == "m.typing":
is_typing = self.custom_mxid in event["content"].get("user_ids", [])
event["content"]["user_ids"] = [self.custom_mxid] if is_typing else []
- elif type == "m.receipt":
+ elif evt_type == "m.receipt":
val = None
evt = None
for event_id in event["content"]:
@@ -273,7 +281,7 @@ class Puppet:
return round(similarity * 1000) / 10
@staticmethod
- def get_displayname(info, format=True):
+ def get_displayname(info, enable_format=True):
data = {
"phone number": info.phone if hasattr(info, "phone") else None,
"username": info.username,
@@ -295,7 +303,7 @@ class Puppet:
elif not name:
name = info.id
- if not format:
+ if not enable_format:
return name
return config.get("bridge.displayname_template", "{displayname} (Telegram)").format(
displayname=name)
@@ -347,18 +355,18 @@ class Puppet:
# region Getters
@classmethod
- def get(cls, id, create=True) -> "Optional[Puppet]":
+ def get(cls, tgid, create=True) -> "Optional[Puppet]":
try:
- return cls.cache[id]
+ return cls.cache[tgid]
except KeyError:
pass
- puppet = DBPuppet.query.get(id)
+ puppet = DBPuppet.query.get(tgid)
if puppet:
return cls.from_db(puppet)
if create:
- puppet = cls(id)
+ puppet = cls(tgid)
cls.db.add(puppet.db_instance)
cls.db.commit()
return puppet
@@ -402,8 +410,8 @@ class Puppet:
return None
@classmethod
- def get_mxid_from_id(cls, id):
- return f"@{cls.username_template.format(userid=id)}:{cls.hs_domain}"
+ def get_mxid_from_id(cls, tgid):
+ return f"@{cls.username_template.format(userid=tgid)}:{cls.hs_domain}"
@classmethod
def find_by_username(cls, username) -> "Optional[Puppet]":
@@ -437,12 +445,12 @@ class Puppet:
# endregion
-def init(context):
+def init(context: "Context") -> List[Awaitable[int]]:
global config
Puppet.az, Puppet.db, config, Puppet.loop, _ = context
Puppet.mx = context.mx
Puppet.username_template = config.get("bridge.username_template", "telegram_{userid}")
Puppet.hs_domain = config["homeserver"]["domain"]
- localpart = Puppet.username_template.format(userid="(.+)")
- Puppet.mxid_regex = re.compile(f"@{localpart}:{Puppet.hs_domain}")
+ Puppet.mxid_regex = re.compile(
+ f"@{Puppet.username_template.format(userid='(.+)')}:{Puppet.hs_domain}")
return [puppet.init_custom_mxid() for puppet in Puppet.get_all_with_custom_mxid()]
diff --git a/mautrix_telegram/sqlstatestore.py b/mautrix_telegram/sqlstatestore.py
index 63b030d2..68e9fd9d 100644
--- a/mautrix_telegram/sqlstatestore.py
+++ b/mautrix_telegram/sqlstatestore.py
@@ -16,6 +16,8 @@
# along with this program. If not, see .
from typing import Dict, Tuple
+from sqlalchemy import orm
+
from mautrix_appservice import StateStore
from . import puppet as pu
@@ -25,15 +27,17 @@ from .db import RoomState, UserProfile
class SQLStateStore(StateStore):
def __init__(self, db):
super().__init__()
- self.db = db
+ self.db = db # type: orm.Session
self.profile_cache = {} # type: Dict[Tuple[str, str], UserProfile]
self.room_state_cache = {} # type: Dict[str, RoomState]
- def is_registered(self, user: str) -> bool:
+ @staticmethod
+ def is_registered(user: str) -> bool:
puppet = pu.Puppet.get_by_mxid(user)
return puppet.is_registered if puppet else False
- def registered(self, user: str):
+ @staticmethod
+ def registered(user: str):
puppet = pu.Puppet.get_by_mxid(user)
if puppet:
puppet.is_registered = True
diff --git a/mautrix_telegram/tgclient.py b/mautrix_telegram/tgclient.py
index 302515d8..4534524e 100644
--- a/mautrix_telegram/tgclient.py
+++ b/mautrix_telegram/tgclient.py
@@ -17,10 +17,14 @@
from telethon import TelegramClient, utils
from telethon.tl.functions.messages import SendMediaRequest
from telethon.tl.types import *
+from telethon.tl import custom
class MautrixTelegramClient(TelegramClient):
- async def upload_file(self, file, mime_type=None, attributes=None, file_name=None):
+ async def upload_file_direct(self, file: bytes, mime_type: str = None,
+ attributes: List[TypeDocumentAttribute] = None,
+ file_name: str = None
+ ) -> Union[InputMediaUploadedDocument, InputMediaUploadedPhoto]:
file_handle = await super().upload_file(file, file_name=file_name, use_cache=False)
if mime_type == "image/png" or mime_type == "image/jpeg":
@@ -34,7 +38,10 @@ class MautrixTelegramClient(TelegramClient):
mime_type=mime_type or "application/octet-stream",
attributes=list(attr_dict.values()))
- async def send_media(self, entity, media, caption=None, entities=None, reply_to=None):
+ async def send_media(self, entity: Union[TypeInputPeer, TypePeer],
+ media: Union[TypeInputMedia, TypeMessageMedia],
+ caption: str = None, entities: List[TypeMessageEntity] = None,
+ reply_to: int = None) -> Optional[custom.Message]:
entity = await self.get_input_entity(entity)
reply_to = utils.get_message_id(reply_to)
request = SendMediaRequest(entity, media, message=caption or "", entities=entities or [],
diff --git a/mautrix_telegram/user.py b/mautrix_telegram/user.py
index 8e229f94..c2bdf780 100644
--- a/mautrix_telegram/user.py
+++ b/mautrix_telegram/user.py
@@ -14,42 +14,51 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-from typing import Dict, Awaitable, Optional
+from typing import Dict, Awaitable, Optional, Match, Tuple, TYPE_CHECKING
import logging
import asyncio
import re
from telethon.tl.types import *
+from telethon.tl.types import User as TLUser
from telethon.tl.types.contacts import ContactsNotModified
from telethon.tl.functions.contacts import GetContactsRequest, SearchRequest
from telethon.tl.functions.account import UpdateStatusRequest
from mautrix_appservice import MatrixRequestError
-from .db import User as DBUser, Contact as DBContact
+from .db import User as DBUser, Contact as DBContact, Portal as DBPortal
from .abstract_user import AbstractUser
from . import portal as po, puppet as pu
-config = None
+if TYPE_CHECKING:
+ from .config import Config
+ from .context import Context
+
+config = None # type: Config
+
+SearchResults = List[Tuple["pu.Puppet", int]]
class User(AbstractUser):
- log = logging.getLogger("mau.user")
- by_mxid = {}
- by_tgid = {}
+ log = logging.getLogger("mau.user") # type: logging.Logger
+ by_mxid = {} # type: Dict[str, User]
+ by_tgid = {} # type: Dict[int, User]
- def __init__(self, mxid, tgid=None, username=None, db_contacts=None, saved_contacts=0,
- is_bot=False, db_portals=None, db_instance=None):
+ def __init__(self, mxid: str, tgid: Optional[int] = None, username: Optional[str] = None,
+ db_contacts: Optional[List[DBContact]] = None, saved_contacts: int = 0,
+ is_bot: bool = False, db_portals: Optional[List[DBPortal]] = None,
+ db_instance: Optional[DBUser] = None):
super().__init__()
self.mxid = mxid # type: str
self.tgid = tgid # type: int
self.is_bot = is_bot # type: bool
self.username = username # type: str
- self.contacts = []
- self.saved_contacts = saved_contacts
- self.db_contacts = db_contacts
- self.portals = {} # type: Dict[str, po.Portal]
- self.db_portals = db_portals
- self._db_instance = db_instance
+ self.contacts = [] # type: List[pu.Puppet]
+ self.saved_contacts = saved_contacts # type: int
+ self.db_contacts = db_contacts # type: List[DBContact]
+ self.portals = {} # type: Dict[Tuple[int, int], po.Portal]
+ self.db_portals = db_portals # type: List[DBPortal]
+ self._db_instance = db_instance # type: DBUser
self.command_status = None # type: dict
@@ -64,53 +73,47 @@ class User(AbstractUser):
self.by_tgid[tgid] = self
@property
- def name(self):
+ def name(self) -> str:
return self.mxid
@property
- def mxid_localpart(self):
- match = re.compile("@(.+):(.+)").match(self.mxid)
+ def mxid_localpart(self) -> str:
+ match = re.compile("@(.+):(.+)").match(self.mxid) # type: Match
return match.group(1)
# TODO replace with proper displayname getting everywhere
@property
- def displayname(self):
+ def displayname(self) -> str:
return self.mxid_localpart
@property
- def db_contacts(self):
+ def db_contacts(self) -> List[DBContact]:
return [self.db.merge(DBContact(user=self.tgid, contact=puppet.id))
for puppet in self.contacts]
@db_contacts.setter
- def db_contacts(self, contacts):
- if contacts:
- self.contacts = [pu.Puppet.get(entry.contact) for entry in contacts]
- else:
- self.contacts = []
+ def db_contacts(self, contacts: List[DBContact]):
+ self.contacts = [pu.Puppet.get(entry.contact) for entry in contacts] if contacts else []
@property
- def db_portals(self):
+ def db_portals(self) -> List[DBPortal]:
return [portal.db_instance for portal in self.portals.values()]
@db_portals.setter
- def db_portals(self, portals):
- if portals:
- self.portals = {(portal.tgid, portal.tg_receiver):
- po.Portal.get_by_tgid(portal.tgid, portal.tg_receiver)
- for portal in portals}
- else:
- self.portals = {}
+ def db_portals(self, portals: List[DBPortal]):
+ self.portals = {(portal.tgid, portal.tg_receiver):
+ po.Portal.get_by_tgid(portal.tgid, portal.tg_receiver)
+ for portal in portals} if portals else {}
# region Database conversion
@property
- def db_instance(self):
+ def db_instance(self) -> DBUser:
if not self._db_instance:
self._db_instance = self.new_db_instance()
return self._db_instance
- def new_db_instance(self):
+ def new_db_instance(self) -> DBUser:
return DBUser(mxid=self.mxid, tgid=self.tgid, tg_username=self.username,
contacts=self.db_contacts, saved_contacts=self.saved_contacts or 0,
portals=self.db_portals)
@@ -134,14 +137,14 @@ class User(AbstractUser):
self.db.commit()
@classmethod
- def from_db(cls, db_user):
+ def from_db(cls, db_user: DBUser) -> "User":
return User(db_user.mxid, db_user.tgid, db_user.tg_username, db_user.contacts,
False, db_user.saved_contacts, db_user.portals, db_instance=db_user)
# endregion
# region Telegram connection management
- async def start(self, delete_unless_authenticated=False):
+ async def start(self, delete_unless_authenticated: bool = False) -> "User":
await super().start()
if await self.is_logged_in():
self.log.debug(f"Ensuring post_login() for {self.name}")
@@ -152,7 +155,7 @@ class User(AbstractUser):
self.client.session.delete()
return self
- async def post_login(self, info=None):
+ async def post_login(self, info: TLUser = None):
try:
await self.update_info(info)
if not self.is_bot:
@@ -163,7 +166,7 @@ class User(AbstractUser):
except Exception:
self.log.exception("Failed to run post-login functions for %s", self.mxid)
- async def update(self, update):
+ async def update(self, update: TypeUpdate):
if not self.is_bot:
return
@@ -186,7 +189,7 @@ class User(AbstractUser):
# endregion
# region Telegram actions that need custom methods
- def ensure_started(self, even_if_no_session=False) -> "Awaitable[User]":
+ def ensure_started(self, even_if_no_session: bool = False) -> "Awaitable[User]":
return super().ensure_started(even_if_no_session)
def set_presence(self, online: bool = True):
@@ -194,7 +197,7 @@ class User(AbstractUser):
return
return self.client(UpdateStatusRequest(offline=not online))
- async def update_info(self, info: User = None):
+ async def update_info(self, info: TLUser = None):
info = info or await self.client.get_me()
changed = False
if self.is_bot != info.bot:
@@ -233,8 +236,9 @@ class User(AbstractUser):
self.delete()
return True
- def _search_local(self, query, max_results=5, min_similarity=45):
- results = []
+ def _search_local(self, query: str, max_results: int = 5, min_similarity: int = 45
+ ) -> SearchResults:
+ results = [] # type: SearchResults
for contact in self.contacts:
similarity = contact.similarity(query)
if similarity >= min_similarity:
@@ -242,11 +246,11 @@ class User(AbstractUser):
results.sort(key=lambda tup: tup[1], reverse=True)
return results[0:max_results]
- async def _search_remote(self, query, max_results=5):
+ async def _search_remote(self, query: str, max_results: int = 5) -> SearchResults:
if len(query) < 5:
return []
server_results = await self.client(SearchRequest(q=query, limit=max_results))
- results = []
+ results = [] # type: SearchResults
for user in server_results.users:
puppet = pu.Puppet.get(user.id)
await puppet.update_info(self, user)
@@ -254,7 +258,7 @@ class User(AbstractUser):
results.sort(key=lambda tup: tup[1], reverse=True)
return results[0:max_results]
- async def search(self, query, force_remote=False):
+ async def search(self, query: str, force_remote: bool = False) -> Tuple[SearchResults, bool]:
if force_remote:
return await self._search_remote(query), True
@@ -264,7 +268,7 @@ class User(AbstractUser):
return await self._search_remote(query), True
- async def sync_dialogs(self, synchronous_create=False):
+ async def sync_dialogs(self, synchronous_create: bool = False):
creators = []
for entity in await self.get_dialogs(limit=30):
portal = po.Portal.get_by_entity(entity)
@@ -275,7 +279,7 @@ class User(AbstractUser):
self.save()
await asyncio.gather(*creators, loop=self.loop)
- def register_portal(self, portal):
+ def register_portal(self, portal: po.Portal):
try:
if self.portals[portal.tgid_full] == portal:
return
@@ -284,18 +288,18 @@ class User(AbstractUser):
self.portals[portal.tgid_full] = portal
self.save()
- def unregister_portal(self, portal):
+ def unregister_portal(self, portal: po.Portal):
try:
del self.portals[portal.tgid_full]
self.save()
except KeyError:
pass
- async def needs_relaybot(self, portal):
+ async def needs_relaybot(self, portal: po.Portal) -> bool:
return not await self.is_logged_in() or (
self.is_bot and portal.tgid_full not in self.portals)
- def _hash_contacts(self):
+ def _hash_contacts(self) -> int:
acc = 0
for id in sorted([self.saved_contacts] + [contact.id for contact in self.contacts]):
acc = (acc * 20261 + id) & 0xffffffff
@@ -318,7 +322,7 @@ class User(AbstractUser):
# region Class instance lookup
@classmethod
- def get_by_mxid(cls, mxid, create=True) -> "Optional[User]":
+ def get_by_mxid(cls, mxid: str, create: bool=True) -> "Optional[User]":
if not mxid:
raise ValueError("Matrix ID can't be empty")
@@ -341,7 +345,7 @@ class User(AbstractUser):
return None
@classmethod
- def get_by_tgid(cls, tgid) -> "Optional[User]":
+ def get_by_tgid(cls, tgid: int) -> "Optional[User]":
try:
return cls.by_tgid[tgid]
except KeyError:
@@ -355,7 +359,7 @@ class User(AbstractUser):
return None
@classmethod
- def find_by_username(cls, username) -> "Optional[User]":
+ def find_by_username(cls, username: str) -> "Optional[User]":
if not username:
return None
@@ -371,7 +375,7 @@ class User(AbstractUser):
# endregion
-def init(context):
+def init(context: "Context") -> List[Awaitable[User]]:
global config
config = context.config
diff --git a/mautrix_telegram/util/file_transfer.py b/mautrix_telegram/util/file_transfer.py
index e927cd77..d950b2a0 100644
--- a/mautrix_telegram/util/file_transfer.py
+++ b/mautrix_telegram/util/file_transfer.py
@@ -14,15 +14,25 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
+from typing import Optional, Tuple, Union, Dict
from io import BytesIO
import time
import logging
import asyncio
import magic
+from sqlalchemy import orm
from sqlalchemy.exc import IntegrityError, InvalidRequestError
from sqlalchemy.orm.exc import FlushError
+from telethon.tl.types import (Document, FileLocation, InputFileLocation,
+ InputDocumentFileLocation, PhotoSize, PhotoCachedSize)
+from telethon.errors import *
+from mautrix_appservice import IntentAPI
+
+from ..tgclient import MautrixTelegramClient
+from ..db import TelegramFile as DBTelegramFile
+
try:
from PIL import Image
except ImportError:
@@ -36,20 +46,18 @@ try:
except ImportError:
VideoFileClip = random = string = os = mimetypes = None
-from telethon.tl.types import (Document, FileLocation, InputFileLocation,
- InputDocumentFileLocation, PhotoSize, PhotoCachedSize)
-from telethon.errors import *
+log = logging.getLogger("mau.util") # type: logging.Logger
-from ..db import TelegramFile as DBTelegramFile
-
-log = logging.getLogger("mau.util")
+TypeLocation = Union[Document, InputDocumentFileLocation, FileLocation, InputFileLocation]
-def convert_image(file, source_mime="image/webp", target_type="png", thumbnail_to=None):
+def convert_image(file: bytes, source_mime: str = "image/webp", target_type: str = "png",
+ thumbnail_to: Optional[Tuple[int, int]] = None
+ ) -> Tuple[str, bytes, Optional[int], Optional[int]]:
if not Image:
return source_mime, file, None, None
try:
- image = Image.open(BytesIO(file)).convert("RGBA")
+ image = Image.open(BytesIO(file)).convert("RGBA") # type: Image.Image
if thumbnail_to:
image.thumbnail(thumbnail_to, Image.ANTIALIAS)
new_file = BytesIO()
@@ -61,13 +69,14 @@ def convert_image(file, source_mime="image/webp", target_type="png", thumbnail_t
return source_mime, file, None, None
-def _temp_file_name(ext):
+def _temp_file_name(ext: str) -> str:
return ("/tmp/mxtg-video-"
+ "".join(random.choice(string.ascii_uppercase + string.digits) for _ in range(10))
+ ext)
-def _read_video_thumbnail(data, video_ext="mp4", frame_ext="png", max_size=(1024, 720)):
+def _read_video_thumbnail(data: bytes, video_ext: str = "mp4", frame_ext: str = "png",
+ max_size: Tuple[int, int] = (1024, 720)) -> Tuple[bytes, int, int]:
# We don't have any way to read the video from memory, so save it to disk.
temp_file = _temp_file_name(video_ext)
with open(temp_file, "wb") as file:
@@ -90,21 +99,21 @@ def _read_video_thumbnail(data, video_ext="mp4", frame_ext="png", max_size=(1024
return thumbnail_file.getvalue(), w, h
-def _location_to_id(location):
+def _location_to_id(location: TypeLocation) -> str:
if isinstance(location, (Document, InputDocumentFileLocation)):
return f"{location.id}-{location.version}"
elif isinstance(location, (FileLocation, InputFileLocation)):
return f"{location.volume_id}-{location.local_id}"
- else:
- return None
-async def transfer_thumbnail_to_matrix(client, intent, thumbnail_loc, video, mime):
+async def transfer_thumbnail_to_matrix(client: MautrixTelegramClient, intent: IntentAPI,
+ thumbnail_loc: TypeLocation, video: bytes,
+ mime: str) -> Optional[DBTelegramFile]:
if not Image or not VideoFileClip:
return None
- id = _location_to_id(thumbnail_loc)
- if not id:
+ loc_id = _location_to_id(thumbnail_loc)
+ if not loc_id:
return None
video_ext = mimetypes.guess_extension(mime)
@@ -121,36 +130,40 @@ async def transfer_thumbnail_to_matrix(client, intent, thumbnail_loc, video, mim
content_uri = await intent.upload_file(file, mime_type)
- return DBTelegramFile(id=id, mxc=content_uri, mime_type=mime_type,
+ return DBTelegramFile(id=loc_id, mxc=content_uri, mime_type=mime_type,
was_converted=False, timestamp=int(time.time()), size=len(file),
width=width, height=height)
-transfer_locks = {}
-transfer_locks_lock = asyncio.Lock()
+transfer_locks = {} # type: Dict[str, asyncio.Lock]
-async def transfer_file_to_matrix(db, client, intent, location, thumbnail=None, is_sticker=False):
- id = _location_to_id(location)
- if not id:
+async def transfer_file_to_matrix(db: orm.Session, client: MautrixTelegramClient, intent: IntentAPI,
+ location: TypeLocation, thumbnail: Optional[TypeLocation] = None,
+ is_sticker: bool = False) -> Optional[DBTelegramFile]:
+ location_id = _location_to_id(location)
+ if not location_id:
return None
- db_file = DBTelegramFile.query.get(id)
+ db_file = DBTelegramFile.query.get(location_id)
if db_file:
return db_file
- async with transfer_locks_lock:
- try:
- lock = transfer_locks[id]
- except KeyError:
- lock = asyncio.Lock()
- transfer_locks[id] = lock
+ try:
+ lock = transfer_locks[location_id]
+ except KeyError:
+ lock = asyncio.Lock()
+ transfer_locks[location_id] = lock
async with lock:
- return await _unlocked_transfer_file_to_matrix(db, client, intent, id, location, thumbnail, is_sticker)
+ return await _unlocked_transfer_file_to_matrix(db, client, intent, location_id, location,
+ thumbnail, is_sticker)
-async def _unlocked_transfer_file_to_matrix(db, client, intent, id, location, thumbnail, is_sticker):
- db_file = DBTelegramFile.query.get(id)
+async def _unlocked_transfer_file_to_matrix(db: orm.Session, client: MautrixTelegramClient,
+ intent: IntentAPI, loc_id: str, location: TypeLocation,
+ thumbnail: Optional[TypeLocation],
+ is_sticker: bool) -> Optional[DBTelegramFile]:
+ db_file = DBTelegramFile.query.get(loc_id)
if db_file:
return db_file
@@ -167,15 +180,16 @@ async def _unlocked_transfer_file_to_matrix(db, client, intent, id, location, th
image_converted = False
if mime_type == "image/webp":
- new_mime_type, file, width, height = convert_image(file, source_mime="image/webp", target_type="png", thumbnail_to=(
- 256, 256) if is_sticker else None)
+ new_mime_type, file, width, height = convert_image(
+ file, source_mime="image/webp", target_type="png",
+ thumbnail_to=(256, 256) if is_sticker else None)
image_converted = new_mime_type != mime_type
mime_type = new_mime_type
thumbnail = None
content_uri = await intent.upload_file(file, mime_type)
- db_file = DBTelegramFile(id=id, mxc=content_uri,
+ db_file = DBTelegramFile(id=loc_id, mxc=content_uri,
mime_type=mime_type, was_converted=image_converted,
timestamp=int(time.time()), size=len(file),
width=width, height=height)
diff --git a/mautrix_telegram/util/format_duration.py b/mautrix_telegram/util/format_duration.py
index c873e9e5..9402b83e 100644
--- a/mautrix_telegram/util/format_duration.py
+++ b/mautrix_telegram/util/format_duration.py
@@ -16,10 +16,12 @@
# along with this program. If not, see .
-def format_duration(seconds):
- def pluralize(count, singular): return singular if count == 1 else singular + "s"
+def format_duration(seconds: int) -> str:
+ def pluralize(count, singular):
+ return singular if count == 1 else singular + "s"
- def include(count, word): return f"{count} {pluralize(count, word)}" if count > 0 else ""
+ def include(count, word):
+ return f"{count} {pluralize(count, word)}" if count > 0 else ""
minutes, seconds = divmod(seconds, 60)
hours, minutes = divmod(minutes, 60)