From cf847d3b8e0cb4edf91d3b76db88ffd87ee6ae9a Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 12 Feb 2019 14:42:03 +0200 Subject: [PATCH] Finish moving portals and users to SQLAlchemy Core --- mautrix_telegram/db.py | 108 +++++++++++++++++++++++++++++++------ mautrix_telegram/puppet.py | 19 +++---- mautrix_telegram/user.py | 31 ++++++----- 3 files changed, 118 insertions(+), 40 deletions(-) diff --git a/mautrix_telegram/db.py b/mautrix_telegram/db.py index 5afc5c66..463f2259 100644 --- a/mautrix_telegram/db.py +++ b/mautrix_telegram/db.py @@ -20,7 +20,7 @@ from sqlalchemy import (Column, UniqueConstraint, ForeignKey, ForeignKeyConstrai from sqlalchemy.engine.result import RowProxy from sqlalchemy.sql import expression from sqlalchemy.orm import relationship, Query -from typing import Dict, Optional, List, Iterable +from typing import Dict, Optional, List, Iterable, Tuple import json from mautrix_telegram.types import MatrixUserID, MatrixRoomID, MatrixEventID @@ -48,14 +48,18 @@ class Portal(Base): about = Column(String, nullable=True) photo_id = Column(String, nullable=True) + @classmethod + def scan(cls, row) -> Optional['Portal']: + (tgid, tg_receiver, peer_type, megagroup, mxid, config, username, title, about, + photo_id) = row + return cls(tgid=tgid, tg_receiver=tg_receiver, peer_type=peer_type, megagroup=megagroup, + mxid=mxid, config=config, username=username, title=title, about=about, + photo_id=photo_id) + @classmethod def _one_or_none(cls, rows: RowProxy) -> Optional['Portal']: try: - (tgid, tg_receiver, peer_type, megagroup, mxid, config, - username, title, about, photo_id) = next(rows) - return cls(tgid=tgid, tg_receiver=tg_receiver, peer_type=peer_type, - megagroup=megagroup, mxid=mxid, config=config, username=username, - title=title, about=about, photo_id=photo_id) + return cls.scan(next(rows)) except StopIteration: return None @@ -155,10 +159,6 @@ class User(Base): tg_username = Column(String, nullable=True) tg_phone = Column(String, nullable=True) saved_contacts = Column(Integer, default=0, nullable=False) - contacts = relationship("Contact", uselist=True, - cascade="save-update, merge, delete, delete-orphan" - ) # type: List[Contact] - portals = relationship("Portal", secondary="user_portal") @classmethod def _one_or_none(cls, rows: RowProxy) -> Optional['User']: @@ -170,7 +170,7 @@ class User(Base): return None @classmethod - def get_all(cls) -> Iterable['User']: + def all(cls) -> Iterable['User']: rows = cls.db.execute(cls.t.select()) for row in rows: mxid, tgid, tg_username, tg_phone, saved_contacts = row @@ -198,6 +198,36 @@ class User(Base): mxid=self.mxid, tgid=self.tgid, tg_username=self.tg_username, tg_phone=self.tg_phone, saved_contacts=self.saved_contacts)) + @property + def contacts(self) -> Iterable[TelegramID]: + rows = self.db.execute(Contact.t.select().where(Contact.c.user == self.tgid)) + for row in rows: + user, contact = row + yield contact + + @contacts.setter + def contacts(self, puppets: Iterable[TelegramID]) -> None: + self.db.execute(Contact.t.delete().where(Contact.c.user == self.tgid)) + self.db.execute(Contact.t.insert(), [{"user": self.tgid, "contact": tgid} + for tgid in puppets]) + + @property + def portals(self) -> Iterable[Tuple[TelegramID, TelegramID]]: + rows = self.db.execute(UserPortal.t.select().where(UserPortal.c.user == self.tgid)) + for row in rows: + user, portal, portal_receiver = row + yield (portal, portal_receiver) + + @portals.setter + def portals(self, portals: Iterable[Tuple[TelegramID, TelegramID]]) -> None: + self.db.execute(UserPortal.t.delete().where(UserPortal.c.user == self.tgid)) + self.db.execute(UserPortal.t.insert(), + [{ + "user": self.tgid, + "portal": tgid, + "portal_receiver": tg_receiver + } for tgid, tg_receiver in portals]) + class UserPortal(Base): __tablename__ = "user_portal" @@ -302,7 +332,6 @@ class UserProfile(Base): class Puppet(Base): - query = None # type: Query __tablename__ = "puppet" id = Column(Integer, primary_key=True) # type: TelegramID @@ -315,6 +344,55 @@ class Puppet(Base): is_bot = Column(Boolean, nullable=True) matrix_registered = Column(Boolean, nullable=False, server_default=expression.false()) + @classmethod + def scan(cls, row) -> Optional['Puppet']: + (id, custom_mxid, access_token, displayname, displayname_source, username, photo_id, + is_bot, matrix_registered) = row + return cls(id=id, custom_mxid=custom_mxid, access_token=access_token, + displayname=displayname, displayname_source=displayname_source, + username=username, photo_id=photo_id, is_bot=is_bot, + matrix_registered=matrix_registered) + + @classmethod + def _one_or_none(cls, rows: RowProxy) -> Optional['Puppet']: + try: + return cls.scan(next(rows)) + except StopIteration: + return None + + @classmethod + def all_with_custom_mxid(cls) -> Iterable['Puppet']: + rows = cls.db.execute(cls.t.select().where(cls.c.custom_mxid != None)) + for row in rows: + yield cls.scan(row) + + @classmethod + def get_by_tgid(cls, tgid: TelegramID) -> Optional['Puppet']: + return cls._select_one_or_none(cls.c.id == tgid) + + @classmethod + def get_by_custom_mxid(cls, mxid: MatrixRoomID) -> Optional['Puppet']: + return cls._select_one_or_none(cls.c.custom_mxid == mxid) + + @classmethod + def get_by_username(cls, username: str) -> Optional['Puppet']: + return cls._select_one_or_none(cls.c.username == username) + + @classmethod + def get_by_displayname(cls, displayname: str) -> Optional['Puppet']: + return cls._select_one_or_none(cls.c.displayname == displayname) + + @property + def _edit_identity(self): + return self.c.id == self.id + + def insert(self) -> None: + self.db.execute(self.t.insert().values( + id=self.id, custom_mxid=self.custom_mxid, access_token=self.access_token, + displayname=self.displayname, displayname_source=self.displayname_source, + username=self.username, photo_id=self.photo_id, is_bot=self.is_bot, + matrix_registered=self.matrix_registered)) + # Fucking Telegram not telling bots what chats they are in 3:< class BotChat(Base): @@ -359,9 +437,9 @@ class TelegramFile(Base): def init(db_session, db_engine) -> None: - query = db_session.query_property() - for table in (Portal, Message, User, Puppet, BotChat, TelegramFile, UserProfile, RoomState): - table.query = query + BotChat.query = db_session.query_property() + for table in (Portal, Message, User, Contact, UserPortal, Puppet, TelegramFile, UserProfile, + RoomState): table.db = db_engine table.t = table.__table__ table.c = table.t.c diff --git a/mautrix_telegram/puppet.py b/mautrix_telegram/puppet.py index 4f9070ba..4a05345d 100644 --- a/mautrix_telegram/puppet.py +++ b/mautrix_telegram/puppet.py @@ -14,7 +14,8 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Awaitable, Coroutine, Dict, List, Optional, Pattern, Union, TYPE_CHECKING +from typing import (Awaitable, Coroutine, Dict, List, Iterable, Optional, Pattern, Union, + TYPE_CHECKING) from difflib import SequenceMatcher from enum import Enum from aiohttp import ServerDisconnectedError @@ -396,7 +397,7 @@ class Puppet: except KeyError: pass - puppet = DBPuppet.query.get(tgid) + puppet = DBPuppet.get_by_tgid(tgid) if puppet: return cls.from_db(puppet) @@ -426,7 +427,7 @@ class Puppet: except KeyError: pass - puppet = DBPuppet.query.filter(DBPuppet.custom_mxid == mxid).one_or_none() + puppet = DBPuppet.get_by_custom_mxid(mxid) if puppet: puppet = cls.from_db(puppet) return puppet @@ -434,11 +435,11 @@ class Puppet: return None @classmethod - def get_all_with_custom_mxid(cls) -> List['Puppet']: - return [cls.by_custom_mxid[puppet.mxid] + def all_with_custom_mxid(cls) -> Iterable['Puppet']: + return (cls.by_custom_mxid[puppet.mxid] if puppet.custom_mxid in cls.by_custom_mxid else cls.from_db(puppet) - for puppet in DBPuppet.query.filter(DBPuppet.custom_mxid is not None).all()] + for puppet in DBPuppet.all_with_custom_mxid()) @classmethod def get_id_from_mxid(cls, mxid: MatrixUserID) -> Optional[TelegramID]: @@ -460,7 +461,7 @@ class Puppet: if puppet.username and puppet.username.lower() == username.lower(): return puppet - dbpuppet = DBPuppet.query.filter(DBPuppet.username == username).one_or_none() + dbpuppet = DBPuppet.get_by_username(username) if dbpuppet: return cls.from_db(dbpuppet) @@ -475,7 +476,7 @@ class Puppet: if puppet.displayname and puppet.displayname == displayname: return puppet - dbpuppet = DBPuppet.query.filter(DBPuppet.displayname == displayname).one_or_none() + dbpuppet = DBPuppet.get_by_displayname(displayname) if dbpuppet: return cls.from_db(dbpuppet) @@ -491,4 +492,4 @@ def init(context: 'Context') -> List[Coroutine]: # [None, None, PuppetError] Puppet.hs_domain = config["homeserver"]["domain"] Puppet.mxid_regex = re.compile( f"@{Puppet.username_template.format(userid='([0-9]+)')}:{Puppet.hs_domain}") - return [puppet.init_custom_mxid() for puppet in Puppet.get_all_with_custom_mxid()] + return [puppet.init_custom_mxid() for puppet in Puppet.all_with_custom_mxid()] diff --git a/mautrix_telegram/user.py b/mautrix_telegram/user.py index 4c71df0e..dcb571cb 100644 --- a/mautrix_telegram/user.py +++ b/mautrix_telegram/user.py @@ -48,9 +48,9 @@ class User(AbstractUser): def __init__(self, mxid: MatrixUserID, tgid: Optional[TelegramID] = None, username: Optional[str] = None, phone: Optional[str] = None, - db_contacts: Optional[List[DBContact]] = None, + db_contacts: Optional[Iterable[TelegramID]] = None, saved_contacts: int = 0, is_bot: bool = False, - db_portals: Optional[List[DBPortal]] = None, + db_portals: Optional[Iterable[Tuple[TelegramID, TelegramID]]] = None, db_instance: Optional[DBUser] = None) -> None: super().__init__() self.mxid = mxid # type: MatrixUserID @@ -60,9 +60,9 @@ class User(AbstractUser): self.phone = phone # type: str self.contacts = [] # type: List[pu.Puppet] self.saved_contacts = saved_contacts # type: int - self.db_contacts = db_contacts # type: List[DBContact] - self.portals = {} # type: Dict[Tuple[int, int], po.Portal] - self.db_portals = db_portals or [] # type: List[DBPortal] + self.db_contacts = db_contacts + self.portals = {} # type: Dict[Tuple[TelegramID, TelegramID], po.Portal] + self.db_portals = db_portals or [] self._db_instance = db_instance # type: Optional[DBUser] self.command_status = None # type: Dict @@ -101,23 +101,22 @@ class User(AbstractUser): return self.displayname @property - def db_contacts(self) -> Iterable[DBContact]: - return (DBContact(user=self.tgid, contact=puppet.id) for puppet in self.contacts) + def db_contacts(self) -> Iterable[TelegramID]: + return (puppet.id for puppet in self.contacts) @db_contacts.setter - def db_contacts(self, contacts: Iterable[DBContact]) -> None: - self.contacts = [pu.Puppet.get(entry.contact) for entry in contacts] if contacts else [] + def db_contacts(self, contacts: Iterable[TelegramID]) -> None: + self.contacts = [pu.Puppet.get(entry) for entry in contacts] if contacts else [] @property - def db_portals(self) -> Iterable[DBPortal]: - return (portal.db_instance for portal in self.portals.values() if not portal.deleted) + def db_portals(self) -> Iterable[Tuple[TelegramID, TelegramID]]: + return (portal.tgid_full for portal in self.portals.values() if not portal.deleted) @db_portals.setter - def db_portals(self, portals: Iterable[DBPortal]) -> None: + def db_portals(self, portals: Iterable[Tuple[TelegramID, TelegramID]]) -> None: self.portals = { - (portal.tgid, portal.tg_receiver): po.Portal.get_by_tgid(portal.tgid, - portal.tg_receiver) - for portal in portals + tgid_full: po.Portal.get_by_tgid(*tgid_full) + for tgid_full in portals } if portals else {} # region Database conversion @@ -398,5 +397,5 @@ def init(context: 'Context') -> List[Awaitable['User']]: global config config = context.config - users = [User.from_db(user) for user in DBUser.get_all()] + users = [User.from_db(user) for user in DBUser.all()] return [user.ensure_started() for user in users if user.tgid]