Finish moving portals and users to SQLAlchemy Core

This commit is contained in:
Tulir Asokan
2019-02-12 14:42:03 +02:00
parent 53489e7356
commit cf847d3b8e
3 changed files with 118 additions and 40 deletions
+93 -15
View File
@@ -20,7 +20,7 @@ from sqlalchemy import (Column, UniqueConstraint, ForeignKey, ForeignKeyConstrai
from sqlalchemy.engine.result import RowProxy
from sqlalchemy.sql import expression
from sqlalchemy.orm import relationship, Query
from typing import Dict, Optional, List, Iterable
from typing import Dict, Optional, List, Iterable, Tuple
import json
from mautrix_telegram.types import MatrixUserID, MatrixRoomID, MatrixEventID
@@ -48,14 +48,18 @@ class Portal(Base):
about = Column(String, nullable=True)
photo_id = Column(String, nullable=True)
@classmethod
def scan(cls, row) -> Optional['Portal']:
(tgid, tg_receiver, peer_type, megagroup, mxid, config, username, title, about,
photo_id) = row
return cls(tgid=tgid, tg_receiver=tg_receiver, peer_type=peer_type, megagroup=megagroup,
mxid=mxid, config=config, username=username, title=title, about=about,
photo_id=photo_id)
@classmethod
def _one_or_none(cls, rows: RowProxy) -> Optional['Portal']:
try:
(tgid, tg_receiver, peer_type, megagroup, mxid, config,
username, title, about, photo_id) = next(rows)
return cls(tgid=tgid, tg_receiver=tg_receiver, peer_type=peer_type,
megagroup=megagroup, mxid=mxid, config=config, username=username,
title=title, about=about, photo_id=photo_id)
return cls.scan(next(rows))
except StopIteration:
return None
@@ -155,10 +159,6 @@ class User(Base):
tg_username = Column(String, nullable=True)
tg_phone = Column(String, nullable=True)
saved_contacts = Column(Integer, default=0, nullable=False)
contacts = relationship("Contact", uselist=True,
cascade="save-update, merge, delete, delete-orphan"
) # type: List[Contact]
portals = relationship("Portal", secondary="user_portal")
@classmethod
def _one_or_none(cls, rows: RowProxy) -> Optional['User']:
@@ -170,7 +170,7 @@ class User(Base):
return None
@classmethod
def get_all(cls) -> Iterable['User']:
def all(cls) -> Iterable['User']:
rows = cls.db.execute(cls.t.select())
for row in rows:
mxid, tgid, tg_username, tg_phone, saved_contacts = row
@@ -198,6 +198,36 @@ class User(Base):
mxid=self.mxid, tgid=self.tgid, tg_username=self.tg_username, tg_phone=self.tg_phone,
saved_contacts=self.saved_contacts))
@property
def contacts(self) -> Iterable[TelegramID]:
rows = self.db.execute(Contact.t.select().where(Contact.c.user == self.tgid))
for row in rows:
user, contact = row
yield contact
@contacts.setter
def contacts(self, puppets: Iterable[TelegramID]) -> None:
self.db.execute(Contact.t.delete().where(Contact.c.user == self.tgid))
self.db.execute(Contact.t.insert(), [{"user": self.tgid, "contact": tgid}
for tgid in puppets])
@property
def portals(self) -> Iterable[Tuple[TelegramID, TelegramID]]:
rows = self.db.execute(UserPortal.t.select().where(UserPortal.c.user == self.tgid))
for row in rows:
user, portal, portal_receiver = row
yield (portal, portal_receiver)
@portals.setter
def portals(self, portals: Iterable[Tuple[TelegramID, TelegramID]]) -> None:
self.db.execute(UserPortal.t.delete().where(UserPortal.c.user == self.tgid))
self.db.execute(UserPortal.t.insert(),
[{
"user": self.tgid,
"portal": tgid,
"portal_receiver": tg_receiver
} for tgid, tg_receiver in portals])
class UserPortal(Base):
__tablename__ = "user_portal"
@@ -302,7 +332,6 @@ class UserProfile(Base):
class Puppet(Base):
query = None # type: Query
__tablename__ = "puppet"
id = Column(Integer, primary_key=True) # type: TelegramID
@@ -315,6 +344,55 @@ class Puppet(Base):
is_bot = Column(Boolean, nullable=True)
matrix_registered = Column(Boolean, nullable=False, server_default=expression.false())
@classmethod
def scan(cls, row) -> Optional['Puppet']:
(id, custom_mxid, access_token, displayname, displayname_source, username, photo_id,
is_bot, matrix_registered) = row
return cls(id=id, custom_mxid=custom_mxid, access_token=access_token,
displayname=displayname, displayname_source=displayname_source,
username=username, photo_id=photo_id, is_bot=is_bot,
matrix_registered=matrix_registered)
@classmethod
def _one_or_none(cls, rows: RowProxy) -> Optional['Puppet']:
try:
return cls.scan(next(rows))
except StopIteration:
return None
@classmethod
def all_with_custom_mxid(cls) -> Iterable['Puppet']:
rows = cls.db.execute(cls.t.select().where(cls.c.custom_mxid != None))
for row in rows:
yield cls.scan(row)
@classmethod
def get_by_tgid(cls, tgid: TelegramID) -> Optional['Puppet']:
return cls._select_one_or_none(cls.c.id == tgid)
@classmethod
def get_by_custom_mxid(cls, mxid: MatrixRoomID) -> Optional['Puppet']:
return cls._select_one_or_none(cls.c.custom_mxid == mxid)
@classmethod
def get_by_username(cls, username: str) -> Optional['Puppet']:
return cls._select_one_or_none(cls.c.username == username)
@classmethod
def get_by_displayname(cls, displayname: str) -> Optional['Puppet']:
return cls._select_one_or_none(cls.c.displayname == displayname)
@property
def _edit_identity(self):
return self.c.id == self.id
def insert(self) -> None:
self.db.execute(self.t.insert().values(
id=self.id, custom_mxid=self.custom_mxid, access_token=self.access_token,
displayname=self.displayname, displayname_source=self.displayname_source,
username=self.username, photo_id=self.photo_id, is_bot=self.is_bot,
matrix_registered=self.matrix_registered))
# Fucking Telegram not telling bots what chats they are in 3:<
class BotChat(Base):
@@ -359,9 +437,9 @@ class TelegramFile(Base):
def init(db_session, db_engine) -> None:
query = db_session.query_property()
for table in (Portal, Message, User, Puppet, BotChat, TelegramFile, UserProfile, RoomState):
table.query = query
BotChat.query = db_session.query_property()
for table in (Portal, Message, User, Contact, UserPortal, Puppet, TelegramFile, UserProfile,
RoomState):
table.db = db_engine
table.t = table.__table__
table.c = table.t.c
+10 -9
View File
@@ -14,7 +14,8 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Awaitable, Coroutine, Dict, List, Optional, Pattern, Union, TYPE_CHECKING
from typing import (Awaitable, Coroutine, Dict, List, Iterable, Optional, Pattern, Union,
TYPE_CHECKING)
from difflib import SequenceMatcher
from enum import Enum
from aiohttp import ServerDisconnectedError
@@ -396,7 +397,7 @@ class Puppet:
except KeyError:
pass
puppet = DBPuppet.query.get(tgid)
puppet = DBPuppet.get_by_tgid(tgid)
if puppet:
return cls.from_db(puppet)
@@ -426,7 +427,7 @@ class Puppet:
except KeyError:
pass
puppet = DBPuppet.query.filter(DBPuppet.custom_mxid == mxid).one_or_none()
puppet = DBPuppet.get_by_custom_mxid(mxid)
if puppet:
puppet = cls.from_db(puppet)
return puppet
@@ -434,11 +435,11 @@ class Puppet:
return None
@classmethod
def get_all_with_custom_mxid(cls) -> List['Puppet']:
return [cls.by_custom_mxid[puppet.mxid]
def all_with_custom_mxid(cls) -> Iterable['Puppet']:
return (cls.by_custom_mxid[puppet.mxid]
if puppet.custom_mxid in cls.by_custom_mxid
else cls.from_db(puppet)
for puppet in DBPuppet.query.filter(DBPuppet.custom_mxid is not None).all()]
for puppet in DBPuppet.all_with_custom_mxid())
@classmethod
def get_id_from_mxid(cls, mxid: MatrixUserID) -> Optional[TelegramID]:
@@ -460,7 +461,7 @@ class Puppet:
if puppet.username and puppet.username.lower() == username.lower():
return puppet
dbpuppet = DBPuppet.query.filter(DBPuppet.username == username).one_or_none()
dbpuppet = DBPuppet.get_by_username(username)
if dbpuppet:
return cls.from_db(dbpuppet)
@@ -475,7 +476,7 @@ class Puppet:
if puppet.displayname and puppet.displayname == displayname:
return puppet
dbpuppet = DBPuppet.query.filter(DBPuppet.displayname == displayname).one_or_none()
dbpuppet = DBPuppet.get_by_displayname(displayname)
if dbpuppet:
return cls.from_db(dbpuppet)
@@ -491,4 +492,4 @@ def init(context: 'Context') -> List[Coroutine]: # [None, None, PuppetError]
Puppet.hs_domain = config["homeserver"]["domain"]
Puppet.mxid_regex = re.compile(
f"@{Puppet.username_template.format(userid='([0-9]+)')}:{Puppet.hs_domain}")
return [puppet.init_custom_mxid() for puppet in Puppet.get_all_with_custom_mxid()]
return [puppet.init_custom_mxid() for puppet in Puppet.all_with_custom_mxid()]
+15 -16
View File
@@ -48,9 +48,9 @@ class User(AbstractUser):
def __init__(self, mxid: MatrixUserID, tgid: Optional[TelegramID] = None,
username: Optional[str] = None, phone: Optional[str] = None,
db_contacts: Optional[List[DBContact]] = None,
db_contacts: Optional[Iterable[TelegramID]] = None,
saved_contacts: int = 0, is_bot: bool = False,
db_portals: Optional[List[DBPortal]] = None,
db_portals: Optional[Iterable[Tuple[TelegramID, TelegramID]]] = None,
db_instance: Optional[DBUser] = None) -> None:
super().__init__()
self.mxid = mxid # type: MatrixUserID
@@ -60,9 +60,9 @@ class User(AbstractUser):
self.phone = phone # type: str
self.contacts = [] # type: List[pu.Puppet]
self.saved_contacts = saved_contacts # type: int
self.db_contacts = db_contacts # type: List[DBContact]
self.portals = {} # type: Dict[Tuple[int, int], po.Portal]
self.db_portals = db_portals or [] # type: List[DBPortal]
self.db_contacts = db_contacts
self.portals = {} # type: Dict[Tuple[TelegramID, TelegramID], po.Portal]
self.db_portals = db_portals or []
self._db_instance = db_instance # type: Optional[DBUser]
self.command_status = None # type: Dict
@@ -101,23 +101,22 @@ class User(AbstractUser):
return self.displayname
@property
def db_contacts(self) -> Iterable[DBContact]:
return (DBContact(user=self.tgid, contact=puppet.id) for puppet in self.contacts)
def db_contacts(self) -> Iterable[TelegramID]:
return (puppet.id for puppet in self.contacts)
@db_contacts.setter
def db_contacts(self, contacts: Iterable[DBContact]) -> None:
self.contacts = [pu.Puppet.get(entry.contact) for entry in contacts] if contacts else []
def db_contacts(self, contacts: Iterable[TelegramID]) -> None:
self.contacts = [pu.Puppet.get(entry) for entry in contacts] if contacts else []
@property
def db_portals(self) -> Iterable[DBPortal]:
return (portal.db_instance for portal in self.portals.values() if not portal.deleted)
def db_portals(self) -> Iterable[Tuple[TelegramID, TelegramID]]:
return (portal.tgid_full for portal in self.portals.values() if not portal.deleted)
@db_portals.setter
def db_portals(self, portals: Iterable[DBPortal]) -> None:
def db_portals(self, portals: Iterable[Tuple[TelegramID, TelegramID]]) -> None:
self.portals = {
(portal.tgid, portal.tg_receiver): po.Portal.get_by_tgid(portal.tgid,
portal.tg_receiver)
for portal in portals
tgid_full: po.Portal.get_by_tgid(*tgid_full)
for tgid_full in portals
} if portals else {}
# region Database conversion
@@ -398,5 +397,5 @@ def init(context: 'Context') -> List[Awaitable['User']]:
global config
config = context.config
users = [User.from_db(user) for user in DBUser.get_all()]
users = [User.from_db(user) for user in DBUser.all()]
return [user.ensure_started() for user in users if user.tgid]