Add more type hints
This commit is contained in:
@@ -14,34 +14,33 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Optional
|
||||
import argparse
|
||||
import sys
|
||||
import logging
|
||||
import logging.config
|
||||
import asyncio
|
||||
import logging.config
|
||||
import sys
|
||||
|
||||
import sqlalchemy as sql
|
||||
from sqlalchemy import orm
|
||||
import sqlalchemy as sql
|
||||
|
||||
from alchemysession import AlchemySessionContainer
|
||||
from mautrix_appservice import AppService
|
||||
from alchemysession import AlchemySessionContainer
|
||||
|
||||
from .base import Base
|
||||
from .config import Config
|
||||
from .matrix import MatrixHandler
|
||||
|
||||
from . import __version__
|
||||
from .db import init as init_db
|
||||
from .web.provisioning import ProvisioningAPI
|
||||
from .web.public import PublicBridgeWebsite
|
||||
from .abstract_user import init as init_abstract_user
|
||||
from .user import init as init_user, User
|
||||
from .base import Base
|
||||
from .bot import init as init_bot
|
||||
from .config import Config
|
||||
from .context import Context
|
||||
from .db import init as init_db
|
||||
from .formatter import init as init_formatter
|
||||
from .matrix import MatrixHandler
|
||||
from .portal import init as init_portal
|
||||
from .puppet import init as init_puppet
|
||||
from .formatter import init as init_formatter
|
||||
from .web.public import PublicBridgeWebsite
|
||||
from .web.provisioning import ProvisioningAPI
|
||||
from .context import Context
|
||||
from .sqlstatestore import SQLStateStore
|
||||
from .user import User, init as init_user
|
||||
from . import __version__
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="A Matrix-Telegram puppeting bridge.",
|
||||
@@ -68,7 +67,7 @@ if args.generate_registration:
|
||||
sys.exit(0)
|
||||
|
||||
logging.config.dictConfig(config["logging"])
|
||||
log = logging.getLogger("mau.init")
|
||||
log = logging.getLogger("mau.init") # type: logging.Logger
|
||||
log.debug(f"Initializing mautrix-telegram {__version__}")
|
||||
|
||||
db_engine = sql.create_engine(config["appservice.database"] or "sqlite:///mautrix-telegram.db")
|
||||
@@ -80,7 +79,7 @@ session_container = AlchemySessionContainer(engine=db_engine, session=db_session
|
||||
table_base=Base, table_prefix="telethon_",
|
||||
manage_tables=False)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
loop = asyncio.get_event_loop() # type: asyncio.AbstractEventLoop
|
||||
|
||||
state_store = SQLStateStore(db_session)
|
||||
appserv = AppService(config["homeserver.address"], config["homeserver.domain"],
|
||||
@@ -89,8 +88,8 @@ appserv = AppService(config["homeserver.address"], config["homeserver.domain"],
|
||||
verify_ssl=config["homeserver.verify_ssl"], state_store=state_store,
|
||||
real_user_content_key="net.maunium.telegram.puppet")
|
||||
|
||||
public_website = None
|
||||
provisioning_api = None
|
||||
public_website = None # type: Optional[PublicBridgeWebsite]
|
||||
provisioning_api = None # type: Optional[ProvisioningAPI]
|
||||
|
||||
if config["appservice.public.enabled"]:
|
||||
public_website = PublicBridgeWebsite(loop)
|
||||
|
||||
@@ -14,26 +14,48 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Tuple, Optional, List, Union, TYPE_CHECKING
|
||||
from abc import ABC, abstractmethod
|
||||
import asyncio
|
||||
import logging
|
||||
import platform
|
||||
|
||||
from telethon.tl.types import *
|
||||
from mautrix_appservice import MatrixRequestError
|
||||
from sqlalchemy import orm
|
||||
from telethon.tl.types import Channel, ChannelForbidden, Chat, ChatForbidden, Message, \
|
||||
MessageActionChannelMigrateFrom, MessageService, PeerUser, TypeUpdate, \
|
||||
UpdateChannelPinnedMessage, UpdateChatAdmins, UpdateChatParticipantAdmin, \
|
||||
UpdateChatParticipants, UpdateChatUserTyping, UpdateDeleteChannelMessages, \
|
||||
UpdateDeleteMessages, UpdateEditChannelMessage, UpdateEditMessage, UpdateNewChannelMessage, \
|
||||
UpdateNewMessage, UpdateReadHistoryOutbox, UpdateShortChatMessage, UpdateShortMessage, \
|
||||
UpdateUserName, UpdateUserPhoto, UpdateUserStatus, UpdateUserTyping, User, UserStatusOffline, \
|
||||
UserStatusOnline
|
||||
|
||||
from mautrix_appservice import MatrixRequestError, AppService
|
||||
from alchemysession import AlchemySessionContainer
|
||||
|
||||
from .tgclient import MautrixTelegramClient
|
||||
from .db import Message as DBMessage
|
||||
from . import portal as po, puppet as pu, __version__
|
||||
from .db import Message as DBMessage
|
||||
from .tgclient import MautrixTelegramClient
|
||||
|
||||
config = None
|
||||
if TYPE_CHECKING:
|
||||
from .context import Context
|
||||
from .config import Config
|
||||
|
||||
config = None # type: Config
|
||||
# Value updated from config in init()
|
||||
MAX_DELETIONS = 10
|
||||
MAX_DELETIONS = 10 # type: int
|
||||
|
||||
UpdateMessage = Union[UpdateShortChatMessage, UpdateShortMessage, UpdateNewChannelMessage,
|
||||
UpdateNewMessage, UpdateEditMessage, UpdateEditChannelMessage]
|
||||
UpdateMessageContent = Union[UpdateShortMessage, UpdateShortChatMessage, Message, MessageService]
|
||||
|
||||
|
||||
class AbstractUser:
|
||||
session_container = None
|
||||
loop = None
|
||||
log = None
|
||||
db = None
|
||||
az = None
|
||||
class AbstractUser(ABC):
|
||||
session_container = None # type: AlchemySessionContainer
|
||||
loop = None # type: asyncio.AbstractEventLoop
|
||||
log = None # type: logging.Logger
|
||||
db = None # type: orm.Session
|
||||
az = None # type: AppService
|
||||
|
||||
def __init__(self):
|
||||
self.puppet_whitelisted = False # type: bool
|
||||
@@ -47,22 +69,22 @@ class AbstractUser:
|
||||
self.is_bot = False # type: bool
|
||||
|
||||
@property
|
||||
def connected(self):
|
||||
def connected(self) -> bool:
|
||||
return self.client and self.client.is_connected()
|
||||
|
||||
@property
|
||||
def _proxy_settings(self):
|
||||
type = config["telegram.proxy.type"].lower()
|
||||
if type == "disabled":
|
||||
def _proxy_settings(self) -> Optional[Tuple[int, str, str, str, str, str]]:
|
||||
proxy_type = config["telegram.proxy.type"].lower()
|
||||
if proxy_type == "disabled":
|
||||
return None
|
||||
elif type == "socks4":
|
||||
type = 1
|
||||
elif type == "socks5":
|
||||
type = 2
|
||||
elif type == "http":
|
||||
type = 3
|
||||
elif proxy_type == "socks4":
|
||||
proxy_type = 1
|
||||
elif proxy_type == "socks5":
|
||||
proxy_type = 2
|
||||
elif proxy_type == "http":
|
||||
proxy_type = 3
|
||||
|
||||
return (type,
|
||||
return (proxy_type,
|
||||
config["telegram.proxy.address"], config["telegram.proxy.port"],
|
||||
config["telegram.proxy.rdns"],
|
||||
config["telegram.proxy.username"], config["telegram.proxy.password"])
|
||||
@@ -83,20 +105,30 @@ class AbstractUser:
|
||||
proxy=self._proxy_settings)
|
||||
self.client.add_event_handler(self._update_catch)
|
||||
|
||||
async def update(self, update):
|
||||
@abstractmethod
|
||||
async def update(self, update: TypeUpdate) -> bool:
|
||||
return False
|
||||
|
||||
@abstractmethod
|
||||
async def post_login(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
async def _update_catch(self, update):
|
||||
@abstractmethod
|
||||
def register_portal(self, portal: po.Portal):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def unregister_portal(self, portal: po.Portal):
|
||||
raise NotImplementedError()
|
||||
|
||||
async def _update_catch(self, update: TypeUpdate):
|
||||
try:
|
||||
if not await self.update(update):
|
||||
await self._update(update)
|
||||
except Exception:
|
||||
self.log.exception("Failed to handle Telegram update")
|
||||
|
||||
async def get_dialogs(self, limit=None) -> List[Union[Chat, Channel]]:
|
||||
async def get_dialogs(self, limit: int = None) -> List[Union[Chat, Channel]]:
|
||||
if self.is_bot:
|
||||
return []
|
||||
dialogs = await self.client.get_dialogs(limit=limit)
|
||||
@@ -106,18 +138,19 @@ class AbstractUser:
|
||||
and (dialog.entity.deactivated or dialog.entity.left)))]
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def is_logged_in(self):
|
||||
async def is_logged_in(self) -> bool:
|
||||
return self.client and await self.client.is_user_authorized()
|
||||
|
||||
async def has_full_access(self, allow_bot=False):
|
||||
async def has_full_access(self, allow_bot: bool = False) -> bool:
|
||||
return (self.puppet_whitelisted
|
||||
and (not self.is_bot or allow_bot)
|
||||
and await self.is_logged_in())
|
||||
|
||||
async def start(self, delete_unless_authenticated=False):
|
||||
async def start(self, delete_unless_authenticated: bool = False) -> "AbstractUser":
|
||||
if not self.client:
|
||||
self._init_client()
|
||||
await self.client.connect()
|
||||
@@ -144,7 +177,7 @@ class AbstractUser:
|
||||
|
||||
# region Telegram update handling
|
||||
|
||||
async def _update(self, update):
|
||||
async def _update(self, update: TypeUpdate):
|
||||
if isinstance(update, (UpdateShortChatMessage, UpdateShortMessage, UpdateNewChannelMessage,
|
||||
UpdateNewMessage, UpdateEditMessage, UpdateEditChannelMessage)):
|
||||
await self.update_message(update)
|
||||
@@ -169,17 +202,19 @@ class AbstractUser:
|
||||
else:
|
||||
self.log.debug("Unhandled update: %s", update)
|
||||
|
||||
async def update_pinned_messages(self, update):
|
||||
@staticmethod
|
||||
async def update_pinned_messages(update: UpdateChannelPinnedMessage):
|
||||
portal = po.Portal.get_by_tgid(update.channel_id)
|
||||
if portal and portal.mxid:
|
||||
await portal.receive_telegram_pin_id(update.id)
|
||||
|
||||
async def update_participants(self, update):
|
||||
@staticmethod
|
||||
async def update_participants(update: UpdateChatParticipants):
|
||||
portal = po.Portal.get_by_tgid(update.participants.chat_id)
|
||||
if portal and portal.mxid:
|
||||
await portal.update_telegram_participants(update.participants.participants)
|
||||
|
||||
async def update_read_receipt(self, update):
|
||||
async def update_read_receipt(self, update: UpdateReadHistoryOutbox):
|
||||
if not isinstance(update.peer, PeerUser):
|
||||
self.log.debug("Unexpected read receipt peer: %s", update.peer)
|
||||
return
|
||||
@@ -196,7 +231,7 @@ class AbstractUser:
|
||||
puppet = pu.Puppet.get(update.peer.user_id)
|
||||
await puppet.intent.mark_read(portal.mxid, message.mxid)
|
||||
|
||||
async def update_admin(self, update):
|
||||
async def update_admin(self, update: Union[UpdateChatAdmins, UpdateChatParticipantAdmin]):
|
||||
# TODO duplication not checked
|
||||
portal = po.Portal.get_by_tgid(update.chat_id, peer_type="chat")
|
||||
if isinstance(update, UpdateChatAdmins):
|
||||
@@ -206,7 +241,7 @@ class AbstractUser:
|
||||
else:
|
||||
self.log.warning("Unexpected admin status update: %s", update)
|
||||
|
||||
async def update_typing(self, update):
|
||||
async def update_typing(self, update: Union[UpdateUserTyping, UpdateChatUserTyping]):
|
||||
if isinstance(update, UpdateUserTyping):
|
||||
portal = po.Portal.get_by_tgid(update.user_id, self.tgid, "user")
|
||||
else:
|
||||
@@ -214,7 +249,7 @@ class AbstractUser:
|
||||
sender = pu.Puppet.get(update.user_id)
|
||||
await portal.handle_telegram_typing(sender, update)
|
||||
|
||||
async def update_others_info(self, update):
|
||||
async def update_others_info(self, update: Union[UpdateUserName, UpdateUserPhoto]):
|
||||
# TODO duplication not checked
|
||||
puppet = pu.Puppet.get(update.user_id)
|
||||
if isinstance(update, UpdateUserName):
|
||||
@@ -226,7 +261,7 @@ class AbstractUser:
|
||||
else:
|
||||
self.log.warning("Unexpected other user info update: %s", update)
|
||||
|
||||
async def update_status(self, update):
|
||||
async def update_status(self, update: UpdateUserStatus):
|
||||
puppet = pu.Puppet.get(update.user_id)
|
||||
if isinstance(update.status, UserStatusOnline):
|
||||
await puppet.default_mxid_intent.set_presence("online")
|
||||
@@ -236,7 +271,9 @@ class AbstractUser:
|
||||
self.log.warning("Unexpected user status update: %s", update)
|
||||
return
|
||||
|
||||
def get_message_details(self, update):
|
||||
def get_message_details(self, update: UpdateMessage) -> Tuple[UpdateMessageContent,
|
||||
Optional[pu.Puppet],
|
||||
Optional[po.Portal]]:
|
||||
if isinstance(update, UpdateShortChatMessage):
|
||||
portal = po.Portal.get_by_tgid(update.chat_id, peer_type="chat")
|
||||
sender = pu.Puppet.get(update.from_id)
|
||||
@@ -259,7 +296,7 @@ class AbstractUser:
|
||||
return update, sender, portal
|
||||
|
||||
@staticmethod
|
||||
async def _try_redact(portal, message):
|
||||
async def _try_redact(portal: po.Portal, message: DBMessage):
|
||||
if not portal:
|
||||
return
|
||||
try:
|
||||
@@ -267,7 +304,7 @@ class AbstractUser:
|
||||
except MatrixRequestError:
|
||||
pass
|
||||
|
||||
async def delete_message(self, update):
|
||||
async def delete_message(self, update: UpdateDeleteMessages):
|
||||
if len(update.messages) > MAX_DELETIONS:
|
||||
return
|
||||
|
||||
@@ -283,7 +320,7 @@ class AbstractUser:
|
||||
await self._try_redact(portal, message)
|
||||
self.db.commit()
|
||||
|
||||
async def delete_channel_message(self, update):
|
||||
async def delete_channel_message(self, update: UpdateDeleteChannelMessages):
|
||||
if len(update.messages) > MAX_DELETIONS:
|
||||
return
|
||||
|
||||
@@ -299,7 +336,7 @@ class AbstractUser:
|
||||
await self._try_redact(portal, message)
|
||||
self.db.commit()
|
||||
|
||||
async def update_message(self, original_update):
|
||||
async def update_message(self, original_update: UpdateMessage):
|
||||
update, sender, portal = self.get_message_details(original_update)
|
||||
|
||||
if isinstance(update, MessageService):
|
||||
@@ -325,7 +362,7 @@ class AbstractUser:
|
||||
# endregion
|
||||
|
||||
|
||||
def init(context):
|
||||
def init(context: "Context"):
|
||||
global config, MAX_DELETIONS
|
||||
AbstractUser.az, AbstractUser.db, config, AbstractUser.loop, _ = context
|
||||
AbstractUser.session_container = context.session_container
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
Base = declarative_base()
|
||||
Base = declarative_base() # type: declarative_base
|
||||
|
||||
+22
-18
@@ -14,7 +14,7 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Awaitable, Callable
|
||||
from typing import Awaitable, Callable, Pattern, Dict, TYPE_CHECKING
|
||||
import logging
|
||||
import re
|
||||
|
||||
@@ -27,27 +27,31 @@ from .abstract_user import AbstractUser
|
||||
from .db import BotChat
|
||||
from . import puppet as pu, portal as po, user as u
|
||||
|
||||
config = None
|
||||
if TYPE_CHECKING:
|
||||
from .config import Config
|
||||
|
||||
config = None # type: Config
|
||||
|
||||
ReplyFunc = Callable[[str], Awaitable[Message]]
|
||||
|
||||
|
||||
class Bot(AbstractUser):
|
||||
log = logging.getLogger("mau.bot")
|
||||
mxid_regex = re.compile("@.+:.+")
|
||||
log = logging.getLogger("mau.bot") # type: logging.Logger
|
||||
mxid_regex = re.compile("@.+:.+") # type: Pattern
|
||||
|
||||
def __init__(self, token: str):
|
||||
super().__init__()
|
||||
self.token = token
|
||||
self.puppet_whitelisted = True
|
||||
self.whitelisted = True
|
||||
self.relaybot_whitelisted = True
|
||||
self.username = None
|
||||
self.is_relaybot = True
|
||||
self.is_bot = True
|
||||
self.chats = {chat.id: chat.type for chat in BotChat.query.all()}
|
||||
self.tg_whitelist = []
|
||||
self.whitelist_group_admins = config["bridge.relaybot.whitelist_group_admins"] or False
|
||||
self.token = token # type: str
|
||||
self.puppet_whitelisted = True # type: bool
|
||||
self.whitelisted = True # type: bool
|
||||
self.relaybot_whitelisted = True # type: bool
|
||||
self.username = None # type: str
|
||||
self.is_relaybot = True # type: bool
|
||||
self.is_bot = True # type: bool
|
||||
self.chats = {chat.id: chat.type for chat in BotChat.query.all()} # type: Dict[int, str]
|
||||
self.tg_whitelist = [] # type: List[int]
|
||||
self.whitelist_group_admins = (config["bridge.relaybot.whitelist_group_admins"]
|
||||
or False) # type: bool
|
||||
|
||||
async def init_permissions(self):
|
||||
whitelist = config["bridge.relaybot.whitelist"] or []
|
||||
@@ -61,7 +65,7 @@ class Bot(AbstractUser):
|
||||
if isinstance(id, int):
|
||||
self.tg_whitelist.append(id)
|
||||
|
||||
async def start(self, delete_unless_authenticated=False):
|
||||
async def start(self, delete_unless_authenticated: bool = False) -> "Bot":
|
||||
await super().start(delete_unless_authenticated)
|
||||
if not await self.is_logged_in():
|
||||
await self.client.sign_in(bot_token=self.token)
|
||||
@@ -118,7 +122,7 @@ class Bot(AbstractUser):
|
||||
self.db.delete(existing_chat)
|
||||
self.db.commit()
|
||||
|
||||
async def _can_use_commands(self, chat, tgid):
|
||||
async def _can_use_commands(self, chat: TypePeer, tgid: int) -> bool:
|
||||
if tgid in self.tg_whitelist:
|
||||
return True
|
||||
|
||||
@@ -138,7 +142,7 @@ class Bot(AbstractUser):
|
||||
if p.user_id == tgid:
|
||||
return isinstance(p, (ChatParticipantCreator, ChatParticipantAdmin))
|
||||
|
||||
async def check_can_use_commands(self, event: Message, reply: ReplyFunc):
|
||||
async def check_can_use_commands(self, event: Message, reply: ReplyFunc) -> bool:
|
||||
if not await self._can_use_commands(event.to_id, event.from_id):
|
||||
await reply("You do not have the permission to use that command.")
|
||||
return False
|
||||
@@ -262,7 +266,7 @@ class Bot(AbstractUser):
|
||||
return "bot"
|
||||
|
||||
|
||||
def init(context):
|
||||
def init(context) -> Optional[Bot]:
|
||||
global config
|
||||
config = context.config
|
||||
token = config["telegram.bot_token"]
|
||||
|
||||
@@ -23,15 +23,14 @@ from .. import puppet as pu, portal as po
|
||||
|
||||
ManagementRoomList = List[Tuple[str, str]]
|
||||
RoomIDList = List[str]
|
||||
PortalList = List[po.Portal]
|
||||
|
||||
|
||||
async def _find_rooms(intent: IntentAPI) -> Tuple[
|
||||
ManagementRoomList, RoomIDList, PortalList, PortalList]:
|
||||
async def _find_rooms(intent: IntentAPI) -> Tuple[ManagementRoomList, RoomIDList,
|
||||
List["po.Portal"], List["po.Portal"]]:
|
||||
management_rooms = [] # type: ManagementRoomList
|
||||
unidentified_rooms = [] # type: RoomIDList
|
||||
portals = [] # type: PortalList
|
||||
empty_portals = [] # type: PortalList
|
||||
portals = [] # type: List[po.Portal]
|
||||
empty_portals = [] # type: List[po.Portal]
|
||||
|
||||
rooms = await intent.get_joined_rooms()
|
||||
for room in rooms:
|
||||
@@ -108,8 +107,8 @@ async def clean_rooms(evt: CommandEvent):
|
||||
|
||||
|
||||
async def set_rooms_to_clean(evt, management_rooms: ManagementRoomList,
|
||||
unidentified_rooms: RoomIDList, portals: PortalList,
|
||||
empty_portals: PortalList):
|
||||
unidentified_rooms: RoomIDList, portals: List["po.Portal"],
|
||||
empty_portals: List["po.Portal"]):
|
||||
command = evt.args[0]
|
||||
rooms_to_clean = []
|
||||
if command == "clean-recommended":
|
||||
|
||||
@@ -222,7 +222,7 @@ async def bridge(evt: CommandEvent):
|
||||
"chat to this room, use `$cmdprefix+sp continue`")
|
||||
|
||||
|
||||
async def cleanup_old_portal_while_bridging(evt: CommandEvent, portal: po.Portal):
|
||||
async def cleanup_old_portal_while_bridging(evt: CommandEvent, portal: "po.Portal"):
|
||||
if not portal.mxid:
|
||||
await evt.reply("The portal seems to have lost its Matrix room between you"
|
||||
"calling `$cmdprefix+sp bridge` and this command.\n\n"
|
||||
|
||||
+22
-21
@@ -14,6 +14,7 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Tuple, Any, Optional
|
||||
from ruamel.yaml import YAML
|
||||
from ruamel.yaml.comments import CommentedMap
|
||||
import random
|
||||
@@ -24,28 +25,28 @@ yaml.indent(4)
|
||||
|
||||
|
||||
class DictWithRecursion:
|
||||
def __init__(self, data=None):
|
||||
self._data = data or CommentedMap()
|
||||
def __init__(self, data: CommentedMap = None):
|
||||
self._data = data or CommentedMap() # type: CommentedMap
|
||||
|
||||
def _recursive_get(self, data, key, default_value):
|
||||
def _recursive_get(self, data: CommentedMap, key: str, default_value: Any) -> Any:
|
||||
if '.' in key:
|
||||
key, next_key = key.split('.', 1)
|
||||
next_data = data.get(key, CommentedMap())
|
||||
return self._recursive_get(next_data, next_key, default_value)
|
||||
return data.get(key, default_value)
|
||||
|
||||
def get(self, key, default_value, allow_recursion=True):
|
||||
def get(self, key: str, default_value: Any, allow_recursion: bool = True) -> Any:
|
||||
if allow_recursion and '.' in key:
|
||||
return self._recursive_get(self._data, key, default_value)
|
||||
return self._data.get(key, default_value)
|
||||
|
||||
def __getitem__(self, key):
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
return self.get(key, None)
|
||||
|
||||
def __contains__(self, key):
|
||||
def __contains__(self, key: str) -> bool:
|
||||
return self[key] is not None
|
||||
|
||||
def _recursive_set(self, data, key, value):
|
||||
def _recursive_set(self, data: CommentedMap, key: str, value: Any):
|
||||
if '.' in key:
|
||||
key, next_key = key.split('.', 1)
|
||||
if key not in data:
|
||||
@@ -55,16 +56,16 @@ class DictWithRecursion:
|
||||
return
|
||||
data[key] = value
|
||||
|
||||
def set(self, key, value, allow_recursion=True):
|
||||
def set(self, key: str, value: Any, allow_recursion: bool = True):
|
||||
if allow_recursion and '.' in key:
|
||||
self._recursive_set(self._data, key, value)
|
||||
return
|
||||
self._data[key] = value
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
def __setitem__(self, key: str, value: Any):
|
||||
self.set(key, value)
|
||||
|
||||
def _recursive_del(self, data, key):
|
||||
def _recursive_del(self, data: CommentedMap, key: str):
|
||||
if '.' in key:
|
||||
key, next_key = key.split('.', 1)
|
||||
if key not in data:
|
||||
@@ -78,7 +79,7 @@ class DictWithRecursion:
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
def delete(self, key, allow_recursion=True):
|
||||
def delete(self, key: str, allow_recursion: bool = True):
|
||||
if allow_recursion and '.' in key:
|
||||
self._recursive_del(self._data, key)
|
||||
return
|
||||
@@ -88,23 +89,23 @@ class DictWithRecursion:
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
def __delitem__(self, key):
|
||||
def __delitem__(self, key: str):
|
||||
self.delete(key)
|
||||
|
||||
|
||||
class Config(DictWithRecursion):
|
||||
def __init__(self, path, registration_path, base_path):
|
||||
def __init__(self, path: str, registration_path: str, base_path: str):
|
||||
super().__init__()
|
||||
self.path = path
|
||||
self.registration_path = registration_path
|
||||
self.base_path = base_path
|
||||
self._registration = None
|
||||
self.path = path # type: str
|
||||
self.registration_path = registration_path # type: str
|
||||
self.base_path = base_path # type: str
|
||||
self._registration = None # type: dict
|
||||
|
||||
def load(self):
|
||||
with open(self.path, 'r') as stream:
|
||||
self._data = yaml.load(stream)
|
||||
|
||||
def load_base(self):
|
||||
def load_base(self) -> Optional[DictWithRecursion]:
|
||||
try:
|
||||
with open(self.base_path, 'r') as stream:
|
||||
return DictWithRecursion(yaml.load(stream))
|
||||
@@ -120,7 +121,7 @@ class Config(DictWithRecursion):
|
||||
yaml.dump(self._registration, stream)
|
||||
|
||||
@staticmethod
|
||||
def _new_token():
|
||||
def _new_token() -> str:
|
||||
return "".join(random.choice(string.ascii_lowercase + string.digits) for _ in range(64))
|
||||
|
||||
def update(self):
|
||||
@@ -246,7 +247,7 @@ class Config(DictWithRecursion):
|
||||
self._data = base._data
|
||||
self.save()
|
||||
|
||||
def _get_permissions(self, key):
|
||||
def _get_permissions(self, key: str) -> Tuple[bool, bool, bool, bool, bool]:
|
||||
level = self["bridge.permissions"].get(key, "")
|
||||
admin = level == "admin"
|
||||
puppeting = level == "full" or admin
|
||||
@@ -254,7 +255,7 @@ class Config(DictWithRecursion):
|
||||
relaybot = level == "relaybot" or user
|
||||
return relaybot, user, puppeting, admin, level
|
||||
|
||||
def get_permissions(self, mxid):
|
||||
def get_permissions(self, mxid: str) -> Tuple[bool, bool, bool, bool, bool]:
|
||||
permissions = self["bridge.permissions"] or {}
|
||||
if mxid in permissions:
|
||||
return self._get_permissions(mxid)
|
||||
|
||||
+18
-12
@@ -14,21 +14,27 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Tuple
|
||||
import asyncio
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import asyncio
|
||||
|
||||
from sqlalchemy.orm import scoped_session
|
||||
|
||||
from alchemysession import AlchemySessionContainer
|
||||
from mautrix_appservice import AppService
|
||||
|
||||
from .web import PublicBridgeWebsite, ProvisioningAPI
|
||||
from .config import Config
|
||||
from .bot import Bot
|
||||
from .matrix import MatrixHandler
|
||||
|
||||
from sqlalchemy.orm import scoped_session
|
||||
from alchemysession import AlchemySessionContainer
|
||||
from mautrix_appservice import AppService
|
||||
|
||||
class Context:
|
||||
def __init__(self, az, db, config, loop, bot, mx, session_container, public_website,
|
||||
provisioning_api):
|
||||
from .web import PublicBridgeWebsite, ProvisioningAPI
|
||||
from .config import Config
|
||||
from .bot import Bot
|
||||
from .matrix import MatrixHandler
|
||||
|
||||
def __init__(self, az: "AppService", db: "scoped_session", config: "Config",
|
||||
loop: "asyncio.AbstractEventLoop", bot: "Bot", mx: "MatrixHandler",
|
||||
session_container: "AlchemySessionContainer",
|
||||
public_website: "PublicBridgeWebsite", provisioning_api: "ProvisioningAPI"):
|
||||
self.az = az # type: AppService
|
||||
self.db = db # type: scoped_session
|
||||
self.config = config # type: Config
|
||||
|
||||
@@ -42,6 +42,7 @@ class Portal(Base):
|
||||
about = Column(String, nullable=True)
|
||||
photo_id = Column(String, nullable=True)
|
||||
|
||||
|
||||
class Message(Base):
|
||||
query = None # type: Query
|
||||
__tablename__ = "message"
|
||||
|
||||
@@ -14,10 +14,10 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import (Optional, List, Tuple, Type, Callable, Dict, Any, Pattern, Deque, Match, TYPE_CHECKING)
|
||||
from html import unescape
|
||||
from html.parser import HTMLParser
|
||||
from collections import deque
|
||||
from typing import Optional, List, Tuple, Type, Callable, Dict, Any
|
||||
import math
|
||||
import re
|
||||
import logging
|
||||
@@ -27,37 +27,40 @@ from telethon.tl.types import (MessageEntityMention, MessageEntityMentionName, M
|
||||
MessageEntityItalic, MessageEntityCode, MessageEntityPre,
|
||||
MessageEntityBotCommand, TypeMessageEntity)
|
||||
|
||||
from .. import user as u, puppet as pu, portal as po, context as c
|
||||
from .. import user as u, puppet as pu, portal as po
|
||||
from ..db import Message as DBMessage
|
||||
from .util import (add_surrogates, remove_surrogates, trim_reply_fallback_html,
|
||||
trim_reply_fallback_text, html_to_unicode)
|
||||
|
||||
log = logging.getLogger("mau.fmt.mx")
|
||||
should_bridge_plaintext_highlights = False
|
||||
if TYPE_CHECKING:
|
||||
from ..context import Context
|
||||
|
||||
log = logging.getLogger("mau.fmt.mx") # type: logging.Logger
|
||||
should_bridge_plaintext_highlights = False # type: bool
|
||||
|
||||
|
||||
class MatrixParser(HTMLParser):
|
||||
mention_regex = re.compile("https://matrix.to/#/(@.+:.+)")
|
||||
room_regex = re.compile("https://matrix.to/#/(#.+:.+)")
|
||||
mention_regex = re.compile("https://matrix.to/#/(@.+:.+)") # type: Pattern
|
||||
room_regex = re.compile("https://matrix.to/#/(#.+:.+)") # type: Pattern
|
||||
block_tags = ("br", "p", "pre", "blockquote",
|
||||
"ol", "ul", "li",
|
||||
"h1", "h2", "h3", "h4", "h5", "h6",
|
||||
"div", "hr", "table")
|
||||
"div", "hr", "table") # type: Tuple[str, ...]
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.text = ""
|
||||
self.entities = []
|
||||
self._building_entities = {}
|
||||
self._list_counter = 0
|
||||
self._open_tags = deque()
|
||||
self._open_tags_meta = deque()
|
||||
self._line_is_new = True
|
||||
self._list_entry_is_new = False
|
||||
self.text = "" # type: str
|
||||
self.entities = [] # type: List[TypeMessageEntity]
|
||||
self._building_entities = {} # type: Dict[str, TypeMessageEntity]
|
||||
self._list_counter = 0 # type: int
|
||||
self._open_tags = deque() # type: Deque[str]
|
||||
self._open_tags_meta = deque() # type: Deque[Any]
|
||||
self._line_is_new = True # type: bool
|
||||
self._list_entry_is_new = False # type: bool
|
||||
|
||||
def _parse_url(self, url: str, args: Dict[str, Any]
|
||||
) -> Tuple[Optional[Type[TypeMessageEntity]], Optional[str]]:
|
||||
mention = self.mention_regex.match(url)
|
||||
mention = self.mention_regex.match(url) # type: Match
|
||||
if mention:
|
||||
mxid = mention.group(1)
|
||||
user = (pu.Puppet.get_by_mxid(mxid)
|
||||
@@ -72,7 +75,7 @@ class MatrixParser(HTMLParser):
|
||||
else:
|
||||
return None, None
|
||||
|
||||
room = self.room_regex.match(url)
|
||||
room = self.room_regex.match(url) # type: Match
|
||||
if room:
|
||||
username = po.Portal.get_username_from_mx_alias(room.group(1))
|
||||
portal = po.Portal.find_by_username(username)
|
||||
@@ -92,8 +95,8 @@ class MatrixParser(HTMLParser):
|
||||
self._open_tags_meta.appendleft(0)
|
||||
|
||||
attrs = dict(attrs)
|
||||
entity_type = None
|
||||
args = {}
|
||||
entity_type = None # type: type(TypeMessageEntity)
|
||||
args = {} # type: Dict[str, Any]
|
||||
if tag in ("strong", "b"):
|
||||
entity_type = MessageEntityBold
|
||||
elif tag in ("em", "i"):
|
||||
@@ -243,12 +246,12 @@ class MatrixParser(HTMLParser):
|
||||
self._newline(allow_multi=tag == "br")
|
||||
|
||||
|
||||
command_regex = re.compile(r"^!([A-Za-z0-9@]+)")
|
||||
not_command_regex = re.compile(r"^\\(![A-Za-z0-9@]+)")
|
||||
plain_mention_regex = None
|
||||
command_regex = re.compile(r"^!([A-Za-z0-9@]+)") # type: Pattern
|
||||
not_command_regex = re.compile(r"^\\(![A-Za-z0-9@]+)") # type: Pattern
|
||||
plain_mention_regex = None # type: Pattern
|
||||
|
||||
|
||||
def plain_mention_to_html(match):
|
||||
def plain_mention_to_html(match: Match) -> str:
|
||||
puppet = pu.Puppet.find_by_displayname(match.group(2))
|
||||
if puppet:
|
||||
return (f"{match.group(1)}"
|
||||
@@ -351,7 +354,7 @@ def plain_mention_to_text() -> Tuple[List[TypeMessageEntity], Callable[[str], st
|
||||
return entities, replacer
|
||||
|
||||
|
||||
def init_mx(context: c.Context):
|
||||
def init_mx(context: "Context"):
|
||||
global plain_mention_regex, should_bridge_plaintext_highlights
|
||||
config = context.config
|
||||
dn_template = config.get("bridge.displayname_template", "{displayname} (Telegram)")
|
||||
|
||||
@@ -14,13 +14,8 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Optional, List, Tuple, TYPE_CHECKING
|
||||
from html import escape
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
try:
|
||||
from lxml.html.diff import htmldiff
|
||||
except ImportError:
|
||||
htmldiff = None # type: function
|
||||
import logging
|
||||
import re
|
||||
|
||||
@@ -33,16 +28,26 @@ from telethon.tl.types import (MessageEntityMention, MessageEntityMentionName,
|
||||
from mautrix_appservice import MatrixRequestError
|
||||
from mautrix_appservice.intent_api import IntentAPI
|
||||
|
||||
from .. import user as u, puppet as pu, portal as po, context as c
|
||||
from .. import user as u, puppet as pu, portal as po
|
||||
from ..db import Message as DBMessage
|
||||
from .util import (add_surrogates, remove_surrogates, trim_reply_fallback_html,
|
||||
trim_reply_fallback_text, unicode_to_html)
|
||||
|
||||
log = logging.getLogger("mau.fmt.tg")
|
||||
should_highlight_edits = False
|
||||
if TYPE_CHECKING:
|
||||
from ..abstract_user import AbstractUser
|
||||
from ..context import Context
|
||||
|
||||
try:
|
||||
from lxml.html.diff import htmldiff
|
||||
except ImportError:
|
||||
htmldiff = None # type: function
|
||||
|
||||
|
||||
def telegram_reply_to_matrix(evt: Message, source: u.User) -> dict:
|
||||
log = logging.getLogger("mau.fmt.tg") # type: logging.Logger
|
||||
should_highlight_edits = False # type: bool
|
||||
|
||||
|
||||
def telegram_reply_to_matrix(evt: Message, source: "AbstractUser") -> dict:
|
||||
if evt.reply_to_msg_id:
|
||||
space = (evt.to_id.channel_id
|
||||
if isinstance(evt, Message) and isinstance(evt.to_id, PeerChannel)
|
||||
@@ -78,7 +83,7 @@ async def _add_forward_header(source, text: str, html: Optional[str],
|
||||
if not fwd_from_text:
|
||||
user = await source.client.get_entity(PeerUser(fwd_from.from_id))
|
||||
if user:
|
||||
fwd_from_text = pu.Puppet.get_displayname(user, format=False)
|
||||
fwd_from_text = pu.Puppet.get_displayname(user, False)
|
||||
fwd_from_html = f"<b>{fwd_from_text}</b>"
|
||||
|
||||
if not fwd_from_text:
|
||||
@@ -110,8 +115,9 @@ def highlight_edits(new_html: str, old_html: str) -> str:
|
||||
return new_html
|
||||
|
||||
|
||||
async def _add_reply_header(source: u.User, text: str, html: str, evt: Message, relates_to: dict,
|
||||
main_intent: IntentAPI, is_edit: bool) -> Tuple[str, str]:
|
||||
async def _add_reply_header(source: "AbstractUser", text: str, html: str, evt: Message,
|
||||
relates_to: dict, main_intent: IntentAPI, is_edit: bool
|
||||
) -> Tuple[str, str]:
|
||||
space = (evt.to_id.channel_id
|
||||
if isinstance(evt, Message) and isinstance(evt.to_id, PeerChannel)
|
||||
else source.tgid)
|
||||
@@ -142,7 +148,7 @@ async def _add_reply_header(source: u.User, text: str, html: str, evt: Message,
|
||||
|
||||
if is_edit and should_highlight_edits:
|
||||
html = highlight_edits(html or escape(text), r_html_body)
|
||||
except (ValueError, KeyError, MatrixRequestError) as e:
|
||||
except (ValueError, KeyError, MatrixRequestError):
|
||||
r_sender_link = "unknown user"
|
||||
r_displayname = "unknown user"
|
||||
r_text_body = "Failed to fetch message"
|
||||
@@ -154,8 +160,9 @@ async def _add_reply_header(source: u.User, text: str, html: str, evt: Message,
|
||||
|
||||
r_keyword = "In reply to" if not is_edit else "Edit to"
|
||||
r_msg_link = f"<a href='https://matrix.to/#/{msg.mx_room}/{msg.mxid}'>{r_keyword}</a>"
|
||||
html = (f"<mx-reply><blockquote>{r_msg_link} {r_sender_link}\n{r_html_body}</blockquote></mx-reply>"
|
||||
+ (html or escape(text)))
|
||||
html = (
|
||||
f"<mx-reply><blockquote>{r_msg_link} {r_sender_link}\n{r_html_body}</blockquote></mx-reply>"
|
||||
+ (html or escape(text)))
|
||||
|
||||
lines = r_text_body.strip().split("\n")
|
||||
text_with_quote = f"> <{r_displayname}> {lines.pop(0)}"
|
||||
@@ -167,7 +174,8 @@ async def _add_reply_header(source: u.User, text: str, html: str, evt: Message,
|
||||
return text_with_quote, html
|
||||
|
||||
|
||||
async def telegram_to_matrix(evt: Message, source: u.User, main_intent: Optional[IntentAPI] = None,
|
||||
async def telegram_to_matrix(evt: Message, source: "AbstractUser",
|
||||
main_intent: Optional[IntentAPI] = None,
|
||||
is_edit: bool = False, prefix_text: Optional[str] = None,
|
||||
prefix_html: Optional[str] = None) -> Tuple[str, str, dict]:
|
||||
text = add_surrogates(evt.message)
|
||||
@@ -320,6 +328,6 @@ def _parse_url(html: List[str], entity_text: str, url: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def init_tg(context: c.Context):
|
||||
def init_tg(context: "Context"):
|
||||
global should_highlight_edits
|
||||
should_highlight_edits = htmldiff and context.config["bridge.highlight_edits"]
|
||||
|
||||
@@ -14,8 +14,8 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Optional, Pattern
|
||||
from html import escape
|
||||
from typing import Optional
|
||||
import struct
|
||||
import re
|
||||
|
||||
@@ -47,7 +47,7 @@ def trim_reply_fallback_text(text: str) -> str:
|
||||
|
||||
html_reply_fallback_regex = re.compile("^<mx-reply>"
|
||||
r"[\s\S]+?"
|
||||
"</mx-reply>")
|
||||
"</mx-reply>") # type: Pattern
|
||||
|
||||
|
||||
def trim_reply_fallback_html(html: str) -> str:
|
||||
|
||||
+114
-108
@@ -14,26 +14,23 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import List, Dict
|
||||
from typing import List, Dict, Tuple, Set, Match
|
||||
import logging
|
||||
import asyncio
|
||||
import re
|
||||
|
||||
from mautrix_appservice import MatrixRequestError, IntentError
|
||||
|
||||
from .user import User
|
||||
from .portal import Portal
|
||||
from .puppet import Puppet
|
||||
from .commands import CommandProcessor
|
||||
from . import user as u, portal as po, puppet as pu, commands as com
|
||||
|
||||
|
||||
class MatrixHandler:
|
||||
log = logging.getLogger("mau.mx")
|
||||
log = logging.getLogger("mau.mx") # type: logging.Logger
|
||||
|
||||
def __init__(self, context):
|
||||
self.az, self.db, self.config, _, self.tgbot = context
|
||||
self.commands = CommandProcessor(context)
|
||||
self.previously_typing = []
|
||||
self.commands = com.CommandProcessor(context) # type: com.CommandProcessor
|
||||
self.previously_typing = [] # type: List[str]
|
||||
|
||||
self.az.matrix_event_handler(self.handle_event)
|
||||
|
||||
@@ -53,68 +50,68 @@ class MatrixHandler:
|
||||
except asyncio.TimeoutError:
|
||||
self.log.exception("TimeoutError when trying to set avatar")
|
||||
|
||||
async def handle_puppet_invite(self, room, puppet, inviter):
|
||||
async def handle_puppet_invite(self, room_id, puppet: pu.Puppet, inviter: u.User):
|
||||
intent = puppet.default_mxid_intent
|
||||
self.log.debug(f"{inviter} invited puppet for {puppet.tgid} to {room}")
|
||||
self.log.debug(f"{inviter} invited puppet for {puppet.tgid} to {room_id}")
|
||||
if not await inviter.is_logged_in():
|
||||
await intent.error_and_leave(
|
||||
room, text="Please log in before inviting Telegram puppets.")
|
||||
room_id, text="Please log in before inviting Telegram puppets.")
|
||||
return
|
||||
portal = Portal.get_by_mxid(room)
|
||||
portal = po.Portal.get_by_mxid(room_id)
|
||||
if portal:
|
||||
if portal.peer_type == "user":
|
||||
await intent.error_and_leave(
|
||||
room, text="You can not invite additional users to private chats.")
|
||||
room_id, text="You can not invite additional users to private chats.")
|
||||
return
|
||||
await portal.invite_telegram(inviter, puppet)
|
||||
await intent.join_room(room)
|
||||
await intent.join_room(room_id)
|
||||
return
|
||||
try:
|
||||
members = await self.az.intent.get_room_members(room)
|
||||
members = await self.az.intent.get_room_members(room_id)
|
||||
except MatrixRequestError:
|
||||
members = []
|
||||
if self.az.bot_mxid not in members:
|
||||
if len(members) > 1:
|
||||
await intent.error_and_leave(room, text=None, html=(
|
||||
await intent.error_and_leave(room_id, text=None, html=(
|
||||
f"Please invite "
|
||||
f"<a href='https://matrix.to/#/{self.az.bot_mxid}'>the bridge bot</a> "
|
||||
f"first if you want to create a Telegram chat."))
|
||||
return
|
||||
|
||||
await intent.join_room(room)
|
||||
portal = Portal.get_by_tgid(puppet.tgid, inviter.tgid, "user")
|
||||
await intent.join_room(room_id)
|
||||
portal = po.Portal.get_by_tgid(puppet.tgid, inviter.tgid, "user")
|
||||
if portal.mxid:
|
||||
try:
|
||||
await intent.invite(portal.mxid, inviter.mxid)
|
||||
await intent.send_notice(room, text=None, html=(
|
||||
await intent.send_notice(room_id, text=None, html=(
|
||||
"You already have a private chat with me: "
|
||||
f"<a href='https://matrix.to/#/{portal.mxid}'>"
|
||||
"Link to room"
|
||||
"</a>"))
|
||||
await intent.leave_room(room)
|
||||
await intent.leave_room(room_id)
|
||||
return
|
||||
except MatrixRequestError:
|
||||
pass
|
||||
portal.mxid = room
|
||||
portal.mxid = room_id
|
||||
portal.save()
|
||||
inviter.register_portal(portal)
|
||||
await intent.send_notice(room, "Portal to private chat created.")
|
||||
await intent.send_notice(room_id, "po.Portal to private chat created.")
|
||||
else:
|
||||
await intent.join_room(room)
|
||||
await intent.send_notice(room, "This puppet will remain inactive until a "
|
||||
"Telegram chat is created for this room.")
|
||||
await intent.join_room(room_id)
|
||||
await intent.send_notice(room_id, "This puppet will remain inactive until a "
|
||||
"Telegram chat is created for this room.")
|
||||
|
||||
async def accept_bot_invite(self, room, inviter):
|
||||
async def accept_bot_invite(self, room_id: str, inviter: u.User):
|
||||
tries = 0
|
||||
while tries < 5:
|
||||
try:
|
||||
await self.az.intent.join_room(room)
|
||||
await self.az.intent.join_room(room_id)
|
||||
break
|
||||
except (IntentError, MatrixRequestError) as e:
|
||||
except (IntentError, MatrixRequestError):
|
||||
tries += 1
|
||||
wait_for_seconds = (tries + 1) * 10
|
||||
if tries < 5:
|
||||
self.log.exception(f"Failed to join room {room} with bridge bot, "
|
||||
self.log.exception(f"Failed to join room {room_id} with bridge bot, "
|
||||
f"retrying in {wait_for_seconds} seconds...")
|
||||
await asyncio.sleep(wait_for_seconds)
|
||||
else:
|
||||
@@ -123,81 +120,81 @@ class MatrixHandler:
|
||||
|
||||
if not inviter.whitelisted:
|
||||
await self.az.intent.send_notice(
|
||||
room, text=None,
|
||||
room_id, text=None,
|
||||
html="You are not whitelisted to use this bridge.<br/><br/>"
|
||||
"If you are the owner of this bridge, see the "
|
||||
"<code>bridge.permissions</code> section in your config file.")
|
||||
await self.az.intent.leave_room(room)
|
||||
await self.az.intent.leave_room(room_id)
|
||||
|
||||
async def handle_invite(self, room, user, inviter):
|
||||
self.log.debug(f"{inviter} invited {user} to {room}")
|
||||
inviter = await User.get_by_mxid(inviter).ensure_started()
|
||||
if user == self.az.bot_mxid:
|
||||
return await self.accept_bot_invite(room, inviter)
|
||||
async def handle_invite(self, room_id: str, user_id: str, inviter_mxid: str):
|
||||
self.log.debug(f"{inviter_mxid} invited {user_id} to {room_id}")
|
||||
inviter = await u.User.get_by_mxid(inviter_mxid).ensure_started()
|
||||
if user_id == self.az.bot_mxid:
|
||||
return await self.accept_bot_invite(room_id, inviter)
|
||||
elif not inviter.whitelisted:
|
||||
return
|
||||
|
||||
puppet = Puppet.get_by_mxid(user)
|
||||
puppet = pu.Puppet.get_by_mxid(user_id)
|
||||
if puppet:
|
||||
await self.handle_puppet_invite(room, puppet, inviter)
|
||||
await self.handle_puppet_invite(room_id, puppet, inviter)
|
||||
return
|
||||
|
||||
user = User.get_by_mxid(user, create=False)
|
||||
user = u.User.get_by_mxid(user_id, create=False)
|
||||
if not user:
|
||||
return
|
||||
await user.ensure_started()
|
||||
portal = Portal.get_by_mxid(room)
|
||||
portal = po.Portal.get_by_mxid(room_id)
|
||||
if user and await user.has_full_access(allow_bot=True) and portal:
|
||||
await portal.invite_telegram(inviter, user)
|
||||
return
|
||||
|
||||
# The rest can probably be ignored
|
||||
|
||||
async def handle_join(self, room, user, event_id):
|
||||
user = await User.get_by_mxid(user).ensure_started()
|
||||
async def handle_join(self, room_id: str, user_id: str, event_id: str):
|
||||
user = await u.User.get_by_mxid(user_id).ensure_started()
|
||||
|
||||
portal = Portal.get_by_mxid(room)
|
||||
portal = po.Portal.get_by_mxid(room_id)
|
||||
if not portal:
|
||||
return
|
||||
|
||||
if not user.relaybot_whitelisted:
|
||||
await portal.main_intent.kick(room, user.mxid,
|
||||
await portal.main_intent.kick(room_id, user.mxid,
|
||||
"You are not whitelisted on this Telegram bridge.")
|
||||
return
|
||||
elif not await user.is_logged_in() and not portal.has_bot:
|
||||
await portal.main_intent.kick(room, user.mxid,
|
||||
await portal.main_intent.kick(room_id, user.mxid,
|
||||
"This chat does not have a bot relaying "
|
||||
"messages for unauthenticated users.")
|
||||
return
|
||||
|
||||
self.log.debug(f"{user} joined {room}")
|
||||
self.log.debug(f"{user} joined {room_id}")
|
||||
if await user.is_logged_in() or portal.has_bot:
|
||||
await portal.join_matrix(user, event_id)
|
||||
|
||||
async def handle_part(self, room, user, sender, event_id):
|
||||
self.log.debug(f"{user} left {room}")
|
||||
async def handle_part(self, room_id: str, user_id, sender_mxid: str, event_id: str):
|
||||
self.log.debug(f"{user_id} left {room_id}")
|
||||
|
||||
sender = User.get_by_mxid(sender, create=False)
|
||||
sender = u.User.get_by_mxid(sender_mxid, create=False)
|
||||
if not sender:
|
||||
return
|
||||
await sender.ensure_started()
|
||||
|
||||
portal = Portal.get_by_mxid(room)
|
||||
portal = po.Portal.get_by_mxid(room_id)
|
||||
if not portal:
|
||||
return
|
||||
|
||||
puppet = Puppet.get_by_mxid(user)
|
||||
puppet = pu.Puppet.get_by_mxid(user_id)
|
||||
if sender and puppet:
|
||||
await portal.leave_matrix(puppet, sender, event_id)
|
||||
|
||||
user = User.get_by_mxid(user, create=False)
|
||||
user = u.User.get_by_mxid(user_id, create=False)
|
||||
if not user:
|
||||
return
|
||||
await user.ensure_started()
|
||||
if await user.is_logged_in() or portal.has_bot:
|
||||
await portal.leave_matrix(user, sender, event_id)
|
||||
|
||||
def is_command(self, message):
|
||||
def is_command(self, message: dict) -> Tuple[bool, str]:
|
||||
text = message.get("body", "")
|
||||
prefix = self.config["bridge.command_prefix"]
|
||||
is_command = text.startswith(prefix)
|
||||
@@ -207,14 +204,14 @@ class MatrixHandler:
|
||||
|
||||
async def handle_message(self, room, sender, message, event_id):
|
||||
is_command, text = self.is_command(message)
|
||||
sender = await User.get_by_mxid(sender).ensure_started()
|
||||
sender = await u.User.get_by_mxid(sender).ensure_started()
|
||||
if not sender.relaybot_whitelisted:
|
||||
self.log.debug(f"Ignoring message \"{message}\" from {sender} to {room}:"
|
||||
" User is not whitelisted.")
|
||||
" u.User is not whitelisted.")
|
||||
return
|
||||
self.log.debug(f"Received Matrix event \"{message}\" from {sender} in {room}")
|
||||
|
||||
portal = Portal.get_by_mxid(room)
|
||||
portal = po.Portal.get_by_mxid(room)
|
||||
if not is_command and portal and (await sender.is_logged_in() or portal.has_bot):
|
||||
await portal.handle_matrix_message(sender, message, event_id)
|
||||
return
|
||||
@@ -239,39 +236,44 @@ class MatrixHandler:
|
||||
await self.commands.handle(room, sender, command, args, is_management,
|
||||
is_portal=portal is not None)
|
||||
|
||||
async def handle_redaction(self, room, sender, event_id):
|
||||
sender = await User.get_by_mxid(sender).ensure_started()
|
||||
@staticmethod
|
||||
async def handle_redaction(room_id: str, sender_mxid: str, event_id: str):
|
||||
sender = await u.User.get_by_mxid(sender_mxid).ensure_started()
|
||||
if not sender.relaybot_whitelisted:
|
||||
return
|
||||
|
||||
portal = Portal.get_by_mxid(room)
|
||||
portal = po.Portal.get_by_mxid(room_id)
|
||||
if not portal:
|
||||
return
|
||||
|
||||
await portal.handle_matrix_deletion(sender, event_id)
|
||||
|
||||
async def handle_power_levels(self, room, sender, new, old):
|
||||
portal = Portal.get_by_mxid(room)
|
||||
sender = await User.get_by_mxid(sender).ensure_started()
|
||||
@staticmethod
|
||||
async def handle_power_levels(room_id: str, sender_mxid: str, new: dict, old: dict):
|
||||
portal = po.Portal.get_by_mxid(room_id)
|
||||
sender = await u.User.get_by_mxid(sender_mxid).ensure_started()
|
||||
if await sender.has_full_access(allow_bot=True) and portal:
|
||||
await portal.handle_matrix_power_levels(sender, new["users"], old["users"])
|
||||
|
||||
async def handle_room_meta(self, type, room, sender, content):
|
||||
portal = Portal.get_by_mxid(room)
|
||||
sender = await User.get_by_mxid(sender).ensure_started()
|
||||
@staticmethod
|
||||
async def handle_room_meta(evt_type: str, room_id: str, sender_mxid: str, content: dict):
|
||||
portal = po.Portal.get_by_mxid(room_id)
|
||||
sender = await u.User.get_by_mxid(sender_mxid).ensure_started()
|
||||
if await sender.has_full_access(allow_bot=True) and portal:
|
||||
handler, content_key = {
|
||||
"m.room.name": (portal.handle_matrix_title, "name"),
|
||||
"m.room.topic": (portal.handle_matrix_about, "topic"),
|
||||
"m.room.avatar": (portal.handle_matrix_avatar, "url"),
|
||||
}[type]
|
||||
}[evt_type]
|
||||
if content_key not in content:
|
||||
return
|
||||
await handler(sender, content[content_key])
|
||||
|
||||
async def handle_room_pin(self, room, sender, new_events, old_events):
|
||||
portal = Portal.get_by_mxid(room)
|
||||
sender = await User.get_by_mxid(sender).ensure_started()
|
||||
@staticmethod
|
||||
async def handle_room_pin(room_id: str, sender_mxid: str, new_events: Set[str],
|
||||
old_events: Set[str]):
|
||||
portal = po.Portal.get_by_mxid(room_id)
|
||||
sender = await u.User.get_by_mxid(sender_mxid).ensure_started()
|
||||
if await sender.has_full_access(allow_bot=True) and portal:
|
||||
events = new_events - old_events
|
||||
if len(events) > 0:
|
||||
@@ -281,12 +283,14 @@ class MatrixHandler:
|
||||
# All pinned events removed, remove pinned event in Telegram.
|
||||
await portal.handle_matrix_pin(sender, None)
|
||||
|
||||
async def handle_name_change(self, room, user, displayname, prev_displayname, event_id):
|
||||
portal = Portal.get_by_mxid(room)
|
||||
@staticmethod
|
||||
async def handle_name_change(room_id: str, user_id: str, displayname: str,
|
||||
prev_displayname: str, event_id: str):
|
||||
portal = po.Portal.get_by_mxid(room_id)
|
||||
if not portal or not portal.has_bot:
|
||||
return
|
||||
|
||||
user = await User.get_by_mxid(user).ensure_started()
|
||||
user = await u.User.get_by_mxid(user_id).ensure_started()
|
||||
if await user.needs_relaybot(portal):
|
||||
await portal.name_change_matrix(user, displayname, prev_displayname, event_id)
|
||||
|
||||
@@ -296,25 +300,27 @@ class MatrixHandler:
|
||||
for event_id, receipts in content.items()
|
||||
for user_id in receipts.get("m.read", {})}
|
||||
|
||||
async def handle_read_receipts(self, room_id: str, receipts: Dict[str, str]):
|
||||
portal = Portal.get_by_mxid(room_id)
|
||||
@staticmethod
|
||||
async def handle_read_receipts(room_id: str, receipts: Dict[str, str]):
|
||||
portal = po.Portal.get_by_mxid(room_id)
|
||||
if not portal:
|
||||
return
|
||||
|
||||
for user_id, event_id in receipts.items():
|
||||
user = await User.get_by_mxid(user_id).ensure_started()
|
||||
user = await u.User.get_by_mxid(user_id).ensure_started()
|
||||
if not await user.is_logged_in():
|
||||
continue
|
||||
await portal.mark_read(user, event_id)
|
||||
|
||||
async def handle_presence(self, user: str, presence: str):
|
||||
user = await User.get_by_mxid(user).ensure_started()
|
||||
@staticmethod
|
||||
async def handle_presence(user_id: str, presence: str):
|
||||
user = await u.User.get_by_mxid(user_id).ensure_started()
|
||||
if not await user.is_logged_in():
|
||||
return
|
||||
await user.set_presence(presence == "online")
|
||||
|
||||
async def handle_typing(self, room_id: str, now_typing: List[str]):
|
||||
portal = Portal.get_by_mxid(room_id)
|
||||
portal = po.Portal.get_by_mxid(room_id)
|
||||
if not portal:
|
||||
return
|
||||
|
||||
@@ -324,7 +330,7 @@ class MatrixHandler:
|
||||
if is_typing and was_typing:
|
||||
continue
|
||||
|
||||
user = await User.get_by_mxid(user_id).ensure_started()
|
||||
user = await u.User.get_by_mxid(user_id).ensure_started()
|
||||
if not await user.is_logged_in():
|
||||
continue
|
||||
|
||||
@@ -332,38 +338,38 @@ class MatrixHandler:
|
||||
|
||||
self.previously_typing = now_typing
|
||||
|
||||
def filter_matrix_event(self, event):
|
||||
def filter_matrix_event(self, event: dict):
|
||||
sender = event.get("sender", None)
|
||||
if not sender:
|
||||
return False
|
||||
return (sender == self.az.bot_mxid
|
||||
or Puppet.get_id_from_mxid(sender) is not None)
|
||||
or pu.Puppet.get_id_from_mxid(sender) is not None)
|
||||
|
||||
async def try_handle_event(self, evt):
|
||||
async def try_handle_event(self, evt: dict):
|
||||
try:
|
||||
await self.handle_event(evt)
|
||||
except Exception:
|
||||
self.log.exception("Error handling manually received Matrix event")
|
||||
|
||||
async def handle_event(self, evt):
|
||||
async def handle_event(self, evt: dict):
|
||||
if self.filter_matrix_event(evt):
|
||||
return
|
||||
self.log.debug("Received event: %s", evt)
|
||||
type = evt.get("type", "m.unknown")
|
||||
room_id = evt.get("room_id", None)
|
||||
event_id = evt.get("event_id", None)
|
||||
sender = evt.get("sender", None)
|
||||
content = evt.get("content", {})
|
||||
if type == "m.room.member":
|
||||
state_key = evt["state_key"]
|
||||
prev_content = evt.get("unsigned", {}).get("prev_content", {})
|
||||
membership = content.get("membership", "")
|
||||
prev_membership = prev_content.get("membership", "leave")
|
||||
evt_type = evt.get("type", "m.unknown") # type: str
|
||||
room_id = evt.get("room_id", None) # type: str
|
||||
event_id = evt.get("event_id", None) # type: str
|
||||
sender = evt.get("sender", None) # type: str
|
||||
content = evt.get("content", {}) # type: dict
|
||||
if evt_type == "m.room.member":
|
||||
state_key = evt["state_key"] # type: str
|
||||
prev_content = evt.get("unsigned", {}).get("prev_content", {}) # type: dict
|
||||
membership = content.get("membership", "") # type: str
|
||||
prev_membership = prev_content.get("membership", "leave") # type: str
|
||||
if membership == prev_membership:
|
||||
match = re.compile("@(.+):(.+)").match(state_key)
|
||||
localpart = match.group(1)
|
||||
displayname = content.get("displayname", localpart)
|
||||
prev_displayname = prev_content.get("displayname", localpart)
|
||||
match = re.compile("@(.+):(.+)").match(state_key) # type: Match
|
||||
localpart = match.group(1) # type: str
|
||||
displayname = content.get("displayname", localpart) # type: str
|
||||
prev_displayname = prev_content.get("displayname", localpart) # type: str
|
||||
if displayname != prev_displayname:
|
||||
await self.handle_name_change(room_id, state_key, displayname,
|
||||
prev_displayname, event_id)
|
||||
@@ -373,26 +379,26 @@ class MatrixHandler:
|
||||
await self.handle_part(room_id, state_key, sender, event_id)
|
||||
elif membership == "join":
|
||||
await self.handle_join(room_id, state_key, event_id)
|
||||
elif type in ("m.room.message", "m.sticker"):
|
||||
if type != "m.room.message":
|
||||
content["msgtype"] = type
|
||||
elif evt_type in ("m.room.message", "m.sticker"):
|
||||
if evt_type != "m.room.message":
|
||||
content["msgtype"] = evt_type
|
||||
await self.handle_message(room_id, sender, content, event_id)
|
||||
elif type == "m.room.redaction":
|
||||
elif evt_type == "m.room.redaction":
|
||||
await self.handle_redaction(room_id, sender, evt["redacts"])
|
||||
elif type == "m.room.power_levels":
|
||||
elif evt_type == "m.room.power_levels":
|
||||
await self.handle_power_levels(room_id, sender, evt["content"], evt["prev_content"])
|
||||
elif type in ("m.room.name", "m.room.avatar", "m.room.topic"):
|
||||
await self.handle_room_meta(type, room_id, sender, evt["content"])
|
||||
elif type == "m.room.pinned_events":
|
||||
elif evt_type in ("m.room.name", "m.room.avatar", "m.room.topic"):
|
||||
await self.handle_room_meta(evt_type, room_id, sender, evt["content"])
|
||||
elif evt_type == "m.room.pinned_events":
|
||||
new_events = set(evt["content"]["pinned"])
|
||||
try:
|
||||
old_events = set(evt["unsigned"]["prev_content"]["pinned"])
|
||||
except KeyError:
|
||||
old_events = set()
|
||||
await self.handle_room_pin(room_id, sender, new_events, old_events)
|
||||
elif type == "m.receipt":
|
||||
elif evt_type == "m.receipt":
|
||||
await self.handle_read_receipts(room_id, self.parse_read_receipts(content))
|
||||
elif type == "m.presence":
|
||||
elif evt_type == "m.presence":
|
||||
await self.handle_presence(sender, content.get("presence", "offline"))
|
||||
elif type == "m.typing":
|
||||
elif evt_type == "m.typing":
|
||||
await self.handle_typing(room_id, content.get("user_ids", []))
|
||||
|
||||
+249
-196
File diff suppressed because it is too large
Load Diff
+34
-26
@@ -14,32 +14,39 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Optional, Awaitable, Pattern, Dict, List, TYPE_CHECKING
|
||||
from difflib import SequenceMatcher
|
||||
from typing import Optional, Awaitable
|
||||
import re
|
||||
import logging
|
||||
import asyncio
|
||||
|
||||
from sqlalchemy import orm
|
||||
|
||||
from telethon.tl.types import UserProfilePhoto
|
||||
from mautrix_appservice import AppService, IntentAPI, IntentError, MatrixRequestError
|
||||
|
||||
from .db import Puppet as DBPuppet
|
||||
from . import util, matrix
|
||||
from . import util
|
||||
|
||||
config = None
|
||||
if TYPE_CHECKING:
|
||||
from .matrix import MatrixHandler
|
||||
from .config import Config
|
||||
from .context import Context
|
||||
|
||||
config = None # type: Config
|
||||
|
||||
|
||||
class Puppet:
|
||||
log = logging.getLogger("mau.puppet")
|
||||
db = None
|
||||
log = logging.getLogger("mau.puppet") # type: logging.Logger
|
||||
db = None # type: orm.Session
|
||||
az = None # type: AppService
|
||||
mx = None # type: matrix.MatrixHandler
|
||||
mx = None # type: MatrixHandler
|
||||
loop = None # type: asyncio.AbstractEventLoop
|
||||
mxid_regex = None
|
||||
username_template = None
|
||||
hs_domain = None
|
||||
cache = {}
|
||||
by_custom_mxid = {}
|
||||
mxid_regex = None # type: Pattern
|
||||
username_template = None # type: str
|
||||
hs_domain = None # type: str
|
||||
cache = {} # type: Dict[str, Puppet]
|
||||
by_custom_mxid = {} # type: Dict[str, Puppet]
|
||||
|
||||
def __init__(self, id=None, access_token=None, custom_mxid=None, username=None,
|
||||
displayname=None, displayname_source=None, photo_id=None, is_bot=None,
|
||||
@@ -71,7 +78,8 @@ class Puppet:
|
||||
def tgid(self):
|
||||
return self.id
|
||||
|
||||
async def is_logged_in(self):
|
||||
@staticmethod
|
||||
async def is_logged_in():
|
||||
return True
|
||||
|
||||
# region Custom puppet management
|
||||
@@ -154,12 +162,12 @@ class Puppet:
|
||||
def filter_events(self, events):
|
||||
new_events = []
|
||||
for event in events:
|
||||
type = event.get("type", None)
|
||||
evt_type = event.get("type", None)
|
||||
event.setdefault("content", {})
|
||||
if type == "m.typing":
|
||||
if evt_type == "m.typing":
|
||||
is_typing = self.custom_mxid in event["content"].get("user_ids", [])
|
||||
event["content"]["user_ids"] = [self.custom_mxid] if is_typing else []
|
||||
elif type == "m.receipt":
|
||||
elif evt_type == "m.receipt":
|
||||
val = None
|
||||
evt = None
|
||||
for event_id in event["content"]:
|
||||
@@ -273,7 +281,7 @@ class Puppet:
|
||||
return round(similarity * 1000) / 10
|
||||
|
||||
@staticmethod
|
||||
def get_displayname(info, format=True):
|
||||
def get_displayname(info, enable_format=True):
|
||||
data = {
|
||||
"phone number": info.phone if hasattr(info, "phone") else None,
|
||||
"username": info.username,
|
||||
@@ -295,7 +303,7 @@ class Puppet:
|
||||
elif not name:
|
||||
name = info.id
|
||||
|
||||
if not format:
|
||||
if not enable_format:
|
||||
return name
|
||||
return config.get("bridge.displayname_template", "{displayname} (Telegram)").format(
|
||||
displayname=name)
|
||||
@@ -347,18 +355,18 @@ class Puppet:
|
||||
# region Getters
|
||||
|
||||
@classmethod
|
||||
def get(cls, id, create=True) -> "Optional[Puppet]":
|
||||
def get(cls, tgid, create=True) -> "Optional[Puppet]":
|
||||
try:
|
||||
return cls.cache[id]
|
||||
return cls.cache[tgid]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
puppet = DBPuppet.query.get(id)
|
||||
puppet = DBPuppet.query.get(tgid)
|
||||
if puppet:
|
||||
return cls.from_db(puppet)
|
||||
|
||||
if create:
|
||||
puppet = cls(id)
|
||||
puppet = cls(tgid)
|
||||
cls.db.add(puppet.db_instance)
|
||||
cls.db.commit()
|
||||
return puppet
|
||||
@@ -402,8 +410,8 @@ class Puppet:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_mxid_from_id(cls, id):
|
||||
return f"@{cls.username_template.format(userid=id)}:{cls.hs_domain}"
|
||||
def get_mxid_from_id(cls, tgid):
|
||||
return f"@{cls.username_template.format(userid=tgid)}:{cls.hs_domain}"
|
||||
|
||||
@classmethod
|
||||
def find_by_username(cls, username) -> "Optional[Puppet]":
|
||||
@@ -437,12 +445,12 @@ class Puppet:
|
||||
# endregion
|
||||
|
||||
|
||||
def init(context):
|
||||
def init(context: "Context") -> List[Awaitable[int]]:
|
||||
global config
|
||||
Puppet.az, Puppet.db, config, Puppet.loop, _ = context
|
||||
Puppet.mx = context.mx
|
||||
Puppet.username_template = config.get("bridge.username_template", "telegram_{userid}")
|
||||
Puppet.hs_domain = config["homeserver"]["domain"]
|
||||
localpart = Puppet.username_template.format(userid="(.+)")
|
||||
Puppet.mxid_regex = re.compile(f"@{localpart}:{Puppet.hs_domain}")
|
||||
Puppet.mxid_regex = re.compile(
|
||||
f"@{Puppet.username_template.format(userid='(.+)')}:{Puppet.hs_domain}")
|
||||
return [puppet.init_custom_mxid() for puppet in Puppet.get_all_with_custom_mxid()]
|
||||
|
||||
@@ -16,6 +16,8 @@
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Dict, Tuple
|
||||
|
||||
from sqlalchemy import orm
|
||||
|
||||
from mautrix_appservice import StateStore
|
||||
|
||||
from . import puppet as pu
|
||||
@@ -25,15 +27,17 @@ from .db import RoomState, UserProfile
|
||||
class SQLStateStore(StateStore):
|
||||
def __init__(self, db):
|
||||
super().__init__()
|
||||
self.db = db
|
||||
self.db = db # type: orm.Session
|
||||
self.profile_cache = {} # type: Dict[Tuple[str, str], UserProfile]
|
||||
self.room_state_cache = {} # type: Dict[str, RoomState]
|
||||
|
||||
def is_registered(self, user: str) -> bool:
|
||||
@staticmethod
|
||||
def is_registered(user: str) -> bool:
|
||||
puppet = pu.Puppet.get_by_mxid(user)
|
||||
return puppet.is_registered if puppet else False
|
||||
|
||||
def registered(self, user: str):
|
||||
@staticmethod
|
||||
def registered(user: str):
|
||||
puppet = pu.Puppet.get_by_mxid(user)
|
||||
if puppet:
|
||||
puppet.is_registered = True
|
||||
|
||||
@@ -17,10 +17,14 @@
|
||||
from telethon import TelegramClient, utils
|
||||
from telethon.tl.functions.messages import SendMediaRequest
|
||||
from telethon.tl.types import *
|
||||
from telethon.tl import custom
|
||||
|
||||
|
||||
class MautrixTelegramClient(TelegramClient):
|
||||
async def upload_file(self, file, mime_type=None, attributes=None, file_name=None):
|
||||
async def upload_file_direct(self, file: bytes, mime_type: str = None,
|
||||
attributes: List[TypeDocumentAttribute] = None,
|
||||
file_name: str = None
|
||||
) -> Union[InputMediaUploadedDocument, InputMediaUploadedPhoto]:
|
||||
file_handle = await super().upload_file(file, file_name=file_name, use_cache=False)
|
||||
|
||||
if mime_type == "image/png" or mime_type == "image/jpeg":
|
||||
@@ -34,7 +38,10 @@ class MautrixTelegramClient(TelegramClient):
|
||||
mime_type=mime_type or "application/octet-stream",
|
||||
attributes=list(attr_dict.values()))
|
||||
|
||||
async def send_media(self, entity, media, caption=None, entities=None, reply_to=None):
|
||||
async def send_media(self, entity: Union[TypeInputPeer, TypePeer],
|
||||
media: Union[TypeInputMedia, TypeMessageMedia],
|
||||
caption: str = None, entities: List[TypeMessageEntity] = None,
|
||||
reply_to: int = None) -> Optional[custom.Message]:
|
||||
entity = await self.get_input_entity(entity)
|
||||
reply_to = utils.get_message_id(reply_to)
|
||||
request = SendMediaRequest(entity, media, message=caption or "", entities=entities or [],
|
||||
|
||||
+58
-54
@@ -14,42 +14,51 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Dict, Awaitable, Optional
|
||||
from typing import Dict, Awaitable, Optional, Match, Tuple, TYPE_CHECKING
|
||||
import logging
|
||||
import asyncio
|
||||
import re
|
||||
|
||||
from telethon.tl.types import *
|
||||
from telethon.tl.types import User as TLUser
|
||||
from telethon.tl.types.contacts import ContactsNotModified
|
||||
from telethon.tl.functions.contacts import GetContactsRequest, SearchRequest
|
||||
from telethon.tl.functions.account import UpdateStatusRequest
|
||||
from mautrix_appservice import MatrixRequestError
|
||||
|
||||
from .db import User as DBUser, Contact as DBContact
|
||||
from .db import User as DBUser, Contact as DBContact, Portal as DBPortal
|
||||
from .abstract_user import AbstractUser
|
||||
from . import portal as po, puppet as pu
|
||||
|
||||
config = None
|
||||
if TYPE_CHECKING:
|
||||
from .config import Config
|
||||
from .context import Context
|
||||
|
||||
config = None # type: Config
|
||||
|
||||
SearchResults = List[Tuple["pu.Puppet", int]]
|
||||
|
||||
|
||||
class User(AbstractUser):
|
||||
log = logging.getLogger("mau.user")
|
||||
by_mxid = {}
|
||||
by_tgid = {}
|
||||
log = logging.getLogger("mau.user") # type: logging.Logger
|
||||
by_mxid = {} # type: Dict[str, User]
|
||||
by_tgid = {} # type: Dict[int, User]
|
||||
|
||||
def __init__(self, mxid, tgid=None, username=None, db_contacts=None, saved_contacts=0,
|
||||
is_bot=False, db_portals=None, db_instance=None):
|
||||
def __init__(self, mxid: str, tgid: Optional[int] = None, username: Optional[str] = None,
|
||||
db_contacts: Optional[List[DBContact]] = None, saved_contacts: int = 0,
|
||||
is_bot: bool = False, db_portals: Optional[List[DBPortal]] = None,
|
||||
db_instance: Optional[DBUser] = None):
|
||||
super().__init__()
|
||||
self.mxid = mxid # type: str
|
||||
self.tgid = tgid # type: int
|
||||
self.is_bot = is_bot # type: bool
|
||||
self.username = username # type: str
|
||||
self.contacts = []
|
||||
self.saved_contacts = saved_contacts
|
||||
self.db_contacts = db_contacts
|
||||
self.portals = {} # type: Dict[str, po.Portal]
|
||||
self.db_portals = db_portals
|
||||
self._db_instance = db_instance
|
||||
self.contacts = [] # type: List[pu.Puppet]
|
||||
self.saved_contacts = saved_contacts # type: int
|
||||
self.db_contacts = db_contacts # type: List[DBContact]
|
||||
self.portals = {} # type: Dict[Tuple[int, int], po.Portal]
|
||||
self.db_portals = db_portals # type: List[DBPortal]
|
||||
self._db_instance = db_instance # type: DBUser
|
||||
|
||||
self.command_status = None # type: dict
|
||||
|
||||
@@ -64,53 +73,47 @@ class User(AbstractUser):
|
||||
self.by_tgid[tgid] = self
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
def name(self) -> str:
|
||||
return self.mxid
|
||||
|
||||
@property
|
||||
def mxid_localpart(self):
|
||||
match = re.compile("@(.+):(.+)").match(self.mxid)
|
||||
def mxid_localpart(self) -> str:
|
||||
match = re.compile("@(.+):(.+)").match(self.mxid) # type: Match
|
||||
return match.group(1)
|
||||
|
||||
# TODO replace with proper displayname getting everywhere
|
||||
@property
|
||||
def displayname(self):
|
||||
def displayname(self) -> str:
|
||||
return self.mxid_localpart
|
||||
|
||||
@property
|
||||
def db_contacts(self):
|
||||
def db_contacts(self) -> List[DBContact]:
|
||||
return [self.db.merge(DBContact(user=self.tgid, contact=puppet.id))
|
||||
for puppet in self.contacts]
|
||||
|
||||
@db_contacts.setter
|
||||
def db_contacts(self, contacts):
|
||||
if contacts:
|
||||
self.contacts = [pu.Puppet.get(entry.contact) for entry in contacts]
|
||||
else:
|
||||
self.contacts = []
|
||||
def db_contacts(self, contacts: List[DBContact]):
|
||||
self.contacts = [pu.Puppet.get(entry.contact) for entry in contacts] if contacts else []
|
||||
|
||||
@property
|
||||
def db_portals(self):
|
||||
def db_portals(self) -> List[DBPortal]:
|
||||
return [portal.db_instance for portal in self.portals.values()]
|
||||
|
||||
@db_portals.setter
|
||||
def db_portals(self, portals):
|
||||
if portals:
|
||||
self.portals = {(portal.tgid, portal.tg_receiver):
|
||||
po.Portal.get_by_tgid(portal.tgid, portal.tg_receiver)
|
||||
for portal in portals}
|
||||
else:
|
||||
self.portals = {}
|
||||
def db_portals(self, portals: List[DBPortal]):
|
||||
self.portals = {(portal.tgid, portal.tg_receiver):
|
||||
po.Portal.get_by_tgid(portal.tgid, portal.tg_receiver)
|
||||
for portal in portals} if portals else {}
|
||||
|
||||
# region Database conversion
|
||||
|
||||
@property
|
||||
def db_instance(self):
|
||||
def db_instance(self) -> DBUser:
|
||||
if not self._db_instance:
|
||||
self._db_instance = self.new_db_instance()
|
||||
return self._db_instance
|
||||
|
||||
def new_db_instance(self):
|
||||
def new_db_instance(self) -> DBUser:
|
||||
return DBUser(mxid=self.mxid, tgid=self.tgid, tg_username=self.username,
|
||||
contacts=self.db_contacts, saved_contacts=self.saved_contacts or 0,
|
||||
portals=self.db_portals)
|
||||
@@ -134,14 +137,14 @@ class User(AbstractUser):
|
||||
self.db.commit()
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, db_user):
|
||||
def from_db(cls, db_user: DBUser) -> "User":
|
||||
return User(db_user.mxid, db_user.tgid, db_user.tg_username, db_user.contacts,
|
||||
False, db_user.saved_contacts, db_user.portals, db_instance=db_user)
|
||||
|
||||
# endregion
|
||||
# region Telegram connection management
|
||||
|
||||
async def start(self, delete_unless_authenticated=False):
|
||||
async def start(self, delete_unless_authenticated: bool = False) -> "User":
|
||||
await super().start()
|
||||
if await self.is_logged_in():
|
||||
self.log.debug(f"Ensuring post_login() for {self.name}")
|
||||
@@ -152,7 +155,7 @@ class User(AbstractUser):
|
||||
self.client.session.delete()
|
||||
return self
|
||||
|
||||
async def post_login(self, info=None):
|
||||
async def post_login(self, info: TLUser = None):
|
||||
try:
|
||||
await self.update_info(info)
|
||||
if not self.is_bot:
|
||||
@@ -163,7 +166,7 @@ class User(AbstractUser):
|
||||
except Exception:
|
||||
self.log.exception("Failed to run post-login functions for %s", self.mxid)
|
||||
|
||||
async def update(self, update):
|
||||
async def update(self, update: TypeUpdate):
|
||||
if not self.is_bot:
|
||||
return
|
||||
|
||||
@@ -186,7 +189,7 @@ class User(AbstractUser):
|
||||
# endregion
|
||||
# region Telegram actions that need custom methods
|
||||
|
||||
def ensure_started(self, even_if_no_session=False) -> "Awaitable[User]":
|
||||
def ensure_started(self, even_if_no_session: bool = False) -> "Awaitable[User]":
|
||||
return super().ensure_started(even_if_no_session)
|
||||
|
||||
def set_presence(self, online: bool = True):
|
||||
@@ -194,7 +197,7 @@ class User(AbstractUser):
|
||||
return
|
||||
return self.client(UpdateStatusRequest(offline=not online))
|
||||
|
||||
async def update_info(self, info: User = None):
|
||||
async def update_info(self, info: TLUser = None):
|
||||
info = info or await self.client.get_me()
|
||||
changed = False
|
||||
if self.is_bot != info.bot:
|
||||
@@ -233,8 +236,9 @@ class User(AbstractUser):
|
||||
self.delete()
|
||||
return True
|
||||
|
||||
def _search_local(self, query, max_results=5, min_similarity=45):
|
||||
results = []
|
||||
def _search_local(self, query: str, max_results: int = 5, min_similarity: int = 45
|
||||
) -> SearchResults:
|
||||
results = [] # type: SearchResults
|
||||
for contact in self.contacts:
|
||||
similarity = contact.similarity(query)
|
||||
if similarity >= min_similarity:
|
||||
@@ -242,11 +246,11 @@ class User(AbstractUser):
|
||||
results.sort(key=lambda tup: tup[1], reverse=True)
|
||||
return results[0:max_results]
|
||||
|
||||
async def _search_remote(self, query, max_results=5):
|
||||
async def _search_remote(self, query: str, max_results: int = 5) -> SearchResults:
|
||||
if len(query) < 5:
|
||||
return []
|
||||
server_results = await self.client(SearchRequest(q=query, limit=max_results))
|
||||
results = []
|
||||
results = [] # type: SearchResults
|
||||
for user in server_results.users:
|
||||
puppet = pu.Puppet.get(user.id)
|
||||
await puppet.update_info(self, user)
|
||||
@@ -254,7 +258,7 @@ class User(AbstractUser):
|
||||
results.sort(key=lambda tup: tup[1], reverse=True)
|
||||
return results[0:max_results]
|
||||
|
||||
async def search(self, query, force_remote=False):
|
||||
async def search(self, query: str, force_remote: bool = False) -> Tuple[SearchResults, bool]:
|
||||
if force_remote:
|
||||
return await self._search_remote(query), True
|
||||
|
||||
@@ -264,7 +268,7 @@ class User(AbstractUser):
|
||||
|
||||
return await self._search_remote(query), True
|
||||
|
||||
async def sync_dialogs(self, synchronous_create=False):
|
||||
async def sync_dialogs(self, synchronous_create: bool = False):
|
||||
creators = []
|
||||
for entity in await self.get_dialogs(limit=30):
|
||||
portal = po.Portal.get_by_entity(entity)
|
||||
@@ -275,7 +279,7 @@ class User(AbstractUser):
|
||||
self.save()
|
||||
await asyncio.gather(*creators, loop=self.loop)
|
||||
|
||||
def register_portal(self, portal):
|
||||
def register_portal(self, portal: po.Portal):
|
||||
try:
|
||||
if self.portals[portal.tgid_full] == portal:
|
||||
return
|
||||
@@ -284,18 +288,18 @@ class User(AbstractUser):
|
||||
self.portals[portal.tgid_full] = portal
|
||||
self.save()
|
||||
|
||||
def unregister_portal(self, portal):
|
||||
def unregister_portal(self, portal: po.Portal):
|
||||
try:
|
||||
del self.portals[portal.tgid_full]
|
||||
self.save()
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
async def needs_relaybot(self, portal):
|
||||
async def needs_relaybot(self, portal: po.Portal) -> bool:
|
||||
return not await self.is_logged_in() or (
|
||||
self.is_bot and portal.tgid_full not in self.portals)
|
||||
|
||||
def _hash_contacts(self):
|
||||
def _hash_contacts(self) -> int:
|
||||
acc = 0
|
||||
for id in sorted([self.saved_contacts] + [contact.id for contact in self.contacts]):
|
||||
acc = (acc * 20261 + id) & 0xffffffff
|
||||
@@ -318,7 +322,7 @@ class User(AbstractUser):
|
||||
# region Class instance lookup
|
||||
|
||||
@classmethod
|
||||
def get_by_mxid(cls, mxid, create=True) -> "Optional[User]":
|
||||
def get_by_mxid(cls, mxid: str, create: bool=True) -> "Optional[User]":
|
||||
if not mxid:
|
||||
raise ValueError("Matrix ID can't be empty")
|
||||
|
||||
@@ -341,7 +345,7 @@ class User(AbstractUser):
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_by_tgid(cls, tgid) -> "Optional[User]":
|
||||
def get_by_tgid(cls, tgid: int) -> "Optional[User]":
|
||||
try:
|
||||
return cls.by_tgid[tgid]
|
||||
except KeyError:
|
||||
@@ -355,7 +359,7 @@ class User(AbstractUser):
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def find_by_username(cls, username) -> "Optional[User]":
|
||||
def find_by_username(cls, username: str) -> "Optional[User]":
|
||||
if not username:
|
||||
return None
|
||||
|
||||
@@ -371,7 +375,7 @@ class User(AbstractUser):
|
||||
# endregion
|
||||
|
||||
|
||||
def init(context):
|
||||
def init(context: "Context") -> List[Awaitable[User]]:
|
||||
global config
|
||||
config = context.config
|
||||
|
||||
|
||||
@@ -14,15 +14,25 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Optional, Tuple, Union, Dict
|
||||
from io import BytesIO
|
||||
import time
|
||||
import logging
|
||||
import asyncio
|
||||
|
||||
import magic
|
||||
from sqlalchemy import orm
|
||||
from sqlalchemy.exc import IntegrityError, InvalidRequestError
|
||||
from sqlalchemy.orm.exc import FlushError
|
||||
|
||||
from telethon.tl.types import (Document, FileLocation, InputFileLocation,
|
||||
InputDocumentFileLocation, PhotoSize, PhotoCachedSize)
|
||||
from telethon.errors import *
|
||||
from mautrix_appservice import IntentAPI
|
||||
|
||||
from ..tgclient import MautrixTelegramClient
|
||||
from ..db import TelegramFile as DBTelegramFile
|
||||
|
||||
try:
|
||||
from PIL import Image
|
||||
except ImportError:
|
||||
@@ -36,20 +46,18 @@ try:
|
||||
except ImportError:
|
||||
VideoFileClip = random = string = os = mimetypes = None
|
||||
|
||||
from telethon.tl.types import (Document, FileLocation, InputFileLocation,
|
||||
InputDocumentFileLocation, PhotoSize, PhotoCachedSize)
|
||||
from telethon.errors import *
|
||||
log = logging.getLogger("mau.util") # type: logging.Logger
|
||||
|
||||
from ..db import TelegramFile as DBTelegramFile
|
||||
|
||||
log = logging.getLogger("mau.util")
|
||||
TypeLocation = Union[Document, InputDocumentFileLocation, FileLocation, InputFileLocation]
|
||||
|
||||
|
||||
def convert_image(file, source_mime="image/webp", target_type="png", thumbnail_to=None):
|
||||
def convert_image(file: bytes, source_mime: str = "image/webp", target_type: str = "png",
|
||||
thumbnail_to: Optional[Tuple[int, int]] = None
|
||||
) -> Tuple[str, bytes, Optional[int], Optional[int]]:
|
||||
if not Image:
|
||||
return source_mime, file, None, None
|
||||
try:
|
||||
image = Image.open(BytesIO(file)).convert("RGBA")
|
||||
image = Image.open(BytesIO(file)).convert("RGBA") # type: Image.Image
|
||||
if thumbnail_to:
|
||||
image.thumbnail(thumbnail_to, Image.ANTIALIAS)
|
||||
new_file = BytesIO()
|
||||
@@ -61,13 +69,14 @@ def convert_image(file, source_mime="image/webp", target_type="png", thumbnail_t
|
||||
return source_mime, file, None, None
|
||||
|
||||
|
||||
def _temp_file_name(ext):
|
||||
def _temp_file_name(ext: str) -> str:
|
||||
return ("/tmp/mxtg-video-"
|
||||
+ "".join(random.choice(string.ascii_uppercase + string.digits) for _ in range(10))
|
||||
+ ext)
|
||||
|
||||
|
||||
def _read_video_thumbnail(data, video_ext="mp4", frame_ext="png", max_size=(1024, 720)):
|
||||
def _read_video_thumbnail(data: bytes, video_ext: str = "mp4", frame_ext: str = "png",
|
||||
max_size: Tuple[int, int] = (1024, 720)) -> Tuple[bytes, int, int]:
|
||||
# We don't have any way to read the video from memory, so save it to disk.
|
||||
temp_file = _temp_file_name(video_ext)
|
||||
with open(temp_file, "wb") as file:
|
||||
@@ -90,21 +99,21 @@ def _read_video_thumbnail(data, video_ext="mp4", frame_ext="png", max_size=(1024
|
||||
return thumbnail_file.getvalue(), w, h
|
||||
|
||||
|
||||
def _location_to_id(location):
|
||||
def _location_to_id(location: TypeLocation) -> str:
|
||||
if isinstance(location, (Document, InputDocumentFileLocation)):
|
||||
return f"{location.id}-{location.version}"
|
||||
elif isinstance(location, (FileLocation, InputFileLocation)):
|
||||
return f"{location.volume_id}-{location.local_id}"
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
async def transfer_thumbnail_to_matrix(client, intent, thumbnail_loc, video, mime):
|
||||
async def transfer_thumbnail_to_matrix(client: MautrixTelegramClient, intent: IntentAPI,
|
||||
thumbnail_loc: TypeLocation, video: bytes,
|
||||
mime: str) -> Optional[DBTelegramFile]:
|
||||
if not Image or not VideoFileClip:
|
||||
return None
|
||||
|
||||
id = _location_to_id(thumbnail_loc)
|
||||
if not id:
|
||||
loc_id = _location_to_id(thumbnail_loc)
|
||||
if not loc_id:
|
||||
return None
|
||||
|
||||
video_ext = mimetypes.guess_extension(mime)
|
||||
@@ -121,36 +130,40 @@ async def transfer_thumbnail_to_matrix(client, intent, thumbnail_loc, video, mim
|
||||
|
||||
content_uri = await intent.upload_file(file, mime_type)
|
||||
|
||||
return DBTelegramFile(id=id, mxc=content_uri, mime_type=mime_type,
|
||||
return DBTelegramFile(id=loc_id, mxc=content_uri, mime_type=mime_type,
|
||||
was_converted=False, timestamp=int(time.time()), size=len(file),
|
||||
width=width, height=height)
|
||||
|
||||
|
||||
transfer_locks = {}
|
||||
transfer_locks_lock = asyncio.Lock()
|
||||
transfer_locks = {} # type: Dict[str, asyncio.Lock]
|
||||
|
||||
|
||||
async def transfer_file_to_matrix(db, client, intent, location, thumbnail=None, is_sticker=False):
|
||||
id = _location_to_id(location)
|
||||
if not id:
|
||||
async def transfer_file_to_matrix(db: orm.Session, client: MautrixTelegramClient, intent: IntentAPI,
|
||||
location: TypeLocation, thumbnail: Optional[TypeLocation] = None,
|
||||
is_sticker: bool = False) -> Optional[DBTelegramFile]:
|
||||
location_id = _location_to_id(location)
|
||||
if not location_id:
|
||||
return None
|
||||
|
||||
db_file = DBTelegramFile.query.get(id)
|
||||
db_file = DBTelegramFile.query.get(location_id)
|
||||
if db_file:
|
||||
return db_file
|
||||
|
||||
async with transfer_locks_lock:
|
||||
try:
|
||||
lock = transfer_locks[id]
|
||||
except KeyError:
|
||||
lock = asyncio.Lock()
|
||||
transfer_locks[id] = lock
|
||||
try:
|
||||
lock = transfer_locks[location_id]
|
||||
except KeyError:
|
||||
lock = asyncio.Lock()
|
||||
transfer_locks[location_id] = lock
|
||||
async with lock:
|
||||
return await _unlocked_transfer_file_to_matrix(db, client, intent, id, location, thumbnail, is_sticker)
|
||||
return await _unlocked_transfer_file_to_matrix(db, client, intent, location_id, location,
|
||||
thumbnail, is_sticker)
|
||||
|
||||
|
||||
async def _unlocked_transfer_file_to_matrix(db, client, intent, id, location, thumbnail, is_sticker):
|
||||
db_file = DBTelegramFile.query.get(id)
|
||||
async def _unlocked_transfer_file_to_matrix(db: orm.Session, client: MautrixTelegramClient,
|
||||
intent: IntentAPI, loc_id: str, location: TypeLocation,
|
||||
thumbnail: Optional[TypeLocation],
|
||||
is_sticker: bool) -> Optional[DBTelegramFile]:
|
||||
db_file = DBTelegramFile.query.get(loc_id)
|
||||
if db_file:
|
||||
return db_file
|
||||
|
||||
@@ -167,15 +180,16 @@ async def _unlocked_transfer_file_to_matrix(db, client, intent, id, location, th
|
||||
|
||||
image_converted = False
|
||||
if mime_type == "image/webp":
|
||||
new_mime_type, file, width, height = convert_image(file, source_mime="image/webp", target_type="png", thumbnail_to=(
|
||||
256, 256) if is_sticker else None)
|
||||
new_mime_type, file, width, height = convert_image(
|
||||
file, source_mime="image/webp", target_type="png",
|
||||
thumbnail_to=(256, 256) if is_sticker else None)
|
||||
image_converted = new_mime_type != mime_type
|
||||
mime_type = new_mime_type
|
||||
thumbnail = None
|
||||
|
||||
content_uri = await intent.upload_file(file, mime_type)
|
||||
|
||||
db_file = DBTelegramFile(id=id, mxc=content_uri,
|
||||
db_file = DBTelegramFile(id=loc_id, mxc=content_uri,
|
||||
mime_type=mime_type, was_converted=image_converted,
|
||||
timestamp=int(time.time()), size=len(file),
|
||||
width=width, height=height)
|
||||
|
||||
@@ -16,10 +16,12 @@
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
|
||||
def format_duration(seconds):
|
||||
def pluralize(count, singular): return singular if count == 1 else singular + "s"
|
||||
def format_duration(seconds: int) -> str:
|
||||
def pluralize(count, singular):
|
||||
return singular if count == 1 else singular + "s"
|
||||
|
||||
def include(count, word): return f"{count} {pluralize(count, word)}" if count > 0 else ""
|
||||
def include(count, word):
|
||||
return f"{count} {pluralize(count, word)}" if count > 0 else ""
|
||||
|
||||
minutes, seconds = divmod(seconds, 60)
|
||||
hours, minutes = divmod(minutes, 60)
|
||||
|
||||
Reference in New Issue
Block a user