Start moving portals and users to SQLAlchemy Core

This commit is contained in:
Tulir Asokan
2019-02-12 01:19:12 +02:00
parent c028e1befc
commit 53489e7356
5 changed files with 216 additions and 127 deletions
+40 -1
View File
@@ -1,2 +1,41 @@
from abc import abstractmethod
from sqlalchemy import Table
from sqlalchemy.engine.base import Engine
from sqlalchemy.engine.result import RowProxy
from sqlalchemy.sql.base import ImmutableColumnCollection
from sqlalchemy.ext.declarative import declarative_base
Base = declarative_base() # type: declarative_base
class BaseBase:
db = None # type: Engine
t = None # type: Table
__table__ = None # type: Table
c = None # type: ImmutableColumnCollection
@classmethod
@abstractmethod
def _one_or_none(cls, rows: RowProxy):
pass
@classmethod
def _select_one_or_none(cls, *args):
return cls._one_or_none(cls.db.execute(cls.t.select().where(*args)))
@property
@abstractmethod
def _edit_identity(self):
pass
def update(self, **values) -> None:
self.db.execute(self.t.update()
.where(self._edit_identity)
.values(**values))
for key, value in values.items():
setattr(self, key, value)
def delete(self) -> None:
self.db.execute(self.t.delete().where(self._edit_identity))
Base = declarative_base(cls=BaseBase)
+26
View File
@@ -0,0 +1,26 @@
from abc import abstractmethod
from sqlalchemy import Table
from sqlalchemy.engine.base import Engine
from sqlalchemy.engine.result import RowProxy
from sqlalchemy.sql.base import ImmutableColumnCollection
from sqlalchemy.ext.declarative import declarative_base
class Base(declarative_base):
db: Engine
t: Table
__table__: Table
c: ImmutableColumnCollection
@classmethod
@abstractmethod
def _one_or_none(cls, rows: RowProxy): ...
@classmethod
def _select_one_or_none(cls, *args): ...
def _edit_identity(self): ...
def update(self, **values) -> None: ...
def delete(self) -> None: ...
+120 -77
View File
@@ -15,13 +15,12 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from sqlalchemy import (Column, UniqueConstraint, ForeignKey, ForeignKeyConstraint, Integer,
BigInteger, String, Boolean, Text, Table,
BigInteger, String, Boolean, Text,
and_, func, select)
from sqlalchemy.engine import Engine, RowProxy
from sqlalchemy.engine.result import RowProxy
from sqlalchemy.sql import expression
from sqlalchemy.orm import relationship, Query
from sqlalchemy.sql.base import ImmutableColumnCollection
from typing import Dict, Optional, List
from typing import Dict, Optional, List, Iterable
import json
from mautrix_telegram.types import MatrixUserID, MatrixRoomID, MatrixEventID
@@ -30,7 +29,6 @@ from .base import Base
class Portal(Base):
query = None # type: Query
__tablename__ = "portal"
# Telegram chat information
@@ -50,11 +48,41 @@ class Portal(Base):
about = Column(String, nullable=True)
photo_id = Column(String, nullable=True)
@classmethod
def _one_or_none(cls, rows: RowProxy) -> Optional['Portal']:
try:
(tgid, tg_receiver, peer_type, megagroup, mxid, config,
username, title, about, photo_id) = next(rows)
return cls(tgid=tgid, tg_receiver=tg_receiver, peer_type=peer_type,
megagroup=megagroup, mxid=mxid, config=config, username=username,
title=title, about=about, photo_id=photo_id)
except StopIteration:
return None
@classmethod
def get_by_tgid(cls, tgid: TelegramID, tg_receiver: TelegramID) -> Optional['Portal']:
return cls._select_one_or_none(and_(cls.c.tgid == tgid, cls.c.tg_receiver == tg_receiver))
@classmethod
def get_by_mxid(cls, mxid: MatrixRoomID) -> Optional['Portal']:
return cls._select_one_or_none(cls.c.mxid == mxid)
@classmethod
def get_by_username(cls, username: str) -> Optional['Portal']:
return cls._select_one_or_none(cls.c.username == username)
@property
def _edit_identity(self):
return and_(self.c.tgid == self.tgid, self.c.tg_receiver == self.tg_receiver)
def insert(self) -> None:
self.db.execute(self.t.insert().values(
tgid=self.tgid, tg_receiver=self.tg_receiver, peer_type=self.peer_type,
megagroup=self.megagroup, mxid=self.mxid, config=self.config, username=self.username,
title=self.title, about=self.about, photo_id=self.photo_id))
class Message(Base):
db = None # type: Engine
t = None # type: Table
c = None # type: ImmutableColumnCollection
__tablename__ = "message"
mxid = Column(String) # type: MatrixEventID
@@ -64,11 +92,11 @@ class Message(Base):
__table_args__ = (UniqueConstraint("mxid", "mx_room", "tg_space", name="_mx_id_room"),)
@staticmethod
def _one_or_none(rows: RowProxy) -> Optional['Message']:
@classmethod
def _one_or_none(cls, rows: RowProxy) -> Optional['Message']:
try:
mxid, mx_room, tgid, tg_space = next(rows)
return Message(mxid=mxid, mx_room=mx_room, tgid=tgid, tg_space=tg_space)
return cls(mxid=mxid, mx_room=mx_room, tgid=tgid, tg_space=tg_space)
except StopIteration:
return None
@@ -79,9 +107,7 @@ class Message(Base):
@classmethod
def get_by_tgid(cls, tgid: TelegramID, tg_space: TelegramID) -> Optional['Message']:
rows = cls.db.execute(cls.t.select()
.where(and_(cls.c.tgid == tgid, cls.c.tg_space == tg_space)))
return cls._one_or_none(rows)
return cls._select_one_or_none(and_(cls.c.tgid == tgid, cls.c.tg_space == tg_space))
@classmethod
def count_spaces_by_mxid(cls, mxid: MatrixEventID, mx_room: MatrixRoomID) -> int:
@@ -96,9 +122,9 @@ class Message(Base):
@classmethod
def get_by_mxid(cls, mxid: MatrixEventID, mx_room: MatrixRoomID, tg_space: TelegramID
) -> Optional['Message']:
rows = cls.db.execute(cls.t.select().where(
and_(cls.c.mxid == mxid, cls.c.mx_room == mx_room, cls.c.tg_space == tg_space)))
return cls._one_or_none(rows)
return cls._select_one_or_none(and_(cls.c.mxid == mxid,
cls.c.mx_room == mx_room,
cls.c.tg_space == tg_space))
@classmethod
def update_by_tgid(cls, s_tgid: TelegramID, s_tg_space: TelegramID, **values) -> None:
@@ -112,36 +138,16 @@ class Message(Base):
.where(and_(cls.c.mxid == s_mxid, cls.c.mx_room == s_mx_room))
.values(**values))
def update(self, **values) -> None:
for key, value in values.items():
setattr(self, key, value)
self.update_by_tgid(self.tgid, self.tg_space, **values)
def delete(self) -> None:
self.db.execute(self.t.delete().where(
and_(self.c.tgid == self.tgid, self.c.tg_space == self.tg_space)))
@property
def _edit_identity(self):
return and_(self.c.tgid == self.tgid, self.c.tg_space == self.tg_space)
def insert(self) -> None:
self.db.execute(self.t.insert().values(mxid=self.mxid, mx_room=self.mx_room, tgid=self.tgid,
tg_space=self.tg_space))
class UserPortal(Base):
query = None # type: Query
__tablename__ = "user_portal"
user = Column(Integer, ForeignKey("user.tgid", onupdate="CASCADE", ondelete="CASCADE"),
primary_key=True) # type: TelegramID
portal = Column(Integer, primary_key=True) # type: TelegramID
portal_receiver = Column(Integer, primary_key=True) # type: TelegramID
__table_args__ = (ForeignKeyConstraint(("portal", "portal_receiver"),
("portal.tgid", "portal.tg_receiver"),
onupdate="CASCADE", ondelete="CASCADE"),)
class User(Base):
query = None # type: Query
__tablename__ = "user"
mxid = Column(String, primary_key=True) # type: MatrixUserID
@@ -154,11 +160,66 @@ class User(Base):
) # type: List[Contact]
portals = relationship("Portal", secondary="user_portal")
@classmethod
def _one_or_none(cls, rows: RowProxy) -> Optional['User']:
try:
mxid, tgid, tg_username, tg_phone, saved_contacts = next(rows)
return cls(mxid=mxid, tgid=tgid, tg_username=tg_username, tg_phone=tg_phone,
saved_contacts=saved_contacts)
except StopIteration:
return None
@classmethod
def get_all(cls) -> Iterable['User']:
rows = cls.db.execute(cls.t.select())
for row in rows:
mxid, tgid, tg_username, tg_phone, saved_contacts = row
yield cls(mxid=mxid, tgid=tgid, tg_username=tg_username, tg_phone=tg_phone,
saved_contacts=saved_contacts)
@classmethod
def get_by_tgid(cls, tgid: TelegramID) -> Optional['User']:
return cls._select_one_or_none(cls.c.tgid == tgid)
@classmethod
def get_by_mxid(cls, mxid: MatrixRoomID) -> Optional['User']:
return cls._select_one_or_none(cls.c.mxid == mxid)
@classmethod
def get_by_username(cls, username: str) -> Optional['User']:
return cls._select_one_or_none(cls.c.username == username)
@property
def _edit_identity(self):
return self.c.mxid == self.mxid
def insert(self) -> None:
self.db.execute(self.t.insert().values(
mxid=self.mxid, tgid=self.tgid, tg_username=self.tg_username, tg_phone=self.tg_phone,
saved_contacts=self.saved_contacts))
class UserPortal(Base):
__tablename__ = "user_portal"
user = Column(Integer, ForeignKey("user.tgid", onupdate="CASCADE", ondelete="CASCADE"),
primary_key=True) # type: TelegramID
portal = Column(Integer, primary_key=True) # type: TelegramID
portal_receiver = Column(Integer, primary_key=True) # type: TelegramID
__table_args__ = (ForeignKeyConstraint(("portal", "portal_receiver"),
("portal.tgid", "portal.tg_receiver"),
onupdate="CASCADE", ondelete="CASCADE"),)
class Contact(Base):
__tablename__ = "contact"
user = Column(Integer, ForeignKey("user.tgid"), primary_key=True) # type: TelegramID
contact = Column(Integer, ForeignKey("puppet.id"), primary_key=True) # type: TelegramID
class RoomState(Base):
db = None # type: Engine
t = None # type: Table
c = None # type: ImmutableColumnCollection
__tablename__ = "mx_room_state"
room_id = Column(String, primary_key=True) # type: MatrixRoomID
@@ -177,18 +238,17 @@ class RoomState(Base):
rows = cls.db.execute(cls.t.select().where(cls.c.room_id == room_id))
try:
room_id, power_levels_text = next(rows)
return RoomState(room_id=room_id, power_levels=(json.loads(power_levels_text)
if power_levels_text else None))
return cls(room_id=room_id, power_levels=(json.loads(power_levels_text)
if power_levels_text else None))
except StopIteration:
return None
def update(self) -> None:
self.db.execute(self.t.update()
.where(self.c.room_id == self.room_id)
.values(power_levels=self._power_levels_text))
return super().update(power_levels=self._power_levels_text)
def delete(self) -> None:
self.db.execute(self.t.delete().where(self.c.room_id == self.room_id))
@property
def _edit_identity(self):
return self.c.room_id == self.room_id
def insert(self) -> None:
self.db.execute(self.t.insert().values(room_id=self.room_id,
@@ -196,9 +256,6 @@ class RoomState(Base):
class UserProfile(Base):
db = None # type: Engine
t = None # type: Table
c = None # type: ImmutableColumnCollection
__tablename__ = "mx_user_profile"
room_id = Column(String, primary_key=True) # type: MatrixRoomID
@@ -220,8 +277,8 @@ class UserProfile(Base):
cls.t.select().where(and_(cls.c.room_id == room_id, cls.c.user_id == user_id)))
try:
room_id, user_id, membership, displayname, avatar_url = next(rows)
return UserProfile(room_id=room_id, user_id=user_id, membership=membership,
displayname=displayname, avatar_url=avatar_url)
return cls(room_id=room_id, user_id=user_id, membership=membership,
displayname=displayname, avatar_url=avatar_url)
except StopIteration:
return None
@@ -230,14 +287,12 @@ class UserProfile(Base):
cls.db.execute(cls.t.delete().where(cls.c.room_id == room_id))
def update(self) -> None:
self.db.execute(self.t.update()
.where(and_(self.c.room_id == self.room_id, self.c.user_id == self.user_id))
.values(membership=self.membership, displayname=self.displayname,
avatar_url=self.avatar_url))
super().update(membership=self.membership, displayname=self.displayname,
avatar_url=self.avatar_url)
def delete(self) -> None:
self.db.execute(self.t.delete().where(and_(self.c.room_id == self.room_id,
self.c.user_id == self.user_id)))
@property
def _edit_identity(self):
return and_(self.c.room_id == self.room_id, self.c.user_id == self.user_id)
def insert(self) -> None:
self.db.execute(self.t.insert().values(room_id=self.room_id, user_id=self.user_id,
@@ -246,14 +301,6 @@ class UserProfile(Base):
avatar_url=self.avatar_url))
class Contact(Base):
query = None # type: Query
__tablename__ = "contact"
user = Column(Integer, ForeignKey("user.tgid"), primary_key=True) # type: TelegramID
contact = Column(Integer, ForeignKey("puppet.id"), primary_key=True) # type: TelegramID
class Puppet(Base):
query = None # type: Query
__tablename__ = "puppet"
@@ -278,9 +325,6 @@ class BotChat(Base):
class TelegramFile(Base):
db = None # type: Engine
t = None # type: Table
c = None # type: ImmutableColumnCollection
__tablename__ = "telegram_file"
id = Column(String, primary_key=True)
@@ -302,8 +346,8 @@ class TelegramFile(Base):
thumb = None
if thumb_id:
thumb = cls.get(thumb_id)
return TelegramFile(id=id, mxc=mxc, mime_type=mime, was_converted=conv, timestamp=ts,
size=s, width=w, height=h, thumbnail_id=thumb_id, thumbnail=thumb)
return cls(id=id, mxc=mxc, mime_type=mime, was_converted=conv, timestamp=ts,
size=s, width=w, height=h, thumbnail_id=thumb_id, thumbnail=thumb)
except StopIteration:
return None
@@ -316,8 +360,7 @@ class TelegramFile(Base):
def init(db_session, db_engine) -> None:
query = db_session.query_property()
for table in (Portal, Message, UserPortal, User, Puppet, BotChat, TelegramFile, UserProfile,
RoomState):
for table in (Portal, Message, User, Puppet, BotChat, TelegramFile, UserProfile, RoomState):
table.query = query
table.db = db_engine
table.t = table.__table__
+15 -26
View File
@@ -31,7 +31,6 @@ import json
import re
import magic
from sqlalchemy import orm
from sqlalchemy.exc import IntegrityError
from telethon.tl.functions.messages import (
@@ -89,7 +88,6 @@ InviteList = Union[MatrixUserID, List[MatrixUserID]]
class Portal:
log = logging.getLogger("mau.portal") # type: logging.Logger
db = None # type: orm.Session
az = None # type: AppService
bot = None # type: Bot
loop = None # type: asyncio.AbstractEventLoop
@@ -1255,8 +1253,7 @@ class Portal:
self.tg_receiver = self.tgid
self.by_tgid[self.tgid_full] = self
await self.update_info(source, entity)
self.db.add(self.db_instance)
self.save()
self.db_instance.insert()
if self.bot and self.bot.tgid in invites:
self.bot.add_chat(self.tgid, self.peer_type)
@@ -1842,15 +1839,13 @@ class Portal:
del self.by_tgid[self.tgid_full]
except KeyError:
pass
self.tgid = new_id
self.tg_receiver = new_id
existing = self.by_tgid[self.tgid_full]
existing = self.by_tgid[(new_id, new_id)]
if existing:
existing.delete()
self.db_instance.update(tgid=new_id, tg_receiver=new_id)
self.tgid = new_id
self.tg_receiver = new_id
self.by_tgid[self.tgid_full] = self
self.db_instance.tgid = self.tgid
self.db_instance.tg_receiver = self.tg_receiver
self.save()
def migrate_and_save_matrix(self, new_id: MatrixRoomID) -> None:
try:
@@ -1858,17 +1853,13 @@ class Portal:
except KeyError:
pass
self.mxid = new_id
self.db_instance.update(mxid=self.mxid)
self.by_mxid[self.mxid] = self
self.save()
def save(self) -> None:
self.db_instance.mxid = self.mxid
self.db_instance.username = self.username
self.db_instance.title = self.title
self.db_instance.about = self.about
self.db_instance.photo_id = self.photo_id
self.db_instance.config = json.dumps(self.local_config)
self.db.commit()
self.db_instance.update(mxid=self.mxid, username=self.username, title=self.title,
about=self.about, photo_id=self.photo_id,
config=json.dumps(self.local_config))
def delete(self) -> None:
try:
@@ -1880,8 +1871,7 @@ class Portal:
except KeyError:
pass
if self._db_instance:
self.db.delete(self._db_instance)
self.db.commit()
self._db_instance.delete()
self.deleted = True
@classmethod
@@ -1902,7 +1892,7 @@ class Portal:
except KeyError:
pass
portal = DBPortal.query.filter(DBPortal.mxid == mxid).one_or_none()
portal = DBPortal.get_by_mxid(mxid)
if portal:
return cls.from_db(portal)
@@ -1924,7 +1914,7 @@ class Portal:
if portal.username and portal.username.lower() == username.lower():
return portal
dbportal = DBPortal.query.filter(DBPortal.username == username).one_or_none()
dbportal = DBPortal.get_by_username(username)
if dbportal:
return cls.from_db(dbportal)
@@ -1940,14 +1930,13 @@ class Portal:
except KeyError:
pass
portal = DBPortal.query.get(tgid_full)
portal = DBPortal.get_by_tgid(tgid, tg_receiver)
if portal:
return cls.from_db(portal)
if peer_type:
portal = Portal(tgid, peer_type=peer_type, tg_receiver=tg_receiver)
cls.db.add(portal.db_instance)
cls.db.commit()
portal.db_instance.insert()
return portal
return None
@@ -1987,7 +1976,7 @@ class Portal:
def init(context: Context) -> None:
global config
Portal.az, Portal.db, config, Portal.loop, Portal.bot = context.core
Portal.az, _, config, Portal.loop, Portal.bot = context.core
Portal.max_initial_member_sync = config["bridge.max_initial_member_sync"]
Portal.sync_channel_members = config["bridge.sync_channel_members"]
Portal.sync_matrix_state = config["bridge.sync_matrix_state"]
+15 -23
View File
@@ -14,7 +14,7 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Awaitable, Dict, List, Match, NewType, Optional, Tuple, TYPE_CHECKING
from typing import Awaitable, Dict, List, Iterable, Match, NewType, Optional, Tuple, TYPE_CHECKING
import logging
import asyncio
import re
@@ -101,20 +101,19 @@ class User(AbstractUser):
return self.displayname
@property
def db_contacts(self) -> List[DBContact]:
return [self.db.merge(DBContact(user=self.tgid, contact=puppet.id))
for puppet in self.contacts]
def db_contacts(self) -> Iterable[DBContact]:
return (DBContact(user=self.tgid, contact=puppet.id) for puppet in self.contacts)
@db_contacts.setter
def db_contacts(self, contacts: List[DBContact]) -> None:
def db_contacts(self, contacts: Iterable[DBContact]) -> None:
self.contacts = [pu.Puppet.get(entry.contact) for entry in contacts] if contacts else []
@property
def db_portals(self) -> List[DBPortal]:
return [portal.db_instance for portal in self.portals.values() if not portal.deleted]
def db_portals(self) -> Iterable[DBPortal]:
return (portal.db_instance for portal in self.portals.values() if not portal.deleted)
@db_portals.setter
def db_portals(self, portals: List[DBPortal]) -> None:
def db_portals(self, portals: Iterable[DBPortal]) -> None:
self.portals = {
(portal.tgid, portal.tg_receiver): po.Portal.get_by_tgid(portal.tgid,
portal.tg_receiver)
@@ -135,13 +134,8 @@ class User(AbstractUser):
portals=self.db_portals)
def save(self) -> None:
self.db_instance.tgid = self.tgid
self.db_instance.tg_username = self.username
self.db_instance.tg_phone = self.phone
self.db_instance.contacts = self.db_contacts
self.db_instance.saved_contacts = self.saved_contacts
self.db_instance.portals = self.db_portals
self.db.commit()
self.db_instance.update(tgid=self.tgid, tg_username=self.username, tg_phone=self.phone,
saved_contacts=self.saved_contacts)
def delete(self) -> None:
try:
@@ -150,8 +144,7 @@ class User(AbstractUser):
except KeyError:
pass
if self._db_instance:
self.db.delete(self._db_instance)
self.db.commit()
self._db_instance.delete()
@classmethod
def from_db(cls, db_user: DBUser) -> 'User':
@@ -358,15 +351,14 @@ class User(AbstractUser):
except KeyError:
pass
user = DBUser.query.get(mxid)
user = DBUser.get_by_mxid(mxid)
if user:
user = cls.from_db(user)
return user
if create:
user = cls(mxid)
cls.db.add(user.db_instance)
cls.db.commit()
user.db_instance.insert()
return user
return None
@@ -378,7 +370,7 @@ class User(AbstractUser):
except KeyError:
pass
user = DBUser.query.filter(DBUser.tgid == tgid).one_or_none()
user = DBUser.get_by_tgid(tgid)
if user:
user = cls.from_db(user)
return user
@@ -394,7 +386,7 @@ class User(AbstractUser):
if user.username and user.username.lower() == username.lower():
return user
puppet = DBUser.query.filter(DBUser.tg_username == username).one_or_none()
puppet = DBUser.get_by_username(username)
if puppet:
return cls.from_db(puppet)
@@ -406,5 +398,5 @@ def init(context: 'Context') -> List[Awaitable['User']]:
global config
config = context.config
users = [User.from_db(user) for user in DBUser.query.all()]
users = [User.from_db(user) for user in DBUser.get_all()]
return [user.ensure_started() for user in users if user.tgid]