Start moving portals and users to SQLAlchemy Core
This commit is contained in:
@@ -1,2 +1,41 @@
|
||||
from abc import abstractmethod
|
||||
|
||||
from sqlalchemy import Table
|
||||
from sqlalchemy.engine.base import Engine
|
||||
from sqlalchemy.engine.result import RowProxy
|
||||
from sqlalchemy.sql.base import ImmutableColumnCollection
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
Base = declarative_base() # type: declarative_base
|
||||
|
||||
|
||||
class BaseBase:
|
||||
db = None # type: Engine
|
||||
t = None # type: Table
|
||||
__table__ = None # type: Table
|
||||
c = None # type: ImmutableColumnCollection
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def _one_or_none(cls, rows: RowProxy):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def _select_one_or_none(cls, *args):
|
||||
return cls._one_or_none(cls.db.execute(cls.t.select().where(*args)))
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def _edit_identity(self):
|
||||
pass
|
||||
|
||||
def update(self, **values) -> None:
|
||||
self.db.execute(self.t.update()
|
||||
.where(self._edit_identity)
|
||||
.values(**values))
|
||||
for key, value in values.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
def delete(self) -> None:
|
||||
self.db.execute(self.t.delete().where(self._edit_identity))
|
||||
|
||||
|
||||
Base = declarative_base(cls=BaseBase)
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
from abc import abstractmethod
|
||||
|
||||
from sqlalchemy import Table
|
||||
from sqlalchemy.engine.base import Engine
|
||||
from sqlalchemy.engine.result import RowProxy
|
||||
from sqlalchemy.sql.base import ImmutableColumnCollection
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
class Base(declarative_base):
|
||||
db: Engine
|
||||
t: Table
|
||||
__table__: Table
|
||||
c: ImmutableColumnCollection
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def _one_or_none(cls, rows: RowProxy): ...
|
||||
|
||||
@classmethod
|
||||
def _select_one_or_none(cls, *args): ...
|
||||
|
||||
def _edit_identity(self): ...
|
||||
|
||||
def update(self, **values) -> None: ...
|
||||
|
||||
def delete(self) -> None: ...
|
||||
+120
-77
@@ -15,13 +15,12 @@
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from sqlalchemy import (Column, UniqueConstraint, ForeignKey, ForeignKeyConstraint, Integer,
|
||||
BigInteger, String, Boolean, Text, Table,
|
||||
BigInteger, String, Boolean, Text,
|
||||
and_, func, select)
|
||||
from sqlalchemy.engine import Engine, RowProxy
|
||||
from sqlalchemy.engine.result import RowProxy
|
||||
from sqlalchemy.sql import expression
|
||||
from sqlalchemy.orm import relationship, Query
|
||||
from sqlalchemy.sql.base import ImmutableColumnCollection
|
||||
from typing import Dict, Optional, List
|
||||
from typing import Dict, Optional, List, Iterable
|
||||
import json
|
||||
|
||||
from mautrix_telegram.types import MatrixUserID, MatrixRoomID, MatrixEventID
|
||||
@@ -30,7 +29,6 @@ from .base import Base
|
||||
|
||||
|
||||
class Portal(Base):
|
||||
query = None # type: Query
|
||||
__tablename__ = "portal"
|
||||
|
||||
# Telegram chat information
|
||||
@@ -50,11 +48,41 @@ class Portal(Base):
|
||||
about = Column(String, nullable=True)
|
||||
photo_id = Column(String, nullable=True)
|
||||
|
||||
@classmethod
|
||||
def _one_or_none(cls, rows: RowProxy) -> Optional['Portal']:
|
||||
try:
|
||||
(tgid, tg_receiver, peer_type, megagroup, mxid, config,
|
||||
username, title, about, photo_id) = next(rows)
|
||||
return cls(tgid=tgid, tg_receiver=tg_receiver, peer_type=peer_type,
|
||||
megagroup=megagroup, mxid=mxid, config=config, username=username,
|
||||
title=title, about=about, photo_id=photo_id)
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_by_tgid(cls, tgid: TelegramID, tg_receiver: TelegramID) -> Optional['Portal']:
|
||||
return cls._select_one_or_none(and_(cls.c.tgid == tgid, cls.c.tg_receiver == tg_receiver))
|
||||
|
||||
@classmethod
|
||||
def get_by_mxid(cls, mxid: MatrixRoomID) -> Optional['Portal']:
|
||||
return cls._select_one_or_none(cls.c.mxid == mxid)
|
||||
|
||||
@classmethod
|
||||
def get_by_username(cls, username: str) -> Optional['Portal']:
|
||||
return cls._select_one_or_none(cls.c.username == username)
|
||||
|
||||
@property
|
||||
def _edit_identity(self):
|
||||
return and_(self.c.tgid == self.tgid, self.c.tg_receiver == self.tg_receiver)
|
||||
|
||||
def insert(self) -> None:
|
||||
self.db.execute(self.t.insert().values(
|
||||
tgid=self.tgid, tg_receiver=self.tg_receiver, peer_type=self.peer_type,
|
||||
megagroup=self.megagroup, mxid=self.mxid, config=self.config, username=self.username,
|
||||
title=self.title, about=self.about, photo_id=self.photo_id))
|
||||
|
||||
|
||||
class Message(Base):
|
||||
db = None # type: Engine
|
||||
t = None # type: Table
|
||||
c = None # type: ImmutableColumnCollection
|
||||
__tablename__ = "message"
|
||||
|
||||
mxid = Column(String) # type: MatrixEventID
|
||||
@@ -64,11 +92,11 @@ class Message(Base):
|
||||
|
||||
__table_args__ = (UniqueConstraint("mxid", "mx_room", "tg_space", name="_mx_id_room"),)
|
||||
|
||||
@staticmethod
|
||||
def _one_or_none(rows: RowProxy) -> Optional['Message']:
|
||||
@classmethod
|
||||
def _one_or_none(cls, rows: RowProxy) -> Optional['Message']:
|
||||
try:
|
||||
mxid, mx_room, tgid, tg_space = next(rows)
|
||||
return Message(mxid=mxid, mx_room=mx_room, tgid=tgid, tg_space=tg_space)
|
||||
return cls(mxid=mxid, mx_room=mx_room, tgid=tgid, tg_space=tg_space)
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
@@ -79,9 +107,7 @@ class Message(Base):
|
||||
|
||||
@classmethod
|
||||
def get_by_tgid(cls, tgid: TelegramID, tg_space: TelegramID) -> Optional['Message']:
|
||||
rows = cls.db.execute(cls.t.select()
|
||||
.where(and_(cls.c.tgid == tgid, cls.c.tg_space == tg_space)))
|
||||
return cls._one_or_none(rows)
|
||||
return cls._select_one_or_none(and_(cls.c.tgid == tgid, cls.c.tg_space == tg_space))
|
||||
|
||||
@classmethod
|
||||
def count_spaces_by_mxid(cls, mxid: MatrixEventID, mx_room: MatrixRoomID) -> int:
|
||||
@@ -96,9 +122,9 @@ class Message(Base):
|
||||
@classmethod
|
||||
def get_by_mxid(cls, mxid: MatrixEventID, mx_room: MatrixRoomID, tg_space: TelegramID
|
||||
) -> Optional['Message']:
|
||||
rows = cls.db.execute(cls.t.select().where(
|
||||
and_(cls.c.mxid == mxid, cls.c.mx_room == mx_room, cls.c.tg_space == tg_space)))
|
||||
return cls._one_or_none(rows)
|
||||
return cls._select_one_or_none(and_(cls.c.mxid == mxid,
|
||||
cls.c.mx_room == mx_room,
|
||||
cls.c.tg_space == tg_space))
|
||||
|
||||
@classmethod
|
||||
def update_by_tgid(cls, s_tgid: TelegramID, s_tg_space: TelegramID, **values) -> None:
|
||||
@@ -112,36 +138,16 @@ class Message(Base):
|
||||
.where(and_(cls.c.mxid == s_mxid, cls.c.mx_room == s_mx_room))
|
||||
.values(**values))
|
||||
|
||||
def update(self, **values) -> None:
|
||||
for key, value in values.items():
|
||||
setattr(self, key, value)
|
||||
self.update_by_tgid(self.tgid, self.tg_space, **values)
|
||||
|
||||
def delete(self) -> None:
|
||||
self.db.execute(self.t.delete().where(
|
||||
and_(self.c.tgid == self.tgid, self.c.tg_space == self.tg_space)))
|
||||
@property
|
||||
def _edit_identity(self):
|
||||
return and_(self.c.tgid == self.tgid, self.c.tg_space == self.tg_space)
|
||||
|
||||
def insert(self) -> None:
|
||||
self.db.execute(self.t.insert().values(mxid=self.mxid, mx_room=self.mx_room, tgid=self.tgid,
|
||||
tg_space=self.tg_space))
|
||||
|
||||
|
||||
class UserPortal(Base):
|
||||
query = None # type: Query
|
||||
__tablename__ = "user_portal"
|
||||
|
||||
user = Column(Integer, ForeignKey("user.tgid", onupdate="CASCADE", ondelete="CASCADE"),
|
||||
primary_key=True) # type: TelegramID
|
||||
portal = Column(Integer, primary_key=True) # type: TelegramID
|
||||
portal_receiver = Column(Integer, primary_key=True) # type: TelegramID
|
||||
|
||||
__table_args__ = (ForeignKeyConstraint(("portal", "portal_receiver"),
|
||||
("portal.tgid", "portal.tg_receiver"),
|
||||
onupdate="CASCADE", ondelete="CASCADE"),)
|
||||
|
||||
|
||||
class User(Base):
|
||||
query = None # type: Query
|
||||
__tablename__ = "user"
|
||||
|
||||
mxid = Column(String, primary_key=True) # type: MatrixUserID
|
||||
@@ -154,11 +160,66 @@ class User(Base):
|
||||
) # type: List[Contact]
|
||||
portals = relationship("Portal", secondary="user_portal")
|
||||
|
||||
@classmethod
|
||||
def _one_or_none(cls, rows: RowProxy) -> Optional['User']:
|
||||
try:
|
||||
mxid, tgid, tg_username, tg_phone, saved_contacts = next(rows)
|
||||
return cls(mxid=mxid, tgid=tgid, tg_username=tg_username, tg_phone=tg_phone,
|
||||
saved_contacts=saved_contacts)
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_all(cls) -> Iterable['User']:
|
||||
rows = cls.db.execute(cls.t.select())
|
||||
for row in rows:
|
||||
mxid, tgid, tg_username, tg_phone, saved_contacts = row
|
||||
yield cls(mxid=mxid, tgid=tgid, tg_username=tg_username, tg_phone=tg_phone,
|
||||
saved_contacts=saved_contacts)
|
||||
|
||||
@classmethod
|
||||
def get_by_tgid(cls, tgid: TelegramID) -> Optional['User']:
|
||||
return cls._select_one_or_none(cls.c.tgid == tgid)
|
||||
|
||||
@classmethod
|
||||
def get_by_mxid(cls, mxid: MatrixRoomID) -> Optional['User']:
|
||||
return cls._select_one_or_none(cls.c.mxid == mxid)
|
||||
|
||||
@classmethod
|
||||
def get_by_username(cls, username: str) -> Optional['User']:
|
||||
return cls._select_one_or_none(cls.c.username == username)
|
||||
|
||||
@property
|
||||
def _edit_identity(self):
|
||||
return self.c.mxid == self.mxid
|
||||
|
||||
def insert(self) -> None:
|
||||
self.db.execute(self.t.insert().values(
|
||||
mxid=self.mxid, tgid=self.tgid, tg_username=self.tg_username, tg_phone=self.tg_phone,
|
||||
saved_contacts=self.saved_contacts))
|
||||
|
||||
|
||||
class UserPortal(Base):
|
||||
__tablename__ = "user_portal"
|
||||
|
||||
user = Column(Integer, ForeignKey("user.tgid", onupdate="CASCADE", ondelete="CASCADE"),
|
||||
primary_key=True) # type: TelegramID
|
||||
portal = Column(Integer, primary_key=True) # type: TelegramID
|
||||
portal_receiver = Column(Integer, primary_key=True) # type: TelegramID
|
||||
|
||||
__table_args__ = (ForeignKeyConstraint(("portal", "portal_receiver"),
|
||||
("portal.tgid", "portal.tg_receiver"),
|
||||
onupdate="CASCADE", ondelete="CASCADE"),)
|
||||
|
||||
|
||||
class Contact(Base):
|
||||
__tablename__ = "contact"
|
||||
|
||||
user = Column(Integer, ForeignKey("user.tgid"), primary_key=True) # type: TelegramID
|
||||
contact = Column(Integer, ForeignKey("puppet.id"), primary_key=True) # type: TelegramID
|
||||
|
||||
|
||||
class RoomState(Base):
|
||||
db = None # type: Engine
|
||||
t = None # type: Table
|
||||
c = None # type: ImmutableColumnCollection
|
||||
__tablename__ = "mx_room_state"
|
||||
|
||||
room_id = Column(String, primary_key=True) # type: MatrixRoomID
|
||||
@@ -177,18 +238,17 @@ class RoomState(Base):
|
||||
rows = cls.db.execute(cls.t.select().where(cls.c.room_id == room_id))
|
||||
try:
|
||||
room_id, power_levels_text = next(rows)
|
||||
return RoomState(room_id=room_id, power_levels=(json.loads(power_levels_text)
|
||||
if power_levels_text else None))
|
||||
return cls(room_id=room_id, power_levels=(json.loads(power_levels_text)
|
||||
if power_levels_text else None))
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
def update(self) -> None:
|
||||
self.db.execute(self.t.update()
|
||||
.where(self.c.room_id == self.room_id)
|
||||
.values(power_levels=self._power_levels_text))
|
||||
return super().update(power_levels=self._power_levels_text)
|
||||
|
||||
def delete(self) -> None:
|
||||
self.db.execute(self.t.delete().where(self.c.room_id == self.room_id))
|
||||
@property
|
||||
def _edit_identity(self):
|
||||
return self.c.room_id == self.room_id
|
||||
|
||||
def insert(self) -> None:
|
||||
self.db.execute(self.t.insert().values(room_id=self.room_id,
|
||||
@@ -196,9 +256,6 @@ class RoomState(Base):
|
||||
|
||||
|
||||
class UserProfile(Base):
|
||||
db = None # type: Engine
|
||||
t = None # type: Table
|
||||
c = None # type: ImmutableColumnCollection
|
||||
__tablename__ = "mx_user_profile"
|
||||
|
||||
room_id = Column(String, primary_key=True) # type: MatrixRoomID
|
||||
@@ -220,8 +277,8 @@ class UserProfile(Base):
|
||||
cls.t.select().where(and_(cls.c.room_id == room_id, cls.c.user_id == user_id)))
|
||||
try:
|
||||
room_id, user_id, membership, displayname, avatar_url = next(rows)
|
||||
return UserProfile(room_id=room_id, user_id=user_id, membership=membership,
|
||||
displayname=displayname, avatar_url=avatar_url)
|
||||
return cls(room_id=room_id, user_id=user_id, membership=membership,
|
||||
displayname=displayname, avatar_url=avatar_url)
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
@@ -230,14 +287,12 @@ class UserProfile(Base):
|
||||
cls.db.execute(cls.t.delete().where(cls.c.room_id == room_id))
|
||||
|
||||
def update(self) -> None:
|
||||
self.db.execute(self.t.update()
|
||||
.where(and_(self.c.room_id == self.room_id, self.c.user_id == self.user_id))
|
||||
.values(membership=self.membership, displayname=self.displayname,
|
||||
avatar_url=self.avatar_url))
|
||||
super().update(membership=self.membership, displayname=self.displayname,
|
||||
avatar_url=self.avatar_url)
|
||||
|
||||
def delete(self) -> None:
|
||||
self.db.execute(self.t.delete().where(and_(self.c.room_id == self.room_id,
|
||||
self.c.user_id == self.user_id)))
|
||||
@property
|
||||
def _edit_identity(self):
|
||||
return and_(self.c.room_id == self.room_id, self.c.user_id == self.user_id)
|
||||
|
||||
def insert(self) -> None:
|
||||
self.db.execute(self.t.insert().values(room_id=self.room_id, user_id=self.user_id,
|
||||
@@ -246,14 +301,6 @@ class UserProfile(Base):
|
||||
avatar_url=self.avatar_url))
|
||||
|
||||
|
||||
class Contact(Base):
|
||||
query = None # type: Query
|
||||
__tablename__ = "contact"
|
||||
|
||||
user = Column(Integer, ForeignKey("user.tgid"), primary_key=True) # type: TelegramID
|
||||
contact = Column(Integer, ForeignKey("puppet.id"), primary_key=True) # type: TelegramID
|
||||
|
||||
|
||||
class Puppet(Base):
|
||||
query = None # type: Query
|
||||
__tablename__ = "puppet"
|
||||
@@ -278,9 +325,6 @@ class BotChat(Base):
|
||||
|
||||
|
||||
class TelegramFile(Base):
|
||||
db = None # type: Engine
|
||||
t = None # type: Table
|
||||
c = None # type: ImmutableColumnCollection
|
||||
__tablename__ = "telegram_file"
|
||||
|
||||
id = Column(String, primary_key=True)
|
||||
@@ -302,8 +346,8 @@ class TelegramFile(Base):
|
||||
thumb = None
|
||||
if thumb_id:
|
||||
thumb = cls.get(thumb_id)
|
||||
return TelegramFile(id=id, mxc=mxc, mime_type=mime, was_converted=conv, timestamp=ts,
|
||||
size=s, width=w, height=h, thumbnail_id=thumb_id, thumbnail=thumb)
|
||||
return cls(id=id, mxc=mxc, mime_type=mime, was_converted=conv, timestamp=ts,
|
||||
size=s, width=w, height=h, thumbnail_id=thumb_id, thumbnail=thumb)
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
@@ -316,8 +360,7 @@ class TelegramFile(Base):
|
||||
|
||||
def init(db_session, db_engine) -> None:
|
||||
query = db_session.query_property()
|
||||
for table in (Portal, Message, UserPortal, User, Puppet, BotChat, TelegramFile, UserProfile,
|
||||
RoomState):
|
||||
for table in (Portal, Message, User, Puppet, BotChat, TelegramFile, UserProfile, RoomState):
|
||||
table.query = query
|
||||
table.db = db_engine
|
||||
table.t = table.__table__
|
||||
|
||||
+15
-26
@@ -31,7 +31,6 @@ import json
|
||||
import re
|
||||
|
||||
import magic
|
||||
from sqlalchemy import orm
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from telethon.tl.functions.messages import (
|
||||
@@ -89,7 +88,6 @@ InviteList = Union[MatrixUserID, List[MatrixUserID]]
|
||||
|
||||
class Portal:
|
||||
log = logging.getLogger("mau.portal") # type: logging.Logger
|
||||
db = None # type: orm.Session
|
||||
az = None # type: AppService
|
||||
bot = None # type: Bot
|
||||
loop = None # type: asyncio.AbstractEventLoop
|
||||
@@ -1255,8 +1253,7 @@ class Portal:
|
||||
self.tg_receiver = self.tgid
|
||||
self.by_tgid[self.tgid_full] = self
|
||||
await self.update_info(source, entity)
|
||||
self.db.add(self.db_instance)
|
||||
self.save()
|
||||
self.db_instance.insert()
|
||||
|
||||
if self.bot and self.bot.tgid in invites:
|
||||
self.bot.add_chat(self.tgid, self.peer_type)
|
||||
@@ -1842,15 +1839,13 @@ class Portal:
|
||||
del self.by_tgid[self.tgid_full]
|
||||
except KeyError:
|
||||
pass
|
||||
self.tgid = new_id
|
||||
self.tg_receiver = new_id
|
||||
existing = self.by_tgid[self.tgid_full]
|
||||
existing = self.by_tgid[(new_id, new_id)]
|
||||
if existing:
|
||||
existing.delete()
|
||||
self.db_instance.update(tgid=new_id, tg_receiver=new_id)
|
||||
self.tgid = new_id
|
||||
self.tg_receiver = new_id
|
||||
self.by_tgid[self.tgid_full] = self
|
||||
self.db_instance.tgid = self.tgid
|
||||
self.db_instance.tg_receiver = self.tg_receiver
|
||||
self.save()
|
||||
|
||||
def migrate_and_save_matrix(self, new_id: MatrixRoomID) -> None:
|
||||
try:
|
||||
@@ -1858,17 +1853,13 @@ class Portal:
|
||||
except KeyError:
|
||||
pass
|
||||
self.mxid = new_id
|
||||
self.db_instance.update(mxid=self.mxid)
|
||||
self.by_mxid[self.mxid] = self
|
||||
self.save()
|
||||
|
||||
def save(self) -> None:
|
||||
self.db_instance.mxid = self.mxid
|
||||
self.db_instance.username = self.username
|
||||
self.db_instance.title = self.title
|
||||
self.db_instance.about = self.about
|
||||
self.db_instance.photo_id = self.photo_id
|
||||
self.db_instance.config = json.dumps(self.local_config)
|
||||
self.db.commit()
|
||||
self.db_instance.update(mxid=self.mxid, username=self.username, title=self.title,
|
||||
about=self.about, photo_id=self.photo_id,
|
||||
config=json.dumps(self.local_config))
|
||||
|
||||
def delete(self) -> None:
|
||||
try:
|
||||
@@ -1880,8 +1871,7 @@ class Portal:
|
||||
except KeyError:
|
||||
pass
|
||||
if self._db_instance:
|
||||
self.db.delete(self._db_instance)
|
||||
self.db.commit()
|
||||
self._db_instance.delete()
|
||||
self.deleted = True
|
||||
|
||||
@classmethod
|
||||
@@ -1902,7 +1892,7 @@ class Portal:
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
portal = DBPortal.query.filter(DBPortal.mxid == mxid).one_or_none()
|
||||
portal = DBPortal.get_by_mxid(mxid)
|
||||
if portal:
|
||||
return cls.from_db(portal)
|
||||
|
||||
@@ -1924,7 +1914,7 @@ class Portal:
|
||||
if portal.username and portal.username.lower() == username.lower():
|
||||
return portal
|
||||
|
||||
dbportal = DBPortal.query.filter(DBPortal.username == username).one_or_none()
|
||||
dbportal = DBPortal.get_by_username(username)
|
||||
if dbportal:
|
||||
return cls.from_db(dbportal)
|
||||
|
||||
@@ -1940,14 +1930,13 @@ class Portal:
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
portal = DBPortal.query.get(tgid_full)
|
||||
portal = DBPortal.get_by_tgid(tgid, tg_receiver)
|
||||
if portal:
|
||||
return cls.from_db(portal)
|
||||
|
||||
if peer_type:
|
||||
portal = Portal(tgid, peer_type=peer_type, tg_receiver=tg_receiver)
|
||||
cls.db.add(portal.db_instance)
|
||||
cls.db.commit()
|
||||
portal.db_instance.insert()
|
||||
return portal
|
||||
|
||||
return None
|
||||
@@ -1987,7 +1976,7 @@ class Portal:
|
||||
|
||||
def init(context: Context) -> None:
|
||||
global config
|
||||
Portal.az, Portal.db, config, Portal.loop, Portal.bot = context.core
|
||||
Portal.az, _, config, Portal.loop, Portal.bot = context.core
|
||||
Portal.max_initial_member_sync = config["bridge.max_initial_member_sync"]
|
||||
Portal.sync_channel_members = config["bridge.sync_channel_members"]
|
||||
Portal.sync_matrix_state = config["bridge.sync_matrix_state"]
|
||||
|
||||
+15
-23
@@ -14,7 +14,7 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Awaitable, Dict, List, Match, NewType, Optional, Tuple, TYPE_CHECKING
|
||||
from typing import Awaitable, Dict, List, Iterable, Match, NewType, Optional, Tuple, TYPE_CHECKING
|
||||
import logging
|
||||
import asyncio
|
||||
import re
|
||||
@@ -101,20 +101,19 @@ class User(AbstractUser):
|
||||
return self.displayname
|
||||
|
||||
@property
|
||||
def db_contacts(self) -> List[DBContact]:
|
||||
return [self.db.merge(DBContact(user=self.tgid, contact=puppet.id))
|
||||
for puppet in self.contacts]
|
||||
def db_contacts(self) -> Iterable[DBContact]:
|
||||
return (DBContact(user=self.tgid, contact=puppet.id) for puppet in self.contacts)
|
||||
|
||||
@db_contacts.setter
|
||||
def db_contacts(self, contacts: List[DBContact]) -> None:
|
||||
def db_contacts(self, contacts: Iterable[DBContact]) -> None:
|
||||
self.contacts = [pu.Puppet.get(entry.contact) for entry in contacts] if contacts else []
|
||||
|
||||
@property
|
||||
def db_portals(self) -> List[DBPortal]:
|
||||
return [portal.db_instance for portal in self.portals.values() if not portal.deleted]
|
||||
def db_portals(self) -> Iterable[DBPortal]:
|
||||
return (portal.db_instance for portal in self.portals.values() if not portal.deleted)
|
||||
|
||||
@db_portals.setter
|
||||
def db_portals(self, portals: List[DBPortal]) -> None:
|
||||
def db_portals(self, portals: Iterable[DBPortal]) -> None:
|
||||
self.portals = {
|
||||
(portal.tgid, portal.tg_receiver): po.Portal.get_by_tgid(portal.tgid,
|
||||
portal.tg_receiver)
|
||||
@@ -135,13 +134,8 @@ class User(AbstractUser):
|
||||
portals=self.db_portals)
|
||||
|
||||
def save(self) -> None:
|
||||
self.db_instance.tgid = self.tgid
|
||||
self.db_instance.tg_username = self.username
|
||||
self.db_instance.tg_phone = self.phone
|
||||
self.db_instance.contacts = self.db_contacts
|
||||
self.db_instance.saved_contacts = self.saved_contacts
|
||||
self.db_instance.portals = self.db_portals
|
||||
self.db.commit()
|
||||
self.db_instance.update(tgid=self.tgid, tg_username=self.username, tg_phone=self.phone,
|
||||
saved_contacts=self.saved_contacts)
|
||||
|
||||
def delete(self) -> None:
|
||||
try:
|
||||
@@ -150,8 +144,7 @@ class User(AbstractUser):
|
||||
except KeyError:
|
||||
pass
|
||||
if self._db_instance:
|
||||
self.db.delete(self._db_instance)
|
||||
self.db.commit()
|
||||
self._db_instance.delete()
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, db_user: DBUser) -> 'User':
|
||||
@@ -358,15 +351,14 @@ class User(AbstractUser):
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
user = DBUser.query.get(mxid)
|
||||
user = DBUser.get_by_mxid(mxid)
|
||||
if user:
|
||||
user = cls.from_db(user)
|
||||
return user
|
||||
|
||||
if create:
|
||||
user = cls(mxid)
|
||||
cls.db.add(user.db_instance)
|
||||
cls.db.commit()
|
||||
user.db_instance.insert()
|
||||
return user
|
||||
|
||||
return None
|
||||
@@ -378,7 +370,7 @@ class User(AbstractUser):
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
user = DBUser.query.filter(DBUser.tgid == tgid).one_or_none()
|
||||
user = DBUser.get_by_tgid(tgid)
|
||||
if user:
|
||||
user = cls.from_db(user)
|
||||
return user
|
||||
@@ -394,7 +386,7 @@ class User(AbstractUser):
|
||||
if user.username and user.username.lower() == username.lower():
|
||||
return user
|
||||
|
||||
puppet = DBUser.query.filter(DBUser.tg_username == username).one_or_none()
|
||||
puppet = DBUser.get_by_username(username)
|
||||
if puppet:
|
||||
return cls.from_db(puppet)
|
||||
|
||||
@@ -406,5 +398,5 @@ def init(context: 'Context') -> List[Awaitable['User']]:
|
||||
global config
|
||||
config = context.config
|
||||
|
||||
users = [User.from_db(user) for user in DBUser.query.all()]
|
||||
users = [User.from_db(user) for user in DBUser.get_all()]
|
||||
return [user.ensure_started() for user in users if user.tgid]
|
||||
|
||||
Reference in New Issue
Block a user