Merge pull request #206 from V02460/master

Add type annotations
This commit is contained in:
Tulir Asokan
2018-08-15 10:18:39 +03:00
committed by GitHub
28 changed files with 705 additions and 531 deletions
+2 -2
View File
@@ -14,7 +14,7 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional
from typing import Coroutine, List, Optional
import argparse
import asyncio
import logging.config
@@ -115,7 +115,7 @@ with appserv.run(config["appservice.hostname"], config["appservice.port"]) as st
startup_actions = (init_puppet(context) +
init_user(context) +
[start,
context.mx.init_as_bot()])
context.mx.init_as_bot()]) # type: List[Coroutine]
if context.bot:
startup_actions.append(context.bot.start())
+27 -24
View File
@@ -38,6 +38,7 @@ from .db import Message as DBMessage
from .tgclient import MautrixTelegramClient
if TYPE_CHECKING:
from .types import TelegramID
from .context import Context
from .config import Config
from .bot import Bot
@@ -60,17 +61,18 @@ class AbstractUser(ABC):
bot = None # type: Bot
ignore_incoming_bot_events = True # type: bool
def __init__(self):
def __init__(self) -> None:
self.is_admin = False # type: bool
self.matrix_puppet_whitelisted = False # type: bool
self.puppet_whitelisted = False # type: bool
self.whitelisted = False # type: bool
self.relaybot_whitelisted = False # type: bool
self.client = None # type: MautrixTelegramClient
self.tgid = None # type: int
self.tgid = None # type: TelegramID
self.mxid = None # type: str
self.is_relaybot = False # type: bool
self.is_bot = False # type: bool
self.relaybot = None # type: Optional[Bot]
@property
def connected(self) -> bool:
@@ -93,7 +95,7 @@ class AbstractUser(ABC):
config["telegram.proxy.rdns"],
config["telegram.proxy.username"], config["telegram.proxy.password"])
def _init_client(self):
def _init_client(self) -> None:
self.log.debug(f"Initializing client for {self.name}")
device = f"{platform.system()} {platform.release()}"
sysversion = MautrixTelegramClient.__version__
@@ -114,18 +116,18 @@ class AbstractUser(ABC):
return False
@abstractmethod
async def post_login(self):
async def post_login(self) -> None:
raise NotImplementedError()
@abstractmethod
def register_portal(self, portal: po.Portal):
def register_portal(self, portal: po.Portal) -> None:
raise NotImplementedError()
@abstractmethod
def unregister_portal(self, portal: po.Portal):
def unregister_portal(self, portal: po.Portal) -> None:
raise NotImplementedError()
async def _update_catch(self, update: TypeUpdate):
async def _update_catch(self, update: TypeUpdate) -> None:
try:
if not await self.update(update):
await self._update(update)
@@ -154,14 +156,14 @@ class AbstractUser(ABC):
and (not self.is_bot or allow_bot)
and await self.is_logged_in())
async def start(self, delete_unless_authenticated: bool = False) -> "AbstractUser":
async def start(self, delete_unless_authenticated: bool = False) -> 'AbstractUser':
if not self.client:
self._init_client()
await self.client.connect()
self.log.debug("%s connected: %s", self.mxid, self.connected)
return self
async def ensure_started(self, even_if_no_session=False) -> "AbstractUser":
async def ensure_started(self, even_if_no_session=False) -> 'AbstractUser':
if not self.puppet_whitelisted:
return self
self.log.debug("ensure_started(%s, connected=%s, even_if_no_session=%s, session_count=%s)",
@@ -175,13 +177,13 @@ class AbstractUser(ABC):
await self.start(delete_unless_authenticated=not even_if_no_session)
return self
async def stop(self):
async def stop(self) -> None:
await self.client.disconnect()
self.client = None
# region Telegram update handling
async def _update(self, update: TypeUpdate):
async def _update(self, update: TypeUpdate) -> None:
if isinstance(update, (UpdateShortChatMessage, UpdateShortMessage, UpdateNewChannelMessage,
UpdateNewMessage, UpdateEditMessage, UpdateEditChannelMessage)):
await self.update_message(update)
@@ -207,18 +209,18 @@ class AbstractUser(ABC):
self.log.debug("Unhandled update: %s", update)
@staticmethod
async def update_pinned_messages(update: UpdateChannelPinnedMessage):
async def update_pinned_messages(update: UpdateChannelPinnedMessage) -> None:
portal = po.Portal.get_by_tgid(update.channel_id)
if portal and portal.mxid:
await portal.receive_telegram_pin_id(update.id)
@staticmethod
async def update_participants(update: UpdateChatParticipants):
async def update_participants(update: UpdateChatParticipants) -> None:
portal = po.Portal.get_by_tgid(update.participants.chat_id)
if portal and portal.mxid:
await portal.update_telegram_participants(update.participants.participants)
async def update_read_receipt(self, update: UpdateReadHistoryOutbox):
async def update_read_receipt(self, update: UpdateReadHistoryOutbox) -> None:
if not isinstance(update.peer, PeerUser):
self.log.debug("Unexpected read receipt peer: %s", update.peer)
return
@@ -235,7 +237,8 @@ class AbstractUser(ABC):
puppet = pu.Puppet.get(update.peer.user_id)
await puppet.intent.mark_read(portal.mxid, message.mxid)
async def update_admin(self, update: Union[UpdateChatAdmins, UpdateChatParticipantAdmin]):
async def update_admin(self,
update: Union[UpdateChatAdmins, UpdateChatParticipantAdmin]) -> None:
# TODO duplication not checked
portal = po.Portal.get_by_tgid(update.chat_id, peer_type="chat")
if isinstance(update, UpdateChatAdmins):
@@ -245,7 +248,7 @@ class AbstractUser(ABC):
else:
self.log.warning("Unexpected admin status update: %s", update)
async def update_typing(self, update: Union[UpdateUserTyping, UpdateChatUserTyping]):
async def update_typing(self, update: Union[UpdateUserTyping, UpdateChatUserTyping]) -> None:
if isinstance(update, UpdateUserTyping):
portal = po.Portal.get_by_tgid(update.user_id, self.tgid, "user")
else:
@@ -253,7 +256,7 @@ class AbstractUser(ABC):
sender = pu.Puppet.get(update.user_id)
await portal.handle_telegram_typing(sender, update)
async def update_others_info(self, update: Union[UpdateUserName, UpdateUserPhoto]):
async def update_others_info(self, update: Union[UpdateUserName, UpdateUserPhoto]) -> None:
# TODO duplication not checked
puppet = pu.Puppet.get(update.user_id)
if isinstance(update, UpdateUserName):
@@ -265,7 +268,7 @@ class AbstractUser(ABC):
else:
self.log.warning("Unexpected other user info update: %s", update)
async def update_status(self, update: UpdateUserStatus):
async def update_status(self, update: UpdateUserStatus) -> None:
puppet = pu.Puppet.get(update.user_id)
if isinstance(update.status, UserStatusOnline):
await puppet.default_mxid_intent.set_presence("online")
@@ -300,7 +303,7 @@ class AbstractUser(ABC):
return update, sender, portal
@staticmethod
async def _try_redact(portal: po.Portal, message: DBMessage):
async def _try_redact(portal: po.Portal, message: DBMessage) -> None:
if not portal:
return
try:
@@ -308,7 +311,7 @@ class AbstractUser(ABC):
except MatrixRequestError:
pass
async def delete_message(self, update: UpdateDeleteMessages):
async def delete_message(self, update: UpdateDeleteMessages) -> None:
if len(update.messages) > MAX_DELETIONS:
return
@@ -324,7 +327,7 @@ class AbstractUser(ABC):
await self._try_redact(portal, message)
self.db.commit()
async def delete_channel_message(self, update: UpdateDeleteChannelMessages):
async def delete_channel_message(self, update: UpdateDeleteChannelMessages) -> None:
if len(update.messages) > MAX_DELETIONS:
return
@@ -340,7 +343,7 @@ class AbstractUser(ABC):
await self._try_redact(portal, message)
self.db.commit()
async def update_message(self, original_update: UpdateMessage):
async def update_message(self, original_update: UpdateMessage) -> None:
update, sender, portal = self.get_message_details(original_update)
if self.ignore_incoming_bot_events and self.bot and sender.id == self.bot.tgid:
self.log.debug(f"Ignoring relaybot-sent message %s to %s", update, portal.tgid_log)
@@ -369,9 +372,9 @@ class AbstractUser(ABC):
# endregion
def init(context: "Context"):
def init(context: "Context") -> None:
global config, MAX_DELETIONS
AbstractUser.az, AbstractUser.db, config, AbstractUser.loop, AbstractUser.relaybot = context
AbstractUser.az, AbstractUser.db, config, AbstractUser.loop, AbstractUser.relaybot = context.core
AbstractUser.ignore_incoming_bot_events = config["bridge.relaybot.ignore_own_incoming_events"]
AbstractUser.session_container = context.session_container
MAX_DELETIONS = config.get("bridge.max_telegram_delete", 10)
+33 -24
View File
@@ -14,21 +14,27 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Awaitable, Callable, Pattern, Dict, TYPE_CHECKING
from typing import Awaitable, Callable, Dict, List, Optional, Pattern, TYPE_CHECKING
import logging
import re
from telethon.tl.types import *
from telethon.tl.types import (
ChannelParticipantAdmin, ChannelParticipantCreator, ChatForbidden, ChatParticipantAdmin,
ChatParticipantCreator, InputChannel, InputUser, Message, MessageActionChatAddUser,
MessageActionChatDeleteUser, MessageEntityBotCommand, MessageService, PeerChannel, PeerChat,
TypePeer, UpdateNewChannelMessage, UpdateNewMessage)
from telethon.tl.functions.messages import GetChatsRequest, GetFullChatRequest
from telethon.tl.functions.channels import GetChannelsRequest, GetParticipantRequest
from telethon.errors import ChannelInvalidError, ChannelPrivateError
from .types import MatrixUserID
from .abstract_user import AbstractUser
from .db import BotChat
from . import puppet as pu, portal as po, user as u
if TYPE_CHECKING:
from .config import Config
from .context import Context
config = None # type: Config
@@ -39,7 +45,7 @@ class Bot(AbstractUser):
log = logging.getLogger("mau.bot") # type: logging.Logger
mxid_regex = re.compile("@.+:.+") # type: Pattern
def __init__(self, token: str):
def __init__(self, token: str) -> None:
super().__init__()
self.token = token # type: str
self.puppet_whitelisted = True # type: bool
@@ -53,7 +59,7 @@ class Bot(AbstractUser):
self.whitelist_group_admins = (config["bridge.relaybot.whitelist_group_admins"]
or False) # type: bool
async def init_permissions(self):
async def init_permissions(self) -> None:
whitelist = config["bridge.relaybot.whitelist"] or []
for id in whitelist:
if isinstance(id, str):
@@ -65,14 +71,14 @@ class Bot(AbstractUser):
if isinstance(id, int):
self.tg_whitelist.append(id)
async def start(self, delete_unless_authenticated: bool = False) -> "Bot":
async def start(self, delete_unless_authenticated: bool = False) -> 'Bot':
await super().start(delete_unless_authenticated)
if not await self.is_logged_in():
await self.client.sign_in(bot_token=self.token)
await self.post_login()
return self
async def post_login(self):
async def post_login(self) -> None:
await self.init_permissions()
info = await self.client.get_me()
self.tgid = info.id
@@ -100,19 +106,19 @@ class Bot(AbstractUser):
except Exception:
self.log.exception("Failed to run catch_up() for bot")
def register_portal(self, portal: po.Portal):
def register_portal(self, portal: po.Portal) -> None:
self.add_chat(portal.tgid, portal.peer_type)
def unregister_portal(self, portal: po.Portal):
def unregister_portal(self, portal: po.Portal) -> None:
self.remove_chat(portal.tgid)
def add_chat(self, id: int, type: str):
def add_chat(self, id: int, type: str) -> None:
if id not in self.chats:
self.chats[id] = type
self.db.add(BotChat(id=id, type=type))
self.db.commit()
def remove_chat(self, id: int):
def remove_chat(self, id: int) -> None:
try:
del self.chats[id]
except KeyError:
@@ -141,6 +147,7 @@ class Bot(AbstractUser):
for p in participants:
if p.user_id == tgid:
return isinstance(p, (ChatParticipantCreator, ChatParticipantAdmin))
return False
async def check_can_use_commands(self, event: Message, reply: ReplyFunc) -> bool:
if not await self._can_use_commands(event.to_id, event.from_id):
@@ -148,7 +155,7 @@ class Bot(AbstractUser):
return False
return True
async def handle_command_portal(self, portal: po.Portal, reply: ReplyFunc):
async def handle_command_portal(self, portal: po.Portal, reply: ReplyFunc) -> None:
if not config["bridge.relaybot.authless_portals"]:
return await reply("This bridge doesn't allow portal creation from Telegram.")
@@ -164,15 +171,16 @@ class Bot(AbstractUser):
return await reply(
"Portal is not public. Use `/invite <mxid>` to get an invite.")
async def handle_command_invite(self, portal: po.Portal, reply: ReplyFunc, mxid: str):
if len(mxid) == 0:
async def handle_command_invite(self, portal: po.Portal, reply: ReplyFunc,
mxid_input: MatrixUserID) -> Message:
if len(mxid_input) == 0:
return await reply("Usage: `/invite <mxid>`")
elif not portal.mxid:
return await reply("Portal does not have Matrix room. "
"Create one with /portal first.")
if not self.mxid_regex.match(mxid):
if not self.mxid_regex.match(mxid_input):
return await reply("That doesn't look like a Matrix ID.")
user = await u.User.get_by_mxid(mxid).ensure_started()
user = await u.User.get_by_mxid(MatrixUserID(mxid_input)).ensure_started()
if not user.relaybot_whitelisted:
return await reply("That user is not whitelisted to use the bridge.")
elif await user.is_logged_in():
@@ -183,7 +191,7 @@ class Bot(AbstractUser):
await portal.main_intent.invite(portal.mxid, user.mxid)
return await reply(f"Invited `{user.mxid}` to the portal.")
def handle_command_id(self, message: Message, reply: ReplyFunc):
def handle_command_id(self, message: Message, reply: ReplyFunc) -> Awaitable[Message]:
# Provide the prefixed ID to the user so that the user wouldn't need to specify whether the
# chat is a normal group or a supergroup/channel when using the ID.
if isinstance(message.to_id, PeerChannel):
@@ -205,8 +213,8 @@ class Bot(AbstractUser):
return False
async def handle_command(self, message: Message):
def reply(reply_text):
async def handle_command(self, message: Message) -> None:
def reply(reply_text: str) -> Awaitable[Message]:
return self.client.send_message(message.to_id, reply_text, reply_to=message.id)
text = message.message
@@ -227,9 +235,9 @@ class Bot(AbstractUser):
mxid = text[text.index(" ") + 1:]
except ValueError:
mxid = ""
await self.handle_command_invite(portal, reply, mxid=mxid)
await self.handle_command_invite(portal, reply, mxid_input=mxid)
def handle_service_message(self, message: MessageService):
def handle_service_message(self, message: MessageService) -> None:
to_id = message.to_id
if isinstance(to_id, PeerChannel):
to_id = to_id.channel_id
@@ -246,11 +254,12 @@ class Bot(AbstractUser):
elif isinstance(action, MessageActionChatDeleteUser) and action.user_id == self.tgid:
self.remove_chat(to_id)
async def update(self, update):
async def update(self, update) -> bool:
if not isinstance(update, (UpdateNewMessage, UpdateNewChannelMessage)):
return
return False
if isinstance(update.message, MessageService):
return self.handle_service_message(update.message)
self.handle_service_message(update.message)
return False
is_command = (isinstance(update.message, Message)
and update.message.entities and len(update.message.entities) > 0
@@ -266,7 +275,7 @@ class Bot(AbstractUser):
return "bot"
def init(context) -> Optional[Bot]:
def init(context: 'Context') -> Optional[Bot]:
global config
config = context.config
token = config["telegram.bot_token"]
+29 -19
View File
@@ -14,10 +14,14 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict
from typing import Any, Awaitable, Dict, Optional
import asyncio
from telethon.errors import *
from telethon.errors import (
AccessTokenExpiredError, AccessTokenInvalidError, FirstNameInvalidError, FloodWaitError,
PasswordHashInvalidError, PhoneCodeExpiredError, PhoneCodeInvalidError,
PhoneNumberAppSignupForbiddenError, PhoneNumberBannedError, PhoneNumberFloodError,
PhoneNumberOccupiedError, PhoneNumberUnoccupiedError, SessionPasswordNeededError)
from . import command_handler, CommandEvent, SECTION_AUTH
from .. import puppet as pu
@@ -27,7 +31,7 @@ from ..util import format_duration
@command_handler(needs_auth=False,
help_section=SECTION_AUTH,
help_text="Check if you're logged into Telegram.")
async def ping(evt: CommandEvent):
async def ping(evt: CommandEvent) -> Optional[Dict]:
me = await evt.sender.client.get_me() if await evt.sender.is_logged_in() else None
if me:
return await evt.reply(f"You're logged in as @{me.username}")
@@ -38,7 +42,7 @@ async def ping(evt: CommandEvent):
@command_handler(needs_auth=False, needs_puppeting=False,
help_section=SECTION_AUTH,
help_text="Get the info of the message relay Telegram bot.")
async def ping_bot(evt: CommandEvent):
async def ping_bot(evt: CommandEvent) -> Optional[Dict]:
if not evt.tgbot:
return await evt.reply("Telegram message relay bot not configured.")
bot_info = await evt.tgbot.client.get_me()
@@ -53,19 +57,19 @@ async def ping_bot(evt: CommandEvent):
help_section=SECTION_AUTH,
help_text="Revert your Telegram account's Matrix puppet to use the default Matrix "
"account.")
async def logout_matrix(evt: CommandEvent):
async def logout_matrix(evt: CommandEvent) -> Optional[Dict]:
puppet = pu.Puppet.get(evt.sender.tgid)
if not puppet.is_real_user:
return await evt.reply("You are not logged in with your Matrix account.")
await puppet.switch_mxid(None, None)
await evt.reply("Reverted your Telegram account's Matrix puppet back to the default.")
return await evt.reply("Reverted your Telegram account's Matrix puppet back to the default.")
@command_handler(needs_auth=True, management_only=True, needs_matrix_puppeting=True,
help_section=SECTION_AUTH,
help_text="Replace your Telegram account's Matrix puppet with your own Matrix "
"account")
async def login_matrix(evt: CommandEvent):
async def login_matrix(evt: CommandEvent) -> Optional[Dict]:
puppet = pu.Puppet.get(evt.sender.tgid)
if puppet.is_real_user:
return await evt.reply("You have already logged in with your Matrix account. "
@@ -96,7 +100,7 @@ async def login_matrix(evt: CommandEvent):
return await evt.reply("This bridge instance has been configured to not allow logging in.")
async def enter_matrix_token(evt: CommandEvent):
async def enter_matrix_token(evt: CommandEvent) -> Dict:
evt.sender.command_status = None
puppet = pu.Puppet.get(evt.sender.tgid)
@@ -105,10 +109,11 @@ async def enter_matrix_token(evt: CommandEvent):
"Log out with `$cmdprefix+sp logout-matrix` first.")
resp = await puppet.switch_mxid(" ".join(evt.args), evt.sender.mxid)
if resp == 2:
if resp == pu.PuppetError.OnlyLoginSelf:
return await evt.reply("You can only log in as your own Matrix user.")
elif resp == 1:
elif resp == pu.PuppetError.InvalidAccessToken:
return await evt.reply("Failed to verify access token.")
assert resp == pu.PuppetError.Success, "Encountered an unhandled PuppetError."
return await evt.reply(
f"Replaced your Telegram account's Matrix puppet with {puppet.custom_mxid}.")
@@ -117,7 +122,7 @@ async def enter_matrix_token(evt: CommandEvent):
help_section=SECTION_AUTH,
help_args="<_phone_> <_full name_>",
help_text="Register to Telegram")
async def register(evt: CommandEvent):
async def register(evt: CommandEvent) -> Optional[Dict]:
if await evt.sender.is_logged_in():
return await evt.reply("You are already logged in.")
elif len(evt.args) < 1:
@@ -134,9 +139,10 @@ async def register(evt: CommandEvent):
"action": "Register",
"full_name": full_name,
})
return None
async def enter_code_register(evt: CommandEvent):
async def enter_code_register(evt: CommandEvent) -> Dict:
if len(evt.args) == 0:
return await evt.reply("**Usage:** `$cmdprefix+sp <code>`")
try:
@@ -165,7 +171,7 @@ async def enter_code_register(evt: CommandEvent):
@command_handler(needs_auth=False, management_only=True,
help_section=SECTION_AUTH,
help_text="Get instructions on how to log in.")
async def login(evt: CommandEvent):
async def login(evt: CommandEvent) -> Optional[Dict]:
if await evt.sender.is_logged_in():
return await evt.reply("You are already logged in.")
@@ -196,7 +202,8 @@ async def login(evt: CommandEvent):
return await evt.reply("This bridge instance has been configured to not allow logging in.")
async def request_code(evt: CommandEvent, phone_number: str, next_status: Dict[str, str]):
async def request_code(evt: CommandEvent, phone_number: str, next_status: Dict[str, Any]
) -> Dict:
ok = False
try:
await evt.sender.ensure_started(even_if_no_session=True)
@@ -228,7 +235,7 @@ async def request_code(evt: CommandEvent, phone_number: str, next_status: Dict[s
@command_handler(needs_auth=False)
async def enter_phone_or_token(evt: CommandEvent):
async def enter_phone_or_token(evt: CommandEvent) -> Optional[Dict]:
if len(evt.args) == 0:
return await evt.reply("**Usage:** `$cmdprefix+sp enter-phone-or-token <phone-or-token>`")
elif not evt.config.get("bridge.allow_matrix_login", True):
@@ -248,10 +255,11 @@ async def enter_phone_or_token(evt: CommandEvent):
"next": enter_code,
"action": "Login",
})
return None
@command_handler(needs_auth=False)
async def enter_code(evt: CommandEvent):
async def enter_code(evt: CommandEvent) -> Optional[Dict]:
if len(evt.args) == 0:
return await evt.reply("**Usage:** `$cmdprefix+sp enter-code <code>`")
elif not evt.config.get("bridge.allow_matrix_login", True):
@@ -263,10 +271,11 @@ async def enter_code(evt: CommandEvent):
evt.log.exception("Error sending phone code")
return await evt.reply("Unhandled exception while sending code. "
"Check console for more details.")
return None
@command_handler(needs_auth=False)
async def enter_password(evt: CommandEvent):
async def enter_password(evt: CommandEvent) -> Optional[Dict]:
if len(evt.args) == 0:
return await evt.reply("**Usage:** `$cmdprefix+sp enter-password <password>`")
elif not evt.config.get("bridge.allow_matrix_login", True):
@@ -282,9 +291,10 @@ async def enter_password(evt: CommandEvent):
evt.log.exception("Error sending password")
return await evt.reply("Unhandled exception while sending password. "
"Check console for more details.")
return None
async def sign_in(evt: CommandEvent, **sign_in_info):
async def sign_in(evt: CommandEvent, **sign_in_info) -> Dict:
try:
await evt.sender.ensure_started(even_if_no_session=True)
user = await evt.sender.client.sign_in(**sign_in_info)
@@ -309,7 +319,7 @@ async def sign_in(evt: CommandEvent, **sign_in_info):
@command_handler(needs_auth=True,
help_section=SECTION_AUTH,
help_text="Log out from Telegram.")
async def logout(evt: CommandEvent):
async def logout(evt: CommandEvent) -> Optional[Dict]:
if await evt.sender.log_out():
return await evt.reply("Logged out successfully.")
return await evt.reply("Failed to log out.")
+17 -16
View File
@@ -14,21 +14,21 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Tuple, List
from typing import Dict, List, NewType, Optional, Tuple, Union
from mautrix_appservice import MatrixRequestError, IntentAPI
from ..types import MatrixRoomID, MatrixUserID
from . import command_handler, CommandEvent, SECTION_ADMIN
from .. import puppet as pu, portal as po
ManagementRoomList = List[Tuple[str, str]]
RoomIDList = List[str]
ManagementRoom = NewType('ManagementRoom', Tuple[MatrixRoomID, MatrixUserID])
async def _find_rooms(intent: IntentAPI) -> Tuple[ManagementRoomList, RoomIDList,
List["po.Portal"], List["po.Portal"]]:
management_rooms = [] # type: ManagementRoomList
unidentified_rooms = [] # type: RoomIDList
async def _find_rooms(intent: IntentAPI) -> Tuple[List[ManagementRoom], List[MatrixRoomID],
List['po.Portal'], List['po.Portal']]:
management_rooms = [] # type: List[ManagementRoom]
unidentified_rooms = [] # type: List[MatrixRoomID]
portals = [] # type: List[po.Portal]
empty_portals = [] # type: List[po.Portal]
@@ -45,7 +45,7 @@ async def _find_rooms(intent: IntentAPI) -> Tuple[ManagementRoomList, RoomIDList
if pu.Puppet.get_id_from_mxid(other_member):
unidentified_rooms.append(room)
else:
management_rooms.append((room, other_member))
management_rooms.append(ManagementRoom((room, other_member)))
else:
unidentified_rooms.append(room)
else:
@@ -61,7 +61,7 @@ async def _find_rooms(intent: IntentAPI) -> Tuple[ManagementRoomList, RoomIDList
@command_handler(needs_admin=True, needs_auth=False, management_only=True, name="clean-rooms",
help_section=SECTION_ADMIN,
help_text="Clean up unused portal/management rooms.")
async def clean_rooms(evt: CommandEvent):
async def clean_rooms(evt: CommandEvent) -> Optional[Dict]:
management_rooms, unidentified_rooms, portals, empty_portals = await _find_rooms(evt.az.intent)
reply = ["#### Management rooms (M)"]
@@ -106,13 +106,14 @@ async def clean_rooms(evt: CommandEvent):
return await evt.reply("\n".join(reply))
async def set_rooms_to_clean(evt, management_rooms: ManagementRoomList,
unidentified_rooms: RoomIDList, portals: List["po.Portal"],
empty_portals: List["po.Portal"]):
async def set_rooms_to_clean(evt, management_rooms: List[ManagementRoom],
unidentified_rooms: List[MatrixRoomID], portals: List["po.Portal"],
empty_portals: List["po.Portal"]) -> None:
command = evt.args[0]
rooms_to_clean = []
rooms_to_clean = [] # type: List[Union[po.Portal, MatrixRoomID]]
if command == "clean-recommended":
rooms_to_clean = empty_portals + unidentified_rooms
rooms_to_clean += empty_portals
rooms_to_clean += unidentified_rooms
elif command == "clean-groups":
if len(evt.args) < 2:
return await evt.reply("**Usage:** `$cmdprefix+sp clean-groups [M][A][U][I]")
@@ -158,7 +159,7 @@ async def set_rooms_to_clean(evt, management_rooms: ManagementRoomList,
"`$cmdprefix+sp confirm-clean`.")
async def execute_room_cleanup(evt, rooms_to_clean):
async def execute_room_cleanup(evt, rooms_to_clean: List[Union[po.Portal, MatrixRoomID]]) -> None:
if len(evt.args) > 0 and evt.args[0] == "confirm-clean":
await evt.reply(f"Cleaning {len(rooms_to_clean)} rooms. "
"This might take a while.")
@@ -167,7 +168,7 @@ async def execute_room_cleanup(evt, rooms_to_clean):
if isinstance(room, po.Portal):
await room.cleanup_and_delete()
cleaned += 1
elif isinstance(room, str):
elif isinstance(room, str): # str is aliased by MatrixRoomID
await po.Portal.cleanup_room(evt.az.intent, room, message="Room deleted")
cleaned += 1
evt.sender.command_status = None
+32 -20
View File
@@ -14,19 +14,20 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import List, Dict, Callable, Optional
from typing import Any, Awaitable, Callable, Coroutine, Dict, List, NamedTuple, Optional, Union
from collections import namedtuple
import markdown
import logging
from telethon.errors import FloodWaitError
from ..types import MatrixRoomID
from ..util import format_duration
from .. import user as u, context as c
command_handlers = {} # type: Dict[str, CommandHandler]
HelpSection = namedtuple("HelpSection", "name order description")
HelpSection = NamedTuple('HelpSection', [('name', str), ('order', int), ('description', str)])
SECTION_GENERAL = HelpSection("General", 0, "")
SECTION_AUTH = HelpSection("Authentication", 10, "")
@@ -37,8 +38,8 @@ SECTION_ADMIN = HelpSection("Administration", 50, "")
class CommandEvent:
def __init__(self, processor: "CommandProcessor", room: str, sender: u.User, command: str,
args: List[str], is_management: bool, is_portal: bool):
def __init__(self, processor: 'CommandProcessor', room: MatrixRoomID, sender: u.User,
command: str, args: List[str], is_management: bool, is_portal: bool) -> None:
self.az = processor.az
self.log = processor.log
self.loop = processor.loop
@@ -53,7 +54,8 @@ class CommandEvent:
self.is_management = is_management
self.is_portal = is_portal
def reply(self, message: str, allow_html: bool = False, render_markdown: bool = True):
def reply(self, message: str, allow_html: bool = False, render_markdown: bool = True
) -> Awaitable[Dict]:
message = message.replace("$cmdprefix+sp ",
"" if self.is_management else f"{self.command_prefix} ")
message = message.replace("$cmdprefix", self.command_prefix)
@@ -66,10 +68,10 @@ class CommandEvent:
class CommandHandler:
def __init__(self, handler: Callable[[CommandEvent], None], needs_auth: bool,
def __init__(self, handler: Callable[[CommandEvent], Awaitable[Dict]], needs_auth: bool,
needs_puppeting: bool, needs_matrix_puppeting: bool, needs_admin: bool,
management_only: bool, name: str, help_text: str, help_args: str,
help_section: HelpSection):
help_section: HelpSection) -> None:
self._handler = handler
self.needs_auth = needs_auth
self.needs_puppeting = needs_puppeting
@@ -103,7 +105,8 @@ class CommandHandler:
(not self.needs_admin or is_admin) and
(not self.needs_auth or is_logged_in))
async def __call__(self, evt: CommandEvent):
async def __call__(self, evt: CommandEvent
) -> Dict:
error = await self.get_permission_error(evt)
if error is not None:
return await evt.reply(error)
@@ -118,13 +121,21 @@ class CommandHandler:
return f"**{self.name}** {self._help_args} - {self._help_text}"
def command_handler(_func: Optional[Callable[[CommandEvent], None]] = None, *, needs_auth=True,
needs_puppeting=True, needs_matrix_puppeting=False, needs_admin=False,
management_only=False, name=None, help_text="", help_args="",
help_section=None):
def command_handler(_func: Optional[Callable[[CommandEvent], Awaitable[Dict]]] = None, *,
needs_auth: bool = True,
needs_puppeting: bool = True,
needs_matrix_puppeting: bool = False,
needs_admin: bool = False,
management_only: bool = False,
name: Optional[str] = None,
help_text: str = "",
help_args: str = "",
help_section: HelpSection = None
) -> Callable[[Callable[[CommandEvent], Awaitable[Optional[Dict]]]],
CommandHandler]:
input_name = name
def decorator(func: Callable[[CommandEvent], None]):
def decorator(func: Callable[[CommandEvent], Awaitable[Optional[Dict]]]) -> CommandHandler:
name = input_name or func.__name__.replace("_", "-")
handler = CommandHandler(func, needs_auth, needs_puppeting, needs_matrix_puppeting,
needs_admin, management_only, name, help_text, help_args,
@@ -138,27 +149,27 @@ def command_handler(_func: Optional[Callable[[CommandEvent], None]] = None, *, n
class CommandProcessor:
log = logging.getLogger("mau.commands")
def __init__(self, context: c.Context):
self.az, self.db, self.config, self.loop, self.tgbot = context
def __init__(self, context: c.Context) -> None:
self.az, self.db, self.config, self.loop, self.tgbot = context.core
self.public_website = context.public_website
self.command_prefix = self.config["bridge.command_prefix"]
async def handle(self, room: str, sender: u.User, command: str, args: List[str],
is_management: bool, is_portal: bool):
async def handle(self, room: MatrixRoomID, sender: u.User, command: str, args: List[str],
is_management: bool, is_portal: bool) -> Optional[Dict]:
evt = CommandEvent(self, room, sender, command, args, is_management, is_portal)
orig_command = command
command = command.lower()
try:
command = command_handlers[command]
command_handler = command_handlers[command]
except KeyError:
if sender.command_status and "next" in sender.command_status:
args.insert(0, orig_command)
evt.command = ""
command = sender.command_status["next"]
else:
command = command_handlers["unknown-command"]
command_handler = command_handlers["unknown-command"]
try:
await command(evt)
await command_handler(evt)
except FloodWaitError as e:
return await evt.reply(f"Flood error: Please wait {format_duration(e.seconds)}")
except Exception:
@@ -166,3 +177,4 @@ class CommandProcessor:
f"{evt.command} {' '.join(args)} from {sender.mxid}")
return await evt.reply("Unhandled error while handling command. "
"Check logs for more details.")
return None
+17 -14
View File
@@ -14,46 +14,49 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict, List, Optional, Tuple
from . import command_handler, CommandEvent, _command_handlers, SECTION_GENERAL
from .handler import HelpSection
@command_handler(needs_auth=False, needs_puppeting=False,
help_section=SECTION_GENERAL,
help_text="Cancel an ongoing action (such as login)")
def cancel(evt: CommandEvent):
async def cancel(evt: CommandEvent) -> Optional[Dict]:
if evt.sender.command_status:
action = evt.sender.command_status["action"]
evt.sender.command_status = None
return evt.reply(f"{action} cancelled.")
return await evt.reply(f"{action} cancelled.")
else:
return evt.reply("No ongoing command.")
return await evt.reply("No ongoing command.")
@command_handler(needs_auth=False, needs_puppeting=False)
def unknown_command(evt: CommandEvent):
return evt.reply("Unknown command. Try `$cmdprefix+sp help` for help.")
async def unknown_command(evt: CommandEvent) -> Optional[Dict]:
return await evt.reply("Unknown command. Try `$cmdprefix+sp help` for help.")
help_cache = {}
help_cache = {} # type: Dict[Tuple[bool, bool, bool, bool, bool], str]
async def _get_help_text(evt: CommandEvent):
async def _get_help_text(evt: CommandEvent) -> str:
cache_key = (evt.is_management, evt.sender.puppet_whitelisted,
evt.sender.matrix_puppet_whitelisted, evt.sender.is_admin,
await evt.sender.is_logged_in())
if cache_key not in help_cache:
help = {}
help_sections = {} # type: Dict[HelpSection, List[str]]
for handler in _command_handlers.values():
if handler.has_help and handler.has_permission(*cache_key):
help.setdefault(handler.help_section, [])
help[handler.help_section].append(handler.help + " ")
help = sorted(help.items(), key=lambda item: item[0].order)
help = ["#### {}\n{}\n".format(key.name, "\n".join(value)) for key, value in help]
help_sections.setdefault(handler.help_section, [])
help_sections[handler.help_section].append(handler.help + " ")
help_sorted = sorted(help_sections.items(), key=lambda item: item[0].order)
help = ["#### {}\n{}\n".format(key.name, "\n".join(value)) for key, value in help_sorted]
help_cache[cache_key] = "\n".join(help)
return help_cache[cache_key]
def _get_management_status(evt: CommandEvent):
def _get_management_status(evt: CommandEvent) -> str:
if evt.is_management:
return "This is a management room: prefixing commands with `$cmdprefix` is not required."
elif evt.is_portal:
@@ -65,5 +68,5 @@ def _get_management_status(evt: CommandEvent):
@command_handler(needs_auth=False, needs_puppeting=False,
help_section=SECTION_GENERAL,
help_text="Show this help message.")
async def help(evt: CommandEvent):
async def help(evt: CommandEvent) -> Optional[Dict]:
return await evt.reply(_get_management_status(evt) + "\n" + await _get_help_text(evt))
+48 -38
View File
@@ -14,13 +14,15 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional, Callable
from typing import Awaitable, Dict, Callable, Coroutine, Optional, Tuple, Union, cast
import asyncio
from telethon.errors import *
from telethon.errors import (ChatAdminRequiredError, UsernameInvalidError,
UsernameNotModifiedError, UsernameOccupiedError)
from telethon.tl.types import ChatForbidden, ChannelForbidden
from mautrix_appservice import MatrixRequestError, IntentAPI
from ..types import MatrixRoomID, TelegramID
from .. import portal as po, user as u
from . import (command_handler, CommandEvent,
SECTION_ADMIN, SECTION_CREATING_PORTALS, SECTION_PORTAL_MANAGEMENT)
@@ -30,7 +32,7 @@ from . import (command_handler, CommandEvent,
help_section=SECTION_ADMIN,
help_args="<_level_> [_mxid_]",
help_text="Set a temporary power level without affecting Telegram.")
async def set_power_level(evt: CommandEvent):
async def set_power_level(evt: CommandEvent) -> Dict:
try:
level = int(evt.args[0])
except KeyError:
@@ -45,11 +47,12 @@ async def set_power_level(evt: CommandEvent):
except MatrixRequestError:
evt.log.exception("Failed to set power level.")
return await evt.reply("Failed to set power level.")
return {}
@command_handler(help_section=SECTION_PORTAL_MANAGEMENT,
help_text="Get a Telegram invite link to the current chat.")
async def invite_link(evt: CommandEvent):
async def invite_link(evt: CommandEvent) -> Dict:
portal = po.Portal.get_by_mxid(evt.room_id)
if not portal:
return await evt.reply("This is not a portal room.")
@@ -66,7 +69,8 @@ async def invite_link(evt: CommandEvent):
return await evt.reply("You don't have the permission to create an invite link.")
async def user_has_power_level(room: str, intent, sender: u.User, event: str, default: int = 50):
async def user_has_power_level(room: str, intent, sender: u.User, event: str, default: int = 50
) -> bool:
if sender.is_admin:
return True
# Make sure the state store contains the power levels.
@@ -80,8 +84,9 @@ async def user_has_power_level(room: str, intent, sender: u.User, event: str, de
async def _get_portal_and_check_permission(evt: CommandEvent, permission: str,
action: Optional[str] = None):
room_id = evt.args[0] if len(evt.args) > 0 else evt.room_id
action: Optional[str] = None
) -> Tuple[Union[Dict, po.Portal], bool]:
room_id = MatrixRoomID(evt.args[0]) if len(evt.args) > 0 else evt.room_id
portal = po.Portal.get_by_mxid(room_id)
if not portal:
@@ -95,8 +100,8 @@ async def _get_portal_and_check_permission(evt: CommandEvent, permission: str,
def _get_portal_murder_function(action: str, room_id: str, function: Callable, command: str,
completed_message: str):
async def post_confirm(confirm):
completed_message: str) -> Dict:
async def post_confirm(confirm) -> Optional[Dict]:
confirm.sender.command_status = None
if len(confirm.args) > 0 and confirm.args[0] == f"confirm-{command}":
await function()
@@ -104,6 +109,7 @@ def _get_portal_murder_function(action: str, room_id: str, function: Callable, c
return await confirm.reply(completed_message)
else:
return await confirm.reply(f"{action} cancelled.")
return None
return {
"next": post_confirm,
@@ -116,10 +122,11 @@ def _get_portal_murder_function(action: str, room_id: str, function: Callable, c
help_text="Remove all users from the current portal room and forget the portal. "
"Only works for group chats; to delete a private chat portal, simply "
"leave the room.")
async def delete_portal(evt: CommandEvent):
portal, ok = await _get_portal_and_check_permission(evt, "unbridge")
async def delete_portal(evt: CommandEvent) -> Optional[Dict]:
result, ok = await _get_portal_and_check_permission(evt, "unbridge")
if not ok:
return
return None
portal = cast('po.Portal', result)
evt.sender.command_status = _get_portal_murder_function("Portal deletion", portal.mxid,
portal.cleanup_and_delete, "delete",
@@ -137,10 +144,11 @@ async def delete_portal(evt: CommandEvent):
@command_handler(needs_auth=False, needs_puppeting=False,
help_section=SECTION_PORTAL_MANAGEMENT,
help_text="Remove puppets from the current portal room and forget the portal.")
async def unbridge(evt: CommandEvent):
portal, ok = await _get_portal_and_check_permission(evt, "unbridge")
async def unbridge(evt: CommandEvent) -> Optional[Dict]:
result, ok = await _get_portal_and_check_permission(evt, "unbridge")
if not ok:
return
return None
portal = cast('po.Portal', result)
evt.sender.command_status = _get_portal_murder_function("Room unbridging", portal.mxid,
portal.unbridge, "unbridge",
@@ -156,11 +164,11 @@ async def unbridge(evt: CommandEvent):
help_text="Bridge the current Matrix room to the Telegram chat with the given "
"ID. The ID must be the prefixed version that you get with the `/id` "
"command of the Telegram-side bot.")
async def bridge(evt: CommandEvent):
async def bridge(evt: CommandEvent) -> Dict:
if len(evt.args) == 0:
return await evt.reply("**Usage:** "
"`$cmdprefix+sp bridge <Telegram chat ID> [Matrix room ID]`")
room_id = evt.args[1] if len(evt.args) > 1 else evt.room_id
room_id = MatrixRoomID(evt.args[1]) if len(evt.args) > 1 else evt.room_id
that_this = "This" if room_id == evt.room_id else "That"
portal = po.Portal.get_by_mxid(room_id)
@@ -171,12 +179,12 @@ async def bridge(evt: CommandEvent):
return await evt.reply(f"You do not have the permissions to bridge {that_this} room.")
# The /id bot command provides the prefixed ID, so we assume
tgid = evt.args[0]
if tgid.startswith("-100"):
tgid = int(tgid[4:])
tgid_str = evt.args[0]
if tgid_str.startswith("-100"):
tgid = TelegramID(int(tgid_str[4:]))
peer_type = "channel"
elif tgid.startswith("-"):
tgid = -int(tgid)
elif tgid_str.startswith("-"):
tgid = TelegramID(-int(tgid_str))
peer_type = "chat"
else:
return await evt.reply("That doesn't seem like a prefixed Telegram chat ID.\n\n"
@@ -222,7 +230,8 @@ async def bridge(evt: CommandEvent):
"chat to this room, use `$cmdprefix+sp continue`")
async def cleanup_old_portal_while_bridging(evt: CommandEvent, portal: "po.Portal"):
async def cleanup_old_portal_while_bridging(evt: CommandEvent, portal: "po.Portal"
) -> Tuple[bool, Coroutine[None, None, None]]:
if not portal.mxid:
await evt.reply("The portal seems to have lost its Matrix room between you"
"calling `$cmdprefix+sp bridge` and this command.\n\n"
@@ -245,7 +254,7 @@ async def cleanup_old_portal_while_bridging(evt: CommandEvent, portal: "po.Porta
return False, None
async def confirm_bridge(evt: CommandEvent):
async def confirm_bridge(evt: CommandEvent) -> Optional[Dict]:
status = evt.sender.command_status
try:
portal = po.Portal.get_by_tgid(status["tgid"], peer_type=status["peer_type"])
@@ -258,7 +267,7 @@ async def confirm_bridge(evt: CommandEvent):
if "mxid" in status:
ok, coro = await cleanup_old_portal_while_bridging(evt, portal)
if not ok:
return
return None
elif coro:
asyncio.ensure_future(coro, loop=evt.loop)
await evt.reply("Cleaning up previous portal room...")
@@ -302,7 +311,7 @@ async def confirm_bridge(evt: CommandEvent):
return await evt.reply("Bridging complete. Portal synchronization should begin momentarily.")
async def get_initial_state(intent: IntentAPI, room_id: str):
async def get_initial_state(intent: IntentAPI, room_id: str) -> Tuple[str, str, Dict]:
state = await intent.get_room_state(room_id)
title = None
about = None
@@ -328,7 +337,7 @@ async def get_initial_state(intent: IntentAPI, room_id: str):
help_text="Create a Telegram chat of the given type for the current Matrix room. "
"The type is either `group`, `supergroup` or `channel` (defaults to "
"`group`).")
async def create(evt: CommandEvent):
async def create(evt: CommandEvent) -> Dict:
type = evt.args[0] if len(evt.args) > 0 else "group"
if type not in {"chat", "group", "supergroup", "channel"}:
return await evt.reply(
@@ -363,7 +372,7 @@ async def create(evt: CommandEvent):
@command_handler(help_section=SECTION_PORTAL_MANAGEMENT,
help_text="Upgrade a normal Telegram group to a supergroup.")
async def upgrade(evt: CommandEvent):
async def upgrade(evt: CommandEvent) -> Dict:
portal = po.Portal.get_by_mxid(evt.room_id)
if not portal:
return await evt.reply("This is not a portal room.")
@@ -385,7 +394,7 @@ async def upgrade(evt: CommandEvent):
help_args="<_name_|`-`>",
help_text="Change the username of a supergroup/channel. "
"To disable, use a dash (`-`) as the name.")
async def group_name(evt: CommandEvent):
async def group_name(evt: CommandEvent) -> Dict:
if len(evt.args) == 0:
return await evt.reply("**Usage:** `$cmdprefix+sp group-name <name/->`")
@@ -421,7 +430,7 @@ async def group_name(evt: CommandEvent):
help_args="<`whitelist`|`blacklist`>",
help_text="Change whether the bridge will allow or disallow bridging rooms by "
"default.")
async def filter_mode(evt: CommandEvent):
async def filter_mode(evt: CommandEvent) -> Dict:
try:
mode = evt.args[0]
if mode not in ("whitelist", "blacklist"):
@@ -446,19 +455,19 @@ async def filter_mode(evt: CommandEvent):
help_section=SECTION_ADMIN,
help_args="<`whitelist`|`blacklist`> <_chat ID_>",
help_text="Allow or disallow bridging a specific chat.")
async def filter(evt: CommandEvent):
async def filter(evt: CommandEvent) -> Optional[Dict]:
try:
action = evt.args[0]
if action not in ("whitelist", "blacklist", "add", "remove"):
raise ValueError()
id = evt.args[1]
if id.startswith("-100"):
id = int(id[4:])
elif id.startswith("-"):
id = int(id[1:])
id_str = evt.args[1]
if id_str.startswith("-100"):
id = int(id_str[4:])
elif id_str.startswith("-"):
id = int(id_str[1:])
else:
id = int(id)
id = int(id_str)
except (IndexError, ValueError):
return await evt.reply("**Usage:** `$cmdprefix+sp filter <whitelist/blacklist> <chat ID>`")
@@ -471,7 +480,7 @@ async def filter(evt: CommandEvent):
if action in ("blacklist", "whitelist"):
action = "add" if mode == action else "remove"
def save():
def save() -> None:
evt.config["bridge.filter.list"] = list
evt.config.save()
po.Portal.filter_list = list
@@ -488,3 +497,4 @@ async def filter(evt: CommandEvent):
list.remove(id)
save()
return await evt.reply(f"Chat ID removed from {mode}.")
return None
+14 -8
View File
@@ -14,8 +14,13 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from telethon.errors import *
from typing import Awaitable, Dict, List, Optional, Tuple
import re
from telethon.errors import (
InviteHashInvalidError, InviteHashExpiredError, UserAlreadyParticipantError)
from telethon.tl.types import User as TLUser
from telethon.tl.types import TypeUpdates
from telethon.tl.functions.messages import ImportChatInviteRequest, CheckChatInviteRequest
from telethon.tl.functions.channels import JoinChannelRequest
@@ -26,7 +31,7 @@ from . import command_handler, CommandEvent, SECTION_MISC, SECTION_CREATING_PORT
@command_handler(help_section=SECTION_MISC,
help_args="[_-r|--remote_] <_query_>",
help_text="Search your contacts or the Telegram servers for users.")
async def search(evt: CommandEvent):
async def search(evt: CommandEvent) -> Optional[Dict]:
if len(evt.args) == 0:
return await evt.reply("**Usage:** `$cmdprefix+sp search [-r|--remote] <query>`")
@@ -47,7 +52,7 @@ async def search(evt: CommandEvent):
"Minimum length of remote query is 5 characters.")
return await evt.reply("No results 3:")
reply = []
reply = [] # type: List[str]
if remote:
reply += ["**Results from Telegram server:**", ""]
else:
@@ -68,7 +73,7 @@ async def search(evt: CommandEvent):
"either the internal user ID, the username or the phone number. "
"**N.B.** The phone numbers you start chats with must already be in "
"your contacts.")
async def private_message(evt: CommandEvent):
async def private_message(evt: CommandEvent) -> Optional[Dict]:
if len(evt.args) == 0:
return await evt.reply("**Usage:** `$cmdprefix+sp pm <user identifier>`")
@@ -87,7 +92,7 @@ async def private_message(evt: CommandEvent):
f"{pu.Puppet.get_displayname(user, False)}")
async def _join(evt: CommandEvent, arg: str):
async def _join(evt: CommandEvent, arg: str) -> Tuple[TypeUpdates, Dict]:
if arg.startswith("joinchat/"):
invite_hash = arg[len("joinchat/"):]
try:
@@ -110,7 +115,7 @@ async def _join(evt: CommandEvent, arg: str):
@command_handler(help_section=SECTION_CREATING_PORTALS,
help_args="<_link_>",
help_text="Join a chat with an invite link.")
async def join(evt: CommandEvent):
async def join(evt: CommandEvent) -> Optional[Dict]:
if len(evt.args) == 0:
return await evt.reply("**Usage:** `$cmdprefix+sp join <invite link>`")
@@ -121,7 +126,7 @@ async def join(evt: CommandEvent):
updates, _ = await _join(evt, arg.group(1))
if not updates:
return
return None
for chat in updates.chats:
portal = po.Portal.get_by_entity(chat)
@@ -132,12 +137,13 @@ async def join(evt: CommandEvent):
await evt.reply(f"Creating room for {chat.title}... This might take a while.")
await portal.create_matrix_room(evt.sender, chat, [evt.sender.mxid])
return await evt.reply(f"Created room for {portal.title}")
return None
@command_handler(help_section=SECTION_MISC,
help_args="[`chats`|`contacts`|`me`]",
help_text="Synchronize your chat portals, contacts and/or own info.")
async def sync(evt: CommandEvent):
async def sync(evt: CommandEvent) -> Optional[Dict]:
if len(evt.args) > 0:
sync_only = evt.args[0]
if sync_only not in ("chats", "contacts", "me"):
+16 -16
View File
@@ -14,7 +14,7 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Tuple, Any, Optional
from typing import Any, Dict, Optional, Tuple
from ruamel.yaml import YAML
from ruamel.yaml.comments import CommentedMap
import random
@@ -25,7 +25,7 @@ yaml.indent(4)
class DictWithRecursion:
def __init__(self, data: CommentedMap = None):
def __init__(self, data: Optional[CommentedMap] = None) -> None:
self._data = data or CommentedMap() # type: CommentedMap
def _recursive_get(self, data: CommentedMap, key: str, default_value: Any) -> Any:
@@ -46,7 +46,7 @@ class DictWithRecursion:
def __contains__(self, key: str) -> bool:
return self[key] is not None
def _recursive_set(self, data: CommentedMap, key: str, value: Any):
def _recursive_set(self, data: CommentedMap, key: str, value: Any) -> None:
if '.' in key:
key, next_key = key.split('.', 1)
if key not in data:
@@ -56,16 +56,16 @@ class DictWithRecursion:
return
data[key] = value
def set(self, key: str, value: Any, allow_recursion: bool = True):
def set(self, key: str, value: Any, allow_recursion: bool = True) -> None:
if allow_recursion and '.' in key:
self._recursive_set(self._data, key, value)
return
self._data[key] = value
def __setitem__(self, key: str, value: Any):
def __setitem__(self, key: str, value: Any) -> None:
self.set(key, value)
def _recursive_del(self, data: CommentedMap, key: str):
def _recursive_del(self, data: CommentedMap, key: str) -> None:
if '.' in key:
key, next_key = key.split('.', 1)
if key not in data:
@@ -79,7 +79,7 @@ class DictWithRecursion:
except KeyError:
pass
def delete(self, key: str, allow_recursion: bool = True):
def delete(self, key: str, allow_recursion: bool = True) -> None:
if allow_recursion and '.' in key:
self._recursive_del(self._data, key)
return
@@ -89,19 +89,19 @@ class DictWithRecursion:
except KeyError:
pass
def __delitem__(self, key: str):
def __delitem__(self, key: str) -> None:
self.delete(key)
class Config(DictWithRecursion):
def __init__(self, path: str, registration_path: str, base_path: str):
def __init__(self, path: str, registration_path: str, base_path: str) -> None:
super().__init__()
self.path = path # type: str
self.registration_path = registration_path # type: str
self.base_path = base_path # type: str
self._registration = None # type: dict
self._registration = None # type: Optional[Dict]
def load(self):
def load(self) -> None:
with open(self.path, 'r') as stream:
self._data = yaml.load(stream)
@@ -113,7 +113,7 @@ class Config(DictWithRecursion):
pass
return None
def save(self):
def save(self) -> None:
with open(self.path, 'w') as stream:
yaml.dump(self._data, stream)
if self._registration and self.registration_path:
@@ -124,16 +124,16 @@ class Config(DictWithRecursion):
def _new_token() -> str:
return "".join(random.choice(string.ascii_lowercase + string.digits) for _ in range(64))
def update(self):
def update(self) -> None:
base = self.load_base()
if not base:
return
def copy(from_path, to_path=None):
def copy(from_path, to_path=None) -> None:
if from_path in self:
base[to_path or from_path] = self[from_path]
def copy_dict(from_path, to_path=None, override_existing_map=True):
def copy_dict(from_path, to_path=None, override_existing_map=True) -> None:
if from_path in self:
to_path = to_path or from_path
if override_existing_map or to_path not in base:
@@ -273,7 +273,7 @@ class Config(DictWithRecursion):
return self._get_permissions("*")
def generate_registration(self):
def generate_registration(self) -> None:
homeserver = self["homeserver.domain"]
username_format = self.get("bridge.username_template", "telegram_{userid}") \
+7 -8
View File
@@ -14,7 +14,7 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import TYPE_CHECKING, Optional
from typing import Generator, Optional, Tuple, Union, TYPE_CHECKING
if TYPE_CHECKING:
import asyncio
@@ -32,7 +32,8 @@ if TYPE_CHECKING:
class Context:
def __init__(self, az: "AppService", db: "scoped_session", config: "Config",
loop: "asyncio.AbstractEventLoop", session_container: "AlchemySessionContainer"):
loop: "asyncio.AbstractEventLoop", session_container: "AlchemySessionContainer"
) -> None:
self.az = az # type: AppService
self.db = db # type: scoped_session
self.config = config # type: Config
@@ -43,9 +44,7 @@ class Context:
self.public_website = None # type: PublicBridgeWebsite
self.provisioning_api = None # type: ProvisioningAPI
def __iter__(self):
yield self.az
yield self.db
yield self.config
yield self.loop
yield self.bot
@property
def core(self) -> Tuple['AppService', 'scoped_session', 'Config',
'asyncio.AbstractEventLoop', Optional['Bot']]:
return (self.az, self.db, self.config, self.loop, self.bot)
+9 -7
View File
@@ -14,6 +14,8 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict
from sqlalchemy import (Column, UniqueConstraint, ForeignKey, ForeignKeyConstraint, Integer,
BigInteger, String, Boolean, Text)
from sqlalchemy.sql import expression
@@ -88,20 +90,20 @@ class RoomState(Base):
room_id = Column(String, primary_key=True)
_power_levels_text = Column("power_levels", Text, nullable=True)
_power_levels_json = None
_power_levels_json = {} # type: Dict
@property
def has_power_levels(self):
def has_power_levels(self) -> bool:
return bool(self._power_levels_text)
@property
def power_levels(self):
def power_levels(self) -> Dict:
if not self._power_levels_json and self._power_levels_text:
self._power_levels_json = json.loads(self._power_levels_text)
return self._power_levels_json or {}
return self._power_levels_json
@power_levels.setter
def power_levels(self, val):
def power_levels(self, val: Dict) -> None:
self._power_levels_json = val
self._power_levels_text = json.dumps(val)
@@ -116,7 +118,7 @@ class UserProfile(Base):
displayname = Column(String, nullable=True)
avatar_url = Column(String, nullable=True)
def dict(self):
def dict(self) -> Dict[str, Column]:
return {
"membership": self.membership,
"displayname": self.displayname,
@@ -171,7 +173,7 @@ class TelegramFile(Base):
thumbnail = relationship("TelegramFile", uselist=False)
def init(db_session):
def init(db_session) -> None:
Portal.query = db_session.query_property()
Message.query = db_session.query_property()
UserPortal.query = db_session.query_property()
@@ -80,12 +80,12 @@ class MatrixParser(HTMLParser, MatrixParserCommon):
args["url"] = url
return MessageEntityTextUrl, None
def handle_starttag(self, tag: str, attrs: List[Tuple[str, str]]):
def handle_starttag(self, tag: str, attrs_list: List[Tuple[str, str]]):
self._open_tags.appendleft(tag)
self._open_tags_meta.appendleft(0)
attrs = dict(attrs)
entity_type = None # type: type(TypeMessageEntity)
attrs = dict(attrs_list)
entity_type = None # type: Optional[Type[TypeMessageEntity]]
args = {} # type: Dict[str, Any]
if tag in ("strong", "b"):
entity_type = MessageEntityBold
@@ -14,7 +14,7 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional, List, Tuple, Union, Callable
from typing import Callable, List, Optional, Sequence, Tuple, Type, Union
from lxml import html
from telethon.tl.types import (MessageEntityMention as Mention,
@@ -83,11 +83,11 @@ def offset_length_multiply(amount: int):
class TelegramMessage:
def __init__(self, text: str = "", entities: Optional[List[TypeMessageEntity]] = None):
def __init__(self, text: str = "", entities: Optional[List[TypeMessageEntity]] = None) -> None:
self.text = text # type: str
self.entities = entities or [] # type: List[TypeMessageEntity]
def offset_entities(self, offset: int) -> "TelegramMessage":
def offset_entities(self, offset: int) -> 'TelegramMessage':
def apply_offset(entity: TypeMessageEntity, inner_offset: int
) -> Optional[TypeMessageEntity]:
entity = Entity.copy(entity)
@@ -104,7 +104,7 @@ class TelegramMessage:
self.entities = [x for x in self.entities if x is not None]
return self
def append(self, *args: Union[str, "TelegramMessage"]) -> "TelegramMessage":
def append(self, *args: Union[str, 'TelegramMessage']) -> 'TelegramMessage':
for msg in args:
if isinstance(msg, str):
msg = TelegramMessage(text=msg)
@@ -112,7 +112,7 @@ class TelegramMessage:
self.text += msg.text
return self
def prepend(self, *args: Union[str, "TelegramMessage"]) -> "TelegramMessage":
def prepend(self, *args: Union[str, 'TelegramMessage']) -> 'TelegramMessage':
for msg in args:
if isinstance(msg, str):
msg = TelegramMessage(text=msg)
@@ -120,17 +120,17 @@ class TelegramMessage:
self.text = msg.text + self.text
return self
def format(self, entity_type: type(TypeMessageEntity), offset: int = None, length: int = None,
**kwargs) -> "TelegramMessage":
def format(self, entity_type: Type[TypeMessageEntity], offset: int = None, length: int = None,
**kwargs) -> 'TelegramMessage':
self.entities.append(entity_type(offset=offset or 0,
length=length if length is not None else len(self.text),
**kwargs))
return self
def concat(self, *args: Union[str, "TelegramMessage"]) -> "TelegramMessage":
def concat(self, *args: Union[str, 'TelegramMessage']) -> 'TelegramMessage':
return TelegramMessage().append(self, *args)
def trim(self) -> "TelegramMessage":
def trim(self) -> 'TelegramMessage':
orig_len = len(self.text)
self.text = self.text.lstrip()
diff = orig_len - len(self.text)
@@ -138,7 +138,7 @@ class TelegramMessage:
self.offset_entities(-diff)
return self
def split(self, separator, max_items: int = 0) -> List["TelegramMessage"]:
def split(self, separator, max_items: int = 0) -> List['TelegramMessage']:
text_parts = self.text.split(separator, max_items - 1)
output = [] # type: List[TelegramMessage]
@@ -158,7 +158,8 @@ class TelegramMessage:
return output
@staticmethod
def join(items: List[Union[str, "TelegramMessage"]], separator: str = " ") -> "TelegramMessage":
def join(items: Sequence[Union[str, 'TelegramMessage']],
separator: str = " ") -> 'TelegramMessage':
main = TelegramMessage()
for msg in items:
if isinstance(msg, str):
+12 -10
View File
@@ -14,7 +14,7 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional, List, Tuple, TYPE_CHECKING
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
from html import escape
import logging
import re
@@ -28,6 +28,7 @@ from telethon.tl.types import (MessageEntityMention, MessageEntityMentionName,
from mautrix_appservice import MatrixRequestError
from mautrix_appservice.intent_api import IntentAPI
from ..types import TelegramID
from .. import user as u, puppet as pu, portal as po
from ..db import Message as DBMessage
from .util import (add_surrogates, remove_surrogates, trim_reply_fallback_html,
@@ -40,14 +41,14 @@ if TYPE_CHECKING:
try:
from lxml.html.diff import htmldiff
except ImportError:
htmldiff = None # type: function
htmldiff = None # type: ignore
log = logging.getLogger("mau.fmt.tg") # type: logging.Logger
should_highlight_edits = False # type: bool
def telegram_reply_to_matrix(evt: Message, source: "AbstractUser") -> dict:
def telegram_reply_to_matrix(evt: Message, source: 'AbstractUser') -> Dict:
if evt.reply_to_msg_id:
space = (evt.to_id.channel_id
if isinstance(evt, Message) and isinstance(evt.to_id, PeerChannel)
@@ -116,7 +117,7 @@ def highlight_edits(new_html: str, old_html: str) -> str:
async def _add_reply_header(source: "AbstractUser", text: str, html: str, evt: Message,
relates_to: dict, main_intent: IntentAPI, is_edit: bool
relates_to: Dict, main_intent: IntentAPI, is_edit: bool
) -> Tuple[str, str]:
space = (evt.to_id.channel_id
if isinstance(evt, Message) and isinstance(evt.to_id, PeerChannel)
@@ -177,10 +178,10 @@ async def _add_reply_header(source: "AbstractUser", text: str, html: str, evt: M
async def telegram_to_matrix(evt: Message, source: "AbstractUser",
main_intent: Optional[IntentAPI] = None,
is_edit: bool = False, prefix_text: Optional[str] = None,
prefix_html: Optional[str] = None) -> Tuple[str, str, dict]:
prefix_html: Optional[str] = None) -> Tuple[str, str, Dict]:
text = add_surrogates(evt.message)
html = _telegram_entities_to_matrix_catch(text, evt.entities) if evt.entities else None
relates_to = {}
relates_to = {} # type: Dict
if prefix_html:
html = prefix_html + (html or escape(text))
@@ -217,6 +218,7 @@ def _telegram_entities_to_matrix_catch(text: str, entities: List[TypeMessageEnti
"message=%s\n"
"entities=%s",
text, entities)
return "[failed conversion in _telegram_entities_to_matrix]"
def _telegram_entities_to_matrix(text: str, entities: List[TypeMessageEntity]) -> str:
@@ -290,7 +292,7 @@ def _parse_mention(html: List[str], entity_text: str) -> bool:
return False
def _parse_name_mention(html: List[str], entity_text: str, user_id: int) -> bool:
def _parse_name_mention(html: List[str], entity_text: str, user_id: TelegramID) -> bool:
user = u.User.get_by_tgid(user_id)
if user:
mxid = user.mxid
@@ -315,8 +317,8 @@ def _parse_url(html: List[str], entity_text: str, url: str) -> bool:
message_link_match = message_link_regex.match(url)
if message_link_match:
group, msgid = message_link_match.groups()
msgid = int(msgid)
group, msgid_str = message_link_match.groups()
msgid = int(msgid_str)
portal = po.Portal.find_by_username(group)
if portal:
@@ -328,6 +330,6 @@ def _parse_url(html: List[str], entity_text: str, url: str) -> bool:
return False
def init_tg(context: "Context"):
def init_tg(context: "Context") -> None:
global should_highlight_edits
should_highlight_edits = htmldiff and context.config["bridge.highlight_edits"]
+60 -38
View File
@@ -14,27 +14,35 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import List, Dict, Tuple, Set, Match
from typing import Dict, List, Match, Optional, Set, Tuple, TYPE_CHECKING
import logging
import asyncio
import re
from mautrix_appservice import MatrixRequestError, IntentError
from .types import MatrixEvent, MatrixEventID, MatrixRoomID, MatrixUserID
from . import user as u, portal as po, puppet as pu, commands as com
if TYPE_CHECKING:
from mautrix_appservice import AppService
from .context import Context
from sqlalchemy.orm import scoped_session
from .config import Config
from .bot import Bot
class MatrixHandler:
log = logging.getLogger("mau.mx") # type: logging.Logger
def __init__(self, context):
self.az, self.db, self.config, _, self.tgbot = context
def __init__(self, context: 'Context') -> None:
self.az, self.db, self.config, _, self.tgbot = context.core
self.commands = com.CommandProcessor(context) # type: com.CommandProcessor
self.previously_typing = [] # type: List[str]
self.previously_typing = [] # type: List[MatrixUserID]
self.az.matrix_event_handler(self.handle_event)
async def init_as_bot(self):
async def init_as_bot(self) -> None:
displayname = self.config["appservice.bot_displayname"]
if displayname:
try:
@@ -50,7 +58,8 @@ class MatrixHandler:
except asyncio.TimeoutError:
self.log.exception("TimeoutError when trying to set avatar")
async def handle_puppet_invite(self, room_id, puppet: pu.Puppet, inviter: u.User):
async def handle_puppet_invite(self, room_id: MatrixRoomID, puppet: pu.Puppet, inviter: u.User
) -> None:
intent = puppet.default_mxid_intent
self.log.debug(f"{inviter} invited puppet for {puppet.tgid} to {room_id}")
if not await inviter.is_logged_in():
@@ -80,6 +89,7 @@ class MatrixHandler:
await intent.join_room(room_id)
portal = po.Portal.get_by_tgid(puppet.tgid, inviter.tgid, "user")
# TODO: if portal is None:
if portal.mxid:
try:
await intent.invite(portal.mxid, inviter.mxid)
@@ -95,13 +105,13 @@ class MatrixHandler:
portal.mxid = room_id
portal.save()
inviter.register_portal(portal)
await intent.send_notice(room_id, "po.Portal to private chat created.")
await intent.send_notice(room_id, "Portal to private chat created.")
else:
await intent.join_room(room_id)
await intent.send_notice(room_id, "This puppet will remain inactive until a "
"Telegram chat is created for this room.")
async def accept_bot_invite(self, room_id: str, inviter: u.User):
async def accept_bot_invite(self, room_id: MatrixRoomID, inviter: u.User) -> None:
tries = 0
while tries < 5:
try:
@@ -126,9 +136,13 @@ class MatrixHandler:
"<code>bridge.permissions</code> section in your config file.")
await self.az.intent.leave_room(room_id)
async def handle_invite(self, room_id: str, user_id: str, inviter_mxid: str):
async def handle_invite(self, room_id: MatrixRoomID, user_id: MatrixUserID,
inviter_mxid: MatrixUserID) -> None:
self.log.debug(f"{inviter_mxid} invited {user_id} to {room_id}")
inviter = await u.User.get_by_mxid(inviter_mxid).ensure_started()
inviter = u.User.get_by_mxid(inviter_mxid)
if inviter is None:
self.log.exception("Failed to find user with Matrix ID {inviter_mxid}")
await inviter.ensure_started()
if user_id == self.az.bot_mxid:
return await self.accept_bot_invite(room_id, inviter)
elif not inviter.whitelisted:
@@ -150,7 +164,8 @@ class MatrixHandler:
# The rest can probably be ignored
async def handle_join(self, room_id: str, user_id: str, event_id: str):
async def handle_join(self, room_id: MatrixRoomID, user_id: MatrixUserID,
event_id: MatrixEventID) -> None:
user = await u.User.get_by_mxid(user_id).ensure_started()
portal = po.Portal.get_by_mxid(room_id)
@@ -171,7 +186,8 @@ class MatrixHandler:
if await user.is_logged_in() or portal.has_bot:
await portal.join_matrix(user, event_id)
async def handle_part(self, room_id: str, user_id, sender_mxid: str, event_id: str):
async def handle_part(self, room_id: MatrixRoomID, user_id: MatrixUserID,
sender_mxid: MatrixUserID, event_id: MatrixEventID) -> None:
self.log.debug(f"{user_id} left {room_id}")
sender = u.User.get_by_mxid(sender_mxid, create=False)
@@ -185,6 +201,7 @@ class MatrixHandler:
puppet = pu.Puppet.get_by_mxid(user_id)
if sender and puppet:
# TODO: Puppet should probably be an AbstractUser
await portal.leave_matrix(puppet, sender, event_id)
user = u.User.get_by_mxid(user_id, create=False)
@@ -194,7 +211,7 @@ class MatrixHandler:
if await user.is_logged_in() or portal.has_bot:
await portal.leave_matrix(user, sender, event_id)
def is_command(self, message: dict) -> Tuple[bool, str]:
def is_command(self, message: Dict) -> Tuple[bool, str]:
text = message.get("body", "")
prefix = self.config["bridge.command_prefix"]
is_command = text.startswith(prefix)
@@ -202,9 +219,10 @@ class MatrixHandler:
text = text[len(prefix) + 1:]
return is_command, text
async def handle_message(self, room, sender, message, event_id):
async def handle_message(self, room: MatrixRoomID, sender_id: MatrixUserID, message: Dict,
event_id: MatrixEventID) -> None:
is_command, text = self.is_command(message)
sender = await u.User.get_by_mxid(sender).ensure_started()
sender = await u.User.get_by_mxid(sender_id).ensure_started()
if not sender.relaybot_whitelisted:
self.log.debug(f"Ignoring message \"{message}\" from {sender} to {room}:"
" u.User is not whitelisted.")
@@ -237,7 +255,8 @@ class MatrixHandler:
is_portal=portal is not None)
@staticmethod
async def handle_redaction(room_id: str, sender_mxid: str, event_id: str):
async def handle_redaction(room_id: MatrixRoomID, sender_mxid: MatrixUserID,
event_id: MatrixEventID) -> None:
sender = await u.User.get_by_mxid(sender_mxid).ensure_started()
if not sender.relaybot_whitelisted:
return
@@ -249,14 +268,16 @@ class MatrixHandler:
await portal.handle_matrix_deletion(sender, event_id)
@staticmethod
async def handle_power_levels(room_id: str, sender_mxid: str, new: dict, old: dict):
async def handle_power_levels(room_id: MatrixRoomID, sender_mxid: MatrixUserID,
new: Dict, old: Dict) -> None:
portal = po.Portal.get_by_mxid(room_id)
sender = await u.User.get_by_mxid(sender_mxid).ensure_started()
if await sender.has_full_access(allow_bot=True) and portal:
await portal.handle_matrix_power_levels(sender, new["users"], old["users"])
@staticmethod
async def handle_room_meta(evt_type: str, room_id: str, sender_mxid: str, content: dict):
async def handle_room_meta(evt_type: str, room_id: MatrixRoomID, sender_mxid: MatrixUserID,
content: dict) -> None:
portal = po.Portal.get_by_mxid(room_id)
sender = await u.User.get_by_mxid(sender_mxid).ensure_started()
if await sender.has_full_access(allow_bot=True) and portal:
@@ -270,8 +291,8 @@ class MatrixHandler:
await handler(sender, content[content_key])
@staticmethod
async def handle_room_pin(room_id: str, sender_mxid: str, new_events: Set[str],
old_events: Set[str]):
async def handle_room_pin(room_id: MatrixRoomID, sender_mxid: MatrixUserID,
new_events: Set[str], old_events: Set[str]) -> None:
portal = po.Portal.get_by_mxid(room_id)
sender = await u.User.get_by_mxid(sender_mxid).ensure_started()
if await sender.has_full_access(allow_bot=True) and portal:
@@ -284,8 +305,8 @@ class MatrixHandler:
await portal.handle_matrix_pin(sender, None)
@staticmethod
async def handle_name_change(room_id: str, user_id: str, displayname: str,
prev_displayname: str, event_id: str):
async def handle_name_change(room_id: MatrixRoomID, user_id: MatrixUserID, displayname: str,
prev_displayname: str, event_id: MatrixEventID) -> None:
portal = po.Portal.get_by_mxid(room_id)
if not portal or not portal.has_bot:
return
@@ -295,13 +316,14 @@ class MatrixHandler:
await portal.name_change_matrix(user, displayname, prev_displayname, event_id)
@staticmethod
def parse_read_receipts(content: dict) -> Dict[str, str]:
def parse_read_receipts(content: Dict) -> Dict[MatrixUserID, MatrixEventID]:
return {user_id: event_id
for event_id, receipts in content.items()
for user_id in receipts.get("m.read", {})}
@staticmethod
async def handle_read_receipts(room_id: str, receipts: Dict[str, str]):
async def handle_read_receipts(room_id: MatrixRoomID,
receipts: Dict[MatrixUserID, MatrixEventID]) -> None:
portal = po.Portal.get_by_mxid(room_id)
if not portal:
return
@@ -313,13 +335,13 @@ class MatrixHandler:
await portal.mark_read(user, event_id)
@staticmethod
async def handle_presence(user_id: str, presence: str):
async def handle_presence(user_id: MatrixUserID, presence: str) -> None:
user = await u.User.get_by_mxid(user_id).ensure_started()
if not await user.is_logged_in():
return
await user.set_presence(presence == "online")
user.set_presence(presence == "online")
async def handle_typing(self, room_id: str, now_typing: List[str]):
async def handle_typing(self, room_id: MatrixRoomID, now_typing: List[MatrixUserID]) -> None:
portal = po.Portal.get_by_mxid(room_id)
if not portal:
return
@@ -334,35 +356,35 @@ class MatrixHandler:
if not await user.is_logged_in():
continue
await portal.set_typing(user, is_typing)
portal.set_typing(user, is_typing)
self.previously_typing = now_typing
def filter_matrix_event(self, event: dict):
def filter_matrix_event(self, event: MatrixEvent) -> bool:
sender = event.get("sender", None)
if not sender:
return False
return (sender == self.az.bot_mxid
or pu.Puppet.get_id_from_mxid(sender) is not None)
async def try_handle_event(self, evt: dict):
async def try_handle_event(self, evt: MatrixEvent) -> None:
try:
await self.handle_event(evt)
except Exception:
self.log.exception("Error handling manually received Matrix event")
async def handle_event(self, evt: dict):
async def handle_event(self, evt: MatrixEvent) -> None:
if self.filter_matrix_event(evt):
return
self.log.debug("Received event: %s", evt)
evt_type = evt.get("type", "m.unknown") # type: str
room_id = evt.get("room_id", None) # type: str
event_id = evt.get("event_id", None) # type: str
sender = evt.get("sender", None) # type: str
content = evt.get("content", {}) # type: dict
room_id = evt.get("room_id", None) # type: Optional[MatrixRoomID]
event_id = evt.get("event_id", None) # type: Optional[MatrixEventID]
sender = evt.get("sender", None) # type: Optional[MatrixUserID]
content = evt.get("content", {}) # type: Dict
if evt_type == "m.room.member":
state_key = evt["state_key"] # type: str
prev_content = evt.get("unsigned", {}).get("prev_content", {}) # type: dict
state_key = evt["state_key"] # type: MatrixUserID
prev_content = evt.get("unsigned", {}).get("prev_content", {}) # type: Dict
membership = content.get("membership", "") # type: str
prev_membership = prev_content.get("membership", "leave") # type: str
if membership == prev_membership:
@@ -386,7 +408,7 @@ class MatrixHandler:
elif evt_type == "m.room.redaction":
await self.handle_redaction(room_id, sender, evt["redacts"])
elif evt_type == "m.room.power_levels":
prev_content = evt.get("unsigned", {}).get("prev_content", {}) # type: dict
prev_content = evt.get("unsigned", {}).get("prev_content", {})
await self.handle_power_levels(room_id, sender, evt["content"], prev_content)
elif evt_type in ("m.room.name", "m.room.avatar", "m.room.topic"):
await self.handle_room_meta(evt_type, room_id, sender, evt["content"])
+151 -109
View File
@@ -14,7 +14,7 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Pattern, Dict, Tuple, Awaitable, TYPE_CHECKING
from typing import Awaitable, Dict, List, Optional, Pattern, Tuple, Union, cast, TYPE_CHECKING
from collections import deque
from datetime import datetime
from string import Template
@@ -32,14 +32,37 @@ from sqlalchemy import orm
from sqlalchemy.exc import IntegrityError, InvalidRequestError
from sqlalchemy.orm.exc import FlushError
from telethon.tl.functions.messages import *
from telethon.tl.functions.channels import *
from telethon.tl.functions.messages import (
AddChatUserRequest, CreateChatRequest, DeleteChatUserRequest, EditChatAdminRequest,
EditChatPhotoRequest, EditChatTitleRequest, ExportChatInviteRequest, GetFullChatRequest,
MigrateChatRequest, SetTypingRequest)
from telethon.tl.functions.channels import (
CreateChannelRequest, EditAboutRequest, EditAdminRequest, EditBannedRequest, EditPhotoRequest,
EditTitleRequest, ExportInviteRequest, GetParticipantsRequest, InviteToChannelRequest,
JoinChannelRequest, LeaveChannelRequest, UpdatePinnedMessageRequest, UpdateUsernameRequest)
from telethon.tl.functions.messages import ReadHistoryRequest as ReadMessageHistoryRequest
from telethon.tl.functions.channels import ReadHistoryRequest as ReadChannelHistoryRequest
from telethon.errors import ChatAdminRequiredError, ChatNotModifiedError
from telethon.tl.types import *
from telethon.tl.types import (
Channel, ChannelAdminRights, ChannelBannedRights, ChannelFull, ChannelParticipantAdmin,
ChannelParticipantCreator, ChannelParticipantsRecent, ChannelParticipantsSearch, Chat,
ChatFull, ChatInviteEmpty, ChatParticipantAdmin, ChatParticipantCreator, ChatPhoto,
DocumentAttributeFilename, DocumentAttributeImageSize, DocumentAttributeSticker,
DocumentAttributeVideo, FileLocation, GeoPoint, InputChannel, InputChatUploadedPhoto,
InputPeerChannel, InputPeerChat, InputPeerUser, InputUser, InputUserSelf, Message,
MessageActionChannelCreate, MessageActionChatAddUser, MessageActionChatCreate,
MessageActionChatDeletePhoto, MessageActionChatDeleteUser, MessageActionChatEditPhoto,
MessageActionChatEditTitle, MessageActionChatJoinedByLink, MessageActionChatMigrateTo,
MessageActionPinMessage, MessageMediaContact, MessageMediaDocument, MessageMediaGeo,
MessageMediaPhoto, MessageService, PeerChannel, PeerChat, PeerUser, Photo, PhotoCachedSize,
SendMessageCancelAction, SendMessageTypingAction, TypeChannelParticipant, TypeChat,
TypeChatParticipant, TypeDocumentAttribute, TypeInputPeer, TypeMessageAction,
TypeMessageEntity, TypePeer, TypePhotoSize, TypeUpdates, TypeUser, TypeUserFull,
UpdateChatUserTyping, UpdateNewChannelMessage, UpdateNewMessage, UpdateUserTyping, User,
UserFull)
from mautrix_appservice import MatrixRequestError, IntentError, AppService, IntentAPI
from .types import MatrixEventID, MatrixRoomID, MatrixUserID, TelegramID
from .context import Context
from .db import Portal as DBPortal, Message as DBMessage, TelegramFile as DBTelegramFile
from . import puppet as p, user as u, formatter, util
@@ -82,18 +105,18 @@ class Portal:
by_mxid = {} # type: Dict[str, Portal]
by_tgid = {} # type: Dict[Tuple[int, int], Portal]
def __init__(self, tgid: int, peer_type: str, tg_receiver: Optional[int] = None,
mxid: Optional[str] = None, username: Optional[str] = None,
def __init__(self, tgid: TelegramID, peer_type: str, tg_receiver: Optional[int] = None,
mxid: Optional[MatrixRoomID] = None, username: Optional[str] = None,
megagroup: Optional[bool] = False, title: Optional[str] = None,
about: Optional[str] = None, photo_id: Optional[str] = None,
db_instance: DBPortal = None):
self.mxid = mxid # type: str
self.tgid = tgid # type: int
db_instance: DBPortal = None) -> None:
self.mxid = mxid # type: Optional[MatrixRoomID]
self.tgid = tgid # type: TelegramID
self.tg_receiver = tg_receiver or tgid # type: int
self.peer_type = peer_type # type: str
self.username = username # type: str
self.megagroup = megagroup # type: bool
self.title = title # type: str
self.title = title # type: Optional[str]
self.about = about # type: str
self.photo_id = photo_id # type: str
self._db_instance = db_instance # type: DBPortal
@@ -138,7 +161,7 @@ class Portal:
@property
def has_bot(self) -> bool:
return self.bot and self.bot.is_in_chat(self.tgid)
return bool(self.bot and self.bot.is_in_chat(self.tgid))
@property
def main_intent(self) -> IntentAPI:
@@ -232,13 +255,13 @@ class Portal:
del self._dedup_mxid[self._dedup.popleft()]
return None
def get_input_entity(self, user: u.User) -> Awaitable[TypeInputPeer]:
def get_input_entity(self, user: 'u.User') -> Awaitable[TypeInputPeer]:
return user.client.get_input_entity(self.peer)
# endregion
# region Matrix room info updating
async def invite_to_matrix(self, users: InviteList):
async def invite_to_matrix(self, users: InviteList) -> None:
if isinstance(users, str):
await self.main_intent.invite(self.mxid, users, check_cache=True)
elif isinstance(users, list):
@@ -247,10 +270,10 @@ class Portal:
else:
raise ValueError("Invalid invite identifier given to invite_matrix()")
async def update_matrix_room(self, user: "AbstractUser", entity: TypeChat, direct: bool,
puppet: p.Puppet = None, levels: dict = None,
async def update_matrix_room(self, user: 'AbstractUser', entity: TypeChat, direct: bool,
puppet: p.Puppet = None, levels: Dict = None,
users: List[User] = None,
participants: List[TypeParticipant] = None):
participants: List[TypeParticipant] = None) -> None:
if not direct:
await self.update_info(user, entity)
if not users or not participants:
@@ -280,8 +303,8 @@ class Portal:
async with self._room_create_lock:
return await self._create_matrix_room(user, entity, invites)
async def _create_matrix_room(self, user: "AbstractUser", entity: TypeChat, invites: InviteList
) -> Optional[str]:
async def _create_matrix_room(self, user: 'AbstractUser', entity: TypeChat, invites: InviteList
) -> Optional[MatrixRoomID]:
direct = self.peer_type == "user"
if self.mxid:
@@ -346,6 +369,8 @@ class Portal:
participants=participants),
loop=self.loop)
return self.mxid
def _get_base_power_levels(self, levels: dict = None, entity: TypeChat = None) -> dict:
levels = levels or {}
power_level_requirement = (0 if self.peer_type == "chat" and not entity.admins_enabled
@@ -383,7 +408,7 @@ class Portal:
return None
return self.alias_template.format(groupname=username)
def add_bot_chat(self, bot: User):
def add_bot_chat(self, bot: User) -> None:
if self.bot and bot.id == self.bot.tgid:
self.bot.add_chat(self.tgid, self.peer_type)
return
@@ -392,7 +417,7 @@ class Portal:
if user and user.is_bot:
user.register_portal(self)
async def sync_telegram_users(self, source: "AbstractUser", users: List[User]):
async def sync_telegram_users(self, source: "AbstractUser", users: List[User]) -> None:
allowed_tgids = set()
for entity in users:
puppet = p.Puppet.get(entity.id)
@@ -414,18 +439,19 @@ class Portal:
and config["bridge.max_initial_member_sync"] == -1
and (self.megagroup or self.peer_type != "channel"))
if trust_member_list:
joined_mxids = await self.main_intent.get_room_members(self.mxid)
for user in joined_mxids:
if user == self.az.bot_mxid:
joined_mxids = cast(List[MatrixUserID],
await self.main_intent.get_room_members(self.mxid))
for user_mxid in joined_mxids:
if user_mxid == self.az.bot_mxid:
continue
puppet_id = p.Puppet.get_id_from_mxid(user)
puppet_id = p.Puppet.get_id_from_mxid(user_mxid)
if puppet_id and puppet_id not in allowed_tgids:
if self.bot and puppet_id == self.bot.tgid:
self.bot.remove_chat(self.tgid)
await self.main_intent.kick(self.mxid, user,
await self.main_intent.kick(self.mxid, user_mxid,
"User had left this Telegram chat.")
continue
mx_user = u.User.get_by_mxid(user, create=False)
mx_user = u.User.get_by_mxid(user_mxid, create=False)
if mx_user and mx_user.is_bot and mx_user.tgid not in allowed_tgids:
mx_user.unregister_portal(self)
@@ -434,7 +460,8 @@ class Portal:
"You had left this Telegram chat.")
continue
async def add_telegram_user(self, user_id: int, source: Optional["AbstractUser"] = None):
async def add_telegram_user(self, user_id: TelegramID, source: Optional['AbstractUser'] = None
) -> None:
puppet = p.Puppet.get(user_id)
if source:
entity = await source.client.get_entity(PeerUser(user_id))
@@ -446,7 +473,7 @@ class Portal:
user.register_portal(self)
await self.invite_to_matrix(user.mxid)
async def delete_telegram_user(self, user_id: int, sender: p.Puppet):
async def delete_telegram_user(self, user_id: TelegramID, sender: p.Puppet) -> None:
puppet = p.Puppet.get(user_id)
user = u.User.get_by_tgid(user_id)
kick_message = (f"Kicked by {sender.displayname}"
@@ -460,7 +487,7 @@ class Portal:
user.unregister_portal(self)
await self.main_intent.kick(self.mxid, user.mxid, kick_message)
async def update_info(self, user: "AbstractUser", entity: TypeChat = None):
async def update_info(self, user: "AbstractUser", entity: TypeChat = None) -> None:
if self.peer_type == "user":
self.log.warning(f"Called update_info() for direct chat portal {self.tgid_log}")
return
@@ -524,7 +551,7 @@ class Portal:
return max(photo.sizes, key=(lambda photo2: (
len(photo2.bytes) if isinstance(photo2, PhotoCachedSize) else photo2.size)))
async def remove_avatar(self, _: "AbstractUser", save: bool = False):
async def remove_avatar(self, _: "AbstractUser", save: bool = False) -> None:
await self.main_intent.set_room_avatar(self.mxid, None)
self.photo_id = None
if save:
@@ -544,8 +571,9 @@ class Portal:
return True
return False
async def _get_users(self, user: "AbstractUser", entity: Union[TypeInputPeer, InputUser,
TypeChat, TypeUser]
async def _get_users(self,
user: 'AbstractUser',
entity: Union[TypeInputPeer, InputUser, TypeChat, TypeUser]
) -> Tuple[List[TypeUser], List[TypeParticipant]]:
if self.peer_type == "chat":
chat = await user.client(GetFullChatRequest(chat_id=self.tgid))
@@ -564,7 +592,7 @@ class Portal:
entity, ChannelParticipantsRecent(), offset=0, limit=limit, hash=0))
return response.users, response.participants
elif limit > 200 or limit == -1:
users, participants = [], []
users, participants = [], [] # type: Tuple[List[TypeUser], List[TypeParticipant]]
offset = 0
remaining_quota = limit if limit > 0 else 1000000
query = (ChannelParticipantsSearch("") if limit == -1
@@ -585,8 +613,9 @@ class Portal:
return [], []
elif self.peer_type == "user":
return [entity], []
return [], []
async def get_invite_link(self, user: u.User) -> str:
async def get_invite_link(self, user: 'u.User') -> str:
if self.peer_type == "user":
raise ValueError("You can't invite users to private chats.")
elif self.peer_type == "chat":
@@ -604,7 +633,7 @@ class Portal:
return link.link
async def get_authenticated_matrix_users(self) -> List[u.User]:
async def get_authenticated_matrix_users(self) -> List['u.User']:
try:
members = await self.main_intent.get_room_members(self.mxid)
except MatrixRequestError:
@@ -622,7 +651,7 @@ class Portal:
@staticmethod
async def cleanup_room(intent: IntentAPI, room_id: str, message: str = "Portal deleted",
puppets_only: bool = False):
puppets_only: bool = False) -> None:
try:
members = await intent.get_room_members(room_id)
except MatrixRequestError:
@@ -639,11 +668,11 @@ class Portal:
pass
await intent.leave_room(room_id)
async def unbridge(self):
async def unbridge(self) -> None:
await self.cleanup_room(self.main_intent, self.mxid, "Room unbridged", puppets_only=True)
self.delete()
async def cleanup_and_delete(self):
async def cleanup_and_delete(self) -> None:
await self.cleanup_room(self.main_intent, self.mxid)
self.delete()
@@ -663,8 +692,8 @@ class Portal:
else:
return ""
async def _get_state_change_message(self, event: str, user: u.User,
arguments: Optional[dict] = None) -> Optional[dict]:
async def _get_state_change_message(self, event: str, user: 'u.User',
arguments: Optional[Dict] = None) -> Optional[Dict]:
tpl = config[f"bridge.state_event_formats.{event}"]
if len(tpl) == 0:
# Empty format means they don't want the message
@@ -681,8 +710,8 @@ class Portal:
"formatted_body": message,
}
async def name_change_matrix(self, user: u.User, displayname: str, prev_displayname: str,
event_id: str):
async def name_change_matrix(self, user: 'u.User', displayname: str, prev_displayname: str,
event_id: str) -> None:
async with self.require_send_lock(self.bot.tgid):
message = await self._get_state_change_message(
"name_change", user,
@@ -695,15 +724,16 @@ class Portal:
space = self.tgid if self.peer_type == "channel" else self.bot.tgid
self.is_duplicate(response, (event_id, space))
async def get_displayname(self, user: u.User) -> str:
async def get_displayname(self, user: 'u.User') -> str:
return (await self.main_intent.get_displayname(self.mxid, user.mxid)
or user.mxid_localpart)
def set_typing(self, user: u.User, typing: bool = True, action=SendMessageTypingAction):
def set_typing(self, user: 'u.User', typing: bool = True,
action: type = SendMessageTypingAction) -> bool:
return user.client(SetTypingRequest(
self.peer, action() if typing else SendMessageCancelAction()))
async def mark_read(self, user: u.User, event_id: str):
async def mark_read(self, user: 'u.User', event_id: MatrixEventID) -> None:
if user.is_bot:
return
space = self.tgid if self.peer_type == "channel" else user.tgid
@@ -718,7 +748,8 @@ class Portal:
else:
await user.client(ReadMessageHistoryRequest(peer=self.peer, max_id=message.tgid))
async def leave_matrix(self, user: u.User, source: u.User, event_id: str):
async def leave_matrix(self, user: 'u.User', source: 'u.User', event_id: MatrixEventID
) -> None:
if await user.needs_relaybot(self):
async with self.require_send_lock(self.bot.tgid):
message = await self._get_state_change_message("leave", user)
@@ -754,7 +785,7 @@ class Portal:
channel = await self.get_input_entity(user)
await user.client(LeaveChannelRequest(channel=channel))
async def join_matrix(self, user: u.User, event_id: str):
async def join_matrix(self, user: 'u.User', event_id: str) -> None:
if await user.needs_relaybot(self):
async with self.require_send_lock(self.bot.tgid):
message = await self._get_state_change_message("join", user)
@@ -773,7 +804,7 @@ class Portal:
# We'll just assume the user is already in the chat.
pass
async def _apply_msg_format(self, sender: u.User, msgtype: str, message: dict):
async def _apply_msg_format(self, sender: 'u.User', msgtype: str, message: Dict) -> None:
if "formatted_body" not in message:
message["format"] = "org.matrix.custom.html"
message["formatted_body"] = escape_html(message.get("body", ""))
@@ -788,7 +819,8 @@ class Portal:
message=body)
message["formatted_body"] = Template(tpl).safe_substitute(tpl_args)
async def _pre_process_matrix_message(self, sender: u.User, use_relaybot: bool, message: dict):
async def _pre_process_matrix_message(self, sender: 'u.User', use_relaybot: bool,
message: dict) -> None:
msgtype = message.get("msgtype", "m.text")
if msgtype == "m.emote":
await self._apply_msg_format(sender, msgtype, message)
@@ -797,7 +829,7 @@ class Portal:
await self._apply_msg_format(sender, msgtype, message)
@staticmethod
def _matrix_event_to_entities(event: dict) -> Tuple[str, Optional[List[TypeMessageEntity]]]:
def _matrix_event_to_entities(event: Dict) -> Tuple[str, Optional[List[TypeMessageEntity]]]:
try:
if event.get("format", None) == "org.matrix.custom.html":
message, entities = formatter.matrix_to_telegram(event.get("formatted_body", ""))
@@ -825,7 +857,8 @@ class Portal:
return None
async def _handle_matrix_text(self, sender_id: int, event_id: str, space: int,
client: "MautrixTelegramClient", message: dict, reply_to: int):
client: 'MautrixTelegramClient', message: Dict, reply_to: int
) -> None:
lock = self.require_send_lock(sender_id)
async with lock:
response = await client.send_message(self.peer, message, reply_to=reply_to,
@@ -833,7 +866,8 @@ class Portal:
self._add_telegram_message_to_db(event_id, space, response)
async def _handle_matrix_file(self, msgtype: str, sender_id: int, event_id: str, space: int,
client: "MautrixTelegramClient", message: dict, reply_to: int):
client: 'MautrixTelegramClient', message: dict, reply_to: int
) -> None:
file = await self.main_intent.download_file(message["url"])
info = message.get("info", {})
@@ -867,24 +901,25 @@ class Portal:
self._add_telegram_message_to_db(event_id, space, response)
async def _handle_matrix_location(self, sender_id: int, event_id: str, space: int,
client: "MautrixTelegramClient", message: dict,
reply_to: int):
client: 'MautrixTelegramClient', message: Dict,
reply_to: int) -> None:
try:
lat, long = message["geo_uri"][len("geo:"):].split(",")
lat, long = float(lat), float(long)
except (KeyError, ValueError):
self.log.exception("Failed to parse location")
return None
message, entities = self._matrix_event_to_entities(message)
caption, entities = self._matrix_event_to_entities(message)
media = MessageMediaGeo(geo=GeoPoint(lat, long, access_hash=0))
lock = self.require_send_lock(sender_id)
async with lock:
response = await client.send_media(self.peer, media, reply_to=reply_to,
caption=message, entities=entities)
caption=caption, entities=entities)
self._add_telegram_message_to_db(event_id, space, response)
def _add_telegram_message_to_db(self, event_id: str, space: int, response: TypeMessage):
def _add_telegram_message_to_db(self, event_id: str, space: int,
response: TypeMessage) -> None:
self.log.debug("Handled Matrix message: %s", response)
self.is_duplicate(response, (event_id, space))
self.db.add(DBMessage(
@@ -894,7 +929,7 @@ class Portal:
mxid=event_id))
self.db.commit()
async def handle_matrix_message(self, sender: u.User, message: dict, event_id: str):
async def handle_matrix_message(self, sender: 'u.User', message: dict, event_id: str) -> None:
puppet = p.Puppet.get_by_custom_mxid(sender.mxid)
if puppet and message.get("net.maunium.telegram.puppet", False):
self.log.debug("Ignoring puppet-sent message by confirmed puppet user %s", sender.mxid)
@@ -922,7 +957,7 @@ class Portal:
else:
self.log.debug(f"Unhandled Matrix event: {message}")
async def handle_matrix_pin(self, sender: u.User, pinned_message: Optional[str]):
async def handle_matrix_pin(self, sender: 'u.User', pinned_message: Optional[str]) -> None:
if self.peer_type != "channel":
return
try:
@@ -936,17 +971,18 @@ class Portal:
except ChatNotModifiedError:
pass
async def handle_matrix_deletion(self, deleter: u.User, event_id: str):
deleter = deleter if not await deleter.needs_relaybot(self) else self.bot
space = self.tgid if self.peer_type == "channel" else deleter.tgid
async def handle_matrix_deletion(self, deleter: 'u.User', event_id: MatrixEventID) -> None:
real_deleter = deleter if not await deleter.needs_relaybot(self) else self.bot
space = self.tgid if self.peer_type == "channel" else real_deleter.tgid
message = DBMessage.query.filter(DBMessage.mxid == event_id,
DBMessage.tg_space == space,
DBMessage.mx_room == self.mxid).one_or_none()
if not message:
return
await deleter.client.delete_messages(self.peer, [message.tgid])
await real_deleter.client.delete_messages(self.peer, [message.tgid])
async def _update_telegram_power_level(self, sender: u.User, user_id: int, level: int):
async def _update_telegram_power_level(self, sender: 'u.User', user_id: TelegramID,
level: int) -> None:
if self.peer_type == "chat":
await sender.client(EditChatAdminRequest(
chat_id=self.tgid, user_id=user_id, is_admin=level >= 50))
@@ -962,8 +998,9 @@ class Portal:
EditAdminRequest(channel=await self.get_input_entity(sender),
user_id=user_id, admin_rights=rights))
async def handle_matrix_power_levels(self, sender: u.User, new_users: Dict[str, int],
old_users: Dict[str, int]):
async def handle_matrix_power_levels(self, sender: 'u.User',
new_users: Dict[MatrixUserID, int],
old_users: Dict[str, int]) -> None:
# TODO handle all power level changes and bridge exact admin rights to supergroups/channels
for user, level in new_users.items():
if not user or user == self.main_intent.mxid or user == sender.mxid:
@@ -979,7 +1016,7 @@ class Portal:
if user not in old_users or level != old_users[user]:
await self._update_telegram_power_level(sender, user_id, level)
async def handle_matrix_about(self, sender: u.User, about: str):
async def handle_matrix_about(self, sender: 'u.User', about: str) -> None:
if self.peer_type not in {"channel"}:
return
channel = await self.get_input_entity(sender)
@@ -987,7 +1024,7 @@ class Portal:
self.about = about
self.save()
async def handle_matrix_title(self, sender: u.User, title: str):
async def handle_matrix_title(self, sender: 'u.User', title: str) -> None:
if self.peer_type not in {"chat", "channel"}:
return
@@ -1000,7 +1037,7 @@ class Portal:
self.title = title
self.save()
async def handle_matrix_avatar(self, sender: u.User, url: str):
async def handle_matrix_avatar(self, sender: 'u.User', url: str) -> None:
if self.peer_type not in {"chat", "channel"}:
# Invalid peer type
return
@@ -1027,7 +1064,7 @@ class Portal:
self.save()
break
def _register_outgoing_actions_for_dedup(self, response: TypeUpdates):
def _register_outgoing_actions_for_dedup(self, response: TypeUpdates) -> None:
for update in response.updates:
check_dedup = (isinstance(update, (UpdateNewMessage, UpdateNewChannelMessage))
and isinstance(update.message, MessageService))
@@ -1051,7 +1088,7 @@ class Portal:
user_tgids.add(puppet_id)
return list(user_tgids)
async def upgrade_telegram_chat(self, source: u.User):
async def upgrade_telegram_chat(self, source: 'u.User') -> None:
if self.peer_type != "chat":
raise ValueError("Only normal group chats are upgradable to supergroups.")
@@ -1067,7 +1104,7 @@ class Portal:
self.migrate_and_save(entity.id)
await self.update_info(source, entity)
async def set_telegram_username(self, source: u.User, username: str):
async def set_telegram_username(self, source: 'u.User', username: str) -> None:
if self.peer_type != "channel":
raise ValueError("Only channels and supergroups have usernames.")
await source.client(
@@ -1075,7 +1112,7 @@ class Portal:
if await self.update_username(username):
self.save()
async def create_telegram_chat(self, source: u.User, supergroup: bool = False):
async def create_telegram_chat(self, source: 'u.User', supergroup: bool = False) -> None:
if not self.mxid:
raise ValueError("Can't create Telegram chat for portal without Matrix room.")
elif self.tgid:
@@ -1116,7 +1153,8 @@ class Portal:
await self.main_intent.set_power_levels(self.mxid, levels)
await self.handle_matrix_power_levels(source, levels["users"], {})
async def invite_telegram(self, source: u.User, puppet: Union[p.Puppet, "AbstractUser"]):
async def invite_telegram(self, source: 'u.User',
puppet: Union[p.Puppet, "AbstractUser"]) -> None:
if self.peer_type == "chat":
await source.client(
AddChatUserRequest(chat_id=self.tgid, user_id=puppet.tgid, fwd_limit=0))
@@ -1129,7 +1167,7 @@ class Portal:
# region Telegram event handling
async def handle_telegram_typing(self, user: p.Puppet,
_: Union[UpdateUserTyping, UpdateChatUserTyping]):
_: Union[UpdateUserTyping, UpdateChatUserTyping]) -> None:
if self.mxid:
await user.intent.set_typing(self.mxid, is_typing=True)
@@ -1139,7 +1177,7 @@ class Portal:
return None
async def handle_telegram_photo(self, source: "AbstractUser", intent: IntentAPI, evt: Message,
relates_to=None):
relates_to: Dict = {}) -> None:
largest_size = self._get_largest_photo_size(evt.media.photo)
file = await util.transfer_file_to_matrix(self.db, source.client, intent,
largest_size.location)
@@ -1169,7 +1207,7 @@ class Portal:
external_url=self.get_external_url(evt))
@staticmethod
def _parse_telegram_document_attributes(attributes: List[TypeDocumentAttribute]) -> dict:
def _parse_telegram_document_attributes(attributes: List[TypeDocumentAttribute]) -> Dict:
attrs = {
"name": None,
"mime_type": None,
@@ -1177,7 +1215,7 @@ class Portal:
"sticker_alt": None,
"width": None,
"height": None,
}
} # type: Dict
for attr in attributes:
if isinstance(attr, DocumentAttributeFilename):
attrs["name"] = attrs["name"] or attr.file_name
@@ -1190,8 +1228,8 @@ class Portal:
return attrs
@staticmethod
def _parse_telegram_document_meta(evt: Message, file: DBTelegramFile, attrs: dict
) -> Tuple[dict, str]:
def _parse_telegram_document_meta(evt: Message, file: DBTelegramFile, attrs: Dict
) -> Tuple[Dict, str]:
document = evt.media.document
name = evt.message or attrs["name"]
if attrs["is_sticker"]:
@@ -1225,7 +1263,7 @@ class Portal:
async def handle_telegram_document(self, source: "AbstractUser", intent: IntentAPI,
evt: Message,
relates_to: dict = None) -> Optional[dict]:
relates_to: dict = None) -> Optional[Dict]:
document = evt.media.document
attrs = self._parse_telegram_document_attributes(document.attributes)
@@ -1300,7 +1338,8 @@ class Portal:
msgtype=msgtype, timestamp=evt.date,
external_url=self.get_external_url(evt))
async def handle_telegram_edit(self, source: "AbstractUser", sender: p.Puppet, evt: Message):
async def handle_telegram_edit(self, source: "AbstractUser", sender: p.Puppet,
evt: Message) -> None:
if not self.mxid:
return
elif not config["bridge.edits_as_replies"]:
@@ -1349,7 +1388,8 @@ class Portal:
.update({"mxid": mxid})
self.db.commit()
async def handle_telegram_message(self, source: "AbstractUser", sender: p.Puppet, evt: Message):
async def handle_telegram_message(self, source: "AbstractUser", sender: p.Puppet,
evt: Message) -> None:
if not self.mxid:
await self.create_matrix_room(source, invites=[source.mxid], update_if_exists=False)
@@ -1461,7 +1501,7 @@ class Portal:
return True
async def handle_telegram_action(self, source: "AbstractUser", sender: p.Puppet,
update: MessageService):
update: MessageService) -> None:
action = update.action
should_ignore = ((not self.mxid and not await self._create_room_on_action(source, action))
or self.is_duplicate_action(update))
@@ -1491,9 +1531,9 @@ class Portal:
else:
self.log.debug("Unhandled Telegram action in %s: %s", self.title, action)
async def set_telegram_admin(self, user_id: int):
async def set_telegram_admin(self, user_id: TelegramID) -> None:
puppet = p.Puppet.get(user_id)
user = await u.User.get_by_tgid(user_id)
user = u.User.get_by_tgid(user_id)
levels = await self.main_intent.get_power_levels(self.mxid)
if user:
@@ -1502,12 +1542,12 @@ class Portal:
levels["users"][puppet.mxid] = 50
await self.main_intent.set_power_levels(self.mxid, levels)
async def receive_telegram_pin_sender(self, sender: p.Puppet):
async def receive_telegram_pin_sender(self, sender: p.Puppet) -> None:
self._temp_pinned_message_sender = sender
if self._temp_pinned_message_id:
await self.update_telegram_pin()
async def update_telegram_pin(self):
async def update_telegram_pin(self) -> None:
intent = (self._temp_pinned_message_sender.intent
if self._temp_pinned_message_sender else self.main_intent)
msg_id = self._temp_pinned_message_id
@@ -1520,7 +1560,7 @@ class Portal:
else:
await intent.set_pinned_messages(self.mxid, [])
async def receive_telegram_pin_id(self, msg_id: int):
async def receive_telegram_pin_id(self, msg_id: int) -> None:
if msg_id == 0:
return await self.update_telegram_pin()
self._temp_pinned_message_id = msg_id
@@ -1528,7 +1568,7 @@ class Portal:
await self.update_telegram_pin()
@staticmethod
def _get_level_from_participant(participant: TypeParticipant, _) -> int:
def _get_level_from_participant(participant: TypeParticipant, _: Dict) -> int:
# TODO use the power level requirements to get better precision in channels
if isinstance(participant, (ChatParticipantAdmin, ChannelParticipantAdmin)):
return 50
@@ -1537,7 +1577,7 @@ class Portal:
return 0
@staticmethod
def _participant_to_power_levels(levels: dict, user: Union[u.User, p.Puppet], new_level: int,
def _participant_to_power_levels(levels: dict, user: Union['u.User', p.Puppet], new_level: int,
bot_level: int) -> bool:
new_level = min(new_level, bot_level)
default_level = levels["users_default"] if "users_default" in levels else 0
@@ -1569,7 +1609,7 @@ class Portal:
except KeyError:
return 50
def _participants_to_power_levels(self, participants: List[TypeParticipant], levels: dict
def _participants_to_power_levels(self, participants: List[TypeParticipant], levels: Dict
) -> bool:
bot_level = self._get_bot_level(levels)
if bot_level < self._get_powerlevel_level(levels):
@@ -1596,13 +1636,13 @@ class Portal:
return changed
async def update_telegram_participants(self, participants: List[TypeParticipant],
levels: dict = None):
levels: dict = None) -> None:
if not levels:
levels = await self.main_intent.get_power_levels(self.mxid)
if self._participants_to_power_levels(participants, levels):
await self.main_intent.set_power_levels(self.mxid, levels)
async def set_telegram_admins_enabled(self, enabled: bool):
async def set_telegram_admins_enabled(self, enabled: bool) -> None:
level = 50 if enabled else 10
levels = await self.main_intent.get_power_levels(self.mxid)
levels["invite"] = level
@@ -1624,7 +1664,7 @@ class Portal:
mxid=self.mxid, username=self.username, megagroup=self.megagroup,
title=self.title, about=self.about, photo_id=self.photo_id)
def migrate_and_save(self, new_id: int):
def migrate_and_save(self, new_id: TelegramID) -> None:
existing = DBPortal.query.get(self.tgid_full)
if existing:
self.db.delete(existing)
@@ -1637,7 +1677,7 @@ class Portal:
self.by_tgid[self.tgid_full] = self
self.save()
def save(self):
def save(self) -> None:
self.db_instance.mxid = self.mxid
self.db_instance.username = self.username
self.db_instance.title = self.title
@@ -1645,7 +1685,7 @@ class Portal:
self.db_instance.photo_id = self.photo_id
self.db.commit()
def delete(self):
def delete(self) -> None:
try:
del self.by_tgid[self.tgid_full]
except KeyError:
@@ -1660,7 +1700,7 @@ class Portal:
self.deleted = True
@classmethod
def from_db(cls, db_portal: DBPortal) -> "Portal":
def from_db(cls, db_portal: DBPortal) -> 'Portal':
return Portal(tgid=db_portal.tgid, tg_receiver=db_portal.tg_receiver,
peer_type=db_portal.peer_type, mxid=db_portal.mxid,
username=db_portal.username, megagroup=db_portal.megagroup,
@@ -1671,7 +1711,7 @@ class Portal:
# region Class instance lookup
@classmethod
def get_by_mxid(cls, mxid: str) -> Optional["Portal"]:
def get_by_mxid(cls, mxid: MatrixRoomID) -> Optional['Portal']:
try:
return cls.by_mxid[mxid]
except KeyError:
@@ -1691,7 +1731,7 @@ class Portal:
return None
@classmethod
def find_by_username(cls, username: str) -> Optional["Portal"]:
def find_by_username(cls, username: str) -> Optional['Portal']:
if not username:
return None
@@ -1699,15 +1739,15 @@ class Portal:
if portal.username and portal.username.lower() == username.lower():
return portal
portal = DBPortal.query.filter(DBPortal.username == username).one_or_none()
if portal:
return cls.from_db(portal)
dbportal = DBPortal.query.filter(DBPortal.username == username).one_or_none()
if dbportal:
return cls.from_db(dbportal)
return None
@classmethod
def get_by_tgid(cls, tgid: int, tg_receiver: int = None, peer_type: str = None
) -> Optional["Portal"]:
def get_by_tgid(cls, tgid: TelegramID, tg_receiver: Optional[TelegramID] = None,
peer_type: str = None) -> Optional['Portal']:
tg_receiver = tg_receiver or tgid
tgid_full = (tgid, tg_receiver)
try:
@@ -1728,8 +1768,10 @@ class Portal:
return None
@classmethod
def get_by_entity(cls, entity: Union[TypeChat, TypePeer, TypeUser, TypeUserFull, TypeInputPeer],
receiver_id: int = None, create: bool = True) -> Optional["Portal"]:
def get_by_entity(cls, entity: Union[TypeChat, TypePeer, TypeUser, TypeUserFull,
TypeInputPeer],
receiver_id: Optional[TelegramID] = None, create: bool = True
) -> Optional['Portal']:
entity_type = type(entity)
if entity_type in {Chat, ChatFull}:
type_name = "chat"
@@ -1758,9 +1800,9 @@ class Portal:
# endregion
def init(context: Context):
def init(context: Context) -> None:
global config
Portal.az, Portal.db, config, Portal.loop, Portal.bot = context
Portal.az, Portal.db, config, Portal.loop, Portal.bot = context.core
Portal.bridge_notices = config["bridge.bridge_notices"]
Portal.filter_mode = config["bridge.filter.mode"]
Portal.filter_list = config["bridge.filter.list"]
+110 -84
View File
@@ -14,17 +14,19 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional, Awaitable, Pattern, Dict, List, TYPE_CHECKING
from typing import Awaitable, Coroutine, Dict, List, NewType, Optional, Pattern, TYPE_CHECKING
from difflib import SequenceMatcher
import re
import logging
import asyncio
from enum import Enum
from sqlalchemy import orm
from telethon.tl.types import UserProfilePhoto
from telethon.tl.types import UserProfilePhoto, User, FileLocation
from mautrix_appservice import AppService, IntentAPI, IntentError, MatrixRequestError
from .types import MatrixUserID, TelegramID
from .db import Puppet as DBPuppet
from . import util
@@ -32,6 +34,11 @@ if TYPE_CHECKING:
from .matrix import MatrixHandler
from .config import Config
from .context import Context
from . import user as u
from .abstract_user import AbstractUser
PuppetError = Enum('PuppetError', 'Success OnlyLoginSelf InvalidAccessToken')
config = None # type: Config
@@ -45,87 +52,100 @@ class Puppet:
mxid_regex = None # type: Pattern
username_template = None # type: str
hs_domain = None # type: str
cache = {} # type: Dict[str, Puppet]
cache = {} # type: Dict[TelegramID, Puppet]
by_custom_mxid = {} # type: Dict[str, Puppet]
def __init__(self, id=None, access_token=None, custom_mxid=None, username=None,
displayname=None, displayname_source=None, photo_id=None, is_bot=None,
is_registered=False, db_instance=None):
self.id = id
self.access_token = access_token
self.custom_mxid = custom_mxid
self.is_real_user = self.custom_mxid and self.access_token
self.default_mxid = self.get_mxid_from_id(self.id)
self.mxid = self.custom_mxid or self.default_mxid
def __init__(self,
id: TelegramID,
access_token: Optional[str] = None,
custom_mxid: Optional[MatrixUserID] = None,
username: Optional[str] = None,
displayname: Optional[str] = None,
displayname_source: Optional[TelegramID] = None,
photo_id: Optional[str] = None,
is_bot: bool = False,
is_registered: bool = False,
db_instance: Optional[DBPuppet] = None) -> None:
self.id = id # type: TelegramID
self.access_token = access_token # type: Optional[str]
self.custom_mxid = custom_mxid # type: Optional[MatrixUserID]
self.default_mxid = self.get_mxid_from_id(self.id) # type: MatrixUserID
self.username = username
self.displayname = displayname
self.displayname_source = displayname_source
self.photo_id = photo_id
self.is_bot = is_bot
self.is_registered = is_registered
self._db_instance = db_instance
self.username = username # type: Optional[str]
self.displayname = displayname # type: Optional[str]
self.displayname_source = displayname_source # type: Optional[TelegramID]
self.photo_id = photo_id # type: Optional[str]
self.is_bot = is_bot # type: bool
self.is_registered = is_registered # type: bool
self._db_instance = db_instance # type: Optional[DBPuppet]
self.default_mxid_intent = self.az.intent.user(self.default_mxid)
self.intent = None # type: IntentAPI
self.refresh_intents()
self.intent = self._fresh_intent() # type: IntentAPI
self.cache[id] = self
if self.custom_mxid:
self.by_custom_mxid[self.custom_mxid] = self
@property
def tgid(self):
def mxid(self):
return self.custom_mxid or self.default_mxid
@property
def tgid(self) -> TelegramID:
return self.id
@property
def is_real_user(self) -> bool:
""" Is True when the puppet is a real Matrix user. """
return bool(self.custom_mxid and self.access_token)
@staticmethod
async def is_logged_in():
async def is_logged_in() -> bool:
""" Is True if the puppet is logged in. """
return True
# region Custom puppet management
def refresh_intents(self):
self.is_real_user = self.custom_mxid and self.access_token
self.intent = (self.az.intent.user(self.custom_mxid, self.access_token)
if self.is_real_user else self.default_mxid_intent)
def _fresh_intent(self) -> IntentAPI:
return (self.az.intent.user(self.custom_mxid, self.access_token)
if self.is_real_user else self.default_mxid_intent)
async def switch_mxid(self, access_token, mxid):
async def switch_mxid(self, access_token: str, mxid: MatrixUserID) -> PuppetError:
prev_mxid = self.custom_mxid
self.custom_mxid = mxid
self.access_token = access_token
self.refresh_intents()
self.intent = self._fresh_intent()
err = await self.init_custom_mxid()
if err != 0:
if err != PuppetError.Success:
return err
try:
del self.by_custom_mxid[prev_mxid]
del self.by_custom_mxid[prev_mxid] # type: ignore
except KeyError:
pass
self.mxid = self.custom_mxid or self.default_mxid
if self.mxid != self.default_mxid:
self.by_custom_mxid[self.mxid] = self
await self.leave_rooms_with_default_user()
self.save()
return 0
return PuppetError.Success
async def init_custom_mxid(self):
async def init_custom_mxid(self) -> PuppetError:
if not self.is_real_user:
return 0
return PuppetError.Success
mxid = await self.intent.whoami()
if not mxid or mxid != self.custom_mxid:
self.custom_mxid = None
self.access_token = None
self.refresh_intents()
self.intent = self._fresh_intent()
if mxid != self.custom_mxid:
return 2
return 1
return PuppetError.OnlyLoginSelf
return PuppetError.InvalidAccessToken
if config["bridge.sync_with_custom_puppets"]:
asyncio.ensure_future(self.sync(), loop=self.loop)
return 0
return PuppetError.Success
async def leave_rooms_with_default_user(self):
async def leave_rooms_with_default_user(self) -> None:
for room_id in await self.default_mxid_intent.get_joined_rooms():
try:
await self.default_mxid_intent.leave_room(room_id)
@@ -159,7 +179,7 @@ class Puppet:
},
})
def filter_events(self, events):
def filter_events(self, events: List[Dict]) -> List:
new_events = []
for event in events:
evt_type = event.get("type", None)
@@ -186,28 +206,28 @@ class Puppet:
new_events.append(event)
return new_events
def handle_sync(self, presence, ephemeral):
presence = [self.mx.try_handle_event(event) for event in presence]
def handle_sync(self, presence: List, ephemeral: Dict) -> None:
presence_events = [self.mx.try_handle_event(event) for event in presence]
for room_id, events in ephemeral.items():
for event in events:
event["room_id"] = room_id
ephemeral = [self.mx.try_handle_event(event)
for events in ephemeral.values()
for event in self.filter_events(events)]
ephemeral_events = [self.mx.try_handle_event(event)
for events in ephemeral.values()
for event in self.filter_events(events)]
events = ephemeral + presence
events = ephemeral_events + presence_events # List[Callable[[int], Awaitable[None]]]
coro = asyncio.gather(*events, loop=self.loop)
asyncio.ensure_future(coro, loop=self.loop)
async def sync(self):
async def sync(self) -> None:
try:
await self._sync()
except Exception:
self.log.exception("Fatal error syncing")
async def _sync(self):
async def _sync(self) -> None:
if not self.is_real_user:
self.log.warning("Called sync() for non-custom puppet.")
return
@@ -220,13 +240,14 @@ class Puppet:
while access_token_at_start == self.access_token:
try:
sync_resp = await self.intent.client.sync(filter=filter_id, since=next_batch,
set_presence="offline")
set_presence="offline") # type: Dict
errors = 0
if next_batch is not None:
presence = sync_resp.get("presence", {}).get("events", [])
presence = sync_resp.get("presence", {}).get("events", []) # type: List
ephemeral = {room: data.get("ephemeral", {}).get("events", [])
for room, data
in sync_resp.get("rooms", {}).get("join", {}).items()}
in sync_resp.get("rooms", {}).get("join", {}).items()
} # type: Dict
self.handle_sync(presence, ephemeral)
next_batch = sync_resp.get("next_batch", None)
except MatrixRequestError as e:
@@ -241,25 +262,25 @@ class Puppet:
# region DB conversion
@property
def db_instance(self):
def db_instance(self) -> DBPuppet:
if not self._db_instance:
self._db_instance = self.new_db_instance()
return self._db_instance
def new_db_instance(self):
def new_db_instance(self) -> DBPuppet:
return DBPuppet(id=self.id, access_token=self.access_token, custom_mxid=self.custom_mxid,
username=self.username, displayname=self.displayname,
displayname_source=self.displayname_source, photo_id=self.photo_id,
is_bot=self.is_bot, matrix_registered=self.is_registered)
@classmethod
def from_db(cls, db_puppet):
def from_db(cls, db_puppet: DBPuppet) -> 'Puppet':
return Puppet(db_puppet.id, db_puppet.access_token, db_puppet.custom_mxid,
db_puppet.username, db_puppet.displayname, db_puppet.displayname_source,
db_puppet.photo_id, db_puppet.is_bot, db_puppet.matrix_registered,
db_instance=db_puppet)
def save(self):
def save(self) -> None:
self.db_instance.access_token = self.access_token
self.db_instance.custom_mxid = self.custom_mxid
self.db_instance.username = self.username
@@ -272,16 +293,16 @@ class Puppet:
# endregion
# region Info updating
def similarity(self, query):
def similarity(self, query: str) -> int:
username_similarity = (SequenceMatcher(None, self.username, query).ratio()
if self.username else 0)
displayname_similarity = (SequenceMatcher(None, self.displayname, query).ratio()
if self.displayname else 0)
similarity = max(username_similarity, displayname_similarity)
return round(similarity * 1000) / 10
return int(round(similarity * 1000) / 10)
@staticmethod
def get_displayname(info, enable_format=True):
def get_displayname(info: User, enable_format: bool = True) -> str:
data = {
"phone number": info.phone if hasattr(info, "phone") else None,
"username": info.username,
@@ -308,7 +329,7 @@ class Puppet:
return config.get("bridge.displayname_template", "{displayname} (Telegram)").format(
displayname=name)
async def update_info(self, source, info):
async def update_info(self, source: 'AbstractUser', info: User) -> None:
changed = False
if self.username != info.username:
self.username = info.username
@@ -323,24 +344,26 @@ class Puppet:
if changed:
self.save()
async def update_displayname(self, source, info):
async def update_displayname(self, source: 'AbstractUser', info: User) -> bool:
ignore_source = (not source.is_relaybot
and self.displayname_source is not None
and self.displayname_source != source.tgid)
if ignore_source:
return
return False
displayname = self.get_displayname(info)
if displayname != self.displayname:
await self.default_mxid_intent.set_display_name(displayname)
self.displayname = displayname
self.displayname_source = source.tgid
self.displayname_source = TelegramID(source.tgid)
return True
elif source.is_relaybot or self.displayname_source is None:
self.displayname_source = source.tgid
self.displayname_source = TelegramID(source.tgid)
return True
else:
return False
async def update_avatar(self, source, photo):
async def update_avatar(self, source: 'AbstractUser', photo: FileLocation) -> bool:
photo_id = f"{photo.volume_id}-{photo.local_id}"
if self.photo_id != photo_id:
file = await util.transfer_file_to_matrix(self.db, source.client,
@@ -355,7 +378,7 @@ class Puppet:
# region Getters
@classmethod
def get(cls, tgid, create=True) -> "Optional[Puppet]":
def get(cls, tgid: TelegramID, create: bool = True) -> Optional['Puppet']:
try:
return cls.cache[tgid]
except KeyError:
@@ -374,12 +397,15 @@ class Puppet:
return None
@classmethod
def get_by_mxid(cls, mxid, create=True) -> "Optional[Puppet]":
def get_by_mxid(cls, mxid: MatrixUserID, create: bool = True) -> Optional['Puppet']:
tgid = cls.get_id_from_mxid(mxid)
return cls.get(tgid, create) if tgid else None
if tgid:
return cls.get(tgid, create)
return None
@classmethod
def get_by_custom_mxid(cls, mxid):
def get_by_custom_mxid(cls, mxid: MatrixUserID) -> Optional['Puppet']:
if not mxid:
raise ValueError("Matrix ID can't be empty")
@@ -396,25 +422,25 @@ class Puppet:
return None
@classmethod
def get_all_with_custom_mxid(cls):
def get_all_with_custom_mxid(cls) -> List['Puppet']:
return [cls.by_custom_mxid[puppet.mxid]
if puppet.custom_mxid in cls.by_custom_mxid
else cls.from_db(puppet)
for puppet in DBPuppet.query.filter(DBPuppet.custom_mxid is not None).all()]
@classmethod
def get_id_from_mxid(cls, mxid):
def get_id_from_mxid(cls, mxid: MatrixUserID) -> Optional[TelegramID]:
match = cls.mxid_regex.match(mxid)
if match:
return int(match.group(1))
return TelegramID(int(match.group(1)))
return None
@classmethod
def get_mxid_from_id(cls, tgid):
return f"@{cls.username_template.format(userid=tgid)}:{cls.hs_domain}"
def get_mxid_from_id(cls, tgid: TelegramID) -> MatrixUserID:
return MatrixUserID(f"@{cls.username_template.format(userid=tgid)}:{cls.hs_domain}")
@classmethod
def find_by_username(cls, username) -> "Optional[Puppet]":
def find_by_username(cls, username: str) -> Optional['Puppet']:
if not username:
return None
@@ -422,14 +448,14 @@ class Puppet:
if puppet.username and puppet.username.lower() == username.lower():
return puppet
puppet = DBPuppet.query.filter(DBPuppet.username == username).one_or_none()
if puppet:
return cls.from_db(puppet)
dbpuppet = DBPuppet.query.filter(DBPuppet.username == username).one_or_none()
if dbpuppet:
return cls.from_db(dbpuppet)
return None
@classmethod
def find_by_displayname(cls, displayname) -> "Optional[Puppet]":
def find_by_displayname(cls, displayname: str) -> Optional['Puppet']:
if not displayname:
return None
@@ -437,17 +463,17 @@ class Puppet:
if puppet.displayname and puppet.displayname == displayname:
return puppet
puppet = DBPuppet.query.filter(DBPuppet.displayname == displayname).one_or_none()
if puppet:
return cls.from_db(puppet)
dbpuppet = DBPuppet.query.filter(DBPuppet.displayname == displayname).one_or_none()
if dbpuppet:
return cls.from_db(dbpuppet)
return None
# endregion
def init(context: "Context") -> List[Awaitable[int]]:
def init(context: 'Context') -> List[Coroutine]: # [None, None, PuppetError]
global config
Puppet.az, Puppet.db, config, Puppet.loop, _ = context
Puppet.az, Puppet.db, config, Puppet.loop, _ = context.core
Puppet.mx = context.mx
Puppet.username_template = config.get("bridge.username_template", "telegram_{userid}")
Puppet.hs_domain = config["homeserver"]["domain"]
@@ -40,7 +40,7 @@ telematrix_db_engine.dispose()
portals = {}
chats = {}
messages = {}
puppets = {}
puppets = {} # Dict[int, Puppet]
for chat_link in chat_links:
if type(chat_link.tg_room) is str:
+15 -13
View File
@@ -20,37 +20,39 @@ from sqlalchemy import orm
from mautrix_appservice import StateStore
from .types import MatrixUserID, MatrixRoomID
from . import puppet as pu
from .db import RoomState, UserProfile
class SQLStateStore(StateStore):
def __init__(self, db):
def __init__(self, db: orm.Session) -> None:
super().__init__()
self.db = db # type: orm.Session
self.profile_cache = {} # type: Dict[Tuple[str, str], UserProfile]
self.room_state_cache = {} # type: Dict[str, RoomState]
@staticmethod
def is_registered(user: str) -> bool:
def is_registered(user: MatrixUserID) -> bool:
puppet = pu.Puppet.get_by_mxid(user)
return puppet.is_registered if puppet else False
@staticmethod
def registered(user: str):
def registered(user: MatrixUserID) -> None:
puppet = pu.Puppet.get_by_mxid(user)
if puppet:
puppet.is_registered = True
puppet.save()
def update_state(self, event: dict):
def update_state(self, event: Dict) -> None:
event_type = event["type"]
if event_type == "m.room.power_levels":
self.set_power_levels(event["room_id"], event["content"])
elif event_type == "m.room.member":
self.set_member(event["room_id"], event["state_key"], event["content"])
def _get_user_profile(self, room_id: str, user_id: str, create: bool = True) -> UserProfile:
def _get_user_profile(self, room_id: MatrixRoomID, user_id: MatrixUserID, create: bool = True
) -> UserProfile:
key = (room_id, user_id)
try:
return self.profile_cache[key]
@@ -67,22 +69,22 @@ class SQLStateStore(StateStore):
self.profile_cache[key] = profile
return profile
def get_member(self, room: str, user: str) -> dict:
def get_member(self, room: MatrixRoomID, user: MatrixUserID) -> Dict:
return self._get_user_profile(room, user).dict()
def set_member(self, room: str, user: str, member: dict):
def set_member(self, room: MatrixRoomID, user: MatrixUserID, member: Dict) -> None:
profile = self._get_user_profile(room, user)
profile.membership = member.get("membership", profile.membership or "leave")
profile.displayname = member.get("displayname", profile.displayname)
profile.avatar_url = member.get("avatar_url", profile.avatar_url)
self.db.commit()
def set_membership(self, room: str, user: str, membership: str):
def set_membership(self, room: MatrixRoomID, user: MatrixUserID, membership: str) -> None:
self.set_member(room, user, {
"membership": membership,
})
def _get_room_state(self, room_id: str, create: bool = True) -> RoomState:
def _get_room_state(self, room_id: MatrixRoomID, create: bool = True) -> RoomState:
try:
return self.room_state_cache[room_id]
except KeyError:
@@ -96,13 +98,13 @@ class SQLStateStore(StateStore):
self.room_state_cache[room_id] = room
return room
def has_power_levels(self, room: str) -> bool:
def has_power_levels(self, room: MatrixRoomID) -> bool:
return self._get_room_state(room).has_power_levels
def get_power_levels(self, room: str) -> dict:
def get_power_levels(self, room: MatrixRoomID) -> Dict:
return self._get_room_state(room).power_levels
def set_power_level(self, room: str, user: str, level: int):
def set_power_level(self, room: MatrixRoomID, user: MatrixUserID, level: int) -> None:
room_state = self._get_room_state(room)
power_levels = room_state.power_levels
if not power_levels:
@@ -114,7 +116,7 @@ class SQLStateStore(StateStore):
room_state.power_levels = power_levels
self.db.commit()
def set_power_levels(self, room: str, content: dict):
def set_power_levels(self, room: MatrixRoomID, content: Dict) -> None:
state = self._get_room_state(room)
state.power_levels = content
self.db.commit()
+5 -1
View File
@@ -14,9 +14,13 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import List, Union, Optional
from telethon import TelegramClient, utils
from telethon.tl.functions.messages import SendMediaRequest
from telethon.tl.types import *
from telethon.tl.types import (
InputMediaUploadedDocument, InputMediaUploadedPhoto, TypeDocumentAttribute, TypeInputMedia,
TypeInputPeer, TypeMessageEntity, TypeMessageMedia, TypePeer)
from telethon.tl import custom
+10
View File
@@ -0,0 +1,10 @@
from typing import Dict, NewType
# MatrixId = NewType('MatrixId', str)
MatrixUserID = NewType('MatrixUserID', str)
MatrixRoomID = NewType('MatrixRoomID', str)
MatrixEventID = NewType('MatrixEventID', str)
MatrixEvent = NewType('MatrixEvent', Dict)
TelegramID = NewType('TelegramID', int)
+51 -44
View File
@@ -14,18 +14,21 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict, Awaitable, Optional, Match, Tuple, TYPE_CHECKING
from typing import Coroutine, Dict, List, Match, NewType, Optional, Tuple, cast, TYPE_CHECKING
import logging
import asyncio
import re
from telethon.tl.types import *
from telethon.tl.types import (
TypeUpdate, UpdateNewMessage, UpdateNewChannelMessage, PeerUser,
UpdateShortChatMessage, UpdateShortMessage)
from telethon.tl.types import User as TLUser
from telethon.tl.types.contacts import ContactsNotModified
from telethon.tl.functions.contacts import GetContactsRequest, SearchRequest
from telethon.tl.functions.account import UpdateStatusRequest
from mautrix_appservice import MatrixRequestError
from .types import MatrixUserID, TelegramID
from .db import User as DBUser, Contact as DBContact, Portal as DBPortal
from .abstract_user import AbstractUser
from . import portal as po, puppet as pu
@@ -36,7 +39,7 @@ if TYPE_CHECKING:
config = None # type: Config
SearchResults = List[Tuple["pu.Puppet", int]]
SearchResult = NewType('SearchResult', Tuple['pu.Puppet', int])
class User(AbstractUser):
@@ -44,23 +47,23 @@ class User(AbstractUser):
by_mxid = {} # type: Dict[str, User]
by_tgid = {} # type: Dict[int, User]
def __init__(self, mxid: str, tgid: Optional[int] = None, username: Optional[str] = None,
db_contacts: Optional[List[DBContact]] = None, saved_contacts: int = 0,
is_bot: bool = False, db_portals: Optional[List[DBPortal]] = None,
db_instance: Optional[DBUser] = None):
def __init__(self, mxid: MatrixUserID, tgid: Optional[TelegramID] = None,
username: Optional[str] = None, db_contacts: Optional[List[DBContact]] = None,
saved_contacts: int = 0, is_bot: bool = False, db_portals: List[DBPortal] = [],
db_instance: Optional[DBUser] = None) -> None:
super().__init__()
self.mxid = mxid # type: str
self.tgid = tgid # type: int
self.mxid = mxid # type: MatrixUserID
self.tgid = tgid # type: TelegramID
self.is_bot = is_bot # type: bool
self.username = username # type: str
self.contacts = [] # type: List[pu.Puppet]
self.saved_contacts = saved_contacts # type: int
self.db_contacts = db_contacts # type: List[DBContact]
self.portals = {} # type: Dict[Tuple[int, int], po.Portal]
self.db_portals = db_portals # type: List[DBPortal]
self._db_instance = db_instance # type: DBUser
self.db_portals = db_portals or [] # type: List[DBPortal]
self._db_instance = db_instance # type: Optional[DBUser]
self.command_status = None # type: dict
self.command_status = None # type: Dict
(self.relaybot_whitelisted,
self.whitelisted,
@@ -93,7 +96,7 @@ class User(AbstractUser):
for puppet in self.contacts]
@db_contacts.setter
def db_contacts(self, contacts: List[DBContact]):
def db_contacts(self, contacts: List[DBContact]) -> None:
self.contacts = [pu.Puppet.get(entry.contact) for entry in contacts] if contacts else []
@property
@@ -101,7 +104,7 @@ class User(AbstractUser):
return [portal.db_instance for portal in self.portals.values() if not portal.deleted]
@db_portals.setter
def db_portals(self, portals: List[DBPortal]):
def db_portals(self, portals: List[DBPortal]) -> None:
self.portals = {(portal.tgid, portal.tg_receiver):
po.Portal.get_by_tgid(portal.tgid, portal.tg_receiver)
for portal in portals} if portals else {}
@@ -119,7 +122,7 @@ class User(AbstractUser):
contacts=self.db_contacts, saved_contacts=self.saved_contacts or 0,
portals=self.db_portals)
def save(self):
def save(self) -> None:
self.db_instance.tgid = self.tgid
self.db_instance.username = self.username
self.db_instance.contacts = self.db_contacts
@@ -127,7 +130,7 @@ class User(AbstractUser):
self.db_instance.portals = self.db_portals
self.db.commit()
def delete(self):
def delete(self) -> None:
try:
del self.by_mxid[self.mxid]
del self.by_tgid[self.tgid]
@@ -138,14 +141,14 @@ class User(AbstractUser):
self.db.commit()
@classmethod
def from_db(cls, db_user: DBUser) -> "User":
def from_db(cls, db_user: DBUser) -> 'User':
return User(db_user.mxid, db_user.tgid, db_user.tg_username, db_user.contacts,
False, db_user.saved_contacts, db_user.portals, db_instance=db_user)
# endregion
# region Telegram connection management
async def start(self, delete_unless_authenticated: bool = False) -> "User":
async def start(self, delete_unless_authenticated: bool = False) -> 'User':
await super().start()
if await self.is_logged_in():
self.log.debug(f"Ensuring post_login() for {self.name}")
@@ -156,7 +159,7 @@ class User(AbstractUser):
self.client.session.delete()
return self
async def post_login(self, info: TLUser = None):
async def post_login(self, info: TLUser = None) -> None:
try:
await self.update_info(info)
if not self.is_bot:
@@ -167,9 +170,9 @@ class User(AbstractUser):
except Exception:
self.log.exception("Failed to run post-login functions for %s", self.mxid)
async def update(self, update: TypeUpdate):
async def update(self, update: TypeUpdate) -> bool:
if not self.is_bot:
return
return False
if isinstance(update, (UpdateNewMessage, UpdateNewChannelMessage)):
message = update.message
@@ -183,22 +186,25 @@ class User(AbstractUser):
elif isinstance(update, UpdateShortMessage):
portal = po.Portal.get_by_tgid(update.user_id, self.tgid, "user")
else:
return
return False
self.register_portal(portal)
if portal:
self.register_portal(portal)
return True
# endregion
# region Telegram actions that need custom methods
def ensure_started(self, even_if_no_session: bool = False) -> "Awaitable[User]":
return super().ensure_started(even_if_no_session)
def ensure_started(self, even_if_no_session: bool = False) -> Coroutine[None, None, 'User']:
return cast(Coroutine[None, None, 'User'], super().ensure_started(even_if_no_session))
def set_presence(self, online: bool = True):
def set_presence(self, online: bool = True) -> bool:
if self.is_bot:
return
return False
return self.client(UpdateStatusRequest(offline=not online))
async def update_info(self, info: TLUser = None):
async def update_info(self, info: TLUser = None) -> None:
info = info or await self.client.get_me()
changed = False
if self.is_bot != info.bot:
@@ -213,7 +219,7 @@ class User(AbstractUser):
if changed:
self.save()
async def log_out(self):
async def log_out(self) -> bool:
puppet = pu.Puppet.get(self.tgid)
if puppet.is_real_user:
await puppet.switch_mxid(None, None)
@@ -241,28 +247,29 @@ class User(AbstractUser):
return True
def _search_local(self, query: str, max_results: int = 5, min_similarity: int = 45
) -> SearchResults:
results = [] # type: SearchResults
) -> List[SearchResult]:
results = [] # type: List[SearchResult]
for contact in self.contacts:
similarity = contact.similarity(query)
if similarity >= min_similarity:
results.append((contact, similarity))
results.append(SearchResult((contact, similarity)))
results.sort(key=lambda tup: tup[1], reverse=True)
return results[0:max_results]
async def _search_remote(self, query: str, max_results: int = 5) -> SearchResults:
async def _search_remote(self, query: str, max_results: int = 5) -> List[SearchResult]:
if len(query) < 5:
return []
server_results = await self.client(SearchRequest(q=query, limit=max_results))
results = [] # type: SearchResults
results = [] # type: List[SearchResult]
for user in server_results.users:
puppet = pu.Puppet.get(user.id)
await puppet.update_info(self, user)
results.append((puppet, puppet.similarity(query)))
results.append(SearchResult((puppet, puppet.similarity(query))))
results.sort(key=lambda tup: tup[1], reverse=True)
return results[0:max_results]
async def search(self, query: str, force_remote: bool = False) -> Tuple[SearchResults, bool]:
async def search(self, query: str, force_remote: bool = False
) -> Tuple[List[SearchResult], bool]:
if force_remote:
return await self._search_remote(query), True
@@ -272,7 +279,7 @@ class User(AbstractUser):
return await self._search_remote(query), True
async def sync_dialogs(self, synchronous_create: bool = False):
async def sync_dialogs(self, synchronous_create: bool = False) -> None:
creators = []
for entity in await self.get_dialogs(limit=30):
portal = po.Portal.get_by_entity(entity)
@@ -283,7 +290,7 @@ class User(AbstractUser):
self.save()
await asyncio.gather(*creators, loop=self.loop)
def register_portal(self, portal: po.Portal):
def register_portal(self, portal: po.Portal) -> None:
try:
if self.portals[portal.tgid_full] == portal:
return
@@ -292,7 +299,7 @@ class User(AbstractUser):
self.portals[portal.tgid_full] = portal
self.save()
def unregister_portal(self, portal: po.Portal):
def unregister_portal(self, portal: po.Portal) -> None:
try:
del self.portals[portal.tgid_full]
self.save()
@@ -309,7 +316,7 @@ class User(AbstractUser):
acc = (acc * 20261 + id) & 0xffffffff
return acc & 0x7fffffff
async def sync_contacts(self):
async def sync_contacts(self) -> None:
response = await self.client(GetContactsRequest(hash=self._hash_contacts()))
if isinstance(response, ContactsNotModified):
return
@@ -326,7 +333,7 @@ class User(AbstractUser):
# region Class instance lookup
@classmethod
def get_by_mxid(cls, mxid: str, create: bool=True) -> "Optional[User]":
def get_by_mxid(cls, mxid: MatrixUserID, create: bool = True) -> Optional['User']:
if not mxid:
raise ValueError("Matrix ID can't be empty")
@@ -349,7 +356,7 @@ class User(AbstractUser):
return None
@classmethod
def get_by_tgid(cls, tgid: int) -> "Optional[User]":
def get_by_tgid(cls, tgid: int) -> Optional['User']:
try:
return cls.by_tgid[tgid]
except KeyError:
@@ -363,7 +370,7 @@ class User(AbstractUser):
return None
@classmethod
def find_by_username(cls, username: str) -> "Optional[User]":
def find_by_username(cls, username: str) -> Optional['User']:
if not username:
return None
@@ -379,7 +386,7 @@ class User(AbstractUser):
# endregion
def init(context: "Context") -> List[Awaitable[User]]:
def init(context: 'Context') -> List[Coroutine]: # [None, None, AbstractUser]
global config
config = context.config
+2 -1
View File
@@ -27,7 +27,8 @@ from sqlalchemy.orm.exc import FlushError
from telethon.tl.types import (Document, FileLocation, InputFileLocation,
InputDocumentFileLocation, PhotoSize, PhotoCachedSize)
from telethon.errors import *
from telethon.errors import (AuthBytesInvalidError, AuthKeyInvalidError, LocationInvalidError,
SecurityError)
from mautrix_appservice import IntentAPI
from ..tgclient import MautrixTelegramClient
+2 -2
View File
@@ -17,10 +17,10 @@
def format_duration(seconds: int) -> str:
def pluralize(count, singular):
def pluralize(count: int, singular: str) -> str:
return singular if count == 1 else singular + "s"
def include(count, word):
def include(count: int, word: str) -> str:
return f"{count} {pluralize(count, word)}" if count > 0 else ""
minutes, seconds = divmod(seconds, 60)
+6 -6
View File
@@ -14,7 +14,7 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional
from typing import Dict, Optional
import json
import base64
import hashlib
@@ -28,13 +28,13 @@ def _get_checksum(key: str, payload: bytes) -> str:
return checksum
def sign_token(key: str, payload: dict) -> str:
payload = base64.urlsafe_b64encode(json.dumps(payload).encode("utf-8"))
checksum = _get_checksum(key, payload)
return f"{checksum}:{payload.decode('utf-8')}"
def sign_token(key: str, payload: Dict) -> str:
payload_b64 = base64.urlsafe_b64encode(json.dumps(payload).encode("utf-8"))
checksum = _get_checksum(key, payload_b64)
return f"{checksum}:{payload_b64.decode('utf-8')}"
def verify_token(key: str, data: str) -> Optional[dict]:
def verify_token(key: str, data: str) -> Optional[Dict]:
if not data:
return None
+4 -3
View File
@@ -23,7 +23,7 @@ from telethon.errors import *
from ...commands.auth import enter_password
from ...util import format_duration
from ...puppet import Puppet
from ...puppet import Puppet, PuppetError
from ...user import User
@@ -51,12 +51,13 @@ class AuthAPI(abc.ABC):
"account.", errcode="already-logged-in")
resp = await puppet.switch_mxid(token, user.mxid)
if resp == 2:
if resp == PuppetError.OnlyLoginSelf:
return self.get_mx_login_response(status=403, errcode="only-login-self",
error="You can only log in as your own Matrix user.")
elif resp == 1:
elif resp == PuppetError.InvalidAccessToken:
return self.get_mx_login_response(status=401, errcode="invalid-access-token",
error="Failed to verify access token.")
assert resp == PuppetError.Success, "Encountered an unhandled PuppetError."
return self.get_mx_login_response(mxid=user.mxid, status=200, state="logged-in")
+10 -9
View File
@@ -15,7 +15,7 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from aiohttp import web
from typing import Tuple, Optional, Callable, Awaitable, TYPE_CHECKING
from typing import Awaitable, Callable, Dict, Optional, Tuple, TYPE_CHECKING
import asyncio
import logging
import json
@@ -24,6 +24,7 @@ from telethon.utils import get_peer_id, resolve_id
from telethon.tl.types import ChatForbidden, ChannelForbidden, TypeChat
from mautrix_appservice import AppService, MatrixRequestError, IntentError
from ...types import MatrixUserID, TelegramID
from ...user import User
from ...portal import Portal
from ...commands.portal import user_has_power_level, get_initial_state
@@ -36,7 +37,7 @@ if TYPE_CHECKING:
class ProvisioningAPI(AuthAPI):
log = logging.getLogger("mau.web.provisioning")
def __init__(self, context: "Context"):
def __init__(self, context: "Context") -> None:
super().__init__(context.loop)
self.secret = context.config["appservice.provisioning.shared_secret"]
self.az = context.az # type: AppService
@@ -118,10 +119,10 @@ class ProvisioningAPI(AuthAPI):
chat_id = request.match_info["chat_id"]
if chat_id.startswith("-100"):
tgid = int(chat_id[4:])
tgid = TelegramID(int(chat_id[4:]))
peer_type = "channel"
elif chat_id.startswith("-"):
tgid = -int(chat_id)
tgid = TelegramID(-int(chat_id))
peer_type = "chat"
else:
return self.get_error_response(400, "tgid_invalid", "Invalid Telegram chat ID.")
@@ -153,14 +154,14 @@ class ProvisioningAPI(AuthAPI):
"Matrix room.")
is_logged_in = user is not None and await user.is_logged_in()
user = user if is_logged_in else self.context.bot
if not user:
acting_user = user if is_logged_in else self.context.bot
if not acting_user:
return self.get_login_response(status=403, errcode="not_logged_in",
error="You are not logged in and there is no relay bot.")
entity = None # type: Optional[TypeChat]
try:
entity = await user.client.get_entity(portal.peer)
entity = await acting_user.client.get_entity(portal.peer)
except Exception:
self.log.exception("Failed to get_entity(%s) for manual bridging.", portal.peer)
@@ -411,7 +412,7 @@ class ProvisioningAPI(AuthAPI):
except json.JSONDecodeError:
return None
async def get_user(self, mxid: str, expect_logged_in: Optional[bool] = False,
async def get_user(self, mxid: MatrixUserID, expect_logged_in: Optional[bool] = False,
require_puppeting: bool = True, require_user: bool = True
) -> Tuple[Optional[User], Optional[web.Response]]:
if not mxid:
@@ -439,7 +440,7 @@ class ProvisioningAPI(AuthAPI):
expect_logged_in: Optional[bool] = False,
require_puppeting: bool = False,
want_data: bool = True,
) -> (Tuple[Optional[dict],
) -> (Tuple[Optional[Dict],
Optional[User],
Optional[web.Response]]):
err = self.check_authorization(request)