diff --git a/mautrix_telegram/commands.py b/mautrix_telegram/commands.py index c414adae..1480224c 100644 --- a/mautrix_telegram/commands.py +++ b/mautrix_telegram/commands.py @@ -14,9 +14,9 @@ # # You should have received a copy of the GNU General Public License # along with this program. If not, see . -from contextlib import contextmanager import markdown import logging +import asyncio from matrix_client.errors import MatrixRequestError @@ -54,33 +54,45 @@ def format_duration(seconds): return " and ".join(parts) +class CommandEvent: + def __init__(self, az, command_prefix, room, sender, args, is_management, is_portal): + self.az = az + self.command_prefix = command_prefix + self.room_id = room + self.sender = sender + self.args = args + self.is_management = is_management + self.is_portal = is_portal + + def reply(self, message, allow_html=False, render_markdown=True): + if not self.room_id: + raise AttributeError("the reply function can only be used from within" + "the `CommandHandler.run` context manager") + + message = message.replace("$cmdprefix+sp ", + "" if self.is_management else f"{self.command_prefix} ") + message = message.replace("$cmdprefix", self.command_prefix) + html = None + if render_markdown: + html = markdown.markdown(message, safe_mode="escape" if allow_html else False) + elif allow_html: + html = message + return self.az.intent.send_notice(self.room_id, message, html=html) + + class CommandHandler: log = logging.getLogger("mau.commands") def __init__(self, context): - self.az, self.db, self.config, _ = context + self.az, self.db, self.config, self.loop = context self.command_prefix = self.config["bridge.command_prefix"] - self._room_id = None - self._is_management = False - self._is_portal = False # region Utility functions for handling commands - def handle(self, room, sender, command, args, is_management, is_portal): - with self.handler(sender, room, command, args, is_management, is_portal) as handle_command: - try: - return handle_command(self, sender, args) - except FloodWaitError as e: - return self.reply(f"Flood error: Please wait {format_duration(e.seconds)}") - except Exception: - self.log.exception(f"Fatal error handling command " - + f"'$cmdprefix {command} {''.join(args)}' from {sender.mxid}") - return self.reply("Fatal error while handling command. Check logs for more details.") - - @contextmanager - def handler(self, sender, room, command, args, is_management, is_portal): + async def handle(self, room, sender, command, args, is_management, is_portal): + evt = CommandEvent(self.az, self.command_prefix, room, sender, args, is_management, + is_portal) command = command.lower() - self._room_id = room try: command = command_handlers[command] except KeyError: @@ -89,207 +101,197 @@ class CommandHandler: command = sender.command_status["next"] else: command = command_handlers["unknown_command"] - self._is_management = is_management - self._is_portal = is_portal - yield command - self._is_management = None - self._is_portal = None - self._room_id = None - - def reply(self, message, allow_html=False, render_markdown=True): - if not self._room_id: - raise AttributeError("the reply function can only be used from within" - "the `CommandHandler.run` context manager") - - message = message.replace("$cmdprefix+sp ", - "" if self._is_management else f"{self.command_prefix} ") - message = message.replace("$cmdprefix", self.command_prefix) - html = None - if render_markdown: - html = markdown.markdown(message, safe_mode="escape" if allow_html else False) - elif allow_html: - html = message - return self.az.intent.send_notice(self._room_id, message, html=html) + try: + await command(self, evt) + except FloodWaitError as e: + return evt.reply(f"Flood error: Please wait {format_duration(e.seconds)}") + except Exception: + self.log.exception(f"Fatal error handling command " + + f"'$cmdprefix {command} {''.join(args)}' from {sender.mxid}") + return evt.reply("Fatal error while handling command. Check logs for more details.") # endregion # region Command handlers @command_handler - async def ping(self, sender, args): - if not sender.logged_in: - return await self.reply("You're not logged in.") - me = await sender.client.get_me() + async def ping(self, evt): + if not evt.sender.logged_in: + return await evt.reply("You're not logged in.") + me = await evt.sender.client.get_me() if me: - return await self.reply(f"You're logged in as @{me.username}") + return await evt.reply(f"You're logged in as @{me.username}") else: - return await self.reply("You're not logged in.") + return await evt.reply("You're not logged in.") # region Authentication commands @command_handler - def register(self, sender, args): - return self.reply("Not yet implemented.") + def register(self, evt): + return evt.reply("Not yet implemented.") @command_handler - async def login(self, sender, args): - if not self._is_management: - return await self.reply( + async def login(self, evt): + if not evt.is_management: + return await evt.reply( "`login` is a restricted command: you may only run it in management rooms.") - elif sender.logged_in: - return await self.reply("You are already logged in.") - elif len(args) == 0: - return await self.reply("**Usage:** `$cmdprefix+sp login `") - phone_number = args[0] - await sender.client.sign_in(phone_number) - sender.command_status = { + elif evt.sender.logged_in: + return await evt.reply("You are already logged in.") + elif len(evt.args) == 0: + return await evt.reply("**Usage:** `$cmdprefix+sp login `") + phone_number = evt.args[0] + await evt.sender.client.sign_in(phone_number) + evt.sender.command_status = { "next": command_handlers["enter_code"], "action": "Login", } - return await self.reply(f"Login code sent to {phone_number}. Please send the code here.") + return await evt.reply(f"Login code sent to {phone_number}. Please send the code here.") @command_handler - async def enter_code(self, sender, args): - if not sender.command_status: - return await self.reply("Request a login code first with `$cmdprefix+sp login `") - elif len(args) == 0: - return await self.reply("**Usage:** `$cmdprefix+sp enter_code `") + async def enter_code(self, evt): + if not evt.sender.command_status: + return await evt.reply( + "Request a login code first with `$cmdprefix+sp login `") + elif len(evt.args) == 0: + return await evt.reply("**Usage:** `$cmdprefix+sp enter_code `") try: - user = await sender.client.sign_in(code=args[0]) - sender.post_login(user) - sender.command_status = None - return await self.reply(f"Successfully logged in as @{user.username}") + user = await evt.sender.client.sign_in(code=evt.args[0]) + asyncio.ensure_future(evt.sender.post_login(user), loop=self.loop) + evt.sender.command_status = None + return await evt.reply(f"Successfully logged in as @{user.username}") except PhoneNumberUnoccupiedError: - return await self.reply("That phone number has not been registered." - "Please register with `$cmdprefix+sp register `.") + return await evt.reply("That phone number has not been registered." + "Please register with `$cmdprefix+sp register `.") except PhoneCodeExpiredError: - return await self.reply( + return await evt.reply( "Phone code expired. Try again with `$cmdprefix+sp login `.") except PhoneCodeInvalidError: - return await self.reply("Invalid phone code.") + return await evt.reply("Invalid phone code.") except PhoneNumberAppSignupForbiddenError: - return await self.reply( + return await evt.reply( "Your phone number does not allow 3rd party apps to sign in.") except PhoneNumberFloodError: - return await self.reply( + return await evt.reply( "Your phone number has been temporarily blocked for flooding. " "The block is usually applied for around a day.") except PhoneNumberBannedError: - return await self.reply("Your phone number has been banned from Telegram.") + return await evt.reply("Your phone number has been banned from Telegram.") except SessionPasswordNeededError: - sender.command_status = { + evt.sender.command_status = { "next": command_handlers["enter_password"], "action": "Login (password entry)", } - return await self.reply("Your account has two-factor authentication." - "Please send your password here.") + return await evt.reply("Your account has two-factor authentication." + "Please send your password here.") except Exception: - self.log.exception() - return await self.reply("Unhandled exception while sending code." - "Check console for more details.") + self.log.exception("Error sending phone code") + return await evt.reply("Unhandled exception while sending code." + "Check console for more details.") @command_handler - async def enter_password(self, sender, args): - if not sender.command_status: - return await self.reply("Request a login code first with `$cmdprefix+sp login `") - elif len(args) == 0: - return await self.reply("**Usage:** `$cmdprefix+sp enter_password `") + async def enter_password(self, evt): + if not evt.sender.command_status: + return await evt.reply( + "Request a login code first with `$cmdprefix+sp login `") + elif len(evt.args) == 0: + return await evt.reply("**Usage:** `$cmdprefix+sp enter_password `") try: - user = await sender.client.sign_in(password=args[0]) - sender.post_login(user) - sender.command_status = None - return await self.reply(f"Successfully logged in as @{user.username}") + user = await evt.sender.client.sign_in(password=evt.args[0]) + asyncio.ensure_future(evt.sender.post_login(user), loop=self.loop) + evt.sender.command_status = None + return await evt.reply(f"Successfully logged in as @{user.username}") except PasswordHashInvalidError: - return await self.reply("Incorrect password.") + return await evt.reply("Incorrect password.") except Exception: - self.log.exception() - return await self.reply("Unhandled exception while sending password. " - "Check console for more details.") + self.log.exception("Error sending password") + return await evt.reply("Unhandled exception while sending password. " + "Check console for more details.") @command_handler - async def logout(self, sender, args): - if not sender.logged_in: - return await self.reply("You're not logged in.") - if await sender.log_out(): - return await self.reply("Logged out successfully.") - return await self.reply("Failed to log out.") + async def logout(self, evt): + if not evt.sender.logged_in: + return await evt.reply("You're not logged in.") + if await evt.sender.log_out(): + return await evt.reply("Logged out successfully.") + return await evt.reply("Failed to log out.") # endregion # region Telegram interaction commands @command_handler - async def search(self, sender, args): - if len(args) == 0: - return await self.reply("**Usage:** `$cmdprefix+sp search [-r|--remote] `") - elif not sender.logged_in: - return await self.reply("This command requires you to be logged in.") + async def search(self, evt): + if len(evt.args) == 0: + return await evt.reply("**Usage:** `$cmdprefix+sp search [-r|--remote] `") + elif not evt.sender.logged_in: + return await evt.reply("This command requires you to be logged in.") # force_remote = False - if args[0] in {"-r", "--remote"}: + if evt.args[0] in {"-r", "--remote"}: # force_remote = True - args.pop(0) - query = " ".join(args) + evt.args.pop(0) + query = " ".join(evt.args) if len(query) < 5: - return await self.reply("Minimum length of query for remote search is 5 characters.") - found = await sender.client(SearchRequest(q=query, limit=10)) + return await evt.reply("Minimum length of query for remote search is 5 characters.") + found = await evt.sender.client(SearchRequest(q=query, limit=10)) # reply = ["**People:**", ""] reply = ["**Results from Telegram server:**", ""] for result in found.users: puppet = pu.Puppet.get(result.id) - await puppet.update_info(sender, result) + await puppet.update_info(evt.sender, result) reply.append( f"* [{puppet.displayname}](https://matrix.to/#/{puppet.mxid}): {puppet.id}") # reply.extend(("", "**Chats:**", "")) # for result in found.chats: # reply.append(f"* {result.title}") - return await self.reply("\n".join(reply)) + return await evt.reply("\n".join(reply)) @command_handler - async def pm(self, sender, args): - if len(args) == 0: - return await self.reply("**Usage:** `$cmdprefix+sp pm `") - elif not sender.logged_in: - return await self.reply("This command requires you to be logged in.") + async def pm(self, evt): + if len(evt.args) == 0: + return await evt.reply("**Usage:** `$cmdprefix+sp pm `") + elif not evt.sender.logged_in: + return await evt.reply("This command requires you to be logged in.") - user = await sender.client.get_entity(args[0]) + user = await evt.sender.client.get_entity(evt.args[0]) if not user: - return await self.reply("User not found.") + return await evt.reply("User not found.") elif not isinstance(user, User): - return await self.reply("That doesn't seem to be a user.") - portal = po.Portal.get_by_entity(user, sender.tgid) - await portal.create_matrix_room(sender, user, [sender.mxid]) - return await self.reply(f"Created private chat room with {pu.Puppet.get_displayname(user, False)}") + return await evt.reply("That doesn't seem to be a user.") + portal = po.Portal.get_by_entity(user, evt.sender.tgid) + await portal.create_matrix_room(evt.sender, user, [evt.sender.mxid]) + return await evt.reply( + f"Created private chat room with {pu.Puppet.get_displayname(user, False)}") @command_handler - async def invitelink(self, sender, args): - if not sender.logged_in: - return await self.reply("This command requires you to be logged in.") + async def invitelink(self, evt): + if not evt.sender.logged_in: + return await evt.reply("This command requires you to be logged in.") - portal = po.Portal.get_by_mxid(self._room_id) + portal = po.Portal.get_by_mxid(evt.room_id) if not portal: - return await self.reply("This is not a portal room.") + return await evt.reply("This is not a portal room.") if portal.peer_type == "user": - return await self.reply("You can't invite users to private chats.") + return await evt.reply("You can't invite users to private chats.") try: - link = await portal.get_invite_link(sender) - return await self.reply(f"Invite link to {portal.title}: {link}") + link = await portal.get_invite_link(evt.sender) + return await evt.reply(f"Invite link to {portal.title}: {link}") except ValueError as e: - return await self.reply(e.args[0]) + return await evt.reply(e.args[0]) except ChatAdminRequiredError: - return await self.reply("You don't have the permission to create an invite link.") + return await evt.reply("You don't have the permission to create an invite link.") @command_handler - async def deleteportal(self, sender, args): - if not sender.logged_in: - return await self.reply("This command requires you to be logged in.") - elif not sender.is_admin: - return await self.reply("This is command requires administrator privileges.") + async def deleteportal(self, evt): + if not evt.sender.logged_in: + return await evt.reply("This command requires you to be logged in.") + elif not evt.sender.is_admin: + return await evt.reply("This is command requires administrator privileges.") - portal = po.Portal.get_by_mxid(self._room_id) + portal = po.Portal.get_by_mxid(evt.room_id) if not portal: - return await self.reply("This is not a portal room.") + return await evt.reply("This is not a portal room.") for user in portal.main_intent.get_room_members(portal.mxid): if user != portal.main_intent.mxid: @@ -308,55 +310,56 @@ class CommandHandler: return value @command_handler - async def join(self, sender, args): - if len(args) == 0: - return await self.reply("**Usage:** `$cmdprefix+sp join `") - elif not sender.logged_in: - return await self.reply("This command requires you to be logged in.") + async def join(self, evt): + if len(evt.args) == 0: + return await evt.reply("**Usage:** `$cmdprefix+sp join `") + elif not evt.sender.logged_in: + return await evt.reply("This command requires you to be logged in.") regex = re.compile(r"(?:https?://)?t(?:elegram)?\.(?:dog|me)(?:joinchat/)?/(.+)") - arg = regex.match(args[0]) + arg = regex.match(evt.args[0]) if not arg: - return await self.reply("That doesn't look like a Telegram invite link.") + return await evt.reply("That doesn't look like a Telegram invite link.") arg = arg.group(1) if arg.startswith("joinchat/"): invite_hash = arg[len("joinchat/"):] try: - await sender.client(CheckChatInviteRequest(invite_hash)) + await evt.sender.client(CheckChatInviteRequest(invite_hash)) except InviteHashInvalidError: - return await self.reply("Invalid invite link.") + return await evt.reply("Invalid invite link.") except InviteHashExpiredError: - return await self.reply("Invite link expired.") + return await evt.reply("Invite link expired.") try: - updates = sender.client(ImportChatInviteRequest(invite_hash)) + updates = evt.sender.client(ImportChatInviteRequest(invite_hash)) except UserAlreadyParticipantError: - return await self.reply("You are already in that chat.") + return await evt.reply("You are already in that chat.") else: - channel = await sender.client.get_entity(arg) + channel = await evt.sender.client.get_entity(arg) if not channel: - return await self.reply("Channel/supergroup not found.") - updates = await sender.client(JoinChannelRequest(channel)) + return await evt.reply("Channel/supergroup not found.") + updates = await evt.sender.client(JoinChannelRequest(channel)) for chat in updates.chats: portal = po.Portal.get_by_entity(chat) if portal.mxid: - await portal.create_matrix_room(sender, chat, [sender.mxid]) - return await self.reply(f"Created room for {portal.title}") + await portal.create_matrix_room(evt.sender, chat, [evt.sender.mxid]) + return await evt.reply(f"Created room for {portal.title}") else: - await portal.invite_matrix([sender.mxid]) - return await self.reply(f"Invited you to portal of {portal.title}") + await portal.invite_matrix([evt.sender.mxid]) + return await evt.reply(f"Invited you to portal of {portal.title}") @command_handler - async def create(self, sender, args): - type = args[0] if len(args) > 0 else "group" + async def create(self, evt): + type = evt.args[0] if len(evt.args) > 0 else "group" if type not in {"chat", "group", "supergroup", "channel"}: - return await self.reply("**Usage:** `$cmdprefix+sp create ['group'/'supergroup'/'channel']`") - elif not sender.logged_in: - return await self.reply("This command requires you to be logged in.") + return await evt.reply( + "**Usage:** `$cmdprefix+sp create ['group'/'supergroup'/'channel']`") + elif not evt.sender.logged_in: + return await evt.reply("This command requires you to be logged in.") - if po.Portal.get_by_mxid(self._room_id): - return await self.reply("This is already a portal room.") + if po.Portal.get_by_mxid(evt.room_id): + return await evt.reply("This is already a portal room.") - state = await self.az.intent.get_room_state(self._room_id) + state = await self.az.intent.get_room_state(evt.room_id) title = None about = None levels = None @@ -368,18 +371,19 @@ class CommandHandler: elif event["type"] == "m.room.power_levels": levels = event["content"] if not title: - return await self.reply("Please set a title before creating a Telegram chat.") + return await evt.reply("Please set a title before creating a Telegram chat.") elif (not levels or not levels["users"] or self.az.intent.mxid not in levels["users"] or levels["users"][self.az.intent.mxid] < 100): - return await self.reply(f"Please give " - + f"[the bridge bot](https://matrix.to/#/{self.az.intent.mxid}) " - + f"a power level of 100 before creating a Telegram chat.") + return await evt.reply(f"Please give " + + f"[the bridge bot](https://matrix.to/#/{self.az.intent.mxid})" + + f" a power level of 100 before creating a Telegram chat.") else: for user, level in levels["users"].items(): if level >= 100 and user != self.az.intent.mxid: - return await self.reply(f"Please make sure only the bridge bot has power level above" - + f"99 before creating a Telegram chat.\n\n" - + f"Use power level 95 instead of 100 for admins.") + return await evt.reply( + f"Please make sure only the bridge bot has power level above" + + f"99 before creating a Telegram chat.\n\n" + + f"Use power level 95 instead of 100 for admins.") supergroup = type == "supergroup" type = { @@ -389,86 +393,88 @@ class CommandHandler: "group": "chat", }[type] - portal = po.Portal(tgid=None, mxid=self._room_id, title=title, about=about, peer_type=type) + portal = po.Portal(tgid=None, mxid=evt.room_id, title=title, about=about, peer_type=type) try: - await portal.create_telegram_chat(sender, supergroup=supergroup) + await portal.create_telegram_chat(evt.sender, supergroup=supergroup) except ValueError as e: - return await self.reply(e.args[0]) - return await self.reply(f"Telegram chat created. ID: {portal.tgid}") + return await evt.reply(e.args[0]) + return await evt.reply(f"Telegram chat created. ID: {portal.tgid}") @command_handler - async def upgrade(self, sender, args): - if not sender.logged_in: - return await self.reply("This command requires you to be logged in.") + async def upgrade(self, evt): + if not evt.sender.logged_in: + return await evt.reply("This command requires you to be logged in.") - portal = po.Portal.get_by_mxid(self._room_id) + portal = po.Portal.get_by_mxid(evt.room_id) if not portal: - return await self.reply("This is not a portal room.") + return await evt.reply("This is not a portal room.") elif portal.peer_type == "channel": - return await self.reply("This is already a supergroup or a channel.") + return await evt.reply("This is already a supergroup or a channel.") elif portal.peer_type == "user": - return await self.reply("You can't upgrade private chats.") + return await evt.reply("You can't upgrade private chats.") try: - await portal.upgrade_telegram_chat(sender) - return await self.reply(f"Group upgraded to supergroup. New ID: {portal.tgid}") + await portal.upgrade_telegram_chat(evt.sender) + return await evt.reply(f"Group upgraded to supergroup. New ID: {portal.tgid}") except ChatAdminRequiredError: - return await self.reply("You don't have the permission to upgrade this group.") + return await evt.reply("You don't have the permission to upgrade this group.") except ValueError as e: - return await self.reply(e.args[0]) + return await evt.reply(e.args[0]) @command_handler - async def groupname(self, sender, args): - if len(args) == 0: - return await self.reply("**Usage:** `$cmdprefix+sp groupname `") - if not sender.logged_in: - return await self.reply("This command requires you to be logged in.") + async def groupname(self, evt): + if len(evt.args) == 0: + return await evt.reply("**Usage:** `$cmdprefix+sp groupname `") + if not evt.sender.logged_in: + return await evt.reply("This command requires you to be logged in.") - portal = po.Portal.get_by_mxid(self._room_id) + portal = po.Portal.get_by_mxid(evt.room_id) if not portal: - return await self.reply("This is not a portal room.") + return await evt.reply("This is not a portal room.") elif portal.peer_type != "channel": - return await self.reply("Only channels and supergroups have usernames.") + return await evt.reply("Only channels and supergroups have usernames.") try: - await portal.set_telegram_username(sender, args[0] if args[0] != "-" else "") + await portal.set_telegram_username(evt.sender, + evt.args[0] if evt.args[0] != "-" else "") if portal.username: - return await self.reply(f"Username of channel changed to {portal.username}.") + return await evt.reply(f"Username of channel changed to {portal.username}.") else: - return await self.reply(f"Channel is now private.") + return await evt.reply(f"Channel is now private.") except ChatAdminRequiredError: - return await self.reply("You don't have the permission to set the username of this channel.") + return await evt.reply( + "You don't have the permission to set the username of this channel.") except UsernameNotModifiedError: if portal.username: - return await self.reply("That is already the username of this channel.") + return await evt.reply("That is already the username of this channel.") else: - return await self.reply("This channel is already private") + return await evt.reply("This channel is already private") except UsernameOccupiedError: - return await self.reply("That username is already in use.") + return await evt.reply("That username is already in use.") except UsernameInvalidError: - return await self.reply("Invalid username") + return await evt.reply("Invalid username") # endregion # region Command-related commands @command_handler - def cancel(self, sender, args): - if sender.command_status: - action = sender.command_status["action"] - sender.command_status = None - return self.reply(f"{action} cancelled.") + def cancel(self, evt): + if evt.sender.command_status: + action = evt.sender.command_status["action"] + evt.sender.command_status = None + return evt.reply(f"{action} cancelled.") else: - return self.reply("No ongoing command.") + return evt.reply("No ongoing command.") @command_handler - def unknown_command(self, sender, args): - return self.reply("Unknown command. Try `$cmdprefix+sp help` for help.") + def unknown_command(self, evt): + return evt.reply("Unknown command. Try `$cmdprefix+sp help` for help.") @command_handler - def help(self, sender, args): - if self._is_management: + def help(self, evt): + if evt.is_management: management_status = ("This is a management room: prefixing commands " "with `$cmdprefix` is not required.\n") - elif self._is_portal: + elif evt.is_portal: management_status = ("**This is a portal room**: you must always " "prefix commands with `$cmdprefix`.\n" "Management commands will not be sent to Telegram.") @@ -503,7 +509,7 @@ class CommandHandler: **groupname** <_name_|`-`> - Change the username of a supergroup/channel. To disable, use a dash (`-`) as the name. """ - return self.reply(management_status + help) + return evt.reply(management_status + help) # endregion # endregion