diff --git a/mautrix_telegram/__main__.py b/mautrix_telegram/__main__.py index 430696eb..fff2d868 100644 --- a/mautrix_telegram/__main__.py +++ b/mautrix_telegram/__main__.py @@ -110,8 +110,10 @@ with appserv.run(config["appservice.hostname"], config["appservice.port"]) as st context.mx = MatrixHandler(context) init_formatter(context) init_portal(context) - init_puppet(context) - startup_actions = init_user(context) + [start, context.mx.init_as_bot()] + startup_actions = (init_puppet(context) + + init_user(context) + + [start, + context.mx.init_as_bot()]) if context.bot: startup_actions.append(context.bot.start()) diff --git a/mautrix_telegram/abstract_user.py b/mautrix_telegram/abstract_user.py index a4ae028c..dd17234f 100644 --- a/mautrix_telegram/abstract_user.py +++ b/mautrix_telegram/abstract_user.py @@ -124,7 +124,7 @@ class AbstractUser: self.log.debug("%s connected: %s", self.mxid, self.connected) return self - async def ensure_started(self, even_if_no_session=False): + async def ensure_started(self, even_if_no_session=False) -> "AbstractUser": if not self.puppet_whitelisted: return self self.log.debug("ensure_started(%s, connected=%s, even_if_no_session=%s, session_count=%s)", diff --git a/mautrix_telegram/commands/auth.py b/mautrix_telegram/commands/auth.py index 38bcdf52..33c88cb7 100644 --- a/mautrix_telegram/commands/auth.py +++ b/mautrix_telegram/commands/auth.py @@ -56,15 +56,11 @@ async def ping_bot(evt: CommandEvent): "account") async def login_matrix(evt: CommandEvent): puppet = pu.Puppet.get(evt.sender.tgid) - prev_info = puppet.custom_mxid, puppet.access_token - puppet.custom_mxid = evt.sender.mxid - puppet.access_token = " ".join(evt.args) - puppet.refresh_intents() - if not await puppet.get_profile(): - puppet.custom_mxid, puppet.access_token = prev_info - puppet.refresh_intents() + resp = puppet.switch_mxid(" ".join(evt.args), evt.sender.mxid) + if resp == 2: + return await evt.reply("You can only log in as your own Matrix user.") + elif resp == 1: return await evt.reply("Failed to verify access token.") - puppet.save() return await evt.reply( f"Replaced your Telegram account's Matrix puppet with {puppet.custom_mxid}.") diff --git a/mautrix_telegram/db.py b/mautrix_telegram/db.py index 32099908..81bc0598 100644 --- a/mautrix_telegram/db.py +++ b/mautrix_telegram/db.py @@ -17,14 +17,14 @@ from sqlalchemy import (Column, UniqueConstraint, ForeignKey, ForeignKeyConstraint, Integer, BigInteger, String, Boolean, Text) from sqlalchemy.sql import expression -from sqlalchemy.orm import relationship +from sqlalchemy.orm import relationship, Query import json from .base import Base class Portal(Base): - query = None + query = None # type: Query __tablename__ = "portal" # Telegram chat information @@ -42,9 +42,8 @@ class Portal(Base): about = Column(String, nullable=True) photo_id = Column(String, nullable=True) - class Message(Base): - query = None + query = None # type: Query __tablename__ = "message" mxid = Column(String) @@ -56,7 +55,7 @@ class Message(Base): class UserPortal(Base): - query = None + query = None # type: Query __tablename__ = "user_portal" user = Column(Integer, ForeignKey("user.tgid", onupdate="CASCADE", ondelete="CASCADE"), @@ -70,7 +69,7 @@ class UserPortal(Base): class User(Base): - query = None + query = None # type: Query __tablename__ = "user" mxid = Column(String, primary_key=True) @@ -83,7 +82,7 @@ class User(Base): class RoomState(Base): - query = None + query = None # type: Query __tablename__ = "mx_room_state" room_id = Column(String, primary_key=True) @@ -107,7 +106,7 @@ class RoomState(Base): class UserProfile(Base): - query = None + query = None # type: Query __tablename__ = "mx_user_profile" room_id = Column(String, primary_key=True) @@ -125,7 +124,7 @@ class UserProfile(Base): class Contact(Base): - query = None + query = None # type: Query __tablename__ = "contact" user = Column(Integer, ForeignKey("user.tgid"), primary_key=True) @@ -133,7 +132,7 @@ class Contact(Base): class Puppet(Base): - query = None + query = None # type: Query __tablename__ = "puppet" id = Column(Integer, primary_key=True) @@ -149,14 +148,14 @@ class Puppet(Base): # Fucking Telegram not telling bots what chats they are in 3:< class BotChat(Base): - query = None + query = None # type: Query __tablename__ = "bot_chat" id = Column(Integer, primary_key=True) type = Column(String, nullable=False) class TelegramFile(Base): - query = None + query = None # type: Query __tablename__ = "telegram_file" id = Column(String, primary_key=True) diff --git a/mautrix_telegram/portal.py b/mautrix_telegram/portal.py index 35872438..c78869d2 100644 --- a/mautrix_telegram/portal.py +++ b/mautrix_telegram/portal.py @@ -824,7 +824,12 @@ class Portal: mxid=event_id)) self.db.commit() - async def handle_matrix_message(self, sender, message, event_id): + async def handle_matrix_message(self, sender: u.User, message: dict, event_id: str): + puppet = p.Puppet.get_by_custom_mxid(sender.mxid) + if puppet and message.get("net.maunium.telegram.puppet", False): + self.log.debug("Ignoring puppet-sent message by confirmed puppet user %s", sender.mxid) + return + logged_in = not await sender.needs_relaybot(self) client = sender.client if logged_in else self.bot.client sender_id = sender.tgid if logged_in else self.bot.tgid diff --git a/mautrix_telegram/puppet.py b/mautrix_telegram/puppet.py index 1403e64d..f708196e 100644 --- a/mautrix_telegram/puppet.py +++ b/mautrix_telegram/puppet.py @@ -36,6 +36,7 @@ class Puppet: username_template = None hs_domain = None cache = {} + by_custom_mxid = {} def __init__(self, id=None, access_token=None, custom_mxid=None, username=None, displayname=None, displayname_source=None, photo_id=None, is_bot=None, @@ -60,22 +61,51 @@ class Puppet: self.refresh_intents() self.cache[id] = self + if self.custom_mxid: + self.by_custom_mxid[self.custom_mxid] = self def refresh_intents(self): self.is_real_user = self.custom_mxid and self.access_token self.intent = (self.az.intent.user(self.custom_mxid, self.access_token) if self.is_real_user else self.default_mxid_intent) - async def get_profile(self): - try: - return await self.intent.get_profile(self.custom_mxid) - except MatrixError: - return None - @property def tgid(self): return self.id + async def switch_mxid(self, access_token, mxid): + prev_mxid = self.custom_mxid + self.custom_mxid = mxid + self.access_token = access_token + self.refresh_intents() + + err = await self.test_custom_mxid() + if err != 0: + return err + + try: + del self.by_custom_mxid[prev_mxid] + except KeyError: + pass + self.mxid = self.custom_mxid or self.default_mxid + self.by_custom_mxid[self.mxid] = self + self.save() + return 0 + + async def test_custom_mxid(self): + if not self.is_real_user: + return 0 + + mxid = await self.intent.whoami() + if not mxid or mxid != self.custom_mxid: + self.custom_mxid = None + self.access_token = None + self.refresh_intents() + if mxid != self.custom_mxid: + return 2 + return 1 + return 0 + async def is_logged_in(self): return True @@ -212,6 +242,30 @@ class Puppet: tgid = cls.get_id_from_mxid(mxid) return cls.get(tgid, create) if tgid else None + @classmethod + def get_by_custom_mxid(cls, mxid): + if not mxid: + raise ValueError("Matrix ID can't be empty") + + try: + return cls.by_custom_mxid[mxid] + except KeyError: + pass + + puppet = DBPuppet.query.filter(DBPuppet.custom_mxid == mxid).one_or_none() + if puppet: + puppet = cls.from_db(puppet) + return puppet + + return None + + @classmethod + def get_all_with_custom_mxid(cls): + return [cls.by_custom_mxid[puppet.mxid] + if puppet.custom_mxid in cls.by_custom_mxid + else cls.from_db(puppet) + for puppet in DBPuppet.query.filter(DBPuppet.custom_mxid is not None).all()] + @classmethod def get_id_from_mxid(cls, mxid): match = cls.mxid_regex.match(mxid) @@ -261,3 +315,4 @@ def init(context): Puppet.hs_domain = config["homeserver"]["domain"] localpart = Puppet.username_template.format(userid="(.+)") Puppet.mxid_regex = re.compile(f"@{localpart}:{Puppet.hs_domain}") + return [puppet.test_custom_mxid() for puppet in Puppet.get_all_with_custom_mxid()] diff --git a/mautrix_telegram/user.py b/mautrix_telegram/user.py index 43ad5dec..110c9f8e 100644 --- a/mautrix_telegram/user.py +++ b/mautrix_telegram/user.py @@ -14,7 +14,7 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Dict +from typing import Dict, Awaitable, Optional import logging import asyncio import re @@ -185,6 +185,9 @@ class User(AbstractUser): # endregion # region Telegram actions that need custom methods + def ensure_started(self, even_if_no_session=False) -> "Awaitable[User]": + return super().ensure_started(even_if_no_session) + async def update_info(self, info: User = None): info = info or await self.client.get_me() changed = False @@ -309,7 +312,7 @@ class User(AbstractUser): # region Class instance lookup @classmethod - def get_by_mxid(cls, mxid, create=True): + def get_by_mxid(cls, mxid, create=True) -> "Optional[User]": if not mxid: raise ValueError("Matrix ID can't be empty") @@ -332,7 +335,7 @@ class User(AbstractUser): return None @classmethod - def get_by_tgid(cls, tgid): + def get_by_tgid(cls, tgid) -> "Optional[User]": try: return cls.by_tgid[tgid] except KeyError: @@ -346,7 +349,7 @@ class User(AbstractUser): return None @classmethod - def find_by_username(cls, username): + def find_by_username(cls, username) -> "Optional[User]": if not username: return None