Finish moving portals and users to SQLAlchemy Core
This commit is contained in:
+93
-15
@@ -20,7 +20,7 @@ from sqlalchemy import (Column, UniqueConstraint, ForeignKey, ForeignKeyConstrai
|
||||
from sqlalchemy.engine.result import RowProxy
|
||||
from sqlalchemy.sql import expression
|
||||
from sqlalchemy.orm import relationship, Query
|
||||
from typing import Dict, Optional, List, Iterable
|
||||
from typing import Dict, Optional, List, Iterable, Tuple
|
||||
import json
|
||||
|
||||
from mautrix_telegram.types import MatrixUserID, MatrixRoomID, MatrixEventID
|
||||
@@ -48,14 +48,18 @@ class Portal(Base):
|
||||
about = Column(String, nullable=True)
|
||||
photo_id = Column(String, nullable=True)
|
||||
|
||||
@classmethod
|
||||
def scan(cls, row) -> Optional['Portal']:
|
||||
(tgid, tg_receiver, peer_type, megagroup, mxid, config, username, title, about,
|
||||
photo_id) = row
|
||||
return cls(tgid=tgid, tg_receiver=tg_receiver, peer_type=peer_type, megagroup=megagroup,
|
||||
mxid=mxid, config=config, username=username, title=title, about=about,
|
||||
photo_id=photo_id)
|
||||
|
||||
@classmethod
|
||||
def _one_or_none(cls, rows: RowProxy) -> Optional['Portal']:
|
||||
try:
|
||||
(tgid, tg_receiver, peer_type, megagroup, mxid, config,
|
||||
username, title, about, photo_id) = next(rows)
|
||||
return cls(tgid=tgid, tg_receiver=tg_receiver, peer_type=peer_type,
|
||||
megagroup=megagroup, mxid=mxid, config=config, username=username,
|
||||
title=title, about=about, photo_id=photo_id)
|
||||
return cls.scan(next(rows))
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
@@ -155,10 +159,6 @@ class User(Base):
|
||||
tg_username = Column(String, nullable=True)
|
||||
tg_phone = Column(String, nullable=True)
|
||||
saved_contacts = Column(Integer, default=0, nullable=False)
|
||||
contacts = relationship("Contact", uselist=True,
|
||||
cascade="save-update, merge, delete, delete-orphan"
|
||||
) # type: List[Contact]
|
||||
portals = relationship("Portal", secondary="user_portal")
|
||||
|
||||
@classmethod
|
||||
def _one_or_none(cls, rows: RowProxy) -> Optional['User']:
|
||||
@@ -170,7 +170,7 @@ class User(Base):
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_all(cls) -> Iterable['User']:
|
||||
def all(cls) -> Iterable['User']:
|
||||
rows = cls.db.execute(cls.t.select())
|
||||
for row in rows:
|
||||
mxid, tgid, tg_username, tg_phone, saved_contacts = row
|
||||
@@ -198,6 +198,36 @@ class User(Base):
|
||||
mxid=self.mxid, tgid=self.tgid, tg_username=self.tg_username, tg_phone=self.tg_phone,
|
||||
saved_contacts=self.saved_contacts))
|
||||
|
||||
@property
|
||||
def contacts(self) -> Iterable[TelegramID]:
|
||||
rows = self.db.execute(Contact.t.select().where(Contact.c.user == self.tgid))
|
||||
for row in rows:
|
||||
user, contact = row
|
||||
yield contact
|
||||
|
||||
@contacts.setter
|
||||
def contacts(self, puppets: Iterable[TelegramID]) -> None:
|
||||
self.db.execute(Contact.t.delete().where(Contact.c.user == self.tgid))
|
||||
self.db.execute(Contact.t.insert(), [{"user": self.tgid, "contact": tgid}
|
||||
for tgid in puppets])
|
||||
|
||||
@property
|
||||
def portals(self) -> Iterable[Tuple[TelegramID, TelegramID]]:
|
||||
rows = self.db.execute(UserPortal.t.select().where(UserPortal.c.user == self.tgid))
|
||||
for row in rows:
|
||||
user, portal, portal_receiver = row
|
||||
yield (portal, portal_receiver)
|
||||
|
||||
@portals.setter
|
||||
def portals(self, portals: Iterable[Tuple[TelegramID, TelegramID]]) -> None:
|
||||
self.db.execute(UserPortal.t.delete().where(UserPortal.c.user == self.tgid))
|
||||
self.db.execute(UserPortal.t.insert(),
|
||||
[{
|
||||
"user": self.tgid,
|
||||
"portal": tgid,
|
||||
"portal_receiver": tg_receiver
|
||||
} for tgid, tg_receiver in portals])
|
||||
|
||||
|
||||
class UserPortal(Base):
|
||||
__tablename__ = "user_portal"
|
||||
@@ -302,7 +332,6 @@ class UserProfile(Base):
|
||||
|
||||
|
||||
class Puppet(Base):
|
||||
query = None # type: Query
|
||||
__tablename__ = "puppet"
|
||||
|
||||
id = Column(Integer, primary_key=True) # type: TelegramID
|
||||
@@ -315,6 +344,55 @@ class Puppet(Base):
|
||||
is_bot = Column(Boolean, nullable=True)
|
||||
matrix_registered = Column(Boolean, nullable=False, server_default=expression.false())
|
||||
|
||||
@classmethod
|
||||
def scan(cls, row) -> Optional['Puppet']:
|
||||
(id, custom_mxid, access_token, displayname, displayname_source, username, photo_id,
|
||||
is_bot, matrix_registered) = row
|
||||
return cls(id=id, custom_mxid=custom_mxid, access_token=access_token,
|
||||
displayname=displayname, displayname_source=displayname_source,
|
||||
username=username, photo_id=photo_id, is_bot=is_bot,
|
||||
matrix_registered=matrix_registered)
|
||||
|
||||
@classmethod
|
||||
def _one_or_none(cls, rows: RowProxy) -> Optional['Puppet']:
|
||||
try:
|
||||
return cls.scan(next(rows))
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def all_with_custom_mxid(cls) -> Iterable['Puppet']:
|
||||
rows = cls.db.execute(cls.t.select().where(cls.c.custom_mxid != None))
|
||||
for row in rows:
|
||||
yield cls.scan(row)
|
||||
|
||||
@classmethod
|
||||
def get_by_tgid(cls, tgid: TelegramID) -> Optional['Puppet']:
|
||||
return cls._select_one_or_none(cls.c.id == tgid)
|
||||
|
||||
@classmethod
|
||||
def get_by_custom_mxid(cls, mxid: MatrixRoomID) -> Optional['Puppet']:
|
||||
return cls._select_one_or_none(cls.c.custom_mxid == mxid)
|
||||
|
||||
@classmethod
|
||||
def get_by_username(cls, username: str) -> Optional['Puppet']:
|
||||
return cls._select_one_or_none(cls.c.username == username)
|
||||
|
||||
@classmethod
|
||||
def get_by_displayname(cls, displayname: str) -> Optional['Puppet']:
|
||||
return cls._select_one_or_none(cls.c.displayname == displayname)
|
||||
|
||||
@property
|
||||
def _edit_identity(self):
|
||||
return self.c.id == self.id
|
||||
|
||||
def insert(self) -> None:
|
||||
self.db.execute(self.t.insert().values(
|
||||
id=self.id, custom_mxid=self.custom_mxid, access_token=self.access_token,
|
||||
displayname=self.displayname, displayname_source=self.displayname_source,
|
||||
username=self.username, photo_id=self.photo_id, is_bot=self.is_bot,
|
||||
matrix_registered=self.matrix_registered))
|
||||
|
||||
|
||||
# Fucking Telegram not telling bots what chats they are in 3:<
|
||||
class BotChat(Base):
|
||||
@@ -359,9 +437,9 @@ class TelegramFile(Base):
|
||||
|
||||
|
||||
def init(db_session, db_engine) -> None:
|
||||
query = db_session.query_property()
|
||||
for table in (Portal, Message, User, Puppet, BotChat, TelegramFile, UserProfile, RoomState):
|
||||
table.query = query
|
||||
BotChat.query = db_session.query_property()
|
||||
for table in (Portal, Message, User, Contact, UserPortal, Puppet, TelegramFile, UserProfile,
|
||||
RoomState):
|
||||
table.db = db_engine
|
||||
table.t = table.__table__
|
||||
table.c = table.t.c
|
||||
|
||||
@@ -14,7 +14,8 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Awaitable, Coroutine, Dict, List, Optional, Pattern, Union, TYPE_CHECKING
|
||||
from typing import (Awaitable, Coroutine, Dict, List, Iterable, Optional, Pattern, Union,
|
||||
TYPE_CHECKING)
|
||||
from difflib import SequenceMatcher
|
||||
from enum import Enum
|
||||
from aiohttp import ServerDisconnectedError
|
||||
@@ -396,7 +397,7 @@ class Puppet:
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
puppet = DBPuppet.query.get(tgid)
|
||||
puppet = DBPuppet.get_by_tgid(tgid)
|
||||
if puppet:
|
||||
return cls.from_db(puppet)
|
||||
|
||||
@@ -426,7 +427,7 @@ class Puppet:
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
puppet = DBPuppet.query.filter(DBPuppet.custom_mxid == mxid).one_or_none()
|
||||
puppet = DBPuppet.get_by_custom_mxid(mxid)
|
||||
if puppet:
|
||||
puppet = cls.from_db(puppet)
|
||||
return puppet
|
||||
@@ -434,11 +435,11 @@ class Puppet:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_all_with_custom_mxid(cls) -> List['Puppet']:
|
||||
return [cls.by_custom_mxid[puppet.mxid]
|
||||
def all_with_custom_mxid(cls) -> Iterable['Puppet']:
|
||||
return (cls.by_custom_mxid[puppet.mxid]
|
||||
if puppet.custom_mxid in cls.by_custom_mxid
|
||||
else cls.from_db(puppet)
|
||||
for puppet in DBPuppet.query.filter(DBPuppet.custom_mxid is not None).all()]
|
||||
for puppet in DBPuppet.all_with_custom_mxid())
|
||||
|
||||
@classmethod
|
||||
def get_id_from_mxid(cls, mxid: MatrixUserID) -> Optional[TelegramID]:
|
||||
@@ -460,7 +461,7 @@ class Puppet:
|
||||
if puppet.username and puppet.username.lower() == username.lower():
|
||||
return puppet
|
||||
|
||||
dbpuppet = DBPuppet.query.filter(DBPuppet.username == username).one_or_none()
|
||||
dbpuppet = DBPuppet.get_by_username(username)
|
||||
if dbpuppet:
|
||||
return cls.from_db(dbpuppet)
|
||||
|
||||
@@ -475,7 +476,7 @@ class Puppet:
|
||||
if puppet.displayname and puppet.displayname == displayname:
|
||||
return puppet
|
||||
|
||||
dbpuppet = DBPuppet.query.filter(DBPuppet.displayname == displayname).one_or_none()
|
||||
dbpuppet = DBPuppet.get_by_displayname(displayname)
|
||||
if dbpuppet:
|
||||
return cls.from_db(dbpuppet)
|
||||
|
||||
@@ -491,4 +492,4 @@ def init(context: 'Context') -> List[Coroutine]: # [None, None, PuppetError]
|
||||
Puppet.hs_domain = config["homeserver"]["domain"]
|
||||
Puppet.mxid_regex = re.compile(
|
||||
f"@{Puppet.username_template.format(userid='([0-9]+)')}:{Puppet.hs_domain}")
|
||||
return [puppet.init_custom_mxid() for puppet in Puppet.get_all_with_custom_mxid()]
|
||||
return [puppet.init_custom_mxid() for puppet in Puppet.all_with_custom_mxid()]
|
||||
|
||||
+15
-16
@@ -48,9 +48,9 @@ class User(AbstractUser):
|
||||
|
||||
def __init__(self, mxid: MatrixUserID, tgid: Optional[TelegramID] = None,
|
||||
username: Optional[str] = None, phone: Optional[str] = None,
|
||||
db_contacts: Optional[List[DBContact]] = None,
|
||||
db_contacts: Optional[Iterable[TelegramID]] = None,
|
||||
saved_contacts: int = 0, is_bot: bool = False,
|
||||
db_portals: Optional[List[DBPortal]] = None,
|
||||
db_portals: Optional[Iterable[Tuple[TelegramID, TelegramID]]] = None,
|
||||
db_instance: Optional[DBUser] = None) -> None:
|
||||
super().__init__()
|
||||
self.mxid = mxid # type: MatrixUserID
|
||||
@@ -60,9 +60,9 @@ class User(AbstractUser):
|
||||
self.phone = phone # type: str
|
||||
self.contacts = [] # type: List[pu.Puppet]
|
||||
self.saved_contacts = saved_contacts # type: int
|
||||
self.db_contacts = db_contacts # type: List[DBContact]
|
||||
self.portals = {} # type: Dict[Tuple[int, int], po.Portal]
|
||||
self.db_portals = db_portals or [] # type: List[DBPortal]
|
||||
self.db_contacts = db_contacts
|
||||
self.portals = {} # type: Dict[Tuple[TelegramID, TelegramID], po.Portal]
|
||||
self.db_portals = db_portals or []
|
||||
self._db_instance = db_instance # type: Optional[DBUser]
|
||||
|
||||
self.command_status = None # type: Dict
|
||||
@@ -101,23 +101,22 @@ class User(AbstractUser):
|
||||
return self.displayname
|
||||
|
||||
@property
|
||||
def db_contacts(self) -> Iterable[DBContact]:
|
||||
return (DBContact(user=self.tgid, contact=puppet.id) for puppet in self.contacts)
|
||||
def db_contacts(self) -> Iterable[TelegramID]:
|
||||
return (puppet.id for puppet in self.contacts)
|
||||
|
||||
@db_contacts.setter
|
||||
def db_contacts(self, contacts: Iterable[DBContact]) -> None:
|
||||
self.contacts = [pu.Puppet.get(entry.contact) for entry in contacts] if contacts else []
|
||||
def db_contacts(self, contacts: Iterable[TelegramID]) -> None:
|
||||
self.contacts = [pu.Puppet.get(entry) for entry in contacts] if contacts else []
|
||||
|
||||
@property
|
||||
def db_portals(self) -> Iterable[DBPortal]:
|
||||
return (portal.db_instance for portal in self.portals.values() if not portal.deleted)
|
||||
def db_portals(self) -> Iterable[Tuple[TelegramID, TelegramID]]:
|
||||
return (portal.tgid_full for portal in self.portals.values() if not portal.deleted)
|
||||
|
||||
@db_portals.setter
|
||||
def db_portals(self, portals: Iterable[DBPortal]) -> None:
|
||||
def db_portals(self, portals: Iterable[Tuple[TelegramID, TelegramID]]) -> None:
|
||||
self.portals = {
|
||||
(portal.tgid, portal.tg_receiver): po.Portal.get_by_tgid(portal.tgid,
|
||||
portal.tg_receiver)
|
||||
for portal in portals
|
||||
tgid_full: po.Portal.get_by_tgid(*tgid_full)
|
||||
for tgid_full in portals
|
||||
} if portals else {}
|
||||
|
||||
# region Database conversion
|
||||
@@ -398,5 +397,5 @@ def init(context: 'Context') -> List[Awaitable['User']]:
|
||||
global config
|
||||
config = context.config
|
||||
|
||||
users = [User.from_db(user) for user in DBUser.get_all()]
|
||||
users = [User.from_db(user) for user in DBUser.all()]
|
||||
return [user.ensure_started() for user in users if user.tgid]
|
||||
|
||||
Reference in New Issue
Block a user