diff --git a/alembic/env.py b/alembic/env.py index ef158c7d..7d7e8967 100644 --- a/alembic/env.py +++ b/alembic/env.py @@ -7,7 +7,7 @@ from os.path import abspath, dirname sys.path.insert(0, dirname(dirname(abspath(__file__)))) -from mautrix.bridge.db import Base +from mautrix.util.db import Base import mautrix_telegram.db from mautrix_telegram.config import Config from alchemysession import AlchemySessionContainer diff --git a/alembic/versions/6ca3d74d51e4_move_state_store_to_main_database.py b/alembic/versions/6ca3d74d51e4_move_state_store_to_main_database.py index a66b21d5..b0ebf1ce 100644 --- a/alembic/versions/6ca3d74d51e4_move_state_store_to_main_database.py +++ b/alembic/versions/6ca3d74d51e4_move_state_store_to_main_database.py @@ -12,7 +12,7 @@ from alembic import context, op import sqlalchemy.orm as orm import sqlalchemy as sa -from mautrix.bridge.db import Base +from mautrix.util.db import Base from mautrix_telegram.config import Config diff --git a/mautrix_telegram/__main__.py b/mautrix_telegram/__main__.py index 1a7e07d0..a59c0e53 100644 --- a/mautrix_telegram/__main__.py +++ b/mautrix_telegram/__main__.py @@ -19,7 +19,7 @@ from itertools import chain from alchemysession import AlchemySessionContainer from mautrix.bridge import Bridge -from mautrix.bridge.db import Base +from mautrix.util.db import Base from .web.provisioning import ProvisioningAPI from .web.public import PublicBridgeWebsite diff --git a/mautrix_telegram/db/__init__.py b/mautrix_telegram/db/__init__.py index 28106767..92a824f0 100644 --- a/mautrix_telegram/db/__init__.py +++ b/mautrix_telegram/db/__init__.py @@ -31,3 +31,4 @@ def init(db_engine: Engine) -> None: table.db = db_engine table.t = table.__table__ table.c = table.t.c + table.column_names = table.c.keys() diff --git a/mautrix_telegram/db/bot_chat.py b/mautrix_telegram/db/bot_chat.py index 9903a630..449afbe1 100644 --- a/mautrix_telegram/db/bot_chat.py +++ b/mautrix_telegram/db/bot_chat.py @@ -16,9 +16,8 @@ from typing import Iterable from sqlalchemy import Column, Integer, String -from sqlalchemy.engine.result import RowProxy -from mautrix.bridge.db import Base +from mautrix.util.db import Base from ..types import TelegramID @@ -34,14 +33,6 @@ class BotChat(Base): with cls.db.begin() as conn: conn.execute(cls.t.delete().where(cls.c.id == chat_id)) - @classmethod - def scan(cls, row: RowProxy) -> 'BotChat': - return cls(id=row[0], type=row[1]) - @classmethod def all(cls) -> Iterable['BotChat']: return cls._select_all() - - def insert(self) -> None: - with self.db.begin() as conn: - conn.execute(self.t.insert().values(id=self.id, type=self.type)) diff --git a/mautrix_telegram/db/message.py b/mautrix_telegram/db/message.py index 83082718..06608328 100644 --- a/mautrix_telegram/db/message.py +++ b/mautrix_telegram/db/message.py @@ -16,11 +16,9 @@ from typing import Optional, Iterator from sqlalchemy import Column, UniqueConstraint, Integer, String, and_, func, desc, select -from sqlalchemy.engine.result import RowProxy -from sqlalchemy.sql.expression import ClauseElement from mautrix.types import RoomID, EventID -from mautrix.bridge.db import Base +from mautrix.util.db import Base from ..types import TelegramID @@ -36,29 +34,21 @@ class Message(Base): __table_args__ = (UniqueConstraint("mxid", "mx_room", "tg_space", name="_mx_id_room_2"),) - @classmethod - def scan(cls, row: RowProxy) -> 'Message': - return cls(mxid=row[0], mx_room=row[1], tgid=row[2], tg_space=row[3], edit_index=row[4]) - @classmethod def get_all_by_tgid(cls, tgid: TelegramID, tg_space: TelegramID) -> Iterator['Message']: - return cls._all(cls.db.execute(cls.t.select().where(and_(cls.c.tgid == tgid, - cls.c.tg_space == tg_space)))) + return cls._select_all(cls.c.tgid == tgid, cls.c.tg_space == tg_space) @classmethod def get_one_by_tgid(cls, tgid: TelegramID, tg_space: TelegramID, edit_index: int = 0 ) -> Optional['Message']: - query = cls.t.select() if edit_index < 0: - query = (query - .where(and_(cls.c.tgid == tgid, cls.c.tg_space == tg_space)) - .order_by(desc(cls.c.edit_index)) - .limit(1) - .offset(-edit_index - 1)) + return cls._one_or_none(cls.t.select() + .where(and_(cls.c.tgid == tgid, cls.c.tg_space == tg_space)) + .order_by(desc(cls.c.edit_index)) + .limit(1).offset(-edit_index - 1)) else: - query = query.where(and_(cls.c.tgid == tgid, cls.c.tg_space == tg_space, - cls.c.edit_index == edit_index)) - return cls._one_or_none(cls.db.execute(query)) + return cls._select_one_or_none(cls.c.tgid == tgid, cls.c.tg_space == tg_space, + cls.c.edit_index == edit_index) @classmethod def count_spaces_by_mxid(cls, mxid: EventID, mx_room: RoomID) -> int: @@ -73,9 +63,8 @@ class Message(Base): @classmethod def get_by_mxid(cls, mxid: EventID, mx_room: RoomID, tg_space: TelegramID ) -> Optional['Message']: - return cls._select_one_or_none(and_(cls.c.mxid == mxid, - cls.c.mx_room == mx_room, - cls.c.tg_space == tg_space)) + return cls._select_one_or_none(cls.c.mxid == mxid, cls.c.mx_room == mx_room, + cls.c.tg_space == tg_space) @classmethod def update_by_tgid(cls, s_tgid: TelegramID, s_tg_space: TelegramID, s_edit_index: int, @@ -92,14 +81,3 @@ class Message(Base): conn.execute(cls.t.update() .where(and_(cls.c.mxid == s_mxid, cls.c.mx_room == s_mx_room)) .values(**values)) - - @property - def _edit_identity(self) -> ClauseElement: - return and_(self.c.tgid == self.tgid, self.c.tg_space == self.tg_space, - self.c.edit_index == self.edit_index) - - def insert(self) -> None: - with self.db.begin() as conn: - conn.execute(self.t.insert().values(mxid=self.mxid, mx_room=self.mx_room, - tgid=self.tgid, tg_space=self.tg_space, - edit_index=self.edit_index)) diff --git a/mautrix_telegram/db/portal.py b/mautrix_telegram/db/portal.py index c5a04846..dcc418d8 100644 --- a/mautrix_telegram/db/portal.py +++ b/mautrix_telegram/db/portal.py @@ -15,12 +15,10 @@ # along with this program. If not, see . from typing import Optional -from sqlalchemy import Column, Integer, String, Boolean, Text, and_ -from sqlalchemy.engine.result import RowProxy -from sqlalchemy.sql.expression import ClauseElement +from sqlalchemy import Column, Integer, String, Boolean, Text from mautrix.types import RoomID -from mautrix.bridge.db import Base +from mautrix.util.db import Base from ..types import TelegramID @@ -45,17 +43,9 @@ class Portal(Base): about: str = Column(String, nullable=True) photo_id: str = Column(String, nullable=True) - @classmethod - def scan(cls, row: RowProxy) -> Optional['Portal']: - (tgid, tg_receiver, peer_type, megagroup, mxid, config, username, title, about, - photo_id) = row - return cls(tgid=tgid, tg_receiver=tg_receiver, peer_type=peer_type, megagroup=megagroup, - mxid=mxid, config=config, username=username, title=title, about=about, - photo_id=photo_id) - @classmethod def get_by_tgid(cls, tgid: TelegramID, tg_receiver: TelegramID) -> Optional['Portal']: - return cls._select_one_or_none(and_(cls.c.tgid == tgid, cls.c.tg_receiver == tg_receiver)) + return cls._select_one_or_none(cls.c.tgid == tgid, cls.c.tg_receiver == tg_receiver) @classmethod def get_by_mxid(cls, mxid: RoomID) -> Optional['Portal']: @@ -64,14 +54,3 @@ class Portal(Base): @classmethod def get_by_username(cls, username: str) -> Optional['Portal']: return cls._select_one_or_none(cls.c.username == username) - - @property - def _edit_identity(self) -> ClauseElement: - return and_(self.c.tgid == self.tgid, self.c.tg_receiver == self.tg_receiver) - - def insert(self) -> None: - with self.db.begin() as conn: - conn.execute(self.t.insert().values( - tgid=self.tgid, tg_receiver=self.tg_receiver, peer_type=self.peer_type, - megagroup=self.megagroup, mxid=self.mxid, config=self.config, - username=self.username, title=self.title, about=self.about, photo_id=self.photo_id)) diff --git a/mautrix_telegram/db/puppet.py b/mautrix_telegram/db/puppet.py index 8b3027e2..c647b8d3 100644 --- a/mautrix_telegram/db/puppet.py +++ b/mautrix_telegram/db/puppet.py @@ -17,11 +17,9 @@ from typing import Optional, Iterable from sqlalchemy import Column, Integer, String, Boolean from sqlalchemy.sql import expression -from sqlalchemy.engine.result import RowProxy -from sqlalchemy.sql.expression import ClauseElement from mautrix.types import UserID, SyncToken -from mautrix.bridge.db import Base +from mautrix.util.db import Base from ..types import TelegramID @@ -41,20 +39,9 @@ class Puppet(Base): matrix_registered: bool = Column(Boolean, nullable=False, server_default=expression.false()) disable_updates: bool = Column(Boolean, nullable=False, server_default=expression.false()) - @classmethod - def scan(cls, row: RowProxy) -> Optional['Puppet']: - (id, custom_mxid, access_token, next_batch, displayname, displayname_source, username, - photo_id, is_bot, matrix_registered, disable_updates) = row - return cls(id=id, custom_mxid=custom_mxid, access_token=access_token, username=username, - next_batch=next_batch, displayname=displayname, photo_id=photo_id, - displayname_source=displayname_source, matrix_registered=matrix_registered, - disable_updates=disable_updates, is_bot=is_bot) - @classmethod def all_with_custom_mxid(cls) -> Iterable['Puppet']: - rows = cls.db.execute(cls.t.select().where(cls.c.custom_mxid != None)) - for row in rows: - yield cls.scan(row) + yield from cls._select_all(cls.c.custom_mxid != None) @classmethod def get_by_tgid(cls, tgid: TelegramID) -> Optional['Puppet']: @@ -71,16 +58,3 @@ class Puppet(Base): @classmethod def get_by_displayname(cls, displayname: str) -> Optional['Puppet']: return cls._select_one_or_none(cls.c.displayname == displayname) - - @property - def _edit_identity(self) -> ClauseElement: - return self.c.id == self.id - - def insert(self) -> None: - with self.db.begin() as conn: - conn.execute(self.t.insert().values( - id=self.id, custom_mxid=self.custom_mxid, access_token=self.access_token, - next_batch=self.next_batch, displayname=self.displayname, username=self.username, - displayname_source=self.displayname_source, photo_id=self.photo_id, - is_bot=self.is_bot, matrix_registered=self.matrix_registered, - disable_updates=self.disable_updates)) diff --git a/mautrix_telegram/db/telegram_file.py b/mautrix_telegram/db/telegram_file.py index 909bd782..4ac05293 100644 --- a/mautrix_telegram/db/telegram_file.py +++ b/mautrix_telegram/db/telegram_file.py @@ -19,7 +19,7 @@ from sqlalchemy import Column, ForeignKey, Integer, BigInteger, String, Boolean from sqlalchemy.engine.result import RowProxy from mautrix.types import ContentURI -from mautrix.bridge.db import Base +from mautrix.util.db import Base class TelegramFile(Base): @@ -38,12 +38,10 @@ class TelegramFile(Base): @classmethod def scan(cls, row: RowProxy) -> 'TelegramFile': - loc_id, mxc, mime, conv, ts, s, w, h, thumb_id = row - thumb = None - if thumb_id: - thumb = cls.get(thumb_id) - return cls(id=loc_id, mxc=mxc, mime_type=mime, was_converted=conv, timestamp=ts, - size=s, width=w, height=h, thumbnail_id=thumb_id, thumbnail=thumb) + telegram_file: TelegramFile = super().scan(row) + if telegram_file.thumbnail_id: + telegram_file.thumbnail = cls.get(telegram_file.thumbnail_id) + return telegram_file @classmethod def get(cls, loc_id: str) -> Optional['TelegramFile']: diff --git a/mautrix_telegram/db/user.py b/mautrix_telegram/db/user.py index 1580e277..36b9d7b7 100644 --- a/mautrix_telegram/db/user.py +++ b/mautrix_telegram/db/user.py @@ -16,11 +16,9 @@ from typing import Optional, Iterable, Tuple from sqlalchemy import Column, ForeignKey, ForeignKeyConstraint, Integer, String -from sqlalchemy.engine.result import RowProxy -from sqlalchemy.sql.expression import ClauseElement from mautrix.types import UserID -from mautrix.bridge.db import Base +from mautrix.util.db import Base from ..types import TelegramID @@ -34,12 +32,6 @@ class User(Base): tg_phone: str = Column(String, nullable=True) saved_contacts: int = Column(Integer, default=0, nullable=False) - @classmethod - def scan(cls, row: RowProxy) -> 'User': - mxid, tgid, tg_username, tg_phone, saved_contacts = row - return cls(mxid=mxid, tgid=tgid, tg_username=tg_username, tg_phone=tg_phone, - saved_contacts=saved_contacts) - @classmethod def all_with_tgid(cls) -> Iterable['User']: return cls._select_all(cls.c.tgid != None) @@ -56,16 +48,6 @@ class User(Base): def get_by_username(cls, username: str) -> Optional['User']: return cls._select_one_or_none(cls.c.tg_username == username) - @property - def _edit_identity(self) -> ClauseElement: - return self.c.mxid == self.mxid - - def insert(self) -> None: - with self.db.begin() as conn: - conn.execute(self.t.insert().values( - mxid=self.mxid, tgid=self.tgid, tg_username=self.tg_username, - tg_phone=self.tg_phone, saved_contacts=self.saved_contacts)) - @property def contacts(self) -> Iterable[TelegramID]: rows = self.db.execute(Contact.t.select().where(Contact.c.user == self.tgid)) diff --git a/mautrix_telegram/scripts/telematrix_import/__main__.py b/mautrix_telegram/scripts/telematrix_import/__main__.py index 5324cc8b..b6ef97c0 100644 --- a/mautrix_telegram/scripts/telematrix_import/__main__.py +++ b/mautrix_telegram/scripts/telematrix_import/__main__.py @@ -19,7 +19,7 @@ import argparse from sqlalchemy import orm import sqlalchemy as sql -from mautrix.bridge.db import Base +from mautrix.util.db import Base from mautrix_telegram.db import Portal, Message, Puppet, BotChat from mautrix_telegram.config import Config