Start using new db base functions

This commit is contained in:
Tulir Asokan
2019-09-02 22:02:50 +03:00
parent 2c443a3b93
commit fbb1267609
11 changed files with 27 additions and 124 deletions
+1 -1
View File
@@ -7,7 +7,7 @@ from os.path import abspath, dirname
sys.path.insert(0, dirname(dirname(abspath(__file__))))
from mautrix.bridge.db import Base
from mautrix.util.db import Base
import mautrix_telegram.db
from mautrix_telegram.config import Config
from alchemysession import AlchemySessionContainer
@@ -12,7 +12,7 @@ from alembic import context, op
import sqlalchemy.orm as orm
import sqlalchemy as sa
from mautrix.bridge.db import Base
from mautrix.util.db import Base
from mautrix_telegram.config import Config
+1 -1
View File
@@ -19,7 +19,7 @@ from itertools import chain
from alchemysession import AlchemySessionContainer
from mautrix.bridge import Bridge
from mautrix.bridge.db import Base
from mautrix.util.db import Base
from .web.provisioning import ProvisioningAPI
from .web.public import PublicBridgeWebsite
+1
View File
@@ -31,3 +31,4 @@ def init(db_engine: Engine) -> None:
table.db = db_engine
table.t = table.__table__
table.c = table.t.c
table.column_names = table.c.keys()
+1 -10
View File
@@ -16,9 +16,8 @@
from typing import Iterable
from sqlalchemy import Column, Integer, String
from sqlalchemy.engine.result import RowProxy
from mautrix.bridge.db import Base
from mautrix.util.db import Base
from ..types import TelegramID
@@ -34,14 +33,6 @@ class BotChat(Base):
with cls.db.begin() as conn:
conn.execute(cls.t.delete().where(cls.c.id == chat_id))
@classmethod
def scan(cls, row: RowProxy) -> 'BotChat':
return cls(id=row[0], type=row[1])
@classmethod
def all(cls) -> Iterable['BotChat']:
return cls._select_all()
def insert(self) -> None:
with self.db.begin() as conn:
conn.execute(self.t.insert().values(id=self.id, type=self.type))
+10 -32
View File
@@ -16,11 +16,9 @@
from typing import Optional, Iterator
from sqlalchemy import Column, UniqueConstraint, Integer, String, and_, func, desc, select
from sqlalchemy.engine.result import RowProxy
from sqlalchemy.sql.expression import ClauseElement
from mautrix.types import RoomID, EventID
from mautrix.bridge.db import Base
from mautrix.util.db import Base
from ..types import TelegramID
@@ -36,29 +34,21 @@ class Message(Base):
__table_args__ = (UniqueConstraint("mxid", "mx_room", "tg_space", name="_mx_id_room_2"),)
@classmethod
def scan(cls, row: RowProxy) -> 'Message':
return cls(mxid=row[0], mx_room=row[1], tgid=row[2], tg_space=row[3], edit_index=row[4])
@classmethod
def get_all_by_tgid(cls, tgid: TelegramID, tg_space: TelegramID) -> Iterator['Message']:
return cls._all(cls.db.execute(cls.t.select().where(and_(cls.c.tgid == tgid,
cls.c.tg_space == tg_space))))
return cls._select_all(cls.c.tgid == tgid, cls.c.tg_space == tg_space)
@classmethod
def get_one_by_tgid(cls, tgid: TelegramID, tg_space: TelegramID, edit_index: int = 0
) -> Optional['Message']:
query = cls.t.select()
if edit_index < 0:
query = (query
.where(and_(cls.c.tgid == tgid, cls.c.tg_space == tg_space))
.order_by(desc(cls.c.edit_index))
.limit(1)
.offset(-edit_index - 1))
return cls._one_or_none(cls.t.select()
.where(and_(cls.c.tgid == tgid, cls.c.tg_space == tg_space))
.order_by(desc(cls.c.edit_index))
.limit(1).offset(-edit_index - 1))
else:
query = query.where(and_(cls.c.tgid == tgid, cls.c.tg_space == tg_space,
cls.c.edit_index == edit_index))
return cls._one_or_none(cls.db.execute(query))
return cls._select_one_or_none(cls.c.tgid == tgid, cls.c.tg_space == tg_space,
cls.c.edit_index == edit_index)
@classmethod
def count_spaces_by_mxid(cls, mxid: EventID, mx_room: RoomID) -> int:
@@ -73,9 +63,8 @@ class Message(Base):
@classmethod
def get_by_mxid(cls, mxid: EventID, mx_room: RoomID, tg_space: TelegramID
) -> Optional['Message']:
return cls._select_one_or_none(and_(cls.c.mxid == mxid,
cls.c.mx_room == mx_room,
cls.c.tg_space == tg_space))
return cls._select_one_or_none(cls.c.mxid == mxid, cls.c.mx_room == mx_room,
cls.c.tg_space == tg_space)
@classmethod
def update_by_tgid(cls, s_tgid: TelegramID, s_tg_space: TelegramID, s_edit_index: int,
@@ -92,14 +81,3 @@ class Message(Base):
conn.execute(cls.t.update()
.where(and_(cls.c.mxid == s_mxid, cls.c.mx_room == s_mx_room))
.values(**values))
@property
def _edit_identity(self) -> ClauseElement:
return and_(self.c.tgid == self.tgid, self.c.tg_space == self.tg_space,
self.c.edit_index == self.edit_index)
def insert(self) -> None:
with self.db.begin() as conn:
conn.execute(self.t.insert().values(mxid=self.mxid, mx_room=self.mx_room,
tgid=self.tgid, tg_space=self.tg_space,
edit_index=self.edit_index))
+3 -24
View File
@@ -15,12 +15,10 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional
from sqlalchemy import Column, Integer, String, Boolean, Text, and_
from sqlalchemy.engine.result import RowProxy
from sqlalchemy.sql.expression import ClauseElement
from sqlalchemy import Column, Integer, String, Boolean, Text
from mautrix.types import RoomID
from mautrix.bridge.db import Base
from mautrix.util.db import Base
from ..types import TelegramID
@@ -45,17 +43,9 @@ class Portal(Base):
about: str = Column(String, nullable=True)
photo_id: str = Column(String, nullable=True)
@classmethod
def scan(cls, row: RowProxy) -> Optional['Portal']:
(tgid, tg_receiver, peer_type, megagroup, mxid, config, username, title, about,
photo_id) = row
return cls(tgid=tgid, tg_receiver=tg_receiver, peer_type=peer_type, megagroup=megagroup,
mxid=mxid, config=config, username=username, title=title, about=about,
photo_id=photo_id)
@classmethod
def get_by_tgid(cls, tgid: TelegramID, tg_receiver: TelegramID) -> Optional['Portal']:
return cls._select_one_or_none(and_(cls.c.tgid == tgid, cls.c.tg_receiver == tg_receiver))
return cls._select_one_or_none(cls.c.tgid == tgid, cls.c.tg_receiver == tg_receiver)
@classmethod
def get_by_mxid(cls, mxid: RoomID) -> Optional['Portal']:
@@ -64,14 +54,3 @@ class Portal(Base):
@classmethod
def get_by_username(cls, username: str) -> Optional['Portal']:
return cls._select_one_or_none(cls.c.username == username)
@property
def _edit_identity(self) -> ClauseElement:
return and_(self.c.tgid == self.tgid, self.c.tg_receiver == self.tg_receiver)
def insert(self) -> None:
with self.db.begin() as conn:
conn.execute(self.t.insert().values(
tgid=self.tgid, tg_receiver=self.tg_receiver, peer_type=self.peer_type,
megagroup=self.megagroup, mxid=self.mxid, config=self.config,
username=self.username, title=self.title, about=self.about, photo_id=self.photo_id))
+2 -28
View File
@@ -17,11 +17,9 @@ from typing import Optional, Iterable
from sqlalchemy import Column, Integer, String, Boolean
from sqlalchemy.sql import expression
from sqlalchemy.engine.result import RowProxy
from sqlalchemy.sql.expression import ClauseElement
from mautrix.types import UserID, SyncToken
from mautrix.bridge.db import Base
from mautrix.util.db import Base
from ..types import TelegramID
@@ -41,20 +39,9 @@ class Puppet(Base):
matrix_registered: bool = Column(Boolean, nullable=False, server_default=expression.false())
disable_updates: bool = Column(Boolean, nullable=False, server_default=expression.false())
@classmethod
def scan(cls, row: RowProxy) -> Optional['Puppet']:
(id, custom_mxid, access_token, next_batch, displayname, displayname_source, username,
photo_id, is_bot, matrix_registered, disable_updates) = row
return cls(id=id, custom_mxid=custom_mxid, access_token=access_token, username=username,
next_batch=next_batch, displayname=displayname, photo_id=photo_id,
displayname_source=displayname_source, matrix_registered=matrix_registered,
disable_updates=disable_updates, is_bot=is_bot)
@classmethod
def all_with_custom_mxid(cls) -> Iterable['Puppet']:
rows = cls.db.execute(cls.t.select().where(cls.c.custom_mxid != None))
for row in rows:
yield cls.scan(row)
yield from cls._select_all(cls.c.custom_mxid != None)
@classmethod
def get_by_tgid(cls, tgid: TelegramID) -> Optional['Puppet']:
@@ -71,16 +58,3 @@ class Puppet(Base):
@classmethod
def get_by_displayname(cls, displayname: str) -> Optional['Puppet']:
return cls._select_one_or_none(cls.c.displayname == displayname)
@property
def _edit_identity(self) -> ClauseElement:
return self.c.id == self.id
def insert(self) -> None:
with self.db.begin() as conn:
conn.execute(self.t.insert().values(
id=self.id, custom_mxid=self.custom_mxid, access_token=self.access_token,
next_batch=self.next_batch, displayname=self.displayname, username=self.username,
displayname_source=self.displayname_source, photo_id=self.photo_id,
is_bot=self.is_bot, matrix_registered=self.matrix_registered,
disable_updates=self.disable_updates))
+5 -7
View File
@@ -19,7 +19,7 @@ from sqlalchemy import Column, ForeignKey, Integer, BigInteger, String, Boolean
from sqlalchemy.engine.result import RowProxy
from mautrix.types import ContentURI
from mautrix.bridge.db import Base
from mautrix.util.db import Base
class TelegramFile(Base):
@@ -38,12 +38,10 @@ class TelegramFile(Base):
@classmethod
def scan(cls, row: RowProxy) -> 'TelegramFile':
loc_id, mxc, mime, conv, ts, s, w, h, thumb_id = row
thumb = None
if thumb_id:
thumb = cls.get(thumb_id)
return cls(id=loc_id, mxc=mxc, mime_type=mime, was_converted=conv, timestamp=ts,
size=s, width=w, height=h, thumbnail_id=thumb_id, thumbnail=thumb)
telegram_file: TelegramFile = super().scan(row)
if telegram_file.thumbnail_id:
telegram_file.thumbnail = cls.get(telegram_file.thumbnail_id)
return telegram_file
@classmethod
def get(cls, loc_id: str) -> Optional['TelegramFile']:
+1 -19
View File
@@ -16,11 +16,9 @@
from typing import Optional, Iterable, Tuple
from sqlalchemy import Column, ForeignKey, ForeignKeyConstraint, Integer, String
from sqlalchemy.engine.result import RowProxy
from sqlalchemy.sql.expression import ClauseElement
from mautrix.types import UserID
from mautrix.bridge.db import Base
from mautrix.util.db import Base
from ..types import TelegramID
@@ -34,12 +32,6 @@ class User(Base):
tg_phone: str = Column(String, nullable=True)
saved_contacts: int = Column(Integer, default=0, nullable=False)
@classmethod
def scan(cls, row: RowProxy) -> 'User':
mxid, tgid, tg_username, tg_phone, saved_contacts = row
return cls(mxid=mxid, tgid=tgid, tg_username=tg_username, tg_phone=tg_phone,
saved_contacts=saved_contacts)
@classmethod
def all_with_tgid(cls) -> Iterable['User']:
return cls._select_all(cls.c.tgid != None)
@@ -56,16 +48,6 @@ class User(Base):
def get_by_username(cls, username: str) -> Optional['User']:
return cls._select_one_or_none(cls.c.tg_username == username)
@property
def _edit_identity(self) -> ClauseElement:
return self.c.mxid == self.mxid
def insert(self) -> None:
with self.db.begin() as conn:
conn.execute(self.t.insert().values(
mxid=self.mxid, tgid=self.tgid, tg_username=self.tg_username,
tg_phone=self.tg_phone, saved_contacts=self.saved_contacts))
@property
def contacts(self) -> Iterable[TelegramID]:
rows = self.db.execute(Contact.t.select().where(Contact.c.user == self.tgid))
@@ -19,7 +19,7 @@ import argparse
from sqlalchemy import orm
import sqlalchemy as sql
from mautrix.bridge.db import Base
from mautrix.util.db import Base
from mautrix_telegram.db import Portal, Message, Puppet, BotChat
from mautrix_telegram.config import Config