diff --git a/mautrix_telegram/base.py b/mautrix_telegram/base.py
index 0b62d886..c3e1756f 100644
--- a/mautrix_telegram/base.py
+++ b/mautrix_telegram/base.py
@@ -1,2 +1,41 @@
+from abc import abstractmethod
+
+from sqlalchemy import Table
+from sqlalchemy.engine.base import Engine
+from sqlalchemy.engine.result import RowProxy
+from sqlalchemy.sql.base import ImmutableColumnCollection
from sqlalchemy.ext.declarative import declarative_base
-Base = declarative_base() # type: declarative_base
+
+
+class BaseBase:
+ db = None # type: Engine
+ t = None # type: Table
+ __table__ = None # type: Table
+ c = None # type: ImmutableColumnCollection
+
+ @classmethod
+ @abstractmethod
+ def _one_or_none(cls, rows: RowProxy):
+ pass
+
+ @classmethod
+ def _select_one_or_none(cls, *args):
+ return cls._one_or_none(cls.db.execute(cls.t.select().where(*args)))
+
+ @property
+ @abstractmethod
+ def _edit_identity(self):
+ pass
+
+ def update(self, **values) -> None:
+ self.db.execute(self.t.update()
+ .where(self._edit_identity)
+ .values(**values))
+ for key, value in values.items():
+ setattr(self, key, value)
+
+ def delete(self) -> None:
+ self.db.execute(self.t.delete().where(self._edit_identity))
+
+
+Base = declarative_base(cls=BaseBase)
diff --git a/mautrix_telegram/base.pyi b/mautrix_telegram/base.pyi
new file mode 100644
index 00000000..8575893d
--- /dev/null
+++ b/mautrix_telegram/base.pyi
@@ -0,0 +1,26 @@
+from abc import abstractmethod
+
+from sqlalchemy import Table
+from sqlalchemy.engine.base import Engine
+from sqlalchemy.engine.result import RowProxy
+from sqlalchemy.sql.base import ImmutableColumnCollection
+from sqlalchemy.ext.declarative import declarative_base
+
+class Base(declarative_base):
+ db: Engine
+ t: Table
+ __table__: Table
+ c: ImmutableColumnCollection
+
+ @classmethod
+ @abstractmethod
+ def _one_or_none(cls, rows: RowProxy): ...
+
+ @classmethod
+ def _select_one_or_none(cls, *args): ...
+
+ def _edit_identity(self): ...
+
+ def update(self, **values) -> None: ...
+
+ def delete(self) -> None: ...
diff --git a/mautrix_telegram/db.py b/mautrix_telegram/db.py
index d04b58ea..5afc5c66 100644
--- a/mautrix_telegram/db.py
+++ b/mautrix_telegram/db.py
@@ -15,13 +15,12 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
from sqlalchemy import (Column, UniqueConstraint, ForeignKey, ForeignKeyConstraint, Integer,
- BigInteger, String, Boolean, Text, Table,
+ BigInteger, String, Boolean, Text,
and_, func, select)
-from sqlalchemy.engine import Engine, RowProxy
+from sqlalchemy.engine.result import RowProxy
from sqlalchemy.sql import expression
from sqlalchemy.orm import relationship, Query
-from sqlalchemy.sql.base import ImmutableColumnCollection
-from typing import Dict, Optional, List
+from typing import Dict, Optional, List, Iterable
import json
from mautrix_telegram.types import MatrixUserID, MatrixRoomID, MatrixEventID
@@ -30,7 +29,6 @@ from .base import Base
class Portal(Base):
- query = None # type: Query
__tablename__ = "portal"
# Telegram chat information
@@ -50,11 +48,41 @@ class Portal(Base):
about = Column(String, nullable=True)
photo_id = Column(String, nullable=True)
+ @classmethod
+ def _one_or_none(cls, rows: RowProxy) -> Optional['Portal']:
+ try:
+ (tgid, tg_receiver, peer_type, megagroup, mxid, config,
+ username, title, about, photo_id) = next(rows)
+ return cls(tgid=tgid, tg_receiver=tg_receiver, peer_type=peer_type,
+ megagroup=megagroup, mxid=mxid, config=config, username=username,
+ title=title, about=about, photo_id=photo_id)
+ except StopIteration:
+ return None
+
+ @classmethod
+ def get_by_tgid(cls, tgid: TelegramID, tg_receiver: TelegramID) -> Optional['Portal']:
+ return cls._select_one_or_none(and_(cls.c.tgid == tgid, cls.c.tg_receiver == tg_receiver))
+
+ @classmethod
+ def get_by_mxid(cls, mxid: MatrixRoomID) -> Optional['Portal']:
+ return cls._select_one_or_none(cls.c.mxid == mxid)
+
+ @classmethod
+ def get_by_username(cls, username: str) -> Optional['Portal']:
+ return cls._select_one_or_none(cls.c.username == username)
+
+ @property
+ def _edit_identity(self):
+ return and_(self.c.tgid == self.tgid, self.c.tg_receiver == self.tg_receiver)
+
+ def insert(self) -> None:
+ self.db.execute(self.t.insert().values(
+ tgid=self.tgid, tg_receiver=self.tg_receiver, peer_type=self.peer_type,
+ megagroup=self.megagroup, mxid=self.mxid, config=self.config, username=self.username,
+ title=self.title, about=self.about, photo_id=self.photo_id))
+
class Message(Base):
- db = None # type: Engine
- t = None # type: Table
- c = None # type: ImmutableColumnCollection
__tablename__ = "message"
mxid = Column(String) # type: MatrixEventID
@@ -64,11 +92,11 @@ class Message(Base):
__table_args__ = (UniqueConstraint("mxid", "mx_room", "tg_space", name="_mx_id_room"),)
- @staticmethod
- def _one_or_none(rows: RowProxy) -> Optional['Message']:
+ @classmethod
+ def _one_or_none(cls, rows: RowProxy) -> Optional['Message']:
try:
mxid, mx_room, tgid, tg_space = next(rows)
- return Message(mxid=mxid, mx_room=mx_room, tgid=tgid, tg_space=tg_space)
+ return cls(mxid=mxid, mx_room=mx_room, tgid=tgid, tg_space=tg_space)
except StopIteration:
return None
@@ -79,9 +107,7 @@ class Message(Base):
@classmethod
def get_by_tgid(cls, tgid: TelegramID, tg_space: TelegramID) -> Optional['Message']:
- rows = cls.db.execute(cls.t.select()
- .where(and_(cls.c.tgid == tgid, cls.c.tg_space == tg_space)))
- return cls._one_or_none(rows)
+ return cls._select_one_or_none(and_(cls.c.tgid == tgid, cls.c.tg_space == tg_space))
@classmethod
def count_spaces_by_mxid(cls, mxid: MatrixEventID, mx_room: MatrixRoomID) -> int:
@@ -96,9 +122,9 @@ class Message(Base):
@classmethod
def get_by_mxid(cls, mxid: MatrixEventID, mx_room: MatrixRoomID, tg_space: TelegramID
) -> Optional['Message']:
- rows = cls.db.execute(cls.t.select().where(
- and_(cls.c.mxid == mxid, cls.c.mx_room == mx_room, cls.c.tg_space == tg_space)))
- return cls._one_or_none(rows)
+ return cls._select_one_or_none(and_(cls.c.mxid == mxid,
+ cls.c.mx_room == mx_room,
+ cls.c.tg_space == tg_space))
@classmethod
def update_by_tgid(cls, s_tgid: TelegramID, s_tg_space: TelegramID, **values) -> None:
@@ -112,36 +138,16 @@ class Message(Base):
.where(and_(cls.c.mxid == s_mxid, cls.c.mx_room == s_mx_room))
.values(**values))
- def update(self, **values) -> None:
- for key, value in values.items():
- setattr(self, key, value)
- self.update_by_tgid(self.tgid, self.tg_space, **values)
-
- def delete(self) -> None:
- self.db.execute(self.t.delete().where(
- and_(self.c.tgid == self.tgid, self.c.tg_space == self.tg_space)))
+ @property
+ def _edit_identity(self):
+ return and_(self.c.tgid == self.tgid, self.c.tg_space == self.tg_space)
def insert(self) -> None:
self.db.execute(self.t.insert().values(mxid=self.mxid, mx_room=self.mx_room, tgid=self.tgid,
tg_space=self.tg_space))
-class UserPortal(Base):
- query = None # type: Query
- __tablename__ = "user_portal"
-
- user = Column(Integer, ForeignKey("user.tgid", onupdate="CASCADE", ondelete="CASCADE"),
- primary_key=True) # type: TelegramID
- portal = Column(Integer, primary_key=True) # type: TelegramID
- portal_receiver = Column(Integer, primary_key=True) # type: TelegramID
-
- __table_args__ = (ForeignKeyConstraint(("portal", "portal_receiver"),
- ("portal.tgid", "portal.tg_receiver"),
- onupdate="CASCADE", ondelete="CASCADE"),)
-
-
class User(Base):
- query = None # type: Query
__tablename__ = "user"
mxid = Column(String, primary_key=True) # type: MatrixUserID
@@ -154,11 +160,66 @@ class User(Base):
) # type: List[Contact]
portals = relationship("Portal", secondary="user_portal")
+ @classmethod
+ def _one_or_none(cls, rows: RowProxy) -> Optional['User']:
+ try:
+ mxid, tgid, tg_username, tg_phone, saved_contacts = next(rows)
+ return cls(mxid=mxid, tgid=tgid, tg_username=tg_username, tg_phone=tg_phone,
+ saved_contacts=saved_contacts)
+ except StopIteration:
+ return None
+
+ @classmethod
+ def get_all(cls) -> Iterable['User']:
+ rows = cls.db.execute(cls.t.select())
+ for row in rows:
+ mxid, tgid, tg_username, tg_phone, saved_contacts = row
+ yield cls(mxid=mxid, tgid=tgid, tg_username=tg_username, tg_phone=tg_phone,
+ saved_contacts=saved_contacts)
+
+ @classmethod
+ def get_by_tgid(cls, tgid: TelegramID) -> Optional['User']:
+ return cls._select_one_or_none(cls.c.tgid == tgid)
+
+ @classmethod
+ def get_by_mxid(cls, mxid: MatrixRoomID) -> Optional['User']:
+ return cls._select_one_or_none(cls.c.mxid == mxid)
+
+ @classmethod
+ def get_by_username(cls, username: str) -> Optional['User']:
+ return cls._select_one_or_none(cls.c.username == username)
+
+ @property
+ def _edit_identity(self):
+ return self.c.mxid == self.mxid
+
+ def insert(self) -> None:
+ self.db.execute(self.t.insert().values(
+ mxid=self.mxid, tgid=self.tgid, tg_username=self.tg_username, tg_phone=self.tg_phone,
+ saved_contacts=self.saved_contacts))
+
+
+class UserPortal(Base):
+ __tablename__ = "user_portal"
+
+ user = Column(Integer, ForeignKey("user.tgid", onupdate="CASCADE", ondelete="CASCADE"),
+ primary_key=True) # type: TelegramID
+ portal = Column(Integer, primary_key=True) # type: TelegramID
+ portal_receiver = Column(Integer, primary_key=True) # type: TelegramID
+
+ __table_args__ = (ForeignKeyConstraint(("portal", "portal_receiver"),
+ ("portal.tgid", "portal.tg_receiver"),
+ onupdate="CASCADE", ondelete="CASCADE"),)
+
+
+class Contact(Base):
+ __tablename__ = "contact"
+
+ user = Column(Integer, ForeignKey("user.tgid"), primary_key=True) # type: TelegramID
+ contact = Column(Integer, ForeignKey("puppet.id"), primary_key=True) # type: TelegramID
+
class RoomState(Base):
- db = None # type: Engine
- t = None # type: Table
- c = None # type: ImmutableColumnCollection
__tablename__ = "mx_room_state"
room_id = Column(String, primary_key=True) # type: MatrixRoomID
@@ -177,18 +238,17 @@ class RoomState(Base):
rows = cls.db.execute(cls.t.select().where(cls.c.room_id == room_id))
try:
room_id, power_levels_text = next(rows)
- return RoomState(room_id=room_id, power_levels=(json.loads(power_levels_text)
- if power_levels_text else None))
+ return cls(room_id=room_id, power_levels=(json.loads(power_levels_text)
+ if power_levels_text else None))
except StopIteration:
return None
def update(self) -> None:
- self.db.execute(self.t.update()
- .where(self.c.room_id == self.room_id)
- .values(power_levels=self._power_levels_text))
+ return super().update(power_levels=self._power_levels_text)
- def delete(self) -> None:
- self.db.execute(self.t.delete().where(self.c.room_id == self.room_id))
+ @property
+ def _edit_identity(self):
+ return self.c.room_id == self.room_id
def insert(self) -> None:
self.db.execute(self.t.insert().values(room_id=self.room_id,
@@ -196,9 +256,6 @@ class RoomState(Base):
class UserProfile(Base):
- db = None # type: Engine
- t = None # type: Table
- c = None # type: ImmutableColumnCollection
__tablename__ = "mx_user_profile"
room_id = Column(String, primary_key=True) # type: MatrixRoomID
@@ -220,8 +277,8 @@ class UserProfile(Base):
cls.t.select().where(and_(cls.c.room_id == room_id, cls.c.user_id == user_id)))
try:
room_id, user_id, membership, displayname, avatar_url = next(rows)
- return UserProfile(room_id=room_id, user_id=user_id, membership=membership,
- displayname=displayname, avatar_url=avatar_url)
+ return cls(room_id=room_id, user_id=user_id, membership=membership,
+ displayname=displayname, avatar_url=avatar_url)
except StopIteration:
return None
@@ -230,14 +287,12 @@ class UserProfile(Base):
cls.db.execute(cls.t.delete().where(cls.c.room_id == room_id))
def update(self) -> None:
- self.db.execute(self.t.update()
- .where(and_(self.c.room_id == self.room_id, self.c.user_id == self.user_id))
- .values(membership=self.membership, displayname=self.displayname,
- avatar_url=self.avatar_url))
+ super().update(membership=self.membership, displayname=self.displayname,
+ avatar_url=self.avatar_url)
- def delete(self) -> None:
- self.db.execute(self.t.delete().where(and_(self.c.room_id == self.room_id,
- self.c.user_id == self.user_id)))
+ @property
+ def _edit_identity(self):
+ return and_(self.c.room_id == self.room_id, self.c.user_id == self.user_id)
def insert(self) -> None:
self.db.execute(self.t.insert().values(room_id=self.room_id, user_id=self.user_id,
@@ -246,14 +301,6 @@ class UserProfile(Base):
avatar_url=self.avatar_url))
-class Contact(Base):
- query = None # type: Query
- __tablename__ = "contact"
-
- user = Column(Integer, ForeignKey("user.tgid"), primary_key=True) # type: TelegramID
- contact = Column(Integer, ForeignKey("puppet.id"), primary_key=True) # type: TelegramID
-
-
class Puppet(Base):
query = None # type: Query
__tablename__ = "puppet"
@@ -278,9 +325,6 @@ class BotChat(Base):
class TelegramFile(Base):
- db = None # type: Engine
- t = None # type: Table
- c = None # type: ImmutableColumnCollection
__tablename__ = "telegram_file"
id = Column(String, primary_key=True)
@@ -302,8 +346,8 @@ class TelegramFile(Base):
thumb = None
if thumb_id:
thumb = cls.get(thumb_id)
- return TelegramFile(id=id, mxc=mxc, mime_type=mime, was_converted=conv, timestamp=ts,
- size=s, width=w, height=h, thumbnail_id=thumb_id, thumbnail=thumb)
+ return cls(id=id, mxc=mxc, mime_type=mime, was_converted=conv, timestamp=ts,
+ size=s, width=w, height=h, thumbnail_id=thumb_id, thumbnail=thumb)
except StopIteration:
return None
@@ -316,8 +360,7 @@ class TelegramFile(Base):
def init(db_session, db_engine) -> None:
query = db_session.query_property()
- for table in (Portal, Message, UserPortal, User, Puppet, BotChat, TelegramFile, UserProfile,
- RoomState):
+ for table in (Portal, Message, User, Puppet, BotChat, TelegramFile, UserProfile, RoomState):
table.query = query
table.db = db_engine
table.t = table.__table__
diff --git a/mautrix_telegram/portal.py b/mautrix_telegram/portal.py
index 9c411a21..3a04b74f 100644
--- a/mautrix_telegram/portal.py
+++ b/mautrix_telegram/portal.py
@@ -31,7 +31,6 @@ import json
import re
import magic
-from sqlalchemy import orm
from sqlalchemy.exc import IntegrityError
from telethon.tl.functions.messages import (
@@ -89,7 +88,6 @@ InviteList = Union[MatrixUserID, List[MatrixUserID]]
class Portal:
log = logging.getLogger("mau.portal") # type: logging.Logger
- db = None # type: orm.Session
az = None # type: AppService
bot = None # type: Bot
loop = None # type: asyncio.AbstractEventLoop
@@ -1255,8 +1253,7 @@ class Portal:
self.tg_receiver = self.tgid
self.by_tgid[self.tgid_full] = self
await self.update_info(source, entity)
- self.db.add(self.db_instance)
- self.save()
+ self.db_instance.insert()
if self.bot and self.bot.tgid in invites:
self.bot.add_chat(self.tgid, self.peer_type)
@@ -1842,15 +1839,13 @@ class Portal:
del self.by_tgid[self.tgid_full]
except KeyError:
pass
- self.tgid = new_id
- self.tg_receiver = new_id
- existing = self.by_tgid[self.tgid_full]
+ existing = self.by_tgid[(new_id, new_id)]
if existing:
existing.delete()
+ self.db_instance.update(tgid=new_id, tg_receiver=new_id)
+ self.tgid = new_id
+ self.tg_receiver = new_id
self.by_tgid[self.tgid_full] = self
- self.db_instance.tgid = self.tgid
- self.db_instance.tg_receiver = self.tg_receiver
- self.save()
def migrate_and_save_matrix(self, new_id: MatrixRoomID) -> None:
try:
@@ -1858,17 +1853,13 @@ class Portal:
except KeyError:
pass
self.mxid = new_id
+ self.db_instance.update(mxid=self.mxid)
self.by_mxid[self.mxid] = self
- self.save()
def save(self) -> None:
- self.db_instance.mxid = self.mxid
- self.db_instance.username = self.username
- self.db_instance.title = self.title
- self.db_instance.about = self.about
- self.db_instance.photo_id = self.photo_id
- self.db_instance.config = json.dumps(self.local_config)
- self.db.commit()
+ self.db_instance.update(mxid=self.mxid, username=self.username, title=self.title,
+ about=self.about, photo_id=self.photo_id,
+ config=json.dumps(self.local_config))
def delete(self) -> None:
try:
@@ -1880,8 +1871,7 @@ class Portal:
except KeyError:
pass
if self._db_instance:
- self.db.delete(self._db_instance)
- self.db.commit()
+ self._db_instance.delete()
self.deleted = True
@classmethod
@@ -1902,7 +1892,7 @@ class Portal:
except KeyError:
pass
- portal = DBPortal.query.filter(DBPortal.mxid == mxid).one_or_none()
+ portal = DBPortal.get_by_mxid(mxid)
if portal:
return cls.from_db(portal)
@@ -1924,7 +1914,7 @@ class Portal:
if portal.username and portal.username.lower() == username.lower():
return portal
- dbportal = DBPortal.query.filter(DBPortal.username == username).one_or_none()
+ dbportal = DBPortal.get_by_username(username)
if dbportal:
return cls.from_db(dbportal)
@@ -1940,14 +1930,13 @@ class Portal:
except KeyError:
pass
- portal = DBPortal.query.get(tgid_full)
+ portal = DBPortal.get_by_tgid(tgid, tg_receiver)
if portal:
return cls.from_db(portal)
if peer_type:
portal = Portal(tgid, peer_type=peer_type, tg_receiver=tg_receiver)
- cls.db.add(portal.db_instance)
- cls.db.commit()
+ portal.db_instance.insert()
return portal
return None
@@ -1987,7 +1976,7 @@ class Portal:
def init(context: Context) -> None:
global config
- Portal.az, Portal.db, config, Portal.loop, Portal.bot = context.core
+ Portal.az, _, config, Portal.loop, Portal.bot = context.core
Portal.max_initial_member_sync = config["bridge.max_initial_member_sync"]
Portal.sync_channel_members = config["bridge.sync_channel_members"]
Portal.sync_matrix_state = config["bridge.sync_matrix_state"]
diff --git a/mautrix_telegram/user.py b/mautrix_telegram/user.py
index 55929475..4c71df0e 100644
--- a/mautrix_telegram/user.py
+++ b/mautrix_telegram/user.py
@@ -14,7 +14,7 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-from typing import Awaitable, Dict, List, Match, NewType, Optional, Tuple, TYPE_CHECKING
+from typing import Awaitable, Dict, List, Iterable, Match, NewType, Optional, Tuple, TYPE_CHECKING
import logging
import asyncio
import re
@@ -101,20 +101,19 @@ class User(AbstractUser):
return self.displayname
@property
- def db_contacts(self) -> List[DBContact]:
- return [self.db.merge(DBContact(user=self.tgid, contact=puppet.id))
- for puppet in self.contacts]
+ def db_contacts(self) -> Iterable[DBContact]:
+ return (DBContact(user=self.tgid, contact=puppet.id) for puppet in self.contacts)
@db_contacts.setter
- def db_contacts(self, contacts: List[DBContact]) -> None:
+ def db_contacts(self, contacts: Iterable[DBContact]) -> None:
self.contacts = [pu.Puppet.get(entry.contact) for entry in contacts] if contacts else []
@property
- def db_portals(self) -> List[DBPortal]:
- return [portal.db_instance for portal in self.portals.values() if not portal.deleted]
+ def db_portals(self) -> Iterable[DBPortal]:
+ return (portal.db_instance for portal in self.portals.values() if not portal.deleted)
@db_portals.setter
- def db_portals(self, portals: List[DBPortal]) -> None:
+ def db_portals(self, portals: Iterable[DBPortal]) -> None:
self.portals = {
(portal.tgid, portal.tg_receiver): po.Portal.get_by_tgid(portal.tgid,
portal.tg_receiver)
@@ -135,13 +134,8 @@ class User(AbstractUser):
portals=self.db_portals)
def save(self) -> None:
- self.db_instance.tgid = self.tgid
- self.db_instance.tg_username = self.username
- self.db_instance.tg_phone = self.phone
- self.db_instance.contacts = self.db_contacts
- self.db_instance.saved_contacts = self.saved_contacts
- self.db_instance.portals = self.db_portals
- self.db.commit()
+ self.db_instance.update(tgid=self.tgid, tg_username=self.username, tg_phone=self.phone,
+ saved_contacts=self.saved_contacts)
def delete(self) -> None:
try:
@@ -150,8 +144,7 @@ class User(AbstractUser):
except KeyError:
pass
if self._db_instance:
- self.db.delete(self._db_instance)
- self.db.commit()
+ self._db_instance.delete()
@classmethod
def from_db(cls, db_user: DBUser) -> 'User':
@@ -358,15 +351,14 @@ class User(AbstractUser):
except KeyError:
pass
- user = DBUser.query.get(mxid)
+ user = DBUser.get_by_mxid(mxid)
if user:
user = cls.from_db(user)
return user
if create:
user = cls(mxid)
- cls.db.add(user.db_instance)
- cls.db.commit()
+ user.db_instance.insert()
return user
return None
@@ -378,7 +370,7 @@ class User(AbstractUser):
except KeyError:
pass
- user = DBUser.query.filter(DBUser.tgid == tgid).one_or_none()
+ user = DBUser.get_by_tgid(tgid)
if user:
user = cls.from_db(user)
return user
@@ -394,7 +386,7 @@ class User(AbstractUser):
if user.username and user.username.lower() == username.lower():
return user
- puppet = DBUser.query.filter(DBUser.tg_username == username).one_or_none()
+ puppet = DBUser.get_by_username(username)
if puppet:
return cls.from_db(puppet)
@@ -406,5 +398,5 @@ def init(context: 'Context') -> List[Awaitable['User']]:
global config
config = context.config
- users = [User.from_db(user) for user in DBUser.query.all()]
+ users = [User.from_db(user) for user in DBUser.get_all()]
return [user.ensure_started() for user in users if user.tgid]