From ac24bc86a09711d71023735e4f01e30faebaa44a Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 8 Aug 2019 22:21:24 +0300 Subject: [PATCH] Minor improvements --- mautrix_telegram/bot.py | 26 ++++++++++++-------------- mautrix_telegram/portal/base.py | 22 +++++++++++++--------- 2 files changed, 25 insertions(+), 23 deletions(-) diff --git a/mautrix_telegram/bot.py b/mautrix_telegram/bot.py index f03d072a..cf87543e 100644 --- a/mautrix_telegram/bot.py +++ b/mautrix_telegram/bot.py @@ -13,9 +13,8 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Awaitable, Callable, Dict, List, Optional, Pattern, Tuple, TYPE_CHECKING +from typing import Awaitable, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING import logging -import re from telethon.tl.patched import Message, MessageService from telethon.tl.types import ( @@ -44,7 +43,6 @@ ReplyFunc = Callable[[str], Awaitable[Message]] class Bot(AbstractUser): log: logging.Logger = logging.getLogger("mau.user.bot") - mxid_regex: Pattern = re.compile("@.+:.+") token: str chats: Dict[int, str] @@ -110,9 +108,9 @@ class Bot(AbstractUser): if isinstance(chat, ChatForbidden) or chat.left or chat.deactivated: self.remove_chat(TelegramID(chat.id)) - channel_ids = [InputChannel(chat_id, 0) + channel_ids = (InputChannel(chat_id, 0) for chat_id, chat_type in self.chats.items() - if chat_type == "channel"] + if chat_type == "channel") for channel_id in channel_ids: try: await self.client(GetChannelsRequest([channel_id])) @@ -165,7 +163,7 @@ class Bot(AbstractUser): return False async def check_can_use_commands(self, event: Message, reply: ReplyFunc) -> bool: - if not await self._can_use_commands(event.to_id, event.from_id): + if not await self._can_use_commands(event.to_id, TelegramID(event.from_id)): await reply("You do not have the permission to use that command.") return False return True @@ -193,7 +191,7 @@ class Bot(AbstractUser): elif not portal.mxid: return await reply("Portal does not have Matrix room. " "Create one with /portal first.") - if not self.mxid_regex.match(mxid_input): + if mxid_input[0] != '@' or mxid_input.find(':') < 2: return await reply("That doesn't look like a Matrix ID.") user = await u.User.get_by_mxid(mxid_input).ensure_started() if not user.relaybot_whitelisted: @@ -203,7 +201,7 @@ class Bot(AbstractUser): return await reply("That user seems to be logged in. " f"Just invite [{displayname}](tg://user?id={user.tgid})") else: - await portal.main_intent.invite(portal.mxid, user.mxid) + await portal.main_intent.invite_user(portal.mxid, user.mxid) return await reply(f"Invited `{user.mxid}` to the portal.") @staticmethod @@ -252,15 +250,15 @@ class Bot(AbstractUser): mxid = text[text.index(" ") + 1:] except ValueError: mxid = "" - await self.handle_command_invite(portal, reply, mxid_input=mxid) + await self.handle_command_invite(portal, reply, mxid_input=UserID(mxid)) def handle_service_message(self, message: MessageService) -> None: - to_id: TelegramID = message.to_id - if isinstance(to_id, PeerChannel): - to_id = to_id.channel_id + to_peer = message.to_id + if isinstance(to_peer, PeerChannel): + to_id = TelegramID(to_peer.channel_id) chat_type = "channel" - elif isinstance(to_id, PeerChat): - to_id = to_id.chat_id + elif isinstance(to_peer, PeerChat): + to_id = TelegramID(to_peer.chat_id) chat_type = "chat" else: return diff --git a/mautrix_telegram/portal/base.py b/mautrix_telegram/portal/base.py index 046ab5ea..cd53d58b 100644 --- a/mautrix_telegram/portal/base.py +++ b/mautrix_telegram/portal/base.py @@ -13,12 +13,11 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Awaitable, Dict, List, Optional, Pattern, Tuple, Union, Any, TYPE_CHECKING +from typing import Awaitable, Dict, List, Optional, Tuple, Union, Any, TYPE_CHECKING from abc import ABC, abstractmethod import asyncio import logging import json -import re from telethon.tl.functions.messages import ExportChatInviteRequest from telethon.tl.types import (Channel, ChannelFull, Chat, ChatFull, ChatInviteEmpty, InputChannel, @@ -69,7 +68,8 @@ class BasePortal(ABC): public_portals: bool = False alias_template: str = None - mx_alias_regex: Pattern = None + _mx_alias_prefix: str = None + _mx_alias_suffix: str = None hs_domain: str = None # Instance cache @@ -346,9 +346,10 @@ class BasePortal(ABC): @classmethod def get_username_from_mx_alias(cls, alias: str) -> Optional[str]: - match = cls.mx_alias_regex.match(alias) - if match: - return match.group(1) + prefix = cls._mx_alias_prefix + suffix = cls._mx_alias_suffix + if alias[:len(prefix)] == prefix and alias[-len(suffix):] == suffix: + return alias[len(prefix):-len(suffix)] return None @classmethod @@ -473,7 +474,10 @@ def init(context: Context) -> None: BasePortal.public_portals = config["bridge.public_portals"] BasePortal.filter_mode = config["bridge.filter.mode"] BasePortal.filter_list = config["bridge.filter.list"] - BasePortal.alias_template = config.get("bridge.alias_template", "telegram_{groupname}") BasePortal.hs_domain = config["homeserver.domain"] - BasePortal.mx_alias_regex = re.compile( - f"#{BasePortal.alias_template.format(groupname='(.+)')}:{BasePortal.hs_domain}") + BasePortal.alias_template = config["bridge.alias_template"] + index = BasePortal.alias_template.index("{groupname}") + length = len("{groupname}") + BasePortal._mx_alias_prefix = f"#{BasePortal.alias_template[:index]}" + BasePortal._mx_alias_suffix = (f"{BasePortal.alias_template[index + length:]}" + f":{BasePortal.hs_domain}")