Merge pull request #239 from tulir/sqlalchemy-core

Port Message table to SQLAlchemy Core
This commit is contained in:
Tulir Asokan
2018-10-21 00:32:14 +03:00
committed by GitHub
6 changed files with 102 additions and 53 deletions
+1 -1
View File
@@ -113,7 +113,7 @@ if config["appservice.provisioning.enabled"]:
context.provisioning_api = provisioning_api
with appserv.run(config["appservice.hostname"], config["appservice.port"]) as start:
init_db(db_session)
init_db(db_session, db_engine)
init_abstract_user(context)
context.bot = init_bot(context)
context.mx = MatrixHandler(context)
+6 -7
View File
@@ -230,7 +230,7 @@ class AbstractUser(ABC):
return
# We check that these are user read receipts, so tg_space is always the user ID.
message = DBMessage.query.get((update.max_id, self.tgid))
message = DBMessage.get_by_tgid(update.max_id, self.tgid)
if not message:
return
@@ -323,12 +323,11 @@ class AbstractUser(ABC):
return
for message in update.messages:
message = DBMessage.query.get((message, self.tgid))
message = DBMessage.get_by_tgid(TelegramID(message), self.tgid)
if not message:
continue
self.db.delete(message)
number_left = DBMessage.query.filter(DBMessage.mxid == message.mxid,
DBMessage.mx_room == message.mx_room).count()
message.delete()
number_left = DBMessage.count_spaces_by_mxid(message.mxid, message.mx_room)
if number_left == 0:
portal = po.Portal.get_by_mxid(message.mx_room)
await self._try_redact(portal, message)
@@ -343,10 +342,10 @@ class AbstractUser(ABC):
return
for message in update.messages:
message = DBMessage.query.get((message, portal.tgid))
message = DBMessage.get_by_tgid(TelegramID(message), portal.tgid)
if not message:
continue
self.db.delete(message)
message.delete()
await self._try_redact(portal, message)
self.db.commit()
+72 -4
View File
@@ -15,9 +15,12 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from sqlalchemy import (Column, UniqueConstraint, ForeignKey, ForeignKeyConstraint, Integer,
BigInteger, String, Boolean, Text)
BigInteger, String, Boolean, Text, Table,
and_, func, select)
from sqlalchemy.engine import Engine, RowProxy
from sqlalchemy.sql import expression
from sqlalchemy.orm import relationship, Query
from sqlalchemy.sql.base import ImmutableColumnCollection
from typing import Dict, Optional, List
import json
@@ -49,7 +52,9 @@ class Portal(Base):
class Message(Base):
query = None # type: Query
db = None # type: Engine
t = None # type: Table
c = None # type: ImmutableColumnCollection
__tablename__ = "message"
mxid = Column(String) # type: MatrixEventID
@@ -59,6 +64,67 @@ class Message(Base):
__table_args__ = (UniqueConstraint("mxid", "mx_room", "tg_space", name="_mx_id_room"),)
@staticmethod
def _one_or_none(rows: RowProxy) -> Optional['Message']:
try:
mxid, mx_room, tgid, tg_space = next(rows)
return Message(mxid=mxid, mx_room=mx_room, tgid=tgid, tg_space=tg_space)
except StopIteration:
return None
@staticmethod
def _all(rows: RowProxy) -> List['Message']:
return [Message(mxid=row[0], mx_room=row[1], tgid=row[2], tg_space=row[3])
for row in rows]
@classmethod
def get_by_tgid(cls, tgid: TelegramID, tg_space: TelegramID) -> Optional['Message']:
rows = cls.db.execute(cls.t.select()
.where(and_(cls.c.tgid == tgid, cls.c.tg_space == tg_space)))
return cls._one_or_none(rows)
@classmethod
def count_spaces_by_mxid(cls, mxid: MatrixEventID, mx_room: MatrixRoomID) -> int:
rows = cls.db.execute(select([func.count(cls.c.tg_space)])
.where(and_(cls.c.mxid == mxid, cls.c.mx_room == mx_room)))
try:
count, = next(rows)
return count
except StopIteration:
return 0
@classmethod
def get_by_mxid(cls, mxid: MatrixEventID, mx_room: MatrixRoomID, tg_space: TelegramID
) -> Optional['Message']:
rows = cls.db.execute(cls.t.select().where(
and_(cls.c.mxid == mxid, cls.c.mx_room == mx_room, cls.c.tg_space == tg_space)))
return cls._one_or_none(rows)
@classmethod
def update_by_tgid(cls, s_tgid: TelegramID, s_tg_space: TelegramID, **values) -> None:
cls.db.execute(cls.t.update()
.where(and_(cls.c.tgid == s_tgid, cls.c.tg_space == s_tg_space))
.values(**values))
@classmethod
def update_by_mxid(cls, s_mxid: MatrixEventID, s_mx_room: MatrixRoomID, **values) -> None:
cls.db.execute(cls.t.update()
.where(and_(cls.c.mxid == s_mxid, cls.c.mx_room == s_mx_room))
.values(**values))
def update(self, **values) -> None:
for key, value in values.items():
setattr(self, key, value)
self.update_by_tgid(self.tgid, self.tg_space, **values)
def delete(self) -> None:
self.db.execute(self.t.delete().where(
and_(self.c.tgid == self.tgid, self.c.tg_space == self.tg_space)))
def insert(self) -> None:
self.db.execute(self.t.insert().values(mxid=self.mxid, mx_room=self.mx_room, tgid=self.tgid,
tg_space=self.tg_space))
class UserPortal(Base):
query = None # type: Query
@@ -178,9 +244,11 @@ class TelegramFile(Base):
thumbnail = relationship("TelegramFile", uselist=False)
def init(db_session) -> None:
def init(db_session, db_engine) -> None:
Portal.query = db_session.query_property()
Message.query = db_session.query_property()
Message.db = db_engine
Message.t = Message.__table__
Message.c = Message.t.c
UserPortal.query = db_session.query_property()
User.query = db_session.query_property()
Puppet.query = db_session.query_property()
@@ -105,9 +105,7 @@ def matrix_reply_to_telegram(content: Dict[str, Any], tg_space: TelegramID,
pass
content["body"] = trim_reply_fallback_text(content["body"])
message = DBMessage.query.filter(DBMessage.mxid == event_id,
DBMessage.tg_space == tg_space,
DBMessage.mx_room == room_id).one_or_none()
message = DBMessage.get_by_mxid(event_id, room_id, tg_space)
if message:
return message.tgid
except KeyError:
+3 -3
View File
@@ -54,7 +54,7 @@ def telegram_reply_to_matrix(evt: Message, source: 'AbstractUser') -> Dict:
space = (evt.to_id.channel_id
if isinstance(evt, Message) and isinstance(evt.to_id, PeerChannel)
else source.tgid)
msg = DBMessage.query.get((evt.reply_to_msg_id, space))
msg = DBMessage.get_by_tgid(evt.reply_to_msg_id, space)
if msg:
return {
"m.in_reply_to": {
@@ -124,7 +124,7 @@ async def _add_reply_header(source: "AbstractUser", text: str, html: str, evt: M
if isinstance(evt, Message) and isinstance(evt.to_id, PeerChannel)
else source.tgid)
msg = DBMessage.query.get((evt.reply_to_msg_id, space))
msg = DBMessage.get_by_tgid(evt.reply_to_msg_id, space)
if not msg:
return text, html
@@ -325,7 +325,7 @@ def _parse_url(html: List[str], entity_text: str, url: str) -> bool:
portal = po.Portal.find_by_username(group)
if portal:
message = DBMessage.query.get((msgid, portal.tgid))
message = DBMessage.get_by_tgid(TelegramID(msgid), portal.tgid)
if message:
url = f"https://matrix.to/#/{portal.mxid}/{message.mxid}"
+19 -35
View File
@@ -772,9 +772,7 @@ class Portal:
if user.is_bot:
return
space = self.tgid if self.peer_type == "channel" else user.tgid
message = DBMessage.query.filter(DBMessage.mxid == event_id,
DBMessage.mx_room == self.mxid,
DBMessage.tg_space == space).one_or_none()
message = DBMessage.get_by_mxid(event_id, self.mxid, space)
if not message:
return
if self.peer_type == "channel":
@@ -959,12 +957,11 @@ class Portal:
response: TypeMessage) -> None:
self.log.debug("Handled Matrix message: %s", response)
self.is_duplicate(response, (event_id, space))
self.db.add(DBMessage(
DBMessage(
tgid=response.id,
tg_space=space,
mx_room=self.mxid,
mxid=event_id))
self.db.commit()
mxid=event_id).insert()
async def handle_matrix_message(self, sender: 'u.User', message: Dict[str, Any],
event_id: MatrixEventID) -> None:
@@ -1009,9 +1006,10 @@ class Portal:
if not pinned_message:
await sender.client(UpdatePinnedMessageRequest(channel=self.peer, id=0))
else:
message = DBMessage.query.filter(DBMessage.mxid == pinned_message,
DBMessage.tg_space == self.tgid,
DBMessage.mx_room == self.mxid).one_or_none()
message = DBMessage.get_by_mxid(pinned_message, self.mxid, self.tgid)
if message is None:
self.log.warning(f"Could not find pinned {pinned_message} in {self.mxid}")
return
await sender.client(UpdatePinnedMessageRequest(channel=self.peer, id=message.tgid))
except ChatNotModifiedError:
pass
@@ -1019,9 +1017,7 @@ class Portal:
async def handle_matrix_deletion(self, deleter: 'u.User', event_id: MatrixEventID) -> None:
real_deleter = deleter if not await deleter.needs_relaybot(self) else self.bot
space = self.tgid if self.peer_type == "channel" else real_deleter.tgid
message = DBMessage.query.filter(DBMessage.mxid == event_id,
DBMessage.tg_space == space,
DBMessage.mx_room == self.mxid).one_or_none()
message = DBMessage.get_by_mxid(event_id, self.mxid, space)
if not message:
return
await real_deleter.client.delete_messages(self.peer, [message.tgid])
@@ -1413,10 +1409,9 @@ class Portal:
if duplicate_found:
mxid, other_tg_space = duplicate_found
if tg_space != other_tg_space:
msg = DBMessage.query.get((evt.id, tg_space))
msg.mxid = mxid
msg.mx_room = self.mxid
self.db.commit()
DBMessage.update_by_tgid(evt.id, tg_space,
mxid=mxid,
mx_room=self.mxid)
return
evt.reply_to_msg_id = evt.id
@@ -1429,19 +1424,14 @@ class Portal:
mxid = response["event_id"]
msg = DBMessage.query.get((evt.id, tg_space))
msg = DBMessage.get_by_tgid(evt.id, tg_space)
if not msg:
self.log.info(f"Didn't find edited message {evt.id}@{tg_space} (src {source.tgid}) "
"in database.")
# Oh crap
return
msg.mxid = mxid
msg.mx_room = self.mxid
DBMessage.query \
.filter(DBMessage.mx_room == self.mxid,
DBMessage.mxid == temporary_identifier) \
.update({"mxid": mxid})
self.db.commit()
msg.update(mxid=mxid, mx_room=self.mxid)
DBMessage.update_by_mxid(temporary_identifier, self.mxid, mxid=mxid)
async def handle_telegram_message(self, source: "AbstractUser", sender: p.Puppet,
evt: Message) -> None:
@@ -1463,13 +1453,11 @@ class Portal:
self.log.debug(f"Ignoring message {evt.id}@{tg_space} (src {source.tgid}) "
f"as it was already handled (in space {other_tg_space})")
if tg_space != other_tg_space:
self.db.add(
DBMessage(tgid=evt.id, mx_room=self.mxid, mxid=mxid, tg_space=tg_space))
self.db.commit()
DBMessage(tgid=evt.id, mx_room=self.mxid, mxid=mxid, tg_space=tg_space).insert()
return
if self.dedup_pre_db_check and self.peer_type == "channel":
msg = DBMessage.query.get((evt.id, tg_space))
msg = DBMessage.get_by_tgid(evt.id, tg_space)
if msg:
self.log.debug(f"Ignoring message {evt.id} (src {source.tgid}) as it was already"
f"handled into {msg.mxid}. This duplicate was catched in the db "
@@ -1523,12 +1511,8 @@ class Portal:
self.log.debug("Handled Telegram message: %s", evt)
try:
self.db.add(DBMessage(tgid=evt.id, mx_room=self.mxid, mxid=mxid, tg_space=tg_space))
self.db.commit()
DBMessage.query \
.filter(DBMessage.mx_room == self.mxid,
DBMessage.mxid == temporary_identifier) \
.update({"mxid": mxid})
DBMessage(tgid=evt.id, mx_room=self.mxid, mxid=mxid, tg_space=tg_space).insert()
DBMessage.update_by_mxid(temporary_identifier, self.mxid, mxid=mxid)
except FlushError as e:
self.log.exception(f"{e.__class__.__name__} while saving message mapping. "
"This might mean that an update was handled after it left the "
@@ -1610,7 +1594,7 @@ class Portal:
self._temp_pinned_message_id = None
self._temp_pinned_message_sender = None
message = DBMessage.query.get((msg_id, self.tgid))
message = DBMessage.get_by_tgid(msg_id, self.tgid)
if message:
await intent.set_pinned_messages(self.mxid, [message.mxid])
else: