Start moving portals and users to SQLAlchemy Core

This commit is contained in:
Tulir Asokan
2019-02-12 01:19:12 +02:00
parent c028e1befc
commit 53489e7356
5 changed files with 216 additions and 127 deletions
+120 -77
View File
@@ -15,13 +15,12 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from sqlalchemy import (Column, UniqueConstraint, ForeignKey, ForeignKeyConstraint, Integer,
BigInteger, String, Boolean, Text, Table,
BigInteger, String, Boolean, Text,
and_, func, select)
from sqlalchemy.engine import Engine, RowProxy
from sqlalchemy.engine.result import RowProxy
from sqlalchemy.sql import expression
from sqlalchemy.orm import relationship, Query
from sqlalchemy.sql.base import ImmutableColumnCollection
from typing import Dict, Optional, List
from typing import Dict, Optional, List, Iterable
import json
from mautrix_telegram.types import MatrixUserID, MatrixRoomID, MatrixEventID
@@ -30,7 +29,6 @@ from .base import Base
class Portal(Base):
query = None # type: Query
__tablename__ = "portal"
# Telegram chat information
@@ -50,11 +48,41 @@ class Portal(Base):
about = Column(String, nullable=True)
photo_id = Column(String, nullable=True)
@classmethod
def _one_or_none(cls, rows: RowProxy) -> Optional['Portal']:
try:
(tgid, tg_receiver, peer_type, megagroup, mxid, config,
username, title, about, photo_id) = next(rows)
return cls(tgid=tgid, tg_receiver=tg_receiver, peer_type=peer_type,
megagroup=megagroup, mxid=mxid, config=config, username=username,
title=title, about=about, photo_id=photo_id)
except StopIteration:
return None
@classmethod
def get_by_tgid(cls, tgid: TelegramID, tg_receiver: TelegramID) -> Optional['Portal']:
return cls._select_one_or_none(and_(cls.c.tgid == tgid, cls.c.tg_receiver == tg_receiver))
@classmethod
def get_by_mxid(cls, mxid: MatrixRoomID) -> Optional['Portal']:
return cls._select_one_or_none(cls.c.mxid == mxid)
@classmethod
def get_by_username(cls, username: str) -> Optional['Portal']:
return cls._select_one_or_none(cls.c.username == username)
@property
def _edit_identity(self):
return and_(self.c.tgid == self.tgid, self.c.tg_receiver == self.tg_receiver)
def insert(self) -> None:
self.db.execute(self.t.insert().values(
tgid=self.tgid, tg_receiver=self.tg_receiver, peer_type=self.peer_type,
megagroup=self.megagroup, mxid=self.mxid, config=self.config, username=self.username,
title=self.title, about=self.about, photo_id=self.photo_id))
class Message(Base):
db = None # type: Engine
t = None # type: Table
c = None # type: ImmutableColumnCollection
__tablename__ = "message"
mxid = Column(String) # type: MatrixEventID
@@ -64,11 +92,11 @@ class Message(Base):
__table_args__ = (UniqueConstraint("mxid", "mx_room", "tg_space", name="_mx_id_room"),)
@staticmethod
def _one_or_none(rows: RowProxy) -> Optional['Message']:
@classmethod
def _one_or_none(cls, rows: RowProxy) -> Optional['Message']:
try:
mxid, mx_room, tgid, tg_space = next(rows)
return Message(mxid=mxid, mx_room=mx_room, tgid=tgid, tg_space=tg_space)
return cls(mxid=mxid, mx_room=mx_room, tgid=tgid, tg_space=tg_space)
except StopIteration:
return None
@@ -79,9 +107,7 @@ class Message(Base):
@classmethod
def get_by_tgid(cls, tgid: TelegramID, tg_space: TelegramID) -> Optional['Message']:
rows = cls.db.execute(cls.t.select()
.where(and_(cls.c.tgid == tgid, cls.c.tg_space == tg_space)))
return cls._one_or_none(rows)
return cls._select_one_or_none(and_(cls.c.tgid == tgid, cls.c.tg_space == tg_space))
@classmethod
def count_spaces_by_mxid(cls, mxid: MatrixEventID, mx_room: MatrixRoomID) -> int:
@@ -96,9 +122,9 @@ class Message(Base):
@classmethod
def get_by_mxid(cls, mxid: MatrixEventID, mx_room: MatrixRoomID, tg_space: TelegramID
) -> Optional['Message']:
rows = cls.db.execute(cls.t.select().where(
and_(cls.c.mxid == mxid, cls.c.mx_room == mx_room, cls.c.tg_space == tg_space)))
return cls._one_or_none(rows)
return cls._select_one_or_none(and_(cls.c.mxid == mxid,
cls.c.mx_room == mx_room,
cls.c.tg_space == tg_space))
@classmethod
def update_by_tgid(cls, s_tgid: TelegramID, s_tg_space: TelegramID, **values) -> None:
@@ -112,36 +138,16 @@ class Message(Base):
.where(and_(cls.c.mxid == s_mxid, cls.c.mx_room == s_mx_room))
.values(**values))
def update(self, **values) -> None:
for key, value in values.items():
setattr(self, key, value)
self.update_by_tgid(self.tgid, self.tg_space, **values)
def delete(self) -> None:
self.db.execute(self.t.delete().where(
and_(self.c.tgid == self.tgid, self.c.tg_space == self.tg_space)))
@property
def _edit_identity(self):
return and_(self.c.tgid == self.tgid, self.c.tg_space == self.tg_space)
def insert(self) -> None:
self.db.execute(self.t.insert().values(mxid=self.mxid, mx_room=self.mx_room, tgid=self.tgid,
tg_space=self.tg_space))
class UserPortal(Base):
query = None # type: Query
__tablename__ = "user_portal"
user = Column(Integer, ForeignKey("user.tgid", onupdate="CASCADE", ondelete="CASCADE"),
primary_key=True) # type: TelegramID
portal = Column(Integer, primary_key=True) # type: TelegramID
portal_receiver = Column(Integer, primary_key=True) # type: TelegramID
__table_args__ = (ForeignKeyConstraint(("portal", "portal_receiver"),
("portal.tgid", "portal.tg_receiver"),
onupdate="CASCADE", ondelete="CASCADE"),)
class User(Base):
query = None # type: Query
__tablename__ = "user"
mxid = Column(String, primary_key=True) # type: MatrixUserID
@@ -154,11 +160,66 @@ class User(Base):
) # type: List[Contact]
portals = relationship("Portal", secondary="user_portal")
@classmethod
def _one_or_none(cls, rows: RowProxy) -> Optional['User']:
try:
mxid, tgid, tg_username, tg_phone, saved_contacts = next(rows)
return cls(mxid=mxid, tgid=tgid, tg_username=tg_username, tg_phone=tg_phone,
saved_contacts=saved_contacts)
except StopIteration:
return None
@classmethod
def get_all(cls) -> Iterable['User']:
rows = cls.db.execute(cls.t.select())
for row in rows:
mxid, tgid, tg_username, tg_phone, saved_contacts = row
yield cls(mxid=mxid, tgid=tgid, tg_username=tg_username, tg_phone=tg_phone,
saved_contacts=saved_contacts)
@classmethod
def get_by_tgid(cls, tgid: TelegramID) -> Optional['User']:
return cls._select_one_or_none(cls.c.tgid == tgid)
@classmethod
def get_by_mxid(cls, mxid: MatrixRoomID) -> Optional['User']:
return cls._select_one_or_none(cls.c.mxid == mxid)
@classmethod
def get_by_username(cls, username: str) -> Optional['User']:
return cls._select_one_or_none(cls.c.username == username)
@property
def _edit_identity(self):
return self.c.mxid == self.mxid
def insert(self) -> None:
self.db.execute(self.t.insert().values(
mxid=self.mxid, tgid=self.tgid, tg_username=self.tg_username, tg_phone=self.tg_phone,
saved_contacts=self.saved_contacts))
class UserPortal(Base):
__tablename__ = "user_portal"
user = Column(Integer, ForeignKey("user.tgid", onupdate="CASCADE", ondelete="CASCADE"),
primary_key=True) # type: TelegramID
portal = Column(Integer, primary_key=True) # type: TelegramID
portal_receiver = Column(Integer, primary_key=True) # type: TelegramID
__table_args__ = (ForeignKeyConstraint(("portal", "portal_receiver"),
("portal.tgid", "portal.tg_receiver"),
onupdate="CASCADE", ondelete="CASCADE"),)
class Contact(Base):
__tablename__ = "contact"
user = Column(Integer, ForeignKey("user.tgid"), primary_key=True) # type: TelegramID
contact = Column(Integer, ForeignKey("puppet.id"), primary_key=True) # type: TelegramID
class RoomState(Base):
db = None # type: Engine
t = None # type: Table
c = None # type: ImmutableColumnCollection
__tablename__ = "mx_room_state"
room_id = Column(String, primary_key=True) # type: MatrixRoomID
@@ -177,18 +238,17 @@ class RoomState(Base):
rows = cls.db.execute(cls.t.select().where(cls.c.room_id == room_id))
try:
room_id, power_levels_text = next(rows)
return RoomState(room_id=room_id, power_levels=(json.loads(power_levels_text)
if power_levels_text else None))
return cls(room_id=room_id, power_levels=(json.loads(power_levels_text)
if power_levels_text else None))
except StopIteration:
return None
def update(self) -> None:
self.db.execute(self.t.update()
.where(self.c.room_id == self.room_id)
.values(power_levels=self._power_levels_text))
return super().update(power_levels=self._power_levels_text)
def delete(self) -> None:
self.db.execute(self.t.delete().where(self.c.room_id == self.room_id))
@property
def _edit_identity(self):
return self.c.room_id == self.room_id
def insert(self) -> None:
self.db.execute(self.t.insert().values(room_id=self.room_id,
@@ -196,9 +256,6 @@ class RoomState(Base):
class UserProfile(Base):
db = None # type: Engine
t = None # type: Table
c = None # type: ImmutableColumnCollection
__tablename__ = "mx_user_profile"
room_id = Column(String, primary_key=True) # type: MatrixRoomID
@@ -220,8 +277,8 @@ class UserProfile(Base):
cls.t.select().where(and_(cls.c.room_id == room_id, cls.c.user_id == user_id)))
try:
room_id, user_id, membership, displayname, avatar_url = next(rows)
return UserProfile(room_id=room_id, user_id=user_id, membership=membership,
displayname=displayname, avatar_url=avatar_url)
return cls(room_id=room_id, user_id=user_id, membership=membership,
displayname=displayname, avatar_url=avatar_url)
except StopIteration:
return None
@@ -230,14 +287,12 @@ class UserProfile(Base):
cls.db.execute(cls.t.delete().where(cls.c.room_id == room_id))
def update(self) -> None:
self.db.execute(self.t.update()
.where(and_(self.c.room_id == self.room_id, self.c.user_id == self.user_id))
.values(membership=self.membership, displayname=self.displayname,
avatar_url=self.avatar_url))
super().update(membership=self.membership, displayname=self.displayname,
avatar_url=self.avatar_url)
def delete(self) -> None:
self.db.execute(self.t.delete().where(and_(self.c.room_id == self.room_id,
self.c.user_id == self.user_id)))
@property
def _edit_identity(self):
return and_(self.c.room_id == self.room_id, self.c.user_id == self.user_id)
def insert(self) -> None:
self.db.execute(self.t.insert().values(room_id=self.room_id, user_id=self.user_id,
@@ -246,14 +301,6 @@ class UserProfile(Base):
avatar_url=self.avatar_url))
class Contact(Base):
query = None # type: Query
__tablename__ = "contact"
user = Column(Integer, ForeignKey("user.tgid"), primary_key=True) # type: TelegramID
contact = Column(Integer, ForeignKey("puppet.id"), primary_key=True) # type: TelegramID
class Puppet(Base):
query = None # type: Query
__tablename__ = "puppet"
@@ -278,9 +325,6 @@ class BotChat(Base):
class TelegramFile(Base):
db = None # type: Engine
t = None # type: Table
c = None # type: ImmutableColumnCollection
__tablename__ = "telegram_file"
id = Column(String, primary_key=True)
@@ -302,8 +346,8 @@ class TelegramFile(Base):
thumb = None
if thumb_id:
thumb = cls.get(thumb_id)
return TelegramFile(id=id, mxc=mxc, mime_type=mime, was_converted=conv, timestamp=ts,
size=s, width=w, height=h, thumbnail_id=thumb_id, thumbnail=thumb)
return cls(id=id, mxc=mxc, mime_type=mime, was_converted=conv, timestamp=ts,
size=s, width=w, height=h, thumbnail_id=thumb_id, thumbnail=thumb)
except StopIteration:
return None
@@ -316,8 +360,7 @@ class TelegramFile(Base):
def init(db_session, db_engine) -> None:
query = db_session.query_property()
for table in (Portal, Message, UserPortal, User, Puppet, BotChat, TelegramFile, UserProfile,
RoomState):
for table in (Portal, Message, User, Puppet, BotChat, TelegramFile, UserProfile, RoomState):
table.query = query
table.db = db_engine
table.t = table.__table__