diff --git a/mautrix_telegram/__main__.py b/mautrix_telegram/__main__.py index 97656dc6..dde8aaa3 100644 --- a/mautrix_telegram/__main__.py +++ b/mautrix_telegram/__main__.py @@ -14,7 +14,7 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Optional +from typing import Coroutine, List, Optional import argparse import asyncio import logging.config @@ -115,7 +115,7 @@ with appserv.run(config["appservice.hostname"], config["appservice.port"]) as st startup_actions = (init_puppet(context) + init_user(context) + [start, - context.mx.init_as_bot()]) + context.mx.init_as_bot()]) # type: List[Coroutine] if context.bot: startup_actions.append(context.bot.start()) diff --git a/mautrix_telegram/abstract_user.py b/mautrix_telegram/abstract_user.py index 300a4e29..968c017d 100644 --- a/mautrix_telegram/abstract_user.py +++ b/mautrix_telegram/abstract_user.py @@ -38,6 +38,7 @@ from .db import Message as DBMessage from .tgclient import MautrixTelegramClient if TYPE_CHECKING: + from .types import TelegramId from .context import Context from .config import Config from .bot import Bot @@ -67,10 +68,11 @@ class AbstractUser(ABC): self.whitelisted = False # type: bool self.relaybot_whitelisted = False # type: bool self.client = None # type: MautrixTelegramClient - self.tgid = None # type: int + self.tgid = None # type: TelegramId self.mxid = None # type: str self.is_relaybot = False # type: bool self.is_bot = False # type: bool + self.relaybot = None # type: Optional[Bot] @property def connected(self) -> bool: @@ -372,7 +374,7 @@ class AbstractUser(ABC): def init(context: "Context") -> None: global config, MAX_DELETIONS - AbstractUser.az, AbstractUser.db, config, AbstractUser.loop, AbstractUser.relaybot = context + AbstractUser.az, AbstractUser.db, config, AbstractUser.loop, AbstractUser.relaybot = context.core AbstractUser.ignore_incoming_bot_events = config["bridge.relaybot.ignore_own_incoming_events"] AbstractUser.session_container = context.session_container MAX_DELETIONS = config.get("bridge.max_telegram_delete", 10) diff --git a/mautrix_telegram/bot.py b/mautrix_telegram/bot.py index 3a8dcc6b..c2409cff 100644 --- a/mautrix_telegram/bot.py +++ b/mautrix_telegram/bot.py @@ -14,7 +14,7 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Awaitable, Callable, Dict, Optional, Pattern, TYPE_CHECKING +from typing import Awaitable, Callable, Dict, List, Optional, Pattern, TYPE_CHECKING import logging import re @@ -27,12 +27,14 @@ from telethon.tl.functions.messages import GetChatsRequest, GetFullChatRequest from telethon.tl.functions.channels import GetChannelsRequest, GetParticipantRequest from telethon.errors import ChannelInvalidError, ChannelPrivateError +from .types import MatrixUserId from .abstract_user import AbstractUser from .db import BotChat from . import puppet as pu, portal as po, user as u if TYPE_CHECKING: from .config import Config + from .context import Context config = None # type: Config @@ -145,6 +147,7 @@ class Bot(AbstractUser): for p in participants: if p.user_id == tgid: return isinstance(p, (ChatParticipantCreator, ChatParticipantAdmin)) + return False async def check_can_use_commands(self, event: Message, reply: ReplyFunc) -> bool: if not await self._can_use_commands(event.to_id, event.from_id): @@ -168,15 +171,16 @@ class Bot(AbstractUser): return await reply( "Portal is not public. Use `/invite ` to get an invite.") - async def handle_command_invite(self, portal: po.Portal, reply: ReplyFunc, mxid: str) -> None: - if len(mxid) == 0: + async def handle_command_invite(self, portal: po.Portal, reply: ReplyFunc, + mxid_input: MatrixUserId) -> Message: + if len(mxid_input) == 0: return await reply("Usage: `/invite `") elif not portal.mxid: return await reply("Portal does not have Matrix room. " "Create one with /portal first.") - if not self.mxid_regex.match(mxid): + if not self.mxid_regex.match(mxid_input): return await reply("That doesn't look like a Matrix ID.") - user = await u.User.get_by_mxid(mxid).ensure_started() + user = await u.User.get_by_mxid(MatrixUserId(mxid_input)).ensure_started() if not user.relaybot_whitelisted: return await reply("That user is not whitelisted to use the bridge.") elif await user.is_logged_in(): @@ -187,7 +191,7 @@ class Bot(AbstractUser): await portal.main_intent.invite(portal.mxid, user.mxid) return await reply(f"Invited `{user.mxid}` to the portal.") - def handle_command_id(self, message: Message, reply: ReplyFunc) -> None: + def handle_command_id(self, message: Message, reply: ReplyFunc) -> Awaitable[Message]: # Provide the prefixed ID to the user so that the user wouldn't need to specify whether the # chat is a normal group or a supergroup/channel when using the ID. if isinstance(message.to_id, PeerChannel): @@ -210,7 +214,7 @@ class Bot(AbstractUser): return False async def handle_command(self, message: Message) -> None: - def reply(reply_text) -> None: + def reply(reply_text: str) -> Awaitable[Message]: return self.client.send_message(message.to_id, reply_text, reply_to=message.id) text = message.message @@ -231,7 +235,7 @@ class Bot(AbstractUser): mxid = text[text.index(" ") + 1:] except ValueError: mxid = "" - await self.handle_command_invite(portal, reply, mxid=mxid) + await self.handle_command_invite(portal, reply, mxid_input=mxid) def handle_service_message(self, message: MessageService) -> None: to_id = message.to_id @@ -250,11 +254,12 @@ class Bot(AbstractUser): elif isinstance(action, MessageActionChatDeleteUser) and action.user_id == self.tgid: self.remove_chat(to_id) - async def update(self, update) -> None: + async def update(self, update) -> bool: if not isinstance(update, (UpdateNewMessage, UpdateNewChannelMessage)): - return + return False if isinstance(update.message, MessageService): - return self.handle_service_message(update.message) + self.handle_service_message(update.message) + return False is_command = (isinstance(update.message, Message) and update.message.entities and len(update.message.entities) > 0 @@ -270,7 +275,7 @@ class Bot(AbstractUser): return "bot" -def init(context) -> Optional[Bot]: +def init(context: 'Context') -> Optional[Bot]: global config config = context.config token = config["telegram.bot_token"] diff --git a/mautrix_telegram/commands/auth.py b/mautrix_telegram/commands/auth.py index f2514dcb..12151e4f 100644 --- a/mautrix_telegram/commands/auth.py +++ b/mautrix_telegram/commands/auth.py @@ -14,7 +14,7 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Dict +from typing import Any, Awaitable, Dict, Optional import asyncio from telethon.errors import ( @@ -31,7 +31,7 @@ from ..util import format_duration @command_handler(needs_auth=False, help_section=SECTION_AUTH, help_text="Check if you're logged into Telegram.") -async def ping(evt: CommandEvent) -> None: +async def ping(evt: CommandEvent) -> Optional[Dict]: me = await evt.sender.client.get_me() if await evt.sender.is_logged_in() else None if me: return await evt.reply(f"You're logged in as @{me.username}") @@ -42,7 +42,7 @@ async def ping(evt: CommandEvent) -> None: @command_handler(needs_auth=False, needs_puppeting=False, help_section=SECTION_AUTH, help_text="Get the info of the message relay Telegram bot.") -async def ping_bot(evt: CommandEvent) -> None: +async def ping_bot(evt: CommandEvent) -> Optional[Dict]: if not evt.tgbot: return await evt.reply("Telegram message relay bot not configured.") bot_info = await evt.tgbot.client.get_me() @@ -57,19 +57,19 @@ async def ping_bot(evt: CommandEvent) -> None: help_section=SECTION_AUTH, help_text="Revert your Telegram account's Matrix puppet to use the default Matrix " "account.") -async def logout_matrix(evt: CommandEvent) -> None: +async def logout_matrix(evt: CommandEvent) -> Optional[Dict]: puppet = pu.Puppet.get(evt.sender.tgid) if not puppet.is_real_user: return await evt.reply("You are not logged in with your Matrix account.") await puppet.switch_mxid(None, None) - await evt.reply("Reverted your Telegram account's Matrix puppet back to the default.") + return await evt.reply("Reverted your Telegram account's Matrix puppet back to the default.") @command_handler(needs_auth=True, management_only=True, needs_matrix_puppeting=True, help_section=SECTION_AUTH, help_text="Replace your Telegram account's Matrix puppet with your own Matrix " "account") -async def login_matrix(evt: CommandEvent) -> None: +async def login_matrix(evt: CommandEvent) -> Optional[Dict]: puppet = pu.Puppet.get(evt.sender.tgid) if puppet.is_real_user: return await evt.reply("You have already logged in with your Matrix account. " @@ -100,7 +100,7 @@ async def login_matrix(evt: CommandEvent) -> None: return await evt.reply("This bridge instance has been configured to not allow logging in.") -async def enter_matrix_token(evt: CommandEvent) -> None: +async def enter_matrix_token(evt: CommandEvent) -> Dict: evt.sender.command_status = None puppet = pu.Puppet.get(evt.sender.tgid) @@ -109,10 +109,11 @@ async def enter_matrix_token(evt: CommandEvent) -> None: "Log out with `$cmdprefix+sp logout-matrix` first.") resp = await puppet.switch_mxid(" ".join(evt.args), evt.sender.mxid) - if resp == 2: + if resp == pu.PuppetError.OnlyLoginSelf: return await evt.reply("You can only log in as your own Matrix user.") - elif resp == 1: + elif resp == pu.PuppetError.InvalidAccessToken: return await evt.reply("Failed to verify access token.") + assert resp == pu.PuppetError.Success, "Encountered an unhandled PuppetError." return await evt.reply( f"Replaced your Telegram account's Matrix puppet with {puppet.custom_mxid}.") @@ -121,7 +122,7 @@ async def enter_matrix_token(evt: CommandEvent) -> None: help_section=SECTION_AUTH, help_args="<_phone_> <_full name_>", help_text="Register to Telegram") -async def register(evt: CommandEvent) -> None: +async def register(evt: CommandEvent) -> Optional[Dict]: if await evt.sender.is_logged_in(): return await evt.reply("You are already logged in.") elif len(evt.args) < 1: @@ -138,9 +139,10 @@ async def register(evt: CommandEvent) -> None: "action": "Register", "full_name": full_name, }) + return None -async def enter_code_register(evt: CommandEvent) -> None: +async def enter_code_register(evt: CommandEvent) -> Dict: if len(evt.args) == 0: return await evt.reply("**Usage:** `$cmdprefix+sp `") try: @@ -169,7 +171,7 @@ async def enter_code_register(evt: CommandEvent) -> None: @command_handler(needs_auth=False, management_only=True, help_section=SECTION_AUTH, help_text="Get instructions on how to log in.") -async def login(evt: CommandEvent) -> None: +async def login(evt: CommandEvent) -> Optional[Dict]: if await evt.sender.is_logged_in(): return await evt.reply("You are already logged in.") @@ -200,7 +202,8 @@ async def login(evt: CommandEvent) -> None: return await evt.reply("This bridge instance has been configured to not allow logging in.") -async def request_code(evt: CommandEvent, phone_number: str, next_status: Dict[str, str]) -> None: +async def request_code(evt: CommandEvent, phone_number: str, next_status: Dict[str, Any] + ) -> Dict: ok = False try: await evt.sender.ensure_started(even_if_no_session=True) @@ -232,7 +235,7 @@ async def request_code(evt: CommandEvent, phone_number: str, next_status: Dict[s @command_handler(needs_auth=False) -async def enter_phone_or_token(evt: CommandEvent) -> None: +async def enter_phone_or_token(evt: CommandEvent) -> Optional[Dict]: if len(evt.args) == 0: return await evt.reply("**Usage:** `$cmdprefix+sp enter-phone-or-token `") elif not evt.config.get("bridge.allow_matrix_login", True): @@ -252,10 +255,11 @@ async def enter_phone_or_token(evt: CommandEvent) -> None: "next": enter_code, "action": "Login", }) + return None @command_handler(needs_auth=False) -async def enter_code(evt: CommandEvent) -> None: +async def enter_code(evt: CommandEvent) -> Optional[Dict]: if len(evt.args) == 0: return await evt.reply("**Usage:** `$cmdprefix+sp enter-code `") elif not evt.config.get("bridge.allow_matrix_login", True): @@ -267,10 +271,11 @@ async def enter_code(evt: CommandEvent) -> None: evt.log.exception("Error sending phone code") return await evt.reply("Unhandled exception while sending code. " "Check console for more details.") + return None @command_handler(needs_auth=False) -async def enter_password(evt: CommandEvent) -> None: +async def enter_password(evt: CommandEvent) -> Optional[Dict]: if len(evt.args) == 0: return await evt.reply("**Usage:** `$cmdprefix+sp enter-password `") elif not evt.config.get("bridge.allow_matrix_login", True): @@ -286,9 +291,10 @@ async def enter_password(evt: CommandEvent) -> None: evt.log.exception("Error sending password") return await evt.reply("Unhandled exception while sending password. " "Check console for more details.") + return None -async def sign_in(evt: CommandEvent, **sign_in_info) -> None: +async def sign_in(evt: CommandEvent, **sign_in_info) -> Dict: try: await evt.sender.ensure_started(even_if_no_session=True) user = await evt.sender.client.sign_in(**sign_in_info) @@ -313,7 +319,7 @@ async def sign_in(evt: CommandEvent, **sign_in_info) -> None: @command_handler(needs_auth=True, help_section=SECTION_AUTH, help_text="Log out from Telegram.") -async def logout(evt: CommandEvent) -> None: +async def logout(evt: CommandEvent) -> Optional[Dict]: if await evt.sender.log_out(): return await evt.reply("Logged out successfully.") return await evt.reply("Failed to log out.") diff --git a/mautrix_telegram/commands/clean_rooms.py b/mautrix_telegram/commands/clean_rooms.py index dce36902..2368c163 100644 --- a/mautrix_telegram/commands/clean_rooms.py +++ b/mautrix_telegram/commands/clean_rooms.py @@ -14,21 +14,21 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Tuple, List +from typing import Dict, List, NewType, Optional, Tuple, Union from mautrix_appservice import MatrixRequestError, IntentAPI +from ..types import MatrixRoomId, MatrixUserId from . import command_handler, CommandEvent, SECTION_ADMIN from .. import puppet as pu, portal as po -ManagementRoomList = List[Tuple[str, str]] -RoomIDList = List[str] +ManagementRoom = NewType('ManagementRoom', Tuple[MatrixRoomId, MatrixUserId]) -async def _find_rooms(intent: IntentAPI) -> Tuple[ManagementRoomList, RoomIDList, +async def _find_rooms(intent: IntentAPI) -> Tuple[List[ManagementRoom], List[MatrixRoomId], List["po.Portal"], List["po.Portal"]]: - management_rooms = [] # type: ManagementRoomList - unidentified_rooms = [] # type: RoomIDList + management_rooms = [] # type: List[ManagementRoom] + unidentified_rooms = [] # type: List[MatrixRoomId] portals = [] # type: List[po.Portal] empty_portals = [] # type: List[po.Portal] @@ -45,7 +45,7 @@ async def _find_rooms(intent: IntentAPI) -> Tuple[ManagementRoomList, RoomIDList if pu.Puppet.get_id_from_mxid(other_member): unidentified_rooms.append(room) else: - management_rooms.append((room, other_member)) + management_rooms.append(ManagementRoom((room, other_member))) else: unidentified_rooms.append(room) else: @@ -61,7 +61,7 @@ async def _find_rooms(intent: IntentAPI) -> Tuple[ManagementRoomList, RoomIDList @command_handler(needs_admin=True, needs_auth=False, management_only=True, name="clean-rooms", help_section=SECTION_ADMIN, help_text="Clean up unused portal/management rooms.") -async def clean_rooms(evt: CommandEvent) -> None: +async def clean_rooms(evt: CommandEvent) -> Optional[Dict]: management_rooms, unidentified_rooms, portals, empty_portals = await _find_rooms(evt.az.intent) reply = ["#### Management rooms (M)"] @@ -106,13 +106,14 @@ async def clean_rooms(evt: CommandEvent) -> None: return await evt.reply("\n".join(reply)) -async def set_rooms_to_clean(evt, management_rooms: ManagementRoomList, - unidentified_rooms: RoomIDList, portals: List["po.Portal"], +async def set_rooms_to_clean(evt, management_rooms: List[ManagementRoom], + unidentified_rooms: List[MatrixRoomId], portals: List["po.Portal"], empty_portals: List["po.Portal"]) -> None: command = evt.args[0] - rooms_to_clean = [] + rooms_to_clean = [] # type: List[Union[po.Portal, MatrixRoomId]] if command == "clean-recommended": - rooms_to_clean = empty_portals + unidentified_rooms + rooms_to_clean += empty_portals + rooms_to_clean += unidentified_rooms elif command == "clean-groups": if len(evt.args) < 2: return await evt.reply("**Usage:** `$cmdprefix+sp clean-groups [M][A][U][I]") @@ -158,7 +159,7 @@ async def set_rooms_to_clean(evt, management_rooms: ManagementRoomList, "`$cmdprefix+sp confirm-clean`.") -async def execute_room_cleanup(evt, rooms_to_clean) -> None: +async def execute_room_cleanup(evt, rooms_to_clean: List[Union[po.Portal, MatrixRoomId]]) -> None: if len(evt.args) > 0 and evt.args[0] == "confirm-clean": await evt.reply(f"Cleaning {len(rooms_to_clean)} rooms. " "This might take a while.") @@ -167,7 +168,7 @@ async def execute_room_cleanup(evt, rooms_to_clean) -> None: if isinstance(room, po.Portal): await room.cleanup_and_delete() cleaned += 1 - elif isinstance(room, str): + elif isinstance(room, str): # str is aliased by MatrixRoomId await po.Portal.cleanup_room(evt.az.intent, room, message="Room deleted") cleaned += 1 evt.sender.command_status = None diff --git a/mautrix_telegram/commands/handler.py b/mautrix_telegram/commands/handler.py index 2c28bef6..34f5b047 100644 --- a/mautrix_telegram/commands/handler.py +++ b/mautrix_telegram/commands/handler.py @@ -14,19 +14,20 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import List, Dict, Callable, Optional +from typing import Any, Awaitable, Callable, Coroutine, Dict, List, NamedTuple, Optional, Union from collections import namedtuple import markdown import logging from telethon.errors import FloodWaitError +from ..types import MatrixRoomId from ..util import format_duration from .. import user as u, context as c command_handlers = {} # type: Dict[str, CommandHandler] -HelpSection = namedtuple("HelpSection", "name order description") +HelpSection = NamedTuple('HelpSection', [('name', str), ('order', int), ('description', str)]) SECTION_GENERAL = HelpSection("General", 0, "") SECTION_AUTH = HelpSection("Authentication", 10, "") @@ -37,8 +38,8 @@ SECTION_ADMIN = HelpSection("Administration", 50, "") class CommandEvent: - def __init__(self, processor: "CommandProcessor", room: str, sender: u.User, command: str, - args: List[str], is_management: bool, is_portal: bool) -> None: + def __init__(self, processor: 'CommandProcessor', room: MatrixRoomId, sender: u.User, + command: str, args: List[str], is_management: bool, is_portal: bool) -> None: self.az = processor.az self.log = processor.log self.loop = processor.loop @@ -53,7 +54,8 @@ class CommandEvent: self.is_management = is_management self.is_portal = is_portal - def reply(self, message: str, allow_html: bool = False, render_markdown: bool = True) -> None: + def reply(self, message: str, allow_html: bool = False, render_markdown: bool = True + ) -> Awaitable[Dict]: message = message.replace("$cmdprefix+sp ", "" if self.is_management else f"{self.command_prefix} ") message = message.replace("$cmdprefix", self.command_prefix) @@ -66,7 +68,7 @@ class CommandEvent: class CommandHandler: - def __init__(self, handler: Callable[[CommandEvent], None], needs_auth: bool, + def __init__(self, handler: Callable[[CommandEvent], Awaitable[Dict]], needs_auth: bool, needs_puppeting: bool, needs_matrix_puppeting: bool, needs_admin: bool, management_only: bool, name: str, help_text: str, help_args: str, help_section: HelpSection) -> None: @@ -103,7 +105,8 @@ class CommandHandler: (not self.needs_admin or is_admin) and (not self.needs_auth or is_logged_in)) - async def __call__(self, evt: CommandEvent) -> None: + async def __call__(self, evt: CommandEvent + ) -> Dict: error = await self.get_permission_error(evt) if error is not None: return await evt.reply(error) @@ -118,13 +121,21 @@ class CommandHandler: return f"**{self.name}** {self._help_args} - {self._help_text}" -def command_handler(_func: Optional[Callable[[CommandEvent], None]] = None, *, needs_auth=True, - needs_puppeting=True, needs_matrix_puppeting=False, needs_admin=False, - management_only=False, name=None, help_text="", help_args="", - help_section=None) -> None: +def command_handler(_func: Optional[Callable[[CommandEvent], Awaitable[Dict]]] = None, *, + needs_auth: bool = True, + needs_puppeting: bool = True, + needs_matrix_puppeting: bool = False, + needs_admin: bool = False, + management_only: bool = False, + name: Optional[str] = None, + help_text: str = "", + help_args: str = "", + help_section: HelpSection = None + ) -> Callable[[Callable[[CommandEvent], Awaitable[Optional[Dict]]]], + CommandHandler]: input_name = name - def decorator(func: Callable[[CommandEvent], None]) -> None: + def decorator(func: Callable[[CommandEvent], Awaitable[Optional[Dict]]]) -> CommandHandler: name = input_name or func.__name__.replace("_", "-") handler = CommandHandler(func, needs_auth, needs_puppeting, needs_matrix_puppeting, needs_admin, management_only, name, help_text, help_args, @@ -139,26 +150,26 @@ class CommandProcessor: log = logging.getLogger("mau.commands") def __init__(self, context: c.Context) -> None: - self.az, self.db, self.config, self.loop, self.tgbot = context + self.az, self.db, self.config, self.loop, self.tgbot = context.core self.public_website = context.public_website self.command_prefix = self.config["bridge.command_prefix"] - async def handle(self, room: str, sender: u.User, command: str, args: List[str], - is_management: bool, is_portal: bool) -> None: + async def handle(self, room: MatrixRoomId, sender: u.User, command: str, args: List[str], + is_management: bool, is_portal: bool) -> Optional[Dict]: evt = CommandEvent(self, room, sender, command, args, is_management, is_portal) orig_command = command command = command.lower() try: - command = command_handlers[command] + command_handler = command_handlers[command] except KeyError: if sender.command_status and "next" in sender.command_status: args.insert(0, orig_command) evt.command = "" command = sender.command_status["next"] else: - command = command_handlers["unknown-command"] + command_handler = command_handlers["unknown-command"] try: - await command(evt) + await command_handler(evt) except FloodWaitError as e: return await evt.reply(f"Flood error: Please wait {format_duration(e.seconds)}") except Exception: @@ -166,3 +177,4 @@ class CommandProcessor: f"{evt.command} {' '.join(args)} from {sender.mxid}") return await evt.reply("Unhandled error while handling command. " "Check logs for more details.") + return None diff --git a/mautrix_telegram/commands/meta.py b/mautrix_telegram/commands/meta.py index 8478a636..5920dbee 100644 --- a/mautrix_telegram/commands/meta.py +++ b/mautrix_telegram/commands/meta.py @@ -14,46 +14,49 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . +from typing import Dict, List, Optional, Tuple + from . import command_handler, CommandEvent, _command_handlers, SECTION_GENERAL +from .handler import HelpSection @command_handler(needs_auth=False, needs_puppeting=False, help_section=SECTION_GENERAL, help_text="Cancel an ongoing action (such as login)") -def cancel(evt: CommandEvent) -> None: +async def cancel(evt: CommandEvent) -> Optional[Dict]: if evt.sender.command_status: action = evt.sender.command_status["action"] evt.sender.command_status = None - return evt.reply(f"{action} cancelled.") + return await evt.reply(f"{action} cancelled.") else: - return evt.reply("No ongoing command.") + return await evt.reply("No ongoing command.") @command_handler(needs_auth=False, needs_puppeting=False) -def unknown_command(evt: CommandEvent) -> None: - return evt.reply("Unknown command. Try `$cmdprefix+sp help` for help.") +async def unknown_command(evt: CommandEvent) -> Optional[Dict]: + return await evt.reply("Unknown command. Try `$cmdprefix+sp help` for help.") -help_cache = {} +help_cache = {} # type: Dict[Tuple[bool, bool, bool, bool, bool], str] -async def _get_help_text(evt: CommandEvent) -> None: +async def _get_help_text(evt: CommandEvent) -> str: cache_key = (evt.is_management, evt.sender.puppet_whitelisted, evt.sender.matrix_puppet_whitelisted, evt.sender.is_admin, await evt.sender.is_logged_in()) if cache_key not in help_cache: - help = {} + help_sections = {} # type: Dict[HelpSection, List[str]] for handler in _command_handlers.values(): if handler.has_help and handler.has_permission(*cache_key): - help.setdefault(handler.help_section, []) - help[handler.help_section].append(handler.help + " ") - help = sorted(help.items(), key=lambda item: item[0].order) - help = ["#### {}\n{}\n".format(key.name, "\n".join(value)) for key, value in help] + help_sections.setdefault(handler.help_section, []) + help_sections[handler.help_section].append(handler.help + " ") + help_sorted = sorted(help_sections.items(), key=lambda item: item[0].order) + help = ["#### {}\n{}\n".format(key.name, "\n".join(value)) for key, value in help_sorted] help_cache[cache_key] = "\n".join(help) return help_cache[cache_key] -def _get_management_status(evt: CommandEvent) -> None: +def _get_management_status(evt: CommandEvent) -> str: if evt.is_management: return "This is a management room: prefixing commands with `$cmdprefix` is not required." elif evt.is_portal: @@ -65,5 +68,5 @@ def _get_management_status(evt: CommandEvent) -> None: @command_handler(needs_auth=False, needs_puppeting=False, help_section=SECTION_GENERAL, help_text="Show this help message.") -async def help(evt: CommandEvent) -> None: +async def help(evt: CommandEvent) -> Optional[Dict]: return await evt.reply(_get_management_status(evt) + "\n" + await _get_help_text(evt)) diff --git a/mautrix_telegram/commands/portal.py b/mautrix_telegram/commands/portal.py index db0a8421..fa48f765 100644 --- a/mautrix_telegram/commands/portal.py +++ b/mautrix_telegram/commands/portal.py @@ -14,7 +14,7 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Optional, Callable +from typing import Awaitable, Dict, Callable, Coroutine, Optional, Tuple, Union, cast import asyncio from telethon.errors import (ChatAdminRequiredError, UsernameInvalidError, @@ -22,6 +22,7 @@ from telethon.errors import (ChatAdminRequiredError, UsernameInvalidError, from telethon.tl.types import ChatForbidden, ChannelForbidden from mautrix_appservice import MatrixRequestError, IntentAPI +from ..types import MatrixRoomId, TelegramId from .. import portal as po, user as u from . import (command_handler, CommandEvent, SECTION_ADMIN, SECTION_CREATING_PORTALS, SECTION_PORTAL_MANAGEMENT) @@ -31,7 +32,7 @@ from . import (command_handler, CommandEvent, help_section=SECTION_ADMIN, help_args="<_level_> [_mxid_]", help_text="Set a temporary power level without affecting Telegram.") -async def set_power_level(evt: CommandEvent) -> None: +async def set_power_level(evt: CommandEvent) -> Dict: try: level = int(evt.args[0]) except KeyError: @@ -46,11 +47,12 @@ async def set_power_level(evt: CommandEvent) -> None: except MatrixRequestError: evt.log.exception("Failed to set power level.") return await evt.reply("Failed to set power level.") + return {} @command_handler(help_section=SECTION_PORTAL_MANAGEMENT, help_text="Get a Telegram invite link to the current chat.") -async def invite_link(evt: CommandEvent) -> None: +async def invite_link(evt: CommandEvent) -> Dict: portal = po.Portal.get_by_mxid(evt.room_id) if not portal: return await evt.reply("This is not a portal room.") @@ -68,7 +70,7 @@ async def invite_link(evt: CommandEvent) -> None: async def user_has_power_level(room: str, intent, sender: u.User, event: str, default: int = 50 - ) -> None: + ) -> bool: if sender.is_admin: return True # Make sure the state store contains the power levels. @@ -82,8 +84,9 @@ async def user_has_power_level(room: str, intent, sender: u.User, event: str, de async def _get_portal_and_check_permission(evt: CommandEvent, permission: str, - action: Optional[str] = None) -> None: - room_id = evt.args[0] if len(evt.args) > 0 else evt.room_id + action: Optional[str] = None + ) -> Tuple[Union[Dict, po.Portal], bool]: + room_id = MatrixRoomId(evt.args[0]) if len(evt.args) > 0 else evt.room_id portal = po.Portal.get_by_mxid(room_id) if not portal: @@ -97,8 +100,8 @@ async def _get_portal_and_check_permission(evt: CommandEvent, permission: str, def _get_portal_murder_function(action: str, room_id: str, function: Callable, command: str, - completed_message: str) -> None: - async def post_confirm(confirm) -> None: + completed_message: str) -> Dict: + async def post_confirm(confirm) -> Optional[Dict]: confirm.sender.command_status = None if len(confirm.args) > 0 and confirm.args[0] == f"confirm-{command}": await function() @@ -106,6 +109,7 @@ def _get_portal_murder_function(action: str, room_id: str, function: Callable, c return await confirm.reply(completed_message) else: return await confirm.reply(f"{action} cancelled.") + return None return { "next": post_confirm, @@ -118,10 +122,11 @@ def _get_portal_murder_function(action: str, room_id: str, function: Callable, c help_text="Remove all users from the current portal room and forget the portal. " "Only works for group chats; to delete a private chat portal, simply " "leave the room.") -async def delete_portal(evt: CommandEvent) -> None: - portal, ok = await _get_portal_and_check_permission(evt, "unbridge") +async def delete_portal(evt: CommandEvent) -> Optional[Dict]: + result, ok = await _get_portal_and_check_permission(evt, "unbridge") if not ok: - return + return None + portal = cast('po.Portal', result) evt.sender.command_status = _get_portal_murder_function("Portal deletion", portal.mxid, portal.cleanup_and_delete, "delete", @@ -139,10 +144,11 @@ async def delete_portal(evt: CommandEvent) -> None: @command_handler(needs_auth=False, needs_puppeting=False, help_section=SECTION_PORTAL_MANAGEMENT, help_text="Remove puppets from the current portal room and forget the portal.") -async def unbridge(evt: CommandEvent) -> None: - portal, ok = await _get_portal_and_check_permission(evt, "unbridge") +async def unbridge(evt: CommandEvent) -> Optional[Dict]: + result, ok = await _get_portal_and_check_permission(evt, "unbridge") if not ok: - return + return None + portal = cast('po.Portal', result) evt.sender.command_status = _get_portal_murder_function("Room unbridging", portal.mxid, portal.unbridge, "unbridge", @@ -158,11 +164,11 @@ async def unbridge(evt: CommandEvent) -> None: help_text="Bridge the current Matrix room to the Telegram chat with the given " "ID. The ID must be the prefixed version that you get with the `/id` " "command of the Telegram-side bot.") -async def bridge(evt: CommandEvent) -> None: +async def bridge(evt: CommandEvent) -> Dict: if len(evt.args) == 0: return await evt.reply("**Usage:** " "`$cmdprefix+sp bridge [Matrix room ID]`") - room_id = evt.args[1] if len(evt.args) > 1 else evt.room_id + room_id = MatrixRoomId(evt.args[1]) if len(evt.args) > 1 else evt.room_id that_this = "This" if room_id == evt.room_id else "That" portal = po.Portal.get_by_mxid(room_id) @@ -173,12 +179,12 @@ async def bridge(evt: CommandEvent) -> None: return await evt.reply(f"You do not have the permissions to bridge {that_this} room.") # The /id bot command provides the prefixed ID, so we assume - tgid = evt.args[0] - if tgid.startswith("-100"): - tgid = int(tgid[4:]) + tgid_str = evt.args[0] + if tgid_str.startswith("-100"): + tgid = TelegramId(int(tgid_str[4:])) peer_type = "channel" - elif tgid.startswith("-"): - tgid = -int(tgid) + elif tgid_str.startswith("-"): + tgid = TelegramId(-int(tgid_str)) peer_type = "chat" else: return await evt.reply("That doesn't seem like a prefixed Telegram chat ID.\n\n" @@ -224,7 +230,8 @@ async def bridge(evt: CommandEvent) -> None: "chat to this room, use `$cmdprefix+sp continue`") -async def cleanup_old_portal_while_bridging(evt: CommandEvent, portal: "po.Portal") -> None: +async def cleanup_old_portal_while_bridging(evt: CommandEvent, portal: "po.Portal" + ) -> Tuple[bool, Coroutine[None, None, None]]: if not portal.mxid: await evt.reply("The portal seems to have lost its Matrix room between you" "calling `$cmdprefix+sp bridge` and this command.\n\n" @@ -247,7 +254,7 @@ async def cleanup_old_portal_while_bridging(evt: CommandEvent, portal: "po.Porta return False, None -async def confirm_bridge(evt: CommandEvent) -> None: +async def confirm_bridge(evt: CommandEvent) -> Optional[Dict]: status = evt.sender.command_status try: portal = po.Portal.get_by_tgid(status["tgid"], peer_type=status["peer_type"]) @@ -260,7 +267,7 @@ async def confirm_bridge(evt: CommandEvent) -> None: if "mxid" in status: ok, coro = await cleanup_old_portal_while_bridging(evt, portal) if not ok: - return + return None elif coro: asyncio.ensure_future(coro, loop=evt.loop) await evt.reply("Cleaning up previous portal room...") @@ -304,7 +311,7 @@ async def confirm_bridge(evt: CommandEvent) -> None: return await evt.reply("Bridging complete. Portal synchronization should begin momentarily.") -async def get_initial_state(intent: IntentAPI, room_id: str) -> None: +async def get_initial_state(intent: IntentAPI, room_id: str) -> Tuple[str, str, Dict]: state = await intent.get_room_state(room_id) title = None about = None @@ -330,7 +337,7 @@ async def get_initial_state(intent: IntentAPI, room_id: str) -> None: help_text="Create a Telegram chat of the given type for the current Matrix room. " "The type is either `group`, `supergroup` or `channel` (defaults to " "`group`).") -async def create(evt: CommandEvent) -> None: +async def create(evt: CommandEvent) -> Dict: type = evt.args[0] if len(evt.args) > 0 else "group" if type not in {"chat", "group", "supergroup", "channel"}: return await evt.reply( @@ -365,7 +372,7 @@ async def create(evt: CommandEvent) -> None: @command_handler(help_section=SECTION_PORTAL_MANAGEMENT, help_text="Upgrade a normal Telegram group to a supergroup.") -async def upgrade(evt: CommandEvent) -> None: +async def upgrade(evt: CommandEvent) -> Dict: portal = po.Portal.get_by_mxid(evt.room_id) if not portal: return await evt.reply("This is not a portal room.") @@ -387,7 +394,7 @@ async def upgrade(evt: CommandEvent) -> None: help_args="<_name_|`-`>", help_text="Change the username of a supergroup/channel. " "To disable, use a dash (`-`) as the name.") -async def group_name(evt: CommandEvent) -> None: +async def group_name(evt: CommandEvent) -> Dict: if len(evt.args) == 0: return await evt.reply("**Usage:** `$cmdprefix+sp group-name `") @@ -423,7 +430,7 @@ async def group_name(evt: CommandEvent) -> None: help_args="<`whitelist`|`blacklist`>", help_text="Change whether the bridge will allow or disallow bridging rooms by " "default.") -async def filter_mode(evt: CommandEvent) -> None: +async def filter_mode(evt: CommandEvent) -> Dict: try: mode = evt.args[0] if mode not in ("whitelist", "blacklist"): @@ -448,19 +455,19 @@ async def filter_mode(evt: CommandEvent) -> None: help_section=SECTION_ADMIN, help_args="<`whitelist`|`blacklist`> <_chat ID_>", help_text="Allow or disallow bridging a specific chat.") -async def filter(evt: CommandEvent) -> None: +async def filter(evt: CommandEvent) -> Optional[Dict]: try: action = evt.args[0] if action not in ("whitelist", "blacklist", "add", "remove"): raise ValueError() - id = evt.args[1] - if id.startswith("-100"): - id = int(id[4:]) - elif id.startswith("-"): - id = int(id[1:]) + id_str = evt.args[1] + if id_str.startswith("-100"): + id = int(id_str[4:]) + elif id_str.startswith("-"): + id = int(id_str[1:]) else: - id = int(id) + id = int(id_str) except (IndexError, ValueError): return await evt.reply("**Usage:** `$cmdprefix+sp filter `") @@ -490,3 +497,4 @@ async def filter(evt: CommandEvent) -> None: list.remove(id) save() return await evt.reply(f"Chat ID removed from {mode}.") + return None diff --git a/mautrix_telegram/commands/telegram.py b/mautrix_telegram/commands/telegram.py index 3faa3e56..2f968742 100644 --- a/mautrix_telegram/commands/telegram.py +++ b/mautrix_telegram/commands/telegram.py @@ -14,10 +14,13 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . +from typing import Awaitable, Dict, List, Optional, Tuple +import re from telethon.errors import ( InviteHashInvalidError, InviteHashExpiredError, UserAlreadyParticipantError) from telethon.tl.types import User as TLUser +from telethon.tl.types import TypeUpdates from telethon.tl.functions.messages import ImportChatInviteRequest, CheckChatInviteRequest from telethon.tl.functions.channels import JoinChannelRequest @@ -28,7 +31,7 @@ from . import command_handler, CommandEvent, SECTION_MISC, SECTION_CREATING_PORT @command_handler(help_section=SECTION_MISC, help_args="[_-r|--remote_] <_query_>", help_text="Search your contacts or the Telegram servers for users.") -async def search(evt: CommandEvent) -> None: +async def search(evt: CommandEvent) -> Optional[Dict]: if len(evt.args) == 0: return await evt.reply("**Usage:** `$cmdprefix+sp search [-r|--remote] `") @@ -49,7 +52,7 @@ async def search(evt: CommandEvent) -> None: "Minimum length of remote query is 5 characters.") return await evt.reply("No results 3:") - reply = [] + reply = [] # type: List[str] if remote: reply += ["**Results from Telegram server:**", ""] else: @@ -70,7 +73,7 @@ async def search(evt: CommandEvent) -> None: "either the internal user ID, the username or the phone number. " "**N.B.** The phone numbers you start chats with must already be in " "your contacts.") -async def private_message(evt: CommandEvent) -> None: +async def private_message(evt: CommandEvent) -> Optional[Dict]: if len(evt.args) == 0: return await evt.reply("**Usage:** `$cmdprefix+sp pm `") @@ -89,7 +92,7 @@ async def private_message(evt: CommandEvent) -> None: f"{pu.Puppet.get_displayname(user, False)}") -async def _join(evt: CommandEvent, arg: str) -> None: +async def _join(evt: CommandEvent, arg: str) -> Tuple[TypeUpdates, Dict]: if arg.startswith("joinchat/"): invite_hash = arg[len("joinchat/"):] try: @@ -112,7 +115,7 @@ async def _join(evt: CommandEvent, arg: str) -> None: @command_handler(help_section=SECTION_CREATING_PORTALS, help_args="<_link_>", help_text="Join a chat with an invite link.") -async def join(evt: CommandEvent) -> None: +async def join(evt: CommandEvent) -> Optional[Dict]: if len(evt.args) == 0: return await evt.reply("**Usage:** `$cmdprefix+sp join `") @@ -123,7 +126,7 @@ async def join(evt: CommandEvent) -> None: updates, _ = await _join(evt, arg.group(1)) if not updates: - return + return None for chat in updates.chats: portal = po.Portal.get_by_entity(chat) @@ -134,12 +137,13 @@ async def join(evt: CommandEvent) -> None: await evt.reply(f"Creating room for {chat.title}... This might take a while.") await portal.create_matrix_room(evt.sender, chat, [evt.sender.mxid]) return await evt.reply(f"Created room for {portal.title}") + return None @command_handler(help_section=SECTION_MISC, help_args="[`chats`|`contacts`|`me`]", help_text="Synchronize your chat portals, contacts and/or own info.") -async def sync(evt: CommandEvent) -> None: +async def sync(evt: CommandEvent) -> Optional[Dict]: if len(evt.args) > 0: sync_only = evt.args[0] if sync_only not in ("chats", "contacts", "me"): diff --git a/mautrix_telegram/config.py b/mautrix_telegram/config.py index c40b910b..7f0029d3 100644 --- a/mautrix_telegram/config.py +++ b/mautrix_telegram/config.py @@ -14,7 +14,7 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Tuple, Any, Optional +from typing import Any, Dict, Optional, Tuple from ruamel.yaml import YAML from ruamel.yaml.comments import CommentedMap import random @@ -25,7 +25,7 @@ yaml.indent(4) class DictWithRecursion: - def __init__(self, data: CommentedMap = None) -> None: + def __init__(self, data: Optional[CommentedMap] = None) -> None: self._data = data or CommentedMap() # type: CommentedMap def _recursive_get(self, data: CommentedMap, key: str, default_value: Any) -> Any: @@ -99,7 +99,7 @@ class Config(DictWithRecursion): self.path = path # type: str self.registration_path = registration_path # type: str self.base_path = base_path # type: str - self._registration = None # type: dict + self._registration = None # type: Optional[Dict] def load(self) -> None: with open(self.path, 'r') as stream: diff --git a/mautrix_telegram/context.py b/mautrix_telegram/context.py index ac48e239..4330c102 100644 --- a/mautrix_telegram/context.py +++ b/mautrix_telegram/context.py @@ -14,7 +14,7 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import TYPE_CHECKING, Optional +from typing import Generator, Optional, Tuple, Union, TYPE_CHECKING if TYPE_CHECKING: import asyncio @@ -44,9 +44,7 @@ class Context: self.public_website = None # type: PublicBridgeWebsite self.provisioning_api = None # type: ProvisioningAPI - def __iter__(self) -> None: - yield self.az - yield self.db - yield self.config - yield self.loop - yield self.bot + @property + def core(self) -> Tuple['AppService', 'scoped_session', 'Config', + 'asyncio.AbstractEventLoop', Optional['Bot']]: + return (self.az, self.db, self.config, self.loop, self.bot) diff --git a/mautrix_telegram/db.py b/mautrix_telegram/db.py index 13aa12a4..751ca76f 100644 --- a/mautrix_telegram/db.py +++ b/mautrix_telegram/db.py @@ -14,6 +14,8 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . +from typing import Dict + from sqlalchemy import (Column, UniqueConstraint, ForeignKey, ForeignKeyConstraint, Integer, BigInteger, String, Boolean, Text) from sqlalchemy.sql import expression @@ -88,20 +90,20 @@ class RoomState(Base): room_id = Column(String, primary_key=True) _power_levels_text = Column("power_levels", Text, nullable=True) - _power_levels_json = None + _power_levels_json = {} # type: Dict @property - def has_power_levels(self) -> None: + def has_power_levels(self) -> bool: return bool(self._power_levels_text) @property - def power_levels(self) -> None: + def power_levels(self) -> Dict: if not self._power_levels_json and self._power_levels_text: self._power_levels_json = json.loads(self._power_levels_text) - return self._power_levels_json or {} + return self._power_levels_json @power_levels.setter - def power_levels(self, val) -> None: + def power_levels(self, val: Dict) -> None: self._power_levels_json = val self._power_levels_text = json.dumps(val) @@ -116,7 +118,7 @@ class UserProfile(Base): displayname = Column(String, nullable=True) avatar_url = Column(String, nullable=True) - def dict(self) -> None: + def dict(self) -> Dict[str, Column]: return { "membership": self.membership, "displayname": self.displayname, diff --git a/mautrix_telegram/formatter/from_matrix/parser_htmlparser.py b/mautrix_telegram/formatter/from_matrix/parser_htmlparser.py index 4de0cba1..ad085fe9 100644 --- a/mautrix_telegram/formatter/from_matrix/parser_htmlparser.py +++ b/mautrix_telegram/formatter/from_matrix/parser_htmlparser.py @@ -80,12 +80,12 @@ class MatrixParser(HTMLParser, MatrixParserCommon): args["url"] = url return MessageEntityTextUrl, None - def handle_starttag(self, tag: str, attrs: List[Tuple[str, str]]): + def handle_starttag(self, tag: str, attrs_list: List[Tuple[str, str]]): self._open_tags.appendleft(tag) self._open_tags_meta.appendleft(0) - attrs = dict(attrs) - entity_type = None # type: type(TypeMessageEntity) + attrs = dict(attrs_list) + entity_type = None # type: Optional[Type[TypeMessageEntity]] args = {} # type: Dict[str, Any] if tag in ("strong", "b"): entity_type = MessageEntityBold diff --git a/mautrix_telegram/formatter/from_matrix/parser_lxml.py b/mautrix_telegram/formatter/from_matrix/parser_lxml.py index 0a997db3..25997d08 100644 --- a/mautrix_telegram/formatter/from_matrix/parser_lxml.py +++ b/mautrix_telegram/formatter/from_matrix/parser_lxml.py @@ -14,7 +14,7 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Optional, List, Tuple, Union, Callable +from typing import Callable, List, Optional, Sequence, Tuple, Type, Union from lxml import html from telethon.tl.types import (MessageEntityMention as Mention, @@ -83,7 +83,7 @@ def offset_length_multiply(amount: int): class TelegramMessage: - def __init__(self, text: str = "", entities: Optional[List[TypeMessageEntity]] = None): + def __init__(self, text: str = "", entities: Optional[List[TypeMessageEntity]] = None) -> None: self.text = text # type: str self.entities = entities or [] # type: List[TypeMessageEntity] @@ -120,7 +120,7 @@ class TelegramMessage: self.text = msg.text + self.text return self - def format(self, entity_type: type(TypeMessageEntity), offset: int = None, length: int = None, + def format(self, entity_type: Type[TypeMessageEntity], offset: int = None, length: int = None, **kwargs) -> "TelegramMessage": self.entities.append(entity_type(offset=offset or 0, length=length if length is not None else len(self.text), @@ -158,7 +158,8 @@ class TelegramMessage: return output @staticmethod - def join(items: List[Union[str, "TelegramMessage"]], separator: str = " ") -> "TelegramMessage": + def join(items: Sequence[Union[str, "TelegramMessage"]], + separator: str = " ") -> "TelegramMessage": main = TelegramMessage() for msg in items: if isinstance(msg, str): diff --git a/mautrix_telegram/formatter/from_telegram.py b/mautrix_telegram/formatter/from_telegram.py index 5afb1a58..f9e246e0 100644 --- a/mautrix_telegram/formatter/from_telegram.py +++ b/mautrix_telegram/formatter/from_telegram.py @@ -14,7 +14,7 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Optional, List, Tuple, TYPE_CHECKING +from typing import Dict, List, Optional, Tuple, TYPE_CHECKING from html import escape import logging import re @@ -28,6 +28,7 @@ from telethon.tl.types import (MessageEntityMention, MessageEntityMentionName, from mautrix_appservice import MatrixRequestError from mautrix_appservice.intent_api import IntentAPI +from ..types import TelegramId from .. import user as u, puppet as pu, portal as po from ..db import Message as DBMessage from .util import (add_surrogates, remove_surrogates, trim_reply_fallback_html, @@ -40,14 +41,14 @@ if TYPE_CHECKING: try: from lxml.html.diff import htmldiff except ImportError: - htmldiff = None # type: function + htmldiff = None # type: ignore log = logging.getLogger("mau.fmt.tg") # type: logging.Logger should_highlight_edits = False # type: bool -def telegram_reply_to_matrix(evt: Message, source: "AbstractUser") -> dict: +def telegram_reply_to_matrix(evt: Message, source: 'AbstractUser') -> Dict: if evt.reply_to_msg_id: space = (evt.to_id.channel_id if isinstance(evt, Message) and isinstance(evt.to_id, PeerChannel) @@ -116,7 +117,7 @@ def highlight_edits(new_html: str, old_html: str) -> str: async def _add_reply_header(source: "AbstractUser", text: str, html: str, evt: Message, - relates_to: dict, main_intent: IntentAPI, is_edit: bool + relates_to: Dict, main_intent: IntentAPI, is_edit: bool ) -> Tuple[str, str]: space = (evt.to_id.channel_id if isinstance(evt, Message) and isinstance(evt.to_id, PeerChannel) @@ -177,10 +178,10 @@ async def _add_reply_header(source: "AbstractUser", text: str, html: str, evt: M async def telegram_to_matrix(evt: Message, source: "AbstractUser", main_intent: Optional[IntentAPI] = None, is_edit: bool = False, prefix_text: Optional[str] = None, - prefix_html: Optional[str] = None) -> Tuple[str, str, dict]: + prefix_html: Optional[str] = None) -> Tuple[str, str, Dict]: text = add_surrogates(evt.message) html = _telegram_entities_to_matrix_catch(text, evt.entities) if evt.entities else None - relates_to = {} + relates_to = {} # type: Dict if prefix_html: html = prefix_html + (html or escape(text)) @@ -217,6 +218,7 @@ def _telegram_entities_to_matrix_catch(text: str, entities: List[TypeMessageEnti "message=%s\n" "entities=%s", text, entities) + return "[failed conversion in _telegram_entities_to_matrix]" def _telegram_entities_to_matrix(text: str, entities: List[TypeMessageEntity]) -> str: @@ -290,7 +292,7 @@ def _parse_mention(html: List[str], entity_text: str) -> bool: return False -def _parse_name_mention(html: List[str], entity_text: str, user_id: int) -> bool: +def _parse_name_mention(html: List[str], entity_text: str, user_id: TelegramId) -> bool: user = u.User.get_by_tgid(user_id) if user: mxid = user.mxid @@ -315,8 +317,8 @@ def _parse_url(html: List[str], entity_text: str, url: str) -> bool: message_link_match = message_link_regex.match(url) if message_link_match: - group, msgid = message_link_match.groups() - msgid = int(msgid) + group, msgid_str = message_link_match.groups() + msgid = int(msgid_str) portal = po.Portal.find_by_username(group) if portal: diff --git a/mautrix_telegram/matrix.py b/mautrix_telegram/matrix.py index 70bafb47..362ba678 100644 --- a/mautrix_telegram/matrix.py +++ b/mautrix_telegram/matrix.py @@ -14,23 +14,31 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import List, Dict, Tuple, Set, Match +from typing import Dict, List, Match, Optional, Set, Tuple, TYPE_CHECKING import logging import asyncio import re from mautrix_appservice import MatrixRequestError, IntentError +from .types import MatrixEvent, MatrixEventId, MatrixRoomId, MatrixUserId from . import user as u, portal as po, puppet as pu, commands as com +if TYPE_CHECKING: + from mautrix_appservice import AppService + from .context import Context + from sqlalchemy.orm import scoped_session + from .config import Config + from .bot import Bot + class MatrixHandler: log = logging.getLogger("mau.mx") # type: logging.Logger - def __init__(self, context) -> None: - self.az, self.db, self.config, _, self.tgbot = context + def __init__(self, context: 'Context') -> None: + self.az, self.db, self.config, _, self.tgbot = context.core self.commands = com.CommandProcessor(context) # type: com.CommandProcessor - self.previously_typing = [] # type: List[str] + self.previously_typing = [] # type: List[MatrixUserId] self.az.matrix_event_handler(self.handle_event) @@ -50,7 +58,8 @@ class MatrixHandler: except asyncio.TimeoutError: self.log.exception("TimeoutError when trying to set avatar") - async def handle_puppet_invite(self, room_id, puppet: pu.Puppet, inviter: u.User) -> None: + async def handle_puppet_invite(self, room_id: MatrixRoomId, puppet: pu.Puppet, inviter: u.User + ) -> None: intent = puppet.default_mxid_intent self.log.debug(f"{inviter} invited puppet for {puppet.tgid} to {room_id}") if not await inviter.is_logged_in(): @@ -80,6 +89,7 @@ class MatrixHandler: await intent.join_room(room_id) portal = po.Portal.get_by_tgid(puppet.tgid, inviter.tgid, "user") + # TODO: if portal is None: if portal.mxid: try: await intent.invite(portal.mxid, inviter.mxid) @@ -95,13 +105,13 @@ class MatrixHandler: portal.mxid = room_id portal.save() inviter.register_portal(portal) - await intent.send_notice(room_id, "po.Portal to private chat created.") + await intent.send_notice(room_id, "Portal to private chat created.") else: await intent.join_room(room_id) await intent.send_notice(room_id, "This puppet will remain inactive until a " "Telegram chat is created for this room.") - async def accept_bot_invite(self, room_id: str, inviter: u.User) -> None: + async def accept_bot_invite(self, room_id: MatrixRoomId, inviter: u.User) -> None: tries = 0 while tries < 5: try: @@ -126,9 +136,13 @@ class MatrixHandler: "bridge.permissions section in your config file.") await self.az.intent.leave_room(room_id) - async def handle_invite(self, room_id: str, user_id: str, inviter_mxid: str) -> None: + async def handle_invite(self, room_id: MatrixRoomId, user_id: MatrixUserId, + inviter_mxid: MatrixUserId) -> None: self.log.debug(f"{inviter_mxid} invited {user_id} to {room_id}") - inviter = await u.User.get_by_mxid(inviter_mxid).ensure_started() + inviter = u.User.get_by_mxid(inviter_mxid) + if inviter is None: + self.log.exception("Failed to find user with Matrix ID {inviter_mxid}") + await inviter.ensure_started() if user_id == self.az.bot_mxid: return await self.accept_bot_invite(room_id, inviter) elif not inviter.whitelisted: @@ -150,7 +164,8 @@ class MatrixHandler: # The rest can probably be ignored - async def handle_join(self, room_id: str, user_id: str, event_id: str) -> None: + async def handle_join(self, room_id: MatrixRoomId, user_id: MatrixUserId, + event_id: MatrixEventId) -> None: user = await u.User.get_by_mxid(user_id).ensure_started() portal = po.Portal.get_by_mxid(room_id) @@ -171,7 +186,8 @@ class MatrixHandler: if await user.is_logged_in() or portal.has_bot: await portal.join_matrix(user, event_id) - async def handle_part(self, room_id: str, user_id, sender_mxid: str, event_id: str) -> None: + async def handle_part(self, room_id: MatrixRoomId, user_id: MatrixUserId, + sender_mxid: MatrixUserId, event_id: MatrixEventId) -> None: self.log.debug(f"{user_id} left {room_id}") sender = u.User.get_by_mxid(sender_mxid, create=False) @@ -185,6 +201,7 @@ class MatrixHandler: puppet = pu.Puppet.get_by_mxid(user_id) if sender and puppet: + # TODO: Puppet should probably be an AbstractUser await portal.leave_matrix(puppet, sender, event_id) user = u.User.get_by_mxid(user_id, create=False) @@ -194,7 +211,7 @@ class MatrixHandler: if await user.is_logged_in() or portal.has_bot: await portal.leave_matrix(user, sender, event_id) - def is_command(self, message: dict) -> Tuple[bool, str]: + def is_command(self, message: Dict) -> Tuple[bool, str]: text = message.get("body", "") prefix = self.config["bridge.command_prefix"] is_command = text.startswith(prefix) @@ -202,9 +219,10 @@ class MatrixHandler: text = text[len(prefix) + 1:] return is_command, text - async def handle_message(self, room, sender, message, event_id) -> None: + async def handle_message(self, room: MatrixRoomId, sender_id: MatrixUserId, message: Dict, + event_id: MatrixEventId) -> None: is_command, text = self.is_command(message) - sender = await u.User.get_by_mxid(sender).ensure_started() + sender = await u.User.get_by_mxid(sender_id).ensure_started() if not sender.relaybot_whitelisted: self.log.debug(f"Ignoring message \"{message}\" from {sender} to {room}:" " u.User is not whitelisted.") @@ -237,7 +255,8 @@ class MatrixHandler: is_portal=portal is not None) @staticmethod - async def handle_redaction(room_id: str, sender_mxid: str, event_id: str) -> None: + async def handle_redaction(room_id: MatrixRoomId, sender_mxid: MatrixUserId, + event_id: MatrixEventId) -> None: sender = await u.User.get_by_mxid(sender_mxid).ensure_started() if not sender.relaybot_whitelisted: return @@ -249,14 +268,15 @@ class MatrixHandler: await portal.handle_matrix_deletion(sender, event_id) @staticmethod - async def handle_power_levels(room_id: str, sender_mxid: str, new: dict, old: dict) -> None: + async def handle_power_levels(room_id: MatrixRoomId, sender_mxid: MatrixUserId, + new: Dict, old: Dict) -> None: portal = po.Portal.get_by_mxid(room_id) sender = await u.User.get_by_mxid(sender_mxid).ensure_started() if await sender.has_full_access(allow_bot=True) and portal: await portal.handle_matrix_power_levels(sender, new["users"], old["users"]) @staticmethod - async def handle_room_meta(evt_type: str, room_id: str, sender_mxid: str, + async def handle_room_meta(evt_type: str, room_id: MatrixRoomId, sender_mxid: MatrixUserId, content: dict) -> None: portal = po.Portal.get_by_mxid(room_id) sender = await u.User.get_by_mxid(sender_mxid).ensure_started() @@ -271,8 +291,8 @@ class MatrixHandler: await handler(sender, content[content_key]) @staticmethod - async def handle_room_pin(room_id: str, sender_mxid: str, new_events: Set[str], - old_events: Set[str]) -> None: + async def handle_room_pin(room_id: MatrixRoomId, sender_mxid: MatrixUserId, + new_events: Set[str], old_events: Set[str]) -> None: portal = po.Portal.get_by_mxid(room_id) sender = await u.User.get_by_mxid(sender_mxid).ensure_started() if await sender.has_full_access(allow_bot=True) and portal: @@ -285,8 +305,8 @@ class MatrixHandler: await portal.handle_matrix_pin(sender, None) @staticmethod - async def handle_name_change(room_id: str, user_id: str, displayname: str, - prev_displayname: str, event_id: str) -> None: + async def handle_name_change(room_id: MatrixRoomId, user_id: MatrixUserId, displayname: str, + prev_displayname: str, event_id: MatrixEventId) -> None: portal = po.Portal.get_by_mxid(room_id) if not portal or not portal.has_bot: return @@ -296,13 +316,14 @@ class MatrixHandler: await portal.name_change_matrix(user, displayname, prev_displayname, event_id) @staticmethod - def parse_read_receipts(content: dict) -> Dict[str, str]: + def parse_read_receipts(content: Dict) -> Dict[MatrixUserId, MatrixEventId]: return {user_id: event_id for event_id, receipts in content.items() for user_id in receipts.get("m.read", {})} @staticmethod - async def handle_read_receipts(room_id: str, receipts: Dict[str, str]) -> None: + async def handle_read_receipts(room_id: MatrixRoomId, + receipts: Dict[MatrixUserId, MatrixEventId]) -> None: portal = po.Portal.get_by_mxid(room_id) if not portal: return @@ -314,13 +335,13 @@ class MatrixHandler: await portal.mark_read(user, event_id) @staticmethod - async def handle_presence(user_id: str, presence: str) -> None: + async def handle_presence(user_id: MatrixUserId, presence: str) -> None: user = await u.User.get_by_mxid(user_id).ensure_started() if not await user.is_logged_in(): return - await user.set_presence(presence == "online") + user.set_presence(presence == "online") - async def handle_typing(self, room_id: str, now_typing: List[str]) -> None: + async def handle_typing(self, room_id: MatrixRoomId, now_typing: List[MatrixUserId]) -> None: portal = po.Portal.get_by_mxid(room_id) if not portal: return @@ -335,35 +356,35 @@ class MatrixHandler: if not await user.is_logged_in(): continue - await portal.set_typing(user, is_typing) + portal.set_typing(user, is_typing) self.previously_typing = now_typing - def filter_matrix_event(self, event: dict) -> None: + def filter_matrix_event(self, event: MatrixEvent) -> bool: sender = event.get("sender", None) if not sender: return False return (sender == self.az.bot_mxid or pu.Puppet.get_id_from_mxid(sender) is not None) - async def try_handle_event(self, evt: dict) -> None: + async def try_handle_event(self, evt: MatrixEvent) -> None: try: await self.handle_event(evt) except Exception: self.log.exception("Error handling manually received Matrix event") - async def handle_event(self, evt: dict) -> None: + async def handle_event(self, evt: MatrixEvent) -> None: if self.filter_matrix_event(evt): return self.log.debug("Received event: %s", evt) evt_type = evt.get("type", "m.unknown") # type: str - room_id = evt.get("room_id", None) # type: str - event_id = evt.get("event_id", None) # type: str - sender = evt.get("sender", None) # type: str - content = evt.get("content", {}) # type: dict + room_id = evt.get("room_id", None) # type: Optional[MatrixRoomId] + event_id = evt.get("event_id", None) # type: Optional[MatrixEventId] + sender = evt.get("sender", None) # type: Optional[MatrixUserId] + content = evt.get("content", {}) # type: Dict if evt_type == "m.room.member": - state_key = evt["state_key"] # type: str - prev_content = evt.get("unsigned", {}).get("prev_content", {}) # type: dict + state_key = evt["state_key"] # type: MatrixUserId + prev_content = evt.get("unsigned", {}).get("prev_content", {}) # type: Dict membership = content.get("membership", "") # type: str prev_membership = prev_content.get("membership", "leave") # type: str if membership == prev_membership: @@ -387,7 +408,7 @@ class MatrixHandler: elif evt_type == "m.room.redaction": await self.handle_redaction(room_id, sender, evt["redacts"]) elif evt_type == "m.room.power_levels": - prev_content = evt.get("unsigned", {}).get("prev_content", {}) # type: dict + prev_content = evt.get("unsigned", {}).get("prev_content", {}) await self.handle_power_levels(room_id, sender, evt["content"], prev_content) elif evt_type in ("m.room.name", "m.room.avatar", "m.room.topic"): await self.handle_room_meta(evt_type, room_id, sender, evt["content"]) diff --git a/mautrix_telegram/portal.py b/mautrix_telegram/portal.py index ff63e0d5..d96dfd98 100644 --- a/mautrix_telegram/portal.py +++ b/mautrix_telegram/portal.py @@ -14,7 +14,7 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Awaitable, Dict, List, Optional, Pattern, Tuple, Union, TYPE_CHECKING +from typing import Awaitable, Dict, List, Optional, Pattern, Tuple, Union, cast, TYPE_CHECKING from collections import deque from datetime import datetime from string import Template @@ -62,7 +62,7 @@ from telethon.tl.types import ( UserFull) from mautrix_appservice import MatrixRequestError, IntentError, AppService, IntentAPI - +from .types import MatrixEventId, MatrixRoomId, MatrixUserId, TelegramId from .context import Context from .db import Portal as DBPortal, Message as DBMessage, TelegramFile as DBTelegramFile from . import puppet as p, user as u, formatter, util @@ -105,18 +105,18 @@ class Portal: by_mxid = {} # type: Dict[str, Portal] by_tgid = {} # type: Dict[Tuple[int, int], Portal] - def __init__(self, tgid: int, peer_type: str, tg_receiver: Optional[int] = None, - mxid: Optional[str] = None, username: Optional[str] = None, + def __init__(self, tgid: TelegramId, peer_type: str, tg_receiver: Optional[int] = None, + mxid: Optional[MatrixRoomId] = None, username: Optional[str] = None, megagroup: Optional[bool] = False, title: Optional[str] = None, about: Optional[str] = None, photo_id: Optional[str] = None, db_instance: DBPortal = None) -> None: - self.mxid = mxid # type: str - self.tgid = tgid # type: int + self.mxid = mxid # type: Optional[MatrixRoomId] + self.tgid = tgid # type: TelegramId self.tg_receiver = tg_receiver or tgid # type: int self.peer_type = peer_type # type: str self.username = username # type: str self.megagroup = megagroup # type: bool - self.title = title # type: str + self.title = title # type: Optional[str] self.about = about # type: str self.photo_id = photo_id # type: str self._db_instance = db_instance # type: DBPortal @@ -161,7 +161,7 @@ class Portal: @property def has_bot(self) -> bool: - return self.bot and self.bot.is_in_chat(self.tgid) + return bool(self.bot and self.bot.is_in_chat(self.tgid)) @property def main_intent(self) -> IntentAPI: @@ -270,8 +270,8 @@ class Portal: else: raise ValueError("Invalid invite identifier given to invite_matrix()") - async def update_matrix_room(self, user: "AbstractUser", entity: TypeChat, direct: bool, - puppet: p.Puppet = None, levels: dict = None, + async def update_matrix_room(self, user: 'AbstractUser', entity: TypeChat, direct: bool, + puppet: p.Puppet = None, levels: Dict = None, users: List[User] = None, participants: List[TypeParticipant] = None) -> None: if not direct: @@ -303,8 +303,8 @@ class Portal: async with self._room_create_lock: return await self._create_matrix_room(user, entity, invites) - async def _create_matrix_room(self, user: "AbstractUser", entity: TypeChat, invites: InviteList - ) -> Optional[str]: + async def _create_matrix_room(self, user: 'AbstractUser', entity: TypeChat, invites: InviteList + ) -> Optional[MatrixRoomId]: direct = self.peer_type == "user" if self.mxid: @@ -369,6 +369,8 @@ class Portal: participants=participants), loop=self.loop) + return self.mxid + def _get_base_power_levels(self, levels: dict = None, entity: TypeChat = None) -> dict: levels = levels or {} power_level_requirement = (0 if self.peer_type == "chat" and not entity.admins_enabled @@ -437,18 +439,19 @@ class Portal: and config["bridge.max_initial_member_sync"] == -1 and (self.megagroup or self.peer_type != "channel")) if trust_member_list: - joined_mxids = await self.main_intent.get_room_members(self.mxid) - for user in joined_mxids: - if user == self.az.bot_mxid: + joined_mxids = cast(List[MatrixUserId], + await self.main_intent.get_room_members(self.mxid)) + for user_mxid in joined_mxids: + if user_mxid == self.az.bot_mxid: continue - puppet_id = p.Puppet.get_id_from_mxid(user) + puppet_id = p.Puppet.get_id_from_mxid(user_mxid) if puppet_id and puppet_id not in allowed_tgids: if self.bot and puppet_id == self.bot.tgid: self.bot.remove_chat(self.tgid) - await self.main_intent.kick(self.mxid, user, + await self.main_intent.kick(self.mxid, user_mxid, "User had left this Telegram chat.") continue - mx_user = u.User.get_by_mxid(user, create=False) + mx_user = u.User.get_by_mxid(user_mxid, create=False) if mx_user and mx_user.is_bot and mx_user.tgid not in allowed_tgids: mx_user.unregister_portal(self) @@ -457,7 +460,7 @@ class Portal: "You had left this Telegram chat.") continue - async def add_telegram_user(self, user_id: int, source: Optional["AbstractUser"] = None + async def add_telegram_user(self, user_id: TelegramId, source: Optional['AbstractUser'] = None ) -> None: puppet = p.Puppet.get(user_id) if source: @@ -470,7 +473,7 @@ class Portal: user.register_portal(self) await self.invite_to_matrix(user.mxid) - async def delete_telegram_user(self, user_id: int, sender: p.Puppet) -> None: + async def delete_telegram_user(self, user_id: TelegramId, sender: p.Puppet) -> None: puppet = p.Puppet.get(user_id) user = u.User.get_by_tgid(user_id) kick_message = (f"Kicked by {sender.displayname}" @@ -568,8 +571,9 @@ class Portal: return True return False - async def _get_users(self, user: "AbstractUser", entity: Union[TypeInputPeer, InputUser, - TypeChat, TypeUser] + async def _get_users(self, + user: 'AbstractUser', + entity: Union[TypeInputPeer, InputUser, TypeChat, TypeUser] ) -> Tuple[List[TypeUser], List[TypeParticipant]]: if self.peer_type == "chat": chat = await user.client(GetFullChatRequest(chat_id=self.tgid)) @@ -588,7 +592,7 @@ class Portal: entity, ChannelParticipantsRecent(), offset=0, limit=limit, hash=0)) return response.users, response.participants elif limit > 200 or limit == -1: - users, participants = [], [] + users, participants = [], [] # type: Tuple[List[TypeUser], List[TypeParticipant]] offset = 0 remaining_quota = limit if limit > 0 else 1000000 query = (ChannelParticipantsSearch("") if limit == -1 @@ -609,6 +613,7 @@ class Portal: return [], [] elif self.peer_type == "user": return [entity], [] + return [], [] async def get_invite_link(self, user: 'u.User') -> str: if self.peer_type == "user": @@ -688,7 +693,7 @@ class Portal: return "" async def _get_state_change_message(self, event: str, user: 'u.User', - arguments: Optional[dict] = None) -> Optional[dict]: + arguments: Optional[Dict] = None) -> Optional[Dict]: tpl = config[f"bridge.state_event_formats.{event}"] if len(tpl) == 0: # Empty format means they don't want the message @@ -724,11 +729,11 @@ class Portal: or user.mxid_localpart) def set_typing(self, user: 'u.User', typing: bool = True, - action=SendMessageTypingAction) -> None: + action: type = SendMessageTypingAction) -> bool: return user.client(SetTypingRequest( self.peer, action() if typing else SendMessageCancelAction())) - async def mark_read(self, user: 'u.User', event_id: str) -> None: + async def mark_read(self, user: 'u.User', event_id: MatrixEventId) -> None: if user.is_bot: return space = self.tgid if self.peer_type == "channel" else user.tgid @@ -743,7 +748,8 @@ class Portal: else: await user.client(ReadMessageHistoryRequest(peer=self.peer, max_id=message.tgid)) - async def leave_matrix(self, user: 'u.User', source: 'u.User', event_id: str) -> None: + async def leave_matrix(self, user: 'u.User', source: 'u.User', event_id: MatrixEventId + ) -> None: if await user.needs_relaybot(self): async with self.require_send_lock(self.bot.tgid): message = await self._get_state_change_message("leave", user) @@ -798,7 +804,7 @@ class Portal: # We'll just assume the user is already in the chat. pass - async def _apply_msg_format(self, sender: 'u.User', msgtype: str, message: dict) -> None: + async def _apply_msg_format(self, sender: 'u.User', msgtype: str, message: Dict) -> None: if "formatted_body" not in message: message["format"] = "org.matrix.custom.html" message["formatted_body"] = escape_html(message.get("body", "")) @@ -823,7 +829,7 @@ class Portal: await self._apply_msg_format(sender, msgtype, message) @staticmethod - def _matrix_event_to_entities(event: dict) -> Tuple[str, Optional[List[TypeMessageEntity]]]: + def _matrix_event_to_entities(event: Dict) -> Tuple[str, Optional[List[TypeMessageEntity]]]: try: if event.get("format", None) == "org.matrix.custom.html": message, entities = formatter.matrix_to_telegram(event.get("formatted_body", "")) @@ -851,7 +857,8 @@ class Portal: return None async def _handle_matrix_text(self, sender_id: int, event_id: str, space: int, - client: "MautrixTelegramClient", message: dict, reply_to: int) -> None: + client: 'MautrixTelegramClient', message: Dict, reply_to: int + ) -> None: lock = self.require_send_lock(sender_id) async with lock: response = await client.send_message(self.peer, message, reply_to=reply_to, @@ -859,7 +866,8 @@ class Portal: self._add_telegram_message_to_db(event_id, space, response) async def _handle_matrix_file(self, msgtype: str, sender_id: int, event_id: str, space: int, - client: "MautrixTelegramClient", message: dict, reply_to: int) -> None: + client: 'MautrixTelegramClient', message: dict, reply_to: int + ) -> None: file = await self.main_intent.download_file(message["url"]) info = message.get("info", {}) @@ -893,7 +901,7 @@ class Portal: self._add_telegram_message_to_db(event_id, space, response) async def _handle_matrix_location(self, sender_id: int, event_id: str, space: int, - client: "MautrixTelegramClient", message: dict, + client: 'MautrixTelegramClient', message: Dict, reply_to: int) -> None: try: lat, long = message["geo_uri"][len("geo:"):].split(",") @@ -901,13 +909,13 @@ class Portal: except (KeyError, ValueError): self.log.exception("Failed to parse location") return None - message, entities = self._matrix_event_to_entities(message) + caption, entities = self._matrix_event_to_entities(message) media = MessageMediaGeo(geo=GeoPoint(lat, long, access_hash=0)) lock = self.require_send_lock(sender_id) async with lock: response = await client.send_media(self.peer, media, reply_to=reply_to, - caption=message, entities=entities) + caption=caption, entities=entities) self._add_telegram_message_to_db(event_id, space, response) def _add_telegram_message_to_db(self, event_id: str, space: int, @@ -963,17 +971,18 @@ class Portal: except ChatNotModifiedError: pass - async def handle_matrix_deletion(self, deleter: 'u.User', event_id: str) -> None: - deleter = deleter if not await deleter.needs_relaybot(self) else self.bot - space = self.tgid if self.peer_type == "channel" else deleter.tgid + async def handle_matrix_deletion(self, deleter: 'u.User', event_id: MatrixEventId) -> None: + real_deleter = deleter if not await deleter.needs_relaybot(self) else self.bot + space = self.tgid if self.peer_type == "channel" else real_deleter.tgid message = DBMessage.query.filter(DBMessage.mxid == event_id, DBMessage.tg_space == space, DBMessage.mx_room == self.mxid).one_or_none() if not message: return - await deleter.client.delete_messages(self.peer, [message.tgid]) + await real_deleter.client.delete_messages(self.peer, [message.tgid]) - async def _update_telegram_power_level(self, sender: 'u.User', user_id: int, level: int) -> None: + async def _update_telegram_power_level(self, sender: 'u.User', user_id: TelegramId, + level: int) -> None: if self.peer_type == "chat": await sender.client(EditChatAdminRequest( chat_id=self.tgid, user_id=user_id, is_admin=level >= 50)) @@ -989,7 +998,8 @@ class Portal: EditAdminRequest(channel=await self.get_input_entity(sender), user_id=user_id, admin_rights=rights)) - async def handle_matrix_power_levels(self, sender: 'u.User', new_users: Dict[str, int], + async def handle_matrix_power_levels(self, sender: 'u.User', + new_users: Dict[MatrixUserId, int], old_users: Dict[str, int]) -> None: # TODO handle all power level changes and bridge exact admin rights to supergroups/channels for user, level in new_users.items(): @@ -1167,7 +1177,7 @@ class Portal: return None async def handle_telegram_photo(self, source: "AbstractUser", intent: IntentAPI, evt: Message, - relates_to=None) -> None: + relates_to: Dict = {}) -> None: largest_size = self._get_largest_photo_size(evt.media.photo) file = await util.transfer_file_to_matrix(self.db, source.client, intent, largest_size.location) @@ -1197,7 +1207,7 @@ class Portal: external_url=self.get_external_url(evt)) @staticmethod - def _parse_telegram_document_attributes(attributes: List[TypeDocumentAttribute]) -> dict: + def _parse_telegram_document_attributes(attributes: List[TypeDocumentAttribute]) -> Dict: attrs = { "name": None, "mime_type": None, @@ -1205,7 +1215,7 @@ class Portal: "sticker_alt": None, "width": None, "height": None, - } + } # type: Dict for attr in attributes: if isinstance(attr, DocumentAttributeFilename): attrs["name"] = attrs["name"] or attr.file_name @@ -1218,8 +1228,8 @@ class Portal: return attrs @staticmethod - def _parse_telegram_document_meta(evt: Message, file: DBTelegramFile, attrs: dict - ) -> Tuple[dict, str]: + def _parse_telegram_document_meta(evt: Message, file: DBTelegramFile, attrs: Dict + ) -> Tuple[Dict, str]: document = evt.media.document name = evt.message or attrs["name"] if attrs["is_sticker"]: @@ -1253,7 +1263,7 @@ class Portal: async def handle_telegram_document(self, source: "AbstractUser", intent: IntentAPI, evt: Message, - relates_to: dict = None) -> Optional[dict]: + relates_to: dict = None) -> Optional[Dict]: document = evt.media.document attrs = self._parse_telegram_document_attributes(document.attributes) @@ -1521,9 +1531,9 @@ class Portal: else: self.log.debug("Unhandled Telegram action in %s: %s", self.title, action) - async def set_telegram_admin(self, user_id: int) -> None: + async def set_telegram_admin(self, user_id: TelegramId) -> None: puppet = p.Puppet.get(user_id) - user = await u.User.get_by_tgid(user_id) + user = u.User.get_by_tgid(user_id) levels = await self.main_intent.get_power_levels(self.mxid) if user: @@ -1558,7 +1568,7 @@ class Portal: await self.update_telegram_pin() @staticmethod - def _get_level_from_participant(participant: TypeParticipant, _) -> int: + def _get_level_from_participant(participant: TypeParticipant, _: Dict) -> int: # TODO use the power level requirements to get better precision in channels if isinstance(participant, (ChatParticipantAdmin, ChannelParticipantAdmin)): return 50 @@ -1599,7 +1609,7 @@ class Portal: except KeyError: return 50 - def _participants_to_power_levels(self, participants: List[TypeParticipant], levels: dict + def _participants_to_power_levels(self, participants: List[TypeParticipant], levels: Dict ) -> bool: bot_level = self._get_bot_level(levels) if bot_level < self._get_powerlevel_level(levels): @@ -1654,7 +1664,7 @@ class Portal: mxid=self.mxid, username=self.username, megagroup=self.megagroup, title=self.title, about=self.about, photo_id=self.photo_id) - def migrate_and_save(self, new_id: int) -> None: + def migrate_and_save(self, new_id: TelegramId) -> None: existing = DBPortal.query.get(self.tgid_full) if existing: self.db.delete(existing) @@ -1701,7 +1711,7 @@ class Portal: # region Class instance lookup @classmethod - def get_by_mxid(cls, mxid: str) -> Optional["Portal"]: + def get_by_mxid(cls, mxid: MatrixRoomId) -> Optional['Portal']: try: return cls.by_mxid[mxid] except KeyError: @@ -1721,7 +1731,7 @@ class Portal: return None @classmethod - def find_by_username(cls, username: str) -> Optional["Portal"]: + def find_by_username(cls, username: str) -> Optional['Portal']: if not username: return None @@ -1729,15 +1739,15 @@ class Portal: if portal.username and portal.username.lower() == username.lower(): return portal - portal = DBPortal.query.filter(DBPortal.username == username).one_or_none() - if portal: - return cls.from_db(portal) + dbportal = DBPortal.query.filter(DBPortal.username == username).one_or_none() + if dbportal: + return cls.from_db(dbportal) return None @classmethod - def get_by_tgid(cls, tgid: int, tg_receiver: int = None, peer_type: str = None - ) -> Optional["Portal"]: + def get_by_tgid(cls, tgid: TelegramId, tg_receiver: Optional[TelegramId] = None, + peer_type: str = None) -> Optional['Portal']: tg_receiver = tg_receiver or tgid tgid_full = (tgid, tg_receiver) try: @@ -1758,8 +1768,10 @@ class Portal: return None @classmethod - def get_by_entity(cls, entity: Union[TypeChat, TypePeer, TypeUser, TypeUserFull, TypeInputPeer], - receiver_id: int = None, create: bool = True) -> Optional["Portal"]: + def get_by_entity(cls, entity: Union[TypeChat, TypePeer, TypeUser, TypeUserFull, + TypeInputPeer], + receiver_id: Optional[TelegramId] = None, create: bool = True + ) -> Optional['Portal']: entity_type = type(entity) if entity_type in {Chat, ChatFull}: type_name = "chat" @@ -1790,7 +1802,7 @@ class Portal: def init(context: Context) -> None: global config - Portal.az, Portal.db, config, Portal.loop, Portal.bot = context + Portal.az, Portal.db, config, Portal.loop, Portal.bot = context.core Portal.bridge_notices = config["bridge.bridge_notices"] Portal.filter_mode = config["bridge.filter.mode"] Portal.filter_list = config["bridge.filter.list"] diff --git a/mautrix_telegram/puppet.py b/mautrix_telegram/puppet.py index 4e7a34d6..c377d298 100644 --- a/mautrix_telegram/puppet.py +++ b/mautrix_telegram/puppet.py @@ -14,17 +14,19 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Optional, Awaitable, Pattern, Dict, List, TYPE_CHECKING +from typing import Awaitable, Coroutine, Dict, List, NewType, Optional, Pattern, TYPE_CHECKING from difflib import SequenceMatcher import re import logging import asyncio +from enum import Enum from sqlalchemy import orm -from telethon.tl.types import UserProfilePhoto +from telethon.tl.types import UserProfilePhoto, User, FileLocation from mautrix_appservice import AppService, IntentAPI, IntentError, MatrixRequestError +from .types import MatrixUserId, TelegramId from .db import Puppet as DBPuppet from . import util @@ -32,6 +34,11 @@ if TYPE_CHECKING: from .matrix import MatrixHandler from .config import Config from .context import Context + from . import user as u + from .abstract_user import AbstractUser + + +PuppetError = Enum('PuppetError', 'Success OnlyLoginSelf InvalidAccessToken') config = None # type: Config @@ -45,85 +52,98 @@ class Puppet: mxid_regex = None # type: Pattern username_template = None # type: str hs_domain = None # type: str - cache = {} # type: Dict[str, Puppet] + cache = {} # type: Dict[TelegramId, Puppet] by_custom_mxid = {} # type: Dict[str, Puppet] - def __init__(self, id=None, access_token=None, custom_mxid=None, username=None, - displayname=None, displayname_source=None, photo_id=None, is_bot=None, - is_registered=False, db_instance=None) -> None: - self.id = id - self.access_token = access_token - self.custom_mxid = custom_mxid - self.is_real_user = self.custom_mxid and self.access_token - self.default_mxid = self.get_mxid_from_id(self.id) - self.mxid = self.custom_mxid or self.default_mxid + def __init__(self, + id: TelegramId, + access_token: Optional[str] = None, + custom_mxid: Optional[MatrixUserId] = None, + username: Optional[str] = None, + displayname: Optional[str] = None, + displayname_source: Optional[TelegramId] = None, + photo_id: Optional[str] = None, + is_bot: bool = False, + is_registered: bool = False, + db_instance: Optional[DBPuppet] = None) -> None: + self.id = id # type: TelegramId + self.access_token = access_token # type: Optional[str] + self.custom_mxid = custom_mxid # type: Optional[MatrixUserId] + self.default_mxid = self.get_mxid_from_id(self.id) # type: MatrixUserId - self.username = username - self.displayname = displayname - self.displayname_source = displayname_source - self.photo_id = photo_id - self.is_bot = is_bot - self.is_registered = is_registered - self._db_instance = db_instance + self.username = username # type: Optional[str] + self.displayname = displayname # type: Optional[str] + self.displayname_source = displayname_source # type: Optional[TelegramId] + self.photo_id = photo_id # type: Optional[str] + self.is_bot = is_bot # type: bool + self.is_registered = is_registered # type: bool + self._db_instance = db_instance # type: Optional[DBPuppet] self.default_mxid_intent = self.az.intent.user(self.default_mxid) - self.intent = None # type: IntentAPI - self.refresh_intents() + self.intent = self._fresh_intent() # type: IntentAPI self.cache[id] = self if self.custom_mxid: self.by_custom_mxid[self.custom_mxid] = self @property - def tgid(self) -> None: + def mxid(self): + return self.custom_mxid or self.default_mxid + + @property + def tgid(self) -> TelegramId: return self.id + @property + def is_real_user(self) -> bool: + """ Is True when the puppet is a real Matrix user. """ + return bool(self.custom_mxid and self.access_token) + @staticmethod - async def is_logged_in() -> None: + async def is_logged_in() -> bool: + """ Is True if the puppet is logged in. """ return True # region Custom puppet management - def refresh_intents(self) -> None: - self.is_real_user = self.custom_mxid and self.access_token - self.intent = (self.az.intent.user(self.custom_mxid, self.access_token) - if self.is_real_user else self.default_mxid_intent) + def _fresh_intent(self) -> IntentAPI: + return (self.az.intent.user(self.custom_mxid, self.access_token) + if self.is_real_user else self.default_mxid_intent) - async def switch_mxid(self, access_token, mxid) -> None: + async def switch_mxid(self, access_token: str, mxid: MatrixUserId) -> PuppetError: prev_mxid = self.custom_mxid self.custom_mxid = mxid self.access_token = access_token - self.refresh_intents() + self.intent = self._fresh_intent() err = await self.init_custom_mxid() - if err != 0: + if err != PuppetError.Success: return err try: - del self.by_custom_mxid[prev_mxid] + del self.by_custom_mxid[prev_mxid] # type: ignore except KeyError: pass - self.mxid = self.custom_mxid or self.default_mxid if self.mxid != self.default_mxid: self.by_custom_mxid[self.mxid] = self await self.leave_rooms_with_default_user() self.save() - return 0 + return PuppetError.Success - async def init_custom_mxid(self) -> None: + async def init_custom_mxid(self) -> PuppetError: if not self.is_real_user: - return 0 + return PuppetError.Success mxid = await self.intent.whoami() if not mxid or mxid != self.custom_mxid: self.custom_mxid = None self.access_token = None - self.refresh_intents() + self.intent = self._fresh_intent() if mxid != self.custom_mxid: - return 2 - return 1 + return PuppetError.OnlyLoginSelf + return PuppetError.InvalidAccessToken if config["bridge.sync_with_custom_puppets"]: asyncio.ensure_future(self.sync(), loop=self.loop) - return 0 + return PuppetError.Success async def leave_rooms_with_default_user(self) -> None: for room_id in await self.default_mxid_intent.get_joined_rooms(): @@ -159,7 +179,7 @@ class Puppet: }, }) - def filter_events(self, events) -> None: + def filter_events(self, events: List[Dict]) -> List: new_events = [] for event in events: evt_type = event.get("type", None) @@ -186,18 +206,18 @@ class Puppet: new_events.append(event) return new_events - def handle_sync(self, presence, ephemeral) -> None: - presence = [self.mx.try_handle_event(event) for event in presence] + def handle_sync(self, presence: List, ephemeral: Dict) -> None: + presence_events = [self.mx.try_handle_event(event) for event in presence] for room_id, events in ephemeral.items(): for event in events: event["room_id"] = room_id - ephemeral = [self.mx.try_handle_event(event) - for events in ephemeral.values() - for event in self.filter_events(events)] + ephemeral_events = [self.mx.try_handle_event(event) + for events in ephemeral.values() + for event in self.filter_events(events)] - events = ephemeral + presence + events = ephemeral_events + presence_events # List[Callable[[int], Awaitable[None]]] coro = asyncio.gather(*events, loop=self.loop) asyncio.ensure_future(coro, loop=self.loop) @@ -220,13 +240,14 @@ class Puppet: while access_token_at_start == self.access_token: try: sync_resp = await self.intent.client.sync(filter=filter_id, since=next_batch, - set_presence="offline") + set_presence="offline") # type: Dict errors = 0 if next_batch is not None: - presence = sync_resp.get("presence", {}).get("events", []) + presence = sync_resp.get("presence", {}).get("events", []) # type: List ephemeral = {room: data.get("ephemeral", {}).get("events", []) for room, data - in sync_resp.get("rooms", {}).get("join", {}).items()} + in sync_resp.get("rooms", {}).get("join", {}).items() + } # type: Dict self.handle_sync(presence, ephemeral) next_batch = sync_resp.get("next_batch", None) except MatrixRequestError as e: @@ -241,19 +262,19 @@ class Puppet: # region DB conversion @property - def db_instance(self) -> None: + def db_instance(self) -> DBPuppet: if not self._db_instance: self._db_instance = self.new_db_instance() return self._db_instance - def new_db_instance(self) -> None: + def new_db_instance(self) -> DBPuppet: return DBPuppet(id=self.id, access_token=self.access_token, custom_mxid=self.custom_mxid, username=self.username, displayname=self.displayname, displayname_source=self.displayname_source, photo_id=self.photo_id, is_bot=self.is_bot, matrix_registered=self.is_registered) @classmethod - def from_db(cls, db_puppet) -> None: + def from_db(cls, db_puppet: DBPuppet) -> 'Puppet': return Puppet(db_puppet.id, db_puppet.access_token, db_puppet.custom_mxid, db_puppet.username, db_puppet.displayname, db_puppet.displayname_source, db_puppet.photo_id, db_puppet.is_bot, db_puppet.matrix_registered, @@ -272,16 +293,16 @@ class Puppet: # endregion # region Info updating - def similarity(self, query) -> None: + def similarity(self, query: str) -> int: username_similarity = (SequenceMatcher(None, self.username, query).ratio() if self.username else 0) displayname_similarity = (SequenceMatcher(None, self.displayname, query).ratio() if self.displayname else 0) similarity = max(username_similarity, displayname_similarity) - return round(similarity * 1000) / 10 + return int(round(similarity * 1000) / 10) @staticmethod - def get_displayname(info, enable_format=True) -> None: + def get_displayname(info: User, enable_format: bool = True) -> str: data = { "phone number": info.phone if hasattr(info, "phone") else None, "username": info.username, @@ -308,7 +329,7 @@ class Puppet: return config.get("bridge.displayname_template", "{displayname} (Telegram)").format( displayname=name) - async def update_info(self, source, info) -> None: + async def update_info(self, source: 'AbstractUser', info: User) -> None: changed = False if self.username != info.username: self.username = info.username @@ -323,24 +344,26 @@ class Puppet: if changed: self.save() - async def update_displayname(self, source, info) -> None: + async def update_displayname(self, source: 'AbstractUser', info: User) -> bool: ignore_source = (not source.is_relaybot and self.displayname_source is not None and self.displayname_source != source.tgid) if ignore_source: - return + return False displayname = self.get_displayname(info) if displayname != self.displayname: await self.default_mxid_intent.set_display_name(displayname) self.displayname = displayname - self.displayname_source = source.tgid + self.displayname_source = TelegramId(source.tgid) return True elif source.is_relaybot or self.displayname_source is None: - self.displayname_source = source.tgid + self.displayname_source = TelegramId(source.tgid) return True + else: + return False - async def update_avatar(self, source, photo) -> None: + async def update_avatar(self, source: 'AbstractUser', photo: FileLocation) -> bool: photo_id = f"{photo.volume_id}-{photo.local_id}" if self.photo_id != photo_id: file = await util.transfer_file_to_matrix(self.db, source.client, @@ -355,7 +378,7 @@ class Puppet: # region Getters @classmethod - def get(cls, tgid, create=True) -> "Optional[Puppet]": + def get(cls, tgid: TelegramId, create: bool = True) -> Optional['Puppet']: try: return cls.cache[tgid] except KeyError: @@ -374,12 +397,15 @@ class Puppet: return None @classmethod - def get_by_mxid(cls, mxid, create=True) -> "Optional[Puppet]": + def get_by_mxid(cls, mxid: MatrixUserId, create: bool = True) -> Optional['Puppet']: tgid = cls.get_id_from_mxid(mxid) - return cls.get(tgid, create) if tgid else None + if tgid: + return cls.get(tgid, create) + + return None @classmethod - def get_by_custom_mxid(cls, mxid) -> None: + def get_by_custom_mxid(cls, mxid: MatrixUserId) -> Optional['Puppet']: if not mxid: raise ValueError("Matrix ID can't be empty") @@ -396,25 +422,25 @@ class Puppet: return None @classmethod - def get_all_with_custom_mxid(cls) -> None: + def get_all_with_custom_mxid(cls) -> List['Puppet']: return [cls.by_custom_mxid[puppet.mxid] if puppet.custom_mxid in cls.by_custom_mxid else cls.from_db(puppet) for puppet in DBPuppet.query.filter(DBPuppet.custom_mxid is not None).all()] @classmethod - def get_id_from_mxid(cls, mxid) -> None: + def get_id_from_mxid(cls, mxid: MatrixUserId) -> Optional[TelegramId]: match = cls.mxid_regex.match(mxid) if match: - return int(match.group(1)) + return TelegramId(int(match.group(1))) return None @classmethod - def get_mxid_from_id(cls, tgid) -> None: - return f"@{cls.username_template.format(userid=tgid)}:{cls.hs_domain}" + def get_mxid_from_id(cls, tgid: TelegramId) -> MatrixUserId: + return MatrixUserId(f"@{cls.username_template.format(userid=tgid)}:{cls.hs_domain}") @classmethod - def find_by_username(cls, username) -> "Optional[Puppet]": + def find_by_username(cls, username: str) -> Optional['Puppet']: if not username: return None @@ -422,14 +448,14 @@ class Puppet: if puppet.username and puppet.username.lower() == username.lower(): return puppet - puppet = DBPuppet.query.filter(DBPuppet.username == username).one_or_none() - if puppet: - return cls.from_db(puppet) + dbpuppet = DBPuppet.query.filter(DBPuppet.username == username).one_or_none() + if dbpuppet: + return cls.from_db(dbpuppet) return None @classmethod - def find_by_displayname(cls, displayname) -> "Optional[Puppet]": + def find_by_displayname(cls, displayname: str) -> Optional['Puppet']: if not displayname: return None @@ -437,17 +463,17 @@ class Puppet: if puppet.displayname and puppet.displayname == displayname: return puppet - puppet = DBPuppet.query.filter(DBPuppet.displayname == displayname).one_or_none() - if puppet: - return cls.from_db(puppet) + dbpuppet = DBPuppet.query.filter(DBPuppet.displayname == displayname).one_or_none() + if dbpuppet: + return cls.from_db(dbpuppet) return None # endregion -def init(context: "Context") -> List[Awaitable[int]]: +def init(context: 'Context') -> List[Coroutine]: # [None, None, PuppetError] global config - Puppet.az, Puppet.db, config, Puppet.loop, _ = context + Puppet.az, Puppet.db, config, Puppet.loop, _ = context.core Puppet.mx = context.mx Puppet.username_template = config.get("bridge.username_template", "telegram_{userid}") Puppet.hs_domain = config["homeserver"]["domain"] diff --git a/mautrix_telegram/scripts/telematrix_import/__main__.py b/mautrix_telegram/scripts/telematrix_import/__main__.py index 2de531c7..119c7689 100644 --- a/mautrix_telegram/scripts/telematrix_import/__main__.py +++ b/mautrix_telegram/scripts/telematrix_import/__main__.py @@ -40,7 +40,7 @@ telematrix_db_engine.dispose() portals = {} chats = {} messages = {} -puppets = {} +puppets = {} # Dict[int, Puppet] for chat_link in chat_links: if type(chat_link.tg_room) is str: diff --git a/mautrix_telegram/sqlstatestore.py b/mautrix_telegram/sqlstatestore.py index c60af89c..5cf96b58 100644 --- a/mautrix_telegram/sqlstatestore.py +++ b/mautrix_telegram/sqlstatestore.py @@ -20,37 +20,39 @@ from sqlalchemy import orm from mautrix_appservice import StateStore +from .types import MatrixUserId, MatrixRoomId from . import puppet as pu from .db import RoomState, UserProfile class SQLStateStore(StateStore): - def __init__(self, db) -> None: + def __init__(self, db: orm.Session) -> None: super().__init__() self.db = db # type: orm.Session self.profile_cache = {} # type: Dict[Tuple[str, str], UserProfile] self.room_state_cache = {} # type: Dict[str, RoomState] @staticmethod - def is_registered(user: str) -> bool: + def is_registered(user: MatrixUserId) -> bool: puppet = pu.Puppet.get_by_mxid(user) return puppet.is_registered if puppet else False @staticmethod - def registered(user: str) -> None: + def registered(user: MatrixUserId) -> None: puppet = pu.Puppet.get_by_mxid(user) if puppet: puppet.is_registered = True puppet.save() - def update_state(self, event: dict) -> None: + def update_state(self, event: Dict) -> None: event_type = event["type"] if event_type == "m.room.power_levels": self.set_power_levels(event["room_id"], event["content"]) elif event_type == "m.room.member": self.set_member(event["room_id"], event["state_key"], event["content"]) - def _get_user_profile(self, room_id: str, user_id: str, create: bool = True) -> UserProfile: + def _get_user_profile(self, room_id: MatrixRoomId, user_id: MatrixUserId, create: bool = True + ) -> UserProfile: key = (room_id, user_id) try: return self.profile_cache[key] @@ -67,22 +69,22 @@ class SQLStateStore(StateStore): self.profile_cache[key] = profile return profile - def get_member(self, room: str, user: str) -> dict: + def get_member(self, room: MatrixRoomId, user: MatrixUserId) -> Dict: return self._get_user_profile(room, user).dict() - def set_member(self, room: str, user: str, member: dict) -> None: + def set_member(self, room: MatrixRoomId, user: MatrixUserId, member: Dict) -> None: profile = self._get_user_profile(room, user) profile.membership = member.get("membership", profile.membership or "leave") profile.displayname = member.get("displayname", profile.displayname) profile.avatar_url = member.get("avatar_url", profile.avatar_url) self.db.commit() - def set_membership(self, room: str, user: str, membership: str) -> None: + def set_membership(self, room: MatrixRoomId, user: MatrixUserId, membership: str) -> None: self.set_member(room, user, { "membership": membership, }) - def _get_room_state(self, room_id: str, create: bool = True) -> RoomState: + def _get_room_state(self, room_id: MatrixRoomId, create: bool = True) -> RoomState: try: return self.room_state_cache[room_id] except KeyError: @@ -96,13 +98,13 @@ class SQLStateStore(StateStore): self.room_state_cache[room_id] = room return room - def has_power_levels(self, room: str) -> bool: + def has_power_levels(self, room: MatrixRoomId) -> bool: return self._get_room_state(room).has_power_levels - def get_power_levels(self, room: str) -> dict: + def get_power_levels(self, room: MatrixRoomId) -> Dict: return self._get_room_state(room).power_levels - def set_power_level(self, room: str, user: str, level: int) -> None: + def set_power_level(self, room: MatrixRoomId, user: MatrixUserId, level: int) -> None: room_state = self._get_room_state(room) power_levels = room_state.power_levels if not power_levels: @@ -114,7 +116,7 @@ class SQLStateStore(StateStore): room_state.power_levels = power_levels self.db.commit() - def set_power_levels(self, room: str, content: dict) -> None: + def set_power_levels(self, room: MatrixRoomId, content: Dict) -> None: state = self._get_room_state(room) state.power_levels = content self.db.commit() diff --git a/mautrix_telegram/types.py b/mautrix_telegram/types.py new file mode 100644 index 00000000..50a0a9ca --- /dev/null +++ b/mautrix_telegram/types.py @@ -0,0 +1,10 @@ +from typing import Dict, NewType + +# MatrixId = NewType('MatrixId', str) +MatrixUserId = NewType('MatrixUserId', str) +MatrixRoomId = NewType('MatrixRoomId', str) +MatrixEventId = NewType('MatrixEventId', str) + +MatrixEvent = NewType('MatrixEvent', Dict) + +TelegramId = NewType('TelegramId', int) diff --git a/mautrix_telegram/user.py b/mautrix_telegram/user.py index 37cb15ad..19375939 100644 --- a/mautrix_telegram/user.py +++ b/mautrix_telegram/user.py @@ -14,7 +14,7 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Awaitable, Dict, List, Match, Optional, Tuple, TYPE_CHECKING +from typing import Coroutine, Dict, List, Match, Optional, Tuple, cast, TYPE_CHECKING import logging import asyncio import re @@ -28,6 +28,7 @@ from telethon.tl.functions.contacts import GetContactsRequest, SearchRequest from telethon.tl.functions.account import UpdateStatusRequest from mautrix_appservice import MatrixRequestError +from .types import MatrixUserId, TelegramId from .db import User as DBUser, Contact as DBContact, Portal as DBPortal from .abstract_user import AbstractUser from . import portal as po, puppet as pu @@ -46,23 +47,23 @@ class User(AbstractUser): by_mxid = {} # type: Dict[str, User] by_tgid = {} # type: Dict[int, User] - def __init__(self, mxid: str, tgid: Optional[int] = None, username: Optional[str] = None, - db_contacts: Optional[List[DBContact]] = None, saved_contacts: int = 0, - is_bot: bool = False, db_portals: Optional[List[DBPortal]] = None, + def __init__(self, mxid: MatrixUserId, tgid: Optional[TelegramId] = None, + username: Optional[str] = None, db_contacts: Optional[List[DBContact]] = None, + saved_contacts: int = 0, is_bot: bool = False, db_portals: List[DBPortal] = [], db_instance: Optional[DBUser] = None) -> None: super().__init__() - self.mxid = mxid # type: str - self.tgid = tgid # type: int + self.mxid = mxid # type: MatrixUserId + self.tgid = tgid # type: TelegramId self.is_bot = is_bot # type: bool self.username = username # type: str self.contacts = [] # type: List[pu.Puppet] self.saved_contacts = saved_contacts # type: int self.db_contacts = db_contacts # type: List[DBContact] self.portals = {} # type: Dict[Tuple[int, int], po.Portal] - self.db_portals = db_portals # type: List[DBPortal] - self._db_instance = db_instance # type: DBUser + self.db_portals = db_portals or [] # type: List[DBPortal] + self._db_instance = db_instance # type: Optional[DBUser] - self.command_status = None # type: dict + self.command_status = None # type: Dict (self.relaybot_whitelisted, self.whitelisted, @@ -169,9 +170,9 @@ class User(AbstractUser): except Exception: self.log.exception("Failed to run post-login functions for %s", self.mxid) - async def update(self, update: TypeUpdate) -> None: + async def update(self, update: TypeUpdate) -> bool: if not self.is_bot: - return + return False if isinstance(update, (UpdateNewMessage, UpdateNewChannelMessage)): message = update.message @@ -185,19 +186,22 @@ class User(AbstractUser): elif isinstance(update, UpdateShortMessage): portal = po.Portal.get_by_tgid(update.user_id, self.tgid, "user") else: - return + return False - self.register_portal(portal) + if portal: + self.register_portal(portal) + + return True # endregion # region Telegram actions that need custom methods - def ensure_started(self, even_if_no_session: bool = False) -> "Awaitable[User]": - return super().ensure_started(even_if_no_session) + def ensure_started(self, even_if_no_session: bool = False) -> Coroutine[None, None, 'User']: + return cast(Coroutine[None, None, 'User'], super().ensure_started(even_if_no_session)) - def set_presence(self, online: bool = True) -> None: + def set_presence(self, online: bool = True) -> bool: if self.is_bot: - return + return False return self.client(UpdateStatusRequest(offline=not online)) async def update_info(self, info: TLUser = None) -> None: @@ -215,7 +219,7 @@ class User(AbstractUser): if changed: self.save() - async def log_out(self) -> None: + async def log_out(self) -> bool: puppet = pu.Puppet.get(self.tgid) if puppet.is_real_user: await puppet.switch_mxid(None, None) @@ -328,7 +332,7 @@ class User(AbstractUser): # region Class instance lookup @classmethod - def get_by_mxid(cls, mxid: str, create: bool=True) -> "Optional[User]": + def get_by_mxid(cls, mxid: MatrixUserId, create: bool = True) -> Optional['User']: if not mxid: raise ValueError("Matrix ID can't be empty") @@ -351,7 +355,7 @@ class User(AbstractUser): return None @classmethod - def get_by_tgid(cls, tgid: int) -> "Optional[User]": + def get_by_tgid(cls, tgid: int) -> Optional['User']: try: return cls.by_tgid[tgid] except KeyError: @@ -365,7 +369,7 @@ class User(AbstractUser): return None @classmethod - def find_by_username(cls, username: str) -> "Optional[User]": + def find_by_username(cls, username: str) -> Optional['User']: if not username: return None @@ -381,7 +385,7 @@ class User(AbstractUser): # endregion -def init(context: "Context") -> List[Awaitable[User]]: +def init(context: 'Context') -> List[Coroutine]: # [None, None, AbstractUser] global config config = context.config diff --git a/mautrix_telegram/util/format_duration.py b/mautrix_telegram/util/format_duration.py index 076eb720..44d16550 100644 --- a/mautrix_telegram/util/format_duration.py +++ b/mautrix_telegram/util/format_duration.py @@ -17,10 +17,10 @@ def format_duration(seconds: int) -> str: - def pluralize(count, singular) -> None: + def pluralize(count: int, singular: str) -> str: return singular if count == 1 else singular + "s" - def include(count, word) -> None: + def include(count: int, word: str) -> str: return f"{count} {pluralize(count, word)}" if count > 0 else "" minutes, seconds = divmod(seconds, 60) diff --git a/mautrix_telegram/util/signed_token.py b/mautrix_telegram/util/signed_token.py index 13281012..febb2aa4 100644 --- a/mautrix_telegram/util/signed_token.py +++ b/mautrix_telegram/util/signed_token.py @@ -14,7 +14,7 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Optional +from typing import Dict, Optional import json import base64 import hashlib @@ -28,13 +28,13 @@ def _get_checksum(key: str, payload: bytes) -> str: return checksum -def sign_token(key: str, payload: dict) -> str: - payload = base64.urlsafe_b64encode(json.dumps(payload).encode("utf-8")) - checksum = _get_checksum(key, payload) - return f"{checksum}:{payload.decode('utf-8')}" +def sign_token(key: str, payload: Dict) -> str: + payload_b64 = base64.urlsafe_b64encode(json.dumps(payload).encode("utf-8")) + checksum = _get_checksum(key, payload_b64) + return f"{checksum}:{payload_b64.decode('utf-8')}" -def verify_token(key: str, data: str) -> Optional[dict]: +def verify_token(key: str, data: str) -> Optional[Dict]: if not data: return None diff --git a/mautrix_telegram/web/common/auth_api.py b/mautrix_telegram/web/common/auth_api.py index 24fa74e9..b293f368 100644 --- a/mautrix_telegram/web/common/auth_api.py +++ b/mautrix_telegram/web/common/auth_api.py @@ -23,7 +23,7 @@ from telethon.errors import * from ...commands.auth import enter_password from ...util import format_duration -from ...puppet import Puppet +from ...puppet import Puppet, PuppetError from ...user import User @@ -51,12 +51,13 @@ class AuthAPI(abc.ABC): "account.", errcode="already-logged-in") resp = await puppet.switch_mxid(token, user.mxid) - if resp == 2: + if resp == PuppetError.OnlyLoginSelf: return self.get_mx_login_response(status=403, errcode="only-login-self", error="You can only log in as your own Matrix user.") - elif resp == 1: + elif resp == PuppetError.InvalidAccessToken: return self.get_mx_login_response(status=401, errcode="invalid-access-token", error="Failed to verify access token.") + assert resp == PuppetError.Success, "Encountered an unhandled PuppetError." return self.get_mx_login_response(mxid=user.mxid, status=200, state="logged-in") diff --git a/mautrix_telegram/web/provisioning/__init__.py b/mautrix_telegram/web/provisioning/__init__.py index 04aa499a..ad4f3635 100644 --- a/mautrix_telegram/web/provisioning/__init__.py +++ b/mautrix_telegram/web/provisioning/__init__.py @@ -15,7 +15,7 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . from aiohttp import web -from typing import Tuple, Optional, Callable, Awaitable, TYPE_CHECKING +from typing import Awaitable, Callable, Dict, Optional, Tuple, TYPE_CHECKING import asyncio import logging import json @@ -24,6 +24,7 @@ from telethon.utils import get_peer_id, resolve_id from telethon.tl.types import ChatForbidden, ChannelForbidden, TypeChat from mautrix_appservice import AppService, MatrixRequestError, IntentError +from ...types import MatrixUserId from ...user import User from ...portal import Portal from ...commands.portal import user_has_power_level, get_initial_state @@ -36,7 +37,7 @@ if TYPE_CHECKING: class ProvisioningAPI(AuthAPI): log = logging.getLogger("mau.web.provisioning") - def __init__(self, context: "Context"): + def __init__(self, context: "Context") -> None: super().__init__(context.loop) self.secret = context.config["appservice.provisioning.shared_secret"] self.az = context.az # type: AppService @@ -411,7 +412,7 @@ class ProvisioningAPI(AuthAPI): except json.JSONDecodeError: return None - async def get_user(self, mxid: str, expect_logged_in: Optional[bool] = False, + async def get_user(self, mxid: MatrixUserId, expect_logged_in: Optional[bool] = False, require_puppeting: bool = True, require_user: bool = True ) -> Tuple[Optional[User], Optional[web.Response]]: if not mxid: @@ -439,7 +440,7 @@ class ProvisioningAPI(AuthAPI): expect_logged_in: Optional[bool] = False, require_puppeting: bool = False, want_data: bool = True, - ) -> (Tuple[Optional[dict], + ) -> (Tuple[Optional[Dict], Optional[User], Optional[web.Response]]): err = self.check_authorization(request)