From d4e39569416c4e6eaed40165e1d005429cb09ba5 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 19 Jul 2019 21:36:21 +0300 Subject: [PATCH] Even more migrations to mautrix-python --- mautrix_telegram/commands/clean_rooms.py | 34 +- mautrix_telegram/commands/handler.py | 23 +- mautrix_telegram/context.py | 5 +- mautrix_telegram/db/__init__.py | 2 - mautrix_telegram/db/base.py | 9 +- mautrix_telegram/formatter/from_telegram.py | 2 +- mautrix_telegram/formatter/util.py | 7 +- mautrix_telegram/matrix.py | 4 +- mautrix_telegram/portal.py | 4 +- mautrix_telegram/puppet.py | 310 +++++------------- .../scripts/dbms_migrate/__main__.py | 13 +- mautrix_telegram/tgclient.py | 2 +- mautrix_telegram/types.py | 2 +- mautrix_telegram/user.py | 13 +- mautrix_telegram/util/__init__.py | 5 +- mautrix_telegram/util/file_transfer.py | 3 +- mautrix_telegram/util/recursive_dict.py | 9 +- mautrix_telegram/web/common/auth_api.py | 24 +- mautrix_telegram/web/provisioning/__init__.py | 52 ++- mautrix_telegram/web/public/__init__.py | 29 +- 20 files changed, 215 insertions(+), 337 deletions(-) diff --git a/mautrix_telegram/commands/clean_rooms.py b/mautrix_telegram/commands/clean_rooms.py index 81d7c4e8..1501bc2f 100644 --- a/mautrix_telegram/commands/clean_rooms.py +++ b/mautrix_telegram/commands/clean_rooms.py @@ -13,41 +13,41 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Dict, List, NewType, Optional, Tuple, Union +from typing import Dict, List, NamedTuple, Optional, Tuple, Union -from mautrix_appservice import MatrixRequestError, IntentAPI +from mautrix.appservice import IntentAPI +from mautrix.errors import MatrixRequestError +from mautrix.types import RoomID, UserID -from ..types import MatrixRoomID, MatrixUserID from . import command_handler, CommandEvent, SECTION_ADMIN from .. import puppet as pu, portal as po -ManagementRoom = NewType('ManagementRoom', Tuple[MatrixRoomID, MatrixUserID]) +ManagementRoom = NamedTuple('ManagementRoom', room_id=RoomID, user_id=UserID) -async def _find_rooms(intent: IntentAPI) -> Tuple[List[ManagementRoom], List[MatrixRoomID], +async def _find_rooms(intent: IntentAPI) -> Tuple[List[ManagementRoom], List[RoomID], List['po.Portal'], List['po.Portal']]: - management_rooms = [] # type: List[ManagementRoom] - unidentified_rooms = [] # type: List[MatrixRoomID] - portals = [] # type: List[po.Portal] - empty_portals = [] # type: List[po.Portal] + management_rooms: List[ManagementRoom] = [] + unidentified_rooms: List[RoomID] = [] + portals: List[po.Portal] = [] + empty_portals: List[po.Portal] = [] rooms = await intent.get_joined_rooms() - for room_str in rooms: - room = MatrixRoomID(room_str) - portal = po.Portal.get_by_mxid(room) + for room_id in rooms: + portal = po.Portal.get_by_mxid(room_id) if not portal: try: - members = await intent.get_room_members(room) + members = await intent.get_room_members(room_id) except MatrixRequestError: members = [] if len(members) == 2: - other_member = MatrixUserID(members[0] if members[0] != intent.mxid else members[1]) + other_member = members[0] if members[0] != intent.mxid else members[1] if pu.Puppet.get_id_from_mxid(other_member): - unidentified_rooms.append(room) + unidentified_rooms.append(room_id) else: - management_rooms.append(ManagementRoom((room, other_member))) + management_rooms.append(ManagementRoom(room_id, other_member)) else: - unidentified_rooms.append(room) + unidentified_rooms.append(room_id) else: members = await portal.get_authenticated_matrix_users() if len(members) == 0: diff --git a/mautrix_telegram/commands/handler.py b/mautrix_telegram/commands/handler.py index 63553082..1647ea6e 100644 --- a/mautrix_telegram/commands/handler.py +++ b/mautrix_telegram/commands/handler.py @@ -22,11 +22,12 @@ import commonmark from telethon.errors import FloodWaitError -from ..types import MatrixRoomID, MatrixEventID +from mautrix.types import RoomID, EventID + from ..util import format_duration from .. import user as u, context as c -command_handlers = {} # type: Dict[str, CommandHandler] +command_handlers: Dict[str, 'CommandHandler'] = {} HelpSection = NamedTuple('HelpSection', [('name', str), ('order', int), ('description', str)]) @@ -82,7 +83,7 @@ class CommandEvent: is a portal. """ - def __init__(self, processor: 'CommandProcessor', room: MatrixRoomID, event: MatrixEventID, + def __init__(self, processor: 'CommandProcessor', room: RoomID, event: EventID, sender: u.User, command: str, args: List[str], is_management: bool, is_portal: bool) -> None: self.az = processor.az @@ -101,7 +102,7 @@ class CommandEvent: self.is_portal = is_portal def reply(self, message: str, allow_html: bool = False, render_markdown: bool = True - ) -> Awaitable[Dict]: + ) -> Awaitable[EventID]: """Write a reply to the room in which the command was issued. Replaces occurences of "$cmdprefix" in the message with the command @@ -178,7 +179,7 @@ class CommandHandler: help_section: Section of the help in which this command will appear. """ - def __init__(self, handler: Callable[[CommandEvent], Awaitable[Dict]], needs_auth: bool, + def __init__(self, handler: Callable[[CommandEvent], Awaitable[EventID]], needs_auth: bool, needs_puppeting: bool, needs_matrix_puppeting: bool, needs_admin: bool, management_only: bool, name: str, help_text: str, help_args: str, help_section: HelpSection) -> None: @@ -255,7 +256,7 @@ class CommandHandler: (not self.needs_admin or is_admin) and (not self.needs_auth or is_logged_in)) - async def __call__(self, evt: CommandEvent) -> Dict: + async def __call__(self, evt: CommandEvent) -> EventID: """Executes the command if evt was issued with proper rights. Args: @@ -283,14 +284,14 @@ class CommandHandler: return f"**{self.name}** {self._help_args} - {self._help_text}" -def command_handler(_func: Optional[Callable[[CommandEvent], Awaitable[Dict]]] = None, *, +def command_handler(_func: Optional[Callable[[CommandEvent], Awaitable[EventID]]] = None, *, needs_auth: bool = True, needs_puppeting: bool = True, needs_matrix_puppeting: bool = False, needs_admin: bool = False, management_only: bool = False, name: Optional[str] = None, help_text: str = "", help_args: str = "", help_section: HelpSection = None - ) -> Callable[[Callable[[CommandEvent], Awaitable[Optional[Dict]]]], + ) -> Callable[[Callable[[CommandEvent], Awaitable[Optional[EventID]]]], CommandHandler]: - def decorator(func: Callable[[CommandEvent], Awaitable[Optional[Dict]]]) -> CommandHandler: + def decorator(func: Callable[[CommandEvent], Awaitable[Optional[EventID]]]) -> CommandHandler: actual_name = name or func.__name__.replace("_", "-") handler = CommandHandler(func, needs_auth, needs_puppeting, needs_matrix_puppeting, needs_admin, management_only, actual_name, help_text, help_args, @@ -310,9 +311,9 @@ class CommandProcessor: self.public_website = context.public_website self.command_prefix = self.config["bridge.command_prefix"] - async def handle(self, room: MatrixRoomID, event_id: MatrixEventID, sender: u.User, + async def handle(self, room: RoomID, event_id: EventID, sender: u.User, command: str, args: List[str], is_management: bool, is_portal: bool - ) -> Optional[Dict]: + ) -> Optional[EventID]: """Handles the raw commands issued by a user to the Matrix bot. If the command is not known, it might be a followup command and is diff --git a/mautrix_telegram/context.py b/mautrix_telegram/context.py index e1a1de2f..4566de3f 100644 --- a/mautrix_telegram/context.py +++ b/mautrix_telegram/context.py @@ -14,14 +14,13 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . from typing import Optional, Tuple, TYPE_CHECKING - import asyncio from alchemysession import AlchemySessionContainer -from mautrix_appservice import AppService + +from mautrix.appservice import AppService if TYPE_CHECKING: - from .web import PublicBridgeWebsite, ProvisioningAPI from .config import Config from .bot import Bot diff --git a/mautrix_telegram/db/__init__.py b/mautrix_telegram/db/__init__.py index 3afdb48d..d2544765 100644 --- a/mautrix_telegram/db/__init__.py +++ b/mautrix_telegram/db/__init__.py @@ -18,10 +18,8 @@ from .bot_chat import BotChat from .message import Message from .portal import Portal from .puppet import Puppet -from .room_state import RoomState from .telegram_file import TelegramFile from .user import User, UserPortal, Contact -from .user_profile import UserProfile def init(db_engine) -> None: diff --git a/mautrix_telegram/db/base.py b/mautrix_telegram/db/base.py index edf6726e..266a8b53 100644 --- a/mautrix_telegram/db/base.py +++ b/mautrix_telegram/db/base.py @@ -23,10 +23,10 @@ from sqlalchemy.ext.declarative import declarative_base class BaseBase: - db = None # type: Engine - t = None # type: Table - __table__ = None # type: Table - c = None # type: ImmutableColumnCollection + db: Engine = None + t: Table = None + __table__: Table = None + c: ImmutableColumnCollection = None @classmethod @abstractmethod @@ -54,4 +54,5 @@ class BaseBase: with self.db.begin() as conn: conn.execute(self.t.delete().where(self._edit_identity)) + Base = declarative_base(cls=BaseBase) diff --git a/mautrix_telegram/formatter/from_telegram.py b/mautrix_telegram/formatter/from_telegram.py index dbca951f..b3273249 100644 --- a/mautrix_telegram/formatter/from_telegram.py +++ b/mautrix_telegram/formatter/from_telegram.py @@ -38,7 +38,7 @@ from .util import (add_surrogates, remove_surrogates, trim_reply_fallback_html, if TYPE_CHECKING: from ..abstract_user import AbstractUser -log = logging.getLogger("mau.fmt.tg") # type: logging.Logger +log: logging.Logger = logging.getLogger("mau.fmt.tg") def telegram_reply_to_matrix(evt: Message, source: 'AbstractUser') -> Dict: diff --git a/mautrix_telegram/formatter/util.py b/mautrix_telegram/formatter/util.py index 67973c61..6c6dfb10 100644 --- a/mautrix_telegram/formatter/util.py +++ b/mautrix_telegram/formatter/util.py @@ -14,7 +14,6 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . from typing import Optional, Pattern -from html import escape import struct import re @@ -47,9 +46,9 @@ def trim_reply_fallback_text(text: str) -> str: return "\n".join(lines) -html_reply_fallback_regex = re.compile("^" - r"[\s\S]+?" - "") # type: Pattern +html_reply_fallback_regex: Pattern = re.compile("^" + r"[\s\S]+?" + "") def trim_reply_fallback_html(html: str) -> str: diff --git a/mautrix_telegram/matrix.py b/mautrix_telegram/matrix.py index 68319f3b..76519a39 100644 --- a/mautrix_telegram/matrix.py +++ b/mautrix_telegram/matrix.py @@ -13,9 +13,7 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Dict, Match, Optional, Set, Tuple, Union, Iterable, TYPE_CHECKING -import time -import re +from typing import Dict, Set, Tuple, Union, Iterable, TYPE_CHECKING from mautrix.bridge import BaseMatrixHandler from mautrix.types import (Event, EventType, RoomID, UserID, EventID, ReceiptEvent, ReceiptType, diff --git a/mautrix_telegram/portal.py b/mautrix_telegram/portal.py index b11ec5ee..4c1f5f13 100644 --- a/mautrix_telegram/portal.py +++ b/mautrix_telegram/portal.py @@ -80,7 +80,7 @@ if TYPE_CHECKING: from .config import Config from .tgclient import MautrixTelegramClient -config: Optional[Config] = None +config: Optional['Config'] = None TypeMessage = Union[Message, MessageService] TypeParticipant = Union[TypeChatParticipant, TypeChannelParticipant] @@ -91,7 +91,7 @@ InviteList = Union[UserID, List[UserID]] class Portal: base_log: logging.Logger = logging.getLogger("mau.portal") az: AppService = None - bot: Bot = None + bot: 'Bot' = None loop: asyncio.AbstractEventLoop = None # Config cache diff --git a/mautrix_telegram/puppet.py b/mautrix_telegram/puppet.py index 85d7f326..8aaa2ad2 100644 --- a/mautrix_telegram/puppet.py +++ b/mautrix_telegram/puppet.py @@ -13,19 +13,21 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Awaitable, Any, Dict, List, Iterable, Optional, Pattern, Union, TYPE_CHECKING +from typing import Awaitable, Any, Dict, Iterable, Optional, Union, TYPE_CHECKING from difflib import SequenceMatcher -from enum import Enum -from aiohttp import ServerDisconnectedError import asyncio import logging import re from telethon.tl.types import (UserProfilePhoto, User, UpdateUserName, PeerUser, TypeInputPeer, InputPeerPhotoFileLocation, UserProfilePhotoEmpty) -from mautrix_appservice import AppService, IntentAPI, IntentError, MatrixRequestError -from .types import MatrixUserID, TelegramID +from mautrix.appservice import AppService, IntentAPI +from mautrix.errors import MatrixRequestError +from mautrix.bridge import CustomPuppetMixin +from mautrix.types import UserID + +from .types import TelegramID from .db import Puppet as DBPuppet from . import util @@ -35,26 +37,48 @@ if TYPE_CHECKING: from .context import Context from .abstract_user import AbstractUser -PuppetError = Enum('PuppetError', 'Success OnlyLoginSelf InvalidAccessToken') - -config = None # type: Config +config: Optional['Config'] = None -class Puppet: - log = logging.getLogger("mau.puppet") # type: logging.Logger - az = None # type: AppService - mx = None # type: MatrixHandler - loop = None # type: asyncio.AbstractEventLoop - mxid_regex = None # type: Pattern - username_template = None # type: str - hs_domain = None # type: str - cache = {} # type: Dict[TelegramID, Puppet] - by_custom_mxid = {} # type: Dict[str, Puppet] +class Puppet(CustomPuppetMixin): + log: logging.Logger = logging.getLogger("mau.puppet") + az: AppService + mx: 'MatrixHandler' + loop: asyncio.AbstractEventLoop + username_template: str + hs_domain: str + _mxid_prefix: str + _mxid_suffix: str + _displayname_prefix: str + _displayname_suffix: str + + cache: Dict[TelegramID, 'Puppet'] = {} + by_custom_mxid: Dict[UserID, 'Puppet'] = {} + + id: TelegramID + access_token: Optional[str] + custom_mxid: Optional[UserID] + default_mxid: UserID + + username: Optional[str] + displayname: Optional[str] + displayname_source: Optional[TelegramID] + photo_id: Optional[str] + is_bot: bool + is_registered: bool + disable_updates: bool + + default_mxid_intent: IntentAPI + intent: IntentAPI + + sync_task: Optional[asyncio.Future] + + _db_instance: Optional[DBPuppet] def __init__(self, id: TelegramID, access_token: Optional[str] = None, - custom_mxid: Optional[MatrixUserID] = None, + custom_mxid: Optional[UserID] = None, username: Optional[str] = None, displayname: Optional[str] = None, displayname_source: Optional[TelegramID] = None, @@ -63,41 +87,32 @@ class Puppet: is_registered: bool = False, disable_updates: bool = False, db_instance: Optional[DBPuppet] = None) -> None: - self.id = id # type: TelegramID - self.access_token = access_token # type: Optional[str] - self.custom_mxid = custom_mxid # type: Optional[MatrixUserID] - self.default_mxid = self.get_mxid_from_id(self.id) # type: MatrixUserID + self.id = id + self.access_token = access_token + self.custom_mxid = custom_mxid + self.default_mxid = self.get_mxid_from_id(self.id) - self.username = username # type: Optional[str] - self.displayname = displayname # type: Optional[str] - self.displayname_source = displayname_source # type: Optional[TelegramID] - self.photo_id = photo_id # type: Optional[str] - self.is_bot = is_bot # type: bool - self.is_registered = is_registered # type: bool - self.disable_updates = disable_updates # type: bool - self._db_instance = db_instance # type: Optional[DBPuppet] + self.username = username + self.displayname = displayname + self.displayname_source = displayname_source + self.photo_id = photo_id + self.is_bot = is_bot + self.is_registered = is_registered + self.disable_updates = disable_updates + self._db_instance = db_instance self.default_mxid_intent = self.az.intent.user(self.default_mxid) - self.intent = self._fresh_intent() # type: IntentAPI - self.sync_task = None # type: Optional[asyncio.Future] + self.intent = self._fresh_intent() + self.sync_task = None self.cache[id] = self if self.custom_mxid: self.by_custom_mxid[self.custom_mxid] = self - @property - def mxid(self) -> MatrixUserID: - return self.custom_mxid or self.default_mxid - @property def tgid(self) -> TelegramID: return self.id - @property - def is_real_user(self) -> bool: - """ Is True when the puppet is a real Matrix user. """ - return bool(self.custom_mxid and self.access_token) - @staticmethod async def is_logged_in() -> bool: """ Is True if the puppet is logged in. """ @@ -105,175 +120,15 @@ class Puppet: @property def plain_displayname(self) -> str: - tpl = config["bridge.displayname_template"] - if tpl == "{displayname}": - # Template has no extra stuff, no need to parse. - return self.displayname - regex = re.compile("^" + re.escape(tpl).replace(re.escape("{displayname}"), "(.+?)") + "$") - match = regex.match(self.displayname) - return match.group(1) or self.displayname + prefix = self._mxid_prefix + suffix = self._mxid_suffix + if self.displayname[:len(prefix)] == prefix and self.displayname[-len(suffix):] == suffix: + return self.displayname[len(prefix):-len(suffix)] + return self.displayname def get_input_entity(self, user: 'AbstractUser') -> Awaitable[TypeInputPeer]: return user.client.get_input_entity(PeerUser(user_id=self.tgid)) - # region Custom puppet management - def _fresh_intent(self) -> IntentAPI: - return (self.az.intent.user(self.custom_mxid, self.access_token) - if self.is_real_user else self.default_mxid_intent) - - async def switch_mxid(self, access_token: Optional[str], - mxid: Optional[MatrixUserID]) -> PuppetError: - prev_mxid = self.custom_mxid - self.custom_mxid = mxid - self.access_token = access_token - self.intent = self._fresh_intent() - - err = await self.init_custom_mxid() - if err != PuppetError.Success: - return err - - try: - del self.by_custom_mxid[prev_mxid] # type: ignore - except KeyError: - pass - if self.mxid != self.default_mxid: - self.by_custom_mxid[self.mxid] = self - await self.leave_rooms_with_default_user() - self.save() - return PuppetError.Success - - async def init_custom_mxid(self) -> PuppetError: - if not self.is_real_user: - return PuppetError.Success - - mxid = await self.intent.whoami() - if not mxid or mxid != self.custom_mxid: - self.custom_mxid = None - self.access_token = None - self.intent = self._fresh_intent() - if mxid != self.custom_mxid: - return PuppetError.OnlyLoginSelf - return PuppetError.InvalidAccessToken - if config["bridge.sync_with_custom_puppets"]: - self.sync_task = asyncio.ensure_future(self.sync(), loop=self.loop) - return PuppetError.Success - - async def leave_rooms_with_default_user(self) -> None: - for room_id in await self.default_mxid_intent.get_joined_rooms(): - try: - await self.default_mxid_intent.leave_room(room_id) - await self.intent.ensure_joined(room_id) - except (IntentError, MatrixRequestError): - pass - - def create_sync_filter(self) -> Awaitable[str]: - return self.intent.client.create_filter(self.custom_mxid, { - "room": { - "include_leave": False, - "state": { - "types": [] - }, - "timeline": { - "types": [], - }, - "ephemeral": { - "types": ["m.typing", "m.receipt"], - }, - "account_data": { - "types": [] - } - }, - "account_data": { - "types": [], - }, - "presence": { - "types": ["m.presence"], - "senders": [self.custom_mxid], - }, - }) - - def filter_events(self, events: List[Dict]) -> List: - new_events = [] - for event in events: - evt_type = event.get("type", None) - event.setdefault("content", {}) - if evt_type == "m.typing": - is_typing = self.custom_mxid in event["content"].get("user_ids", []) - event["content"]["user_ids"] = [self.custom_mxid] if is_typing else [] - elif evt_type == "m.receipt": - val = None - evt = None - for event_id in event["content"]: - try: - val = event["content"][event_id]["m.read"][self.custom_mxid] - evt = event_id - break - except KeyError: - pass - if val and evt: - event["content"] = {evt: {"m.read": { - self.custom_mxid: val - }}} - else: - continue - new_events.append(event) - return new_events - - def handle_sync(self, presence: List, ephemeral: Dict) -> None: - presence_events = [self.mx.try_handle_ephemeral_event(event) for event in presence] - - for room_id, events in ephemeral.items(): - for event in events: - event["room_id"] = room_id - - ephemeral_events = [self.mx.try_handle_ephemeral_event(event) - for events in ephemeral.values() - for event in self.filter_events(events)] - - events = ephemeral_events + presence_events # List[Callable[[int], Awaitable[None]]] - coro = asyncio.gather(*events, loop=self.loop) - asyncio.ensure_future(coro, loop=self.loop) - - async def sync(self) -> None: - try: - await self._sync() - except asyncio.CancelledError: - self.log.info("Syncing cancelled") - except Exception: - self.log.exception("Fatal error syncing") - - async def _sync(self) -> None: - if not self.is_real_user: - self.log.warning("Called sync() for non-custom puppet.") - return - custom_mxid = self.custom_mxid - access_token_at_start = self.access_token - errors = 0 - next_batch = None - filter_id = await self.create_sync_filter() - self.log.debug(f"Starting syncer for {custom_mxid} with sync filter {filter_id}.") - while access_token_at_start == self.access_token: - try: - sync_resp = await self.intent.client.sync(filter=filter_id, since=next_batch, - set_presence="offline") # type: Dict - errors = 0 - if next_batch is not None: - presence = sync_resp.get("presence", {}).get("events", []) # type: List - ephemeral = {room: data.get("ephemeral", {}).get("events", []) - for room, data - in sync_resp.get("rooms", {}).get("join", {}).items() - } # type: Dict - self.handle_sync(presence, ephemeral) - next_batch = sync_resp.get("next_batch", None) - except (MatrixRequestError, ServerDisconnectedError) as e: - wait = min(errors, 11) ** 2 - self.log.warning(f"Syncer for {custom_mxid} errored: {e}. " - f"Waiting for {wait} seconds...") - errors += 1 - await asyncio.sleep(wait) - self.log.debug(f"Syncer for custom puppet {custom_mxid} stopped.") - - # endregion # region DB conversion @property @@ -378,7 +233,7 @@ class Puppet: self.displayname = displayname self.displayname_source = source.tgid try: - await self.default_mxid_intent.set_display_name(displayname[:100]) + await self.default_mxid_intent.set_displayname(displayname[:100]) except MatrixRequestError: self.log.exception("Failed to set displayname") self.displayname = "" @@ -402,7 +257,7 @@ class Puppet: if not photo_id: self.photo_id = "" try: - await self.default_mxid_intent.set_avatar("") + await self.default_mxid_intent.set_avatar_url("") except MatrixRequestError: self.log.exception("Failed to set avatar") self.photo_id = "" @@ -418,7 +273,7 @@ class Puppet: if file: self.photo_id = photo_id try: - await self.default_mxid_intent.set_avatar(file.mxc) + await self.default_mxid_intent.set_avatar_url(file.mxc) except MatrixRequestError: self.log.exception("Failed to set avatar") self.photo_id = "" @@ -447,7 +302,7 @@ class Puppet: return None @classmethod - def get_by_mxid(cls, mxid: MatrixUserID, create: bool = True) -> Optional['Puppet']: + def get_by_mxid(cls, mxid: UserID, create: bool = True) -> Optional['Puppet']: tgid = cls.get_id_from_mxid(mxid) if tgid: return cls.get(tgid, create) @@ -455,7 +310,7 @@ class Puppet: return None @classmethod - def get_by_custom_mxid(cls, mxid: MatrixUserID) -> Optional['Puppet']: + def get_by_custom_mxid(cls, mxid: UserID) -> Optional['Puppet']: if not mxid: raise ValueError("Matrix ID can't be empty") @@ -479,15 +334,16 @@ class Puppet: for puppet in DBPuppet.all_with_custom_mxid()) @classmethod - def get_id_from_mxid(cls, mxid: MatrixUserID) -> Optional[TelegramID]: - match = cls.mxid_regex.match(mxid) - if match: - return TelegramID(int(match.group(1))) + def get_id_from_mxid(cls, mxid: UserID) -> Optional[TelegramID]: + prefix = cls._mxid_prefix + suffix = cls._mxid_suffix + if mxid[:len(prefix)] == prefix and mxid[-len(suffix):] == suffix: + return TelegramID(int(mxid[len(prefix):-len(suffix)])) return None @classmethod - def get_mxid_from_id(cls, tgid: TelegramID) -> MatrixUserID: - return MatrixUserID(f"@{cls.username_template.format(userid=tgid)}:{cls.hs_domain}") + def get_mxid_from_id(cls, tgid: TelegramID) -> UserID: + return UserID(f"@{cls.username_template.format(userid=tgid)}:{cls.hs_domain}") @classmethod def find_by_username(cls, username: str) -> Optional['Puppet']: @@ -525,8 +381,18 @@ def init(context: 'Context') -> Iterable[Awaitable[Any]]: global config Puppet.az, config, Puppet.loop, _ = context.core Puppet.mx = context.mx - Puppet.username_template = config.get("bridge.username_template", "telegram_{userid}") Puppet.hs_domain = config["homeserver"]["domain"] - Puppet.mxid_regex = re.compile( - f"@{Puppet.username_template.format(userid='([0-9]+)')}:{Puppet.hs_domain}") - return (puppet.init_custom_mxid() for puppet in Puppet.all_with_custom_mxid()) + + Puppet.username_template = config["bridge.username_template"] + index = Puppet.username_template.index("{userid}") + length = len("{userid}") + Puppet._mxid_prefix = f"@{Puppet.username_template[:index]}" + Puppet._mxid_suffix = f"{Puppet.username_template[index + length:]}:{Puppet.hs_domain}" + + displayname_template = config["bridge.displayname_template"] + index = displayname_template.index("{displayname}") + length = len("{displayname}") + Puppet._displayname_prefix = displayname_template[:index] + Puppet._displayname_suffix = displayname_template[index+length:] + + return (puppet.start() for puppet in Puppet.all_with_custom_mxid()) diff --git a/mautrix_telegram/scripts/dbms_migrate/__main__.py b/mautrix_telegram/scripts/dbms_migrate/__main__.py index e9edfffd..a6494b88 100644 --- a/mautrix_telegram/scripts/dbms_migrate/__main__.py +++ b/mautrix_telegram/scripts/dbms_migrate/__main__.py @@ -1,7 +1,8 @@ -import argparse -import sqlalchemy as sql +from typing import Union from sqlalchemy import orm from sqlalchemy.ext.declarative import declarative_base +import sqlalchemy as sql +import argparse from alchemysession import AlchemySessionContainer @@ -24,11 +25,12 @@ def log(message, end="\n"): def connect(to): import mautrix_telegram.db.base as base base.Base = declarative_base(cls=base.BaseBase) - from mautrix_telegram.db import (Portal, Message, UserPortal, User, RoomState, UserProfile, - Contact, Puppet, BotChat, TelegramFile) + from mautrix_telegram.db import (Portal, Message, UserPortal, User, Contact, Puppet, BotChat, + TelegramFile) + from mautrix.bridge.db import RoomState, UserProfile db_engine = sql.create_engine(to) db_factory = orm.sessionmaker(bind=db_engine) - db_session = orm.scoped_session(db_factory) # type: orm.Session + db_session: Union[orm.Session, orm.scoped_session] = orm.scoped_session(db_factory) base.Base.metadata.bind = db_engine session_container = AlchemySessionContainer(engine=db_engine, session=db_session, table_base=base.Base, table_prefix="telethon_", @@ -52,6 +54,7 @@ def connect(to): "TelegramFile": TelegramFile, } + log("Connecting to old database") session, tables = connect(args.from_url) diff --git a/mautrix_telegram/tgclient.py b/mautrix_telegram/tgclient.py index 2f49bcd6..670d5da6 100644 --- a/mautrix_telegram/tgclient.py +++ b/mautrix_telegram/tgclient.py @@ -31,7 +31,7 @@ class MautrixTelegramClient(TelegramClient): attributes: List[TypeDocumentAttribute] = None, file_name: str = None, max_image_size: float = 10 * 1000 ** 2, ) -> Union[InputMediaUploadedDocument, InputMediaUploadedPhoto]: - file_handle = await super().upload_file(file, file_name=file_name, use_cache=False) + file_handle = await super().upload_file(file, file_name=file_name) if (mime_type == "image/png" or mime_type == "image/jpeg") and len(file) < max_image_size: return InputMediaUploadedPhoto(file_handle) diff --git a/mautrix_telegram/types.py b/mautrix_telegram/types.py index a1871b48..f5cb2145 100644 --- a/mautrix_telegram/types.py +++ b/mautrix_telegram/types.py @@ -1,3 +1,3 @@ -from typing import Dict, NewType +from typing import NewType TelegramID = NewType('TelegramID', int) diff --git a/mautrix_telegram/user.py b/mautrix_telegram/user.py index eacea399..3d184e73 100644 --- a/mautrix_telegram/user.py +++ b/mautrix_telegram/user.py @@ -25,9 +25,11 @@ from telethon.tl.types import ( from telethon.tl.types.contacts import ContactsNotModified from telethon.tl.functions.contacts import GetContactsRequest, SearchRequest from telethon.tl.functions.account import UpdateStatusRequest -from mautrix_appservice import MatrixRequestError -from .types import MatrixUserID, TelegramID +from mautrix.errors import MatrixRequestError +from mautrix.types import UserID + +from .types import TelegramID from .db import User as DBUser from .abstract_user import AbstractUser from . import portal as po, puppet as pu @@ -54,7 +56,7 @@ class User(AbstractUser): _db_instance: Optional[DBUser] - def __init__(self, mxid: MatrixUserID, tgid: Optional[TelegramID] = None, + def __init__(self, mxid: UserID, tgid: Optional[TelegramID] = None, username: Optional[str] = None, phone: Optional[str] = None, db_contacts: Optional[Iterable[TelegramID]] = None, saved_contacts: int = 0, is_bot: bool = False, @@ -250,7 +252,8 @@ class User(AbstractUser): if not portal or portal.deleted or not portal.mxid or portal.has_bot: continue try: - await portal.main_intent.kick(portal.mxid, self.mxid, "Logged out of Telegram.") + await portal.main_intent.kick_user(portal.mxid, self.mxid, + "Logged out of Telegram.") except MatrixRequestError: pass self.portals = {} @@ -356,7 +359,7 @@ class User(AbstractUser): # region Class instance lookup @classmethod - def get_by_mxid(cls, mxid: MatrixUserID, create: bool = True) -> Optional['User']: + def get_by_mxid(cls, mxid: UserID, create: bool = True) -> Optional['User']: if not mxid: raise ValueError("Matrix ID can't be empty") diff --git a/mautrix_telegram/util/__init__.py b/mautrix_telegram/util/__init__.py index 2ba35c28..0e8b20ab 100644 --- a/mautrix_telegram/util/__init__.py +++ b/mautrix_telegram/util/__init__.py @@ -1,7 +1,10 @@ +from asyncio import Future + from .file_transfer import transfer_file_to_matrix, convert_image from .format_duration import format_duration from .signed_token import sign_token, verify_token from .recursive_dict import recursive_del, recursive_set, recursive_get -def ignore_coro(coro): + +def ignore_coro(_: Future) -> None: pass diff --git a/mautrix_telegram/util/file_transfer.py b/mautrix_telegram/util/file_transfer.py index 3aebaa70..c4c01930 100644 --- a/mautrix_telegram/util/file_transfer.py +++ b/mautrix_telegram/util/file_transfer.py @@ -27,7 +27,8 @@ from telethon.tl.types import (Document, InputFileLocation, InputDocumentFileLoc InputPeerPhotoFileLocation) from telethon.errors import (AuthBytesInvalidError, AuthKeyInvalidError, LocationInvalidError, SecurityError, FileIdInvalidError) -from mautrix_appservice import IntentAPI + +from mautrix.appservice import IntentAPI from ..tgclient import MautrixTelegramClient from ..db import TelegramFile as DBTelegramFile diff --git a/mautrix_telegram/util/recursive_dict.py b/mautrix_telegram/util/recursive_dict.py index d5f51c80..6fb0b7e2 100644 --- a/mautrix_telegram/util/recursive_dict.py +++ b/mautrix_telegram/util/recursive_dict.py @@ -14,11 +14,12 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . from typing import Dict, Any -from ..config import DictWithRecursion + +from mautrix.util.config import RecursiveDict def recursive_set(data: Dict[str, Any], key: str, value: Any) -> bool: - key, next_key = DictWithRecursion._parse_key(key) + key, next_key = RecursiveDict.parse_key(key) if next_key is not None: if key not in data: data[key] = {} @@ -31,7 +32,7 @@ def recursive_set(data: Dict[str, Any], key: str, value: Any) -> bool: def recursive_get(data: Dict[str, Any], key: str) -> Any: - key, next_key = DictWithRecursion._parse_key(key) + key, next_key = RecursiveDict.parse_key(key) if next_key is not None: next_data = data.get(key, None) if not next_data: @@ -41,7 +42,7 @@ def recursive_get(data: Dict[str, Any], key: str) -> Any: def recursive_del(data: Dict[str, any], key: str) -> bool: - key, next_key = DictWithRecursion._parse_key(key) + key, next_key = RecursiveDict.parse_key(key) if next_key is not None: if key not in data: return False diff --git a/mautrix_telegram/web/common/auth_api.py b/mautrix_telegram/web/common/auth_api.py index e61693dd..ae353bad 100644 --- a/mautrix_telegram/web/common/auth_api.py +++ b/mautrix_telegram/web/common/auth_api.py @@ -13,27 +13,30 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from abc import abstractmethod from typing import Optional - -from aiohttp import web +from abc import abstractmethod import abc import asyncio import logging +from aiohttp import web + from telethon.errors import * +from mautrix.bridge import OnlyLoginSelf, InvalidAccessToken + from ...commands.telegram.auth import enter_password from ...util import format_duration, ignore_coro -from ...puppet import Puppet, PuppetError +from ...puppet import Puppet from ...user import User class AuthAPI(abc.ABC): - log = logging.getLogger("mau.web.auth") # type: logging.Logger + log: logging.Logger = logging.getLogger("mau.web.auth") + loop: asyncio.AbstractEventLoop def __init__(self, loop: asyncio.AbstractEventLoop): - self.loop = loop # type: asyncio.AbstractEventLoop + self.loop = loop @abstractmethod def get_login_response(self, status: int = 200, state: str = "", username: str = "", @@ -55,15 +58,14 @@ class AuthAPI(abc.ABC): error="You have already logged in with your Matrix " "account.", errcode="already-logged-in") - resp = await puppet.switch_mxid(token.strip(), user.mxid) - if resp == PuppetError.OnlyLoginSelf: + try: + await puppet.switch_mxid(token.strip(), user.mxid) + except OnlyLoginSelf: return self.get_mx_login_response(status=403, errcode="only-login-self", error="You can only log in as your own Matrix user.") - elif resp == PuppetError.InvalidAccessToken: + except InvalidAccessToken: return self.get_mx_login_response(status=401, errcode="invalid-access-token", error="Failed to verify access token.") - assert resp == PuppetError.Success, "Encountered an unhandled PuppetError." - return self.get_mx_login_response(mxid=user.mxid, status=200, state="logged-in") async def post_matrix_password(self, user: User, password: str) -> web.Response: diff --git a/mautrix_telegram/web/provisioning/__init__.py b/mautrix_telegram/web/provisioning/__init__.py index d60cfcdb..4c345937 100644 --- a/mautrix_telegram/web/provisioning/__init__.py +++ b/mautrix_telegram/web/provisioning/__init__.py @@ -13,17 +13,21 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from aiohttp import web from typing import Awaitable, Callable, Dict, Optional, Tuple, TYPE_CHECKING import asyncio import logging import json +from aiohttp import web + from telethon.utils import get_peer_id, resolve_id from telethon.tl.types import ChatForbidden, ChannelForbidden, TypeChat -from mautrix_appservice import AppService, MatrixRequestError, IntentError -from ...types import MatrixUserID, TelegramID +from mautrix.appservice import AppService +from mautrix.errors import MatrixRequestError, IntentError +from mautrix.types import UserID + +from ...types import TelegramID from ...user import User from ...portal import Portal from ...util import ignore_coro @@ -35,16 +39,19 @@ if TYPE_CHECKING: class ProvisioningAPI(AuthAPI): - log = logging.getLogger("mau.web.provisioning") # type: logging.Logger + log: logging.Logger = logging.getLogger("mau.web.provisioning") + secret: str + az: AppService + context: 'Context' + app: web.Application def __init__(self, context: "Context") -> None: super().__init__(context.loop) - self.secret = context.config["appservice.provisioning.shared_secret"] # type: str - self.az = context.az # type: AppService - self.context = context # type: Context + self.secret = context.config["appservice.provisioning.shared_secret"] + self.az = context.az + self.context = context - self.app = web.Application(loop=context.loop, middlewares=[self.error_middleware] - ) # type: web.Application + self.app = web.Application(loop=context.loop, middlewares=[self.error_middleware]) portal_prefix = "/portal/{mxid:![^/]+}" self.app.router.add_route("GET", f"{portal_prefix}", self.get_portal_by_mxid) @@ -76,18 +83,7 @@ class ProvisioningAPI(AuthAPI): if not portal: return self.get_error_response(404, "portal_not_found", "Portal with given Matrix ID not found.") - user, _ = await self.get_user(request.query.get("user_id", None), expect_logged_in=None, - require_puppeting=False) - return web.json_response({ - "mxid": portal.mxid, - "chat_id": get_peer_id(portal.peer), - "peer_type": portal.peer_type, - "title": portal.title, - "about": portal.about, - "username": portal.username, - "megagroup": portal.megagroup, - "can_unbridge": (await portal.can_user_perform(user, "unbridge")) if user else False, - }) + return await self._get_portal_response(UserID(request.query.get("user_id", "")), portal) async def get_portal_by_tgid(self, request: web.Request) -> web.Response: err = self.check_authorization(request) @@ -103,8 +99,10 @@ class ProvisioningAPI(AuthAPI): if not portal: return self.get_error_response(404, "portal_not_found", "Portal to given Telegram chat not found.") - user, _ = await self.get_user(request.query.get("user_id", None), expect_logged_in=None, - require_puppeting=False) + return await self._get_portal_response(UserID(request.query.get("user_id", "")), portal) + + async def _get_portal_response(self, user_id: UserID, portal: Portal) -> web.Response: + user, _ = await self.get_user(user_id, expect_logged_in=None, require_puppeting=False) return web.json_response({ "mxid": portal.mxid, "chat_id": get_peer_id(portal.peer), @@ -364,7 +362,8 @@ class ProvisioningAPI(AuthAPI): async def bridge_info(self, request: web.Request) -> web.Response: return web.json_response({ - "relaybot_username": self.context.bot.username if self.context.bot is not None else None, + "relaybot_username": (self.context.bot.username + if self.context.bot is not None else None), }, status=200) @staticmethod @@ -430,7 +429,7 @@ class ProvisioningAPI(AuthAPI): except json.JSONDecodeError: return None - async def get_user(self, mxid: MatrixUserID, expect_logged_in: Optional[bool] = False, + async def get_user(self, mxid: Optional[UserID], expect_logged_in: Optional[bool] = False, require_puppeting: bool = True, require_user: bool = True ) -> Tuple[Optional[User], Optional[web.Response]]: if not mxid: @@ -459,8 +458,7 @@ class ProvisioningAPI(AuthAPI): expect_logged_in: Optional[bool] = False, require_puppeting: bool = False, want_data: bool = True, - ) -> (Tuple[Optional[Dict], - Optional[User], + ) -> (Tuple[Optional[Dict], Optional[User], Optional[web.Response]]): err = self.check_authorization(request) if err is not None: diff --git a/mautrix_telegram/web/public/__init__.py b/mautrix_telegram/web/public/__init__.py index 36f60925..31dc07e3 100644 --- a/mautrix_telegram/web/public/__init__.py +++ b/mautrix_telegram/web/public/__init__.py @@ -14,16 +14,18 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . from typing import Optional -from aiohttp import web -from mako.template import Template -import pkg_resources import asyncio import logging import random import string import time -from ...types import MatrixUserID +from mako.template import Template +from aiohttp import web +import pkg_resources + +from mautrix.types import UserID + from ...util import sign_token, verify_token from ...user import User from ...puppet import Puppet @@ -31,20 +33,23 @@ from ..common import AuthAPI class PublicBridgeWebsite(AuthAPI): - log = logging.getLogger("mau.web.public") # type: logging.Logger + log: logging.Logger = logging.getLogger("mau.web.public") + secret_key: str + login: Template + mx_login: Template + app: web.Application def __init__(self, loop: asyncio.AbstractEventLoop): super().__init__(loop) - self.secret_key = "".join( - random.choice(string.ascii_lowercase + string.digits) for _ in range(64)) # type: str + self.secret_key = "".join(random.choices(string.ascii_lowercase + string.digits, k=64)) self.login = Template(pkg_resources.resource_string( - "mautrix_telegram", "web/public/login.html.mako")) # type: Template + "mautrix_telegram", "web/public/login.html.mako")) self.mx_login = Template(pkg_resources.resource_string( - "mautrix_telegram", "web/public/matrix-login.html.mako")) # type: Template + "mautrix_telegram", "web/public/matrix-login.html.mako")) - self.app = web.Application(loop=loop) # type: web.Application + self.app = web.Application(loop=loop) self.app.router.add_route("GET", "/login", self.get_login) self.app.router.add_route("POST", "/login", self.post_login) self.app.router.add_route("GET", "/matrix-login", self.get_matrix_login) @@ -59,11 +64,11 @@ class PublicBridgeWebsite(AuthAPI): "expiry": int(time.time()) + expires_in, }) - def verify_token(self, token: str, endpoint: str = "/login") -> Optional[MatrixUserID]: + def verify_token(self, token: str, endpoint: str = "/login") -> Optional[UserID]: token = verify_token(self.secret_key, token) if token and (token.get("expiry", 0) > int(time.time()) and token.get("endpoint", None) == endpoint): - return MatrixUserID(token.get("mxid", None)) + return UserID(token.get("mxid", None)) return None async def get_login(self, request: web.Request) -> web.Response: