From 89ab29ea5f6e827230c1f5fccbfcc70281df0349 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 20 Dec 2021 22:39:09 +0200 Subject: [PATCH] Switch from SQLAlchemy to asyncpg/aiosqlite --- Dockerfile | 5 - alembic/env.py | 4 - mautrix_telegram/__main__.py | 113 +- mautrix_telegram/abstract_user.py | 231 +- mautrix_telegram/bot.py | 70 +- mautrix_telegram/commands/handler.py | 30 +- mautrix_telegram/commands/matrix_auth.py | 10 +- mautrix_telegram/commands/portal/admin.py | 20 +- mautrix_telegram/commands/portal/bridge.py | 16 +- mautrix_telegram/commands/portal/config.py | 2 +- .../commands/portal/create_chat.py | 6 +- mautrix_telegram/commands/portal/misc.py | 21 +- mautrix_telegram/commands/portal/unbridge.py | 18 +- mautrix_telegram/commands/portal/util.py | 14 +- mautrix_telegram/commands/telegram/account.py | 11 +- mautrix_telegram/commands/telegram/auth.py | 28 +- mautrix_telegram/commands/telegram/misc.py | 30 +- mautrix_telegram/config.py | 4 +- mautrix_telegram/context.py | 57 - mautrix_telegram/db/__init__.py | 20 +- mautrix_telegram/db/bot_chat.py | 43 +- mautrix_telegram/db/message.py | 178 +- mautrix_telegram/db/portal.py | 120 +- mautrix_telegram/db/puppet.py | 117 +- mautrix_telegram/db/telegram_file.py | 107 +- mautrix_telegram/db/telethon_session.py | 204 ++ mautrix_telegram/db/upgrade/__init__.py | 5 + .../db/upgrade/v01_initial_revision.py | 300 ++ mautrix_telegram/db/user.py | 159 +- mautrix_telegram/example-config.yaml | 4 + mautrix_telegram/formatter/__init__.py | 7 +- .../formatter/from_matrix/__init__.py | 184 +- .../formatter/from_matrix/parser.py | 75 +- .../formatter/from_matrix/telegram_message.py | 16 +- mautrix_telegram/formatter/from_telegram.py | 86 +- mautrix_telegram/matrix.py | 135 +- mautrix_telegram/portal.py | 2759 +++++++++++++++++ mautrix_telegram/portal/__init__.py | 21 - mautrix_telegram/portal/__init__.pyi | 15 - mautrix_telegram/portal/base.py | 551 ---- mautrix_telegram/portal/matrix.py | 680 ---- mautrix_telegram/portal/metadata.py | 875 ------ mautrix_telegram/portal/telegram.py | 808 ----- mautrix_telegram/puppet.py | 347 +-- mautrix_telegram/scripts/__init__.py | 0 .../scripts/dbms_migrate/__init__.py | 0 .../scripts/dbms_migrate/__main__.py | 88 - .../scripts/telematrix_import/__init__.py | 0 .../scripts/telematrix_import/__main__.py | 125 - .../scripts/telematrix_import/models.py | 44 - mautrix_telegram/user.py | 426 ++- mautrix_telegram/util/__init__.py | 2 + .../{portal => util}/deduplication.py | 15 +- mautrix_telegram/util/file_transfer.py | 17 +- .../{portal => util}/send_lock.py | 2 +- mautrix_telegram/web/common/auth_api.py | 4 +- mautrix_telegram/web/provisioning/__init__.py | 51 +- mautrix_telegram/web/public/__init__.py | 8 +- optional-requirements.txt | 7 - requirements.txt | 12 +- setup.py | 2 +- 61 files changed, 4681 insertions(+), 4628 deletions(-) delete mode 100644 mautrix_telegram/context.py create mode 100644 mautrix_telegram/db/telethon_session.py create mode 100644 mautrix_telegram/db/upgrade/__init__.py create mode 100644 mautrix_telegram/db/upgrade/v01_initial_revision.py create mode 100644 mautrix_telegram/portal.py delete mode 100644 mautrix_telegram/portal/__init__.py delete mode 100644 mautrix_telegram/portal/__init__.pyi delete mode 100644 mautrix_telegram/portal/base.py delete mode 100644 mautrix_telegram/portal/matrix.py delete mode 100644 mautrix_telegram/portal/metadata.py delete mode 100644 mautrix_telegram/portal/telegram.py delete mode 100644 mautrix_telegram/scripts/__init__.py delete mode 100644 mautrix_telegram/scripts/dbms_migrate/__init__.py delete mode 100644 mautrix_telegram/scripts/dbms_migrate/__main__.py delete mode 100644 mautrix_telegram/scripts/telematrix_import/__init__.py delete mode 100644 mautrix_telegram/scripts/telematrix_import/__main__.py delete mode 100644 mautrix_telegram/scripts/telematrix_import/models.py rename mautrix_telegram/{portal => util}/deduplication.py (92%) rename mautrix_telegram/{portal => util}/send_lock.py (97%) diff --git a/Dockerfile b/Dockerfile index c954c630..994a5d37 100644 --- a/Dockerfile +++ b/Dockerfile @@ -8,10 +8,6 @@ RUN apk add --no-cache \ py3-pillow \ py3-aiohttp \ py3-magic \ - py3-sqlalchemy \ - py3-telethon-session-sqlalchemy \ - py3-alembic \ - py3-psycopg2 \ py3-ruamel.yaml \ py3-commonmark \ py3-prometheus-client \ @@ -53,7 +49,6 @@ RUN apk add --virtual .build-deps \ python3-dev \ libffi-dev \ build-base \ - && sed -Ei 's/psycopg2-binary.+//' optional-requirements.txt \ && pip3 install -r requirements.txt -r optional-requirements.txt \ && apk del .build-deps diff --git a/alembic/env.py b/alembic/env.py index 4f91e8a4..9fbd478f 100644 --- a/alembic/env.py +++ b/alembic/env.py @@ -8,9 +8,7 @@ from os.path import abspath, dirname sys.path.insert(0, dirname(dirname(abspath(__file__)))) from mautrix.util.db import Base -import mautrix_telegram.db from mautrix_telegram.config import Config -from alchemysession import AlchemySessionContainer # this is the Alembic Config object, which provides # access to the values within the .ini file in use. @@ -21,8 +19,6 @@ mxtg_config = Config(mxtg_config_path, None, None) mxtg_config.load() config.set_main_option("sqlalchemy.url", mxtg_config["appservice.database"].replace("%", "%%")) -AlchemySessionContainer.create_table_classes(None, "telethon_", Base) - # Interpret the config file for Python logging. # This line sets up loggers basically. fileConfig(config.config_file_name) diff --git a/mautrix_telegram/__main__.py b/mautrix_telegram/__main__.py index cc36d976..9a92366b 100644 --- a/mautrix_telegram/__main__.py +++ b/mautrix_telegram/__main__.py @@ -13,39 +13,27 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Dict, Any -import sys +from __future__ import annotations + +from typing import Any from telethon import __version__ as __telethon_version__ -from alchemysession import AlchemySessionContainer from mautrix.types import UserID, RoomID from mautrix.bridge import Bridge -from mautrix.util.db import Base -from mautrix.bridge.state_store.sqlalchemy import SQLBridgeStateStore from .web.provisioning import ProvisioningAPI from .web.public import PublicBridgeWebsite -from .abstract_user import init as init_abstract_user -from .bot import Bot, init as init_bot +from .abstract_user import AbstractUser +from .bot import Bot from .config import Config -from .context import Context -from .db import init as init_db -from .formatter import init as init_formatter +from .db import init as init_db, upgrade_table from .matrix import MatrixHandler -from .portal import Portal, init as init_portal -from .puppet import Puppet, init as init_puppet -from .user import User, init as init_user +from .portal import Portal +from .puppet import Puppet +from .user import User from .version import version, linkified_version -import sqlalchemy as sql -from sqlalchemy.engine.base import Engine - -try: - import prometheus_client as prometheus -except ImportError: - prometheus = None - class TelegramBridge(Bridge): module = "mautrix_telegram" @@ -57,55 +45,46 @@ class TelegramBridge(Bridge): markdown_version = linkified_version config_class = Config matrix_class = MatrixHandler - state_store_class = SQLBridgeStateStore + upgrade_table = upgrade_table - db: 'Engine' config: Config - session_container: AlchemySessionContainer - bot: Bot + bot: Bot | None + public_website: PublicBridgeWebsite | None + provisioning_api: ProvisioningAPI | None def prepare_db(self) -> None: - if not sql: - raise RuntimeError("SQLAlchemy is not installed") - self.db = sql.create_engine(self.config["appservice.database"], - **self.config["appservice.database_opts"]) - Base.metadata.bind = self.db - if not self.db.has_table("alembic_version"): - self.log.critical("alembic_version table not found. " - "Did you forget to `alembic upgrade head`?") - sys.exit(10) - + super().prepare_db() init_db(self.db) - self.session_container = AlchemySessionContainer( - engine=self.db, table_base=Base, session=False, - table_prefix="telethon_", manage_tables=False) - def make_state_store(self) -> None: - self.state_store = self.state_store_class(self.get_puppet, self.get_double_puppet) - - def _prepare_website(self, context: Context) -> None: + def _prepare_website(self) -> None: if self.config["appservice.public.enabled"]: - public_website = PublicBridgeWebsite(self.loop) - self.az.app.add_subapp(self.config["appservice.public.prefix"], public_website.app) - context.public_website = public_website + self.public_website = PublicBridgeWebsite(self.loop) + self.az.app.add_subapp( + self.config["appservice.public.prefix"], self.public_website.app + ) + else: + self.public_website = None if self.config["appservice.provisioning.enabled"]: - provisioning_api = ProvisioningAPI(context) - self.az.app.add_subapp(self.config["appservice.provisioning.prefix"], - provisioning_api.app) - context.provisioning_api = provisioning_api + self.provisioning_api = ProvisioningAPI(self) + self.az.app.add_subapp( + self.config["appservice.provisioning.prefix"], self.provisioning_api.app + ) + else: + self.provisioning_api = None def prepare_bridge(self) -> None: - self.bot = init_bot(self.config) - context = Context(self.az, self.config, self.loop, self.session_container, self, self.bot) - self._prepare_website(context) - self.matrix = context.mx = MatrixHandler(context) - - init_abstract_user(context) - init_formatter(context) - init_portal(context) - self.add_startup_actions(init_puppet(context)) - self.add_startup_actions(init_user(context)) + self._prepare_website() + AbstractUser.init_cls(self) + bot_token: str = self.config["telegram.bot_token"] + if bot_token and not bot_token.lower().startswith("disable"): + self.bot = AbstractUser.relaybot = Bot(bot_token) + else: + self.bot = AbstractUser.relaybot = None + self.matrix = MatrixHandler(self) + Portal.init_cls(self) + self.add_startup_actions(Puppet.init_cls(self)) + self.add_startup_actions(User.init_cls(self)) if self.bot: self.add_startup_actions(self.bot.start()) if self.config["bridge.resend_bridge_info"]: @@ -115,7 +94,7 @@ class TelegramBridge(Bridge): self.config["bridge.resend_bridge_info"] = False self.config.save() self.log.info("Re-sending bridge info state event to all portals") - for portal in Portal.all(): + async for portal in Portal.all(): await portal.update_bridge_info() self.log.info("Finished re-sending bridge info state events") @@ -124,19 +103,19 @@ class TelegramBridge(Bridge): puppet.stop() self.shutdown_actions = (user.stop() for user in User.by_tgid.values()) - async def get_user(self, user_id: UserID, create: bool = True) -> User: - user = User.get_by_mxid(user_id, create=create) + async def get_user(self, user_id: UserID, create: bool = True) -> User | None: + user = await User.get_by_mxid(user_id, create=create) if user: await user.ensure_started() return user - async def get_portal(self, room_id: RoomID) -> Portal: - return Portal.get_by_mxid(room_id) + async def get_portal(self, room_id: RoomID) -> Portal | None: + return await Portal.get_by_mxid(room_id) - async def get_puppet(self, user_id: UserID, create: bool = False) -> Puppet: + async def get_puppet(self, user_id: UserID, create: bool = False) -> Puppet | None: return await Puppet.get_by_mxid(user_id, create=create) - async def get_double_puppet(self, user_id: UserID) -> Puppet: + async def get_double_puppet(self, user_id: UserID) -> Puppet | None: return await Puppet.get_by_custom_mxid(user_id) def is_bridge_ghost(self, user_id: UserID) -> bool: @@ -145,7 +124,7 @@ class TelegramBridge(Bridge): async def count_logged_in_users(self) -> int: return len([user for user in User.by_tgid.values() if user.tgid]) - async def manhole_global_namespace(self, user_id: UserID) -> Dict[str, Any]: + async def manhole_global_namespace(self, user_id: UserID) -> dict[str, Any]: return { **await super().manhole_global_namespace(user_id), "User": User, diff --git a/mautrix_telegram/abstract_user.py b/mautrix_telegram/abstract_user.py index f9d9fe7b..f44e3182 100644 --- a/mautrix_telegram/abstract_user.py +++ b/mautrix_telegram/abstract_user.py @@ -13,7 +13,9 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Tuple, Optional, Union, Dict, Type, Any, TYPE_CHECKING +from __future__ import annotations + +from typing import Type, Any, Union, TYPE_CHECKING from abc import ABC, abstractmethod import platform import asyncio @@ -39,48 +41,50 @@ from mautrix.errors import MatrixError from mautrix.appservice import AppService from mautrix.util.logging import TraceLogger from mautrix.util.opt_prometheus import Histogram, Counter -from alchemysession import AlchemySessionContainer from . import portal as po, puppet as pu, __version__ -from .db import Message as DBMessage +from .db import Message as DBMessage, PgSession from .types import TelegramID from .tgclient import MautrixTelegramClient +from .config import Config if TYPE_CHECKING: - from .context import Context - from .config import Config from .bot import Bot from .__main__ import TelegramBridge -config: Optional['Config'] = None -# Value updated from config in init() -MAX_DELETIONS: int = 10 - UpdateMessage = Union[UpdateShortChatMessage, UpdateShortMessage, UpdateNewChannelMessage, UpdateNewMessage, UpdateEditMessage, UpdateEditChannelMessage] -UpdateMessageContent = Union[UpdateShortMessage, UpdateShortChatMessage, Message, MessageService] -UpdateTyping = Union[UpdateUserTyping, UpdateChatUserTyping, UpdateChannelUserTyping] +UpdateMessageContent = Union[ + UpdateShortMessage, UpdateShortChatMessage, Message, MessageService, MessageEmpty +] -UPDATE_TIME = Histogram("bridge_telegram_update", "Time spent processing Telegram updates", - ("update_type",)) -UPDATE_ERRORS = Counter("bridge_telegram_update_error", - "Number of fatal errors while handling Telegram updates", ("update_type",)) +UPDATE_TIME = Histogram( + name="bridge_telegram_update", + documentation="Time spent processing Telegram updates", + labelnames=("update_type",), +) +UPDATE_ERRORS = Counter( + name="bridge_telegram_update_error", + documentation="Number of fatal errors while handling Telegram updates", + labelnames=("update_type",), +) class AbstractUser(ABC): - session_container: AlchemySessionContainer = None loop: asyncio.AbstractEventLoop = None log: TraceLogger az: AppService bridge: 'TelegramBridge' - relaybot: Optional['Bot'] + config: Config + relaybot: 'Bot' ignore_incoming_bot_events: bool = True + max_deletions: int = 10 - client: Optional[MautrixTelegramClient] - mxid: Optional[UserID] + client: MautrixTelegramClient | None + mxid: UserID | None - tgid: Optional[TelegramID] - username: Optional['str'] + tgid: TelegramID | None + username: str | None is_bot: bool is_relaybot: bool @@ -106,14 +110,14 @@ class AbstractUser(ABC): return self.client and self.client.is_connected() @property - def _proxy_settings(self) -> Tuple[Type[Connection], Optional[Tuple[Any, ...]]]: - proxy_type = config["telegram.proxy.type"].lower() + def _proxy_settings(self) -> tuple[Type[Connection], tuple[Any, ...] | None]: + proxy_type = self.config["telegram.proxy.type"].lower() connection = ConnectionTcpFull - connection_data = (config["telegram.proxy.address"], - config["telegram.proxy.port"], - config["telegram.proxy.rdns"], - config["telegram.proxy.username"], - config["telegram.proxy.password"]) + connection_data = (self.config["telegram.proxy.address"], + self.config["telegram.proxy.port"], + self.config["telegram.proxy.rdns"], + self.config["telegram.proxy.username"], + self.config["telegram.proxy.password"]) if proxy_type == "disabled": connection_data = None elif proxy_type == "socks4": @@ -128,23 +132,32 @@ class AbstractUser(ABC): return connection, connection_data - def _init_client(self) -> None: + @classmethod + def init_cls(cls, bridge: 'TelegramBridge') -> None: + cls.bridge = bridge + cls.config = bridge.config + cls.loop = bridge.loop + cls.az = bridge.az + cls.ignore_incoming_bot_events = cls.config["bridge.relaybot.ignore_own_incoming_events"] + cls.max_deletions = cls.config["bridge.max_telegram_delete"] + + async def _init_client(self) -> None: self.log.debug(f"Initializing client for {self.name}") - session = self.session_container.new_session(self.name) - if config["telegram.server.enabled"]: - session.set_dc(config["telegram.server.dc"], - config["telegram.server.ip"], - config["telegram.server.port"]) + session = await PgSession.get(self.name) + if self.config["telegram.server.enabled"]: + session.set_dc(self.config["telegram.server.dc"], + self.config["telegram.server.ip"], + self.config["telegram.server.port"]) if self.is_relaybot: base_logger = logging.getLogger("telethon.relaybot") else: base_logger = logging.getLogger(f"telethon.{self.tgid or -hash(self.mxid)}") - device = config["telegram.device_info.device_model"] - sysversion = config["telegram.device_info.system_version"] - appversion = config["telegram.device_info.app_version"] + device = self.config["telegram.device_info.device_model"] + sysversion = self.config["telegram.device_info.system_version"] + appversion = self.config["telegram.device_info.app_version"] connection, proxy = self._proxy_settings assert isinstance(session, Session) @@ -152,8 +165,8 @@ class AbstractUser(ABC): self.client = MautrixTelegramClient( session=session, - api_id=config["telegram.api_id"], - api_hash=config["telegram.api_hash"], + api_id=self.config["telegram.api_id"], + api_hash=self.config["telegram.api_hash"], app_version=__version__ if appversion == "auto" else appversion, system_version=(MautrixTelegramClient.__version__ @@ -161,11 +174,11 @@ class AbstractUser(ABC): device_model=(f"{platform.system()} {platform.release()}" if device == "auto" else device), - timeout=config["telegram.connection.timeout"], - connection_retries=config["telegram.connection.retries"], - retry_delay=config["telegram.connection.retry_delay"], - flood_sleep_threshold=config["telegram.connection.flood_sleep_threshold"], - request_retries=config["telegram.connection.request_retries"], + timeout=self.config["telegram.connection.timeout"], + connection_retries=self.config["telegram.connection.retries"], + retry_delay=self.config["telegram.connection.retry_delay"], + flood_sleep_threshold=self.config["telegram.connection.flood_sleep_threshold"], + request_retries=self.config["telegram.connection.request_retries"], connection=connection, proxy=proxy, raise_last_call_error=True, @@ -216,17 +229,17 @@ class AbstractUser(ABC): and (not self.is_bot or allow_bot) and await self.is_logged_in()) - async def start(self, delete_unless_authenticated: bool = False) -> 'AbstractUser': + async def start(self, delete_unless_authenticated: bool = False) -> AbstractUser: if not self.client: - self._init_client() + await self._init_client() await self.client.connect() self.log.debug(f"{'Bot' if self.is_relaybot else self.mxid} connected: {self.connected}") return self - async def ensure_started(self, even_if_no_session=False) -> 'AbstractUser': + async def ensure_started(self, even_if_no_session=False) -> AbstractUser: if self.connected: return self - if even_if_no_session or self.session_container.has_session(self.mxid): + if even_if_no_session or await PgSession.has(self.mxid): self.log.debug("Starting client due to ensure_started" f"(even_if_no_session={even_if_no_session})") await self.start(delete_unless_authenticated=not even_if_no_session) @@ -281,19 +294,20 @@ class AbstractUser(ABC): async def update_notify_settings(self, update: UpdateNotifySettings) -> None: pass - async def update_pinned_messages(self, update: Union[UpdatePinnedMessages, - UpdatePinnedChannelMessages]) -> None: + async def update_pinned_messages( + self, update: UpdatePinnedMessages | UpdatePinnedChannelMessages + ) -> None: if isinstance(update, UpdatePinnedMessages): - portal = po.Portal.get_by_entity(update.peer, receiver_id=self.tgid) + portal = await po.Portal.get_by_entity(update.peer, tg_receiver=self.tgid) else: - portal = po.Portal.get_by_tgid(TelegramID(update.channel_id)) + portal = await po.Portal.get_by_tgid(TelegramID(update.channel_id)) if portal and portal.mxid: await portal.receive_telegram_pin_ids(update.messages, self.tgid, remove=not update.pinned) @staticmethod async def update_participants(update: UpdateChatParticipants) -> None: - portal = po.Portal.get_by_tgid(TelegramID(update.participants.chat_id)) + portal = await po.Portal.get_by_tgid(TelegramID(update.participants.chat_id)) if portal and portal.mxid: await portal.update_power_levels(update.participants.participants) @@ -302,30 +316,36 @@ class AbstractUser(ABC): self.log.debug("Unexpected read receipt peer: %s", update.peer) return - portal = po.Portal.get_by_tgid(TelegramID(update.peer.user_id), self.tgid) + portal = await po.Portal.get_by_tgid( + TelegramID(update.peer.user_id), tg_receiver=self.tgid + ) if not portal or not portal.mxid: return # We check that these are user read receipts, so tg_space is always the user ID. - message = DBMessage.get_one_by_tgid(TelegramID(update.max_id), self.tgid, edit_index=-1) + message = await DBMessage.get_one_by_tgid(TelegramID(update.max_id), self.tgid, + edit_index=-1) if not message: return - puppet = pu.Puppet.get(TelegramID(update.peer.user_id)) + puppet = await pu.Puppet.get_by_tgid(TelegramID(update.peer.user_id)) await puppet.intent.mark_read(portal.mxid, message.mxid) - async def update_own_read_receipt(self, update: Union[UpdateReadHistoryInbox, - UpdateReadChannelInbox]) -> None: - puppet = pu.Puppet.get(self.tgid) + async def update_own_read_receipt( + self, update: UpdateReadHistoryInbox | UpdateReadChannelInbox + ) -> None: + puppet = await pu.Puppet.get_by_tgid(self.tgid) if not puppet.is_real_user: return if isinstance(update, UpdateReadChannelInbox): - portal = po.Portal.get_by_tgid(TelegramID(update.channel_id)) + portal = await po.Portal.get_by_tgid(TelegramID(update.channel_id)) elif isinstance(update.peer, PeerChat): - portal = po.Portal.get_by_tgid(TelegramID(update.peer.chat_id)) + portal = await po.Portal.get_by_tgid(TelegramID(update.peer.chat_id)) elif isinstance(update.peer, PeerUser): - portal = po.Portal.get_by_tgid(TelegramID(update.peer.user_id), self.tgid) + portal = await po.Portal.get_by_tgid( + TelegramID(update.peer.user_id), tg_receiver=self.tgid + ) else: self.log.debug("Unexpected own read receipt peer: %s", update.peer) return @@ -334,7 +354,8 @@ class AbstractUser(ABC): return tg_space = portal.tgid if portal.peer_type == "channel" else self.tgid - message = DBMessage.get_one_by_tgid(TelegramID(update.max_id), tg_space, edit_index=-1) + message = await DBMessage.get_one_by_tgid(TelegramID(update.max_id), tg_space, + edit_index=-1) if not message: return @@ -342,21 +363,25 @@ class AbstractUser(ABC): async def update_admin(self, update: UpdateChatParticipantAdmin) -> None: # TODO duplication not checked - portal = po.Portal.get_by_tgid(TelegramID(update.chat_id)) + portal = await po.Portal.get_by_tgid(TelegramID(update.chat_id)) if not portal or not portal.mxid: return await portal.set_telegram_admin(TelegramID(update.user_id)) - async def update_typing(self, update: UpdateTyping) -> None: + async def update_typing( + self, update: UpdateUserTyping | UpdateChatUserTyping | UpdateChannelUserTyping + ) -> None: sender = None if isinstance(update, UpdateUserTyping): - portal = po.Portal.get_by_tgid(TelegramID(update.user_id), self.tgid, "user") - sender = pu.Puppet.get(TelegramID(update.user_id)) + portal = await po.Portal.get_by_tgid( + TelegramID(update.user_id), tg_receiver=self.tgid, peer_type="user" + ) + sender = await pu.Puppet.get_by_tgid(TelegramID(update.user_id)) elif isinstance(update, UpdateChannelUserTyping): - portal = po.Portal.get_by_tgid(TelegramID(update.channel_id)) + portal = await po.Portal.get_by_tgid(TelegramID(update.channel_id)) elif isinstance(update, UpdateChatUserTyping): - portal = po.Portal.get_by_tgid(TelegramID(update.chat_id)) + portal = await po.Portal.get_by_tgid(TelegramID(update.chat_id)) else: return @@ -364,26 +389,25 @@ class AbstractUser(ABC): # Can typing notifications come from non-user peers? if not update.from_id.user_id: return - sender = pu.Puppet.get(TelegramID(update.from_id.user_id)) + sender = await pu.Puppet.get_by_tgid(TelegramID(update.from_id.user_id)) if not sender or not portal or not portal.mxid: return await portal.handle_telegram_typing(sender, update) - async def _handle_entity_updates(self, entities: Dict[int, Union[User, Chat, Channel]] - ) -> None: + async def _handle_entity_updates(self, entities: dict[int, User | Chat | Channel]) -> None: try: users = (entity for entity in entities.values() if isinstance(entity, User)) - puppets = ((pu.Puppet.get(TelegramID(user.id)), user) for user in users) + puppets = ((await pu.Puppet.get_by_tgid(TelegramID(user.id)), user) for user in users) await asyncio.gather(*[puppet.try_update_info(self, info) - for puppet, info in puppets if puppet]) + async for puppet, info in puppets if puppet]) except Exception: self.log.exception("Failed to handle entity updates") - async def update_others_info(self, update: Union[UpdateUserName, UpdateUserPhoto]) -> None: + async def update_others_info(self, update: UpdateUserName | UpdateUserPhoto) -> None: # TODO duplication not checked - puppet = pu.Puppet.get(TelegramID(update.user_id)) + puppet = await pu.Puppet.get_by_tgid(TelegramID(update.user_id)) if isinstance(update, UpdateUserName): puppet.username = update.username if await puppet.update_displayname(self, update): @@ -395,7 +419,7 @@ class AbstractUser(ABC): self.log.warning(f"Unexpected other user info update: {type(update)}") async def update_status(self, update: UpdateUserStatus) -> None: - puppet = pu.Puppet.get(TelegramID(update.user_id)) + puppet = await pu.Puppet.get_by_tgid(TelegramID(update.user_id)) if isinstance(update.status, UserStatusOnline): await puppet.default_mxid_intent.set_presence(PresenceState.ONLINE) elif isinstance(update.status, UserStatusOffline): @@ -404,27 +428,29 @@ class AbstractUser(ABC): self.log.warning(f"Unexpected user status update: type({update})") return - def get_message_details(self, update: UpdateMessage) -> Tuple[UpdateMessageContent, - Optional[pu.Puppet], - Optional[po.Portal]]: + async def get_message_details( + self, update: UpdateMessage + ) -> tuple[UpdateMessageContent, pu.Puppet | None, po.Portal | None]: if isinstance(update, UpdateShortChatMessage): portal = po.Portal.get_by_tgid(TelegramID(update.chat_id)) if not portal: self.log.warning(f"Received message in chat with unknown type {update.chat_id}") - sender = pu.Puppet.get(TelegramID(update.from_id)) + sender = await pu.Puppet.get_by_tgid(TelegramID(update.from_id)) elif isinstance(update, UpdateShortMessage): - portal = po.Portal.get_by_tgid(TelegramID(update.user_id), self.tgid, "user") - sender = pu.Puppet.get(self.tgid if update.out else update.user_id) + portal = await po.Portal.get_by_tgid( + TelegramID(update.user_id), tg_receiver=self.tgid, peer_type="user" + ) + sender = await pu.Puppet.get_by_tgid(self.tgid if update.out else update.user_id) elif isinstance(update, (UpdateNewMessage, UpdateNewChannelMessage, UpdateEditMessage, UpdateEditChannelMessage)): update = update.message if isinstance(update, MessageEmpty): return update, None, None - portal = po.Portal.get_by_entity(update.peer_id, receiver_id=self.tgid) + portal = await po.Portal.get_by_entity(update.peer_id, tg_receiver=self.tgid) if update.out: - sender = pu.Puppet.get(self.tgid) + sender = await pu.Puppet.get_by_tgid(self.tgid) elif isinstance(update.from_id, PeerUser): - sender = pu.Puppet.get(TelegramID(update.from_id.user_id)) + sender = await pu.Puppet.get_by_tgid(TelegramID(update.from_id.user_id)) else: sender = None else: @@ -435,7 +461,7 @@ class AbstractUser(ABC): @staticmethod async def _try_redact(message: DBMessage) -> None: - portal = po.Portal.get_by_mxid(message.mx_room) + portal = await po.Portal.get_by_mxid(message.mx_room) if not portal: return try: @@ -444,33 +470,33 @@ class AbstractUser(ABC): pass async def delete_message(self, update: UpdateDeleteMessages) -> None: - if len(update.messages) > MAX_DELETIONS: + if len(update.messages) > self.max_deletions: return for message_id in update.messages: - for message in DBMessage.get_all_by_tgid(TelegramID(message_id), self.tgid): + for message in await DBMessage.get_all_by_tgid(TelegramID(message_id), self.tgid): if message.redacted: continue - message.delete() - number_left = DBMessage.count_spaces_by_mxid(message.mxid, message.mx_room) + await message.delete() + number_left = await DBMessage.count_spaces_by_mxid(message.mxid, message.mx_room) if number_left == 0: await self._try_redact(message) async def delete_channel_message(self, update: UpdateDeleteChannelMessages) -> None: - if len(update.messages) > MAX_DELETIONS: + if len(update.messages) > self.max_deletions: return channel_id = TelegramID(update.channel_id) for message_id in update.messages: - for message in DBMessage.get_all_by_tgid(TelegramID(message_id), channel_id): + for message in await DBMessage.get_all_by_tgid(TelegramID(message_id), channel_id): if message.redacted: continue - message.delete() + await message.delete() await self._try_redact(message) async def update_message(self, original_update: UpdateMessage) -> None: - update, sender, portal = self.get_message_details(original_update) + update, sender, portal = await self.get_message_details(original_update) if not portal: return elif portal and not portal.allow_bridging: @@ -479,10 +505,10 @@ class AbstractUser(ABC): if self.is_relaybot: if update.is_private: - if not config["bridge.relaybot.private_chat.invite"]: + if not self.config["bridge.relaybot.private_chat.invite"]: self.log.debug(f"Ignoring private message to bot from {sender.id}") return - elif not portal.mxid and config["bridge.relaybot.ignore_unbridged_group_chat"]: + elif not portal.mxid and self.config["bridge.relaybot.ignore_unbridged_group_chat"]: self.log.debug("Ignoring message received by bot" f" in unbridged chat {portal.tgid_log}") return @@ -492,7 +518,7 @@ class AbstractUser(ABC): self.log.debug(f"Ignoring relaybot-sent message %s to %s", update.id, portal.tgid_log) return - await portal.backfill_lock.wait(update.id) + await portal.backfill_lock.wait(f"update {update.id}") if isinstance(update, MessageService): if isinstance(update.action, MessageActionChannelMigrateFrom): @@ -510,12 +536,3 @@ class AbstractUser(ABC): return await portal.handle_telegram_message(self, sender, update) # endregion - - -def init(context: 'Context') -> None: - global config, MAX_DELETIONS - AbstractUser.az, config, AbstractUser.loop, AbstractUser.relaybot = context.core - AbstractUser.bridge = context.bridge - AbstractUser.ignore_incoming_bot_events = config["bridge.relaybot.ignore_own_incoming_events"] - AbstractUser.session_container = context.session_container - MAX_DELETIONS = config.get("bridge.max_telegram_delete", 10) diff --git a/mautrix_telegram/bot.py b/mautrix_telegram/bot.py index 9d3ffb79..ea2e67a7 100644 --- a/mautrix_telegram/bot.py +++ b/mautrix_telegram/bot.py @@ -1,5 +1,5 @@ # mautrix-telegram - A Matrix-Telegram puppeting bridge -# Copyright (C) 2019 Tulir Asokan +# Copyright (C) 2021 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -33,10 +33,6 @@ from .db import BotChat from .types import TelegramID from . import puppet as pu, portal as po, user as u -if TYPE_CHECKING: - from .config import Config - -config: Optional['Config'] = None ReplyFunc = Callable[[str], Awaitable[Message]] @@ -59,12 +55,12 @@ class Bot(AbstractUser): self.puppet_whitelisted = True self.whitelisted = True self.relaybot_whitelisted = True - self.username = None + self.tg_username = None self.is_relaybot = True self.is_bot = True self.chats = {} self.tg_whitelist = [] - self.whitelist_group_admins = (config["bridge.relaybot.whitelist_group_admins"] + self.whitelist_group_admins = (self.config["bridge.relaybot.whitelist_group_admins"] or False) self._me_info = None self._me_mxid = None @@ -76,7 +72,7 @@ class Bot(AbstractUser): return self._me_info, self._me_mxid async def init_permissions(self) -> None: - whitelist = config["bridge.relaybot.whitelist"] or [] + whitelist = self.config["bridge.relaybot.whitelist"] or [] for user_id in whitelist: if isinstance(user_id, str): entity = await self.client.get_input_entity(user_id) @@ -88,7 +84,7 @@ class Bot(AbstractUser): self.tg_whitelist.append(user_id) async def start(self, delete_unless_authenticated: bool = False) -> 'Bot': - self.chats = {chat.id: chat.type for chat in BotChat.all()} + self.chats = {chat.id: chat.type for chat in await BotChat.all()} await super().start(delete_unless_authenticated) if not await self.is_logged_in(): await self.client.sign_in(bot_token=self.token) @@ -99,14 +95,14 @@ class Bot(AbstractUser): await self.init_permissions() info = await self.client.get_me() self.tgid = TelegramID(info.id) - self.username = info.username + self.tg_username = info.username self.mxid = pu.Puppet.get_mxid_from_id(self.tgid) chat_ids = [chat_id for chat_id, chat_type in self.chats.items() if chat_type == "chat"] response = await self.client(GetChatsRequest(chat_ids)) for chat in response.chats: if isinstance(chat, ChatForbidden) or chat.left or chat.deactivated: - self.remove_chat(TelegramID(chat.id)) + await self.remove_chat(TelegramID(chat.id)) channel_ids = [InputChannel(chat_id, 0) for chat_id, chat_type in self.chats.items() @@ -115,31 +111,31 @@ class Bot(AbstractUser): try: await self.client(GetChannelsRequest([channel_id])) except (ChannelPrivateError, ChannelInvalidError): - self.remove_chat(TelegramID(channel_id.channel_id)) + await self.remove_chat(TelegramID(channel_id.channel_id)) async def register_portal(self, portal: po.Portal) -> None: - self.add_chat(portal.tgid, portal.peer_type) + await self.add_chat(portal.tgid, portal.peer_type) - async def unregister_portal(self, tgid: int, tg_receiver: int) -> None: - self.remove_chat(tgid) + async def unregister_portal(self, tgid: TelegramID, tg_receiver: TelegramID) -> None: + await self.remove_chat(tgid) - def add_chat(self, chat_id: TelegramID, chat_type: str) -> None: + async def add_chat(self, chat_id: TelegramID, chat_type: str) -> None: if chat_id not in self.chats: self.chats[chat_id] = chat_type - BotChat(id=TelegramID(chat_id), type=chat_type).insert() + await BotChat(id=chat_id, type=chat_type).insert() - def remove_chat(self, chat_id: TelegramID) -> None: + async def remove_chat(self, chat_id: TelegramID) -> None: try: del self.chats[chat_id] except KeyError: pass - BotChat.delete_by_id(chat_id) + await BotChat.delete_by_id(chat_id) async def _can_use_commands(self, chat: TypePeer, tgid: TelegramID) -> bool: if tgid in self.tg_whitelist: return True - user = u.User.get_by_tgid(tgid) + user = await u.User.get_by_tgid(tgid) if user and user.is_admin: self.tg_whitelist.append(user.tgid) return True @@ -157,13 +153,14 @@ class Bot(AbstractUser): return False async def check_can_use_commands(self, event: Message, reply: ReplyFunc) -> bool: + # FIXME event.from_id is not int if not await self._can_use_commands(event.to_id, TelegramID(event.from_id)): await reply("You do not have the permission to use that command.") return False return True async def handle_command_portal(self, portal: po.Portal, reply: ReplyFunc) -> Message: - if not config["bridge.relaybot.authless_portals"]: + if not self.config["bridge.relaybot.authless_portals"]: return await reply("This bridge doesn't allow portal creation from Telegram.") if not portal.allow_bridging: @@ -187,11 +184,11 @@ class Bot(AbstractUser): "Create one with /portal first.") if mxid_input[0] != '@' or mxid_input.find(':') < 2: return await reply("That doesn't look like a Matrix ID.") - user = await u.User.get_by_mxid(mxid_input).ensure_started() + user = await u.User.get_and_start_by_mxid(mxid_input) if not user.relaybot_whitelisted: return await reply("That user is not whitelisted to use the bridge.") elif await user.is_logged_in(): - displayname = f"@{user.username}" if user.username else user.displayname + displayname = f"@{user.tg_username}" if user.tg_username else user.displayname return await reply("That user seems to be logged in. " f"Just invite [{displayname}](tg://user?id={user.tgid})") else: @@ -214,7 +211,7 @@ class Bot(AbstractUser): def match_command(self, text: str, command: str) -> bool: text = text.lower() command = f"/{command.lower()}" - command_targeted = f"{command}@{self.username.lower()}" + command_targeted = f"{command}@{self.tg_username.lower()}" is_plain_command = text == command or text == command_targeted if is_plain_command: @@ -233,7 +230,7 @@ class Bot(AbstractUser): text = message.message if self.match_command(text, "start"): - pcm = config["bridge.relaybot.private_chat.message"] + pcm = self.config["bridge.relaybot.private_chat.message"] if pcm: await reply(pcm) return @@ -243,7 +240,7 @@ class Bot(AbstractUser): elif message.is_private: return - portal = po.Portal.get_by_entity(message.to_id) + portal = await po.Portal.get_by_entity(message.to_id) is_portal_cmd = self.match_command(text, "portal") is_invite_cmd = self.match_command(text, "invite") @@ -259,7 +256,7 @@ class Bot(AbstractUser): mxid = "" await self.handle_command_invite(portal, reply, mxid_input=UserID(mxid)) - def handle_service_message(self, message: MessageService) -> None: + async def handle_service_message(self, message: MessageService) -> None: to_peer = message.to_id if isinstance(to_peer, PeerChannel): to_id = TelegramID(to_peer.channel_id) @@ -272,18 +269,18 @@ class Bot(AbstractUser): action = message.action if isinstance(action, MessageActionChatAddUser) and self.tgid in action.users: - self.add_chat(to_id, chat_type) + await self.add_chat(to_id, chat_type) elif isinstance(action, MessageActionChatDeleteUser) and action.user_id == self.tgid: - self.remove_chat(to_id) + await self.remove_chat(to_id) elif isinstance(action, MessageActionChatMigrateTo): - self.remove_chat(to_id) - self.add_chat(TelegramID(action.channel_id), "channel") + await self.remove_chat(to_id) + await self.add_chat(TelegramID(action.channel_id), "channel") async def update(self, update) -> bool: if not isinstance(update, (UpdateNewMessage, UpdateNewChannelMessage)): return False if isinstance(update.message, MessageService): - self.handle_service_message(update.message) + await self.handle_service_message(update.message) return False is_command = (isinstance(update.message, Message) @@ -300,12 +297,3 @@ class Bot(AbstractUser): @property def name(self) -> str: return "bot" - - -def init(cfg: 'Config') -> Optional[Bot]: - global config - config = cfg - token = config["telegram.bot_token"] - if token and not token.lower().startswith("disable"): - return Bot(token) - return None diff --git a/mautrix_telegram/commands/handler.py b/mautrix_telegram/commands/handler.py index 958dc60a..0f09a98c 100644 --- a/mautrix_telegram/commands/handler.py +++ b/mautrix_telegram/commands/handler.py @@ -13,8 +13,9 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -"""This module contains classes handling commands issued by Matrix users.""" -from typing import Awaitable, Callable, List, Optional, NamedTuple, Any +from __future__ import annotations + +from typing import Awaitable, Callable, NamedTuple, Any, TYPE_CHECKING from telethon.errors import FloodWaitError @@ -25,7 +26,10 @@ from mautrix.bridge.commands import (HelpSection, CommandEvent as BaseCommandEve CommandHandlerFunc, command_handler as base_command_handler) from mautrix.util.format_duration import format_duration -from .. import user as u, context as c, portal as po +from .. import user as u, portal as po + +if TYPE_CHECKING: + from ..__main__ import TelegramBridge class HelpCacheKey(NamedTuple): @@ -48,9 +52,9 @@ class CommandEvent(BaseCommandEvent): sender: u.User portal: po.Portal - def __init__(self, processor: 'CommandProcessor', room_id: RoomID, event_id: EventID, - sender: u.User, command: str, args: List[str], content: MessageEventContent, - portal: Optional['po.Portal'], is_management: bool, has_bridge_bot: bool) -> None: + def __init__(self, processor: CommandProcessor, room_id: RoomID, event_id: EventID, + sender: u.User, command: str, args: list[str], content: MessageEventContent, + portal: po.Portal | None, is_management: bool, has_bridge_bot: bool) -> None: super().__init__(processor, room_id, event_id, sender, command, args, content, portal, is_management, has_bridge_bot) self.bridge = processor.bridge @@ -83,7 +87,7 @@ class CommandHandler(BaseCommandHandler): needs_matrix_puppeting=needs_matrix_puppeting, needs_admin=needs_admin, **kwargs) - async def get_permission_error(self, evt: CommandEvent) -> Optional[str]: + async def get_permission_error(self, evt: CommandEvent) -> str | None: if self.needs_puppeting and not evt.sender.puppet_whitelisted: return "This command requires puppeting privileges." elif self.needs_matrix_puppeting and not evt.sender.matrix_puppet_whitelisted: @@ -96,10 +100,10 @@ class CommandHandler(BaseCommandHandler): (not self.needs_matrix_puppeting or key.matrix_puppet_whitelisted)) -def command_handler(_func: Optional[CommandHandlerFunc] = None, *, needs_auth: bool = True, +def command_handler(_func: CommandHandlerFunc | None = None, *, needs_auth: bool = True, needs_puppeting: bool = True, needs_matrix_puppeting: bool = False, needs_admin: bool = False, management_only: bool = False, - name: Optional[str] = None, help_text: str = "", help_args: str = "", + name: str | None = None, help_text: str = "", help_args: str = "", help_section: HelpSection = None) -> Callable[[CommandHandlerFunc], CommandHandler]: return base_command_handler( @@ -110,10 +114,10 @@ def command_handler(_func: Optional[CommandHandlerFunc] = None, *, needs_auth: b class CommandProcessor(BaseCommandProcessor): - def __init__(self, context: c.Context) -> None: - super().__init__(event_class=CommandEvent, bridge=context.bridge) - self.tgbot = context.bot - self.public_website = context.public_website + def __init__(self, bridge: 'TelegramBridge') -> None: + super().__init__(event_class=CommandEvent, bridge=bridge) + self.tgbot = bridge.bot + self.public_website = bridge.public_website @staticmethod async def _run_handler(handler: Callable[[CommandEvent], Awaitable[Any]], evt: CommandEvent diff --git a/mautrix_telegram/commands/matrix_auth.py b/mautrix_telegram/commands/matrix_auth.py index 6fca9f6b..cc85d693 100644 --- a/mautrix_telegram/commands/matrix_auth.py +++ b/mautrix_telegram/commands/matrix_auth.py @@ -24,7 +24,7 @@ from .. import puppet as pu help_section=SECTION_AUTH, help_text="Revert your Telegram account's Matrix " "puppet to use the default Matrix account.") async def logout_matrix(evt: CommandEvent) -> EventID: - puppet = pu.Puppet.get(evt.sender.tgid) + puppet = await pu.Puppet.get_by_tgid(evt.sender.tgid) if not puppet.is_real_user: return await evt.reply("You are not logged in with your Matrix account.") await puppet.switch_mxid(None, None) @@ -36,7 +36,7 @@ async def logout_matrix(evt: CommandEvent) -> EventID: help_text="Replace your Telegram account's Matrix puppet with your own Matrix " "account.") async def login_matrix(evt: CommandEvent) -> EventID: - puppet = pu.Puppet.get(evt.sender.tgid) + puppet = await pu.Puppet.get_by_tgid(evt.sender.tgid) if puppet.is_real_user: return await evt.reply("You have already logged in with your Matrix account. " "Log out with `$cmdprefix+sp logout-matrix` first.") @@ -71,7 +71,7 @@ async def login_matrix(evt: CommandEvent) -> EventID: help_section=SECTION_AUTH, help_text="Pings the server with the stored matrix authentication.") async def ping_matrix(evt: CommandEvent) -> EventID: - puppet = pu.Puppet.get(evt.sender.tgid) + puppet = await pu.Puppet.get_by_tgid(evt.sender.tgid) if not puppet.is_real_user: return await evt.reply("You are not logged in with your Matrix account.") try: @@ -84,7 +84,7 @@ async def ping_matrix(evt: CommandEvent) -> EventID: @command_handler(needs_auth=True, needs_matrix_puppeting=True, help_section=SECTION_AUTH, help_text="Clear the Matrix sync token stored for your custom puppet.") async def clear_cache_matrix(evt: CommandEvent) -> EventID: - puppet = pu.Puppet.get(evt.sender.tgid) + puppet = await pu.Puppet.get_by_tgid(evt.sender.tgid) if not puppet.is_real_user: return await evt.reply("You are not logged in with your Matrix account.") try: @@ -99,7 +99,7 @@ async def clear_cache_matrix(evt: CommandEvent) -> EventID: async def enter_matrix_token(evt: CommandEvent) -> EventID: evt.sender.command_status = None - puppet = pu.Puppet.get(evt.sender.tgid) + puppet = await pu.Puppet.get_by_tgid(evt.sender.tgid) if puppet.is_real_user: return await evt.reply("You have already logged in with your Matrix account. " "Log out with `$cmdprefix+sp logout-matrix` first.") diff --git a/mautrix_telegram/commands/portal/admin.py b/mautrix_telegram/commands/portal/admin.py index 1020f101..2bff444f 100644 --- a/mautrix_telegram/commands/portal/admin.py +++ b/mautrix_telegram/commands/portal/admin.py @@ -1,5 +1,5 @@ # mautrix-telegram - A Matrix-Telegram puppeting bridge -# Copyright (C) 2019 Tulir Asokan +# Copyright (C) 2021 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -35,12 +35,13 @@ async def clear_db_cache(evt: CommandEvent) -> EventID: po.Portal.by_mxid = {} await evt.reply("Cleared portal cache") elif section == "puppet": - pu.Puppet.cache = {} + pu.Puppet.by_tgid = {} for puppet in pu.Puppet.by_custom_mxid.values(): - puppet.sync_task.cancel() + puppet.stop() pu.Puppet.by_custom_mxid = {} - await asyncio.gather(*[puppet.try_start() for puppet in pu.Puppet.all_with_custom_mxid()], - loop=evt.loop) + await asyncio.gather( + *[puppet.try_start() async for puppet in pu.Puppet.all_with_custom_mxid()] + ) await evt.reply("Cleared puppet cache and restarted custom puppet syncers") elif section == "user": u.User.by_mxid = { @@ -61,15 +62,16 @@ async def reload_user(evt: CommandEvent) -> EventID: mxid = evt.args[0] else: mxid = evt.sender.mxid - user = u.User.get_by_mxid(mxid, create=False) + user = await u.User.get_by_mxid(mxid, create=False) if not user: return await evt.reply("User not found") puppet = await pu.Puppet.get_by_custom_mxid(mxid) if puppet: - puppet.sync_task.cancel() + puppet.stop() await user.stop() - user.delete(delete_db=False) - user = u.User.get_by_mxid(mxid) + del u.User.by_tgid[user.tgid] + del u.User.by_mxid[user.mxid] + user = await u.User.get_by_mxid(mxid) await user.ensure_started() if puppet: await puppet.start() diff --git a/mautrix_telegram/commands/portal/bridge.py b/mautrix_telegram/commands/portal/bridge.py index 33f1a779..6b67ff22 100644 --- a/mautrix_telegram/commands/portal/bridge.py +++ b/mautrix_telegram/commands/portal/bridge.py @@ -13,7 +13,9 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Optional, Tuple, Awaitable +from __future__ import annotations + +from typing import Awaitable import asyncio from telethon.tl.types import ChatForbidden, ChannelForbidden @@ -43,7 +45,7 @@ async def bridge(evt: CommandEvent) -> EventID: room_id = RoomID(evt.args[1]) if len(evt.args) > 1 else evt.room_id that_this = "This" if room_id == evt.room_id else "That" - portal = po.Portal.get_by_mxid(room_id) + portal = await po.Portal.get_by_mxid(room_id) if portal: return await evt.reply(f"{that_this} room is already a portal room.") @@ -64,7 +66,7 @@ async def bridge(evt: CommandEvent) -> EventID: "prefix channel IDs with `-100` and normal group IDs with `-`.\n\n" "Bridging private chats to existing rooms is not allowed.") - portal = po.Portal.get_by_tgid(tgid, peer_type=peer_type) + portal = await po.Portal.get_by_tgid(tgid, peer_type=peer_type) if not portal.allow_bridging: return await evt.reply("This bridge doesn't allow bridging that Telegram chat.\n" "If you're the bridge admin, try " @@ -105,7 +107,7 @@ async def bridge(evt: CommandEvent) -> EventID: async def cleanup_old_portal_while_bridging(evt: CommandEvent, portal: "po.Portal" - ) -> Tuple[bool, Optional[Awaitable[None]]]: + ) -> tuple[bool, Awaitable[None] | None]: if not portal.mxid: await evt.reply("The portal seems to have lost its Matrix room between you" "calling `$cmdprefix+sp bridge` and this command.\n\n" @@ -126,10 +128,10 @@ async def cleanup_old_portal_while_bridging(evt: CommandEvent, portal: "po.Porta return False, None -async def confirm_bridge(evt: CommandEvent) -> Optional[EventID]: +async def confirm_bridge(evt: CommandEvent) -> EventID | None: status = evt.sender.command_status try: - portal = po.Portal.get_by_tgid(status["tgid"], peer_type=status["peer_type"]) + portal = await po.Portal.get_by_tgid(status["tgid"], peer_type=status["peer_type"]) bridge_to_mxid = status["bridge_to_mxid"] except KeyError: evt.sender.command_status = None @@ -162,7 +164,7 @@ async def confirm_bridge(evt: CommandEvent) -> Optional[EventID]: async def _locked_confirm_bridge(evt: CommandEvent, portal: 'po.Portal', room_id: RoomID, - is_logged_in: bool) -> Optional[EventID]: + is_logged_in: bool) -> EventID | None: user = evt.sender if is_logged_in else evt.tgbot try: entity = await user.client.get_entity(portal.peer) diff --git a/mautrix_telegram/commands/portal/config.py b/mautrix_telegram/commands/portal/config.py index 1e8e649e..5a867c7f 100644 --- a/mautrix_telegram/commands/portal/config.py +++ b/mautrix_telegram/commands/portal/config.py @@ -37,7 +37,7 @@ async def config(evt: CommandEvent) -> None: await config_defaults(evt) return - portal = po.Portal.get_by_mxid(evt.room_id) + portal = await po.Portal.get_by_mxid(evt.room_id) if not portal: await evt.reply("This is not a portal room.") return diff --git a/mautrix_telegram/commands/portal/create_chat.py b/mautrix_telegram/commands/portal/create_chat.py index fb75e15f..7c304a68 100644 --- a/mautrix_telegram/commands/portal/create_chat.py +++ b/mautrix_telegram/commands/portal/create_chat.py @@ -32,7 +32,7 @@ async def create(evt: CommandEvent) -> EventID: return await evt.reply( "**Usage:** `$cmdprefix+sp create ['group'/'supergroup'/'channel']`") - if po.Portal.get_by_mxid(evt.room_id): + if await po.Portal.get_by_mxid(evt.room_id): return await evt.reply("This is already a portal room.") if not await user_has_power_level(evt.room_id, evt.az.intent, evt.sender, "bridge"): @@ -50,8 +50,8 @@ async def create(evt: CommandEvent) -> EventID: "group": "chat", }[type] - portal = po.Portal(tgid=TelegramID(0), peer_type=type, mxid=evt.room_id, - title=title, about=about, encrypted=encrypted) + portal = po.Portal(tgid=TelegramID(0), tg_receiver=TelegramID(0), peer_type=type, + mxid=evt.room_id, title=title, about=about, encrypted=encrypted) invites, errors = await portal.get_telegram_users_in_matrix_room(evt.sender) if len(errors) > 0: error_list = "\n".join(f"* [{mxid}](https://matrix.to/#/{mxid})" for mxid in errors) diff --git a/mautrix_telegram/commands/portal/misc.py b/mautrix_telegram/commands/portal/misc.py index b9720684..cd4dddba 100644 --- a/mautrix_telegram/commands/portal/misc.py +++ b/mautrix_telegram/commands/portal/misc.py @@ -1,5 +1,5 @@ # mautrix-telegram - A Matrix-Telegram puppeting bridge -# Copyright (C) 2019 Tulir Asokan +# Copyright (C) 2021 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -13,7 +13,8 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Optional, List, Tuple +from __future__ import annotations + from datetime import timedelta, datetime import re @@ -33,7 +34,7 @@ from .util import user_has_power_level help_section=SECTION_MISC, help_text="Fetch Matrix room state to ensure the bridge has up-to-date info.") async def sync_state(evt: CommandEvent) -> EventID: - portal = po.Portal.get_by_mxid(evt.room_id) + portal = await po.Portal.get_by_mxid(evt.room_id) if not portal: return await evt.reply("This is not a portal room.") elif not await user_has_power_level(evt.room_id, evt.az.intent, evt.sender, "bridge"): @@ -46,7 +47,7 @@ async def sync_state(evt: CommandEvent) -> EventID: @command_handler(needs_admin=False, needs_puppeting=False, needs_auth=False, help_section=SECTION_MISC) async def sync_full(evt: CommandEvent) -> EventID: - portal = po.Portal.get_by_mxid(evt.room_id) + portal = await po.Portal.get_by_mxid(evt.room_id) if not portal: return await evt.reply("This is not a portal room.") @@ -73,7 +74,7 @@ async def sync_full(evt: CommandEvent) -> EventID: help_section=SECTION_MISC, help_text="Get the ID of the Telegram chat where this room is bridged.") async def get_id(evt: CommandEvent) -> EventID: - portal = po.Portal.get_by_mxid(evt.room_id) + portal = await po.Portal.get_by_mxid(evt.room_id) if not portal: return await evt.reply("This is not a portal room.") tgid = portal.tgid @@ -92,7 +93,7 @@ invite_link_usage = ("**Usage:** `$cmdprefix+sp invite-link [--uses=] [- " A number suffixed with d(ay), h(our), m(inute) or s(econd)") -def _parse_flag(args: List[str]) -> Tuple[str, str]: +def _parse_flag(args: list[str]) -> tuple[str, str]: arg = args.pop(0).lower() if arg.startswith("--"): value_start = arg.index("=") @@ -116,7 +117,7 @@ def _parse_flag(args: List[str]) -> Tuple[str, str]: delta_regex = re.compile("([0-9]+)(w(?:eek)?|d(?:ay)?|h(?:our)?|m(?:in(?:ute)?)?|s(?:ec(?:ond)?)?)") -def _parse_delta(value: str) -> Optional[timedelta]: +def _parse_delta(value: str) -> timedelta | None: match = delta_regex.fullmatch(value) if not match: return None @@ -159,7 +160,7 @@ async def invite_link(evt: CommandEvent) -> EventID: await evt.reply("Invalid format for expiry time delta") expire = datetime.now() + expire_delta - portal = po.Portal.get_by_mxid(evt.room_id) + portal = await po.Portal.get_by_mxid(evt.room_id) if not portal: return await evt.reply("This is not a portal room.") @@ -178,7 +179,7 @@ async def invite_link(evt: CommandEvent) -> EventID: @command_handler(help_section=SECTION_PORTAL_MANAGEMENT, help_text="Upgrade a normal Telegram group to a supergroup.") async def upgrade(evt: CommandEvent) -> EventID: - portal = po.Portal.get_by_mxid(evt.room_id) + portal = await po.Portal.get_by_mxid(evt.room_id) if not portal: return await evt.reply("This is not a portal room.") elif portal.peer_type == "channel": @@ -203,7 +204,7 @@ async def group_name(evt: CommandEvent) -> EventID: if len(evt.args) == 0: return await evt.reply("**Usage:** `$cmdprefix+sp group-name `") - portal = po.Portal.get_by_mxid(evt.room_id) + portal = await po.Portal.get_by_mxid(evt.room_id) if not portal: return await evt.reply("This is not a portal room.") elif portal.peer_type != "channel": diff --git a/mautrix_telegram/commands/portal/unbridge.py b/mautrix_telegram/commands/portal/unbridge.py index 8a0a08a6..10f67995 100644 --- a/mautrix_telegram/commands/portal/unbridge.py +++ b/mautrix_telegram/commands/portal/unbridge.py @@ -1,5 +1,5 @@ # mautrix-telegram - A Matrix-Telegram puppeting bridge -# Copyright (C) 2019 Tulir Asokan +# Copyright (C) 2021 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -13,7 +13,9 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Dict, Callable, Optional +from __future__ import annotations + +from typing import Callable from mautrix.types import RoomID, EventID @@ -22,10 +24,10 @@ from .. import command_handler, CommandEvent, SECTION_PORTAL_MANAGEMENT from .util import user_has_power_level -async def _get_portal_and_check_permission(evt: CommandEvent) -> Optional[po.Portal]: +async def _get_portal_and_check_permission(evt: CommandEvent) -> po.Portal | None: room_id = RoomID(evt.args[0]) if len(evt.args) > 0 else evt.room_id - portal = po.Portal.get_by_mxid(room_id) + portal = await po.Portal.get_by_mxid(room_id) if not portal: that_this = "This" if room_id == evt.room_id else "That" await evt.reply(f"{that_this} is not a portal room.") @@ -44,8 +46,8 @@ async def _get_portal_and_check_permission(evt: CommandEvent) -> Optional[po.Por def _get_portal_murder_function(action: str, room_id: str, function: Callable, command: str, - completed_message: str) -> Dict: - async def post_confirm(confirm) -> Optional[EventID]: + completed_message: str) -> dict: + async def post_confirm(confirm) -> EventID | None: confirm.sender.command_status = None if len(confirm.args) > 0 and confirm.args[0] == f"confirm-{command}": await function() @@ -66,7 +68,7 @@ def _get_portal_murder_function(action: str, room_id: str, function: Callable, c help_text="Remove all users from the current portal room and forget the portal. " "Only works for group chats; to delete a private chat portal, simply " "leave the room.") -async def delete_portal(evt: CommandEvent) -> Optional[EventID]: +async def delete_portal(evt: CommandEvent) -> EventID | None: portal = await _get_portal_and_check_permission(evt) if not portal: return None @@ -87,7 +89,7 @@ async def delete_portal(evt: CommandEvent) -> Optional[EventID]: @command_handler(needs_auth=False, needs_puppeting=False, help_section=SECTION_PORTAL_MANAGEMENT, help_text="Remove puppets from the current portal room and forget the portal.") -async def unbridge(evt: CommandEvent) -> Optional[EventID]: +async def unbridge(evt: CommandEvent) -> EventID | None: portal = await _get_portal_and_check_permission(evt) if not portal: return None diff --git a/mautrix_telegram/commands/portal/util.py b/mautrix_telegram/commands/portal/util.py index c3fbe71f..1ff5488a 100644 --- a/mautrix_telegram/commands/portal/util.py +++ b/mautrix_telegram/commands/portal/util.py @@ -1,5 +1,5 @@ # mautrix-telegram - A Matrix-Telegram puppeting bridge -# Copyright (C) 2019 Tulir Asokan +# Copyright (C) 2021 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -13,7 +13,7 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Tuple, Optional +from __future__ import annotations from mautrix.errors import MatrixRequestError from mautrix.appservice import IntentAPI @@ -22,16 +22,14 @@ from .. import CommandEvent from ... import user as u -OptStr = Optional[str] - async def get_initial_state( intent: IntentAPI, room_id: RoomID -) -> Tuple[OptStr, OptStr, Optional[PowerLevelStateEventContent], bool]: +) -> tuple[str | None, str | None, PowerLevelStateEventContent | None, bool]: state = await intent.get_state(room_id) - title: OptStr = None - about: OptStr = None - levels: Optional[PowerLevelStateEventContent] = None + title: str | None = None + about: str | None = None + levels: PowerLevelStateEventContent | None = None encrypted: bool = False for event in state: try: diff --git a/mautrix_telegram/commands/telegram/account.py b/mautrix_telegram/commands/telegram/account.py index 52460e99..2f85b666 100644 --- a/mautrix_telegram/commands/telegram/account.py +++ b/mautrix_telegram/commands/telegram/account.py @@ -13,10 +13,9 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Optional - from telethon.errors import (UsernameInvalidError, UsernameNotModifiedError, UsernameOccupiedError, - HashInvalidError, AuthKeyError, FirstNameInvalidError, AboutTooLongError) + HashInvalidError, AuthKeyError, FirstNameInvalidError, + AboutTooLongError) from telethon.tl.types import Authorization from telethon.tl.functions.account import (UpdateUsernameRequest, GetAuthorizationsRequest, ResetAuthorizationRequest, UpdateProfileRequest) @@ -48,10 +47,11 @@ async def username(evt: CommandEvent) -> EventID: except UsernameOccupiedError: return await evt.reply("That username is already in use.") await evt.sender.update_info() - if not evt.sender.username: + if not evt.sender.tg_username: await evt.reply("Username removed") else: - await evt.reply(f"Username changed to {evt.sender.username}") + await evt.reply(f"Username changed to {evt.sender.tg_username}") + @command_handler(needs_auth=True, help_section=SECTION_AUTH, @@ -71,6 +71,7 @@ async def about(evt: CommandEvent) -> EventID: return await evt.reply("The provided about section is too long") return await evt.reply("About section updated") + @command_handler(needs_auth=True, help_section=SECTION_AUTH, help_args="<_new displayname_>", help_text="Change your Telegram displayname.") async def displayname(evt: CommandEvent) -> EventID: diff --git a/mautrix_telegram/commands/telegram/auth.py b/mautrix_telegram/commands/telegram/auth.py index 7b8e62d5..c3606629 100644 --- a/mautrix_telegram/commands/telegram/auth.py +++ b/mautrix_telegram/commands/telegram/auth.py @@ -1,5 +1,5 @@ # mautrix-telegram - A Matrix-Telegram puppeting bridge -# Copyright (C) 2019 Tulir Asokan +# Copyright (C) 2021 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -13,11 +13,13 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Any, Dict, Optional +from __future__ import annotations + +from typing import Any import asyncio import io -from telethon.errors import ( # isort: skip +from telethon.errors import ( AccessTokenExpiredError, AccessTokenInvalidError, FirstNameInvalidError, FloodWaitError, PasswordHashInvalidError, PhoneCodeExpiredError, PhoneCodeInvalidError, PhoneNumberAppSignupForbiddenError, PhoneNumberBannedError, PhoneNumberFloodError, @@ -125,7 +127,7 @@ async def enter_code_register(evt: CommandEvent) -> EventID: async def login_qr(evt: CommandEvent) -> EventID: login_as = evt.sender if len(evt.args) > 0 and evt.sender.is_admin: - login_as = u.User.get_by_mxid(UserID(evt.args[0])) + login_as = await u.User.get_by_mxid(UserID(evt.args[0])) if not qrcode or not QRLogin: return await evt.reply("This bridge instance does not support logging in with a QR code.") if await login_as.is_logged_in(): @@ -133,7 +135,7 @@ async def login_qr(evt: CommandEvent) -> EventID: await login_as.ensure_started(even_if_no_session=True) qr_login = QRLogin(login_as.client, ignored_ids=[]) - qr_event_id: Optional[EventID] = None + qr_event_id: EventID | None = None async def upload_qr() -> None: nonlocal qr_event_id @@ -184,7 +186,7 @@ async def login_qr(evt: CommandEvent) -> EventID: async def login(evt: CommandEvent) -> EventID: override_sender = False if len(evt.args) > 0 and evt.sender.is_admin: - evt.sender = await u.User.get_by_mxid(UserID(evt.args[0])).ensure_started() + evt.sender = await u.User.get_and_start_by_mxid(UserID(evt.args[0])) override_sender = True if await evt.sender.is_logged_in(): return await evt.reply(f"You are already logged in as {evt.sender.human_tg_id}.") @@ -217,7 +219,7 @@ async def login(evt: CommandEvent) -> EventID: return await evt.reply("This bridge instance has been configured to not allow logging in.") -async def _request_code(evt: CommandEvent, phone_number: str, next_status: Dict[str, Any] +async def _request_code(evt: CommandEvent, phone_number: str, next_status: dict[str, Any] ) -> EventID: ok = False try: @@ -249,7 +251,7 @@ async def _request_code(evt: CommandEvent, phone_number: str, next_status: Dict[ @command_handler(needs_auth=False) -async def enter_phone_or_token(evt: CommandEvent) -> Optional[EventID]: +async def enter_phone_or_token(evt: CommandEvent) -> EventID | None: if len(evt.args) == 0: return await evt.reply("**Usage:** `$cmdprefix+sp enter-phone-or-token `") elif not evt.config.get("bridge.allow_matrix_login", True): @@ -273,7 +275,7 @@ async def enter_phone_or_token(evt: CommandEvent) -> Optional[EventID]: @command_handler(needs_auth=False) -async def enter_code(evt: CommandEvent) -> Optional[EventID]: +async def enter_code(evt: CommandEvent) -> EventID | None: if len(evt.args) == 0: return await evt.reply("**Usage:** `$cmdprefix+sp enter-code `") elif not evt.config.get("bridge.allow_matrix_login", True): @@ -289,7 +291,7 @@ async def enter_code(evt: CommandEvent) -> Optional[EventID]: @command_handler(needs_auth=False) -async def enter_password(evt: CommandEvent) -> Optional[EventID]: +async def enter_password(evt: CommandEvent) -> EventID | None: if len(evt.args) == 0: return await evt.reply("**Usage:** `$cmdprefix+sp enter-password `") elif not evt.config.get("bridge.allow_matrix_login", True): @@ -309,7 +311,7 @@ async def enter_password(evt: CommandEvent) -> Optional[EventID]: return None -async def _sign_in(evt: CommandEvent, login_as: 'u.User' = None, **sign_in_info) -> EventID: +async def _sign_in(evt: CommandEvent, login_as: u.User = None, **sign_in_info) -> EventID: login_as = login_as or evt.sender try: await login_as.ensure_started(even_if_no_session=True) @@ -330,9 +332,9 @@ async def _sign_in(evt: CommandEvent, login_as: 'u.User' = None, **sign_in_info) "Please send your password here.") -async def _finish_sign_in(evt: CommandEvent, user: User, login_as: 'u.User' = None) -> EventID: +async def _finish_sign_in(evt: CommandEvent, user: User, login_as: u.User = None) -> EventID: login_as = login_as or evt.sender - existing_user = u.User.get_by_tgid(TelegramID(user.id)) + existing_user = await u.User.get_by_tgid(TelegramID(user.id)) if existing_user and existing_user != login_as: await existing_user.log_out() await evt.reply(f"[{existing_user.displayname}]" diff --git a/mautrix_telegram/commands/telegram/misc.py b/mautrix_telegram/commands/telegram/misc.py index 4d615587..03944fd4 100644 --- a/mautrix_telegram/commands/telegram/misc.py +++ b/mautrix_telegram/commands/telegram/misc.py @@ -1,5 +1,5 @@ # mautrix-telegram - A Matrix-Telegram puppeting bridge -# Copyright (C) 2020 Tulir Asokan +# Copyright (C) 2021 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -13,7 +13,9 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import List, Optional, Tuple, cast +from __future__ import annotations + +from typing import cast import codecs import base64 import re @@ -81,7 +83,7 @@ async def search(evt: CommandEvent) -> EventID: "Minimum length of remote query is 5 characters.") return await evt.reply("No results 3:") - reply: List[str] = [] + reply: list[str] = [] if remote: reply += ["**Results from Telegram server:**", ""] else: @@ -114,14 +116,14 @@ async def pm(evt: CommandEvent) -> EventID: return await evt.reply("User not found.") elif not isinstance(user, TLUser): return await evt.reply("That doesn't seem to be a user.") - portal = po.Portal.get_by_entity(user, evt.sender.tgid) + portal = await po.Portal.get_by_entity(user, tg_receiver=evt.sender.tgid) await portal.create_matrix_room(evt.sender, user, [evt.sender.mxid]) displayname, _ = pu.Puppet.get_displayname(user, False) return await evt.reply(f"Created private chat room with {displayname}") async def _join(evt: CommandEvent, identifier: str, link_type: str - ) -> Tuple[Optional[TypeUpdates], Optional[EventID]]: + ) -> tuple[TypeUpdates | None, EventID | None]: if link_type == "joinchat": try: await evt.sender.client(CheckChatInviteRequest(identifier)) @@ -143,7 +145,7 @@ async def _join(evt: CommandEvent, identifier: str, link_type: str @command_handler(help_section=SECTION_CREATING_PORTALS, help_args="<_link_>", help_text="Join a chat with an invite link.") -async def join(evt: CommandEvent) -> Optional[EventID]: +async def join(evt: CommandEvent) -> EventID | None: if len(evt.args) == 0: return await evt.reply("**Usage:** `$cmdprefix+sp join `") @@ -171,7 +173,7 @@ async def join(evt: CommandEvent) -> Optional[EventID]: return None for chat in updates.chats: - portal = po.Portal.get_by_entity(chat) + portal = await po.Portal.get_by_entity(chat) if portal.mxid: await portal.invite_to_matrix([evt.sender.mxid]) return await evt.reply(f"Invited you to portal of {portal.title}") @@ -219,7 +221,7 @@ class MessageIDError(ValueError): async def _parse_encoded_msgid(user: AbstractUser, enc_id: str, type_name: str - ) -> Tuple[TypeInputPeer, Message]: + ) -> tuple[TypeInputPeer, Message]: try: enc_id += (4 - len(enc_id) % 4) * "=" enc_id = base64.b64decode(enc_id) @@ -233,10 +235,10 @@ async def _parse_encoded_msgid(user: AbstractUser, enc_id: str, type_name: str raise MessageIDError(f"Invalid {type_name} ID (format)") from e if peer_type == PEER_TYPE_CHAT: - orig_msg = DBMessage.get_one_by_tgid(msg_id, space) + orig_msg = await DBMessage.get_one_by_tgid(msg_id, space) if not orig_msg: raise MessageIDError(f"Invalid {type_name} ID (original message not found in db)") - new_msg = DBMessage.get_by_mxid(orig_msg.mxid, orig_msg.mx_room, user.tgid) + new_msg = await DBMessage.get_by_mxid(orig_msg.mxid, orig_msg.mx_room, user.tgid) if not new_msg: raise MessageIDError(f"Invalid {type_name} ID (your copy of message not found in db)") msg_id = new_msg.tgid @@ -282,7 +284,7 @@ async def play(evt: CommandEvent) -> EventID: @command_handler(help_section=SECTION_MISC, help_args="<_poll ID_> <_choice number_>", help_text="Vote in a Telegram poll.") -async def vote(evt: CommandEvent) -> EventID: +async def vote(evt: CommandEvent) -> EventID | None: if len(evt.args) < 1: return await evt.reply("**Usage:** `$cmdprefix+sp vote `") elif not await evt.sender.is_logged_in(): @@ -319,7 +321,7 @@ async def vote(evt: CommandEvent) -> EventID: options = [msg.media.poll.answers[int(option) - 1].option for option in evt.args[1:]] try: - resp = await evt.sender.client(SendVoteRequest(peer=peer, msg_id=msg.id, options=options)) + await evt.sender.client(SendVoteRequest(peer=peer, msg_id=msg.id, options=options)) except OptionsTooMuchError: return await evt.reply("You passed too many options.") # TODO use response @@ -332,7 +334,7 @@ async def vote(evt: CommandEvent) -> EventID: async def random(evt: CommandEvent) -> EventID: if not evt.is_portal: return await evt.reply("You can only randomize values in portal rooms") - portal = po.Portal.get_by_mxid(evt.room_id) + portal = await po.Portal.get_by_mxid(evt.room_id) arg = evt.args[0] if len(evt.args) > 0 else "dice" emoticon = { "dart": "\U0001F3AF", @@ -359,7 +361,7 @@ async def backfill(evt: CommandEvent) -> None: limit = int(evt.args[0]) except (ValueError, IndexError): limit = -1 - portal = po.Portal.get_by_mxid(evt.room_id) + portal = await po.Portal.get_by_mxid(evt.room_id) if not evt.config["bridge.backfill.normal_groups"] and portal.peer_type == "chat": await evt.reply("Backfilling normal groups is disabled in the bridge config") return diff --git a/mautrix_telegram/config.py b/mautrix_telegram/config.py index 30a9d691..bebf48b6 100644 --- a/mautrix_telegram/config.py +++ b/mautrix_telegram/config.py @@ -1,5 +1,5 @@ # mautrix-telegram - A Matrix-Telegram puppeting bridge -# Copyright (C) 2020 Tulir Asokan +# Copyright (C) 2021 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -132,6 +132,8 @@ class Config(BaseBridgeConfig): copy("bridge.pinned_tag") copy("bridge.archive_tag") copy("bridge.tag_only_on_create") + copy("bridge.bridge_matrix_leave") + copy("bridge.kick_on_logout") copy("bridge.backfill.invite_own_puppet") copy("bridge.backfill.takeout_limit") copy("bridge.backfill.initial_limit") diff --git a/mautrix_telegram/context.py b/mautrix_telegram/context.py deleted file mode 100644 index 3ed15e4f..00000000 --- a/mautrix_telegram/context.py +++ /dev/null @@ -1,57 +0,0 @@ -# mautrix-telegram - A Matrix-Telegram puppeting bridge -# Copyright (C) 2019 Tulir Asokan -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Affero General Public License for more details. -# -# You should have received a copy of the GNU Affero General Public License -# along with this program. If not, see . -from typing import Optional, Tuple, TYPE_CHECKING -import asyncio - -from alchemysession import AlchemySessionContainer - -from mautrix.appservice import AppService - -if TYPE_CHECKING: - from .web import PublicBridgeWebsite, ProvisioningAPI - from .config import Config - from .bot import Bot - from .matrix import MatrixHandler - from .__main__ import TelegramBridge - - -class Context: - az: AppService - config: 'Config' - loop: asyncio.AbstractEventLoop - bridge: 'TelegramBridge' - bot: Optional['Bot'] - mx: Optional['MatrixHandler'] - session_container: AlchemySessionContainer - public_website: Optional['PublicBridgeWebsite'] - provisioning_api: Optional['ProvisioningAPI'] - - def __init__(self, az: AppService, config: 'Config', loop: asyncio.AbstractEventLoop, - session_container: AlchemySessionContainer, bridge: 'TelegramBridge', - bot: Optional['Bot']) -> None: - self.az = az - self.config = config - self.loop = loop - self.bridge = bridge - self.bot = bot - self.mx = None - self.session_container = session_container - self.public_website = None - self.provisioning_api = None - - @property - def core(self) -> Tuple[AppService, 'Config', asyncio.AbstractEventLoop, Optional['Bot']]: - return self.az, self.config, self.loop, self.bot diff --git a/mautrix_telegram/db/__init__.py b/mautrix_telegram/db/__init__.py index 67c9779a..3837d201 100644 --- a/mautrix_telegram/db/__init__.py +++ b/mautrix_telegram/db/__init__.py @@ -1,5 +1,5 @@ # mautrix-telegram - A Matrix-Telegram puppeting bridge -# Copyright (C) 2019 Tulir Asokan +# Copyright (C) 2021 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -13,19 +13,23 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from sqlalchemy.engine.base import Engine +from mautrix.util.async_db import Database -from mautrix.client.state_store.sqlalchemy import UserProfile, RoomState +from .upgrade import upgrade_table from .bot_chat import BotChat from .message import Message from .portal import Portal from .puppet import Puppet from .telegram_file import TelegramFile -from .user import User, UserPortal, Contact +from .user import User +from .telethon_session import PgSession -def init(db_engine: Engine) -> None: - for table in (Portal, Message, User, Contact, UserPortal, Puppet, TelegramFile, UserProfile, - RoomState, BotChat): - table.bind(db_engine) +def init(db: Database) -> None: + for table in (Portal, Message, User, Puppet, TelegramFile, BotChat, PgSession): + table.db = db + + +__all__ = ["upgrade_table", "init", "Portal", "Message", "User", "Puppet", "TelegramFile", + "BotChat", "PgSession"] diff --git a/mautrix_telegram/db/bot_chat.py b/mautrix_telegram/db/bot_chat.py index d7ecdc52..3ea66d6a 100644 --- a/mautrix_telegram/db/bot_chat.py +++ b/mautrix_telegram/db/bot_chat.py @@ -1,5 +1,5 @@ # mautrix-telegram - A Matrix-Telegram puppeting bridge -# Copyright (C) 2019 Tulir Asokan +# Copyright (C) 2021 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -13,26 +13,43 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Iterable +from __future__ import annotations -from sqlalchemy import Column, BigInteger, String +from typing import ClassVar, TYPE_CHECKING -from mautrix.util.db import Base +from asyncpg import Record +from attr import dataclass + +from mautrix.util.async_db import Database from ..types import TelegramID +fake_db = Database.create("") if TYPE_CHECKING else None + # Fucking Telegram not telling bots what chats they are in 3:< -class BotChat(Base): - __tablename__ = "bot_chat" - id: TelegramID = Column(BigInteger, primary_key=True) - type: str = Column(String, nullable=False) +@dataclass +class BotChat: + db: ClassVar[Database] = fake_db + + id: TelegramID + type: str @classmethod - def delete_by_id(cls, chat_id: TelegramID) -> None: - with cls.db.begin() as conn: - conn.execute(cls.t.delete().where(cls.c.id == chat_id)) + def _from_row(cls, row: Record | None) -> BotChat | None: + if row is None: + return None + return cls(**row) @classmethod - def all(cls) -> Iterable['BotChat']: - return cls._select_all() + async def delete_by_id(cls, chat_id: TelegramID) -> None: + await cls.db.execute("DELETE FROM bot_chat WHERE id=$1", chat_id) + + @classmethod + async def all(cls) -> list[BotChat]: + rows = await cls.db.fetch("SELECT id, type FROM bot_chat") + return [cls._from_row(row) for row in rows] + + async def insert(self) -> None: + q = "INSERT INTO bot_chat (id, type) VALUES ($1, $2)" + await self.db.execute(q, self.id, self.type) diff --git a/mautrix_telegram/db/message.py b/mautrix_telegram/db/message.py index 16c2a25d..fae5caea 100644 --- a/mautrix_telegram/db/message.py +++ b/mautrix_telegram/db/message.py @@ -1,5 +1,5 @@ # mautrix-telegram - A Matrix-Telegram puppeting bridge -# Copyright (C) 2019 Tulir Asokan +# Copyright (C) 2021 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -13,96 +13,146 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Optional, Iterator, List +from __future__ import annotations -from sqlalchemy import (Column, UniqueConstraint, BigInteger, Integer, String, Boolean, and_, func, - desc, select, false) +from typing import ClassVar, TYPE_CHECKING + +from asyncpg import Record +from attr import dataclass from mautrix.types import RoomID, EventID -from mautrix.util.db import Base +from mautrix.util.async_db import Database from ..types import TelegramID +fake_db = Database.create("") if TYPE_CHECKING else None -class Message(Base): - __tablename__ = "message" - mxid: EventID = Column(String) - mx_room: RoomID = Column(String) - tgid: TelegramID = Column(BigInteger, primary_key=True) - tg_space: TelegramID = Column(BigInteger, primary_key=True) - edit_index: int = Column(Integer, primary_key=True) - redacted: bool = Column(Boolean, server_default=false()) +@dataclass +class Message: + db: ClassVar[Database] = fake_db - __table_args__ = (UniqueConstraint("mxid", "mx_room", "tg_space", name="_mx_id_room_2"),) + mxid: EventID + mx_room: RoomID + tgid: TelegramID + tg_space: TelegramID + edit_index: int + redacted: bool = False @classmethod - def get_all_by_tgid(cls, tgid: TelegramID, tg_space: TelegramID) -> Iterator['Message']: - return cls._select_all(cls.c.tgid == tgid, cls.c.tg_space == tg_space) + def _from_row(cls, row: Record | None) -> Message | None: + if row is None: + return None + return cls(**row) + + columns: ClassVar[str] = "mxid, mx_room, tgid, tg_space, edit_index, redacted" @classmethod - def get_one_by_tgid(cls, tgid: TelegramID, tg_space: TelegramID, edit_index: int = 0 - ) -> Optional['Message']: + async def get_all_by_tgid(cls, tgid: TelegramID, tg_space: TelegramID) -> list[Message]: + q = f"SELECT {cls.columns} FROM message WHERE tgid=$1 AND tg_space=$2" + rows = await cls.db.fetch(q, tgid, tg_space) + return [cls._from_row(row) for row in rows] + + @classmethod + async def get_one_by_tgid( + cls, tgid: TelegramID, tg_space: TelegramID, edit_index: int = 0 + ) -> Message | None: if edit_index < 0: - return cls._one_or_none(cls.db.execute( - cls.t.select() - .where(and_(cls.c.tgid == tgid, cls.c.tg_space == tg_space)) - .order_by(desc(cls.c.edit_index)) - .limit(1).offset(-edit_index - 1))) + q = ( + f"SELECT {cls.columns} FROM message WHERE tgid=$1 AND tg_space=$2 " + f"ORDER BY edit_index DESC LIMIT 1 OFFSET {-edit_index - 1}" + ) + row = await cls.db.fetchrow(q, tgid, tg_space) else: - return cls._select_one_or_none(cls.c.tgid == tgid, cls.c.tg_space == tg_space, - cls.c.edit_index == edit_index) + q = ( + f"SELECT {cls.columns} FROM message" + " WHERE tgid=$1 AND tg_space=$2 AND edit_index=$3" + ) + row = await cls.db.fetchrow(q, tgid, tg_space, edit_index) + return cls._from_row(row) @classmethod - def get_first_by_tgids(cls, tgids: List[TelegramID], tg_space: TelegramID - ) -> Iterator['Message']: - return cls._select_all(cls.c.tgid.in_(tgids), cls.c.tg_space == tg_space, - cls.c.edit_index == 0) + async def get_first_by_tgids( + cls, tgids: list[TelegramID], tg_space: TelegramID + ) -> list[Message]: + if cls.db.scheme == "postgres": + q = ( + f"SELECT {cls.columns} FROM message" + " WHERE tgid=ANY($1) AND tg_space=$2 AND edit_index=0" + ) + rows = await cls.db.fetch(q, tgids, tg_space) + else: + tgid_placeholders = ("?," * len(tgids)).rstrip(",") + q = ( + f"SELECT {cls.columns} FROM message " + f"WHERE tg_space=? AND edit_index=0 AND tgid IN ({tgid_placeholders})" + ) + rows = await cls.db.fetch(q, tg_space, *tgids) + return [cls._from_row(row) for row in rows] @classmethod - def count_spaces_by_mxid(cls, mxid: EventID, mx_room: RoomID) -> int: - rows = cls.db.execute(select([func.count(cls.c.tg_space)]) - .where(and_(cls.c.mxid == mxid, cls.c.mx_room == mx_room))) - try: - count, = next(rows) - return count - except StopIteration: - return 0 + async def count_spaces_by_mxid(cls, mxid: EventID, mx_room: RoomID) -> int: + return await cls.db.fetchval( + "SELECT COUNT(tg_space) FROM message WHERE mxid=$1 AND mx_room=$2", mxid, mx_room + ) or 0 @classmethod - def find_last(cls, mx_room: RoomID, tg_space: TelegramID) -> Optional['Message']: - return cls._one_or_none(cls.db.execute( - cls._make_simple_select(cls.c.mx_room == mx_room, cls.c.tg_space == tg_space) - .order_by(desc(cls.c.tgid)).limit(1))) + async def find_last(cls, mx_room: RoomID, tg_space: TelegramID) -> Message | None: + q = ( + f"SELECT {cls.columns} FROM message WHERE mx_room=$1 AND tg_space=$2 " + f"ORDER BY tgid DESC LIMIT 1" + ) + return cls._from_row(await cls.db.fetchrow(q, mx_room, tg_space)) @classmethod - def delete_all(cls, mx_room: RoomID) -> None: - cls.db.execute(cls.t.delete().where(cls.c.mx_room == mx_room)) + async def delete_all(cls, mx_room: RoomID) -> None: + await cls.db.execute("DELETE FROM message WHERE mx_room=$1", mx_room) @classmethod - def get_by_mxid(cls, mxid: EventID, mx_room: RoomID, tg_space: TelegramID - ) -> Optional['Message']: - return cls._select_one_or_none(cls.c.mxid == mxid, cls.c.mx_room == mx_room, - cls.c.tg_space == tg_space) + async def get_by_mxid( + cls, mxid: EventID, mx_room: RoomID, tg_space: TelegramID + ) -> Message | None: + q = f"SELECT {cls.columns} FROM message WHERE mxid=$1 AND mx_room=$2 AND tg_space=$3" + return cls._from_row(await cls.db.fetchrow(q, mxid, mx_room, tg_space)) @classmethod - def get_by_mxids(cls, mxids: List[EventID], mx_room: RoomID, tg_space: TelegramID - ) -> Iterator['Message']: - return cls._select_all(cls.c.mxid.in_(mxids), cls.c.mx_room == mx_room, - cls.c.tg_space == tg_space) + async def get_by_mxids( + cls, mxids: list[EventID], mx_room: RoomID, tg_space: TelegramID + ) -> list[Message]: + if cls.db.scheme == "postgres": + q = ( + f"SELECT {cls.columns} FROM message" + " WHERE mxid=ANY($1) AND mx_room=$2 AND tg_space=$3" + ) + rows = await cls.db.fetch(q, mxids, mx_room, tg_space) + else: + mxid_placeholders = ("?," * len(mxids)).rstrip(",") + q = ( + f"SELECT {cls.columns} FROM message " + f"WHERE mx_room=? AND tg_space=? AND mxid IN ({mxid_placeholders})" + ) + rows = await cls.db.fetch(q, mx_room, tg_space, *mxids) + return [cls._from_row(row) for row in rows] @classmethod - def update_by_tgid(cls, s_tgid: TelegramID, s_tg_space: TelegramID, s_edit_index: int, - **values) -> None: - with cls.db.begin() as conn: - conn.execute(cls.t.update() - .where(and_(cls.c.tgid == s_tgid, cls.c.tg_space == s_tg_space, - cls.c.edit_index == s_edit_index)) - .values(**values)) + async def replace_temp_mxid(cls, temp_mxid: str, mx_room: RoomID, real_mxid: EventID) -> None: + q = "UPDATE message SET mxid=$1 WHERE mxid=$2 AND mx_room=$3" + await cls.db.execute(q, real_mxid, temp_mxid, mx_room) - @classmethod - def update_by_mxid(cls, s_mxid: EventID, s_mx_room: RoomID, **values) -> None: - with cls.db.begin() as conn: - conn.execute(cls.t.update() - .where(and_(cls.c.mxid == s_mxid, cls.c.mx_room == s_mx_room)) - .values(**values)) + async def insert(self) -> None: + q = ( + "INSERT INTO message (mxid, mx_room, tgid, tg_space, edit_index, redacted) " + "VALUES ($1, $2, $3, $4, $5, $6)" + ) + await self.db.execute( + q, self.mxid, self.mx_room, self.tgid, self.tg_space, self.edit_index, self.redacted + ) + + async def delete(self) -> None: + q = "DELETE FROM message WHERE mxid=$1 AND mx_room=$2 AND tg_space=$3" + await self.db.execute(q, self.mxid, self.mx_room, self.tg_space) + + async def mark_redacted(self) -> None: + self.redacted = True + q = "UPDATE message SET redacted=true WHERE mxid=$1 AND mx_room=$2" + await self.db.execute(q, self.mxid, self.mx_room) diff --git a/mautrix_telegram/db/portal.py b/mautrix_telegram/db/portal.py index a543cbe9..b518863c 100644 --- a/mautrix_telegram/db/portal.py +++ b/mautrix_telegram/db/portal.py @@ -1,5 +1,5 @@ # mautrix-telegram - A Matrix-Telegram puppeting bridge -# Copyright (C) 2019 Tulir Asokan +# Copyright (C) 2021 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -13,54 +13,116 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Optional, Iterable +from __future__ import annotations -from sqlalchemy import Column, BigInteger, String, Boolean, Text, func, sql +from typing import ClassVar, Any, TYPE_CHECKING +import json + +from asyncpg import Record +from attr import dataclass +import attr from mautrix.types import RoomID, ContentURI -from mautrix.util.db import Base +from mautrix.util.async_db import Database from ..types import TelegramID +fake_db = Database.create("") if TYPE_CHECKING else None -class Portal(Base): - __tablename__ = "portal" + +@dataclass +class Portal: + db: ClassVar[Database] = fake_db # Telegram chat information - tgid: TelegramID = Column(BigInteger, primary_key=True) - tg_receiver: TelegramID = Column(BigInteger, primary_key=True) - peer_type: str = Column(String, nullable=False) - megagroup: bool = Column(Boolean) + tgid: TelegramID + tg_receiver: TelegramID + peer_type: str + megagroup: bool # Matrix portal information - mxid: Optional[RoomID] = Column(String, unique=True, nullable=True) - avatar_url: Optional[ContentURI] = Column(String, nullable=True) - encrypted: bool = Column(Boolean, nullable=False, server_default=sql.expression.false()) - - config: str = Column(Text, nullable=True) + mxid: RoomID | None + avatar_url: ContentURI | None + encrypted: bool # Telegram chat metadata - username: str = Column(String, nullable=True) - title: str = Column(String, nullable=True) - about: str = Column(String, nullable=True) - photo_id: str = Column(String, nullable=True) + username: str | None + title: str | None + about: str | None + photo_id: str | None + + local_config: dict[str, Any] = attr.ib(factory=lambda: {}) @classmethod - def get_by_tgid(cls, tgid: TelegramID, tg_receiver: TelegramID) -> Optional['Portal']: - return cls._select_one_or_none(cls.c.tgid == tgid, cls.c.tg_receiver == tg_receiver) + def _from_row(cls, row: Record | None) -> Portal | None: + if row is None: + return None + data = {**row} + data["local_config"] = json.loads(data.pop("config", None) or "{}") + return cls(**data) + + columns: ClassVar[str] = ( + "tgid, tg_receiver, peer_type, megagroup, mxid, avatar_url, encrypted, config, " + "username, title, about, photo_id" + ) @classmethod - def find_private_chats(cls, tg_receiver: TelegramID) -> Iterable['Portal']: - yield from cls._select_all(cls.c.tg_receiver == tg_receiver, cls.c.peer_type == "user") + async def get_by_tgid(cls, tgid: TelegramID, tg_receiver: TelegramID) -> Portal | None: + q = f"SELECT {cls.columns} FROM portal WHERE tgid=$1 AND tg_receiver=$2" + return cls._from_row(await cls.db.fetchrow(q, tgid, tg_receiver)) @classmethod - def get_by_mxid(cls, mxid: RoomID) -> Optional['Portal']: - return cls._select_one_or_none(cls.c.mxid == mxid) + async def get_by_mxid(cls, mxid: RoomID) -> Portal | None: + q = f"SELECT {cls.columns} FROM portal WHERE mxid=$1" + return cls._from_row(await cls.db.fetchrow(q, mxid)) @classmethod - def get_by_username(cls, username: str) -> Optional['Portal']: - return cls._select_one_or_none(func.lower(cls.c.username) == username) + async def find_by_username(cls, username: str) -> Portal | None: + q = f"SELECT {cls.columns} FROM portal WHERE lower(username)=$1" + return cls._from_row(await cls.db.fetchrow(q, username.lower())) @classmethod - def all(cls) -> Iterable['Portal']: - yield from cls._select_all() + async def find_private_chats(cls, tg_receiver: TelegramID) -> list[Portal]: + q = f"SELECT {cls.columns} FROM portal WHERE tg_receiver=$1 AND peer_type='user'" + return [cls._from_row(row) for row in await cls.db.fetch(q, tg_receiver)] + + @classmethod + async def all(cls) -> list[Portal]: + rows = await cls.db.fetch(f"SELECT {cls.columns} FROM portal") + return [cls._from_row(row) for row in rows] + + @property + def _values(self): + return (self.tgid, self.tg_receiver, self.peer_type, self.mxid, self.avatar_url, + self.encrypted, self.username, self.title, self.about, self.photo_id, + self.megagroup, json.dumps(self.local_config) if self.local_config else None) + + async def save(self) -> None: + q = ( + "UPDATE portal SET mxid=$4, avatar_url=$5, encrypted=$6, username=$7, title=$8," + " about=$9, photo_id=$10, megagroup=$11, config=$12 " + "WHERE tgid=$1 AND tg_receiver=$2 AND (peer_type=$3 OR true)" + ) + await self.db.execute(q, *self._values) + + async def update_id(self, id: TelegramID, peer_type: str) -> None: + q = ( + "UPDATE portal SET tgid=$1, tg_receiver=$1, peer_type=$2 " + "WHERE tgid=$3 AND tg_receiver=$3" + ) + await self.db.execute(q, id, peer_type, self.tgid) + self.tgid = id + self.tg_receiver = id + self.peer_type = peer_type + + async def insert(self) -> None: + q = ( + "INSERT INTO portal (tgid, tg_receiver, peer_type, mxid, avatar_url, encrypted," + " username, title, about, photo_id, megagroup, config) " + "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)" + ) + await self.db.execute(q, *self._values) + + async def delete(self) -> None: + q = "DELETE FROM portal WHERE tgid=$1 AND tg_receiver=$2" + await self.db.execute(q, self.tgid, self.tg_receiver) diff --git a/mautrix_telegram/db/puppet.py b/mautrix_telegram/db/puppet.py index f842a14e..3a73ad30 100644 --- a/mautrix_telegram/db/puppet.py +++ b/mautrix_telegram/db/puppet.py @@ -1,5 +1,5 @@ # mautrix-telegram - A Matrix-Telegram puppeting bridge -# Copyright (C) 2019 Tulir Asokan +# Copyright (C) 2021 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -13,51 +13,106 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Optional, Iterable +from __future__ import annotations -from sqlalchemy import Column, Integer, BigInteger, String, Text, Boolean -from sqlalchemy.sql import expression, func +from typing import ClassVar, TYPE_CHECKING + +from asyncpg import Record +from attr import dataclass +from yarl import URL from mautrix.types import UserID, SyncToken -from mautrix.util.db import Base +from mautrix.util.async_db import Database from ..types import TelegramID +fake_db = Database.create("") if TYPE_CHECKING else None -class Puppet(Base): - __tablename__ = "puppet" - id: TelegramID = Column(BigInteger, primary_key=True) - custom_mxid: UserID = Column(String, nullable=True) - access_token: str = Column(String, nullable=True) - next_batch: SyncToken = Column(String, nullable=True) - base_url: str = Column(Text, nullable=True) - displayname: str = Column(String, nullable=True) - displayname_source: TelegramID = Column(BigInteger, nullable=True) - displayname_contact: bool = Column(Boolean, nullable=False, server_default=expression.true()) - displayname_quality: int = Column(Integer, nullable=False, server_default="0") - username: str = Column(String, nullable=True) - photo_id: str = Column(String, nullable=True) - is_bot: bool = Column(Boolean, nullable=True) - matrix_registered: bool = Column(Boolean, nullable=False, server_default=expression.false()) - disable_updates: bool = Column(Boolean, nullable=False, server_default=expression.false()) +@dataclass +class Puppet: + db: ClassVar[Database] = fake_db + + id: TelegramID + + is_registered: bool + + displayname: str | None + displayname_source: TelegramID | None + displayname_contact: bool + displayname_quality: int + disable_updates: bool + username: str | None + photo_id: str | None + is_bot: bool | None + + custom_mxid: UserID | None + access_token: str | None + next_batch: SyncToken | None + base_url: URL | None @classmethod - def all_with_custom_mxid(cls) -> Iterable['Puppet']: - yield from cls._select_all(cls.c.custom_mxid != None) + def _from_row(cls, row: Record | None) -> Puppet | None: + if row is None: + return None + data = {**row} + base_url = data.pop("base_url", None) + return cls(**data, base_url=URL(base_url) if base_url else None) + + columns: ClassVar[str] = ( + "id, is_registered, displayname, displayname_source, displayname_contact, " + "displayname_quality, disable_updates, username, photo_id, is_bot, " + "custom_mxid, access_token, next_batch, base_url" + ) @classmethod - def get_by_tgid(cls, tgid: TelegramID) -> Optional['Puppet']: - return cls._select_one_or_none(cls.c.id == tgid) + async def all_with_custom_mxid(cls) -> list[Puppet]: + q = f"SELECT {cls.columns} FROM puppet WHERE custom_mxid<>''" + return [cls._from_row(row) for row in await cls.db.fetch(q)] @classmethod - def get_by_custom_mxid(cls, mxid: UserID) -> Optional['Puppet']: - return cls._select_one_or_none(cls.c.custom_mxid == mxid) + async def get_by_tgid(cls, tgid: TelegramID) -> Puppet | None: + q = f"SELECT {cls.columns} FROM puppet WHERE id=$1" + return cls._from_row(await cls.db.fetchrow(q, tgid)) @classmethod - def get_by_username(cls, username: str) -> Optional['Puppet']: - return cls._select_one_or_none(func.lower(cls.c.username) == username) + async def get_by_custom_mxid(cls, mxid: UserID) -> Puppet | None: + q = f"SELECT {cls.columns} FROM puppet WHERE custom_mxid=$1" + return cls._from_row(await cls.db.fetchrow(q, mxid)) @classmethod - def get_by_displayname(cls, displayname: str) -> Optional['Puppet']: - return cls._select_one_or_none(cls.c.displayname == displayname) + async def find_by_username(cls, username: str) -> Puppet | None: + q = f"SELECT {cls.columns} FROM puppet WHERE lower(username)=$1" + return cls._from_row(await cls.db.fetchrow(q, username.lower())) + + @classmethod + async def find_by_displayname(cls, displayname: str) -> Puppet | None: + q = f"SELECT {cls.columns} FROM puppet WHERE displayname=$1" + return cls._from_row(await cls.db.fetchrow(q, displayname)) + + @property + def _values(self): + return (self.id, self.is_registered, self.displayname, self.displayname_source, + self.displayname_contact, self.displayname_quality, self.disable_updates, + self.username, self.photo_id, self.is_bot, self.custom_mxid, self.access_token, + self.next_batch, str(self.base_url) if self.base_url else None) + + async def save(self) -> None: + q = ( + "UPDATE puppet " + "SET is_registered=$2, displayname=$3, displayname_source=$4, displayname_contact=$5," + " displayname_quality=$6, disable_updates=$7, username=$8, photo_id=$9, is_bot=$10," + " custom_mxid=$11, access_token=$12, next_batch=$13, base_url=$14 " + "WHERE id=$1" + ) + await self.db.execute(q, *self._values) + + async def insert(self) -> None: + q = ( + "INSERT INTO puppet (" + " id, is_registered, displayname, displayname_source, displayname_contact," + " displayname_quality, disable_updates, username, photo_id, is_bot," + " custom_mxid, access_token, next_batch, base_url" + ") VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)" + ) + await self.db.execute(q, *self._values) diff --git a/mautrix_telegram/db/telegram_file.py b/mautrix_telegram/db/telegram_file.py index 5ce4acc2..18a5fc3d 100644 --- a/mautrix_telegram/db/telegram_file.py +++ b/mautrix_telegram/db/telegram_file.py @@ -1,5 +1,5 @@ # mautrix-telegram - A Matrix-Telegram puppeting bridge -# Copyright (C) 2019 Tulir Asokan +# Copyright (C) 2021 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -13,69 +13,62 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Optional, cast, Dict, Any, TYPE_CHECKING +from __future__ import annotations -from sqlalchemy import (Column, ForeignKey, Integer, BigInteger, String, Boolean, Text, - TypeDecorator) +from typing import ClassVar, TYPE_CHECKING + +from attr import dataclass from mautrix.types import ContentURI, EncryptedFile -from mautrix.util.db import Base +from mautrix.util.async_db import Database -if TYPE_CHECKING: - from sqlalchemy.engine.result import RowProxy +fake_db = Database.create("") if TYPE_CHECKING else None -class DBEncryptedFile(TypeDecorator): - impl = Text +@dataclass +class TelegramFile: + db: ClassVar[Database] = fake_db - @property - def python_type(self): - return EncryptedFile - - def process_bind_param(self, value: EncryptedFile, dialect) -> Optional[str]: - if value is not None: - return value.json() - return None - - def process_result_value(self, value: str, dialect) -> Optional[EncryptedFile]: - if value is not None: - return EncryptedFile.parse_json(value) - return None - - def process_literal_param(self, value, dialect): - return value - - -class TelegramFile(Base): - __tablename__ = "telegram_file" - - id: str = Column(String, primary_key=True) - mxc: ContentURI = Column(String) - mime_type: str = Column(String) - was_converted: bool = Column(Boolean) - timestamp: int = Column(BigInteger) - size: Optional[int] = Column(Integer, nullable=True) - width: Optional[int] = Column(Integer, nullable=True) - height: Optional[int] = Column(Integer, nullable=True) - decryption_info: Optional[Dict[str, Any]] = Column(DBEncryptedFile, nullable=True) - thumbnail_id: str = Column("thumbnail", String, ForeignKey("telegram_file.id"), nullable=True) - thumbnail: Optional['TelegramFile'] = None + id: str + mxc: ContentURI + mime_type: str + was_converted: bool + timestamp: int + size: int | None + width: int | None + height: int | None + decryption_info: EncryptedFile | None + thumbnail: TelegramFile | None = None @classmethod - def scan(cls, row: 'RowProxy') -> 'TelegramFile': - telegram_file = cast(TelegramFile, super().scan(row)) - if isinstance(telegram_file.thumbnail, str): - telegram_file.thumbnail = cls.get(telegram_file.thumbnail) - return telegram_file + async def get(cls, loc_id: str, *, _thumbnail: bool = False) -> TelegramFile | None: + q = ( + "SELECT id, mxc, mime_type, was_converted, timestamp, size, width, height, thumbnail," + " decryption_info " + "FROM telegram_file WHERE id=$1" + ) + row = await cls.db.fetchrow(q, loc_id) + if row is None: + return None + data = {**row} + thumbnail_id = data.pop("thumbnail", None) + if _thumbnail: + # Don't allow more than one level of recursion + thumbnail_id = None + decryption_info = data.pop("decryption_info", None) + return cls( + **data, + thumbnail=(await cls.get(thumbnail_id, _thumbnail=True)) if thumbnail_id else None, + decryption_info=EncryptedFile.parse_json(decryption_info) if decryption_info else None, + ) - @classmethod - def get(cls, loc_id: str) -> Optional['TelegramFile']: - return cls._select_one_or_none(cls.c.id == loc_id) - - def insert(self) -> None: - with self.db.begin() as conn: - conn.execute(self.t.insert().values( - id=self.id, mxc=self.mxc, mime_type=self.mime_type, - was_converted=self.was_converted, timestamp=self.timestamp, size=self.size, - width=self.width, height=self.height, decryption_info=self.decryption_info, - thumbnail=self.thumbnail.id if self.thumbnail else self.thumbnail_id)) + async def insert(self) -> None: + q = ( + "INSERT INTO telegram_file (id, mxc, mime_type, was_converted, size, width, height, " + " thumbnail, decryption_info) " + "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)" + ) + await self.db.execute(q, self.id, self.mxc, self.mime_type, self.was_converted, self.size, + self.width, self.height, + self.thumbnail.id if self.thumbnail else None, + self.decryption_info.json() if self.decryption_info else None) diff --git a/mautrix_telegram/db/telethon_session.py b/mautrix_telegram/db/telethon_session.py new file mode 100644 index 00000000..cb8f5757 --- /dev/null +++ b/mautrix_telegram/db/telethon_session.py @@ -0,0 +1,204 @@ +# mautrix-telegram - A Matrix-Telegram puppeting bridge +# Copyright (C) 2021 Tulir Asokan +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +from __future__ import annotations + +from typing import ClassVar, TYPE_CHECKING +import datetime +import asyncio + +from telethon.sessions import MemorySession +from telethon.tl.types import updates, PeerUser, PeerChat, PeerChannel +from telethon.crypto import AuthKey +from telethon import utils + +from mautrix.util.async_db import Database + +fake_db = Database.create("") if TYPE_CHECKING else None + + +class PgSession(MemorySession): + db: ClassVar[Database] = fake_db + + session_id: str + _dc_id: int + _server_address: str | None + _port: int | None + _auth_key: AuthKey | None + _takeout_id: int | None + _process_entities_lock: asyncio.Lock + + def __init__( + self, + session_id: str, + dc_id: int = 0, + server_address: str | None = None, + port: int | None = None, + auth_key: AuthKey | None = None, + takeout_id: int | None = None, + ) -> None: + super().__init__() + self.session_id = session_id + self._dc_id = dc_id + self._server_address = server_address + self._port = port + self._auth_key = auth_key + self._takeout_id = takeout_id + self._process_entities_lock = asyncio.Lock() + + def clone(self, to_instance=None) -> MemorySession: + # We don't want to store data of clones + # (which are used for temporarily connecting to different DCs) + return super().clone(MemorySession()) + + @property + def auth_key_bytes(self) -> bytes | None: + return self._auth_key.key if self._auth_key else None + + @classmethod + async def get(cls, session_id: str) -> PgSession: + q = ( + "SELECT session_id, dc_id, server_address, port, auth_key FROM telethon_sessions " + "WHERE session_id=$1" + ) + row = await cls.db.fetchrow(q, session_id) + if row is None: + return cls(session_id) + data = {**row} + auth_key = AuthKey(data.pop("auth_key", None)) + return cls(**data, auth_key=auth_key) + + @classmethod + async def has(cls, session_id: str) -> bool: + q = "SELECT COUNT(*) FROM telethon_sessions WHERE session_id=$1" + count = await cls.db.fetchval(q, session_id) + return count > 0 + + async def save(self) -> None: + q = ( + "INSERT INTO telethon_sessions (session_id, dc_id, server_address, port, auth_key) " + "VALUES ($1, $2, $3, $4, $5) ON CONFLICT (session_id) " + "DO UPDATE SET dc_id=$2, server_address=$3, port=$4, auth_key=$5" + ) + await self.db.execute( + q, self.session_id, self.dc_id, self.server_address, self.port, self.auth_key_bytes + ) + + _tables: ClassVar[tuple[str, ...]] = ( + "telethon_sessions", "telethon_entities", "telethon_sent_files", "telethon_update_state" + ) + + async def delete(self) -> None: + async with self.db.acquire() as conn, conn.transaction(): + for table in self._tables: + await conn.execute(f"DELETE FROM {table} WHERE session_id=$1", self.session_id) + + async def close(self) -> None: + # Nothing to do here, DB connection is global + pass + + async def get_update_state(self, entity_id: int) -> updates.State | None: + q = ( + "SELECT pts, qts, date, seq, unread_count FROM telethon_update_state " + "WHERE session_id=$1 AND entity_id=$2" + ) + row = await self.db.fetchrow(q, self.session_id, entity_id) + if row is None: + return None + date = datetime.datetime.utcfromtimestamp(row["date"]) + return updates.State(row["pts"], row["qts"], date, row["seq"], row["unread_count"]) + + async def set_update_state(self, entity_id: int, row: updates.State) -> None: + q = ( + "INSERT INTO telethon_update_state" + " (session_id, entity_id, pts, qts, date, seq, unread_count) " + "VALUES ($1, $2, $3, $4, $5, $6, $7)" + "ON CONFLICT (session_id, entity_id) DO UPDATE" + " SET pts=$3, qts=$4, date=$5, seq=$6, unread_count=$7" + ) + ts = row.date.timestamp() + await self.db.execute( + q, self.session_id, entity_id, row.pts, row.qts, ts, row.seq, row.unread_count + ) + + def _entity_values_to_row( + self, id: int, hash: int, username: str | None, phone: str | int | None, name: str | None + ) -> tuple[str, int, int, str | None, str | None, str | None]: + return self.session_id, id, hash, username, str(phone) if phone else None, name + + async def process_entities(self, tlo) -> None: + # Postgres likes to deadlock on simultaneous upserts, so just lock the whole thing here + # TODO: make sure postgres doesn't deadlock on upserts when session_id is different + async with self._process_entities_lock: + await self._locked_process_entities(tlo) + + async def _locked_process_entities(self, tlo) -> None: + rows: list[ + tuple[str, int, int, str | None, str | None, str | None] + ] = self._entities_to_rows(tlo) + if not rows: + return + if self.db.scheme == "postgres": + q = ( + "INSERT INTO telethon_entities (session_id, id, hash, username, phone, name) " + "VALUES ($1, unnest($2::bigint[]), unnest($3::bigint[]), " + " unnest($4::text[]), unnest($5::text[]), unnest($6::text[])) " + "ON CONFLICT (session_id, id) DO UPDATE" + " SET hash=excluded.hash, username=excluded.username," + " phone=excluded.phone, name=excluded.name" + ) + _, ids, hashes, usernames, phones, names = zip(*rows) + await self.db.execute(q, self.session_id, ids, hashes, usernames, phones, names) + else: + q = ( + "INSERT INTO telethon_entities (session_id, id, hash, username, phone, name) " + "VALUES ($1, $2, $3, $4, $5, $6) " + "ON CONFLICT (session_id, id) DO UPDATE " + " SET hash=$3, username=$4, phone=$5, name=$6" + ) + await self.db.executemany(q, rows) + + async def _select_entity( + self, constraint: str, *args: str | int | tuple[int, ...] + ) -> tuple[int, int] | None: + row = await self.db.fetchrow( + f"SELECT id, hash FROM telethon_entities WHERE {constraint}", *args + ) + if row is None: + return None + return row["id"], row["hash"] + + async def get_entity_rows_by_phone(self, key: str | int) -> tuple[int, int] | None: + return await self._select_entity("phone=$1", str(key)) + + async def get_entity_rows_by_username(self, key: str) -> tuple[int, int] | None: + return await self._select_entity("username=$1", key) + + async def get_entity_rows_by_name(self, key: str) -> tuple[int, int] | None: + return await self._select_entity("name=$1", key) + + async def get_entity_rows_by_id(self, key: int, exact: bool = True) -> tuple[int, int] | None: + if exact: + return await self._select_entity("id=$1", key) + + ids = ( + utils.get_peer_id(PeerUser(key)), + utils.get_peer_id(PeerChat(key)), + utils.get_peer_id(PeerChannel(key)) + ) + if self.db.scheme == "postgres": + return await self._select_entity("id=ANY($1)", ids) + else: + return await self._select_entity(f"id IN ($1, $2, $3)", *ids) diff --git a/mautrix_telegram/db/upgrade/__init__.py b/mautrix_telegram/db/upgrade/__init__.py new file mode 100644 index 00000000..146e7134 --- /dev/null +++ b/mautrix_telegram/db/upgrade/__init__.py @@ -0,0 +1,5 @@ +from mautrix.util.async_db import UpgradeTable + +upgrade_table = UpgradeTable() + +from . import v01_initial_revision diff --git a/mautrix_telegram/db/upgrade/v01_initial_revision.py b/mautrix_telegram/db/upgrade/v01_initial_revision.py new file mode 100644 index 00000000..8b35d087 --- /dev/null +++ b/mautrix_telegram/db/upgrade/v01_initial_revision.py @@ -0,0 +1,300 @@ +# mautrix-telegram - A Matrix-Telegram puppeting bridge +# Copyright (C) 2021 Tulir Asokan +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +from asyncpg import Connection +from . import upgrade_table + +legacy_version_query = "SELECT version_num FROM alembic_version" +last_legacy_version = "bfc0a39bfe02" + + +def table_exists(scheme: str, name: str) -> str: + if scheme == "sqlite": + return f"SELECT EXISTS(SELECT 1 FROM sqlite_master WHERE type='table' AND name='{name}')" + elif scheme == "postgres": + return f"SELECT EXISTS(SELECT FROM information_schema.tables WHERE table_name='{name}')" + raise RuntimeError("unsupported database scheme") + + +@upgrade_table.register(description="Initial asyncpg revision") +async def upgrade_v1(conn: Connection, scheme: str) -> None: + is_legacy = await conn.fetchval(table_exists(scheme, "alembic_version")) + if is_legacy: + await migrate_legacy_to_v1(conn, scheme) + else: + await create_v1_tables(conn) + + +async def migrate_legacy_to_v1(conn: Connection, scheme: str) -> None: + legacy_version = await conn.fetchval(legacy_version_query) + if legacy_version != last_legacy_version: + raise RuntimeError("Legacy database is not on last version. Please upgrade the old " + "database with alembic or drop it completely first.") + if scheme != "sqlite": + await conn.execute( + """ + ALTER TABLE contact + DROP CONSTRAINT contact_user_fkey, + DROP CONSTRAINT contact_contact_fkey, + ADD CONSTRAINT contact_user_fkey FOREIGN KEY (contact) REFERENCES puppet(id) + ON DELETE CASCADE ON UPDATE CASCADE, + ADD CONSTRAINT contact_contact_fkey FOREIGN KEY ("user") REFERENCES "user"(tgid) + ON DELETE CASCADE ON UPDATE CASCADE + """ + ) + await conn.execute( + """ + ALTER TABLE telethon_sessions + DROP CONSTRAINT telethon_sessions_pkey, + ADD CONSTRAINT telethon_sessions_pkey PRIMARY KEY (session_id) + """ + ) + await conn.execute( + """ + ALTER TABLE telegram_file + DROP CONSTRAINT fk_file_thumbnail, + ADD CONSTRAINT fk_file_thumbnail + FOREIGN KEY (thumbnail) REFERENCES telegram_file(id) + ON UPDATE CASCADE ON DELETE SET NULL + """ + ) + await conn.execute("ALTER TABLE puppet ALTER COLUMN id DROP DEFAULT") + await conn.execute("DROP SEQUENCE puppet_id_seq") + await conn.execute("ALTER TABLE bot_chat ALTER COLUMN id DROP DEFAULT") + await conn.execute("DROP SEQUENCE bot_chat_id_seq") + await conn.execute("ALTER TABLE portal ALTER COLUMN config TYPE jsonb USING config::jsonb") + await conn.execute( + "ALTER TABLE telegram_file ALTER COLUMN decryption_info TYPE jsonb " + "USING decryption_info::jsonb" + ) + await varchar_to_text(conn) + else: + await conn.execute( + """CREATE TABLE telethon_sessions_new ( + session_id TEXT PRIMARY KEY, + dc_id INTEGER, + server_address TEXT, + port INTEGER, + auth_key bytea + )""" + ) + await conn.execute( + """ + INSERT INTO telethon_sessions_new (session_id, dc_id, server_address, port, auth_key) + SELECT session_id, dc_id, server_address, port, auth_key FROM telethon_sessions + """ + ) + await conn.execute("DROP TABLE telethon_sessions") + await conn.execute("ALTER TABLE telethon_sessions_new RENAME TO telethon_sessions") + + await update_state_store(conn, scheme) + await conn.execute('ALTER TABLE "user" ADD COLUMN is_bot BOOLEAN NOT NULL DEFAULT false') + await conn.execute("ALTER TABLE puppet RENAME COLUMN matrix_registered TO is_registered") + await conn.execute("DROP TABLE telethon_version") + await conn.execute("DROP TABLE alembic_version") + + +async def update_state_store(conn: Connection, scheme: str) -> None: + # The Matrix state store already has more or less the correct schema, so set the version + await conn.execute("CREATE TABLE mx_version (version INTEGER PRIMARY KEY)") + await conn.execute("INSERT INTO mx_version (version) VALUES (2)") + if scheme != "sqlite": + # Also add the membership type on postgres + await conn.execute( + "CREATE TYPE membership AS ENUM ('join', 'leave', 'invite', 'ban', 'knock')" + ) + await conn.execute( + "ALTER TABLE mx_user_profile ALTER COLUMN membership TYPE membership " + "USING LOWER(membership)::membership" + ) + else: + # On SQLite there's no custom type, but we still want to lowercase everything + await conn.execute("UPDATE mx_user_profile SET membership=LOWER(membership)") + + +async def varchar_to_text(conn: Connection) -> None: + columns_to_adjust = { + "user": ("mxid", "tg_username", "tg_phone"), + "portal": ( + "peer_type", "mxid", "username", "title", "about", "photo_id", "avatar_url", "config" + ), + "message": ("mxid", "mx_room"), + "puppet": ( + "displayname", "username", "photo_id", + ) + ( + "access_token", "custom_mxid", "next_batch", "base_url" + ), + "bot_chat": ("type",), + "telegram_file": ("id", "mxc", "mime_type", "thumbnail"), + # Phone is a bigint in the old schema, which is safe, but we don't do math on it, + # so let's change it to a string + "telethon_entities": ("session_id", "username", "name", "phone"), + "telethon_sent_files": ("session_id",), + "telethon_sessions": ("session_id", "server_address"), + "telethon_update_state": ("session_id",), + "mx_room_state": ("room_id",), + "mx_user_profile": ("room_id", "user_id", "displayname", "avatar_url"), + } + for table, columns in columns_to_adjust.items(): + for column in columns: + await conn.execute(f'ALTER TABLE "{table}" ALTER COLUMN {column} TYPE TEXT') + + +async def create_v1_tables(conn: Connection) -> None: + await conn.execute( + """CREATE TABLE "user" ( + mxid TEXT PRIMARY KEY, + tgid BIGINT UNIQUE, + tg_username TEXT, + tg_phone TEXT, + is_bot BOOLEAN NOT NULL DEFAULT false, + saved_contacts INTEGER NOT NULL DEFAULT 0 + )""" + ) + await conn.execute( + """CREATE TABLE portal ( + tgid BIGINT, + tg_receiver BIGINT, + peer_type TEXT NOT NULL, + mxid TEXT UNIQUE, + avatar_url TEXT, + encrypted BOOLEAN NOT NULL DEFAULT false, + username TEXT, + title TEXT, + about TEXT, + photo_id TEXT, + megagroup BOOLEAN, + config jsonb, + PRIMARY KEY (tgid, tg_receiver) + )""" + ) + await conn.execute( + """CREATE TABLE message ( + mxid TEXT, + mx_room TEXT, + tgid BIGINT NOT NULL, + tg_space BIGINT NOT NULL, + edit_index INTEGER NOT NULL, + redacted BOOLEAN NOT NULL DEFAULT false, + PRIMARY KEY (tgid, tg_space, edit_index), + UNIQUE (mxid, mx_room, tg_space) + )""" + ) + await conn.execute( + """CREATE TABLE puppet ( + id BIGINT PRIMARY KEY, + + is_registered BOOLEAN NOT NULL DEFAULT false, + + displayname TEXT, + displayname_source BIGINT, + displayname_contact BOOLEAN NOT NULL DEFAULT true, + displayname_quality INTEGER NOT NULL DEFAULT 0, + disable_updates BOOLEAN NOT NULL DEFAULT false, + username TEXT, + photo_id TEXT, + is_bot BOOLEAN, + + access_token TEXT, + custom_mxid TEXT, + next_batch TEXT, + base_url TEXT + )""" + ) + await conn.execute( + """CREATE TABLE telegram_file ( + id TEXT PRIMARY KEY, + mxc TEXT NOT NULL, + mime_type TEXT, + was_converted BOOLEAN NOT NULL DEFAULT false, + timestamp BIGINT NOT NULL DEFAULT 0, + size BIGINT, + width INTEGER, + height INTEGER, + thumbnail TEXT, + decryption_info jsonb, + FOREIGN KEY (thumbnail) REFERENCES telegram_file(id) + ON UPDATE CASCADE ON DELETE SET NULL + )""" + ) + await conn.execute( + """CREATE TABLE bot_chat ( + id BIGINT PRIMARY KEY, + type TEXT NOT NULL + )""" + ) + await conn.execute( + """CREATE TABLE user_portal ( + "user" BIGINT, + portal BIGINT, + portal_receiver BIGINT, + PRIMARY KEY ("user", portal, portal_receiver), + FOREIGN KEY ("user") REFERENCES "user"(tgid) ON DELETE CASCADE ON UPDATE CASCADE, + FOREIGN KEY (portal, portal_receiver) REFERENCES portal(tgid, tg_receiver) + ON DELETE CASCADE ON UPDATE CASCADE + )""" + ) + await conn.execute( + """CREATE TABLE contact ( + "user" BIGINT, + contact BIGINT, + PRIMARY KEY ("user", contact), + FOREIGN KEY ("user") REFERENCES "user"(tgid) ON DELETE CASCADE ON UPDATE CASCADE, + FOREIGN KEY (contact) REFERENCES puppet(id) ON DELETE CASCADE ON UPDATE CASCADE + )""" + ) + await conn.execute( + """CREATE TABLE telethon_sessions ( + session_id TEXT PRIMARY KEY, + dc_id INTEGER, + server_address TEXT, + port INTEGER, + auth_key bytea + )""" + ) + await conn.execute( + """CREATE TABLE telethon_entities ( + session_id TEXT, + id BIGINT, + hash BIGINT NOT NULL, + username TEXT, + phone TEXT, + name TEXT, + PRIMARY KEY (session_id, id) + )""" + ) + await conn.execute( + """CREATE TABLE telethon_sent_files ( + session_id TEXT, + md5_digest bytea, + file_size INTEGER, + type INTEGER, + id BIGINT, + hash BIGINT, + PRIMARY KEY (session_id, md5_digest, file_size, type) + )""" + ) + await conn.execute( + """CREATE TABLE telethon_update_state ( + session_id TEXT, + entity_id BIGINT, + pts BIGINT, + qts BIGINT, + date BIGINT, + seq BIGINT, + unread_count INTEGER, + PRIMARY KEY (session_id, entity_id) + )""" + ) diff --git a/mautrix_telegram/db/user.py b/mautrix_telegram/db/user.py index a2b952ae..3274a1b8 100644 --- a/mautrix_telegram/db/user.py +++ b/mautrix_telegram/db/user.py @@ -1,5 +1,5 @@ # mautrix-telegram - A Matrix-Telegram puppeting bridge -# Copyright (C) 2019 Tulir Asokan +# Copyright (C) 2021 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -13,96 +13,119 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Optional, Iterable, Tuple +from __future__ import annotations -from sqlalchemy import Column, ForeignKey, ForeignKeyConstraint, BigInteger, Integer, String, func +from typing import Iterable, ClassVar, TYPE_CHECKING + +from asyncpg import Record +from attr import dataclass from mautrix.types import UserID -from mautrix.util.db import Base +from mautrix.util.async_db import Database from ..types import TelegramID +fake_db = Database.create("") if TYPE_CHECKING else None -class User(Base): - __tablename__ = "user" - mxid: UserID = Column(String, primary_key=True) - tgid: Optional[TelegramID] = Column(BigInteger, nullable=True, unique=True) - tg_username: str = Column(String, nullable=True) - tg_phone: str = Column(String, nullable=True) - saved_contacts: int = Column(Integer, default=0, nullable=False) +@dataclass +class User: + db: ClassVar[Database] = fake_db + + mxid: UserID + tgid: TelegramID | None + tg_username: str | None + tg_phone: str | None + is_bot: bool + saved_contacts: int @classmethod - def all_with_tgid(cls) -> Iterable['User']: - return cls._select_all(cls.c.tgid != None) + def _from_row(cls, row: Record | None) -> User | None: + if row is None: + return None + return cls(**row) + + columns: ClassVar[str] = "mxid, tgid, tg_username, tg_phone, is_bot, saved_contacts" @classmethod - def get_by_tgid(cls, tgid: TelegramID) -> Optional['User']: - return cls._select_one_or_none(cls.c.tgid == tgid) + async def get_by_tgid(cls, tgid: TelegramID) -> User | None: + q = f'SELECT {cls.columns} FROM "user" WHERE tgid=$1' + return cls._from_row(await cls.db.fetchrow(q, tgid)) @classmethod - def get_by_mxid(cls, mxid: UserID) -> Optional['User']: - return cls._select_one_or_none(cls.c.mxid == mxid) + async def get_by_mxid(cls, mxid: UserID) -> User | None: + q = f'SELECT {cls.columns} FROM "user" WHERE mxid=$1' + return cls._from_row(await cls.db.fetchrow(q, mxid)) @classmethod - def get_by_username(cls, username: str) -> Optional['User']: - return cls._select_one_or_none(func.lower(cls.c.tg_username) == username) + async def find_by_username(cls, username: str) -> User | None: + q = f'SELECT {cls.columns} FROM "user" WHERE lower(tg_username)=$1' + return cls._from_row(await cls.db.fetchrow(q, username.lower())) + + @classmethod + async def all_with_tgid(cls) -> list[User]: + q = f'SELECT {cls.columns} FROM "user" WHERE tgid IS NOT NULL' + return [cls._from_row(row) for row in await cls.db.fetch(q)] + + async def delete(self) -> None: + await self.db.execute('DELETE FROM "user" WHERE mxid=$1', self.mxid) @property - def contacts(self) -> Iterable[TelegramID]: - rows = self.db.execute(Contact.t.select().where(Contact.c.user == self.tgid)) - for row in rows: - user, contact = row - yield contact + def _values(self): + return ( + self.mxid, self.tgid, self.tg_username, self.tg_phone, self.is_bot, self.saved_contacts + ) - @contacts.setter - def contacts(self, puppets: Iterable[TelegramID]) -> None: - with self.db.begin() as conn: - conn.execute(Contact.t.delete().where(Contact.c.user == self.tgid)) - insert_puppets = [{"user": self.tgid, "contact": tgid} for tgid in puppets] - if insert_puppets: - conn.execute(Contact.t.insert(), insert_puppets) + async def save(self) -> None: + q = ( + 'UPDATE "user" SET tgid=$2, tg_username=$3, tg_phone=$4, is_bot=$5, saved_contacts=$6 ' + 'WHERE mxid=$1' + ) + await self.db.execute(q, *self._values) - @property - def portals(self) -> Iterable[Tuple[TelegramID, TelegramID]]: - rows = self.db.execute(UserPortal.t.select().where(UserPortal.c.user == self.tgid)) - for row in rows: - user, portal, portal_receiver = row - yield (portal, portal_receiver) + async def insert(self) -> None: + q = ( + 'INSERT INTO "user" (mxid, tgid, tg_username, tg_phone, is_bot, saved_contacts) ' + 'VALUES ($1, $2, $3, $4, $5, $6)' + ) + await self.db.execute(q, *self._values) - @portals.setter - def portals(self, portals: Iterable[Tuple[TelegramID, TelegramID]]) -> None: - with self.db.begin() as conn: - conn.execute(UserPortal.t.delete().where(UserPortal.c.user == self.tgid)) - insert_portals = [{ - "user": self.tgid, - "portal": tgid, - "portal_receiver": tg_receiver - } for tgid, tg_receiver in portals] - if insert_portals: - conn.execute(UserPortal.t.insert(), insert_portals) + async def get_contacts(self) -> list[TelegramID]: + rows = await self.db.fetch('SELECT contact FROM contact WHERE "user"=$1', self.tgid) + return [TelegramID(row["contact"]) for row in rows] - def delete(self) -> None: - super().delete() - self.portals = [] - self.contacts = [] + async def set_contacts(self, puppets: Iterable[TelegramID]) -> None: + columns = ["user", "contact"] + records = [(self.tgid, puppet_id) for puppet_id in puppets] + async with self.db.acquire() as conn, conn.transaction(): + await conn.execute('DELETE FROM contact WHERE "user"=$1', self.tgid) + if self.db.scheme == "postgres": + await conn.copy_records_to_table("contact", records=records, columns=columns) + else: + q = 'INSERT INTO contact ("user", contact) VALUES ($1, $2)' + await conn.executemany(q, records) + async def get_portals(self) -> list[tuple[TelegramID, TelegramID]]: + q = 'SELECT portal, portal_receiver FROM user_portal WHERE "user"=$1' + rows = await self.db.fetch(q, self.tgid) + return [(TelegramID(row["portal"]), TelegramID(row["portal_receiver"])) for row in rows] -class UserPortal(Base): - __tablename__ = "user_portal" + async def set_portals(self, portals: Iterable[tuple[TelegramID, TelegramID]]) -> None: + columns = ["user", "portal", "portal_receiver"] + records = [(self.tgid, tgid, tg_receiver) for tgid, tg_receiver in portals] + async with self.db.acquire() as conn, conn.transaction(): + await conn.execute('DELETE FROM user_portal WHERE "user"=$1', self.tgid) + if self.db.scheme == "postgres": + await conn.copy_records_to_table("user_portal", records=records, columns=columns) + else: + q = 'INSERT INTO user_portal ("user", portal, portal_receiver) VALUES ($1, $2, $3)' + await conn.executemany(q, records) - user: TelegramID = Column(BigInteger, ForeignKey("user.tgid", onupdate="CASCADE", - ondelete="CASCADE"), primary_key=True) - portal: TelegramID = Column(BigInteger, primary_key=True) - portal_receiver: TelegramID = Column(BigInteger, primary_key=True) + async def register_portal(self, tgid: TelegramID, tg_receiver: TelegramID) -> None: + q = ('INSERT INTO user_portal ("user", portal, portal_receiver) VALUES ($1, $2, $3) ' + 'ON CONFLICT ("user", portal, portal_receiver) DO NOTHING') + await self.db.execute(q, self.tgid, tgid, tg_receiver) - __table_args__ = (ForeignKeyConstraint(("portal", "portal_receiver"), - ("portal.tgid", "portal.tg_receiver"), - onupdate="CASCADE", ondelete="CASCADE"),) - - -class Contact(Base): - __tablename__ = "contact" - - user: TelegramID = Column(BigInteger, ForeignKey("user.tgid"), primary_key=True) - contact: TelegramID = Column(BigInteger, ForeignKey("puppet.id"), primary_key=True) + async def unregister_portal(self, tgid: TelegramID, tg_receiver: TelegramID) -> None: + q = 'DELETE FROM user_portal WHERE "user"=$1 AND portal=$2 AND portal_receiver=$3' + await self.db.execute(q, self.tgid, tgid, tg_receiver) diff --git a/mautrix_telegram/example-config.yaml b/mautrix_telegram/example-config.yaml index 4fab863e..8b31a882 100644 --- a/mautrix_telegram/example-config.yaml +++ b/mautrix_telegram/example-config.yaml @@ -277,6 +277,10 @@ bridge: archive_tag: null # Whether or not mute status and tags should only be bridged when the portal room is created. tag_only_on_create: true + # Should leaving the room on Matrix make the user leave on Telegram? + bridge_matrix_leave: true + # Should the user be kicked out of all portals when logging out of the bridge? + kick_on_logout: true # Settings for backfilling messages from Telegram. backfill: # Whether or not the Telegram ghosts of logged in Matrix users should be diff --git a/mautrix_telegram/formatter/__init__.py b/mautrix_telegram/formatter/__init__.py index 2978ed01..cf46d796 100644 --- a/mautrix_telegram/formatter/__init__.py +++ b/mautrix_telegram/formatter/__init__.py @@ -1,7 +1,2 @@ -from .from_matrix import matrix_reply_to_telegram, matrix_to_telegram, init_mx +from .from_matrix import matrix_reply_to_telegram, matrix_to_telegram from .from_telegram import telegram_reply_to_matrix, telegram_to_matrix -from .. import context as c - - -def init(context: c.Context) -> None: - init_mx(context) diff --git a/mautrix_telegram/formatter/from_matrix/__init__.py b/mautrix_telegram/formatter/from_matrix/__init__.py index 1fb6af1b..3093d684 100644 --- a/mautrix_telegram/formatter/from_matrix/__init__.py +++ b/mautrix_telegram/formatter/from_matrix/__init__.py @@ -1,5 +1,5 @@ # mautrix-telegram - A Matrix-Telegram puppeting bridge -# Copyright (C) 2019 Tulir Asokan +# Copyright (C) 2021 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -13,39 +13,77 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Optional, List, Tuple, Callable, Pattern, Match, TYPE_CHECKING -import re -import logging +from __future__ import annotations -from telethon.tl.types import (MessageEntityMention, MessageEntityMentionName, MessageEntityItalic, - TypeMessageEntity, InputMessageEntityMentionName) +import re + +from telethon.tl.types import MessageEntityItalic, TypeMessageEntity from telethon.helpers import add_surrogate, del_surrogate from telethon import TelegramClient from mautrix.types import RoomID, MessageEventContent -from mautrix.util.logging import TraceLogger -from ... import puppet as pu from ...types import TelegramID from ...db import Message as DBMessage -from .parser import ParsedMessage, parse_html +from .parser import MatrixParser -if TYPE_CHECKING: - from ...context import Context - -log: TraceLogger = logging.getLogger("mau.fmt.mx") -should_bridge_plaintext_highlights: bool = False - -command_regex: Pattern = re.compile(r"^!([A-Za-z0-9@]+)") -not_command_regex: Pattern = re.compile(r"^\\(![A-Za-z0-9@]+)") -plain_mention_regex: Optional[Pattern] = None +command_regex = re.compile(r"^!([A-Za-z0-9@]+)") +not_command_regex = re.compile(r"^\\(![A-Za-z0-9@]+)") MAX_LENGTH = 4096 CUTOFF_TEXT = " [message cut]" CUT_MAX_LENGTH = MAX_LENGTH - len(CUTOFF_TEXT) -def _cut_long_message(message: str, entities: List[TypeMessageEntity]) -> ParsedMessage: +class FormatError(Exception): + pass + + +async def matrix_reply_to_telegram( + content: MessageEventContent, tg_space: TelegramID, room_id: RoomID | None = None +) -> TelegramID | None: + event_id = content.get_reply_to() + if not event_id: + return + content.trim_reply_fallback() + + message = await DBMessage.get_by_mxid(event_id, room_id, tg_space) + if message: + return message.tgid + return None + + +async def matrix_to_telegram( + client: TelegramClient, *, text: str | None = None, html: str | None = None +) -> tuple[str, list[TypeMessageEntity]]: + if html is not None: + return await _matrix_html_to_telegram(client, html) + elif text is not None: + return _matrix_text_to_telegram(text), [] + else: + raise ValueError("text or html must be provided to convert formatting") + + +async def _matrix_html_to_telegram( + client: TelegramClient, html: str +) -> tuple[str, list[TypeMessageEntity]]: + try: + html = command_regex.sub(r"\1", html) + html = html.replace("\t", " " * 4) + html = not_command_regex.sub(r"\1", html) + + parsed = await MatrixParser(client).parse(add_surrogate(html)) + text = del_surrogate(parsed.text.strip()) + text, entities = _cut_long_message(text, parsed.telegram_entities) + + return text, entities + except Exception as e: + raise FormatError(f"Failed to convert Matrix format: {html}") from e + + +def _cut_long_message( + message: str, entities: list[TypeMessageEntity] +) -> tuple[str, list[TypeMessageEntity]]: if len(message) > MAX_LENGTH: message = message[0:CUT_MAX_LENGTH] + CUTOFF_TEXT new_entities = [] @@ -60,112 +98,8 @@ def _cut_long_message(message: str, entities: List[TypeMessageEntity]) -> Parsed return message, entities -class FormatError(Exception): - pass - - -def matrix_reply_to_telegram(content: MessageEventContent, tg_space: TelegramID, - room_id: Optional[RoomID] = None) -> Optional[TelegramID]: - event_id = content.get_reply_to() - if not event_id: - return - content.trim_reply_fallback() - - message = DBMessage.get_by_mxid(event_id, room_id, tg_space) - if message: - return message.tgid - return None - - -async def matrix_to_telegram(client: TelegramClient, *, text: Optional[str] = None, - html: Optional[str] = None) -> ParsedMessage: - if html is not None: - text, entities = _matrix_html_to_telegram(html) - elif text is not None: - text, entities = _matrix_text_to_telegram(text) - else: - raise ValueError("text or html must be provided to convert formatting") - await _fix_name_mentions(client, entities) - return text, entities - - -def _matrix_html_to_telegram(html: str) -> ParsedMessage: - try: - html = command_regex.sub(r"\1", html) - html = html.replace("\t", " " * 4) - html = not_command_regex.sub(r"\1", html) - if should_bridge_plaintext_highlights: - html = plain_mention_regex.sub(_plain_mention_to_html, html) - - text, entities = parse_html(add_surrogate(html)) - text = del_surrogate(text.strip()) - text, entities = _cut_long_message(text, entities) - - return text, entities - except Exception as e: - raise FormatError(f"Failed to convert Matrix format: {html}") from e - - -def _matrix_text_to_telegram(text: str) -> ParsedMessage: +def _matrix_text_to_telegram(text: str) -> str: text = command_regex.sub(r"/\1", text) text = text.replace("\t", " " * 4) text = not_command_regex.sub(r"\1", text) - if should_bridge_plaintext_highlights: - entities, pmr_replacer = _plain_mention_to_text() - text = plain_mention_regex.sub(pmr_replacer, text) - else: - entities = [] - return text, entities - - -async def _fix_name_mentions(client: TelegramClient, entities: List[TypeMessageEntity]) -> None: - for index in reversed(range(len(entities))): - entity = entities[index] - if isinstance(entity, (MessageEntityMentionName, InputMessageEntityMentionName)): - try: - user = await client.get_input_entity(entity.user_id) - except (ValueError, TypeError) as e: - log.trace(f"Dropping mention of {entity.user_id}: {e}") - del entities[index] - else: - entities[index] = InputMessageEntityMentionName(entity.offset, entity.length, user) - - -def _plain_mention_to_text() -> Tuple[List[TypeMessageEntity], Callable[[Match], str]]: - entities = [] - - def replacer(match: Match) -> str: - puppet = pu.Puppet.find_by_displayname(match.group(2)) - if puppet: - offset = match.start() - length = match.end() - offset - if puppet.username: - entity = MessageEntityMention(offset, length) - text = f"@{puppet.username}" - else: - entity = MessageEntityMentionName(offset, length, user_id=puppet.tgid) - text = puppet.displayname - entities.append(entity) - return text - return "".join(match.groups()) - - return entities, replacer - - -def _plain_mention_to_html(match: Match) -> str: - puppet = pu.Puppet.find_by_displayname(match.group(2)) - if puppet: - return (f"{match.group(1)}" - f"" - f"{puppet.displayname}" - "") - return "".join(match.groups()) - - -def init_mx(context: "Context") -> None: - global plain_mention_regex, should_bridge_plaintext_highlights - config = context.config - dn_template = config["bridge.displayname_template"] - dn_template = re.escape(dn_template).replace(re.escape("{displayname}"), "[^>]+") - plain_mention_regex = re.compile(f"^({dn_template})") - should_bridge_plaintext_highlights = config["bridge.plaintext_highlights"] + return text diff --git a/mautrix_telegram/formatter/from_matrix/parser.py b/mautrix_telegram/formatter/from_matrix/parser.py index cf672210..08e2445d 100644 --- a/mautrix_telegram/formatter/from_matrix/parser.py +++ b/mautrix_telegram/formatter/from_matrix/parser.py @@ -1,5 +1,5 @@ # mautrix-telegram - A Matrix-Telegram puppeting bridge -# Copyright (C) 2019 Tulir Asokan +# Copyright (C) 2021 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -13,77 +13,80 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import List, Tuple, Optional +from __future__ import annotations -from telethon.tl.types import TypeMessageEntity +import logging + +from telethon import TelegramClient from mautrix.types import UserID, RoomID from mautrix.util.formatter import MatrixParser as BaseMatrixParser, RecursionContext from mautrix.util.formatter.html_reader_htmlparser import read_html, HTMLNode +from mautrix.util.logging import TraceLogger from ... import user as u, puppet as pu, portal as po from .telegram_message import TelegramMessage, TelegramEntityType - -ParsedMessage = Tuple[str, List[TypeMessageEntity]] - - -def parse_html(input_html: str) -> ParsedMessage: - msg = MatrixParser.parse(input_html) - return msg.text, msg.telegram_entities +log: TraceLogger = logging.getLogger("mau.fmt.mx") class MatrixParser(BaseMatrixParser[TelegramMessage]): e = TelegramEntityType fs = TelegramMessage - read_html = read_html + client: TelegramClient - @classmethod - def custom_node_to_fstring(cls, node: HTMLNode, ctx: RecursionContext - ) -> Optional[TelegramMessage]: - msg = cls.tag_aware_parse_node(node, ctx) + def __init__(self, client: TelegramClient) -> None: + self.client = client + self.read_html = read_html + + async def custom_node_to_fstring( + self, node: HTMLNode, ctx: RecursionContext + ) -> TelegramMessage | None: + msg = await self.tag_aware_parse_node(node, ctx) if node.tag == "command": msg.format(TelegramEntityType.COMMAND) return None - @classmethod - def user_pill_to_fstring(cls, msg: TelegramMessage, user_id: UserID) -> TelegramMessage: - user = (pu.Puppet.deprecated_sync_get_by_mxid(user_id) - or u.User.get_by_mxid(user_id, create=False)) + async def user_pill_to_fstring(self, msg: TelegramMessage, user_id: UserID) -> TelegramMessage: + user = (await pu.Puppet.get_by_mxid(user_id) + or await u.User.get_by_mxid(user_id, create=False)) if not user: return msg - if user.username: - return TelegramMessage(f"@{user.username}").format(TelegramEntityType.MENTION) + if user.tg_username: + return TelegramMessage(f"@{user.tg_username}").format(TelegramEntityType.MENTION) elif user.tgid: displayname = user.plain_displayname or msg.text - return TelegramMessage(displayname).format(TelegramEntityType.MENTION_NAME, - user_id=user.tgid) + msg = TelegramMessage(displayname) + try: + input_entity = self.client.get_input_entity(user.tgid) + except (ValueError, TypeError) as e: + log.trace(f"Dropping mention of {user.tgid}: {e}") + else: + msg = msg.format(TelegramEntityType.MENTION_NAME, user_id=input_entity) return msg - @classmethod - def url_to_fstring(cls, msg: TelegramMessage, url: str) -> TelegramMessage: + async def url_to_fstring(self, msg: TelegramMessage, url: str) -> TelegramMessage: if url == msg.text: - return msg.format(cls.e.URL) + return msg.format(self.e.URL) else: - return msg.format(cls.e.INLINE_URL, url=url) + return msg.format(self.e.INLINE_URL, url=url) - @classmethod - def room_pill_to_fstring(cls, msg: TelegramMessage, room_id: RoomID) -> TelegramMessage: + async def room_pill_to_fstring(self, msg: TelegramMessage, room_id: RoomID) -> TelegramMessage: username = po.Portal.get_username_from_mx_alias(room_id) - portal = po.Portal.find_by_username(username) + portal = await po.Portal.find_by_username(username) if portal and portal.username: return TelegramMessage(f"@{portal.username}").format(TelegramEntityType.MENTION) - @classmethod - def header_to_fstring(cls, node: HTMLNode, ctx: RecursionContext) -> TelegramMessage: - children = cls.node_to_fstrings(node, ctx) + async def header_to_fstring(self, node: HTMLNode, ctx: RecursionContext) -> TelegramMessage: + children = await self.node_to_fstrings(node, ctx) length = int(node.tag[1]) prefix = "#" * length + " " return TelegramMessage.join(children, "").prepend(prefix).format(TelegramEntityType.BOLD) - @classmethod - def blockquote_to_fstring(cls, node: HTMLNode, ctx: RecursionContext) -> TelegramMessage: - msg = cls.tag_aware_parse_node(node, ctx) + async def blockquote_to_fstring( + self, node: HTMLNode, ctx: RecursionContext + ) -> TelegramMessage: + msg = await self.tag_aware_parse_node(node, ctx) children = msg.trim().split("\n") children = [child.prepend("> ") for child in children] return TelegramMessage.join(children, "\n") diff --git a/mautrix_telegram/formatter/from_matrix/telegram_message.py b/mautrix_telegram/formatter/from_matrix/telegram_message.py index 9ee1b94e..db04fe8c 100644 --- a/mautrix_telegram/formatter/from_matrix/telegram_message.py +++ b/mautrix_telegram/formatter/from_matrix/telegram_message.py @@ -1,5 +1,5 @@ # mautrix-telegram - A Matrix-Telegram puppeting bridge -# Copyright (C) 2019 Tulir Asokan +# Copyright (C) 2021 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -13,7 +13,9 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Optional, Union, Any, List, Type, Dict +from __future__ import annotations + +from typing import Any, Type from enum import Enum from telethon.tl.types import (MessageEntityMention as Mention, MessageEntityBotCommand as Command, @@ -41,7 +43,7 @@ class TelegramEntityType(Enum): INLINE_CODE = Code BLOCKQUOTE = Blockquote MENTION = Mention - MENTION_NAME = MentionName + MENTION_NAME = InputMentionName COMMAND = Command USER_MENTION = 1 @@ -52,15 +54,15 @@ class TelegramEntityType(Enum): class TelegramEntity(SemiAbstractEntity): internal: TypeMessageEntity - def __init__(self, type: Union[TelegramEntityType, Type[TypeMessageEntity]], - offset: int, length: int, extra_info: Dict[str, Any]) -> None: + def __init__(self, type: TelegramEntityType | Type[TypeMessageEntity], + offset: int, length: int, extra_info: dict[str, Any]) -> None: if isinstance(type, TelegramEntityType): if isinstance(type.value, int): raise ValueError(f"Can't create Entity with non-Telegram EntityType {type}") type = type.value self.internal = type(offset=offset, length=length, **extra_info) - def copy(self) -> Optional['TelegramEntity']: + def copy(self) -> TelegramEntity: extra_info = {} if isinstance(self.internal, Pre): extra_info["language"] = self.internal.language @@ -95,5 +97,5 @@ class TelegramMessage(EntityString[TelegramEntity, TelegramEntityType]): entity_class = TelegramEntity @property - def telegram_entities(self) -> List[TypeMessageEntity]: + def telegram_entities(self) -> list[TypeMessageEntity]: return [entity.internal for entity in self.entities] diff --git a/mautrix_telegram/formatter/from_telegram.py b/mautrix_telegram/formatter/from_telegram.py index 022ff882..2d90d2d7 100644 --- a/mautrix_telegram/formatter/from_telegram.py +++ b/mautrix_telegram/formatter/from_telegram.py @@ -13,7 +13,8 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import List, Optional, TYPE_CHECKING +from __future__ import annotations + from html import escape import logging import re @@ -29,47 +30,45 @@ from telethon.tl.custom import Message from telethon.errors import RPCError from telethon.helpers import add_surrogate, del_surrogate -from mautrix.errors import MatrixRequestError from mautrix.appservice import IntentAPI from mautrix.types import (TextMessageEventContent, RelatesTo, RelationType, Format, MessageType, - MessageEvent, EventType) + EventType) -from .. import user as u, puppet as pu, portal as po +from .. import user as u, puppet as pu, portal as po, abstract_user as au from ..types import TelegramID from ..db import Message as DBMessage -if TYPE_CHECKING: - from ..abstract_user import AbstractUser - log: logging.Logger = logging.getLogger("mau.fmt.tg") -def telegram_reply_to_matrix(evt: Message, source: 'AbstractUser') -> Optional[RelatesTo]: +async def telegram_reply_to_matrix(evt: Message, source: au.AbstractUser) -> RelatesTo | None: if evt.reply_to: space = (evt.peer_id.channel_id if isinstance(evt, Message) and isinstance(evt.peer_id, PeerChannel) else source.tgid) - msg = DBMessage.get_one_by_tgid(TelegramID(evt.reply_to.reply_to_msg_id), space) + msg = await DBMessage.get_one_by_tgid(TelegramID(evt.reply_to.reply_to_msg_id), space) if msg: return RelatesTo(rel_type=RelationType.REPLY, event_id=msg.mxid) return None -async def _add_forward_header(source: 'AbstractUser', content: TextMessageEventContent, +async def _add_forward_header(source: au.AbstractUser, content: TextMessageEventContent, fwd_from: MessageFwdHeader) -> None: if not content.formatted_body or content.format != Format.HTML: content.format = Format.HTML content.formatted_body = escape(content.body) fwd_from_html, fwd_from_text = None, None if isinstance(fwd_from.from_id, PeerUser): - user = u.User.get_by_tgid(TelegramID(fwd_from.from_id.user_id)) + user = await u.User.get_by_tgid(TelegramID(fwd_from.from_id.user_id)) if user: fwd_from_text = user.displayname or user.mxid fwd_from_html = (f"" f"{escape(fwd_from_text)}") if not fwd_from_text: - puppet = pu.Puppet.get(TelegramID(fwd_from.from_id.user_id), create=False) + puppet = await pu.Puppet.get_by_tgid( + TelegramID(fwd_from.from_id.user_id), create=False + ) if puppet and puppet.displayname: fwd_from_text = puppet.displayname or puppet.mxid fwd_from_html = (f"" @@ -86,7 +85,7 @@ async def _add_forward_header(source: 'AbstractUser', content: TextMessageEventC elif isinstance(fwd_from.from_id, (PeerChannel, PeerChat)): from_id = (fwd_from.from_id.chat_id if isinstance(fwd_from.from_id, PeerChat) else fwd_from.from_id.channel_id) - portal = po.Portal.get_by_tgid(TelegramID(from_id)) + portal = await po.Portal.get_by_tgid(TelegramID(from_id)) if portal and portal.title: fwd_from_text = portal.title if portal.alias: @@ -116,13 +115,13 @@ async def _add_forward_header(source: 'AbstractUser', content: TextMessageEventC f"
{content.formatted_body}
") -async def _add_reply_header(source: 'AbstractUser', content: TextMessageEventContent, evt: Message, - main_intent: IntentAPI): +async def _add_reply_header(source: au.AbstractUser, content: TextMessageEventContent, + evt: Message, main_intent: IntentAPI) -> None: space = (evt.peer_id.channel_id if isinstance(evt, Message) and isinstance(evt.peer_id, PeerChannel) else source.tgid) - msg = DBMessage.get_one_by_tgid(TelegramID(evt.reply_to.reply_to_msg_id), space) + msg = await DBMessage.get_one_by_tgid(TelegramID(evt.reply_to.reply_to_msg_id), space) if not msg: return @@ -140,11 +139,11 @@ async def _add_reply_header(source: 'AbstractUser', content: TextMessageEventCon log.exception("Failed to get event to add reply fallback") -async def telegram_to_matrix(evt: Message, source: "AbstractUser", - main_intent: Optional[IntentAPI] = None, - prefix_text: Optional[str] = None, prefix_html: Optional[str] = None, +async def telegram_to_matrix(evt: Message, source: au.AbstractUser, + main_intent: IntentAPI | None = None, + prefix_text: str | None = None, prefix_html: str | None = None, override_text: str = None, - override_entities: List[TypeMessageEntity] = None, + override_entities: list[TypeMessageEntity] = None, no_reply_fallback: bool = False) -> TextMessageEventContent: content = TextMessageEventContent( msgtype=MessageType.TEXT, @@ -153,7 +152,7 @@ async def telegram_to_matrix(evt: Message, source: "AbstractUser", entities = override_entities or evt.entities if entities: content.format = Format.HTML - content.formatted_body = _telegram_entities_to_matrix_catch(content.body, entities) + content.formatted_body = await _telegram_entities_to_matrix_catch(content.body, entities) if prefix_html: if not content.formatted_body: @@ -183,9 +182,9 @@ async def telegram_to_matrix(evt: Message, source: "AbstractUser", return content -def _telegram_entities_to_matrix_catch(text: str, entities: List[TypeMessageEntity]) -> str: +async def _telegram_entities_to_matrix_catch(text: str, entities: list[TypeMessageEntity]) -> str: try: - return _telegram_entities_to_matrix(text, entities) + return await _telegram_entities_to_matrix(text, entities) except Exception: log.exception("Failed to convert Telegram format:\n" "message=%s\n" @@ -194,8 +193,8 @@ def _telegram_entities_to_matrix_catch(text: str, entities: List[TypeMessageEnti return "[failed conversion in _telegram_entities_to_matrix]" -def _telegram_entities_to_matrix(text: str, entities: List[TypeMessageEntity], - offset: int = 0, length: int = None) -> str: +async def _telegram_entities_to_matrix(text: str, entities: list[TypeMessageEntity], + offset: int = 0, length: int = None) -> str: if not entities: return escape(text) if length is None: @@ -212,7 +211,7 @@ def _telegram_entities_to_matrix(text: str, entities: List[TypeMessageEntity], continue skip_entity = False - entity_text = _telegram_entities_to_matrix( + entity_text = await _telegram_entities_to_matrix( text=text[relative_offset:relative_offset + entity.length], entities=entities[i + 1:], offset=entity.offset, length=entity.length) entity_type = type(entity) @@ -234,16 +233,17 @@ def _telegram_entities_to_matrix(text: str, entities: List[TypeMessageEntity], elif entity_type == MessageEntityPre: skip_entity = _parse_pre(html, entity_text, entity.language) elif entity_type == MessageEntityMention: - skip_entity = _parse_mention(html, entity_text) + skip_entity = await _parse_mention(html, entity_text) elif entity_type == MessageEntityMentionName: - skip_entity = _parse_name_mention(html, entity_text, TelegramID(entity.user_id)) + skip_entity = await _parse_name_mention(html, entity_text, TelegramID(entity.user_id)) elif entity_type == MessageEntityEmail: html.append(f"
{entity_text}") elif entity_type in (MessageEntityTextUrl, MessageEntityUrl): - skip_entity = _parse_url(html, entity_text, - entity.url if entity_type == MessageEntityTextUrl else None) + skip_entity = await _parse_url( + html, entity_text, entity.url if entity_type == MessageEntityTextUrl else None + ) elif entity_type == MessageEntityBotCommand: - html.append(f"!{entity_text[1:]}") + html.append(f"{entity_text}") elif entity_type in (MessageEntityHashtag, MessageEntityCashtag, MessageEntityPhone): html.append(f"{entity_text}") else: @@ -254,24 +254,22 @@ def _telegram_entities_to_matrix(text: str, entities: List[TypeMessageEntity], return "".join(html) -def _parse_pre(html: List[str], entity_text: str, language: str) -> bool: +def _parse_pre(html: list[str], entity_text: str, language: str) -> bool: if language: - html.append("
"
-                    f"{entity_text}"
-                    "
") + html.append(f"
{entity_text}
") else: html.append(f"
{entity_text}
") return False -def _parse_mention(html: List[str], entity_text: str) -> bool: +async def _parse_mention(html: list[str], entity_text: str) -> bool: username = entity_text[1:] - user = u.User.find_by_username(username) or pu.Puppet.find_by_username(username) + user = await u.User.find_by_username(username) or await pu.Puppet.find_by_username(username) if user: mxid = user.mxid else: - portal = po.Portal.find_by_username(username) + portal = await po.Portal.find_by_username(username) mxid = portal.alias or portal.mxid if portal else None if mxid: @@ -281,12 +279,12 @@ def _parse_mention(html: List[str], entity_text: str) -> bool: return False -def _parse_name_mention(html: List[str], entity_text: str, user_id: TelegramID) -> bool: - user = u.User.get_by_tgid(user_id) +async def _parse_name_mention(html: list[str], entity_text: str, user_id: TelegramID) -> bool: + user = await u.User.get_by_tgid(user_id) if user: mxid = user.mxid else: - puppet = pu.Puppet.get(user_id, create=False) + puppet = await pu.Puppet.get_by_tgid(user_id, create=False) mxid = puppet.mxid if puppet else None if mxid: html.append(f"{entity_text}") @@ -299,7 +297,7 @@ message_link_regex = re.compile(r"https?://t(?:elegram)?\.(?:me|dog)/" r"([A-Za-z][A-Za-z0-9_]{3,}[A-Za-z0-9])/([0-9]{1,50})") -def _parse_url(html: List[str], entity_text: str, url: str) -> bool: +async def _parse_url(html: list[str], entity_text: str, url: str) -> bool: url = escape(url) if url else entity_text if not url.startswith(("https://", "http://", "ftp://", "magnet://")): url = "http://" + url @@ -309,9 +307,9 @@ def _parse_url(html: List[str], entity_text: str, url: str) -> bool: group, msgid_str = message_link_match.groups() msgid = int(msgid_str) - portal = po.Portal.find_by_username(group) + portal = await po.Portal.find_by_username(group) if portal: - message = DBMessage.get_one_by_tgid(TelegramID(msgid), portal.tgid) + message = await DBMessage.get_one_by_tgid(TelegramID(msgid), portal.tgid) if message: url = f"https://matrix.to/#/{portal.mxid}/{message.mxid}" diff --git a/mautrix_telegram/matrix.py b/mautrix_telegram/matrix.py index ec7e25b4..137d5c96 100644 --- a/mautrix_telegram/matrix.py +++ b/mautrix_telegram/matrix.py @@ -1,5 +1,5 @@ # mautrix-telegram - A Matrix-Telegram puppeting bridge -# Copyright (C) 2019 Tulir Asokan +# Copyright (C) 2021 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -13,13 +13,17 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Dict, Set, Tuple, Union, Iterable, TYPE_CHECKING +from __future__ import annotations + +from typing import Iterable, TYPE_CHECKING from mautrix.bridge import BaseMatrixHandler from mautrix.types import (Event, EventType, RoomID, UserID, EventID, ReceiptEvent, ReceiptType, ReceiptEventContent, PresenceEvent, PresenceState, TypingEvent, - StateEvent, RedactionEvent, RoomNameStateEventContent, - RoomAvatarStateEventContent, RoomTopicStateEventContent, + StateEvent, RedactionEvent, + RoomNameStateEventContent as NameContent, + RoomAvatarStateEventContent as AvatarContent, + RoomTopicStateEventContent as TopicContent, MemberStateEventContent, TextMessageEventContent, MessageType) from mautrix.errors import MatrixError @@ -27,28 +31,22 @@ from mautrix.errors import MatrixError from . import user as u, portal as po, puppet as pu, commands as com if TYPE_CHECKING: - from .context import Context - from .bot import Bot - -RoomMetaStateEventContent = Union[RoomNameStateEventContent, RoomAvatarStateEventContent, - RoomTopicStateEventContent] + from .__main__ import TelegramBridge class MatrixHandler(BaseMatrixHandler): - bot: 'Bot' - commands: 'com.CommandProcessor' - previously_typing: Dict[RoomID, Set[UserID]] + commands: com.CommandProcessor + _previously_typing: dict[RoomID, set[UserID]] - def __init__(self, context: 'Context') -> None: - prefix, suffix = context.config["bridge.username_template"].format(userid=":").split(":") - homeserver = context.config["homeserver.domain"] + def __init__(self, bridge: 'TelegramBridge') -> None: + prefix, suffix = bridge.config["bridge.username_template"].format(userid=":").split(":") + homeserver = bridge.config["homeserver.domain"] self.user_id_prefix = f"@{prefix}" self.user_id_suffix = f"{suffix}:{homeserver}" - super().__init__(command_processor=com.CommandProcessor(context), bridge=context.bridge) + super().__init__(command_processor=com.CommandProcessor(bridge), bridge=bridge) - self.bot = context.bot - self.previously_typing = {} + self._previously_typing = {} async def handle_puppet_invite(self, room_id: RoomID, puppet: pu.Puppet, inviter: u.User, event_id: EventID) -> None: @@ -58,7 +56,7 @@ class MatrixHandler(BaseMatrixHandler): await intent.error_and_leave( room_id, text="Please log in before inviting Telegram puppets.") return - portal = po.Portal.get_by_mxid(room_id) + portal = await po.Portal.get_by_mxid(room_id) if portal: if portal.peer_type == "user": await intent.error_and_leave( @@ -81,7 +79,9 @@ class MatrixHandler(BaseMatrixHandler): return await intent.join_room(room_id) - portal = po.Portal.get_by_tgid(puppet.tgid, inviter.tgid, "user") + portal = await po.Portal.get_by_tgid( + puppet.tgid, tg_receiver=inviter.tgid, peer_type="user" + ) if portal.mxid: try: await portal.invite_to_matrix(inviter.mxid) @@ -115,21 +115,21 @@ class MatrixHandler(BaseMatrixHandler): await intent.send_notice(room_id, "This puppet will remain inactive until a " "Telegram chat is created for this room.") - async def handle_invite(self, room_id: RoomID, user_id: UserID, inviter: 'u.User', + async def handle_invite(self, room_id: RoomID, user_id: UserID, inviter: u.User, event_id: EventID) -> None: - user = u.User.get_by_mxid(user_id, create=False) + user = await u.User.get_by_mxid(user_id, create=False) if not user: return await user.ensure_started() - portal = po.Portal.get_by_mxid(room_id) + portal = await po.Portal.get_by_mxid(room_id) if user and await user.has_full_access(allow_bot=True): if portal and portal.allow_bridging: await portal.invite_telegram(inviter, user) async def handle_join(self, room_id: RoomID, user_id: UserID, event_id: EventID) -> None: - user = await u.User.get_by_mxid(user_id).ensure_started() + user = await u.User.get_and_start_by_mxid(user_id) - portal = po.Portal.get_by_mxid(room_id) + portal = await po.Portal.get_by_mxid(room_id) if not portal or not portal.allow_bridging: return @@ -147,16 +147,13 @@ class MatrixHandler(BaseMatrixHandler): if await user.is_logged_in() or portal.has_bot: await portal.join_matrix(user, event_id) - async def get_leave_handle_info(self) -> Tuple[po.Portal, u.User]: - pass - async def handle_leave(self, room_id: RoomID, user_id: UserID, event_id: EventID) -> None: self.log.debug(f"{user_id} left {room_id}") - portal = po.Portal.get_by_mxid(room_id) + portal = await po.Portal.get_by_mxid(room_id) if not portal or not portal.allow_bridging: return - user = u.User.get_by_mxid(user_id, create=False) + user = await u.User.get_by_mxid(user_id, create=False) if not user: return await user.ensure_started() @@ -166,7 +163,7 @@ class MatrixHandler(BaseMatrixHandler): reason: str, event_id: EventID) -> None: action = "banned" if ban else "kicked" self.log.debug(f"{user_id} was {action} from {room_id} by {sender} for {reason}") - portal = po.Portal.get_by_mxid(room_id) + portal = await po.Portal.get_by_mxid(room_id) if not portal or not portal.allow_bridging: return @@ -176,7 +173,7 @@ class MatrixHandler(BaseMatrixHandler): await portal.unbridge() return - sender = u.User.get_by_mxid(sender, create=False) + sender = await u.User.get_by_mxid(sender, create=False) if not sender: return await sender.ensure_started() @@ -189,7 +186,7 @@ class MatrixHandler(BaseMatrixHandler): await portal.kick_matrix(puppet, sender) return - user = u.User.get_by_mxid(user_id, create=False) + user = await u.User.get_by_mxid(user_id, create=False) if not user: return await user.ensure_started() @@ -211,25 +208,23 @@ class MatrixHandler(BaseMatrixHandler): event_id: EventID) -> None: await self.handle_kick_ban(True, room_id, user_id, banned_by, reason, event_id) - @staticmethod - async def allow_message(user: 'u.User') -> bool: + async def allow_message(self, user: u.User) -> bool: return user.relaybot_whitelisted - @staticmethod - async def allow_command(user: 'u.User') -> bool: + async def allow_command(self, user: u.User) -> bool: return user.whitelisted @staticmethod - async def allow_bridging_message(user: 'u.User', portal: 'po.Portal') -> bool: + async def allow_bridging_message(user: u.User, portal: po.Portal) -> bool: return await user.is_logged_in() or portal.has_bot @staticmethod async def handle_redaction(evt: RedactionEvent) -> None: - sender = await u.User.get_by_mxid(evt.sender).ensure_started() + sender = await u.User.get_and_start_by_mxid(evt.sender) if not sender.relaybot_whitelisted: return - portal = po.Portal.get_by_mxid(evt.room_id) + portal = await po.Portal.get_by_mxid(evt.room_id) if not portal or not portal.allow_bridging: return @@ -237,23 +232,28 @@ class MatrixHandler(BaseMatrixHandler): @staticmethod async def handle_power_levels(evt: StateEvent) -> None: - portal = po.Portal.get_by_mxid(evt.room_id) - sender = await u.User.get_by_mxid(evt.sender).ensure_started() + portal = await po.Portal.get_by_mxid(evt.room_id) + sender = await u.User.get_and_start_by_mxid(evt.sender) if await sender.has_full_access(allow_bot=True) and portal and portal.allow_bridging: await portal.handle_matrix_power_levels(sender, evt.content.users, evt.unsigned.prev_content.users, evt.event_id) @staticmethod - async def handle_room_meta(evt_type: EventType, room_id: RoomID, sender_mxid: UserID, - content: RoomMetaStateEventContent, event_id: EventID) -> None: - portal = po.Portal.get_by_mxid(room_id) - sender = await u.User.get_by_mxid(sender_mxid).ensure_started() + async def handle_room_meta( + evt_type: EventType, + room_id: RoomID, + sender_mxid: UserID, + content: NameContent | AvatarContent | TopicContent, + event_id: EventID + ) -> None: + portal = await po.Portal.get_by_mxid(room_id) + sender = await u.User.get_and_start_by_mxid(sender_mxid) if await sender.has_full_access(allow_bot=True) and portal and portal.allow_bridging: handler, content_type, content_key = { - EventType.ROOM_NAME: (portal.handle_matrix_title, RoomNameStateEventContent, "name"), - EventType.ROOM_TOPIC: (portal.handle_matrix_about, RoomTopicStateEventContent, "topic"), - EventType.ROOM_AVATAR: (portal.handle_matrix_avatar, RoomAvatarStateEventContent, "url"), + EventType.ROOM_NAME: (portal.handle_matrix_title, NameContent, "name"), + EventType.ROOM_TOPIC: (portal.handle_matrix_about, TopicContent, "topic"), + EventType.ROOM_AVATAR: (portal.handle_matrix_avatar, AvatarContent, "url"), }[evt_type] if not isinstance(content, content_type): return @@ -261,10 +261,10 @@ class MatrixHandler(BaseMatrixHandler): @staticmethod async def handle_room_pin(room_id: RoomID, sender_mxid: UserID, - new_events: Set[str], old_events: Set[str], + new_events: set[str], old_events: set[str], event_id: EventID) -> None: - portal = po.Portal.get_by_mxid(room_id) - sender = await u.User.get_by_mxid(sender_mxid).ensure_started() + portal = await po.Portal.get_by_mxid(room_id) + sender = await u.User.get_and_start_by_mxid(sender_mxid) if await sender.has_full_access(allow_bot=True) and portal and portal.allow_bridging: if not new_events: await portal.handle_matrix_unpin_all(sender, event_id) @@ -276,7 +276,7 @@ class MatrixHandler(BaseMatrixHandler): @staticmethod async def handle_room_upgrade(room_id: RoomID, sender: UserID, new_room_id: RoomID, event_id: EventID) -> None: - portal = po.Portal.get_by_mxid(room_id) + portal = await po.Portal.get_by_mxid(room_id) if portal and portal.allow_bridging: await portal.handle_matrix_upgrade(sender, new_room_id, event_id) @@ -287,45 +287,45 @@ class MatrixHandler(BaseMatrixHandler): if profile.displayname == prev_profile.displayname: return - portal = po.Portal.get_by_mxid(room_id) + portal = await po.Portal.get_by_mxid(room_id) if not portal or not portal.has_bot or not portal.allow_bridging: return - user = await u.User.get_by_mxid(user_id).ensure_started() + user = await u.User.get_and_start_by_mxid(user_id) if await user.needs_relaybot(portal): await portal.name_change_matrix(user, profile.displayname, prev_profile.displayname, event_id) @staticmethod - def parse_read_receipts(content: ReceiptEventContent) -> Iterable[Tuple[UserID, EventID]]: + def parse_read_receipts(content: ReceiptEventContent) -> Iterable[tuple[UserID, EventID]]: return ((user_id, event_id) for event_id, receipts in content.items() for user_id in receipts.get(ReceiptType.READ, {})) @staticmethod - async def handle_read_receipts(room_id: RoomID, receipts: Iterable[Tuple[UserID, EventID]] + async def handle_read_receipts(room_id: RoomID, receipts: Iterable[tuple[UserID, EventID]] ) -> None: - portal = po.Portal.get_by_mxid(room_id) + portal = await po.Portal.get_by_mxid(room_id) if not portal or not portal.allow_bridging: return for user_id, event_id in receipts: - user = u.User.get_by_mxid(user_id, check_db=False, create=False) + user = await u.User.get_by_mxid(user_id, check_db=False, create=False) if user and await user.is_logged_in(): await portal.mark_read(user, event_id) @staticmethod async def handle_presence(user_id: UserID, presence: PresenceState) -> None: - user = u.User.get_by_mxid(user_id, check_db=False, create=False) + user = await u.User.get_by_mxid(user_id, check_db=False, create=False) if user and await user.is_logged_in(): await user.set_presence(presence == PresenceState.ONLINE) - async def handle_typing(self, room_id: RoomID, now_typing: Set[UserID]) -> None: - portal = po.Portal.get_by_mxid(room_id) + async def handle_typing(self, room_id: RoomID, now_typing: set[UserID]) -> None: + portal = await po.Portal.get_by_mxid(room_id) if not portal or not portal.allow_bridging: return - previously_typing = self.previously_typing.get(room_id, set()) + previously_typing = self._previously_typing.get(room_id, set()) for user_id in set(previously_typing | now_typing): is_typing = user_id in now_typing @@ -333,14 +333,15 @@ class MatrixHandler(BaseMatrixHandler): if is_typing and was_typing: continue - user = u.User.get_by_mxid(user_id, check_db=False, create=False) + user = await u.User.get_by_mxid(user_id, check_db=False, create=False) if user and await user.is_logged_in(): await portal.set_typing(user, is_typing) - self.previously_typing[room_id] = now_typing + self._previously_typing[room_id] = now_typing - async def handle_ephemeral_event(self, evt: Union[ReceiptEvent, PresenceEvent, TypingEvent] - ) -> None: + async def handle_ephemeral_event( + self, evt: ReceiptEvent | PresenceEvent | TypingEvent + ) -> None: if evt.type == EventType.RECEIPT: await self.handle_read_receipts(evt.room_id, self.parse_read_receipts(evt.content)) elif evt.type == EventType.PRESENCE: diff --git a/mautrix_telegram/portal.py b/mautrix_telegram/portal.py new file mode 100644 index 00000000..ba51f587 --- /dev/null +++ b/mautrix_telegram/portal.py @@ -0,0 +1,2759 @@ +# mautrix-telegram - A Matrix-Telegram puppeting bridge +# Copyright (C) 2021 Tulir Asokan +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +from __future__ import annotations + +from typing import (Awaitable, Any, Iterable, NamedTuple, AsyncGenerator, Callable, Union, List, + TYPE_CHECKING, cast) +from datetime import datetime +from string import Template +from html import escape as escape_html +import asyncio +import random +import mimetypes +import codecs +import unicodedata +import base64 + +import magic +from asyncpg import UniqueViolationError +from sqlite3 import IntegrityError + +from telethon.tl.functions.messages import (AddChatUserRequest, CreateChatRequest, + GetFullChatRequest, MigrateChatRequest, + ExportChatInviteRequest, EditChatPhotoRequest, + EditChatTitleRequest, + UpdatePinnedMessageRequest, SetTypingRequest, + EditChatAboutRequest, UnpinAllMessagesRequest) +from telethon.tl.functions.channels import (CreateChannelRequest, GetParticipantsRequest, + InviteToChannelRequest, UpdateUsernameRequest, + EditPhotoRequest, EditTitleRequest, JoinChannelRequest) +from telethon.errors import (ChatAdminRequiredError, ChatNotModifiedError, PhotoExtInvalidError, + MessageIdInvalidError, + PhotoInvalidDimensionsError, PhotoSaveFileInvalidError, RPCError) +from telethon.tl.patched import Message, MessageService +from telethon.tl.types import ( + ChannelFull, Chat, ChatFull, InputPeerChannel, InputPeerChat, PeerChannel, PeerChat, TypePeer, + TypeUserFull, UserFull, TypeInputChannel, Document, InputPhotoFileLocation, + PhotoSizeProgressive, PhotoSizeEmpty, Channel, ChatBannedRights, ChannelParticipantsRecent, + ChannelParticipantsSearch, ChatPhoto, PhotoEmpty, InputChannel, InputUser, Photo, TypeChat, + TypeUser, User, InputPeerPhotoFileLocation, ChatParticipantAdmin, ChannelParticipantAdmin, + ChatParticipantCreator, ChannelParticipantCreator, UserProfilePhoto, UserProfilePhotoEmpty, + InputPeerUser, ChannelParticipantBanned, Poll, DocumentAttributeSticker, + DocumentAttributeVideo, MessageMediaPoll, MessageActionChannelCreate, MessageActionChatAddUser, + MessageActionChatCreate, MessageActionChatDeletePhoto, MessageActionChatDeleteUser, + MessageActionChatEditTitle, MessageActionChatJoinedByLink, MessageActionChatMigrateTo, + MessageActionGameScore, MessageMediaDocument, MessageMediaPhoto, MessageMediaDice, + MessageMediaGame, MessageMediaUnsupported, PeerUser, PhotoCachedSize, TypeChannelParticipant, + TypeChatParticipant, TypeDocumentAttribute, TypeMessageAction, TypePhotoSize, PhotoSize, + UpdateChatUserTyping, UpdateUserTyping, MessageEntityPre, ChatPhotoEmpty, + DocumentAttributeAnimated, UpdateChannelUserTyping, DocumentAttributeFilename, + DocumentAttributeImageSize, GeoPoint, InputChatUploadedPhoto, MessageActionChatEditPhoto, + MessageMediaGeo, SendMessageCancelAction, SendMessageTypingAction, TypeInputPeer, + UpdateNewMessage, InputMediaUploadedDocument, InputMediaUploadedPhoto, TypeMessage, +) + +from mautrix.errors import MatrixRequestError, IntentError, MForbidden +from mautrix.appservice import IntentAPI, DOUBLE_PUPPET_SOURCE_KEY +from mautrix.types import ( + RoomAlias, ContentURI, RoomID, + RoomCreatePreset, Membership, PowerLevelStateEventContent, + RoomTopicStateEventContent, RoomNameStateEventContent, RoomAvatarStateEventContent, + StateEventContent, JoinRule, + EventID, UserID, ImageInfo, ThumbnailInfo, RelatesTo, MessageType, + EventType, MediaMessageEventContent, TextMessageEventContent, + LocationMessageEventContent, Format, MessageEventContent, VideoInfo, +) +from mautrix.util.simple_template import SimpleTemplate +from mautrix.util.simple_lock import SimpleLock +from mautrix.util.message_send_checkpoint import MessageSendCheckpointStatus +from mautrix.bridge import BasePortal, NotificationDisabler, async_getter_lock + +from .types import TelegramID +from .db import Portal as DBPortal, Message as DBMessage, TelegramFile as DBTelegramFile +from .util import sane_mimetypes, parallel_transfer_to_telegram +from .tgclient import MautrixTelegramClient +from .config import Config +from . import puppet as p, user as u, abstract_user as au, util, formatter + +try: + from mautrix.crypto.attachments import decrypt_attachment +except ImportError: + decrypt_attachment = None + +if TYPE_CHECKING: + from .bot import Bot + from .__main__ import TelegramBridge + +StateBridge = EventType.find("m.bridge", EventType.Class.STATE) +StateHalfShotBridge = EventType.find("uk.half-shot.bridge", EventType.Class.STATE) + +InviteList = Union[UserID, List[UserID]] +TypeParticipant = Union[TypeChatParticipant, TypeChannelParticipant] +UpdateTyping = Union[UpdateUserTyping, UpdateChatUserTyping, UpdateChannelUserTyping] +TypeChatPhoto = Union[ChatPhoto, ChatPhotoEmpty, Photo, PhotoEmpty] +MediaHandler = Callable[['au.AbstractUser', IntentAPI, Message, RelatesTo], Awaitable[EventID]] + + +class DocAttrs(NamedTuple): + name: str | None + mime_type: str | None + is_sticker: bool + sticker_alt: str | None + width: int + height: int + is_gif: bool + + +class Portal(DBPortal, BasePortal): + bot: 'Bot' + config: Config + + # Instance cache + by_mxid: dict[RoomID, Portal] = {} + by_tgid: dict[tuple[TelegramID, TelegramID], Portal] = {} + + # Config cache + filter_mode: str + filter_list: list[int] + + max_initial_member_sync: int + sync_channel_members: bool + sync_matrix_state: bool + public_portals: bool + private_chat_portal_meta: bool + + alias_template: SimpleTemplate[str] + hs_domain: str + + # Instance variables + deleted: bool + + backfill_lock: SimpleLock + backfill_method_lock: asyncio.Lock + backfill_leave: set[IntentAPI] | None + + alias: RoomAlias | None + + dedup: util.PortalDedup + send_lock: util.PortalSendLock + _pin_lock: asyncio.Lock + + _main_intent: IntentAPI | None + _room_create_lock: asyncio.Lock + + def __init__( + self, + tgid: TelegramID, + tg_receiver: TelegramID, + peer_type: str, + megagroup: bool = False, + mxid: RoomID | None = None, + avatar_url: ContentURI | None = None, + encrypted: bool = False, + username: str | None = None, + title: str | None = None, + about: str | None = None, + photo_id: str | None = None, + local_config: dict[str, Any] | None = None, + ) -> None: + super().__init__( + tgid=tgid, + tg_receiver=tg_receiver, + peer_type=peer_type, + megagroup=megagroup, + mxid=mxid, + avatar_url=avatar_url, + encrypted=encrypted, + username=username, + title=title, + about=about, + photo_id=photo_id, + local_config=local_config or {}, + ) + self.log = self.log.getChild(self.tgid_log if self.tgid else self.mxid) + self._main_intent = None + self.deleted = False + self.backfill_lock = SimpleLock("Waiting for backfilling to finish before handling %s", + log=self.log) + self.backfill_method_lock = asyncio.Lock() + self.backfill_leave = None + + self.dedup = util.PortalDedup(self) + self.send_lock = util.PortalSendLock() + self._pin_lock = asyncio.Lock() + self._room_create_lock = asyncio.Lock() + + # region Properties + + @property + def tgid_full(self) -> tuple[TelegramID, TelegramID]: + return self.tgid, self.tg_receiver + + @property + def tgid_log(self) -> str: + if self.tgid == self.tg_receiver: + return str(self.tgid) + return f"{self.tg_receiver}<->{self.tgid}" + + @property + def name(self) -> str: + return self.title + + @property + def alias(self) -> RoomAlias | None: + if not self.username: + return None + return RoomAlias(f"#{self.alias_localpart}:{self.hs_domain}") + + @property + def alias_localpart(self) -> str | None: + if not self.username: + return None + return self.alias_template.format(self.username) + + @property + def peer(self) -> TypePeer | TypeInputPeer: + if self.peer_type == "user": + return PeerUser(user_id=self.tgid) + elif self.peer_type == "chat": + return PeerChat(chat_id=self.tgid) + elif self.peer_type == "channel": + return PeerChannel(channel_id=self.tgid) + + @property + def is_direct(self) -> bool: + return self.peer_type == "user" + + @property + def has_bot(self) -> bool: + return (bool(self.bot) + and (self.bot.is_in_chat(self.tgid) + or (self.peer_type == "user" and self.tg_receiver == self.bot.tgid))) + + @property + def main_intent(self) -> IntentAPI: + if self._main_intent is None: + raise RuntimeError("Portal must be postinit()ed before main_intent can be used") + return self._main_intent + + @property + def allow_bridging(self) -> bool: + if self.peer_type == "user": + return True + elif self.filter_mode == "whitelist": + return self.tgid in self.filter_list + elif self.filter_mode == "blacklist": + return self.tgid not in self.filter_list + return True + + # endregion + + @classmethod + def init_cls(cls, bridge: 'TelegramBridge') -> None: + BasePortal.bridge = bridge + cls.az = bridge.az + cls.config = bridge.config + cls.loop = bridge.loop + cls.matrix = bridge.matrix + cls.bot = bridge.bot + + cls.max_initial_member_sync = cls.config["bridge.max_initial_member_sync"] + cls.sync_channel_members = cls.config["bridge.sync_channel_members"] + cls.sync_matrix_state = cls.config["bridge.sync_matrix_state"] + cls.public_portals = cls.config["bridge.public_portals"] + cls.private_chat_portal_meta = cls.config["bridge.private_chat_portal_meta"] + cls.filter_mode = cls.config["bridge.filter.mode"] + cls.filter_list = cls.config["bridge.filter.list"] + cls.hs_domain = cls.config["homeserver.domain"] + cls.alias_template = SimpleTemplate(cls.config["bridge.alias_template"], "groupname", + prefix="#", suffix=f":{cls.hs_domain}") + NotificationDisabler.puppet_cls = p.Puppet + NotificationDisabler.config_enabled = cls.config["bridge.backfill.disable_notifications"] + util.PortalDedup.dedup_pre_db_check = cls.config["bridge.deduplication.pre_db_check"] + util.PortalDedup.dedup_cache_queue_length = cls.config["bridge.deduplication.cache_queue_length"] + + # region Matrix -> Telegram metadata + + async def get_telegram_users_in_matrix_room( + self, source: u.User + ) -> tuple[list[InputPeerUser], list[UserID]]: + user_tgids = {} + user_mxids = await self.main_intent.get_room_members(self.mxid, (Membership.JOIN, + Membership.INVITE)) + for mxid in user_mxids: + if mxid == self.az.bot_mxid: + continue + mx_user = await u.User.get_by_mxid(mxid, create=False) + if mx_user and mx_user.tgid: + user_tgids[mx_user.tgid] = mxid + puppet_id = p.Puppet.get_id_from_mxid(mxid) + if puppet_id: + user_tgids[puppet_id] = mxid + input_users = [] + errors = [] + for tgid, mxid in user_tgids.items(): + try: + input_users.append(await source.client.get_input_entity(tgid)) + except ValueError as e: + source.log.debug(f"Failed to find the input entity for {tgid} ({mxid}) for " + f"creating a group: {e}") + errors.append(mxid) + return input_users, errors + + async def upgrade_telegram_chat(self, source: u.User) -> None: + if self.peer_type != "chat": + raise ValueError("Only normal group chats are upgradable to supergroups.") + + response = await source.client(MigrateChatRequest(chat_id=self.tgid)) + entity = None + for chat in response.chats: + if isinstance(chat, Channel): + entity = chat + break + if not entity: + raise ValueError("Upgrade may have failed: output channel not found.") + await self._migrate_and_save_telegram(TelegramID(entity.id)) + await self.update_info(source, entity) + + async def _migrate_and_save_telegram(self, new_id: TelegramID) -> None: + try: + del self.by_tgid[self.tgid_full] + except KeyError: + pass + try: + existing = self.by_tgid[(new_id, new_id)] + except KeyError: + existing = None + self.by_tgid[(new_id, new_id)] = self + if existing: + await existing.delete() + old_id = self.tgid + await self.update_id(new_id, "channel") + self.log = self.__class__.log.getChild(self.tgid_log) + self.log.info(f"Telegram chat upgraded from {old_id}") + + async def set_telegram_username(self, source: u.User, username: str) -> None: + if self.peer_type != "channel": + raise ValueError("Only channels and supergroups have usernames.") + await source.client( + UpdateUsernameRequest(await self.get_input_entity(source), username)) + if await self._update_username(username): + await self.save() + + async def create_telegram_chat( + self, source: u.User, invites: list[InputUser], supergroup: bool = False + ) -> None: + if not self.mxid: + raise ValueError("Can't create Telegram chat for portal without Matrix room.") + elif self.tgid: + raise ValueError("Can't create Telegram chat for portal with existing Telegram chat.") + + if len(invites) < 2: + if self.bot is not None: + info, mxid = await self.bot.get_me() + raise ValueError("Not enough Telegram users to create a chat. " + "Invite more Telegram ghost users to the room, such as the " + f"relaybot ([{info.first_name}](https://matrix.to/#/{mxid})).") + raise ValueError("Not enough Telegram users to create a chat. " + "Invite more Telegram ghost users to the room.") + if self.peer_type == "chat": + response = await source.client(CreateChatRequest(title=self.title, users=invites)) + entity = response.chats[0] + elif self.peer_type == "channel": + response = await source.client(CreateChannelRequest(title=self.title, + about=self.about or "", + megagroup=supergroup)) + entity = response.chats[0] + await source.client(InviteToChannelRequest( + channel=await source.client.get_input_entity(entity), + users=invites)) + else: + raise ValueError("Invalid peer type for Telegram chat creation") + + self.tgid = entity.id + self.tg_receiver = self.tgid + await self.postinit() + await self.insert() + await self.update_info(source, entity) + self.log = self.__class__.log.getChild(self.tgid_log) + + if self.bot and self.bot.tgid in invites: + await self.bot.add_chat(self.tgid, self.peer_type) + + levels = await self.main_intent.get_power_levels(self.mxid) + if levels.get_user_level(self.main_intent.mxid) == 100: + levels = self._get_base_power_levels(levels, entity) + await self.main_intent.set_power_levels(self.mxid, levels) + await self.handle_matrix_power_levels(source, levels.users, {}, None) + await self.update_bridge_info() + + async def invite_telegram( + self, source: u.User, puppet: p.Puppet | au.AbstractUser + ) -> None: + if self.peer_type == "chat": + await source.client( + AddChatUserRequest(chat_id=self.tgid, user_id=puppet.tgid, fwd_limit=0)) + elif self.peer_type == "channel": + await source.client(InviteToChannelRequest(channel=self.peer, users=[puppet.tgid])) + # We don't care if there are invites for private chat portals with the relaybot. + elif not self.bot or self.tg_receiver != self.bot.tgid: + raise ValueError("Invalid peer type for Telegram user invite") + + # endregion + # region Telegram -> Matrix metadata + + def _get_invite_content(self, double_puppet: p.Puppet | None) -> dict[str, Any]: + invite_content = {} + if double_puppet: + invite_content["fi.mau.will_auto_accept"] = True + if self.is_direct: + invite_content["is_direct"] = True + return invite_content + + async def invite_to_matrix(self, users: InviteList) -> None: + if isinstance(users, list): + for user in users: + await self.invite_to_matrix(user) + else: + puppet = await p.Puppet.get_by_custom_mxid(users) + await self.main_intent.invite_user(self.mxid, users, check_cache=True, + extra_content=self._get_invite_content(puppet)) + if puppet: + try: + await puppet.intent.ensure_joined(self.mxid) + except Exception: + self.log.exception("Failed to ensure %s is joined to portal", users) + + async def update_matrix_room(self, user: au.AbstractUser, entity: TypeChat | User, + direct: bool = None, puppet: p.Puppet = None, + levels: PowerLevelStateEventContent = None, + users: list[User] = None) -> None: + if direct is None: + direct = self.peer_type == "user" + try: + await self._update_matrix_room(user, entity, direct, puppet, levels, users) + except Exception: + self.log.exception("Fatal error updating Matrix room") + + async def _update_matrix_room(self, user: au.AbstractUser, entity: TypeChat | User, + direct: bool, puppet: p.Puppet = None, + levels: PowerLevelStateEventContent = None, + users: list[User] = None) -> None: + if not direct: + await self.update_info(user, entity) + if not users: + users = await self._get_users(user, entity) + await self._sync_telegram_users(user, users) + await self.update_power_levels(users, levels) + else: + if not puppet: + puppet = await p.Puppet.get_by_tgid(self.tgid) + await puppet.update_info(user, entity) + await puppet.intent_for(self).join_room(self.mxid) + if self.encrypted or self.private_chat_portal_meta: + # The bridge bot needs to join for e2ee, but that messes up the default name + # generation. If/when canonical DMs happen, this might not be necessary anymore. + changed = await self._update_title(puppet.displayname) + changed = await self._update_avatar(user, entity.photo) or changed + if changed: + await self.save() + await self.update_bridge_info() + + puppet = await p.Puppet.get_by_custom_mxid(user.mxid) + if puppet: + try: + did_join = await puppet.intent.ensure_joined(self.mxid) + if isinstance(user, u.User) and did_join and self.peer_type == "user": + await user.update_direct_chats({self.main_intent.mxid: [self.mxid]}) + except Exception: + self.log.exception("Failed to ensure %s is joined to portal", user.mxid) + + if self.sync_matrix_state: + await self.main_intent.get_joined_members(self.mxid) + + async def create_matrix_room(self, user: au.AbstractUser, entity: TypeChat | User = None, + invites: InviteList = None, update_if_exists: bool = True + ) -> RoomID | None: + if self.mxid: + if update_if_exists: + if not entity: + try: + entity = await self.get_entity(user) + except Exception: + self.log.exception(f"Failed to get entity through {user.tgid} for update") + return self.mxid + update = self.update_matrix_room(user, entity, self.peer_type == "user") + self.loop.create_task(update) + await self.invite_to_matrix(invites or []) + return self.mxid + async with self._room_create_lock: + try: + return await self._create_matrix_room(user, entity, invites) + except Exception: + self.log.exception("Fatal error creating Matrix room") + + @property + def bridge_info_state_key(self) -> str: + return f"net.maunium.telegram://telegram/{self.tgid}" + + @property + def bridge_info(self) -> dict[str, Any]: + info = { + "bridgebot": self.az.bot_mxid, + "creator": self.main_intent.mxid, + "protocol": { + "id": "telegram", + "displayname": "Telegram", + "avatar_url": self.config["appservice.bot_avatar"], + "external_url": "https://telegram.org", + }, + "channel": { + "id": str(self.tgid), + "displayname": self.title, + "avatar_url": self.avatar_url, + } + } + if self.username: + info["channel"]["external_url"] = f"https://t.me/{self.username}" + elif self.peer_type == "user": + # TODO this doesn't feel very reliable + puppet = p.Puppet.by_tgid.get(self.tgid, None) + if puppet and puppet.username: + info["channel"]["external_url"] = f"https://t.me/{puppet.username}" + return info + + async def update_bridge_info(self) -> None: + if not self.mxid: + self.log.debug("Not updating bridge info: no Matrix room created") + return + try: + self.log.debug("Updating bridge info...") + await self.main_intent.send_state_event(self.mxid, StateBridge, + self.bridge_info, self.bridge_info_state_key) + # TODO remove this once https://github.com/matrix-org/matrix-doc/pull/2346 is in spec + await self.main_intent.send_state_event(self.mxid, StateHalfShotBridge, + self.bridge_info, self.bridge_info_state_key) + except Exception: + self.log.warning("Failed to update bridge info", exc_info=True) + + async def _create_matrix_room(self, user: au.AbstractUser, entity: TypeChat | User, + invites: InviteList) -> RoomID | None: + if self.mxid: + return self.mxid + elif not self.allow_bridging: + return None + + direct = self.peer_type == "user" + invites = invites or [] + + if not entity: + entity = await self.get_entity(user) + self.log.trace("Fetched data: %s", entity) + + self.log.debug("Creating room") + + try: + self.title = entity.title + except AttributeError: + self.title = None + + if direct and self.tgid == user.tgid: + self.title = "Telegram Saved Messages" + self.about = "Your Telegram cloud storage chat" + + puppet = await p.Puppet.get_by_tgid(self.tgid) if direct else None + if puppet: + await puppet.update_info(user, entity) + self._main_intent = puppet.intent_for(self) if direct else self.az.intent + + if self.peer_type == "channel": + self.megagroup = entity.megagroup + + preset = RoomCreatePreset.PRIVATE + if self.peer_type == "channel" and entity.username: + if self.public_portals: + preset = RoomCreatePreset.PUBLIC + self.username = entity.username + alias = self.alias_localpart + else: + # TODO invite link alias? + alias = None + + if alias: + # TODO? properly handle existing room aliases + await self.main_intent.remove_room_alias(alias) + + power_levels = self._get_base_power_levels(entity=entity) + users = None + if not direct: + users = await self._get_users(user, entity) + if self.has_bot: + extra_invites = self.config["bridge.relaybot.group_chat_invite"] + invites += extra_invites + for invite in extra_invites: + power_levels.users.setdefault(invite, 100) + await self._participants_to_power_levels(users, power_levels) + elif self.bot and self.tg_receiver == self.bot.tgid: + invites = self.config["bridge.relaybot.private_chat.invite"] + for invite in invites: + power_levels.users.setdefault(invite, 100) + self.title = puppet.displayname + + initial_state = [{ + "type": EventType.ROOM_POWER_LEVELS.serialize(), + "content": power_levels.serialize(), + }, { + "type": str(StateBridge), + "state_key": self.bridge_info_state_key, + "content": self.bridge_info, + }, { + # TODO remove this once https://github.com/matrix-org/matrix-doc/pull/2346 is in spec + "type": str(StateHalfShotBridge), + "state_key": self.bridge_info_state_key, + "content": self.bridge_info, + }] + create_invites = [] + if self.config["bridge.encryption.default"] and self.matrix.e2ee: + self.encrypted = True + initial_state.append({ + "type": "m.room.encryption", + "content": {"algorithm": "m.megolm.v1.aes-sha2"}, + }) + if direct: + create_invites.append(self.az.bot_mxid) + if direct and (self.encrypted or self.private_chat_portal_meta): + self.title = puppet.displayname + if self.config["appservice.community_id"]: + initial_state.append({ + "type": "m.room.related_groups", + "content": {"groups": [self.config["appservice.community_id"]]}, + }) + creation_content = {} + if not self.config["bridge.federate_rooms"]: + creation_content["m.federate"] = False + + with self.backfill_lock: + room_id = await self.main_intent.create_room(alias_localpart=alias, preset=preset, + is_direct=direct, invitees=create_invites, + name=self.title, topic=self.about, + initial_state=initial_state, + creation_content=creation_content) + if not room_id: + raise Exception(f"Failed to create room") + + if self.encrypted and self.matrix.e2ee and direct: + try: + await self.az.intent.ensure_joined(room_id) + except Exception: + self.log.warning(f"Failed to add bridge bot to new private chat {room_id}") + + self.mxid = room_id + self.by_mxid[self.mxid] = self + await self.save() + await self.az.state_store.set_power_levels(self.mxid, power_levels) + await user.register_portal(self) + + await self.invite_to_matrix(invites) + + update_room = self.loop.create_task(self.update_matrix_room( + user, entity, direct, puppet, + levels=power_levels, users=users)) + + if self.config["bridge.backfill.initial_limit"] > 0: + self.log.debug("Initial backfill is enabled. Waiting for room members to sync " + "and then starting backfill") + await update_room + + try: + if isinstance(user, u.User): + await self.backfill(user, is_initial=True) + except Exception: + self.log.exception("Failed to backfill new portal") + + return self.mxid + + def _get_base_power_levels(self, levels: PowerLevelStateEventContent = None, + entity: TypeChat = None) -> PowerLevelStateEventContent: + levels = levels or PowerLevelStateEventContent() + if self.peer_type == "user": + overrides = self.config["bridge.initial_power_level_overrides.user"] + levels.ban = overrides.get("ban", 100) + levels.kick = overrides.get("kick", 100) + levels.invite = overrides.get("invite", 100) + levels.redact = overrides.get("redact", 0) + levels.events[EventType.ROOM_NAME] = 0 + levels.events[EventType.ROOM_AVATAR] = 0 + levels.events[EventType.ROOM_TOPIC] = 0 + levels.state_default = overrides.get("state_default", 0) + levels.users_default = overrides.get("users_default", 0) + levels.events_default = overrides.get("events_default", 0) + else: + overrides = self.config["bridge.initial_power_level_overrides.group"] + dbr = entity.default_banned_rights + if not dbr: + self.log.debug(f"default_banned_rights is None in {entity}") + dbr = ChatBannedRights(invite_users=True, change_info=True, pin_messages=True, + send_stickers=False, send_messages=False, until_date=None) + levels.ban = overrides.get("ban", 50) + levels.kick = overrides.get("kick", 50) + levels.redact = overrides.get("redact", 50) + levels.invite = overrides.get("invite", 50 if dbr.invite_users else 0) + levels.events[EventType.ROOM_ENCRYPTION] = 50 if self.matrix.e2ee else 99 + levels.events[EventType.ROOM_TOMBSTONE] = 99 + levels.events[EventType.ROOM_NAME] = 50 if dbr.change_info else 0 + levels.events[EventType.ROOM_AVATAR] = 50 if dbr.change_info else 0 + levels.events[EventType.ROOM_TOPIC] = 50 if dbr.change_info else 0 + levels.events[EventType.ROOM_PINNED_EVENTS] = 50 if dbr.pin_messages else 0 + levels.events[EventType.ROOM_POWER_LEVELS] = 75 + levels.events[EventType.ROOM_HISTORY_VISIBILITY] = 75 + levels.events[EventType.STICKER] = 50 if dbr.send_stickers else levels.events_default + levels.state_default = overrides.get("state_default", 50) + levels.users_default = overrides.get("users_default", 0) + levels.events_default = ( + overrides.get("events_default", + 50 if (self.peer_type == "channel" and not entity.megagroup + or entity.default_banned_rights.send_messages) + else 0)) + for evt_type, value in overrides.get("events", {}).items(): + levels.events[EventType.find(evt_type)] = value + levels.users = overrides.get("users", {}) + if self.main_intent.mxid not in levels.users: + levels.users[self.main_intent.mxid] = 100 + return levels + + @classmethod + def _get_level_from_participant(cls, participant: TypeParticipant, + levels: PowerLevelStateEventContent) -> int: + # TODO use the power level requirements to get better precision in channels + if isinstance(participant, (ChatParticipantAdmin, ChannelParticipantAdmin)): + return levels.state_default or 50 + elif isinstance(participant, (ChatParticipantCreator, ChannelParticipantCreator)): + return levels.get_user_level(cls.az.bot_mxid) - 5 + return levels.users_default or 0 + + @staticmethod + def _participant_to_power_levels(levels: PowerLevelStateEventContent, + user: u.User | p.Puppet, new_level: int, + bot_level: int) -> bool: + new_level = min(new_level, bot_level) + user_level = levels.get_user_level(user.mxid) + if user_level != new_level and user_level < bot_level: + levels.users[user.mxid] = new_level + return True + return False + + async def _participants_to_power_levels(self, users: list[TypeUser | TypeParticipant], + levels: PowerLevelStateEventContent) -> bool: + bot_level = levels.get_user_level(self.main_intent.mxid) + if bot_level < levels.get_event_level(EventType.ROOM_POWER_LEVELS): + return False + changed = False + admin_power_level = min(75 if self.peer_type == "channel" else 50, bot_level) + if levels.get_event_level(EventType.ROOM_POWER_LEVELS) != admin_power_level: + changed = True + levels.events[EventType.ROOM_POWER_LEVELS] = admin_power_level + + for user in users: + # The User objects we get from TelegramClient.get_participants have a custom + # participant property + participant = getattr(user, "participant", user) + + puppet = await p.Puppet.get_by_tgid(TelegramID(participant.user_id)) + user = await u.User.get_by_tgid(TelegramID(participant.user_id)) + new_level = self._get_level_from_participant(participant, levels) + + if user: + await user.register_portal(self) + changed = self._participant_to_power_levels(levels, user, new_level, + bot_level) or changed + + if puppet: + changed = self._participant_to_power_levels(levels, puppet, new_level, + bot_level) or changed + return changed + + async def update_power_levels(self, users: list[TypeUser | TypeParticipant], + levels: PowerLevelStateEventContent = None) -> None: + if not levels: + levels = await self.main_intent.get_power_levels(self.mxid) + if await self._participants_to_power_levels(users, levels): + await self.main_intent.set_power_levels(self.mxid, levels) + + async def _add_bot_chat(self, bot: User) -> None: + if self.bot and bot.id == self.bot.tgid: + await self.bot.add_chat(self.tgid, self.peer_type) + return + + user = await u.User.get_by_tgid(TelegramID(bot.id)) + if user and user.is_bot: + await user.register_portal(self) + + async def _sync_telegram_users(self, source: au.AbstractUser, users: list[User]) -> None: + allowed_tgids = set() + skip_deleted = self.config["bridge.skip_deleted_members"] + for entity in users: + puppet = await p.Puppet.get_by_tgid(TelegramID(entity.id)) + if entity.bot: + await self._add_bot_chat(entity) + allowed_tgids.add(entity.id) + + await puppet.update_info(source, entity) + if skip_deleted and entity.deleted: + continue + + await puppet.intent_for(self).ensure_joined(self.mxid) + + user = await u.User.get_by_tgid(TelegramID(entity.id)) + if user: + await self.invite_to_matrix(user.mxid) + + # We can't trust the member list if any of the following cases is true: + # * There are close to 10 000 users, because Telegram might not be sending all members. + # * The member sync count is limited, because then we might ignore some members. + # * It's a channel, because non-admins don't have access to the member list. + trust_member_list = ((len(allowed_tgids) < 9900 + if self.max_initial_member_sync < 0 + else len(allowed_tgids) < self.max_initial_member_sync - 10) + and (self.megagroup or self.peer_type != "channel")) + if not trust_member_list: + return + + for user_mxid in await self.main_intent.get_room_members(self.mxid): + if user_mxid == self.az.bot_mxid: + continue + + puppet_id = p.Puppet.get_id_from_mxid(user_mxid) + if puppet_id: + if puppet_id in allowed_tgids: + continue + if self.bot and puppet_id == self.bot.tgid: + await self.bot.remove_chat(self.tgid) + try: + await self.main_intent.kick_user(self.mxid, user_mxid, + "User had left this Telegram chat.") + except MForbidden: + pass + continue + + mx_user = await u.User.get_by_mxid(user_mxid, create=False) + if mx_user: + if mx_user.tgid in allowed_tgids: + continue + if mx_user.is_bot: + await mx_user.unregister_portal(*self.tgid_full) + if not self.has_bot: + try: + await self.main_intent.kick_user(self.mxid, mx_user.mxid, + "You had left this Telegram chat.") + except MForbidden: + pass + + async def _add_telegram_user(self, user_id: TelegramID, source: au.AbstractUser | None = None + ) -> None: + puppet = await p.Puppet.get_by_tgid(user_id) + if source: + entity: User = await source.client.get_entity(PeerUser(user_id)) + await puppet.update_info(source, entity) + await puppet.intent_for(self).ensure_joined(self.mxid) + + user = await u.User.get_by_tgid(user_id) + if user: + await user.register_portal(self) + await self.invite_to_matrix(user.mxid) + + async def _delete_telegram_user(self, user_id: TelegramID, sender: p.Puppet) -> None: + puppet = await p.Puppet.get_by_tgid(user_id) + user = await u.User.get_by_tgid(user_id) + kick_message = (f"Kicked by {sender.displayname}" + if sender and sender.tgid != puppet.tgid + else "Left Telegram chat") + puppet_extra_content = None + if sender.is_real_user: + puppet_extra_content = {DOUBLE_PUPPET_SOURCE_KEY: self.bridge.name} + if sender.tgid != puppet.tgid: + try: + await sender.intent_for(self).kick_user(self.mxid, puppet.mxid, extra_content=puppet_extra_content) + except MForbidden: + await self.main_intent.kick_user(self.mxid, puppet.mxid, kick_message) + else: + await puppet.intent_for(self).leave_room(self.mxid, extra_content=puppet_extra_content) + if user: + await user.unregister_portal(*self.tgid_full) + if sender.tgid != puppet.tgid: + try: + await sender.intent_for(self).kick_user(self.mxid, user.mxid, extra_content=puppet_extra_content) + return + except MForbidden: + pass + try: + await self.main_intent.kick_user(self.mxid, user.mxid, kick_message) + except MForbidden as e: + self.log.warning(f"Failed to kick {user.mxid}: {e}") + + async def update_info(self, user: au.AbstractUser, entity: TypeChat = None) -> None: + if self.peer_type == "user": + self.log.warning("Called update_info() for direct chat portal") + return + + changed = False + self.log.debug("Updating info") + try: + if not entity: + entity = await self.get_entity(user) + self.log.trace("Fetched data: %s", entity) + + if self.peer_type == "channel": + changed = self.megagroup != entity.megagroup or changed + self.megagroup = entity.megagroup + changed = await self._update_username(entity.username) or changed + + if hasattr(entity, "about"): + changed = self._update_about(entity.about) or changed + + changed = await self._update_title(entity.title) or changed + + if isinstance(entity.photo, ChatPhoto): + changed = await self._update_avatar(user, entity.photo) or changed + except Exception: + self.log.exception(f"Failed to update info from source {user.tgid}") + + if changed: + await self.save() + await self.update_bridge_info() + + async def _update_username(self, username: str, save: bool = False) -> bool: + if self.username == username: + return False + + if self.username: + await self.main_intent.remove_room_alias(self.alias_localpart) + self.username = username or None + if self.username: + await self.main_intent.add_room_alias(self.mxid, self.alias_localpart, override=True) + if self.public_portals: + await self.main_intent.set_join_rule(self.mxid, JoinRule.PUBLIC) + else: + await self.main_intent.set_join_rule(self.mxid, JoinRule.INVITE) + + if save: + await self.save() + return True + + async def _try_set_state(self, sender: p.Puppet | None, evt_type: EventType, + content: StateEventContent) -> None: + if sender: + try: + intent = sender.intent_for(self) + if sender.is_real_user: + content[DOUBLE_PUPPET_SOURCE_KEY] = self.bridge.name + await intent.send_state_event(self.mxid, evt_type, content) + except MForbidden: + await self.main_intent.send_state_event(self.mxid, evt_type, content) + else: + await self.main_intent.send_state_event(self.mxid, evt_type, content) + + async def _update_about(self, about: str, sender: p.Puppet | None = None, + save: bool = False) -> bool: + if self.about == about: + return False + + self.about = about + await self._try_set_state(sender, EventType.ROOM_TOPIC, + RoomTopicStateEventContent(topic=self.about)) + if save: + await self.save() + return True + + async def _update_title(self, title: str, sender: p.Puppet | None = None, + save: bool = False) -> bool: + if self.title == title: + return False + + self.title = title + await self._try_set_state(sender, EventType.ROOM_NAME, + RoomNameStateEventContent(name=self.title)) + if save: + await self.save() + return True + + async def _update_avatar(self, user: au.AbstractUser, photo: TypeChatPhoto, + sender: p.Puppet | None = None, save: bool = False) -> bool: + if isinstance(photo, (ChatPhoto, UserProfilePhoto)): + loc = InputPeerPhotoFileLocation( + peer=await self.get_input_entity(user), + photo_id=photo.photo_id, + big=True + ) + photo_id = str(photo.photo_id) + elif isinstance(photo, Photo): + loc, _ = self._get_largest_photo_size(photo) + photo_id = str(loc.id) + elif isinstance(photo, (UserProfilePhotoEmpty, ChatPhotoEmpty, PhotoEmpty, type(None))): + photo_id = "" + loc = None + else: + raise ValueError(f"Unknown photo type {type(photo)}") + if ( + self.peer_type == "user" and not photo_id + and not self.config["bridge.allow_avatar_remove"] + ): + return False + if self.photo_id != photo_id: + if not photo_id: + await self._try_set_state(sender, EventType.ROOM_AVATAR, + RoomAvatarStateEventContent(url=None)) + self.photo_id = "" + self.avatar_url = None + if save: + await self.save() + return True + file = await util.transfer_file_to_matrix(user.client, self.main_intent, loc) + if file: + await self._try_set_state(sender, EventType.ROOM_AVATAR, + RoomAvatarStateEventContent(url=file.mxc)) + self.photo_id = photo_id + self.avatar_url = file.mxc + if save: + await self.save() + return True + return False + + @staticmethod + def _filter_participants(users: list[TypeUser], participants: list[TypeParticipant] + ) -> Iterable[TypeUser]: + participant_map = {part.user_id: part for part in participants + if not isinstance(part, ChannelParticipantBanned)} + for user in users: + try: + user.participant = participant_map[user.id] + except KeyError: + pass + else: + yield user + + async def _get_channel_users(self, user: au.AbstractUser, entity: InputChannel, limit: int + ) -> list[TypeUser]: + if 0 < limit <= 200: + response = await user.client(GetParticipantsRequest( + entity, ChannelParticipantsRecent(), offset=0, limit=limit, hash=0)) + return list(self._filter_participants(response.users, response.participants)) + elif limit > 200 or limit == -1: + users: list[TypeUser] = [] + offset = 0 + remaining_quota = limit if limit > 0 else 1000000 + query = (ChannelParticipantsSearch("") if limit == -1 + else ChannelParticipantsRecent()) + while True: + if remaining_quota <= 0: + break + response = await user.client(GetParticipantsRequest( + entity, query, offset=offset, limit=min(remaining_quota, 200), hash=0)) + if not response.users: + break + users += self._filter_participants(response.users, response.participants) + offset += len(response.participants) + remaining_quota -= len(response.participants) + return users + + async def _get_users(self, user: au.AbstractUser, + entity: TypeInputPeer | InputUser | TypeChat | TypeUser | InputChannel + ) -> list[TypeUser]: + limit = self.max_initial_member_sync + if self.peer_type == "chat": + chat = await user.client(GetFullChatRequest(chat_id=self.tgid)) + return list( + self._filter_participants(chat.users, chat.full_chat.participants.participants) + )[:limit] + elif self.peer_type == "channel": + if not self.megagroup and not self.sync_channel_members: + return [] + + if limit == 0: + return [] + + try: + return await self._get_channel_users(user, entity, limit) + except ChatAdminRequiredError: + return [] + elif self.peer_type == "user": + return [entity] + else: + raise RuntimeError(f"Unexpected peer type {self.peer_type}") + + # endregion + # region Matrix -> Telegram bridging + + async def _send_delivery_receipt(self, event_id: EventID, room_id: RoomID | None = None + ) -> None: + # TODO maybe check if the bot is in the room rather than assuming based on self.encrypted + if event_id and self.config["bridge.delivery_receipts"] and (self.encrypted + or self.peer_type != "user"): + try: + await self.az.intent.mark_read(room_id or self.mxid, event_id) + except Exception: + self.log.exception("Failed to send delivery receipt for %s", event_id) + + async def _get_state_change_message( + self, event: str, user: u.User, **kwargs: Any + ) -> str | None: + tpl = self.get_config(f"state_event_formats.{event}") + if len(tpl) == 0: + # Empty format means they don't want the message + return None + displayname = await self.get_displayname(user) + + tpl_args = { + "mxid": user.mxid, + "username": user.mxid_localpart, + "displayname": escape_html(displayname), + **kwargs, + } + return Template(tpl).safe_substitute(tpl_args) + + async def _send_state_change_message(self, event: str, user: u.User, event_id: EventID, + **kwargs: Any) -> None: + if not self.has_bot: + return + elif ( + self.peer_type == "user" + and not self.config["bridge.relaybot.private_chat.state_changes"] + ): + return + async with self.send_lock(self.bot.tgid): + message = await self._get_state_change_message(event, user, **kwargs) + if not message: + return + message, entities = await formatter.matrix_to_telegram(self.bot.client, html=message) + response = await self.bot.client.send_message(self.peer, message, + formatting_entities=entities) + space = self.tgid if self.peer_type == "channel" else self.bot.tgid + self.dedup.check(response, (event_id, space)) + + async def name_change_matrix(self, user: u.User, displayname: str, prev_displayname: str, + event_id: EventID) -> None: + await self._send_state_change_message("name_change", user, event_id, + displayname=displayname, + prev_displayname=prev_displayname) + + async def get_displayname(self, user: u.User) -> str: + return await self.main_intent.get_room_displayname(self.mxid, user.mxid) or user.mxid + + def set_typing(self, user: u.User, typing: bool = True, + action: type = SendMessageTypingAction) -> Awaitable[bool]: + return user.client(SetTypingRequest( + self.peer, action() if typing else SendMessageCancelAction())) + + async def mark_read(self, user: u.User, event_id: EventID) -> None: + if user.is_bot: + return + space = self.tgid if self.peer_type == "channel" else user.tgid + message = await DBMessage.get_by_mxid(event_id, self.mxid, space) + if not message: + message = await DBMessage.find_last(self.mxid, space) + if not message: + self.log.debug(f"Dropping Matrix read receipt from {user.mxid}: " + f"target message {event_id} not known and last message" + " in chat not found") + return + else: + self.log.debug(f"Matrix read receipt target {event_id} not known, marking " + f"messages up to most recent ({message.mxid}/{message.tgid}) " + f"as read by {user.mxid}/{user.tgid}") + else: + self.log.debug("Handling Matrix read receipt: marking messages up to " + f"{message.mxid}/{message.tgid} as read by {user.mxid}/{user.tgid}") + await user.client.send_read_acknowledge(self.peer, max_id=message.tgid, + clear_mentions=True) + + async def _preproc_kick_ban(self, user: u.User | p.Puppet, source: u.User + ) -> au.AbstractUser | None: + if user.tgid == source.tgid: + return None + if self.peer_type == "user" and user.tgid == self.tgid: + await self.delete() + return None + if isinstance(user, u.User) and await user.needs_relaybot(self): + if not self.bot: + return None + # TODO kick message + return None + if await source.needs_relaybot(self): + if not self.has_bot: + return None + return self.bot + return source + + async def kick_matrix(self, user: u.User | p.Puppet, source: u.User) -> None: + source = await self._preproc_kick_ban(user, source) + if source is not None: + await source.client.kick_participant(self.peer, user.peer) + + async def ban_matrix(self, user: u.User | p.Puppet, source: u.User): + source = await self._preproc_kick_ban(user, source) + if source is not None: + await source.client.edit_permissions(self.peer, user.peer, view_messages=False) + + async def leave_matrix(self, user: u.User, event_id: EventID) -> None: + if await user.needs_relaybot(self): + await self._send_state_change_message("leave", user, event_id) + return + + if self.peer_type == "user": + await self.main_intent.leave_room(self.mxid) + await self.delete() + try: + del self.by_tgid[self.tgid_full] + del self.by_mxid[self.mxid] + except KeyError: + pass + elif self.config["bridge.kick_on_logout"]: + await user.client.delete_dialog(self.peer) + + async def join_matrix(self, user: u.User, event_id: EventID) -> None: + if await user.needs_relaybot(self): + await self._send_state_change_message("join", user, event_id) + return + + if self.peer_type == "channel" and not user.is_bot: + await user.client(JoinChannelRequest(channel=await self.get_input_entity(user))) + else: + # We'll just assume the user is already in the chat. + pass + + async def _apply_msg_format(self, sender: u.User, content: MessageEventContent + ) -> None: + if not isinstance(content, TextMessageEventContent) or content.format != Format.HTML: + content.format = Format.HTML + content.formatted_body = escape_html(content.body).replace("\n", "
") + + tpl = (self.get_config(f"message_formats.[{content.msgtype.value}]") + or "$sender_displayname: $message") + displayname = await self.get_displayname(sender) + tpl_args = dict(sender_mxid=sender.mxid, + sender_username=sender.mxid_localpart, + sender_displayname=escape_html(displayname), + message=content.formatted_body, + body=content.body, formatted_body=content.formatted_body) + content.formatted_body = Template(tpl).safe_substitute(tpl_args) + + async def _apply_emote_format(self, sender: u.User, + content: TextMessageEventContent) -> None: + if content.format != Format.HTML: + content.format = Format.HTML + content.formatted_body = escape_html(content.body).replace("\n", "
") + + tpl = self.get_config("emote_format") + puppet = await p.Puppet.get_by_tgid(sender.tgid) + content.formatted_body = Template(tpl).safe_substitute( + dict(sender_mxid=sender.mxid, + sender_username=sender.mxid_localpart, + sender_displayname=escape_html(await self.get_displayname(sender)), + mention=f"{puppet.displayname}", + username=sender.tg_username, + displayname=puppet.displayname, + body=content.body, + formatted_body=content.formatted_body)) + content.msgtype = MessageType.TEXT + + async def _pre_process_matrix_message(self, sender: u.User, use_relaybot: bool, + content: MessageEventContent) -> None: + if use_relaybot: + await self._apply_msg_format(sender, content) + elif content.msgtype == MessageType.EMOTE: + await self._apply_emote_format(sender, content) + + async def _handle_matrix_text(self, sender: u.User, logged_in: bool, event_id: EventID, + space: TelegramID, client: MautrixTelegramClient, + content: TextMessageEventContent, reply_to: TelegramID | None + ) -> None: + message, entities = await formatter.matrix_to_telegram(client, text=content.body, + html=content.formatted(Format.HTML)) + sender_id = sender.tgid if logged_in else self.bot.tgid + async with self.send_lock(sender_id): + lp = self.get_config("telegram_link_preview") + if content.get_edit(): + orig_msg = await DBMessage.get_by_mxid(content.get_edit(), self.mxid, space) + if orig_msg: + response = await client.edit_message(self.peer, orig_msg.tgid, message, + formatting_entities=entities, + link_preview=lp) + await self._add_telegram_message_to_db(event_id, space, -1, response) + return + try: + response = await client.send_message(self.peer, message, reply_to=reply_to, + formatting_entities=entities, + link_preview=lp) + except Exception: + raise + else: + sender.send_remote_checkpoint( + MessageSendCheckpointStatus.SUCCESS, + event_id, + self.mxid, + EventType.ROOM_MESSAGE, + message_type=content.msgtype, + ) + await self._add_telegram_message_to_db(event_id, space, 0, response) + await self._send_delivery_receipt(event_id) + + async def _handle_matrix_file(self, sender: u.User, logged_in: bool, event_id: EventID, + space: TelegramID, client: MautrixTelegramClient, + content: MediaMessageEventContent, reply_to: TelegramID, + caption: TextMessageEventContent = None) -> None: + sender_id = sender.tgid if logged_in else self.bot.tgid + mime = content.info.mimetype + if isinstance(content.info, (ImageInfo, VideoInfo)): + w, h = content.info.width, content.info.height + else: + w = h = None + file_name = content["net.maunium.telegram.internal.filename"] + max_image_size = self.config["bridge.image_as_file_size"] * 1000 ** 2 + + if self.config["bridge.parallel_file_transfer"] and content.url: + file_handle, file_size = await parallel_transfer_to_telegram(client, self.main_intent, + content.url, sender_id) + else: + if content.file: + if not decrypt_attachment: + raise Exception(f"Can't bridge encrypted media event {event_id}: " + "encryption dependencies not installed") + file = await self.main_intent.download_media(content.file.url) + file = decrypt_attachment(file, content.file.key.key, + content.file.hashes.get("sha256"), content.file.iv) + else: + file = await self.main_intent.download_media(content.url) + + if content.msgtype == MessageType.STICKER: + if mime != "image/gif": + mime, file, w, h = util.convert_image(file, source_mime=mime, + target_type="webp") + else: + # Remove sticker description + file_name = "sticker.gif" + + file_handle = await client.upload_file(file) + file_size = len(file) + + file_handle.name = file_name + + attributes = [DocumentAttributeFilename(file_name=file_name)] + if w and h: + attributes.append(DocumentAttributeImageSize(w, h)) + + if (mime == "image/png" or mime == "image/jpeg") and file_size < max_image_size: + media = InputMediaUploadedPhoto(file_handle) + else: + media = InputMediaUploadedDocument(file=file_handle, attributes=attributes, + mime_type=mime or "application/octet-stream") + + capt, entities = (await formatter.matrix_to_telegram(client, text=caption.body, + html=caption.formatted(Format.HTML)) + if caption else (None, None)) + + async with self.send_lock(sender_id): + if await self._matrix_document_edit(client, content, space, capt, media, event_id): + return + try: + try: + response = await client.send_media(self.peer, media, reply_to=reply_to, + caption=capt, entities=entities) + except ( + PhotoInvalidDimensionsError, PhotoSaveFileInvalidError, PhotoExtInvalidError + ): + media = InputMediaUploadedDocument(file=media.file, mime_type=mime, + attributes=attributes) + response = await client.send_media(self.peer, media, reply_to=reply_to, + caption=capt, entities=entities) + except Exception: + raise + else: + sender.send_remote_checkpoint( + MessageSendCheckpointStatus.SUCCESS, + event_id, + self.mxid, + EventType.ROOM_MESSAGE, + message_type=content.msgtype, + ) + await self._add_telegram_message_to_db(event_id, space, 0, response) + await self._send_delivery_receipt(event_id) + + async def _matrix_document_edit(self, client: MautrixTelegramClient, + content: MessageEventContent, space: TelegramID, + caption: str, media: Any, event_id: EventID) -> bool: + if content.get_edit(): + orig_msg = await DBMessage.get_by_mxid(content.get_edit(), self.mxid, space) + if orig_msg: + response = await client.edit_message(self.peer, orig_msg.tgid, + caption, file=media) + await self._add_telegram_message_to_db(event_id, space, -1, response) + await self._send_delivery_receipt(event_id) + return True + return False + + async def _handle_matrix_location(self, sender: u.User, logged_in: bool, event_id: EventID, + space: TelegramID, client: MautrixTelegramClient, + content: LocationMessageEventContent, reply_to: TelegramID + ) -> None: + sender_id = sender.tgid if logged_in else self.bot.tgid + try: + lat, long = content.geo_uri[len("geo:"):].split(";")[0].split(",") + lat, long = float(lat), float(long) + except (KeyError, ValueError): + self.log.exception("Failed to parse location") + return None + caption, entities = await formatter.matrix_to_telegram(client, text=content.body) + media = MessageMediaGeo(geo=GeoPoint(lat=lat, long=long, access_hash=0)) + + async with self.send_lock(sender_id): + if await self._matrix_document_edit(client, content, space, caption, media, event_id): + return + try: + response = await client.send_media(self.peer, media, reply_to=reply_to, + caption=caption, entities=entities) + except Exception: + raise + else: + await self._add_telegram_message_to_db(event_id, space, 0, response) + sender.send_remote_checkpoint( + MessageSendCheckpointStatus.SUCCESS, + event_id, + self.mxid, + EventType.ROOM_MESSAGE, + message_type=content.msgtype, + ) + await self._send_delivery_receipt(event_id) + + async def _add_telegram_message_to_db(self, event_id: EventID, space: TelegramID, + edit_index: int, response: TypeMessage) -> None: + self.log.trace("Handled Matrix message: %s", response) + self.dedup.check(response, (event_id, space), force_hash=edit_index != 0) + if edit_index < 0: + prev_edit = await DBMessage.get_one_by_tgid(TelegramID(response.id), space, -1) + edit_index = prev_edit.edit_index + 1 + await DBMessage( + tgid=TelegramID(response.id), + tg_space=space, + mx_room=self.mxid, + mxid=event_id, + edit_index=edit_index).insert() + + async def _send_bridge_error(self, sender: u.User, err: Exception, event_id: EventID, + event_type: EventType, + message_type: MessageType | None = None, + msg: str | None = None) -> None: + sender.send_remote_checkpoint( + MessageSendCheckpointStatus.PERM_FAILURE, + event_id, + self.mxid, + event_type, + message_type=message_type, + error=err, + ) + + if msg and self.config["bridge.delivery_error_reports"]: + await self._send_message( + self.main_intent, TextMessageEventContent(msgtype=MessageType.NOTICE, body=msg) + ) + + async def handle_matrix_message(self, sender: u.User, content: MessageEventContent, + event_id: EventID) -> None: + try: + await self._handle_matrix_message(sender, content, event_id) + except RPCError as e: + self.log.exception(f"RPCError while bridging {event_id}: {e}") + await self._send_bridge_error( + sender, + e, + event_id, + EventType.ROOM_MESSAGE, + message_type=content.msgtype, + msg=f"\u26a0 Your message may not have been bridged: {e}", + ) + raise + except Exception as e: + self.log.exception(f"Failed to bridge {event_id}: {e}") + await self._send_bridge_error( + sender, + e, + event_id, + EventType.ROOM_MESSAGE, + message_type=content.msgtype, + ) + + async def _handle_matrix_message(self, sender: u.User, content: MessageEventContent, + event_id: EventID) -> None: + if not content.body or not content.msgtype: + self.log.debug(f"Ignoring message {event_id} in {self.mxid} without body or msgtype") + return + + logged_in = not await sender.needs_relaybot(self) + client = sender.client if logged_in else self.bot.client + space = (self.tgid if self.peer_type == "channel" # Channels have their own ID space + else (sender.tgid if logged_in else self.bot.tgid)) + reply_to = await formatter.matrix_reply_to_telegram(content, space, room_id=self.mxid) + + media = (MessageType.STICKER, MessageType.IMAGE, MessageType.FILE, MessageType.AUDIO, + MessageType.VIDEO) + + if content.msgtype == MessageType.NOTICE: + bridge_notices = self.get_config("bridge_notices.default") + excepted = sender.mxid in self.get_config("bridge_notices.exceptions") + if not bridge_notices and not excepted: + raise Exception("Notices are not configured to be bridged.") + + if content.msgtype in (MessageType.TEXT, MessageType.EMOTE, MessageType.NOTICE): + await self._pre_process_matrix_message(sender, not logged_in, content) + await self._handle_matrix_text(sender, logged_in, event_id, space, client, content, + reply_to) + elif content.msgtype == MessageType.LOCATION: + await self._pre_process_matrix_message(sender, not logged_in, content) + await self._handle_matrix_location(sender, logged_in, event_id, space, client, content, + reply_to) + elif content.msgtype in media: + content["net.maunium.telegram.internal.filename"] = content.body + try: + caption_content: MessageEventContent = sender.command_status["caption"] + reply_to = reply_to or await formatter.matrix_reply_to_telegram( + caption_content, space, room_id=self.mxid + ) + sender.command_status = None + except (KeyError, TypeError): + caption_content = None if logged_in else TextMessageEventContent(body=content.body) + if caption_content: + caption_content.msgtype = content.msgtype + await self._pre_process_matrix_message(sender, not logged_in, caption_content) + await self._handle_matrix_file(sender, logged_in, event_id, space, client, content, + reply_to, caption_content) + else: + self.log.debug( + f"Didn't handle Matrix event {event_id} due to unknown msgtype {content.msgtype}") + self.log.trace("Unhandled Matrix event content: %s", content) + raise Exception(f"Unhandled msgtype {content.msgtype}") + + async def handle_matrix_unpin_all(self, sender: u.User, pin_event_id: EventID) -> None: + await sender.client(UnpinAllMessagesRequest(peer=self.peer)) + await self._send_delivery_receipt(pin_event_id) + + async def handle_matrix_pin(self, sender: u.User, changes: dict[EventID, bool], + pin_event_id: EventID) -> None: + tg_space = self.tgid if self.peer_type == "channel" else sender.tgid + ids = {msg.mxid: msg.tgid + for msg in await DBMessage.get_by_mxids(list(changes.keys()), + mx_room=self.mxid, tg_space=tg_space)} + for event_id, pinned in changes.items(): + try: + await sender.client(UpdatePinnedMessageRequest(peer=self.peer, id=ids[event_id], + unpin=not pinned)) + except (ChatNotModifiedError, MessageIdInvalidError, KeyError): + pass + await self._send_delivery_receipt(pin_event_id) + + async def handle_matrix_deletion(self, deleter: u.User, event_id: EventID, + redaction_event_id: EventID) -> None: + try: + await self._handle_matrix_deletion(deleter, event_id) + except Exception as e: + self.log.debug(str(e)) + await self._send_bridge_error(deleter, e, redaction_event_id, EventType.ROOM_REDACTION) + else: + deleter.send_remote_checkpoint( + MessageSendCheckpointStatus.SUCCESS, + redaction_event_id, + self.mxid, + EventType.ROOM_REDACTION, + ) + await self._send_delivery_receipt(redaction_event_id) + + async def _handle_matrix_deletion(self, deleter: u.User, event_id: EventID) -> None: + real_deleter = deleter if not await deleter.needs_relaybot(self) else self.bot + space = self.tgid if self.peer_type == "channel" else real_deleter.tgid + message = await DBMessage.get_by_mxid(event_id, self.mxid, space) + if not message: + raise Exception(f"Ignoring Matrix redaction of unknown event {event_id}") + elif message.redacted: + raise Exception("Ignoring Matrix redaction of already redacted event " + f"{message.mxid} in {message.mx_room}") + elif message.edit_index != 0: + await message.mark_redacted() + raise Exception("Ignoring Matrix redaction of edit event " + f"{message.mxid} in {message.mx_room}") + else: + await message.mark_redacted() + await real_deleter.client.delete_messages(self.peer, [message.tgid]) + + async def _update_telegram_power_level(self, sender: u.User, user_id: TelegramID, + level: int) -> None: + moderator = level >= 50 + admin = level >= 75 + await sender.client.edit_admin(self.peer, user_id, + change_info=moderator, post_messages=moderator, + edit_messages=moderator, delete_messages=moderator, + ban_users=moderator, invite_users=moderator, + pin_messages=moderator, add_admins=admin) + + async def handle_matrix_power_levels( + self, sender: u.User, new_users: dict[UserID, int], + old_users: dict[UserID, int], event_id: EventID | None + ) -> None: + # TODO handle all power level changes and bridge exact admin rights to supergroups/channels + for user, level in new_users.items(): + if not user or user == self.main_intent.mxid or user == sender.mxid: + continue + user_id = p.Puppet.get_id_from_mxid(user) + if not user_id: + mx_user = await u.User.get_by_mxid(user, create=False) + if not mx_user or not mx_user.tgid: + continue + user_id = mx_user.tgid + if not user_id or user_id == sender.tgid: + continue + if user not in old_users or level != old_users[user]: + await self._update_telegram_power_level(sender, user_id, level) + + async def handle_matrix_about(self, sender: u.User, about: str, event_id: EventID) -> None: + if self.peer_type not in ("chat", "channel"): + return + peer = await self.get_input_entity(sender) + await sender.client(EditChatAboutRequest(peer=peer, about=about)) + self.about = about + await self.save() + await self._send_delivery_receipt(event_id) + + async def handle_matrix_title(self, sender: u.User, title: str, event_id: EventID) -> None: + if self.peer_type not in ("chat", "channel"): + return + + if self.peer_type == "chat": + response = await sender.client(EditChatTitleRequest(chat_id=self.tgid, title=title)) + else: + channel = await self.get_input_entity(sender) + response = await sender.client(EditTitleRequest(channel=channel, title=title)) + self.dedup.register_outgoing_actions(response) + self.title = title + await self.save() + await self._send_delivery_receipt(event_id) + await self.update_bridge_info() + + async def handle_matrix_avatar(self, sender: u.User, url: ContentURI, event_id: EventID + ) -> None: + if self.peer_type not in ("chat", "channel"): + # Invalid peer type + return + elif self.avatar_url == url: + return + + self.avatar_url = url + file = await self.main_intent.download_media(url) + mime = magic.from_buffer(file, mime=True) + ext = sane_mimetypes.guess_extension(mime) + uploaded = await sender.client.upload_file(file, file_name=f"avatar{ext}") + photo = InputChatUploadedPhoto(file=uploaded) + + if self.peer_type == "chat": + response = await sender.client(EditChatPhotoRequest(chat_id=self.tgid, photo=photo)) + else: + channel = await self.get_input_entity(sender) + response = await sender.client(EditPhotoRequest(channel=channel, photo=photo)) + self.dedup.register_outgoing_actions(response) + for update in response.updates: + is_photo_update = (isinstance(update, UpdateNewMessage) + and isinstance(update.message, MessageService) + and isinstance(update.message.action, MessageActionChatEditPhoto)) + if is_photo_update: + loc, size = self._get_largest_photo_size(update.message.action.photo) + self.photo_id = str(loc.id) + await self.save() + break + await self._send_delivery_receipt(event_id) + await self.update_bridge_info() + + async def handle_matrix_upgrade(self, sender: UserID, new_room: RoomID, event_id: EventID + ) -> None: + _, server = self.main_intent.parse_user_id(sender) + old_room = self.mxid + await self.migrate_and_save_matrix(new_room) + await self.main_intent.join_room(new_room, servers=[server]) + entity: TypeInputPeer | None = None + user: au.AbstractUser | None = None + if self.bot and self.has_bot: + user = self.bot + entity = await self.get_input_entity(self.bot) + if not entity: + user_mxids = await self.main_intent.get_room_members(self.mxid) + for user_str in user_mxids: + user_id = UserID(user_str) + if user_id == self.az.bot_mxid: + continue + user = await u.User.get_by_mxid(user_id, create=False) + if user and user.tgid: + entity = await self.get_input_entity(user) + if entity: + break + if not entity: + self.log.error("Failed to fully migrate to upgraded Matrix room: " + "no Telegram user found.") + return + await self.update_matrix_room(user, entity, direct=self.peer_type == "user") + self.log.info(f"{sender} upgraded room from {old_room} to {self.mxid}") + await self._send_delivery_receipt(event_id, room_id=old_room) + + async def migrate_and_save_matrix(self, new_id: RoomID) -> None: + try: + del self.by_mxid[self.mxid] + except KeyError: + pass + self.mxid = new_id + self.by_mxid[self.mxid] = self + await self.save() + + async def enable_dm_encryption(self) -> bool: + ok = await super().enable_dm_encryption() + if ok: + try: + puppet = await p.Puppet.get_by_tgid(self.tgid) + await self.main_intent.set_room_name(self.mxid, puppet.displayname) + except Exception: + self.log.warning(f"Failed to set room name", exc_info=True) + return ok + + # endregion + # region Telegram -> Matrix bridging + + async def handle_telegram_typing(self, user: p.Puppet, update: UpdateTyping) -> None: + if user.is_real_user: + # Ignore typing notifications from double puppeted users to avoid echoing + return + is_typing = isinstance(update.action, SendMessageTypingAction) + await user.default_mxid_intent.set_typing(self.mxid, is_typing=is_typing) + + def _get_external_url(self, evt: Message) -> str | None: + if self.peer_type == "channel" and self.username is not None: + return f"https://t.me/{self.username}/{evt.id}" + elif self.peer_type != "user": + return f"https://t.me/c/{self.tgid}/{evt.id}" + return None + + async def _expire_telegram_photo(self, intent: IntentAPI, event_id: EventID, ttl: int) -> None: + try: + content = TextMessageEventContent(msgtype=MessageType.NOTICE, body="Photo has expired") + content.set_edit(event_id) + await asyncio.sleep(ttl) + await self._send_message(intent, content) + except Exception: + self.log.warning("Failed to expire Telegram photo %s", event_id, exc_info=True) + + async def _handle_telegram_photo( + self, source: au.AbstractUser, intent: IntentAPI, evt: Message, relates_to: RelatesTo + ) -> EventID | None: + media: MessageMediaPhoto = evt.media + if media.photo is None and media.ttl_seconds: + return await self._send_message(intent, TextMessageEventContent( + msgtype=MessageType.NOTICE, body="Photo has expired")) + loc, largest_size = self._get_largest_photo_size(media.photo) + if loc is None: + content = TextMessageEventContent(msgtype=MessageType.TEXT, + body="Failed to bridge image", + external_url=self._get_external_url(evt)) + return await self._send_message(intent, content, timestamp=evt.date) + file = await util.transfer_file_to_matrix(source.client, intent, loc, + encrypt=self.encrypted) + if not file: + return None + if self.get_config("inline_images") and (evt.message or evt.fwd_from or evt.reply_to): + content = await formatter.telegram_to_matrix( + evt, source, self.main_intent, + prefix_html=f"Inline Telegram photo
", + prefix_text="Inline image: ") + content.external_url = self._get_external_url(evt) + await intent.set_typing(self.mxid, is_typing=False) + return await self._send_message(intent, content, timestamp=evt.date) + info = ImageInfo( + height=largest_size.h, width=largest_size.w, orientation=0, mimetype=file.mime_type, + size=self._photo_size_key(largest_size)) + ext = sane_mimetypes.guess_extension(file.mime_type) + name = f"disappearing_image{ext}" if media.ttl_seconds else f"image{ext}" + await intent.set_typing(self.mxid, is_typing=False) + content = MediaMessageEventContent(msgtype=MessageType.IMAGE, info=info, + body=name, relates_to=relates_to, + external_url=self._get_external_url(evt)) + if file.decryption_info: + content.file = file.decryption_info + else: + content.url = file.mxc + result = await self._send_message(intent, content, timestamp=evt.date) + if media.ttl_seconds: + self.loop.create_task(self._expire_telegram_photo(intent, result, + media.ttl_seconds)) + if evt.message: + caption_content = await formatter.telegram_to_matrix(evt, source, self.main_intent, + no_reply_fallback=True) + caption_content.external_url = content.external_url + result = await self._send_message(intent, caption_content, timestamp=evt.date) + return result + + @staticmethod + def _parse_telegram_document_attributes(attributes: list[TypeDocumentAttribute]) -> DocAttrs: + name, mime_type, is_sticker, sticker_alt, width, height = None, None, False, None, 0, 0 + is_gif = False + for attr in attributes: + if isinstance(attr, DocumentAttributeFilename): + name = name or attr.file_name + mime_type, _ = mimetypes.guess_type(attr.file_name) + elif isinstance(attr, DocumentAttributeSticker): + is_sticker = True + sticker_alt = attr.alt + elif isinstance(attr, DocumentAttributeAnimated): + is_gif = True + elif isinstance(attr, DocumentAttributeVideo): + width, height = attr.w, attr.h + elif isinstance(attr, DocumentAttributeImageSize): + width, height = attr.w, attr.h + return DocAttrs(name, mime_type, is_sticker, sticker_alt, width, height, is_gif) + + @staticmethod + def _parse_telegram_document_meta(evt: Message, file: DBTelegramFile, attrs: DocAttrs, + thumb_size: TypePhotoSize) -> tuple[ImageInfo, str]: + document = evt.media.document + name = attrs.name + if attrs.is_sticker: + alt = attrs.sticker_alt + if len(alt) > 0: + try: + name = f"{alt} ({unicodedata.name(alt[0]).lower()})" + except ValueError: + name = alt + + generic_types = ("text/plain", "application/octet-stream") + if file.mime_type in generic_types and document.mime_type not in generic_types: + mime_type = document.mime_type or file.mime_type + elif file.mime_type == "application/ogg": + mime_type = "audio/ogg" + else: + mime_type = file.mime_type or document.mime_type + info = ImageInfo(size=file.size, mimetype=mime_type) + + if attrs.mime_type and not file.was_converted: + file.mime_type = attrs.mime_type or file.mime_type + if file.width and file.height: + info.width, info.height = file.width, file.height + elif attrs.width and attrs.height: + info.width, info.height = attrs.width, attrs.height + + if file.thumbnail: + if file.thumbnail.decryption_info: + info.thumbnail_file = file.thumbnail.decryption_info + else: + info.thumbnail_url = file.thumbnail.mxc + info.thumbnail_info = ThumbnailInfo(mimetype=file.thumbnail.mime_type, + height=file.thumbnail.height or thumb_size.h, + width=file.thumbnail.width or thumb_size.w, + size=file.thumbnail.size) + elif attrs.is_sticker: + # This is a hack for bad clients like Element iOS that require a thumbnail + info.thumbnail_info = ImageInfo.deserialize(info.serialize()) + if file.decryption_info: + info.thumbnail_file = file.decryption_info + else: + info.thumbnail_url = file.mxc + + return info, name + + async def _handle_telegram_document( + self, source: au.AbstractUser, intent: IntentAPI, evt: Message, relates_to: RelatesTo + ) -> EventID | None: + document = evt.media.document + + attrs = self._parse_telegram_document_attributes(document.attributes) + + if document.size > self.config["bridge.max_document_size"] * 1000 ** 2: + name = attrs.name or "" + caption = f"\n{evt.message}" if evt.message else "" + # TODO encrypt + return await intent.send_notice(self.mxid, f"Too large file {name}{caption}") + + thumb_loc, thumb_size = self._get_largest_photo_size(document) + if thumb_size and not isinstance(thumb_size, (PhotoSize, PhotoCachedSize)): + self.log.debug(f"Unsupported thumbnail type {type(thumb_size)}") + thumb_loc = None + thumb_size = None + parallel_id = source.tgid if self.config["bridge.parallel_file_transfer"] else None + file = await util.transfer_file_to_matrix( + source.client, intent, document, thumb_loc, + is_sticker=attrs.is_sticker, + tgs_convert=self.config["bridge.animated_sticker"], + filename=attrs.name, parallel_id=parallel_id, + encrypt=self.encrypted + ) + if not file: + return None + + info, name = self._parse_telegram_document_meta(evt, file, attrs, thumb_size) + + await intent.set_typing(self.mxid, is_typing=False) + + event_type = EventType.ROOM_MESSAGE + # Elements only support images as stickers, so send animated webm stickers as m.video + if attrs.is_sticker and file.mime_type.startswith("image/"): + event_type = EventType.STICKER + # Tell clients to render the stickers as 256x256 if they're bigger + if info.width > 256 or info.height > 256: + if info.width > info.height: + info.height = int(info.height / (info.width / 256)) + info.width = 256 + else: + info.width = int(info.width / (info.height / 256)) + info.height = 256 + if info.thumbnail_info: + info.thumbnail_info.width = info.width + info.thumbnail_info.height = info.height + if attrs.is_gif or (attrs.is_sticker and info.mimetype == "video/webm"): + if attrs.is_gif: + info["fi.mau.telegram.gif"] = True + else: + info["fi.mau.telegram.animated_sticker"] = True + info["fi.mau.loop"] = True + info["fi.mau.autoplay"] = True + info["fi.mau.hide_controls"] = True + info["fi.mau.no_audio"] = True + if not name: + ext = sane_mimetypes.guess_extension(file.mime_type) + name = "unnamed_file" + ext + + content = MediaMessageEventContent( + body=name, info=info, relates_to=relates_to, + external_url=self._get_external_url(evt), + msgtype={ + "video/": MessageType.VIDEO, + "audio/": MessageType.AUDIO, + "image/": MessageType.IMAGE, + }.get(info.mimetype[:6], MessageType.FILE)) + if file.decryption_info: + content.file = file.decryption_info + else: + content.url = file.mxc + res = await self._send_message(intent, content, event_type=event_type, timestamp=evt.date) + if evt.message: + caption_content = await formatter.telegram_to_matrix(evt, source, self.main_intent, + no_reply_fallback=True) + caption_content.external_url = content.external_url + res = await self._send_message(intent, caption_content, timestamp=evt.date) + return res + + def _handle_telegram_location( + self, source: au.AbstractUser, intent: IntentAPI, evt: Message, relates_to: RelatesTo + ) -> Awaitable[EventID]: + long = evt.media.geo.long + lat = evt.media.geo.lat + long_char = "E" if long > 0 else "W" + lat_char = "N" if lat > 0 else "S" + geo = f"{round(lat, 6)},{round(long, 6)}" + + body = f"{round(abs(lat), 4)}° {lat_char}, {round(abs(long), 4)}° {long_char}" + url = f"https://maps.google.com/?q={geo}" + + content = LocationMessageEventContent( + msgtype=MessageType.LOCATION, geo_uri=f"geo:{geo}", + body=f"Location: {body}\n{url}", + relates_to=relates_to, external_url=self._get_external_url(evt)) + content["format"] = str(Format.HTML) + content["formatted_body"] = f"Location: {body}" + + return self._send_message(intent, content, timestamp=evt.date) + + async def _handle_telegram_text( + self, source: au.AbstractUser, intent: IntentAPI, is_bot: bool, evt: Message + ) -> EventID: + self.log.trace(f"Sending {evt.message} to {self.mxid} by {intent.mxid}") + content = await formatter.telegram_to_matrix(evt, source, self.main_intent) + content.external_url = self._get_external_url(evt) + if is_bot and self.get_config("bot_messages_as_notices"): + content.msgtype = MessageType.NOTICE + await intent.set_typing(self.mxid, is_typing=False) + return await self._send_message(intent, content, timestamp=evt.date) + + async def _handle_telegram_unsupported( + self, source: au.AbstractUser, intent: IntentAPI, evt: Message, relates_to: RelatesTo + ) -> EventID: + override_text = ("This message is not supported on your version of Mautrix-Telegram. " + "Please check https://github.com/mautrix/telegram or ask your " + "bridge administrator about possible updates.") + content = await formatter.telegram_to_matrix( + evt, source, self.main_intent, override_text=override_text) + content.msgtype = MessageType.NOTICE + content.external_url = self._get_external_url(evt) + content["net.maunium.telegram.unsupported"] = True + await intent.set_typing(self.mxid, is_typing=False) + return await self._send_message(intent, content, timestamp=evt.date) + + async def _handle_telegram_poll( + self, source: au.AbstractUser, intent: IntentAPI, evt: Message, relates_to: RelatesTo + ) -> EventID: + poll: Poll = evt.media.poll + poll_id = self._encode_msgid(source, evt) + + _n = 0 + + def n() -> int: + nonlocal _n + _n += 1 + return _n + + text_answers = "\n".join(f"{n()}. {answer.text}" for answer in poll.answers) + html_answers = "\n".join(f"
  • {answer.text}
  • " for answer in poll.answers) + content = TextMessageEventContent( + msgtype=MessageType.TEXT, format=Format.HTML, + body=f"Poll: {poll.question}\n{text_answers}\n" + f"Vote with !tg vote {poll_id} ", + formatted_body=f"Poll: {poll.question}
    \n" + f"
      {html_answers}
    \n" + f"Vote with !tg vote {poll_id} <choice number>", + relates_to=relates_to, external_url=self._get_external_url(evt)) + + await intent.set_typing(self.mxid, is_typing=False) + return await self._send_message(intent, content, timestamp=evt.date) + + @staticmethod + def _format_dice(roll: MessageMediaDice) -> str: + if roll.emoticon == "\U0001F3B0": + emojis = { + 0: "\U0001F36B", # "🍫", + 1: "\U0001F352", # "🍒", + 2: "\U0001F34B", # "🍋", + 3: "7\ufe0f\u20e3" # "7️⃣", + } + res = roll.value - 1 + slot1, slot2, slot3 = emojis[res % 4], emojis[res // 4 % 4], emojis[res // 16] + return f"{slot1} {slot2} {slot3} ({roll.value})" + elif roll.emoticon == "\u26BD": + results = { + 1: "miss", + 2: "hit the woodwork", + 3: "goal", # seems to go in through the center + 4: "goal", + 5: "goal 🎉", # seems to go in through the top right corner, includes confetti + } + elif roll.emoticon == "\U0001F3B3": + results = { + 1: "miss", + 2: "1 pin down", + 3: "3 pins down, split", + 4: "4 pins down, split", + 5: "5 pins down", + 6: "strike 🎉", + } + # elif roll.emoticon == "\U0001F3C0": + # results = { + # 2: "rolled off", + # 3: "stuck", + # } + # elif roll.emoticon == "\U0001F3AF": + # results = { + # 1: "bounced off", + # 2: "outer rim", + # + # 6: "bullseye", + # } + else: + return str(roll.value) + return f"{results[roll.value]} ({roll.value})" + + async def _handle_telegram_dice( + self, _: au.AbstractUser, intent: IntentAPI, evt: Message, relates_to: RelatesTo + ) -> EventID: + emoji_text = { + "\U0001F3AF": " Dart throw", + "\U0001F3B2": " Dice roll", + "\U0001F3C0": " Basketball throw", + "\U0001F3B0": " Slot machine", + "\U0001F3B3": " Bowling", + "\u26BD": " Football kick" + } + roll: MessageMediaDice = evt.media + text = f"{roll.emoticon}{emoji_text.get(roll.emoticon, '')} result: {self._format_dice(roll)}" + content = TextMessageEventContent(msgtype=MessageType.TEXT, format=Format.HTML, body=text, + formatted_body=f"

    {text}

    ", relates_to=relates_to, + external_url=self._get_external_url(evt)) + content["net.maunium.telegram.dice"] = {"emoticon": roll.emoticon, "value": roll.value} + await intent.set_typing(self.mxid, is_typing=False) + return await self._send_message(intent, content, timestamp=evt.date) + + @staticmethod + def _int_to_bytes(i: int) -> bytes: + hex_value = f"{i:010x}".encode("utf-8") + return codecs.decode(hex_value, "hex_codec") + + def _encode_msgid(self, source: au.AbstractUser, evt: Message) -> str: + if self.peer_type == "channel": + play_id = (b"c" + + self._int_to_bytes(self.tgid) + + self._int_to_bytes(evt.id)) + elif self.peer_type == "chat": + play_id = (b"g" + + self._int_to_bytes(self.tgid) + + self._int_to_bytes(evt.id) + + self._int_to_bytes(source.tgid)) + elif self.peer_type == "user": + play_id = (b"u" + + self._int_to_bytes(self.tgid) + + self._int_to_bytes(evt.id)) + else: + raise ValueError("Portal has invalid peer type") + return base64.b64encode(play_id).decode("utf-8").rstrip("=") + + async def _handle_telegram_game( + self, source: au.AbstractUser, intent: IntentAPI, evt: Message, relates_to: RelatesTo + ) -> EventID: + game = evt.media.game + play_id = self._encode_msgid(source, evt) + command = f"!tg play {play_id}" + override_text = f"Run {command} in your bridge management room to play {game.title}" + override_entities = [ + MessageEntityPre(offset=len("Run "), length=len(command), language="")] + + content = await formatter.telegram_to_matrix( + evt, source, self.main_intent, + override_text=override_text, override_entities=override_entities) + content.msgtype = MessageType.NOTICE + content.external_url = self._get_external_url(evt) + content["net.maunium.telegram.game"] = play_id + + await intent.set_typing(self.mxid, is_typing=False) + return await self._send_message(intent, content, timestamp=evt.date) + + async def handle_telegram_edit( + self, source: au.AbstractUser, sender: p.Puppet, evt: Message + ) -> None: + if not self.mxid: + self.log.trace("Ignoring edit to %d as chat has no Matrix room", evt.id) + return + elif hasattr(evt, "media") and isinstance(evt.media, MessageMediaGame): + self.log.debug("Ignoring game message edit event") + return + + async with self.send_lock(sender.tgid if sender else None, required=False): + tg_space = self.tgid if self.peer_type == "channel" else source.tgid + + temporary_identifier = EventID( + f"${random.randint(1000000000000, 9999999999999)}TGBRIDGEDITEMP") + duplicate_found = self.dedup.check(evt, (temporary_identifier, tg_space), + force_hash=True) + if duplicate_found: + mxid, other_tg_space = duplicate_found + if tg_space != other_tg_space: + prev_edit_msg = await DBMessage.get_one_by_tgid(TelegramID(evt.id), tg_space, + edit_index=-1) + if not prev_edit_msg: + return + await DBMessage(mxid=mxid, mx_room=self.mxid, tg_space=tg_space, + tgid=TelegramID(evt.id), + edit_index=prev_edit_msg.edit_index + 1 + ).insert() + return + + content = await formatter.telegram_to_matrix(evt, source, self.main_intent, + no_reply_fallback=True) + editing_msg = await DBMessage.get_one_by_tgid(TelegramID(evt.id), tg_space) + if not editing_msg: + self.log.info(f"Didn't find edited message {evt.id}@{tg_space} (src {source.tgid}) " + "in database.") + return + + content.msgtype = (MessageType.NOTICE if (sender and sender.is_bot + and self.get_config("bot_messages_as_notices")) + else MessageType.TEXT) + content.external_url = self._get_external_url(evt) + content.set_edit(editing_msg.mxid) + + intent = sender.intent_for(self) if sender else self.main_intent + await intent.set_typing(self.mxid, is_typing=False) + event_id = await self._send_message(intent, content) + + prev_edit_msg = (await DBMessage.get_one_by_tgid(TelegramID(evt.id), tg_space, -1) + or editing_msg) + await DBMessage(mxid=event_id, mx_room=self.mxid, tg_space=tg_space, + tgid=TelegramID(evt.id), edit_index=prev_edit_msg.edit_index + 1).insert() + await DBMessage.replace_temp_mxid(temporary_identifier, self.mxid, event_id) + + @property + def _takeout_options(self) -> dict[str, bool | int]: + return { + "files": True, + "megagroups": self.megagroup, + "chats": self.peer_type == "chat", + "users": self.peer_type == "user", + "channels": (self.peer_type == "channel" and not self.megagroup), + "max_file_size": min(self.config["bridge.max_document_size"], 2000) * 1024 * 1024 + } + + async def backfill( + self, + source: u.User, + is_initial: bool = False, + limit: int | None = None, + last_id: int | None = None, + ) -> None: + async with self.backfill_method_lock: + await self._locked_backfill(source, is_initial, limit, last_id) + + async def _locked_backfill( + self, + source: u.User, + is_initial: bool = False, + limit: int | None = None, + last_id: int | None = None, + ) -> None: + limit = limit or (self.config["bridge.backfill.initial_limit"] if is_initial + else self.config["bridge.backfill.missed_limit"]) + if limit == 0: + return + if not self.config["bridge.backfill.normal_groups"] and self.peer_type == "chat": + return + last = await DBMessage.find_last(self.mxid, (source.tgid if self.peer_type != "channel" + else self.tgid)) + min_id = last.tgid if last else 0 + if last_id is None: + messages = await source.client.get_messages(self.peer, limit=1) + if not messages: + # The chat seems empty + return + last_id = messages[0].id + if last_id <= min_id: + # Nothing to backfill + return + if limit < 0: + limit = last_id - min_id + self.log.debug(f"Backfilling approximately {last_id - min_id} messages " + f"through {source.mxid}") + elif self.peer_type == "channel": + # This is a channel or supergroup, so we'll backfill messages based on the ID. + # There are some cases, such as deleted messages, where this may backfill less + # messages than the limit. + min_id = max(last_id - limit, min_id) + self.log.debug(f"Backfilling messages after ID {min_id} (last message: {last_id}) " + f"through {source.mxid}") + else: + # Private chats and normal groups don't have their own message ID namespace, + # which means we'll have to fetch messages a different way. + # The _backfill_messages method will detect min_id=None and not use reverse=True + min_id = None + self.log.debug(f"Backfilling up to {limit} messages through {source.mxid}") + with self.backfill_lock: + await self._backfill(source, min_id, limit) + + async def _backfill(self, source: u.User, min_id: int | None, limit: int) -> None: + self.backfill_leave = set() + if ((self.peer_type == "user" and self.tgid != source.tgid + and self.config["bridge.backfill.invite_own_puppet"])): + self.log.debug("Adding %s's default puppet to room for backfilling", source.mxid) + sender = await p.Puppet.get_by_tgid(source.tgid) + await self.main_intent.invite_user(self.mxid, sender.default_mxid) + await sender.default_mxid_intent.join_room_by_id(self.mxid) + self.backfill_leave.add(sender.default_mxid_intent) + + client = source.client + async with NotificationDisabler(self.mxid, source): + if limit > self.config["bridge.backfill.takeout_limit"]: + self.log.debug(f"Opening takeout client for {source.tgid}") + async with client.takeout(**self._takeout_options) as takeout: + count = await self._backfill_messages(source, min_id, limit, takeout) + else: + count = await self._backfill_messages(source, min_id, limit, client) + + for intent in self.backfill_leave: + self.log.trace("Leaving room with %s post-backfill", intent.mxid) + await intent.leave_room(self.mxid) + self.backfill_leave = None + self.log.info("Backfilled %d messages through %s", count, source.mxid) + + async def _backfill_messages( + self, source: u.User, min_id: int | None, limit: int, client: MautrixTelegramClient + ) -> int: + count = 0 + entity = await self.get_input_entity(source) + if min_id is not None: + self.log.debug(f"Iterating all messages starting with {min_id} (approx: {limit})") + messages = client.iter_messages(entity, reverse=True, min_id=min_id) + async for message in messages: + sender = (await p.Puppet.get_by_tgid(TelegramID(message.from_id.user_id)) + if isinstance(message.from_id, PeerUser) else None) + # TODO handle service messages? + await self.handle_telegram_message(source, sender, message) + count += 1 + else: + self.log.debug(f"Fetching up to {limit} most recent messages") + messages = await client.get_messages(entity, limit=limit) + for message in reversed(messages): + sender = (await p.Puppet.get_by_tgid(TelegramID(message.from_id.user_id)) + if isinstance(message.from_id, PeerUser) else None) + await self.handle_telegram_message(source, sender, message) + count += 1 + return count + + async def handle_telegram_message( + self, source: au.AbstractUser, sender: p.Puppet, evt: Message + ) -> None: + if not self.mxid: + self.log.trace("Got telegram message %d, but no room exists, creating...", evt.id) + await self.create_matrix_room(source, invites=[source.mxid], update_if_exists=False) + + if (self.peer_type == "user" and sender and sender.tgid == self.tg_receiver + and not sender.is_real_user and not await self.az.state_store.is_joined(self.mxid, + sender.mxid)): + self.log.debug(f"Ignoring private chat message {evt.id}@{source.tgid} as receiver does" + " not have matrix puppeting and their default puppet isn't in the room") + return + + async with self.send_lock(sender.tgid if sender else None, required=False): + tg_space = self.tgid if self.peer_type == "channel" else source.tgid + + temporary_identifier = EventID( + f"${random.randint(1000000000000, 9999999999999)}TGBRIDGETEMP") + duplicate_found = self.dedup.check(evt, (temporary_identifier, tg_space)) + if duplicate_found: + mxid, other_tg_space = duplicate_found + self.log.debug(f"Ignoring message {evt.id}@{tg_space} (src {source.tgid}) " + f"as it was already handled (in space {other_tg_space})") + if tg_space != other_tg_space: + await DBMessage(tgid=TelegramID(evt.id), mx_room=self.mxid, mxid=mxid, + tg_space=tg_space, edit_index=0).insert() + return + + if self.backfill_lock.locked or (self.dedup.pre_db_check and self.peer_type == "channel"): + msg = await DBMessage.get_one_by_tgid(TelegramID(evt.id), tg_space) + if msg: + self.log.debug(f"Ignoring message {evt.id} (src {source.tgid}) as it was already " + f"handled into {msg.mxid}. This duplicate was catched in the db " + "check. If you get this message often, consider increasing " + "bridge.deduplication.cache_queue_length in the config.") + return + + self.log.trace("Handling Telegram message %s", evt) + + if sender and not sender.displayname: + self.log.debug(f"Telegram user {sender.tgid} sent a message, but doesn't have a " + "displayname, updating info...") + entity = await source.client.get_entity(PeerUser(sender.tgid)) + await sender.update_info(source, entity) + if not sender.displayname: + self.log.debug(f"Telegram user {sender.tgid} doesn't have a displayname even after" + f" updating with data {entity!s}") + + allowed_media = (MessageMediaPhoto, MessageMediaDocument, MessageMediaGeo, + MessageMediaGame, MessageMediaDice, MessageMediaPoll, + MessageMediaUnsupported) + if sender: + intent = sender.intent_for(self) + if ((self.backfill_lock.locked and intent != sender.default_mxid_intent + and self.config["bridge.backfill.invite_own_puppet"])): + intent = sender.default_mxid_intent + self.backfill_leave.add(intent) + else: + intent = self.main_intent + if hasattr(evt, "media") and isinstance(evt.media, allowed_media): + handler: MediaHandler = { + MessageMediaPhoto: self._handle_telegram_photo, + MessageMediaDocument: self._handle_telegram_document, + MessageMediaGeo: self._handle_telegram_location, + MessageMediaPoll: self._handle_telegram_poll, + MessageMediaDice: self._handle_telegram_dice, + MessageMediaUnsupported: self._handle_telegram_unsupported, + MessageMediaGame: self._handle_telegram_game, + }[type(evt.media)] + relates_to = await formatter.telegram_reply_to_matrix(evt, source) + event_id = await handler(source, intent, evt, relates_to) + elif evt.message: + is_bot = sender.is_bot if sender else False + event_id = await self._handle_telegram_text(source, intent, is_bot, evt) + else: + self.log.debug("Unhandled Telegram message %d", evt.id) + return + + if not event_id: + return + + prev_id = self.dedup.update(evt, (event_id, tg_space), (temporary_identifier, tg_space)) + if prev_id: + self.log.debug(f"Sent message {evt.id}@{tg_space} to Matrix as {event_id}. " + f"Temporary dedup identifier was {temporary_identifier}, " + f"but dedup map contained {prev_id[1]} instead! -- " + "This was probably a race condition caused by Telegram sending updates" + "to other clients before responding to the sender. I'll just redact " + "the likely duplicate message now.") + await intent.redact(self.mxid, event_id) + return + + self.log.debug("Handled telegram message %d -> %s", evt.id, event_id) + try: + await DBMessage(tgid=TelegramID(evt.id), mx_room=self.mxid, mxid=event_id, + tg_space=tg_space, edit_index=0).insert() + await DBMessage.replace_temp_mxid(temporary_identifier, self.mxid, event_id) + except (IntegrityError, UniqueViolationError) as e: + self.log.exception(f"{e.__class__.__name__} while saving message mapping. " + "This might mean that an update was handled after it left the " + "dedup cache queue. You can try enabling bridge.deduplication." + "pre_db_check in the config.") + await intent.redact(self.mxid, event_id) + await self._send_delivery_receipt(event_id) + + async def _create_room_on_action( + self, source: au.AbstractUser, action: TypeMessageAction + ) -> bool: + if source.is_relaybot and self.config["bridge.ignore_unbridged_group_chat"]: + return False + create_and_exit = (MessageActionChatCreate, MessageActionChannelCreate) + create_and_continue = (MessageActionChatAddUser, MessageActionChatJoinedByLink) + if isinstance(action, create_and_exit) or isinstance(action, create_and_continue): + await self.create_matrix_room(source, invites=[source.mxid], + update_if_exists=isinstance(action, create_and_exit)) + if not isinstance(action, create_and_continue): + return False + return True + + async def handle_telegram_action( + self, source: au.AbstractUser, sender: p.Puppet, update: MessageService + ) -> None: + action = update.action + should_ignore = ((not self.mxid and not await self._create_room_on_action(source, action)) + or self.dedup.check_action(update)) + if should_ignore or not self.mxid: + return + if isinstance(action, MessageActionChatEditTitle): + await self._update_title(action.title, sender=sender, save=True) + await self.update_bridge_info() + elif isinstance(action, MessageActionChatEditPhoto): + await self._update_avatar(source, action.photo, sender=sender, save=True) + await self.update_bridge_info() + elif isinstance(action, MessageActionChatDeletePhoto): + await self._update_avatar(source, ChatPhotoEmpty(), sender=sender, save=True) + await self.update_bridge_info() + elif isinstance(action, MessageActionChatAddUser): + for user_id in action.users: + await self._add_telegram_user(TelegramID(user_id), source) + elif isinstance(action, MessageActionChatJoinedByLink): + await self._add_telegram_user(sender.id, source) + elif isinstance(action, MessageActionChatDeleteUser): + await self._delete_telegram_user(TelegramID(action.user_id), sender) + elif isinstance(action, MessageActionChatMigrateTo): + await self._migrate_and_save_telegram(TelegramID(action.channel_id)) + # TODO encrypt + await sender.intent_for(self).send_emote(self.mxid, + "upgraded this group to a supergroup.") + await self.update_bridge_info() + elif isinstance(action, MessageActionGameScore): + # TODO handle game score + pass + else: + self.log.trace("Unhandled Telegram action in %s: %s", self.title, action) + + async def set_telegram_admin(self, user_id: TelegramID) -> None: + puppet = await p.Puppet.get_by_tgid(user_id) + user = await u.User.get_by_tgid(user_id) + + levels = await self.main_intent.get_power_levels(self.mxid) + if user: + levels.users[user.mxid] = 50 + if puppet: + levels.users[puppet.mxid] = 50 + await self.main_intent.set_power_levels(self.mxid, levels) + + async def receive_telegram_pin_ids(self, msg_ids: list[TelegramID], receiver: TelegramID, + remove: bool) -> None: + async with self._pin_lock: + tg_space = receiver if self.peer_type != "channel" else self.tgid + previously_pinned = await self.main_intent.get_pinned_messages(self.mxid) + currently_pinned_dict = {event_id: True for event_id in previously_pinned} + for message in await DBMessage.get_first_by_tgids(msg_ids, tg_space): + if remove: + currently_pinned_dict.pop(message.mxid, None) + else: + currently_pinned_dict[message.mxid] = True + currently_pinned = list(currently_pinned_dict.keys()) + if currently_pinned != previously_pinned: + await self.main_intent.set_pinned_messages(self.mxid, currently_pinned) + + async def set_telegram_admins_enabled(self, enabled: bool) -> None: + level = 50 if enabled else 10 + levels = await self.main_intent.get_power_levels(self.mxid) + levels.invite = level + levels.events[EventType.ROOM_NAME] = level + levels.events[EventType.ROOM_AVATAR] = level + await self.main_intent.set_power_levels(self.mxid, levels) + + # endregion + # region Miscellaneous getters + + def get_config(self, key: str) -> Any: + local = util.recursive_get(self.local_config, key) + if local is not None: + return local + return self.config[f"bridge.{key}"] + + @staticmethod + def _photo_size_key(photo: TypePhotoSize) -> int: + if isinstance(photo, PhotoSize): + return photo.size + elif isinstance(photo, PhotoSizeProgressive): + return max(photo.sizes) + elif isinstance(photo, PhotoSizeEmpty): + return 0 + else: + return len(photo.bytes) + + @classmethod + def _get_largest_photo_size( + cls, photo: Photo | Document + ) -> tuple[InputPhotoFileLocation | None, TypePhotoSize | None]: + if not photo or isinstance(photo, PhotoEmpty) or (isinstance(photo, Document) + and not photo.thumbs): + return None, None + + largest = max(photo.thumbs if isinstance(photo, Document) else photo.sizes, + key=cls._photo_size_key) + return InputPhotoFileLocation( + id=photo.id, + access_hash=photo.access_hash, + file_reference=photo.file_reference, + thumb_size=largest.type, + ), largest + + async def can_user_perform(self, user: u.User, event: str) -> bool: + if user.is_admin: + return True + if not self.mxid: + # No room for anybody to perform actions in + return False + try: + await self.main_intent.get_power_levels(self.mxid) + except MatrixRequestError: + return False + evt_type = EventType.find(f"net.maunium.telegram.{event}", t_class=EventType.Class.STATE) + return await self.main_intent.state_store.has_power_level(self.mxid, user.mxid, evt_type) + + def get_input_entity( + self, user: au.AbstractUser + ) -> Awaitable[TypeInputPeer | TypeInputChannel]: + return user.client.get_input_entity(self.peer) + + async def get_entity(self, user: au.AbstractUser) -> TypeChat: + try: + return await user.client.get_entity(self.peer) + except ValueError: + if user.is_bot: + self.log.warning(f"Could not find entity with bot {user.tgid}. Failing...") + raise + self.log.warning(f"Could not find entity with user {user.tgid}. " + "falling back to get_dialogs.") + async for dialog in user.client.iter_dialogs(): + if dialog.entity.id == self.tgid: + return dialog.entity + raise + + async def get_invite_link( + self, user: u.User, uses: int | None = None, expire: datetime | None = None + ) -> str: + if self.peer_type == "user": + raise ValueError("You can't invite users to private chats.") + if self.username: + return f"https://t.me/{self.username}" + link = await user.client(ExportChatInviteRequest(peer=await self.get_input_entity(user), + expire_date=expire, usage_limit=uses)) + return link.link + + # endregion + # region Matrix room cleanup + + async def get_authenticated_matrix_users(self) -> list[UserID]: + try: + members = await self.main_intent.get_room_members(self.mxid) + except MatrixRequestError: + return [] + authenticated: list[UserID] = [] + has_bot = self.has_bot + for member in members: + if p.Puppet.get_id_from_mxid(member) or member == self.az.bot_mxid: + continue + user = await u.User.get_and_start_by_mxid(member) + authenticated_through_bot = has_bot and user.relaybot_whitelisted + if authenticated_through_bot or await user.has_full_access(allow_bot=True): + authenticated.append(user.mxid) + return authenticated + + async def cleanup_portal(self, message: str, puppets_only: bool = False, delete: bool = True + ) -> None: + if self.username: + try: + await self.main_intent.remove_room_alias(self.alias_localpart) + except (MatrixRequestError, IntentError): + self.log.warning("Failed to remove alias when cleaning up room", exc_info=True) + await self.cleanup_room(self.main_intent, self.mxid, message, puppets_only) + if delete: + await self.delete() + + async def delete(self) -> None: + try: + del self.by_tgid[self.tgid_full] + except KeyError: + pass + try: + del self.by_mxid[self.mxid] + except KeyError: + pass + await super().delete() + await DBMessage.delete_all(self.mxid) + self.deleted = True + + # endregion + # region Class instance lookup + + async def postinit(self) -> None: + puppet = await p.Puppet.get_by_tgid(self.tgid) if self.is_direct else None + self._main_intent = puppet.intent_for(self) if self.is_direct else self.az.intent + + if self.tgid: + self.by_tgid[self.tgid_full] = self + if self.mxid: + self.by_mxid[self.mxid] = self + + @classmethod + async def all(cls) -> AsyncGenerator[Portal, None]: + portals = await super().all() + portal: cls + for portal in portals: + try: + yield cls.by_tgid[portal.tgid_full] + except KeyError: + await portal.postinit() + yield portal + + @classmethod + async def find_private_chats(cls, tg_receiver: TelegramID) -> AsyncGenerator[Portal, None]: + portals = await super().find_private_chats(tg_receiver) + portal: cls + for portal in portals: + try: + yield cls.by_tgid[portal.tgid_full] + except KeyError: + await portal.postinit() + yield portal + + @classmethod + @async_getter_lock + async def get_by_mxid(cls, mxid: RoomID) -> Portal | None: + try: + return cls.by_mxid[mxid] + except KeyError: + pass + + portal = cast(cls, await super().get_by_mxid(mxid)) + if portal: + await portal.postinit() + return portal + + return None + + @classmethod + def get_username_from_mx_alias(cls, alias: str) -> str | None: + return cls.alias_template.parse(alias) + + @classmethod + async def find_by_username(cls, username: str) -> Portal | None: + if not username: + return None + + username = username.lower() + + for _, portal in cls.by_tgid.items(): + if portal.username and portal.username.lower() == username: + return portal + + portal = cast(cls, await super().find_by_username(username)) + if portal: + try: + return cls.by_tgid[portal.tgid_full] + except KeyError: + await portal.postinit() + return portal + + return None + + @classmethod + @async_getter_lock + async def get_by_tgid( + cls, tgid: TelegramID, *, tg_receiver: TelegramID | None = None, peer_type: str = None + ) -> Portal | None: + if peer_type == "user" and tg_receiver is None: + raise ValueError("tg_receiver is required when peer_type is \"user\"") + tg_receiver = tg_receiver or tgid + tgid_full = (tgid, tg_receiver) + try: + return cls.by_tgid[tgid_full] + except KeyError: + pass + + portal = cast(cls, await super().get_by_tgid(tgid, tg_receiver)) + if portal: + await portal.postinit() + return portal + + if peer_type: + cls.log.info(f"Creating portal for {peer_type} {tgid} (receiver {tg_receiver})") + # TODO enable this for non-release builds + # (or add better wrong peer type error handling) + # if peer_type == "chat": + # import traceback + # cls.log.info("Chat portal stack trace:\n" + "".join(traceback.format_stack())) + portal = cls(tgid, peer_type=peer_type, tg_receiver=tg_receiver) + await portal.postinit() + await portal.insert() + return portal + + return None + + @classmethod + async def get_by_entity( + cls, + entity: TypeChat | TypePeer | TypeUser | TypeUserFull | TypeInputPeer, + tg_receiver: TelegramID | None = None, + create: bool = True, + ) -> Portal | None: + entity_type = type(entity) + if entity_type in (Chat, ChatFull): + type_name = "chat" + entity_id = entity.id + elif entity_type in (PeerChat, InputPeerChat): + type_name = "chat" + entity_id = entity.chat_id + elif entity_type in (Channel, ChannelFull): + type_name = "channel" + entity_id = entity.id + elif entity_type in (PeerChannel, InputPeerChannel, InputChannel): + type_name = "channel" + entity_id = entity.channel_id + elif entity_type in (User, UserFull): + type_name = "user" + entity_id = entity.id + elif entity_type in (PeerUser, InputPeerUser, InputUser): + type_name = "user" + entity_id = entity.user_id + else: + raise ValueError(f"Unknown entity type {entity_type.__name__}") + return await cls.get_by_tgid( + TelegramID(entity_id), + tg_receiver=tg_receiver if type_name == "user" else entity_id, + peer_type=type_name if create else None, + ) + + # endregion diff --git a/mautrix_telegram/portal/__init__.py b/mautrix_telegram/portal/__init__.py deleted file mode 100644 index 800f93d2..00000000 --- a/mautrix_telegram/portal/__init__.py +++ /dev/null @@ -1,21 +0,0 @@ -from .base import BasePortal, init as init_base -from .matrix import PortalMatrix, init as init_matrix -from .metadata import PortalMetadata, init as init_metadata -from .telegram import PortalTelegram, init as init_telegram -from .deduplication import init as init_dedup -from ..context import Context - - -class Portal(PortalMatrix, PortalTelegram, PortalMetadata): - pass - - -def init(context: Context) -> None: - init_base(context) - init_dedup(context) - init_metadata(context) - init_telegram(context) - init_matrix(context) - - -__all__ = ["Portal", "init"] diff --git a/mautrix_telegram/portal/__init__.pyi b/mautrix_telegram/portal/__init__.pyi deleted file mode 100644 index c1564548..00000000 --- a/mautrix_telegram/portal/__init__.pyi +++ /dev/null @@ -1,15 +0,0 @@ -from typing import Union -from .base import BasePortal -from .matrix import PortalMatrix -from .metadata import PortalMetadata -from .telegram import PortalTelegram -from ..context import Context - -Portal = Union[BasePortal, PortalMatrix, PortalMetadata, PortalTelegram] - - -def init(context: Context) -> None: - pass - - -__all__ = ["Portal", "init"] diff --git a/mautrix_telegram/portal/base.py b/mautrix_telegram/portal/base.py deleted file mode 100644 index 0f1ec276..00000000 --- a/mautrix_telegram/portal/base.py +++ /dev/null @@ -1,551 +0,0 @@ -# mautrix-telegram - A Matrix-Telegram puppeting bridge -# Copyright (C) 2020 Tulir Asokan -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Affero General Public License for more details. -# -# You should have received a copy of the GNU Affero General Public License -# along with this program. If not, see . -from typing import Awaitable, Dict, List, Optional, Tuple, Union, Any, Set, Iterable, TYPE_CHECKING -from abc import ABC, abstractmethod -from datetime import datetime -import asyncio -import logging -import json - -from telethon.tl.functions.messages import ExportChatInviteRequest -from telethon.tl.types import (Channel, ChannelFull, Chat, ChatFull, InputChannel, - InputPeerChannel, InputPeerChat, InputPeerUser, InputUser, - PeerChannel, PeerChat, PeerUser, TypeChat, TypeInputPeer, TypePeer, - TypeUser, TypeUserFull, User, UserFull, TypeInputChannel, Photo, - Document, TypePhotoSize, PhotoSize, InputPhotoFileLocation, - TypeChatParticipant, TypeChannelParticipant, PhotoEmpty, ChatPhoto, - ChatPhotoEmpty, PhotoSizeProgressive, PhotoSizeEmpty) - -from mautrix.errors import MatrixRequestError, IntentError -from mautrix.appservice import AppService, IntentAPI -from mautrix.types import (RoomID, RoomAlias, UserID, EventID, EventType, - PowerLevelStateEventContent, ContentURI) -from mautrix.util.simple_template import SimpleTemplate -from mautrix.util.simple_lock import SimpleLock -from mautrix.util.logging import TraceLogger -from mautrix.bridge import BasePortal as MautrixBasePortal - -from ..types import TelegramID -from ..context import Context -from ..db import Portal as DBPortal, Message as DBMessage -from .. import puppet as p, user as u, util -from .deduplication import PortalDedup -from .send_lock import PortalSendLock - -if TYPE_CHECKING: - from ..bot import Bot - from ..abstract_user import AbstractUser - from ..config import Config - from ..matrix import MatrixHandler - from . import Portal - -TypeParticipant = Union[TypeChatParticipant, TypeChannelParticipant] -TypeChatPhoto = Union[ChatPhoto, ChatPhotoEmpty, Photo, PhotoEmpty] -InviteList = Union[UserID, List[UserID]] - -config: Optional['Config'] = None - - -class BasePortal(MautrixBasePortal, ABC): - base_log: TraceLogger = logging.getLogger("mau.portal") - az: AppService = None - bot: 'Bot' = None - loop: asyncio.AbstractEventLoop = None - matrix: 'MatrixHandler' = None - - # Config cache - filter_mode: str = None - filter_list: List[int] = None - - max_initial_member_sync: int = -1 - sync_channel_members: bool = True - sync_matrix_state: bool = True - public_portals: bool = False - private_chat_portal_meta: bool = False - - alias_template: SimpleTemplate[str] - hs_domain: str - - # Instance cache - by_mxid: Dict[RoomID, 'Portal'] = {} - by_tgid: Dict[Tuple[TelegramID, TelegramID], 'Portal'] = {} - - mxid: Optional[RoomID] - tgid: TelegramID - tg_receiver: TelegramID - peer_type: str - username: str - megagroup: bool - title: Optional[str] - about: Optional[str] - photo_id: Optional[str] - local_config: Dict[str, Any] - avatar_url: Optional[ContentURI] - encrypted: bool - deleted: bool - backfill_lock: SimpleLock - backfill_method_lock: asyncio.Lock - backfill_leave: Optional[Set[IntentAPI]] - log: TraceLogger - - alias: Optional[RoomAlias] - - dedup: PortalDedup - send_lock: PortalSendLock - _pin_lock: asyncio.Lock - - _db_instance: DBPortal - _main_intent: Optional[IntentAPI] - _room_create_lock: asyncio.Lock - - def __init__(self, tgid: TelegramID, peer_type: str, tg_receiver: Optional[TelegramID] = None, - mxid: Optional[RoomID] = None, username: Optional[str] = None, - megagroup: Optional[bool] = False, title: Optional[str] = None, - about: Optional[str] = None, photo_id: Optional[str] = None, - local_config: Optional[str] = None, avatar_url: Optional[ContentURI] = None, - encrypted: Optional[bool] = False, db_instance: DBPortal = None) -> None: - self.mxid = mxid - self.tgid = tgid - self.tg_receiver = tg_receiver or tgid - self.peer_type = peer_type - self.username = username - self.megagroup = megagroup - self.title = title - self.about = about - self.photo_id = photo_id - self.local_config = json.loads(local_config or "{}") - self.avatar_url = avatar_url - self.encrypted = encrypted - self._db_instance = db_instance - self._main_intent = None - self.deleted = False - self.log = self.base_log.getChild(self.tgid_log if self.tgid else self.mxid) - self.backfill_lock = SimpleLock("Waiting for backfilling to finish before handling %s", - log=self.log) - self.backfill_method_lock = asyncio.Lock() - self.backfill_leave = None - - self.dedup = PortalDedup(self) - self.send_lock = PortalSendLock() - self._pin_lock = asyncio.Lock() - - if tgid: - self.by_tgid[self.tgid_full] = self - if mxid: - self.by_mxid[mxid] = self - - # region Properties - - @property - def tgid_full(self) -> Tuple[TelegramID, TelegramID]: - return self.tgid, self.tg_receiver - - @property - def tgid_log(self) -> str: - if self.tgid == self.tg_receiver: - return str(self.tgid) - return f"{self.tg_receiver}<->{self.tgid}" - - @property - def name(self) -> str: - return self.title - - @property - def alias(self) -> Optional[RoomAlias]: - if not self.username: - return None - return RoomAlias(f"#{self.alias_localpart}:{self.hs_domain}") - - @property - def alias_localpart(self) -> Optional[str]: - if not self.username: - return None - return self.alias_template.format(self.username) - - @property - def peer(self) -> Union[TypePeer, TypeInputPeer]: - if self.peer_type == "user": - return PeerUser(user_id=self.tgid) - elif self.peer_type == "chat": - return PeerChat(chat_id=self.tgid) - elif self.peer_type == "channel": - return PeerChannel(channel_id=self.tgid) - - @property - def is_direct(self) -> bool: - return self.peer_type == "user" - - @property - def has_bot(self) -> bool: - return (bool(self.bot) - and (self.bot.is_in_chat(self.tgid) - or (self.peer_type == "user" and self.tg_receiver == self.bot.tgid))) - - @property - def main_intent(self) -> IntentAPI: - if not self._main_intent: - direct = self.peer_type == "user" - puppet = p.Puppet.get(self.tgid) if direct else None - self._main_intent = puppet.intent_for(self) if direct else self.az.intent - return self._main_intent - - @property - def allow_bridging(self) -> bool: - if self.peer_type == "user": - return True - elif self.filter_mode == "whitelist": - return self.tgid in self.filter_list - elif self.filter_mode == "blacklist": - return self.tgid not in self.filter_list - return True - - # endregion - # region Miscellaneous getters - - def get_config(self, key: str) -> Any: - local = util.recursive_get(self.local_config, key) - if local is not None: - return local - return config[f"bridge.{key}"] - - @staticmethod - def _photo_size_key(photo: TypePhotoSize) -> int: - if isinstance(photo, PhotoSize): - return photo.size - elif isinstance(photo, PhotoSizeProgressive): - return max(photo.sizes) - elif isinstance(photo, PhotoSizeEmpty): - return 0 - else: - return len(photo.bytes) - - @classmethod - def _get_largest_photo_size(cls, photo: Union[Photo, Document] - ) -> Tuple[Optional[InputPhotoFileLocation], - Optional[TypePhotoSize]]: - if not photo or isinstance(photo, PhotoEmpty) or (isinstance(photo, Document) - and not photo.thumbs): - return None, None - - largest = max(photo.thumbs if isinstance(photo, Document) else photo.sizes, - key=cls._photo_size_key) - return InputPhotoFileLocation( - id=photo.id, - access_hash=photo.access_hash, - file_reference=photo.file_reference, - thumb_size=largest.type, - ), largest - - async def can_user_perform(self, user: 'u.User', event: str) -> bool: - if user.is_admin: - return True - if not self.mxid: - # No room for anybody to perform actions in - return False - try: - await self.main_intent.get_power_levels(self.mxid) - except MatrixRequestError: - return False - evt_type = EventType.find(f"net.maunium.telegram.{event}", t_class=EventType.Class.STATE) - return await self.main_intent.state_store.has_power_level(self.mxid, user.mxid, evt_type) - - def get_input_entity(self, user: 'AbstractUser' - ) -> Awaitable[Union[TypeInputPeer, TypeInputChannel]]: - return user.client.get_input_entity(self.peer) - - async def get_entity(self, user: 'AbstractUser') -> TypeChat: - try: - return await user.client.get_entity(self.peer) - except ValueError: - if user.is_bot: - self.log.warning(f"Could not find entity with bot {user.tgid}. Failing...") - raise - self.log.warning(f"Could not find entity with user {user.tgid}. " - "falling back to get_dialogs.") - async for dialog in user.client.iter_dialogs(): - if dialog.entity.id == self.tgid: - return dialog.entity - raise - - async def get_invite_link(self, user: 'u.User', uses: Optional[int] = None, - expire: Optional[datetime] = None) -> str: - if self.peer_type == "user": - raise ValueError("You can't invite users to private chats.") - if self.username: - return f"https://t.me/{self.username}" - link = await user.client(ExportChatInviteRequest(peer=await self.get_input_entity(user), - expire_date=expire, usage_limit=uses)) - return link.link - - # endregion - # region Matrix room cleanup - - async def get_authenticated_matrix_users(self) -> List[UserID]: - try: - members = await self.main_intent.get_room_members(self.mxid) - except MatrixRequestError: - return [] - authenticated: List[UserID] = [] - has_bot = self.has_bot - for member in members: - if p.Puppet.get_id_from_mxid(member) or member == self.az.bot_mxid: - continue - user = await u.User.get_by_mxid(member).ensure_started() - authenticated_through_bot = has_bot and user.relaybot_whitelisted - if authenticated_through_bot or await user.has_full_access(allow_bot=True): - authenticated.append(user.mxid) - return authenticated - - async def cleanup_portal(self, message: str, puppets_only: bool = False, delete: bool = True - ) -> None: - if self.username: - try: - await self.main_intent.remove_room_alias(self.alias_localpart) - except (MatrixRequestError, IntentError): - self.log.warning("Failed to remove alias when cleaning up room", exc_info=True) - await self.cleanup_room(self.main_intent, self.mxid, message, puppets_only) - if delete: - await self.delete() - - # endregion - # region Database conversion - - @property - def db_instance(self) -> DBPortal: - if not self._db_instance: - self._db_instance = self.new_db_instance() - return self._db_instance - - def new_db_instance(self) -> DBPortal: - return DBPortal(tgid=self.tgid, tg_receiver=self.tg_receiver, peer_type=self.peer_type, - mxid=self.mxid, username=self.username, megagroup=self.megagroup, - title=self.title, about=self.about, photo_id=self.photo_id, - config=json.dumps(self.local_config), avatar_url=self.avatar_url, - encrypted=self.encrypted) - - async def save(self) -> None: - self.db_instance.edit(mxid=self.mxid, username=self.username, title=self.title, - about=self.about, photo_id=self.photo_id, megagroup=self.megagroup, - config=json.dumps(self.local_config), avatar_url=self.avatar_url, - encrypted=self.encrypted) - - async def delete(self) -> None: - self.delete_sync() - - def delete_sync(self) -> None: - try: - del self.by_tgid[self.tgid_full] - except KeyError: - pass - try: - del self.by_mxid[self.mxid] - except KeyError: - pass - if self._db_instance: - self._db_instance.delete() - DBMessage.delete_all(self.mxid) - self.deleted = True - - @classmethod - def from_db(cls, db_portal: DBPortal) -> 'Portal': - return cls(tgid=db_portal.tgid, tg_receiver=db_portal.tg_receiver, - peer_type=db_portal.peer_type, mxid=db_portal.mxid, username=db_portal.username, - megagroup=db_portal.megagroup, title=db_portal.title, about=db_portal.about, - photo_id=db_portal.photo_id, local_config=db_portal.config, - avatar_url=db_portal.avatar_url, encrypted=db_portal.encrypted, - db_instance=db_portal) - - # endregion - # region Class instance lookup - - @classmethod - def all(cls) -> Iterable['Portal']: - for db_portal in DBPortal.all(): - try: - yield cls.by_tgid[(db_portal.tgid, db_portal.tg_receiver)] - except KeyError: - yield cls.from_db(db_portal) - - @classmethod - def get_by_mxid(cls, mxid: RoomID) -> Optional['Portal']: - try: - return cls.by_mxid[mxid] - except KeyError: - pass - - portal = DBPortal.get_by_mxid(mxid) - if portal: - return cls.from_db(portal) - - return None - - @classmethod - def get_username_from_mx_alias(cls, alias: str) -> Optional[str]: - return cls.alias_template.parse(alias) - - @classmethod - def find_by_username(cls, username: str) -> Optional['Portal']: - if not username: - return None - - username = username.lower() - - for _, portal in cls.by_tgid.items(): - if portal.username and portal.username.lower() == username: - return portal - - dbportal = DBPortal.get_by_username(username) - if dbportal: - return cls.from_db(dbportal) - - return None - - @classmethod - def get_by_tgid(cls, tgid: TelegramID, tg_receiver: Optional[TelegramID] = None, - peer_type: str = None) -> Optional['Portal']: - if peer_type == "user" and tg_receiver is None: - raise ValueError("tg_receiver is required when peer_type is \"user\"") - tg_receiver = tg_receiver or tgid - tgid_full = (tgid, tg_receiver) - try: - return cls.by_tgid[tgid_full] - except KeyError: - pass - - db_portal = DBPortal.get_by_tgid(tgid, tg_receiver) - if db_portal: - return cls.from_db(db_portal) - - if peer_type: - cls.log.info(f"Creating portal for {peer_type} {tgid} (receiver {tg_receiver})") - # TODO enable this for non-release builds - # (or add better wrong peer type error handling) - # if peer_type == "chat": - # import traceback - # cls.log.info("Chat portal stack trace:\n" + "".join(traceback.format_stack())) - portal = cls(tgid, peer_type=peer_type, tg_receiver=tg_receiver) - portal.db_instance.insert() - return portal - - return None - - @classmethod - def get_by_entity(cls, entity: Union[TypeChat, TypePeer, TypeUser, TypeUserFull, - TypeInputPeer], - receiver_id: Optional[TelegramID] = None, create: bool = True - ) -> Optional['Portal']: - entity_type = type(entity) - if entity_type in (Chat, ChatFull): - type_name = "chat" - entity_id = entity.id - elif entity_type in (PeerChat, InputPeerChat): - type_name = "chat" - entity_id = entity.chat_id - elif entity_type in (Channel, ChannelFull): - type_name = "channel" - entity_id = entity.id - elif entity_type in (PeerChannel, InputPeerChannel, InputChannel): - type_name = "channel" - entity_id = entity.channel_id - elif entity_type in (User, UserFull): - type_name = "user" - entity_id = entity.id - elif entity_type in (PeerUser, InputPeerUser, InputUser): - type_name = "user" - entity_id = entity.user_id - else: - raise ValueError(f"Unknown entity type {entity_type.__name__}") - return cls.get_by_tgid(TelegramID(entity_id), - receiver_id if type_name == "user" else entity_id, - type_name if create else None) - - # endregion - # region Abstract methods (cross-called in matrix/metadata/telegram classes) - - @abstractmethod - async def update_matrix_room(self, user: 'AbstractUser', entity: Union[TypeChat, User], - direct: bool, puppet: p.Puppet = None, - levels: PowerLevelStateEventContent = None, - users: List[User] = None) -> None: - pass - - @abstractmethod - async def create_matrix_room(self, user: 'AbstractUser', entity: TypeChat = None, - invites: InviteList = None, update_if_exists: bool = True - ) -> Optional[RoomID]: - pass - - @abstractmethod - async def _add_telegram_user(self, user_id: TelegramID, source: Optional['AbstractUser'] = None - ) -> None: - pass - - @abstractmethod - async def _delete_telegram_user(self, user_id: TelegramID, sender: p.Puppet) -> None: - pass - - @abstractmethod - async def _update_title(self, title: str, sender: Optional['p.Puppet'] = None, - save: bool = False) -> bool: - pass - - @abstractmethod - async def _update_avatar(self, user: 'AbstractUser', photo: Union[TypeChatPhoto], - sender: Optional['p.Puppet'] = None, save: bool = False) -> bool: - pass - - @abstractmethod - def _migrate_and_save_telegram(self, new_id: TelegramID) -> None: - pass - - @abstractmethod - async def update_bridge_info(self) -> None: - pass - - @abstractmethod - def handle_matrix_power_levels(self, sender: 'u.User', new_levels: Dict[UserID, int], - old_levels: Dict[UserID, int], event_id: Optional[EventID] - ) -> Awaitable[None]: - pass - - @abstractmethod - def backfill(self, source: 'AbstractUser', is_initial: bool = False, - limit: Optional[int] = None, last_id: Optional[int] = None) -> Awaitable[None]: - pass - - @abstractmethod - async def _send_delivery_receipt(self, event_id: EventID, room_id: Optional[RoomID] = None - ) -> None: - pass - - # endregion - - -def init(context: Context) -> None: - global config - BasePortal.az, config, BasePortal.loop, BasePortal.bot = context.core - BasePortal.matrix = context.mx - MautrixBasePortal.bridge = context.bridge - BasePortal.max_initial_member_sync = config["bridge.max_initial_member_sync"] - BasePortal.sync_channel_members = config["bridge.sync_channel_members"] - BasePortal.sync_matrix_state = config["bridge.sync_matrix_state"] - BasePortal.public_portals = config["bridge.public_portals"] - BasePortal.private_chat_portal_meta = config["bridge.private_chat_portal_meta"] - BasePortal.filter_mode = config["bridge.filter.mode"] - BasePortal.filter_list = config["bridge.filter.list"] - BasePortal.hs_domain = config["homeserver.domain"] - BasePortal.alias_template = SimpleTemplate(config["bridge.alias_template"], "groupname", - prefix="#", suffix=f":{BasePortal.hs_domain}") diff --git a/mautrix_telegram/portal/matrix.py b/mautrix_telegram/portal/matrix.py deleted file mode 100644 index bf29826e..00000000 --- a/mautrix_telegram/portal/matrix.py +++ /dev/null @@ -1,680 +0,0 @@ -# mautrix-telegram - A Matrix-Telegram puppeting bridge -# Copyright (C) 2020 Tulir Asokan -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Affero General Public License for more details. -# -# You should have received a copy of the GNU Affero General Public License -# along with this program. If not, see . -from typing import Awaitable, Dict, Optional, Union, Any, TYPE_CHECKING -from html import escape as escape_html -from string import Template -from abc import ABC - -import magic - -from telethon.tl.functions.messages import (EditChatPhotoRequest, EditChatTitleRequest, - UpdatePinnedMessageRequest, SetTypingRequest, - EditChatAboutRequest, UnpinAllMessagesRequest) -from telethon.tl.functions.channels import EditPhotoRequest, EditTitleRequest, JoinChannelRequest -from telethon.errors import (ChatNotModifiedError, PhotoExtInvalidError, MessageIdInvalidError, - PhotoInvalidDimensionsError, PhotoSaveFileInvalidError, RPCError) -from telethon.tl.patched import Message, MessageService -from telethon.tl.types import (DocumentAttributeFilename, DocumentAttributeImageSize, GeoPoint, - InputChatUploadedPhoto, MessageActionChatEditPhoto, MessageMediaGeo, - SendMessageCancelAction, SendMessageTypingAction, TypeInputPeer, - UpdateNewMessage, InputMediaUploadedDocument, - InputMediaUploadedPhoto) - -from mautrix.types import (EventID, EventType, RoomID, UserID, ContentURI, MessageType, - MessageEventContent, TextMessageEventContent, MediaMessageEventContent, - Format, LocationMessageEventContent, ImageInfo, VideoInfo) -from mautrix.util.message_send_checkpoint import MessageSendCheckpointStatus - -from ..types import TelegramID -from ..db import Message as DBMessage -from ..util import sane_mimetypes, parallel_transfer_to_telegram -from ..context import Context -from .. import puppet as p, user as u, formatter, util -from .base import BasePortal - -if TYPE_CHECKING: - from ..abstract_user import AbstractUser - from ..tgclient import MautrixTelegramClient - from ..config import Config - -try: - from mautrix.crypto.attachments import decrypt_attachment -except ImportError: - decrypt_attachment = None - -TypeMessage = Union[Message, MessageService] - -config: Optional['Config'] = None - - -class PortalMatrix(BasePortal, ABC): - async def _get_state_change_message(self, event: str, user: 'u.User', **kwargs: Any - ) -> Optional[str]: - tpl = self.get_config(f"state_event_formats.{event}") - if len(tpl) == 0: - # Empty format means they don't want the message - return None - displayname = await self.get_displayname(user) - - tpl_args = { - "mxid": user.mxid, - "username": user.mxid_localpart, - "displayname": escape_html(displayname), - **kwargs, - } - return Template(tpl).safe_substitute(tpl_args) - - async def _send_state_change_message(self, event: str, user: 'u.User', event_id: EventID, - **kwargs: Any) -> None: - if not self.has_bot: - return - elif self.peer_type == "user" and not config["bridge.relaybot.private_chat.state_changes"]: - return - async with self.send_lock(self.bot.tgid): - message = await self._get_state_change_message(event, user, **kwargs) - if not message: - return - message, entities = await formatter.matrix_to_telegram(self.bot.client, html=message) - response = await self.bot.client.send_message(self.peer, message, - formatting_entities=entities) - space = self.tgid if self.peer_type == "channel" else self.bot.tgid - self.dedup.check(response, (event_id, space)) - - async def name_change_matrix(self, user: 'u.User', displayname: str, prev_displayname: str, - event_id: EventID) -> None: - await self._send_state_change_message("name_change", user, event_id, - displayname=displayname, - prev_displayname=prev_displayname) - - async def get_displayname(self, user: 'u.User') -> str: - return await self.main_intent.get_room_displayname(self.mxid, user.mxid) or user.mxid - - def set_typing(self, user: 'u.User', typing: bool = True, - action: type = SendMessageTypingAction) -> Awaitable[bool]: - return user.client(SetTypingRequest( - self.peer, action() if typing else SendMessageCancelAction())) - - async def mark_read(self, user: 'u.User', event_id: EventID) -> None: - if user.is_bot: - return - space = self.tgid if self.peer_type == "channel" else user.tgid - message = DBMessage.get_by_mxid(event_id, self.mxid, space) - if not message: - message = DBMessage.find_last(self.mxid, space) - if not message: - self.log.debug(f"Dropping Matrix read receipt from {user.mxid}: " - f"target message {event_id} not known and last message" - " in chat not found") - return - else: - self.log.debug(f"Matrix read receipt target {event_id} not known, marking " - f"messages up to most recent ({message.mxid}/{message.tgid}) " - f"as read by {user.mxid}/{user.tgid}") - else: - self.log.debug("Handling Matrix read receipt: marking messages up to " - f"{message.mxid}/{message.tgid} as read by {user.mxid}/{user.tgid}") - await user.client.send_read_acknowledge(self.peer, max_id=message.tgid, - clear_mentions=True) - - async def _preproc_kick_ban(self, user: Union['u.User', 'p.Puppet'], source: 'u.User' - ) -> Optional['AbstractUser']: - if user.tgid == source.tgid: - return None - if self.peer_type == "user" and user.tgid == self.tgid: - await self.delete() - return None - if isinstance(user, u.User) and await user.needs_relaybot(self): - if not self.bot: - return None - # TODO kick message - return None - if await source.needs_relaybot(self): - if not self.has_bot: - return None - return self.bot - return source - - async def kick_matrix(self, user: Union['u.User', 'p.Puppet'], source: 'u.User') -> None: - source = await self._preproc_kick_ban(user, source) - if source is not None: - await source.client.kick_participant(self.peer, user.peer) - - async def ban_matrix(self, user: Union['u.User', 'p.Puppet'], source: 'u.User'): - source = await self._preproc_kick_ban(user, source) - if source is not None: - await source.client.edit_permissions(self.peer, user.peer, view_messages=False) - - async def leave_matrix(self, user: 'u.User', event_id: EventID) -> None: - if await user.needs_relaybot(self): - await self._send_state_change_message("leave", user, event_id) - return - - if self.peer_type == "user": - await self.main_intent.leave_room(self.mxid) - await self.delete() - try: - del self.by_tgid[self.tgid_full] - del self.by_mxid[self.mxid] - except KeyError: - pass - else: - await user.client.delete_dialog(self.peer) - - async def join_matrix(self, user: 'u.User', event_id: EventID) -> None: - if await user.needs_relaybot(self): - await self._send_state_change_message("join", user, event_id) - return - - if self.peer_type == "channel" and not user.is_bot: - await user.client(JoinChannelRequest(channel=await self.get_input_entity(user))) - else: - # We'll just assume the user is already in the chat. - pass - - async def _apply_msg_format(self, sender: 'u.User', content: MessageEventContent - ) -> None: - if not isinstance(content, TextMessageEventContent) or content.format != Format.HTML: - content.format = Format.HTML - content.formatted_body = escape_html(content.body).replace("\n", "
    ") - - tpl = (self.get_config(f"message_formats.[{content.msgtype.value}]") - or "$sender_displayname: $message") - displayname = await self.get_displayname(sender) - tpl_args = dict(sender_mxid=sender.mxid, - sender_username=sender.mxid_localpart, - sender_displayname=escape_html(displayname), - message=content.formatted_body, - body=content.body, formatted_body=content.formatted_body) - content.formatted_body = Template(tpl).safe_substitute(tpl_args) - - async def _apply_emote_format(self, sender: 'u.User', - content: TextMessageEventContent) -> None: - if content.format != Format.HTML: - content.format = Format.HTML - content.formatted_body = escape_html(content.body).replace("\n", "
    ") - - tpl = self.get_config("emote_format") - puppet = p.Puppet.get(sender.tgid) - content.formatted_body = Template(tpl).safe_substitute( - dict(sender_mxid=sender.mxid, - sender_username=sender.mxid_localpart, - sender_displayname=escape_html(await self.get_displayname(sender)), - mention=f"{puppet.displayname}", - username=sender.username, - displayname=puppet.displayname, - body=content.body, - formatted_body=content.formatted_body)) - content.msgtype = MessageType.TEXT - - async def _pre_process_matrix_message(self, sender: 'u.User', use_relaybot: bool, - content: MessageEventContent) -> None: - if use_relaybot: - await self._apply_msg_format(sender, content) - elif content.msgtype == MessageType.EMOTE: - await self._apply_emote_format(sender, content) - - async def _handle_matrix_text(self, sender: 'u.User', logged_in: bool, event_id: EventID, - space: TelegramID, client: 'MautrixTelegramClient', - content: TextMessageEventContent, reply_to: Optional[TelegramID] - ) -> None: - message, entities = await formatter.matrix_to_telegram(client, text=content.body, - html=content.formatted(Format.HTML)) - sender_id = sender.tgid if logged_in else self.bot.tgid - async with self.send_lock(sender_id): - lp = self.get_config("telegram_link_preview") - if content.get_edit(): - orig_msg = DBMessage.get_by_mxid(content.get_edit(), self.mxid, space) - if orig_msg: - response = await client.edit_message(self.peer, orig_msg.tgid, message, - formatting_entities=entities, - link_preview=lp) - self._add_telegram_message_to_db(event_id, space, -1, response) - return - try: - response = await client.send_message(self.peer, message, reply_to=reply_to, - formatting_entities=entities, - link_preview=lp) - except Exception: - raise - else: - sender.send_remote_checkpoint( - MessageSendCheckpointStatus.SUCCESS, - event_id, - self.mxid, - EventType.ROOM_MESSAGE, - message_type=content.msgtype, - ) - self._add_telegram_message_to_db(event_id, space, 0, response) - await self._send_delivery_receipt(event_id) - - async def _handle_matrix_file(self, sender: 'u.User', logged_in: bool, event_id: EventID, - space: TelegramID, client: 'MautrixTelegramClient', - content: MediaMessageEventContent, reply_to: TelegramID, - caption: TextMessageEventContent = None) -> None: - sender_id = sender.tgid if logged_in else self.bot.tgid - mime = content.info.mimetype - if isinstance(content.info, (ImageInfo, VideoInfo)): - w, h = content.info.width, content.info.height - else: - w = h = None - file_name = content["net.maunium.telegram.internal.filename"] - max_image_size = config["bridge.image_as_file_size"] * 1000 ** 2 - - if config["bridge.parallel_file_transfer"] and content.url: - file_handle, file_size = await parallel_transfer_to_telegram(client, self.main_intent, - content.url, sender_id) - else: - if content.file: - if not decrypt_attachment: - raise Exception(f"Can't bridge encrypted media event {event_id}: " - "encryption dependencies not installed") - file = await self.main_intent.download_media(content.file.url) - file = decrypt_attachment(file, content.file.key.key, - content.file.hashes.get("sha256"), content.file.iv) - else: - file = await self.main_intent.download_media(content.url) - - if content.msgtype == MessageType.STICKER: - if mime != "image/gif": - mime, file, w, h = util.convert_image(file, source_mime=mime, - target_type="webp") - else: - # Remove sticker description - file_name = "sticker.gif" - - file_handle = await client.upload_file(file) - file_size = len(file) - - file_handle.name = file_name - - attributes = [DocumentAttributeFilename(file_name=file_name)] - if w and h: - attributes.append(DocumentAttributeImageSize(w, h)) - - if (mime == "image/png" or mime == "image/jpeg") and file_size < max_image_size: - media = InputMediaUploadedPhoto(file_handle) - else: - media = InputMediaUploadedDocument(file=file_handle, attributes=attributes, - mime_type=mime or "application/octet-stream") - - capt, entities = (await formatter.matrix_to_telegram(client, text=caption.body, - html=caption.formatted(Format.HTML)) - if caption else (None, None)) - - async with self.send_lock(sender_id): - if await self._matrix_document_edit(client, content, space, capt, media, event_id): - return - try: - try: - response = await client.send_media(self.peer, media, reply_to=reply_to, - caption=capt, entities=entities) - except (PhotoInvalidDimensionsError, PhotoSaveFileInvalidError, PhotoExtInvalidError): - media = InputMediaUploadedDocument(file=media.file, mime_type=mime, - attributes=attributes) - response = await client.send_media(self.peer, media, reply_to=reply_to, - caption=capt, entities=entities) - except Exception: - raise - else: - sender.send_remote_checkpoint( - MessageSendCheckpointStatus.SUCCESS, - event_id, - self.mxid, - EventType.ROOM_MESSAGE, - message_type=content.msgtype, - ) - self._add_telegram_message_to_db(event_id, space, 0, response) - await self._send_delivery_receipt(event_id) - - async def _matrix_document_edit(self, client: 'MautrixTelegramClient', - content: MessageEventContent, space: TelegramID, - caption: str, media: Any, event_id: EventID) -> bool: - if content.get_edit(): - orig_msg = DBMessage.get_by_mxid(content.get_edit(), self.mxid, space) - if orig_msg: - response = await client.edit_message(self.peer, orig_msg.tgid, - caption, file=media) - self._add_telegram_message_to_db(event_id, space, -1, response) - await self._send_delivery_receipt(event_id) - return True - return False - - async def _handle_matrix_location(self, sender: 'u.User', logged_in: bool, event_id: EventID, - space: TelegramID, client: 'MautrixTelegramClient', - content: LocationMessageEventContent, reply_to: TelegramID - ) -> None: - sender_id = sender.tgid if logged_in else self.bot.tgid - try: - lat, long = content.geo_uri[len("geo:"):].split(";")[0].split(",") - lat, long = float(lat), float(long) - except (KeyError, ValueError): - self.log.exception("Failed to parse location") - return None - caption, entities = await formatter.matrix_to_telegram(client, text=content.body) - media = MessageMediaGeo(geo=GeoPoint(lat=lat, long=long, access_hash=0)) - - async with self.send_lock(sender_id): - if await self._matrix_document_edit(client, content, space, caption, media, event_id): - return - try: - response = await client.send_media(self.peer, media, reply_to=reply_to, - caption=caption, entities=entities) - except Exception: - raise - else: - self._add_telegram_message_to_db(event_id, space, 0, response) - sender.send_remote_checkpoint( - MessageSendCheckpointStatus.SUCCESS, - event_id, - self.mxid, - EventType.ROOM_MESSAGE, - message_type=content.msgtype, - ) - await self._send_delivery_receipt(event_id) - - def _add_telegram_message_to_db(self, event_id: EventID, space: TelegramID, - edit_index: int, response: TypeMessage) -> None: - self.log.trace("Handled Matrix message: %s", response) - self.dedup.check(response, (event_id, space), force_hash=edit_index != 0) - if edit_index < 0: - prev_edit = DBMessage.get_one_by_tgid(TelegramID(response.id), space, -1) - edit_index = prev_edit.edit_index + 1 - DBMessage( - tgid=TelegramID(response.id), - tg_space=space, - mx_room=self.mxid, - mxid=event_id, - edit_index=edit_index).insert() - - async def _send_bridge_error(self, sender: 'u.User', err: Exception, event_id: EventID, - event_type: EventType, - message_type: Optional[MessageType] = None, - msg: Optional[str] = None, confirmed: bool = False) -> None: - sender.send_remote_checkpoint( - MessageSendCheckpointStatus.PERM_FAILURE, - event_id, - self.mxid, - event_type, - message_type=message_type, - error=err, - ) - - if config["bridge.delivery_error_reports"]: - await self._send_message(self.main_intent, - TextMessageEventContent(msgtype=MessageType.NOTICE, body=msg)) - - async def handle_matrix_message(self, sender: 'u.User', content: MessageEventContent, - event_id: EventID) -> None: - try: - await self._handle_matrix_message(sender, content, event_id) - except RPCError as e: - self.log.exception(f"RPCError while bridging {event_id}: {e}") - await self._send_bridge_error( - sender, - e, - event_id, - EventType.ROOM_MESSAGE, - message_type=content.msgtype, - msg=f"\u26a0 Your message may not have been bridged: {e}", - ) - raise - except Exception as e: - self.log.exception(f"Failed to bridge {event_id}: {e}") - await self._send_bridge_error( - sender, - e, - event_id, - EventType.ROOM_MESSAGE, - message_type=content.msgtype, - ) - - async def _handle_matrix_message(self, sender: 'u.User', content: MessageEventContent, - event_id: EventID) -> None: - if not content.body or not content.msgtype: - self.log.debug(f"Ignoring message {event_id} in {self.mxid} without body or msgtype") - return - - logged_in = not await sender.needs_relaybot(self) - client = sender.client if logged_in else self.bot.client - space = (self.tgid if self.peer_type == "channel" # Channels have their own ID space - else (sender.tgid if logged_in else self.bot.tgid)) - reply_to = formatter.matrix_reply_to_telegram(content, space, room_id=self.mxid) - - media = (MessageType.STICKER, MessageType.IMAGE, MessageType.FILE, MessageType.AUDIO, - MessageType.VIDEO) - - if content.msgtype == MessageType.NOTICE: - bridge_notices = self.get_config("bridge_notices.default") - excepted = sender.mxid in self.get_config("bridge_notices.exceptions") - if not bridge_notices and not excepted: - raise Exception("Notices are not configured to be bridged.") - - if content.msgtype in (MessageType.TEXT, MessageType.EMOTE, MessageType.NOTICE): - await self._pre_process_matrix_message(sender, not logged_in, content) - await self._handle_matrix_text(sender, logged_in, event_id, space, client, content, - reply_to) - elif content.msgtype == MessageType.LOCATION: - await self._pre_process_matrix_message(sender, not logged_in, content) - await self._handle_matrix_location(sender, logged_in, event_id, space, client, content, - reply_to) - elif content.msgtype in media: - content["net.maunium.telegram.internal.filename"] = content.body - try: - caption_content: MessageEventContent = sender.command_status["caption"] - reply_to = reply_to or formatter.matrix_reply_to_telegram(caption_content, space, - room_id=self.mxid) - sender.command_status = None - except (KeyError, TypeError): - caption_content = None if logged_in else TextMessageEventContent(body=content.body) - if caption_content: - caption_content.msgtype = content.msgtype - await self._pre_process_matrix_message(sender, not logged_in, caption_content) - await self._handle_matrix_file(sender, logged_in, event_id, space, client, content, - reply_to, caption_content) - else: - self.log.debug(f"Didn't handle Matrix event {event_id} due to unknown msgtype {content.msgtype}") - self.log.trace("Unhandled Matrix event content: %s", content) - raise Exception(f"Unhandled msgtype {content.msgtype}") - - async def handle_matrix_unpin_all(self, sender: 'u.User', pin_event_id: EventID) -> None: - await sender.client(UnpinAllMessagesRequest(peer=self.peer)) - await self._send_delivery_receipt(pin_event_id) - - async def handle_matrix_pin(self, sender: 'u.User', changes: Dict[EventID, bool], - pin_event_id: EventID) -> None: - tg_space = self.tgid if self.peer_type == "channel" else sender.tgid - ids = {msg.mxid: msg.tgid - for msg in DBMessage.get_by_mxids(list(changes.keys()), - mx_room=self.mxid, tg_space=tg_space)} - for event_id, pinned in changes.items(): - try: - await sender.client(UpdatePinnedMessageRequest(peer=self.peer, id=ids[event_id], - unpin=not pinned)) - except (ChatNotModifiedError, MessageIdInvalidError, KeyError): - pass - await self._send_delivery_receipt(pin_event_id) - - async def handle_matrix_deletion(self, deleter: 'u.User', event_id: EventID, - redaction_event_id: EventID) -> None: - try: - await self._handle_matrix_deletion(deleter, event_id) - except Exception as e: - self.log.debug(str(e)) - await self._send_bridge_error(deleter, e, redaction_event_id, EventType.ROOM_REDACTION) - else: - deleter.send_remote_checkpoint( - MessageSendCheckpointStatus.SUCCESS, - redaction_event_id, - self.mxid, - EventType.ROOM_REDACTION, - ) - await self._send_delivery_receipt(redaction_event_id) - - async def _handle_matrix_deletion(self, deleter: 'u.User', event_id: EventID) -> None: - real_deleter = deleter if not await deleter.needs_relaybot(self) else self.bot - space = self.tgid if self.peer_type == "channel" else real_deleter.tgid - message = DBMessage.get_by_mxid(event_id, self.mxid, space) - if not message: - raise Exception(f"Ignoring Matrix redaction of unknown event {event_id}") - elif message.redacted: - raise Exception("Ignoring Matrix redaction of already redacted event " - f"{message.mxid} in {message.mx_room}") - elif message.edit_index != 0: - message.edit(redacted=True) - raise Exception("Ignoring Matrix redaction of edit event " - f"{message.mxid} in {message.mx_room}") - else: - message.edit(redacted=True) - await real_deleter.client.delete_messages(self.peer, [message.tgid]) - - async def _update_telegram_power_level(self, sender: 'u.User', user_id: TelegramID, - level: int) -> None: - moderator = level >= 50 - admin = level >= 75 - await sender.client.edit_admin(self.peer, user_id, - change_info=moderator, post_messages=moderator, - edit_messages=moderator, delete_messages=moderator, - ban_users=moderator, invite_users=moderator, - pin_messages=moderator, add_admins=admin) - - async def handle_matrix_power_levels(self, sender: 'u.User', new_users: Dict[UserID, int], - old_users: Dict[UserID, int], event_id: Optional[EventID] - ) -> None: - # TODO handle all power level changes and bridge exact admin rights to supergroups/channels - for user, level in new_users.items(): - if not user or user == self.main_intent.mxid or user == sender.mxid: - continue - user_id = p.Puppet.get_id_from_mxid(user) - if not user_id: - mx_user = u.User.get_by_mxid(user, create=False) - if not mx_user or not mx_user.tgid: - continue - user_id = mx_user.tgid - if not user_id or user_id == sender.tgid: - continue - if user not in old_users or level != old_users[user]: - await self._update_telegram_power_level(sender, user_id, level) - - async def handle_matrix_about(self, sender: 'u.User', about: str, event_id: EventID) -> None: - if self.peer_type not in ("chat", "channel"): - return - peer = await self.get_input_entity(sender) - await sender.client(EditChatAboutRequest(peer=peer, about=about)) - self.about = about - await self.save() - await self._send_delivery_receipt(event_id) - - async def handle_matrix_title(self, sender: 'u.User', title: str, event_id: EventID) -> None: - if self.peer_type not in ("chat", "channel"): - return - - if self.peer_type == "chat": - response = await sender.client(EditChatTitleRequest(chat_id=self.tgid, title=title)) - else: - channel = await self.get_input_entity(sender) - response = await sender.client(EditTitleRequest(channel=channel, title=title)) - self.dedup.register_outgoing_actions(response) - self.title = title - await self.save() - await self._send_delivery_receipt(event_id) - await self.update_bridge_info() - - async def handle_matrix_avatar(self, sender: 'u.User', url: ContentURI, event_id: EventID - ) -> None: - if self.peer_type not in ("chat", "channel"): - # Invalid peer type - return - elif self.avatar_url == url: - return - - self.avatar_url = url - file = await self.main_intent.download_media(url) - mime = magic.from_buffer(file, mime=True) - ext = sane_mimetypes.guess_extension(mime) - uploaded = await sender.client.upload_file(file, file_name=f"avatar{ext}") - photo = InputChatUploadedPhoto(file=uploaded) - - if self.peer_type == "chat": - response = await sender.client(EditChatPhotoRequest(chat_id=self.tgid, photo=photo)) - else: - channel = await self.get_input_entity(sender) - response = await sender.client(EditPhotoRequest(channel=channel, photo=photo)) - self.dedup.register_outgoing_actions(response) - for update in response.updates: - is_photo_update = (isinstance(update, UpdateNewMessage) - and isinstance(update.message, MessageService) - and isinstance(update.message.action, MessageActionChatEditPhoto)) - if is_photo_update: - loc, size = self._get_largest_photo_size(update.message.action.photo) - self.photo_id = f"{size.location.volume_id}-{size.location.local_id}" - await self.save() - break - await self._send_delivery_receipt(event_id) - await self.update_bridge_info() - - async def handle_matrix_upgrade(self, sender: UserID, new_room: RoomID, event_id: EventID - ) -> None: - _, server = self.main_intent.parse_user_id(sender) - old_room = self.mxid - self.migrate_and_save_matrix(new_room) - await self.main_intent.join_room(new_room, servers=[server]) - entity: Optional[TypeInputPeer] = None - user: Optional[AbstractUser] = None - if self.bot and self.has_bot: - user = self.bot - entity = await self.get_input_entity(self.bot) - if not entity: - user_mxids = await self.main_intent.get_room_members(self.mxid) - for user_str in user_mxids: - user_id = UserID(user_str) - if user_id == self.az.bot_mxid: - continue - user = u.User.get_by_mxid(user_id, create=False) - if user and user.tgid: - entity = await self.get_input_entity(user) - if entity: - break - if not entity: - self.log.error("Failed to fully migrate to upgraded Matrix room: " - "no Telegram user found.") - return - await self.update_matrix_room(user, entity, direct=self.peer_type == "user") - self.log.info(f"{sender} upgraded room from {old_room} to {self.mxid}") - await self._send_delivery_receipt(event_id, room_id=old_room) - - def migrate_and_save_matrix(self, new_id: RoomID) -> None: - try: - del self.by_mxid[self.mxid] - except KeyError: - pass - self.mxid = new_id - self.db_instance.edit(mxid=self.mxid) - self.by_mxid[self.mxid] = self - - async def enable_dm_encryption(self) -> bool: - ok = await super().enable_dm_encryption() - if ok: - try: - puppet = p.Puppet.get(self.tgid) - await self.main_intent.set_room_name(self.mxid, puppet.displayname) - except Exception: - self.log.warning(f"Failed to set room name", exc_info=True) - return ok - - -def init(context: Context) -> None: - global config - config = context.config diff --git a/mautrix_telegram/portal/metadata.py b/mautrix_telegram/portal/metadata.py deleted file mode 100644 index 34484d61..00000000 --- a/mautrix_telegram/portal/metadata.py +++ /dev/null @@ -1,875 +0,0 @@ -# mautrix-telegram - A Matrix-Telegram puppeting bridge -# Copyright (C) 2020 Tulir Asokan -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Affero General Public License for more details. -# -# You should have received a copy of the GNU Affero General Public License -# along with this program. If not, see . -from typing import List, Optional, Iterable, Union, Dict, Any, Tuple, TYPE_CHECKING -from abc import ABC -import asyncio - -from telethon.tl.functions.messages import (AddChatUserRequest, CreateChatRequest, - GetFullChatRequest, MigrateChatRequest) -from telethon.tl.functions.channels import (CreateChannelRequest, GetParticipantsRequest, - InviteToChannelRequest, UpdateUsernameRequest) -from telethon.errors import ChatAdminRequiredError -from telethon.tl.types import ( - Channel, ChatBannedRights, ChannelParticipantsRecent, ChannelParticipantsSearch, ChatPhoto, - PhotoEmpty, InputChannel, InputUser, ChatPhotoEmpty, PeerUser, Photo, TypeChat, TypeInputPeer, - TypeUser, User, InputPeerPhotoFileLocation, ChatParticipantAdmin, ChannelParticipantAdmin, - ChatParticipantCreator, ChannelParticipantCreator, UserProfilePhoto, UserProfilePhotoEmpty, - InputPeerUser, ChannelParticipantBanned) - -from mautrix.errors import MForbidden -from mautrix.types import (RoomID, UserID, RoomCreatePreset, EventType, Membership, - PowerLevelStateEventContent, RoomTopicStateEventContent, - RoomNameStateEventContent, RoomAvatarStateEventContent, - StateEventContent, EventID, JoinRule) -from mautrix.appservice import DOUBLE_PUPPET_SOURCE_KEY - -from ..types import TelegramID -from ..context import Context -from .. import puppet as p, user as u, util -from .base import BasePortal, InviteList, TypeParticipant, TypeChatPhoto - -if TYPE_CHECKING: - from ..abstract_user import AbstractUser - from ..config import Config - -config: Optional['Config'] = None - -StateBridge = EventType.find("m.bridge", EventType.Class.STATE) -StateHalfShotBridge = EventType.find("uk.half-shot.bridge", EventType.Class.STATE) - - -class PortalMetadata(BasePortal, ABC): - _room_create_lock: asyncio.Lock - - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - self._room_create_lock = asyncio.Lock() - - # region Matrix -> Telegram - - async def get_telegram_users_in_matrix_room(self, source: 'u.User' - ) -> Tuple[List[InputPeerUser], List[UserID]]: - user_tgids = {} - user_mxids = await self.main_intent.get_room_members(self.mxid, (Membership.JOIN, - Membership.INVITE)) - for mxid in user_mxids: - if mxid == self.az.bot_mxid: - continue - mx_user = u.User.get_by_mxid(mxid, create=False) - if mx_user and mx_user.tgid: - user_tgids[mx_user.tgid] = mxid - puppet_id = p.Puppet.get_id_from_mxid(mxid) - if puppet_id: - user_tgids[puppet_id] = mxid - input_users = [] - errors = [] - for tgid, mxid in user_tgids.items(): - try: - input_users.append(await source.client.get_input_entity(tgid)) - except ValueError as e: - source.log.debug(f"Failed to find the input entity for {tgid} ({mxid}) for " - f"creating a group: {e}") - errors.append(mxid) - return input_users, errors - - async def upgrade_telegram_chat(self, source: 'u.User') -> None: - if self.peer_type != "chat": - raise ValueError("Only normal group chats are upgradable to supergroups.") - - response = await source.client(MigrateChatRequest(chat_id=self.tgid)) - entity = None - for chat in response.chats: - if isinstance(chat, Channel): - entity = chat - break - if not entity: - raise ValueError("Upgrade may have failed: output channel not found.") - self.peer_type = "channel" - self._migrate_and_save_telegram(TelegramID(entity.id)) - await self.update_info(source, entity) - - def _migrate_and_save_telegram(self, new_id: TelegramID) -> None: - try: - del self.by_tgid[self.tgid_full] - except KeyError: - pass - try: - existing = self.by_tgid[(new_id, new_id)] - existing.delete_sync() - except KeyError: - pass - self.db_instance.edit(tgid=new_id, tg_receiver=new_id, peer_type=self.peer_type) - old_id = self.tgid - self.tgid = new_id - self.tg_receiver = new_id - self.by_tgid[self.tgid_full] = self - self.log = self.base_log.getChild(self.tgid_log) - self.log.info(f"Telegram chat upgraded from {old_id}") - - async def set_telegram_username(self, source: 'u.User', username: str) -> None: - if self.peer_type != "channel": - raise ValueError("Only channels and supergroups have usernames.") - await source.client( - UpdateUsernameRequest(await self.get_input_entity(source), username)) - if await self._update_username(username): - await self.save() - - async def create_telegram_chat(self, source: 'u.User', invites: List[InputUser], - supergroup: bool = False) -> None: - if not self.mxid: - raise ValueError("Can't create Telegram chat for portal without Matrix room.") - elif self.tgid: - raise ValueError("Can't create Telegram chat for portal with existing Telegram chat.") - - if len(invites) < 2: - if self.bot is not None: - info, mxid = await self.bot.get_me() - raise ValueError("Not enough Telegram users to create a chat. " - "Invite more Telegram ghost users to the room, such as the " - f"relaybot ([{info.first_name}](https://matrix.to/#/{mxid})).") - raise ValueError("Not enough Telegram users to create a chat. " - "Invite more Telegram ghost users to the room.") - if self.peer_type == "chat": - response = await source.client(CreateChatRequest(title=self.title, users=invites)) - entity = response.chats[0] - elif self.peer_type == "channel": - response = await source.client(CreateChannelRequest(title=self.title, - about=self.about or "", - megagroup=supergroup)) - entity = response.chats[0] - await source.client(InviteToChannelRequest( - channel=await source.client.get_input_entity(entity), - users=invites)) - else: - raise ValueError("Invalid peer type for Telegram chat creation") - - self.tgid = entity.id - self.tg_receiver = self.tgid - self.by_tgid[self.tgid_full] = self - await self.update_info(source, entity) - self.db_instance.insert() - self.log = self.base_log.getChild(self.tgid_log) - - if self.bot and self.bot.tgid in invites: - self.bot.add_chat(self.tgid, self.peer_type) - - levels = await self.main_intent.get_power_levels(self.mxid) - if levels.get_user_level(self.main_intent.mxid) == 100: - levels = self._get_base_power_levels(levels, entity) - await self.main_intent.set_power_levels(self.mxid, levels) - await self.handle_matrix_power_levels(source, levels.users, {}, None) - await self.update_bridge_info() - - async def invite_telegram(self, source: 'u.User', - puppet: Union[p.Puppet, 'AbstractUser']) -> None: - if self.peer_type == "chat": - await source.client( - AddChatUserRequest(chat_id=self.tgid, user_id=puppet.tgid, fwd_limit=0)) - elif self.peer_type == "channel": - await source.client(InviteToChannelRequest(channel=self.peer, users=[puppet.tgid])) - # We don't care if there are invites for private chat portals with the relaybot. - elif not self.bot or self.tg_receiver != self.bot.tgid: - raise ValueError("Invalid peer type for Telegram user invite") - - # endregion - # region Telegram -> Matrix - - def _get_invite_content(self, double_puppet: Optional['p.Puppet']) -> Dict[str, Any]: - invite_content = {} - if double_puppet: - invite_content["fi.mau.will_auto_accept"] = True - if self.is_direct: - invite_content["is_direct"] = True - return invite_content - - async def invite_to_matrix(self, users: InviteList) -> None: - if isinstance(users, list): - for user in users: - await self.invite_to_matrix(user) - else: - puppet = await p.Puppet.get_by_custom_mxid(users) - await self.main_intent.invite_user(self.mxid, users, check_cache=True, - extra_content=self._get_invite_content(puppet)) - if puppet: - try: - await puppet.intent.ensure_joined(self.mxid) - except Exception: - self.log.exception("Failed to ensure %s is joined to portal", users) - - async def update_matrix_room(self, user: 'AbstractUser', entity: Union[TypeChat, User], - direct: bool = None, puppet: p.Puppet = None, - levels: PowerLevelStateEventContent = None, - users: List[User] = None) -> None: - if direct is None: - direct = self.peer_type == "user" - try: - await self._update_matrix_room(user, entity, direct, puppet, levels, users) - except Exception: - self.log.exception("Fatal error updating Matrix room") - - async def _update_matrix_room(self, user: 'AbstractUser', entity: Union[TypeChat, User], - direct: bool, puppet: p.Puppet = None, - levels: PowerLevelStateEventContent = None, - users: List[User] = None) -> None: - if not direct: - await self.update_info(user, entity) - if not users: - users = await self._get_users(user, entity) - await self._sync_telegram_users(user, users) - await self.update_power_levels(users, levels) - else: - if not puppet: - puppet = p.Puppet.get(self.tgid) - await puppet.update_info(user, entity) - await puppet.intent_for(self).join_room(self.mxid) - if self.encrypted or self.private_chat_portal_meta: - # The bridge bot needs to join for e2ee, but that messes up the default name - # generation. If/when canonical DMs happen, this might not be necessary anymore. - changed = await self._update_title(puppet.displayname) - changed = await self._update_avatar(user, entity.photo) or changed - if changed: - await self.save() - await self.update_bridge_info() - - puppet = await p.Puppet.get_by_custom_mxid(user.mxid) - if puppet: - try: - did_join = await puppet.intent.ensure_joined(self.mxid) - if isinstance(user, u.User) and did_join and self.peer_type == "user": - await user.update_direct_chats({self.main_intent.mxid: [self.mxid]}) - except Exception: - self.log.exception("Failed to ensure %s is joined to portal", user.mxid) - - if self.sync_matrix_state: - await self.main_intent.get_joined_members(self.mxid) - - async def create_matrix_room(self, user: 'AbstractUser', entity: Union[TypeChat, User] = None, - invites: InviteList = None, update_if_exists: bool = True - ) -> Optional[RoomID]: - if self.mxid: - if update_if_exists: - if not entity: - try: - entity = await self.get_entity(user) - except Exception: - self.log.exception(f"Failed to get entity through {user.tgid} for update") - return self.mxid - update = self.update_matrix_room(user, entity, self.peer_type == "user") - self.loop.create_task(update) - await self.invite_to_matrix(invites or []) - return self.mxid - async with self._room_create_lock: - try: - return await self._create_matrix_room(user, entity, invites) - except Exception: - self.log.exception("Fatal error creating Matrix room") - - @property - def bridge_info_state_key(self) -> str: - return f"net.maunium.telegram://telegram/{self.tgid}" - - @property - def bridge_info(self) -> Dict[str, Any]: - info = { - "bridgebot": self.az.bot_mxid, - "creator": self.main_intent.mxid, - "protocol": { - "id": "telegram", - "displayname": "Telegram", - "avatar_url": config["appservice.bot_avatar"], - "external_url": "https://telegram.org", - }, - "channel": { - "id": str(self.tgid), - "displayname": self.title, - "avatar_url": self.avatar_url, - } - } - if self.username: - info["channel"]["external_url"] = f"https://t.me/{self.username}" - elif self.peer_type == "user": - puppet = p.Puppet.get(self.tgid) - if puppet and puppet.username: - info["channel"]["external_url"] = f"https://t.me/{puppet.username}" - return info - - async def update_bridge_info(self) -> None: - if not self.mxid: - self.log.debug("Not updating bridge info: no Matrix room created") - return - try: - self.log.debug("Updating bridge info...") - await self.main_intent.send_state_event(self.mxid, StateBridge, - self.bridge_info, self.bridge_info_state_key) - # TODO remove this once https://github.com/matrix-org/matrix-doc/pull/2346 is in spec - await self.main_intent.send_state_event(self.mxid, StateHalfShotBridge, - self.bridge_info, self.bridge_info_state_key) - except Exception: - self.log.warning("Failed to update bridge info", exc_info=True) - - async def _create_matrix_room(self, user: 'AbstractUser', entity: Union[TypeChat, User], - invites: InviteList) -> Optional[RoomID]: - if self.mxid: - return self.mxid - elif not self.allow_bridging: - return None - - direct = self.peer_type == "user" - invites = invites or [] - - if not entity: - entity = await self.get_entity(user) - self.log.trace("Fetched data: %s", entity) - - self.log.debug("Creating room") - - try: - self.title = entity.title - except AttributeError: - self.title = None - - if direct and self.tgid == user.tgid: - self.title = "Telegram Saved Messages" - self.about = "Your Telegram cloud storage chat" - - puppet = p.Puppet.get(self.tgid) if direct else None - if puppet: - await puppet.update_info(user, entity) - self._main_intent = puppet.intent_for(self) if direct else self.az.intent - - if self.peer_type == "channel": - self.megagroup = entity.megagroup - - preset = RoomCreatePreset.PRIVATE - if self.peer_type == "channel" and entity.username: - if self.public_portals: - preset = RoomCreatePreset.PUBLIC - self.username = entity.username - alias = self.alias_localpart - else: - # TODO invite link alias? - alias = None - - if alias: - # TODO? properly handle existing room aliases - await self.main_intent.remove_room_alias(alias) - - power_levels = self._get_base_power_levels(entity=entity) - users = None - if not direct: - users = await self._get_users(user, entity) - if self.has_bot: - extra_invites = config["bridge.relaybot.group_chat_invite"] - invites += extra_invites - for invite in extra_invites: - power_levels.users.setdefault(invite, 100) - await self._participants_to_power_levels(users, power_levels) - elif self.bot and self.tg_receiver == self.bot.tgid: - invites = config["bridge.relaybot.private_chat.invite"] - for invite in invites: - power_levels.users.setdefault(invite, 100) - self.title = puppet.displayname - - initial_state = [{ - "type": EventType.ROOM_POWER_LEVELS.serialize(), - "content": power_levels.serialize(), - }, { - "type": str(StateBridge), - "state_key": self.bridge_info_state_key, - "content": self.bridge_info, - }, { - # TODO remove this once https://github.com/matrix-org/matrix-doc/pull/2346 is in spec - "type": str(StateHalfShotBridge), - "state_key": self.bridge_info_state_key, - "content": self.bridge_info, - }] - create_invites = [] - if config["bridge.encryption.default"] and self.matrix.e2ee: - self.encrypted = True - initial_state.append({ - "type": "m.room.encryption", - "content": {"algorithm": "m.megolm.v1.aes-sha2"}, - }) - if direct: - create_invites.append(self.az.bot_mxid) - if direct and (self.encrypted or self.private_chat_portal_meta): - self.title = puppet.displayname - if config["appservice.community_id"]: - initial_state.append({ - "type": "m.room.related_groups", - "content": {"groups": [config["appservice.community_id"]]}, - }) - creation_content = {} - if not config["bridge.federate_rooms"]: - creation_content["m.federate"] = False - - with self.backfill_lock: - room_id = await self.main_intent.create_room(alias_localpart=alias, preset=preset, - is_direct=direct, invitees=create_invites, - name=self.title, topic=self.about, - initial_state=initial_state, - creation_content=creation_content) - if not room_id: - raise Exception(f"Failed to create room") - - if self.encrypted and self.matrix.e2ee and direct: - try: - await self.az.intent.ensure_joined(room_id) - except Exception: - self.log.warning(f"Failed to add bridge bot to new private chat {room_id}") - - self.mxid = room_id - self.by_mxid[self.mxid] = self - await self.save() - await self.az.state_store.set_power_levels(self.mxid, power_levels) - await user.register_portal(self) - - await self.invite_to_matrix(invites) - - update_room = self.loop.create_task(self.update_matrix_room( - user, entity, direct, puppet, - levels=power_levels, users=users)) - - if config["bridge.backfill.initial_limit"] > 0: - self.log.debug("Initial backfill is enabled. Waiting for room members to sync " - "and then starting backfill") - await update_room - - try: - await self.backfill(user, is_initial=True) - except Exception: - self.log.exception("Failed to backfill new portal") - - return self.mxid - - def _get_base_power_levels(self, levels: PowerLevelStateEventContent = None, - entity: TypeChat = None) -> PowerLevelStateEventContent: - levels = levels or PowerLevelStateEventContent() - if self.peer_type == "user": - overrides = config["bridge.initial_power_level_overrides.user"] - levels.ban = overrides.get("ban", 100) - levels.kick = overrides.get("kick", 100) - levels.invite = overrides.get("invite", 100) - levels.redact = overrides.get("redact", 0) - levels.events[EventType.ROOM_NAME] = 0 - levels.events[EventType.ROOM_AVATAR] = 0 - levels.events[EventType.ROOM_TOPIC] = 0 - levels.state_default = overrides.get("state_default", 0) - levels.users_default = overrides.get("users_default", 0) - levels.events_default = overrides.get("events_default", 0) - else: - overrides = config["bridge.initial_power_level_overrides.group"] - dbr = entity.default_banned_rights - if not dbr: - self.log.debug(f"default_banned_rights is None in {entity}") - dbr = ChatBannedRights(invite_users=True, change_info=True, pin_messages=True, - send_stickers=False, send_messages=False, until_date=None) - levels.ban = overrides.get("ban", 50) - levels.kick = overrides.get("kick", 50) - levels.redact = overrides.get("redact", 50) - levels.invite = overrides.get("invite", 50 if dbr.invite_users else 0) - levels.events[EventType.ROOM_ENCRYPTION] = 50 if self.matrix.e2ee else 99 - levels.events[EventType.ROOM_TOMBSTONE] = 99 - levels.events[EventType.ROOM_NAME] = 50 if dbr.change_info else 0 - levels.events[EventType.ROOM_AVATAR] = 50 if dbr.change_info else 0 - levels.events[EventType.ROOM_TOPIC] = 50 if dbr.change_info else 0 - levels.events[EventType.ROOM_PINNED_EVENTS] = 50 if dbr.pin_messages else 0 - levels.events[EventType.ROOM_POWER_LEVELS] = 75 - levels.events[EventType.ROOM_HISTORY_VISIBILITY] = 75 - levels.events[EventType.STICKER] = 50 if dbr.send_stickers else levels.events_default - levels.state_default = overrides.get("state_default", 50) - levels.users_default = overrides.get("users_default", 0) - levels.events_default = ( - overrides.get("events_default", - 50 if (self.peer_type == "channel" and not entity.megagroup - or entity.default_banned_rights.send_messages) - else 0)) - for evt_type, value in overrides.get("events", {}).items(): - levels.events[EventType.find(evt_type)] = value - levels.users = overrides.get("users", {}) - if self.main_intent.mxid not in levels.users: - levels.users[self.main_intent.mxid] = 100 - return levels - - @classmethod - def _get_level_from_participant(cls, participant: TypeParticipant, - levels: PowerLevelStateEventContent) -> int: - # TODO use the power level requirements to get better precision in channels - if isinstance(participant, (ChatParticipantAdmin, ChannelParticipantAdmin)): - return levels.state_default or 50 - elif isinstance(participant, (ChatParticipantCreator, ChannelParticipantCreator)): - return levels.get_user_level(cls.az.bot_mxid) - 5 - return levels.users_default or 0 - - @staticmethod - def _participant_to_power_levels(levels: PowerLevelStateEventContent, - user: Union['u.User', p.Puppet], new_level: int, - bot_level: int) -> bool: - new_level = min(new_level, bot_level) - user_level = levels.get_user_level(user.mxid) - if user_level != new_level and user_level < bot_level: - levels.users[user.mxid] = new_level - return True - return False - - async def _participants_to_power_levels(self, users: List[Union[TypeUser, TypeParticipant]], - levels: PowerLevelStateEventContent) -> bool: - bot_level = levels.get_user_level(self.main_intent.mxid) - if bot_level < levels.get_event_level(EventType.ROOM_POWER_LEVELS): - return False - changed = False - admin_power_level = min(75 if self.peer_type == "channel" else 50, bot_level) - if levels.get_event_level(EventType.ROOM_POWER_LEVELS) != admin_power_level: - changed = True - levels.events[EventType.ROOM_POWER_LEVELS] = admin_power_level - - for user in users: - # The User objects we get from TelegramClient.get_participants have a custom - # participant property - participant = getattr(user, "participant", user) - - puppet = p.Puppet.get(TelegramID(participant.user_id)) - user = u.User.get_by_tgid(TelegramID(participant.user_id)) - new_level = self._get_level_from_participant(participant, levels) - - if user: - await user.register_portal(self) - changed = self._participant_to_power_levels(levels, user, new_level, - bot_level) or changed - - if puppet: - changed = self._participant_to_power_levels(levels, puppet, new_level, - bot_level) or changed - return changed - - async def update_power_levels(self, users: List[Union[TypeUser, TypeParticipant]], - levels: PowerLevelStateEventContent = None) -> None: - if not levels: - levels = await self.main_intent.get_power_levels(self.mxid) - if await self._participants_to_power_levels(users, levels): - await self.main_intent.set_power_levels(self.mxid, levels) - - async def _add_bot_chat(self, bot: User) -> None: - if self.bot and bot.id == self.bot.tgid: - self.bot.add_chat(self.tgid, self.peer_type) - return - - user = u.User.get_by_tgid(TelegramID(bot.id)) - if user and user.is_bot: - await user.register_portal(self) - - async def _sync_telegram_users(self, source: 'AbstractUser', users: List[User]) -> None: - allowed_tgids = set() - skip_deleted = config["bridge.skip_deleted_members"] - for entity in users: - puppet = p.Puppet.get(TelegramID(entity.id)) - if entity.bot: - await self._add_bot_chat(entity) - allowed_tgids.add(entity.id) - - await puppet.update_info(source, entity) - if skip_deleted and entity.deleted: - continue - - await puppet.intent_for(self).ensure_joined(self.mxid) - - user = u.User.get_by_tgid(TelegramID(entity.id)) - if user: - await self.invite_to_matrix(user.mxid) - - # We can't trust the member list if any of the following cases is true: - # * There are close to 10 000 users, because Telegram might not be sending all members. - # * The member sync count is limited, because then we might ignore some members. - # * It's a channel, because non-admins don't have access to the member list. - trust_member_list = ((len(allowed_tgids) < 9900 - if self.max_initial_member_sync < 0 - else len(allowed_tgids) < self.max_initial_member_sync - 10) - and (self.megagroup or self.peer_type != "channel")) - if not trust_member_list: - return - - for user_mxid in await self.main_intent.get_room_members(self.mxid): - if user_mxid == self.az.bot_mxid: - continue - - puppet_id = p.Puppet.get_id_from_mxid(user_mxid) - if puppet_id: - if puppet_id in allowed_tgids: - continue - if self.bot and puppet_id == self.bot.tgid: - self.bot.remove_chat(self.tgid) - try: - await self.main_intent.kick_user(self.mxid, user_mxid, - "User had left this Telegram chat.") - except MForbidden: - pass - continue - - mx_user = u.User.get_by_mxid(user_mxid, create=False) - if mx_user: - if mx_user.tgid in allowed_tgids: - continue - if mx_user.is_bot: - await mx_user.unregister_portal(*self.tgid_full) - if not self.has_bot: - try: - await self.main_intent.kick_user(self.mxid, mx_user.mxid, - "You had left this Telegram chat.") - except MForbidden: - pass - - async def _add_telegram_user(self, user_id: TelegramID, source: Optional['AbstractUser'] = None - ) -> None: - puppet = p.Puppet.get(user_id) - if source: - entity: User = await source.client.get_entity(PeerUser(user_id)) - await puppet.update_info(source, entity) - await puppet.intent_for(self).ensure_joined(self.mxid) - - user = u.User.get_by_tgid(user_id) - if user: - await user.register_portal(self) - await self.invite_to_matrix(user.mxid) - - async def _delete_telegram_user(self, user_id: TelegramID, sender: p.Puppet) -> None: - puppet = p.Puppet.get(user_id) - user = u.User.get_by_tgid(user_id) - kick_message = (f"Kicked by {sender.displayname}" - if sender and sender.tgid != puppet.tgid - else "Left Telegram chat") - if sender.tgid != puppet.tgid: - try: - await sender.intent_for(self).kick_user(self.mxid, puppet.mxid) - except MForbidden: - await self.main_intent.kick_user(self.mxid, puppet.mxid, kick_message) - else: - await puppet.intent_for(self).leave_room(self.mxid) - if user: - await user.unregister_portal(*self.tgid_full) - if sender.tgid != puppet.tgid: - try: - await sender.intent_for(self).kick_user(self.mxid, puppet.mxid) - return - except MForbidden: - pass - try: - await self.main_intent.kick_user(self.mxid, user.mxid, kick_message) - except MForbidden as e: - self.log.warning(f"Failed to kick {user.mxid}: {e}") - - async def update_info(self, user: 'AbstractUser', entity: TypeChat = None) -> None: - if self.peer_type == "user": - self.log.warning("Called update_info() for direct chat portal") - return - - changed = False - self.log.debug("Updating info") - try: - if not entity: - entity = await self.get_entity(user) - self.log.trace("Fetched data: %s", entity) - - if self.peer_type == "channel": - changed = self.megagroup != entity.megagroup or changed - self.megagroup = entity.megagroup - changed = await self._update_username(entity.username) or changed - - if hasattr(entity, "about"): - changed = self._update_about(entity.about) or changed - - changed = await self._update_title(entity.title) or changed - - if isinstance(entity.photo, ChatPhoto): - changed = await self._update_avatar(user, entity.photo) or changed - except Exception: - self.log.exception(f"Failed to update info from source {user.tgid}") - - if changed: - await self.save() - await self.update_bridge_info() - - async def _update_username(self, username: str, save: bool = False) -> bool: - if self.username == username: - return False - - if self.username: - await self.main_intent.remove_room_alias(self.alias_localpart) - self.username = username or None - if self.username: - await self.main_intent.add_room_alias(self.mxid, self.alias_localpart, override=True) - if self.public_portals: - await self.main_intent.set_join_rule(self.mxid, JoinRule.PUBLIC) - else: - await self.main_intent.set_join_rule(self.mxid, JoinRule.INVITE) - - if save: - await self.save() - return True - - async def _try_set_state(self, sender: Optional['p.Puppet'], evt_type: EventType, - content: StateEventContent) -> None: - if sender: - try: - intent = sender.intent_for(self) - if sender.is_real_user: - content[DOUBLE_PUPPET_SOURCE_KEY] = self.bridge.name - await intent.send_state_event(self.mxid, evt_type, content) - except MForbidden: - await self.main_intent.send_state_event(self.mxid, evt_type, content) - else: - await self.main_intent.send_state_event(self.mxid, evt_type, content) - - async def _update_about(self, about: str, sender: Optional['p.Puppet'] = None, - save: bool = False) -> bool: - if self.about == about: - return False - - self.about = about - await self._try_set_state(sender, EventType.ROOM_TOPIC, - RoomTopicStateEventContent(topic=self.about)) - if save: - await self.save() - return True - - async def _update_title(self, title: str, sender: Optional['p.Puppet'] = None, - save: bool = False) -> bool: - if self.title == title: - return False - - self.title = title - await self._try_set_state(sender, EventType.ROOM_NAME, - RoomNameStateEventContent(name=self.title)) - if save: - await self.save() - return True - - async def _update_avatar(self, user: 'AbstractUser', photo: TypeChatPhoto, - sender: Optional['p.Puppet'] = None, save: bool = False) -> bool: - if isinstance(photo, (ChatPhoto, UserProfilePhoto)): - loc = InputPeerPhotoFileLocation( - peer=await self.get_input_entity(user), - photo_id=photo.photo_id, - big=True - ) - photo_id = str(photo.photo_id) - elif isinstance(photo, Photo): - loc, _ = self._get_largest_photo_size(photo) - photo_id = str(loc.id) - elif isinstance(photo, (UserProfilePhotoEmpty, ChatPhotoEmpty, PhotoEmpty, type(None))): - photo_id = "" - loc = None - else: - raise ValueError(f"Unknown photo type {type(photo)}") - if self.peer_type == "user" and not photo_id and not config["bridge.allow_avatar_remove"]: - return False - if self.photo_id != photo_id: - if not photo_id: - await self._try_set_state(sender, EventType.ROOM_AVATAR, - RoomAvatarStateEventContent(url=None)) - self.photo_id = "" - self.avatar_url = None - if save: - await self.save() - return True - file = await util.transfer_file_to_matrix(user.client, self.main_intent, loc) - if file: - await self._try_set_state(sender, EventType.ROOM_AVATAR, - RoomAvatarStateEventContent(url=file.mxc)) - self.photo_id = photo_id - self.avatar_url = file.mxc - if save: - await self.save() - return True - return False - - @staticmethod - def _filter_participants(users: List[TypeUser], participants: List[TypeParticipant] - ) -> Iterable[TypeUser]: - participant_map = {part.user_id: part for part in participants - if not isinstance(part, ChannelParticipantBanned)} - for user in users: - try: - user.participant = participant_map[user.id] - except KeyError: - pass - else: - yield user - - async def _get_channel_users(self, user: 'AbstractUser', entity: InputChannel, limit: int - ) -> List[TypeUser]: - if 0 < limit <= 200: - response = await user.client(GetParticipantsRequest( - entity, ChannelParticipantsRecent(), offset=0, limit=limit, hash=0)) - return list(self._filter_participants(response.users, response.participants)) - elif limit > 200 or limit == -1: - users: List[TypeUser] = [] - offset = 0 - remaining_quota = limit if limit > 0 else 1000000 - query = (ChannelParticipantsSearch("") if limit == -1 - else ChannelParticipantsRecent()) - while True: - if remaining_quota <= 0: - break - response = await user.client(GetParticipantsRequest( - entity, query, offset=offset, limit=min(remaining_quota, 200), hash=0)) - if not response.users: - break - users += self._filter_participants(response.users, response.participants) - offset += len(response.participants) - remaining_quota -= len(response.participants) - return users - - async def _get_users(self, user: 'AbstractUser', - entity: Union[TypeInputPeer, InputUser, TypeChat, TypeUser, InputChannel] - ) -> List[TypeUser]: - limit = self.max_initial_member_sync - if self.peer_type == "chat": - chat = await user.client(GetFullChatRequest(chat_id=self.tgid)) - return list( - self._filter_participants(chat.users, chat.full_chat.participants.participants) - )[:limit] - elif self.peer_type == "channel": - if not self.megagroup and not self.sync_channel_members: - return [] - - if limit == 0: - return [] - - try: - return await self._get_channel_users(user, entity, limit) - except ChatAdminRequiredError: - return [] - elif self.peer_type == "user": - return [entity] - else: - raise RuntimeError(f"Unexpected peer type {self.peer_type}") - - # endregion - - async def _send_delivery_receipt(self, event_id: EventID, room_id: Optional[RoomID] = None - ) -> None: - # TODO maybe check if the bot is in the room rather than assuming based on self.encrypted - if event_id and config["bridge.delivery_receipts"] and (self.encrypted - or self.peer_type != "user"): - try: - await self.az.intent.mark_read(room_id or self.mxid, event_id) - except Exception: - self.log.exception("Failed to send delivery receipt for %s", event_id) - - -def init(context: Context) -> None: - global config - config = context.config diff --git a/mautrix_telegram/portal/telegram.py b/mautrix_telegram/portal/telegram.py deleted file mode 100644 index 63eb770f..00000000 --- a/mautrix_telegram/portal/telegram.py +++ /dev/null @@ -1,808 +0,0 @@ -# mautrix-telegram - A Matrix-Telegram puppeting bridge -# Copyright (C) 2020 Tulir Asokan -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Affero General Public License for more details. -# -# You should have received a copy of the GNU Affero General Public License -# along with this program. If not, see . -from typing import Awaitable, Dict, List, Optional, Tuple, Union, NamedTuple, TYPE_CHECKING -from abc import ABC -import random -import mimetypes -import codecs -import unicodedata -import base64 -import asyncio - -from sqlalchemy.exc import IntegrityError - -from telethon.tl.patched import Message, MessageService -from telethon.tl.types import ( - Poll, DocumentAttributeFilename, DocumentAttributeSticker, DocumentAttributeVideo, - MessageMediaPoll, MessageActionChannelCreate, MessageActionChatAddUser, - MessageActionChatCreate, MessageActionChatDeletePhoto, MessageActionChatDeleteUser, - MessageActionChatEditPhoto, MessageActionChatEditTitle, MessageActionChatJoinedByLink, - MessageActionChatMigrateTo, MessageActionGameScore, MessageMediaDocument, MessageMediaGeo, - MessageMediaPhoto, MessageMediaDice, MessageMediaGame, MessageMediaUnsupported, PeerUser, - PhotoCachedSize, TypeChannelParticipant, TypeChatParticipant, TypeDocumentAttribute, - TypeMessageAction, TypePhotoSize, PhotoSize, UpdateChatUserTyping, UpdateUserTyping, - MessageEntityPre, ChatPhotoEmpty, DocumentAttributeImageSize, DocumentAttributeAnimated, - UpdateChannelUserTyping, SendMessageTypingAction) - -from mautrix.appservice import IntentAPI -from mautrix.types import (EventID, UserID, ImageInfo, ThumbnailInfo, RelatesTo, MessageType, - EventType, MediaMessageEventContent, TextMessageEventContent, - LocationMessageEventContent, Format) -from mautrix.bridge import NotificationDisabler - -from ..types import TelegramID -from ..db import Message as DBMessage, TelegramFile as DBTelegramFile -from ..util import sane_mimetypes -from ..context import Context -from ..tgclient import TelegramClient -from .. import puppet as p, user as u, formatter, util -from .base import BasePortal - -if TYPE_CHECKING: - from ..abstract_user import AbstractUser - from ..config import Config - -InviteList = Union[UserID, List[UserID]] -TypeParticipant = Union[TypeChatParticipant, TypeChannelParticipant] -UpdateTyping = Union[UpdateUserTyping, UpdateChatUserTyping, UpdateChannelUserTyping] -DocAttrs = NamedTuple("DocAttrs", name=Optional[str], mime_type=Optional[str], is_sticker=bool, - sticker_alt=Optional[str], width=int, height=int, is_gif=bool) - -config: Optional['Config'] = None - - -class PortalTelegram(BasePortal, ABC): - async def handle_telegram_typing(self, user: p.Puppet, update: UpdateTyping) -> None: - if user.is_real_user: - # Ignore typing notifications from double puppeted users to avoid echoing - return - is_typing = isinstance(update.action, SendMessageTypingAction) - await user.default_mxid_intent.set_typing(self.mxid, is_typing=is_typing) - - def _get_external_url(self, evt: Message) -> Optional[str]: - if self.peer_type == "channel" and self.username is not None: - return f"https://t.me/{self.username}/{evt.id}" - elif self.peer_type != "user": - return f"https://t.me/c/{self.tgid}/{evt.id}" - return None - - async def _expire_telegram_photo(self, intent: IntentAPI, event_id: EventID, ttl: int) -> None: - try: - content = TextMessageEventContent(msgtype=MessageType.NOTICE, body="Photo has expired") - content.set_edit(event_id) - await asyncio.sleep(ttl) - await self._send_message(intent, content) - except Exception: - self.log.warning("Failed to expire Telegram photo %s", event_id, exc_info=True) - - async def handle_telegram_photo(self, source: 'AbstractUser', intent: IntentAPI, evt: Message, - relates_to: RelatesTo = None) -> Optional[EventID]: - media: MessageMediaPhoto = evt.media - if media.photo is None and media.ttl_seconds: - return await self._send_message(intent, TextMessageEventContent( - msgtype=MessageType.NOTICE, body="Photo has expired")) - loc, largest_size = self._get_largest_photo_size(media.photo) - if loc is None: - content = TextMessageEventContent(msgtype=MessageType.TEXT, - body="Failed to bridge image", - external_url=self._get_external_url(evt)) - return await self._send_message(intent, content, timestamp=evt.date) - file = await util.transfer_file_to_matrix(source.client, intent, loc, - encrypt=self.encrypted) - if not file: - return None - if self.get_config("inline_images") and (evt.message or evt.fwd_from or evt.reply_to): - content = await formatter.telegram_to_matrix( - evt, source, self.main_intent, - prefix_html=f"Inline Telegram photo
    ", - prefix_text="Inline image: ") - content.external_url = self._get_external_url(evt) - await intent.set_typing(self.mxid, is_typing=False) - return await self._send_message(intent, content, timestamp=evt.date) - info = ImageInfo( - height=largest_size.h, width=largest_size.w, orientation=0, mimetype=file.mime_type, - size=self._photo_size_key(largest_size)) - ext = sane_mimetypes.guess_extension(file.mime_type) - name = f"disappearing_image{ext}" if media.ttl_seconds else f"image{ext}" - await intent.set_typing(self.mxid, is_typing=False) - content = MediaMessageEventContent(msgtype=MessageType.IMAGE, info=info, - body=name, relates_to=relates_to, - external_url=self._get_external_url(evt)) - if file.decryption_info: - content.file = file.decryption_info - else: - content.url = file.mxc - result = await self._send_message(intent, content, timestamp=evt.date) - if media.ttl_seconds: - self.loop.create_task(self._expire_telegram_photo(intent, result, - media.ttl_seconds)) - if evt.message: - caption_content = await formatter.telegram_to_matrix(evt, source, self.main_intent, - no_reply_fallback=True) - caption_content.external_url = content.external_url - result = await self._send_message(intent, caption_content, timestamp=evt.date) - return result - - @staticmethod - def _parse_telegram_document_attributes(attributes: List[TypeDocumentAttribute]) -> DocAttrs: - name, mime_type, is_sticker, sticker_alt, width, height = None, None, False, None, 0, 0 - is_gif = False - for attr in attributes: - if isinstance(attr, DocumentAttributeFilename): - name = name or attr.file_name - mime_type, _ = mimetypes.guess_type(attr.file_name) - elif isinstance(attr, DocumentAttributeSticker): - is_sticker = True - sticker_alt = attr.alt - elif isinstance(attr, DocumentAttributeAnimated): - is_gif = True - elif isinstance(attr, DocumentAttributeVideo): - width, height = attr.w, attr.h - elif isinstance(attr, DocumentAttributeImageSize): - width, height = attr.w, attr.h - return DocAttrs(name, mime_type, is_sticker, sticker_alt, width, height, is_gif) - - @staticmethod - def _parse_telegram_document_meta(evt: Message, file: DBTelegramFile, attrs: DocAttrs, - thumb_size: TypePhotoSize) -> Tuple[ImageInfo, str]: - document = evt.media.document - name = attrs.name - if attrs.is_sticker: - alt = attrs.sticker_alt - if len(alt) > 0: - try: - name = f"{alt} ({unicodedata.name(alt[0]).lower()})" - except ValueError: - name = alt - - generic_types = ("text/plain", "application/octet-stream") - if file.mime_type in generic_types and document.mime_type not in generic_types: - mime_type = document.mime_type or file.mime_type - elif file.mime_type == 'application/ogg': - mime_type = 'audio/ogg' - else: - mime_type = file.mime_type or document.mime_type - info = ImageInfo(size=file.size, mimetype=mime_type) - - if attrs.mime_type and not file.was_converted: - file.mime_type = attrs.mime_type or file.mime_type - if file.width and file.height: - info.width, info.height = file.width, file.height - elif attrs.width and attrs.height: - info.width, info.height = attrs.width, attrs.height - - if file.thumbnail: - if file.thumbnail.decryption_info: - info.thumbnail_file = file.thumbnail.decryption_info - else: - info.thumbnail_url = file.thumbnail.mxc - info.thumbnail_info = ThumbnailInfo(mimetype=file.thumbnail.mime_type, - height=file.thumbnail.height or thumb_size.h, - width=file.thumbnail.width or thumb_size.w, - size=file.thumbnail.size) - elif attrs.is_sticker: - # This is a hack for bad clients like Element iOS that require a thumbnail - info.thumbnail_info = ImageInfo.deserialize(info.serialize()) - if file.decryption_info: - info.thumbnail_file = file.decryption_info - else: - info.thumbnail_url = file.mxc - - return info, name - - async def handle_telegram_document(self, source: 'AbstractUser', intent: IntentAPI, - evt: Message, relates_to: RelatesTo = None - ) -> Optional[EventID]: - document = evt.media.document - - attrs = self._parse_telegram_document_attributes(document.attributes) - - if document.size > config["bridge.max_document_size"] * 1000 ** 2: - name = attrs.name or "" - caption = f"\n{evt.message}" if evt.message else "" - # TODO encrypt - return await intent.send_notice(self.mxid, f"Too large file {name}{caption}") - - thumb_loc, thumb_size = self._get_largest_photo_size(document) - if thumb_size and not isinstance(thumb_size, (PhotoSize, PhotoCachedSize)): - self.log.debug(f"Unsupported thumbnail type {type(thumb_size)}") - thumb_loc = None - thumb_size = None - parallel_id = source.tgid if config["bridge.parallel_file_transfer"] else None - file = await util.transfer_file_to_matrix(source.client, intent, document, thumb_loc, - is_sticker=attrs.is_sticker, - tgs_convert=config["bridge.animated_sticker"], - filename=attrs.name, parallel_id=parallel_id, - encrypt=self.encrypted) - if not file: - return None - - info, name = self._parse_telegram_document_meta(evt, file, attrs, thumb_size) - - await intent.set_typing(self.mxid, is_typing=False) - - event_type = EventType.ROOM_MESSAGE - # Elements only support images as stickers, so send animated webm stickers as m.video - if attrs.is_sticker and file.mime_type.startswith("image/"): - event_type = EventType.STICKER - # Tell clients to render the stickers as 256x256 if they're bigger - if info.width > 256 or info.height > 256: - if info.width > info.height: - info.height = int(info.height / (info.width / 256)) - info.width = 256 - else: - info.width = int(info.width / (info.height / 256)) - info.height = 256 - if info.thumbnail_info: - info.thumbnail_info.width = info.width - info.thumbnail_info.height = info.height - if attrs.is_gif or (attrs.is_sticker and info.mimetype == "video/webm"): - if attrs.is_gif: - info["fi.mau.telegram.gif"] = True - else: - info["fi.mau.telegram.animated_sticker"] = True - info["fi.mau.loop"] = True - info["fi.mau.autoplay"] = True - info["fi.mau.hide_controls"] = True - info["fi.mau.no_audio"] = True - if not name: - ext = sane_mimetypes.guess_extension(file.mime_type) - name = "unnamed_file" + ext - - content = MediaMessageEventContent( - body=name, info=info, relates_to=relates_to, - external_url=self._get_external_url(evt), - msgtype={ - "video/": MessageType.VIDEO, - "audio/": MessageType.AUDIO, - "image/": MessageType.IMAGE, - }.get(info.mimetype[:6], MessageType.FILE)) - if file.decryption_info: - content.file = file.decryption_info - else: - content.url = file.mxc - res = await self._send_message(intent, content, event_type=event_type, timestamp=evt.date) - if evt.message: - caption_content = await formatter.telegram_to_matrix(evt, source, self.main_intent, - no_reply_fallback=True) - caption_content.external_url = content.external_url - res = await self._send_message(intent, caption_content, timestamp=evt.date) - return res - - def handle_telegram_location(self, source: 'AbstractUser', intent: IntentAPI, evt: Message, - relates_to: RelatesTo = None) -> Awaitable[EventID]: - long = evt.media.geo.long - lat = evt.media.geo.lat - long_char = "E" if long > 0 else "W" - lat_char = "N" if lat > 0 else "S" - geo = f"{round(lat, 6)},{round(long, 6)}" - - body = f"{round(abs(lat), 4)}° {lat_char}, {round(abs(long), 4)}° {long_char}" - url = f"https://maps.google.com/?q={geo}" - - content = LocationMessageEventContent( - msgtype=MessageType.LOCATION, geo_uri=f"geo:{geo}", - body=f"Location: {body}\n{url}", - relates_to=relates_to, external_url=self._get_external_url(evt)) - content["format"] = str(Format.HTML) - content["formatted_body"] = f"Location: {body}" - - return self._send_message(intent, content, timestamp=evt.date) - - async def handle_telegram_text(self, source: 'AbstractUser', intent: IntentAPI, is_bot: bool, - evt: Message) -> EventID: - self.log.trace(f"Sending {evt.message} to {self.mxid} by {intent.mxid}") - content = await formatter.telegram_to_matrix(evt, source, self.main_intent) - content.external_url = self._get_external_url(evt) - if is_bot and self.get_config("bot_messages_as_notices"): - content.msgtype = MessageType.NOTICE - await intent.set_typing(self.mxid, is_typing=False) - return await self._send_message(intent, content, timestamp=evt.date) - - async def handle_telegram_unsupported(self, source: 'AbstractUser', intent: IntentAPI, - evt: Message, relates_to: RelatesTo = None) -> EventID: - override_text = ("This message is not supported on your version of Mautrix-Telegram. " - "Please check https://github.com/mautrix/telegram or ask your " - "bridge administrator about possible updates.") - content = await formatter.telegram_to_matrix( - evt, source, self.main_intent, override_text=override_text) - content.msgtype = MessageType.NOTICE - content.external_url = self._get_external_url(evt) - content["net.maunium.telegram.unsupported"] = True - await intent.set_typing(self.mxid, is_typing=False) - return await self._send_message(intent, content, timestamp=evt.date) - - async def handle_telegram_poll(self, source: 'AbstractUser', intent: IntentAPI, evt: Message, - relates_to: RelatesTo) -> EventID: - poll: Poll = evt.media.poll - poll_id = self._encode_msgid(source, evt) - - _n = 0 - - def n() -> int: - nonlocal _n - _n += 1 - return _n - - text_answers = "\n".join(f"{n()}. {answer.text}" for answer in poll.answers) - html_answers = "\n".join(f"
  • {answer.text}
  • " for answer in poll.answers) - content = TextMessageEventContent( - msgtype=MessageType.TEXT, format=Format.HTML, - body=f"Poll: {poll.question}\n{text_answers}\n" - f"Vote with !tg vote {poll_id} ", - formatted_body=f"Poll: {poll.question}
    \n" - f"
      {html_answers}
    \n" - f"Vote with !tg vote {poll_id} <choice number>", - relates_to=relates_to, external_url=self._get_external_url(evt)) - - await intent.set_typing(self.mxid, is_typing=False) - return await self._send_message(intent, content, timestamp=evt.date) - - @staticmethod - def _format_dice(roll: MessageMediaDice) -> str: - if roll.emoticon == "\U0001F3B0": - emojis = { - 0: "\U0001F36B", # "🍫", - 1: "\U0001F352", # "🍒", - 2: "\U0001F34B", # "🍋", - 3: "7\ufe0f\u20e3" # "7️⃣", - } - res = roll.value - 1 - slot1, slot2, slot3 = emojis[res % 4], emojis[res // 4 % 4], emojis[res // 16] - return f"{slot1} {slot2} {slot3} ({roll.value})" - elif roll.emoticon == "\u26BD": - results = { - 1: "miss", - 2: "hit the woodwork", - 3: "goal", # seems to go in through the center - 4: "goal", - 5: "goal 🎉", # seems to go in through the top right corner, includes confetti - } - elif roll.emoticon == "\U0001F3B3": - results = { - 1: "miss", - 2: "1 pin down", - 3: "3 pins down, split", - 4: "4 pins down, split", - 5: "5 pins down", - 6: "strike 🎉", - } - # elif roll.emoticon == "\U0001F3C0": - # results = { - # 2: "rolled off", - # 3: "stuck", - # } - # elif roll.emoticon == "\U0001F3AF": - # results = { - # 1: "bounced off", - # 2: "outer rim", - # - # 6: "bullseye", - # } - else: - return str(roll.value) - return f"{results[roll.value]} ({roll.value})" - - async def handle_telegram_dice(self, source: 'AbstractUser', intent: IntentAPI, evt: Message, - relates_to: RelatesTo) -> EventID: - emoji_text = { - "\U0001F3AF": " Dart throw", - "\U0001F3B2": " Dice roll", - "\U0001F3C0": " Basketball throw", - "\U0001F3B0": " Slot machine", - "\U0001F3B3": " Bowling", - "\u26BD": " Football kick" - } - roll: MessageMediaDice = evt.media - text = f"{roll.emoticon}{emoji_text.get(roll.emoticon, '')} result: {self._format_dice(roll)}" - content = TextMessageEventContent(msgtype=MessageType.TEXT, format=Format.HTML, body=text, - formatted_body=f"

    {text}

    ", relates_to=relates_to, - external_url=self._get_external_url(evt)) - content["net.maunium.telegram.dice"] = {"emoticon": roll.emoticon, "value": roll.value} - await intent.set_typing(self.mxid, is_typing=False) - return await self._send_message(intent, content, timestamp=evt.date) - - @staticmethod - def _int_to_bytes(i: int) -> bytes: - hex_value = f"{i:010x}".encode("utf-8") - return codecs.decode(hex_value, "hex_codec") - - def _encode_msgid(self, source: 'AbstractUser', evt: Message) -> str: - if self.peer_type == "channel": - play_id = (b"c" - + self._int_to_bytes(self.tgid) - + self._int_to_bytes(evt.id)) - elif self.peer_type == "chat": - play_id = (b"g" - + self._int_to_bytes(self.tgid) - + self._int_to_bytes(evt.id) - + self._int_to_bytes(source.tgid)) - elif self.peer_type == "user": - play_id = (b"u" - + self._int_to_bytes(self.tgid) - + self._int_to_bytes(evt.id)) - else: - raise ValueError("Portal has invalid peer type") - return base64.b64encode(play_id).decode("utf-8").rstrip("=") - - async def handle_telegram_game(self, source: 'AbstractUser', intent: IntentAPI, - evt: Message, relates_to: RelatesTo = None) -> EventID: - game = evt.media.game - play_id = self._encode_msgid(source, evt) - command = f"!tg play {play_id}" - override_text = f"Run {command} in your bridge management room to play {game.title}" - override_entities = [ - MessageEntityPre(offset=len("Run "), length=len(command), language="")] - - content = await formatter.telegram_to_matrix( - evt, source, self.main_intent, - override_text=override_text, override_entities=override_entities) - content.msgtype = MessageType.NOTICE - content.external_url = self._get_external_url(evt) - content["net.maunium.telegram.game"] = play_id - - await intent.set_typing(self.mxid, is_typing=False) - return await self._send_message(intent, content, timestamp=evt.date) - - async def handle_telegram_edit(self, source: 'AbstractUser', sender: p.Puppet, evt: Message - ) -> None: - if not self.mxid: - self.log.trace("Ignoring edit to %d as chat has no Matrix room", evt.id) - return - elif hasattr(evt, "media") and isinstance(evt.media, MessageMediaGame): - self.log.debug("Ignoring game message edit event") - return - - async with self.send_lock(sender.tgid if sender else None, required=False): - tg_space = self.tgid if self.peer_type == "channel" else source.tgid - - temporary_identifier = EventID( - f"${random.randint(1000000000000, 9999999999999)}TGBRIDGEDITEMP") - duplicate_found = self.dedup.check(evt, (temporary_identifier, tg_space), - force_hash=True) - if duplicate_found: - mxid, other_tg_space = duplicate_found - if tg_space != other_tg_space: - prev_edit_msg = DBMessage.get_one_by_tgid(TelegramID(evt.id), tg_space, -1) - if not prev_edit_msg: - return - DBMessage(mxid=mxid, mx_room=self.mxid, tg_space=tg_space, - tgid=TelegramID(evt.id), edit_index=prev_edit_msg.edit_index + 1 - ).insert() - return - - content = await formatter.telegram_to_matrix(evt, source, self.main_intent, - no_reply_fallback=True) - editing_msg = DBMessage.get_one_by_tgid(TelegramID(evt.id), tg_space) - if not editing_msg: - self.log.info(f"Didn't find edited message {evt.id}@{tg_space} (src {source.tgid}) " - "in database.") - return - - content.msgtype = (MessageType.NOTICE if (sender and sender.is_bot - and self.get_config("bot_messages_as_notices")) - else MessageType.TEXT) - content.external_url = self._get_external_url(evt) - content.set_edit(editing_msg.mxid) - - intent = sender.intent_for(self) if sender else self.main_intent - await intent.set_typing(self.mxid, is_typing=False) - event_id = await self._send_message(intent, content) - - prev_edit_msg = DBMessage.get_one_by_tgid(TelegramID(evt.id), tg_space, -1) or editing_msg - DBMessage(mxid=event_id, mx_room=self.mxid, tg_space=tg_space, tgid=TelegramID(evt.id), - edit_index=prev_edit_msg.edit_index + 1).insert() - DBMessage.update_by_mxid(temporary_identifier, self.mxid, mxid=event_id) - - @property - def _takeout_options(self) -> Dict[str, Union[bool, int]]: - return { - "files": True, - "megagroups": self.megagroup, - "chats": self.peer_type == "chat", - "users": self.peer_type == "user", - "channels": (self.peer_type == "channel" and not self.megagroup), - "max_file_size": min(config["bridge.max_document_size"], 2000) * 1024 * 1024 - } - - async def backfill(self, source: 'u.User', is_initial: bool = False, - limit: Optional[int] = None, last_id: Optional[int] = None) -> None: - async with self.backfill_method_lock: - await self._locked_backfill(source, is_initial, limit, last_id) - - async def _locked_backfill(self, source: 'u.User', is_initial: bool = False, - limit: Optional[int] = None, last_id: Optional[int] = None) -> None: - limit = limit or (config["bridge.backfill.initial_limit"] if is_initial - else config["bridge.backfill.missed_limit"]) - if limit == 0: - return - if not config["bridge.backfill.normal_groups"] and self.peer_type == "chat": - return - last = DBMessage.find_last(self.mxid, (source.tgid if self.peer_type != "channel" - else self.tgid)) - min_id = last.tgid if last else 0 - if last_id is None: - messages = await source.client.get_messages(self.peer, limit=1) - if not messages: - # The chat seems empty - return - last_id = messages[0].id - if last_id <= min_id: - # Nothing to backfill - return - if limit < 0: - limit = last_id - min_id - self.log.debug(f"Backfilling approximately {last_id - min_id} messages " - f"through {source.mxid}") - elif self.peer_type == "channel": - # This is a channel or supergroup, so we'll backfill messages based on the ID. - # There are some cases, such as deleted messages, where this may backfill less - # messages than the limit. - min_id = max(last_id - limit, min_id) - self.log.debug(f"Backfilling messages after ID {min_id} (last message: {last_id}) " - f"through {source.mxid}") - else: - # Private chats and normal groups don't have their own message ID namespace, - # which means we'll have to fetch messages a different way. - # The _backfill_messages method will detect min_id=None and not use reverse=True - min_id = None - self.log.debug(f"Backfilling up to {limit} messages through {source.mxid}") - with self.backfill_lock: - await self._backfill(source, min_id, limit) - - async def _backfill(self, source: 'u.User', min_id: Optional[int], limit: int) -> None: - self.backfill_leave = set() - if ((self.peer_type == "user" and self.tgid != source.tgid - and config["bridge.backfill.invite_own_puppet"])): - self.log.debug("Adding %s's default puppet to room for backfilling", source.mxid) - sender = p.Puppet.get(source.tgid) - await self.main_intent.invite_user(self.mxid, sender.default_mxid) - await sender.default_mxid_intent.join_room_by_id(self.mxid) - self.backfill_leave.add(sender.default_mxid_intent) - - client = source.client - async with NotificationDisabler(self.mxid, source): - if limit > config["bridge.backfill.takeout_limit"]: - self.log.debug(f"Opening takeout client for {source.tgid}") - async with client.takeout(**self._takeout_options) as takeout: - count = await self._backfill_messages(source, min_id, limit, takeout) - else: - count = await self._backfill_messages(source, min_id, limit, client) - - for intent in self.backfill_leave: - self.log.trace("Leaving room with %s post-backfill", intent.mxid) - await intent.leave_room(self.mxid) - self.backfill_leave = None - self.log.info("Backfilled %d messages through %s", count, source.mxid) - - async def _backfill_messages(self, source: 'AbstractUser', min_id: Optional[int], limit: int, - client: TelegramClient) -> int: - count = 0 - entity = await self.get_input_entity(source) - if min_id is not None: - self.log.debug(f"Iterating all messages starting with {min_id} (approx: {limit})") - messages = client.iter_messages(entity, reverse=True, min_id=min_id) - async for message in messages: - sender = (p.Puppet.get(message.from_id.user_id) - if isinstance(message.from_id, PeerUser) else None) - # TODO handle service messages? - await self.handle_telegram_message(source, sender, message) - count += 1 - else: - self.log.debug(f"Fetching up to {limit} most recent messages") - messages = await client.get_messages(entity, limit=limit) - for message in reversed(messages): - sender = (p.Puppet.get(TelegramID(message.from_id.user_id)) - if isinstance(message.from_id, PeerUser) else None) - await self.handle_telegram_message(source, sender, message) - count += 1 - return count - - async def handle_telegram_message(self, source: 'AbstractUser', sender: p.Puppet, - evt: Message) -> None: - if not self.mxid: - self.log.trace("Got telegram message %d, but no room exists, creating...", evt.id) - await self.create_matrix_room(source, invites=[source.mxid], update_if_exists=False) - - if (self.peer_type == "user" and sender and sender.tgid == self.tg_receiver - and not sender.is_real_user and not await self.az.state_store.is_joined(self.mxid, - sender.mxid)): - self.log.debug(f"Ignoring private chat message {evt.id}@{source.tgid} as receiver does" - " not have matrix puppeting and their default puppet isn't in the room") - return - - async with self.send_lock(sender.tgid if sender else None, required=False): - tg_space = self.tgid if self.peer_type == "channel" else source.tgid - - temporary_identifier = EventID( - f"${random.randint(1000000000000, 9999999999999)}TGBRIDGETEMP") - duplicate_found = self.dedup.check(evt, (temporary_identifier, tg_space)) - if duplicate_found: - mxid, other_tg_space = duplicate_found - self.log.debug(f"Ignoring message {evt.id}@{tg_space} (src {source.tgid}) " - f"as it was already handled (in space {other_tg_space})") - if tg_space != other_tg_space: - DBMessage(tgid=TelegramID(evt.id), mx_room=self.mxid, mxid=mxid, - tg_space=tg_space, edit_index=0).insert() - return - - if self.backfill_lock.locked or (self.dedup.pre_db_check and self.peer_type == "channel"): - msg = DBMessage.get_one_by_tgid(TelegramID(evt.id), tg_space) - if msg: - self.log.debug(f"Ignoring message {evt.id} (src {source.tgid}) as it was already " - f"handled into {msg.mxid}. This duplicate was catched in the db " - "check. If you get this message often, consider increasing " - "bridge.deduplication.cache_queue_length in the config.") - return - - self.log.trace("Handling Telegram message %s", evt) - - if sender and not sender.displayname: - self.log.debug(f"Telegram user {sender.tgid} sent a message, but doesn't have a " - "displayname, updating info...") - entity = await source.client.get_entity(PeerUser(sender.tgid)) - await sender.update_info(source, entity) - if not sender.displayname: - self.log.debug(f"Telegram user {sender.tgid} doesn't have a displayname even after" - f" updating with data {entity!s}") - - allowed_media = (MessageMediaPhoto, MessageMediaDocument, MessageMediaGeo, - MessageMediaGame, MessageMediaDice, MessageMediaPoll, - MessageMediaUnsupported) - media = evt.media if hasattr(evt, "media") and isinstance(evt.media, - allowed_media) else None - if sender: - intent = sender.intent_for(self) - if ((self.backfill_lock.locked and intent != sender.default_mxid_intent - and config["bridge.backfill.invite_own_puppet"])): - intent = sender.default_mxid_intent - self.backfill_leave.add(intent) - else: - intent = self.main_intent - if not media and evt.message: - is_bot = sender.is_bot if sender else False - event_id = await self.handle_telegram_text(source, intent, is_bot, evt) - elif media: - event_id = await { - MessageMediaPhoto: self.handle_telegram_photo, - MessageMediaDocument: self.handle_telegram_document, - MessageMediaGeo: self.handle_telegram_location, - MessageMediaPoll: self.handle_telegram_poll, - MessageMediaDice: self.handle_telegram_dice, - MessageMediaUnsupported: self.handle_telegram_unsupported, - MessageMediaGame: self.handle_telegram_game, - }[type(media)](source, intent, evt, - relates_to=formatter.telegram_reply_to_matrix(evt, source)) - else: - self.log.debug("Unhandled Telegram message %d", evt.id) - return - - if not event_id: - return - - prev_id = self.dedup.update(evt, (event_id, tg_space), (temporary_identifier, tg_space)) - if prev_id: - self.log.debug(f"Sent message {evt.id}@{tg_space} to Matrix as {event_id}. " - f"Temporary dedup identifier was {temporary_identifier}, " - f"but dedup map contained {prev_id[1]} instead! -- " - "This was probably a race condition caused by Telegram sending updates" - "to other clients before responding to the sender. I'll just redact " - "the likely duplicate message now.") - await intent.redact(self.mxid, event_id) - return - - self.log.debug("Handled telegram message %d -> %s", evt.id, event_id) - try: - DBMessage(tgid=TelegramID(evt.id), mx_room=self.mxid, mxid=event_id, - tg_space=tg_space, edit_index=0).insert() - DBMessage.update_by_mxid(temporary_identifier, self.mxid, mxid=event_id) - except IntegrityError as e: - self.log.exception(f"{e.__class__.__name__} while saving message mapping. " - "This might mean that an update was handled after it left the " - "dedup cache queue. You can try enabling bridge.deduplication." - "pre_db_check in the config.") - await intent.redact(self.mxid, event_id) - await self._send_delivery_receipt(event_id) - - async def _create_room_on_action(self, source: 'AbstractUser', - action: TypeMessageAction) -> bool: - if source.is_relaybot and config["bridge.ignore_unbridged_group_chat"]: - return False - create_and_exit = (MessageActionChatCreate, MessageActionChannelCreate) - create_and_continue = (MessageActionChatAddUser, MessageActionChatJoinedByLink) - if isinstance(action, create_and_exit) or isinstance(action, create_and_continue): - await self.create_matrix_room(source, invites=[source.mxid], - update_if_exists=isinstance(action, create_and_exit)) - if not isinstance(action, create_and_continue): - return False - return True - - async def handle_telegram_action(self, source: 'AbstractUser', sender: p.Puppet, - update: MessageService) -> None: - action = update.action - should_ignore = ((not self.mxid and not await self._create_room_on_action(source, action)) - or self.dedup.check_action(update)) - if should_ignore or not self.mxid: - return - if isinstance(action, MessageActionChatEditTitle): - await self._update_title(action.title, sender=sender, save=True) - await self.update_bridge_info() - elif isinstance(action, MessageActionChatEditPhoto): - await self._update_avatar(source, action.photo, sender=sender, save=True) - await self.update_bridge_info() - elif isinstance(action, MessageActionChatDeletePhoto): - await self._update_avatar(source, ChatPhotoEmpty(), sender=sender, save=True) - await self.update_bridge_info() - elif isinstance(action, MessageActionChatAddUser): - for user_id in action.users: - await self._add_telegram_user(TelegramID(user_id), source) - elif isinstance(action, MessageActionChatJoinedByLink): - await self._add_telegram_user(sender.id, source) - elif isinstance(action, MessageActionChatDeleteUser): - await self._delete_telegram_user(TelegramID(action.user_id), sender) - elif isinstance(action, MessageActionChatMigrateTo): - self.peer_type = "channel" - self._migrate_and_save_telegram(TelegramID(action.channel_id)) - # TODO encrypt - await sender.intent_for(self).send_emote(self.mxid, - "upgraded this group to a supergroup.") - await self.update_bridge_info() - elif isinstance(action, MessageActionGameScore): - # TODO handle game score - pass - else: - self.log.trace("Unhandled Telegram action in %s: %s", self.title, action) - - async def set_telegram_admin(self, user_id: TelegramID) -> None: - puppet = p.Puppet.get(user_id) - user = u.User.get_by_tgid(user_id) - - levels = await self.main_intent.get_power_levels(self.mxid) - if user: - levels.users[user.mxid] = 50 - if puppet: - levels.users[puppet.mxid] = 50 - await self.main_intent.set_power_levels(self.mxid, levels) - - async def receive_telegram_pin_ids(self, msg_ids: List[TelegramID], receiver: TelegramID, - remove: bool) -> None: - async with self._pin_lock: - tg_space = receiver if self.peer_type != "channel" else self.tgid - previously_pinned = await self.main_intent.get_pinned_messages(self.mxid) - currently_pinned_dict = {event_id: True for event_id in previously_pinned} - for message in DBMessage.get_first_by_tgids(msg_ids, tg_space): - if remove: - currently_pinned_dict.pop(message.mxid, None) - else: - currently_pinned_dict[message.mxid] = True - currently_pinned = list(currently_pinned_dict.keys()) - if currently_pinned != previously_pinned: - await self.main_intent.set_pinned_messages(self.mxid, currently_pinned) - - async def set_telegram_admins_enabled(self, enabled: bool) -> None: - level = 50 if enabled else 10 - levels = await self.main_intent.get_power_levels(self.mxid) - levels.invite = level - levels.events[EventType.ROOM_NAME] = level - levels.events[EventType.ROOM_AVATAR] = level - await self.main_intent.set_power_levels(self.mxid, levels) - - -def init(context: Context) -> None: - global config - config = context.config - NotificationDisabler.puppet_cls = p.Puppet - NotificationDisabler.config_enabled = config["bridge.backfill.disable_notifications"] diff --git a/mautrix_telegram/puppet.py b/mautrix_telegram/puppet.py index 9ebe92ef..45d3be1c 100644 --- a/mautrix_telegram/puppet.py +++ b/mautrix_telegram/puppet.py @@ -1,5 +1,5 @@ # mautrix-telegram - A Matrix-Telegram puppeting bridge -# Copyright (C) 2020 Tulir Asokan +# Copyright (C) 2021 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -13,111 +13,79 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Awaitable, Any, Dict, Iterable, Optional, Union, Tuple, TYPE_CHECKING +from __future__ import annotations + +from typing import Awaitable, AsyncGenerator, AsyncIterable, TYPE_CHECKING, cast from difflib import SequenceMatcher import unicodedata -import asyncio -import logging from telethon.tl.types import (UserProfilePhoto, User, UpdateUserName, PeerUser, TypeInputPeer, InputPeerPhotoFileLocation, UserProfilePhotoEmpty, TypeInputUser) from yarl import URL -from mautrix.appservice import AppService, IntentAPI -from mautrix.errors import MatrixRequestError, MatrixError -from mautrix.bridge import BasePuppet +from mautrix.appservice import IntentAPI +from mautrix.errors import MatrixError +from mautrix.bridge import BasePuppet, async_getter_lock from mautrix.types import UserID, SyncToken, RoomID, ContentURI from mautrix.util.simple_template import SimpleTemplate -from mautrix.util.logging import TraceLogger +from .config import Config from .types import TelegramID from .db import Puppet as DBPuppet -from . import util, portal as p +from . import util, portal as p, abstract_user as au if TYPE_CHECKING: - from .matrix import MatrixHandler - from .config import Config - from .context import Context - from .abstract_user import AbstractUser - -config: Optional['Config'] = None + from .__main__ import TelegramBridge -class Puppet(BasePuppet): - log: TraceLogger = logging.getLogger("mau.puppet") - az: AppService - mx: 'MatrixHandler' - loop: asyncio.AbstractEventLoop +class Puppet(DBPuppet, BasePuppet): + config: Config hs_domain: str mxid_template: SimpleTemplate[TelegramID] displayname_template: SimpleTemplate[str] - cache: Dict[TelegramID, 'Puppet'] = {} - by_custom_mxid: Dict[UserID, 'Puppet'] = {} + by_tgid: dict[TelegramID, Puppet] = {} + by_custom_mxid: dict[UserID, Puppet] = {} - id: TelegramID - access_token: Optional[str] - custom_mxid: Optional[UserID] - _next_batch: Optional[SyncToken] - base_url: Optional[URL] - default_mxid: UserID + def __init__( + self, + id: TelegramID, + is_registered: bool = False, + displayname: str | None = None, + displayname_source: TelegramID | None = None, + displayname_contact: bool = True, + displayname_quality: int = 0, + disable_updates: bool = False, + username: str | None = None, + photo_id: str | None = None, + is_bot: bool = False, + custom_mxid: UserID | None = None, + access_token: str | None = None, + next_batch: SyncToken | None = None, + base_url: str | None = None + ) -> None: + super().__init__( + id=id, + is_registered=is_registered, + displayname=displayname, + displayname_source=displayname_source, + displayname_contact=displayname_contact, + displayname_quality=displayname_quality, + disable_updates=disable_updates, + username=username, + photo_id=photo_id, + is_bot=is_bot, + custom_mxid=custom_mxid, + access_token=access_token, + next_batch=next_batch, + base_url=base_url, + ) - username: Optional[str] - displayname: Optional[str] - displayname_source: Optional[TelegramID] - displayname_contact: bool - displayname_quality: int - photo_id: Optional[str] - is_bot: bool - is_registered: bool - disable_updates: bool - - default_mxid_intent: IntentAPI - intent: IntentAPI - - sync_task: Optional[asyncio.Future] - - _db_instance: Optional[DBPuppet] - - def __init__(self, - id: TelegramID, - access_token: Optional[str] = None, - custom_mxid: Optional[UserID] = None, - next_batch: Optional[SyncToken] = None, - base_url: Optional[str] = None, - username: Optional[str] = None, - displayname: Optional[str] = None, - displayname_source: Optional[TelegramID] = None, - displayname_contact: bool = True, - displayname_quality: int = 0, - photo_id: Optional[str] = None, - is_bot: bool = False, - is_registered: bool = False, - disable_updates: bool = False, - db_instance: Optional[DBPuppet] = None) -> None: - self.id = id - self.access_token = access_token - self.custom_mxid = custom_mxid - self._next_batch = next_batch - self.base_url = URL(base_url) if base_url else None self.default_mxid = self.get_mxid_from_id(self.id) - - self.username = username - self.displayname = displayname - self.displayname_source = displayname_source - self.displayname_contact = displayname_contact - self.displayname_quality = displayname_quality - self.photo_id = photo_id - self.is_bot = is_bot - self.is_registered = is_registered - self.disable_updates = disable_updates - self._db_instance = db_instance - self.default_mxid_intent = self.az.intent.user(self.default_mxid) self.intent = self._fresh_intent() - self.sync_task = None - self.cache[id] = self + self.by_tgid[id] = self if self.custom_mxid: self.by_custom_mxid[self.custom_mxid] = self @@ -128,76 +96,59 @@ class Puppet(BasePuppet): return self.id @property - def peer(self) -> PeerUser: - return PeerUser(user_id=self.tgid) + def tg_username(self) -> str | None: + return self.username @property - def next_batch(self) -> SyncToken: - return self._next_batch - - @next_batch.setter - def next_batch(self, value: SyncToken) -> None: - self._next_batch = value - self.db_instance.edit(next_batch=self._next_batch) - - @staticmethod - async def is_logged_in() -> bool: - """ Is True if the puppet is logged in. """ - return True + def peer(self) -> PeerUser: + return PeerUser(user_id=self.tgid) @property def plain_displayname(self) -> str: return self.displayname_template.parse(self.displayname) or self.displayname - def get_input_entity(self, user: 'AbstractUser' - ) -> Awaitable[Union[TypeInputPeer, TypeInputUser]]: + def get_input_entity(self, user: au.AbstractUser) -> Awaitable[TypeInputPeer | TypeInputUser]: return user.client.get_input_entity(self.peer) - def intent_for(self, portal: 'p.Portal') -> IntentAPI: + def intent_for(self, portal: p.Portal) -> IntentAPI: if portal.tgid == self.tgid: return self.default_mxid_intent return self.intent - # region DB conversion - - @property - def db_instance(self) -> DBPuppet: - if not self._db_instance: - self._db_instance = self.new_db_instance() - return self._db_instance - - @property - def _fields(self) -> Dict[str, Any]: - return dict(access_token=self.access_token, next_batch=self._next_batch, - custom_mxid=self.custom_mxid, username=self.username, is_bot=self.is_bot, - displayname=self.displayname, displayname_source=self.displayname_source, - displayname_contact=self.displayname_contact, - displayname_quality=self.displayname_quality, photo_id=self.photo_id, - matrix_registered=self.is_registered, disable_updates=self.disable_updates, - base_url=str(self.base_url) if self.base_url else None) - - def new_db_instance(self) -> DBPuppet: - return DBPuppet(id=self.id, **self._fields) - - async def save(self) -> None: - self.db_instance.edit(**self._fields) - @classmethod - def from_db(cls, db_puppet: DBPuppet) -> 'Puppet': - return Puppet(db_puppet.id, db_puppet.access_token, db_puppet.custom_mxid, - db_puppet.next_batch, db_puppet.base_url, db_puppet.username, - db_puppet.displayname, db_puppet.displayname_source, - db_puppet.displayname_contact, db_puppet.displayname_quality, - db_puppet.photo_id, db_puppet.is_bot, db_puppet.matrix_registered, - db_puppet.disable_updates, db_instance=db_puppet) + def init_cls(cls, bridge: 'TelegramBridge') -> AsyncIterable[Awaitable[None]]: + cls.config = bridge.config + cls.loop = bridge.loop + cls.mx = bridge.matrix + cls.az = bridge.az + cls.hs_domain = cls.config["homeserver.domain"] + mxid_tpl = SimpleTemplate( + cls.config["bridge.username_template"], + "userid", + prefix="@", + suffix=f":{Puppet.hs_domain}", + type=int, + ) + cls.mxid_template = cast(SimpleTemplate[TelegramID], mxid_tpl) + cls.displayname_template = SimpleTemplate( + cls.config["bridge.displayname_template"], "displayname" + ) + cls.sync_with_custom_puppets = cls.config["bridge.sync_with_custom_puppets"] + cls.homeserver_url_map = {server: URL(url) for server, url + in cls.config["bridge.double_puppet_server_map"].items()} + cls.allow_discover_url = cls.config["bridge.double_puppet_allow_discovery"] + cls.login_shared_secret_map = {server: secret.encode("utf-8") for server, secret + in cls.config["bridge.login_shared_secret_map"].items()} + cls.login_device_name = "Telegram Bridge" + + return (puppet.try_start() async for puppet in cls.all_with_custom_mxid()) - # endregion # region Info updating def similarity(self, query: str) -> int: username_similarity = (SequenceMatcher(None, self.username, query).ratio() if self.username else 0) - displayname_similarity = (SequenceMatcher(None, self.displayname, query).ratio() + displayname_similarity = (SequenceMatcher(None, self.plain_displayname, query).ratio() if self.displayname else 0) similarity = max(username_similarity, displayname_similarity) return int(round(similarity * 100)) @@ -211,11 +162,11 @@ class Puppet(BasePuppet): "\u200c\u200d\u200e\u200f\ufe0f") allowed_other_format = ("\u200d", "\u200c") name = "".join(c for c in name.strip(whitespace) if unicodedata.category(c) != 'Cf' - or c in allowed_other_format) + or c in allowed_other_format) return name @classmethod - def get_displayname(cls, info: User, enable_format: bool = True) -> Tuple[str, int]: + def get_displayname(cls, info: User, enable_format: bool = True) -> tuple[str, int]: fn = cls._filter_name(info.first_name) ln = cls._filter_name(info.last_name) data = { @@ -226,7 +177,7 @@ class Puppet(BasePuppet): "first name": fn, "last name": ln, } - preferences = config["bridge.displayname_preference"] + preferences = cls.config["bridge.displayname_preference"] name = None quality = 99 for preference in preferences: @@ -244,13 +195,13 @@ class Puppet(BasePuppet): return (cls.displayname_template.format_full(name) if enable_format else name), quality - async def try_update_info(self, source: 'AbstractUser', info: User) -> None: + async def try_update_info(self, source: au.AbstractUser, info: User) -> None: try: await self.update_info(source, info) except Exception: source.log.exception(f"Failed to update info of {self.tgid}") - async def update_info(self, source: 'AbstractUser', info: User) -> None: + async def update_info(self, source: au.AbstractUser, info: User) -> None: changed = False if self.username != info.username: self.username = info.username @@ -268,7 +219,7 @@ class Puppet(BasePuppet): if changed: await self.save() - async def update_displayname(self, source: 'AbstractUser', info: Union[User, UpdateUserName] + async def update_displayname(self, source: au.AbstractUser, info: User | UpdateUserName ) -> bool: if self.disable_updates: return False @@ -306,7 +257,7 @@ class Puppet(BasePuppet): self.displayname_quality = quality try: await self.default_mxid_intent.set_displayname( - displayname[:config["bridge.displayname_max_length"]]) + displayname[:self.config["bridge.displayname_max_length"]]) except MatrixError: self.log.exception("Failed to set displayname") self.displayname = "" @@ -318,8 +269,8 @@ class Puppet(BasePuppet): return True return False - async def update_avatar(self, source: 'AbstractUser', - photo: Union[UserProfilePhoto, UserProfilePhotoEmpty]) -> bool: + async def update_avatar(self, source: au.AbstractUser, + photo: UserProfilePhoto | UserProfilePhotoEmpty) -> bool: if self.disable_updates: return False @@ -330,7 +281,7 @@ class Puppet(BasePuppet): else: self.log.warning(f"Unknown user profile photo type: {type(photo)}") return False - if not photo_id and not config["bridge.allow_avatar_remove"]: + if not photo_id and not self.config["bridge.allow_avatar_remove"]: return False if self.photo_id != photo_id: if not photo_id: @@ -359,72 +310,73 @@ class Puppet(BasePuppet): return False async def default_puppet_should_leave_room(self, room_id: RoomID) -> bool: - portal: p.Portal = p.Portal.get_by_mxid(room_id) + portal: p.Portal = await p.Portal.get_by_mxid(room_id) return portal and not portal.backfill_lock.locked and portal.peer_type != "user" # endregion # region Getters + def _add_to_cache(self) -> None: + self.by_tgid[self.id] = self + if self.custom_mxid: + self.by_custom_mxid[self.custom_mxid] = self + @classmethod - def get(cls, tgid: TelegramID, create: bool = True) -> Optional['Puppet']: + @async_getter_lock + async def get_by_tgid(cls, tgid: TelegramID, *, create: bool = True) -> Puppet | None: + if tgid is None: + return None + try: - return cls.cache[tgid] + return cls.by_tgid[tgid] except KeyError: pass - puppet = DBPuppet.get_by_tgid(tgid) + puppet = cast(cls, await super().get_by_tgid(tgid)) if puppet: - return cls.from_db(puppet) + puppet._add_to_cache() + return puppet if create: puppet = cls(tgid) - puppet.db_instance.insert() + await puppet.insert() + puppet._add_to_cache() return puppet return None @classmethod - def deprecated_sync_get_by_mxid(cls, mxid: UserID, create: bool = True) -> Optional['Puppet']: - tgid = cls.get_id_from_mxid(mxid) - if tgid: - return cls.get(tgid, create) - - return None + def get_by_mxid(cls, mxid: UserID, create: bool = True) -> Awaitable[Puppet | None]: + return cls.get_by_tgid(cls.get_id_from_mxid(mxid), create=create) @classmethod - async def get_by_mxid(cls, mxid: UserID, create: bool = True) -> Optional['Puppet']: - return cls.deprecated_sync_get_by_mxid(mxid, create) - - @classmethod - def deprecated_sync_get_by_custom_mxid(cls, mxid: UserID) -> Optional['Puppet']: - if not mxid: - raise ValueError("Matrix ID can't be empty") - + @async_getter_lock + async def get_by_custom_mxid(cls, mxid: UserID) -> Puppet | None: try: return cls.by_custom_mxid[mxid] except KeyError: pass - puppet = DBPuppet.get_by_custom_mxid(mxid) + puppet = cast(cls, await super().get_by_custom_mxid(mxid)) if puppet: - puppet = cls.from_db(puppet) + puppet._add_to_cache() return puppet return None @classmethod - async def get_by_custom_mxid(cls, mxid: UserID) -> Optional['Puppet']: - return cls.deprecated_sync_get_by_custom_mxid(mxid) + async def all_with_custom_mxid(cls) -> AsyncGenerator[Puppet, None]: + puppets = await super().all_with_custom_mxid() + puppet: cls + for puppet in puppets: + try: + yield cls.by_tgid[puppet.tgid] + except KeyError: + puppet._add_to_cache() + yield puppet @classmethod - def all_with_custom_mxid(cls) -> Iterable['Puppet']: - return (cls.by_custom_mxid[puppet.custom_mxid] - if puppet.custom_mxid in cls.by_custom_mxid - else cls.from_db(puppet) - for puppet in DBPuppet.all_with_custom_mxid()) - - @classmethod - def get_id_from_mxid(cls, mxid: UserID) -> Optional[TelegramID]: + def get_id_from_mxid(cls, mxid: UserID) -> TelegramID | None: return cls.mxid_template.parse(mxid) @classmethod @@ -432,56 +384,43 @@ class Puppet(BasePuppet): return UserID(cls.mxid_template.format_full(tgid)) @classmethod - def find_by_username(cls, username: str) -> Optional['Puppet']: + async def find_by_username(cls, username: str) -> Puppet | None: if not username: return None username = username.lower() - for _, puppet in cls.cache.items(): + for _, puppet in cls.by_tgid.items(): if puppet.username and puppet.username.lower() == username: return puppet - dbpuppet = DBPuppet.get_by_username(username) - if dbpuppet: - return cls.from_db(dbpuppet) + puppet = cast(cls, await super().find_by_username(username)) + if puppet: + try: + return cls.by_tgid[puppet.tgid] + except KeyError: + puppet._add_to_cache() + return puppet return None @classmethod - def find_by_displayname(cls, displayname: str) -> Optional['Puppet']: + async def find_by_displayname(cls, displayname: str) -> Puppet | None: if not displayname: return None - for _, puppet in cls.cache.items(): + for _, puppet in cls.by_tgid.items(): if puppet.displayname and puppet.displayname == displayname: return puppet - dbpuppet = DBPuppet.get_by_displayname(displayname) - if dbpuppet: - return cls.from_db(dbpuppet) + puppet = cast(cls, await super().find_by_displayname(displayname)) + if puppet: + try: + return cls.by_tgid[puppet.tgid] + except KeyError: + puppet._add_to_cache() + return puppet return None + # endregion - - -def init(context: 'Context') -> Iterable[Awaitable[Any]]: - global config - Puppet.az, config, Puppet.loop, _ = context.core - Puppet.mx = context.mx - Puppet.hs_domain = config["homeserver"]["domain"] - - Puppet.mxid_template = SimpleTemplate(config["bridge.username_template"], "userid", - prefix="@", suffix=f":{Puppet.hs_domain}", type=int) - Puppet.displayname_template = SimpleTemplate(config["bridge.displayname_template"], - "displayname") - - Puppet.sync_with_custom_puppets = config["bridge.sync_with_custom_puppets"] - Puppet.homeserver_url_map = {server: URL(url) for server, url - in config["bridge.double_puppet_server_map"].items()} - Puppet.allow_discover_url = config["bridge.double_puppet_allow_discovery"] - Puppet.login_shared_secret_map = {server: secret.encode("utf-8") for server, secret - in config["bridge.login_shared_secret_map"].items()} - Puppet.login_device_name = "Telegram Bridge" - - return (puppet.try_start() for puppet in Puppet.all_with_custom_mxid()) diff --git a/mautrix_telegram/scripts/__init__.py b/mautrix_telegram/scripts/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/mautrix_telegram/scripts/dbms_migrate/__init__.py b/mautrix_telegram/scripts/dbms_migrate/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/mautrix_telegram/scripts/dbms_migrate/__main__.py b/mautrix_telegram/scripts/dbms_migrate/__main__.py deleted file mode 100644 index 23261c0e..00000000 --- a/mautrix_telegram/scripts/dbms_migrate/__main__.py +++ /dev/null @@ -1,88 +0,0 @@ -from typing import Union -import argparse - -from sqlalchemy import orm -from sqlalchemy.ext.declarative import declarative_base -import sqlalchemy as sql - -from alchemysession import AlchemySessionContainer - -parser = argparse.ArgumentParser(description="mautrix-telegram dbms migration script", - prog="python -m mautrix_telegram.scripts.dbms_migrate") -parser.add_argument("-f", "--from-url", type=str, required=True, metavar="", - help="the old database path") -parser.add_argument("-t", "--to-url", type=str, required=True, metavar="", - help="the new database path") -parser.add_argument("-v", "--verbose", action="store_true", help="Verbose logs while migrating") -args = parser.parse_args() -verbose = args.verbose or False - - -def log(message, end="\n"): - if verbose: - print(message, end=end, flush=True) - - -def connect(to): - from mautrix.util.db import Base - from mautrix.client.state_store.sqlalchemy import RoomState, UserProfile - from mautrix_telegram.db import (Portal, Message, UserPortal, User, Contact, Puppet, BotChat, - TelegramFile) - - db_engine = sql.create_engine(to) - db_factory = orm.sessionmaker(bind=db_engine) - db_session: Union[orm.Session, orm.scoped_session] = orm.scoped_session(db_factory) - Base.metadata.bind = db_engine - - new_base = declarative_base() - new_base.metadata.bind = db_engine - session_container = AlchemySessionContainer(engine=db_engine, session=db_session, - table_base=new_base, table_prefix="telethon_", - manage_tables=False) - - return db_session, { - "Version": session_container.Version, - "Session": session_container.Session, - "Entity": session_container.Entity, - "SentFile": session_container.SentFile, - "UpdateState": session_container.UpdateState, - "Portal": Portal, - "Message": Message, - "Puppet": Puppet, - "User": User, - "UserPortal": UserPortal, - "RoomState": RoomState, - "UserProfile": UserProfile, - "Contact": Contact, - "BotChat": BotChat, - "TelegramFile": TelegramFile, - } - - -log("Connecting to old database") -session, tables = connect(args.from_url) - -data = {} -for name, table in tables.items(): - log("Reading table {name}...".format(name=name), end=" ") - data[name] = session.query(table).all() - log("Done!") - -log("Connecting to new database") -session, tables = connect(args.to_url) - -for name, table in tables.items(): - log("Writing table {name}".format(name=name), end="") - length = len(data[name]) - n = 0 - for row in data[name]: - session.merge(row) - n += 5 - if n >= length: - log(".", end="") - n = 0 - log(" Done!") - -log("Committing changes to database...", end=" ") -session.commit() -log("Done!") diff --git a/mautrix_telegram/scripts/telematrix_import/__init__.py b/mautrix_telegram/scripts/telematrix_import/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/mautrix_telegram/scripts/telematrix_import/__main__.py b/mautrix_telegram/scripts/telematrix_import/__main__.py deleted file mode 100644 index b6ef97c0..00000000 --- a/mautrix_telegram/scripts/telematrix_import/__main__.py +++ /dev/null @@ -1,125 +0,0 @@ -# mautrix-telegram - A Matrix-Telegram puppeting bridge -# Copyright (C) 2019 Tulir Asokan -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Affero General Public License for more details. -# -# You should have received a copy of the GNU Affero General Public License -# along with this program. If not, see . -from typing import Dict -import argparse - -from sqlalchemy import orm -import sqlalchemy as sql - -from mautrix.util.db import Base - -from mautrix_telegram.db import Portal, Message, Puppet, BotChat -from mautrix_telegram.config import Config - -from .models import ChatLink, TgUser, MatrixUser, Message as TMMessage, Base as TelematrixBase - -parser = argparse.ArgumentParser( - description="mautrix-telegram telematrix import script", - prog="python -m mautrix_telegram.scripts.telematrix_import") -parser.add_argument("-c", "--config", type=str, default="config.yaml", - metavar="", help="the path to your mautrix-telegram config file") -parser.add_argument("-b", "--bot-id", type=int, required=True, - metavar="", help="the telegram user ID of your relay bot") -parser.add_argument("-t", "--telematrix-database", type=str, default="sqlite:///database.db", - metavar="", help="your telematrix database URL") -args = parser.parse_args() - -config = Config(args.config, None, None) -config.load() - -mxtg_db_engine = sql.create_engine(config["appservice.database"]) -mxtg = orm.sessionmaker(bind=mxtg_db_engine)() -Base.metadata.bind = mxtg_db_engine - -telematrix_db_engine = sql.create_engine(args.telematrix_database) -telematrix = orm.sessionmaker(bind=telematrix_db_engine)() -TelematrixBase.metadata.bind = telematrix_db_engine - -chat_links = telematrix.query(ChatLink).all() -tg_users = telematrix.query(TgUser).all() -mx_users = telematrix.query(MatrixUser).all() -tm_messages = telematrix.query(TMMessage).all() - -telematrix.close() -telematrix_db_engine.dispose() - -portals_by_tgid: Dict[int, Portal] = {} -portals_by_mxid: Dict[str, Portal] = {} -chats: Dict[int, BotChat] = {} -messages: Dict[str, Message] = {} -puppets: Dict[int, Puppet] = {} - -for chat_link in chat_links: - if type(chat_link.tg_room) is str: - print(f"Expected tg_room to be a number, got a string. Ignoring {chat_link.tg_room}") - continue - if chat_link.tg_room >= 0: - print(f"Unexpected unprefixed telegram chat ID: {chat_link.tg_room}, ignoring...") - continue - tgid = str(chat_link.tg_room) - if tgid.startswith("-100"): - tgid = int(tgid[4:]) - peer_type = "channel" - megagroup = True - else: - tgid = -chat_link.tg_room - peer_type = "chat" - megagroup = False - - portal = Portal(tgid=tgid, tg_receiver=tgid, peer_type=peer_type, megagroup=megagroup, - mxid=chat_link.matrix_room) - chats[tgid] = BotChat(id=tgid, type=peer_type) - if chat_link.tg_room in portals_by_tgid: - print(f"Warning: Ignoring bridge from {portal.tgid} to {portal.mxid} " - f"in favor of {portals_by_tgid[portal.tgid].mxid}") - continue - elif chat_link.matrix_room in portals_by_mxid: - print(f"Warning: Ignoring bridge from {portal.mxid} to {portal.tgid} " - f"in favor of {portals_by_mxid[portal.mxid].tgid}") - continue - portals_by_tgid[portal.tgid] = portal - portals_by_mxid[portal.mxid] = portal - -for tm_msg in tm_messages: - try: - portal = portals_by_tgid[tm_msg.tg_group_id] - except KeyError: - print(f"Found message entry {tm_msg.tg_message_id} in unlinked chat {tm_msg.tg_group_id}," - " ignoring...") - continue - if tm_msg.matrix_room_id != portal.mxid: - print(f"Found message entry {tm_msg.tg_message_id} with " - f"mismatching matrix room ID {tm_msg.matrix_room_id} (expected {portal.mxid})") - continue - tg_space = portal.tgid if portal.peer_type == "channel" else args.bot_id - message = Message(mxid=tm_msg.matrix_event_id, mx_room=tm_msg.matrix_room_id, - tgid=tm_msg.tg_message_id, tg_space=tg_space) - messages[tm_msg.matrix_event_id] = message - -for user in tg_users: - puppets[user.tg_id] = Puppet(id=user.tg_id, displayname=user.name, - displayname_source=args.bot_id) - -for k, v in portals_by_tgid.items(): - mxtg.add(v) -for k, v in chats.items(): - mxtg.add(v) -for k, v in messages.items(): - mxtg.add(v) -for k, v in puppets.items(): - mxtg.add(v) - -mxtg.commit() diff --git a/mautrix_telegram/scripts/telematrix_import/models.py b/mautrix_telegram/scripts/telematrix_import/models.py deleted file mode 100644 index ef7c3b42..00000000 --- a/mautrix_telegram/scripts/telematrix_import/models.py +++ /dev/null @@ -1,44 +0,0 @@ -import sqlalchemy as sa -from sqlalchemy.ext.declarative import declarative_base - -Base = declarative_base() - - -class ChatLink(Base): - __tablename__ = "chat_link" - - id = sa.Column(sa.Integer, primary_key=True) - matrix_room = sa.Column(sa.String) - tg_room = sa.Column(sa.BigInteger) - active = sa.Column(sa.Boolean) - - -class TgUser(Base): - __tablename__ = "tg_user" - - id = sa.Column(sa.Integer, primary_key=True) - tg_id = sa.Column(sa.BigInteger) - name = sa.Column(sa.String) - profile_pic_id = sa.Column(sa.String, nullable=True) - - -class MatrixUser(Base): - __tablename__ = "matrix_user" - - id = sa.Column(sa.Integer, primary_key=True) - matrix_id = sa.Column(sa.String) - name = sa.Column(sa.String) - - -class Message(Base): - """Describes a message in a room bridged between Telegram and Matrix""" - __tablename__ = "message" - - id = sa.Column(sa.Integer, primary_key=True) - tg_group_id = sa.Column(sa.BigInteger) - tg_message_id = sa.Column(sa.BigInteger) - - matrix_room_id = sa.Column(sa.String) - matrix_event_id = sa.Column(sa.String) - - displayname = sa.Column(sa.String) diff --git a/mautrix_telegram/user.py b/mautrix_telegram/user.py index be20574d..9b106695 100644 --- a/mautrix_telegram/user.py +++ b/mautrix_telegram/user.py @@ -13,10 +13,10 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import (Awaitable, Dict, List, Iterable, NamedTuple, Optional, Tuple, Any, cast, - TYPE_CHECKING) +from __future__ import annotations + +from typing import Awaitable, AsyncIterable, NamedTuple, AsyncGenerator, TYPE_CHECKING, cast from datetime import datetime, timezone -import logging import asyncio from telethon.tl.types import (TypeUpdate, UpdateNewMessage, UpdateNewChannelMessage, @@ -35,21 +35,17 @@ from mautrix.client import Client from mautrix.errors import MatrixRequestError, MNotFound from mautrix.types import UserID, RoomID, PushRuleScope, PushRuleKind, PushActionType, RoomTagInfo from mautrix.appservice import DOUBLE_PUPPET_SOURCE_KEY -from mautrix.bridge import BaseUser +from mautrix.bridge import BaseUser, async_getter_lock from mautrix.util.bridge_state import BridgeState, BridgeStateEvent -from mautrix.util.logging import TraceLogger from mautrix.util.opt_prometheus import Gauge from .types import TelegramID -from .db import User as DBUser, Portal as DBPortal, Message as DBMessage +from .db import User as DBUser, Message as DBMessage, PgSession from .abstract_user import AbstractUser from . import portal as po, puppet as pu if TYPE_CHECKING: - from .config import Config - from .context import Context - -config: Optional['Config'] = None + from .__main__ import TelegramBridge SearchResult = NamedTuple('SearchResult', puppet='pu.Puppet', similarity=int) @@ -64,54 +60,46 @@ BridgeState.human_readable_errors.update({ }) -class User(AbstractUser, BaseUser): - log: TraceLogger = logging.getLogger("mau.user") - by_mxid: Dict[str, 'User'] = {} - by_tgid: Dict[int, 'User'] = {} +class User(DBUser, AbstractUser, BaseUser): + by_mxid: dict[str, User] = {} + by_tgid: dict[int, User] = {} - phone: Optional[str] - contacts: List['pu.Puppet'] - saved_contacts: int - portals: Dict[Tuple[TelegramID, TelegramID], 'po.Portal'] - command_status: Optional[Dict[str, Any]] + _portals_cache: dict[tuple[TelegramID, TelegramID], po.Portal] | None - _db_instance: Optional[DBUser] _ensure_started_lock: asyncio.Lock - _track_connection_task: Optional[asyncio.Task] + _track_connection_task: asyncio.Task | None + _is_backfilling: bool - def __init__(self, mxid: UserID, tgid: Optional[TelegramID] = None, - username: Optional[str] = None, phone: Optional[str] = None, - db_contacts: Optional[Iterable[TelegramID]] = None, - saved_contacts: int = 0, is_bot: bool = False, - db_portals: Optional[Iterable[Tuple[TelegramID, TelegramID]]] = None, - db_instance: Optional[DBUser] = None) -> None: + def __init__( + self, + mxid: UserID, + tgid: TelegramID | None = None, + tg_username: str | None = None, + tg_phone: str | None = None, + is_bot: bool = False, + saved_contacts: int = 0, + ) -> None: + super().__init__( + mxid=mxid, + tgid=tgid, + tg_username=tg_username, + tg_phone=tg_phone, + is_bot=is_bot, + saved_contacts=saved_contacts, + ) AbstractUser.__init__(self) - self.mxid = mxid BaseUser.__init__(self) - self.tgid = tgid - self.is_bot = is_bot - self.username = username - self.phone = phone - self.contacts = [] - self.saved_contacts = saved_contacts - self.db_contacts = db_contacts - self.portals = {} - self.db_portals = db_portals or [] - self._db_instance = db_instance self._ensure_started_lock = asyncio.Lock() self._track_connection_task = None self._is_backfilling = False + self._portals_cache = None (self.relaybot_whitelisted, self.whitelisted, self.puppet_whitelisted, self.matrix_puppet_whitelisted, self.is_admin, - self.permissions) = config.get_permissions(self.mxid) - - self.by_mxid[mxid] = self - if tgid: - self.by_tgid[tgid] = self + self.permissions) = self.config.get_permissions(self.mxid) @property def name(self) -> str: @@ -124,7 +112,7 @@ class User(AbstractUser, BaseUser): @property def human_tg_id(self) -> str: - return f"@{self.username}" if self.username else f"+{self.phone}" or None + return f"@{self.tg_username}" if self.tg_username else f"+{self.tg_phone}" or None # TODO replace with proper displayname getting everywhere @property @@ -135,65 +123,15 @@ class User(AbstractUser, BaseUser): def plain_displayname(self) -> str: return self.displayname - @property - def db_contacts(self) -> Iterable[TelegramID]: - return (puppet.id - for puppet in self.contacts - if puppet) - - @db_contacts.setter - def db_contacts(self, contacts: Iterable[TelegramID]) -> None: - self.contacts = [pu.Puppet.get(entry) for entry in contacts] if contacts else [] - - @property - def db_portals(self) -> Iterable[Tuple[TelegramID, TelegramID]]: - return (portal.tgid_full - for portal in self.portals.values() - if portal and not portal.deleted) - - @db_portals.setter - def db_portals(self, portals: Iterable[Tuple[TelegramID, TelegramID]]) -> None: - self.portals = { - tgid_full: po.Portal.get_by_tgid(*tgid_full) - for tgid_full in portals - } if portals else {} - - # region Database conversion - - @property - def db_instance(self) -> DBUser: - if not self._db_instance: - self._db_instance = self.new_db_instance() - return self._db_instance - - def new_db_instance(self) -> DBUser: - return DBUser(mxid=self.mxid, tgid=self.tgid, tg_username=self.username, - saved_contacts=self.saved_contacts, portals=self.db_portals) - - async def save(self, contacts: bool = False, portals: bool = False) -> None: - self.db_instance.edit(tgid=self.tgid, tg_username=self.username, tg_phone=self.phone, - saved_contacts=self.saved_contacts) - if contacts: - self.db_instance.contacts = self.db_contacts - if portals: - self.db_instance.portals = self.db_portals - - def delete(self, delete_db: bool = True) -> None: - try: - del self.by_mxid[self.mxid] - del self.by_tgid[self.tgid] - except KeyError: - pass - if delete_db and self._db_instance: - self._db_instance.delete() - @classmethod - def from_db(cls, db_user: DBUser) -> 'User': - return User(db_user.mxid, db_user.tgid, db_user.tg_username, db_user.tg_phone, - db_user.contacts, db_user.saved_contacts, False, db_user.portals, - db_instance=db_user) + def init_cls(cls, bridge: 'TelegramBridge') -> AsyncIterable[Awaitable[User]]: + cls.config = bridge.config + cls.bridge = bridge + cls.az = bridge.az + cls.loop = bridge.loop + + return (user.try_ensure_started() async for user in cls.all_with_tgid()) - # endregion # region Telegram connection management async def try_ensure_started(self) -> None: @@ -202,19 +140,19 @@ class User(AbstractUser, BaseUser): except Exception: self.log.exception("Exception in ensure_started") else: - if not self.client and not self.session_container.has_session(self.mxid): + if not self.client and not await PgSession.has(self.mxid): self.log.warning("Didn't start user: no session stored") if self.tgid: await self.push_bridge_state(BridgeStateEvent.BAD_CREDENTIALS, error="tg-no-auth") - async def ensure_started(self, even_if_no_session=False) -> 'User': + async def ensure_started(self, even_if_no_session=False) -> User: if not self.puppet_whitelisted or self.connected: return self async with self._ensure_started_lock: return cast(User, await super().ensure_started(even_if_no_session)) - async def start(self, delete_unless_authenticated: bool = False) -> 'User': + async def start(self, delete_unless_authenticated: bool = False) -> User: try: await super().start() except AuthKeyDuplicatedError: @@ -222,7 +160,7 @@ class User(AbstractUser, BaseUser): await self.push_bridge_state(BridgeStateEvent.BAD_CREDENTIALS, error="tg-auth-key-duplicated") await self.client.disconnect() - self.client.session.delete() + await self.client.session.delete() self.client = None if not delete_unless_authenticated: # The caller wants the client to be connected, so restart the connection. @@ -257,7 +195,7 @@ class User(AbstractUser, BaseUser): if delete_unless_authenticated: self.log.debug(f"Unauthenticated user {self.name} start()ed, deleting session...") await self.client.disconnect() - self.client.session.delete() + await self.client.session.delete() return self @property @@ -283,7 +221,7 @@ class User(AbstractUser, BaseUser): state.remote_id = str(self.tgid) state.remote_name = self.human_tg_id - async def get_bridge_states(self) -> List[BridgeState]: + async def get_bridge_states(self) -> list[BridgeState]: if not self.tgid: return [] if self._is_connected and await self.is_logged_in(): @@ -295,10 +233,10 @@ class User(AbstractUser, BaseUser): ttl = 240 return [BridgeState(state_event=state_event, ttl=ttl)] - async def get_puppet(self) -> Optional['pu.Puppet']: + async def get_puppet(self) -> pu.Puppet | None: if not self.tgid: return None - return pu.Puppet.get(self.tgid) + return await pu.Puppet.get_by_tgid(self.tgid) async def stop(self) -> None: if self._track_connection_task: @@ -308,7 +246,7 @@ class User(AbstractUser, BaseUser): self._track_metric(METRIC_CONNECTED, False) async def post_login(self, info: TLUser = None, first_login: bool = False) -> None: - if config["metrics.enabled"] and not self._track_connection_task: + if self.config["metrics.enabled"] and not self._track_connection_task: self._track_connection_task = self.loop.create_task(self._track_connection()) try: @@ -320,14 +258,14 @@ class User(AbstractUser, BaseUser): self._track_metric(METRIC_LOGGED_IN, True) try: - puppet = pu.Puppet.get(self.tgid) + puppet = await pu.Puppet.get_by_tgid(self.tgid) if puppet.custom_mxid != self.mxid and puppet.can_auto_login(self.mxid): self.log.info(f"Automatically enabling custom puppet") await puppet.switch_mxid(access_token="auto", mxid=self.mxid) except Exception: self.log.exception("Failed to automatically enable custom puppet") - if not self.is_bot and config["bridge.startup_sync"]: + if not self.is_bot and self.config["bridge.startup_sync"]: try: self._is_backfilling = True await self.sync_dialogs() @@ -342,11 +280,13 @@ class User(AbstractUser, BaseUser): return False if isinstance(update, (UpdateNewMessage, UpdateNewChannelMessage)): - portal = po.Portal.get_by_entity(update.message.peer_id, receiver_id=self.tgid) + portal = await po.Portal.get_by_entity(update.message.peer_id, tg_receiver=self.tgid) elif isinstance(update, UpdateShortChatMessage): - portal = po.Portal.get_by_tgid(TelegramID(update.chat_id)) + portal = await po.Portal.get_by_tgid(TelegramID(update.chat_id)) elif isinstance(update, UpdateShortMessage): - portal = po.Portal.get_by_tgid(TelegramID(update.user_id), self.tgid, "user") + portal = await po.Portal.get_by_tgid( + TelegramID(update.user_id), tg_receiver=self.tgid, peer_type="user" + ) else: return False @@ -364,7 +304,7 @@ class User(AbstractUser, BaseUser): if not self.is_bot: await self.client(UpdateStatusRequest(offline=not online)) - async def get_me(self) -> Optional[TLUser]: + async def get_me(self) -> TLUser | None: try: return (await self.client(GetUsersRequest([InputUserSelf()])))[0] except UnauthorizedError as e: @@ -384,11 +324,11 @@ class User(AbstractUser, BaseUser): if self.is_bot != info.bot: self.is_bot = info.bot changed = True - if self.username != info.username: - self.username = info.username + if self.tg_username != info.username: + self.tg_username = info.username changed = True - if self.phone != info.phone: - self.phone = info.phone + if self.tg_phone != info.phone: + self.tg_phone = info.phone changed = True if self.tgid != info.id: self.tgid = TelegramID(info.id) @@ -396,11 +336,11 @@ class User(AbstractUser, BaseUser): if changed: await self.save() - async def log_out(self) -> bool: - puppet = pu.Puppet.get(self.tgid) - if puppet.is_real_user: - await puppet.switch_mxid(None, None) - for _, portal in self.portals.items(): + async def kick_from_portals(self) -> None: + if not self.config["bridge.kick_on_logout"]: + return + portals = await self.get_cached_portals() + for _, portal in portals.values(): if not portal or portal.deleted or not portal.mxid or portal.has_bot: continue if portal.peer_type == "user": @@ -411,9 +351,15 @@ class User(AbstractUser, BaseUser): "Logged out of Telegram.") except MatrixRequestError: pass - self.portals = {} - self.contacts = [] - await self.save(portals=True, contacts=True) + + async def log_out(self) -> bool: + puppet = await pu.Puppet.get_by_tgid(self.tgid) + if puppet.is_real_user: + await puppet.switch_mxid(None, None) + try: + await self.kick_from_portals() + except Exception: + self.log.exception("Failed to kick user from portals on logout") await self.push_bridge_state(BridgeStateEvent.LOGGED_OUT) if self.tgid: try: @@ -421,51 +367,54 @@ class User(AbstractUser, BaseUser): except KeyError: pass self.tgid = None - await self.save() ok = await self.client.log_out() - self.client.session.delete() - self.delete() + await self.client.session.delete() + await self.delete() + self.by_mxid.pop(self.mxid, None) await self.stop() self._track_metric(METRIC_LOGGED_IN, False) return ok - def _search_local(self, query: str, max_results: int = 5, min_similarity: int = 45 - ) -> List[SearchResult]: - results: List[SearchResult] = [] - for contact in self.contacts: + async def _search_local(self, query: str, max_results: int = 5, min_similarity: int = 45 + ) -> list[SearchResult]: + results: list[SearchResult] = [] + for contact_id in await self.get_contacts(): + contact = await pu.Puppet.get_by_tgid(contact_id, create=False) + if not contact: + continue similarity = contact.similarity(query) if similarity >= min_similarity: results.append(SearchResult(contact, similarity)) results.sort(key=lambda tup: tup[1], reverse=True) return results[0:max_results] - async def _search_remote(self, query: str, max_results: int = 5) -> List[SearchResult]: + async def _search_remote(self, query: str, max_results: int = 5) -> list[SearchResult]: if len(query) < 5: return [] server_results = await self.client(SearchRequest(q=query, limit=max_results)) - results: List[SearchResult] = [] + results: list[SearchResult] = [] for user in server_results.users: - puppet = pu.Puppet.get(user.id) + puppet = await pu.Puppet.get_by_tgid(user.id) await puppet.update_info(self, user) results.append(SearchResult(puppet, puppet.similarity(query))) results.sort(key=lambda tup: tup[1], reverse=True) return results[0:max_results] async def search(self, query: str, force_remote: bool = False - ) -> Tuple[List[SearchResult], bool]: + ) -> tuple[list[SearchResult], bool]: if force_remote: return await self._search_remote(query), True - results = self._search_local(query) + results = await self._search_local(query) if results: return results, False return await self._search_remote(query), True - async def get_direct_chats(self) -> Dict[UserID, List[RoomID]]: + async def get_direct_chats(self) -> dict[UserID, list[RoomID]]: return { pu.Puppet.get_mxid_from_id(portal.tgid): [portal.mxid] - for portal in DBPortal.find_private_chats(self.tgid) + async for portal in po.Portal.find_private_chats(self.tgid) if portal.mxid } @@ -478,12 +427,14 @@ class User(AbstractUser, BaseUser): tag_info = RoomTagInfo(order=0.5) tag_info[DOUBLE_PUPPET_SOURCE_KEY] = self.bridge.name await puppet.intent.set_room_tag(portal.mxid, tag, tag_info) - elif not active and tag_info and tag_info.get(DOUBLE_PUPPET_SOURCE_KEY) == self.bridge.name: + elif ( + not active and tag_info + and tag_info.get(DOUBLE_PUPPET_SOURCE_KEY) == self.bridge.name + ): await puppet.intent.remove_room_tag(portal.mxid, tag) - @staticmethod - async def _mute_room(puppet: pu.Puppet, portal: po.Portal, mute_until: datetime) -> None: - if not config["bridge.mute_bridging"] or not portal or not portal.mxid: + async def _mute_room(cls, puppet: pu.Puppet, portal: po.Portal, mute_until: datetime) -> None: + if not cls.config["bridge.mute_bridging"] or not portal or not portal.mxid: return now = datetime.utcnow().replace(tzinfo=timezone.utc) if mute_until is not None and mute_until > now: @@ -497,29 +448,31 @@ class User(AbstractUser, BaseUser): pass async def update_folder_peers(self, update: UpdateFolderPeers) -> None: - if config["bridge.tag_only_on_create"]: + if self.config["bridge.tag_only_on_create"]: return puppet = await pu.Puppet.get_by_custom_mxid(self.mxid) if not puppet or not puppet.is_real_user: return for peer in update.folder_peers: - portal = po.Portal.get_by_entity(peer.peer, receiver_id=self.tgid, create=False) - await self._tag_room(puppet, portal, config["bridge.archive_tag"], + portal = await po.Portal.get_by_entity(peer.peer, tg_receiver=self.tgid, create=False) + await self._tag_room(puppet, portal, self.config["bridge.archive_tag"], peer.folder_id == 1) async def update_pinned_dialogs(self, update: UpdatePinnedDialogs) -> None: - if config["bridge.tag_only_on_create"]: + if self.config["bridge.tag_only_on_create"]: return puppet = await pu.Puppet.get_by_custom_mxid(self.mxid) if not puppet or not puppet.is_real_user: return # TODO bridge unpinning properly for pinned in update.order: - portal = po.Portal.get_by_entity(pinned.peer, receiver_id=self.tgid, create=False) - await self._tag_room(puppet, portal, config["bridge.pinned_tag"], True) + portal = await po.Portal.get_by_entity( + pinned.peer, tg_receiver=self.tgid, create=False + ) + await self._tag_room(puppet, portal, self.config["bridge.pinned_tag"], True) async def update_notify_settings(self, update: UpdateNotifySettings) -> None: - if config["bridge.tag_only_on_create"]: + if self.config["bridge.tag_only_on_create"]: return elif not isinstance(update.peer, NotifyPeer): # TODO handle global notification setting changes? @@ -527,11 +480,13 @@ class User(AbstractUser, BaseUser): puppet = await pu.Puppet.get_by_custom_mxid(self.mxid) if not puppet or not puppet.is_real_user: return - portal = po.Portal.get_by_entity(update.peer.peer, receiver_id=self.tgid, create=False) + portal = await po.Portal.get_by_entity( + update.peer.peer, tg_receiver=self.tgid, create=False + ) await self._mute_room(puppet, portal, update.notify_settings.mute_until) async def _sync_dialog(self, portal: po.Portal, dialog: Dialog, should_create: bool, - puppet: Optional[pu.Puppet]) -> None: + puppet: pu.Puppet | None) -> None: was_created = False if portal.mxid: try: @@ -553,29 +508,41 @@ class User(AbstractUser, BaseUser): if dialog.unread_count == 0: # This is usually more reliable than finding a specific message # e.g. if the last read message is a service message that isn't in the message db - last_read = DBMessage.find_last(portal.mxid, tg_space) + last_read = await DBMessage.find_last(portal.mxid, tg_space) else: - last_read = DBMessage.get_one_by_tgid(portal.tgid, tg_space, - dialog.dialog.read_inbox_max_id) + last_read = await DBMessage.get_one_by_tgid(portal.tgid, tg_space, + dialog.dialog.read_inbox_max_id) if last_read: await puppet.intent.mark_read(last_read.mx_room, last_read.mxid) - if was_created or not config["bridge.tag_only_on_create"]: + if was_created or not self.config["bridge.tag_only_on_create"]: await self._mute_room(puppet, portal, dialog.dialog.notify_settings.mute_until) - await self._tag_room(puppet, portal, config["bridge.pinned_tag"], dialog.pinned) - await self._tag_room(puppet, portal, config["bridge.archive_tag"], dialog.archived) + await self._tag_room(puppet, portal, self.config["bridge.pinned_tag"], + dialog.pinned) + await self._tag_room(puppet, portal, self.config["bridge.archive_tag"], + dialog.archived) + + async def get_cached_portals(self) -> dict[tuple[TelegramID, TelegramID], po.Portal]: + if self._portals_cache is None: + self._portals_cache = { + (tgid, tg_receiver): await po.Portal.get_by_tgid(tgid, tg_receiver=tg_receiver) + for tgid, tg_receiver in await self.get_portals() + } + return self._portals_cache async def sync_dialogs(self) -> None: if self.is_bot: return creators = [] - update_limit = config["bridge.sync_update_limit"] or None - create_limit = config["bridge.sync_create_limit"] + update_limit = self.config["bridge.sync_update_limit"] or None + create_limit = self.config["bridge.sync_create_limit"] index = 0 self.log.debug(f"Syncing dialogs (update_limit={update_limit}, " f"create_limit={create_limit})") await self.push_bridge_state(BridgeStateEvent.BACKFILLING) puppet = await pu.Puppet.get_by_custom_mxid(self.mxid) dialog: Dialog + old_portal_cache = await self.get_cached_portals() + new_portal_cache = old_portal_cache.copy() async for dialog in self.client.iter_dialogs(limit=update_limit, ignore_migrated=True, archived=False): entity = dialog.entity @@ -585,125 +552,152 @@ class User(AbstractUser, BaseUser): elif isinstance(entity, Chat) and (entity.deactivated or entity.left): self.log.warning(f"Ignoring deactivated or left chat {entity} while syncing") continue - elif isinstance(entity, TLUser) and not config["bridge.sync_direct_chats"]: + elif isinstance(entity, TLUser) and not self.config["bridge.sync_direct_chats"]: self.log.trace(f"Ignoring user {entity.id} while syncing") continue - portal = po.Portal.get_by_entity(entity, receiver_id=self.tgid) - self.portals[portal.tgid_full] = portal + portal = await po.Portal.get_by_entity(entity, tg_receiver=self.tgid) + new_portal_cache[portal.tgid_full] = portal coro = self._sync_dialog(portal=portal, dialog=dialog, puppet=puppet, should_create=not create_limit or index < create_limit) creators.append(self.loop.create_task(coro)) index += 1 - await self.save(portals=True) + if new_portal_cache.keys() != old_portal_cache.keys(): + await self.set_portals(new_portal_cache.keys()) + self._portals_cache = new_portal_cache await asyncio.gather(*creators) await self.update_direct_chats() self.log.debug("Dialog syncing complete") async def register_portal(self, portal: po.Portal) -> None: self.log.trace(f"Registering portal {portal.tgid_full}") - try: - if self.portals[portal.tgid_full] == portal: + if self._portals_cache is not None: + if self._portals_cache.get(portal.tgid_full) == portal: return - except KeyError: - pass - self.portals[portal.tgid_full] = portal - await self.save(portals=True) + self._portals_cache[portal.tgid_full] = portal + await super().register_portal(portal.tgid, portal.tg_receiver) async def unregister_portal(self, tgid: TelegramID, tg_receiver: TelegramID) -> None: self.log.trace(f"Unregistering portal {(tgid, tg_receiver)}") - try: - del self.portals[(tgid, tg_receiver)] - await self.save(portals=True) - except KeyError: - pass + if self._portals_cache is not None: + self._portals_cache.pop((tgid, tg_receiver), None) + await super().unregister_portal(tgid, tg_receiver) async def needs_relaybot(self, portal: po.Portal) -> bool: return not await self.is_logged_in() or ( - (portal.has_bot or self.is_bot) and portal.tgid_full not in self.portals) + (portal.has_bot or self.is_bot) + and portal.tgid_full not in await self.get_cached_portals() + ) - def _hash_contacts(self) -> int: + @staticmethod + def _hash_contacts(count: int, ids: list[TelegramID]) -> int: acc = 0 - for contact in sorted([self.saved_contacts] + [contact.id for contact in self.contacts]): + for contact in sorted([count] + ids): acc = (acc * 20261 + contact) & 0xffffffff return acc & 0x7fffffff async def sync_contacts(self) -> None: - response = await self.client(GetContactsRequest(hash=self._hash_contacts())) + existing_contacts = await self.get_contacts() + contact_hash = self._hash_contacts(self.saved_contacts, existing_contacts) + response = await self.client(GetContactsRequest(hash=contact_hash)) if isinstance(response, ContactsNotModified): return self.log.debug(f"Updating contacts of {self.name}...") - self.contacts = [] - self.saved_contacts = response.saved_count + if self.saved_contacts != response.saved_count: + self.saved_contacts = response.saved_count + await self.save() + await self.set_contacts(user.id for user in response.users) for user in response.users: - puppet = pu.Puppet.get(user.id) + puppet = await pu.Puppet.get_by_tgid(user.id) await puppet.update_info(self, user) - self.contacts.append(puppet) - await self.save(contacts=True) # endregion # region Class instance lookup - @classmethod - def get_by_mxid(cls, mxid: UserID, create: bool = True, check_db: bool = True - ) -> Optional['User']: - if not mxid: - raise ValueError("Matrix ID can't be empty") + def _add_to_cache(self) -> None: + self.by_mxid[self.mxid] = self + if self.tgid: + self.by_tgid[self.tgid] = self + @classmethod + async def get_and_start_by_mxid(cls, mxid: UserID, even_if_no_session: bool = False) -> User: + user = await cls.get_by_mxid(mxid, create=True) + await user.ensure_started(even_if_no_session=even_if_no_session) + return user + + @classmethod + async def all_with_tgid(cls) -> AsyncGenerator[User, None]: + users = await super().all_with_tgid() + user: cls + for user in users: + try: + yield cls.by_mxid[user.mxid] + except KeyError: + user._add_to_cache() + yield user + + @classmethod + @async_getter_lock + async def get_by_mxid( + cls, mxid: UserID, *, check_db: bool = True, create: bool = True + ) -> User | None: + if not mxid or pu.Puppet.get_id_from_mxid(mxid) or mxid == cls.az.bot_mxid: + return None try: return cls.by_mxid[mxid] except KeyError: pass - if check_db: - user = DBUser.get_by_mxid(mxid) - if user: - user = cls.from_db(user) - return user + if not check_db: + return None + + user = cast(cls, await super().get_by_mxid(mxid)) + if user is not None: + user._add_to_cache() + return user if create: + cls.log.debug(f"Creating user instance for {mxid}") user = cls(mxid) - user.db_instance.insert() + await user.insert() + user._add_to_cache() return user return None @classmethod - def get_by_tgid(cls, tgid: TelegramID) -> Optional['User']: + @async_getter_lock + async def get_by_tgid(cls, tgid: TelegramID) -> User | None: try: return cls.by_tgid[tgid] except KeyError: pass - user = DBUser.get_by_tgid(tgid) - if user: - user = cls.from_db(user) + user = cast(cls, await super().get_by_tgid(tgid)) + if user is not None: + user._add_to_cache() return user return None @classmethod - def find_by_username(cls, username: str) -> Optional['User']: + async def find_by_username(cls, username: str) -> User | None: if not username: return None username = username.lower() for _, user in cls.by_tgid.items(): - if user.username and user.username.lower() == username: + if user.tg_username and user.tg_username.lower() == username: return user - puppet = DBUser.get_by_username(username) - if puppet: - return cls.from_db(puppet) + user = cast(cls, await super().find_by_username(username)) + if user: + try: + return cls.by_mxid[user.mxid] + except KeyError: + user._add_to_cache() + return user return None + # endregion - - -def init(context: 'Context') -> Iterable[Awaitable['User']]: - global config - config = context.config - User.bridge = context.bridge - - return (User.from_db(db_user).try_ensure_started() - for db_user in DBUser.all_with_tgid()) diff --git a/mautrix_telegram/util/__init__.py b/mautrix_telegram/util/__init__.py index 023ef0de..1deae593 100644 --- a/mautrix_telegram/util/__init__.py +++ b/mautrix_telegram/util/__init__.py @@ -2,3 +2,5 @@ from .file_transfer import transfer_file_to_matrix, convert_image from .parallel_file_transfer import parallel_transfer_to_telegram from .recursive_dict import recursive_del, recursive_set, recursive_get from .color_log import ColorFormatter +from .send_lock import PortalSendLock +from .deduplication import PortalDedup diff --git a/mautrix_telegram/portal/deduplication.py b/mautrix_telegram/util/deduplication.py similarity index 92% rename from mautrix_telegram/portal/deduplication.py rename to mautrix_telegram/util/deduplication.py index 82faafc9..d2ff8dfd 100644 --- a/mautrix_telegram/portal/deduplication.py +++ b/mautrix_telegram/util/deduplication.py @@ -1,5 +1,5 @@ # mautrix-telegram - A Matrix-Telegram puppeting bridge -# Copyright (C) 2019 Tulir Asokan +# Copyright (C) 2021 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -24,11 +24,10 @@ from telethon.tl.types import (MessageMediaContact, MessageMediaDocument, Messag from mautrix.types import EventID -from ..context import Context from ..types import TelegramID if TYPE_CHECKING: - from .base import BasePortal + from ..portal import Portal DedupMXID = Tuple[EventID, TelegramID] @@ -40,9 +39,9 @@ class PortalDedup: _dedup: Deque[str] _dedup_mxid: Dict[str, DedupMXID] _dedup_action: Deque[str] - _portal: 'BasePortal' + _portal: 'Portal' - def __init__(self, portal: 'BasePortal') -> None: + def __init__(self, portal: 'Portal') -> None: self._dedup = deque() self._dedup_mxid = {} self._dedup_action = deque() @@ -125,9 +124,3 @@ class PortalDedup: and isinstance(update.message, MessageService)) if check_dedup: self.check(update.message) - - -def init(context: Context) -> None: - cfg = context.config - PortalDedup.dedup_pre_db_check = cfg["bridge.deduplication.pre_db_check"] - PortalDedup.dedup_cache_queue_length = cfg["bridge.deduplication.cache_queue_length"] diff --git a/mautrix_telegram/util/file_transfer.py b/mautrix_telegram/util/file_transfer.py index 2ba76271..b40a245a 100644 --- a/mautrix_telegram/util/file_transfer.py +++ b/mautrix_telegram/util/file_transfer.py @@ -21,7 +21,8 @@ import asyncio import tempfile import magic -from sqlalchemy.exc import IntegrityError, InvalidRequestError +from asyncpg import UniqueViolationError +from sqlite3 import IntegrityError from telethon.tl.types import (Document, InputFileLocation, InputDocumentFileLocation, TypePhotoSize, PhotoSize, PhotoCachedSize, InputPhotoFileLocation, @@ -123,7 +124,7 @@ async def transfer_thumbnail_to_matrix(client: MautrixTelegramClient, intent: In if custom_data: loc_id += "-mau_custom_thumbnail" - db_file = DBTelegramFile.get(loc_id) + db_file = await DBTelegramFile.get(loc_id) if db_file: return db_file @@ -154,8 +155,8 @@ async def transfer_thumbnail_to_matrix(client: MautrixTelegramClient, intent: In was_converted=False, timestamp=int(time.time()), size=len(file), width=width, height=height, decryption_info=decryption_info) try: - db_file.insert() - except (IntegrityError, InvalidRequestError) as e: + await db_file.insert() + except (UniqueViolationError, IntegrityError) as e: log.exception(f"{e.__class__.__name__} while saving transferred file thumbnail data. " "This was probably caused by two simultaneous transfers of the same file, " "and might (but probably won't) cause problems with thumbnails or something.") @@ -176,7 +177,7 @@ async def transfer_file_to_matrix(client: MautrixTelegramClient, intent: IntentA if not location_id: return None - db_file = DBTelegramFile.get(location_id) + db_file = await DBTelegramFile.get(location_id) if db_file: return db_file @@ -197,7 +198,7 @@ async def _unlocked_transfer_file_to_matrix(client: MautrixTelegramClient, inten tgs_convert: Optional[dict], filename: Optional[str], encrypt: bool, parallel_id: Optional[int] ) -> Optional[DBTelegramFile]: - db_file = DBTelegramFile.get(loc_id) + db_file = await DBTelegramFile.get(loc_id) if db_file: return db_file @@ -263,8 +264,8 @@ async def _unlocked_transfer_file_to_matrix(client: MautrixTelegramClient, inten width=converted_anim.width, height=converted_anim.height) try: - db_file.insert() - except (IntegrityError, InvalidRequestError) as e: + await db_file.insert() + except (UniqueViolationError, IntegrityError) as e: log.exception(f"{e.__class__.__name__} while saving transferred file data. " "This was probably caused by two simultaneous transfers of the same file, " "and should not cause any problems.") diff --git a/mautrix_telegram/portal/send_lock.py b/mautrix_telegram/util/send_lock.py similarity index 97% rename from mautrix_telegram/portal/send_lock.py rename to mautrix_telegram/util/send_lock.py index c760f44b..e8fc7b79 100644 --- a/mautrix_telegram/portal/send_lock.py +++ b/mautrix_telegram/util/send_lock.py @@ -1,5 +1,5 @@ # mautrix-telegram - A Matrix-Telegram puppeting bridge -# Copyright (C) 2019 Tulir Asokan +# Copyright (C) 2021 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by diff --git a/mautrix_telegram/web/common/auth_api.py b/mautrix_telegram/web/common/auth_api.py index c020dd54..f35eabc5 100644 --- a/mautrix_telegram/web/common/auth_api.py +++ b/mautrix_telegram/web/common/auth_api.py @@ -52,7 +52,7 @@ class AuthAPI(abc.ABC): raise NotImplementedError() async def post_matrix_token(self, user: User, token: str) -> web.Response: - puppet = Puppet.get(user.tgid) + puppet = await Puppet.get_by_tgid(user.tgid) if puppet.is_real_user: return self.get_mx_login_response(state="already-logged-in", status=409, error="You have already logged in with your Matrix " @@ -116,7 +116,7 @@ class AuthAPI(abc.ABC): error="Internal server error while requesting code.") async def postprocess_login(self, user: User, user_info) -> None: - existing_user = User.get_by_tgid(user_info.id) + existing_user = await User.get_by_tgid(user_info.id) if existing_user and existing_user != user: await existing_user.log_out() asyncio.ensure_future(user.post_login(user_info, first_login=True), loop=self.loop) diff --git a/mautrix_telegram/web/provisioning/__init__.py b/mautrix_telegram/web/provisioning/__init__.py index 18e6b522..7538334c 100644 --- a/mautrix_telegram/web/provisioning/__init__.py +++ b/mautrix_telegram/web/provisioning/__init__.py @@ -21,10 +21,7 @@ import json from aiohttp import web from telethon.utils import get_peer_id, resolve_id -from telethon.tl.types import ChatForbidden, ChannelForbidden, TypeChat, InputUserSelf -from telethon.tl.functions.users import GetUsersRequest -from telethon.errors import (UserDeactivatedError, UserDeactivatedBanError, SessionRevokedError, - UnauthorizedError) +from telethon.tl.types import ChatForbidden, ChannelForbidden, TypeChat from mautrix.appservice import AppService from mautrix.errors import MatrixRequestError, IntentError @@ -37,23 +34,23 @@ from ...commands.portal.util import user_has_power_level, get_initial_state from ..common import AuthAPI if TYPE_CHECKING: - from ...context import Context + from ...__main__ import TelegramBridge class ProvisioningAPI(AuthAPI): log: logging.Logger = logging.getLogger("mau.web.provisioning") secret: str az: AppService - context: 'Context' + bridge: 'TelegramBridge' app: web.Application - def __init__(self, context: "Context") -> None: - super().__init__(context.loop) - self.secret = context.config["appservice.provisioning.shared_secret"] - self.az = context.az - self.context = context + def __init__(self, bridge: "TelegramBridge") -> None: + super().__init__(bridge.loop) + self.secret = bridge.config["appservice.provisioning.shared_secret"] + self.az = bridge.az + self.bridge = bridge - self.app = web.Application(loop=context.loop, middlewares=[self.error_middleware]) + self.app = web.Application(loop=bridge.loop, middlewares=[self.error_middleware]) portal_prefix = "/portal/{mxid:![^/]+}" self.app.router.add_route("GET", f"{portal_prefix}", self.get_portal_by_mxid) @@ -81,7 +78,7 @@ class ProvisioningAPI(AuthAPI): return err mxid = request.match_info["mxid"] - portal = Portal.get_by_mxid(mxid) + portal = await Portal.get_by_mxid(mxid) if not portal: return self.get_error_response(404, "portal_not_found", "Portal with given Matrix ID not found.") @@ -97,7 +94,7 @@ class ProvisioningAPI(AuthAPI): except ValueError: return self.get_error_response(400, "tgid_invalid", "Given chat ID is not valid.") - portal = Portal.get_by_tgid(tgid) + portal = await Portal.get_by_tgid(tgid) if not portal: return self.get_error_response(404, "portal_not_found", "Portal to given Telegram chat not found.") @@ -122,7 +119,7 @@ class ProvisioningAPI(AuthAPI): return err room_id = request.match_info["mxid"] - if Portal.get_by_mxid(room_id): + if await Portal.get_by_mxid(room_id): return self.get_error_response(409, "room_already_bridged", "Room is already bridged to another Telegram chat.") @@ -145,12 +142,12 @@ class ProvisioningAPI(AuthAPI): "You do not have the permissions to bridge that room.") is_logged_in = user is not None and await user.is_logged_in() - acting_user = user if is_logged_in else self.context.bot + acting_user = user if is_logged_in else self.bridge.bot if not acting_user: return self.get_login_response(status=403, errcode="not_logged_in", error="You are not logged in and there is no relay bot.") - portal = Portal.get_by_tgid(tgid, peer_type=peer_type) + portal = await Portal.get_by_tgid(tgid, peer_type=peer_type) if portal.mxid == room_id: return self.get_error_response(200, "bridge_exists", "Telegram chat is already bridged to that Matrix room.") @@ -204,7 +201,7 @@ class ProvisioningAPI(AuthAPI): return self.get_error_response(400, "json_invalid", "Invalid JSON.") room_id = request.match_info["mxid"] - if Portal.get_by_mxid(room_id): + if await Portal.get_by_mxid(room_id): return self.get_error_response(409, "room_already_bridged", "Room is already bridged to another Telegram chat.") @@ -245,7 +242,7 @@ class ProvisioningAPI(AuthAPI): }[type] portal = Portal(tgid=TelegramID(0), mxid=room_id, title=title, about=about, peer_type=type, - encrypted=encrypted) + encrypted=encrypted, tg_receiver=TelegramID(0)) try: await portal.create_telegram_chat(user, supergroup=supergroup) except ValueError as e: @@ -261,7 +258,7 @@ class ProvisioningAPI(AuthAPI): if err is not None: return err - portal = Portal.get_by_mxid(request.match_info["mxid"]) + portal = await Portal.get_by_mxid(request.match_info["mxid"]) if not portal or not portal.tgid: return self.get_error_response(404, "portal_not_found", "Room is not a portal.") @@ -302,10 +299,10 @@ class ProvisioningAPI(AuthAPI): await user.update_info(me) user_data = { "id": user.tgid, - "username": user.username, + "username": user.tg_username, "first_name": me.first_name, "last_name": me.last_name, - "phone": me.phone, + "phone": user.tg_phone, "is_bot": user.is_bot, } return web.json_response({ @@ -328,7 +325,7 @@ class ProvisioningAPI(AuthAPI): return web.json_response([{ "id": get_peer_id(chat.peer), "title": chat.title, - } for chat in user.portals.values() if chat.tgid]) + } for chat in (await user.get_cached_portals()).values() if chat.tgid]) async def send_bot_token(self, request: web.Request) -> web.Response: data, user, err = await self.get_user_request_info(request) @@ -365,8 +362,8 @@ class ProvisioningAPI(AuthAPI): async def bridge_info(self, request: web.Request) -> web.Response: return web.json_response({ - "relaybot_username": (self.context.bot.username - if self.context.bot is not None else None), + "relaybot_username": (self.bridge.bot.tg_username + if self.bridge.bot is not None else None), }, status=200) @staticmethod @@ -441,14 +438,14 @@ class ProvisioningAPI(AuthAPI): return None, self.get_login_response(error="User ID not given.", errcode="mxid_empty", status=400) - user = await User.get_by_mxid(mxid).ensure_started(even_if_no_session=True) + user = await User.get_and_start_by_mxid(mxid, even_if_no_session=True) if require_puppeting and not user.puppet_whitelisted: return user, self.get_login_response(error="You are not whitelisted.", errcode="mxid_not_whitelisted", status=403) if expect_logged_in is not None: logged_in = await user.is_logged_in() if not expect_logged_in and logged_in: - return user, self.get_login_response(username=user.username, phone=user.phone, + return user, self.get_login_response(username=user.tg_username, phone=user.tg_phone, status=409, error="You are already logged in.", errcode="already_logged_in") diff --git a/mautrix_telegram/web/public/__init__.py b/mautrix_telegram/web/public/__init__.py index 71f7de88..62342c3d 100644 --- a/mautrix_telegram/web/public/__init__.py +++ b/mautrix_telegram/web/public/__init__.py @@ -77,7 +77,7 @@ class PublicBridgeWebsite(AuthAPI): mxid = self.verify_token(request.rel_url.query.get("token", None), endpoint="/login") if not mxid: return self.get_login_response(status=401, state="invalid-token") - user = User.get_by_mxid(mxid, create=False) if mxid else None + user = await User.get_by_mxid(mxid, create=False) if mxid else None if not user: return self.get_login_response(mxid=mxid, state=state) @@ -95,7 +95,7 @@ class PublicBridgeWebsite(AuthAPI): endpoint="/matrix-login") if not mxid: return self.get_mx_login_response(status=401, state="invalid-token") - user = User.get_by_mxid(mxid, create=False) if mxid else None + user = await User.get_by_mxid(mxid, create=False) if mxid else None if not user: return self.get_mx_login_response(mxid=mxid) @@ -107,7 +107,7 @@ class PublicBridgeWebsite(AuthAPI): return self.get_mx_login_response(mxid=user.mxid, status=403, error="You are not logged in to Telegram.") - puppet = Puppet.get(user.tgid) + puppet = await Puppet.get_by_tgid(user.tgid) if puppet.is_real_user: return self.get_mx_login_response(state="already-logged-in", status=409) @@ -136,7 +136,7 @@ class PublicBridgeWebsite(AuthAPI): data = await request.post() - user = await User.get_by_mxid(mxid).ensure_started() + user = await User.get_and_start_by_mxid(mxid) if not user.puppet_whitelisted: return self.get_mx_login_response(mxid=user.mxid, error="You are not whitelisted.", status=403) diff --git a/optional-requirements.txt b/optional-requirements.txt index d8fd7fa8..cd07d1d3 100644 --- a/optional-requirements.txt +++ b/optional-requirements.txt @@ -17,13 +17,6 @@ moviepy>=1,<2 #/metrics prometheus_client>=0.6,<0.13 -#/postgres -psycopg2-binary>=2,<3 -asyncpg>=0.20,<0.26 - -#/sqlite -aiosqlite>=0.17,<0.18 - #/e2be python-olm>=3,<4 pycryptodome>=3,<4 diff --git a/requirements.txt b/requirements.txt index 92b6de2b..79833fb2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,13 +1,11 @@ -SQLAlchemy>=1.2,<1.4 -alembic>=1,<2 ruamel.yaml>=0.15.35,<0.18 python-magic>=0.4,<0.5 commonmark>=0.8,<0.10 aiohttp>=3,<4 yarl>=1,<2 -mautrix>=0.13.3,<0.14 -telethon>=1.24,<1.25 -telethon-session-sqlalchemy>=0.2.14,<0.3 -# Temporarily always depend on aiosqlite to prevent breaking old installs -# Will be removed in v0.12 (after which you need to choose the [sqlite] optional dependency) +mautrix==0.14.0rc1 +#telethon>=1.24,<1.25 +# Fork to make session storage async +tulir-telethon==1.25.0a1 +asyncpg>=0.20,<0.26 aiosqlite>=0.17,<0.18 diff --git a/setup.py b/setup.py index 6c326a24..9dc183f1 100644 --- a/setup.py +++ b/setup.py @@ -61,9 +61,9 @@ setuptools.setup( "Framework :: AsyncIO", "Programming Language :: Python", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", ], package_data={"mautrix_telegram": [ "web/public/*.mako", "web/public/*.png", "web/public/*.css",