Start moving portals and users to SQLAlchemy Core
This commit is contained in:
+120
-77
@@ -15,13 +15,12 @@
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from sqlalchemy import (Column, UniqueConstraint, ForeignKey, ForeignKeyConstraint, Integer,
|
||||
BigInteger, String, Boolean, Text, Table,
|
||||
BigInteger, String, Boolean, Text,
|
||||
and_, func, select)
|
||||
from sqlalchemy.engine import Engine, RowProxy
|
||||
from sqlalchemy.engine.result import RowProxy
|
||||
from sqlalchemy.sql import expression
|
||||
from sqlalchemy.orm import relationship, Query
|
||||
from sqlalchemy.sql.base import ImmutableColumnCollection
|
||||
from typing import Dict, Optional, List
|
||||
from typing import Dict, Optional, List, Iterable
|
||||
import json
|
||||
|
||||
from mautrix_telegram.types import MatrixUserID, MatrixRoomID, MatrixEventID
|
||||
@@ -30,7 +29,6 @@ from .base import Base
|
||||
|
||||
|
||||
class Portal(Base):
|
||||
query = None # type: Query
|
||||
__tablename__ = "portal"
|
||||
|
||||
# Telegram chat information
|
||||
@@ -50,11 +48,41 @@ class Portal(Base):
|
||||
about = Column(String, nullable=True)
|
||||
photo_id = Column(String, nullable=True)
|
||||
|
||||
@classmethod
|
||||
def _one_or_none(cls, rows: RowProxy) -> Optional['Portal']:
|
||||
try:
|
||||
(tgid, tg_receiver, peer_type, megagroup, mxid, config,
|
||||
username, title, about, photo_id) = next(rows)
|
||||
return cls(tgid=tgid, tg_receiver=tg_receiver, peer_type=peer_type,
|
||||
megagroup=megagroup, mxid=mxid, config=config, username=username,
|
||||
title=title, about=about, photo_id=photo_id)
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_by_tgid(cls, tgid: TelegramID, tg_receiver: TelegramID) -> Optional['Portal']:
|
||||
return cls._select_one_or_none(and_(cls.c.tgid == tgid, cls.c.tg_receiver == tg_receiver))
|
||||
|
||||
@classmethod
|
||||
def get_by_mxid(cls, mxid: MatrixRoomID) -> Optional['Portal']:
|
||||
return cls._select_one_or_none(cls.c.mxid == mxid)
|
||||
|
||||
@classmethod
|
||||
def get_by_username(cls, username: str) -> Optional['Portal']:
|
||||
return cls._select_one_or_none(cls.c.username == username)
|
||||
|
||||
@property
|
||||
def _edit_identity(self):
|
||||
return and_(self.c.tgid == self.tgid, self.c.tg_receiver == self.tg_receiver)
|
||||
|
||||
def insert(self) -> None:
|
||||
self.db.execute(self.t.insert().values(
|
||||
tgid=self.tgid, tg_receiver=self.tg_receiver, peer_type=self.peer_type,
|
||||
megagroup=self.megagroup, mxid=self.mxid, config=self.config, username=self.username,
|
||||
title=self.title, about=self.about, photo_id=self.photo_id))
|
||||
|
||||
|
||||
class Message(Base):
|
||||
db = None # type: Engine
|
||||
t = None # type: Table
|
||||
c = None # type: ImmutableColumnCollection
|
||||
__tablename__ = "message"
|
||||
|
||||
mxid = Column(String) # type: MatrixEventID
|
||||
@@ -64,11 +92,11 @@ class Message(Base):
|
||||
|
||||
__table_args__ = (UniqueConstraint("mxid", "mx_room", "tg_space", name="_mx_id_room"),)
|
||||
|
||||
@staticmethod
|
||||
def _one_or_none(rows: RowProxy) -> Optional['Message']:
|
||||
@classmethod
|
||||
def _one_or_none(cls, rows: RowProxy) -> Optional['Message']:
|
||||
try:
|
||||
mxid, mx_room, tgid, tg_space = next(rows)
|
||||
return Message(mxid=mxid, mx_room=mx_room, tgid=tgid, tg_space=tg_space)
|
||||
return cls(mxid=mxid, mx_room=mx_room, tgid=tgid, tg_space=tg_space)
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
@@ -79,9 +107,7 @@ class Message(Base):
|
||||
|
||||
@classmethod
|
||||
def get_by_tgid(cls, tgid: TelegramID, tg_space: TelegramID) -> Optional['Message']:
|
||||
rows = cls.db.execute(cls.t.select()
|
||||
.where(and_(cls.c.tgid == tgid, cls.c.tg_space == tg_space)))
|
||||
return cls._one_or_none(rows)
|
||||
return cls._select_one_or_none(and_(cls.c.tgid == tgid, cls.c.tg_space == tg_space))
|
||||
|
||||
@classmethod
|
||||
def count_spaces_by_mxid(cls, mxid: MatrixEventID, mx_room: MatrixRoomID) -> int:
|
||||
@@ -96,9 +122,9 @@ class Message(Base):
|
||||
@classmethod
|
||||
def get_by_mxid(cls, mxid: MatrixEventID, mx_room: MatrixRoomID, tg_space: TelegramID
|
||||
) -> Optional['Message']:
|
||||
rows = cls.db.execute(cls.t.select().where(
|
||||
and_(cls.c.mxid == mxid, cls.c.mx_room == mx_room, cls.c.tg_space == tg_space)))
|
||||
return cls._one_or_none(rows)
|
||||
return cls._select_one_or_none(and_(cls.c.mxid == mxid,
|
||||
cls.c.mx_room == mx_room,
|
||||
cls.c.tg_space == tg_space))
|
||||
|
||||
@classmethod
|
||||
def update_by_tgid(cls, s_tgid: TelegramID, s_tg_space: TelegramID, **values) -> None:
|
||||
@@ -112,36 +138,16 @@ class Message(Base):
|
||||
.where(and_(cls.c.mxid == s_mxid, cls.c.mx_room == s_mx_room))
|
||||
.values(**values))
|
||||
|
||||
def update(self, **values) -> None:
|
||||
for key, value in values.items():
|
||||
setattr(self, key, value)
|
||||
self.update_by_tgid(self.tgid, self.tg_space, **values)
|
||||
|
||||
def delete(self) -> None:
|
||||
self.db.execute(self.t.delete().where(
|
||||
and_(self.c.tgid == self.tgid, self.c.tg_space == self.tg_space)))
|
||||
@property
|
||||
def _edit_identity(self):
|
||||
return and_(self.c.tgid == self.tgid, self.c.tg_space == self.tg_space)
|
||||
|
||||
def insert(self) -> None:
|
||||
self.db.execute(self.t.insert().values(mxid=self.mxid, mx_room=self.mx_room, tgid=self.tgid,
|
||||
tg_space=self.tg_space))
|
||||
|
||||
|
||||
class UserPortal(Base):
|
||||
query = None # type: Query
|
||||
__tablename__ = "user_portal"
|
||||
|
||||
user = Column(Integer, ForeignKey("user.tgid", onupdate="CASCADE", ondelete="CASCADE"),
|
||||
primary_key=True) # type: TelegramID
|
||||
portal = Column(Integer, primary_key=True) # type: TelegramID
|
||||
portal_receiver = Column(Integer, primary_key=True) # type: TelegramID
|
||||
|
||||
__table_args__ = (ForeignKeyConstraint(("portal", "portal_receiver"),
|
||||
("portal.tgid", "portal.tg_receiver"),
|
||||
onupdate="CASCADE", ondelete="CASCADE"),)
|
||||
|
||||
|
||||
class User(Base):
|
||||
query = None # type: Query
|
||||
__tablename__ = "user"
|
||||
|
||||
mxid = Column(String, primary_key=True) # type: MatrixUserID
|
||||
@@ -154,11 +160,66 @@ class User(Base):
|
||||
) # type: List[Contact]
|
||||
portals = relationship("Portal", secondary="user_portal")
|
||||
|
||||
@classmethod
|
||||
def _one_or_none(cls, rows: RowProxy) -> Optional['User']:
|
||||
try:
|
||||
mxid, tgid, tg_username, tg_phone, saved_contacts = next(rows)
|
||||
return cls(mxid=mxid, tgid=tgid, tg_username=tg_username, tg_phone=tg_phone,
|
||||
saved_contacts=saved_contacts)
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_all(cls) -> Iterable['User']:
|
||||
rows = cls.db.execute(cls.t.select())
|
||||
for row in rows:
|
||||
mxid, tgid, tg_username, tg_phone, saved_contacts = row
|
||||
yield cls(mxid=mxid, tgid=tgid, tg_username=tg_username, tg_phone=tg_phone,
|
||||
saved_contacts=saved_contacts)
|
||||
|
||||
@classmethod
|
||||
def get_by_tgid(cls, tgid: TelegramID) -> Optional['User']:
|
||||
return cls._select_one_or_none(cls.c.tgid == tgid)
|
||||
|
||||
@classmethod
|
||||
def get_by_mxid(cls, mxid: MatrixRoomID) -> Optional['User']:
|
||||
return cls._select_one_or_none(cls.c.mxid == mxid)
|
||||
|
||||
@classmethod
|
||||
def get_by_username(cls, username: str) -> Optional['User']:
|
||||
return cls._select_one_or_none(cls.c.username == username)
|
||||
|
||||
@property
|
||||
def _edit_identity(self):
|
||||
return self.c.mxid == self.mxid
|
||||
|
||||
def insert(self) -> None:
|
||||
self.db.execute(self.t.insert().values(
|
||||
mxid=self.mxid, tgid=self.tgid, tg_username=self.tg_username, tg_phone=self.tg_phone,
|
||||
saved_contacts=self.saved_contacts))
|
||||
|
||||
|
||||
class UserPortal(Base):
|
||||
__tablename__ = "user_portal"
|
||||
|
||||
user = Column(Integer, ForeignKey("user.tgid", onupdate="CASCADE", ondelete="CASCADE"),
|
||||
primary_key=True) # type: TelegramID
|
||||
portal = Column(Integer, primary_key=True) # type: TelegramID
|
||||
portal_receiver = Column(Integer, primary_key=True) # type: TelegramID
|
||||
|
||||
__table_args__ = (ForeignKeyConstraint(("portal", "portal_receiver"),
|
||||
("portal.tgid", "portal.tg_receiver"),
|
||||
onupdate="CASCADE", ondelete="CASCADE"),)
|
||||
|
||||
|
||||
class Contact(Base):
|
||||
__tablename__ = "contact"
|
||||
|
||||
user = Column(Integer, ForeignKey("user.tgid"), primary_key=True) # type: TelegramID
|
||||
contact = Column(Integer, ForeignKey("puppet.id"), primary_key=True) # type: TelegramID
|
||||
|
||||
|
||||
class RoomState(Base):
|
||||
db = None # type: Engine
|
||||
t = None # type: Table
|
||||
c = None # type: ImmutableColumnCollection
|
||||
__tablename__ = "mx_room_state"
|
||||
|
||||
room_id = Column(String, primary_key=True) # type: MatrixRoomID
|
||||
@@ -177,18 +238,17 @@ class RoomState(Base):
|
||||
rows = cls.db.execute(cls.t.select().where(cls.c.room_id == room_id))
|
||||
try:
|
||||
room_id, power_levels_text = next(rows)
|
||||
return RoomState(room_id=room_id, power_levels=(json.loads(power_levels_text)
|
||||
if power_levels_text else None))
|
||||
return cls(room_id=room_id, power_levels=(json.loads(power_levels_text)
|
||||
if power_levels_text else None))
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
def update(self) -> None:
|
||||
self.db.execute(self.t.update()
|
||||
.where(self.c.room_id == self.room_id)
|
||||
.values(power_levels=self._power_levels_text))
|
||||
return super().update(power_levels=self._power_levels_text)
|
||||
|
||||
def delete(self) -> None:
|
||||
self.db.execute(self.t.delete().where(self.c.room_id == self.room_id))
|
||||
@property
|
||||
def _edit_identity(self):
|
||||
return self.c.room_id == self.room_id
|
||||
|
||||
def insert(self) -> None:
|
||||
self.db.execute(self.t.insert().values(room_id=self.room_id,
|
||||
@@ -196,9 +256,6 @@ class RoomState(Base):
|
||||
|
||||
|
||||
class UserProfile(Base):
|
||||
db = None # type: Engine
|
||||
t = None # type: Table
|
||||
c = None # type: ImmutableColumnCollection
|
||||
__tablename__ = "mx_user_profile"
|
||||
|
||||
room_id = Column(String, primary_key=True) # type: MatrixRoomID
|
||||
@@ -220,8 +277,8 @@ class UserProfile(Base):
|
||||
cls.t.select().where(and_(cls.c.room_id == room_id, cls.c.user_id == user_id)))
|
||||
try:
|
||||
room_id, user_id, membership, displayname, avatar_url = next(rows)
|
||||
return UserProfile(room_id=room_id, user_id=user_id, membership=membership,
|
||||
displayname=displayname, avatar_url=avatar_url)
|
||||
return cls(room_id=room_id, user_id=user_id, membership=membership,
|
||||
displayname=displayname, avatar_url=avatar_url)
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
@@ -230,14 +287,12 @@ class UserProfile(Base):
|
||||
cls.db.execute(cls.t.delete().where(cls.c.room_id == room_id))
|
||||
|
||||
def update(self) -> None:
|
||||
self.db.execute(self.t.update()
|
||||
.where(and_(self.c.room_id == self.room_id, self.c.user_id == self.user_id))
|
||||
.values(membership=self.membership, displayname=self.displayname,
|
||||
avatar_url=self.avatar_url))
|
||||
super().update(membership=self.membership, displayname=self.displayname,
|
||||
avatar_url=self.avatar_url)
|
||||
|
||||
def delete(self) -> None:
|
||||
self.db.execute(self.t.delete().where(and_(self.c.room_id == self.room_id,
|
||||
self.c.user_id == self.user_id)))
|
||||
@property
|
||||
def _edit_identity(self):
|
||||
return and_(self.c.room_id == self.room_id, self.c.user_id == self.user_id)
|
||||
|
||||
def insert(self) -> None:
|
||||
self.db.execute(self.t.insert().values(room_id=self.room_id, user_id=self.user_id,
|
||||
@@ -246,14 +301,6 @@ class UserProfile(Base):
|
||||
avatar_url=self.avatar_url))
|
||||
|
||||
|
||||
class Contact(Base):
|
||||
query = None # type: Query
|
||||
__tablename__ = "contact"
|
||||
|
||||
user = Column(Integer, ForeignKey("user.tgid"), primary_key=True) # type: TelegramID
|
||||
contact = Column(Integer, ForeignKey("puppet.id"), primary_key=True) # type: TelegramID
|
||||
|
||||
|
||||
class Puppet(Base):
|
||||
query = None # type: Query
|
||||
__tablename__ = "puppet"
|
||||
@@ -278,9 +325,6 @@ class BotChat(Base):
|
||||
|
||||
|
||||
class TelegramFile(Base):
|
||||
db = None # type: Engine
|
||||
t = None # type: Table
|
||||
c = None # type: ImmutableColumnCollection
|
||||
__tablename__ = "telegram_file"
|
||||
|
||||
id = Column(String, primary_key=True)
|
||||
@@ -302,8 +346,8 @@ class TelegramFile(Base):
|
||||
thumb = None
|
||||
if thumb_id:
|
||||
thumb = cls.get(thumb_id)
|
||||
return TelegramFile(id=id, mxc=mxc, mime_type=mime, was_converted=conv, timestamp=ts,
|
||||
size=s, width=w, height=h, thumbnail_id=thumb_id, thumbnail=thumb)
|
||||
return cls(id=id, mxc=mxc, mime_type=mime, was_converted=conv, timestamp=ts,
|
||||
size=s, width=w, height=h, thumbnail_id=thumb_id, thumbnail=thumb)
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
@@ -316,8 +360,7 @@ class TelegramFile(Base):
|
||||
|
||||
def init(db_session, db_engine) -> None:
|
||||
query = db_session.query_property()
|
||||
for table in (Portal, Message, UserPortal, User, Puppet, BotChat, TelegramFile, UserProfile,
|
||||
RoomState):
|
||||
for table in (Portal, Message, User, Puppet, BotChat, TelegramFile, UserProfile, RoomState):
|
||||
table.query = query
|
||||
table.db = db_engine
|
||||
table.t = table.__table__
|
||||
|
||||
Reference in New Issue
Block a user