Even more migrations to mautrix-python

This commit is contained in:
Tulir Asokan
2019-07-19 21:36:21 +03:00
parent eef498d47a
commit d4e3956941
20 changed files with 215 additions and 337 deletions
+17 -17
View File
@@ -13,41 +13,41 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict, List, NewType, Optional, Tuple, Union
from typing import Dict, List, NamedTuple, Optional, Tuple, Union
from mautrix_appservice import MatrixRequestError, IntentAPI
from mautrix.appservice import IntentAPI
from mautrix.errors import MatrixRequestError
from mautrix.types import RoomID, UserID
from ..types import MatrixRoomID, MatrixUserID
from . import command_handler, CommandEvent, SECTION_ADMIN
from .. import puppet as pu, portal as po
ManagementRoom = NewType('ManagementRoom', Tuple[MatrixRoomID, MatrixUserID])
ManagementRoom = NamedTuple('ManagementRoom', room_id=RoomID, user_id=UserID)
async def _find_rooms(intent: IntentAPI) -> Tuple[List[ManagementRoom], List[MatrixRoomID],
async def _find_rooms(intent: IntentAPI) -> Tuple[List[ManagementRoom], List[RoomID],
List['po.Portal'], List['po.Portal']]:
management_rooms = [] # type: List[ManagementRoom]
unidentified_rooms = [] # type: List[MatrixRoomID]
portals = [] # type: List[po.Portal]
empty_portals = [] # type: List[po.Portal]
management_rooms: List[ManagementRoom] = []
unidentified_rooms: List[RoomID] = []
portals: List[po.Portal] = []
empty_portals: List[po.Portal] = []
rooms = await intent.get_joined_rooms()
for room_str in rooms:
room = MatrixRoomID(room_str)
portal = po.Portal.get_by_mxid(room)
for room_id in rooms:
portal = po.Portal.get_by_mxid(room_id)
if not portal:
try:
members = await intent.get_room_members(room)
members = await intent.get_room_members(room_id)
except MatrixRequestError:
members = []
if len(members) == 2:
other_member = MatrixUserID(members[0] if members[0] != intent.mxid else members[1])
other_member = members[0] if members[0] != intent.mxid else members[1]
if pu.Puppet.get_id_from_mxid(other_member):
unidentified_rooms.append(room)
unidentified_rooms.append(room_id)
else:
management_rooms.append(ManagementRoom((room, other_member)))
management_rooms.append(ManagementRoom(room_id, other_member))
else:
unidentified_rooms.append(room)
unidentified_rooms.append(room_id)
else:
members = await portal.get_authenticated_matrix_users()
if len(members) == 0:
+12 -11
View File
@@ -22,11 +22,12 @@ import commonmark
from telethon.errors import FloodWaitError
from ..types import MatrixRoomID, MatrixEventID
from mautrix.types import RoomID, EventID
from ..util import format_duration
from .. import user as u, context as c
command_handlers = {} # type: Dict[str, CommandHandler]
command_handlers: Dict[str, 'CommandHandler'] = {}
HelpSection = NamedTuple('HelpSection', [('name', str), ('order', int), ('description', str)])
@@ -82,7 +83,7 @@ class CommandEvent:
is a portal.
"""
def __init__(self, processor: 'CommandProcessor', room: MatrixRoomID, event: MatrixEventID,
def __init__(self, processor: 'CommandProcessor', room: RoomID, event: EventID,
sender: u.User, command: str, args: List[str], is_management: bool,
is_portal: bool) -> None:
self.az = processor.az
@@ -101,7 +102,7 @@ class CommandEvent:
self.is_portal = is_portal
def reply(self, message: str, allow_html: bool = False, render_markdown: bool = True
) -> Awaitable[Dict]:
) -> Awaitable[EventID]:
"""Write a reply to the room in which the command was issued.
Replaces occurences of "$cmdprefix" in the message with the command
@@ -178,7 +179,7 @@ class CommandHandler:
help_section: Section of the help in which this command will appear.
"""
def __init__(self, handler: Callable[[CommandEvent], Awaitable[Dict]], needs_auth: bool,
def __init__(self, handler: Callable[[CommandEvent], Awaitable[EventID]], needs_auth: bool,
needs_puppeting: bool, needs_matrix_puppeting: bool, needs_admin: bool,
management_only: bool, name: str, help_text: str, help_args: str,
help_section: HelpSection) -> None:
@@ -255,7 +256,7 @@ class CommandHandler:
(not self.needs_admin or is_admin) and
(not self.needs_auth or is_logged_in))
async def __call__(self, evt: CommandEvent) -> Dict:
async def __call__(self, evt: CommandEvent) -> EventID:
"""Executes the command if evt was issued with proper rights.
Args:
@@ -283,14 +284,14 @@ class CommandHandler:
return f"**{self.name}** {self._help_args} - {self._help_text}"
def command_handler(_func: Optional[Callable[[CommandEvent], Awaitable[Dict]]] = None, *,
def command_handler(_func: Optional[Callable[[CommandEvent], Awaitable[EventID]]] = None, *,
needs_auth: bool = True, needs_puppeting: bool = True,
needs_matrix_puppeting: bool = False, needs_admin: bool = False,
management_only: bool = False, name: Optional[str] = None,
help_text: str = "", help_args: str = "", help_section: HelpSection = None
) -> Callable[[Callable[[CommandEvent], Awaitable[Optional[Dict]]]],
) -> Callable[[Callable[[CommandEvent], Awaitable[Optional[EventID]]]],
CommandHandler]:
def decorator(func: Callable[[CommandEvent], Awaitable[Optional[Dict]]]) -> CommandHandler:
def decorator(func: Callable[[CommandEvent], Awaitable[Optional[EventID]]]) -> CommandHandler:
actual_name = name or func.__name__.replace("_", "-")
handler = CommandHandler(func, needs_auth, needs_puppeting, needs_matrix_puppeting,
needs_admin, management_only, actual_name, help_text, help_args,
@@ -310,9 +311,9 @@ class CommandProcessor:
self.public_website = context.public_website
self.command_prefix = self.config["bridge.command_prefix"]
async def handle(self, room: MatrixRoomID, event_id: MatrixEventID, sender: u.User,
async def handle(self, room: RoomID, event_id: EventID, sender: u.User,
command: str, args: List[str], is_management: bool, is_portal: bool
) -> Optional[Dict]:
) -> Optional[EventID]:
"""Handles the raw commands issued by a user to the Matrix bot.
If the command is not known, it might be a followup command and is
+2 -3
View File
@@ -14,14 +14,13 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional, Tuple, TYPE_CHECKING
import asyncio
from alchemysession import AlchemySessionContainer
from mautrix_appservice import AppService
from mautrix.appservice import AppService
if TYPE_CHECKING:
from .web import PublicBridgeWebsite, ProvisioningAPI
from .config import Config
from .bot import Bot
-2
View File
@@ -18,10 +18,8 @@ from .bot_chat import BotChat
from .message import Message
from .portal import Portal
from .puppet import Puppet
from .room_state import RoomState
from .telegram_file import TelegramFile
from .user import User, UserPortal, Contact
from .user_profile import UserProfile
def init(db_engine) -> None:
+5 -4
View File
@@ -23,10 +23,10 @@ from sqlalchemy.ext.declarative import declarative_base
class BaseBase:
db = None # type: Engine
t = None # type: Table
__table__ = None # type: Table
c = None # type: ImmutableColumnCollection
db: Engine = None
t: Table = None
__table__: Table = None
c: ImmutableColumnCollection = None
@classmethod
@abstractmethod
@@ -54,4 +54,5 @@ class BaseBase:
with self.db.begin() as conn:
conn.execute(self.t.delete().where(self._edit_identity))
Base = declarative_base(cls=BaseBase)
+1 -1
View File
@@ -38,7 +38,7 @@ from .util import (add_surrogates, remove_surrogates, trim_reply_fallback_html,
if TYPE_CHECKING:
from ..abstract_user import AbstractUser
log = logging.getLogger("mau.fmt.tg") # type: logging.Logger
log: logging.Logger = logging.getLogger("mau.fmt.tg")
def telegram_reply_to_matrix(evt: Message, source: 'AbstractUser') -> Dict:
+3 -4
View File
@@ -14,7 +14,6 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional, Pattern
from html import escape
import struct
import re
@@ -47,9 +46,9 @@ def trim_reply_fallback_text(text: str) -> str:
return "\n".join(lines)
html_reply_fallback_regex = re.compile("^<mx-reply>"
r"[\s\S]+?"
"</mx-reply>") # type: Pattern
html_reply_fallback_regex: Pattern = re.compile("^<mx-reply>"
r"[\s\S]+?"
"</mx-reply>")
def trim_reply_fallback_html(html: str) -> str:
+1 -3
View File
@@ -13,9 +13,7 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict, Match, Optional, Set, Tuple, Union, Iterable, TYPE_CHECKING
import time
import re
from typing import Dict, Set, Tuple, Union, Iterable, TYPE_CHECKING
from mautrix.bridge import BaseMatrixHandler
from mautrix.types import (Event, EventType, RoomID, UserID, EventID, ReceiptEvent, ReceiptType,
+2 -2
View File
@@ -80,7 +80,7 @@ if TYPE_CHECKING:
from .config import Config
from .tgclient import MautrixTelegramClient
config: Optional[Config] = None
config: Optional['Config'] = None
TypeMessage = Union[Message, MessageService]
TypeParticipant = Union[TypeChatParticipant, TypeChannelParticipant]
@@ -91,7 +91,7 @@ InviteList = Union[UserID, List[UserID]]
class Portal:
base_log: logging.Logger = logging.getLogger("mau.portal")
az: AppService = None
bot: Bot = None
bot: 'Bot' = None
loop: asyncio.AbstractEventLoop = None
# Config cache
+88 -222
View File
@@ -13,19 +13,21 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Awaitable, Any, Dict, List, Iterable, Optional, Pattern, Union, TYPE_CHECKING
from typing import Awaitable, Any, Dict, Iterable, Optional, Union, TYPE_CHECKING
from difflib import SequenceMatcher
from enum import Enum
from aiohttp import ServerDisconnectedError
import asyncio
import logging
import re
from telethon.tl.types import (UserProfilePhoto, User, UpdateUserName, PeerUser, TypeInputPeer,
InputPeerPhotoFileLocation, UserProfilePhotoEmpty)
from mautrix_appservice import AppService, IntentAPI, IntentError, MatrixRequestError
from .types import MatrixUserID, TelegramID
from mautrix.appservice import AppService, IntentAPI
from mautrix.errors import MatrixRequestError
from mautrix.bridge import CustomPuppetMixin
from mautrix.types import UserID
from .types import TelegramID
from .db import Puppet as DBPuppet
from . import util
@@ -35,26 +37,48 @@ if TYPE_CHECKING:
from .context import Context
from .abstract_user import AbstractUser
PuppetError = Enum('PuppetError', 'Success OnlyLoginSelf InvalidAccessToken')
config = None # type: Config
config: Optional['Config'] = None
class Puppet:
log = logging.getLogger("mau.puppet") # type: logging.Logger
az = None # type: AppService
mx = None # type: MatrixHandler
loop = None # type: asyncio.AbstractEventLoop
mxid_regex = None # type: Pattern
username_template = None # type: str
hs_domain = None # type: str
cache = {} # type: Dict[TelegramID, Puppet]
by_custom_mxid = {} # type: Dict[str, Puppet]
class Puppet(CustomPuppetMixin):
log: logging.Logger = logging.getLogger("mau.puppet")
az: AppService
mx: 'MatrixHandler'
loop: asyncio.AbstractEventLoop
username_template: str
hs_domain: str
_mxid_prefix: str
_mxid_suffix: str
_displayname_prefix: str
_displayname_suffix: str
cache: Dict[TelegramID, 'Puppet'] = {}
by_custom_mxid: Dict[UserID, 'Puppet'] = {}
id: TelegramID
access_token: Optional[str]
custom_mxid: Optional[UserID]
default_mxid: UserID
username: Optional[str]
displayname: Optional[str]
displayname_source: Optional[TelegramID]
photo_id: Optional[str]
is_bot: bool
is_registered: bool
disable_updates: bool
default_mxid_intent: IntentAPI
intent: IntentAPI
sync_task: Optional[asyncio.Future]
_db_instance: Optional[DBPuppet]
def __init__(self,
id: TelegramID,
access_token: Optional[str] = None,
custom_mxid: Optional[MatrixUserID] = None,
custom_mxid: Optional[UserID] = None,
username: Optional[str] = None,
displayname: Optional[str] = None,
displayname_source: Optional[TelegramID] = None,
@@ -63,41 +87,32 @@ class Puppet:
is_registered: bool = False,
disable_updates: bool = False,
db_instance: Optional[DBPuppet] = None) -> None:
self.id = id # type: TelegramID
self.access_token = access_token # type: Optional[str]
self.custom_mxid = custom_mxid # type: Optional[MatrixUserID]
self.default_mxid = self.get_mxid_from_id(self.id) # type: MatrixUserID
self.id = id
self.access_token = access_token
self.custom_mxid = custom_mxid
self.default_mxid = self.get_mxid_from_id(self.id)
self.username = username # type: Optional[str]
self.displayname = displayname # type: Optional[str]
self.displayname_source = displayname_source # type: Optional[TelegramID]
self.photo_id = photo_id # type: Optional[str]
self.is_bot = is_bot # type: bool
self.is_registered = is_registered # type: bool
self.disable_updates = disable_updates # type: bool
self._db_instance = db_instance # type: Optional[DBPuppet]
self.username = username
self.displayname = displayname
self.displayname_source = displayname_source
self.photo_id = photo_id
self.is_bot = is_bot
self.is_registered = is_registered
self.disable_updates = disable_updates
self._db_instance = db_instance
self.default_mxid_intent = self.az.intent.user(self.default_mxid)
self.intent = self._fresh_intent() # type: IntentAPI
self.sync_task = None # type: Optional[asyncio.Future]
self.intent = self._fresh_intent()
self.sync_task = None
self.cache[id] = self
if self.custom_mxid:
self.by_custom_mxid[self.custom_mxid] = self
@property
def mxid(self) -> MatrixUserID:
return self.custom_mxid or self.default_mxid
@property
def tgid(self) -> TelegramID:
return self.id
@property
def is_real_user(self) -> bool:
""" Is True when the puppet is a real Matrix user. """
return bool(self.custom_mxid and self.access_token)
@staticmethod
async def is_logged_in() -> bool:
""" Is True if the puppet is logged in. """
@@ -105,175 +120,15 @@ class Puppet:
@property
def plain_displayname(self) -> str:
tpl = config["bridge.displayname_template"]
if tpl == "{displayname}":
# Template has no extra stuff, no need to parse.
return self.displayname
regex = re.compile("^" + re.escape(tpl).replace(re.escape("{displayname}"), "(.+?)") + "$")
match = regex.match(self.displayname)
return match.group(1) or self.displayname
prefix = self._mxid_prefix
suffix = self._mxid_suffix
if self.displayname[:len(prefix)] == prefix and self.displayname[-len(suffix):] == suffix:
return self.displayname[len(prefix):-len(suffix)]
return self.displayname
def get_input_entity(self, user: 'AbstractUser') -> Awaitable[TypeInputPeer]:
return user.client.get_input_entity(PeerUser(user_id=self.tgid))
# region Custom puppet management
def _fresh_intent(self) -> IntentAPI:
return (self.az.intent.user(self.custom_mxid, self.access_token)
if self.is_real_user else self.default_mxid_intent)
async def switch_mxid(self, access_token: Optional[str],
mxid: Optional[MatrixUserID]) -> PuppetError:
prev_mxid = self.custom_mxid
self.custom_mxid = mxid
self.access_token = access_token
self.intent = self._fresh_intent()
err = await self.init_custom_mxid()
if err != PuppetError.Success:
return err
try:
del self.by_custom_mxid[prev_mxid] # type: ignore
except KeyError:
pass
if self.mxid != self.default_mxid:
self.by_custom_mxid[self.mxid] = self
await self.leave_rooms_with_default_user()
self.save()
return PuppetError.Success
async def init_custom_mxid(self) -> PuppetError:
if not self.is_real_user:
return PuppetError.Success
mxid = await self.intent.whoami()
if not mxid or mxid != self.custom_mxid:
self.custom_mxid = None
self.access_token = None
self.intent = self._fresh_intent()
if mxid != self.custom_mxid:
return PuppetError.OnlyLoginSelf
return PuppetError.InvalidAccessToken
if config["bridge.sync_with_custom_puppets"]:
self.sync_task = asyncio.ensure_future(self.sync(), loop=self.loop)
return PuppetError.Success
async def leave_rooms_with_default_user(self) -> None:
for room_id in await self.default_mxid_intent.get_joined_rooms():
try:
await self.default_mxid_intent.leave_room(room_id)
await self.intent.ensure_joined(room_id)
except (IntentError, MatrixRequestError):
pass
def create_sync_filter(self) -> Awaitable[str]:
return self.intent.client.create_filter(self.custom_mxid, {
"room": {
"include_leave": False,
"state": {
"types": []
},
"timeline": {
"types": [],
},
"ephemeral": {
"types": ["m.typing", "m.receipt"],
},
"account_data": {
"types": []
}
},
"account_data": {
"types": [],
},
"presence": {
"types": ["m.presence"],
"senders": [self.custom_mxid],
},
})
def filter_events(self, events: List[Dict]) -> List:
new_events = []
for event in events:
evt_type = event.get("type", None)
event.setdefault("content", {})
if evt_type == "m.typing":
is_typing = self.custom_mxid in event["content"].get("user_ids", [])
event["content"]["user_ids"] = [self.custom_mxid] if is_typing else []
elif evt_type == "m.receipt":
val = None
evt = None
for event_id in event["content"]:
try:
val = event["content"][event_id]["m.read"][self.custom_mxid]
evt = event_id
break
except KeyError:
pass
if val and evt:
event["content"] = {evt: {"m.read": {
self.custom_mxid: val
}}}
else:
continue
new_events.append(event)
return new_events
def handle_sync(self, presence: List, ephemeral: Dict) -> None:
presence_events = [self.mx.try_handle_ephemeral_event(event) for event in presence]
for room_id, events in ephemeral.items():
for event in events:
event["room_id"] = room_id
ephemeral_events = [self.mx.try_handle_ephemeral_event(event)
for events in ephemeral.values()
for event in self.filter_events(events)]
events = ephemeral_events + presence_events # List[Callable[[int], Awaitable[None]]]
coro = asyncio.gather(*events, loop=self.loop)
asyncio.ensure_future(coro, loop=self.loop)
async def sync(self) -> None:
try:
await self._sync()
except asyncio.CancelledError:
self.log.info("Syncing cancelled")
except Exception:
self.log.exception("Fatal error syncing")
async def _sync(self) -> None:
if not self.is_real_user:
self.log.warning("Called sync() for non-custom puppet.")
return
custom_mxid = self.custom_mxid
access_token_at_start = self.access_token
errors = 0
next_batch = None
filter_id = await self.create_sync_filter()
self.log.debug(f"Starting syncer for {custom_mxid} with sync filter {filter_id}.")
while access_token_at_start == self.access_token:
try:
sync_resp = await self.intent.client.sync(filter=filter_id, since=next_batch,
set_presence="offline") # type: Dict
errors = 0
if next_batch is not None:
presence = sync_resp.get("presence", {}).get("events", []) # type: List
ephemeral = {room: data.get("ephemeral", {}).get("events", [])
for room, data
in sync_resp.get("rooms", {}).get("join", {}).items()
} # type: Dict
self.handle_sync(presence, ephemeral)
next_batch = sync_resp.get("next_batch", None)
except (MatrixRequestError, ServerDisconnectedError) as e:
wait = min(errors, 11) ** 2
self.log.warning(f"Syncer for {custom_mxid} errored: {e}. "
f"Waiting for {wait} seconds...")
errors += 1
await asyncio.sleep(wait)
self.log.debug(f"Syncer for custom puppet {custom_mxid} stopped.")
# endregion
# region DB conversion
@property
@@ -378,7 +233,7 @@ class Puppet:
self.displayname = displayname
self.displayname_source = source.tgid
try:
await self.default_mxid_intent.set_display_name(displayname[:100])
await self.default_mxid_intent.set_displayname(displayname[:100])
except MatrixRequestError:
self.log.exception("Failed to set displayname")
self.displayname = ""
@@ -402,7 +257,7 @@ class Puppet:
if not photo_id:
self.photo_id = ""
try:
await self.default_mxid_intent.set_avatar("")
await self.default_mxid_intent.set_avatar_url("")
except MatrixRequestError:
self.log.exception("Failed to set avatar")
self.photo_id = ""
@@ -418,7 +273,7 @@ class Puppet:
if file:
self.photo_id = photo_id
try:
await self.default_mxid_intent.set_avatar(file.mxc)
await self.default_mxid_intent.set_avatar_url(file.mxc)
except MatrixRequestError:
self.log.exception("Failed to set avatar")
self.photo_id = ""
@@ -447,7 +302,7 @@ class Puppet:
return None
@classmethod
def get_by_mxid(cls, mxid: MatrixUserID, create: bool = True) -> Optional['Puppet']:
def get_by_mxid(cls, mxid: UserID, create: bool = True) -> Optional['Puppet']:
tgid = cls.get_id_from_mxid(mxid)
if tgid:
return cls.get(tgid, create)
@@ -455,7 +310,7 @@ class Puppet:
return None
@classmethod
def get_by_custom_mxid(cls, mxid: MatrixUserID) -> Optional['Puppet']:
def get_by_custom_mxid(cls, mxid: UserID) -> Optional['Puppet']:
if not mxid:
raise ValueError("Matrix ID can't be empty")
@@ -479,15 +334,16 @@ class Puppet:
for puppet in DBPuppet.all_with_custom_mxid())
@classmethod
def get_id_from_mxid(cls, mxid: MatrixUserID) -> Optional[TelegramID]:
match = cls.mxid_regex.match(mxid)
if match:
return TelegramID(int(match.group(1)))
def get_id_from_mxid(cls, mxid: UserID) -> Optional[TelegramID]:
prefix = cls._mxid_prefix
suffix = cls._mxid_suffix
if mxid[:len(prefix)] == prefix and mxid[-len(suffix):] == suffix:
return TelegramID(int(mxid[len(prefix):-len(suffix)]))
return None
@classmethod
def get_mxid_from_id(cls, tgid: TelegramID) -> MatrixUserID:
return MatrixUserID(f"@{cls.username_template.format(userid=tgid)}:{cls.hs_domain}")
def get_mxid_from_id(cls, tgid: TelegramID) -> UserID:
return UserID(f"@{cls.username_template.format(userid=tgid)}:{cls.hs_domain}")
@classmethod
def find_by_username(cls, username: str) -> Optional['Puppet']:
@@ -525,8 +381,18 @@ def init(context: 'Context') -> Iterable[Awaitable[Any]]:
global config
Puppet.az, config, Puppet.loop, _ = context.core
Puppet.mx = context.mx
Puppet.username_template = config.get("bridge.username_template", "telegram_{userid}")
Puppet.hs_domain = config["homeserver"]["domain"]
Puppet.mxid_regex = re.compile(
f"@{Puppet.username_template.format(userid='([0-9]+)')}:{Puppet.hs_domain}")
return (puppet.init_custom_mxid() for puppet in Puppet.all_with_custom_mxid())
Puppet.username_template = config["bridge.username_template"]
index = Puppet.username_template.index("{userid}")
length = len("{userid}")
Puppet._mxid_prefix = f"@{Puppet.username_template[:index]}"
Puppet._mxid_suffix = f"{Puppet.username_template[index + length:]}:{Puppet.hs_domain}"
displayname_template = config["bridge.displayname_template"]
index = displayname_template.index("{displayname}")
length = len("{displayname}")
Puppet._displayname_prefix = displayname_template[:index]
Puppet._displayname_suffix = displayname_template[index+length:]
return (puppet.start() for puppet in Puppet.all_with_custom_mxid())
@@ -1,7 +1,8 @@
import argparse
import sqlalchemy as sql
from typing import Union
from sqlalchemy import orm
from sqlalchemy.ext.declarative import declarative_base
import sqlalchemy as sql
import argparse
from alchemysession import AlchemySessionContainer
@@ -24,11 +25,12 @@ def log(message, end="\n"):
def connect(to):
import mautrix_telegram.db.base as base
base.Base = declarative_base(cls=base.BaseBase)
from mautrix_telegram.db import (Portal, Message, UserPortal, User, RoomState, UserProfile,
Contact, Puppet, BotChat, TelegramFile)
from mautrix_telegram.db import (Portal, Message, UserPortal, User, Contact, Puppet, BotChat,
TelegramFile)
from mautrix.bridge.db import RoomState, UserProfile
db_engine = sql.create_engine(to)
db_factory = orm.sessionmaker(bind=db_engine)
db_session = orm.scoped_session(db_factory) # type: orm.Session
db_session: Union[orm.Session, orm.scoped_session] = orm.scoped_session(db_factory)
base.Base.metadata.bind = db_engine
session_container = AlchemySessionContainer(engine=db_engine, session=db_session,
table_base=base.Base, table_prefix="telethon_",
@@ -52,6 +54,7 @@ def connect(to):
"TelegramFile": TelegramFile,
}
log("Connecting to old database")
session, tables = connect(args.from_url)
+1 -1
View File
@@ -31,7 +31,7 @@ class MautrixTelegramClient(TelegramClient):
attributes: List[TypeDocumentAttribute] = None,
file_name: str = None, max_image_size: float = 10 * 1000 ** 2,
) -> Union[InputMediaUploadedDocument, InputMediaUploadedPhoto]:
file_handle = await super().upload_file(file, file_name=file_name, use_cache=False)
file_handle = await super().upload_file(file, file_name=file_name)
if (mime_type == "image/png" or mime_type == "image/jpeg") and len(file) < max_image_size:
return InputMediaUploadedPhoto(file_handle)
+1 -1
View File
@@ -1,3 +1,3 @@
from typing import Dict, NewType
from typing import NewType
TelegramID = NewType('TelegramID', int)
+8 -5
View File
@@ -25,9 +25,11 @@ from telethon.tl.types import (
from telethon.tl.types.contacts import ContactsNotModified
from telethon.tl.functions.contacts import GetContactsRequest, SearchRequest
from telethon.tl.functions.account import UpdateStatusRequest
from mautrix_appservice import MatrixRequestError
from .types import MatrixUserID, TelegramID
from mautrix.errors import MatrixRequestError
from mautrix.types import UserID
from .types import TelegramID
from .db import User as DBUser
from .abstract_user import AbstractUser
from . import portal as po, puppet as pu
@@ -54,7 +56,7 @@ class User(AbstractUser):
_db_instance: Optional[DBUser]
def __init__(self, mxid: MatrixUserID, tgid: Optional[TelegramID] = None,
def __init__(self, mxid: UserID, tgid: Optional[TelegramID] = None,
username: Optional[str] = None, phone: Optional[str] = None,
db_contacts: Optional[Iterable[TelegramID]] = None,
saved_contacts: int = 0, is_bot: bool = False,
@@ -250,7 +252,8 @@ class User(AbstractUser):
if not portal or portal.deleted or not portal.mxid or portal.has_bot:
continue
try:
await portal.main_intent.kick(portal.mxid, self.mxid, "Logged out of Telegram.")
await portal.main_intent.kick_user(portal.mxid, self.mxid,
"Logged out of Telegram.")
except MatrixRequestError:
pass
self.portals = {}
@@ -356,7 +359,7 @@ class User(AbstractUser):
# region Class instance lookup
@classmethod
def get_by_mxid(cls, mxid: MatrixUserID, create: bool = True) -> Optional['User']:
def get_by_mxid(cls, mxid: UserID, create: bool = True) -> Optional['User']:
if not mxid:
raise ValueError("Matrix ID can't be empty")
+4 -1
View File
@@ -1,7 +1,10 @@
from asyncio import Future
from .file_transfer import transfer_file_to_matrix, convert_image
from .format_duration import format_duration
from .signed_token import sign_token, verify_token
from .recursive_dict import recursive_del, recursive_set, recursive_get
def ignore_coro(coro):
def ignore_coro(_: Future) -> None:
pass
+2 -1
View File
@@ -27,7 +27,8 @@ from telethon.tl.types import (Document, InputFileLocation, InputDocumentFileLoc
InputPeerPhotoFileLocation)
from telethon.errors import (AuthBytesInvalidError, AuthKeyInvalidError, LocationInvalidError,
SecurityError, FileIdInvalidError)
from mautrix_appservice import IntentAPI
from mautrix.appservice import IntentAPI
from ..tgclient import MautrixTelegramClient
from ..db import TelegramFile as DBTelegramFile
+5 -4
View File
@@ -14,11 +14,12 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict, Any
from ..config import DictWithRecursion
from mautrix.util.config import RecursiveDict
def recursive_set(data: Dict[str, Any], key: str, value: Any) -> bool:
key, next_key = DictWithRecursion._parse_key(key)
key, next_key = RecursiveDict.parse_key(key)
if next_key is not None:
if key not in data:
data[key] = {}
@@ -31,7 +32,7 @@ def recursive_set(data: Dict[str, Any], key: str, value: Any) -> bool:
def recursive_get(data: Dict[str, Any], key: str) -> Any:
key, next_key = DictWithRecursion._parse_key(key)
key, next_key = RecursiveDict.parse_key(key)
if next_key is not None:
next_data = data.get(key, None)
if not next_data:
@@ -41,7 +42,7 @@ def recursive_get(data: Dict[str, Any], key: str) -> Any:
def recursive_del(data: Dict[str, any], key: str) -> bool:
key, next_key = DictWithRecursion._parse_key(key)
key, next_key = RecursiveDict.parse_key(key)
if next_key is not None:
if key not in data:
return False
+13 -11
View File
@@ -13,27 +13,30 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from abc import abstractmethod
from typing import Optional
from aiohttp import web
from abc import abstractmethod
import abc
import asyncio
import logging
from aiohttp import web
from telethon.errors import *
from mautrix.bridge import OnlyLoginSelf, InvalidAccessToken
from ...commands.telegram.auth import enter_password
from ...util import format_duration, ignore_coro
from ...puppet import Puppet, PuppetError
from ...puppet import Puppet
from ...user import User
class AuthAPI(abc.ABC):
log = logging.getLogger("mau.web.auth") # type: logging.Logger
log: logging.Logger = logging.getLogger("mau.web.auth")
loop: asyncio.AbstractEventLoop
def __init__(self, loop: asyncio.AbstractEventLoop):
self.loop = loop # type: asyncio.AbstractEventLoop
self.loop = loop
@abstractmethod
def get_login_response(self, status: int = 200, state: str = "", username: str = "",
@@ -55,15 +58,14 @@ class AuthAPI(abc.ABC):
error="You have already logged in with your Matrix "
"account.", errcode="already-logged-in")
resp = await puppet.switch_mxid(token.strip(), user.mxid)
if resp == PuppetError.OnlyLoginSelf:
try:
await puppet.switch_mxid(token.strip(), user.mxid)
except OnlyLoginSelf:
return self.get_mx_login_response(status=403, errcode="only-login-self",
error="You can only log in as your own Matrix user.")
elif resp == PuppetError.InvalidAccessToken:
except InvalidAccessToken:
return self.get_mx_login_response(status=401, errcode="invalid-access-token",
error="Failed to verify access token.")
assert resp == PuppetError.Success, "Encountered an unhandled PuppetError."
return self.get_mx_login_response(mxid=user.mxid, status=200, state="logged-in")
async def post_matrix_password(self, user: User, password: str) -> web.Response:
+25 -27
View File
@@ -13,17 +13,21 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from aiohttp import web
from typing import Awaitable, Callable, Dict, Optional, Tuple, TYPE_CHECKING
import asyncio
import logging
import json
from aiohttp import web
from telethon.utils import get_peer_id, resolve_id
from telethon.tl.types import ChatForbidden, ChannelForbidden, TypeChat
from mautrix_appservice import AppService, MatrixRequestError, IntentError
from ...types import MatrixUserID, TelegramID
from mautrix.appservice import AppService
from mautrix.errors import MatrixRequestError, IntentError
from mautrix.types import UserID
from ...types import TelegramID
from ...user import User
from ...portal import Portal
from ...util import ignore_coro
@@ -35,16 +39,19 @@ if TYPE_CHECKING:
class ProvisioningAPI(AuthAPI):
log = logging.getLogger("mau.web.provisioning") # type: logging.Logger
log: logging.Logger = logging.getLogger("mau.web.provisioning")
secret: str
az: AppService
context: 'Context'
app: web.Application
def __init__(self, context: "Context") -> None:
super().__init__(context.loop)
self.secret = context.config["appservice.provisioning.shared_secret"] # type: str
self.az = context.az # type: AppService
self.context = context # type: Context
self.secret = context.config["appservice.provisioning.shared_secret"]
self.az = context.az
self.context = context
self.app = web.Application(loop=context.loop, middlewares=[self.error_middleware]
) # type: web.Application
self.app = web.Application(loop=context.loop, middlewares=[self.error_middleware])
portal_prefix = "/portal/{mxid:![^/]+}"
self.app.router.add_route("GET", f"{portal_prefix}", self.get_portal_by_mxid)
@@ -76,18 +83,7 @@ class ProvisioningAPI(AuthAPI):
if not portal:
return self.get_error_response(404, "portal_not_found",
"Portal with given Matrix ID not found.")
user, _ = await self.get_user(request.query.get("user_id", None), expect_logged_in=None,
require_puppeting=False)
return web.json_response({
"mxid": portal.mxid,
"chat_id": get_peer_id(portal.peer),
"peer_type": portal.peer_type,
"title": portal.title,
"about": portal.about,
"username": portal.username,
"megagroup": portal.megagroup,
"can_unbridge": (await portal.can_user_perform(user, "unbridge")) if user else False,
})
return await self._get_portal_response(UserID(request.query.get("user_id", "")), portal)
async def get_portal_by_tgid(self, request: web.Request) -> web.Response:
err = self.check_authorization(request)
@@ -103,8 +99,10 @@ class ProvisioningAPI(AuthAPI):
if not portal:
return self.get_error_response(404, "portal_not_found",
"Portal to given Telegram chat not found.")
user, _ = await self.get_user(request.query.get("user_id", None), expect_logged_in=None,
require_puppeting=False)
return await self._get_portal_response(UserID(request.query.get("user_id", "")), portal)
async def _get_portal_response(self, user_id: UserID, portal: Portal) -> web.Response:
user, _ = await self.get_user(user_id, expect_logged_in=None, require_puppeting=False)
return web.json_response({
"mxid": portal.mxid,
"chat_id": get_peer_id(portal.peer),
@@ -364,7 +362,8 @@ class ProvisioningAPI(AuthAPI):
async def bridge_info(self, request: web.Request) -> web.Response:
return web.json_response({
"relaybot_username": self.context.bot.username if self.context.bot is not None else None,
"relaybot_username": (self.context.bot.username
if self.context.bot is not None else None),
}, status=200)
@staticmethod
@@ -430,7 +429,7 @@ class ProvisioningAPI(AuthAPI):
except json.JSONDecodeError:
return None
async def get_user(self, mxid: MatrixUserID, expect_logged_in: Optional[bool] = False,
async def get_user(self, mxid: Optional[UserID], expect_logged_in: Optional[bool] = False,
require_puppeting: bool = True, require_user: bool = True
) -> Tuple[Optional[User], Optional[web.Response]]:
if not mxid:
@@ -459,8 +458,7 @@ class ProvisioningAPI(AuthAPI):
expect_logged_in: Optional[bool] = False,
require_puppeting: bool = False,
want_data: bool = True,
) -> (Tuple[Optional[Dict],
Optional[User],
) -> (Tuple[Optional[Dict], Optional[User],
Optional[web.Response]]):
err = self.check_authorization(request)
if err is not None:
+17 -12
View File
@@ -14,16 +14,18 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional
from aiohttp import web
from mako.template import Template
import pkg_resources
import asyncio
import logging
import random
import string
import time
from ...types import MatrixUserID
from mako.template import Template
from aiohttp import web
import pkg_resources
from mautrix.types import UserID
from ...util import sign_token, verify_token
from ...user import User
from ...puppet import Puppet
@@ -31,20 +33,23 @@ from ..common import AuthAPI
class PublicBridgeWebsite(AuthAPI):
log = logging.getLogger("mau.web.public") # type: logging.Logger
log: logging.Logger = logging.getLogger("mau.web.public")
secret_key: str
login: Template
mx_login: Template
app: web.Application
def __init__(self, loop: asyncio.AbstractEventLoop):
super().__init__(loop)
self.secret_key = "".join(
random.choice(string.ascii_lowercase + string.digits) for _ in range(64)) # type: str
self.secret_key = "".join(random.choices(string.ascii_lowercase + string.digits, k=64))
self.login = Template(pkg_resources.resource_string(
"mautrix_telegram", "web/public/login.html.mako")) # type: Template
"mautrix_telegram", "web/public/login.html.mako"))
self.mx_login = Template(pkg_resources.resource_string(
"mautrix_telegram", "web/public/matrix-login.html.mako")) # type: Template
"mautrix_telegram", "web/public/matrix-login.html.mako"))
self.app = web.Application(loop=loop) # type: web.Application
self.app = web.Application(loop=loop)
self.app.router.add_route("GET", "/login", self.get_login)
self.app.router.add_route("POST", "/login", self.post_login)
self.app.router.add_route("GET", "/matrix-login", self.get_matrix_login)
@@ -59,11 +64,11 @@ class PublicBridgeWebsite(AuthAPI):
"expiry": int(time.time()) + expires_in,
})
def verify_token(self, token: str, endpoint: str = "/login") -> Optional[MatrixUserID]:
def verify_token(self, token: str, endpoint: str = "/login") -> Optional[UserID]:
token = verify_token(self.secret_key, token)
if token and (token.get("expiry", 0) > int(time.time()) and
token.get("endpoint", None) == endpoint):
return MatrixUserID(token.get("mxid", None))
return UserID(token.get("mxid", None))
return None
async def get_login(self, request: web.Request) -> web.Response: