Add missing type hints and fix most type errors except for Optionals.
This commit is contained in:
@@ -14,7 +14,7 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Optional
|
||||
from typing import Coroutine, List, Optional
|
||||
import argparse
|
||||
import asyncio
|
||||
import logging.config
|
||||
@@ -115,7 +115,7 @@ with appserv.run(config["appservice.hostname"], config["appservice.port"]) as st
|
||||
startup_actions = (init_puppet(context) +
|
||||
init_user(context) +
|
||||
[start,
|
||||
context.mx.init_as_bot()])
|
||||
context.mx.init_as_bot()]) # type: List[Coroutine]
|
||||
|
||||
if context.bot:
|
||||
startup_actions.append(context.bot.start())
|
||||
|
||||
@@ -38,6 +38,7 @@ from .db import Message as DBMessage
|
||||
from .tgclient import MautrixTelegramClient
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .types import TelegramId
|
||||
from .context import Context
|
||||
from .config import Config
|
||||
from .bot import Bot
|
||||
@@ -67,10 +68,11 @@ class AbstractUser(ABC):
|
||||
self.whitelisted = False # type: bool
|
||||
self.relaybot_whitelisted = False # type: bool
|
||||
self.client = None # type: MautrixTelegramClient
|
||||
self.tgid = None # type: int
|
||||
self.tgid = None # type: TelegramId
|
||||
self.mxid = None # type: str
|
||||
self.is_relaybot = False # type: bool
|
||||
self.is_bot = False # type: bool
|
||||
self.relaybot = None # type: Optional[Bot]
|
||||
|
||||
@property
|
||||
def connected(self) -> bool:
|
||||
@@ -372,7 +374,7 @@ class AbstractUser(ABC):
|
||||
|
||||
def init(context: "Context") -> None:
|
||||
global config, MAX_DELETIONS
|
||||
AbstractUser.az, AbstractUser.db, config, AbstractUser.loop, AbstractUser.relaybot = context
|
||||
AbstractUser.az, AbstractUser.db, config, AbstractUser.loop, AbstractUser.relaybot = context.core
|
||||
AbstractUser.ignore_incoming_bot_events = config["bridge.relaybot.ignore_own_incoming_events"]
|
||||
AbstractUser.session_container = context.session_container
|
||||
MAX_DELETIONS = config.get("bridge.max_telegram_delete", 10)
|
||||
|
||||
+17
-12
@@ -14,7 +14,7 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Awaitable, Callable, Dict, Optional, Pattern, TYPE_CHECKING
|
||||
from typing import Awaitable, Callable, Dict, List, Optional, Pattern, TYPE_CHECKING
|
||||
import logging
|
||||
import re
|
||||
|
||||
@@ -27,12 +27,14 @@ from telethon.tl.functions.messages import GetChatsRequest, GetFullChatRequest
|
||||
from telethon.tl.functions.channels import GetChannelsRequest, GetParticipantRequest
|
||||
from telethon.errors import ChannelInvalidError, ChannelPrivateError
|
||||
|
||||
from .types import MatrixUserId
|
||||
from .abstract_user import AbstractUser
|
||||
from .db import BotChat
|
||||
from . import puppet as pu, portal as po, user as u
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .config import Config
|
||||
from .context import Context
|
||||
|
||||
config = None # type: Config
|
||||
|
||||
@@ -145,6 +147,7 @@ class Bot(AbstractUser):
|
||||
for p in participants:
|
||||
if p.user_id == tgid:
|
||||
return isinstance(p, (ChatParticipantCreator, ChatParticipantAdmin))
|
||||
return False
|
||||
|
||||
async def check_can_use_commands(self, event: Message, reply: ReplyFunc) -> bool:
|
||||
if not await self._can_use_commands(event.to_id, event.from_id):
|
||||
@@ -168,15 +171,16 @@ class Bot(AbstractUser):
|
||||
return await reply(
|
||||
"Portal is not public. Use `/invite <mxid>` to get an invite.")
|
||||
|
||||
async def handle_command_invite(self, portal: po.Portal, reply: ReplyFunc, mxid: str) -> None:
|
||||
if len(mxid) == 0:
|
||||
async def handle_command_invite(self, portal: po.Portal, reply: ReplyFunc,
|
||||
mxid_input: MatrixUserId) -> Message:
|
||||
if len(mxid_input) == 0:
|
||||
return await reply("Usage: `/invite <mxid>`")
|
||||
elif not portal.mxid:
|
||||
return await reply("Portal does not have Matrix room. "
|
||||
"Create one with /portal first.")
|
||||
if not self.mxid_regex.match(mxid):
|
||||
if not self.mxid_regex.match(mxid_input):
|
||||
return await reply("That doesn't look like a Matrix ID.")
|
||||
user = await u.User.get_by_mxid(mxid).ensure_started()
|
||||
user = await u.User.get_by_mxid(MatrixUserId(mxid_input)).ensure_started()
|
||||
if not user.relaybot_whitelisted:
|
||||
return await reply("That user is not whitelisted to use the bridge.")
|
||||
elif await user.is_logged_in():
|
||||
@@ -187,7 +191,7 @@ class Bot(AbstractUser):
|
||||
await portal.main_intent.invite(portal.mxid, user.mxid)
|
||||
return await reply(f"Invited `{user.mxid}` to the portal.")
|
||||
|
||||
def handle_command_id(self, message: Message, reply: ReplyFunc) -> None:
|
||||
def handle_command_id(self, message: Message, reply: ReplyFunc) -> Awaitable[Message]:
|
||||
# Provide the prefixed ID to the user so that the user wouldn't need to specify whether the
|
||||
# chat is a normal group or a supergroup/channel when using the ID.
|
||||
if isinstance(message.to_id, PeerChannel):
|
||||
@@ -210,7 +214,7 @@ class Bot(AbstractUser):
|
||||
return False
|
||||
|
||||
async def handle_command(self, message: Message) -> None:
|
||||
def reply(reply_text) -> None:
|
||||
def reply(reply_text: str) -> Awaitable[Message]:
|
||||
return self.client.send_message(message.to_id, reply_text, reply_to=message.id)
|
||||
|
||||
text = message.message
|
||||
@@ -231,7 +235,7 @@ class Bot(AbstractUser):
|
||||
mxid = text[text.index(" ") + 1:]
|
||||
except ValueError:
|
||||
mxid = ""
|
||||
await self.handle_command_invite(portal, reply, mxid=mxid)
|
||||
await self.handle_command_invite(portal, reply, mxid_input=mxid)
|
||||
|
||||
def handle_service_message(self, message: MessageService) -> None:
|
||||
to_id = message.to_id
|
||||
@@ -250,11 +254,12 @@ class Bot(AbstractUser):
|
||||
elif isinstance(action, MessageActionChatDeleteUser) and action.user_id == self.tgid:
|
||||
self.remove_chat(to_id)
|
||||
|
||||
async def update(self, update) -> None:
|
||||
async def update(self, update) -> bool:
|
||||
if not isinstance(update, (UpdateNewMessage, UpdateNewChannelMessage)):
|
||||
return
|
||||
return False
|
||||
if isinstance(update.message, MessageService):
|
||||
return self.handle_service_message(update.message)
|
||||
self.handle_service_message(update.message)
|
||||
return False
|
||||
|
||||
is_command = (isinstance(update.message, Message)
|
||||
and update.message.entities and len(update.message.entities) > 0
|
||||
@@ -270,7 +275,7 @@ class Bot(AbstractUser):
|
||||
return "bot"
|
||||
|
||||
|
||||
def init(context) -> Optional[Bot]:
|
||||
def init(context: 'Context') -> Optional[Bot]:
|
||||
global config
|
||||
config = context.config
|
||||
token = config["telegram.bot_token"]
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Dict
|
||||
from typing import Any, Awaitable, Dict, Optional
|
||||
import asyncio
|
||||
|
||||
from telethon.errors import (
|
||||
@@ -31,7 +31,7 @@ from ..util import format_duration
|
||||
@command_handler(needs_auth=False,
|
||||
help_section=SECTION_AUTH,
|
||||
help_text="Check if you're logged into Telegram.")
|
||||
async def ping(evt: CommandEvent) -> None:
|
||||
async def ping(evt: CommandEvent) -> Optional[Dict]:
|
||||
me = await evt.sender.client.get_me() if await evt.sender.is_logged_in() else None
|
||||
if me:
|
||||
return await evt.reply(f"You're logged in as @{me.username}")
|
||||
@@ -42,7 +42,7 @@ async def ping(evt: CommandEvent) -> None:
|
||||
@command_handler(needs_auth=False, needs_puppeting=False,
|
||||
help_section=SECTION_AUTH,
|
||||
help_text="Get the info of the message relay Telegram bot.")
|
||||
async def ping_bot(evt: CommandEvent) -> None:
|
||||
async def ping_bot(evt: CommandEvent) -> Optional[Dict]:
|
||||
if not evt.tgbot:
|
||||
return await evt.reply("Telegram message relay bot not configured.")
|
||||
bot_info = await evt.tgbot.client.get_me()
|
||||
@@ -57,19 +57,19 @@ async def ping_bot(evt: CommandEvent) -> None:
|
||||
help_section=SECTION_AUTH,
|
||||
help_text="Revert your Telegram account's Matrix puppet to use the default Matrix "
|
||||
"account.")
|
||||
async def logout_matrix(evt: CommandEvent) -> None:
|
||||
async def logout_matrix(evt: CommandEvent) -> Optional[Dict]:
|
||||
puppet = pu.Puppet.get(evt.sender.tgid)
|
||||
if not puppet.is_real_user:
|
||||
return await evt.reply("You are not logged in with your Matrix account.")
|
||||
await puppet.switch_mxid(None, None)
|
||||
await evt.reply("Reverted your Telegram account's Matrix puppet back to the default.")
|
||||
return await evt.reply("Reverted your Telegram account's Matrix puppet back to the default.")
|
||||
|
||||
|
||||
@command_handler(needs_auth=True, management_only=True, needs_matrix_puppeting=True,
|
||||
help_section=SECTION_AUTH,
|
||||
help_text="Replace your Telegram account's Matrix puppet with your own Matrix "
|
||||
"account")
|
||||
async def login_matrix(evt: CommandEvent) -> None:
|
||||
async def login_matrix(evt: CommandEvent) -> Optional[Dict]:
|
||||
puppet = pu.Puppet.get(evt.sender.tgid)
|
||||
if puppet.is_real_user:
|
||||
return await evt.reply("You have already logged in with your Matrix account. "
|
||||
@@ -100,7 +100,7 @@ async def login_matrix(evt: CommandEvent) -> None:
|
||||
return await evt.reply("This bridge instance has been configured to not allow logging in.")
|
||||
|
||||
|
||||
async def enter_matrix_token(evt: CommandEvent) -> None:
|
||||
async def enter_matrix_token(evt: CommandEvent) -> Dict:
|
||||
evt.sender.command_status = None
|
||||
|
||||
puppet = pu.Puppet.get(evt.sender.tgid)
|
||||
@@ -109,10 +109,11 @@ async def enter_matrix_token(evt: CommandEvent) -> None:
|
||||
"Log out with `$cmdprefix+sp logout-matrix` first.")
|
||||
|
||||
resp = await puppet.switch_mxid(" ".join(evt.args), evt.sender.mxid)
|
||||
if resp == 2:
|
||||
if resp == pu.PuppetError.OnlyLoginSelf:
|
||||
return await evt.reply("You can only log in as your own Matrix user.")
|
||||
elif resp == 1:
|
||||
elif resp == pu.PuppetError.InvalidAccessToken:
|
||||
return await evt.reply("Failed to verify access token.")
|
||||
assert resp == pu.PuppetError.Success, "Encountered an unhandled PuppetError."
|
||||
return await evt.reply(
|
||||
f"Replaced your Telegram account's Matrix puppet with {puppet.custom_mxid}.")
|
||||
|
||||
@@ -121,7 +122,7 @@ async def enter_matrix_token(evt: CommandEvent) -> None:
|
||||
help_section=SECTION_AUTH,
|
||||
help_args="<_phone_> <_full name_>",
|
||||
help_text="Register to Telegram")
|
||||
async def register(evt: CommandEvent) -> None:
|
||||
async def register(evt: CommandEvent) -> Optional[Dict]:
|
||||
if await evt.sender.is_logged_in():
|
||||
return await evt.reply("You are already logged in.")
|
||||
elif len(evt.args) < 1:
|
||||
@@ -138,9 +139,10 @@ async def register(evt: CommandEvent) -> None:
|
||||
"action": "Register",
|
||||
"full_name": full_name,
|
||||
})
|
||||
return None
|
||||
|
||||
|
||||
async def enter_code_register(evt: CommandEvent) -> None:
|
||||
async def enter_code_register(evt: CommandEvent) -> Dict:
|
||||
if len(evt.args) == 0:
|
||||
return await evt.reply("**Usage:** `$cmdprefix+sp <code>`")
|
||||
try:
|
||||
@@ -169,7 +171,7 @@ async def enter_code_register(evt: CommandEvent) -> None:
|
||||
@command_handler(needs_auth=False, management_only=True,
|
||||
help_section=SECTION_AUTH,
|
||||
help_text="Get instructions on how to log in.")
|
||||
async def login(evt: CommandEvent) -> None:
|
||||
async def login(evt: CommandEvent) -> Optional[Dict]:
|
||||
if await evt.sender.is_logged_in():
|
||||
return await evt.reply("You are already logged in.")
|
||||
|
||||
@@ -200,7 +202,8 @@ async def login(evt: CommandEvent) -> None:
|
||||
return await evt.reply("This bridge instance has been configured to not allow logging in.")
|
||||
|
||||
|
||||
async def request_code(evt: CommandEvent, phone_number: str, next_status: Dict[str, str]) -> None:
|
||||
async def request_code(evt: CommandEvent, phone_number: str, next_status: Dict[str, Any]
|
||||
) -> Dict:
|
||||
ok = False
|
||||
try:
|
||||
await evt.sender.ensure_started(even_if_no_session=True)
|
||||
@@ -232,7 +235,7 @@ async def request_code(evt: CommandEvent, phone_number: str, next_status: Dict[s
|
||||
|
||||
|
||||
@command_handler(needs_auth=False)
|
||||
async def enter_phone_or_token(evt: CommandEvent) -> None:
|
||||
async def enter_phone_or_token(evt: CommandEvent) -> Optional[Dict]:
|
||||
if len(evt.args) == 0:
|
||||
return await evt.reply("**Usage:** `$cmdprefix+sp enter-phone-or-token <phone-or-token>`")
|
||||
elif not evt.config.get("bridge.allow_matrix_login", True):
|
||||
@@ -252,10 +255,11 @@ async def enter_phone_or_token(evt: CommandEvent) -> None:
|
||||
"next": enter_code,
|
||||
"action": "Login",
|
||||
})
|
||||
return None
|
||||
|
||||
|
||||
@command_handler(needs_auth=False)
|
||||
async def enter_code(evt: CommandEvent) -> None:
|
||||
async def enter_code(evt: CommandEvent) -> Optional[Dict]:
|
||||
if len(evt.args) == 0:
|
||||
return await evt.reply("**Usage:** `$cmdprefix+sp enter-code <code>`")
|
||||
elif not evt.config.get("bridge.allow_matrix_login", True):
|
||||
@@ -267,10 +271,11 @@ async def enter_code(evt: CommandEvent) -> None:
|
||||
evt.log.exception("Error sending phone code")
|
||||
return await evt.reply("Unhandled exception while sending code. "
|
||||
"Check console for more details.")
|
||||
return None
|
||||
|
||||
|
||||
@command_handler(needs_auth=False)
|
||||
async def enter_password(evt: CommandEvent) -> None:
|
||||
async def enter_password(evt: CommandEvent) -> Optional[Dict]:
|
||||
if len(evt.args) == 0:
|
||||
return await evt.reply("**Usage:** `$cmdprefix+sp enter-password <password>`")
|
||||
elif not evt.config.get("bridge.allow_matrix_login", True):
|
||||
@@ -286,9 +291,10 @@ async def enter_password(evt: CommandEvent) -> None:
|
||||
evt.log.exception("Error sending password")
|
||||
return await evt.reply("Unhandled exception while sending password. "
|
||||
"Check console for more details.")
|
||||
return None
|
||||
|
||||
|
||||
async def sign_in(evt: CommandEvent, **sign_in_info) -> None:
|
||||
async def sign_in(evt: CommandEvent, **sign_in_info) -> Dict:
|
||||
try:
|
||||
await evt.sender.ensure_started(even_if_no_session=True)
|
||||
user = await evt.sender.client.sign_in(**sign_in_info)
|
||||
@@ -313,7 +319,7 @@ async def sign_in(evt: CommandEvent, **sign_in_info) -> None:
|
||||
@command_handler(needs_auth=True,
|
||||
help_section=SECTION_AUTH,
|
||||
help_text="Log out from Telegram.")
|
||||
async def logout(evt: CommandEvent) -> None:
|
||||
async def logout(evt: CommandEvent) -> Optional[Dict]:
|
||||
if await evt.sender.log_out():
|
||||
return await evt.reply("Logged out successfully.")
|
||||
return await evt.reply("Failed to log out.")
|
||||
|
||||
@@ -14,21 +14,21 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Tuple, List
|
||||
from typing import Dict, List, NewType, Optional, Tuple, Union
|
||||
|
||||
from mautrix_appservice import MatrixRequestError, IntentAPI
|
||||
|
||||
from ..types import MatrixRoomId, MatrixUserId
|
||||
from . import command_handler, CommandEvent, SECTION_ADMIN
|
||||
from .. import puppet as pu, portal as po
|
||||
|
||||
ManagementRoomList = List[Tuple[str, str]]
|
||||
RoomIDList = List[str]
|
||||
ManagementRoom = NewType('ManagementRoom', Tuple[MatrixRoomId, MatrixUserId])
|
||||
|
||||
|
||||
async def _find_rooms(intent: IntentAPI) -> Tuple[ManagementRoomList, RoomIDList,
|
||||
async def _find_rooms(intent: IntentAPI) -> Tuple[List[ManagementRoom], List[MatrixRoomId],
|
||||
List["po.Portal"], List["po.Portal"]]:
|
||||
management_rooms = [] # type: ManagementRoomList
|
||||
unidentified_rooms = [] # type: RoomIDList
|
||||
management_rooms = [] # type: List[ManagementRoom]
|
||||
unidentified_rooms = [] # type: List[MatrixRoomId]
|
||||
portals = [] # type: List[po.Portal]
|
||||
empty_portals = [] # type: List[po.Portal]
|
||||
|
||||
@@ -45,7 +45,7 @@ async def _find_rooms(intent: IntentAPI) -> Tuple[ManagementRoomList, RoomIDList
|
||||
if pu.Puppet.get_id_from_mxid(other_member):
|
||||
unidentified_rooms.append(room)
|
||||
else:
|
||||
management_rooms.append((room, other_member))
|
||||
management_rooms.append(ManagementRoom((room, other_member)))
|
||||
else:
|
||||
unidentified_rooms.append(room)
|
||||
else:
|
||||
@@ -61,7 +61,7 @@ async def _find_rooms(intent: IntentAPI) -> Tuple[ManagementRoomList, RoomIDList
|
||||
@command_handler(needs_admin=True, needs_auth=False, management_only=True, name="clean-rooms",
|
||||
help_section=SECTION_ADMIN,
|
||||
help_text="Clean up unused portal/management rooms.")
|
||||
async def clean_rooms(evt: CommandEvent) -> None:
|
||||
async def clean_rooms(evt: CommandEvent) -> Optional[Dict]:
|
||||
management_rooms, unidentified_rooms, portals, empty_portals = await _find_rooms(evt.az.intent)
|
||||
|
||||
reply = ["#### Management rooms (M)"]
|
||||
@@ -106,13 +106,14 @@ async def clean_rooms(evt: CommandEvent) -> None:
|
||||
return await evt.reply("\n".join(reply))
|
||||
|
||||
|
||||
async def set_rooms_to_clean(evt, management_rooms: ManagementRoomList,
|
||||
unidentified_rooms: RoomIDList, portals: List["po.Portal"],
|
||||
async def set_rooms_to_clean(evt, management_rooms: List[ManagementRoom],
|
||||
unidentified_rooms: List[MatrixRoomId], portals: List["po.Portal"],
|
||||
empty_portals: List["po.Portal"]) -> None:
|
||||
command = evt.args[0]
|
||||
rooms_to_clean = []
|
||||
rooms_to_clean = [] # type: List[Union[po.Portal, MatrixRoomId]]
|
||||
if command == "clean-recommended":
|
||||
rooms_to_clean = empty_portals + unidentified_rooms
|
||||
rooms_to_clean += empty_portals
|
||||
rooms_to_clean += unidentified_rooms
|
||||
elif command == "clean-groups":
|
||||
if len(evt.args) < 2:
|
||||
return await evt.reply("**Usage:** `$cmdprefix+sp clean-groups [M][A][U][I]")
|
||||
@@ -158,7 +159,7 @@ async def set_rooms_to_clean(evt, management_rooms: ManagementRoomList,
|
||||
"`$cmdprefix+sp confirm-clean`.")
|
||||
|
||||
|
||||
async def execute_room_cleanup(evt, rooms_to_clean) -> None:
|
||||
async def execute_room_cleanup(evt, rooms_to_clean: List[Union[po.Portal, MatrixRoomId]]) -> None:
|
||||
if len(evt.args) > 0 and evt.args[0] == "confirm-clean":
|
||||
await evt.reply(f"Cleaning {len(rooms_to_clean)} rooms. "
|
||||
"This might take a while.")
|
||||
@@ -167,7 +168,7 @@ async def execute_room_cleanup(evt, rooms_to_clean) -> None:
|
||||
if isinstance(room, po.Portal):
|
||||
await room.cleanup_and_delete()
|
||||
cleaned += 1
|
||||
elif isinstance(room, str):
|
||||
elif isinstance(room, str): # str is aliased by MatrixRoomId
|
||||
await po.Portal.cleanup_room(evt.az.intent, room, message="Room deleted")
|
||||
cleaned += 1
|
||||
evt.sender.command_status = None
|
||||
|
||||
@@ -14,19 +14,20 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import List, Dict, Callable, Optional
|
||||
from typing import Any, Awaitable, Callable, Coroutine, Dict, List, NamedTuple, Optional, Union
|
||||
from collections import namedtuple
|
||||
import markdown
|
||||
import logging
|
||||
|
||||
from telethon.errors import FloodWaitError
|
||||
|
||||
from ..types import MatrixRoomId
|
||||
from ..util import format_duration
|
||||
from .. import user as u, context as c
|
||||
|
||||
command_handlers = {} # type: Dict[str, CommandHandler]
|
||||
|
||||
HelpSection = namedtuple("HelpSection", "name order description")
|
||||
HelpSection = NamedTuple('HelpSection', [('name', str), ('order', int), ('description', str)])
|
||||
|
||||
SECTION_GENERAL = HelpSection("General", 0, "")
|
||||
SECTION_AUTH = HelpSection("Authentication", 10, "")
|
||||
@@ -37,8 +38,8 @@ SECTION_ADMIN = HelpSection("Administration", 50, "")
|
||||
|
||||
|
||||
class CommandEvent:
|
||||
def __init__(self, processor: "CommandProcessor", room: str, sender: u.User, command: str,
|
||||
args: List[str], is_management: bool, is_portal: bool) -> None:
|
||||
def __init__(self, processor: 'CommandProcessor', room: MatrixRoomId, sender: u.User,
|
||||
command: str, args: List[str], is_management: bool, is_portal: bool) -> None:
|
||||
self.az = processor.az
|
||||
self.log = processor.log
|
||||
self.loop = processor.loop
|
||||
@@ -53,7 +54,8 @@ class CommandEvent:
|
||||
self.is_management = is_management
|
||||
self.is_portal = is_portal
|
||||
|
||||
def reply(self, message: str, allow_html: bool = False, render_markdown: bool = True) -> None:
|
||||
def reply(self, message: str, allow_html: bool = False, render_markdown: bool = True
|
||||
) -> Awaitable[Dict]:
|
||||
message = message.replace("$cmdprefix+sp ",
|
||||
"" if self.is_management else f"{self.command_prefix} ")
|
||||
message = message.replace("$cmdprefix", self.command_prefix)
|
||||
@@ -66,7 +68,7 @@ class CommandEvent:
|
||||
|
||||
|
||||
class CommandHandler:
|
||||
def __init__(self, handler: Callable[[CommandEvent], None], needs_auth: bool,
|
||||
def __init__(self, handler: Callable[[CommandEvent], Awaitable[Dict]], needs_auth: bool,
|
||||
needs_puppeting: bool, needs_matrix_puppeting: bool, needs_admin: bool,
|
||||
management_only: bool, name: str, help_text: str, help_args: str,
|
||||
help_section: HelpSection) -> None:
|
||||
@@ -103,7 +105,8 @@ class CommandHandler:
|
||||
(not self.needs_admin or is_admin) and
|
||||
(not self.needs_auth or is_logged_in))
|
||||
|
||||
async def __call__(self, evt: CommandEvent) -> None:
|
||||
async def __call__(self, evt: CommandEvent
|
||||
) -> Dict:
|
||||
error = await self.get_permission_error(evt)
|
||||
if error is not None:
|
||||
return await evt.reply(error)
|
||||
@@ -118,13 +121,21 @@ class CommandHandler:
|
||||
return f"**{self.name}** {self._help_args} - {self._help_text}"
|
||||
|
||||
|
||||
def command_handler(_func: Optional[Callable[[CommandEvent], None]] = None, *, needs_auth=True,
|
||||
needs_puppeting=True, needs_matrix_puppeting=False, needs_admin=False,
|
||||
management_only=False, name=None, help_text="", help_args="",
|
||||
help_section=None) -> None:
|
||||
def command_handler(_func: Optional[Callable[[CommandEvent], Awaitable[Dict]]] = None, *,
|
||||
needs_auth: bool = True,
|
||||
needs_puppeting: bool = True,
|
||||
needs_matrix_puppeting: bool = False,
|
||||
needs_admin: bool = False,
|
||||
management_only: bool = False,
|
||||
name: Optional[str] = None,
|
||||
help_text: str = "",
|
||||
help_args: str = "",
|
||||
help_section: HelpSection = None
|
||||
) -> Callable[[Callable[[CommandEvent], Awaitable[Optional[Dict]]]],
|
||||
CommandHandler]:
|
||||
input_name = name
|
||||
|
||||
def decorator(func: Callable[[CommandEvent], None]) -> None:
|
||||
def decorator(func: Callable[[CommandEvent], Awaitable[Optional[Dict]]]) -> CommandHandler:
|
||||
name = input_name or func.__name__.replace("_", "-")
|
||||
handler = CommandHandler(func, needs_auth, needs_puppeting, needs_matrix_puppeting,
|
||||
needs_admin, management_only, name, help_text, help_args,
|
||||
@@ -139,26 +150,26 @@ class CommandProcessor:
|
||||
log = logging.getLogger("mau.commands")
|
||||
|
||||
def __init__(self, context: c.Context) -> None:
|
||||
self.az, self.db, self.config, self.loop, self.tgbot = context
|
||||
self.az, self.db, self.config, self.loop, self.tgbot = context.core
|
||||
self.public_website = context.public_website
|
||||
self.command_prefix = self.config["bridge.command_prefix"]
|
||||
|
||||
async def handle(self, room: str, sender: u.User, command: str, args: List[str],
|
||||
is_management: bool, is_portal: bool) -> None:
|
||||
async def handle(self, room: MatrixRoomId, sender: u.User, command: str, args: List[str],
|
||||
is_management: bool, is_portal: bool) -> Optional[Dict]:
|
||||
evt = CommandEvent(self, room, sender, command, args, is_management, is_portal)
|
||||
orig_command = command
|
||||
command = command.lower()
|
||||
try:
|
||||
command = command_handlers[command]
|
||||
command_handler = command_handlers[command]
|
||||
except KeyError:
|
||||
if sender.command_status and "next" in sender.command_status:
|
||||
args.insert(0, orig_command)
|
||||
evt.command = ""
|
||||
command = sender.command_status["next"]
|
||||
else:
|
||||
command = command_handlers["unknown-command"]
|
||||
command_handler = command_handlers["unknown-command"]
|
||||
try:
|
||||
await command(evt)
|
||||
await command_handler(evt)
|
||||
except FloodWaitError as e:
|
||||
return await evt.reply(f"Flood error: Please wait {format_duration(e.seconds)}")
|
||||
except Exception:
|
||||
@@ -166,3 +177,4 @@ class CommandProcessor:
|
||||
f"{evt.command} {' '.join(args)} from {sender.mxid}")
|
||||
return await evt.reply("Unhandled error while handling command. "
|
||||
"Check logs for more details.")
|
||||
return None
|
||||
|
||||
@@ -14,46 +14,49 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from . import command_handler, CommandEvent, _command_handlers, SECTION_GENERAL
|
||||
from .handler import HelpSection
|
||||
|
||||
|
||||
@command_handler(needs_auth=False, needs_puppeting=False,
|
||||
help_section=SECTION_GENERAL,
|
||||
help_text="Cancel an ongoing action (such as login)")
|
||||
def cancel(evt: CommandEvent) -> None:
|
||||
async def cancel(evt: CommandEvent) -> Optional[Dict]:
|
||||
if evt.sender.command_status:
|
||||
action = evt.sender.command_status["action"]
|
||||
evt.sender.command_status = None
|
||||
return evt.reply(f"{action} cancelled.")
|
||||
return await evt.reply(f"{action} cancelled.")
|
||||
else:
|
||||
return evt.reply("No ongoing command.")
|
||||
return await evt.reply("No ongoing command.")
|
||||
|
||||
|
||||
@command_handler(needs_auth=False, needs_puppeting=False)
|
||||
def unknown_command(evt: CommandEvent) -> None:
|
||||
return evt.reply("Unknown command. Try `$cmdprefix+sp help` for help.")
|
||||
async def unknown_command(evt: CommandEvent) -> Optional[Dict]:
|
||||
return await evt.reply("Unknown command. Try `$cmdprefix+sp help` for help.")
|
||||
|
||||
|
||||
help_cache = {}
|
||||
help_cache = {} # type: Dict[Tuple[bool, bool, bool, bool, bool], str]
|
||||
|
||||
|
||||
async def _get_help_text(evt: CommandEvent) -> None:
|
||||
async def _get_help_text(evt: CommandEvent) -> str:
|
||||
cache_key = (evt.is_management, evt.sender.puppet_whitelisted,
|
||||
evt.sender.matrix_puppet_whitelisted, evt.sender.is_admin,
|
||||
await evt.sender.is_logged_in())
|
||||
if cache_key not in help_cache:
|
||||
help = {}
|
||||
help_sections = {} # type: Dict[HelpSection, List[str]]
|
||||
for handler in _command_handlers.values():
|
||||
if handler.has_help and handler.has_permission(*cache_key):
|
||||
help.setdefault(handler.help_section, [])
|
||||
help[handler.help_section].append(handler.help + " ")
|
||||
help = sorted(help.items(), key=lambda item: item[0].order)
|
||||
help = ["#### {}\n{}\n".format(key.name, "\n".join(value)) for key, value in help]
|
||||
help_sections.setdefault(handler.help_section, [])
|
||||
help_sections[handler.help_section].append(handler.help + " ")
|
||||
help_sorted = sorted(help_sections.items(), key=lambda item: item[0].order)
|
||||
help = ["#### {}\n{}\n".format(key.name, "\n".join(value)) for key, value in help_sorted]
|
||||
help_cache[cache_key] = "\n".join(help)
|
||||
return help_cache[cache_key]
|
||||
|
||||
|
||||
def _get_management_status(evt: CommandEvent) -> None:
|
||||
def _get_management_status(evt: CommandEvent) -> str:
|
||||
if evt.is_management:
|
||||
return "This is a management room: prefixing commands with `$cmdprefix` is not required."
|
||||
elif evt.is_portal:
|
||||
@@ -65,5 +68,5 @@ def _get_management_status(evt: CommandEvent) -> None:
|
||||
@command_handler(needs_auth=False, needs_puppeting=False,
|
||||
help_section=SECTION_GENERAL,
|
||||
help_text="Show this help message.")
|
||||
async def help(evt: CommandEvent) -> None:
|
||||
async def help(evt: CommandEvent) -> Optional[Dict]:
|
||||
return await evt.reply(_get_management_status(evt) + "\n" + await _get_help_text(evt))
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Optional, Callable
|
||||
from typing import Awaitable, Dict, Callable, Coroutine, Optional, Tuple, Union, cast
|
||||
import asyncio
|
||||
|
||||
from telethon.errors import (ChatAdminRequiredError, UsernameInvalidError,
|
||||
@@ -22,6 +22,7 @@ from telethon.errors import (ChatAdminRequiredError, UsernameInvalidError,
|
||||
from telethon.tl.types import ChatForbidden, ChannelForbidden
|
||||
from mautrix_appservice import MatrixRequestError, IntentAPI
|
||||
|
||||
from ..types import MatrixRoomId, TelegramId
|
||||
from .. import portal as po, user as u
|
||||
from . import (command_handler, CommandEvent,
|
||||
SECTION_ADMIN, SECTION_CREATING_PORTALS, SECTION_PORTAL_MANAGEMENT)
|
||||
@@ -31,7 +32,7 @@ from . import (command_handler, CommandEvent,
|
||||
help_section=SECTION_ADMIN,
|
||||
help_args="<_level_> [_mxid_]",
|
||||
help_text="Set a temporary power level without affecting Telegram.")
|
||||
async def set_power_level(evt: CommandEvent) -> None:
|
||||
async def set_power_level(evt: CommandEvent) -> Dict:
|
||||
try:
|
||||
level = int(evt.args[0])
|
||||
except KeyError:
|
||||
@@ -46,11 +47,12 @@ async def set_power_level(evt: CommandEvent) -> None:
|
||||
except MatrixRequestError:
|
||||
evt.log.exception("Failed to set power level.")
|
||||
return await evt.reply("Failed to set power level.")
|
||||
return {}
|
||||
|
||||
|
||||
@command_handler(help_section=SECTION_PORTAL_MANAGEMENT,
|
||||
help_text="Get a Telegram invite link to the current chat.")
|
||||
async def invite_link(evt: CommandEvent) -> None:
|
||||
async def invite_link(evt: CommandEvent) -> Dict:
|
||||
portal = po.Portal.get_by_mxid(evt.room_id)
|
||||
if not portal:
|
||||
return await evt.reply("This is not a portal room.")
|
||||
@@ -68,7 +70,7 @@ async def invite_link(evt: CommandEvent) -> None:
|
||||
|
||||
|
||||
async def user_has_power_level(room: str, intent, sender: u.User, event: str, default: int = 50
|
||||
) -> None:
|
||||
) -> bool:
|
||||
if sender.is_admin:
|
||||
return True
|
||||
# Make sure the state store contains the power levels.
|
||||
@@ -82,8 +84,9 @@ async def user_has_power_level(room: str, intent, sender: u.User, event: str, de
|
||||
|
||||
|
||||
async def _get_portal_and_check_permission(evt: CommandEvent, permission: str,
|
||||
action: Optional[str] = None) -> None:
|
||||
room_id = evt.args[0] if len(evt.args) > 0 else evt.room_id
|
||||
action: Optional[str] = None
|
||||
) -> Tuple[Union[Dict, po.Portal], bool]:
|
||||
room_id = MatrixRoomId(evt.args[0]) if len(evt.args) > 0 else evt.room_id
|
||||
|
||||
portal = po.Portal.get_by_mxid(room_id)
|
||||
if not portal:
|
||||
@@ -97,8 +100,8 @@ async def _get_portal_and_check_permission(evt: CommandEvent, permission: str,
|
||||
|
||||
|
||||
def _get_portal_murder_function(action: str, room_id: str, function: Callable, command: str,
|
||||
completed_message: str) -> None:
|
||||
async def post_confirm(confirm) -> None:
|
||||
completed_message: str) -> Dict:
|
||||
async def post_confirm(confirm) -> Optional[Dict]:
|
||||
confirm.sender.command_status = None
|
||||
if len(confirm.args) > 0 and confirm.args[0] == f"confirm-{command}":
|
||||
await function()
|
||||
@@ -106,6 +109,7 @@ def _get_portal_murder_function(action: str, room_id: str, function: Callable, c
|
||||
return await confirm.reply(completed_message)
|
||||
else:
|
||||
return await confirm.reply(f"{action} cancelled.")
|
||||
return None
|
||||
|
||||
return {
|
||||
"next": post_confirm,
|
||||
@@ -118,10 +122,11 @@ def _get_portal_murder_function(action: str, room_id: str, function: Callable, c
|
||||
help_text="Remove all users from the current portal room and forget the portal. "
|
||||
"Only works for group chats; to delete a private chat portal, simply "
|
||||
"leave the room.")
|
||||
async def delete_portal(evt: CommandEvent) -> None:
|
||||
portal, ok = await _get_portal_and_check_permission(evt, "unbridge")
|
||||
async def delete_portal(evt: CommandEvent) -> Optional[Dict]:
|
||||
result, ok = await _get_portal_and_check_permission(evt, "unbridge")
|
||||
if not ok:
|
||||
return
|
||||
return None
|
||||
portal = cast('po.Portal', result)
|
||||
|
||||
evt.sender.command_status = _get_portal_murder_function("Portal deletion", portal.mxid,
|
||||
portal.cleanup_and_delete, "delete",
|
||||
@@ -139,10 +144,11 @@ async def delete_portal(evt: CommandEvent) -> None:
|
||||
@command_handler(needs_auth=False, needs_puppeting=False,
|
||||
help_section=SECTION_PORTAL_MANAGEMENT,
|
||||
help_text="Remove puppets from the current portal room and forget the portal.")
|
||||
async def unbridge(evt: CommandEvent) -> None:
|
||||
portal, ok = await _get_portal_and_check_permission(evt, "unbridge")
|
||||
async def unbridge(evt: CommandEvent) -> Optional[Dict]:
|
||||
result, ok = await _get_portal_and_check_permission(evt, "unbridge")
|
||||
if not ok:
|
||||
return
|
||||
return None
|
||||
portal = cast('po.Portal', result)
|
||||
|
||||
evt.sender.command_status = _get_portal_murder_function("Room unbridging", portal.mxid,
|
||||
portal.unbridge, "unbridge",
|
||||
@@ -158,11 +164,11 @@ async def unbridge(evt: CommandEvent) -> None:
|
||||
help_text="Bridge the current Matrix room to the Telegram chat with the given "
|
||||
"ID. The ID must be the prefixed version that you get with the `/id` "
|
||||
"command of the Telegram-side bot.")
|
||||
async def bridge(evt: CommandEvent) -> None:
|
||||
async def bridge(evt: CommandEvent) -> Dict:
|
||||
if len(evt.args) == 0:
|
||||
return await evt.reply("**Usage:** "
|
||||
"`$cmdprefix+sp bridge <Telegram chat ID> [Matrix room ID]`")
|
||||
room_id = evt.args[1] if len(evt.args) > 1 else evt.room_id
|
||||
room_id = MatrixRoomId(evt.args[1]) if len(evt.args) > 1 else evt.room_id
|
||||
that_this = "This" if room_id == evt.room_id else "That"
|
||||
|
||||
portal = po.Portal.get_by_mxid(room_id)
|
||||
@@ -173,12 +179,12 @@ async def bridge(evt: CommandEvent) -> None:
|
||||
return await evt.reply(f"You do not have the permissions to bridge {that_this} room.")
|
||||
|
||||
# The /id bot command provides the prefixed ID, so we assume
|
||||
tgid = evt.args[0]
|
||||
if tgid.startswith("-100"):
|
||||
tgid = int(tgid[4:])
|
||||
tgid_str = evt.args[0]
|
||||
if tgid_str.startswith("-100"):
|
||||
tgid = TelegramId(int(tgid_str[4:]))
|
||||
peer_type = "channel"
|
||||
elif tgid.startswith("-"):
|
||||
tgid = -int(tgid)
|
||||
elif tgid_str.startswith("-"):
|
||||
tgid = TelegramId(-int(tgid_str))
|
||||
peer_type = "chat"
|
||||
else:
|
||||
return await evt.reply("That doesn't seem like a prefixed Telegram chat ID.\n\n"
|
||||
@@ -224,7 +230,8 @@ async def bridge(evt: CommandEvent) -> None:
|
||||
"chat to this room, use `$cmdprefix+sp continue`")
|
||||
|
||||
|
||||
async def cleanup_old_portal_while_bridging(evt: CommandEvent, portal: "po.Portal") -> None:
|
||||
async def cleanup_old_portal_while_bridging(evt: CommandEvent, portal: "po.Portal"
|
||||
) -> Tuple[bool, Coroutine[None, None, None]]:
|
||||
if not portal.mxid:
|
||||
await evt.reply("The portal seems to have lost its Matrix room between you"
|
||||
"calling `$cmdprefix+sp bridge` and this command.\n\n"
|
||||
@@ -247,7 +254,7 @@ async def cleanup_old_portal_while_bridging(evt: CommandEvent, portal: "po.Porta
|
||||
return False, None
|
||||
|
||||
|
||||
async def confirm_bridge(evt: CommandEvent) -> None:
|
||||
async def confirm_bridge(evt: CommandEvent) -> Optional[Dict]:
|
||||
status = evt.sender.command_status
|
||||
try:
|
||||
portal = po.Portal.get_by_tgid(status["tgid"], peer_type=status["peer_type"])
|
||||
@@ -260,7 +267,7 @@ async def confirm_bridge(evt: CommandEvent) -> None:
|
||||
if "mxid" in status:
|
||||
ok, coro = await cleanup_old_portal_while_bridging(evt, portal)
|
||||
if not ok:
|
||||
return
|
||||
return None
|
||||
elif coro:
|
||||
asyncio.ensure_future(coro, loop=evt.loop)
|
||||
await evt.reply("Cleaning up previous portal room...")
|
||||
@@ -304,7 +311,7 @@ async def confirm_bridge(evt: CommandEvent) -> None:
|
||||
return await evt.reply("Bridging complete. Portal synchronization should begin momentarily.")
|
||||
|
||||
|
||||
async def get_initial_state(intent: IntentAPI, room_id: str) -> None:
|
||||
async def get_initial_state(intent: IntentAPI, room_id: str) -> Tuple[str, str, Dict]:
|
||||
state = await intent.get_room_state(room_id)
|
||||
title = None
|
||||
about = None
|
||||
@@ -330,7 +337,7 @@ async def get_initial_state(intent: IntentAPI, room_id: str) -> None:
|
||||
help_text="Create a Telegram chat of the given type for the current Matrix room. "
|
||||
"The type is either `group`, `supergroup` or `channel` (defaults to "
|
||||
"`group`).")
|
||||
async def create(evt: CommandEvent) -> None:
|
||||
async def create(evt: CommandEvent) -> Dict:
|
||||
type = evt.args[0] if len(evt.args) > 0 else "group"
|
||||
if type not in {"chat", "group", "supergroup", "channel"}:
|
||||
return await evt.reply(
|
||||
@@ -365,7 +372,7 @@ async def create(evt: CommandEvent) -> None:
|
||||
|
||||
@command_handler(help_section=SECTION_PORTAL_MANAGEMENT,
|
||||
help_text="Upgrade a normal Telegram group to a supergroup.")
|
||||
async def upgrade(evt: CommandEvent) -> None:
|
||||
async def upgrade(evt: CommandEvent) -> Dict:
|
||||
portal = po.Portal.get_by_mxid(evt.room_id)
|
||||
if not portal:
|
||||
return await evt.reply("This is not a portal room.")
|
||||
@@ -387,7 +394,7 @@ async def upgrade(evt: CommandEvent) -> None:
|
||||
help_args="<_name_|`-`>",
|
||||
help_text="Change the username of a supergroup/channel. "
|
||||
"To disable, use a dash (`-`) as the name.")
|
||||
async def group_name(evt: CommandEvent) -> None:
|
||||
async def group_name(evt: CommandEvent) -> Dict:
|
||||
if len(evt.args) == 0:
|
||||
return await evt.reply("**Usage:** `$cmdprefix+sp group-name <name/->`")
|
||||
|
||||
@@ -423,7 +430,7 @@ async def group_name(evt: CommandEvent) -> None:
|
||||
help_args="<`whitelist`|`blacklist`>",
|
||||
help_text="Change whether the bridge will allow or disallow bridging rooms by "
|
||||
"default.")
|
||||
async def filter_mode(evt: CommandEvent) -> None:
|
||||
async def filter_mode(evt: CommandEvent) -> Dict:
|
||||
try:
|
||||
mode = evt.args[0]
|
||||
if mode not in ("whitelist", "blacklist"):
|
||||
@@ -448,19 +455,19 @@ async def filter_mode(evt: CommandEvent) -> None:
|
||||
help_section=SECTION_ADMIN,
|
||||
help_args="<`whitelist`|`blacklist`> <_chat ID_>",
|
||||
help_text="Allow or disallow bridging a specific chat.")
|
||||
async def filter(evt: CommandEvent) -> None:
|
||||
async def filter(evt: CommandEvent) -> Optional[Dict]:
|
||||
try:
|
||||
action = evt.args[0]
|
||||
if action not in ("whitelist", "blacklist", "add", "remove"):
|
||||
raise ValueError()
|
||||
|
||||
id = evt.args[1]
|
||||
if id.startswith("-100"):
|
||||
id = int(id[4:])
|
||||
elif id.startswith("-"):
|
||||
id = int(id[1:])
|
||||
id_str = evt.args[1]
|
||||
if id_str.startswith("-100"):
|
||||
id = int(id_str[4:])
|
||||
elif id_str.startswith("-"):
|
||||
id = int(id_str[1:])
|
||||
else:
|
||||
id = int(id)
|
||||
id = int(id_str)
|
||||
except (IndexError, ValueError):
|
||||
return await evt.reply("**Usage:** `$cmdprefix+sp filter <whitelist/blacklist> <chat ID>`")
|
||||
|
||||
@@ -490,3 +497,4 @@ async def filter(evt: CommandEvent) -> None:
|
||||
list.remove(id)
|
||||
save()
|
||||
return await evt.reply(f"Chat ID removed from {mode}.")
|
||||
return None
|
||||
|
||||
@@ -14,10 +14,13 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Awaitable, Dict, List, Optional, Tuple
|
||||
import re
|
||||
|
||||
from telethon.errors import (
|
||||
InviteHashInvalidError, InviteHashExpiredError, UserAlreadyParticipantError)
|
||||
from telethon.tl.types import User as TLUser
|
||||
from telethon.tl.types import TypeUpdates
|
||||
from telethon.tl.functions.messages import ImportChatInviteRequest, CheckChatInviteRequest
|
||||
from telethon.tl.functions.channels import JoinChannelRequest
|
||||
|
||||
@@ -28,7 +31,7 @@ from . import command_handler, CommandEvent, SECTION_MISC, SECTION_CREATING_PORT
|
||||
@command_handler(help_section=SECTION_MISC,
|
||||
help_args="[_-r|--remote_] <_query_>",
|
||||
help_text="Search your contacts or the Telegram servers for users.")
|
||||
async def search(evt: CommandEvent) -> None:
|
||||
async def search(evt: CommandEvent) -> Optional[Dict]:
|
||||
if len(evt.args) == 0:
|
||||
return await evt.reply("**Usage:** `$cmdprefix+sp search [-r|--remote] <query>`")
|
||||
|
||||
@@ -49,7 +52,7 @@ async def search(evt: CommandEvent) -> None:
|
||||
"Minimum length of remote query is 5 characters.")
|
||||
return await evt.reply("No results 3:")
|
||||
|
||||
reply = []
|
||||
reply = [] # type: List[str]
|
||||
if remote:
|
||||
reply += ["**Results from Telegram server:**", ""]
|
||||
else:
|
||||
@@ -70,7 +73,7 @@ async def search(evt: CommandEvent) -> None:
|
||||
"either the internal user ID, the username or the phone number. "
|
||||
"**N.B.** The phone numbers you start chats with must already be in "
|
||||
"your contacts.")
|
||||
async def private_message(evt: CommandEvent) -> None:
|
||||
async def private_message(evt: CommandEvent) -> Optional[Dict]:
|
||||
if len(evt.args) == 0:
|
||||
return await evt.reply("**Usage:** `$cmdprefix+sp pm <user identifier>`")
|
||||
|
||||
@@ -89,7 +92,7 @@ async def private_message(evt: CommandEvent) -> None:
|
||||
f"{pu.Puppet.get_displayname(user, False)}")
|
||||
|
||||
|
||||
async def _join(evt: CommandEvent, arg: str) -> None:
|
||||
async def _join(evt: CommandEvent, arg: str) -> Tuple[TypeUpdates, Dict]:
|
||||
if arg.startswith("joinchat/"):
|
||||
invite_hash = arg[len("joinchat/"):]
|
||||
try:
|
||||
@@ -112,7 +115,7 @@ async def _join(evt: CommandEvent, arg: str) -> None:
|
||||
@command_handler(help_section=SECTION_CREATING_PORTALS,
|
||||
help_args="<_link_>",
|
||||
help_text="Join a chat with an invite link.")
|
||||
async def join(evt: CommandEvent) -> None:
|
||||
async def join(evt: CommandEvent) -> Optional[Dict]:
|
||||
if len(evt.args) == 0:
|
||||
return await evt.reply("**Usage:** `$cmdprefix+sp join <invite link>`")
|
||||
|
||||
@@ -123,7 +126,7 @@ async def join(evt: CommandEvent) -> None:
|
||||
|
||||
updates, _ = await _join(evt, arg.group(1))
|
||||
if not updates:
|
||||
return
|
||||
return None
|
||||
|
||||
for chat in updates.chats:
|
||||
portal = po.Portal.get_by_entity(chat)
|
||||
@@ -134,12 +137,13 @@ async def join(evt: CommandEvent) -> None:
|
||||
await evt.reply(f"Creating room for {chat.title}... This might take a while.")
|
||||
await portal.create_matrix_room(evt.sender, chat, [evt.sender.mxid])
|
||||
return await evt.reply(f"Created room for {portal.title}")
|
||||
return None
|
||||
|
||||
|
||||
@command_handler(help_section=SECTION_MISC,
|
||||
help_args="[`chats`|`contacts`|`me`]",
|
||||
help_text="Synchronize your chat portals, contacts and/or own info.")
|
||||
async def sync(evt: CommandEvent) -> None:
|
||||
async def sync(evt: CommandEvent) -> Optional[Dict]:
|
||||
if len(evt.args) > 0:
|
||||
sync_only = evt.args[0]
|
||||
if sync_only not in ("chats", "contacts", "me"):
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Tuple, Any, Optional
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from ruamel.yaml import YAML
|
||||
from ruamel.yaml.comments import CommentedMap
|
||||
import random
|
||||
@@ -25,7 +25,7 @@ yaml.indent(4)
|
||||
|
||||
|
||||
class DictWithRecursion:
|
||||
def __init__(self, data: CommentedMap = None) -> None:
|
||||
def __init__(self, data: Optional[CommentedMap] = None) -> None:
|
||||
self._data = data or CommentedMap() # type: CommentedMap
|
||||
|
||||
def _recursive_get(self, data: CommentedMap, key: str, default_value: Any) -> Any:
|
||||
@@ -99,7 +99,7 @@ class Config(DictWithRecursion):
|
||||
self.path = path # type: str
|
||||
self.registration_path = registration_path # type: str
|
||||
self.base_path = base_path # type: str
|
||||
self._registration = None # type: dict
|
||||
self._registration = None # type: Optional[Dict]
|
||||
|
||||
def load(self) -> None:
|
||||
with open(self.path, 'r') as stream:
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import Generator, Optional, Tuple, Union, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import asyncio
|
||||
@@ -44,9 +44,7 @@ class Context:
|
||||
self.public_website = None # type: PublicBridgeWebsite
|
||||
self.provisioning_api = None # type: ProvisioningAPI
|
||||
|
||||
def __iter__(self) -> None:
|
||||
yield self.az
|
||||
yield self.db
|
||||
yield self.config
|
||||
yield self.loop
|
||||
yield self.bot
|
||||
@property
|
||||
def core(self) -> Tuple['AppService', 'scoped_session', 'Config',
|
||||
'asyncio.AbstractEventLoop', Optional['Bot']]:
|
||||
return (self.az, self.db, self.config, self.loop, self.bot)
|
||||
|
||||
@@ -14,6 +14,8 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Dict
|
||||
|
||||
from sqlalchemy import (Column, UniqueConstraint, ForeignKey, ForeignKeyConstraint, Integer,
|
||||
BigInteger, String, Boolean, Text)
|
||||
from sqlalchemy.sql import expression
|
||||
@@ -88,20 +90,20 @@ class RoomState(Base):
|
||||
|
||||
room_id = Column(String, primary_key=True)
|
||||
_power_levels_text = Column("power_levels", Text, nullable=True)
|
||||
_power_levels_json = None
|
||||
_power_levels_json = {} # type: Dict
|
||||
|
||||
@property
|
||||
def has_power_levels(self) -> None:
|
||||
def has_power_levels(self) -> bool:
|
||||
return bool(self._power_levels_text)
|
||||
|
||||
@property
|
||||
def power_levels(self) -> None:
|
||||
def power_levels(self) -> Dict:
|
||||
if not self._power_levels_json and self._power_levels_text:
|
||||
self._power_levels_json = json.loads(self._power_levels_text)
|
||||
return self._power_levels_json or {}
|
||||
return self._power_levels_json
|
||||
|
||||
@power_levels.setter
|
||||
def power_levels(self, val) -> None:
|
||||
def power_levels(self, val: Dict) -> None:
|
||||
self._power_levels_json = val
|
||||
self._power_levels_text = json.dumps(val)
|
||||
|
||||
@@ -116,7 +118,7 @@ class UserProfile(Base):
|
||||
displayname = Column(String, nullable=True)
|
||||
avatar_url = Column(String, nullable=True)
|
||||
|
||||
def dict(self) -> None:
|
||||
def dict(self) -> Dict[str, Column]:
|
||||
return {
|
||||
"membership": self.membership,
|
||||
"displayname": self.displayname,
|
||||
|
||||
@@ -80,12 +80,12 @@ class MatrixParser(HTMLParser, MatrixParserCommon):
|
||||
args["url"] = url
|
||||
return MessageEntityTextUrl, None
|
||||
|
||||
def handle_starttag(self, tag: str, attrs: List[Tuple[str, str]]):
|
||||
def handle_starttag(self, tag: str, attrs_list: List[Tuple[str, str]]):
|
||||
self._open_tags.appendleft(tag)
|
||||
self._open_tags_meta.appendleft(0)
|
||||
|
||||
attrs = dict(attrs)
|
||||
entity_type = None # type: type(TypeMessageEntity)
|
||||
attrs = dict(attrs_list)
|
||||
entity_type = None # type: Optional[Type[TypeMessageEntity]]
|
||||
args = {} # type: Dict[str, Any]
|
||||
if tag in ("strong", "b"):
|
||||
entity_type = MessageEntityBold
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Optional, List, Tuple, Union, Callable
|
||||
from typing import Callable, List, Optional, Sequence, Tuple, Type, Union
|
||||
from lxml import html
|
||||
|
||||
from telethon.tl.types import (MessageEntityMention as Mention,
|
||||
@@ -83,7 +83,7 @@ def offset_length_multiply(amount: int):
|
||||
|
||||
|
||||
class TelegramMessage:
|
||||
def __init__(self, text: str = "", entities: Optional[List[TypeMessageEntity]] = None):
|
||||
def __init__(self, text: str = "", entities: Optional[List[TypeMessageEntity]] = None) -> None:
|
||||
self.text = text # type: str
|
||||
self.entities = entities or [] # type: List[TypeMessageEntity]
|
||||
|
||||
@@ -120,7 +120,7 @@ class TelegramMessage:
|
||||
self.text = msg.text + self.text
|
||||
return self
|
||||
|
||||
def format(self, entity_type: type(TypeMessageEntity), offset: int = None, length: int = None,
|
||||
def format(self, entity_type: Type[TypeMessageEntity], offset: int = None, length: int = None,
|
||||
**kwargs) -> "TelegramMessage":
|
||||
self.entities.append(entity_type(offset=offset or 0,
|
||||
length=length if length is not None else len(self.text),
|
||||
@@ -158,7 +158,8 @@ class TelegramMessage:
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def join(items: List[Union[str, "TelegramMessage"]], separator: str = " ") -> "TelegramMessage":
|
||||
def join(items: Sequence[Union[str, "TelegramMessage"]],
|
||||
separator: str = " ") -> "TelegramMessage":
|
||||
main = TelegramMessage()
|
||||
for msg in items:
|
||||
if isinstance(msg, str):
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Optional, List, Tuple, TYPE_CHECKING
|
||||
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
from html import escape
|
||||
import logging
|
||||
import re
|
||||
@@ -28,6 +28,7 @@ from telethon.tl.types import (MessageEntityMention, MessageEntityMentionName,
|
||||
from mautrix_appservice import MatrixRequestError
|
||||
from mautrix_appservice.intent_api import IntentAPI
|
||||
|
||||
from ..types import TelegramId
|
||||
from .. import user as u, puppet as pu, portal as po
|
||||
from ..db import Message as DBMessage
|
||||
from .util import (add_surrogates, remove_surrogates, trim_reply_fallback_html,
|
||||
@@ -40,14 +41,14 @@ if TYPE_CHECKING:
|
||||
try:
|
||||
from lxml.html.diff import htmldiff
|
||||
except ImportError:
|
||||
htmldiff = None # type: function
|
||||
htmldiff = None # type: ignore
|
||||
|
||||
|
||||
log = logging.getLogger("mau.fmt.tg") # type: logging.Logger
|
||||
should_highlight_edits = False # type: bool
|
||||
|
||||
|
||||
def telegram_reply_to_matrix(evt: Message, source: "AbstractUser") -> dict:
|
||||
def telegram_reply_to_matrix(evt: Message, source: 'AbstractUser') -> Dict:
|
||||
if evt.reply_to_msg_id:
|
||||
space = (evt.to_id.channel_id
|
||||
if isinstance(evt, Message) and isinstance(evt.to_id, PeerChannel)
|
||||
@@ -116,7 +117,7 @@ def highlight_edits(new_html: str, old_html: str) -> str:
|
||||
|
||||
|
||||
async def _add_reply_header(source: "AbstractUser", text: str, html: str, evt: Message,
|
||||
relates_to: dict, main_intent: IntentAPI, is_edit: bool
|
||||
relates_to: Dict, main_intent: IntentAPI, is_edit: bool
|
||||
) -> Tuple[str, str]:
|
||||
space = (evt.to_id.channel_id
|
||||
if isinstance(evt, Message) and isinstance(evt.to_id, PeerChannel)
|
||||
@@ -177,10 +178,10 @@ async def _add_reply_header(source: "AbstractUser", text: str, html: str, evt: M
|
||||
async def telegram_to_matrix(evt: Message, source: "AbstractUser",
|
||||
main_intent: Optional[IntentAPI] = None,
|
||||
is_edit: bool = False, prefix_text: Optional[str] = None,
|
||||
prefix_html: Optional[str] = None) -> Tuple[str, str, dict]:
|
||||
prefix_html: Optional[str] = None) -> Tuple[str, str, Dict]:
|
||||
text = add_surrogates(evt.message)
|
||||
html = _telegram_entities_to_matrix_catch(text, evt.entities) if evt.entities else None
|
||||
relates_to = {}
|
||||
relates_to = {} # type: Dict
|
||||
|
||||
if prefix_html:
|
||||
html = prefix_html + (html or escape(text))
|
||||
@@ -217,6 +218,7 @@ def _telegram_entities_to_matrix_catch(text: str, entities: List[TypeMessageEnti
|
||||
"message=%s\n"
|
||||
"entities=%s",
|
||||
text, entities)
|
||||
return "[failed conversion in _telegram_entities_to_matrix]"
|
||||
|
||||
|
||||
def _telegram_entities_to_matrix(text: str, entities: List[TypeMessageEntity]) -> str:
|
||||
@@ -290,7 +292,7 @@ def _parse_mention(html: List[str], entity_text: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _parse_name_mention(html: List[str], entity_text: str, user_id: int) -> bool:
|
||||
def _parse_name_mention(html: List[str], entity_text: str, user_id: TelegramId) -> bool:
|
||||
user = u.User.get_by_tgid(user_id)
|
||||
if user:
|
||||
mxid = user.mxid
|
||||
@@ -315,8 +317,8 @@ def _parse_url(html: List[str], entity_text: str, url: str) -> bool:
|
||||
|
||||
message_link_match = message_link_regex.match(url)
|
||||
if message_link_match:
|
||||
group, msgid = message_link_match.groups()
|
||||
msgid = int(msgid)
|
||||
group, msgid_str = message_link_match.groups()
|
||||
msgid = int(msgid_str)
|
||||
|
||||
portal = po.Portal.find_by_username(group)
|
||||
if portal:
|
||||
|
||||
+58
-37
@@ -14,23 +14,31 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import List, Dict, Tuple, Set, Match
|
||||
from typing import Dict, List, Match, Optional, Set, Tuple, TYPE_CHECKING
|
||||
import logging
|
||||
import asyncio
|
||||
import re
|
||||
|
||||
from mautrix_appservice import MatrixRequestError, IntentError
|
||||
|
||||
from .types import MatrixEvent, MatrixEventId, MatrixRoomId, MatrixUserId
|
||||
from . import user as u, portal as po, puppet as pu, commands as com
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mautrix_appservice import AppService
|
||||
from .context import Context
|
||||
from sqlalchemy.orm import scoped_session
|
||||
from .config import Config
|
||||
from .bot import Bot
|
||||
|
||||
|
||||
class MatrixHandler:
|
||||
log = logging.getLogger("mau.mx") # type: logging.Logger
|
||||
|
||||
def __init__(self, context) -> None:
|
||||
self.az, self.db, self.config, _, self.tgbot = context
|
||||
def __init__(self, context: 'Context') -> None:
|
||||
self.az, self.db, self.config, _, self.tgbot = context.core
|
||||
self.commands = com.CommandProcessor(context) # type: com.CommandProcessor
|
||||
self.previously_typing = [] # type: List[str]
|
||||
self.previously_typing = [] # type: List[MatrixUserId]
|
||||
|
||||
self.az.matrix_event_handler(self.handle_event)
|
||||
|
||||
@@ -50,7 +58,8 @@ class MatrixHandler:
|
||||
except asyncio.TimeoutError:
|
||||
self.log.exception("TimeoutError when trying to set avatar")
|
||||
|
||||
async def handle_puppet_invite(self, room_id, puppet: pu.Puppet, inviter: u.User) -> None:
|
||||
async def handle_puppet_invite(self, room_id: MatrixRoomId, puppet: pu.Puppet, inviter: u.User
|
||||
) -> None:
|
||||
intent = puppet.default_mxid_intent
|
||||
self.log.debug(f"{inviter} invited puppet for {puppet.tgid} to {room_id}")
|
||||
if not await inviter.is_logged_in():
|
||||
@@ -80,6 +89,7 @@ class MatrixHandler:
|
||||
|
||||
await intent.join_room(room_id)
|
||||
portal = po.Portal.get_by_tgid(puppet.tgid, inviter.tgid, "user")
|
||||
# TODO: if portal is None:
|
||||
if portal.mxid:
|
||||
try:
|
||||
await intent.invite(portal.mxid, inviter.mxid)
|
||||
@@ -95,13 +105,13 @@ class MatrixHandler:
|
||||
portal.mxid = room_id
|
||||
portal.save()
|
||||
inviter.register_portal(portal)
|
||||
await intent.send_notice(room_id, "po.Portal to private chat created.")
|
||||
await intent.send_notice(room_id, "Portal to private chat created.")
|
||||
else:
|
||||
await intent.join_room(room_id)
|
||||
await intent.send_notice(room_id, "This puppet will remain inactive until a "
|
||||
"Telegram chat is created for this room.")
|
||||
|
||||
async def accept_bot_invite(self, room_id: str, inviter: u.User) -> None:
|
||||
async def accept_bot_invite(self, room_id: MatrixRoomId, inviter: u.User) -> None:
|
||||
tries = 0
|
||||
while tries < 5:
|
||||
try:
|
||||
@@ -126,9 +136,13 @@ class MatrixHandler:
|
||||
"<code>bridge.permissions</code> section in your config file.")
|
||||
await self.az.intent.leave_room(room_id)
|
||||
|
||||
async def handle_invite(self, room_id: str, user_id: str, inviter_mxid: str) -> None:
|
||||
async def handle_invite(self, room_id: MatrixRoomId, user_id: MatrixUserId,
|
||||
inviter_mxid: MatrixUserId) -> None:
|
||||
self.log.debug(f"{inviter_mxid} invited {user_id} to {room_id}")
|
||||
inviter = await u.User.get_by_mxid(inviter_mxid).ensure_started()
|
||||
inviter = u.User.get_by_mxid(inviter_mxid)
|
||||
if inviter is None:
|
||||
self.log.exception("Failed to find user with Matrix ID {inviter_mxid}")
|
||||
await inviter.ensure_started()
|
||||
if user_id == self.az.bot_mxid:
|
||||
return await self.accept_bot_invite(room_id, inviter)
|
||||
elif not inviter.whitelisted:
|
||||
@@ -150,7 +164,8 @@ class MatrixHandler:
|
||||
|
||||
# The rest can probably be ignored
|
||||
|
||||
async def handle_join(self, room_id: str, user_id: str, event_id: str) -> None:
|
||||
async def handle_join(self, room_id: MatrixRoomId, user_id: MatrixUserId,
|
||||
event_id: MatrixEventId) -> None:
|
||||
user = await u.User.get_by_mxid(user_id).ensure_started()
|
||||
|
||||
portal = po.Portal.get_by_mxid(room_id)
|
||||
@@ -171,7 +186,8 @@ class MatrixHandler:
|
||||
if await user.is_logged_in() or portal.has_bot:
|
||||
await portal.join_matrix(user, event_id)
|
||||
|
||||
async def handle_part(self, room_id: str, user_id, sender_mxid: str, event_id: str) -> None:
|
||||
async def handle_part(self, room_id: MatrixRoomId, user_id: MatrixUserId,
|
||||
sender_mxid: MatrixUserId, event_id: MatrixEventId) -> None:
|
||||
self.log.debug(f"{user_id} left {room_id}")
|
||||
|
||||
sender = u.User.get_by_mxid(sender_mxid, create=False)
|
||||
@@ -185,6 +201,7 @@ class MatrixHandler:
|
||||
|
||||
puppet = pu.Puppet.get_by_mxid(user_id)
|
||||
if sender and puppet:
|
||||
# TODO: Puppet should probably be an AbstractUser
|
||||
await portal.leave_matrix(puppet, sender, event_id)
|
||||
|
||||
user = u.User.get_by_mxid(user_id, create=False)
|
||||
@@ -194,7 +211,7 @@ class MatrixHandler:
|
||||
if await user.is_logged_in() or portal.has_bot:
|
||||
await portal.leave_matrix(user, sender, event_id)
|
||||
|
||||
def is_command(self, message: dict) -> Tuple[bool, str]:
|
||||
def is_command(self, message: Dict) -> Tuple[bool, str]:
|
||||
text = message.get("body", "")
|
||||
prefix = self.config["bridge.command_prefix"]
|
||||
is_command = text.startswith(prefix)
|
||||
@@ -202,9 +219,10 @@ class MatrixHandler:
|
||||
text = text[len(prefix) + 1:]
|
||||
return is_command, text
|
||||
|
||||
async def handle_message(self, room, sender, message, event_id) -> None:
|
||||
async def handle_message(self, room: MatrixRoomId, sender_id: MatrixUserId, message: Dict,
|
||||
event_id: MatrixEventId) -> None:
|
||||
is_command, text = self.is_command(message)
|
||||
sender = await u.User.get_by_mxid(sender).ensure_started()
|
||||
sender = await u.User.get_by_mxid(sender_id).ensure_started()
|
||||
if not sender.relaybot_whitelisted:
|
||||
self.log.debug(f"Ignoring message \"{message}\" from {sender} to {room}:"
|
||||
" u.User is not whitelisted.")
|
||||
@@ -237,7 +255,8 @@ class MatrixHandler:
|
||||
is_portal=portal is not None)
|
||||
|
||||
@staticmethod
|
||||
async def handle_redaction(room_id: str, sender_mxid: str, event_id: str) -> None:
|
||||
async def handle_redaction(room_id: MatrixRoomId, sender_mxid: MatrixUserId,
|
||||
event_id: MatrixEventId) -> None:
|
||||
sender = await u.User.get_by_mxid(sender_mxid).ensure_started()
|
||||
if not sender.relaybot_whitelisted:
|
||||
return
|
||||
@@ -249,14 +268,15 @@ class MatrixHandler:
|
||||
await portal.handle_matrix_deletion(sender, event_id)
|
||||
|
||||
@staticmethod
|
||||
async def handle_power_levels(room_id: str, sender_mxid: str, new: dict, old: dict) -> None:
|
||||
async def handle_power_levels(room_id: MatrixRoomId, sender_mxid: MatrixUserId,
|
||||
new: Dict, old: Dict) -> None:
|
||||
portal = po.Portal.get_by_mxid(room_id)
|
||||
sender = await u.User.get_by_mxid(sender_mxid).ensure_started()
|
||||
if await sender.has_full_access(allow_bot=True) and portal:
|
||||
await portal.handle_matrix_power_levels(sender, new["users"], old["users"])
|
||||
|
||||
@staticmethod
|
||||
async def handle_room_meta(evt_type: str, room_id: str, sender_mxid: str,
|
||||
async def handle_room_meta(evt_type: str, room_id: MatrixRoomId, sender_mxid: MatrixUserId,
|
||||
content: dict) -> None:
|
||||
portal = po.Portal.get_by_mxid(room_id)
|
||||
sender = await u.User.get_by_mxid(sender_mxid).ensure_started()
|
||||
@@ -271,8 +291,8 @@ class MatrixHandler:
|
||||
await handler(sender, content[content_key])
|
||||
|
||||
@staticmethod
|
||||
async def handle_room_pin(room_id: str, sender_mxid: str, new_events: Set[str],
|
||||
old_events: Set[str]) -> None:
|
||||
async def handle_room_pin(room_id: MatrixRoomId, sender_mxid: MatrixUserId,
|
||||
new_events: Set[str], old_events: Set[str]) -> None:
|
||||
portal = po.Portal.get_by_mxid(room_id)
|
||||
sender = await u.User.get_by_mxid(sender_mxid).ensure_started()
|
||||
if await sender.has_full_access(allow_bot=True) and portal:
|
||||
@@ -285,8 +305,8 @@ class MatrixHandler:
|
||||
await portal.handle_matrix_pin(sender, None)
|
||||
|
||||
@staticmethod
|
||||
async def handle_name_change(room_id: str, user_id: str, displayname: str,
|
||||
prev_displayname: str, event_id: str) -> None:
|
||||
async def handle_name_change(room_id: MatrixRoomId, user_id: MatrixUserId, displayname: str,
|
||||
prev_displayname: str, event_id: MatrixEventId) -> None:
|
||||
portal = po.Portal.get_by_mxid(room_id)
|
||||
if not portal or not portal.has_bot:
|
||||
return
|
||||
@@ -296,13 +316,14 @@ class MatrixHandler:
|
||||
await portal.name_change_matrix(user, displayname, prev_displayname, event_id)
|
||||
|
||||
@staticmethod
|
||||
def parse_read_receipts(content: dict) -> Dict[str, str]:
|
||||
def parse_read_receipts(content: Dict) -> Dict[MatrixUserId, MatrixEventId]:
|
||||
return {user_id: event_id
|
||||
for event_id, receipts in content.items()
|
||||
for user_id in receipts.get("m.read", {})}
|
||||
|
||||
@staticmethod
|
||||
async def handle_read_receipts(room_id: str, receipts: Dict[str, str]) -> None:
|
||||
async def handle_read_receipts(room_id: MatrixRoomId,
|
||||
receipts: Dict[MatrixUserId, MatrixEventId]) -> None:
|
||||
portal = po.Portal.get_by_mxid(room_id)
|
||||
if not portal:
|
||||
return
|
||||
@@ -314,13 +335,13 @@ class MatrixHandler:
|
||||
await portal.mark_read(user, event_id)
|
||||
|
||||
@staticmethod
|
||||
async def handle_presence(user_id: str, presence: str) -> None:
|
||||
async def handle_presence(user_id: MatrixUserId, presence: str) -> None:
|
||||
user = await u.User.get_by_mxid(user_id).ensure_started()
|
||||
if not await user.is_logged_in():
|
||||
return
|
||||
await user.set_presence(presence == "online")
|
||||
user.set_presence(presence == "online")
|
||||
|
||||
async def handle_typing(self, room_id: str, now_typing: List[str]) -> None:
|
||||
async def handle_typing(self, room_id: MatrixRoomId, now_typing: List[MatrixUserId]) -> None:
|
||||
portal = po.Portal.get_by_mxid(room_id)
|
||||
if not portal:
|
||||
return
|
||||
@@ -335,35 +356,35 @@ class MatrixHandler:
|
||||
if not await user.is_logged_in():
|
||||
continue
|
||||
|
||||
await portal.set_typing(user, is_typing)
|
||||
portal.set_typing(user, is_typing)
|
||||
|
||||
self.previously_typing = now_typing
|
||||
|
||||
def filter_matrix_event(self, event: dict) -> None:
|
||||
def filter_matrix_event(self, event: MatrixEvent) -> bool:
|
||||
sender = event.get("sender", None)
|
||||
if not sender:
|
||||
return False
|
||||
return (sender == self.az.bot_mxid
|
||||
or pu.Puppet.get_id_from_mxid(sender) is not None)
|
||||
|
||||
async def try_handle_event(self, evt: dict) -> None:
|
||||
async def try_handle_event(self, evt: MatrixEvent) -> None:
|
||||
try:
|
||||
await self.handle_event(evt)
|
||||
except Exception:
|
||||
self.log.exception("Error handling manually received Matrix event")
|
||||
|
||||
async def handle_event(self, evt: dict) -> None:
|
||||
async def handle_event(self, evt: MatrixEvent) -> None:
|
||||
if self.filter_matrix_event(evt):
|
||||
return
|
||||
self.log.debug("Received event: %s", evt)
|
||||
evt_type = evt.get("type", "m.unknown") # type: str
|
||||
room_id = evt.get("room_id", None) # type: str
|
||||
event_id = evt.get("event_id", None) # type: str
|
||||
sender = evt.get("sender", None) # type: str
|
||||
content = evt.get("content", {}) # type: dict
|
||||
room_id = evt.get("room_id", None) # type: Optional[MatrixRoomId]
|
||||
event_id = evt.get("event_id", None) # type: Optional[MatrixEventId]
|
||||
sender = evt.get("sender", None) # type: Optional[MatrixUserId]
|
||||
content = evt.get("content", {}) # type: Dict
|
||||
if evt_type == "m.room.member":
|
||||
state_key = evt["state_key"] # type: str
|
||||
prev_content = evt.get("unsigned", {}).get("prev_content", {}) # type: dict
|
||||
state_key = evt["state_key"] # type: MatrixUserId
|
||||
prev_content = evt.get("unsigned", {}).get("prev_content", {}) # type: Dict
|
||||
membership = content.get("membership", "") # type: str
|
||||
prev_membership = prev_content.get("membership", "leave") # type: str
|
||||
if membership == prev_membership:
|
||||
@@ -387,7 +408,7 @@ class MatrixHandler:
|
||||
elif evt_type == "m.room.redaction":
|
||||
await self.handle_redaction(room_id, sender, evt["redacts"])
|
||||
elif evt_type == "m.room.power_levels":
|
||||
prev_content = evt.get("unsigned", {}).get("prev_content", {}) # type: dict
|
||||
prev_content = evt.get("unsigned", {}).get("prev_content", {})
|
||||
await self.handle_power_levels(room_id, sender, evt["content"], prev_content)
|
||||
elif evt_type in ("m.room.name", "m.room.avatar", "m.room.topic"):
|
||||
await self.handle_room_meta(evt_type, room_id, sender, evt["content"])
|
||||
|
||||
+73
-61
@@ -14,7 +14,7 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Awaitable, Dict, List, Optional, Pattern, Tuple, Union, TYPE_CHECKING
|
||||
from typing import Awaitable, Dict, List, Optional, Pattern, Tuple, Union, cast, TYPE_CHECKING
|
||||
from collections import deque
|
||||
from datetime import datetime
|
||||
from string import Template
|
||||
@@ -62,7 +62,7 @@ from telethon.tl.types import (
|
||||
UserFull)
|
||||
from mautrix_appservice import MatrixRequestError, IntentError, AppService, IntentAPI
|
||||
|
||||
|
||||
from .types import MatrixEventId, MatrixRoomId, MatrixUserId, TelegramId
|
||||
from .context import Context
|
||||
from .db import Portal as DBPortal, Message as DBMessage, TelegramFile as DBTelegramFile
|
||||
from . import puppet as p, user as u, formatter, util
|
||||
@@ -105,18 +105,18 @@ class Portal:
|
||||
by_mxid = {} # type: Dict[str, Portal]
|
||||
by_tgid = {} # type: Dict[Tuple[int, int], Portal]
|
||||
|
||||
def __init__(self, tgid: int, peer_type: str, tg_receiver: Optional[int] = None,
|
||||
mxid: Optional[str] = None, username: Optional[str] = None,
|
||||
def __init__(self, tgid: TelegramId, peer_type: str, tg_receiver: Optional[int] = None,
|
||||
mxid: Optional[MatrixRoomId] = None, username: Optional[str] = None,
|
||||
megagroup: Optional[bool] = False, title: Optional[str] = None,
|
||||
about: Optional[str] = None, photo_id: Optional[str] = None,
|
||||
db_instance: DBPortal = None) -> None:
|
||||
self.mxid = mxid # type: str
|
||||
self.tgid = tgid # type: int
|
||||
self.mxid = mxid # type: Optional[MatrixRoomId]
|
||||
self.tgid = tgid # type: TelegramId
|
||||
self.tg_receiver = tg_receiver or tgid # type: int
|
||||
self.peer_type = peer_type # type: str
|
||||
self.username = username # type: str
|
||||
self.megagroup = megagroup # type: bool
|
||||
self.title = title # type: str
|
||||
self.title = title # type: Optional[str]
|
||||
self.about = about # type: str
|
||||
self.photo_id = photo_id # type: str
|
||||
self._db_instance = db_instance # type: DBPortal
|
||||
@@ -161,7 +161,7 @@ class Portal:
|
||||
|
||||
@property
|
||||
def has_bot(self) -> bool:
|
||||
return self.bot and self.bot.is_in_chat(self.tgid)
|
||||
return bool(self.bot and self.bot.is_in_chat(self.tgid))
|
||||
|
||||
@property
|
||||
def main_intent(self) -> IntentAPI:
|
||||
@@ -270,8 +270,8 @@ class Portal:
|
||||
else:
|
||||
raise ValueError("Invalid invite identifier given to invite_matrix()")
|
||||
|
||||
async def update_matrix_room(self, user: "AbstractUser", entity: TypeChat, direct: bool,
|
||||
puppet: p.Puppet = None, levels: dict = None,
|
||||
async def update_matrix_room(self, user: 'AbstractUser', entity: TypeChat, direct: bool,
|
||||
puppet: p.Puppet = None, levels: Dict = None,
|
||||
users: List[User] = None,
|
||||
participants: List[TypeParticipant] = None) -> None:
|
||||
if not direct:
|
||||
@@ -303,8 +303,8 @@ class Portal:
|
||||
async with self._room_create_lock:
|
||||
return await self._create_matrix_room(user, entity, invites)
|
||||
|
||||
async def _create_matrix_room(self, user: "AbstractUser", entity: TypeChat, invites: InviteList
|
||||
) -> Optional[str]:
|
||||
async def _create_matrix_room(self, user: 'AbstractUser', entity: TypeChat, invites: InviteList
|
||||
) -> Optional[MatrixRoomId]:
|
||||
direct = self.peer_type == "user"
|
||||
|
||||
if self.mxid:
|
||||
@@ -369,6 +369,8 @@ class Portal:
|
||||
participants=participants),
|
||||
loop=self.loop)
|
||||
|
||||
return self.mxid
|
||||
|
||||
def _get_base_power_levels(self, levels: dict = None, entity: TypeChat = None) -> dict:
|
||||
levels = levels or {}
|
||||
power_level_requirement = (0 if self.peer_type == "chat" and not entity.admins_enabled
|
||||
@@ -437,18 +439,19 @@ class Portal:
|
||||
and config["bridge.max_initial_member_sync"] == -1
|
||||
and (self.megagroup or self.peer_type != "channel"))
|
||||
if trust_member_list:
|
||||
joined_mxids = await self.main_intent.get_room_members(self.mxid)
|
||||
for user in joined_mxids:
|
||||
if user == self.az.bot_mxid:
|
||||
joined_mxids = cast(List[MatrixUserId],
|
||||
await self.main_intent.get_room_members(self.mxid))
|
||||
for user_mxid in joined_mxids:
|
||||
if user_mxid == self.az.bot_mxid:
|
||||
continue
|
||||
puppet_id = p.Puppet.get_id_from_mxid(user)
|
||||
puppet_id = p.Puppet.get_id_from_mxid(user_mxid)
|
||||
if puppet_id and puppet_id not in allowed_tgids:
|
||||
if self.bot and puppet_id == self.bot.tgid:
|
||||
self.bot.remove_chat(self.tgid)
|
||||
await self.main_intent.kick(self.mxid, user,
|
||||
await self.main_intent.kick(self.mxid, user_mxid,
|
||||
"User had left this Telegram chat.")
|
||||
continue
|
||||
mx_user = u.User.get_by_mxid(user, create=False)
|
||||
mx_user = u.User.get_by_mxid(user_mxid, create=False)
|
||||
if mx_user and mx_user.is_bot and mx_user.tgid not in allowed_tgids:
|
||||
mx_user.unregister_portal(self)
|
||||
|
||||
@@ -457,7 +460,7 @@ class Portal:
|
||||
"You had left this Telegram chat.")
|
||||
continue
|
||||
|
||||
async def add_telegram_user(self, user_id: int, source: Optional["AbstractUser"] = None
|
||||
async def add_telegram_user(self, user_id: TelegramId, source: Optional['AbstractUser'] = None
|
||||
) -> None:
|
||||
puppet = p.Puppet.get(user_id)
|
||||
if source:
|
||||
@@ -470,7 +473,7 @@ class Portal:
|
||||
user.register_portal(self)
|
||||
await self.invite_to_matrix(user.mxid)
|
||||
|
||||
async def delete_telegram_user(self, user_id: int, sender: p.Puppet) -> None:
|
||||
async def delete_telegram_user(self, user_id: TelegramId, sender: p.Puppet) -> None:
|
||||
puppet = p.Puppet.get(user_id)
|
||||
user = u.User.get_by_tgid(user_id)
|
||||
kick_message = (f"Kicked by {sender.displayname}"
|
||||
@@ -568,8 +571,9 @@ class Portal:
|
||||
return True
|
||||
return False
|
||||
|
||||
async def _get_users(self, user: "AbstractUser", entity: Union[TypeInputPeer, InputUser,
|
||||
TypeChat, TypeUser]
|
||||
async def _get_users(self,
|
||||
user: 'AbstractUser',
|
||||
entity: Union[TypeInputPeer, InputUser, TypeChat, TypeUser]
|
||||
) -> Tuple[List[TypeUser], List[TypeParticipant]]:
|
||||
if self.peer_type == "chat":
|
||||
chat = await user.client(GetFullChatRequest(chat_id=self.tgid))
|
||||
@@ -588,7 +592,7 @@ class Portal:
|
||||
entity, ChannelParticipantsRecent(), offset=0, limit=limit, hash=0))
|
||||
return response.users, response.participants
|
||||
elif limit > 200 or limit == -1:
|
||||
users, participants = [], []
|
||||
users, participants = [], [] # type: Tuple[List[TypeUser], List[TypeParticipant]]
|
||||
offset = 0
|
||||
remaining_quota = limit if limit > 0 else 1000000
|
||||
query = (ChannelParticipantsSearch("") if limit == -1
|
||||
@@ -609,6 +613,7 @@ class Portal:
|
||||
return [], []
|
||||
elif self.peer_type == "user":
|
||||
return [entity], []
|
||||
return [], []
|
||||
|
||||
async def get_invite_link(self, user: 'u.User') -> str:
|
||||
if self.peer_type == "user":
|
||||
@@ -688,7 +693,7 @@ class Portal:
|
||||
return ""
|
||||
|
||||
async def _get_state_change_message(self, event: str, user: 'u.User',
|
||||
arguments: Optional[dict] = None) -> Optional[dict]:
|
||||
arguments: Optional[Dict] = None) -> Optional[Dict]:
|
||||
tpl = config[f"bridge.state_event_formats.{event}"]
|
||||
if len(tpl) == 0:
|
||||
# Empty format means they don't want the message
|
||||
@@ -724,11 +729,11 @@ class Portal:
|
||||
or user.mxid_localpart)
|
||||
|
||||
def set_typing(self, user: 'u.User', typing: bool = True,
|
||||
action=SendMessageTypingAction) -> None:
|
||||
action: type = SendMessageTypingAction) -> bool:
|
||||
return user.client(SetTypingRequest(
|
||||
self.peer, action() if typing else SendMessageCancelAction()))
|
||||
|
||||
async def mark_read(self, user: 'u.User', event_id: str) -> None:
|
||||
async def mark_read(self, user: 'u.User', event_id: MatrixEventId) -> None:
|
||||
if user.is_bot:
|
||||
return
|
||||
space = self.tgid if self.peer_type == "channel" else user.tgid
|
||||
@@ -743,7 +748,8 @@ class Portal:
|
||||
else:
|
||||
await user.client(ReadMessageHistoryRequest(peer=self.peer, max_id=message.tgid))
|
||||
|
||||
async def leave_matrix(self, user: 'u.User', source: 'u.User', event_id: str) -> None:
|
||||
async def leave_matrix(self, user: 'u.User', source: 'u.User', event_id: MatrixEventId
|
||||
) -> None:
|
||||
if await user.needs_relaybot(self):
|
||||
async with self.require_send_lock(self.bot.tgid):
|
||||
message = await self._get_state_change_message("leave", user)
|
||||
@@ -798,7 +804,7 @@ class Portal:
|
||||
# We'll just assume the user is already in the chat.
|
||||
pass
|
||||
|
||||
async def _apply_msg_format(self, sender: 'u.User', msgtype: str, message: dict) -> None:
|
||||
async def _apply_msg_format(self, sender: 'u.User', msgtype: str, message: Dict) -> None:
|
||||
if "formatted_body" not in message:
|
||||
message["format"] = "org.matrix.custom.html"
|
||||
message["formatted_body"] = escape_html(message.get("body", ""))
|
||||
@@ -823,7 +829,7 @@ class Portal:
|
||||
await self._apply_msg_format(sender, msgtype, message)
|
||||
|
||||
@staticmethod
|
||||
def _matrix_event_to_entities(event: dict) -> Tuple[str, Optional[List[TypeMessageEntity]]]:
|
||||
def _matrix_event_to_entities(event: Dict) -> Tuple[str, Optional[List[TypeMessageEntity]]]:
|
||||
try:
|
||||
if event.get("format", None) == "org.matrix.custom.html":
|
||||
message, entities = formatter.matrix_to_telegram(event.get("formatted_body", ""))
|
||||
@@ -851,7 +857,8 @@ class Portal:
|
||||
return None
|
||||
|
||||
async def _handle_matrix_text(self, sender_id: int, event_id: str, space: int,
|
||||
client: "MautrixTelegramClient", message: dict, reply_to: int) -> None:
|
||||
client: 'MautrixTelegramClient', message: Dict, reply_to: int
|
||||
) -> None:
|
||||
lock = self.require_send_lock(sender_id)
|
||||
async with lock:
|
||||
response = await client.send_message(self.peer, message, reply_to=reply_to,
|
||||
@@ -859,7 +866,8 @@ class Portal:
|
||||
self._add_telegram_message_to_db(event_id, space, response)
|
||||
|
||||
async def _handle_matrix_file(self, msgtype: str, sender_id: int, event_id: str, space: int,
|
||||
client: "MautrixTelegramClient", message: dict, reply_to: int) -> None:
|
||||
client: 'MautrixTelegramClient', message: dict, reply_to: int
|
||||
) -> None:
|
||||
file = await self.main_intent.download_file(message["url"])
|
||||
|
||||
info = message.get("info", {})
|
||||
@@ -893,7 +901,7 @@ class Portal:
|
||||
self._add_telegram_message_to_db(event_id, space, response)
|
||||
|
||||
async def _handle_matrix_location(self, sender_id: int, event_id: str, space: int,
|
||||
client: "MautrixTelegramClient", message: dict,
|
||||
client: 'MautrixTelegramClient', message: Dict,
|
||||
reply_to: int) -> None:
|
||||
try:
|
||||
lat, long = message["geo_uri"][len("geo:"):].split(",")
|
||||
@@ -901,13 +909,13 @@ class Portal:
|
||||
except (KeyError, ValueError):
|
||||
self.log.exception("Failed to parse location")
|
||||
return None
|
||||
message, entities = self._matrix_event_to_entities(message)
|
||||
caption, entities = self._matrix_event_to_entities(message)
|
||||
media = MessageMediaGeo(geo=GeoPoint(lat, long, access_hash=0))
|
||||
|
||||
lock = self.require_send_lock(sender_id)
|
||||
async with lock:
|
||||
response = await client.send_media(self.peer, media, reply_to=reply_to,
|
||||
caption=message, entities=entities)
|
||||
caption=caption, entities=entities)
|
||||
self._add_telegram_message_to_db(event_id, space, response)
|
||||
|
||||
def _add_telegram_message_to_db(self, event_id: str, space: int,
|
||||
@@ -963,17 +971,18 @@ class Portal:
|
||||
except ChatNotModifiedError:
|
||||
pass
|
||||
|
||||
async def handle_matrix_deletion(self, deleter: 'u.User', event_id: str) -> None:
|
||||
deleter = deleter if not await deleter.needs_relaybot(self) else self.bot
|
||||
space = self.tgid if self.peer_type == "channel" else deleter.tgid
|
||||
async def handle_matrix_deletion(self, deleter: 'u.User', event_id: MatrixEventId) -> None:
|
||||
real_deleter = deleter if not await deleter.needs_relaybot(self) else self.bot
|
||||
space = self.tgid if self.peer_type == "channel" else real_deleter.tgid
|
||||
message = DBMessage.query.filter(DBMessage.mxid == event_id,
|
||||
DBMessage.tg_space == space,
|
||||
DBMessage.mx_room == self.mxid).one_or_none()
|
||||
if not message:
|
||||
return
|
||||
await deleter.client.delete_messages(self.peer, [message.tgid])
|
||||
await real_deleter.client.delete_messages(self.peer, [message.tgid])
|
||||
|
||||
async def _update_telegram_power_level(self, sender: 'u.User', user_id: int, level: int) -> None:
|
||||
async def _update_telegram_power_level(self, sender: 'u.User', user_id: TelegramId,
|
||||
level: int) -> None:
|
||||
if self.peer_type == "chat":
|
||||
await sender.client(EditChatAdminRequest(
|
||||
chat_id=self.tgid, user_id=user_id, is_admin=level >= 50))
|
||||
@@ -989,7 +998,8 @@ class Portal:
|
||||
EditAdminRequest(channel=await self.get_input_entity(sender),
|
||||
user_id=user_id, admin_rights=rights))
|
||||
|
||||
async def handle_matrix_power_levels(self, sender: 'u.User', new_users: Dict[str, int],
|
||||
async def handle_matrix_power_levels(self, sender: 'u.User',
|
||||
new_users: Dict[MatrixUserId, int],
|
||||
old_users: Dict[str, int]) -> None:
|
||||
# TODO handle all power level changes and bridge exact admin rights to supergroups/channels
|
||||
for user, level in new_users.items():
|
||||
@@ -1167,7 +1177,7 @@ class Portal:
|
||||
return None
|
||||
|
||||
async def handle_telegram_photo(self, source: "AbstractUser", intent: IntentAPI, evt: Message,
|
||||
relates_to=None) -> None:
|
||||
relates_to: Dict = {}) -> None:
|
||||
largest_size = self._get_largest_photo_size(evt.media.photo)
|
||||
file = await util.transfer_file_to_matrix(self.db, source.client, intent,
|
||||
largest_size.location)
|
||||
@@ -1197,7 +1207,7 @@ class Portal:
|
||||
external_url=self.get_external_url(evt))
|
||||
|
||||
@staticmethod
|
||||
def _parse_telegram_document_attributes(attributes: List[TypeDocumentAttribute]) -> dict:
|
||||
def _parse_telegram_document_attributes(attributes: List[TypeDocumentAttribute]) -> Dict:
|
||||
attrs = {
|
||||
"name": None,
|
||||
"mime_type": None,
|
||||
@@ -1205,7 +1215,7 @@ class Portal:
|
||||
"sticker_alt": None,
|
||||
"width": None,
|
||||
"height": None,
|
||||
}
|
||||
} # type: Dict
|
||||
for attr in attributes:
|
||||
if isinstance(attr, DocumentAttributeFilename):
|
||||
attrs["name"] = attrs["name"] or attr.file_name
|
||||
@@ -1218,8 +1228,8 @@ class Portal:
|
||||
return attrs
|
||||
|
||||
@staticmethod
|
||||
def _parse_telegram_document_meta(evt: Message, file: DBTelegramFile, attrs: dict
|
||||
) -> Tuple[dict, str]:
|
||||
def _parse_telegram_document_meta(evt: Message, file: DBTelegramFile, attrs: Dict
|
||||
) -> Tuple[Dict, str]:
|
||||
document = evt.media.document
|
||||
name = evt.message or attrs["name"]
|
||||
if attrs["is_sticker"]:
|
||||
@@ -1253,7 +1263,7 @@ class Portal:
|
||||
|
||||
async def handle_telegram_document(self, source: "AbstractUser", intent: IntentAPI,
|
||||
evt: Message,
|
||||
relates_to: dict = None) -> Optional[dict]:
|
||||
relates_to: dict = None) -> Optional[Dict]:
|
||||
document = evt.media.document
|
||||
attrs = self._parse_telegram_document_attributes(document.attributes)
|
||||
|
||||
@@ -1521,9 +1531,9 @@ class Portal:
|
||||
else:
|
||||
self.log.debug("Unhandled Telegram action in %s: %s", self.title, action)
|
||||
|
||||
async def set_telegram_admin(self, user_id: int) -> None:
|
||||
async def set_telegram_admin(self, user_id: TelegramId) -> None:
|
||||
puppet = p.Puppet.get(user_id)
|
||||
user = await u.User.get_by_tgid(user_id)
|
||||
user = u.User.get_by_tgid(user_id)
|
||||
|
||||
levels = await self.main_intent.get_power_levels(self.mxid)
|
||||
if user:
|
||||
@@ -1558,7 +1568,7 @@ class Portal:
|
||||
await self.update_telegram_pin()
|
||||
|
||||
@staticmethod
|
||||
def _get_level_from_participant(participant: TypeParticipant, _) -> int:
|
||||
def _get_level_from_participant(participant: TypeParticipant, _: Dict) -> int:
|
||||
# TODO use the power level requirements to get better precision in channels
|
||||
if isinstance(participant, (ChatParticipantAdmin, ChannelParticipantAdmin)):
|
||||
return 50
|
||||
@@ -1599,7 +1609,7 @@ class Portal:
|
||||
except KeyError:
|
||||
return 50
|
||||
|
||||
def _participants_to_power_levels(self, participants: List[TypeParticipant], levels: dict
|
||||
def _participants_to_power_levels(self, participants: List[TypeParticipant], levels: Dict
|
||||
) -> bool:
|
||||
bot_level = self._get_bot_level(levels)
|
||||
if bot_level < self._get_powerlevel_level(levels):
|
||||
@@ -1654,7 +1664,7 @@ class Portal:
|
||||
mxid=self.mxid, username=self.username, megagroup=self.megagroup,
|
||||
title=self.title, about=self.about, photo_id=self.photo_id)
|
||||
|
||||
def migrate_and_save(self, new_id: int) -> None:
|
||||
def migrate_and_save(self, new_id: TelegramId) -> None:
|
||||
existing = DBPortal.query.get(self.tgid_full)
|
||||
if existing:
|
||||
self.db.delete(existing)
|
||||
@@ -1701,7 +1711,7 @@ class Portal:
|
||||
# region Class instance lookup
|
||||
|
||||
@classmethod
|
||||
def get_by_mxid(cls, mxid: str) -> Optional["Portal"]:
|
||||
def get_by_mxid(cls, mxid: MatrixRoomId) -> Optional['Portal']:
|
||||
try:
|
||||
return cls.by_mxid[mxid]
|
||||
except KeyError:
|
||||
@@ -1721,7 +1731,7 @@ class Portal:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def find_by_username(cls, username: str) -> Optional["Portal"]:
|
||||
def find_by_username(cls, username: str) -> Optional['Portal']:
|
||||
if not username:
|
||||
return None
|
||||
|
||||
@@ -1729,15 +1739,15 @@ class Portal:
|
||||
if portal.username and portal.username.lower() == username.lower():
|
||||
return portal
|
||||
|
||||
portal = DBPortal.query.filter(DBPortal.username == username).one_or_none()
|
||||
if portal:
|
||||
return cls.from_db(portal)
|
||||
dbportal = DBPortal.query.filter(DBPortal.username == username).one_or_none()
|
||||
if dbportal:
|
||||
return cls.from_db(dbportal)
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_by_tgid(cls, tgid: int, tg_receiver: int = None, peer_type: str = None
|
||||
) -> Optional["Portal"]:
|
||||
def get_by_tgid(cls, tgid: TelegramId, tg_receiver: Optional[TelegramId] = None,
|
||||
peer_type: str = None) -> Optional['Portal']:
|
||||
tg_receiver = tg_receiver or tgid
|
||||
tgid_full = (tgid, tg_receiver)
|
||||
try:
|
||||
@@ -1758,8 +1768,10 @@ class Portal:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_by_entity(cls, entity: Union[TypeChat, TypePeer, TypeUser, TypeUserFull, TypeInputPeer],
|
||||
receiver_id: int = None, create: bool = True) -> Optional["Portal"]:
|
||||
def get_by_entity(cls, entity: Union[TypeChat, TypePeer, TypeUser, TypeUserFull,
|
||||
TypeInputPeer],
|
||||
receiver_id: Optional[TelegramId] = None, create: bool = True
|
||||
) -> Optional['Portal']:
|
||||
entity_type = type(entity)
|
||||
if entity_type in {Chat, ChatFull}:
|
||||
type_name = "chat"
|
||||
@@ -1790,7 +1802,7 @@ class Portal:
|
||||
|
||||
def init(context: Context) -> None:
|
||||
global config
|
||||
Portal.az, Portal.db, config, Portal.loop, Portal.bot = context
|
||||
Portal.az, Portal.db, config, Portal.loop, Portal.bot = context.core
|
||||
Portal.bridge_notices = config["bridge.bridge_notices"]
|
||||
Portal.filter_mode = config["bridge.filter.mode"]
|
||||
Portal.filter_list = config["bridge.filter.list"]
|
||||
|
||||
+106
-80
@@ -14,17 +14,19 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Optional, Awaitable, Pattern, Dict, List, TYPE_CHECKING
|
||||
from typing import Awaitable, Coroutine, Dict, List, NewType, Optional, Pattern, TYPE_CHECKING
|
||||
from difflib import SequenceMatcher
|
||||
import re
|
||||
import logging
|
||||
import asyncio
|
||||
from enum import Enum
|
||||
|
||||
from sqlalchemy import orm
|
||||
|
||||
from telethon.tl.types import UserProfilePhoto
|
||||
from telethon.tl.types import UserProfilePhoto, User, FileLocation
|
||||
from mautrix_appservice import AppService, IntentAPI, IntentError, MatrixRequestError
|
||||
|
||||
from .types import MatrixUserId, TelegramId
|
||||
from .db import Puppet as DBPuppet
|
||||
from . import util
|
||||
|
||||
@@ -32,6 +34,11 @@ if TYPE_CHECKING:
|
||||
from .matrix import MatrixHandler
|
||||
from .config import Config
|
||||
from .context import Context
|
||||
from . import user as u
|
||||
from .abstract_user import AbstractUser
|
||||
|
||||
|
||||
PuppetError = Enum('PuppetError', 'Success OnlyLoginSelf InvalidAccessToken')
|
||||
|
||||
config = None # type: Config
|
||||
|
||||
@@ -45,85 +52,98 @@ class Puppet:
|
||||
mxid_regex = None # type: Pattern
|
||||
username_template = None # type: str
|
||||
hs_domain = None # type: str
|
||||
cache = {} # type: Dict[str, Puppet]
|
||||
cache = {} # type: Dict[TelegramId, Puppet]
|
||||
by_custom_mxid = {} # type: Dict[str, Puppet]
|
||||
|
||||
def __init__(self, id=None, access_token=None, custom_mxid=None, username=None,
|
||||
displayname=None, displayname_source=None, photo_id=None, is_bot=None,
|
||||
is_registered=False, db_instance=None) -> None:
|
||||
self.id = id
|
||||
self.access_token = access_token
|
||||
self.custom_mxid = custom_mxid
|
||||
self.is_real_user = self.custom_mxid and self.access_token
|
||||
self.default_mxid = self.get_mxid_from_id(self.id)
|
||||
self.mxid = self.custom_mxid or self.default_mxid
|
||||
def __init__(self,
|
||||
id: TelegramId,
|
||||
access_token: Optional[str] = None,
|
||||
custom_mxid: Optional[MatrixUserId] = None,
|
||||
username: Optional[str] = None,
|
||||
displayname: Optional[str] = None,
|
||||
displayname_source: Optional[TelegramId] = None,
|
||||
photo_id: Optional[str] = None,
|
||||
is_bot: bool = False,
|
||||
is_registered: bool = False,
|
||||
db_instance: Optional[DBPuppet] = None) -> None:
|
||||
self.id = id # type: TelegramId
|
||||
self.access_token = access_token # type: Optional[str]
|
||||
self.custom_mxid = custom_mxid # type: Optional[MatrixUserId]
|
||||
self.default_mxid = self.get_mxid_from_id(self.id) # type: MatrixUserId
|
||||
|
||||
self.username = username
|
||||
self.displayname = displayname
|
||||
self.displayname_source = displayname_source
|
||||
self.photo_id = photo_id
|
||||
self.is_bot = is_bot
|
||||
self.is_registered = is_registered
|
||||
self._db_instance = db_instance
|
||||
self.username = username # type: Optional[str]
|
||||
self.displayname = displayname # type: Optional[str]
|
||||
self.displayname_source = displayname_source # type: Optional[TelegramId]
|
||||
self.photo_id = photo_id # type: Optional[str]
|
||||
self.is_bot = is_bot # type: bool
|
||||
self.is_registered = is_registered # type: bool
|
||||
self._db_instance = db_instance # type: Optional[DBPuppet]
|
||||
|
||||
self.default_mxid_intent = self.az.intent.user(self.default_mxid)
|
||||
self.intent = None # type: IntentAPI
|
||||
self.refresh_intents()
|
||||
self.intent = self._fresh_intent() # type: IntentAPI
|
||||
|
||||
self.cache[id] = self
|
||||
if self.custom_mxid:
|
||||
self.by_custom_mxid[self.custom_mxid] = self
|
||||
|
||||
@property
|
||||
def tgid(self) -> None:
|
||||
def mxid(self):
|
||||
return self.custom_mxid or self.default_mxid
|
||||
|
||||
@property
|
||||
def tgid(self) -> TelegramId:
|
||||
return self.id
|
||||
|
||||
@property
|
||||
def is_real_user(self) -> bool:
|
||||
""" Is True when the puppet is a real Matrix user. """
|
||||
return bool(self.custom_mxid and self.access_token)
|
||||
|
||||
@staticmethod
|
||||
async def is_logged_in() -> None:
|
||||
async def is_logged_in() -> bool:
|
||||
""" Is True if the puppet is logged in. """
|
||||
return True
|
||||
|
||||
# region Custom puppet management
|
||||
def refresh_intents(self) -> None:
|
||||
self.is_real_user = self.custom_mxid and self.access_token
|
||||
self.intent = (self.az.intent.user(self.custom_mxid, self.access_token)
|
||||
if self.is_real_user else self.default_mxid_intent)
|
||||
def _fresh_intent(self) -> IntentAPI:
|
||||
return (self.az.intent.user(self.custom_mxid, self.access_token)
|
||||
if self.is_real_user else self.default_mxid_intent)
|
||||
|
||||
async def switch_mxid(self, access_token, mxid) -> None:
|
||||
async def switch_mxid(self, access_token: str, mxid: MatrixUserId) -> PuppetError:
|
||||
prev_mxid = self.custom_mxid
|
||||
self.custom_mxid = mxid
|
||||
self.access_token = access_token
|
||||
self.refresh_intents()
|
||||
self.intent = self._fresh_intent()
|
||||
|
||||
err = await self.init_custom_mxid()
|
||||
if err != 0:
|
||||
if err != PuppetError.Success:
|
||||
return err
|
||||
|
||||
try:
|
||||
del self.by_custom_mxid[prev_mxid]
|
||||
del self.by_custom_mxid[prev_mxid] # type: ignore
|
||||
except KeyError:
|
||||
pass
|
||||
self.mxid = self.custom_mxid or self.default_mxid
|
||||
if self.mxid != self.default_mxid:
|
||||
self.by_custom_mxid[self.mxid] = self
|
||||
await self.leave_rooms_with_default_user()
|
||||
self.save()
|
||||
return 0
|
||||
return PuppetError.Success
|
||||
|
||||
async def init_custom_mxid(self) -> None:
|
||||
async def init_custom_mxid(self) -> PuppetError:
|
||||
if not self.is_real_user:
|
||||
return 0
|
||||
return PuppetError.Success
|
||||
|
||||
mxid = await self.intent.whoami()
|
||||
if not mxid or mxid != self.custom_mxid:
|
||||
self.custom_mxid = None
|
||||
self.access_token = None
|
||||
self.refresh_intents()
|
||||
self.intent = self._fresh_intent()
|
||||
if mxid != self.custom_mxid:
|
||||
return 2
|
||||
return 1
|
||||
return PuppetError.OnlyLoginSelf
|
||||
return PuppetError.InvalidAccessToken
|
||||
if config["bridge.sync_with_custom_puppets"]:
|
||||
asyncio.ensure_future(self.sync(), loop=self.loop)
|
||||
return 0
|
||||
return PuppetError.Success
|
||||
|
||||
async def leave_rooms_with_default_user(self) -> None:
|
||||
for room_id in await self.default_mxid_intent.get_joined_rooms():
|
||||
@@ -159,7 +179,7 @@ class Puppet:
|
||||
},
|
||||
})
|
||||
|
||||
def filter_events(self, events) -> None:
|
||||
def filter_events(self, events: List[Dict]) -> List:
|
||||
new_events = []
|
||||
for event in events:
|
||||
evt_type = event.get("type", None)
|
||||
@@ -186,18 +206,18 @@ class Puppet:
|
||||
new_events.append(event)
|
||||
return new_events
|
||||
|
||||
def handle_sync(self, presence, ephemeral) -> None:
|
||||
presence = [self.mx.try_handle_event(event) for event in presence]
|
||||
def handle_sync(self, presence: List, ephemeral: Dict) -> None:
|
||||
presence_events = [self.mx.try_handle_event(event) for event in presence]
|
||||
|
||||
for room_id, events in ephemeral.items():
|
||||
for event in events:
|
||||
event["room_id"] = room_id
|
||||
|
||||
ephemeral = [self.mx.try_handle_event(event)
|
||||
for events in ephemeral.values()
|
||||
for event in self.filter_events(events)]
|
||||
ephemeral_events = [self.mx.try_handle_event(event)
|
||||
for events in ephemeral.values()
|
||||
for event in self.filter_events(events)]
|
||||
|
||||
events = ephemeral + presence
|
||||
events = ephemeral_events + presence_events # List[Callable[[int], Awaitable[None]]]
|
||||
coro = asyncio.gather(*events, loop=self.loop)
|
||||
asyncio.ensure_future(coro, loop=self.loop)
|
||||
|
||||
@@ -220,13 +240,14 @@ class Puppet:
|
||||
while access_token_at_start == self.access_token:
|
||||
try:
|
||||
sync_resp = await self.intent.client.sync(filter=filter_id, since=next_batch,
|
||||
set_presence="offline")
|
||||
set_presence="offline") # type: Dict
|
||||
errors = 0
|
||||
if next_batch is not None:
|
||||
presence = sync_resp.get("presence", {}).get("events", [])
|
||||
presence = sync_resp.get("presence", {}).get("events", []) # type: List
|
||||
ephemeral = {room: data.get("ephemeral", {}).get("events", [])
|
||||
for room, data
|
||||
in sync_resp.get("rooms", {}).get("join", {}).items()}
|
||||
in sync_resp.get("rooms", {}).get("join", {}).items()
|
||||
} # type: Dict
|
||||
self.handle_sync(presence, ephemeral)
|
||||
next_batch = sync_resp.get("next_batch", None)
|
||||
except MatrixRequestError as e:
|
||||
@@ -241,19 +262,19 @@ class Puppet:
|
||||
# region DB conversion
|
||||
|
||||
@property
|
||||
def db_instance(self) -> None:
|
||||
def db_instance(self) -> DBPuppet:
|
||||
if not self._db_instance:
|
||||
self._db_instance = self.new_db_instance()
|
||||
return self._db_instance
|
||||
|
||||
def new_db_instance(self) -> None:
|
||||
def new_db_instance(self) -> DBPuppet:
|
||||
return DBPuppet(id=self.id, access_token=self.access_token, custom_mxid=self.custom_mxid,
|
||||
username=self.username, displayname=self.displayname,
|
||||
displayname_source=self.displayname_source, photo_id=self.photo_id,
|
||||
is_bot=self.is_bot, matrix_registered=self.is_registered)
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, db_puppet) -> None:
|
||||
def from_db(cls, db_puppet: DBPuppet) -> 'Puppet':
|
||||
return Puppet(db_puppet.id, db_puppet.access_token, db_puppet.custom_mxid,
|
||||
db_puppet.username, db_puppet.displayname, db_puppet.displayname_source,
|
||||
db_puppet.photo_id, db_puppet.is_bot, db_puppet.matrix_registered,
|
||||
@@ -272,16 +293,16 @@ class Puppet:
|
||||
|
||||
# endregion
|
||||
# region Info updating
|
||||
def similarity(self, query) -> None:
|
||||
def similarity(self, query: str) -> int:
|
||||
username_similarity = (SequenceMatcher(None, self.username, query).ratio()
|
||||
if self.username else 0)
|
||||
displayname_similarity = (SequenceMatcher(None, self.displayname, query).ratio()
|
||||
if self.displayname else 0)
|
||||
similarity = max(username_similarity, displayname_similarity)
|
||||
return round(similarity * 1000) / 10
|
||||
return int(round(similarity * 1000) / 10)
|
||||
|
||||
@staticmethod
|
||||
def get_displayname(info, enable_format=True) -> None:
|
||||
def get_displayname(info: User, enable_format: bool = True) -> str:
|
||||
data = {
|
||||
"phone number": info.phone if hasattr(info, "phone") else None,
|
||||
"username": info.username,
|
||||
@@ -308,7 +329,7 @@ class Puppet:
|
||||
return config.get("bridge.displayname_template", "{displayname} (Telegram)").format(
|
||||
displayname=name)
|
||||
|
||||
async def update_info(self, source, info) -> None:
|
||||
async def update_info(self, source: 'AbstractUser', info: User) -> None:
|
||||
changed = False
|
||||
if self.username != info.username:
|
||||
self.username = info.username
|
||||
@@ -323,24 +344,26 @@ class Puppet:
|
||||
if changed:
|
||||
self.save()
|
||||
|
||||
async def update_displayname(self, source, info) -> None:
|
||||
async def update_displayname(self, source: 'AbstractUser', info: User) -> bool:
|
||||
ignore_source = (not source.is_relaybot
|
||||
and self.displayname_source is not None
|
||||
and self.displayname_source != source.tgid)
|
||||
if ignore_source:
|
||||
return
|
||||
return False
|
||||
|
||||
displayname = self.get_displayname(info)
|
||||
if displayname != self.displayname:
|
||||
await self.default_mxid_intent.set_display_name(displayname)
|
||||
self.displayname = displayname
|
||||
self.displayname_source = source.tgid
|
||||
self.displayname_source = TelegramId(source.tgid)
|
||||
return True
|
||||
elif source.is_relaybot or self.displayname_source is None:
|
||||
self.displayname_source = source.tgid
|
||||
self.displayname_source = TelegramId(source.tgid)
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
async def update_avatar(self, source, photo) -> None:
|
||||
async def update_avatar(self, source: 'AbstractUser', photo: FileLocation) -> bool:
|
||||
photo_id = f"{photo.volume_id}-{photo.local_id}"
|
||||
if self.photo_id != photo_id:
|
||||
file = await util.transfer_file_to_matrix(self.db, source.client,
|
||||
@@ -355,7 +378,7 @@ class Puppet:
|
||||
# region Getters
|
||||
|
||||
@classmethod
|
||||
def get(cls, tgid, create=True) -> "Optional[Puppet]":
|
||||
def get(cls, tgid: TelegramId, create: bool = True) -> Optional['Puppet']:
|
||||
try:
|
||||
return cls.cache[tgid]
|
||||
except KeyError:
|
||||
@@ -374,12 +397,15 @@ class Puppet:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_by_mxid(cls, mxid, create=True) -> "Optional[Puppet]":
|
||||
def get_by_mxid(cls, mxid: MatrixUserId, create: bool = True) -> Optional['Puppet']:
|
||||
tgid = cls.get_id_from_mxid(mxid)
|
||||
return cls.get(tgid, create) if tgid else None
|
||||
if tgid:
|
||||
return cls.get(tgid, create)
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_by_custom_mxid(cls, mxid) -> None:
|
||||
def get_by_custom_mxid(cls, mxid: MatrixUserId) -> Optional['Puppet']:
|
||||
if not mxid:
|
||||
raise ValueError("Matrix ID can't be empty")
|
||||
|
||||
@@ -396,25 +422,25 @@ class Puppet:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_all_with_custom_mxid(cls) -> None:
|
||||
def get_all_with_custom_mxid(cls) -> List['Puppet']:
|
||||
return [cls.by_custom_mxid[puppet.mxid]
|
||||
if puppet.custom_mxid in cls.by_custom_mxid
|
||||
else cls.from_db(puppet)
|
||||
for puppet in DBPuppet.query.filter(DBPuppet.custom_mxid is not None).all()]
|
||||
|
||||
@classmethod
|
||||
def get_id_from_mxid(cls, mxid) -> None:
|
||||
def get_id_from_mxid(cls, mxid: MatrixUserId) -> Optional[TelegramId]:
|
||||
match = cls.mxid_regex.match(mxid)
|
||||
if match:
|
||||
return int(match.group(1))
|
||||
return TelegramId(int(match.group(1)))
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_mxid_from_id(cls, tgid) -> None:
|
||||
return f"@{cls.username_template.format(userid=tgid)}:{cls.hs_domain}"
|
||||
def get_mxid_from_id(cls, tgid: TelegramId) -> MatrixUserId:
|
||||
return MatrixUserId(f"@{cls.username_template.format(userid=tgid)}:{cls.hs_domain}")
|
||||
|
||||
@classmethod
|
||||
def find_by_username(cls, username) -> "Optional[Puppet]":
|
||||
def find_by_username(cls, username: str) -> Optional['Puppet']:
|
||||
if not username:
|
||||
return None
|
||||
|
||||
@@ -422,14 +448,14 @@ class Puppet:
|
||||
if puppet.username and puppet.username.lower() == username.lower():
|
||||
return puppet
|
||||
|
||||
puppet = DBPuppet.query.filter(DBPuppet.username == username).one_or_none()
|
||||
if puppet:
|
||||
return cls.from_db(puppet)
|
||||
dbpuppet = DBPuppet.query.filter(DBPuppet.username == username).one_or_none()
|
||||
if dbpuppet:
|
||||
return cls.from_db(dbpuppet)
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def find_by_displayname(cls, displayname) -> "Optional[Puppet]":
|
||||
def find_by_displayname(cls, displayname: str) -> Optional['Puppet']:
|
||||
if not displayname:
|
||||
return None
|
||||
|
||||
@@ -437,17 +463,17 @@ class Puppet:
|
||||
if puppet.displayname and puppet.displayname == displayname:
|
||||
return puppet
|
||||
|
||||
puppet = DBPuppet.query.filter(DBPuppet.displayname == displayname).one_or_none()
|
||||
if puppet:
|
||||
return cls.from_db(puppet)
|
||||
dbpuppet = DBPuppet.query.filter(DBPuppet.displayname == displayname).one_or_none()
|
||||
if dbpuppet:
|
||||
return cls.from_db(dbpuppet)
|
||||
|
||||
return None
|
||||
# endregion
|
||||
|
||||
|
||||
def init(context: "Context") -> List[Awaitable[int]]:
|
||||
def init(context: 'Context') -> List[Coroutine]: # [None, None, PuppetError]
|
||||
global config
|
||||
Puppet.az, Puppet.db, config, Puppet.loop, _ = context
|
||||
Puppet.az, Puppet.db, config, Puppet.loop, _ = context.core
|
||||
Puppet.mx = context.mx
|
||||
Puppet.username_template = config.get("bridge.username_template", "telegram_{userid}")
|
||||
Puppet.hs_domain = config["homeserver"]["domain"]
|
||||
|
||||
@@ -40,7 +40,7 @@ telematrix_db_engine.dispose()
|
||||
portals = {}
|
||||
chats = {}
|
||||
messages = {}
|
||||
puppets = {}
|
||||
puppets = {} # Dict[int, Puppet]
|
||||
|
||||
for chat_link in chat_links:
|
||||
if type(chat_link.tg_room) is str:
|
||||
|
||||
@@ -20,37 +20,39 @@ from sqlalchemy import orm
|
||||
|
||||
from mautrix_appservice import StateStore
|
||||
|
||||
from .types import MatrixUserId, MatrixRoomId
|
||||
from . import puppet as pu
|
||||
from .db import RoomState, UserProfile
|
||||
|
||||
|
||||
class SQLStateStore(StateStore):
|
||||
def __init__(self, db) -> None:
|
||||
def __init__(self, db: orm.Session) -> None:
|
||||
super().__init__()
|
||||
self.db = db # type: orm.Session
|
||||
self.profile_cache = {} # type: Dict[Tuple[str, str], UserProfile]
|
||||
self.room_state_cache = {} # type: Dict[str, RoomState]
|
||||
|
||||
@staticmethod
|
||||
def is_registered(user: str) -> bool:
|
||||
def is_registered(user: MatrixUserId) -> bool:
|
||||
puppet = pu.Puppet.get_by_mxid(user)
|
||||
return puppet.is_registered if puppet else False
|
||||
|
||||
@staticmethod
|
||||
def registered(user: str) -> None:
|
||||
def registered(user: MatrixUserId) -> None:
|
||||
puppet = pu.Puppet.get_by_mxid(user)
|
||||
if puppet:
|
||||
puppet.is_registered = True
|
||||
puppet.save()
|
||||
|
||||
def update_state(self, event: dict) -> None:
|
||||
def update_state(self, event: Dict) -> None:
|
||||
event_type = event["type"]
|
||||
if event_type == "m.room.power_levels":
|
||||
self.set_power_levels(event["room_id"], event["content"])
|
||||
elif event_type == "m.room.member":
|
||||
self.set_member(event["room_id"], event["state_key"], event["content"])
|
||||
|
||||
def _get_user_profile(self, room_id: str, user_id: str, create: bool = True) -> UserProfile:
|
||||
def _get_user_profile(self, room_id: MatrixRoomId, user_id: MatrixUserId, create: bool = True
|
||||
) -> UserProfile:
|
||||
key = (room_id, user_id)
|
||||
try:
|
||||
return self.profile_cache[key]
|
||||
@@ -67,22 +69,22 @@ class SQLStateStore(StateStore):
|
||||
self.profile_cache[key] = profile
|
||||
return profile
|
||||
|
||||
def get_member(self, room: str, user: str) -> dict:
|
||||
def get_member(self, room: MatrixRoomId, user: MatrixUserId) -> Dict:
|
||||
return self._get_user_profile(room, user).dict()
|
||||
|
||||
def set_member(self, room: str, user: str, member: dict) -> None:
|
||||
def set_member(self, room: MatrixRoomId, user: MatrixUserId, member: Dict) -> None:
|
||||
profile = self._get_user_profile(room, user)
|
||||
profile.membership = member.get("membership", profile.membership or "leave")
|
||||
profile.displayname = member.get("displayname", profile.displayname)
|
||||
profile.avatar_url = member.get("avatar_url", profile.avatar_url)
|
||||
self.db.commit()
|
||||
|
||||
def set_membership(self, room: str, user: str, membership: str) -> None:
|
||||
def set_membership(self, room: MatrixRoomId, user: MatrixUserId, membership: str) -> None:
|
||||
self.set_member(room, user, {
|
||||
"membership": membership,
|
||||
})
|
||||
|
||||
def _get_room_state(self, room_id: str, create: bool = True) -> RoomState:
|
||||
def _get_room_state(self, room_id: MatrixRoomId, create: bool = True) -> RoomState:
|
||||
try:
|
||||
return self.room_state_cache[room_id]
|
||||
except KeyError:
|
||||
@@ -96,13 +98,13 @@ class SQLStateStore(StateStore):
|
||||
self.room_state_cache[room_id] = room
|
||||
return room
|
||||
|
||||
def has_power_levels(self, room: str) -> bool:
|
||||
def has_power_levels(self, room: MatrixRoomId) -> bool:
|
||||
return self._get_room_state(room).has_power_levels
|
||||
|
||||
def get_power_levels(self, room: str) -> dict:
|
||||
def get_power_levels(self, room: MatrixRoomId) -> Dict:
|
||||
return self._get_room_state(room).power_levels
|
||||
|
||||
def set_power_level(self, room: str, user: str, level: int) -> None:
|
||||
def set_power_level(self, room: MatrixRoomId, user: MatrixUserId, level: int) -> None:
|
||||
room_state = self._get_room_state(room)
|
||||
power_levels = room_state.power_levels
|
||||
if not power_levels:
|
||||
@@ -114,7 +116,7 @@ class SQLStateStore(StateStore):
|
||||
room_state.power_levels = power_levels
|
||||
self.db.commit()
|
||||
|
||||
def set_power_levels(self, room: str, content: dict) -> None:
|
||||
def set_power_levels(self, room: MatrixRoomId, content: Dict) -> None:
|
||||
state = self._get_room_state(room)
|
||||
state.power_levels = content
|
||||
self.db.commit()
|
||||
|
||||
@@ -0,0 +1,10 @@
|
||||
from typing import Dict, NewType
|
||||
|
||||
# MatrixId = NewType('MatrixId', str)
|
||||
MatrixUserId = NewType('MatrixUserId', str)
|
||||
MatrixRoomId = NewType('MatrixRoomId', str)
|
||||
MatrixEventId = NewType('MatrixEventId', str)
|
||||
|
||||
MatrixEvent = NewType('MatrixEvent', Dict)
|
||||
|
||||
TelegramId = NewType('TelegramId', int)
|
||||
+26
-22
@@ -14,7 +14,7 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Awaitable, Dict, List, Match, Optional, Tuple, TYPE_CHECKING
|
||||
from typing import Coroutine, Dict, List, Match, Optional, Tuple, cast, TYPE_CHECKING
|
||||
import logging
|
||||
import asyncio
|
||||
import re
|
||||
@@ -28,6 +28,7 @@ from telethon.tl.functions.contacts import GetContactsRequest, SearchRequest
|
||||
from telethon.tl.functions.account import UpdateStatusRequest
|
||||
from mautrix_appservice import MatrixRequestError
|
||||
|
||||
from .types import MatrixUserId, TelegramId
|
||||
from .db import User as DBUser, Contact as DBContact, Portal as DBPortal
|
||||
from .abstract_user import AbstractUser
|
||||
from . import portal as po, puppet as pu
|
||||
@@ -46,23 +47,23 @@ class User(AbstractUser):
|
||||
by_mxid = {} # type: Dict[str, User]
|
||||
by_tgid = {} # type: Dict[int, User]
|
||||
|
||||
def __init__(self, mxid: str, tgid: Optional[int] = None, username: Optional[str] = None,
|
||||
db_contacts: Optional[List[DBContact]] = None, saved_contacts: int = 0,
|
||||
is_bot: bool = False, db_portals: Optional[List[DBPortal]] = None,
|
||||
def __init__(self, mxid: MatrixUserId, tgid: Optional[TelegramId] = None,
|
||||
username: Optional[str] = None, db_contacts: Optional[List[DBContact]] = None,
|
||||
saved_contacts: int = 0, is_bot: bool = False, db_portals: List[DBPortal] = [],
|
||||
db_instance: Optional[DBUser] = None) -> None:
|
||||
super().__init__()
|
||||
self.mxid = mxid # type: str
|
||||
self.tgid = tgid # type: int
|
||||
self.mxid = mxid # type: MatrixUserId
|
||||
self.tgid = tgid # type: TelegramId
|
||||
self.is_bot = is_bot # type: bool
|
||||
self.username = username # type: str
|
||||
self.contacts = [] # type: List[pu.Puppet]
|
||||
self.saved_contacts = saved_contacts # type: int
|
||||
self.db_contacts = db_contacts # type: List[DBContact]
|
||||
self.portals = {} # type: Dict[Tuple[int, int], po.Portal]
|
||||
self.db_portals = db_portals # type: List[DBPortal]
|
||||
self._db_instance = db_instance # type: DBUser
|
||||
self.db_portals = db_portals or [] # type: List[DBPortal]
|
||||
self._db_instance = db_instance # type: Optional[DBUser]
|
||||
|
||||
self.command_status = None # type: dict
|
||||
self.command_status = None # type: Dict
|
||||
|
||||
(self.relaybot_whitelisted,
|
||||
self.whitelisted,
|
||||
@@ -169,9 +170,9 @@ class User(AbstractUser):
|
||||
except Exception:
|
||||
self.log.exception("Failed to run post-login functions for %s", self.mxid)
|
||||
|
||||
async def update(self, update: TypeUpdate) -> None:
|
||||
async def update(self, update: TypeUpdate) -> bool:
|
||||
if not self.is_bot:
|
||||
return
|
||||
return False
|
||||
|
||||
if isinstance(update, (UpdateNewMessage, UpdateNewChannelMessage)):
|
||||
message = update.message
|
||||
@@ -185,19 +186,22 @@ class User(AbstractUser):
|
||||
elif isinstance(update, UpdateShortMessage):
|
||||
portal = po.Portal.get_by_tgid(update.user_id, self.tgid, "user")
|
||||
else:
|
||||
return
|
||||
return False
|
||||
|
||||
self.register_portal(portal)
|
||||
if portal:
|
||||
self.register_portal(portal)
|
||||
|
||||
return True
|
||||
|
||||
# endregion
|
||||
# region Telegram actions that need custom methods
|
||||
|
||||
def ensure_started(self, even_if_no_session: bool = False) -> "Awaitable[User]":
|
||||
return super().ensure_started(even_if_no_session)
|
||||
def ensure_started(self, even_if_no_session: bool = False) -> Coroutine[None, None, 'User']:
|
||||
return cast(Coroutine[None, None, 'User'], super().ensure_started(even_if_no_session))
|
||||
|
||||
def set_presence(self, online: bool = True) -> None:
|
||||
def set_presence(self, online: bool = True) -> bool:
|
||||
if self.is_bot:
|
||||
return
|
||||
return False
|
||||
return self.client(UpdateStatusRequest(offline=not online))
|
||||
|
||||
async def update_info(self, info: TLUser = None) -> None:
|
||||
@@ -215,7 +219,7 @@ class User(AbstractUser):
|
||||
if changed:
|
||||
self.save()
|
||||
|
||||
async def log_out(self) -> None:
|
||||
async def log_out(self) -> bool:
|
||||
puppet = pu.Puppet.get(self.tgid)
|
||||
if puppet.is_real_user:
|
||||
await puppet.switch_mxid(None, None)
|
||||
@@ -328,7 +332,7 @@ class User(AbstractUser):
|
||||
# region Class instance lookup
|
||||
|
||||
@classmethod
|
||||
def get_by_mxid(cls, mxid: str, create: bool=True) -> "Optional[User]":
|
||||
def get_by_mxid(cls, mxid: MatrixUserId, create: bool = True) -> Optional['User']:
|
||||
if not mxid:
|
||||
raise ValueError("Matrix ID can't be empty")
|
||||
|
||||
@@ -351,7 +355,7 @@ class User(AbstractUser):
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_by_tgid(cls, tgid: int) -> "Optional[User]":
|
||||
def get_by_tgid(cls, tgid: int) -> Optional['User']:
|
||||
try:
|
||||
return cls.by_tgid[tgid]
|
||||
except KeyError:
|
||||
@@ -365,7 +369,7 @@ class User(AbstractUser):
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def find_by_username(cls, username: str) -> "Optional[User]":
|
||||
def find_by_username(cls, username: str) -> Optional['User']:
|
||||
if not username:
|
||||
return None
|
||||
|
||||
@@ -381,7 +385,7 @@ class User(AbstractUser):
|
||||
# endregion
|
||||
|
||||
|
||||
def init(context: "Context") -> List[Awaitable[User]]:
|
||||
def init(context: 'Context') -> List[Coroutine]: # [None, None, AbstractUser]
|
||||
global config
|
||||
config = context.config
|
||||
|
||||
|
||||
@@ -17,10 +17,10 @@
|
||||
|
||||
|
||||
def format_duration(seconds: int) -> str:
|
||||
def pluralize(count, singular) -> None:
|
||||
def pluralize(count: int, singular: str) -> str:
|
||||
return singular if count == 1 else singular + "s"
|
||||
|
||||
def include(count, word) -> None:
|
||||
def include(count: int, word: str) -> str:
|
||||
return f"{count} {pluralize(count, word)}" if count > 0 else ""
|
||||
|
||||
minutes, seconds = divmod(seconds, 60)
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Optional
|
||||
from typing import Dict, Optional
|
||||
import json
|
||||
import base64
|
||||
import hashlib
|
||||
@@ -28,13 +28,13 @@ def _get_checksum(key: str, payload: bytes) -> str:
|
||||
return checksum
|
||||
|
||||
|
||||
def sign_token(key: str, payload: dict) -> str:
|
||||
payload = base64.urlsafe_b64encode(json.dumps(payload).encode("utf-8"))
|
||||
checksum = _get_checksum(key, payload)
|
||||
return f"{checksum}:{payload.decode('utf-8')}"
|
||||
def sign_token(key: str, payload: Dict) -> str:
|
||||
payload_b64 = base64.urlsafe_b64encode(json.dumps(payload).encode("utf-8"))
|
||||
checksum = _get_checksum(key, payload_b64)
|
||||
return f"{checksum}:{payload_b64.decode('utf-8')}"
|
||||
|
||||
|
||||
def verify_token(key: str, data: str) -> Optional[dict]:
|
||||
def verify_token(key: str, data: str) -> Optional[Dict]:
|
||||
if not data:
|
||||
return None
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ from telethon.errors import *
|
||||
|
||||
from ...commands.auth import enter_password
|
||||
from ...util import format_duration
|
||||
from ...puppet import Puppet
|
||||
from ...puppet import Puppet, PuppetError
|
||||
from ...user import User
|
||||
|
||||
|
||||
@@ -51,12 +51,13 @@ class AuthAPI(abc.ABC):
|
||||
"account.", errcode="already-logged-in")
|
||||
|
||||
resp = await puppet.switch_mxid(token, user.mxid)
|
||||
if resp == 2:
|
||||
if resp == PuppetError.OnlyLoginSelf:
|
||||
return self.get_mx_login_response(status=403, errcode="only-login-self",
|
||||
error="You can only log in as your own Matrix user.")
|
||||
elif resp == 1:
|
||||
elif resp == PuppetError.InvalidAccessToken:
|
||||
return self.get_mx_login_response(status=401, errcode="invalid-access-token",
|
||||
error="Failed to verify access token.")
|
||||
assert resp == PuppetError.Success, "Encountered an unhandled PuppetError."
|
||||
|
||||
return self.get_mx_login_response(mxid=user.mxid, status=200, state="logged-in")
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from aiohttp import web
|
||||
from typing import Tuple, Optional, Callable, Awaitable, TYPE_CHECKING
|
||||
from typing import Awaitable, Callable, Dict, Optional, Tuple, TYPE_CHECKING
|
||||
import asyncio
|
||||
import logging
|
||||
import json
|
||||
@@ -24,6 +24,7 @@ from telethon.utils import get_peer_id, resolve_id
|
||||
from telethon.tl.types import ChatForbidden, ChannelForbidden, TypeChat
|
||||
from mautrix_appservice import AppService, MatrixRequestError, IntentError
|
||||
|
||||
from ...types import MatrixUserId
|
||||
from ...user import User
|
||||
from ...portal import Portal
|
||||
from ...commands.portal import user_has_power_level, get_initial_state
|
||||
@@ -36,7 +37,7 @@ if TYPE_CHECKING:
|
||||
class ProvisioningAPI(AuthAPI):
|
||||
log = logging.getLogger("mau.web.provisioning")
|
||||
|
||||
def __init__(self, context: "Context"):
|
||||
def __init__(self, context: "Context") -> None:
|
||||
super().__init__(context.loop)
|
||||
self.secret = context.config["appservice.provisioning.shared_secret"]
|
||||
self.az = context.az # type: AppService
|
||||
@@ -411,7 +412,7 @@ class ProvisioningAPI(AuthAPI):
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
|
||||
async def get_user(self, mxid: str, expect_logged_in: Optional[bool] = False,
|
||||
async def get_user(self, mxid: MatrixUserId, expect_logged_in: Optional[bool] = False,
|
||||
require_puppeting: bool = True, require_user: bool = True
|
||||
) -> Tuple[Optional[User], Optional[web.Response]]:
|
||||
if not mxid:
|
||||
@@ -439,7 +440,7 @@ class ProvisioningAPI(AuthAPI):
|
||||
expect_logged_in: Optional[bool] = False,
|
||||
require_puppeting: bool = False,
|
||||
want_data: bool = True,
|
||||
) -> (Tuple[Optional[dict],
|
||||
) -> (Tuple[Optional[Dict],
|
||||
Optional[User],
|
||||
Optional[web.Response]]):
|
||||
err = self.check_authorization(request)
|
||||
|
||||
Reference in New Issue
Block a user