Port Message table to SQLAlchemy Core
This commit is contained in:
@@ -113,7 +113,7 @@ if config["appservice.provisioning.enabled"]:
|
||||
context.provisioning_api = provisioning_api
|
||||
|
||||
with appserv.run(config["appservice.hostname"], config["appservice.port"]) as start:
|
||||
init_db(db_session)
|
||||
init_db(db_session, db_engine)
|
||||
init_abstract_user(context)
|
||||
context.bot = init_bot(context)
|
||||
context.mx = MatrixHandler(context)
|
||||
|
||||
@@ -230,7 +230,7 @@ class AbstractUser(ABC):
|
||||
return
|
||||
|
||||
# We check that these are user read receipts, so tg_space is always the user ID.
|
||||
message = DBMessage.query.get((update.max_id, self.tgid))
|
||||
message = DBMessage.get_by_tgid(update.max_id, self.tgid)
|
||||
if not message:
|
||||
return
|
||||
|
||||
@@ -323,12 +323,11 @@ class AbstractUser(ABC):
|
||||
return
|
||||
|
||||
for message in update.messages:
|
||||
message = DBMessage.query.get((message, self.tgid))
|
||||
message = DBMessage.get_by_tgid(TelegramID(message), self.tgid)
|
||||
if not message:
|
||||
continue
|
||||
self.db.delete(message)
|
||||
number_left = DBMessage.query.filter(DBMessage.mxid == message.mxid,
|
||||
DBMessage.mx_room == message.mx_room).count()
|
||||
message.delete()
|
||||
number_left = DBMessage.count_spaces_by_mxid(message.mxid, message.mx_room)
|
||||
if number_left == 0:
|
||||
portal = po.Portal.get_by_mxid(message.mx_room)
|
||||
await self._try_redact(portal, message)
|
||||
@@ -343,10 +342,10 @@ class AbstractUser(ABC):
|
||||
return
|
||||
|
||||
for message in update.messages:
|
||||
message = DBMessage.query.get((message, portal.tgid))
|
||||
message = DBMessage.get_by_tgid(TelegramID(message), portal.tgid)
|
||||
if not message:
|
||||
continue
|
||||
self.db.delete(message)
|
||||
message.delete()
|
||||
await self._try_redact(portal, message)
|
||||
self.db.commit()
|
||||
|
||||
|
||||
+72
-4
@@ -15,9 +15,12 @@
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from sqlalchemy import (Column, UniqueConstraint, ForeignKey, ForeignKeyConstraint, Integer,
|
||||
BigInteger, String, Boolean, Text)
|
||||
BigInteger, String, Boolean, Text, Table,
|
||||
and_, func, select)
|
||||
from sqlalchemy.engine import Engine, RowProxy
|
||||
from sqlalchemy.sql import expression
|
||||
from sqlalchemy.orm import relationship, Query
|
||||
from sqlalchemy.sql.base import ImmutableColumnCollection
|
||||
from typing import Dict, Optional, List
|
||||
import json
|
||||
|
||||
@@ -49,7 +52,9 @@ class Portal(Base):
|
||||
|
||||
|
||||
class Message(Base):
|
||||
query = None # type: Query
|
||||
db = None # type: Engine
|
||||
t = None # type: Table
|
||||
c = None # type: ImmutableColumnCollection
|
||||
__tablename__ = "message"
|
||||
|
||||
mxid = Column(String) # type: MatrixEventID
|
||||
@@ -59,6 +64,67 @@ class Message(Base):
|
||||
|
||||
__table_args__ = (UniqueConstraint("mxid", "mx_room", "tg_space", name="_mx_id_room"),)
|
||||
|
||||
@staticmethod
|
||||
def _one_or_none(rows: RowProxy) -> Optional['Message']:
|
||||
try:
|
||||
mxid, mx_room, tgid, tg_space = next(rows)
|
||||
return Message(mxid=mxid, mx_room=mx_room, tgid=tgid, tg_space=tg_space)
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _all(rows: RowProxy) -> List['Message']:
|
||||
return [Message(mxid=row[0], mx_room=row[1], tgid=row[2], tg_space=row[3])
|
||||
for row in rows]
|
||||
|
||||
@classmethod
|
||||
def get_by_tgid(cls, tgid: TelegramID, tg_space: TelegramID) -> Optional['Message']:
|
||||
rows = cls.db.execute(cls.t.select()
|
||||
.where(and_(cls.c.tgid == tgid, cls.c.tg_space == tg_space)))
|
||||
return cls._one_or_none(rows)
|
||||
|
||||
@classmethod
|
||||
def count_spaces_by_mxid(cls, mxid: MatrixEventID, mx_room: MatrixRoomID) -> int:
|
||||
rows = cls.db.execute(select([func.count(cls.c.tg_space)])
|
||||
.where(and_(cls.c.mxid == mxid, cls.c.mx_room == mx_room)))
|
||||
try:
|
||||
count, = next(rows)
|
||||
return count
|
||||
except StopIteration:
|
||||
return 0
|
||||
|
||||
@classmethod
|
||||
def get_by_mxid(cls, mxid: MatrixEventID, mx_room: MatrixRoomID, tg_space: TelegramID
|
||||
) -> Optional['Message']:
|
||||
rows = cls.db.execute(cls.t.select().where(
|
||||
and_(cls.c.mxid == mxid, cls.c.mx_room == mx_room, cls.c.tg_space == tg_space)))
|
||||
return cls._one_or_none(rows)
|
||||
|
||||
@classmethod
|
||||
def update_by_tgid(cls, s_tgid: TelegramID, s_tg_space: TelegramID, **values) -> None:
|
||||
cls.db.execute(cls.t.update()
|
||||
.where(and_(cls.c.tgid == s_tgid, cls.c.tg_space == s_tg_space))
|
||||
.values(**values))
|
||||
|
||||
@classmethod
|
||||
def update_by_mxid(cls, s_mxid: MatrixEventID, s_mx_room: MatrixRoomID, **values) -> None:
|
||||
cls.db.execute(cls.t.update()
|
||||
.where(and_(cls.c.mxid == s_mxid, cls.c.mx_room == s_mx_room))
|
||||
.values(**values))
|
||||
|
||||
def update(self, **values) -> None:
|
||||
for key, value in values.items():
|
||||
setattr(self, key, value)
|
||||
self.update_by_tgid(self.tgid, self.tg_space, **values)
|
||||
|
||||
def delete(self) -> None:
|
||||
self.db.execute(self.t.delete().where(
|
||||
and_(self.c.tgid == self.tgid, self.c.tg_space == self.tg_space)))
|
||||
|
||||
def insert(self) -> None:
|
||||
self.db.execute(self.t.insert().values(mxid=self.mxid, mx_room=self.mx_room, tgid=self.tgid,
|
||||
tg_space=self.tg_space))
|
||||
|
||||
|
||||
class UserPortal(Base):
|
||||
query = None # type: Query
|
||||
@@ -178,9 +244,11 @@ class TelegramFile(Base):
|
||||
thumbnail = relationship("TelegramFile", uselist=False)
|
||||
|
||||
|
||||
def init(db_session) -> None:
|
||||
def init(db_session, db_engine) -> None:
|
||||
Portal.query = db_session.query_property()
|
||||
Message.query = db_session.query_property()
|
||||
Message.db = db_engine
|
||||
Message.t = Message.__table__
|
||||
Message.c = Message.t.c
|
||||
UserPortal.query = db_session.query_property()
|
||||
User.query = db_session.query_property()
|
||||
Puppet.query = db_session.query_property()
|
||||
|
||||
@@ -105,9 +105,7 @@ def matrix_reply_to_telegram(content: Dict[str, Any], tg_space: TelegramID,
|
||||
pass
|
||||
content["body"] = trim_reply_fallback_text(content["body"])
|
||||
|
||||
message = DBMessage.query.filter(DBMessage.mxid == event_id,
|
||||
DBMessage.tg_space == tg_space,
|
||||
DBMessage.mx_room == room_id).one_or_none()
|
||||
message = DBMessage.get_by_mxid(event_id, room_id, tg_space)
|
||||
if message:
|
||||
return message.tgid
|
||||
except KeyError:
|
||||
|
||||
@@ -54,7 +54,7 @@ def telegram_reply_to_matrix(evt: Message, source: 'AbstractUser') -> Dict:
|
||||
space = (evt.to_id.channel_id
|
||||
if isinstance(evt, Message) and isinstance(evt.to_id, PeerChannel)
|
||||
else source.tgid)
|
||||
msg = DBMessage.query.get((evt.reply_to_msg_id, space))
|
||||
msg = DBMessage.get_by_tgid(evt.reply_to_msg_id, space)
|
||||
if msg:
|
||||
return {
|
||||
"m.in_reply_to": {
|
||||
@@ -124,7 +124,7 @@ async def _add_reply_header(source: "AbstractUser", text: str, html: str, evt: M
|
||||
if isinstance(evt, Message) and isinstance(evt.to_id, PeerChannel)
|
||||
else source.tgid)
|
||||
|
||||
msg = DBMessage.query.get((evt.reply_to_msg_id, space))
|
||||
msg = DBMessage.get_by_tgid(evt.reply_to_msg_id, space)
|
||||
if not msg:
|
||||
return text, html
|
||||
|
||||
@@ -325,7 +325,7 @@ def _parse_url(html: List[str], entity_text: str, url: str) -> bool:
|
||||
|
||||
portal = po.Portal.find_by_username(group)
|
||||
if portal:
|
||||
message = DBMessage.query.get((msgid, portal.tgid))
|
||||
message = DBMessage.get_by_tgid(TelegramID(msgid), portal.tgid)
|
||||
if message:
|
||||
url = f"https://matrix.to/#/{portal.mxid}/{message.mxid}"
|
||||
|
||||
|
||||
+19
-35
@@ -772,9 +772,7 @@ class Portal:
|
||||
if user.is_bot:
|
||||
return
|
||||
space = self.tgid if self.peer_type == "channel" else user.tgid
|
||||
message = DBMessage.query.filter(DBMessage.mxid == event_id,
|
||||
DBMessage.mx_room == self.mxid,
|
||||
DBMessage.tg_space == space).one_or_none()
|
||||
message = DBMessage.get_by_mxid(event_id, self.mxid, space)
|
||||
if not message:
|
||||
return
|
||||
if self.peer_type == "channel":
|
||||
@@ -959,12 +957,11 @@ class Portal:
|
||||
response: TypeMessage) -> None:
|
||||
self.log.debug("Handled Matrix message: %s", response)
|
||||
self.is_duplicate(response, (event_id, space))
|
||||
self.db.add(DBMessage(
|
||||
DBMessage(
|
||||
tgid=response.id,
|
||||
tg_space=space,
|
||||
mx_room=self.mxid,
|
||||
mxid=event_id))
|
||||
self.db.commit()
|
||||
mxid=event_id).insert()
|
||||
|
||||
async def handle_matrix_message(self, sender: 'u.User', message: Dict[str, Any],
|
||||
event_id: MatrixEventID) -> None:
|
||||
@@ -1009,9 +1006,10 @@ class Portal:
|
||||
if not pinned_message:
|
||||
await sender.client(UpdatePinnedMessageRequest(channel=self.peer, id=0))
|
||||
else:
|
||||
message = DBMessage.query.filter(DBMessage.mxid == pinned_message,
|
||||
DBMessage.tg_space == self.tgid,
|
||||
DBMessage.mx_room == self.mxid).one_or_none()
|
||||
message = DBMessage.get_by_mxid(pinned_message, self.mxid, self.tgid)
|
||||
if message is None:
|
||||
self.log.warning(f"Could not find pinned {pinned_message} in {self.mxid}")
|
||||
return
|
||||
await sender.client(UpdatePinnedMessageRequest(channel=self.peer, id=message.tgid))
|
||||
except ChatNotModifiedError:
|
||||
pass
|
||||
@@ -1019,9 +1017,7 @@ class Portal:
|
||||
async def handle_matrix_deletion(self, deleter: 'u.User', event_id: MatrixEventID) -> None:
|
||||
real_deleter = deleter if not await deleter.needs_relaybot(self) else self.bot
|
||||
space = self.tgid if self.peer_type == "channel" else real_deleter.tgid
|
||||
message = DBMessage.query.filter(DBMessage.mxid == event_id,
|
||||
DBMessage.tg_space == space,
|
||||
DBMessage.mx_room == self.mxid).one_or_none()
|
||||
message = DBMessage.get_by_mxid(event_id, self.mxid, space)
|
||||
if not message:
|
||||
return
|
||||
await real_deleter.client.delete_messages(self.peer, [message.tgid])
|
||||
@@ -1413,10 +1409,9 @@ class Portal:
|
||||
if duplicate_found:
|
||||
mxid, other_tg_space = duplicate_found
|
||||
if tg_space != other_tg_space:
|
||||
msg = DBMessage.query.get((evt.id, tg_space))
|
||||
msg.mxid = mxid
|
||||
msg.mx_room = self.mxid
|
||||
self.db.commit()
|
||||
DBMessage.update_by_tgid(evt.id, tg_space,
|
||||
mxid=mxid,
|
||||
mx_room=self.mxid)
|
||||
return
|
||||
|
||||
evt.reply_to_msg_id = evt.id
|
||||
@@ -1429,19 +1424,14 @@ class Portal:
|
||||
|
||||
mxid = response["event_id"]
|
||||
|
||||
msg = DBMessage.query.get((evt.id, tg_space))
|
||||
msg = DBMessage.get_by_tgid(evt.id, tg_space)
|
||||
if not msg:
|
||||
self.log.info(f"Didn't find edited message {evt.id}@{tg_space} (src {source.tgid}) "
|
||||
"in database.")
|
||||
# Oh crap
|
||||
return
|
||||
msg.mxid = mxid
|
||||
msg.mx_room = self.mxid
|
||||
DBMessage.query \
|
||||
.filter(DBMessage.mx_room == self.mxid,
|
||||
DBMessage.mxid == temporary_identifier) \
|
||||
.update({"mxid": mxid})
|
||||
self.db.commit()
|
||||
msg.update(mxid=mxid, mx_room=self.mxid)
|
||||
DBMessage.update_by_mxid(temporary_identifier, self.mxid, mxid=mxid)
|
||||
|
||||
async def handle_telegram_message(self, source: "AbstractUser", sender: p.Puppet,
|
||||
evt: Message) -> None:
|
||||
@@ -1463,13 +1453,11 @@ class Portal:
|
||||
self.log.debug(f"Ignoring message {evt.id}@{tg_space} (src {source.tgid}) "
|
||||
f"as it was already handled (in space {other_tg_space})")
|
||||
if tg_space != other_tg_space:
|
||||
self.db.add(
|
||||
DBMessage(tgid=evt.id, mx_room=self.mxid, mxid=mxid, tg_space=tg_space))
|
||||
self.db.commit()
|
||||
DBMessage(tgid=evt.id, mx_room=self.mxid, mxid=mxid, tg_space=tg_space).insert()
|
||||
return
|
||||
|
||||
if self.dedup_pre_db_check and self.peer_type == "channel":
|
||||
msg = DBMessage.query.get((evt.id, tg_space))
|
||||
msg = DBMessage.get_by_tgid(evt.id, tg_space)
|
||||
if msg:
|
||||
self.log.debug(f"Ignoring message {evt.id} (src {source.tgid}) as it was already"
|
||||
f"handled into {msg.mxid}. This duplicate was catched in the db "
|
||||
@@ -1523,12 +1511,8 @@ class Portal:
|
||||
|
||||
self.log.debug("Handled Telegram message: %s", evt)
|
||||
try:
|
||||
self.db.add(DBMessage(tgid=evt.id, mx_room=self.mxid, mxid=mxid, tg_space=tg_space))
|
||||
self.db.commit()
|
||||
DBMessage.query \
|
||||
.filter(DBMessage.mx_room == self.mxid,
|
||||
DBMessage.mxid == temporary_identifier) \
|
||||
.update({"mxid": mxid})
|
||||
DBMessage(tgid=evt.id, mx_room=self.mxid, mxid=mxid, tg_space=tg_space).insert()
|
||||
DBMessage.update_by_mxid(temporary_identifier, self.mxid, mxid=mxid)
|
||||
except FlushError as e:
|
||||
self.log.exception(f"{e.__class__.__name__} while saving message mapping. "
|
||||
"This might mean that an update was handled after it left the "
|
||||
@@ -1610,7 +1594,7 @@ class Portal:
|
||||
self._temp_pinned_message_id = None
|
||||
self._temp_pinned_message_sender = None
|
||||
|
||||
message = DBMessage.query.get((msg_id, self.tgid))
|
||||
message = DBMessage.get_by_tgid(msg_id, self.tgid)
|
||||
if message:
|
||||
await intent.set_pinned_messages(self.mxid, [message.mxid])
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user