From 2a65ccc6748318e5ba6da752140ee0d438ec7051 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 15 Jul 2018 00:07:45 +0300 Subject: [PATCH] Cache RoomStates and UserProfiles --- mautrix_telegram/db.py | 6 +-- mautrix_telegram/sqlstatestore.py | 75 +++++++++++++++++++------------ 2 files changed, 49 insertions(+), 32 deletions(-) diff --git a/mautrix_telegram/db.py b/mautrix_telegram/db.py index 4709fbff..5393acad 100644 --- a/mautrix_telegram/db.py +++ b/mautrix_telegram/db.py @@ -90,9 +90,9 @@ class RoomState(Base): _power_levels_text = Column("power_levels", Text, nullable=True) _power_levels_json = None -# def __init__(self, *args, **kwargs): -# super().__init__(*args, **kwargs) -# self._power_levels_json = None + @property + def has_power_levels(self): + return bool(self._power_levels_text) @property def power_levels(self): diff --git a/mautrix_telegram/sqlstatestore.py b/mautrix_telegram/sqlstatestore.py index 1d9442e2..63b030d2 100644 --- a/mautrix_telegram/sqlstatestore.py +++ b/mautrix_telegram/sqlstatestore.py @@ -14,7 +14,7 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -import json +from typing import Dict, Tuple from mautrix_appservice import StateStore @@ -26,6 +26,8 @@ class SQLStateStore(StateStore): def __init__(self, db): super().__init__() self.db = db + self.profile_cache = {} # type: Dict[Tuple[str, str], UserProfile] + self.room_state_cache = {} # type: Dict[str, RoomState] def is_registered(self, user: str) -> bool: puppet = pu.Puppet.get_by_mxid(user) @@ -44,42 +46,60 @@ class SQLStateStore(StateStore): elif event_type == "m.room.member": self.set_member(event["room_id"], event["state_key"], event["content"]) - def get_member(self, room: str, user: str) -> dict: - profile = UserProfile.query.get((room, user)) + def _get_user_profile(self, room_id: str, user_id: str, create: bool = True) -> UserProfile: + key = (room_id, user_id) + try: + return self.profile_cache[key] + except KeyError: + pass + + profile = UserProfile.query.get(key) if profile: - return profile.dict() - return {} + self.profile_cache[key] = profile + elif create: + profile = UserProfile(room_id=room_id, user_id=user_id) + self.db.add(profile) + self.db.commit() + self.profile_cache[key] = profile + return profile + + def get_member(self, room: str, user: str) -> dict: + return self._get_user_profile(room, user).dict() def set_member(self, room: str, user: str, member: dict): - profile = UserProfile(room_id=room, user_id=user, - membership=member.get("membership", "leave"), - displayname=member.get("displayname", None), - avatar_url=member.get("avatar_url", None)) - self.db.merge(profile) + profile = self._get_user_profile(room, user) + profile.membership = member.get("membership", profile.membership or "leave") + profile.displayname = member.get("displayname", profile.displayname) + profile.avatar_url = member.get("avatar_url", profile.avatar_url) self.db.commit() def set_membership(self, room: str, user: str, membership: str): - profile = UserProfile.query.get((room, user)) - if not profile: - profile = UserProfile(room_id=room, user_id=user, membership=membership) - self.db.add(profile) - else: - profile.membership = membership - self.db.commit() + self.set_member(room, user, { + "membership": membership, + }) + + def _get_room_state(self, room_id: str, create: bool = True) -> RoomState: + try: + return self.room_state_cache[room_id] + except KeyError: + pass + + room = RoomState.query.get(room_id) + if room: + self.room_state_cache[room_id] = room + elif create: + room = RoomState(room_id=room_id) + self.room_state_cache[room_id] = room + return room def has_power_levels(self, room: str) -> bool: - room = RoomState.query.get(room) - return room and room._power_levels_text + return self._get_room_state(room).has_power_levels def get_power_levels(self, room: str) -> dict: - return RoomState.query.get(room).power_levels + return self._get_room_state(room).power_levels def set_power_level(self, room: str, user: str, level: int): - room_state = RoomState.query.get(room) - if not room_state: - room_state = RoomState(room) - self.db.add(room_state) - + room_state = self._get_room_state(room) power_levels = room_state.power_levels if not power_levels: power_levels = { @@ -91,9 +111,6 @@ class SQLStateStore(StateStore): self.db.commit() def set_power_levels(self, room: str, content: dict): - state = RoomState.query.get(room) - if not state: - state = RoomState(room_id=room) - self.db.add(state) + state = self._get_room_state(room) state.power_levels = content self.db.commit()