diff --git a/mautrix_telegram/base.py b/mautrix_telegram/base.py index 0b62d886..c3e1756f 100644 --- a/mautrix_telegram/base.py +++ b/mautrix_telegram/base.py @@ -1,2 +1,41 @@ +from abc import abstractmethod + +from sqlalchemy import Table +from sqlalchemy.engine.base import Engine +from sqlalchemy.engine.result import RowProxy +from sqlalchemy.sql.base import ImmutableColumnCollection from sqlalchemy.ext.declarative import declarative_base -Base = declarative_base() # type: declarative_base + + +class BaseBase: + db = None # type: Engine + t = None # type: Table + __table__ = None # type: Table + c = None # type: ImmutableColumnCollection + + @classmethod + @abstractmethod + def _one_or_none(cls, rows: RowProxy): + pass + + @classmethod + def _select_one_or_none(cls, *args): + return cls._one_or_none(cls.db.execute(cls.t.select().where(*args))) + + @property + @abstractmethod + def _edit_identity(self): + pass + + def update(self, **values) -> None: + self.db.execute(self.t.update() + .where(self._edit_identity) + .values(**values)) + for key, value in values.items(): + setattr(self, key, value) + + def delete(self) -> None: + self.db.execute(self.t.delete().where(self._edit_identity)) + + +Base = declarative_base(cls=BaseBase) diff --git a/mautrix_telegram/base.pyi b/mautrix_telegram/base.pyi new file mode 100644 index 00000000..8575893d --- /dev/null +++ b/mautrix_telegram/base.pyi @@ -0,0 +1,26 @@ +from abc import abstractmethod + +from sqlalchemy import Table +from sqlalchemy.engine.base import Engine +from sqlalchemy.engine.result import RowProxy +from sqlalchemy.sql.base import ImmutableColumnCollection +from sqlalchemy.ext.declarative import declarative_base + +class Base(declarative_base): + db: Engine + t: Table + __table__: Table + c: ImmutableColumnCollection + + @classmethod + @abstractmethod + def _one_or_none(cls, rows: RowProxy): ... + + @classmethod + def _select_one_or_none(cls, *args): ... + + def _edit_identity(self): ... + + def update(self, **values) -> None: ... + + def delete(self) -> None: ... diff --git a/mautrix_telegram/db.py b/mautrix_telegram/db.py index d04b58ea..5afc5c66 100644 --- a/mautrix_telegram/db.py +++ b/mautrix_telegram/db.py @@ -15,13 +15,12 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . from sqlalchemy import (Column, UniqueConstraint, ForeignKey, ForeignKeyConstraint, Integer, - BigInteger, String, Boolean, Text, Table, + BigInteger, String, Boolean, Text, and_, func, select) -from sqlalchemy.engine import Engine, RowProxy +from sqlalchemy.engine.result import RowProxy from sqlalchemy.sql import expression from sqlalchemy.orm import relationship, Query -from sqlalchemy.sql.base import ImmutableColumnCollection -from typing import Dict, Optional, List +from typing import Dict, Optional, List, Iterable import json from mautrix_telegram.types import MatrixUserID, MatrixRoomID, MatrixEventID @@ -30,7 +29,6 @@ from .base import Base class Portal(Base): - query = None # type: Query __tablename__ = "portal" # Telegram chat information @@ -50,11 +48,41 @@ class Portal(Base): about = Column(String, nullable=True) photo_id = Column(String, nullable=True) + @classmethod + def _one_or_none(cls, rows: RowProxy) -> Optional['Portal']: + try: + (tgid, tg_receiver, peer_type, megagroup, mxid, config, + username, title, about, photo_id) = next(rows) + return cls(tgid=tgid, tg_receiver=tg_receiver, peer_type=peer_type, + megagroup=megagroup, mxid=mxid, config=config, username=username, + title=title, about=about, photo_id=photo_id) + except StopIteration: + return None + + @classmethod + def get_by_tgid(cls, tgid: TelegramID, tg_receiver: TelegramID) -> Optional['Portal']: + return cls._select_one_or_none(and_(cls.c.tgid == tgid, cls.c.tg_receiver == tg_receiver)) + + @classmethod + def get_by_mxid(cls, mxid: MatrixRoomID) -> Optional['Portal']: + return cls._select_one_or_none(cls.c.mxid == mxid) + + @classmethod + def get_by_username(cls, username: str) -> Optional['Portal']: + return cls._select_one_or_none(cls.c.username == username) + + @property + def _edit_identity(self): + return and_(self.c.tgid == self.tgid, self.c.tg_receiver == self.tg_receiver) + + def insert(self) -> None: + self.db.execute(self.t.insert().values( + tgid=self.tgid, tg_receiver=self.tg_receiver, peer_type=self.peer_type, + megagroup=self.megagroup, mxid=self.mxid, config=self.config, username=self.username, + title=self.title, about=self.about, photo_id=self.photo_id)) + class Message(Base): - db = None # type: Engine - t = None # type: Table - c = None # type: ImmutableColumnCollection __tablename__ = "message" mxid = Column(String) # type: MatrixEventID @@ -64,11 +92,11 @@ class Message(Base): __table_args__ = (UniqueConstraint("mxid", "mx_room", "tg_space", name="_mx_id_room"),) - @staticmethod - def _one_or_none(rows: RowProxy) -> Optional['Message']: + @classmethod + def _one_or_none(cls, rows: RowProxy) -> Optional['Message']: try: mxid, mx_room, tgid, tg_space = next(rows) - return Message(mxid=mxid, mx_room=mx_room, tgid=tgid, tg_space=tg_space) + return cls(mxid=mxid, mx_room=mx_room, tgid=tgid, tg_space=tg_space) except StopIteration: return None @@ -79,9 +107,7 @@ class Message(Base): @classmethod def get_by_tgid(cls, tgid: TelegramID, tg_space: TelegramID) -> Optional['Message']: - rows = cls.db.execute(cls.t.select() - .where(and_(cls.c.tgid == tgid, cls.c.tg_space == tg_space))) - return cls._one_or_none(rows) + return cls._select_one_or_none(and_(cls.c.tgid == tgid, cls.c.tg_space == tg_space)) @classmethod def count_spaces_by_mxid(cls, mxid: MatrixEventID, mx_room: MatrixRoomID) -> int: @@ -96,9 +122,9 @@ class Message(Base): @classmethod def get_by_mxid(cls, mxid: MatrixEventID, mx_room: MatrixRoomID, tg_space: TelegramID ) -> Optional['Message']: - rows = cls.db.execute(cls.t.select().where( - and_(cls.c.mxid == mxid, cls.c.mx_room == mx_room, cls.c.tg_space == tg_space))) - return cls._one_or_none(rows) + return cls._select_one_or_none(and_(cls.c.mxid == mxid, + cls.c.mx_room == mx_room, + cls.c.tg_space == tg_space)) @classmethod def update_by_tgid(cls, s_tgid: TelegramID, s_tg_space: TelegramID, **values) -> None: @@ -112,36 +138,16 @@ class Message(Base): .where(and_(cls.c.mxid == s_mxid, cls.c.mx_room == s_mx_room)) .values(**values)) - def update(self, **values) -> None: - for key, value in values.items(): - setattr(self, key, value) - self.update_by_tgid(self.tgid, self.tg_space, **values) - - def delete(self) -> None: - self.db.execute(self.t.delete().where( - and_(self.c.tgid == self.tgid, self.c.tg_space == self.tg_space))) + @property + def _edit_identity(self): + return and_(self.c.tgid == self.tgid, self.c.tg_space == self.tg_space) def insert(self) -> None: self.db.execute(self.t.insert().values(mxid=self.mxid, mx_room=self.mx_room, tgid=self.tgid, tg_space=self.tg_space)) -class UserPortal(Base): - query = None # type: Query - __tablename__ = "user_portal" - - user = Column(Integer, ForeignKey("user.tgid", onupdate="CASCADE", ondelete="CASCADE"), - primary_key=True) # type: TelegramID - portal = Column(Integer, primary_key=True) # type: TelegramID - portal_receiver = Column(Integer, primary_key=True) # type: TelegramID - - __table_args__ = (ForeignKeyConstraint(("portal", "portal_receiver"), - ("portal.tgid", "portal.tg_receiver"), - onupdate="CASCADE", ondelete="CASCADE"),) - - class User(Base): - query = None # type: Query __tablename__ = "user" mxid = Column(String, primary_key=True) # type: MatrixUserID @@ -154,11 +160,66 @@ class User(Base): ) # type: List[Contact] portals = relationship("Portal", secondary="user_portal") + @classmethod + def _one_or_none(cls, rows: RowProxy) -> Optional['User']: + try: + mxid, tgid, tg_username, tg_phone, saved_contacts = next(rows) + return cls(mxid=mxid, tgid=tgid, tg_username=tg_username, tg_phone=tg_phone, + saved_contacts=saved_contacts) + except StopIteration: + return None + + @classmethod + def get_all(cls) -> Iterable['User']: + rows = cls.db.execute(cls.t.select()) + for row in rows: + mxid, tgid, tg_username, tg_phone, saved_contacts = row + yield cls(mxid=mxid, tgid=tgid, tg_username=tg_username, tg_phone=tg_phone, + saved_contacts=saved_contacts) + + @classmethod + def get_by_tgid(cls, tgid: TelegramID) -> Optional['User']: + return cls._select_one_or_none(cls.c.tgid == tgid) + + @classmethod + def get_by_mxid(cls, mxid: MatrixRoomID) -> Optional['User']: + return cls._select_one_or_none(cls.c.mxid == mxid) + + @classmethod + def get_by_username(cls, username: str) -> Optional['User']: + return cls._select_one_or_none(cls.c.username == username) + + @property + def _edit_identity(self): + return self.c.mxid == self.mxid + + def insert(self) -> None: + self.db.execute(self.t.insert().values( + mxid=self.mxid, tgid=self.tgid, tg_username=self.tg_username, tg_phone=self.tg_phone, + saved_contacts=self.saved_contacts)) + + +class UserPortal(Base): + __tablename__ = "user_portal" + + user = Column(Integer, ForeignKey("user.tgid", onupdate="CASCADE", ondelete="CASCADE"), + primary_key=True) # type: TelegramID + portal = Column(Integer, primary_key=True) # type: TelegramID + portal_receiver = Column(Integer, primary_key=True) # type: TelegramID + + __table_args__ = (ForeignKeyConstraint(("portal", "portal_receiver"), + ("portal.tgid", "portal.tg_receiver"), + onupdate="CASCADE", ondelete="CASCADE"),) + + +class Contact(Base): + __tablename__ = "contact" + + user = Column(Integer, ForeignKey("user.tgid"), primary_key=True) # type: TelegramID + contact = Column(Integer, ForeignKey("puppet.id"), primary_key=True) # type: TelegramID + class RoomState(Base): - db = None # type: Engine - t = None # type: Table - c = None # type: ImmutableColumnCollection __tablename__ = "mx_room_state" room_id = Column(String, primary_key=True) # type: MatrixRoomID @@ -177,18 +238,17 @@ class RoomState(Base): rows = cls.db.execute(cls.t.select().where(cls.c.room_id == room_id)) try: room_id, power_levels_text = next(rows) - return RoomState(room_id=room_id, power_levels=(json.loads(power_levels_text) - if power_levels_text else None)) + return cls(room_id=room_id, power_levels=(json.loads(power_levels_text) + if power_levels_text else None)) except StopIteration: return None def update(self) -> None: - self.db.execute(self.t.update() - .where(self.c.room_id == self.room_id) - .values(power_levels=self._power_levels_text)) + return super().update(power_levels=self._power_levels_text) - def delete(self) -> None: - self.db.execute(self.t.delete().where(self.c.room_id == self.room_id)) + @property + def _edit_identity(self): + return self.c.room_id == self.room_id def insert(self) -> None: self.db.execute(self.t.insert().values(room_id=self.room_id, @@ -196,9 +256,6 @@ class RoomState(Base): class UserProfile(Base): - db = None # type: Engine - t = None # type: Table - c = None # type: ImmutableColumnCollection __tablename__ = "mx_user_profile" room_id = Column(String, primary_key=True) # type: MatrixRoomID @@ -220,8 +277,8 @@ class UserProfile(Base): cls.t.select().where(and_(cls.c.room_id == room_id, cls.c.user_id == user_id))) try: room_id, user_id, membership, displayname, avatar_url = next(rows) - return UserProfile(room_id=room_id, user_id=user_id, membership=membership, - displayname=displayname, avatar_url=avatar_url) + return cls(room_id=room_id, user_id=user_id, membership=membership, + displayname=displayname, avatar_url=avatar_url) except StopIteration: return None @@ -230,14 +287,12 @@ class UserProfile(Base): cls.db.execute(cls.t.delete().where(cls.c.room_id == room_id)) def update(self) -> None: - self.db.execute(self.t.update() - .where(and_(self.c.room_id == self.room_id, self.c.user_id == self.user_id)) - .values(membership=self.membership, displayname=self.displayname, - avatar_url=self.avatar_url)) + super().update(membership=self.membership, displayname=self.displayname, + avatar_url=self.avatar_url) - def delete(self) -> None: - self.db.execute(self.t.delete().where(and_(self.c.room_id == self.room_id, - self.c.user_id == self.user_id))) + @property + def _edit_identity(self): + return and_(self.c.room_id == self.room_id, self.c.user_id == self.user_id) def insert(self) -> None: self.db.execute(self.t.insert().values(room_id=self.room_id, user_id=self.user_id, @@ -246,14 +301,6 @@ class UserProfile(Base): avatar_url=self.avatar_url)) -class Contact(Base): - query = None # type: Query - __tablename__ = "contact" - - user = Column(Integer, ForeignKey("user.tgid"), primary_key=True) # type: TelegramID - contact = Column(Integer, ForeignKey("puppet.id"), primary_key=True) # type: TelegramID - - class Puppet(Base): query = None # type: Query __tablename__ = "puppet" @@ -278,9 +325,6 @@ class BotChat(Base): class TelegramFile(Base): - db = None # type: Engine - t = None # type: Table - c = None # type: ImmutableColumnCollection __tablename__ = "telegram_file" id = Column(String, primary_key=True) @@ -302,8 +346,8 @@ class TelegramFile(Base): thumb = None if thumb_id: thumb = cls.get(thumb_id) - return TelegramFile(id=id, mxc=mxc, mime_type=mime, was_converted=conv, timestamp=ts, - size=s, width=w, height=h, thumbnail_id=thumb_id, thumbnail=thumb) + return cls(id=id, mxc=mxc, mime_type=mime, was_converted=conv, timestamp=ts, + size=s, width=w, height=h, thumbnail_id=thumb_id, thumbnail=thumb) except StopIteration: return None @@ -316,8 +360,7 @@ class TelegramFile(Base): def init(db_session, db_engine) -> None: query = db_session.query_property() - for table in (Portal, Message, UserPortal, User, Puppet, BotChat, TelegramFile, UserProfile, - RoomState): + for table in (Portal, Message, User, Puppet, BotChat, TelegramFile, UserProfile, RoomState): table.query = query table.db = db_engine table.t = table.__table__ diff --git a/mautrix_telegram/portal.py b/mautrix_telegram/portal.py index 9c411a21..3a04b74f 100644 --- a/mautrix_telegram/portal.py +++ b/mautrix_telegram/portal.py @@ -31,7 +31,6 @@ import json import re import magic -from sqlalchemy import orm from sqlalchemy.exc import IntegrityError from telethon.tl.functions.messages import ( @@ -89,7 +88,6 @@ InviteList = Union[MatrixUserID, List[MatrixUserID]] class Portal: log = logging.getLogger("mau.portal") # type: logging.Logger - db = None # type: orm.Session az = None # type: AppService bot = None # type: Bot loop = None # type: asyncio.AbstractEventLoop @@ -1255,8 +1253,7 @@ class Portal: self.tg_receiver = self.tgid self.by_tgid[self.tgid_full] = self await self.update_info(source, entity) - self.db.add(self.db_instance) - self.save() + self.db_instance.insert() if self.bot and self.bot.tgid in invites: self.bot.add_chat(self.tgid, self.peer_type) @@ -1842,15 +1839,13 @@ class Portal: del self.by_tgid[self.tgid_full] except KeyError: pass - self.tgid = new_id - self.tg_receiver = new_id - existing = self.by_tgid[self.tgid_full] + existing = self.by_tgid[(new_id, new_id)] if existing: existing.delete() + self.db_instance.update(tgid=new_id, tg_receiver=new_id) + self.tgid = new_id + self.tg_receiver = new_id self.by_tgid[self.tgid_full] = self - self.db_instance.tgid = self.tgid - self.db_instance.tg_receiver = self.tg_receiver - self.save() def migrate_and_save_matrix(self, new_id: MatrixRoomID) -> None: try: @@ -1858,17 +1853,13 @@ class Portal: except KeyError: pass self.mxid = new_id + self.db_instance.update(mxid=self.mxid) self.by_mxid[self.mxid] = self - self.save() def save(self) -> None: - self.db_instance.mxid = self.mxid - self.db_instance.username = self.username - self.db_instance.title = self.title - self.db_instance.about = self.about - self.db_instance.photo_id = self.photo_id - self.db_instance.config = json.dumps(self.local_config) - self.db.commit() + self.db_instance.update(mxid=self.mxid, username=self.username, title=self.title, + about=self.about, photo_id=self.photo_id, + config=json.dumps(self.local_config)) def delete(self) -> None: try: @@ -1880,8 +1871,7 @@ class Portal: except KeyError: pass if self._db_instance: - self.db.delete(self._db_instance) - self.db.commit() + self._db_instance.delete() self.deleted = True @classmethod @@ -1902,7 +1892,7 @@ class Portal: except KeyError: pass - portal = DBPortal.query.filter(DBPortal.mxid == mxid).one_or_none() + portal = DBPortal.get_by_mxid(mxid) if portal: return cls.from_db(portal) @@ -1924,7 +1914,7 @@ class Portal: if portal.username and portal.username.lower() == username.lower(): return portal - dbportal = DBPortal.query.filter(DBPortal.username == username).one_or_none() + dbportal = DBPortal.get_by_username(username) if dbportal: return cls.from_db(dbportal) @@ -1940,14 +1930,13 @@ class Portal: except KeyError: pass - portal = DBPortal.query.get(tgid_full) + portal = DBPortal.get_by_tgid(tgid, tg_receiver) if portal: return cls.from_db(portal) if peer_type: portal = Portal(tgid, peer_type=peer_type, tg_receiver=tg_receiver) - cls.db.add(portal.db_instance) - cls.db.commit() + portal.db_instance.insert() return portal return None @@ -1987,7 +1976,7 @@ class Portal: def init(context: Context) -> None: global config - Portal.az, Portal.db, config, Portal.loop, Portal.bot = context.core + Portal.az, _, config, Portal.loop, Portal.bot = context.core Portal.max_initial_member_sync = config["bridge.max_initial_member_sync"] Portal.sync_channel_members = config["bridge.sync_channel_members"] Portal.sync_matrix_state = config["bridge.sync_matrix_state"] diff --git a/mautrix_telegram/user.py b/mautrix_telegram/user.py index 55929475..4c71df0e 100644 --- a/mautrix_telegram/user.py +++ b/mautrix_telegram/user.py @@ -14,7 +14,7 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Awaitable, Dict, List, Match, NewType, Optional, Tuple, TYPE_CHECKING +from typing import Awaitable, Dict, List, Iterable, Match, NewType, Optional, Tuple, TYPE_CHECKING import logging import asyncio import re @@ -101,20 +101,19 @@ class User(AbstractUser): return self.displayname @property - def db_contacts(self) -> List[DBContact]: - return [self.db.merge(DBContact(user=self.tgid, contact=puppet.id)) - for puppet in self.contacts] + def db_contacts(self) -> Iterable[DBContact]: + return (DBContact(user=self.tgid, contact=puppet.id) for puppet in self.contacts) @db_contacts.setter - def db_contacts(self, contacts: List[DBContact]) -> None: + def db_contacts(self, contacts: Iterable[DBContact]) -> None: self.contacts = [pu.Puppet.get(entry.contact) for entry in contacts] if contacts else [] @property - def db_portals(self) -> List[DBPortal]: - return [portal.db_instance for portal in self.portals.values() if not portal.deleted] + def db_portals(self) -> Iterable[DBPortal]: + return (portal.db_instance for portal in self.portals.values() if not portal.deleted) @db_portals.setter - def db_portals(self, portals: List[DBPortal]) -> None: + def db_portals(self, portals: Iterable[DBPortal]) -> None: self.portals = { (portal.tgid, portal.tg_receiver): po.Portal.get_by_tgid(portal.tgid, portal.tg_receiver) @@ -135,13 +134,8 @@ class User(AbstractUser): portals=self.db_portals) def save(self) -> None: - self.db_instance.tgid = self.tgid - self.db_instance.tg_username = self.username - self.db_instance.tg_phone = self.phone - self.db_instance.contacts = self.db_contacts - self.db_instance.saved_contacts = self.saved_contacts - self.db_instance.portals = self.db_portals - self.db.commit() + self.db_instance.update(tgid=self.tgid, tg_username=self.username, tg_phone=self.phone, + saved_contacts=self.saved_contacts) def delete(self) -> None: try: @@ -150,8 +144,7 @@ class User(AbstractUser): except KeyError: pass if self._db_instance: - self.db.delete(self._db_instance) - self.db.commit() + self._db_instance.delete() @classmethod def from_db(cls, db_user: DBUser) -> 'User': @@ -358,15 +351,14 @@ class User(AbstractUser): except KeyError: pass - user = DBUser.query.get(mxid) + user = DBUser.get_by_mxid(mxid) if user: user = cls.from_db(user) return user if create: user = cls(mxid) - cls.db.add(user.db_instance) - cls.db.commit() + user.db_instance.insert() return user return None @@ -378,7 +370,7 @@ class User(AbstractUser): except KeyError: pass - user = DBUser.query.filter(DBUser.tgid == tgid).one_or_none() + user = DBUser.get_by_tgid(tgid) if user: user = cls.from_db(user) return user @@ -394,7 +386,7 @@ class User(AbstractUser): if user.username and user.username.lower() == username.lower(): return user - puppet = DBUser.query.filter(DBUser.tg_username == username).one_or_none() + puppet = DBUser.get_by_username(username) if puppet: return cls.from_db(puppet) @@ -406,5 +398,5 @@ def init(context: 'Context') -> List[Awaitable['User']]: global config config = context.config - users = [User.from_db(user) for user in DBUser.query.all()] + users = [User.from_db(user) for user in DBUser.get_all()] return [user.ensure_started() for user in users if user.tgid]