From 388e4f8601064335da78ec167a441d965ea3d765 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 20 Oct 2018 23:11:10 +0300 Subject: [PATCH] Port Message table to SQLAlchemy Core --- mautrix_telegram/__main__.py | 2 +- mautrix_telegram/abstract_user.py | 13 ++-- mautrix_telegram/db.py | 76 ++++++++++++++++++- .../formatter/from_matrix/__init__.py | 4 +- mautrix_telegram/formatter/from_telegram.py | 6 +- mautrix_telegram/portal.py | 54 +++++-------- 6 files changed, 102 insertions(+), 53 deletions(-) diff --git a/mautrix_telegram/__main__.py b/mautrix_telegram/__main__.py index 8bacb4dc..ffd7e0c7 100644 --- a/mautrix_telegram/__main__.py +++ b/mautrix_telegram/__main__.py @@ -113,7 +113,7 @@ if config["appservice.provisioning.enabled"]: context.provisioning_api = provisioning_api with appserv.run(config["appservice.hostname"], config["appservice.port"]) as start: - init_db(db_session) + init_db(db_session, db_engine) init_abstract_user(context) context.bot = init_bot(context) context.mx = MatrixHandler(context) diff --git a/mautrix_telegram/abstract_user.py b/mautrix_telegram/abstract_user.py index 1766e7b7..5075fd67 100644 --- a/mautrix_telegram/abstract_user.py +++ b/mautrix_telegram/abstract_user.py @@ -230,7 +230,7 @@ class AbstractUser(ABC): return # We check that these are user read receipts, so tg_space is always the user ID. - message = DBMessage.query.get((update.max_id, self.tgid)) + message = DBMessage.get_by_tgid(update.max_id, self.tgid) if not message: return @@ -323,12 +323,11 @@ class AbstractUser(ABC): return for message in update.messages: - message = DBMessage.query.get((message, self.tgid)) + message = DBMessage.get_by_tgid(TelegramID(message), self.tgid) if not message: continue - self.db.delete(message) - number_left = DBMessage.query.filter(DBMessage.mxid == message.mxid, - DBMessage.mx_room == message.mx_room).count() + message.delete() + number_left = DBMessage.count_spaces_by_mxid(message.mxid, message.mx_room) if number_left == 0: portal = po.Portal.get_by_mxid(message.mx_room) await self._try_redact(portal, message) @@ -343,10 +342,10 @@ class AbstractUser(ABC): return for message in update.messages: - message = DBMessage.query.get((message, portal.tgid)) + message = DBMessage.get_by_tgid(TelegramID(message), portal.tgid) if not message: continue - self.db.delete(message) + message.delete() await self._try_redact(portal, message) self.db.commit() diff --git a/mautrix_telegram/db.py b/mautrix_telegram/db.py index c78da4a5..5da04161 100644 --- a/mautrix_telegram/db.py +++ b/mautrix_telegram/db.py @@ -15,9 +15,12 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . from sqlalchemy import (Column, UniqueConstraint, ForeignKey, ForeignKeyConstraint, Integer, - BigInteger, String, Boolean, Text) + BigInteger, String, Boolean, Text, Table, + and_, func, select) +from sqlalchemy.engine import Engine, RowProxy from sqlalchemy.sql import expression from sqlalchemy.orm import relationship, Query +from sqlalchemy.sql.base import ImmutableColumnCollection from typing import Dict, Optional, List import json @@ -49,7 +52,9 @@ class Portal(Base): class Message(Base): - query = None # type: Query + db = None # type: Engine + t = None # type: Table + c = None # type: ImmutableColumnCollection __tablename__ = "message" mxid = Column(String) # type: MatrixEventID @@ -59,6 +64,67 @@ class Message(Base): __table_args__ = (UniqueConstraint("mxid", "mx_room", "tg_space", name="_mx_id_room"),) + @staticmethod + def _one_or_none(rows: RowProxy) -> Optional['Message']: + try: + mxid, mx_room, tgid, tg_space = next(rows) + return Message(mxid=mxid, mx_room=mx_room, tgid=tgid, tg_space=tg_space) + except StopIteration: + return None + + @staticmethod + def _all(rows: RowProxy) -> List['Message']: + return [Message(mxid=row[0], mx_room=row[1], tgid=row[2], tg_space=row[3]) + for row in rows] + + @classmethod + def get_by_tgid(cls, tgid: TelegramID, tg_space: TelegramID) -> Optional['Message']: + rows = cls.db.execute(cls.t.select() + .where(and_(cls.c.tgid == tgid, cls.c.tg_space == tg_space))) + return cls._one_or_none(rows) + + @classmethod + def count_spaces_by_mxid(cls, mxid: MatrixEventID, mx_room: MatrixRoomID) -> int: + rows = cls.db.execute(select([func.count(cls.c.tg_space)]) + .where(and_(cls.c.mxid == mxid, cls.c.mx_room == mx_room))) + try: + count, = next(rows) + return count + except StopIteration: + return 0 + + @classmethod + def get_by_mxid(cls, mxid: MatrixEventID, mx_room: MatrixRoomID, tg_space: TelegramID + ) -> Optional['Message']: + rows = cls.db.execute(cls.t.select().where( + and_(cls.c.mxid == mxid, cls.c.mx_room == mx_room, cls.c.tg_space == tg_space))) + return cls._one_or_none(rows) + + @classmethod + def update_by_tgid(cls, s_tgid: TelegramID, s_tg_space: TelegramID, **values) -> None: + cls.db.execute(cls.t.update() + .where(and_(cls.c.tgid == s_tgid, cls.c.tg_space == s_tg_space)) + .values(**values)) + + @classmethod + def update_by_mxid(cls, s_mxid: MatrixEventID, s_mx_room: MatrixRoomID, **values) -> None: + cls.db.execute(cls.t.update() + .where(and_(cls.c.mxid == s_mxid, cls.c.mx_room == s_mx_room)) + .values(**values)) + + def update(self, **values) -> None: + for key, value in values.items(): + setattr(self, key, value) + self.update_by_tgid(self.tgid, self.tg_space, **values) + + def delete(self) -> None: + self.db.execute(self.t.delete().where( + and_(self.c.tgid == self.tgid, self.c.tg_space == self.tg_space))) + + def insert(self) -> None: + self.db.execute(self.t.insert().values(mxid=self.mxid, mx_room=self.mx_room, tgid=self.tgid, + tg_space=self.tg_space)) + class UserPortal(Base): query = None # type: Query @@ -178,9 +244,11 @@ class TelegramFile(Base): thumbnail = relationship("TelegramFile", uselist=False) -def init(db_session) -> None: +def init(db_session, db_engine) -> None: Portal.query = db_session.query_property() - Message.query = db_session.query_property() + Message.db = db_engine + Message.t = Message.__table__ + Message.c = Message.t.c UserPortal.query = db_session.query_property() User.query = db_session.query_property() Puppet.query = db_session.query_property() diff --git a/mautrix_telegram/formatter/from_matrix/__init__.py b/mautrix_telegram/formatter/from_matrix/__init__.py index c95e85af..206165cf 100644 --- a/mautrix_telegram/formatter/from_matrix/__init__.py +++ b/mautrix_telegram/formatter/from_matrix/__init__.py @@ -105,9 +105,7 @@ def matrix_reply_to_telegram(content: Dict[str, Any], tg_space: TelegramID, pass content["body"] = trim_reply_fallback_text(content["body"]) - message = DBMessage.query.filter(DBMessage.mxid == event_id, - DBMessage.tg_space == tg_space, - DBMessage.mx_room == room_id).one_or_none() + message = DBMessage.get_by_mxid(event_id, room_id, tg_space) if message: return message.tgid except KeyError: diff --git a/mautrix_telegram/formatter/from_telegram.py b/mautrix_telegram/formatter/from_telegram.py index 86cd5b5e..eb132b47 100644 --- a/mautrix_telegram/formatter/from_telegram.py +++ b/mautrix_telegram/formatter/from_telegram.py @@ -54,7 +54,7 @@ def telegram_reply_to_matrix(evt: Message, source: 'AbstractUser') -> Dict: space = (evt.to_id.channel_id if isinstance(evt, Message) and isinstance(evt.to_id, PeerChannel) else source.tgid) - msg = DBMessage.query.get((evt.reply_to_msg_id, space)) + msg = DBMessage.get_by_tgid(evt.reply_to_msg_id, space) if msg: return { "m.in_reply_to": { @@ -124,7 +124,7 @@ async def _add_reply_header(source: "AbstractUser", text: str, html: str, evt: M if isinstance(evt, Message) and isinstance(evt.to_id, PeerChannel) else source.tgid) - msg = DBMessage.query.get((evt.reply_to_msg_id, space)) + msg = DBMessage.get_by_tgid(evt.reply_to_msg_id, space) if not msg: return text, html @@ -325,7 +325,7 @@ def _parse_url(html: List[str], entity_text: str, url: str) -> bool: portal = po.Portal.find_by_username(group) if portal: - message = DBMessage.query.get((msgid, portal.tgid)) + message = DBMessage.get_by_tgid(TelegramID(msgid), portal.tgid) if message: url = f"https://matrix.to/#/{portal.mxid}/{message.mxid}" diff --git a/mautrix_telegram/portal.py b/mautrix_telegram/portal.py index 3d865192..92e29839 100644 --- a/mautrix_telegram/portal.py +++ b/mautrix_telegram/portal.py @@ -772,9 +772,7 @@ class Portal: if user.is_bot: return space = self.tgid if self.peer_type == "channel" else user.tgid - message = DBMessage.query.filter(DBMessage.mxid == event_id, - DBMessage.mx_room == self.mxid, - DBMessage.tg_space == space).one_or_none() + message = DBMessage.get_by_mxid(event_id, self.mxid, space) if not message: return if self.peer_type == "channel": @@ -959,12 +957,11 @@ class Portal: response: TypeMessage) -> None: self.log.debug("Handled Matrix message: %s", response) self.is_duplicate(response, (event_id, space)) - self.db.add(DBMessage( + DBMessage( tgid=response.id, tg_space=space, mx_room=self.mxid, - mxid=event_id)) - self.db.commit() + mxid=event_id).insert() async def handle_matrix_message(self, sender: 'u.User', message: Dict[str, Any], event_id: MatrixEventID) -> None: @@ -1009,9 +1006,10 @@ class Portal: if not pinned_message: await sender.client(UpdatePinnedMessageRequest(channel=self.peer, id=0)) else: - message = DBMessage.query.filter(DBMessage.mxid == pinned_message, - DBMessage.tg_space == self.tgid, - DBMessage.mx_room == self.mxid).one_or_none() + message = DBMessage.get_by_mxid(pinned_message, self.mxid, self.tgid) + if message is None: + self.log.warning(f"Could not find pinned {pinned_message} in {self.mxid}") + return await sender.client(UpdatePinnedMessageRequest(channel=self.peer, id=message.tgid)) except ChatNotModifiedError: pass @@ -1019,9 +1017,7 @@ class Portal: async def handle_matrix_deletion(self, deleter: 'u.User', event_id: MatrixEventID) -> None: real_deleter = deleter if not await deleter.needs_relaybot(self) else self.bot space = self.tgid if self.peer_type == "channel" else real_deleter.tgid - message = DBMessage.query.filter(DBMessage.mxid == event_id, - DBMessage.tg_space == space, - DBMessage.mx_room == self.mxid).one_or_none() + message = DBMessage.get_by_mxid(event_id, self.mxid, space) if not message: return await real_deleter.client.delete_messages(self.peer, [message.tgid]) @@ -1413,10 +1409,9 @@ class Portal: if duplicate_found: mxid, other_tg_space = duplicate_found if tg_space != other_tg_space: - msg = DBMessage.query.get((evt.id, tg_space)) - msg.mxid = mxid - msg.mx_room = self.mxid - self.db.commit() + DBMessage.update_by_tgid(evt.id, tg_space, + mxid=mxid, + mx_room=self.mxid) return evt.reply_to_msg_id = evt.id @@ -1429,19 +1424,14 @@ class Portal: mxid = response["event_id"] - msg = DBMessage.query.get((evt.id, tg_space)) + msg = DBMessage.get_by_tgid(evt.id, tg_space) if not msg: self.log.info(f"Didn't find edited message {evt.id}@{tg_space} (src {source.tgid}) " "in database.") # Oh crap return - msg.mxid = mxid - msg.mx_room = self.mxid - DBMessage.query \ - .filter(DBMessage.mx_room == self.mxid, - DBMessage.mxid == temporary_identifier) \ - .update({"mxid": mxid}) - self.db.commit() + msg.update(mxid=mxid, mx_room=self.mxid) + DBMessage.update_by_mxid(temporary_identifier, self.mxid, mxid=mxid) async def handle_telegram_message(self, source: "AbstractUser", sender: p.Puppet, evt: Message) -> None: @@ -1463,13 +1453,11 @@ class Portal: self.log.debug(f"Ignoring message {evt.id}@{tg_space} (src {source.tgid}) " f"as it was already handled (in space {other_tg_space})") if tg_space != other_tg_space: - self.db.add( - DBMessage(tgid=evt.id, mx_room=self.mxid, mxid=mxid, tg_space=tg_space)) - self.db.commit() + DBMessage(tgid=evt.id, mx_room=self.mxid, mxid=mxid, tg_space=tg_space).insert() return if self.dedup_pre_db_check and self.peer_type == "channel": - msg = DBMessage.query.get((evt.id, tg_space)) + msg = DBMessage.get_by_tgid(evt.id, tg_space) if msg: self.log.debug(f"Ignoring message {evt.id} (src {source.tgid}) as it was already" f"handled into {msg.mxid}. This duplicate was catched in the db " @@ -1523,12 +1511,8 @@ class Portal: self.log.debug("Handled Telegram message: %s", evt) try: - self.db.add(DBMessage(tgid=evt.id, mx_room=self.mxid, mxid=mxid, tg_space=tg_space)) - self.db.commit() - DBMessage.query \ - .filter(DBMessage.mx_room == self.mxid, - DBMessage.mxid == temporary_identifier) \ - .update({"mxid": mxid}) + DBMessage(tgid=evt.id, mx_room=self.mxid, mxid=mxid, tg_space=tg_space).insert() + DBMessage.update_by_mxid(temporary_identifier, self.mxid, mxid=mxid) except FlushError as e: self.log.exception(f"{e.__class__.__name__} while saving message mapping. " "This might mean that an update was handled after it left the " @@ -1610,7 +1594,7 @@ class Portal: self._temp_pinned_message_id = None self._temp_pinned_message_sender = None - message = DBMessage.query.get((msg_id, self.tgid)) + message = DBMessage.get_by_tgid(msg_id, self.tgid) if message: await intent.set_pinned_messages(self.mxid, [message.mxid]) else: