From 4a2bb3d7fcd65dbcee70a7dcaffa3e6902e23c6b Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 19 Dec 2018 23:32:22 +0200 Subject: [PATCH] Switch state store to SQLAlchemy core --- mautrix_telegram/db.py | 95 +++++++++++++++++++++++-------- mautrix_telegram/sqlstatestore.py | 18 +++--- 2 files changed, 79 insertions(+), 34 deletions(-) diff --git a/mautrix_telegram/db.py b/mautrix_telegram/db.py index 5da04161..4fad2d2e 100644 --- a/mautrix_telegram/db.py +++ b/mautrix_telegram/db.py @@ -156,31 +156,49 @@ class User(Base): class RoomState(Base): - query = None # type: Query + db = None # type: Engine + t = None # type: Table + c = None # type: ImmutableColumnCollection __tablename__ = "mx_room_state" room_id = Column(String, primary_key=True) # type: MatrixRoomID - _power_levels_text = Column("power_levels", Text, nullable=True) - _power_levels_json = {} # type: Dict + power_levels = Column("power_levels", Text, nullable=True) # type: Optional[Dict] + + @property + def _power_levels_text(self) -> Optional[str]: + return json.dumps(self.power_levels) if self.power_levels else None @property def has_power_levels(self) -> bool: - return bool(self._power_levels_text) + return bool(self.power_levels) - @property - def power_levels(self) -> Dict: - if not self._power_levels_json and self._power_levels_text: - self._power_levels_json = json.loads(self._power_levels_text) - return self._power_levels_json + @classmethod + def get(cls, room_id: MatrixRoomID) -> Optional['RoomState']: + rows = cls.db.execute(cls.t.select().where(cls.c.room_id == room_id)) + try: + room_id, power_levels_text = next(rows) + return RoomState(room_id=room_id, power_levels=(json.loads(power_levels_text) + if power_levels_text else None)) + except StopIteration: + return None - @power_levels.setter - def power_levels(self, val: Dict) -> None: - self._power_levels_json = val - self._power_levels_text = json.dumps(val) + def update(self) -> None: + self.db.execute(self.t.update() + .where(self.c.room_id == self.room_id) + .values(power_levels=self._power_levels_text)) + + def delete(self) -> None: + self.db.execute(self.t.delete().where(self.c.room_id == self.room_id)) + + def insert(self) -> None: + self.db.execute(self.t.insert().values(room_id=self.room_id, + power_levels=self._power_levels_text)) class UserProfile(Base): - query = None # type: Query + db = None # type: Engine + t = None # type: Table + c = None # type: ImmutableColumnCollection __tablename__ = "mx_user_profile" room_id = Column(String, primary_key=True) # type: MatrixRoomID @@ -196,6 +214,37 @@ class UserProfile(Base): "avatar_url": self.avatar_url, } + @classmethod + def get(cls, room_id: MatrixRoomID, user_id: MatrixUserID) -> Optional['UserProfile']: + rows = cls.db.execute( + cls.t.select().where(and_(cls.c.room_id == room_id, cls.c.user_id == user_id))) + try: + room_id, user_id, membership, displayname, avatar_url = next(rows) + return UserProfile(room_id=room_id, user_id=user_id, membership=membership, + displayname=displayname, avatar_url=avatar_url) + except StopIteration: + return None + + @classmethod + def delete_all(cls, room_id: MatrixRoomID) -> None: + cls.db.execute(cls.t.delete().where(cls.c.room_id == room_id)) + + def update(self) -> None: + self.db.execute(self.t.update() + .where(and_(self.c.room_id == self.room_id, self.c.user_id == self.user_id)) + .values(membership=self.membership, displayname=self.displayname, + avatar_url=self.avatar_url)) + + def delete(self) -> None: + self.db.execute(self.t.delete().where(and_(self.c.room_id == self.room_id, + self.c.user_id == self.user_id))) + + def insert(self) -> None: + self.db.execute(self.t.insert().values(room_id=self.room_id, user_id=self.user_id, + membership=self.membership, + displayname=self.displayname, + avatar_url=self.avatar_url)) + class Contact(Base): query = None # type: Query @@ -245,14 +294,10 @@ class TelegramFile(Base): def init(db_session, db_engine) -> None: - Portal.query = db_session.query_property() - Message.db = db_engine - Message.t = Message.__table__ - Message.c = Message.t.c - UserPortal.query = db_session.query_property() - User.query = db_session.query_property() - Puppet.query = db_session.query_property() - BotChat.query = db_session.query_property() - TelegramFile.query = db_session.query_property() - UserProfile.query = db_session.query_property() - RoomState.query = db_session.query_property() + query = db_session.query_property() + for table in (Portal, Message, UserPortal, User, Puppet, BotChat, TelegramFile, UserProfile, + RoomState): + table.query = query + table.db = db_engine + table.t = table.__table__ + table.c = table.t.c diff --git a/mautrix_telegram/sqlstatestore.py b/mautrix_telegram/sqlstatestore.py index ee5b3609..f5ced3bd 100644 --- a/mautrix_telegram/sqlstatestore.py +++ b/mautrix_telegram/sqlstatestore.py @@ -59,13 +59,12 @@ class SQLStateStore(StateStore): except KeyError: pass - profile = UserProfile.query.get(key) + profile = UserProfile.get(*key) if profile: self.profile_cache[key] = profile elif create: - profile = UserProfile(room_id=room_id, user_id=user_id) - self.db.add(profile) - self.db.commit() + profile = UserProfile(room_id=room_id, user_id=user_id, membership="leave") + profile.insert() self.profile_cache[key] = profile return profile @@ -77,7 +76,7 @@ class SQLStateStore(StateStore): profile.membership = member.get("membership", profile.membership or "leave") profile.displayname = member.get("displayname", profile.displayname) profile.avatar_url = member.get("avatar_url", profile.avatar_url) - self.db.commit() + profile.update() def set_membership(self, room: MatrixRoomID, user: MatrixUserID, membership: str) -> None: self.set_member(room, user, { @@ -90,16 +89,17 @@ class SQLStateStore(StateStore): except KeyError: pass - room = RoomState.query.get(room_id) + room = RoomState.get(room_id) if room: self.room_state_cache[room_id] = room elif create: room = RoomState(room_id=room_id) + room.insert() self.room_state_cache[room_id] = room return room def has_power_levels(self, room: MatrixRoomID) -> bool: - return self._get_room_state(room).has_power_levels + return bool(self._get_room_state(room).power_levels) def get_power_levels(self, room: MatrixRoomID) -> Dict: return self._get_room_state(room).power_levels @@ -114,9 +114,9 @@ class SQLStateStore(StateStore): } power_levels[room]["users"][user] = level room_state.power_levels = power_levels - self.db.commit() + room_state.update() def set_power_levels(self, room: MatrixRoomID, content: Dict) -> None: state = self._get_room_state(room) state.power_levels = content - self.db.commit() + state.update()