diff --git a/mautrix_appservice/appservice.py b/mautrix_appservice/appservice.py index d852a25e..e54288b9 100644 --- a/mautrix_appservice/appservice.py +++ b/mautrix_appservice/appservice.py @@ -47,8 +47,8 @@ class AppService: self.log = (logging.getLogger(log) if isinstance(log, str) else log or logging.getLogger("mautrix_appservice")) - self.query_user = query_user or (lambda user: None) - self.query_alias = query_alias or (lambda alias: None) + self.query_user = query_user or self.default_query_handler + self.query_alias = query_alias or self.default_query_handler self.event_handlers = [] @@ -60,6 +60,9 @@ class AppService: self.matrix_event_handler(self.update_state_store) + async def default_query_handler(self, param): + return None + @property def http_session(self): if self._http_session is None: @@ -80,10 +83,10 @@ class AppService: def run(self, host="127.0.0.1", port=8080): self._http_session = aiohttp.ClientSession(loop=self.loop) self._intent = HTTPAPI(base_url=self.server, domain=self.domain, bot_mxid=self.bot_mxid, - token=self.as_token, log=self.log, - state_store=self.state_store).bot_intent() + token=self.as_token, log=self.log, state_store=self.state_store, + client_session=self._http_session).bot_intent() - yield partial(aiohttp.web.run_app, self.app, host=host, port=port) + yield self.loop.create_server(self.app.make_handler(), host, port) self._intent = None self._http_session.close() @@ -107,7 +110,7 @@ class AppService: user_id = request.match_info["userId"] try: - response = self.query_user(user_id) + response = await self.query_user(user_id) except Exception: self.log.exception("Exception in user query handler") return web.Response(status=500) @@ -123,7 +126,7 @@ class AppService: alias = request.match_info["alias"] try: - response = self.query_alias(alias) + response = await self.query_alias(alias) except Exception: self.log.exception("Exception in alias query handler") return web.Response(status=500) @@ -154,7 +157,7 @@ class AppService: return web.json_response({}) - def update_state_store(self, event): + async def update_state_store(self, event): event_type = event["type"] if event_type == "m.room.power_levels": self.state_store.set_power_levels(event["room_id"], event["content"]) @@ -163,12 +166,15 @@ class AppService: event["content"]["membership"]) def handle_matrix_event(self, event): - for handler in self.event_handlers: + async def try_handle(handler): try: - handler(event) + await handler(event) except Exception: self.log.exception("Exception in Matrix event handler") + for handler in self.event_handlers: + asyncio.ensure_future(try_handle(handler)) + def matrix_event_handler(self, func): self.event_handlers.append(func) return func diff --git a/mautrix_appservice/intent_api.py b/mautrix_appservice/intent_api.py index b9c61121..56915cf7 100644 --- a/mautrix_appservice/intent_api.py +++ b/mautrix_appservice/intent_api.py @@ -19,14 +19,15 @@ import json import magic import urllib.request -from matrix_client.api import MatrixHttpApi from matrix_client.errors import MatrixRequestError +from .temp_async_api import AsyncHTTPAPI -class HTTPAPI(MatrixHttpApi): + +class HTTPAPI(AsyncHTTPAPI): def __init__(self, base_url, domain=None, bot_mxid=None, token=None, identity=None, log=None, - state_store=None): - super().__init__(base_url, token, identity) + state_store=None, client_session=None): + super().__init__(base_url, client_session, token, identity) self.domain = domain self.bot_mxid = bot_mxid self.intent_log = log.getChild("intent") @@ -53,7 +54,8 @@ class HTTPAPI(MatrixHttpApi): api_path="/_matrix/client/r0"): if not query_params: query_params = {} - query_params["user_id"] = self.identity + if self.identity: + query_params["user_id"] = self.identity log_content = content if not isinstance(content, bytes) else f"<{len(content)} bytes>" self.log.debug("%s %s %s", method, path, log_content) return super()._send(method, path, content, query_params, headers or {}, api_path=api_path) @@ -104,6 +106,7 @@ class ChildHTTPAPI(HTTPAPI): self.log = parent.log self.domain = parent.domain self.parent = parent + self.client_session = parent.client_session @property def txn_id(self): @@ -127,6 +130,7 @@ def matrix_error_code(err): except Exception: return err.content + def matrix_error_data(err): try: data = json.loads(err.content) @@ -135,8 +139,6 @@ def matrix_error_data(err): return err.content - - class IntentAPI: mxid_regex = re.compile("@(.+):(.+)") @@ -162,51 +164,51 @@ class IntentAPI: # region User actions - def get_joined_rooms(self): - self.ensure_registered() - response = self.client._send("GET", "/joined_rooms") + async def get_joined_rooms(self): + await self.ensure_registered() + response = await self.client._send("GET", "/joined_rooms") return response["joined_rooms"] - def set_display_name(self, name): - self.ensure_registered() - return self.client.set_display_name(self.mxid, name) + async def set_display_name(self, name): + await self.ensure_registered() + return await self.client.set_display_name(self.mxid, name) - def set_presence(self, status="online"): - self.ensure_registered() - return self.client.set_presence(status) + async def set_presence(self, status="online"): + await self.ensure_registered() + return await self.client.set_presence(status) - def set_avatar(self, url): - self.ensure_registered() - return self.client.set_avatar_url(self.mxid, url) + async def set_avatar(self, url): + await self.ensure_registered() + return await self.client.set_avatar_url(self.mxid, url) - def upload_file(self, data, mime_type=None): - self.ensure_registered() + async def upload_file(self, data, mime_type=None): + await self.ensure_registered() mime_type = mime_type or magic.from_buffer(data, mime=True) - return self.client.media_upload(data, mime_type) + return await self.client.media_upload(data, mime_type) - def download_file(self, url): - self.ensure_registered() + async def download_file(self, url): + await self.ensure_registered() url = self.client.get_download_url(url) - response = urllib.request.urlopen(url) - return response.read() + async with self.client.client_session.get(url) as response: + return await response.read() # endregion # region Room actions - def create_room(self, alias=None, is_public=False, name=None, topic=None, is_direct=False, + async def create_room(self, alias=None, is_public=False, name=None, topic=None, is_direct=False, invitees=(), initial_state=None): - self.ensure_registered() - return self.client.create_room(alias, is_public, name, topic, is_direct, invitees, + await self.ensure_registered() + return await self.client.create_room(alias, is_public, name, topic, is_direct, invitees, initial_state or {}) - def invite(self, room_id, user_id, check_cache=False): - self.ensure_joined(room_id) + async def invite(self, room_id, user_id, check_cache=False): + await self.ensure_joined(room_id) try: ok_states = {"invite", "join"} do_invite = (not check_cache or self.state_store.get_membership(room_id, user_id) not in ok_states) if do_invite: - response = self.client.invite_user(room_id, user_id) + response = await self.client.invite_user(room_id, user_id) self.state_store.invited(room_id, user_id) return response except MatrixRequestError as e: @@ -224,38 +226,38 @@ class IntentAPI: content["info"] = info return self.send_state_event(room_id, "m.room.avatar", content) - def add_room_alias(self, room_id, alias): - self.ensure_registered() - self.client.set_room_alias(room_id, f"#{alias}:{self.client.domain}") + async def add_room_alias(self, room_id, alias): + await self.ensure_registered() + return await self.client.set_room_alias(room_id, f"#{alias}:{self.client.domain}") - def remove_room_alias(self, alias): - self.ensure_registered() - self.client.remove_room_alias(f"#{alias}:{self.client.domain}") + async def remove_room_alias(self, alias): + await self.ensure_registered() + return await self.client.remove_room_alias(f"#{alias}:{self.client.domain}") - def set_room_name(self, room_id, name): - self.ensure_joined(room_id) + async def set_room_name(self, room_id, name): + await self.ensure_joined(room_id) self._ensure_has_power_level_for(room_id, "m.room.name") - return self.client.set_room_name(room_id, name) + return await self.client.set_room_name(room_id, name) - def get_power_levels(self, room_id, ignore_cache=False): - self.ensure_joined(room_id) + async def get_power_levels(self, room_id, ignore_cache=False): + await self.ensure_joined(room_id) if not ignore_cache: try: return self.state_store.get_power_levels(room_id) except KeyError: pass - levels = self.client.get_power_levels(room_id) + levels = await self.client.get_power_levels(room_id) self.state_store.set_power_levels(room_id, levels) return levels - def set_power_levels(self, room_id, content): - response = self.send_state_event(room_id, "m.room.power_levels", content) + async def set_power_levels(self, room_id, content): + response = await self.send_state_event(room_id, "m.room.power_levels", content) self.state_store.set_power_levels(room_id, content) return response - def get_pinned_messages(self, room_id): - self.ensure_joined(room_id) - response = self.client._send("GET", f"/rooms/{room_id}/state/m.room.pinned_events") + async def get_pinned_messages(self, room_id): + await self.ensure_joined(room_id) + response = await self.client._send("GET", f"/rooms/{room_id}/state/m.room.pinned_events") return response["content"]["pinned"] def set_pinned_messages(self, room_id, events): @@ -263,29 +265,29 @@ class IntentAPI: "pinned": events }) - def pin_message(self, room_id, event_id): - events = self.get_pinned_messages(room_id) + async def pin_message(self, room_id, event_id): + events = await self.get_pinned_messages(room_id) if event_id not in events: events.append(event_id) - self.set_pinned_messages(room_id, events) + await self.set_pinned_messages(room_id, events) - def unpin_message(self, room_id, event_id): - events = self.get_pinned_messages(room_id) + async def unpin_message(self, room_id, event_id): + events = await self.get_pinned_messages(room_id) if event_id in events: events.remove(event_id) - self.set_pinned_messages(room_id, events) + await self.set_pinned_messages(room_id, events) - def get_event(self, room_id, event_id): - self.ensure_joined(room_id) - return self.client._send("GET", f"/rooms/{room_id}/event/{event_id}") + async def get_event(self, room_id, event_id): + await self.ensure_joined(room_id) + return await self.client._send("GET", f"/rooms/{room_id}/event/{event_id}") - def set_typing(self, room_id, is_typing=True, timeout=5000): - self.ensure_joined(room_id) - return self.client.set_typing(room_id, is_typing, timeout) + async def set_typing(self, room_id, is_typing=True, timeout=5000): + await self.ensure_joined(room_id) + return await self.client.set_typing(room_id, is_typing, timeout) - def mark_read(self, room_id, event_id): - self.ensure_joined(room_id) - return self.client._send("POST", f"/rooms/{room_id}/receipt/m.read/{event_id}", content={}) + async def mark_read(self, room_id, event_id): + await self.ensure_joined(room_id) + return await self.client._send("POST", f"/rooms/{room_id}/receipt/m.read/{event_id}", content={}) def send_notice(self, room_id, text, html=None): return self.send_text(room_id, text, html, "m.notice") @@ -323,24 +325,24 @@ class IntentAPI: def send_message(self, room_id, body): return self.send_event(room_id, "m.room.message", body) - def error_and_leave(self, room_id, text, html=None): - self.ensure_joined(room_id) - self.send_notice(room_id, text, html=html) - self.leave_room(room_id) + async def error_and_leave(self, room_id, text, html=None): + await self.ensure_joined(room_id) + await self.send_notice(room_id, text, html=html) + await self.leave_room(room_id) - def kick(self, room_id, user_id, message): - self.ensure_joined(room_id) - return self.client.kick_user(room_id, user_id, message) + async def kick(self, room_id, user_id, message): + await self.ensure_joined(room_id) + return await self.client.kick_user(room_id, user_id, message) - def send_event(self, room_id, event_type, body, txn_id=None): - self.ensure_joined(room_id) + async def send_event(self, room_id, event_type, body, txn_id=None): + await self.ensure_joined(room_id) self._ensure_has_power_level_for(room_id, event_type) - return self.client.send_message_event(room_id, event_type, body, txn_id) + return await self.client.send_message_event(room_id, event_type, body, txn_id) - def send_state_event(self, room_id, event_type, body, state_key=""): - self.ensure_joined(room_id) + async def send_state_event(self, room_id, event_type, body, state_key=""): + await self.ensure_joined(room_id) self._ensure_has_power_level_for(room_id, event_type) - return self.client.send_state_event(room_id, event_type, body, state_key) + return await self.client.send_state_event(room_id, event_type, body, state_key) def join_room(self, room_id): return self.ensure_joined(room_id, ignore_cache=True) @@ -352,24 +354,25 @@ class IntentAPI: def get_room_memberships(self, room_id): return self.client.get_room_members(room_id) - def get_room_members(self, room_id, allowed_memberships=("join",)): - memberships = self.get_room_memberships(room_id) + async def get_room_members(self, room_id, allowed_memberships=("join",)): + memberships = await self.get_room_memberships(room_id) return [membership["state_key"] for membership in memberships["chunk"] if membership["content"]["membership"] in allowed_memberships] - def get_room_state(self, room_id): - self.ensure_joined(room_id) - return self.client.get_room_state(room_id) + async def get_room_state(self, room_id): + await self.ensure_joined(room_id) + state = await self.client.get_room_state(room_id) + return state # endregion # region Ensure functions - def ensure_joined(self, room_id, ignore_cache=False): + async def ensure_joined(self, room_id, ignore_cache=False): if not ignore_cache and self.state_store.is_joined(room_id, self.mxid): return - self.ensure_registered() + await self.ensure_registered() try: - self.client.join_room(room_id) + await self.client.join_room(room_id) self.state_store.joined(room_id, self.mxid) except MatrixRequestError as e: if matrix_error_code(e) != "M_FORBIDDEN" or not self.bot: @@ -381,11 +384,11 @@ class IntentAPI: except MatrixRequestError as e2: raise IntentError(f"Failed to join room {room_id} as {self.mxid}", e2) - def ensure_registered(self): + async def ensure_registered(self): if self.state_store.is_registered(self.mxid): return try: - self.client.register({"username": self.localpart}) + await self.client.register({"username": self.localpart}) except MatrixRequestError as e: if matrix_error_code(e) != "M_USER_IN_USE": self.log.exception(f"Failed to register {self.mxid}!") diff --git a/mautrix_appservice/temp_async_api.py b/mautrix_appservice/temp_async_api.py new file mode 100644 index 00000000..9e797438 --- /dev/null +++ b/mautrix_appservice/temp_async_api.py @@ -0,0 +1,92 @@ +import json +from asyncio import sleep +from urllib.parse import quote + +from matrix_client.api import MatrixHttpApi +from matrix_client.errors import MatrixError, MatrixRequestError + + +class AsyncHTTPAPI(MatrixHttpApi): + """ + Contains all raw matrix HTTP client-server API calls using asyncio and coroutines. + Examples + -------- + .. code-block: python + async def main(): + async with aiohttp.ClientSession() as session: + mapi = AsyncHTTPAPI("http://matrix.org", session) + resp = await mapi.get_room_id("#matrix:matrix.org") + print(resp) + loop = asyncio.get_event_loop() + loop.run_until_complete(main()) + """ + + def __init__(self, base_url, client_session, token=None, identity=None): + self.base_url = base_url + self.token = token + self.identity = identity + self.txn_id = 0 + self.validate_cert = True + self.client_session = client_session + + async def _send(self, + method, + path, + content=None, + query_params={}, + headers={}, + api_path="/_matrix/client/r0"): + if not content: + content = {} + + method = method.upper() + if method not in ["GET", "PUT", "DELETE", "POST"]: + raise MatrixError("Unsupported HTTP method: %s" % method) + + if "Content-Type" not in headers: + headers["Content-Type"] = "application/json" + + if self.token: + query_params["access_token"] = self.token + endpoint = self.base_url + api_path + path + + if headers["Content-Type"] == "application/json": + content = json.dumps(content) + + while True: + request = self.client_session.request( + method, + endpoint, + params=query_params, + data=content, + headers=headers) + async with request as response: + if response.status < 200 or response.status >= 300: + raise MatrixRequestError( + code=response.status, content=await response.text()) + + if response.status == 429: + await sleep(response.json()['retry_after_ms'] / 1000) + else: + return await response.json() + + async def get_display_name(self, user_id): + content = await self._send("GET", "/profile/%s/displayname" % user_id) + return content.get('displayname', None) + + async def get_avatar_url(self, user_id): + content = await self._send("GET", "/profile/%s/avatar_url" % user_id) + return content.get('avatar_url', None) + + async def get_room_id(self, room_alias): + """Get room id from its alias + Args: + room_alias(str): The room alias name. + Returns: + Wanted room's id. + """ + content = await self._send( + "GET", + "/directory/room/{}".format(quote(room_alias)), + api_path="/_matrix/client/r0") + return content.get("room_id", None) diff --git a/mautrix_telegram/__main__.py b/mautrix_telegram/__main__.py index f2108908..ee05889c 100644 --- a/mautrix_telegram/__main__.py +++ b/mautrix_telegram/__main__.py @@ -17,6 +17,7 @@ import argparse import sys import logging +import asyncio import sqlalchemy as sql from sqlalchemy import orm @@ -28,10 +29,9 @@ from .config import Config from .matrix import MatrixHandler from .db import init as init_db -from .user import init as init_user +from .user import init as init_user, User from .portal import init as init_portal from .puppet import init as init_puppet -# from .formatter import init as init_formatter log = logging.getLogger("mau") time_formatter = logging.Formatter("[%(asctime)s] [%(levelname)s@%(name)s] %(message)s") @@ -72,16 +72,24 @@ db_session = orm.scoping.scoped_session(db_factory) Base.metadata.bind = db_engine Base.metadata.create_all() +loop = asyncio.get_event_loop() appserv = AppService(config["homeserver.address"], config["homeserver.domain"], config["appservice.as_token"], config["appservice.hs_token"], - config["appservice.bot_username"], log="mau.as") -context = (appserv, db_session, config) + config["appservice.bot_username"], log="mau.as", loop=loop) +context = (appserv, db_session, config, loop) with appserv.run(config["appservice.hostname"], config["appservice.port"]) as start: + MatrixHandler(context) init_db(db_session) - # init_formatter(context) init_portal(context) init_puppet(context) - init_user(context) - MatrixHandler(context) - start() + startup_actions = [] + startup_actions += init_user(context) + startup_actions += [start] + try: + loop.run_until_complete(asyncio.gather(*startup_actions)) + loop.run_forever() + except KeyboardInterrupt: + for user in User.by_tgid.values(): + user.client.disconnect() + sys.exit(0) diff --git a/mautrix_telegram/commands.py b/mautrix_telegram/commands.py index 115a63e6..c414adae 100644 --- a/mautrix_telegram/commands.py +++ b/mautrix_telegram/commands.py @@ -58,7 +58,7 @@ class CommandHandler: log = logging.getLogger("mau.commands") def __init__(self, context): - self.az, self.db, self.config = context + self.az, self.db, self.config, _ = context self.command_prefix = self.config["bridge.command_prefix"] self._room_id = None self._is_management = False @@ -69,13 +69,13 @@ class CommandHandler: def handle(self, room, sender, command, args, is_management, is_portal): with self.handler(sender, room, command, args, is_management, is_portal) as handle_command: try: - handle_command(self, sender, args) + return handle_command(self, sender, args) except FloodWaitError as e: - self.reply(f"Flood error: Please wait {format_duration(e.seconds)}") + return self.reply(f"Flood error: Please wait {format_duration(e.seconds)}") except Exception: - self.reply("Fatal error while handling command. Check logs for more details.") self.log.exception(f"Fatal error handling command " + f"'$cmdprefix {command} {''.join(args)}' from {sender.mxid}") + return self.reply("Fatal error while handling command. Check logs for more details.") @contextmanager def handler(self, sender, room, command, args, is_management, is_portal): @@ -109,195 +109,195 @@ class CommandHandler: html = markdown.markdown(message, safe_mode="escape" if allow_html else False) elif allow_html: html = message - self.az.intent.send_notice(self._room_id, message, html=html) + return self.az.intent.send_notice(self._room_id, message, html=html) # endregion # region Command handlers @command_handler - def ping(self, sender, args): + async def ping(self, sender, args): if not sender.logged_in: - return self.reply("You're not logged in.") - me = sender.client.get_me() + return await self.reply("You're not logged in.") + me = await sender.client.get_me() if me: - return self.reply(f"You're logged in as @{me.username}") + return await self.reply(f"You're logged in as @{me.username}") else: - return self.reply("You're not logged in.") + return await self.reply("You're not logged in.") # region Authentication commands @command_handler def register(self, sender, args): - self.reply("Not yet implemented.") + return self.reply("Not yet implemented.") @command_handler - def login(self, sender, args): + async def login(self, sender, args): if not self._is_management: - return self.reply( + return await self.reply( "`login` is a restricted command: you may only run it in management rooms.") elif sender.logged_in: - return self.reply("You are already logged in.") + return await self.reply("You are already logged in.") elif len(args) == 0: - return self.reply("**Usage:** `$cmdprefix+sp login `") + return await self.reply("**Usage:** `$cmdprefix+sp login `") phone_number = args[0] - sender.client.sign_in(phone_number) + await sender.client.sign_in(phone_number) sender.command_status = { "next": command_handlers["enter_code"], "action": "Login", } - return self.reply(f"Login code sent to {phone_number}. Please send the code here.") + return await self.reply(f"Login code sent to {phone_number}. Please send the code here.") @command_handler - def enter_code(self, sender, args): + async def enter_code(self, sender, args): if not sender.command_status: - return self.reply("Request a login code first with `$cmdprefix+sp login `") + return await self.reply("Request a login code first with `$cmdprefix+sp login `") elif len(args) == 0: - return self.reply("**Usage:** `$cmdprefix+sp enter_code `") + return await self.reply("**Usage:** `$cmdprefix+sp enter_code `") try: - user = sender.client.sign_in(code=args[0]) + user = await sender.client.sign_in(code=args[0]) sender.post_login(user) sender.command_status = None - return self.reply(f"Successfully logged in as @{user.username}") + return await self.reply(f"Successfully logged in as @{user.username}") except PhoneNumberUnoccupiedError: - return self.reply("That phone number has not been registered." + return await self.reply("That phone number has not been registered." "Please register with `$cmdprefix+sp register `.") except PhoneCodeExpiredError: - return self.reply( + return await self.reply( "Phone code expired. Try again with `$cmdprefix+sp login `.") except PhoneCodeInvalidError: - return self.reply("Invalid phone code.") + return await self.reply("Invalid phone code.") except PhoneNumberAppSignupForbiddenError: - return self.reply( + return await self.reply( "Your phone number does not allow 3rd party apps to sign in.") except PhoneNumberFloodError: - return self.reply( + return await self.reply( "Your phone number has been temporarily blocked for flooding. " "The block is usually applied for around a day.") except PhoneNumberBannedError: - return self.reply("Your phone number has been banned from Telegram.") + return await self.reply("Your phone number has been banned from Telegram.") except SessionPasswordNeededError: sender.command_status = { "next": command_handlers["enter_password"], "action": "Login (password entry)", } - return self.reply("Your account has two-factor authentication." + return await self.reply("Your account has two-factor authentication." "Please send your password here.") except Exception: self.log.exception() - return self.reply("Unhandled exception while sending code." + return await self.reply("Unhandled exception while sending code." "Check console for more details.") @command_handler - def enter_password(self, sender, args): + async def enter_password(self, sender, args): if not sender.command_status: - return self.reply("Request a login code first with `$cmdprefix+sp login `") + return await self.reply("Request a login code first with `$cmdprefix+sp login `") elif len(args) == 0: - return self.reply("**Usage:** `$cmdprefix+sp enter_password `") + return await self.reply("**Usage:** `$cmdprefix+sp enter_password `") try: - user = sender.client.sign_in(password=args[0]) + user = await sender.client.sign_in(password=args[0]) sender.post_login(user) sender.command_status = None - return self.reply(f"Successfully logged in as @{user.username}") + return await self.reply(f"Successfully logged in as @{user.username}") except PasswordHashInvalidError: - return self.reply("Incorrect password.") + return await self.reply("Incorrect password.") except Exception: self.log.exception() - return self.reply("Unhandled exception while sending password. " + return await self.reply("Unhandled exception while sending password. " "Check console for more details.") @command_handler - def logout(self, sender, args): + async def logout(self, sender, args): if not sender.logged_in: - return self.reply("You're not logged in.") - if sender.log_out(): - return self.reply("Logged out successfully.") - return self.reply("Failed to log out.") + return await self.reply("You're not logged in.") + if await sender.log_out(): + return await self.reply("Logged out successfully.") + return await self.reply("Failed to log out.") # endregion # region Telegram interaction commands @command_handler - def search(self, sender, args): + async def search(self, sender, args): if len(args) == 0: - return self.reply("**Usage:** `$cmdprefix+sp search [-r|--remote] `") + return await self.reply("**Usage:** `$cmdprefix+sp search [-r|--remote] `") elif not sender.logged_in: - return self.reply("This command requires you to be logged in.") + return await self.reply("This command requires you to be logged in.") # force_remote = False if args[0] in {"-r", "--remote"}: # force_remote = True args.pop(0) query = " ".join(args) if len(query) < 5: - return self.reply("Minimum length of query for remote search is 5 characters.") - found = sender.client(SearchRequest(q=query, limit=10)) + return await self.reply("Minimum length of query for remote search is 5 characters.") + found = await sender.client(SearchRequest(q=query, limit=10)) # reply = ["**People:**", ""] reply = ["**Results from Telegram server:**", ""] for result in found.users: puppet = pu.Puppet.get(result.id) - puppet.update_info(sender, result) + await puppet.update_info(sender, result) reply.append( f"* [{puppet.displayname}](https://matrix.to/#/{puppet.mxid}): {puppet.id}") # reply.extend(("", "**Chats:**", "")) # for result in found.chats: # reply.append(f"* {result.title}") - return self.reply("\n".join(reply)) + return await self.reply("\n".join(reply)) @command_handler - def pm(self, sender, args): + async def pm(self, sender, args): if len(args) == 0: - return self.reply("**Usage:** `$cmdprefix+sp pm `") + return await self.reply("**Usage:** `$cmdprefix+sp pm `") elif not sender.logged_in: - return self.reply("This command requires you to be logged in.") + return await self.reply("This command requires you to be logged in.") - user = sender.client.get_entity(args[0]) + user = await sender.client.get_entity(args[0]) if not user: - return self.reply("User not found.") + return await self.reply("User not found.") elif not isinstance(user, User): - return self.reply("That doesn't seem to be a user.") + return await self.reply("That doesn't seem to be a user.") portal = po.Portal.get_by_entity(user, sender.tgid) - portal.create_matrix_room(sender, user, [sender.mxid]) - self.reply(f"Created private chat room with {pu.Puppet.get_displayname(user, False)}") + await portal.create_matrix_room(sender, user, [sender.mxid]) + return await self.reply(f"Created private chat room with {pu.Puppet.get_displayname(user, False)}") @command_handler - def invitelink(self, sender, args): + async def invitelink(self, sender, args): if not sender.logged_in: - return self.reply("This command requires you to be logged in.") + return await self.reply("This command requires you to be logged in.") portal = po.Portal.get_by_mxid(self._room_id) if not portal: - return self.reply("This is not a portal room.") + return await self.reply("This is not a portal room.") if portal.peer_type == "user": - return self.reply("You can't invite users to private chats.") + return await self.reply("You can't invite users to private chats.") try: - link = portal.get_invite_link(sender) - return self.reply(f"Invite link to {portal.title}: {link}") + link = await portal.get_invite_link(sender) + return await self.reply(f"Invite link to {portal.title}: {link}") except ValueError as e: - return self.reply(e.args[0]) + return await self.reply(e.args[0]) except ChatAdminRequiredError: - return self.reply("You don't have the permission to create an invite link.") + return await self.reply("You don't have the permission to create an invite link.") @command_handler - def deleteportal(self, sender, args): + async def deleteportal(self, sender, args): if not sender.logged_in: - return self.reply("This command requires you to be logged in.") + return await self.reply("This command requires you to be logged in.") elif not sender.is_admin: - return self.reply("This is command requires administrator privileges.") + return await self.reply("This is command requires administrator privileges.") portal = po.Portal.get_by_mxid(self._room_id) if not portal: - return self.reply("This is not a portal room.") + return await self.reply("This is not a portal room.") for user in portal.main_intent.get_room_members(portal.mxid): if user != portal.main_intent.mxid: try: - portal.main_intent.kick(portal.mxid, user, "Portal deleted.") + await portal.main_intent.kick(portal.mxid, user, "Portal deleted.") except MatrixRequestError: pass - portal.main_intent.leave_room(portal.mxid) + await portal.main_intent.leave_room(portal.mxid) portal.delete() @staticmethod @@ -308,55 +308,55 @@ class CommandHandler: return value @command_handler - def join(self, sender, args): + async def join(self, sender, args): if len(args) == 0: - return self.reply("**Usage:** `$cmdprefix+sp join `") + return await self.reply("**Usage:** `$cmdprefix+sp join `") elif not sender.logged_in: - return self.reply("This command requires you to be logged in.") + return await self.reply("This command requires you to be logged in.") regex = re.compile(r"(?:https?://)?t(?:elegram)?\.(?:dog|me)(?:joinchat/)?/(.+)") arg = regex.match(args[0]) if not arg: - return self.reply("That doesn't look like a Telegram invite link.") + return await self.reply("That doesn't look like a Telegram invite link.") arg = arg.group(1) if arg.startswith("joinchat/"): invite_hash = arg[len("joinchat/"):] try: - sender.client(CheckChatInviteRequest(invite_hash)) + await sender.client(CheckChatInviteRequest(invite_hash)) except InviteHashInvalidError: - return self.reply("Invalid invite link.") + return await self.reply("Invalid invite link.") except InviteHashExpiredError: - return self.reply("Invite link expired.") + return await self.reply("Invite link expired.") try: updates = sender.client(ImportChatInviteRequest(invite_hash)) except UserAlreadyParticipantError: - return self.reply("You are already in that chat.") + return await self.reply("You are already in that chat.") else: - channel = sender.client.get_entity(arg) + channel = await sender.client.get_entity(arg) if not channel: - return self.reply("Channel/supergroup not found.") - updates = sender.client(JoinChannelRequest(channel)) + return await self.reply("Channel/supergroup not found.") + updates = await sender.client(JoinChannelRequest(channel)) for chat in updates.chats: portal = po.Portal.get_by_entity(chat) if portal.mxid: - portal.create_matrix_room(sender, chat, [sender.mxid]) - self.reply(f"Created room for {portal.title}") + await portal.create_matrix_room(sender, chat, [sender.mxid]) + return await self.reply(f"Created room for {portal.title}") else: - portal.invite_matrix([sender.mxid]) - self.reply(f"Invited you to portal of {portal.title}") + await portal.invite_matrix([sender.mxid]) + return await self.reply(f"Invited you to portal of {portal.title}") @command_handler - def create(self, sender, args): + async def create(self, sender, args): type = args[0] if len(args) > 0 else "group" if type not in {"chat", "group", "supergroup", "channel"}: - return self.reply("**Usage:** `$cmdprefix+sp create ['group'/'supergroup'/'channel']`") + return await self.reply("**Usage:** `$cmdprefix+sp create ['group'/'supergroup'/'channel']`") elif not sender.logged_in: - return self.reply("This command requires you to be logged in.") + return await self.reply("This command requires you to be logged in.") if po.Portal.get_by_mxid(self._room_id): - return self.reply("This is already a portal room.") + return await self.reply("This is already a portal room.") - state = self.az.intent.get_room_state(self._room_id) + state = await self.az.intent.get_room_state(self._room_id) title = None about = None levels = None @@ -368,16 +368,16 @@ class CommandHandler: elif event["type"] == "m.room.power_levels": levels = event["content"] if not title: - return self.reply("Please set a title before creating a Telegram chat.") + return await self.reply("Please set a title before creating a Telegram chat.") elif (not levels or not levels["users"] or self.az.intent.mxid not in levels["users"] or levels["users"][self.az.intent.mxid] < 100): - return self.reply(f"Please give " + return await self.reply(f"Please give " + f"[the bridge bot](https://matrix.to/#/{self.az.intent.mxid}) " + f"a power level of 100 before creating a Telegram chat.") else: for user, level in levels["users"].items(): if level >= 100 and user != self.az.intent.mxid: - return self.reply(f"Please make sure only the bridge bot has power level above" + return await self.reply(f"Please make sure only the bridge bot has power level above" + f"99 before creating a Telegram chat.\n\n" + f"Use power level 95 instead of 100 for admins.") @@ -391,62 +391,62 @@ class CommandHandler: portal = po.Portal(tgid=None, mxid=self._room_id, title=title, about=about, peer_type=type) try: - portal.create_telegram_chat(sender, supergroup=supergroup) + await portal.create_telegram_chat(sender, supergroup=supergroup) except ValueError as e: - return self.reply(e.args[0]) - self.reply(f"Telegram chat created. ID: {portal.tgid}") + return await self.reply(e.args[0]) + return await self.reply(f"Telegram chat created. ID: {portal.tgid}") @command_handler - def upgrade(self, sender, args): + async def upgrade(self, sender, args): if not sender.logged_in: - return self.reply("This command requires you to be logged in.") + return await self.reply("This command requires you to be logged in.") portal = po.Portal.get_by_mxid(self._room_id) if not portal: - return self.reply("This is not a portal room.") + return await self.reply("This is not a portal room.") elif portal.peer_type == "channel": - return self.reply("This is already a supergroup or a channel.") + return await self.reply("This is already a supergroup or a channel.") elif portal.peer_type == "user": - return self.reply("You can't upgrade private chats.") + return await self.reply("You can't upgrade private chats.") try: - portal.upgrade_telegram_chat(sender) - return self.reply(f"Group upgraded to supergroup. New ID: {portal.tgid}") + await portal.upgrade_telegram_chat(sender) + return await self.reply(f"Group upgraded to supergroup. New ID: {portal.tgid}") except ChatAdminRequiredError: - return self.reply("You don't have the permission to upgrade this group.") + return await self.reply("You don't have the permission to upgrade this group.") except ValueError as e: - return self.reply(e.args[0]) + return await self.reply(e.args[0]) @command_handler - def groupname(self, sender, args): + async def groupname(self, sender, args): if len(args) == 0: - return self.reply("**Usage:** `$cmdprefix+sp groupname `") + return await self.reply("**Usage:** `$cmdprefix+sp groupname `") if not sender.logged_in: - return self.reply("This command requires you to be logged in.") + return await self.reply("This command requires you to be logged in.") portal = po.Portal.get_by_mxid(self._room_id) if not portal: - return self.reply("This is not a portal room.") + return await self.reply("This is not a portal room.") elif portal.peer_type != "channel": - return self.reply("Only channels and supergroups have usernames.") + return await self.reply("Only channels and supergroups have usernames.") try: - portal.set_telegram_username(sender, args[0] if args[0] != "-" else "") + await portal.set_telegram_username(sender, args[0] if args[0] != "-" else "") if portal.username: - return self.reply(f"Username of channel changed to {portal.username}.") + return await self.reply(f"Username of channel changed to {portal.username}.") else: - return self.reply(f"Channel is now private.") + return await self.reply(f"Channel is now private.") except ChatAdminRequiredError: - return self.reply("You don't have the permission to set the username of this channel.") + return await self.reply("You don't have the permission to set the username of this channel.") except UsernameNotModifiedError: if portal.username: - return self.reply("That is already the username of this channel.") + return await self.reply("That is already the username of this channel.") else: - return self.reply("This channel is already private") + return await self.reply("This channel is already private") except UsernameOccupiedError: - return self.reply("That username is already in use.") + return await self.reply("That username is already in use.") except UsernameInvalidError: - return self.reply("Invalid username") + return await self.reply("Invalid username") # endregion # region Command-related commands diff --git a/mautrix_telegram/formatter.py b/mautrix_telegram/formatter.py index c233201b..34125eaa 100644 --- a/mautrix_telegram/formatter.py +++ b/mautrix_telegram/formatter.py @@ -213,7 +213,7 @@ def matrix_to_telegram(html, tg_space=None): # endregion # region Telegram to Matrix -def telegram_event_to_matrix(evt, source, native_replies=False, message_link_in_reply=False, +async def telegram_event_to_matrix(evt, source, native_replies=False, message_link_in_reply=False, main_intent=None): text = evt.message html = telegram_to_matrix(evt.message, evt.entities) if evt.entities else None @@ -230,7 +230,7 @@ def telegram_event_to_matrix(evt, source, native_replies=False, message_link_in_ if puppet and puppet.displayname: fwd_from = f"{puppet.displayname}" else: - user = source.client.get_entity(from_id) + user = await source.client.get_entity(from_id) if user: fwd_from = p.Puppet.get_displayname(user, format=False) else: @@ -249,7 +249,7 @@ def telegram_event_to_matrix(evt, source, native_replies=False, message_link_in_ quote = f"Quote
" else: try: - event = main_intent.get_event(msg.mx_room, msg.mxid) + event = await main_intent.get_event(msg.mx_room, msg.mxid) content = event["content"] body = (content["formatted_body"] if "formatted_body" in content diff --git a/mautrix_telegram/matrix.py b/mautrix_telegram/matrix.py index 93bda0fc..1b13ea3e 100644 --- a/mautrix_telegram/matrix.py +++ b/mautrix_telegram/matrix.py @@ -28,85 +28,87 @@ class MatrixHandler: log = logging.getLogger("mau.mx") def __init__(self, context): - self.az, self.db, self.config = context + self.az, self.db, self.config, _ = context self.commands = CommandHandler(context) self.az.matrix_event_handler(self.handle_event) + + async def init_as_bot(self): self.az.intent.set_display_name( self.config.get("appservice.bot_displayname", "Telegram bridge bot")) - def handle_puppet_invite(self, room, puppet, inviter): + async def handle_puppet_invite(self, room, puppet, inviter): self.log.debug(f"{inviter} invited puppet for {puppet.tgid} to {room}") if not inviter.logged_in: - puppet.intent.error_and_leave( + await puppet.intent.error_and_leave( room, text="Please log in before inviting Telegram puppets.") return portal = Portal.get_by_mxid(room) if portal: if portal.peer_type == "user": - puppet.intent.error_and_leave( + await puppet.intent.error_and_leave( room, text="You can not invite additional users to private chats.") return - portal.invite_telegram(inviter, puppet) - puppet.intent.join_room(room) + await portal.invite_telegram(inviter, puppet) + await puppet.intent.join_room(room) return try: - members = self.az.intent.get_room_members(room) + members = await self.az.intent.get_room_members(room) except MatrixRequestError: members = [] if self.az.intent.mxid not in members: if len(members) > 1: - puppet.intent.error_and_leave(room, text=None, html=( + await puppet.intent.error_and_leave(room, text=None, html=( f"Please invite " + f"the bridge bot " + f"first if you want to create a Telegram chat.")) return - puppet.intent.join_room(room) + await puppet.intent.join_room(room) portal = Portal.get_by_tgid(puppet.tgid, inviter.tgid, "user") if portal.mxid: try: - puppet.intent.invite(portal.mxid, inviter.mxid) - puppet.intent.send_notice(room, text=None, html=( + await puppet.intent.invite(portal.mxid, inviter.mxid) + await puppet.intent.send_notice(room, text=None, html=( "You already have a private chat with me: " + f"" + "Link to room" + "")) - puppet.intent.leave_room(room) + await puppet.intent.leave_room(room) return except MatrixRequestError: pass portal.mxid = room portal.save() - puppet.intent.send_notice(room, "Portal to private chat created.") + await puppet.intent.send_notice(room, "Portal to private chat created.") else: - puppet.intent.join_room(room) - puppet.intent.send_notice(room, "This puppet will remain inactive until a Telegram " - "chat is created for this room.") + await puppet.intent.join_room(room) + await puppet.intent.send_notice(room, "This puppet will remain inactive until a" + "Telegram chat is created for this room.") - def handle_invite(self, room, user, inviter): + async def handle_invite(self, room, user, inviter): inviter = User.get_by_mxid(inviter) if not inviter.whitelisted: return elif user == self.az.bot_mxid: - self.az.intent.join_room(room) + await self.az.intent.join_room(room) return puppet = Puppet.get_by_mxid(user) if puppet: - self.handle_puppet_invite(room, puppet, inviter) + await self.handle_puppet_invite(room, puppet, inviter) return user = User.get_by_mxid(user, create=False) portal = Portal.get_by_mxid(room) if user and user.has_full_access and portal: - portal.invite_telegram(inviter, user) + await portal.invite_telegram(inviter, user) return # The rest can probably be ignored self.log.debug(f"{inviter} invited {user} to {room}") - def handle_join(self, room, user): + async def handle_join(self, room, user): user = User.get_by_mxid(user) portal = Portal.get_by_mxid(room) @@ -114,19 +116,19 @@ class MatrixHandler: return if not user.whitelisted: - portal.main_intent.kick(room, user.mxid, - "You are not whitelisted on this Telegram bridge.") + await portal.main_intent.kick(room, user.mxid, + "You are not whitelisted on this Telegram bridge.") return elif not user.logged_in: # TODO[waiting-for-bots] once we have bot support, this won't be needed. - portal.main_intent.kick(room, user.mxid, - "You are not logged into this Telegram bridge.") + await portal.main_intent.kick(room, user.mxid, + "You are not logged into this Telegram bridge.") return self.log.debug(f"{user} joined {room}") # TODO join Telegram chat if applicable - def handle_part(self, room, user, sender): + async def handle_part(self, room, user, sender): self.log.debug(f"{user} left {room}") sender = User.get_by_mxid(sender, create=False) @@ -137,11 +139,11 @@ class MatrixHandler: puppet = Puppet.get_by_mxid(user) if sender and puppet: - portal.leave_matrix(puppet, sender) + await portal.leave_matrix(puppet, sender) user = User.get_by_mxid(user, create=False) if user and user.logged_in: - portal.leave_matrix(user, sender) + await portal.leave_matrix(user, sender) def is_command(self, message): text = message.get("body", "") @@ -151,7 +153,7 @@ class MatrixHandler: text = text[len(prefix) + 1:] return is_command, text - def handle_message(self, room, sender, message, event_id): + async def handle_message(self, room, sender, message, event_id): self.log.debug(f"{sender} sent {message} to ${room}") is_command, text = self.is_command(message) @@ -159,14 +161,14 @@ class MatrixHandler: portal = Portal.get_by_mxid(room) if sender.has_full_access and portal and not is_command: - portal.handle_matrix_message(sender, message, event_id) + await portal.handle_matrix_message(sender, message, event_id) return if message["msgtype"] != "m.text": return try: - is_management = len(self.az.intent.get_room_members(room)) == 2 + is_management = len(await self.az.intent.get_room_members(room)) == 2 except MatrixRequestError: # The AS bot is not in the room. return @@ -179,22 +181,22 @@ class MatrixHandler: # Not enough values to unpack, i.e. no arguments command = text args = [] - self.commands.handle(room, sender, command, args, is_management, - is_portal=portal is not None) + await self.commands.handle(room, sender, command, args, is_management, + is_portal=portal is not None) - def handle_redaction(self, room, sender, event_id): + async def handle_redaction(self, room, sender, event_id): portal = Portal.get_by_mxid(room) sender = User.get_by_mxid(sender) if sender.has_full_access and portal: - portal.handle_matrix_deletion(sender, event_id) + await portal.handle_matrix_deletion(sender, event_id) - def handle_power_levels(self, room, sender, new, old): + async def handle_power_levels(self, room, sender, new, old): portal = Portal.get_by_mxid(room) sender = User.get_by_mxid(sender) if sender.has_full_access and portal: - portal.handle_matrix_power_levels(sender, new["users"], old["users"]) + await portal.handle_matrix_power_levels(sender, new["users"], old["users"]) - def handle_room_meta(self, type, room, sender, content): + async def handle_room_meta(self, type, room, sender, content): portal = Portal.get_by_mxid(room) sender = User.get_by_mxid(sender) if sender.has_full_access and portal: @@ -206,13 +208,13 @@ class MatrixHandler: if content_key not in content: # FIXME handle pass - handler(sender, content[content_key]) + await handler(sender, content[content_key]) def filter_matrix_event(self, event): return (event["sender"] == self.az.bot_mxid or Puppet.get_id_from_mxid(event["sender"]) is not None) - def handle_event(self, evt): + async def handle_event(self, evt): if self.filter_matrix_event(evt): return self.log.debug("Received event: %s", evt) @@ -221,17 +223,17 @@ class MatrixHandler: if type == "m.room.member": membership = content.get("membership", "") if membership == "invite": - self.handle_invite(evt["room_id"], evt["state_key"], evt["sender"]) + await self.handle_invite(evt["room_id"], evt["state_key"], evt["sender"]) elif membership == "leave": - self.handle_part(evt["room_id"], evt["state_key"], evt["sender"]) + await self.handle_part(evt["room_id"], evt["state_key"], evt["sender"]) elif membership == "join": - self.handle_join(evt["room_id"], evt["state_key"]) + await self.handle_join(evt["room_id"], evt["state_key"]) elif type == "m.room.message": - self.handle_message(evt["room_id"], evt["sender"], content, evt["event_id"]) + await self.handle_message(evt["room_id"], evt["sender"], content, evt["event_id"]) elif type == "m.room.redaction": - self.handle_redaction(evt["room_id"], evt["sender"], evt["redacts"]) + await self.handle_redaction(evt["room_id"], evt["sender"], evt["redacts"]) elif type == "m.room.power_levels": - self.handle_power_levels(evt["room_id"], evt["sender"], evt["content"], - evt["prev_content"]) + await self.handle_power_levels(evt["room_id"], evt["sender"], evt["content"], + evt["prev_content"]) elif type == "m.room.name" or type == "m.room.avatar" or type == "m.room.topic": - self.handle_room_meta(type, evt["room_id"], evt["sender"], evt["content"]) + await self.handle_room_meta(type, evt["room_id"], evt["sender"], evt["content"]) diff --git a/mautrix_telegram/portal.py b/mautrix_telegram/portal.py index 29fa8dcf..95bcf767 100644 --- a/mautrix_telegram/portal.py +++ b/mautrix_telegram/portal.py @@ -18,6 +18,7 @@ from io import BytesIO from collections import deque from datetime import datetime import random +import asyncio import mimetypes import hashlib import logging @@ -138,37 +139,37 @@ class Portal: self._main_intent = puppet.intent if direct else self.az.intent return self._main_intent - def invite_matrix(self, users): + async def invite_matrix(self, users): if isinstance(users, str): - self.main_intent.invite(self.mxid, users, check_cache=True) + await self.main_intent.invite(self.mxid, users, check_cache=True) elif isinstance(users, list): for user in users: - self.main_intent.invite(self.mxid, user, check_cache=True) + await self.main_intent.invite(self.mxid, user, check_cache=True) else: raise ValueError("Invalid invite identifier given to invite_matrix()") - def update_after_create(self, user, entity, direct, puppet=None): + async def update_after_create(self, user, entity, direct, puppet=None): if not direct: - self.update_info(user, entity) - users, participants = self.get_users(user, entity) - self.sync_telegram_users(user, users) - self.update_telegram_participants(participants) + await self.update_info(user, entity) + users, participants = await self.get_users(user, entity) + await self.sync_telegram_users(user, users) + await self.update_telegram_participants(participants) else: if not puppet: puppet = p.Puppet.get(self.tgid) - puppet.update_info(user, entity) - puppet.intent.join_room(self.mxid) + await puppet.update_info(user, entity) + await puppet.intent.join_room(self.mxid) - def create_matrix_room(self, user, entity=None, invites=None, update_if_exists=True): + async def create_matrix_room(self, user, entity=None, invites=None, update_if_exists=True): if not entity: - entity = user.client.get_entity(self.peer) + entity = await user.client.get_entity(self.peer) self.log.debug("Fetched data: %s", entity) direct = self.peer_type == "user" if self.mxid: if update_if_exists: - self.update_after_create(user, entity, direct) - self.invite_matrix(invites or []) + await self.update_after_create(user, entity, direct) + await self.invite_matrix(invites or []) return self.mxid self.log.debug(f"Creating room for {self.tgid_log}") @@ -194,8 +195,8 @@ class Portal: if alias: # TODO properly handle existing room aliases intent.remove_room_alias(alias) - room = intent.create_room(alias=alias, is_public=public, invitees=invites or [], - name=self.title, is_direct=direct) + room = await intent.create_room(alias=alias, is_public=public, invitees=invites or [], + name=self.title, is_direct=direct) if not room: raise Exception(f"Failed to create room for {self.tgid_log}") @@ -204,93 +205,93 @@ class Portal: self.save() power_level_requirement = 0 if self.peer_type == "chat" and entity.admins_enabled else 50 - levels = self.main_intent.get_power_levels(self.mxid) + levels = await self.main_intent.get_power_levels(self.mxid) levels["ban"] = 100 levels["invite"] = 50 levels["events"]["m.room.name"] = power_level_requirement levels["events"]["m.room.avatar"] = power_level_requirement levels["events"]["m.room.topic"] = 50 if self.peer_type == "channel" else 100 levels["events"]["m.room.power_levels"] = 75 - self.main_intent.set_power_levels(self.mxid, levels) - self.update_after_create(user, entity, direct, puppet) + await self.main_intent.set_power_levels(self.mxid, levels) + await self.update_after_create(user, entity, direct, puppet) def _get_room_alias(self, username=None): username = username or self.username return config.get("bridge.alias_template", "telegram_{groupname}").format( groupname=username) - def sync_telegram_users(self, source, users): + async def sync_telegram_users(self, source, users): for entity in users: puppet = p.Puppet.get(entity.id) - puppet.update_info(source, entity) - puppet.intent.ensure_joined(self.mxid) + await puppet.intent.ensure_joined(self.mxid) + await puppet.update_info(source, entity) - def add_telegram_user(self, user_id, source=None): + async def add_telegram_user(self, user_id, source=None): puppet = p.Puppet.get(user_id) if source: - entity = source.client.get_entity(user_id) - puppet.update_info(source, entity) - puppet.intent.join_room(self.mxid) + entity = await source.client.get_entity(user_id) + await puppet.update_info(source, entity) + await puppet.intent.join_room(self.mxid) user = u.User.get_by_tgid(user_id) if user: - self.main_intent.invite(self.mxid, user.mxid) + await self.main_intent.invite(self.mxid, user.mxid) - def delete_telegram_user(self, user_id, kick_message=None): + async def delete_telegram_user(self, user_id, kick_message=None): puppet = p.Puppet.get(user_id) user = u.User.get_by_tgid(user_id) if kick_message: - self.main_intent.kick(self.mxid, puppet.mxid, kick_message) + await self.main_intent.kick(self.mxid, puppet.mxid, kick_message) else: - puppet.intent.leave_room(self.mxid) + await puppet.intent.leave_room(self.mxid) if user: - self.main_intent.kick(self.mxid, user.mxid, kick_message or "Left Telegram chat") + await self.main_intent.kick(self.mxid, user.mxid, kick_message or "Left Telegram chat") - def update_info(self, user, entity=None): + async def update_info(self, user, entity=None): if self.peer_type == "user": - self.log.warn(f"Called update_info() for direct chat portal {self.tgid_log}") + self.log.warning(f"Called update_info() for direct chat portal {self.tgid_log}") return self.log.debug(f"Updating info of {self.tgid_log}") if not entity: - entity = user.client.get_entity(self.peer) + entity = await user.client.get_entity(self.peer) self.log.debug("Fetched data: %s", entity) changed = False if self.peer_type == "channel": - changed = self.update_username(entity.username) or changed + changed = await self.update_username(entity.username) or changed # TODO update about text # changed = self.update_about(entity.about) or changed - changed = self.update_title(entity.title) or changed + changed = await self.update_title(entity.title) or changed if isinstance(entity.photo, ChatPhoto): - changed = self.update_avatar(user, entity.photo.photo_big) or changed + changed = await self.update_avatar(user, entity.photo.photo_big) or changed if changed: self.save() - def update_username(self, username): + async def update_username(self, username): if self.username != username: if self.username: - self.main_intent.remove_room_alias(self._get_room_alias()) + await self.main_intent.remove_room_alias(self._get_room_alias()) self.username = username or None if self.username: - self.main_intent.add_room_alias(self.mxid, self._get_room_alias()) + await self.main_intent.add_room_alias(self.mxid, self._get_room_alias()) return True return False - def update_about(self, about): + async def update_about(self, about): if self.about != about: self.about = about - self.main_intent.set_room_topic(self.mxid, self.about) + await self.main_intent.set_room_topic(self.mxid, self.about) return True return False - def update_title(self, title): + async def update_title(self, title): if self.title != title: self.title = title - self.main_intent.set_room_name(self.mxid, self.title) + await self.main_intent.set_room_name(self.mxid, self.title) return True return False @@ -299,26 +300,26 @@ class Portal: return max(photo.sizes, key=(lambda photo2: ( len(photo2.bytes) if isinstance(photo2, PhotoCachedSize) else photo2.size))) - def update_avatar(self, user, photo): + async def update_avatar(self, user, photo): photo_id = f"{photo.volume_id}-{photo.local_id}" if self.photo_id != photo_id: try: - file = user.client.download_file_bytes(photo) + file = await user.client.download_file_bytes(photo) except LocationInvalidError: return False - uploaded = self.main_intent.upload_file(file) - self.main_intent.set_room_avatar(self.mxid, uploaded["content_uri"]) + uploaded = await self.main_intent.upload_file(file) + await self.main_intent.set_room_avatar(self.mxid, uploaded["content_uri"]) self.photo_id = photo_id return True return False - def get_users(self, user, entity): + async def get_users(self, user, entity): if self.peer_type == "chat": - chat = user.client(GetFullChatRequest(chat_id=self.tgid)) + chat = await user.client(GetFullChatRequest(chat_id=self.tgid)) return chat.users, chat.full_chat.participants.participants elif self.peer_type == "channel": try: - participants = user.client(GetParticipantsRequest( + participants = await user.client(GetParticipantsRequest( entity, ChannelParticipantsRecent(), offset=0, limit=100, hash=0 )) return participants.users, participants.participants @@ -327,16 +328,16 @@ class Portal: elif self.peer_type == "user": return [entity], [] - def get_invite_link(self, user): + async def get_invite_link(self, user): if self.peer_type == "user": raise ValueError("You can't invite users to private chats.") elif self.peer_type == "chat": - link = user.client(ExportChatInviteRequest(chat_id=self.tgid)) + link = await user.client(ExportChatInviteRequest(chat_id=self.tgid)) elif self.peer_type == "channel": if self.username: return f"https://t.me/{self.username}" - link = user.client( - ExportInviteRequest(channel=self.get_input_entity(user))) + link = await user.client( + ExportInviteRequest(channel=await self.get_input_entity(user))) else: raise ValueError(f"Invalid peer type '{self.peer_type}' for invite link.") @@ -360,48 +361,47 @@ class Portal: file_name = f"matrix_upload{mimetypes.guess_extension(mime)}" return file_name, None if file_name == body else body - def leave_matrix(self, user, source): + async def leave_matrix(self, user, source): if self.peer_type == "user": - self.main_intent.leave_room(self.mxid) + await self.main_intent.leave_room(self.mxid) self.delete() del self.by_tgid[self.tgid_full] del self.by_mxid[self.mxid] elif source and source.tgid != user.tgid: target = user.get_input_entity(source) if self.peer_type == "chat": - source.client(DeleteChatUserRequest(chat_id=self.tgid, user_id=target)) + await source.client(DeleteChatUserRequest(chat_id=self.tgid, user_id=target)) else: - channel = self.get_input_entity(source) + channel = await self.get_input_entity(source) rights = ChannelBannedRights(datetime.fromtimestamp(0), True) - source.client(EditBannedRequest(channel=channel, - user_id=target, - banned_rights=rights)) + await source.client(EditBannedRequest(channel=channel, + user_id=target, + banned_rights=rights)) elif self.peer_type == "chat": - user.client(DeleteChatUserRequest(chat_id=self.tgid, user_id=InputUserSelf())) + await user.client(DeleteChatUserRequest(chat_id=self.tgid, user_id=InputUserSelf())) elif self.peer_type == "channel": - channel = self.get_input_entity(user) - user.client(LeaveChannelRequest(channel=channel)) + channel = await self.get_input_entity(user) + await user.client(LeaveChannelRequest(channel=channel)) - def handle_matrix_message(self, sender, message, event_id): + async def handle_matrix_message(self, sender, message, event_id): type = message["msgtype"] if type in {"m.text", "m.emote"}: if "format" in message and message["format"] == "org.matrix.custom.html": space = self.tgid if self.peer_type == "channel" else sender.tgid - print(sender.username, sender.tgid, space) message, entities = formatter.matrix_to_telegram(message["formatted_body"], space) if type == "m.emote": message = "/me " + message reply_to = None if len(entities) > 0 and isinstance(entities[0], formatter.MessageEntityReply): reply_to = entities.pop(0).msg_id - response = sender.client.send_message(self.peer, message, entities=entities, - reply_to=reply_to) + response = await sender.client.send_message(self.peer, message, entities=entities, + reply_to=reply_to) else: if type == "m.emote": message["body"] = "/me " + message["body"] - response = sender.client.send_message(self.peer, message["body"]) + response = await sender.client.send_message(self.peer, message["body"]) elif type in {"m.image", "m.file", "m.audio", "m.video"}: - file = self.main_intent.download_file(message["url"]) + file = await self.main_intent.download_file(message["url"]) info = message["info"] mime = info["mimetype"] @@ -412,8 +412,8 @@ class Portal: if "w" in info and "h" in info: attributes.append(DocumentAttributeImageSize(w=info["w"], h=info["h"])) - response = sender.client.send_file(self.peer, file, mime, caption, attributes, - file_name) + response = await sender.client.send_file(self.peer, file, mime, caption, attributes, + file_name) else: self.log.debug("Unhandled Matrix event: %s", message) return @@ -426,16 +426,16 @@ class Portal: mxid=event_id)) self.db.commit() - def handle_matrix_deletion(self, deleter, event_id): + async def handle_matrix_deletion(self, deleter, event_id): space = self.tgid if self.peer_type == "channel" else deleter.tgid message = DBMessage.query.filter(DBMessage.mxid == event_id and DBMessage.tg_space == space and DBMessage.mx_room == self.mxid).one_or_none() if not message: return - deleter.client.delete_messages(self.peer, [message.tgid]) + await deleter.client.delete_messages(self.peer, [message.tgid]) - def handle_matrix_power_levels(self, sender, new_users, old_users): + async def handle_matrix_power_levels(self, sender, new_users, old_users): # TODO handle all power level changes and bridge exact admin rights to supergroups/channels for user, level in new_users.items(): user_id = p.Puppet.get_id_from_mxid(user) @@ -446,7 +446,7 @@ class Portal: user_id = mx_user.tgid if user not in old_users or level != old_users[user]: if self.peer_type == "chat": - sender.client(EditChatAdminRequest( + await sender.client(EditChatAdminRequest( chat_id=self.tgid, user_id=user_id, is_admin=level >= 50)) elif self.peer_type == "channel": moderator = level >= 50 @@ -456,47 +456,47 @@ class Portal: ban_users=moderator, invite_users=moderator, invite_link=moderator, pin_messages=moderator, add_admins=admin, manage_call=moderator) - sender.client( + await sender.client( EditAdminRequest(channel=self.get_input_entity(sender), user_id=sender.client.get_input_entity(PeerUser(user_id)), admin_rights=rights)) - def handle_matrix_about(self, sender, about): + async def handle_matrix_about(self, sender, about): if self.peer_type not in {"channel"}: return - channel = self.get_input_entity(sender) - sender.client(EditAboutRequest(channel=channel, about=about)) + channel = await self.get_input_entity(sender) + await sender.client(EditAboutRequest(channel=channel, about=about)) self.about = about self.save() - def handle_matrix_title(self, sender, title): + async def handle_matrix_title(self, sender, title): if self.peer_type not in {"chat", "channel"}: return if self.peer_type == "chat": - sender.client(EditChatTitleRequest(chat_id=self.tgid, title=title)) + await sender.client(EditChatTitleRequest(chat_id=self.tgid, title=title)) else: - channel = self.get_input_entity(sender) - sender.client(EditTitleRequest(channel=channel, title=title)) + channel = await self.get_input_entity(sender) + await sender.client(EditTitleRequest(channel=channel, title=title)) self.title = title self.save() - def handle_matrix_avatar(self, sender, url): + async def handle_matrix_avatar(self, sender, url): if self.peer_type not in {"chat", "channel"}: # Invalid peer type return - file = self.main_intent.download_file(url) + file = await self.main_intent.download_file(url) mime = magic.from_buffer(file, mime=True) ext = mimetypes.guess_extension(mime) - uploaded = sender.client.upload_file(file, file_name=f"avatar{ext}") + uploaded = await sender.client.upload_file(file, file_name=f"avatar{ext}") photo = InputChatUploadedPhoto(file=uploaded) if self.peer_type == "chat": - updates = sender.client(EditChatPhotoRequest(chat_id=self.tgid, photo=photo)) + updates = await sender.client(EditChatPhotoRequest(chat_id=self.tgid, photo=photo)) else: - channel = self.get_input_entity(sender) - updates = sender.client(EditPhotoRequest(channel=channel, photo=photo)) + channel = await self.get_input_entity(sender) + updates = await sender.client(EditPhotoRequest(channel=channel, photo=photo)) for update in updates.updates: is_photo_update = (isinstance(update, UpdateNewMessage) and isinstance(update.message, MessageService) @@ -510,9 +510,9 @@ class Portal: # endregion # region Telegram chat info updating - def _get_telegram_users_in_matrix_room(self): + async def _get_telegram_users_in_matrix_room(self): user_tgids = set() - user_mxids = self.main_intent.get_room_members(self.mxid, ("join", "invite")) + user_mxids = await self.main_intent.get_room_members(self.mxid, ("join", "invite")) for user in user_mxids: if user == self.az.intent.mxid: continue @@ -524,11 +524,11 @@ class Portal: user_tgids.add(puppet_id) return user_tgids - def upgrade_telegram_chat(self, source): + async def upgrade_telegram_chat(self, source): if self.peer_type != "chat": raise ValueError("Only normal group chats are upgradable to supergroups.") - updates = source.client(MigrateChatRequest(chat_id=self.tgid)) + updates = await source.client(MigrateChatRequest(chat_id=self.tgid)) entity = None for chat in updates.chats: if isinstance(chat, Channel): @@ -538,67 +538,71 @@ class Portal: raise ValueError("Upgrade may have failed: output channel not found.") self.peer_type = "channel" self.migrate_and_save(entity.id) - self.update_info(source, entity) + await self.update_info(source, entity) - def set_telegram_username(self, source, username): + async def set_telegram_username(self, source, username): if self.peer_type != "channel": raise ValueError("Only channels and supergroups have usernames.") - success = source.client(UpdateUsernameRequest(self.get_input_entity(source), username)) - if self.update_username(username): + await source.client( + UpdateUsernameRequest(self.get_input_entity(source), username)) + if await self.update_username(username): self.save() - def create_telegram_chat(self, source, supergroup=False): + async def create_telegram_chat(self, source, supergroup=False): if not self.mxid: raise ValueError("Can't create Telegram chat for portal without Matrix room.") elif self.tgid: raise ValueError("Can't create Telegram chat for portal with existing Telegram chat.") - invites = self._get_telegram_users_in_matrix_room() + invites = await self._get_telegram_users_in_matrix_room() if len(invites) < 2: # TODO[waiting-for-bots] This won't happen when the bot is enabled raise ValueError("Not enough Telegram users to create a chat") - invites = [source.client.get_input_entity(id) for id in invites] + invites = [await source.client.get_input_entity(id) for id in invites] if self.peer_type == "chat": - updates = source.client(CreateChatRequest(title=self.title, users=invites)) + updates = await source.client(CreateChatRequest(title=self.title, users=invites)) entity = updates.chats[0] elif self.peer_type == "channel": - updates = source.client(CreateChannelRequest(title=self.title, about=self.about or "", - megagroup=supergroup)) + updates = await source.client(CreateChannelRequest(title=self.title, + about=self.about or "", + megagroup=supergroup)) entity = updates.chats[0] - source.client(InviteToChannelRequest(channel=source.client.get_input_entity(entity), - users=invites)) + await source.client(InviteToChannelRequest( + channel=source.client.get_input_entity(entity), + users=invites)) else: raise ValueError("Invalid peer type for Telegram chat creation") self.tgid = entity.id self.tg_receiver = self.tgid self.by_tgid[self.tgid_full] = self - self.update_info(source, entity) + await self.update_info(source, entity) self.save() - def invite_telegram(self, source, puppet): + async def invite_telegram(self, source, puppet): if self.peer_type == "chat": - source.client(AddChatUserRequest(chat_id=self.tgid, user_id=puppet.tgid, fwd_limit=0)) + await source.client( + AddChatUserRequest(chat_id=self.tgid, user_id=puppet.tgid, fwd_limit=0)) elif self.peer_type == "channel": - target = puppet.get_input_entity(source) - source.client(InviteToChannelRequest(channel=self.peer, users=[target])) + target = await puppet.get_input_entity(source) + await source.client(InviteToChannelRequest(channel=self.peer, users=[target])) else: raise ValueError("Invalid peer type for Telegram user invite") # endregion # region Telegram event handling - def handle_telegram_typing(self, user, event): + async def handle_telegram_typing(self, user, event): if self.mxid: - user.intent.set_typing(self.mxid, is_typing=True) + await user.intent.set_typing(self.mxid, is_typing=True) - def handle_telegram_photo(self, source, sender, media): + async def handle_telegram_photo(self, source, sender, media): largest_size = self._get_largest_photo_size(media.photo) - file = source.client.download_file_bytes(largest_size.location) + file = await source.client.download_file_bytes(largest_size.location) mime_type = magic.from_buffer(file, mime=True) - uploaded = sender.intent.upload_file(file, mime_type) + uploaded = await sender.intent.upload_file(file, mime_type) info = { "h": largest_size.h, "w": largest_size.w, @@ -608,8 +612,9 @@ class Portal: "mimetype": mime_type, } name = media.caption - sender.intent.set_typing(self.mxid, is_typing=False) - return sender.intent.send_image(self.mxid, uploaded["content_uri"], info=info, text=name) + await sender.intent.set_typing(self.mxid, is_typing=False) + return await sender.intent.send_image(self.mxid, uploaded["content_uri"], info=info, + text=name) def convert_webp(self, file, to="png"): try: @@ -621,14 +626,14 @@ class Portal: self.log.exception(f"Failed to convert webp to {to}") return "image/webp", file - def handle_telegram_document(self, source, sender, media): - file = source.client.download_file_bytes(media.document) + async def handle_telegram_document(self, source, sender, media): + file = await source.client.download_file_bytes(media.document) mime_type = magic.from_buffer(file, mime=True) dont_change_mime = False if mime_type == "image/webp": mime_type, file = self.convert_webp(file, to="png") dont_change_mime = True - uploaded = sender.intent.upload_file(file, mime_type) + uploaded = await sender.intent.upload_file(file, mime_type) name = media.caption for attr in media.document.attributes: if not name and isinstance(attr, DocumentAttributeFilename): @@ -650,9 +655,9 @@ class Portal: type = "m.audio" elif mime_type.startswith("image/"): type = "m.image" - sender.intent.set_typing(self.mxid, is_typing=False) - return sender.intent.send_file(self.mxid, uploaded["content_uri"], info=info, text=name, - file_type=type) + await sender.intent.set_typing(self.mxid, is_typing=False) + return await sender.intent.send_file(self.mxid, uploaded["content_uri"], info=info, + text=name, file_type=type) def handle_telegram_location(self, source, sender, location): long = location.long @@ -679,18 +684,18 @@ class Portal: "formatted_body": formatted_body, }) - def handle_telegram_text(self, source, sender, evt): + async def handle_telegram_text(self, source, sender, evt): self.log.debug(f"Sending {evt.message} to {self.mxid} by {sender.id}") - text, html = formatter.telegram_event_to_matrix(evt, source, + text, html = await formatter.telegram_event_to_matrix(evt, source, config["bridge.native_replies"], config["bridge.link_in_reply"], self.main_intent) - sender.intent.set_typing(self.mxid, is_typing=False) - return sender.intent.send_text(self.mxid, text, html=html) + await sender.intent.set_typing(self.mxid, is_typing=False) + return await sender.intent.send_text(self.mxid, text, html=html) - def handle_telegram_message(self, source, sender, evt): + async def handle_telegram_message(self, source, sender, evt): if not self.mxid: - self.create_matrix_room(source, invites=[source.mxid]) + await self.create_matrix_room(source, invites=[source.mxid]) tg_space = self.tgid if self.peer_type == "channel" else source.tgid @@ -705,14 +710,14 @@ class Portal: return if evt.message: - response = self.handle_telegram_text(source, sender, evt) + response = await self.handle_telegram_text(source, sender, evt) elif evt.media: if isinstance(evt.media, MessageMediaPhoto): - response = self.handle_telegram_photo(source, sender, evt.media) + response = await self.handle_telegram_photo(source, sender, evt.media) elif isinstance(evt.media, MessageMediaDocument): - response = self.handle_telegram_document(source, sender, evt.media) + response = await self.handle_telegram_document(source, sender, evt.media) elif isinstance(evt.media, MessageMediaGeo): - response = self.handle_telegram_location(source, sender, evt.media.geo) + response = await self.handle_telegram_location(source, sender, evt.media.geo) else: self.log.debug("Unhandled Telegram media: %s", evt.media) return @@ -728,7 +733,7 @@ class Portal: self.db.add(DBMessage(tgid=evt.id, mx_room=self.mxid, mxid=mxid, tg_space=tg_space)) self.db.commit() - def handle_telegram_action(self, source, sender, action): + async def handle_telegram_action(self, source, sender, action): if not self.mxid: create_and_exit = (MessageActionChatCreate, MessageActionChannelCreate) create_and_continue = (MessageActionChatAddUser, MessageActionChatJoinedByLink) @@ -739,47 +744,47 @@ class Portal: # TODO figure out how to see changes to about text / channel username if isinstance(action, MessageActionChatEditTitle): - if self.update_title(action.title): + if await self.update_title(action.title): self.save() elif isinstance(action, MessageActionChatEditPhoto): largest_size = self._get_largest_photo_size(action.photo) - if self.update_avatar(source, largest_size.location): + if await self.update_avatar(source, largest_size.location): self.save() elif isinstance(action, MessageActionChatAddUser): for user_id in action.users: - self.add_telegram_user(user_id, source) + await self.add_telegram_user(user_id, source) elif isinstance(action, MessageActionChatJoinedByLink): - self.add_telegram_user(sender.id, source) + await self.add_telegram_user(sender.id, source) elif isinstance(action, MessageActionChatDeleteUser): kick_message = None if sender.id != action.user_id: kick_message = f"Kicked by {sender.displayname}" - self.delete_telegram_user(action.user_id, kick_message) + await self.delete_telegram_user(action.user_id, kick_message) elif isinstance(action, MessageActionChatMigrateTo): self.peer_type = "channel" self.migrate_and_save(action.channel_id) - sender.intent.send_emote(self.mxid, "upgraded this group to a supergroup.") + await sender.intent.send_emote(self.mxid, "upgraded this group to a supergroup.") else: self.log.debug("Unhandled Telegram action in %s: %s", self.title, action) - def set_telegram_admin(self, puppet, user): - levels = self.main_intent.get_power_levels(self.mxid) + async def set_telegram_admin(self, puppet, user): + levels = await self.main_intent.get_power_levels(self.mxid) if user: levels["users"][user.mxid] = 50 if puppet: levels["users"][puppet.mxid] = 50 - self.main_intent.set_power_levels(self.mxid, levels) + await self.main_intent.set_power_levels(self.mxid, levels) - def update_telegram_pin(self, source, id): + async def update_telegram_pin(self, source, id): space = self.tgid if self.peer_type == "channel" else source.tgid message = DBMessage.query.get((id, space)) if message: - self.main_intent.set_pinned_messages(self.mxid, [message.mxid]) + await self.main_intent.set_pinned_messages(self.mxid, [message.mxid]) else: - self.main_intent.set_pinned_messages(self.mxid, []) + await self.main_intent.set_pinned_messages(self.mxid, []) - def update_telegram_participants(self, participants): - levels = self.main_intent.get_power_levels(self.mxid) + async def update_telegram_participants(self, participants): + levels = await self.main_intent.get_power_levels(self.mxid) changed = False admin_power_level = 75 if self.peer_type == "channel" else 50 @@ -815,15 +820,15 @@ class Portal: levels["users"][puppet.mxid] = new_level changed = True if changed: - self.main_intent.set_power_levels(self.mxid, levels) + await self.main_intent.set_power_levels(self.mxid, levels) - def set_telegram_admins_enabled(self, enabled): + async def set_telegram_admins_enabled(self, enabled): level = 50 if enabled else 10 - levels = self.main_intent.get_power_levels(self.mxid) + levels = await self.main_intent.get_power_levels(self.mxid) levels["invite"] = level levels["events"]["m.room.name"] = level levels["events"]["m.room.avatar"] = level - self.main_intent.set_power_levels(self.mxid, levels) + await self.main_intent.set_power_levels(self.mxid, levels) # endregion # region Database conversion @@ -933,4 +938,4 @@ class Portal: def init(context): global config - Portal.az, Portal.db, config = context + Portal.az, Portal.db, config, _ = context diff --git a/mautrix_telegram/puppet.py b/mautrix_telegram/puppet.py index 5b50a537..0f75be31 100644 --- a/mautrix_telegram/puppet.py +++ b/mautrix_telegram/puppet.py @@ -91,35 +91,35 @@ class Puppet: return config.get("bridge.displayname_template", "{displayname} (Telegram)").format( displayname=name) - def update_info(self, source, info): + async def update_info(self, source, info): changed = False if self.username != info.username: self.username = info.username changed = True - changed = self.update_displayname(source, info) or changed + changed = await self.update_displayname(source, info) or changed if isinstance(info.photo, UserProfilePhoto): - changed = self.update_avatar(source, info.photo.photo_big) + changed = await self.update_avatar(source, info.photo.photo_big) if changed: self.save() - def update_displayname(self, source, info): + async def update_displayname(self, source, info): displayname = self.get_displayname(info) if displayname != self.displayname: - self.intent.set_display_name(displayname) + await self.intent.set_display_name(displayname) self.displayname = displayname return True - def update_avatar(self, source, photo): + async def update_avatar(self, source, photo): photo_id = f"{photo.volume_id}-{photo.local_id}" if self.photo_id != photo_id: try: - file = source.client.download_file_bytes(photo) + file = await source.client.download_file_bytes(photo) except LocationInvalidError: return False - uploaded = self.intent.upload_file(file) - self.intent.set_avatar(uploaded["content_uri"]) + uploaded = await self.intent.upload_file(file) + await self.intent.set_avatar(uploaded["content_uri"]) self.photo_id = photo_id return True return False @@ -170,7 +170,7 @@ class Puppet: def init(context): global config - Puppet.az, Puppet.db, config = context + Puppet.az, Puppet.db, config, _ = context localpart = config.get("bridge.username_template", "telegram_{userid}").format(userid="(.+)") hs = config["homeserver"]["domain"] Puppet.mxid_regex = re.compile(f"@{localpart}:{hs}") diff --git a/mautrix_telegram/tgclient.py b/mautrix_telegram/tgclient.py index 7a6efe53..335d8adf 100644 --- a/mautrix_telegram/tgclient.py +++ b/mautrix_telegram/tgclient.py @@ -22,8 +22,8 @@ from telethon.tl.types import * class MautrixTelegramClient(TelegramClient): - def send_message(self, entity, message, reply_to=None, entities=None, link_preview=True): - entity = self.get_input_entity(entity) + async def send_message(self, entity, message, reply_to=None, entities=None, link_preview=True): + entity = await self.get_input_entity(entity) request = SendMessageRequest( peer=entity, @@ -32,7 +32,7 @@ class MautrixTelegramClient(TelegramClient): no_webpage=not link_preview, reply_to_msg_id=self._get_reply_to(reply_to) ) - result = self(request) + result = await self(request) if isinstance(result, UpdateShortSentMessage): return Message( id=result.id, @@ -46,12 +46,12 @@ class MautrixTelegramClient(TelegramClient): return self._get_response_message(request, result) - def send_file(self, entity, file, mime_type=None, caption=None, attributes=None, file_name=None, + async def send_file(self, entity, file, mime_type=None, caption=None, attributes=None, file_name=None, reply_to=None, **kwargs): - entity = self.get_input_entity(entity) + entity = await self.get_input_entity(entity) reply_to = self._get_reply_to(reply_to) - file_handle = self.upload_file(file, file_name=file_name, use_cache=False) + file_handle = await self.upload_file(file, file_name=file_name, use_cache=False) if mime_type == "image/png": media = InputMediaUploadedPhoto(file_handle, caption or "") @@ -66,9 +66,9 @@ class MautrixTelegramClient(TelegramClient): caption=caption or "") request = SendMediaRequest(entity, media, reply_to_msg_id=reply_to) - return self._get_response_message(request, self(request)) + return self._get_response_message(request, await self(request)) - def download_file_bytes(self, location): + async def download_file_bytes(self, location): if isinstance(location, Document): location = InputDocumentFileLocation(location.id, location.access_hash, location.version) @@ -77,7 +77,7 @@ class MautrixTelegramClient(TelegramClient): file = BytesIO() - self.download_file(location, file) + await self.download_file(location, file) data = file.getvalue() file.close() diff --git a/mautrix_telegram/user.py b/mautrix_telegram/user.py index 630a1c78..26db7d16 100644 --- a/mautrix_telegram/user.py +++ b/mautrix_telegram/user.py @@ -15,6 +15,7 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . import logging +import asyncio from telethon.tl.types import * from telethon.tl.types import User as TLUser @@ -81,20 +82,22 @@ class User: # endregion # region Telegram connection management - def start(self): + async def start(self): self.client = MautrixTelegramClient(self.mxid, config["telegram.api_id"], - config["telegram.api_hash"], - update_workers=2) - self.connected = self.client.connect() - if self.logged_in: - self.post_login() + config["telegram.api_hash"]) self.client.add_update_handler(self.update_catch) + self.connected = await self.client.connect() + if self.logged_in: + await self.post_login() return self - def post_login(self, info=None): - self.sync_dialogs() - self.update_info(info) + async def post_login(self, info=None): + try: + await self.sync_dialogs() + await self.update_info(info) + except Exception: + self.log.exception("Failed to run post-login functions") def stop(self): self.client.disconnect() @@ -104,8 +107,8 @@ class User: # endregion # region Telegram actions that need custom methods - def update_info(self, info=None): - info = info or self.client.get_me() + async def update_info(self, info=None): + info = info or await self.client.get_me() changed = False if self.username != info.username: self.username = info.username @@ -127,51 +130,53 @@ class User: self.save() return self.client.log_out() - def sync_dialogs(self): - dialogs = self.client.get_dialogs(limit=30) + async def sync_dialogs(self): + dialogs = await self.client.get_dialogs(limit=30) + creators = [] for dialog in dialogs: entity = dialog.entity if (isinstance(entity, (TLUser, ChatForbidden, ChannelForbidden)) or ( isinstance(entity, Chat) and (entity.deactivated or entity.left))): continue portal = po.Portal.get_by_entity(entity) - portal.create_matrix_room(self, entity, invites=[self.mxid]) + creators.append(portal.create_matrix_room(self, entity, invites=[self.mxid])) + await asyncio.gather(*creators) # endregion # region Telegram update handling - def update_catch(self, update): + async def update_catch(self, update): try: - self.update(update) + await self.update(update) except Exception: self.log.exception("Failed to handle Telegram update") - def update(self, update): + async def update(self, update): if isinstance(update, (UpdateShortChatMessage, UpdateShortMessage, UpdateNewMessage, UpdateNewChannelMessage)): - self.update_message(update) + await self.update_message(update) elif isinstance(update, (UpdateChatUserTyping, UpdateUserTyping)): - self.update_typing(update) + await self.update_typing(update) elif isinstance(update, UpdateUserStatus): - self.update_status(update) + await self.update_status(update) elif isinstance(update, (UpdateChatAdmins, UpdateChatParticipantAdmin)): - self.update_admin(update) + await self.update_admin(update) elif isinstance(update, UpdateChatParticipants): portal = po.Portal.get_by_tgid(update.participants.chat_id) if portal and portal.mxid: - portal.update_telegram_participants(update.participants.participants) + await portal.update_telegram_participants(update.participants.participants) elif isinstance(update, UpdateChannelPinnedMessage): portal = po.Portal.get_by_tgid(update.channel_id, peer_type="channel") if portal and portal.mxid: - portal.update_telegram_pin(self, update.id) + await portal.update_telegram_pin(self, update.id) elif isinstance(update, (UpdateUserName, UpdateUserPhoto)): - self.update_others_info(update) + await self.update_others_info(update) elif isinstance(update, UpdateReadHistoryOutbox): - self.update_read_receipt(update) + await self.update_read_receipt(update) else: self.log.debug("Unhandled update: %s", update) - def update_read_receipt(self, update): + async def update_read_receipt(self, update): if not isinstance(update.peer, PeerUser): self.log.debug("Unexpected read receipt peer: %s", update.peer) return @@ -186,40 +191,42 @@ class User: return puppet = pu.Puppet.get(update.peer.user_id) - puppet.intent.mark_read(portal.mxid, message.mxid) + await puppet.intent.mark_read(portal.mxid, message.mxid) - def update_admin(self, update): + async def update_admin(self, update): portal = po.Portal.get_by_tgid(update.chat_id, peer_type="chat") if isinstance(update, UpdateChatAdmins): - portal.set_telegram_admins_enabled(update.enabled) + await portal.set_telegram_admins_enabled(update.enabled) elif isinstance(update, UpdateChatParticipantAdmin): puppet = pu.Puppet.get(update.user_id) user = User.get_by_tgid(update.user_id) - portal.set_telegram_admin(puppet, user) + await portal.set_telegram_admin(puppet, user) - def update_typing(self, update): + async def update_typing(self, update): if isinstance(update, UpdateUserTyping): portal = po.Portal.get_by_tgid(update.user_id, self.tgid, "user") else: portal = po.Portal.get_by_tgid(update.chat_id, peer_type="chat") sender = pu.Puppet.get(update.user_id) - return portal.handle_telegram_typing(sender, update) + await portal.handle_telegram_typing(sender, update) - def update_others_info(self, update): + async def update_others_info(self, update): puppet = pu.Puppet.get(update.user_id) if isinstance(update, UpdateUserName): - if puppet.update_displayname(self, update): + if await puppet.update_displayname(self, update): puppet.save() elif isinstance(update, UpdateUserPhoto): - if puppet.update_avatar(self, update.photo.photo_big): + if await puppet.update_avatar(self, update.photo.photo_big): puppet.save() - def update_status(self, update): + async def update_status(self, update): puppet = pu.Puppet.get(update.user_id) if isinstance(update.status, UserStatusOnline): - puppet.intent.set_presence("online") + await puppet.intent.set_presence("online") elif isinstance(update.status, UserStatusOffline): - puppet.intent.set_presence("offline") + await puppet.intent.set_presence("offline") + else: + self.log.warning("Unexpected user status update: %s", update) return def get_message_details(self, update): @@ -243,7 +250,7 @@ class User: return update, None, None return update, sender, portal - def update_message(self, update): + async def update_message(self, update): update, sender, portal = self.get_message_details(update) if isinstance(update, MessageService): @@ -253,10 +260,10 @@ class User: return self.log.debug("Handling action %s to %s by %d", update.action, portal.tgid_log, sender.id) - portal.handle_telegram_action(self, sender, update.action) + await portal.handle_telegram_action(self, sender, update.action) else: self.log.debug("Handling message %s to %s by %d", update, portal.tgid_log, sender.tgid) - portal.handle_telegram_message(self, sender, update) + await portal.handle_telegram_message(self, sender, update) # endregion # region Class instance lookup @@ -309,8 +316,7 @@ class User: def init(context): global config - User.az, User.db, config = context + User.az, User.db, config, _ = context users = [User.from_db(user) for user in DBUser.query.all()] - for user in users: - user.start() + return [user.start() for user in users] diff --git a/requirements.txt b/requirements.txt index e938a66a..31f8c940 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,7 @@ aiohttp ruamel.yaml python-magic SQLAlchemy -Telethon +git+git://github.com/LonamiWebs/Telethon@asyncio#egg=Telethon Markdown Pillow future-fstrings