Start using new db base functions
This commit is contained in:
+1
-1
@@ -7,7 +7,7 @@ from os.path import abspath, dirname
|
||||
|
||||
sys.path.insert(0, dirname(dirname(abspath(__file__))))
|
||||
|
||||
from mautrix.bridge.db import Base
|
||||
from mautrix.util.db import Base
|
||||
import mautrix_telegram.db
|
||||
from mautrix_telegram.config import Config
|
||||
from alchemysession import AlchemySessionContainer
|
||||
|
||||
@@ -12,7 +12,7 @@ from alembic import context, op
|
||||
import sqlalchemy.orm as orm
|
||||
import sqlalchemy as sa
|
||||
|
||||
from mautrix.bridge.db import Base
|
||||
from mautrix.util.db import Base
|
||||
|
||||
from mautrix_telegram.config import Config
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ from itertools import chain
|
||||
from alchemysession import AlchemySessionContainer
|
||||
|
||||
from mautrix.bridge import Bridge
|
||||
from mautrix.bridge.db import Base
|
||||
from mautrix.util.db import Base
|
||||
|
||||
from .web.provisioning import ProvisioningAPI
|
||||
from .web.public import PublicBridgeWebsite
|
||||
|
||||
@@ -31,3 +31,4 @@ def init(db_engine: Engine) -> None:
|
||||
table.db = db_engine
|
||||
table.t = table.__table__
|
||||
table.c = table.t.c
|
||||
table.column_names = table.c.keys()
|
||||
|
||||
@@ -16,9 +16,8 @@
|
||||
from typing import Iterable
|
||||
|
||||
from sqlalchemy import Column, Integer, String
|
||||
from sqlalchemy.engine.result import RowProxy
|
||||
|
||||
from mautrix.bridge.db import Base
|
||||
from mautrix.util.db import Base
|
||||
|
||||
from ..types import TelegramID
|
||||
|
||||
@@ -34,14 +33,6 @@ class BotChat(Base):
|
||||
with cls.db.begin() as conn:
|
||||
conn.execute(cls.t.delete().where(cls.c.id == chat_id))
|
||||
|
||||
@classmethod
|
||||
def scan(cls, row: RowProxy) -> 'BotChat':
|
||||
return cls(id=row[0], type=row[1])
|
||||
|
||||
@classmethod
|
||||
def all(cls) -> Iterable['BotChat']:
|
||||
return cls._select_all()
|
||||
|
||||
def insert(self) -> None:
|
||||
with self.db.begin() as conn:
|
||||
conn.execute(self.t.insert().values(id=self.id, type=self.type))
|
||||
|
||||
@@ -16,11 +16,9 @@
|
||||
from typing import Optional, Iterator
|
||||
|
||||
from sqlalchemy import Column, UniqueConstraint, Integer, String, and_, func, desc, select
|
||||
from sqlalchemy.engine.result import RowProxy
|
||||
from sqlalchemy.sql.expression import ClauseElement
|
||||
|
||||
from mautrix.types import RoomID, EventID
|
||||
from mautrix.bridge.db import Base
|
||||
from mautrix.util.db import Base
|
||||
|
||||
from ..types import TelegramID
|
||||
|
||||
@@ -36,29 +34,21 @@ class Message(Base):
|
||||
|
||||
__table_args__ = (UniqueConstraint("mxid", "mx_room", "tg_space", name="_mx_id_room_2"),)
|
||||
|
||||
@classmethod
|
||||
def scan(cls, row: RowProxy) -> 'Message':
|
||||
return cls(mxid=row[0], mx_room=row[1], tgid=row[2], tg_space=row[3], edit_index=row[4])
|
||||
|
||||
@classmethod
|
||||
def get_all_by_tgid(cls, tgid: TelegramID, tg_space: TelegramID) -> Iterator['Message']:
|
||||
return cls._all(cls.db.execute(cls.t.select().where(and_(cls.c.tgid == tgid,
|
||||
cls.c.tg_space == tg_space))))
|
||||
return cls._select_all(cls.c.tgid == tgid, cls.c.tg_space == tg_space)
|
||||
|
||||
@classmethod
|
||||
def get_one_by_tgid(cls, tgid: TelegramID, tg_space: TelegramID, edit_index: int = 0
|
||||
) -> Optional['Message']:
|
||||
query = cls.t.select()
|
||||
if edit_index < 0:
|
||||
query = (query
|
||||
.where(and_(cls.c.tgid == tgid, cls.c.tg_space == tg_space))
|
||||
.order_by(desc(cls.c.edit_index))
|
||||
.limit(1)
|
||||
.offset(-edit_index - 1))
|
||||
return cls._one_or_none(cls.t.select()
|
||||
.where(and_(cls.c.tgid == tgid, cls.c.tg_space == tg_space))
|
||||
.order_by(desc(cls.c.edit_index))
|
||||
.limit(1).offset(-edit_index - 1))
|
||||
else:
|
||||
query = query.where(and_(cls.c.tgid == tgid, cls.c.tg_space == tg_space,
|
||||
cls.c.edit_index == edit_index))
|
||||
return cls._one_or_none(cls.db.execute(query))
|
||||
return cls._select_one_or_none(cls.c.tgid == tgid, cls.c.tg_space == tg_space,
|
||||
cls.c.edit_index == edit_index)
|
||||
|
||||
@classmethod
|
||||
def count_spaces_by_mxid(cls, mxid: EventID, mx_room: RoomID) -> int:
|
||||
@@ -73,9 +63,8 @@ class Message(Base):
|
||||
@classmethod
|
||||
def get_by_mxid(cls, mxid: EventID, mx_room: RoomID, tg_space: TelegramID
|
||||
) -> Optional['Message']:
|
||||
return cls._select_one_or_none(and_(cls.c.mxid == mxid,
|
||||
cls.c.mx_room == mx_room,
|
||||
cls.c.tg_space == tg_space))
|
||||
return cls._select_one_or_none(cls.c.mxid == mxid, cls.c.mx_room == mx_room,
|
||||
cls.c.tg_space == tg_space)
|
||||
|
||||
@classmethod
|
||||
def update_by_tgid(cls, s_tgid: TelegramID, s_tg_space: TelegramID, s_edit_index: int,
|
||||
@@ -92,14 +81,3 @@ class Message(Base):
|
||||
conn.execute(cls.t.update()
|
||||
.where(and_(cls.c.mxid == s_mxid, cls.c.mx_room == s_mx_room))
|
||||
.values(**values))
|
||||
|
||||
@property
|
||||
def _edit_identity(self) -> ClauseElement:
|
||||
return and_(self.c.tgid == self.tgid, self.c.tg_space == self.tg_space,
|
||||
self.c.edit_index == self.edit_index)
|
||||
|
||||
def insert(self) -> None:
|
||||
with self.db.begin() as conn:
|
||||
conn.execute(self.t.insert().values(mxid=self.mxid, mx_room=self.mx_room,
|
||||
tgid=self.tgid, tg_space=self.tg_space,
|
||||
edit_index=self.edit_index))
|
||||
|
||||
@@ -15,12 +15,10 @@
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import Column, Integer, String, Boolean, Text, and_
|
||||
from sqlalchemy.engine.result import RowProxy
|
||||
from sqlalchemy.sql.expression import ClauseElement
|
||||
from sqlalchemy import Column, Integer, String, Boolean, Text
|
||||
|
||||
from mautrix.types import RoomID
|
||||
from mautrix.bridge.db import Base
|
||||
from mautrix.util.db import Base
|
||||
|
||||
from ..types import TelegramID
|
||||
|
||||
@@ -45,17 +43,9 @@ class Portal(Base):
|
||||
about: str = Column(String, nullable=True)
|
||||
photo_id: str = Column(String, nullable=True)
|
||||
|
||||
@classmethod
|
||||
def scan(cls, row: RowProxy) -> Optional['Portal']:
|
||||
(tgid, tg_receiver, peer_type, megagroup, mxid, config, username, title, about,
|
||||
photo_id) = row
|
||||
return cls(tgid=tgid, tg_receiver=tg_receiver, peer_type=peer_type, megagroup=megagroup,
|
||||
mxid=mxid, config=config, username=username, title=title, about=about,
|
||||
photo_id=photo_id)
|
||||
|
||||
@classmethod
|
||||
def get_by_tgid(cls, tgid: TelegramID, tg_receiver: TelegramID) -> Optional['Portal']:
|
||||
return cls._select_one_or_none(and_(cls.c.tgid == tgid, cls.c.tg_receiver == tg_receiver))
|
||||
return cls._select_one_or_none(cls.c.tgid == tgid, cls.c.tg_receiver == tg_receiver)
|
||||
|
||||
@classmethod
|
||||
def get_by_mxid(cls, mxid: RoomID) -> Optional['Portal']:
|
||||
@@ -64,14 +54,3 @@ class Portal(Base):
|
||||
@classmethod
|
||||
def get_by_username(cls, username: str) -> Optional['Portal']:
|
||||
return cls._select_one_or_none(cls.c.username == username)
|
||||
|
||||
@property
|
||||
def _edit_identity(self) -> ClauseElement:
|
||||
return and_(self.c.tgid == self.tgid, self.c.tg_receiver == self.tg_receiver)
|
||||
|
||||
def insert(self) -> None:
|
||||
with self.db.begin() as conn:
|
||||
conn.execute(self.t.insert().values(
|
||||
tgid=self.tgid, tg_receiver=self.tg_receiver, peer_type=self.peer_type,
|
||||
megagroup=self.megagroup, mxid=self.mxid, config=self.config,
|
||||
username=self.username, title=self.title, about=self.about, photo_id=self.photo_id))
|
||||
|
||||
@@ -17,11 +17,9 @@ from typing import Optional, Iterable
|
||||
|
||||
from sqlalchemy import Column, Integer, String, Boolean
|
||||
from sqlalchemy.sql import expression
|
||||
from sqlalchemy.engine.result import RowProxy
|
||||
from sqlalchemy.sql.expression import ClauseElement
|
||||
|
||||
from mautrix.types import UserID, SyncToken
|
||||
from mautrix.bridge.db import Base
|
||||
from mautrix.util.db import Base
|
||||
|
||||
from ..types import TelegramID
|
||||
|
||||
@@ -41,20 +39,9 @@ class Puppet(Base):
|
||||
matrix_registered: bool = Column(Boolean, nullable=False, server_default=expression.false())
|
||||
disable_updates: bool = Column(Boolean, nullable=False, server_default=expression.false())
|
||||
|
||||
@classmethod
|
||||
def scan(cls, row: RowProxy) -> Optional['Puppet']:
|
||||
(id, custom_mxid, access_token, next_batch, displayname, displayname_source, username,
|
||||
photo_id, is_bot, matrix_registered, disable_updates) = row
|
||||
return cls(id=id, custom_mxid=custom_mxid, access_token=access_token, username=username,
|
||||
next_batch=next_batch, displayname=displayname, photo_id=photo_id,
|
||||
displayname_source=displayname_source, matrix_registered=matrix_registered,
|
||||
disable_updates=disable_updates, is_bot=is_bot)
|
||||
|
||||
@classmethod
|
||||
def all_with_custom_mxid(cls) -> Iterable['Puppet']:
|
||||
rows = cls.db.execute(cls.t.select().where(cls.c.custom_mxid != None))
|
||||
for row in rows:
|
||||
yield cls.scan(row)
|
||||
yield from cls._select_all(cls.c.custom_mxid != None)
|
||||
|
||||
@classmethod
|
||||
def get_by_tgid(cls, tgid: TelegramID) -> Optional['Puppet']:
|
||||
@@ -71,16 +58,3 @@ class Puppet(Base):
|
||||
@classmethod
|
||||
def get_by_displayname(cls, displayname: str) -> Optional['Puppet']:
|
||||
return cls._select_one_or_none(cls.c.displayname == displayname)
|
||||
|
||||
@property
|
||||
def _edit_identity(self) -> ClauseElement:
|
||||
return self.c.id == self.id
|
||||
|
||||
def insert(self) -> None:
|
||||
with self.db.begin() as conn:
|
||||
conn.execute(self.t.insert().values(
|
||||
id=self.id, custom_mxid=self.custom_mxid, access_token=self.access_token,
|
||||
next_batch=self.next_batch, displayname=self.displayname, username=self.username,
|
||||
displayname_source=self.displayname_source, photo_id=self.photo_id,
|
||||
is_bot=self.is_bot, matrix_registered=self.matrix_registered,
|
||||
disable_updates=self.disable_updates))
|
||||
|
||||
@@ -19,7 +19,7 @@ from sqlalchemy import Column, ForeignKey, Integer, BigInteger, String, Boolean
|
||||
from sqlalchemy.engine.result import RowProxy
|
||||
|
||||
from mautrix.types import ContentURI
|
||||
from mautrix.bridge.db import Base
|
||||
from mautrix.util.db import Base
|
||||
|
||||
|
||||
class TelegramFile(Base):
|
||||
@@ -38,12 +38,10 @@ class TelegramFile(Base):
|
||||
|
||||
@classmethod
|
||||
def scan(cls, row: RowProxy) -> 'TelegramFile':
|
||||
loc_id, mxc, mime, conv, ts, s, w, h, thumb_id = row
|
||||
thumb = None
|
||||
if thumb_id:
|
||||
thumb = cls.get(thumb_id)
|
||||
return cls(id=loc_id, mxc=mxc, mime_type=mime, was_converted=conv, timestamp=ts,
|
||||
size=s, width=w, height=h, thumbnail_id=thumb_id, thumbnail=thumb)
|
||||
telegram_file: TelegramFile = super().scan(row)
|
||||
if telegram_file.thumbnail_id:
|
||||
telegram_file.thumbnail = cls.get(telegram_file.thumbnail_id)
|
||||
return telegram_file
|
||||
|
||||
@classmethod
|
||||
def get(cls, loc_id: str) -> Optional['TelegramFile']:
|
||||
|
||||
@@ -16,11 +16,9 @@
|
||||
from typing import Optional, Iterable, Tuple
|
||||
|
||||
from sqlalchemy import Column, ForeignKey, ForeignKeyConstraint, Integer, String
|
||||
from sqlalchemy.engine.result import RowProxy
|
||||
from sqlalchemy.sql.expression import ClauseElement
|
||||
|
||||
from mautrix.types import UserID
|
||||
from mautrix.bridge.db import Base
|
||||
from mautrix.util.db import Base
|
||||
|
||||
from ..types import TelegramID
|
||||
|
||||
@@ -34,12 +32,6 @@ class User(Base):
|
||||
tg_phone: str = Column(String, nullable=True)
|
||||
saved_contacts: int = Column(Integer, default=0, nullable=False)
|
||||
|
||||
@classmethod
|
||||
def scan(cls, row: RowProxy) -> 'User':
|
||||
mxid, tgid, tg_username, tg_phone, saved_contacts = row
|
||||
return cls(mxid=mxid, tgid=tgid, tg_username=tg_username, tg_phone=tg_phone,
|
||||
saved_contacts=saved_contacts)
|
||||
|
||||
@classmethod
|
||||
def all_with_tgid(cls) -> Iterable['User']:
|
||||
return cls._select_all(cls.c.tgid != None)
|
||||
@@ -56,16 +48,6 @@ class User(Base):
|
||||
def get_by_username(cls, username: str) -> Optional['User']:
|
||||
return cls._select_one_or_none(cls.c.tg_username == username)
|
||||
|
||||
@property
|
||||
def _edit_identity(self) -> ClauseElement:
|
||||
return self.c.mxid == self.mxid
|
||||
|
||||
def insert(self) -> None:
|
||||
with self.db.begin() as conn:
|
||||
conn.execute(self.t.insert().values(
|
||||
mxid=self.mxid, tgid=self.tgid, tg_username=self.tg_username,
|
||||
tg_phone=self.tg_phone, saved_contacts=self.saved_contacts))
|
||||
|
||||
@property
|
||||
def contacts(self) -> Iterable[TelegramID]:
|
||||
rows = self.db.execute(Contact.t.select().where(Contact.c.user == self.tgid))
|
||||
|
||||
@@ -19,7 +19,7 @@ import argparse
|
||||
from sqlalchemy import orm
|
||||
import sqlalchemy as sql
|
||||
|
||||
from mautrix.bridge.db import Base
|
||||
from mautrix.util.db import Base
|
||||
|
||||
from mautrix_telegram.db import Portal, Message, Puppet, BotChat
|
||||
from mautrix_telegram.config import Config
|
||||
|
||||
Reference in New Issue
Block a user