From 2064f2b2d160acc7f71e2794b081569a550a0cdf Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 13 Feb 2018 00:58:03 +0200 Subject: [PATCH] Store user portals and kick when logging out. Fixes #53 --- mautrix_telegram/db.py | 20 ++++++++- mautrix_telegram/matrix.py | 1 + mautrix_telegram/portal.py | 6 ++- mautrix_telegram/user.py | 83 ++++++++++++++++++++++++++++++-------- 4 files changed, 90 insertions(+), 20 deletions(-) diff --git a/mautrix_telegram/db.py b/mautrix_telegram/db.py index d993b02c..0b6341d8 100644 --- a/mautrix_telegram/db.py +++ b/mautrix_telegram/db.py @@ -13,7 +13,7 @@ # # You should have received a copy of the GNU General Public License # along with this program. If not, see . -from sqlalchemy import Column, UniqueConstraint, ForeignKey, Integer, String +from sqlalchemy import Column, UniqueConstraint, ForeignKey, ForeignKeyConstraint, Integer, String from sqlalchemy.orm import relationship from .base import Base @@ -50,6 +50,18 @@ class Message(Base): __table_args__ = (UniqueConstraint('mxid', 'mx_room', 'tg_space', name='_mx_id_room'),) +class UserPortal(Base): + query = None + __tablename__ = "user_portal" + + user = Column(Integer, ForeignKey("user.tgid"), primary_key=True) + portal = Column(Integer, primary_key=True) + portal_receiver = Column(Integer, primary_key=True) + + __table_args__ = (ForeignKeyConstraint(("portal", "portal_receiver"), + ("portal.tgid", "portal.tg_receiver")),) + + class User(Base): query = None __tablename__ = "user" @@ -58,7 +70,10 @@ class User(Base): tgid = Column(Integer, nullable=True) tg_username = Column(String, nullable=True) saved_contacts = Column(Integer, default=0) - contacts = relationship("Contact", uselist=True) + contacts = relationship("Contact", uselist=True, + cascade="save-update, merge, delete, delete-orphan") + portals = relationship("Portal", secondary="user_portal", single_parent=True, + cascade="save-update, merge, delete, delete-orphan") class Contact(Base): @@ -82,5 +97,6 @@ class Puppet(Base): def init(db_session): Portal.query = db_session.query_property() Message.query = db_session.query_property() + UserPortal.query = db_session.query_property() User.query = db_session.query_property() Puppet.query = db_session.query_property() diff --git a/mautrix_telegram/matrix.py b/mautrix_telegram/matrix.py index e631ff4f..cbca6d09 100644 --- a/mautrix_telegram/matrix.py +++ b/mautrix_telegram/matrix.py @@ -79,6 +79,7 @@ class MatrixHandler: pass portal.mxid = room portal.save() + inviter.register_portal(portal) await puppet.intent.send_notice(room, "Portal to private chat created.") else: await puppet.intent.join_room(room) diff --git a/mautrix_telegram/portal.py b/mautrix_telegram/portal.py index 6351897a..9903b9f5 100644 --- a/mautrix_telegram/portal.py +++ b/mautrix_telegram/portal.py @@ -204,7 +204,7 @@ class Portal: if alias: # TODO properly handle existing room aliases - intent.remove_room_alias(alias) + await intent.remove_room_alias(alias) room = await intent.create_room(alias=alias, is_public=public, invitees=invites or [], name=self.title, is_direct=direct) if not room: @@ -213,6 +213,7 @@ class Portal: self.mxid = room["room_id"] self.by_mxid[self.mxid] = self self.save() + user.register_portal(self) power_level_requirement = 0 if self.peer_type == "chat" and entity.admins_enabled else 50 levels = await self.main_intent.get_power_levels(self.mxid) @@ -245,6 +246,7 @@ class Portal: user = u.User.get_by_tgid(user_id) if user: + user.register_portal(self) await self.main_intent.invite(self.mxid, user.mxid) async def delete_telegram_user(self, user_id, kick_message=None): @@ -255,6 +257,7 @@ class Portal: else: await puppet.intent.leave_room(self.mxid) if user: + user.unregister_portal(self) await self.main_intent.kick(self.mxid, user.mxid, kick_message or "Left Telegram chat") async def update_info(self, user, entity=None): @@ -840,6 +843,7 @@ class Portal: user_levels = levels["users"] if user: + user.register_portal(self) user_level_defined = user.mxid in user_levels user_has_right_level = (user_levels[user.mxid] == new_level if user_level_defined else new_level == 0) diff --git a/mautrix_telegram/user.py b/mautrix_telegram/user.py index 04ab26c4..d64be891 100644 --- a/mautrix_telegram/user.py +++ b/mautrix_telegram/user.py @@ -21,12 +21,12 @@ from telethon.tl.types import * from telethon.tl.types.contacts import ContactsNotModified from telethon.tl.types import User as TLUser from telethon.tl.functions.contacts import GetContactsRequest, SearchRequest +from mautrix_appservice import MatrixRequestError from .db import User as DBUser, Message as DBMessage, Contact as DBContact from .tgclient import MautrixTelegramClient from . import portal as po, puppet as pu, __version__ - config = None @@ -38,26 +38,20 @@ class User: by_mxid = {} by_tgid = {} - def __init__(self, mxid, tgid=None, username=None, db_contacts=None, saved_contacts=0): + def __init__(self, mxid, tgid=None, username=None, db_contacts=None, saved_contacts=0, + db_portals=None): self.mxid = mxid self.tgid = tgid self.username = username self.contacts = [] self.saved_contacts = saved_contacts self.db_contacts = db_contacts + self.portals = {} + self.db_portals = db_portals self.command_status = None self.connected = False - device = f"{platform.system()} {platform.release()}" - sysversion = MautrixTelegramClient.__version__ - self.client = MautrixTelegramClient(self.mxid, - config["telegram.api_id"], - config["telegram.api_hash"], - loop=self.loop, - app_version=__version__, - system_version=sysversion, - device_model=device) - self.client.add_update_handler(self.update_catch) + self._init_client() self.is_admin = self.mxid in config.get("bridge.admins", []) @@ -91,6 +85,19 @@ class User: else: self.contacts = [] + @property + def db_portals(self): + return [portal.to_db(merge=False) for _, portal in self.portals.items()] + + @db_portals.setter + def db_portals(self, portals): + if portals: + self.portals = {(portal.tgid, portal.tg_receiver): + po.Portal.get_by_tgid(portal.tgid, portal.tg_receiver) + for portal in portals} + else: + self.portals = {} + def get_input_entity(self, user): return user.client.get_input_entity(InputUser(user_id=self.tgid, access_hash=0)) @@ -99,7 +106,8 @@ class User: def to_db(self): return self.db.merge( DBUser(mxid=self.mxid, tgid=self.tgid, tg_username=self.username, - contacts=self.db_contacts, saved_contacts=self.saved_contacts)) + contacts=self.db_contacts, saved_contacts=self.saved_contacts, + portals=self.db_portals)) def save(self): self.to_db() @@ -108,11 +116,23 @@ class User: @classmethod def from_db(cls, db_user): return User(db_user.mxid, db_user.tgid, db_user.tg_username, db_user.contacts, - db_user.saved_contacts) + db_user.saved_contacts, db_user.portals) # endregion # region Telegram connection management + def _init_client(self): + device = f"{platform.system()} {platform.release()}" + sysversion = MautrixTelegramClient.__version__ + self.client = MautrixTelegramClient(self.mxid, + config["telegram.api_id"], + config["telegram.api_hash"], + loop=self.loop, + app_version=__version__, + system_version=sysversion, + device_model=device) + self.client.add_update_handler(self.update_catch) + async def start(self): self.connected = await self.client.connect() if self.logged_in: @@ -148,7 +168,14 @@ class User: self.save() async def log_out(self): - self.connected = False + for _, portal in self.portals.items(): + try: + await portal.main_intent.kick(portal.mxid, self.mxid, "Logged out of Telegram.") + except MatrixRequestError: + pass + self.portals = {} + self.contacts = [] + self.save() if self.tgid: try: del self.by_tgid[self.tgid] @@ -156,8 +183,12 @@ class User: pass self.tgid = None self.save() - await self.client.log_out() - # TODO kick user from portals + ok = await self.client.log_out() + if not ok: + return False + self._init_client() + await self.start() + return True def _search_local(self, query, max_results=5, min_similarity=45): results = [] @@ -200,9 +231,27 @@ class User: if invalid: continue portal = po.Portal.get_by_entity(entity) + self.portals[portal.tgid_full] = portal creators.append(portal.create_matrix_room(self, entity, invites=[self.mxid])) + self.save() await asyncio.gather(*creators, loop=self.loop) + def register_portal(self, portal): + try: + if self.portals[portal.tgid_full] == portal: + return + except KeyError: + pass + self.portals[portal.tgid_full] = portal + self.save() + + def unregister_portal(self, portal): + try: + del self.portals[portal.tgid_full] + self.save() + except KeyError: + pass + def _hash_contacts(self): acc = 0 for id in sorted([self.saved_contacts] + [contact.id for contact in self.contacts]):