Add missing type hints and fix most type errors except for Optionals.

This commit is contained in:
Kai A. Hiller
2018-08-09 02:19:55 +02:00
parent 01e153662e
commit 0f8009b1e9
26 changed files with 505 additions and 384 deletions
+2 -2
View File
@@ -14,7 +14,7 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional
from typing import Coroutine, List, Optional
import argparse
import asyncio
import logging.config
@@ -115,7 +115,7 @@ with appserv.run(config["appservice.hostname"], config["appservice.port"]) as st
startup_actions = (init_puppet(context) +
init_user(context) +
[start,
context.mx.init_as_bot()])
context.mx.init_as_bot()]) # type: List[Coroutine]
if context.bot:
startup_actions.append(context.bot.start())
+4 -2
View File
@@ -38,6 +38,7 @@ from .db import Message as DBMessage
from .tgclient import MautrixTelegramClient
if TYPE_CHECKING:
from .types import TelegramId
from .context import Context
from .config import Config
from .bot import Bot
@@ -67,10 +68,11 @@ class AbstractUser(ABC):
self.whitelisted = False # type: bool
self.relaybot_whitelisted = False # type: bool
self.client = None # type: MautrixTelegramClient
self.tgid = None # type: int
self.tgid = None # type: TelegramId
self.mxid = None # type: str
self.is_relaybot = False # type: bool
self.is_bot = False # type: bool
self.relaybot = None # type: Optional[Bot]
@property
def connected(self) -> bool:
@@ -372,7 +374,7 @@ class AbstractUser(ABC):
def init(context: "Context") -> None:
global config, MAX_DELETIONS
AbstractUser.az, AbstractUser.db, config, AbstractUser.loop, AbstractUser.relaybot = context
AbstractUser.az, AbstractUser.db, config, AbstractUser.loop, AbstractUser.relaybot = context.core
AbstractUser.ignore_incoming_bot_events = config["bridge.relaybot.ignore_own_incoming_events"]
AbstractUser.session_container = context.session_container
MAX_DELETIONS = config.get("bridge.max_telegram_delete", 10)
+17 -12
View File
@@ -14,7 +14,7 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Awaitable, Callable, Dict, Optional, Pattern, TYPE_CHECKING
from typing import Awaitable, Callable, Dict, List, Optional, Pattern, TYPE_CHECKING
import logging
import re
@@ -27,12 +27,14 @@ from telethon.tl.functions.messages import GetChatsRequest, GetFullChatRequest
from telethon.tl.functions.channels import GetChannelsRequest, GetParticipantRequest
from telethon.errors import ChannelInvalidError, ChannelPrivateError
from .types import MatrixUserId
from .abstract_user import AbstractUser
from .db import BotChat
from . import puppet as pu, portal as po, user as u
if TYPE_CHECKING:
from .config import Config
from .context import Context
config = None # type: Config
@@ -145,6 +147,7 @@ class Bot(AbstractUser):
for p in participants:
if p.user_id == tgid:
return isinstance(p, (ChatParticipantCreator, ChatParticipantAdmin))
return False
async def check_can_use_commands(self, event: Message, reply: ReplyFunc) -> bool:
if not await self._can_use_commands(event.to_id, event.from_id):
@@ -168,15 +171,16 @@ class Bot(AbstractUser):
return await reply(
"Portal is not public. Use `/invite <mxid>` to get an invite.")
async def handle_command_invite(self, portal: po.Portal, reply: ReplyFunc, mxid: str) -> None:
if len(mxid) == 0:
async def handle_command_invite(self, portal: po.Portal, reply: ReplyFunc,
mxid_input: MatrixUserId) -> Message:
if len(mxid_input) == 0:
return await reply("Usage: `/invite <mxid>`")
elif not portal.mxid:
return await reply("Portal does not have Matrix room. "
"Create one with /portal first.")
if not self.mxid_regex.match(mxid):
if not self.mxid_regex.match(mxid_input):
return await reply("That doesn't look like a Matrix ID.")
user = await u.User.get_by_mxid(mxid).ensure_started()
user = await u.User.get_by_mxid(MatrixUserId(mxid_input)).ensure_started()
if not user.relaybot_whitelisted:
return await reply("That user is not whitelisted to use the bridge.")
elif await user.is_logged_in():
@@ -187,7 +191,7 @@ class Bot(AbstractUser):
await portal.main_intent.invite(portal.mxid, user.mxid)
return await reply(f"Invited `{user.mxid}` to the portal.")
def handle_command_id(self, message: Message, reply: ReplyFunc) -> None:
def handle_command_id(self, message: Message, reply: ReplyFunc) -> Awaitable[Message]:
# Provide the prefixed ID to the user so that the user wouldn't need to specify whether the
# chat is a normal group or a supergroup/channel when using the ID.
if isinstance(message.to_id, PeerChannel):
@@ -210,7 +214,7 @@ class Bot(AbstractUser):
return False
async def handle_command(self, message: Message) -> None:
def reply(reply_text) -> None:
def reply(reply_text: str) -> Awaitable[Message]:
return self.client.send_message(message.to_id, reply_text, reply_to=message.id)
text = message.message
@@ -231,7 +235,7 @@ class Bot(AbstractUser):
mxid = text[text.index(" ") + 1:]
except ValueError:
mxid = ""
await self.handle_command_invite(portal, reply, mxid=mxid)
await self.handle_command_invite(portal, reply, mxid_input=mxid)
def handle_service_message(self, message: MessageService) -> None:
to_id = message.to_id
@@ -250,11 +254,12 @@ class Bot(AbstractUser):
elif isinstance(action, MessageActionChatDeleteUser) and action.user_id == self.tgid:
self.remove_chat(to_id)
async def update(self, update) -> None:
async def update(self, update) -> bool:
if not isinstance(update, (UpdateNewMessage, UpdateNewChannelMessage)):
return
return False
if isinstance(update.message, MessageService):
return self.handle_service_message(update.message)
self.handle_service_message(update.message)
return False
is_command = (isinstance(update.message, Message)
and update.message.entities and len(update.message.entities) > 0
@@ -270,7 +275,7 @@ class Bot(AbstractUser):
return "bot"
def init(context) -> Optional[Bot]:
def init(context: 'Context') -> Optional[Bot]:
global config
config = context.config
token = config["telegram.bot_token"]
+24 -18
View File
@@ -14,7 +14,7 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict
from typing import Any, Awaitable, Dict, Optional
import asyncio
from telethon.errors import (
@@ -31,7 +31,7 @@ from ..util import format_duration
@command_handler(needs_auth=False,
help_section=SECTION_AUTH,
help_text="Check if you're logged into Telegram.")
async def ping(evt: CommandEvent) -> None:
async def ping(evt: CommandEvent) -> Optional[Dict]:
me = await evt.sender.client.get_me() if await evt.sender.is_logged_in() else None
if me:
return await evt.reply(f"You're logged in as @{me.username}")
@@ -42,7 +42,7 @@ async def ping(evt: CommandEvent) -> None:
@command_handler(needs_auth=False, needs_puppeting=False,
help_section=SECTION_AUTH,
help_text="Get the info of the message relay Telegram bot.")
async def ping_bot(evt: CommandEvent) -> None:
async def ping_bot(evt: CommandEvent) -> Optional[Dict]:
if not evt.tgbot:
return await evt.reply("Telegram message relay bot not configured.")
bot_info = await evt.tgbot.client.get_me()
@@ -57,19 +57,19 @@ async def ping_bot(evt: CommandEvent) -> None:
help_section=SECTION_AUTH,
help_text="Revert your Telegram account's Matrix puppet to use the default Matrix "
"account.")
async def logout_matrix(evt: CommandEvent) -> None:
async def logout_matrix(evt: CommandEvent) -> Optional[Dict]:
puppet = pu.Puppet.get(evt.sender.tgid)
if not puppet.is_real_user:
return await evt.reply("You are not logged in with your Matrix account.")
await puppet.switch_mxid(None, None)
await evt.reply("Reverted your Telegram account's Matrix puppet back to the default.")
return await evt.reply("Reverted your Telegram account's Matrix puppet back to the default.")
@command_handler(needs_auth=True, management_only=True, needs_matrix_puppeting=True,
help_section=SECTION_AUTH,
help_text="Replace your Telegram account's Matrix puppet with your own Matrix "
"account")
async def login_matrix(evt: CommandEvent) -> None:
async def login_matrix(evt: CommandEvent) -> Optional[Dict]:
puppet = pu.Puppet.get(evt.sender.tgid)
if puppet.is_real_user:
return await evt.reply("You have already logged in with your Matrix account. "
@@ -100,7 +100,7 @@ async def login_matrix(evt: CommandEvent) -> None:
return await evt.reply("This bridge instance has been configured to not allow logging in.")
async def enter_matrix_token(evt: CommandEvent) -> None:
async def enter_matrix_token(evt: CommandEvent) -> Dict:
evt.sender.command_status = None
puppet = pu.Puppet.get(evt.sender.tgid)
@@ -109,10 +109,11 @@ async def enter_matrix_token(evt: CommandEvent) -> None:
"Log out with `$cmdprefix+sp logout-matrix` first.")
resp = await puppet.switch_mxid(" ".join(evt.args), evt.sender.mxid)
if resp == 2:
if resp == pu.PuppetError.OnlyLoginSelf:
return await evt.reply("You can only log in as your own Matrix user.")
elif resp == 1:
elif resp == pu.PuppetError.InvalidAccessToken:
return await evt.reply("Failed to verify access token.")
assert resp == pu.PuppetError.Success, "Encountered an unhandled PuppetError."
return await evt.reply(
f"Replaced your Telegram account's Matrix puppet with {puppet.custom_mxid}.")
@@ -121,7 +122,7 @@ async def enter_matrix_token(evt: CommandEvent) -> None:
help_section=SECTION_AUTH,
help_args="<_phone_> <_full name_>",
help_text="Register to Telegram")
async def register(evt: CommandEvent) -> None:
async def register(evt: CommandEvent) -> Optional[Dict]:
if await evt.sender.is_logged_in():
return await evt.reply("You are already logged in.")
elif len(evt.args) < 1:
@@ -138,9 +139,10 @@ async def register(evt: CommandEvent) -> None:
"action": "Register",
"full_name": full_name,
})
return None
async def enter_code_register(evt: CommandEvent) -> None:
async def enter_code_register(evt: CommandEvent) -> Dict:
if len(evt.args) == 0:
return await evt.reply("**Usage:** `$cmdprefix+sp <code>`")
try:
@@ -169,7 +171,7 @@ async def enter_code_register(evt: CommandEvent) -> None:
@command_handler(needs_auth=False, management_only=True,
help_section=SECTION_AUTH,
help_text="Get instructions on how to log in.")
async def login(evt: CommandEvent) -> None:
async def login(evt: CommandEvent) -> Optional[Dict]:
if await evt.sender.is_logged_in():
return await evt.reply("You are already logged in.")
@@ -200,7 +202,8 @@ async def login(evt: CommandEvent) -> None:
return await evt.reply("This bridge instance has been configured to not allow logging in.")
async def request_code(evt: CommandEvent, phone_number: str, next_status: Dict[str, str]) -> None:
async def request_code(evt: CommandEvent, phone_number: str, next_status: Dict[str, Any]
) -> Dict:
ok = False
try:
await evt.sender.ensure_started(even_if_no_session=True)
@@ -232,7 +235,7 @@ async def request_code(evt: CommandEvent, phone_number: str, next_status: Dict[s
@command_handler(needs_auth=False)
async def enter_phone_or_token(evt: CommandEvent) -> None:
async def enter_phone_or_token(evt: CommandEvent) -> Optional[Dict]:
if len(evt.args) == 0:
return await evt.reply("**Usage:** `$cmdprefix+sp enter-phone-or-token <phone-or-token>`")
elif not evt.config.get("bridge.allow_matrix_login", True):
@@ -252,10 +255,11 @@ async def enter_phone_or_token(evt: CommandEvent) -> None:
"next": enter_code,
"action": "Login",
})
return None
@command_handler(needs_auth=False)
async def enter_code(evt: CommandEvent) -> None:
async def enter_code(evt: CommandEvent) -> Optional[Dict]:
if len(evt.args) == 0:
return await evt.reply("**Usage:** `$cmdprefix+sp enter-code <code>`")
elif not evt.config.get("bridge.allow_matrix_login", True):
@@ -267,10 +271,11 @@ async def enter_code(evt: CommandEvent) -> None:
evt.log.exception("Error sending phone code")
return await evt.reply("Unhandled exception while sending code. "
"Check console for more details.")
return None
@command_handler(needs_auth=False)
async def enter_password(evt: CommandEvent) -> None:
async def enter_password(evt: CommandEvent) -> Optional[Dict]:
if len(evt.args) == 0:
return await evt.reply("**Usage:** `$cmdprefix+sp enter-password <password>`")
elif not evt.config.get("bridge.allow_matrix_login", True):
@@ -286,9 +291,10 @@ async def enter_password(evt: CommandEvent) -> None:
evt.log.exception("Error sending password")
return await evt.reply("Unhandled exception while sending password. "
"Check console for more details.")
return None
async def sign_in(evt: CommandEvent, **sign_in_info) -> None:
async def sign_in(evt: CommandEvent, **sign_in_info) -> Dict:
try:
await evt.sender.ensure_started(even_if_no_session=True)
user = await evt.sender.client.sign_in(**sign_in_info)
@@ -313,7 +319,7 @@ async def sign_in(evt: CommandEvent, **sign_in_info) -> None:
@command_handler(needs_auth=True,
help_section=SECTION_AUTH,
help_text="Log out from Telegram.")
async def logout(evt: CommandEvent) -> None:
async def logout(evt: CommandEvent) -> Optional[Dict]:
if await evt.sender.log_out():
return await evt.reply("Logged out successfully.")
return await evt.reply("Failed to log out.")
+15 -14
View File
@@ -14,21 +14,21 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Tuple, List
from typing import Dict, List, NewType, Optional, Tuple, Union
from mautrix_appservice import MatrixRequestError, IntentAPI
from ..types import MatrixRoomId, MatrixUserId
from . import command_handler, CommandEvent, SECTION_ADMIN
from .. import puppet as pu, portal as po
ManagementRoomList = List[Tuple[str, str]]
RoomIDList = List[str]
ManagementRoom = NewType('ManagementRoom', Tuple[MatrixRoomId, MatrixUserId])
async def _find_rooms(intent: IntentAPI) -> Tuple[ManagementRoomList, RoomIDList,
async def _find_rooms(intent: IntentAPI) -> Tuple[List[ManagementRoom], List[MatrixRoomId],
List["po.Portal"], List["po.Portal"]]:
management_rooms = [] # type: ManagementRoomList
unidentified_rooms = [] # type: RoomIDList
management_rooms = [] # type: List[ManagementRoom]
unidentified_rooms = [] # type: List[MatrixRoomId]
portals = [] # type: List[po.Portal]
empty_portals = [] # type: List[po.Portal]
@@ -45,7 +45,7 @@ async def _find_rooms(intent: IntentAPI) -> Tuple[ManagementRoomList, RoomIDList
if pu.Puppet.get_id_from_mxid(other_member):
unidentified_rooms.append(room)
else:
management_rooms.append((room, other_member))
management_rooms.append(ManagementRoom((room, other_member)))
else:
unidentified_rooms.append(room)
else:
@@ -61,7 +61,7 @@ async def _find_rooms(intent: IntentAPI) -> Tuple[ManagementRoomList, RoomIDList
@command_handler(needs_admin=True, needs_auth=False, management_only=True, name="clean-rooms",
help_section=SECTION_ADMIN,
help_text="Clean up unused portal/management rooms.")
async def clean_rooms(evt: CommandEvent) -> None:
async def clean_rooms(evt: CommandEvent) -> Optional[Dict]:
management_rooms, unidentified_rooms, portals, empty_portals = await _find_rooms(evt.az.intent)
reply = ["#### Management rooms (M)"]
@@ -106,13 +106,14 @@ async def clean_rooms(evt: CommandEvent) -> None:
return await evt.reply("\n".join(reply))
async def set_rooms_to_clean(evt, management_rooms: ManagementRoomList,
unidentified_rooms: RoomIDList, portals: List["po.Portal"],
async def set_rooms_to_clean(evt, management_rooms: List[ManagementRoom],
unidentified_rooms: List[MatrixRoomId], portals: List["po.Portal"],
empty_portals: List["po.Portal"]) -> None:
command = evt.args[0]
rooms_to_clean = []
rooms_to_clean = [] # type: List[Union[po.Portal, MatrixRoomId]]
if command == "clean-recommended":
rooms_to_clean = empty_portals + unidentified_rooms
rooms_to_clean += empty_portals
rooms_to_clean += unidentified_rooms
elif command == "clean-groups":
if len(evt.args) < 2:
return await evt.reply("**Usage:** `$cmdprefix+sp clean-groups [M][A][U][I]")
@@ -158,7 +159,7 @@ async def set_rooms_to_clean(evt, management_rooms: ManagementRoomList,
"`$cmdprefix+sp confirm-clean`.")
async def execute_room_cleanup(evt, rooms_to_clean) -> None:
async def execute_room_cleanup(evt, rooms_to_clean: List[Union[po.Portal, MatrixRoomId]]) -> None:
if len(evt.args) > 0 and evt.args[0] == "confirm-clean":
await evt.reply(f"Cleaning {len(rooms_to_clean)} rooms. "
"This might take a while.")
@@ -167,7 +168,7 @@ async def execute_room_cleanup(evt, rooms_to_clean) -> None:
if isinstance(room, po.Portal):
await room.cleanup_and_delete()
cleaned += 1
elif isinstance(room, str):
elif isinstance(room, str): # str is aliased by MatrixRoomId
await po.Portal.cleanup_room(evt.az.intent, room, message="Room deleted")
cleaned += 1
evt.sender.command_status = None
+30 -18
View File
@@ -14,19 +14,20 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import List, Dict, Callable, Optional
from typing import Any, Awaitable, Callable, Coroutine, Dict, List, NamedTuple, Optional, Union
from collections import namedtuple
import markdown
import logging
from telethon.errors import FloodWaitError
from ..types import MatrixRoomId
from ..util import format_duration
from .. import user as u, context as c
command_handlers = {} # type: Dict[str, CommandHandler]
HelpSection = namedtuple("HelpSection", "name order description")
HelpSection = NamedTuple('HelpSection', [('name', str), ('order', int), ('description', str)])
SECTION_GENERAL = HelpSection("General", 0, "")
SECTION_AUTH = HelpSection("Authentication", 10, "")
@@ -37,8 +38,8 @@ SECTION_ADMIN = HelpSection("Administration", 50, "")
class CommandEvent:
def __init__(self, processor: "CommandProcessor", room: str, sender: u.User, command: str,
args: List[str], is_management: bool, is_portal: bool) -> None:
def __init__(self, processor: 'CommandProcessor', room: MatrixRoomId, sender: u.User,
command: str, args: List[str], is_management: bool, is_portal: bool) -> None:
self.az = processor.az
self.log = processor.log
self.loop = processor.loop
@@ -53,7 +54,8 @@ class CommandEvent:
self.is_management = is_management
self.is_portal = is_portal
def reply(self, message: str, allow_html: bool = False, render_markdown: bool = True) -> None:
def reply(self, message: str, allow_html: bool = False, render_markdown: bool = True
) -> Awaitable[Dict]:
message = message.replace("$cmdprefix+sp ",
"" if self.is_management else f"{self.command_prefix} ")
message = message.replace("$cmdprefix", self.command_prefix)
@@ -66,7 +68,7 @@ class CommandEvent:
class CommandHandler:
def __init__(self, handler: Callable[[CommandEvent], None], needs_auth: bool,
def __init__(self, handler: Callable[[CommandEvent], Awaitable[Dict]], needs_auth: bool,
needs_puppeting: bool, needs_matrix_puppeting: bool, needs_admin: bool,
management_only: bool, name: str, help_text: str, help_args: str,
help_section: HelpSection) -> None:
@@ -103,7 +105,8 @@ class CommandHandler:
(not self.needs_admin or is_admin) and
(not self.needs_auth or is_logged_in))
async def __call__(self, evt: CommandEvent) -> None:
async def __call__(self, evt: CommandEvent
) -> Dict:
error = await self.get_permission_error(evt)
if error is not None:
return await evt.reply(error)
@@ -118,13 +121,21 @@ class CommandHandler:
return f"**{self.name}** {self._help_args} - {self._help_text}"
def command_handler(_func: Optional[Callable[[CommandEvent], None]] = None, *, needs_auth=True,
needs_puppeting=True, needs_matrix_puppeting=False, needs_admin=False,
management_only=False, name=None, help_text="", help_args="",
help_section=None) -> None:
def command_handler(_func: Optional[Callable[[CommandEvent], Awaitable[Dict]]] = None, *,
needs_auth: bool = True,
needs_puppeting: bool = True,
needs_matrix_puppeting: bool = False,
needs_admin: bool = False,
management_only: bool = False,
name: Optional[str] = None,
help_text: str = "",
help_args: str = "",
help_section: HelpSection = None
) -> Callable[[Callable[[CommandEvent], Awaitable[Optional[Dict]]]],
CommandHandler]:
input_name = name
def decorator(func: Callable[[CommandEvent], None]) -> None:
def decorator(func: Callable[[CommandEvent], Awaitable[Optional[Dict]]]) -> CommandHandler:
name = input_name or func.__name__.replace("_", "-")
handler = CommandHandler(func, needs_auth, needs_puppeting, needs_matrix_puppeting,
needs_admin, management_only, name, help_text, help_args,
@@ -139,26 +150,26 @@ class CommandProcessor:
log = logging.getLogger("mau.commands")
def __init__(self, context: c.Context) -> None:
self.az, self.db, self.config, self.loop, self.tgbot = context
self.az, self.db, self.config, self.loop, self.tgbot = context.core
self.public_website = context.public_website
self.command_prefix = self.config["bridge.command_prefix"]
async def handle(self, room: str, sender: u.User, command: str, args: List[str],
is_management: bool, is_portal: bool) -> None:
async def handle(self, room: MatrixRoomId, sender: u.User, command: str, args: List[str],
is_management: bool, is_portal: bool) -> Optional[Dict]:
evt = CommandEvent(self, room, sender, command, args, is_management, is_portal)
orig_command = command
command = command.lower()
try:
command = command_handlers[command]
command_handler = command_handlers[command]
except KeyError:
if sender.command_status and "next" in sender.command_status:
args.insert(0, orig_command)
evt.command = ""
command = sender.command_status["next"]
else:
command = command_handlers["unknown-command"]
command_handler = command_handlers["unknown-command"]
try:
await command(evt)
await command_handler(evt)
except FloodWaitError as e:
return await evt.reply(f"Flood error: Please wait {format_duration(e.seconds)}")
except Exception:
@@ -166,3 +177,4 @@ class CommandProcessor:
f"{evt.command} {' '.join(args)} from {sender.mxid}")
return await evt.reply("Unhandled error while handling command. "
"Check logs for more details.")
return None
+17 -14
View File
@@ -14,46 +14,49 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict, List, Optional, Tuple
from . import command_handler, CommandEvent, _command_handlers, SECTION_GENERAL
from .handler import HelpSection
@command_handler(needs_auth=False, needs_puppeting=False,
help_section=SECTION_GENERAL,
help_text="Cancel an ongoing action (such as login)")
def cancel(evt: CommandEvent) -> None:
async def cancel(evt: CommandEvent) -> Optional[Dict]:
if evt.sender.command_status:
action = evt.sender.command_status["action"]
evt.sender.command_status = None
return evt.reply(f"{action} cancelled.")
return await evt.reply(f"{action} cancelled.")
else:
return evt.reply("No ongoing command.")
return await evt.reply("No ongoing command.")
@command_handler(needs_auth=False, needs_puppeting=False)
def unknown_command(evt: CommandEvent) -> None:
return evt.reply("Unknown command. Try `$cmdprefix+sp help` for help.")
async def unknown_command(evt: CommandEvent) -> Optional[Dict]:
return await evt.reply("Unknown command. Try `$cmdprefix+sp help` for help.")
help_cache = {}
help_cache = {} # type: Dict[Tuple[bool, bool, bool, bool, bool], str]
async def _get_help_text(evt: CommandEvent) -> None:
async def _get_help_text(evt: CommandEvent) -> str:
cache_key = (evt.is_management, evt.sender.puppet_whitelisted,
evt.sender.matrix_puppet_whitelisted, evt.sender.is_admin,
await evt.sender.is_logged_in())
if cache_key not in help_cache:
help = {}
help_sections = {} # type: Dict[HelpSection, List[str]]
for handler in _command_handlers.values():
if handler.has_help and handler.has_permission(*cache_key):
help.setdefault(handler.help_section, [])
help[handler.help_section].append(handler.help + " ")
help = sorted(help.items(), key=lambda item: item[0].order)
help = ["#### {}\n{}\n".format(key.name, "\n".join(value)) for key, value in help]
help_sections.setdefault(handler.help_section, [])
help_sections[handler.help_section].append(handler.help + " ")
help_sorted = sorted(help_sections.items(), key=lambda item: item[0].order)
help = ["#### {}\n{}\n".format(key.name, "\n".join(value)) for key, value in help_sorted]
help_cache[cache_key] = "\n".join(help)
return help_cache[cache_key]
def _get_management_status(evt: CommandEvent) -> None:
def _get_management_status(evt: CommandEvent) -> str:
if evt.is_management:
return "This is a management room: prefixing commands with `$cmdprefix` is not required."
elif evt.is_portal:
@@ -65,5 +68,5 @@ def _get_management_status(evt: CommandEvent) -> None:
@command_handler(needs_auth=False, needs_puppeting=False,
help_section=SECTION_GENERAL,
help_text="Show this help message.")
async def help(evt: CommandEvent) -> None:
async def help(evt: CommandEvent) -> Optional[Dict]:
return await evt.reply(_get_management_status(evt) + "\n" + await _get_help_text(evt))
+44 -36
View File
@@ -14,7 +14,7 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional, Callable
from typing import Awaitable, Dict, Callable, Coroutine, Optional, Tuple, Union, cast
import asyncio
from telethon.errors import (ChatAdminRequiredError, UsernameInvalidError,
@@ -22,6 +22,7 @@ from telethon.errors import (ChatAdminRequiredError, UsernameInvalidError,
from telethon.tl.types import ChatForbidden, ChannelForbidden
from mautrix_appservice import MatrixRequestError, IntentAPI
from ..types import MatrixRoomId, TelegramId
from .. import portal as po, user as u
from . import (command_handler, CommandEvent,
SECTION_ADMIN, SECTION_CREATING_PORTALS, SECTION_PORTAL_MANAGEMENT)
@@ -31,7 +32,7 @@ from . import (command_handler, CommandEvent,
help_section=SECTION_ADMIN,
help_args="<_level_> [_mxid_]",
help_text="Set a temporary power level without affecting Telegram.")
async def set_power_level(evt: CommandEvent) -> None:
async def set_power_level(evt: CommandEvent) -> Dict:
try:
level = int(evt.args[0])
except KeyError:
@@ -46,11 +47,12 @@ async def set_power_level(evt: CommandEvent) -> None:
except MatrixRequestError:
evt.log.exception("Failed to set power level.")
return await evt.reply("Failed to set power level.")
return {}
@command_handler(help_section=SECTION_PORTAL_MANAGEMENT,
help_text="Get a Telegram invite link to the current chat.")
async def invite_link(evt: CommandEvent) -> None:
async def invite_link(evt: CommandEvent) -> Dict:
portal = po.Portal.get_by_mxid(evt.room_id)
if not portal:
return await evt.reply("This is not a portal room.")
@@ -68,7 +70,7 @@ async def invite_link(evt: CommandEvent) -> None:
async def user_has_power_level(room: str, intent, sender: u.User, event: str, default: int = 50
) -> None:
) -> bool:
if sender.is_admin:
return True
# Make sure the state store contains the power levels.
@@ -82,8 +84,9 @@ async def user_has_power_level(room: str, intent, sender: u.User, event: str, de
async def _get_portal_and_check_permission(evt: CommandEvent, permission: str,
action: Optional[str] = None) -> None:
room_id = evt.args[0] if len(evt.args) > 0 else evt.room_id
action: Optional[str] = None
) -> Tuple[Union[Dict, po.Portal], bool]:
room_id = MatrixRoomId(evt.args[0]) if len(evt.args) > 0 else evt.room_id
portal = po.Portal.get_by_mxid(room_id)
if not portal:
@@ -97,8 +100,8 @@ async def _get_portal_and_check_permission(evt: CommandEvent, permission: str,
def _get_portal_murder_function(action: str, room_id: str, function: Callable, command: str,
completed_message: str) -> None:
async def post_confirm(confirm) -> None:
completed_message: str) -> Dict:
async def post_confirm(confirm) -> Optional[Dict]:
confirm.sender.command_status = None
if len(confirm.args) > 0 and confirm.args[0] == f"confirm-{command}":
await function()
@@ -106,6 +109,7 @@ def _get_portal_murder_function(action: str, room_id: str, function: Callable, c
return await confirm.reply(completed_message)
else:
return await confirm.reply(f"{action} cancelled.")
return None
return {
"next": post_confirm,
@@ -118,10 +122,11 @@ def _get_portal_murder_function(action: str, room_id: str, function: Callable, c
help_text="Remove all users from the current portal room and forget the portal. "
"Only works for group chats; to delete a private chat portal, simply "
"leave the room.")
async def delete_portal(evt: CommandEvent) -> None:
portal, ok = await _get_portal_and_check_permission(evt, "unbridge")
async def delete_portal(evt: CommandEvent) -> Optional[Dict]:
result, ok = await _get_portal_and_check_permission(evt, "unbridge")
if not ok:
return
return None
portal = cast('po.Portal', result)
evt.sender.command_status = _get_portal_murder_function("Portal deletion", portal.mxid,
portal.cleanup_and_delete, "delete",
@@ -139,10 +144,11 @@ async def delete_portal(evt: CommandEvent) -> None:
@command_handler(needs_auth=False, needs_puppeting=False,
help_section=SECTION_PORTAL_MANAGEMENT,
help_text="Remove puppets from the current portal room and forget the portal.")
async def unbridge(evt: CommandEvent) -> None:
portal, ok = await _get_portal_and_check_permission(evt, "unbridge")
async def unbridge(evt: CommandEvent) -> Optional[Dict]:
result, ok = await _get_portal_and_check_permission(evt, "unbridge")
if not ok:
return
return None
portal = cast('po.Portal', result)
evt.sender.command_status = _get_portal_murder_function("Room unbridging", portal.mxid,
portal.unbridge, "unbridge",
@@ -158,11 +164,11 @@ async def unbridge(evt: CommandEvent) -> None:
help_text="Bridge the current Matrix room to the Telegram chat with the given "
"ID. The ID must be the prefixed version that you get with the `/id` "
"command of the Telegram-side bot.")
async def bridge(evt: CommandEvent) -> None:
async def bridge(evt: CommandEvent) -> Dict:
if len(evt.args) == 0:
return await evt.reply("**Usage:** "
"`$cmdprefix+sp bridge <Telegram chat ID> [Matrix room ID]`")
room_id = evt.args[1] if len(evt.args) > 1 else evt.room_id
room_id = MatrixRoomId(evt.args[1]) if len(evt.args) > 1 else evt.room_id
that_this = "This" if room_id == evt.room_id else "That"
portal = po.Portal.get_by_mxid(room_id)
@@ -173,12 +179,12 @@ async def bridge(evt: CommandEvent) -> None:
return await evt.reply(f"You do not have the permissions to bridge {that_this} room.")
# The /id bot command provides the prefixed ID, so we assume
tgid = evt.args[0]
if tgid.startswith("-100"):
tgid = int(tgid[4:])
tgid_str = evt.args[0]
if tgid_str.startswith("-100"):
tgid = TelegramId(int(tgid_str[4:]))
peer_type = "channel"
elif tgid.startswith("-"):
tgid = -int(tgid)
elif tgid_str.startswith("-"):
tgid = TelegramId(-int(tgid_str))
peer_type = "chat"
else:
return await evt.reply("That doesn't seem like a prefixed Telegram chat ID.\n\n"
@@ -224,7 +230,8 @@ async def bridge(evt: CommandEvent) -> None:
"chat to this room, use `$cmdprefix+sp continue`")
async def cleanup_old_portal_while_bridging(evt: CommandEvent, portal: "po.Portal") -> None:
async def cleanup_old_portal_while_bridging(evt: CommandEvent, portal: "po.Portal"
) -> Tuple[bool, Coroutine[None, None, None]]:
if not portal.mxid:
await evt.reply("The portal seems to have lost its Matrix room between you"
"calling `$cmdprefix+sp bridge` and this command.\n\n"
@@ -247,7 +254,7 @@ async def cleanup_old_portal_while_bridging(evt: CommandEvent, portal: "po.Porta
return False, None
async def confirm_bridge(evt: CommandEvent) -> None:
async def confirm_bridge(evt: CommandEvent) -> Optional[Dict]:
status = evt.sender.command_status
try:
portal = po.Portal.get_by_tgid(status["tgid"], peer_type=status["peer_type"])
@@ -260,7 +267,7 @@ async def confirm_bridge(evt: CommandEvent) -> None:
if "mxid" in status:
ok, coro = await cleanup_old_portal_while_bridging(evt, portal)
if not ok:
return
return None
elif coro:
asyncio.ensure_future(coro, loop=evt.loop)
await evt.reply("Cleaning up previous portal room...")
@@ -304,7 +311,7 @@ async def confirm_bridge(evt: CommandEvent) -> None:
return await evt.reply("Bridging complete. Portal synchronization should begin momentarily.")
async def get_initial_state(intent: IntentAPI, room_id: str) -> None:
async def get_initial_state(intent: IntentAPI, room_id: str) -> Tuple[str, str, Dict]:
state = await intent.get_room_state(room_id)
title = None
about = None
@@ -330,7 +337,7 @@ async def get_initial_state(intent: IntentAPI, room_id: str) -> None:
help_text="Create a Telegram chat of the given type for the current Matrix room. "
"The type is either `group`, `supergroup` or `channel` (defaults to "
"`group`).")
async def create(evt: CommandEvent) -> None:
async def create(evt: CommandEvent) -> Dict:
type = evt.args[0] if len(evt.args) > 0 else "group"
if type not in {"chat", "group", "supergroup", "channel"}:
return await evt.reply(
@@ -365,7 +372,7 @@ async def create(evt: CommandEvent) -> None:
@command_handler(help_section=SECTION_PORTAL_MANAGEMENT,
help_text="Upgrade a normal Telegram group to a supergroup.")
async def upgrade(evt: CommandEvent) -> None:
async def upgrade(evt: CommandEvent) -> Dict:
portal = po.Portal.get_by_mxid(evt.room_id)
if not portal:
return await evt.reply("This is not a portal room.")
@@ -387,7 +394,7 @@ async def upgrade(evt: CommandEvent) -> None:
help_args="<_name_|`-`>",
help_text="Change the username of a supergroup/channel. "
"To disable, use a dash (`-`) as the name.")
async def group_name(evt: CommandEvent) -> None:
async def group_name(evt: CommandEvent) -> Dict:
if len(evt.args) == 0:
return await evt.reply("**Usage:** `$cmdprefix+sp group-name <name/->`")
@@ -423,7 +430,7 @@ async def group_name(evt: CommandEvent) -> None:
help_args="<`whitelist`|`blacklist`>",
help_text="Change whether the bridge will allow or disallow bridging rooms by "
"default.")
async def filter_mode(evt: CommandEvent) -> None:
async def filter_mode(evt: CommandEvent) -> Dict:
try:
mode = evt.args[0]
if mode not in ("whitelist", "blacklist"):
@@ -448,19 +455,19 @@ async def filter_mode(evt: CommandEvent) -> None:
help_section=SECTION_ADMIN,
help_args="<`whitelist`|`blacklist`> <_chat ID_>",
help_text="Allow or disallow bridging a specific chat.")
async def filter(evt: CommandEvent) -> None:
async def filter(evt: CommandEvent) -> Optional[Dict]:
try:
action = evt.args[0]
if action not in ("whitelist", "blacklist", "add", "remove"):
raise ValueError()
id = evt.args[1]
if id.startswith("-100"):
id = int(id[4:])
elif id.startswith("-"):
id = int(id[1:])
id_str = evt.args[1]
if id_str.startswith("-100"):
id = int(id_str[4:])
elif id_str.startswith("-"):
id = int(id_str[1:])
else:
id = int(id)
id = int(id_str)
except (IndexError, ValueError):
return await evt.reply("**Usage:** `$cmdprefix+sp filter <whitelist/blacklist> <chat ID>`")
@@ -490,3 +497,4 @@ async def filter(evt: CommandEvent) -> None:
list.remove(id)
save()
return await evt.reply(f"Chat ID removed from {mode}.")
return None
+11 -7
View File
@@ -14,10 +14,13 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Awaitable, Dict, List, Optional, Tuple
import re
from telethon.errors import (
InviteHashInvalidError, InviteHashExpiredError, UserAlreadyParticipantError)
from telethon.tl.types import User as TLUser
from telethon.tl.types import TypeUpdates
from telethon.tl.functions.messages import ImportChatInviteRequest, CheckChatInviteRequest
from telethon.tl.functions.channels import JoinChannelRequest
@@ -28,7 +31,7 @@ from . import command_handler, CommandEvent, SECTION_MISC, SECTION_CREATING_PORT
@command_handler(help_section=SECTION_MISC,
help_args="[_-r|--remote_] <_query_>",
help_text="Search your contacts or the Telegram servers for users.")
async def search(evt: CommandEvent) -> None:
async def search(evt: CommandEvent) -> Optional[Dict]:
if len(evt.args) == 0:
return await evt.reply("**Usage:** `$cmdprefix+sp search [-r|--remote] <query>`")
@@ -49,7 +52,7 @@ async def search(evt: CommandEvent) -> None:
"Minimum length of remote query is 5 characters.")
return await evt.reply("No results 3:")
reply = []
reply = [] # type: List[str]
if remote:
reply += ["**Results from Telegram server:**", ""]
else:
@@ -70,7 +73,7 @@ async def search(evt: CommandEvent) -> None:
"either the internal user ID, the username or the phone number. "
"**N.B.** The phone numbers you start chats with must already be in "
"your contacts.")
async def private_message(evt: CommandEvent) -> None:
async def private_message(evt: CommandEvent) -> Optional[Dict]:
if len(evt.args) == 0:
return await evt.reply("**Usage:** `$cmdprefix+sp pm <user identifier>`")
@@ -89,7 +92,7 @@ async def private_message(evt: CommandEvent) -> None:
f"{pu.Puppet.get_displayname(user, False)}")
async def _join(evt: CommandEvent, arg: str) -> None:
async def _join(evt: CommandEvent, arg: str) -> Tuple[TypeUpdates, Dict]:
if arg.startswith("joinchat/"):
invite_hash = arg[len("joinchat/"):]
try:
@@ -112,7 +115,7 @@ async def _join(evt: CommandEvent, arg: str) -> None:
@command_handler(help_section=SECTION_CREATING_PORTALS,
help_args="<_link_>",
help_text="Join a chat with an invite link.")
async def join(evt: CommandEvent) -> None:
async def join(evt: CommandEvent) -> Optional[Dict]:
if len(evt.args) == 0:
return await evt.reply("**Usage:** `$cmdprefix+sp join <invite link>`")
@@ -123,7 +126,7 @@ async def join(evt: CommandEvent) -> None:
updates, _ = await _join(evt, arg.group(1))
if not updates:
return
return None
for chat in updates.chats:
portal = po.Portal.get_by_entity(chat)
@@ -134,12 +137,13 @@ async def join(evt: CommandEvent) -> None:
await evt.reply(f"Creating room for {chat.title}... This might take a while.")
await portal.create_matrix_room(evt.sender, chat, [evt.sender.mxid])
return await evt.reply(f"Created room for {portal.title}")
return None
@command_handler(help_section=SECTION_MISC,
help_args="[`chats`|`contacts`|`me`]",
help_text="Synchronize your chat portals, contacts and/or own info.")
async def sync(evt: CommandEvent) -> None:
async def sync(evt: CommandEvent) -> Optional[Dict]:
if len(evt.args) > 0:
sync_only = evt.args[0]
if sync_only not in ("chats", "contacts", "me"):
+3 -3
View File
@@ -14,7 +14,7 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Tuple, Any, Optional
from typing import Any, Dict, Optional, Tuple
from ruamel.yaml import YAML
from ruamel.yaml.comments import CommentedMap
import random
@@ -25,7 +25,7 @@ yaml.indent(4)
class DictWithRecursion:
def __init__(self, data: CommentedMap = None) -> None:
def __init__(self, data: Optional[CommentedMap] = None) -> None:
self._data = data or CommentedMap() # type: CommentedMap
def _recursive_get(self, data: CommentedMap, key: str, default_value: Any) -> Any:
@@ -99,7 +99,7 @@ class Config(DictWithRecursion):
self.path = path # type: str
self.registration_path = registration_path # type: str
self.base_path = base_path # type: str
self._registration = None # type: dict
self._registration = None # type: Optional[Dict]
def load(self) -> None:
with open(self.path, 'r') as stream:
+5 -7
View File
@@ -14,7 +14,7 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import TYPE_CHECKING, Optional
from typing import Generator, Optional, Tuple, Union, TYPE_CHECKING
if TYPE_CHECKING:
import asyncio
@@ -44,9 +44,7 @@ class Context:
self.public_website = None # type: PublicBridgeWebsite
self.provisioning_api = None # type: ProvisioningAPI
def __iter__(self) -> None:
yield self.az
yield self.db
yield self.config
yield self.loop
yield self.bot
@property
def core(self) -> Tuple['AppService', 'scoped_session', 'Config',
'asyncio.AbstractEventLoop', Optional['Bot']]:
return (self.az, self.db, self.config, self.loop, self.bot)
+8 -6
View File
@@ -14,6 +14,8 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict
from sqlalchemy import (Column, UniqueConstraint, ForeignKey, ForeignKeyConstraint, Integer,
BigInteger, String, Boolean, Text)
from sqlalchemy.sql import expression
@@ -88,20 +90,20 @@ class RoomState(Base):
room_id = Column(String, primary_key=True)
_power_levels_text = Column("power_levels", Text, nullable=True)
_power_levels_json = None
_power_levels_json = {} # type: Dict
@property
def has_power_levels(self) -> None:
def has_power_levels(self) -> bool:
return bool(self._power_levels_text)
@property
def power_levels(self) -> None:
def power_levels(self) -> Dict:
if not self._power_levels_json and self._power_levels_text:
self._power_levels_json = json.loads(self._power_levels_text)
return self._power_levels_json or {}
return self._power_levels_json
@power_levels.setter
def power_levels(self, val) -> None:
def power_levels(self, val: Dict) -> None:
self._power_levels_json = val
self._power_levels_text = json.dumps(val)
@@ -116,7 +118,7 @@ class UserProfile(Base):
displayname = Column(String, nullable=True)
avatar_url = Column(String, nullable=True)
def dict(self) -> None:
def dict(self) -> Dict[str, Column]:
return {
"membership": self.membership,
"displayname": self.displayname,
@@ -80,12 +80,12 @@ class MatrixParser(HTMLParser, MatrixParserCommon):
args["url"] = url
return MessageEntityTextUrl, None
def handle_starttag(self, tag: str, attrs: List[Tuple[str, str]]):
def handle_starttag(self, tag: str, attrs_list: List[Tuple[str, str]]):
self._open_tags.appendleft(tag)
self._open_tags_meta.appendleft(0)
attrs = dict(attrs)
entity_type = None # type: type(TypeMessageEntity)
attrs = dict(attrs_list)
entity_type = None # type: Optional[Type[TypeMessageEntity]]
args = {} # type: Dict[str, Any]
if tag in ("strong", "b"):
entity_type = MessageEntityBold
@@ -14,7 +14,7 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional, List, Tuple, Union, Callable
from typing import Callable, List, Optional, Sequence, Tuple, Type, Union
from lxml import html
from telethon.tl.types import (MessageEntityMention as Mention,
@@ -83,7 +83,7 @@ def offset_length_multiply(amount: int):
class TelegramMessage:
def __init__(self, text: str = "", entities: Optional[List[TypeMessageEntity]] = None):
def __init__(self, text: str = "", entities: Optional[List[TypeMessageEntity]] = None) -> None:
self.text = text # type: str
self.entities = entities or [] # type: List[TypeMessageEntity]
@@ -120,7 +120,7 @@ class TelegramMessage:
self.text = msg.text + self.text
return self
def format(self, entity_type: type(TypeMessageEntity), offset: int = None, length: int = None,
def format(self, entity_type: Type[TypeMessageEntity], offset: int = None, length: int = None,
**kwargs) -> "TelegramMessage":
self.entities.append(entity_type(offset=offset or 0,
length=length if length is not None else len(self.text),
@@ -158,7 +158,8 @@ class TelegramMessage:
return output
@staticmethod
def join(items: List[Union[str, "TelegramMessage"]], separator: str = " ") -> "TelegramMessage":
def join(items: Sequence[Union[str, "TelegramMessage"]],
separator: str = " ") -> "TelegramMessage":
main = TelegramMessage()
for msg in items:
if isinstance(msg, str):
+11 -9
View File
@@ -14,7 +14,7 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional, List, Tuple, TYPE_CHECKING
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
from html import escape
import logging
import re
@@ -28,6 +28,7 @@ from telethon.tl.types import (MessageEntityMention, MessageEntityMentionName,
from mautrix_appservice import MatrixRequestError
from mautrix_appservice.intent_api import IntentAPI
from ..types import TelegramId
from .. import user as u, puppet as pu, portal as po
from ..db import Message as DBMessage
from .util import (add_surrogates, remove_surrogates, trim_reply_fallback_html,
@@ -40,14 +41,14 @@ if TYPE_CHECKING:
try:
from lxml.html.diff import htmldiff
except ImportError:
htmldiff = None # type: function
htmldiff = None # type: ignore
log = logging.getLogger("mau.fmt.tg") # type: logging.Logger
should_highlight_edits = False # type: bool
def telegram_reply_to_matrix(evt: Message, source: "AbstractUser") -> dict:
def telegram_reply_to_matrix(evt: Message, source: 'AbstractUser') -> Dict:
if evt.reply_to_msg_id:
space = (evt.to_id.channel_id
if isinstance(evt, Message) and isinstance(evt.to_id, PeerChannel)
@@ -116,7 +117,7 @@ def highlight_edits(new_html: str, old_html: str) -> str:
async def _add_reply_header(source: "AbstractUser", text: str, html: str, evt: Message,
relates_to: dict, main_intent: IntentAPI, is_edit: bool
relates_to: Dict, main_intent: IntentAPI, is_edit: bool
) -> Tuple[str, str]:
space = (evt.to_id.channel_id
if isinstance(evt, Message) and isinstance(evt.to_id, PeerChannel)
@@ -177,10 +178,10 @@ async def _add_reply_header(source: "AbstractUser", text: str, html: str, evt: M
async def telegram_to_matrix(evt: Message, source: "AbstractUser",
main_intent: Optional[IntentAPI] = None,
is_edit: bool = False, prefix_text: Optional[str] = None,
prefix_html: Optional[str] = None) -> Tuple[str, str, dict]:
prefix_html: Optional[str] = None) -> Tuple[str, str, Dict]:
text = add_surrogates(evt.message)
html = _telegram_entities_to_matrix_catch(text, evt.entities) if evt.entities else None
relates_to = {}
relates_to = {} # type: Dict
if prefix_html:
html = prefix_html + (html or escape(text))
@@ -217,6 +218,7 @@ def _telegram_entities_to_matrix_catch(text: str, entities: List[TypeMessageEnti
"message=%s\n"
"entities=%s",
text, entities)
return "[failed conversion in _telegram_entities_to_matrix]"
def _telegram_entities_to_matrix(text: str, entities: List[TypeMessageEntity]) -> str:
@@ -290,7 +292,7 @@ def _parse_mention(html: List[str], entity_text: str) -> bool:
return False
def _parse_name_mention(html: List[str], entity_text: str, user_id: int) -> bool:
def _parse_name_mention(html: List[str], entity_text: str, user_id: TelegramId) -> bool:
user = u.User.get_by_tgid(user_id)
if user:
mxid = user.mxid
@@ -315,8 +317,8 @@ def _parse_url(html: List[str], entity_text: str, url: str) -> bool:
message_link_match = message_link_regex.match(url)
if message_link_match:
group, msgid = message_link_match.groups()
msgid = int(msgid)
group, msgid_str = message_link_match.groups()
msgid = int(msgid_str)
portal = po.Portal.find_by_username(group)
if portal:
+58 -37
View File
@@ -14,23 +14,31 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import List, Dict, Tuple, Set, Match
from typing import Dict, List, Match, Optional, Set, Tuple, TYPE_CHECKING
import logging
import asyncio
import re
from mautrix_appservice import MatrixRequestError, IntentError
from .types import MatrixEvent, MatrixEventId, MatrixRoomId, MatrixUserId
from . import user as u, portal as po, puppet as pu, commands as com
if TYPE_CHECKING:
from mautrix_appservice import AppService
from .context import Context
from sqlalchemy.orm import scoped_session
from .config import Config
from .bot import Bot
class MatrixHandler:
log = logging.getLogger("mau.mx") # type: logging.Logger
def __init__(self, context) -> None:
self.az, self.db, self.config, _, self.tgbot = context
def __init__(self, context: 'Context') -> None:
self.az, self.db, self.config, _, self.tgbot = context.core
self.commands = com.CommandProcessor(context) # type: com.CommandProcessor
self.previously_typing = [] # type: List[str]
self.previously_typing = [] # type: List[MatrixUserId]
self.az.matrix_event_handler(self.handle_event)
@@ -50,7 +58,8 @@ class MatrixHandler:
except asyncio.TimeoutError:
self.log.exception("TimeoutError when trying to set avatar")
async def handle_puppet_invite(self, room_id, puppet: pu.Puppet, inviter: u.User) -> None:
async def handle_puppet_invite(self, room_id: MatrixRoomId, puppet: pu.Puppet, inviter: u.User
) -> None:
intent = puppet.default_mxid_intent
self.log.debug(f"{inviter} invited puppet for {puppet.tgid} to {room_id}")
if not await inviter.is_logged_in():
@@ -80,6 +89,7 @@ class MatrixHandler:
await intent.join_room(room_id)
portal = po.Portal.get_by_tgid(puppet.tgid, inviter.tgid, "user")
# TODO: if portal is None:
if portal.mxid:
try:
await intent.invite(portal.mxid, inviter.mxid)
@@ -95,13 +105,13 @@ class MatrixHandler:
portal.mxid = room_id
portal.save()
inviter.register_portal(portal)
await intent.send_notice(room_id, "po.Portal to private chat created.")
await intent.send_notice(room_id, "Portal to private chat created.")
else:
await intent.join_room(room_id)
await intent.send_notice(room_id, "This puppet will remain inactive until a "
"Telegram chat is created for this room.")
async def accept_bot_invite(self, room_id: str, inviter: u.User) -> None:
async def accept_bot_invite(self, room_id: MatrixRoomId, inviter: u.User) -> None:
tries = 0
while tries < 5:
try:
@@ -126,9 +136,13 @@ class MatrixHandler:
"<code>bridge.permissions</code> section in your config file.")
await self.az.intent.leave_room(room_id)
async def handle_invite(self, room_id: str, user_id: str, inviter_mxid: str) -> None:
async def handle_invite(self, room_id: MatrixRoomId, user_id: MatrixUserId,
inviter_mxid: MatrixUserId) -> None:
self.log.debug(f"{inviter_mxid} invited {user_id} to {room_id}")
inviter = await u.User.get_by_mxid(inviter_mxid).ensure_started()
inviter = u.User.get_by_mxid(inviter_mxid)
if inviter is None:
self.log.exception("Failed to find user with Matrix ID {inviter_mxid}")
await inviter.ensure_started()
if user_id == self.az.bot_mxid:
return await self.accept_bot_invite(room_id, inviter)
elif not inviter.whitelisted:
@@ -150,7 +164,8 @@ class MatrixHandler:
# The rest can probably be ignored
async def handle_join(self, room_id: str, user_id: str, event_id: str) -> None:
async def handle_join(self, room_id: MatrixRoomId, user_id: MatrixUserId,
event_id: MatrixEventId) -> None:
user = await u.User.get_by_mxid(user_id).ensure_started()
portal = po.Portal.get_by_mxid(room_id)
@@ -171,7 +186,8 @@ class MatrixHandler:
if await user.is_logged_in() or portal.has_bot:
await portal.join_matrix(user, event_id)
async def handle_part(self, room_id: str, user_id, sender_mxid: str, event_id: str) -> None:
async def handle_part(self, room_id: MatrixRoomId, user_id: MatrixUserId,
sender_mxid: MatrixUserId, event_id: MatrixEventId) -> None:
self.log.debug(f"{user_id} left {room_id}")
sender = u.User.get_by_mxid(sender_mxid, create=False)
@@ -185,6 +201,7 @@ class MatrixHandler:
puppet = pu.Puppet.get_by_mxid(user_id)
if sender and puppet:
# TODO: Puppet should probably be an AbstractUser
await portal.leave_matrix(puppet, sender, event_id)
user = u.User.get_by_mxid(user_id, create=False)
@@ -194,7 +211,7 @@ class MatrixHandler:
if await user.is_logged_in() or portal.has_bot:
await portal.leave_matrix(user, sender, event_id)
def is_command(self, message: dict) -> Tuple[bool, str]:
def is_command(self, message: Dict) -> Tuple[bool, str]:
text = message.get("body", "")
prefix = self.config["bridge.command_prefix"]
is_command = text.startswith(prefix)
@@ -202,9 +219,10 @@ class MatrixHandler:
text = text[len(prefix) + 1:]
return is_command, text
async def handle_message(self, room, sender, message, event_id) -> None:
async def handle_message(self, room: MatrixRoomId, sender_id: MatrixUserId, message: Dict,
event_id: MatrixEventId) -> None:
is_command, text = self.is_command(message)
sender = await u.User.get_by_mxid(sender).ensure_started()
sender = await u.User.get_by_mxid(sender_id).ensure_started()
if not sender.relaybot_whitelisted:
self.log.debug(f"Ignoring message \"{message}\" from {sender} to {room}:"
" u.User is not whitelisted.")
@@ -237,7 +255,8 @@ class MatrixHandler:
is_portal=portal is not None)
@staticmethod
async def handle_redaction(room_id: str, sender_mxid: str, event_id: str) -> None:
async def handle_redaction(room_id: MatrixRoomId, sender_mxid: MatrixUserId,
event_id: MatrixEventId) -> None:
sender = await u.User.get_by_mxid(sender_mxid).ensure_started()
if not sender.relaybot_whitelisted:
return
@@ -249,14 +268,15 @@ class MatrixHandler:
await portal.handle_matrix_deletion(sender, event_id)
@staticmethod
async def handle_power_levels(room_id: str, sender_mxid: str, new: dict, old: dict) -> None:
async def handle_power_levels(room_id: MatrixRoomId, sender_mxid: MatrixUserId,
new: Dict, old: Dict) -> None:
portal = po.Portal.get_by_mxid(room_id)
sender = await u.User.get_by_mxid(sender_mxid).ensure_started()
if await sender.has_full_access(allow_bot=True) and portal:
await portal.handle_matrix_power_levels(sender, new["users"], old["users"])
@staticmethod
async def handle_room_meta(evt_type: str, room_id: str, sender_mxid: str,
async def handle_room_meta(evt_type: str, room_id: MatrixRoomId, sender_mxid: MatrixUserId,
content: dict) -> None:
portal = po.Portal.get_by_mxid(room_id)
sender = await u.User.get_by_mxid(sender_mxid).ensure_started()
@@ -271,8 +291,8 @@ class MatrixHandler:
await handler(sender, content[content_key])
@staticmethod
async def handle_room_pin(room_id: str, sender_mxid: str, new_events: Set[str],
old_events: Set[str]) -> None:
async def handle_room_pin(room_id: MatrixRoomId, sender_mxid: MatrixUserId,
new_events: Set[str], old_events: Set[str]) -> None:
portal = po.Portal.get_by_mxid(room_id)
sender = await u.User.get_by_mxid(sender_mxid).ensure_started()
if await sender.has_full_access(allow_bot=True) and portal:
@@ -285,8 +305,8 @@ class MatrixHandler:
await portal.handle_matrix_pin(sender, None)
@staticmethod
async def handle_name_change(room_id: str, user_id: str, displayname: str,
prev_displayname: str, event_id: str) -> None:
async def handle_name_change(room_id: MatrixRoomId, user_id: MatrixUserId, displayname: str,
prev_displayname: str, event_id: MatrixEventId) -> None:
portal = po.Portal.get_by_mxid(room_id)
if not portal or not portal.has_bot:
return
@@ -296,13 +316,14 @@ class MatrixHandler:
await portal.name_change_matrix(user, displayname, prev_displayname, event_id)
@staticmethod
def parse_read_receipts(content: dict) -> Dict[str, str]:
def parse_read_receipts(content: Dict) -> Dict[MatrixUserId, MatrixEventId]:
return {user_id: event_id
for event_id, receipts in content.items()
for user_id in receipts.get("m.read", {})}
@staticmethod
async def handle_read_receipts(room_id: str, receipts: Dict[str, str]) -> None:
async def handle_read_receipts(room_id: MatrixRoomId,
receipts: Dict[MatrixUserId, MatrixEventId]) -> None:
portal = po.Portal.get_by_mxid(room_id)
if not portal:
return
@@ -314,13 +335,13 @@ class MatrixHandler:
await portal.mark_read(user, event_id)
@staticmethod
async def handle_presence(user_id: str, presence: str) -> None:
async def handle_presence(user_id: MatrixUserId, presence: str) -> None:
user = await u.User.get_by_mxid(user_id).ensure_started()
if not await user.is_logged_in():
return
await user.set_presence(presence == "online")
user.set_presence(presence == "online")
async def handle_typing(self, room_id: str, now_typing: List[str]) -> None:
async def handle_typing(self, room_id: MatrixRoomId, now_typing: List[MatrixUserId]) -> None:
portal = po.Portal.get_by_mxid(room_id)
if not portal:
return
@@ -335,35 +356,35 @@ class MatrixHandler:
if not await user.is_logged_in():
continue
await portal.set_typing(user, is_typing)
portal.set_typing(user, is_typing)
self.previously_typing = now_typing
def filter_matrix_event(self, event: dict) -> None:
def filter_matrix_event(self, event: MatrixEvent) -> bool:
sender = event.get("sender", None)
if not sender:
return False
return (sender == self.az.bot_mxid
or pu.Puppet.get_id_from_mxid(sender) is not None)
async def try_handle_event(self, evt: dict) -> None:
async def try_handle_event(self, evt: MatrixEvent) -> None:
try:
await self.handle_event(evt)
except Exception:
self.log.exception("Error handling manually received Matrix event")
async def handle_event(self, evt: dict) -> None:
async def handle_event(self, evt: MatrixEvent) -> None:
if self.filter_matrix_event(evt):
return
self.log.debug("Received event: %s", evt)
evt_type = evt.get("type", "m.unknown") # type: str
room_id = evt.get("room_id", None) # type: str
event_id = evt.get("event_id", None) # type: str
sender = evt.get("sender", None) # type: str
content = evt.get("content", {}) # type: dict
room_id = evt.get("room_id", None) # type: Optional[MatrixRoomId]
event_id = evt.get("event_id", None) # type: Optional[MatrixEventId]
sender = evt.get("sender", None) # type: Optional[MatrixUserId]
content = evt.get("content", {}) # type: Dict
if evt_type == "m.room.member":
state_key = evt["state_key"] # type: str
prev_content = evt.get("unsigned", {}).get("prev_content", {}) # type: dict
state_key = evt["state_key"] # type: MatrixUserId
prev_content = evt.get("unsigned", {}).get("prev_content", {}) # type: Dict
membership = content.get("membership", "") # type: str
prev_membership = prev_content.get("membership", "leave") # type: str
if membership == prev_membership:
@@ -387,7 +408,7 @@ class MatrixHandler:
elif evt_type == "m.room.redaction":
await self.handle_redaction(room_id, sender, evt["redacts"])
elif evt_type == "m.room.power_levels":
prev_content = evt.get("unsigned", {}).get("prev_content", {}) # type: dict
prev_content = evt.get("unsigned", {}).get("prev_content", {})
await self.handle_power_levels(room_id, sender, evt["content"], prev_content)
elif evt_type in ("m.room.name", "m.room.avatar", "m.room.topic"):
await self.handle_room_meta(evt_type, room_id, sender, evt["content"])
+73 -61
View File
@@ -14,7 +14,7 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Awaitable, Dict, List, Optional, Pattern, Tuple, Union, TYPE_CHECKING
from typing import Awaitable, Dict, List, Optional, Pattern, Tuple, Union, cast, TYPE_CHECKING
from collections import deque
from datetime import datetime
from string import Template
@@ -62,7 +62,7 @@ from telethon.tl.types import (
UserFull)
from mautrix_appservice import MatrixRequestError, IntentError, AppService, IntentAPI
from .types import MatrixEventId, MatrixRoomId, MatrixUserId, TelegramId
from .context import Context
from .db import Portal as DBPortal, Message as DBMessage, TelegramFile as DBTelegramFile
from . import puppet as p, user as u, formatter, util
@@ -105,18 +105,18 @@ class Portal:
by_mxid = {} # type: Dict[str, Portal]
by_tgid = {} # type: Dict[Tuple[int, int], Portal]
def __init__(self, tgid: int, peer_type: str, tg_receiver: Optional[int] = None,
mxid: Optional[str] = None, username: Optional[str] = None,
def __init__(self, tgid: TelegramId, peer_type: str, tg_receiver: Optional[int] = None,
mxid: Optional[MatrixRoomId] = None, username: Optional[str] = None,
megagroup: Optional[bool] = False, title: Optional[str] = None,
about: Optional[str] = None, photo_id: Optional[str] = None,
db_instance: DBPortal = None) -> None:
self.mxid = mxid # type: str
self.tgid = tgid # type: int
self.mxid = mxid # type: Optional[MatrixRoomId]
self.tgid = tgid # type: TelegramId
self.tg_receiver = tg_receiver or tgid # type: int
self.peer_type = peer_type # type: str
self.username = username # type: str
self.megagroup = megagroup # type: bool
self.title = title # type: str
self.title = title # type: Optional[str]
self.about = about # type: str
self.photo_id = photo_id # type: str
self._db_instance = db_instance # type: DBPortal
@@ -161,7 +161,7 @@ class Portal:
@property
def has_bot(self) -> bool:
return self.bot and self.bot.is_in_chat(self.tgid)
return bool(self.bot and self.bot.is_in_chat(self.tgid))
@property
def main_intent(self) -> IntentAPI:
@@ -270,8 +270,8 @@ class Portal:
else:
raise ValueError("Invalid invite identifier given to invite_matrix()")
async def update_matrix_room(self, user: "AbstractUser", entity: TypeChat, direct: bool,
puppet: p.Puppet = None, levels: dict = None,
async def update_matrix_room(self, user: 'AbstractUser', entity: TypeChat, direct: bool,
puppet: p.Puppet = None, levels: Dict = None,
users: List[User] = None,
participants: List[TypeParticipant] = None) -> None:
if not direct:
@@ -303,8 +303,8 @@ class Portal:
async with self._room_create_lock:
return await self._create_matrix_room(user, entity, invites)
async def _create_matrix_room(self, user: "AbstractUser", entity: TypeChat, invites: InviteList
) -> Optional[str]:
async def _create_matrix_room(self, user: 'AbstractUser', entity: TypeChat, invites: InviteList
) -> Optional[MatrixRoomId]:
direct = self.peer_type == "user"
if self.mxid:
@@ -369,6 +369,8 @@ class Portal:
participants=participants),
loop=self.loop)
return self.mxid
def _get_base_power_levels(self, levels: dict = None, entity: TypeChat = None) -> dict:
levels = levels or {}
power_level_requirement = (0 if self.peer_type == "chat" and not entity.admins_enabled
@@ -437,18 +439,19 @@ class Portal:
and config["bridge.max_initial_member_sync"] == -1
and (self.megagroup or self.peer_type != "channel"))
if trust_member_list:
joined_mxids = await self.main_intent.get_room_members(self.mxid)
for user in joined_mxids:
if user == self.az.bot_mxid:
joined_mxids = cast(List[MatrixUserId],
await self.main_intent.get_room_members(self.mxid))
for user_mxid in joined_mxids:
if user_mxid == self.az.bot_mxid:
continue
puppet_id = p.Puppet.get_id_from_mxid(user)
puppet_id = p.Puppet.get_id_from_mxid(user_mxid)
if puppet_id and puppet_id not in allowed_tgids:
if self.bot and puppet_id == self.bot.tgid:
self.bot.remove_chat(self.tgid)
await self.main_intent.kick(self.mxid, user,
await self.main_intent.kick(self.mxid, user_mxid,
"User had left this Telegram chat.")
continue
mx_user = u.User.get_by_mxid(user, create=False)
mx_user = u.User.get_by_mxid(user_mxid, create=False)
if mx_user and mx_user.is_bot and mx_user.tgid not in allowed_tgids:
mx_user.unregister_portal(self)
@@ -457,7 +460,7 @@ class Portal:
"You had left this Telegram chat.")
continue
async def add_telegram_user(self, user_id: int, source: Optional["AbstractUser"] = None
async def add_telegram_user(self, user_id: TelegramId, source: Optional['AbstractUser'] = None
) -> None:
puppet = p.Puppet.get(user_id)
if source:
@@ -470,7 +473,7 @@ class Portal:
user.register_portal(self)
await self.invite_to_matrix(user.mxid)
async def delete_telegram_user(self, user_id: int, sender: p.Puppet) -> None:
async def delete_telegram_user(self, user_id: TelegramId, sender: p.Puppet) -> None:
puppet = p.Puppet.get(user_id)
user = u.User.get_by_tgid(user_id)
kick_message = (f"Kicked by {sender.displayname}"
@@ -568,8 +571,9 @@ class Portal:
return True
return False
async def _get_users(self, user: "AbstractUser", entity: Union[TypeInputPeer, InputUser,
TypeChat, TypeUser]
async def _get_users(self,
user: 'AbstractUser',
entity: Union[TypeInputPeer, InputUser, TypeChat, TypeUser]
) -> Tuple[List[TypeUser], List[TypeParticipant]]:
if self.peer_type == "chat":
chat = await user.client(GetFullChatRequest(chat_id=self.tgid))
@@ -588,7 +592,7 @@ class Portal:
entity, ChannelParticipantsRecent(), offset=0, limit=limit, hash=0))
return response.users, response.participants
elif limit > 200 or limit == -1:
users, participants = [], []
users, participants = [], [] # type: Tuple[List[TypeUser], List[TypeParticipant]]
offset = 0
remaining_quota = limit if limit > 0 else 1000000
query = (ChannelParticipantsSearch("") if limit == -1
@@ -609,6 +613,7 @@ class Portal:
return [], []
elif self.peer_type == "user":
return [entity], []
return [], []
async def get_invite_link(self, user: 'u.User') -> str:
if self.peer_type == "user":
@@ -688,7 +693,7 @@ class Portal:
return ""
async def _get_state_change_message(self, event: str, user: 'u.User',
arguments: Optional[dict] = None) -> Optional[dict]:
arguments: Optional[Dict] = None) -> Optional[Dict]:
tpl = config[f"bridge.state_event_formats.{event}"]
if len(tpl) == 0:
# Empty format means they don't want the message
@@ -724,11 +729,11 @@ class Portal:
or user.mxid_localpart)
def set_typing(self, user: 'u.User', typing: bool = True,
action=SendMessageTypingAction) -> None:
action: type = SendMessageTypingAction) -> bool:
return user.client(SetTypingRequest(
self.peer, action() if typing else SendMessageCancelAction()))
async def mark_read(self, user: 'u.User', event_id: str) -> None:
async def mark_read(self, user: 'u.User', event_id: MatrixEventId) -> None:
if user.is_bot:
return
space = self.tgid if self.peer_type == "channel" else user.tgid
@@ -743,7 +748,8 @@ class Portal:
else:
await user.client(ReadMessageHistoryRequest(peer=self.peer, max_id=message.tgid))
async def leave_matrix(self, user: 'u.User', source: 'u.User', event_id: str) -> None:
async def leave_matrix(self, user: 'u.User', source: 'u.User', event_id: MatrixEventId
) -> None:
if await user.needs_relaybot(self):
async with self.require_send_lock(self.bot.tgid):
message = await self._get_state_change_message("leave", user)
@@ -798,7 +804,7 @@ class Portal:
# We'll just assume the user is already in the chat.
pass
async def _apply_msg_format(self, sender: 'u.User', msgtype: str, message: dict) -> None:
async def _apply_msg_format(self, sender: 'u.User', msgtype: str, message: Dict) -> None:
if "formatted_body" not in message:
message["format"] = "org.matrix.custom.html"
message["formatted_body"] = escape_html(message.get("body", ""))
@@ -823,7 +829,7 @@ class Portal:
await self._apply_msg_format(sender, msgtype, message)
@staticmethod
def _matrix_event_to_entities(event: dict) -> Tuple[str, Optional[List[TypeMessageEntity]]]:
def _matrix_event_to_entities(event: Dict) -> Tuple[str, Optional[List[TypeMessageEntity]]]:
try:
if event.get("format", None) == "org.matrix.custom.html":
message, entities = formatter.matrix_to_telegram(event.get("formatted_body", ""))
@@ -851,7 +857,8 @@ class Portal:
return None
async def _handle_matrix_text(self, sender_id: int, event_id: str, space: int,
client: "MautrixTelegramClient", message: dict, reply_to: int) -> None:
client: 'MautrixTelegramClient', message: Dict, reply_to: int
) -> None:
lock = self.require_send_lock(sender_id)
async with lock:
response = await client.send_message(self.peer, message, reply_to=reply_to,
@@ -859,7 +866,8 @@ class Portal:
self._add_telegram_message_to_db(event_id, space, response)
async def _handle_matrix_file(self, msgtype: str, sender_id: int, event_id: str, space: int,
client: "MautrixTelegramClient", message: dict, reply_to: int) -> None:
client: 'MautrixTelegramClient', message: dict, reply_to: int
) -> None:
file = await self.main_intent.download_file(message["url"])
info = message.get("info", {})
@@ -893,7 +901,7 @@ class Portal:
self._add_telegram_message_to_db(event_id, space, response)
async def _handle_matrix_location(self, sender_id: int, event_id: str, space: int,
client: "MautrixTelegramClient", message: dict,
client: 'MautrixTelegramClient', message: Dict,
reply_to: int) -> None:
try:
lat, long = message["geo_uri"][len("geo:"):].split(",")
@@ -901,13 +909,13 @@ class Portal:
except (KeyError, ValueError):
self.log.exception("Failed to parse location")
return None
message, entities = self._matrix_event_to_entities(message)
caption, entities = self._matrix_event_to_entities(message)
media = MessageMediaGeo(geo=GeoPoint(lat, long, access_hash=0))
lock = self.require_send_lock(sender_id)
async with lock:
response = await client.send_media(self.peer, media, reply_to=reply_to,
caption=message, entities=entities)
caption=caption, entities=entities)
self._add_telegram_message_to_db(event_id, space, response)
def _add_telegram_message_to_db(self, event_id: str, space: int,
@@ -963,17 +971,18 @@ class Portal:
except ChatNotModifiedError:
pass
async def handle_matrix_deletion(self, deleter: 'u.User', event_id: str) -> None:
deleter = deleter if not await deleter.needs_relaybot(self) else self.bot
space = self.tgid if self.peer_type == "channel" else deleter.tgid
async def handle_matrix_deletion(self, deleter: 'u.User', event_id: MatrixEventId) -> None:
real_deleter = deleter if not await deleter.needs_relaybot(self) else self.bot
space = self.tgid if self.peer_type == "channel" else real_deleter.tgid
message = DBMessage.query.filter(DBMessage.mxid == event_id,
DBMessage.tg_space == space,
DBMessage.mx_room == self.mxid).one_or_none()
if not message:
return
await deleter.client.delete_messages(self.peer, [message.tgid])
await real_deleter.client.delete_messages(self.peer, [message.tgid])
async def _update_telegram_power_level(self, sender: 'u.User', user_id: int, level: int) -> None:
async def _update_telegram_power_level(self, sender: 'u.User', user_id: TelegramId,
level: int) -> None:
if self.peer_type == "chat":
await sender.client(EditChatAdminRequest(
chat_id=self.tgid, user_id=user_id, is_admin=level >= 50))
@@ -989,7 +998,8 @@ class Portal:
EditAdminRequest(channel=await self.get_input_entity(sender),
user_id=user_id, admin_rights=rights))
async def handle_matrix_power_levels(self, sender: 'u.User', new_users: Dict[str, int],
async def handle_matrix_power_levels(self, sender: 'u.User',
new_users: Dict[MatrixUserId, int],
old_users: Dict[str, int]) -> None:
# TODO handle all power level changes and bridge exact admin rights to supergroups/channels
for user, level in new_users.items():
@@ -1167,7 +1177,7 @@ class Portal:
return None
async def handle_telegram_photo(self, source: "AbstractUser", intent: IntentAPI, evt: Message,
relates_to=None) -> None:
relates_to: Dict = {}) -> None:
largest_size = self._get_largest_photo_size(evt.media.photo)
file = await util.transfer_file_to_matrix(self.db, source.client, intent,
largest_size.location)
@@ -1197,7 +1207,7 @@ class Portal:
external_url=self.get_external_url(evt))
@staticmethod
def _parse_telegram_document_attributes(attributes: List[TypeDocumentAttribute]) -> dict:
def _parse_telegram_document_attributes(attributes: List[TypeDocumentAttribute]) -> Dict:
attrs = {
"name": None,
"mime_type": None,
@@ -1205,7 +1215,7 @@ class Portal:
"sticker_alt": None,
"width": None,
"height": None,
}
} # type: Dict
for attr in attributes:
if isinstance(attr, DocumentAttributeFilename):
attrs["name"] = attrs["name"] or attr.file_name
@@ -1218,8 +1228,8 @@ class Portal:
return attrs
@staticmethod
def _parse_telegram_document_meta(evt: Message, file: DBTelegramFile, attrs: dict
) -> Tuple[dict, str]:
def _parse_telegram_document_meta(evt: Message, file: DBTelegramFile, attrs: Dict
) -> Tuple[Dict, str]:
document = evt.media.document
name = evt.message or attrs["name"]
if attrs["is_sticker"]:
@@ -1253,7 +1263,7 @@ class Portal:
async def handle_telegram_document(self, source: "AbstractUser", intent: IntentAPI,
evt: Message,
relates_to: dict = None) -> Optional[dict]:
relates_to: dict = None) -> Optional[Dict]:
document = evt.media.document
attrs = self._parse_telegram_document_attributes(document.attributes)
@@ -1521,9 +1531,9 @@ class Portal:
else:
self.log.debug("Unhandled Telegram action in %s: %s", self.title, action)
async def set_telegram_admin(self, user_id: int) -> None:
async def set_telegram_admin(self, user_id: TelegramId) -> None:
puppet = p.Puppet.get(user_id)
user = await u.User.get_by_tgid(user_id)
user = u.User.get_by_tgid(user_id)
levels = await self.main_intent.get_power_levels(self.mxid)
if user:
@@ -1558,7 +1568,7 @@ class Portal:
await self.update_telegram_pin()
@staticmethod
def _get_level_from_participant(participant: TypeParticipant, _) -> int:
def _get_level_from_participant(participant: TypeParticipant, _: Dict) -> int:
# TODO use the power level requirements to get better precision in channels
if isinstance(participant, (ChatParticipantAdmin, ChannelParticipantAdmin)):
return 50
@@ -1599,7 +1609,7 @@ class Portal:
except KeyError:
return 50
def _participants_to_power_levels(self, participants: List[TypeParticipant], levels: dict
def _participants_to_power_levels(self, participants: List[TypeParticipant], levels: Dict
) -> bool:
bot_level = self._get_bot_level(levels)
if bot_level < self._get_powerlevel_level(levels):
@@ -1654,7 +1664,7 @@ class Portal:
mxid=self.mxid, username=self.username, megagroup=self.megagroup,
title=self.title, about=self.about, photo_id=self.photo_id)
def migrate_and_save(self, new_id: int) -> None:
def migrate_and_save(self, new_id: TelegramId) -> None:
existing = DBPortal.query.get(self.tgid_full)
if existing:
self.db.delete(existing)
@@ -1701,7 +1711,7 @@ class Portal:
# region Class instance lookup
@classmethod
def get_by_mxid(cls, mxid: str) -> Optional["Portal"]:
def get_by_mxid(cls, mxid: MatrixRoomId) -> Optional['Portal']:
try:
return cls.by_mxid[mxid]
except KeyError:
@@ -1721,7 +1731,7 @@ class Portal:
return None
@classmethod
def find_by_username(cls, username: str) -> Optional["Portal"]:
def find_by_username(cls, username: str) -> Optional['Portal']:
if not username:
return None
@@ -1729,15 +1739,15 @@ class Portal:
if portal.username and portal.username.lower() == username.lower():
return portal
portal = DBPortal.query.filter(DBPortal.username == username).one_or_none()
if portal:
return cls.from_db(portal)
dbportal = DBPortal.query.filter(DBPortal.username == username).one_or_none()
if dbportal:
return cls.from_db(dbportal)
return None
@classmethod
def get_by_tgid(cls, tgid: int, tg_receiver: int = None, peer_type: str = None
) -> Optional["Portal"]:
def get_by_tgid(cls, tgid: TelegramId, tg_receiver: Optional[TelegramId] = None,
peer_type: str = None) -> Optional['Portal']:
tg_receiver = tg_receiver or tgid
tgid_full = (tgid, tg_receiver)
try:
@@ -1758,8 +1768,10 @@ class Portal:
return None
@classmethod
def get_by_entity(cls, entity: Union[TypeChat, TypePeer, TypeUser, TypeUserFull, TypeInputPeer],
receiver_id: int = None, create: bool = True) -> Optional["Portal"]:
def get_by_entity(cls, entity: Union[TypeChat, TypePeer, TypeUser, TypeUserFull,
TypeInputPeer],
receiver_id: Optional[TelegramId] = None, create: bool = True
) -> Optional['Portal']:
entity_type = type(entity)
if entity_type in {Chat, ChatFull}:
type_name = "chat"
@@ -1790,7 +1802,7 @@ class Portal:
def init(context: Context) -> None:
global config
Portal.az, Portal.db, config, Portal.loop, Portal.bot = context
Portal.az, Portal.db, config, Portal.loop, Portal.bot = context.core
Portal.bridge_notices = config["bridge.bridge_notices"]
Portal.filter_mode = config["bridge.filter.mode"]
Portal.filter_list = config["bridge.filter.list"]
+106 -80
View File
@@ -14,17 +14,19 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional, Awaitable, Pattern, Dict, List, TYPE_CHECKING
from typing import Awaitable, Coroutine, Dict, List, NewType, Optional, Pattern, TYPE_CHECKING
from difflib import SequenceMatcher
import re
import logging
import asyncio
from enum import Enum
from sqlalchemy import orm
from telethon.tl.types import UserProfilePhoto
from telethon.tl.types import UserProfilePhoto, User, FileLocation
from mautrix_appservice import AppService, IntentAPI, IntentError, MatrixRequestError
from .types import MatrixUserId, TelegramId
from .db import Puppet as DBPuppet
from . import util
@@ -32,6 +34,11 @@ if TYPE_CHECKING:
from .matrix import MatrixHandler
from .config import Config
from .context import Context
from . import user as u
from .abstract_user import AbstractUser
PuppetError = Enum('PuppetError', 'Success OnlyLoginSelf InvalidAccessToken')
config = None # type: Config
@@ -45,85 +52,98 @@ class Puppet:
mxid_regex = None # type: Pattern
username_template = None # type: str
hs_domain = None # type: str
cache = {} # type: Dict[str, Puppet]
cache = {} # type: Dict[TelegramId, Puppet]
by_custom_mxid = {} # type: Dict[str, Puppet]
def __init__(self, id=None, access_token=None, custom_mxid=None, username=None,
displayname=None, displayname_source=None, photo_id=None, is_bot=None,
is_registered=False, db_instance=None) -> None:
self.id = id
self.access_token = access_token
self.custom_mxid = custom_mxid
self.is_real_user = self.custom_mxid and self.access_token
self.default_mxid = self.get_mxid_from_id(self.id)
self.mxid = self.custom_mxid or self.default_mxid
def __init__(self,
id: TelegramId,
access_token: Optional[str] = None,
custom_mxid: Optional[MatrixUserId] = None,
username: Optional[str] = None,
displayname: Optional[str] = None,
displayname_source: Optional[TelegramId] = None,
photo_id: Optional[str] = None,
is_bot: bool = False,
is_registered: bool = False,
db_instance: Optional[DBPuppet] = None) -> None:
self.id = id # type: TelegramId
self.access_token = access_token # type: Optional[str]
self.custom_mxid = custom_mxid # type: Optional[MatrixUserId]
self.default_mxid = self.get_mxid_from_id(self.id) # type: MatrixUserId
self.username = username
self.displayname = displayname
self.displayname_source = displayname_source
self.photo_id = photo_id
self.is_bot = is_bot
self.is_registered = is_registered
self._db_instance = db_instance
self.username = username # type: Optional[str]
self.displayname = displayname # type: Optional[str]
self.displayname_source = displayname_source # type: Optional[TelegramId]
self.photo_id = photo_id # type: Optional[str]
self.is_bot = is_bot # type: bool
self.is_registered = is_registered # type: bool
self._db_instance = db_instance # type: Optional[DBPuppet]
self.default_mxid_intent = self.az.intent.user(self.default_mxid)
self.intent = None # type: IntentAPI
self.refresh_intents()
self.intent = self._fresh_intent() # type: IntentAPI
self.cache[id] = self
if self.custom_mxid:
self.by_custom_mxid[self.custom_mxid] = self
@property
def tgid(self) -> None:
def mxid(self):
return self.custom_mxid or self.default_mxid
@property
def tgid(self) -> TelegramId:
return self.id
@property
def is_real_user(self) -> bool:
""" Is True when the puppet is a real Matrix user. """
return bool(self.custom_mxid and self.access_token)
@staticmethod
async def is_logged_in() -> None:
async def is_logged_in() -> bool:
""" Is True if the puppet is logged in. """
return True
# region Custom puppet management
def refresh_intents(self) -> None:
self.is_real_user = self.custom_mxid and self.access_token
self.intent = (self.az.intent.user(self.custom_mxid, self.access_token)
if self.is_real_user else self.default_mxid_intent)
def _fresh_intent(self) -> IntentAPI:
return (self.az.intent.user(self.custom_mxid, self.access_token)
if self.is_real_user else self.default_mxid_intent)
async def switch_mxid(self, access_token, mxid) -> None:
async def switch_mxid(self, access_token: str, mxid: MatrixUserId) -> PuppetError:
prev_mxid = self.custom_mxid
self.custom_mxid = mxid
self.access_token = access_token
self.refresh_intents()
self.intent = self._fresh_intent()
err = await self.init_custom_mxid()
if err != 0:
if err != PuppetError.Success:
return err
try:
del self.by_custom_mxid[prev_mxid]
del self.by_custom_mxid[prev_mxid] # type: ignore
except KeyError:
pass
self.mxid = self.custom_mxid or self.default_mxid
if self.mxid != self.default_mxid:
self.by_custom_mxid[self.mxid] = self
await self.leave_rooms_with_default_user()
self.save()
return 0
return PuppetError.Success
async def init_custom_mxid(self) -> None:
async def init_custom_mxid(self) -> PuppetError:
if not self.is_real_user:
return 0
return PuppetError.Success
mxid = await self.intent.whoami()
if not mxid or mxid != self.custom_mxid:
self.custom_mxid = None
self.access_token = None
self.refresh_intents()
self.intent = self._fresh_intent()
if mxid != self.custom_mxid:
return 2
return 1
return PuppetError.OnlyLoginSelf
return PuppetError.InvalidAccessToken
if config["bridge.sync_with_custom_puppets"]:
asyncio.ensure_future(self.sync(), loop=self.loop)
return 0
return PuppetError.Success
async def leave_rooms_with_default_user(self) -> None:
for room_id in await self.default_mxid_intent.get_joined_rooms():
@@ -159,7 +179,7 @@ class Puppet:
},
})
def filter_events(self, events) -> None:
def filter_events(self, events: List[Dict]) -> List:
new_events = []
for event in events:
evt_type = event.get("type", None)
@@ -186,18 +206,18 @@ class Puppet:
new_events.append(event)
return new_events
def handle_sync(self, presence, ephemeral) -> None:
presence = [self.mx.try_handle_event(event) for event in presence]
def handle_sync(self, presence: List, ephemeral: Dict) -> None:
presence_events = [self.mx.try_handle_event(event) for event in presence]
for room_id, events in ephemeral.items():
for event in events:
event["room_id"] = room_id
ephemeral = [self.mx.try_handle_event(event)
for events in ephemeral.values()
for event in self.filter_events(events)]
ephemeral_events = [self.mx.try_handle_event(event)
for events in ephemeral.values()
for event in self.filter_events(events)]
events = ephemeral + presence
events = ephemeral_events + presence_events # List[Callable[[int], Awaitable[None]]]
coro = asyncio.gather(*events, loop=self.loop)
asyncio.ensure_future(coro, loop=self.loop)
@@ -220,13 +240,14 @@ class Puppet:
while access_token_at_start == self.access_token:
try:
sync_resp = await self.intent.client.sync(filter=filter_id, since=next_batch,
set_presence="offline")
set_presence="offline") # type: Dict
errors = 0
if next_batch is not None:
presence = sync_resp.get("presence", {}).get("events", [])
presence = sync_resp.get("presence", {}).get("events", []) # type: List
ephemeral = {room: data.get("ephemeral", {}).get("events", [])
for room, data
in sync_resp.get("rooms", {}).get("join", {}).items()}
in sync_resp.get("rooms", {}).get("join", {}).items()
} # type: Dict
self.handle_sync(presence, ephemeral)
next_batch = sync_resp.get("next_batch", None)
except MatrixRequestError as e:
@@ -241,19 +262,19 @@ class Puppet:
# region DB conversion
@property
def db_instance(self) -> None:
def db_instance(self) -> DBPuppet:
if not self._db_instance:
self._db_instance = self.new_db_instance()
return self._db_instance
def new_db_instance(self) -> None:
def new_db_instance(self) -> DBPuppet:
return DBPuppet(id=self.id, access_token=self.access_token, custom_mxid=self.custom_mxid,
username=self.username, displayname=self.displayname,
displayname_source=self.displayname_source, photo_id=self.photo_id,
is_bot=self.is_bot, matrix_registered=self.is_registered)
@classmethod
def from_db(cls, db_puppet) -> None:
def from_db(cls, db_puppet: DBPuppet) -> 'Puppet':
return Puppet(db_puppet.id, db_puppet.access_token, db_puppet.custom_mxid,
db_puppet.username, db_puppet.displayname, db_puppet.displayname_source,
db_puppet.photo_id, db_puppet.is_bot, db_puppet.matrix_registered,
@@ -272,16 +293,16 @@ class Puppet:
# endregion
# region Info updating
def similarity(self, query) -> None:
def similarity(self, query: str) -> int:
username_similarity = (SequenceMatcher(None, self.username, query).ratio()
if self.username else 0)
displayname_similarity = (SequenceMatcher(None, self.displayname, query).ratio()
if self.displayname else 0)
similarity = max(username_similarity, displayname_similarity)
return round(similarity * 1000) / 10
return int(round(similarity * 1000) / 10)
@staticmethod
def get_displayname(info, enable_format=True) -> None:
def get_displayname(info: User, enable_format: bool = True) -> str:
data = {
"phone number": info.phone if hasattr(info, "phone") else None,
"username": info.username,
@@ -308,7 +329,7 @@ class Puppet:
return config.get("bridge.displayname_template", "{displayname} (Telegram)").format(
displayname=name)
async def update_info(self, source, info) -> None:
async def update_info(self, source: 'AbstractUser', info: User) -> None:
changed = False
if self.username != info.username:
self.username = info.username
@@ -323,24 +344,26 @@ class Puppet:
if changed:
self.save()
async def update_displayname(self, source, info) -> None:
async def update_displayname(self, source: 'AbstractUser', info: User) -> bool:
ignore_source = (not source.is_relaybot
and self.displayname_source is not None
and self.displayname_source != source.tgid)
if ignore_source:
return
return False
displayname = self.get_displayname(info)
if displayname != self.displayname:
await self.default_mxid_intent.set_display_name(displayname)
self.displayname = displayname
self.displayname_source = source.tgid
self.displayname_source = TelegramId(source.tgid)
return True
elif source.is_relaybot or self.displayname_source is None:
self.displayname_source = source.tgid
self.displayname_source = TelegramId(source.tgid)
return True
else:
return False
async def update_avatar(self, source, photo) -> None:
async def update_avatar(self, source: 'AbstractUser', photo: FileLocation) -> bool:
photo_id = f"{photo.volume_id}-{photo.local_id}"
if self.photo_id != photo_id:
file = await util.transfer_file_to_matrix(self.db, source.client,
@@ -355,7 +378,7 @@ class Puppet:
# region Getters
@classmethod
def get(cls, tgid, create=True) -> "Optional[Puppet]":
def get(cls, tgid: TelegramId, create: bool = True) -> Optional['Puppet']:
try:
return cls.cache[tgid]
except KeyError:
@@ -374,12 +397,15 @@ class Puppet:
return None
@classmethod
def get_by_mxid(cls, mxid, create=True) -> "Optional[Puppet]":
def get_by_mxid(cls, mxid: MatrixUserId, create: bool = True) -> Optional['Puppet']:
tgid = cls.get_id_from_mxid(mxid)
return cls.get(tgid, create) if tgid else None
if tgid:
return cls.get(tgid, create)
return None
@classmethod
def get_by_custom_mxid(cls, mxid) -> None:
def get_by_custom_mxid(cls, mxid: MatrixUserId) -> Optional['Puppet']:
if not mxid:
raise ValueError("Matrix ID can't be empty")
@@ -396,25 +422,25 @@ class Puppet:
return None
@classmethod
def get_all_with_custom_mxid(cls) -> None:
def get_all_with_custom_mxid(cls) -> List['Puppet']:
return [cls.by_custom_mxid[puppet.mxid]
if puppet.custom_mxid in cls.by_custom_mxid
else cls.from_db(puppet)
for puppet in DBPuppet.query.filter(DBPuppet.custom_mxid is not None).all()]
@classmethod
def get_id_from_mxid(cls, mxid) -> None:
def get_id_from_mxid(cls, mxid: MatrixUserId) -> Optional[TelegramId]:
match = cls.mxid_regex.match(mxid)
if match:
return int(match.group(1))
return TelegramId(int(match.group(1)))
return None
@classmethod
def get_mxid_from_id(cls, tgid) -> None:
return f"@{cls.username_template.format(userid=tgid)}:{cls.hs_domain}"
def get_mxid_from_id(cls, tgid: TelegramId) -> MatrixUserId:
return MatrixUserId(f"@{cls.username_template.format(userid=tgid)}:{cls.hs_domain}")
@classmethod
def find_by_username(cls, username) -> "Optional[Puppet]":
def find_by_username(cls, username: str) -> Optional['Puppet']:
if not username:
return None
@@ -422,14 +448,14 @@ class Puppet:
if puppet.username and puppet.username.lower() == username.lower():
return puppet
puppet = DBPuppet.query.filter(DBPuppet.username == username).one_or_none()
if puppet:
return cls.from_db(puppet)
dbpuppet = DBPuppet.query.filter(DBPuppet.username == username).one_or_none()
if dbpuppet:
return cls.from_db(dbpuppet)
return None
@classmethod
def find_by_displayname(cls, displayname) -> "Optional[Puppet]":
def find_by_displayname(cls, displayname: str) -> Optional['Puppet']:
if not displayname:
return None
@@ -437,17 +463,17 @@ class Puppet:
if puppet.displayname and puppet.displayname == displayname:
return puppet
puppet = DBPuppet.query.filter(DBPuppet.displayname == displayname).one_or_none()
if puppet:
return cls.from_db(puppet)
dbpuppet = DBPuppet.query.filter(DBPuppet.displayname == displayname).one_or_none()
if dbpuppet:
return cls.from_db(dbpuppet)
return None
# endregion
def init(context: "Context") -> List[Awaitable[int]]:
def init(context: 'Context') -> List[Coroutine]: # [None, None, PuppetError]
global config
Puppet.az, Puppet.db, config, Puppet.loop, _ = context
Puppet.az, Puppet.db, config, Puppet.loop, _ = context.core
Puppet.mx = context.mx
Puppet.username_template = config.get("bridge.username_template", "telegram_{userid}")
Puppet.hs_domain = config["homeserver"]["domain"]
@@ -40,7 +40,7 @@ telematrix_db_engine.dispose()
portals = {}
chats = {}
messages = {}
puppets = {}
puppets = {} # Dict[int, Puppet]
for chat_link in chat_links:
if type(chat_link.tg_room) is str:
+15 -13
View File
@@ -20,37 +20,39 @@ from sqlalchemy import orm
from mautrix_appservice import StateStore
from .types import MatrixUserId, MatrixRoomId
from . import puppet as pu
from .db import RoomState, UserProfile
class SQLStateStore(StateStore):
def __init__(self, db) -> None:
def __init__(self, db: orm.Session) -> None:
super().__init__()
self.db = db # type: orm.Session
self.profile_cache = {} # type: Dict[Tuple[str, str], UserProfile]
self.room_state_cache = {} # type: Dict[str, RoomState]
@staticmethod
def is_registered(user: str) -> bool:
def is_registered(user: MatrixUserId) -> bool:
puppet = pu.Puppet.get_by_mxid(user)
return puppet.is_registered if puppet else False
@staticmethod
def registered(user: str) -> None:
def registered(user: MatrixUserId) -> None:
puppet = pu.Puppet.get_by_mxid(user)
if puppet:
puppet.is_registered = True
puppet.save()
def update_state(self, event: dict) -> None:
def update_state(self, event: Dict) -> None:
event_type = event["type"]
if event_type == "m.room.power_levels":
self.set_power_levels(event["room_id"], event["content"])
elif event_type == "m.room.member":
self.set_member(event["room_id"], event["state_key"], event["content"])
def _get_user_profile(self, room_id: str, user_id: str, create: bool = True) -> UserProfile:
def _get_user_profile(self, room_id: MatrixRoomId, user_id: MatrixUserId, create: bool = True
) -> UserProfile:
key = (room_id, user_id)
try:
return self.profile_cache[key]
@@ -67,22 +69,22 @@ class SQLStateStore(StateStore):
self.profile_cache[key] = profile
return profile
def get_member(self, room: str, user: str) -> dict:
def get_member(self, room: MatrixRoomId, user: MatrixUserId) -> Dict:
return self._get_user_profile(room, user).dict()
def set_member(self, room: str, user: str, member: dict) -> None:
def set_member(self, room: MatrixRoomId, user: MatrixUserId, member: Dict) -> None:
profile = self._get_user_profile(room, user)
profile.membership = member.get("membership", profile.membership or "leave")
profile.displayname = member.get("displayname", profile.displayname)
profile.avatar_url = member.get("avatar_url", profile.avatar_url)
self.db.commit()
def set_membership(self, room: str, user: str, membership: str) -> None:
def set_membership(self, room: MatrixRoomId, user: MatrixUserId, membership: str) -> None:
self.set_member(room, user, {
"membership": membership,
})
def _get_room_state(self, room_id: str, create: bool = True) -> RoomState:
def _get_room_state(self, room_id: MatrixRoomId, create: bool = True) -> RoomState:
try:
return self.room_state_cache[room_id]
except KeyError:
@@ -96,13 +98,13 @@ class SQLStateStore(StateStore):
self.room_state_cache[room_id] = room
return room
def has_power_levels(self, room: str) -> bool:
def has_power_levels(self, room: MatrixRoomId) -> bool:
return self._get_room_state(room).has_power_levels
def get_power_levels(self, room: str) -> dict:
def get_power_levels(self, room: MatrixRoomId) -> Dict:
return self._get_room_state(room).power_levels
def set_power_level(self, room: str, user: str, level: int) -> None:
def set_power_level(self, room: MatrixRoomId, user: MatrixUserId, level: int) -> None:
room_state = self._get_room_state(room)
power_levels = room_state.power_levels
if not power_levels:
@@ -114,7 +116,7 @@ class SQLStateStore(StateStore):
room_state.power_levels = power_levels
self.db.commit()
def set_power_levels(self, room: str, content: dict) -> None:
def set_power_levels(self, room: MatrixRoomId, content: Dict) -> None:
state = self._get_room_state(room)
state.power_levels = content
self.db.commit()
+10
View File
@@ -0,0 +1,10 @@
from typing import Dict, NewType
# MatrixId = NewType('MatrixId', str)
MatrixUserId = NewType('MatrixUserId', str)
MatrixRoomId = NewType('MatrixRoomId', str)
MatrixEventId = NewType('MatrixEventId', str)
MatrixEvent = NewType('MatrixEvent', Dict)
TelegramId = NewType('TelegramId', int)
+26 -22
View File
@@ -14,7 +14,7 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Awaitable, Dict, List, Match, Optional, Tuple, TYPE_CHECKING
from typing import Coroutine, Dict, List, Match, Optional, Tuple, cast, TYPE_CHECKING
import logging
import asyncio
import re
@@ -28,6 +28,7 @@ from telethon.tl.functions.contacts import GetContactsRequest, SearchRequest
from telethon.tl.functions.account import UpdateStatusRequest
from mautrix_appservice import MatrixRequestError
from .types import MatrixUserId, TelegramId
from .db import User as DBUser, Contact as DBContact, Portal as DBPortal
from .abstract_user import AbstractUser
from . import portal as po, puppet as pu
@@ -46,23 +47,23 @@ class User(AbstractUser):
by_mxid = {} # type: Dict[str, User]
by_tgid = {} # type: Dict[int, User]
def __init__(self, mxid: str, tgid: Optional[int] = None, username: Optional[str] = None,
db_contacts: Optional[List[DBContact]] = None, saved_contacts: int = 0,
is_bot: bool = False, db_portals: Optional[List[DBPortal]] = None,
def __init__(self, mxid: MatrixUserId, tgid: Optional[TelegramId] = None,
username: Optional[str] = None, db_contacts: Optional[List[DBContact]] = None,
saved_contacts: int = 0, is_bot: bool = False, db_portals: List[DBPortal] = [],
db_instance: Optional[DBUser] = None) -> None:
super().__init__()
self.mxid = mxid # type: str
self.tgid = tgid # type: int
self.mxid = mxid # type: MatrixUserId
self.tgid = tgid # type: TelegramId
self.is_bot = is_bot # type: bool
self.username = username # type: str
self.contacts = [] # type: List[pu.Puppet]
self.saved_contacts = saved_contacts # type: int
self.db_contacts = db_contacts # type: List[DBContact]
self.portals = {} # type: Dict[Tuple[int, int], po.Portal]
self.db_portals = db_portals # type: List[DBPortal]
self._db_instance = db_instance # type: DBUser
self.db_portals = db_portals or [] # type: List[DBPortal]
self._db_instance = db_instance # type: Optional[DBUser]
self.command_status = None # type: dict
self.command_status = None # type: Dict
(self.relaybot_whitelisted,
self.whitelisted,
@@ -169,9 +170,9 @@ class User(AbstractUser):
except Exception:
self.log.exception("Failed to run post-login functions for %s", self.mxid)
async def update(self, update: TypeUpdate) -> None:
async def update(self, update: TypeUpdate) -> bool:
if not self.is_bot:
return
return False
if isinstance(update, (UpdateNewMessage, UpdateNewChannelMessage)):
message = update.message
@@ -185,19 +186,22 @@ class User(AbstractUser):
elif isinstance(update, UpdateShortMessage):
portal = po.Portal.get_by_tgid(update.user_id, self.tgid, "user")
else:
return
return False
self.register_portal(portal)
if portal:
self.register_portal(portal)
return True
# endregion
# region Telegram actions that need custom methods
def ensure_started(self, even_if_no_session: bool = False) -> "Awaitable[User]":
return super().ensure_started(even_if_no_session)
def ensure_started(self, even_if_no_session: bool = False) -> Coroutine[None, None, 'User']:
return cast(Coroutine[None, None, 'User'], super().ensure_started(even_if_no_session))
def set_presence(self, online: bool = True) -> None:
def set_presence(self, online: bool = True) -> bool:
if self.is_bot:
return
return False
return self.client(UpdateStatusRequest(offline=not online))
async def update_info(self, info: TLUser = None) -> None:
@@ -215,7 +219,7 @@ class User(AbstractUser):
if changed:
self.save()
async def log_out(self) -> None:
async def log_out(self) -> bool:
puppet = pu.Puppet.get(self.tgid)
if puppet.is_real_user:
await puppet.switch_mxid(None, None)
@@ -328,7 +332,7 @@ class User(AbstractUser):
# region Class instance lookup
@classmethod
def get_by_mxid(cls, mxid: str, create: bool=True) -> "Optional[User]":
def get_by_mxid(cls, mxid: MatrixUserId, create: bool = True) -> Optional['User']:
if not mxid:
raise ValueError("Matrix ID can't be empty")
@@ -351,7 +355,7 @@ class User(AbstractUser):
return None
@classmethod
def get_by_tgid(cls, tgid: int) -> "Optional[User]":
def get_by_tgid(cls, tgid: int) -> Optional['User']:
try:
return cls.by_tgid[tgid]
except KeyError:
@@ -365,7 +369,7 @@ class User(AbstractUser):
return None
@classmethod
def find_by_username(cls, username: str) -> "Optional[User]":
def find_by_username(cls, username: str) -> Optional['User']:
if not username:
return None
@@ -381,7 +385,7 @@ class User(AbstractUser):
# endregion
def init(context: "Context") -> List[Awaitable[User]]:
def init(context: 'Context') -> List[Coroutine]: # [None, None, AbstractUser]
global config
config = context.config
+2 -2
View File
@@ -17,10 +17,10 @@
def format_duration(seconds: int) -> str:
def pluralize(count, singular) -> None:
def pluralize(count: int, singular: str) -> str:
return singular if count == 1 else singular + "s"
def include(count, word) -> None:
def include(count: int, word: str) -> str:
return f"{count} {pluralize(count, word)}" if count > 0 else ""
minutes, seconds = divmod(seconds, 60)
+6 -6
View File
@@ -14,7 +14,7 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional
from typing import Dict, Optional
import json
import base64
import hashlib
@@ -28,13 +28,13 @@ def _get_checksum(key: str, payload: bytes) -> str:
return checksum
def sign_token(key: str, payload: dict) -> str:
payload = base64.urlsafe_b64encode(json.dumps(payload).encode("utf-8"))
checksum = _get_checksum(key, payload)
return f"{checksum}:{payload.decode('utf-8')}"
def sign_token(key: str, payload: Dict) -> str:
payload_b64 = base64.urlsafe_b64encode(json.dumps(payload).encode("utf-8"))
checksum = _get_checksum(key, payload_b64)
return f"{checksum}:{payload_b64.decode('utf-8')}"
def verify_token(key: str, data: str) -> Optional[dict]:
def verify_token(key: str, data: str) -> Optional[Dict]:
if not data:
return None
+4 -3
View File
@@ -23,7 +23,7 @@ from telethon.errors import *
from ...commands.auth import enter_password
from ...util import format_duration
from ...puppet import Puppet
from ...puppet import Puppet, PuppetError
from ...user import User
@@ -51,12 +51,13 @@ class AuthAPI(abc.ABC):
"account.", errcode="already-logged-in")
resp = await puppet.switch_mxid(token, user.mxid)
if resp == 2:
if resp == PuppetError.OnlyLoginSelf:
return self.get_mx_login_response(status=403, errcode="only-login-self",
error="You can only log in as your own Matrix user.")
elif resp == 1:
elif resp == PuppetError.InvalidAccessToken:
return self.get_mx_login_response(status=401, errcode="invalid-access-token",
error="Failed to verify access token.")
assert resp == PuppetError.Success, "Encountered an unhandled PuppetError."
return self.get_mx_login_response(mxid=user.mxid, status=200, state="logged-in")
@@ -15,7 +15,7 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from aiohttp import web
from typing import Tuple, Optional, Callable, Awaitable, TYPE_CHECKING
from typing import Awaitable, Callable, Dict, Optional, Tuple, TYPE_CHECKING
import asyncio
import logging
import json
@@ -24,6 +24,7 @@ from telethon.utils import get_peer_id, resolve_id
from telethon.tl.types import ChatForbidden, ChannelForbidden, TypeChat
from mautrix_appservice import AppService, MatrixRequestError, IntentError
from ...types import MatrixUserId
from ...user import User
from ...portal import Portal
from ...commands.portal import user_has_power_level, get_initial_state
@@ -36,7 +37,7 @@ if TYPE_CHECKING:
class ProvisioningAPI(AuthAPI):
log = logging.getLogger("mau.web.provisioning")
def __init__(self, context: "Context"):
def __init__(self, context: "Context") -> None:
super().__init__(context.loop)
self.secret = context.config["appservice.provisioning.shared_secret"]
self.az = context.az # type: AppService
@@ -411,7 +412,7 @@ class ProvisioningAPI(AuthAPI):
except json.JSONDecodeError:
return None
async def get_user(self, mxid: str, expect_logged_in: Optional[bool] = False,
async def get_user(self, mxid: MatrixUserId, expect_logged_in: Optional[bool] = False,
require_puppeting: bool = True, require_user: bool = True
) -> Tuple[Optional[User], Optional[web.Response]]:
if not mxid:
@@ -439,7 +440,7 @@ class ProvisioningAPI(AuthAPI):
expect_logged_in: Optional[bool] = False,
require_puppeting: bool = False,
want_data: bool = True,
) -> (Tuple[Optional[dict],
) -> (Tuple[Optional[Dict],
Optional[User],
Optional[web.Response]]):
err = self.check_authorization(request)