diff --git a/mautrix_telegram/__main__.py b/mautrix_telegram/__main__.py index ad4cb4ef..ed566c6d 100644 --- a/mautrix_telegram/__main__.py +++ b/mautrix_telegram/__main__.py @@ -14,34 +14,33 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . +from typing import Optional import argparse -import sys -import logging -import logging.config import asyncio +import logging.config +import sys -import sqlalchemy as sql from sqlalchemy import orm +import sqlalchemy as sql -from alchemysession import AlchemySessionContainer from mautrix_appservice import AppService +from alchemysession import AlchemySessionContainer -from .base import Base -from .config import Config -from .matrix import MatrixHandler - -from . import __version__ -from .db import init as init_db +from .web.provisioning import ProvisioningAPI +from .web.public import PublicBridgeWebsite from .abstract_user import init as init_abstract_user -from .user import init as init_user, User +from .base import Base from .bot import init as init_bot +from .config import Config +from .context import Context +from .db import init as init_db +from .formatter import init as init_formatter +from .matrix import MatrixHandler from .portal import init as init_portal from .puppet import init as init_puppet -from .formatter import init as init_formatter -from .web.public import PublicBridgeWebsite -from .web.provisioning import ProvisioningAPI -from .context import Context from .sqlstatestore import SQLStateStore +from .user import User, init as init_user +from . import __version__ parser = argparse.ArgumentParser( description="A Matrix-Telegram puppeting bridge.", @@ -68,7 +67,7 @@ if args.generate_registration: sys.exit(0) logging.config.dictConfig(config["logging"]) -log = logging.getLogger("mau.init") +log = logging.getLogger("mau.init") # type: logging.Logger log.debug(f"Initializing mautrix-telegram {__version__}") db_engine = sql.create_engine(config["appservice.database"] or "sqlite:///mautrix-telegram.db") @@ -80,7 +79,7 @@ session_container = AlchemySessionContainer(engine=db_engine, session=db_session table_base=Base, table_prefix="telethon_", manage_tables=False) -loop = asyncio.get_event_loop() +loop = asyncio.get_event_loop() # type: asyncio.AbstractEventLoop state_store = SQLStateStore(db_session) appserv = AppService(config["homeserver.address"], config["homeserver.domain"], @@ -89,8 +88,8 @@ appserv = AppService(config["homeserver.address"], config["homeserver.domain"], verify_ssl=config["homeserver.verify_ssl"], state_store=state_store, real_user_content_key="net.maunium.telegram.puppet") -public_website = None -provisioning_api = None +public_website = None # type: Optional[PublicBridgeWebsite] +provisioning_api = None # type: Optional[ProvisioningAPI] if config["appservice.public.enabled"]: public_website = PublicBridgeWebsite(loop) diff --git a/mautrix_telegram/abstract_user.py b/mautrix_telegram/abstract_user.py index 0de32c91..49632378 100644 --- a/mautrix_telegram/abstract_user.py +++ b/mautrix_telegram/abstract_user.py @@ -14,26 +14,48 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . +from typing import Tuple, Optional, List, Union, TYPE_CHECKING +from abc import ABC, abstractmethod +import asyncio +import logging import platform -from telethon.tl.types import * -from mautrix_appservice import MatrixRequestError +from sqlalchemy import orm +from telethon.tl.types import Channel, ChannelForbidden, Chat, ChatForbidden, Message, \ + MessageActionChannelMigrateFrom, MessageService, PeerUser, TypeUpdate, \ + UpdateChannelPinnedMessage, UpdateChatAdmins, UpdateChatParticipantAdmin, \ + UpdateChatParticipants, UpdateChatUserTyping, UpdateDeleteChannelMessages, \ + UpdateDeleteMessages, UpdateEditChannelMessage, UpdateEditMessage, UpdateNewChannelMessage, \ + UpdateNewMessage, UpdateReadHistoryOutbox, UpdateShortChatMessage, UpdateShortMessage, \ + UpdateUserName, UpdateUserPhoto, UpdateUserStatus, UpdateUserTyping, User, UserStatusOffline, \ + UserStatusOnline + +from mautrix_appservice import MatrixRequestError, AppService +from alchemysession import AlchemySessionContainer -from .tgclient import MautrixTelegramClient -from .db import Message as DBMessage from . import portal as po, puppet as pu, __version__ +from .db import Message as DBMessage +from .tgclient import MautrixTelegramClient -config = None +if TYPE_CHECKING: + from .context import Context + from .config import Config + +config = None # type: Config # Value updated from config in init() -MAX_DELETIONS = 10 +MAX_DELETIONS = 10 # type: int + +UpdateMessage = Union[UpdateShortChatMessage, UpdateShortMessage, UpdateNewChannelMessage, + UpdateNewMessage, UpdateEditMessage, UpdateEditChannelMessage] +UpdateMessageContent = Union[UpdateShortMessage, UpdateShortChatMessage, Message, MessageService] -class AbstractUser: - session_container = None - loop = None - log = None - db = None - az = None +class AbstractUser(ABC): + session_container = None # type: AlchemySessionContainer + loop = None # type: asyncio.AbstractEventLoop + log = None # type: logging.Logger + db = None # type: orm.Session + az = None # type: AppService def __init__(self): self.puppet_whitelisted = False # type: bool @@ -47,22 +69,22 @@ class AbstractUser: self.is_bot = False # type: bool @property - def connected(self): + def connected(self) -> bool: return self.client and self.client.is_connected() @property - def _proxy_settings(self): - type = config["telegram.proxy.type"].lower() - if type == "disabled": + def _proxy_settings(self) -> Optional[Tuple[int, str, str, str, str, str]]: + proxy_type = config["telegram.proxy.type"].lower() + if proxy_type == "disabled": return None - elif type == "socks4": - type = 1 - elif type == "socks5": - type = 2 - elif type == "http": - type = 3 + elif proxy_type == "socks4": + proxy_type = 1 + elif proxy_type == "socks5": + proxy_type = 2 + elif proxy_type == "http": + proxy_type = 3 - return (type, + return (proxy_type, config["telegram.proxy.address"], config["telegram.proxy.port"], config["telegram.proxy.rdns"], config["telegram.proxy.username"], config["telegram.proxy.password"]) @@ -83,20 +105,30 @@ class AbstractUser: proxy=self._proxy_settings) self.client.add_event_handler(self._update_catch) - async def update(self, update): + @abstractmethod + async def update(self, update: TypeUpdate) -> bool: return False + @abstractmethod async def post_login(self): raise NotImplementedError() - async def _update_catch(self, update): + @abstractmethod + def register_portal(self, portal: po.Portal): + raise NotImplementedError() + + @abstractmethod + def unregister_portal(self, portal: po.Portal): + raise NotImplementedError() + + async def _update_catch(self, update: TypeUpdate): try: if not await self.update(update): await self._update(update) except Exception: self.log.exception("Failed to handle Telegram update") - async def get_dialogs(self, limit=None) -> List[Union[Chat, Channel]]: + async def get_dialogs(self, limit: int = None) -> List[Union[Chat, Channel]]: if self.is_bot: return [] dialogs = await self.client.get_dialogs(limit=limit) @@ -106,18 +138,19 @@ class AbstractUser: and (dialog.entity.deactivated or dialog.entity.left)))] @property - def name(self): + @abstractmethod + def name(self) -> str: raise NotImplementedError() - async def is_logged_in(self): + async def is_logged_in(self) -> bool: return self.client and await self.client.is_user_authorized() - async def has_full_access(self, allow_bot=False): + async def has_full_access(self, allow_bot: bool = False) -> bool: return (self.puppet_whitelisted and (not self.is_bot or allow_bot) and await self.is_logged_in()) - async def start(self, delete_unless_authenticated=False): + async def start(self, delete_unless_authenticated: bool = False) -> "AbstractUser": if not self.client: self._init_client() await self.client.connect() @@ -144,7 +177,7 @@ class AbstractUser: # region Telegram update handling - async def _update(self, update): + async def _update(self, update: TypeUpdate): if isinstance(update, (UpdateShortChatMessage, UpdateShortMessage, UpdateNewChannelMessage, UpdateNewMessage, UpdateEditMessage, UpdateEditChannelMessage)): await self.update_message(update) @@ -169,17 +202,19 @@ class AbstractUser: else: self.log.debug("Unhandled update: %s", update) - async def update_pinned_messages(self, update): + @staticmethod + async def update_pinned_messages(update: UpdateChannelPinnedMessage): portal = po.Portal.get_by_tgid(update.channel_id) if portal and portal.mxid: await portal.receive_telegram_pin_id(update.id) - async def update_participants(self, update): + @staticmethod + async def update_participants(update: UpdateChatParticipants): portal = po.Portal.get_by_tgid(update.participants.chat_id) if portal and portal.mxid: await portal.update_telegram_participants(update.participants.participants) - async def update_read_receipt(self, update): + async def update_read_receipt(self, update: UpdateReadHistoryOutbox): if not isinstance(update.peer, PeerUser): self.log.debug("Unexpected read receipt peer: %s", update.peer) return @@ -196,7 +231,7 @@ class AbstractUser: puppet = pu.Puppet.get(update.peer.user_id) await puppet.intent.mark_read(portal.mxid, message.mxid) - async def update_admin(self, update): + async def update_admin(self, update: Union[UpdateChatAdmins, UpdateChatParticipantAdmin]): # TODO duplication not checked portal = po.Portal.get_by_tgid(update.chat_id, peer_type="chat") if isinstance(update, UpdateChatAdmins): @@ -206,7 +241,7 @@ class AbstractUser: else: self.log.warning("Unexpected admin status update: %s", update) - async def update_typing(self, update): + async def update_typing(self, update: Union[UpdateUserTyping, UpdateChatUserTyping]): if isinstance(update, UpdateUserTyping): portal = po.Portal.get_by_tgid(update.user_id, self.tgid, "user") else: @@ -214,7 +249,7 @@ class AbstractUser: sender = pu.Puppet.get(update.user_id) await portal.handle_telegram_typing(sender, update) - async def update_others_info(self, update): + async def update_others_info(self, update: Union[UpdateUserName, UpdateUserPhoto]): # TODO duplication not checked puppet = pu.Puppet.get(update.user_id) if isinstance(update, UpdateUserName): @@ -226,7 +261,7 @@ class AbstractUser: else: self.log.warning("Unexpected other user info update: %s", update) - async def update_status(self, update): + async def update_status(self, update: UpdateUserStatus): puppet = pu.Puppet.get(update.user_id) if isinstance(update.status, UserStatusOnline): await puppet.default_mxid_intent.set_presence("online") @@ -236,7 +271,9 @@ class AbstractUser: self.log.warning("Unexpected user status update: %s", update) return - def get_message_details(self, update): + def get_message_details(self, update: UpdateMessage) -> Tuple[UpdateMessageContent, + Optional[pu.Puppet], + Optional[po.Portal]]: if isinstance(update, UpdateShortChatMessage): portal = po.Portal.get_by_tgid(update.chat_id, peer_type="chat") sender = pu.Puppet.get(update.from_id) @@ -259,7 +296,7 @@ class AbstractUser: return update, sender, portal @staticmethod - async def _try_redact(portal, message): + async def _try_redact(portal: po.Portal, message: DBMessage): if not portal: return try: @@ -267,7 +304,7 @@ class AbstractUser: except MatrixRequestError: pass - async def delete_message(self, update): + async def delete_message(self, update: UpdateDeleteMessages): if len(update.messages) > MAX_DELETIONS: return @@ -283,7 +320,7 @@ class AbstractUser: await self._try_redact(portal, message) self.db.commit() - async def delete_channel_message(self, update): + async def delete_channel_message(self, update: UpdateDeleteChannelMessages): if len(update.messages) > MAX_DELETIONS: return @@ -299,7 +336,7 @@ class AbstractUser: await self._try_redact(portal, message) self.db.commit() - async def update_message(self, original_update): + async def update_message(self, original_update: UpdateMessage): update, sender, portal = self.get_message_details(original_update) if isinstance(update, MessageService): @@ -325,7 +362,7 @@ class AbstractUser: # endregion -def init(context): +def init(context: "Context"): global config, MAX_DELETIONS AbstractUser.az, AbstractUser.db, config, AbstractUser.loop, _ = context AbstractUser.session_container = context.session_container diff --git a/mautrix_telegram/base.py b/mautrix_telegram/base.py index c64447da..0b62d886 100644 --- a/mautrix_telegram/base.py +++ b/mautrix_telegram/base.py @@ -1,2 +1,2 @@ from sqlalchemy.ext.declarative import declarative_base -Base = declarative_base() +Base = declarative_base() # type: declarative_base diff --git a/mautrix_telegram/bot.py b/mautrix_telegram/bot.py index c05a62aa..51a6a110 100644 --- a/mautrix_telegram/bot.py +++ b/mautrix_telegram/bot.py @@ -14,7 +14,7 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Awaitable, Callable +from typing import Awaitable, Callable, Pattern, Dict, TYPE_CHECKING import logging import re @@ -27,27 +27,31 @@ from .abstract_user import AbstractUser from .db import BotChat from . import puppet as pu, portal as po, user as u -config = None +if TYPE_CHECKING: + from .config import Config + +config = None # type: Config ReplyFunc = Callable[[str], Awaitable[Message]] class Bot(AbstractUser): - log = logging.getLogger("mau.bot") - mxid_regex = re.compile("@.+:.+") + log = logging.getLogger("mau.bot") # type: logging.Logger + mxid_regex = re.compile("@.+:.+") # type: Pattern def __init__(self, token: str): super().__init__() - self.token = token - self.puppet_whitelisted = True - self.whitelisted = True - self.relaybot_whitelisted = True - self.username = None - self.is_relaybot = True - self.is_bot = True - self.chats = {chat.id: chat.type for chat in BotChat.query.all()} - self.tg_whitelist = [] - self.whitelist_group_admins = config["bridge.relaybot.whitelist_group_admins"] or False + self.token = token # type: str + self.puppet_whitelisted = True # type: bool + self.whitelisted = True # type: bool + self.relaybot_whitelisted = True # type: bool + self.username = None # type: str + self.is_relaybot = True # type: bool + self.is_bot = True # type: bool + self.chats = {chat.id: chat.type for chat in BotChat.query.all()} # type: Dict[int, str] + self.tg_whitelist = [] # type: List[int] + self.whitelist_group_admins = (config["bridge.relaybot.whitelist_group_admins"] + or False) # type: bool async def init_permissions(self): whitelist = config["bridge.relaybot.whitelist"] or [] @@ -61,7 +65,7 @@ class Bot(AbstractUser): if isinstance(id, int): self.tg_whitelist.append(id) - async def start(self, delete_unless_authenticated=False): + async def start(self, delete_unless_authenticated: bool = False) -> "Bot": await super().start(delete_unless_authenticated) if not await self.is_logged_in(): await self.client.sign_in(bot_token=self.token) @@ -118,7 +122,7 @@ class Bot(AbstractUser): self.db.delete(existing_chat) self.db.commit() - async def _can_use_commands(self, chat, tgid): + async def _can_use_commands(self, chat: TypePeer, tgid: int) -> bool: if tgid in self.tg_whitelist: return True @@ -138,7 +142,7 @@ class Bot(AbstractUser): if p.user_id == tgid: return isinstance(p, (ChatParticipantCreator, ChatParticipantAdmin)) - async def check_can_use_commands(self, event: Message, reply: ReplyFunc): + async def check_can_use_commands(self, event: Message, reply: ReplyFunc) -> bool: if not await self._can_use_commands(event.to_id, event.from_id): await reply("You do not have the permission to use that command.") return False @@ -262,7 +266,7 @@ class Bot(AbstractUser): return "bot" -def init(context): +def init(context) -> Optional[Bot]: global config config = context.config token = config["telegram.bot_token"] diff --git a/mautrix_telegram/commands/clean_rooms.py b/mautrix_telegram/commands/clean_rooms.py index e9031d3b..aac5a54d 100644 --- a/mautrix_telegram/commands/clean_rooms.py +++ b/mautrix_telegram/commands/clean_rooms.py @@ -23,15 +23,14 @@ from .. import puppet as pu, portal as po ManagementRoomList = List[Tuple[str, str]] RoomIDList = List[str] -PortalList = List[po.Portal] -async def _find_rooms(intent: IntentAPI) -> Tuple[ - ManagementRoomList, RoomIDList, PortalList, PortalList]: +async def _find_rooms(intent: IntentAPI) -> Tuple[ManagementRoomList, RoomIDList, + List["po.Portal"], List["po.Portal"]]: management_rooms = [] # type: ManagementRoomList unidentified_rooms = [] # type: RoomIDList - portals = [] # type: PortalList - empty_portals = [] # type: PortalList + portals = [] # type: List[po.Portal] + empty_portals = [] # type: List[po.Portal] rooms = await intent.get_joined_rooms() for room in rooms: @@ -108,8 +107,8 @@ async def clean_rooms(evt: CommandEvent): async def set_rooms_to_clean(evt, management_rooms: ManagementRoomList, - unidentified_rooms: RoomIDList, portals: PortalList, - empty_portals: PortalList): + unidentified_rooms: RoomIDList, portals: List["po.Portal"], + empty_portals: List["po.Portal"]): command = evt.args[0] rooms_to_clean = [] if command == "clean-recommended": diff --git a/mautrix_telegram/commands/portal.py b/mautrix_telegram/commands/portal.py index 0c88ca74..c2ff2347 100644 --- a/mautrix_telegram/commands/portal.py +++ b/mautrix_telegram/commands/portal.py @@ -222,7 +222,7 @@ async def bridge(evt: CommandEvent): "chat to this room, use `$cmdprefix+sp continue`") -async def cleanup_old_portal_while_bridging(evt: CommandEvent, portal: po.Portal): +async def cleanup_old_portal_while_bridging(evt: CommandEvent, portal: "po.Portal"): if not portal.mxid: await evt.reply("The portal seems to have lost its Matrix room between you" "calling `$cmdprefix+sp bridge` and this command.\n\n" diff --git a/mautrix_telegram/config.py b/mautrix_telegram/config.py index b7766f47..72e61f27 100644 --- a/mautrix_telegram/config.py +++ b/mautrix_telegram/config.py @@ -14,6 +14,7 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . +from typing import Tuple, Any, Optional from ruamel.yaml import YAML from ruamel.yaml.comments import CommentedMap import random @@ -24,28 +25,28 @@ yaml.indent(4) class DictWithRecursion: - def __init__(self, data=None): - self._data = data or CommentedMap() + def __init__(self, data: CommentedMap = None): + self._data = data or CommentedMap() # type: CommentedMap - def _recursive_get(self, data, key, default_value): + def _recursive_get(self, data: CommentedMap, key: str, default_value: Any) -> Any: if '.' in key: key, next_key = key.split('.', 1) next_data = data.get(key, CommentedMap()) return self._recursive_get(next_data, next_key, default_value) return data.get(key, default_value) - def get(self, key, default_value, allow_recursion=True): + def get(self, key: str, default_value: Any, allow_recursion: bool = True) -> Any: if allow_recursion and '.' in key: return self._recursive_get(self._data, key, default_value) return self._data.get(key, default_value) - def __getitem__(self, key): + def __getitem__(self, key: str) -> Any: return self.get(key, None) - def __contains__(self, key): + def __contains__(self, key: str) -> bool: return self[key] is not None - def _recursive_set(self, data, key, value): + def _recursive_set(self, data: CommentedMap, key: str, value: Any): if '.' in key: key, next_key = key.split('.', 1) if key not in data: @@ -55,16 +56,16 @@ class DictWithRecursion: return data[key] = value - def set(self, key, value, allow_recursion=True): + def set(self, key: str, value: Any, allow_recursion: bool = True): if allow_recursion and '.' in key: self._recursive_set(self._data, key, value) return self._data[key] = value - def __setitem__(self, key, value): + def __setitem__(self, key: str, value: Any): self.set(key, value) - def _recursive_del(self, data, key): + def _recursive_del(self, data: CommentedMap, key: str): if '.' in key: key, next_key = key.split('.', 1) if key not in data: @@ -78,7 +79,7 @@ class DictWithRecursion: except KeyError: pass - def delete(self, key, allow_recursion=True): + def delete(self, key: str, allow_recursion: bool = True): if allow_recursion and '.' in key: self._recursive_del(self._data, key) return @@ -88,23 +89,23 @@ class DictWithRecursion: except KeyError: pass - def __delitem__(self, key): + def __delitem__(self, key: str): self.delete(key) class Config(DictWithRecursion): - def __init__(self, path, registration_path, base_path): + def __init__(self, path: str, registration_path: str, base_path: str): super().__init__() - self.path = path - self.registration_path = registration_path - self.base_path = base_path - self._registration = None + self.path = path # type: str + self.registration_path = registration_path # type: str + self.base_path = base_path # type: str + self._registration = None # type: dict def load(self): with open(self.path, 'r') as stream: self._data = yaml.load(stream) - def load_base(self): + def load_base(self) -> Optional[DictWithRecursion]: try: with open(self.base_path, 'r') as stream: return DictWithRecursion(yaml.load(stream)) @@ -120,7 +121,7 @@ class Config(DictWithRecursion): yaml.dump(self._registration, stream) @staticmethod - def _new_token(): + def _new_token() -> str: return "".join(random.choice(string.ascii_lowercase + string.digits) for _ in range(64)) def update(self): @@ -246,7 +247,7 @@ class Config(DictWithRecursion): self._data = base._data self.save() - def _get_permissions(self, key): + def _get_permissions(self, key: str) -> Tuple[bool, bool, bool, bool, bool]: level = self["bridge.permissions"].get(key, "") admin = level == "admin" puppeting = level == "full" or admin @@ -254,7 +255,7 @@ class Config(DictWithRecursion): relaybot = level == "relaybot" or user return relaybot, user, puppeting, admin, level - def get_permissions(self, mxid): + def get_permissions(self, mxid: str) -> Tuple[bool, bool, bool, bool, bool]: permissions = self["bridge.permissions"] or {} if mxid in permissions: return self._get_permissions(mxid) diff --git a/mautrix_telegram/context.py b/mautrix_telegram/context.py index 1324e5f1..76f75ded 100644 --- a/mautrix_telegram/context.py +++ b/mautrix_telegram/context.py @@ -14,21 +14,27 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Tuple -import asyncio +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import asyncio + + from sqlalchemy.orm import scoped_session + + from alchemysession import AlchemySessionContainer + from mautrix_appservice import AppService + + from .web import PublicBridgeWebsite, ProvisioningAPI + from .config import Config + from .bot import Bot + from .matrix import MatrixHandler -from sqlalchemy.orm import scoped_session -from alchemysession import AlchemySessionContainer -from mautrix_appservice import AppService class Context: - def __init__(self, az, db, config, loop, bot, mx, session_container, public_website, - provisioning_api): - from .web import PublicBridgeWebsite, ProvisioningAPI - from .config import Config - from .bot import Bot - from .matrix import MatrixHandler - + def __init__(self, az: "AppService", db: "scoped_session", config: "Config", + loop: "asyncio.AbstractEventLoop", bot: "Bot", mx: "MatrixHandler", + session_container: "AlchemySessionContainer", + public_website: "PublicBridgeWebsite", provisioning_api: "ProvisioningAPI"): self.az = az # type: AppService self.db = db # type: scoped_session self.config = config # type: Config diff --git a/mautrix_telegram/db.py b/mautrix_telegram/db.py index 81bc0598..5a0baf70 100644 --- a/mautrix_telegram/db.py +++ b/mautrix_telegram/db.py @@ -42,6 +42,7 @@ class Portal(Base): about = Column(String, nullable=True) photo_id = Column(String, nullable=True) + class Message(Base): query = None # type: Query __tablename__ = "message" diff --git a/mautrix_telegram/formatter/from_matrix.py b/mautrix_telegram/formatter/from_matrix.py index f98d3ad5..6619ef02 100644 --- a/mautrix_telegram/formatter/from_matrix.py +++ b/mautrix_telegram/formatter/from_matrix.py @@ -14,10 +14,10 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . +from typing import (Optional, List, Tuple, Type, Callable, Dict, Any, Pattern, Deque, Match, TYPE_CHECKING) from html import unescape from html.parser import HTMLParser from collections import deque -from typing import Optional, List, Tuple, Type, Callable, Dict, Any import math import re import logging @@ -27,37 +27,40 @@ from telethon.tl.types import (MessageEntityMention, MessageEntityMentionName, M MessageEntityItalic, MessageEntityCode, MessageEntityPre, MessageEntityBotCommand, TypeMessageEntity) -from .. import user as u, puppet as pu, portal as po, context as c +from .. import user as u, puppet as pu, portal as po from ..db import Message as DBMessage from .util import (add_surrogates, remove_surrogates, trim_reply_fallback_html, trim_reply_fallback_text, html_to_unicode) -log = logging.getLogger("mau.fmt.mx") -should_bridge_plaintext_highlights = False +if TYPE_CHECKING: + from ..context import Context + +log = logging.getLogger("mau.fmt.mx") # type: logging.Logger +should_bridge_plaintext_highlights = False # type: bool class MatrixParser(HTMLParser): - mention_regex = re.compile("https://matrix.to/#/(@.+:.+)") - room_regex = re.compile("https://matrix.to/#/(#.+:.+)") + mention_regex = re.compile("https://matrix.to/#/(@.+:.+)") # type: Pattern + room_regex = re.compile("https://matrix.to/#/(#.+:.+)") # type: Pattern block_tags = ("br", "p", "pre", "blockquote", "ol", "ul", "li", "h1", "h2", "h3", "h4", "h5", "h6", - "div", "hr", "table") + "div", "hr", "table") # type: Tuple[str, ...] def __init__(self): super().__init__() - self.text = "" - self.entities = [] - self._building_entities = {} - self._list_counter = 0 - self._open_tags = deque() - self._open_tags_meta = deque() - self._line_is_new = True - self._list_entry_is_new = False + self.text = "" # type: str + self.entities = [] # type: List[TypeMessageEntity] + self._building_entities = {} # type: Dict[str, TypeMessageEntity] + self._list_counter = 0 # type: int + self._open_tags = deque() # type: Deque[str] + self._open_tags_meta = deque() # type: Deque[Any] + self._line_is_new = True # type: bool + self._list_entry_is_new = False # type: bool def _parse_url(self, url: str, args: Dict[str, Any] ) -> Tuple[Optional[Type[TypeMessageEntity]], Optional[str]]: - mention = self.mention_regex.match(url) + mention = self.mention_regex.match(url) # type: Match if mention: mxid = mention.group(1) user = (pu.Puppet.get_by_mxid(mxid) @@ -72,7 +75,7 @@ class MatrixParser(HTMLParser): else: return None, None - room = self.room_regex.match(url) + room = self.room_regex.match(url) # type: Match if room: username = po.Portal.get_username_from_mx_alias(room.group(1)) portal = po.Portal.find_by_username(username) @@ -92,8 +95,8 @@ class MatrixParser(HTMLParser): self._open_tags_meta.appendleft(0) attrs = dict(attrs) - entity_type = None - args = {} + entity_type = None # type: type(TypeMessageEntity) + args = {} # type: Dict[str, Any] if tag in ("strong", "b"): entity_type = MessageEntityBold elif tag in ("em", "i"): @@ -243,12 +246,12 @@ class MatrixParser(HTMLParser): self._newline(allow_multi=tag == "br") -command_regex = re.compile(r"^!([A-Za-z0-9@]+)") -not_command_regex = re.compile(r"^\\(![A-Za-z0-9@]+)") -plain_mention_regex = None +command_regex = re.compile(r"^!([A-Za-z0-9@]+)") # type: Pattern +not_command_regex = re.compile(r"^\\(![A-Za-z0-9@]+)") # type: Pattern +plain_mention_regex = None # type: Pattern -def plain_mention_to_html(match): +def plain_mention_to_html(match: Match) -> str: puppet = pu.Puppet.find_by_displayname(match.group(2)) if puppet: return (f"{match.group(1)}" @@ -351,7 +354,7 @@ def plain_mention_to_text() -> Tuple[List[TypeMessageEntity], Callable[[str], st return entities, replacer -def init_mx(context: c.Context): +def init_mx(context: "Context"): global plain_mention_regex, should_bridge_plaintext_highlights config = context.config dn_template = config.get("bridge.displayname_template", "{displayname} (Telegram)") diff --git a/mautrix_telegram/formatter/from_telegram.py b/mautrix_telegram/formatter/from_telegram.py index 70a13a55..33f8a335 100644 --- a/mautrix_telegram/formatter/from_telegram.py +++ b/mautrix_telegram/formatter/from_telegram.py @@ -14,13 +14,8 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . +from typing import Optional, List, Tuple, TYPE_CHECKING from html import escape -from typing import Optional, List, Tuple - -try: - from lxml.html.diff import htmldiff -except ImportError: - htmldiff = None # type: function import logging import re @@ -33,16 +28,26 @@ from telethon.tl.types import (MessageEntityMention, MessageEntityMentionName, from mautrix_appservice import MatrixRequestError from mautrix_appservice.intent_api import IntentAPI -from .. import user as u, puppet as pu, portal as po, context as c +from .. import user as u, puppet as pu, portal as po from ..db import Message as DBMessage from .util import (add_surrogates, remove_surrogates, trim_reply_fallback_html, trim_reply_fallback_text, unicode_to_html) -log = logging.getLogger("mau.fmt.tg") -should_highlight_edits = False +if TYPE_CHECKING: + from ..abstract_user import AbstractUser + from ..context import Context + +try: + from lxml.html.diff import htmldiff +except ImportError: + htmldiff = None # type: function -def telegram_reply_to_matrix(evt: Message, source: u.User) -> dict: +log = logging.getLogger("mau.fmt.tg") # type: logging.Logger +should_highlight_edits = False # type: bool + + +def telegram_reply_to_matrix(evt: Message, source: "AbstractUser") -> dict: if evt.reply_to_msg_id: space = (evt.to_id.channel_id if isinstance(evt, Message) and isinstance(evt.to_id, PeerChannel) @@ -78,7 +83,7 @@ async def _add_forward_header(source, text: str, html: Optional[str], if not fwd_from_text: user = await source.client.get_entity(PeerUser(fwd_from.from_id)) if user: - fwd_from_text = pu.Puppet.get_displayname(user, format=False) + fwd_from_text = pu.Puppet.get_displayname(user, False) fwd_from_html = f"{fwd_from_text}" if not fwd_from_text: @@ -110,8 +115,9 @@ def highlight_edits(new_html: str, old_html: str) -> str: return new_html -async def _add_reply_header(source: u.User, text: str, html: str, evt: Message, relates_to: dict, - main_intent: IntentAPI, is_edit: bool) -> Tuple[str, str]: +async def _add_reply_header(source: "AbstractUser", text: str, html: str, evt: Message, + relates_to: dict, main_intent: IntentAPI, is_edit: bool + ) -> Tuple[str, str]: space = (evt.to_id.channel_id if isinstance(evt, Message) and isinstance(evt.to_id, PeerChannel) else source.tgid) @@ -142,7 +148,7 @@ async def _add_reply_header(source: u.User, text: str, html: str, evt: Message, if is_edit and should_highlight_edits: html = highlight_edits(html or escape(text), r_html_body) - except (ValueError, KeyError, MatrixRequestError) as e: + except (ValueError, KeyError, MatrixRequestError): r_sender_link = "unknown user" r_displayname = "unknown user" r_text_body = "Failed to fetch message" @@ -154,8 +160,9 @@ async def _add_reply_header(source: u.User, text: str, html: str, evt: Message, r_keyword = "In reply to" if not is_edit else "Edit to" r_msg_link = f"{r_keyword}" - html = (f"
{r_msg_link} {r_sender_link}\n{r_html_body}
" - + (html or escape(text))) + html = ( + f"
{r_msg_link} {r_sender_link}\n{r_html_body}
" + + (html or escape(text))) lines = r_text_body.strip().split("\n") text_with_quote = f"> <{r_displayname}> {lines.pop(0)}" @@ -167,7 +174,8 @@ async def _add_reply_header(source: u.User, text: str, html: str, evt: Message, return text_with_quote, html -async def telegram_to_matrix(evt: Message, source: u.User, main_intent: Optional[IntentAPI] = None, +async def telegram_to_matrix(evt: Message, source: "AbstractUser", + main_intent: Optional[IntentAPI] = None, is_edit: bool = False, prefix_text: Optional[str] = None, prefix_html: Optional[str] = None) -> Tuple[str, str, dict]: text = add_surrogates(evt.message) @@ -320,6 +328,6 @@ def _parse_url(html: List[str], entity_text: str, url: str) -> bool: return False -def init_tg(context: c.Context): +def init_tg(context: "Context"): global should_highlight_edits should_highlight_edits = htmldiff and context.config["bridge.highlight_edits"] diff --git a/mautrix_telegram/formatter/util.py b/mautrix_telegram/formatter/util.py index f464ffe5..2a296614 100644 --- a/mautrix_telegram/formatter/util.py +++ b/mautrix_telegram/formatter/util.py @@ -14,8 +14,8 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . +from typing import Optional, Pattern from html import escape -from typing import Optional import struct import re @@ -47,7 +47,7 @@ def trim_reply_fallback_text(text: str) -> str: html_reply_fallback_regex = re.compile("^" r"[\s\S]+?" - "") + "") # type: Pattern def trim_reply_fallback_html(html: str) -> str: diff --git a/mautrix_telegram/matrix.py b/mautrix_telegram/matrix.py index 28cf6796..8feed9f4 100644 --- a/mautrix_telegram/matrix.py +++ b/mautrix_telegram/matrix.py @@ -14,26 +14,23 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import List, Dict +from typing import List, Dict, Tuple, Set, Match import logging import asyncio import re from mautrix_appservice import MatrixRequestError, IntentError -from .user import User -from .portal import Portal -from .puppet import Puppet -from .commands import CommandProcessor +from . import user as u, portal as po, puppet as pu, commands as com class MatrixHandler: - log = logging.getLogger("mau.mx") + log = logging.getLogger("mau.mx") # type: logging.Logger def __init__(self, context): self.az, self.db, self.config, _, self.tgbot = context - self.commands = CommandProcessor(context) - self.previously_typing = [] + self.commands = com.CommandProcessor(context) # type: com.CommandProcessor + self.previously_typing = [] # type: List[str] self.az.matrix_event_handler(self.handle_event) @@ -53,68 +50,68 @@ class MatrixHandler: except asyncio.TimeoutError: self.log.exception("TimeoutError when trying to set avatar") - async def handle_puppet_invite(self, room, puppet, inviter): + async def handle_puppet_invite(self, room_id, puppet: pu.Puppet, inviter: u.User): intent = puppet.default_mxid_intent - self.log.debug(f"{inviter} invited puppet for {puppet.tgid} to {room}") + self.log.debug(f"{inviter} invited puppet for {puppet.tgid} to {room_id}") if not await inviter.is_logged_in(): await intent.error_and_leave( - room, text="Please log in before inviting Telegram puppets.") + room_id, text="Please log in before inviting Telegram puppets.") return - portal = Portal.get_by_mxid(room) + portal = po.Portal.get_by_mxid(room_id) if portal: if portal.peer_type == "user": await intent.error_and_leave( - room, text="You can not invite additional users to private chats.") + room_id, text="You can not invite additional users to private chats.") return await portal.invite_telegram(inviter, puppet) - await intent.join_room(room) + await intent.join_room(room_id) return try: - members = await self.az.intent.get_room_members(room) + members = await self.az.intent.get_room_members(room_id) except MatrixRequestError: members = [] if self.az.bot_mxid not in members: if len(members) > 1: - await intent.error_and_leave(room, text=None, html=( + await intent.error_and_leave(room_id, text=None, html=( f"Please invite " f"the bridge bot " f"first if you want to create a Telegram chat.")) return - await intent.join_room(room) - portal = Portal.get_by_tgid(puppet.tgid, inviter.tgid, "user") + await intent.join_room(room_id) + portal = po.Portal.get_by_tgid(puppet.tgid, inviter.tgid, "user") if portal.mxid: try: await intent.invite(portal.mxid, inviter.mxid) - await intent.send_notice(room, text=None, html=( + await intent.send_notice(room_id, text=None, html=( "You already have a private chat with me: " f"" "Link to room" "")) - await intent.leave_room(room) + await intent.leave_room(room_id) return except MatrixRequestError: pass - portal.mxid = room + portal.mxid = room_id portal.save() inviter.register_portal(portal) - await intent.send_notice(room, "Portal to private chat created.") + await intent.send_notice(room_id, "po.Portal to private chat created.") else: - await intent.join_room(room) - await intent.send_notice(room, "This puppet will remain inactive until a " - "Telegram chat is created for this room.") + await intent.join_room(room_id) + await intent.send_notice(room_id, "This puppet will remain inactive until a " + "Telegram chat is created for this room.") - async def accept_bot_invite(self, room, inviter): + async def accept_bot_invite(self, room_id: str, inviter: u.User): tries = 0 while tries < 5: try: - await self.az.intent.join_room(room) + await self.az.intent.join_room(room_id) break - except (IntentError, MatrixRequestError) as e: + except (IntentError, MatrixRequestError): tries += 1 wait_for_seconds = (tries + 1) * 10 if tries < 5: - self.log.exception(f"Failed to join room {room} with bridge bot, " + self.log.exception(f"Failed to join room {room_id} with bridge bot, " f"retrying in {wait_for_seconds} seconds...") await asyncio.sleep(wait_for_seconds) else: @@ -123,81 +120,81 @@ class MatrixHandler: if not inviter.whitelisted: await self.az.intent.send_notice( - room, text=None, + room_id, text=None, html="You are not whitelisted to use this bridge.

" "If you are the owner of this bridge, see the " "bridge.permissions section in your config file.") - await self.az.intent.leave_room(room) + await self.az.intent.leave_room(room_id) - async def handle_invite(self, room, user, inviter): - self.log.debug(f"{inviter} invited {user} to {room}") - inviter = await User.get_by_mxid(inviter).ensure_started() - if user == self.az.bot_mxid: - return await self.accept_bot_invite(room, inviter) + async def handle_invite(self, room_id: str, user_id: str, inviter_mxid: str): + self.log.debug(f"{inviter_mxid} invited {user_id} to {room_id}") + inviter = await u.User.get_by_mxid(inviter_mxid).ensure_started() + if user_id == self.az.bot_mxid: + return await self.accept_bot_invite(room_id, inviter) elif not inviter.whitelisted: return - puppet = Puppet.get_by_mxid(user) + puppet = pu.Puppet.get_by_mxid(user_id) if puppet: - await self.handle_puppet_invite(room, puppet, inviter) + await self.handle_puppet_invite(room_id, puppet, inviter) return - user = User.get_by_mxid(user, create=False) + user = u.User.get_by_mxid(user_id, create=False) if not user: return await user.ensure_started() - portal = Portal.get_by_mxid(room) + portal = po.Portal.get_by_mxid(room_id) if user and await user.has_full_access(allow_bot=True) and portal: await portal.invite_telegram(inviter, user) return # The rest can probably be ignored - async def handle_join(self, room, user, event_id): - user = await User.get_by_mxid(user).ensure_started() + async def handle_join(self, room_id: str, user_id: str, event_id: str): + user = await u.User.get_by_mxid(user_id).ensure_started() - portal = Portal.get_by_mxid(room) + portal = po.Portal.get_by_mxid(room_id) if not portal: return if not user.relaybot_whitelisted: - await portal.main_intent.kick(room, user.mxid, + await portal.main_intent.kick(room_id, user.mxid, "You are not whitelisted on this Telegram bridge.") return elif not await user.is_logged_in() and not portal.has_bot: - await portal.main_intent.kick(room, user.mxid, + await portal.main_intent.kick(room_id, user.mxid, "This chat does not have a bot relaying " "messages for unauthenticated users.") return - self.log.debug(f"{user} joined {room}") + self.log.debug(f"{user} joined {room_id}") if await user.is_logged_in() or portal.has_bot: await portal.join_matrix(user, event_id) - async def handle_part(self, room, user, sender, event_id): - self.log.debug(f"{user} left {room}") + async def handle_part(self, room_id: str, user_id, sender_mxid: str, event_id: str): + self.log.debug(f"{user_id} left {room_id}") - sender = User.get_by_mxid(sender, create=False) + sender = u.User.get_by_mxid(sender_mxid, create=False) if not sender: return await sender.ensure_started() - portal = Portal.get_by_mxid(room) + portal = po.Portal.get_by_mxid(room_id) if not portal: return - puppet = Puppet.get_by_mxid(user) + puppet = pu.Puppet.get_by_mxid(user_id) if sender and puppet: await portal.leave_matrix(puppet, sender, event_id) - user = User.get_by_mxid(user, create=False) + user = u.User.get_by_mxid(user_id, create=False) if not user: return await user.ensure_started() if await user.is_logged_in() or portal.has_bot: await portal.leave_matrix(user, sender, event_id) - def is_command(self, message): + def is_command(self, message: dict) -> Tuple[bool, str]: text = message.get("body", "") prefix = self.config["bridge.command_prefix"] is_command = text.startswith(prefix) @@ -207,14 +204,14 @@ class MatrixHandler: async def handle_message(self, room, sender, message, event_id): is_command, text = self.is_command(message) - sender = await User.get_by_mxid(sender).ensure_started() + sender = await u.User.get_by_mxid(sender).ensure_started() if not sender.relaybot_whitelisted: self.log.debug(f"Ignoring message \"{message}\" from {sender} to {room}:" - " User is not whitelisted.") + " u.User is not whitelisted.") return self.log.debug(f"Received Matrix event \"{message}\" from {sender} in {room}") - portal = Portal.get_by_mxid(room) + portal = po.Portal.get_by_mxid(room) if not is_command and portal and (await sender.is_logged_in() or portal.has_bot): await portal.handle_matrix_message(sender, message, event_id) return @@ -239,39 +236,44 @@ class MatrixHandler: await self.commands.handle(room, sender, command, args, is_management, is_portal=portal is not None) - async def handle_redaction(self, room, sender, event_id): - sender = await User.get_by_mxid(sender).ensure_started() + @staticmethod + async def handle_redaction(room_id: str, sender_mxid: str, event_id: str): + sender = await u.User.get_by_mxid(sender_mxid).ensure_started() if not sender.relaybot_whitelisted: return - portal = Portal.get_by_mxid(room) + portal = po.Portal.get_by_mxid(room_id) if not portal: return await portal.handle_matrix_deletion(sender, event_id) - async def handle_power_levels(self, room, sender, new, old): - portal = Portal.get_by_mxid(room) - sender = await User.get_by_mxid(sender).ensure_started() + @staticmethod + async def handle_power_levels(room_id: str, sender_mxid: str, new: dict, old: dict): + portal = po.Portal.get_by_mxid(room_id) + sender = await u.User.get_by_mxid(sender_mxid).ensure_started() if await sender.has_full_access(allow_bot=True) and portal: await portal.handle_matrix_power_levels(sender, new["users"], old["users"]) - async def handle_room_meta(self, type, room, sender, content): - portal = Portal.get_by_mxid(room) - sender = await User.get_by_mxid(sender).ensure_started() + @staticmethod + async def handle_room_meta(evt_type: str, room_id: str, sender_mxid: str, content: dict): + portal = po.Portal.get_by_mxid(room_id) + sender = await u.User.get_by_mxid(sender_mxid).ensure_started() if await sender.has_full_access(allow_bot=True) and portal: handler, content_key = { "m.room.name": (portal.handle_matrix_title, "name"), "m.room.topic": (portal.handle_matrix_about, "topic"), "m.room.avatar": (portal.handle_matrix_avatar, "url"), - }[type] + }[evt_type] if content_key not in content: return await handler(sender, content[content_key]) - async def handle_room_pin(self, room, sender, new_events, old_events): - portal = Portal.get_by_mxid(room) - sender = await User.get_by_mxid(sender).ensure_started() + @staticmethod + async def handle_room_pin(room_id: str, sender_mxid: str, new_events: Set[str], + old_events: Set[str]): + portal = po.Portal.get_by_mxid(room_id) + sender = await u.User.get_by_mxid(sender_mxid).ensure_started() if await sender.has_full_access(allow_bot=True) and portal: events = new_events - old_events if len(events) > 0: @@ -281,12 +283,14 @@ class MatrixHandler: # All pinned events removed, remove pinned event in Telegram. await portal.handle_matrix_pin(sender, None) - async def handle_name_change(self, room, user, displayname, prev_displayname, event_id): - portal = Portal.get_by_mxid(room) + @staticmethod + async def handle_name_change(room_id: str, user_id: str, displayname: str, + prev_displayname: str, event_id: str): + portal = po.Portal.get_by_mxid(room_id) if not portal or not portal.has_bot: return - user = await User.get_by_mxid(user).ensure_started() + user = await u.User.get_by_mxid(user_id).ensure_started() if await user.needs_relaybot(portal): await portal.name_change_matrix(user, displayname, prev_displayname, event_id) @@ -296,25 +300,27 @@ class MatrixHandler: for event_id, receipts in content.items() for user_id in receipts.get("m.read", {})} - async def handle_read_receipts(self, room_id: str, receipts: Dict[str, str]): - portal = Portal.get_by_mxid(room_id) + @staticmethod + async def handle_read_receipts(room_id: str, receipts: Dict[str, str]): + portal = po.Portal.get_by_mxid(room_id) if not portal: return for user_id, event_id in receipts.items(): - user = await User.get_by_mxid(user_id).ensure_started() + user = await u.User.get_by_mxid(user_id).ensure_started() if not await user.is_logged_in(): continue await portal.mark_read(user, event_id) - async def handle_presence(self, user: str, presence: str): - user = await User.get_by_mxid(user).ensure_started() + @staticmethod + async def handle_presence(user_id: str, presence: str): + user = await u.User.get_by_mxid(user_id).ensure_started() if not await user.is_logged_in(): return await user.set_presence(presence == "online") async def handle_typing(self, room_id: str, now_typing: List[str]): - portal = Portal.get_by_mxid(room_id) + portal = po.Portal.get_by_mxid(room_id) if not portal: return @@ -324,7 +330,7 @@ class MatrixHandler: if is_typing and was_typing: continue - user = await User.get_by_mxid(user_id).ensure_started() + user = await u.User.get_by_mxid(user_id).ensure_started() if not await user.is_logged_in(): continue @@ -332,38 +338,38 @@ class MatrixHandler: self.previously_typing = now_typing - def filter_matrix_event(self, event): + def filter_matrix_event(self, event: dict): sender = event.get("sender", None) if not sender: return False return (sender == self.az.bot_mxid - or Puppet.get_id_from_mxid(sender) is not None) + or pu.Puppet.get_id_from_mxid(sender) is not None) - async def try_handle_event(self, evt): + async def try_handle_event(self, evt: dict): try: await self.handle_event(evt) except Exception: self.log.exception("Error handling manually received Matrix event") - async def handle_event(self, evt): + async def handle_event(self, evt: dict): if self.filter_matrix_event(evt): return self.log.debug("Received event: %s", evt) - type = evt.get("type", "m.unknown") - room_id = evt.get("room_id", None) - event_id = evt.get("event_id", None) - sender = evt.get("sender", None) - content = evt.get("content", {}) - if type == "m.room.member": - state_key = evt["state_key"] - prev_content = evt.get("unsigned", {}).get("prev_content", {}) - membership = content.get("membership", "") - prev_membership = prev_content.get("membership", "leave") + evt_type = evt.get("type", "m.unknown") # type: str + room_id = evt.get("room_id", None) # type: str + event_id = evt.get("event_id", None) # type: str + sender = evt.get("sender", None) # type: str + content = evt.get("content", {}) # type: dict + if evt_type == "m.room.member": + state_key = evt["state_key"] # type: str + prev_content = evt.get("unsigned", {}).get("prev_content", {}) # type: dict + membership = content.get("membership", "") # type: str + prev_membership = prev_content.get("membership", "leave") # type: str if membership == prev_membership: - match = re.compile("@(.+):(.+)").match(state_key) - localpart = match.group(1) - displayname = content.get("displayname", localpart) - prev_displayname = prev_content.get("displayname", localpart) + match = re.compile("@(.+):(.+)").match(state_key) # type: Match + localpart = match.group(1) # type: str + displayname = content.get("displayname", localpart) # type: str + prev_displayname = prev_content.get("displayname", localpart) # type: str if displayname != prev_displayname: await self.handle_name_change(room_id, state_key, displayname, prev_displayname, event_id) @@ -373,26 +379,26 @@ class MatrixHandler: await self.handle_part(room_id, state_key, sender, event_id) elif membership == "join": await self.handle_join(room_id, state_key, event_id) - elif type in ("m.room.message", "m.sticker"): - if type != "m.room.message": - content["msgtype"] = type + elif evt_type in ("m.room.message", "m.sticker"): + if evt_type != "m.room.message": + content["msgtype"] = evt_type await self.handle_message(room_id, sender, content, event_id) - elif type == "m.room.redaction": + elif evt_type == "m.room.redaction": await self.handle_redaction(room_id, sender, evt["redacts"]) - elif type == "m.room.power_levels": + elif evt_type == "m.room.power_levels": await self.handle_power_levels(room_id, sender, evt["content"], evt["prev_content"]) - elif type in ("m.room.name", "m.room.avatar", "m.room.topic"): - await self.handle_room_meta(type, room_id, sender, evt["content"]) - elif type == "m.room.pinned_events": + elif evt_type in ("m.room.name", "m.room.avatar", "m.room.topic"): + await self.handle_room_meta(evt_type, room_id, sender, evt["content"]) + elif evt_type == "m.room.pinned_events": new_events = set(evt["content"]["pinned"]) try: old_events = set(evt["unsigned"]["prev_content"]["pinned"]) except KeyError: old_events = set() await self.handle_room_pin(room_id, sender, new_events, old_events) - elif type == "m.receipt": + elif evt_type == "m.receipt": await self.handle_read_receipts(room_id, self.parse_read_receipts(content)) - elif type == "m.presence": + elif evt_type == "m.presence": await self.handle_presence(sender, content.get("presence", "offline")) - elif type == "m.typing": + elif evt_type == "m.typing": await self.handle_typing(room_id, content.get("user_ids", [])) diff --git a/mautrix_telegram/portal.py b/mautrix_telegram/portal.py index 9331c9bc..a4f80776 100644 --- a/mautrix_telegram/portal.py +++ b/mautrix_telegram/portal.py @@ -14,6 +14,7 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . +from typing import Pattern, Dict, Tuple, Awaitable, TYPE_CHECKING from collections import deque from datetime import datetime from string import Template @@ -24,65 +25,82 @@ import mimetypes import unicodedata import hashlib import logging +import re import magic +from sqlalchemy import orm from sqlalchemy.exc import IntegrityError, InvalidRequestError from sqlalchemy.orm.exc import FlushError from telethon.tl.functions.messages import * from telethon.tl.functions.channels import * -from telethon.tl.functions.messages import ReadHistoryRequest +from telethon.tl.functions.messages import ReadHistoryRequest as ReadMessageHistoryRequest from telethon.tl.functions.channels import ReadHistoryRequest as ReadChannelHistoryRequest -from telethon.errors import * +from telethon.errors import ChatAdminRequiredError, ChatNotModifiedError from telethon.tl.types import * -from mautrix_appservice import MatrixRequestError, IntentError +from mautrix_appservice import MatrixRequestError, IntentError, AppService, IntentAPI -from .db import Portal as DBPortal, Message as DBMessage +from .context import Context +from .db import Portal as DBPortal, Message as DBMessage, TelegramFile as DBTelegramFile from . import puppet as p, user as u, formatter, util +if TYPE_CHECKING: + from .bot import Bot + from .abstract_user import AbstractUser + from .config import Config + from .tgclient import MautrixTelegramClient + mimetypes.init() -config = None +config = None # type: Config + +TypeMessage = Union[Message, MessageService] +TypeParticipant = Union[TypeChatParticipant, TypeChannelParticipant] +DedupMXID = Tuple[str, int] +InviteList = Union[str, List[str]] class Portal: - log = logging.getLogger("mau.portal") - db = None - az = None - bot = None - loop = None - filter_mode = None - filter_list = None - bridge_notices = False - alias_template = None - mx_alias_regex = None - hs_domain = None - by_mxid = {} - by_tgid = {} + log = logging.getLogger("mau.portal") # type: logging.Logger + db = None # type: orm.Session + az = None # type: AppService + bot = None # type: Bot + loop = None # type: asyncio.AbstractEventLoop + filter_mode = None # type: str + filter_list = None # type: List[str] + bridge_notices = False # type: bool + alias_template = None # type: str + mx_alias_regex = None # type: Pattern + hs_domain = None # type: str + by_mxid = {} # type: Dict[str, Portal] + by_tgid = {} # type: Dict[Tuple[int, int], Portal] - def __init__(self, tgid, peer_type, tg_receiver=None, mxid=None, username=None, - megagroup=False, title=None, about=None, photo_id=None, db_instance=None): - self.mxid = mxid - self.tgid = tgid - self.tg_receiver = tg_receiver or tgid - self.peer_type = peer_type - self.username = username - self.megagroup = megagroup - self.title = title - self.about = about - self.photo_id = photo_id - self._db_instance = db_instance + def __init__(self, tgid: int, peer_type: str, tg_receiver: Optional[int] = None, + mxid: Optional[str] = None, username: Optional[str] = None, + megagroup: Optional[bool] = False, title: Optional[str] = None, + about: Optional[str] = None, photo_id: Optional[str] = None, + db_instance: DBPortal = None): + self.mxid = mxid # type: str + self.tgid = tgid # type: int + self.tg_receiver = tg_receiver or tgid # type: int + self.peer_type = peer_type # type: str + self.username = username # type: str + self.megagroup = megagroup # type: bool + self.title = title # type: str + self.about = about # type: str + self.photo_id = photo_id # type: str + self._db_instance = db_instance # type: DBPortal - self._main_intent = None - self._room_create_lock = asyncio.Lock() - self._temp_pinned_message_id = None - self._temp_pinned_message_sender = None + self._main_intent = None # type: IntentAPI + self._room_create_lock = asyncio.Lock() # type: asyncio.Lock + self._temp_pinned_message_id = None # type: Optional[int] + self._temp_pinned_message_sender = None # type: Optional[p.Puppet] - self._dedup = deque() - self._dedup_mxid = {} - self._dedup_action = deque() + self._dedup = deque() # type: deque + self._dedup_mxid = {} # type: Dict[str, DedupMXID] + self._dedup_action = deque() # type: deque - self._send_locks = {} + self._send_locks = {} # type: Dict[int, asyncio.Lock] if tgid: self.by_tgid[self.tgid_full] = self @@ -92,17 +110,17 @@ class Portal: # region Propegrties @property - def tgid_full(self): + def tgid_full(self) -> Tuple[int, int]: return self.tgid, self.tg_receiver @property - def tgid_log(self): + def tgid_log(self) -> str: if self.tgid == self.tg_receiver: - return self.tgid + return str(self.tgid) return f"{self.tg_receiver}<->{self.tgid}" @property - def peer(self): + def peer(self) -> TypePeer: if self.peer_type == "user": return PeerUser(user_id=self.tgid) elif self.peer_type == "chat": @@ -111,11 +129,11 @@ class Portal: return PeerChannel(channel_id=self.tgid) @property - def has_bot(self): + def has_bot(self) -> bool: return self.bot and self.bot.is_in_chat(self.tgid) @property - def main_intent(self): + def main_intent(self) -> IntentAPI: if not self._main_intent: direct = self.peer_type == "user" puppet = p.Puppet.get(self.tgid) if direct else None @@ -125,7 +143,7 @@ class Portal: # endregion # region Filtering - def allow_bridging(self, tgid=None): + def allow_bridging(self, tgid: Optional[int] = None) -> bool: tgid = tgid or self.tgid if self.peer_type == "user": return True @@ -139,7 +157,7 @@ class Portal: # region Deduplication @staticmethod - def _hash_event(event): + def _hash_event(event: TypeMessage) -> str: # Non-channel messages are unique per-user (wtf telegram), so we have no other choice than # to deduplicate based on a hash of the message content. @@ -165,48 +183,54 @@ class Portal: .encode("utf-8") ).hexdigest() - def is_duplicate_action(self, event): - hash = self._hash_event(event) if self.peer_type != "channel" else event.id - if hash in self._dedup_action: + def is_duplicate_action(self, event: TypeMessage) -> bool: + evt_hash = self._hash_event(event) if self.peer_type != "channel" else event.id + if evt_hash in self._dedup_action: return True - self._dedup_action.append(hash) + self._dedup_action.append(evt_hash) if len(self._dedup_action) > 20: self._dedup_action.popleft() return False - def update_duplicate(self, event, mxid=None, expected_mxid=None, force_hash=False): - hash = self._hash_event(event) if self.peer_type != "channel" or force_hash else event.id + def update_duplicate(self, event: TypeMessage, mxid: DedupMXID = None, + expected_mxid: Optional[DedupMXID] = None, force_hash: bool = False + ) -> Optional[DedupMXID]: + evt_hash = self._hash_event( + event) if self.peer_type != "channel" or force_hash else event.id try: - found_mxid = self._dedup_mxid[hash] + found_mxid = self._dedup_mxid[evt_hash] except KeyError: - return 0, "None" + return "None", 0 if found_mxid != expected_mxid: return found_mxid - self._dedup_mxid[hash] = mxid + self._dedup_mxid[evt_hash] = mxid return None - def is_duplicate(self, event, mxid=None, force_hash=False): - hash = self._hash_event(event) if self.peer_type != "channel" or force_hash else event.id - if hash in self._dedup: - return self._dedup_mxid[hash] + def is_duplicate(self, event: TypeMessage, mxid: DedupMXID = None, force_hash: bool = False + ) -> Optional[DedupMXID]: + evt_hash = (self._hash_event(event) + if self.peer_type != "channel" or force_hash + else event.id) + if evt_hash in self._dedup: + return self._dedup_mxid[evt_hash] - self._dedup_mxid[hash] = mxid - self._dedup.append(hash) + self._dedup_mxid[evt_hash] = mxid + self._dedup.append(evt_hash) if len(self._dedup) > 20: del self._dedup_mxid[self._dedup.popleft()] return None - def get_input_entity(self, user): + def get_input_entity(self, user: u.User) -> Awaitable[TypeInputPeer]: return user.client.get_input_entity(self.peer) # endregion # region Matrix room info updating - async def invite_to_matrix(self, users): + async def invite_to_matrix(self, users: InviteList): if isinstance(users, str): await self.main_intent.invite(self.mxid, users, check_cache=True) elif isinstance(users, list): @@ -215,8 +239,10 @@ class Portal: else: raise ValueError("Invalid invite identifier given to invite_matrix()") - async def update_matrix_room(self, user, entity, direct, puppet=None, - levels=None, users=None, participants=None): + async def update_matrix_room(self, user: "AbstractUser", entity: TypeChat, direct: bool, + puppet: p.Puppet = None, levels: dict = None, + users: List[User] = None, + participants: List[TypeParticipant] = None): if not direct: await self.update_info(user, entity) if not users or not participants: @@ -229,8 +255,9 @@ class Portal: await puppet.update_info(user, entity) await puppet.intent.join_room(self.mxid) - async def create_matrix_room(self, user, entity=None, invites=None, update_if_exists=True, - synchronous=False): + async def create_matrix_room(self, user: "AbstractUser", entity: TypeChat = None, + invites: InviteList = None, update_if_exists: bool = True, + synchronous: bool = False) -> Optional[str]: if self.mxid: if update_if_exists: if not entity: @@ -245,7 +272,8 @@ class Portal: async with self._room_create_lock: return await self._create_matrix_room(user, entity, invites) - async def _create_matrix_room(self, user, entity, invites): + async def _create_matrix_room(self, user: "AbstractUser", entity: TypeChat, invites: InviteList + ) -> Optional[str]: direct = self.peer_type == "user" if self.mxid: @@ -310,7 +338,7 @@ class Portal: participants=participants), loop=self.loop) - def _get_base_power_levels(self, levels=None, entity=None): + def _get_base_power_levels(self, levels: dict = None, entity: TypeChat = None) -> dict: levels = levels or {} power_level_requirement = (0 if self.peer_type == "chat" and not entity.admins_enabled else 50) @@ -336,27 +364,27 @@ class Portal: return levels @property - def alias(self): + def alias(self) -> Optional[str]: if not self.username: return None return f"#{self._get_alias_localpart()}:{self.hs_domain}" - def _get_alias_localpart(self, username=None): + def _get_alias_localpart(self, username: Optional[str] = None) -> Optional[str]: username = username or self.username if not username: return None return self.alias_template.format(groupname=username) - def add_bot_chat(self, entity): - if self.bot and entity.id == self.bot.tgid: + def add_bot_chat(self, bot: User): + if self.bot and bot.id == self.bot.tgid: self.bot.add_chat(self.tgid, self.peer_type) return - user = u.User.get_by_tgid(entity.id) + user = u.User.get_by_tgid(bot.id) if user and user.is_bot: user.register_portal(self) - async def sync_telegram_users(self, source, users): + async def sync_telegram_users(self, source: "AbstractUser", users: List[User]): allowed_tgids = set() for entity in users: puppet = p.Puppet.get(entity.id) @@ -398,7 +426,7 @@ class Portal: "You had left this Telegram chat.") continue - async def add_telegram_user(self, user_id, source=None): + async def add_telegram_user(self, user_id: int, source: Optional["AbstractUser"] = None): puppet = p.Puppet.get(user_id) if source: entity = await source.client.get_entity(PeerUser(user_id)) @@ -410,7 +438,7 @@ class Portal: user.register_portal(self) await self.invite_to_matrix(user.mxid) - async def delete_telegram_user(self, user_id, sender): + async def delete_telegram_user(self, user_id: int, sender: p.Puppet): puppet = p.Puppet.get(user_id) user = u.User.get_by_tgid(user_id) kick_message = (f"Kicked by {sender.displayname}" @@ -424,7 +452,7 @@ class Portal: user.unregister_portal(self) await self.main_intent.kick(self.mxid, user.mxid, kick_message) - async def update_info(self, user, entity=None): + async def update_info(self, user: "AbstractUser", entity: TypeChat = None): if self.peer_type == "user": self.log.warning(f"Called update_info() for direct chat portal {self.tgid_log}") return @@ -448,7 +476,7 @@ class Portal: if changed: self.save() - async def update_username(self, username, save=False): + async def update_username(self, username: str, save: bool = False) -> bool: if self.username != username: if self.username: await self.main_intent.remove_room_alias(self._get_alias_localpart()) @@ -465,7 +493,7 @@ class Portal: return True return False - async def update_about(self, about, save=False): + async def update_about(self, about: str, save: bool = False) -> bool: if self.about != about: self.about = about await self.main_intent.set_room_topic(self.mxid, self.about) @@ -474,7 +502,7 @@ class Portal: return True return False - async def update_title(self, title, save=False): + async def update_title(self, title: str, save: bool = False) -> bool: if self.title != title: self.title = title await self.main_intent.set_room_name(self.mxid, self.title) @@ -484,17 +512,18 @@ class Portal: return False @staticmethod - def _get_largest_photo_size(photo): + def _get_largest_photo_size(photo: Photo) -> TypePhotoSize: return max(photo.sizes, key=(lambda photo2: ( len(photo2.bytes) if isinstance(photo2, PhotoCachedSize) else photo2.size))) - async def remove_avatar(self, user, save=False): + async def remove_avatar(self, _: "AbstractUser", save: bool = False): await self.main_intent.set_room_avatar(self.mxid, None) self.photo_id = None if save: self.save() - async def update_avatar(self, user, photo, save=False): + async def update_avatar(self, user: "AbstractUser", photo: FileLocation, + save: bool = False) -> bool: photo_id = f"{photo.volume_id}-{photo.local_id}" if self.photo_id != photo_id: file = await util.transfer_file_to_matrix(self.db, user.client, self.main_intent, @@ -507,7 +536,9 @@ class Portal: return True return False - async def _get_users(self, user, entity): + async def _get_users(self, user: "AbstractUser", entity: Union[TypeInputPeer, InputUser, + TypeChat, TypeUser] + ) -> Tuple[List[TypeUser], List[TypeParticipant]]: if self.peer_type == "chat": chat = await user.client(GetFullChatRequest(chat_id=self.tgid)) return chat.users, chat.full_chat.participants.participants @@ -544,7 +575,7 @@ class Portal: elif self.peer_type == "user": return [entity], [] - async def get_invite_link(self, user): + async def get_invite_link(self, user: u.User) -> str: if self.peer_type == "user": raise ValueError("You can't invite users to private chats.") elif self.peer_type == "chat": @@ -562,7 +593,7 @@ class Portal: return link.link - async def get_authenticated_matrix_users(self): + async def get_authenticated_matrix_users(self) -> List[u.User]: try: members = await self.main_intent.get_room_members(self.mxid) except MatrixRequestError: @@ -573,13 +604,14 @@ class Portal: if p.Puppet.get_id_from_mxid(member) or member == self.main_intent.mxid: continue user = await u.User.get_by_mxid(member).ensure_started() - if (has_bot and user.relaybot_whitelisted) or await user.has_full_access( - allow_bot=True): + authenticated_through_bot = has_bot and user.relaybot_whitelisted + if authenticated_through_bot or await user.has_full_access(allow_bot=True): authenticated.append(user) return authenticated @staticmethod - async def cleanup_room(intent, room_id, message="Portal deleted", puppets_only=False): + async def cleanup_room(intent: IntentAPI, room_id: str, message: str = "Portal deleted", + puppets_only: bool = False): try: members = await intent.get_room_members(room_id) except MatrixRequestError: @@ -608,7 +640,7 @@ class Portal: # region Matrix event handling @staticmethod - def _get_file_meta(body, mime): + def _get_file_meta(body: str, mime: str) -> str: try: current_extension = body[body.rindex("."):] if mimetypes.types_map[current_extension] == mime: @@ -620,7 +652,8 @@ class Portal: else: return "" - async def _get_state_change_message(self, event, user, arguments=None): + async def _get_state_change_message(self, event: str, user: u.User, + arguments: Optional[dict] = None) -> Optional[dict]: tpl = config[f"bridge.state_event_formats.{event}"] if len(tpl) == 0: # Empty format means they don't want the message @@ -637,7 +670,8 @@ class Portal: "formatted_body": message, } - async def name_change_matrix(self, user, displayname, prev_displayname, event_id): + async def name_change_matrix(self, user: u.User, displayname: str, prev_displayname: str, + event_id: str): async with self.require_send_lock(self.bot.tgid): message = await self._get_state_change_message( "name_change", user, @@ -650,15 +684,15 @@ class Portal: space = self.tgid if self.peer_type == "channel" else self.bot.tgid self.is_duplicate(response, (event_id, space)) - async def get_displayname(self, user): + async def get_displayname(self, user: u.User) -> str: return (await self.main_intent.get_displayname(self.mxid, user.mxid) or user.mxid_localpart) - def set_typing(self, user, typing=True, action=SendMessageTypingAction): - return user.client( - SetTypingRequest(self.peer, action() if typing else SendMessageCancelAction())) + def set_typing(self, user: u.User, typing: bool = True, action=SendMessageTypingAction): + return user.client(SetTypingRequest( + self.peer, action() if typing else SendMessageCancelAction())) - async def mark_read(self, user, event_id): + async def mark_read(self, user: u.User, event_id: str): if user.is_bot: return space = self.tgid if self.peer_type == "channel" else user.tgid @@ -671,9 +705,9 @@ class Portal: await user.client(ReadChannelHistoryRequest( channel=await self.get_input_entity(user), max_id=message.tgid)) else: - await user.client(ReadHistoryRequest(peer=self.peer, max_id=message.tgid)) + await user.client(ReadMessageHistoryRequest(peer=self.peer, max_id=message.tgid)) - async def leave_matrix(self, user, source, event_id): + async def leave_matrix(self, user: u.User, source: u.User, event_id: str): if await user.needs_relaybot(self): async with self.require_send_lock(self.bot.tgid): message = await self._get_state_change_message("leave", user) @@ -709,7 +743,7 @@ class Portal: channel = await self.get_input_entity(user) await user.client(LeaveChannelRequest(channel=channel)) - async def join_matrix(self, user, event_id): + async def join_matrix(self, user: u.User, event_id: str): if await user.needs_relaybot(self): async with self.require_send_lock(self.bot.tgid): message = await self._get_state_change_message("join", user) @@ -728,7 +762,7 @@ class Portal: # We'll just assume the user is already in the chat. pass - async def _apply_msg_format(self, sender, msgtype, message): + async def _apply_msg_format(self, sender: u.User, msgtype: str, message: dict): if "formatted_body" not in message: message["format"] = "org.matrix.custom.html" message["formatted_body"] = escape_html(message.get("body", "")) @@ -743,7 +777,7 @@ class Portal: message=body) message["formatted_body"] = Template(tpl).safe_substitute(tpl_args) - async def _preprocess_matrix_message(self, sender, use_relaybot, message): + async def _pre_process_matrix_message(self, sender: u.User, use_relaybot: bool, message: dict): msgtype = message.get("msgtype", "m.text") if msgtype == "m.emote": await self._apply_msg_format(sender, msgtype, message) @@ -751,7 +785,8 @@ class Portal: elif use_relaybot: await self._apply_msg_format(sender, msgtype, message) - def _matrix_event_to_entities(self, event): + @staticmethod + def _matrix_event_to_entities(event: dict) -> Tuple[str, Optional[List[TypeMessageEntity]]]: try: if event.get("format", None) == "org.matrix.custom.html": message, entities = formatter.matrix_to_telegram(event["formatted_body"]) @@ -761,32 +796,33 @@ class Portal: message, entities = None, None return message, entities - def require_send_lock(self, id): - if id is None: - return None + def require_send_lock(self, user_id: int) -> asyncio.Lock: + if user_id is None: + raise ValueError("Required send lock for none id") try: - return self._send_locks[id] + return self._send_locks[user_id] except KeyError: - self._send_locks[id] = asyncio.Lock() - return self._send_locks[id] + self._send_locks[user_id] = asyncio.Lock() + return self._send_locks[user_id] - def optional_send_lock(self, id): - if id is None: + def optional_send_lock(self, user_id: int) -> Optional[asyncio.Lock]: + if user_id is None: return None try: - return self._send_locks[id] + return self._send_locks[user_id] except KeyError: return None - async def _handle_matrix_text(self, sender_id, event_id, space, client, message, reply_to): + async def _handle_matrix_text(self, sender_id: int, event_id: str, space: int, + client: "MautrixTelegramClient", message: dict, reply_to: int): lock = self.require_send_lock(sender_id) async with lock: response = await client.send_message(self.peer, message, reply_to=reply_to, parse_mode=self._matrix_event_to_entities) self._add_telegram_message_to_db(event_id, space, response) - async def _handle_matrix_file(self, type, sender_id, event_id, space, client, message, - reply_to): + async def _handle_matrix_file(self, msgtype: str, sender_id: int, event_id: str, space: int, + client: "MautrixTelegramClient", message: dict, reply_to: int): file = await self.main_intent.download_file(message["url"]) info = message.get("info", {}) @@ -794,7 +830,7 @@ class Portal: w, h = None, None - if type == "m.sticker": + if msgtype == "m.sticker": if mime != "image/gif": mime, file, w, h = util.convert_image(file, source_mime=mime, target_type="webp") else: @@ -812,14 +848,16 @@ class Portal: caption = message["body"] if message["body"] != file_name else None - media = await client.upload_file(file, mime, attributes, file_name) + media = await client.upload_file_direct(file, mime, attributes, file_name) lock = self.require_send_lock(sender_id) async with lock: response = await client.send_media(self.peer, media, reply_to=reply_to, caption=caption) self._add_telegram_message_to_db(event_id, space, response) - async def _handle_matrix_location(self, sender_id, event_id, space, client, message, reply_to): + async def _handle_matrix_location(self, sender_id: int, event_id: str, space: int, + client: "MautrixTelegramClient", message: dict, + reply_to: int): try: lat, long = message["geo_uri"][len("geo:"):].split(",") lat, long = float(lat), float(long) @@ -827,7 +865,7 @@ class Portal: self.log.exception("Failed to parse location") return None message, entities = self._matrix_event_to_entities(message) - media = MessageMediaGeo(geo=GeoPoint(lat, long)) + media = MessageMediaGeo(geo=GeoPoint(lat, long, access_hash=0)) lock = self.require_send_lock(sender_id) async with lock: @@ -835,7 +873,7 @@ class Portal: caption=message, entities=entities) self._add_telegram_message_to_db(event_id, space, response) - def _add_telegram_message_to_db(self, event_id, space, response): + def _add_telegram_message_to_db(self, event_id: str, space: int, response: TypeMessage): self.log.debug("Handled Matrix message: %s", response) self.is_duplicate(response, (event_id, space)) self.db.add(DBMessage( @@ -859,21 +897,21 @@ class Portal: reply_to = formatter.matrix_reply_to_telegram(message, space, room_id=self.mxid) message["mxtg_filename"] = message["body"] - await self._preprocess_matrix_message(sender, not logged_in, message) - type = message["msgtype"] + await self._pre_process_matrix_message(sender, not logged_in, message) + msgtype = message["msgtype"] - if type == "m.text" or (self.bridge_notices and type == "m.notice"): + if msgtype == "m.text" or (self.bridge_notices and msgtype == "m.notice"): await self._handle_matrix_text(sender_id, event_id, space, client, message, reply_to) - elif type == "m.location": + elif msgtype == "m.location": await self._handle_matrix_location(sender_id, event_id, space, client, message, reply_to) - elif type in ("m.sticker", "m.image", "m.file", "m.audio", "m.video"): - await self._handle_matrix_file(type, sender_id, event_id, space, client, message, + elif msgtype in ("m.sticker", "m.image", "m.file", "m.audio", "m.video"): + await self._handle_matrix_file(msgtype, sender_id, event_id, space, client, message, reply_to) else: self.log.debug(f"Unhandled Matrix event: {message}") - async def handle_matrix_pin(self, sender, pinned_message): + async def handle_matrix_pin(self, sender: u.User, pinned_message: Optional[str]): if self.peer_type != "channel": return try: @@ -887,7 +925,7 @@ class Portal: except ChatNotModifiedError: pass - async def handle_matrix_deletion(self, deleter, event_id): + async def handle_matrix_deletion(self, deleter: u.User, event_id: str): deleter = deleter if not await deleter.needs_relaybot(self) else self.bot space = self.tgid if self.peer_type == "channel" else deleter.tgid message = DBMessage.query.filter(DBMessage.mxid == event_id, @@ -897,7 +935,7 @@ class Portal: return await deleter.client.delete_messages(self.peer, [message.tgid]) - async def _update_telegram_power_level(self, sender, user_id, level): + async def _update_telegram_power_level(self, sender: u.User, user_id: int, level: int): if self.peer_type == "chat": await sender.client(EditChatAdminRequest( chat_id=self.tgid, user_id=user_id, is_admin=level >= 50)) @@ -913,7 +951,8 @@ class Portal: EditAdminRequest(channel=await self.get_input_entity(sender), user_id=user_id, admin_rights=rights)) - async def handle_matrix_power_levels(self, sender, new_users, old_users): + async def handle_matrix_power_levels(self, sender: u.User, new_users: Dict[str, int], + old_users: Dict[str, int]): # TODO handle all power level changes and bridge exact admin rights to supergroups/channels for user, level in new_users.items(): if not user or user == self.main_intent.mxid or user == sender.mxid: @@ -929,7 +968,7 @@ class Portal: if user not in old_users or level != old_users[user]: await self._update_telegram_power_level(sender, user_id, level) - async def handle_matrix_about(self, sender, about): + async def handle_matrix_about(self, sender: u.User, about: str): if self.peer_type not in {"channel"}: return channel = await self.get_input_entity(sender) @@ -937,7 +976,7 @@ class Portal: self.about = about self.save() - async def handle_matrix_title(self, sender, title): + async def handle_matrix_title(self, sender: u.User, title: str): if self.peer_type not in {"chat", "channel"}: return @@ -950,7 +989,7 @@ class Portal: self.title = title self.save() - async def handle_matrix_avatar(self, sender, url): + async def handle_matrix_avatar(self, sender: u.User, url: str): if self.peer_type not in {"chat", "channel"}: # Invalid peer type return @@ -958,7 +997,7 @@ class Portal: file = await self.main_intent.download_file(url) mime = magic.from_buffer(file, mime=True) ext = mimetypes.guess_extension(mime) - uploaded = await sender.client.upload_file(file, file_name=f"avatar{ext}") + uploaded = await sender.client.upload_file_direct(file, file_name=f"avatar{ext}") photo = InputChatUploadedPhoto(file=uploaded) if self.peer_type == "chat": @@ -977,7 +1016,7 @@ class Portal: self.save() break - def _register_outgoing_actions_for_dedup(self, response): + def _register_outgoing_actions_for_dedup(self, response: TypeUpdates): for update in response.updates: check_dedup = (isinstance(update, (UpdateNewMessage, UpdateNewChannelMessage)) and isinstance(update.message, MessageService)) @@ -987,7 +1026,7 @@ class Portal: # endregion # region Telegram chat info updating - async def _get_telegram_users_in_matrix_room(self): + async def _get_telegram_users_in_matrix_room(self) -> List[int]: user_tgids = set() user_mxids = await self.main_intent.get_room_members(self.mxid, ("join", "invite")) for user in user_mxids: @@ -1001,13 +1040,13 @@ class Portal: user_tgids.add(puppet_id) return list(user_tgids) - async def upgrade_telegram_chat(self, source): + async def upgrade_telegram_chat(self, source: u.User): if self.peer_type != "chat": raise ValueError("Only normal group chats are upgradable to supergroups.") - updates = await source.client(MigrateChatRequest(chat_id=self.tgid)) + response = await source.client(MigrateChatRequest(chat_id=self.tgid)) entity = None - for chat in updates.chats: + for chat in response.chats: if isinstance(chat, Channel): entity = chat break @@ -1017,7 +1056,7 @@ class Portal: self.migrate_and_save(entity.id) await self.update_info(source, entity) - async def set_telegram_username(self, source, username): + async def set_telegram_username(self, source: u.User, username: str): if self.peer_type != "channel": raise ValueError("Only channels and supergroups have usernames.") await source.client( @@ -1025,7 +1064,7 @@ class Portal: if await self.update_username(username): self.save() - async def create_telegram_chat(self, source, supergroup=False): + async def create_telegram_chat(self, source: u.User, supergroup: bool = False): if not self.mxid: raise ValueError("Can't create Telegram chat for portal without Matrix room.") elif self.tgid: @@ -1036,13 +1075,13 @@ class Portal: raise ValueError("Not enough Telegram users to create a chat") if self.peer_type == "chat": - updates = await source.client(CreateChatRequest(title=self.title, users=invites)) - entity = updates.chats[0] + response = await source.client(CreateChatRequest(title=self.title, users=invites)) + entity = response.chats[0] elif self.peer_type == "channel": - updates = await source.client(CreateChannelRequest(title=self.title, - about=self.about or "", - megagroup=supergroup)) - entity = updates.chats[0] + response = await source.client(CreateChannelRequest(title=self.title, + about=self.about or "", + megagroup=supergroup)) + entity = response.chats[0] await source.client(InviteToChannelRequest( channel=await source.client.get_input_entity(entity), users=invites)) @@ -1066,7 +1105,7 @@ class Portal: await self.main_intent.set_power_levels(self.mxid, levels) await self.handle_matrix_power_levels(source, levels["users"], {}) - async def invite_telegram(self, source, puppet): + async def invite_telegram(self, source: u.User, puppet: Union[p.Puppet, "AbstractUser"]): if self.peer_type == "chat": await source.client( AddChatUserRequest(chat_id=self.tgid, user_id=puppet.tgid, fwd_limit=0)) @@ -1078,16 +1117,18 @@ class Portal: # endregion # region Telegram event handling - async def handle_telegram_typing(self, user, event): + async def handle_telegram_typing(self, user: p.Puppet, + _: Union[UpdateUserTyping, UpdateChatUserTyping]): if self.mxid: await user.intent.set_typing(self.mxid, is_typing=True) - def get_external_url(self, evt: Message): + def get_external_url(self, evt: Message) -> Optional[str]: if self.peer_type == "channel" and self.username is not None: return f"https://t.me/{self.username}/{evt.id}" return None - async def handle_telegram_photo(self, source: u.User, intent, evt: Message, relates_to=None): + async def handle_telegram_photo(self, source: "AbstractUser", intent: IntentAPI, evt: Message, + relates_to=None): largest_size = self._get_largest_photo_size(evt.media.photo) file = await util.transfer_file_to_matrix(self.db, source.client, intent, largest_size.location) @@ -1117,7 +1158,7 @@ class Portal: external_url=self.get_external_url(evt)) @staticmethod - def _parse_telegram_document_attributes(attributes): + def _parse_telegram_document_attributes(attributes: List[TypeDocumentAttribute]) -> dict: attrs = { "name": None, "mime_type": None, @@ -1138,7 +1179,8 @@ class Portal: return attrs @staticmethod - def _parse_telegram_document_meta(evt, file, attrs): + def _parse_telegram_document_meta(evt: Message, file: DBTelegramFile, attrs: dict + ) -> Tuple[dict, str]: document = evt.media.document name = evt.message or attrs["name"] if attrs["is_sticker"]: @@ -1170,7 +1212,9 @@ class Portal: return info, name - async def handle_telegram_document(self, source, intent, evt: Message, relates_to=None): + async def handle_telegram_document(self, source: "AbstractUser", intent: IntentAPI, + evt: Message, + relates_to: dict = None) -> Optional[dict]: document = evt.media.document attrs = self._parse_telegram_document_attributes(document.attributes) @@ -1207,7 +1251,8 @@ class Portal: kwargs["file_type"] = "m.file" return await intent.send_file(**kwargs) - def handle_telegram_location(self, source, intent, evt, relates_to=None): + def handle_telegram_location(self, _: "AbstractUser", intent: IntentAPI, evt: Message, + relates_to: dict = None) -> Awaitable[dict]: location = evt.media.geo long = location.long lat = location.lat @@ -1234,7 +1279,8 @@ class Portal: "m.relates_to": relates_to or None, }, timestamp=evt.date, external_url=self.get_external_url(evt)) - async def handle_telegram_text(self, source, intent, is_bot, evt): + async def handle_telegram_text(self, source: "AbstractUser", intent: IntentAPI, is_bot: bool, + evt: Message) -> dict: self.log.debug(f"Sending {evt.message} to {self.mxid} by {intent.mxid}") text, html, relates_to = await formatter.telegram_to_matrix(evt, source, self.main_intent) await intent.set_typing(self.mxid, is_typing=False) @@ -1243,7 +1289,7 @@ class Portal: msgtype=msgtype, timestamp=evt.date, external_url=self.get_external_url(evt)) - async def handle_telegram_edit(self, source, sender, evt): + async def handle_telegram_edit(self, source: "AbstractUser", sender: p.Puppet, evt: Message): if not self.mxid: return elif not config["bridge.edits_as_replies"]: @@ -1290,7 +1336,7 @@ class Portal: .update({"mxid": mxid}) self.db.commit() - async def handle_telegram_message(self, source, sender, evt): + async def handle_telegram_message(self, source: "AbstractUser", sender: p.Puppet, evt: Message): if not self.mxid: await self.create_matrix_room(source, invites=[source.mxid], update_if_exists=False) @@ -1373,19 +1419,21 @@ class Portal: self.db.rollback() await intent.redact(self.mxid, mxid) - async def _create_room_on_action(self, source, action): + async def _create_room_on_action(self, source: "AbstractUser", + action: TypeMessageAction) -> bool: if source.is_relaybot: return False create_and_exit = (MessageActionChatCreate, MessageActionChannelCreate) create_and_continue = (MessageActionChatAddUser, MessageActionChatJoinedByLink) - if isinstance(action, create_and_exit + create_and_continue): + if isinstance(action, create_and_exit) or isinstance(action, create_and_continue): await self.create_matrix_room(source, invites=[source.mxid], update_if_exists=isinstance(action, create_and_exit)) if not isinstance(action, create_and_continue): return False return True - async def handle_telegram_action(self, source, sender, update): + async def handle_telegram_action(self, source: "AbstractUser", sender: p.Puppet, + update: MessageService): action = update.action should_ignore = ((not self.mxid and not await self._create_room_on_action(source, action)) or self.is_duplicate_action(update)) @@ -1415,7 +1463,7 @@ class Portal: else: self.log.debug("Unhandled Telegram action in %s: %s", self.title, action) - async def set_telegram_admin(self, user_id): + async def set_telegram_admin(self, user_id: int): puppet = p.Puppet.get(user_id) user = await u.User.get_by_tgid(user_id) @@ -1426,7 +1474,7 @@ class Portal: levels["users"][puppet.mxid] = 50 await self.main_intent.set_power_levels(self.mxid, levels) - async def receive_telegram_pin_sender(self, sender): + async def receive_telegram_pin_sender(self, sender: p.Puppet): self._temp_pinned_message_sender = sender if self._temp_pinned_message_id: await self.update_telegram_pin() @@ -1434,25 +1482,25 @@ class Portal: async def update_telegram_pin(self): intent = (self._temp_pinned_message_sender.intent if self._temp_pinned_message_sender else self.main_intent) - id = self._temp_pinned_message_id + msg_id = self._temp_pinned_message_id self._temp_pinned_message_id = None self._temp_pinned_message_sender = None - message = DBMessage.query.get((id, self.tgid)) + message = DBMessage.query.get((msg_id, self.tgid)) if message: await intent.set_pinned_messages(self.mxid, [message.mxid]) else: await intent.set_pinned_messages(self.mxid, []) - async def receive_telegram_pin_id(self, id): - if id == 0: + async def receive_telegram_pin_id(self, msg_id: int): + if msg_id == 0: return await self.update_telegram_pin() - self._temp_pinned_message_id = id + self._temp_pinned_message_id = msg_id if self._temp_pinned_message_sender: await self.update_telegram_pin() @staticmethod - def _get_level_from_participant(participant, _): + def _get_level_from_participant(participant: TypeParticipant, _) -> int: # TODO use the power level requirements to get better precision in channels if isinstance(participant, (ChatParticipantAdmin, ChannelParticipantAdmin)): return 50 @@ -1461,7 +1509,8 @@ class Portal: return 0 @staticmethod - def _participant_to_power_levels(levels, user, new_level, bot_level): + def _participant_to_power_levels(levels: dict, user: Union[u.User, p.Puppet], new_level: int, + bot_level: int) -> bool: new_level = min(new_level, bot_level) default_level = levels["users_default"] if "users_default" in levels else 0 try: @@ -1473,7 +1522,7 @@ class Portal: return True return False - def _get_bot_level(self, levels): + def _get_bot_level(self, levels: dict) -> int: try: return levels["users"][self.main_intent.mxid] except KeyError: @@ -1483,7 +1532,7 @@ class Portal: return 0 @staticmethod - def _get_powerlevel_level(levels): + def _get_powerlevel_level(levels: dict) -> int: try: return levels["events"]["m.room.power_levels"] except KeyError: @@ -1492,7 +1541,8 @@ class Portal: except KeyError: return 50 - def _participants_to_power_levels(self, participants, levels): + def _participants_to_power_levels(self, participants: List[TypeParticipant], levels: dict + ) -> bool: bot_level = self._get_bot_level(levels) if bot_level < self._get_powerlevel_level(levels): return False @@ -1517,13 +1567,14 @@ class Portal: bot_level) or changed return changed - async def update_telegram_participants(self, participants, levels=None): + async def update_telegram_participants(self, participants: List[TypeParticipant], + levels: dict = None): if not levels: levels = await self.main_intent.get_power_levels(self.mxid) if self._participants_to_power_levels(participants, levels): await self.main_intent.set_power_levels(self.mxid, levels) - async def set_telegram_admins_enabled(self, enabled): + async def set_telegram_admins_enabled(self, enabled: bool): level = 50 if enabled else 10 levels = await self.main_intent.get_power_levels(self.mxid) levels["invite"] = level @@ -1535,17 +1586,17 @@ class Portal: # region Database conversion @property - def db_instance(self): + def db_instance(self) -> DBPortal: if not self._db_instance: self._db_instance = self.new_db_instance() return self._db_instance - def new_db_instance(self): + def new_db_instance(self) -> DBPortal: return DBPortal(tgid=self.tgid, tg_receiver=self.tg_receiver, peer_type=self.peer_type, mxid=self.mxid, username=self.username, megagroup=self.megagroup, title=self.title, about=self.about, photo_id=self.photo_id) - def migrate_and_save(self, new_id): + def migrate_and_save(self, new_id: int): existing = DBPortal.query.get(self.tgid_full) if existing: self.db.delete(existing) @@ -1580,7 +1631,7 @@ class Portal: self.db.commit() @classmethod - def from_db(cls, db_portal): + def from_db(cls, db_portal: DBPortal) -> "Portal": return Portal(tgid=db_portal.tgid, tg_receiver=db_portal.tg_receiver, peer_type=db_portal.peer_type, mxid=db_portal.mxid, username=db_portal.username, megagroup=db_portal.megagroup, @@ -1591,7 +1642,7 @@ class Portal: # region Class instance lookup @classmethod - def get_by_mxid(cls, mxid): + def get_by_mxid(cls, mxid: str) -> Optional["Portal"]: try: return cls.by_mxid[mxid] except KeyError: @@ -1604,14 +1655,14 @@ class Portal: return None @classmethod - def get_username_from_mx_alias(cls, alias): + def get_username_from_mx_alias(cls, alias: str) -> Optional[str]: match = cls.mx_alias_regex.match(alias) if match: return match.group(1) return None @classmethod - def find_by_username(cls, username): + def find_by_username(cls, username: str) -> Optional["Portal"]: if not username: return None @@ -1626,7 +1677,8 @@ class Portal: return None @classmethod - def get_by_tgid(cls, tgid, tg_receiver=None, peer_type=None): + def get_by_tgid(cls, tgid: int, tg_receiver: int = None, peer_type: str = None + ) -> Optional["Portal"]: tg_receiver = tg_receiver or tgid tgid_full = (tgid, tg_receiver) try: @@ -1647,36 +1699,37 @@ class Portal: return None @classmethod - def get_by_entity(cls, entity, receiver_id=None, create=True): + def get_by_entity(cls, entity: Union[TypeChat, TypePeer, TypeUser, TypeUserFull, TypeInputPeer], + receiver_id: int = None, create: bool = True) -> Optional["Portal"]: entity_type = type(entity) if entity_type in {Chat, ChatFull}: type_name = "chat" - id = entity.id + entity_id = entity.id elif entity_type in {PeerChat, InputPeerChat}: type_name = "chat" - id = entity.chat_id + entity_id = entity.chat_id elif entity_type in {Channel, ChannelFull}: type_name = "channel" - id = entity.id + entity_id = entity.id elif entity_type in {PeerChannel, InputPeerChannel, InputChannel}: type_name = "channel" - id = entity.channel_id + entity_id = entity.channel_id elif entity_type in {User, UserFull}: type_name = "user" - id = entity.id + entity_id = entity.id elif entity_type in {PeerUser, InputPeerUser, InputUser}: type_name = "user" - id = entity.user_id + entity_id = entity.user_id else: raise ValueError(f"Unknown entity type {entity_type.__name__}") - return cls.get_by_tgid(id, - receiver_id if type_name == "user" else id, + return cls.get_by_tgid(entity_id, + receiver_id if type_name == "user" else entity_id, type_name if create else None) # endregion -def init(context): +def init(context: Context): global config Portal.az, Portal.db, config, Portal.loop, Portal.bot = context Portal.bridge_notices = config["bridge.bridge_notices"] @@ -1684,5 +1737,5 @@ def init(context): Portal.filter_list = config["bridge.filter.list"] Portal.alias_template = config.get("bridge.alias_template", "telegram_{groupname}") Portal.hs_domain = config["homeserver.domain"] - localpart = Portal.alias_template.format(groupname="(.+)") - Portal.mx_alias_regex = re.compile(f"#{localpart}:{Portal.hs_domain}") + Portal.mx_alias_regex = re.compile( + f"#{Portal.alias_template.format(groupname='(.+)')}:{Portal.hs_domain}") diff --git a/mautrix_telegram/puppet.py b/mautrix_telegram/puppet.py index 20b6af8a..f5642bc2 100644 --- a/mautrix_telegram/puppet.py +++ b/mautrix_telegram/puppet.py @@ -14,32 +14,39 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . +from typing import Optional, Awaitable, Pattern, Dict, List, TYPE_CHECKING from difflib import SequenceMatcher -from typing import Optional, Awaitable import re import logging import asyncio +from sqlalchemy import orm + from telethon.tl.types import UserProfilePhoto from mautrix_appservice import AppService, IntentAPI, IntentError, MatrixRequestError from .db import Puppet as DBPuppet -from . import util, matrix +from . import util -config = None +if TYPE_CHECKING: + from .matrix import MatrixHandler + from .config import Config + from .context import Context + +config = None # type: Config class Puppet: - log = logging.getLogger("mau.puppet") - db = None + log = logging.getLogger("mau.puppet") # type: logging.Logger + db = None # type: orm.Session az = None # type: AppService - mx = None # type: matrix.MatrixHandler + mx = None # type: MatrixHandler loop = None # type: asyncio.AbstractEventLoop - mxid_regex = None - username_template = None - hs_domain = None - cache = {} - by_custom_mxid = {} + mxid_regex = None # type: Pattern + username_template = None # type: str + hs_domain = None # type: str + cache = {} # type: Dict[str, Puppet] + by_custom_mxid = {} # type: Dict[str, Puppet] def __init__(self, id=None, access_token=None, custom_mxid=None, username=None, displayname=None, displayname_source=None, photo_id=None, is_bot=None, @@ -71,7 +78,8 @@ class Puppet: def tgid(self): return self.id - async def is_logged_in(self): + @staticmethod + async def is_logged_in(): return True # region Custom puppet management @@ -154,12 +162,12 @@ class Puppet: def filter_events(self, events): new_events = [] for event in events: - type = event.get("type", None) + evt_type = event.get("type", None) event.setdefault("content", {}) - if type == "m.typing": + if evt_type == "m.typing": is_typing = self.custom_mxid in event["content"].get("user_ids", []) event["content"]["user_ids"] = [self.custom_mxid] if is_typing else [] - elif type == "m.receipt": + elif evt_type == "m.receipt": val = None evt = None for event_id in event["content"]: @@ -273,7 +281,7 @@ class Puppet: return round(similarity * 1000) / 10 @staticmethod - def get_displayname(info, format=True): + def get_displayname(info, enable_format=True): data = { "phone number": info.phone if hasattr(info, "phone") else None, "username": info.username, @@ -295,7 +303,7 @@ class Puppet: elif not name: name = info.id - if not format: + if not enable_format: return name return config.get("bridge.displayname_template", "{displayname} (Telegram)").format( displayname=name) @@ -347,18 +355,18 @@ class Puppet: # region Getters @classmethod - def get(cls, id, create=True) -> "Optional[Puppet]": + def get(cls, tgid, create=True) -> "Optional[Puppet]": try: - return cls.cache[id] + return cls.cache[tgid] except KeyError: pass - puppet = DBPuppet.query.get(id) + puppet = DBPuppet.query.get(tgid) if puppet: return cls.from_db(puppet) if create: - puppet = cls(id) + puppet = cls(tgid) cls.db.add(puppet.db_instance) cls.db.commit() return puppet @@ -402,8 +410,8 @@ class Puppet: return None @classmethod - def get_mxid_from_id(cls, id): - return f"@{cls.username_template.format(userid=id)}:{cls.hs_domain}" + def get_mxid_from_id(cls, tgid): + return f"@{cls.username_template.format(userid=tgid)}:{cls.hs_domain}" @classmethod def find_by_username(cls, username) -> "Optional[Puppet]": @@ -437,12 +445,12 @@ class Puppet: # endregion -def init(context): +def init(context: "Context") -> List[Awaitable[int]]: global config Puppet.az, Puppet.db, config, Puppet.loop, _ = context Puppet.mx = context.mx Puppet.username_template = config.get("bridge.username_template", "telegram_{userid}") Puppet.hs_domain = config["homeserver"]["domain"] - localpart = Puppet.username_template.format(userid="(.+)") - Puppet.mxid_regex = re.compile(f"@{localpart}:{Puppet.hs_domain}") + Puppet.mxid_regex = re.compile( + f"@{Puppet.username_template.format(userid='(.+)')}:{Puppet.hs_domain}") return [puppet.init_custom_mxid() for puppet in Puppet.get_all_with_custom_mxid()] diff --git a/mautrix_telegram/sqlstatestore.py b/mautrix_telegram/sqlstatestore.py index 63b030d2..68e9fd9d 100644 --- a/mautrix_telegram/sqlstatestore.py +++ b/mautrix_telegram/sqlstatestore.py @@ -16,6 +16,8 @@ # along with this program. If not, see . from typing import Dict, Tuple +from sqlalchemy import orm + from mautrix_appservice import StateStore from . import puppet as pu @@ -25,15 +27,17 @@ from .db import RoomState, UserProfile class SQLStateStore(StateStore): def __init__(self, db): super().__init__() - self.db = db + self.db = db # type: orm.Session self.profile_cache = {} # type: Dict[Tuple[str, str], UserProfile] self.room_state_cache = {} # type: Dict[str, RoomState] - def is_registered(self, user: str) -> bool: + @staticmethod + def is_registered(user: str) -> bool: puppet = pu.Puppet.get_by_mxid(user) return puppet.is_registered if puppet else False - def registered(self, user: str): + @staticmethod + def registered(user: str): puppet = pu.Puppet.get_by_mxid(user) if puppet: puppet.is_registered = True diff --git a/mautrix_telegram/tgclient.py b/mautrix_telegram/tgclient.py index 302515d8..4534524e 100644 --- a/mautrix_telegram/tgclient.py +++ b/mautrix_telegram/tgclient.py @@ -17,10 +17,14 @@ from telethon import TelegramClient, utils from telethon.tl.functions.messages import SendMediaRequest from telethon.tl.types import * +from telethon.tl import custom class MautrixTelegramClient(TelegramClient): - async def upload_file(self, file, mime_type=None, attributes=None, file_name=None): + async def upload_file_direct(self, file: bytes, mime_type: str = None, + attributes: List[TypeDocumentAttribute] = None, + file_name: str = None + ) -> Union[InputMediaUploadedDocument, InputMediaUploadedPhoto]: file_handle = await super().upload_file(file, file_name=file_name, use_cache=False) if mime_type == "image/png" or mime_type == "image/jpeg": @@ -34,7 +38,10 @@ class MautrixTelegramClient(TelegramClient): mime_type=mime_type or "application/octet-stream", attributes=list(attr_dict.values())) - async def send_media(self, entity, media, caption=None, entities=None, reply_to=None): + async def send_media(self, entity: Union[TypeInputPeer, TypePeer], + media: Union[TypeInputMedia, TypeMessageMedia], + caption: str = None, entities: List[TypeMessageEntity] = None, + reply_to: int = None) -> Optional[custom.Message]: entity = await self.get_input_entity(entity) reply_to = utils.get_message_id(reply_to) request = SendMediaRequest(entity, media, message=caption or "", entities=entities or [], diff --git a/mautrix_telegram/user.py b/mautrix_telegram/user.py index 8e229f94..c2bdf780 100644 --- a/mautrix_telegram/user.py +++ b/mautrix_telegram/user.py @@ -14,42 +14,51 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Dict, Awaitable, Optional +from typing import Dict, Awaitable, Optional, Match, Tuple, TYPE_CHECKING import logging import asyncio import re from telethon.tl.types import * +from telethon.tl.types import User as TLUser from telethon.tl.types.contacts import ContactsNotModified from telethon.tl.functions.contacts import GetContactsRequest, SearchRequest from telethon.tl.functions.account import UpdateStatusRequest from mautrix_appservice import MatrixRequestError -from .db import User as DBUser, Contact as DBContact +from .db import User as DBUser, Contact as DBContact, Portal as DBPortal from .abstract_user import AbstractUser from . import portal as po, puppet as pu -config = None +if TYPE_CHECKING: + from .config import Config + from .context import Context + +config = None # type: Config + +SearchResults = List[Tuple["pu.Puppet", int]] class User(AbstractUser): - log = logging.getLogger("mau.user") - by_mxid = {} - by_tgid = {} + log = logging.getLogger("mau.user") # type: logging.Logger + by_mxid = {} # type: Dict[str, User] + by_tgid = {} # type: Dict[int, User] - def __init__(self, mxid, tgid=None, username=None, db_contacts=None, saved_contacts=0, - is_bot=False, db_portals=None, db_instance=None): + def __init__(self, mxid: str, tgid: Optional[int] = None, username: Optional[str] = None, + db_contacts: Optional[List[DBContact]] = None, saved_contacts: int = 0, + is_bot: bool = False, db_portals: Optional[List[DBPortal]] = None, + db_instance: Optional[DBUser] = None): super().__init__() self.mxid = mxid # type: str self.tgid = tgid # type: int self.is_bot = is_bot # type: bool self.username = username # type: str - self.contacts = [] - self.saved_contacts = saved_contacts - self.db_contacts = db_contacts - self.portals = {} # type: Dict[str, po.Portal] - self.db_portals = db_portals - self._db_instance = db_instance + self.contacts = [] # type: List[pu.Puppet] + self.saved_contacts = saved_contacts # type: int + self.db_contacts = db_contacts # type: List[DBContact] + self.portals = {} # type: Dict[Tuple[int, int], po.Portal] + self.db_portals = db_portals # type: List[DBPortal] + self._db_instance = db_instance # type: DBUser self.command_status = None # type: dict @@ -64,53 +73,47 @@ class User(AbstractUser): self.by_tgid[tgid] = self @property - def name(self): + def name(self) -> str: return self.mxid @property - def mxid_localpart(self): - match = re.compile("@(.+):(.+)").match(self.mxid) + def mxid_localpart(self) -> str: + match = re.compile("@(.+):(.+)").match(self.mxid) # type: Match return match.group(1) # TODO replace with proper displayname getting everywhere @property - def displayname(self): + def displayname(self) -> str: return self.mxid_localpart @property - def db_contacts(self): + def db_contacts(self) -> List[DBContact]: return [self.db.merge(DBContact(user=self.tgid, contact=puppet.id)) for puppet in self.contacts] @db_contacts.setter - def db_contacts(self, contacts): - if contacts: - self.contacts = [pu.Puppet.get(entry.contact) for entry in contacts] - else: - self.contacts = [] + def db_contacts(self, contacts: List[DBContact]): + self.contacts = [pu.Puppet.get(entry.contact) for entry in contacts] if contacts else [] @property - def db_portals(self): + def db_portals(self) -> List[DBPortal]: return [portal.db_instance for portal in self.portals.values()] @db_portals.setter - def db_portals(self, portals): - if portals: - self.portals = {(portal.tgid, portal.tg_receiver): - po.Portal.get_by_tgid(portal.tgid, portal.tg_receiver) - for portal in portals} - else: - self.portals = {} + def db_portals(self, portals: List[DBPortal]): + self.portals = {(portal.tgid, portal.tg_receiver): + po.Portal.get_by_tgid(portal.tgid, portal.tg_receiver) + for portal in portals} if portals else {} # region Database conversion @property - def db_instance(self): + def db_instance(self) -> DBUser: if not self._db_instance: self._db_instance = self.new_db_instance() return self._db_instance - def new_db_instance(self): + def new_db_instance(self) -> DBUser: return DBUser(mxid=self.mxid, tgid=self.tgid, tg_username=self.username, contacts=self.db_contacts, saved_contacts=self.saved_contacts or 0, portals=self.db_portals) @@ -134,14 +137,14 @@ class User(AbstractUser): self.db.commit() @classmethod - def from_db(cls, db_user): + def from_db(cls, db_user: DBUser) -> "User": return User(db_user.mxid, db_user.tgid, db_user.tg_username, db_user.contacts, False, db_user.saved_contacts, db_user.portals, db_instance=db_user) # endregion # region Telegram connection management - async def start(self, delete_unless_authenticated=False): + async def start(self, delete_unless_authenticated: bool = False) -> "User": await super().start() if await self.is_logged_in(): self.log.debug(f"Ensuring post_login() for {self.name}") @@ -152,7 +155,7 @@ class User(AbstractUser): self.client.session.delete() return self - async def post_login(self, info=None): + async def post_login(self, info: TLUser = None): try: await self.update_info(info) if not self.is_bot: @@ -163,7 +166,7 @@ class User(AbstractUser): except Exception: self.log.exception("Failed to run post-login functions for %s", self.mxid) - async def update(self, update): + async def update(self, update: TypeUpdate): if not self.is_bot: return @@ -186,7 +189,7 @@ class User(AbstractUser): # endregion # region Telegram actions that need custom methods - def ensure_started(self, even_if_no_session=False) -> "Awaitable[User]": + def ensure_started(self, even_if_no_session: bool = False) -> "Awaitable[User]": return super().ensure_started(even_if_no_session) def set_presence(self, online: bool = True): @@ -194,7 +197,7 @@ class User(AbstractUser): return return self.client(UpdateStatusRequest(offline=not online)) - async def update_info(self, info: User = None): + async def update_info(self, info: TLUser = None): info = info or await self.client.get_me() changed = False if self.is_bot != info.bot: @@ -233,8 +236,9 @@ class User(AbstractUser): self.delete() return True - def _search_local(self, query, max_results=5, min_similarity=45): - results = [] + def _search_local(self, query: str, max_results: int = 5, min_similarity: int = 45 + ) -> SearchResults: + results = [] # type: SearchResults for contact in self.contacts: similarity = contact.similarity(query) if similarity >= min_similarity: @@ -242,11 +246,11 @@ class User(AbstractUser): results.sort(key=lambda tup: tup[1], reverse=True) return results[0:max_results] - async def _search_remote(self, query, max_results=5): + async def _search_remote(self, query: str, max_results: int = 5) -> SearchResults: if len(query) < 5: return [] server_results = await self.client(SearchRequest(q=query, limit=max_results)) - results = [] + results = [] # type: SearchResults for user in server_results.users: puppet = pu.Puppet.get(user.id) await puppet.update_info(self, user) @@ -254,7 +258,7 @@ class User(AbstractUser): results.sort(key=lambda tup: tup[1], reverse=True) return results[0:max_results] - async def search(self, query, force_remote=False): + async def search(self, query: str, force_remote: bool = False) -> Tuple[SearchResults, bool]: if force_remote: return await self._search_remote(query), True @@ -264,7 +268,7 @@ class User(AbstractUser): return await self._search_remote(query), True - async def sync_dialogs(self, synchronous_create=False): + async def sync_dialogs(self, synchronous_create: bool = False): creators = [] for entity in await self.get_dialogs(limit=30): portal = po.Portal.get_by_entity(entity) @@ -275,7 +279,7 @@ class User(AbstractUser): self.save() await asyncio.gather(*creators, loop=self.loop) - def register_portal(self, portal): + def register_portal(self, portal: po.Portal): try: if self.portals[portal.tgid_full] == portal: return @@ -284,18 +288,18 @@ class User(AbstractUser): self.portals[portal.tgid_full] = portal self.save() - def unregister_portal(self, portal): + def unregister_portal(self, portal: po.Portal): try: del self.portals[portal.tgid_full] self.save() except KeyError: pass - async def needs_relaybot(self, portal): + async def needs_relaybot(self, portal: po.Portal) -> bool: return not await self.is_logged_in() or ( self.is_bot and portal.tgid_full not in self.portals) - def _hash_contacts(self): + def _hash_contacts(self) -> int: acc = 0 for id in sorted([self.saved_contacts] + [contact.id for contact in self.contacts]): acc = (acc * 20261 + id) & 0xffffffff @@ -318,7 +322,7 @@ class User(AbstractUser): # region Class instance lookup @classmethod - def get_by_mxid(cls, mxid, create=True) -> "Optional[User]": + def get_by_mxid(cls, mxid: str, create: bool=True) -> "Optional[User]": if not mxid: raise ValueError("Matrix ID can't be empty") @@ -341,7 +345,7 @@ class User(AbstractUser): return None @classmethod - def get_by_tgid(cls, tgid) -> "Optional[User]": + def get_by_tgid(cls, tgid: int) -> "Optional[User]": try: return cls.by_tgid[tgid] except KeyError: @@ -355,7 +359,7 @@ class User(AbstractUser): return None @classmethod - def find_by_username(cls, username) -> "Optional[User]": + def find_by_username(cls, username: str) -> "Optional[User]": if not username: return None @@ -371,7 +375,7 @@ class User(AbstractUser): # endregion -def init(context): +def init(context: "Context") -> List[Awaitable[User]]: global config config = context.config diff --git a/mautrix_telegram/util/file_transfer.py b/mautrix_telegram/util/file_transfer.py index e927cd77..d950b2a0 100644 --- a/mautrix_telegram/util/file_transfer.py +++ b/mautrix_telegram/util/file_transfer.py @@ -14,15 +14,25 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . +from typing import Optional, Tuple, Union, Dict from io import BytesIO import time import logging import asyncio import magic +from sqlalchemy import orm from sqlalchemy.exc import IntegrityError, InvalidRequestError from sqlalchemy.orm.exc import FlushError +from telethon.tl.types import (Document, FileLocation, InputFileLocation, + InputDocumentFileLocation, PhotoSize, PhotoCachedSize) +from telethon.errors import * +from mautrix_appservice import IntentAPI + +from ..tgclient import MautrixTelegramClient +from ..db import TelegramFile as DBTelegramFile + try: from PIL import Image except ImportError: @@ -36,20 +46,18 @@ try: except ImportError: VideoFileClip = random = string = os = mimetypes = None -from telethon.tl.types import (Document, FileLocation, InputFileLocation, - InputDocumentFileLocation, PhotoSize, PhotoCachedSize) -from telethon.errors import * +log = logging.getLogger("mau.util") # type: logging.Logger -from ..db import TelegramFile as DBTelegramFile - -log = logging.getLogger("mau.util") +TypeLocation = Union[Document, InputDocumentFileLocation, FileLocation, InputFileLocation] -def convert_image(file, source_mime="image/webp", target_type="png", thumbnail_to=None): +def convert_image(file: bytes, source_mime: str = "image/webp", target_type: str = "png", + thumbnail_to: Optional[Tuple[int, int]] = None + ) -> Tuple[str, bytes, Optional[int], Optional[int]]: if not Image: return source_mime, file, None, None try: - image = Image.open(BytesIO(file)).convert("RGBA") + image = Image.open(BytesIO(file)).convert("RGBA") # type: Image.Image if thumbnail_to: image.thumbnail(thumbnail_to, Image.ANTIALIAS) new_file = BytesIO() @@ -61,13 +69,14 @@ def convert_image(file, source_mime="image/webp", target_type="png", thumbnail_t return source_mime, file, None, None -def _temp_file_name(ext): +def _temp_file_name(ext: str) -> str: return ("/tmp/mxtg-video-" + "".join(random.choice(string.ascii_uppercase + string.digits) for _ in range(10)) + ext) -def _read_video_thumbnail(data, video_ext="mp4", frame_ext="png", max_size=(1024, 720)): +def _read_video_thumbnail(data: bytes, video_ext: str = "mp4", frame_ext: str = "png", + max_size: Tuple[int, int] = (1024, 720)) -> Tuple[bytes, int, int]: # We don't have any way to read the video from memory, so save it to disk. temp_file = _temp_file_name(video_ext) with open(temp_file, "wb") as file: @@ -90,21 +99,21 @@ def _read_video_thumbnail(data, video_ext="mp4", frame_ext="png", max_size=(1024 return thumbnail_file.getvalue(), w, h -def _location_to_id(location): +def _location_to_id(location: TypeLocation) -> str: if isinstance(location, (Document, InputDocumentFileLocation)): return f"{location.id}-{location.version}" elif isinstance(location, (FileLocation, InputFileLocation)): return f"{location.volume_id}-{location.local_id}" - else: - return None -async def transfer_thumbnail_to_matrix(client, intent, thumbnail_loc, video, mime): +async def transfer_thumbnail_to_matrix(client: MautrixTelegramClient, intent: IntentAPI, + thumbnail_loc: TypeLocation, video: bytes, + mime: str) -> Optional[DBTelegramFile]: if not Image or not VideoFileClip: return None - id = _location_to_id(thumbnail_loc) - if not id: + loc_id = _location_to_id(thumbnail_loc) + if not loc_id: return None video_ext = mimetypes.guess_extension(mime) @@ -121,36 +130,40 @@ async def transfer_thumbnail_to_matrix(client, intent, thumbnail_loc, video, mim content_uri = await intent.upload_file(file, mime_type) - return DBTelegramFile(id=id, mxc=content_uri, mime_type=mime_type, + return DBTelegramFile(id=loc_id, mxc=content_uri, mime_type=mime_type, was_converted=False, timestamp=int(time.time()), size=len(file), width=width, height=height) -transfer_locks = {} -transfer_locks_lock = asyncio.Lock() +transfer_locks = {} # type: Dict[str, asyncio.Lock] -async def transfer_file_to_matrix(db, client, intent, location, thumbnail=None, is_sticker=False): - id = _location_to_id(location) - if not id: +async def transfer_file_to_matrix(db: orm.Session, client: MautrixTelegramClient, intent: IntentAPI, + location: TypeLocation, thumbnail: Optional[TypeLocation] = None, + is_sticker: bool = False) -> Optional[DBTelegramFile]: + location_id = _location_to_id(location) + if not location_id: return None - db_file = DBTelegramFile.query.get(id) + db_file = DBTelegramFile.query.get(location_id) if db_file: return db_file - async with transfer_locks_lock: - try: - lock = transfer_locks[id] - except KeyError: - lock = asyncio.Lock() - transfer_locks[id] = lock + try: + lock = transfer_locks[location_id] + except KeyError: + lock = asyncio.Lock() + transfer_locks[location_id] = lock async with lock: - return await _unlocked_transfer_file_to_matrix(db, client, intent, id, location, thumbnail, is_sticker) + return await _unlocked_transfer_file_to_matrix(db, client, intent, location_id, location, + thumbnail, is_sticker) -async def _unlocked_transfer_file_to_matrix(db, client, intent, id, location, thumbnail, is_sticker): - db_file = DBTelegramFile.query.get(id) +async def _unlocked_transfer_file_to_matrix(db: orm.Session, client: MautrixTelegramClient, + intent: IntentAPI, loc_id: str, location: TypeLocation, + thumbnail: Optional[TypeLocation], + is_sticker: bool) -> Optional[DBTelegramFile]: + db_file = DBTelegramFile.query.get(loc_id) if db_file: return db_file @@ -167,15 +180,16 @@ async def _unlocked_transfer_file_to_matrix(db, client, intent, id, location, th image_converted = False if mime_type == "image/webp": - new_mime_type, file, width, height = convert_image(file, source_mime="image/webp", target_type="png", thumbnail_to=( - 256, 256) if is_sticker else None) + new_mime_type, file, width, height = convert_image( + file, source_mime="image/webp", target_type="png", + thumbnail_to=(256, 256) if is_sticker else None) image_converted = new_mime_type != mime_type mime_type = new_mime_type thumbnail = None content_uri = await intent.upload_file(file, mime_type) - db_file = DBTelegramFile(id=id, mxc=content_uri, + db_file = DBTelegramFile(id=loc_id, mxc=content_uri, mime_type=mime_type, was_converted=image_converted, timestamp=int(time.time()), size=len(file), width=width, height=height) diff --git a/mautrix_telegram/util/format_duration.py b/mautrix_telegram/util/format_duration.py index c873e9e5..9402b83e 100644 --- a/mautrix_telegram/util/format_duration.py +++ b/mautrix_telegram/util/format_duration.py @@ -16,10 +16,12 @@ # along with this program. If not, see . -def format_duration(seconds): - def pluralize(count, singular): return singular if count == 1 else singular + "s" +def format_duration(seconds: int) -> str: + def pluralize(count, singular): + return singular if count == 1 else singular + "s" - def include(count, word): return f"{count} {pluralize(count, word)}" if count > 0 else "" + def include(count, word): + return f"{count} {pluralize(count, word)}" if count > 0 else "" minutes, seconds = divmod(seconds, 60) hours, minutes = divmod(minutes, 60)