Switch state store to SQLAlchemy core

This commit is contained in:
Tulir Asokan
2018-12-19 23:32:22 +02:00
parent 65e0ebdb37
commit 4a2bb3d7fc
2 changed files with 79 additions and 34 deletions
+70 -25
View File
@@ -156,31 +156,49 @@ class User(Base):
class RoomState(Base):
query = None # type: Query
db = None # type: Engine
t = None # type: Table
c = None # type: ImmutableColumnCollection
__tablename__ = "mx_room_state"
room_id = Column(String, primary_key=True) # type: MatrixRoomID
_power_levels_text = Column("power_levels", Text, nullable=True)
_power_levels_json = {} # type: Dict
power_levels = Column("power_levels", Text, nullable=True) # type: Optional[Dict]
@property
def _power_levels_text(self) -> Optional[str]:
return json.dumps(self.power_levels) if self.power_levels else None
@property
def has_power_levels(self) -> bool:
return bool(self._power_levels_text)
return bool(self.power_levels)
@property
def power_levels(self) -> Dict:
if not self._power_levels_json and self._power_levels_text:
self._power_levels_json = json.loads(self._power_levels_text)
return self._power_levels_json
@classmethod
def get(cls, room_id: MatrixRoomID) -> Optional['RoomState']:
rows = cls.db.execute(cls.t.select().where(cls.c.room_id == room_id))
try:
room_id, power_levels_text = next(rows)
return RoomState(room_id=room_id, power_levels=(json.loads(power_levels_text)
if power_levels_text else None))
except StopIteration:
return None
@power_levels.setter
def power_levels(self, val: Dict) -> None:
self._power_levels_json = val
self._power_levels_text = json.dumps(val)
def update(self) -> None:
self.db.execute(self.t.update()
.where(self.c.room_id == self.room_id)
.values(power_levels=self._power_levels_text))
def delete(self) -> None:
self.db.execute(self.t.delete().where(self.c.room_id == self.room_id))
def insert(self) -> None:
self.db.execute(self.t.insert().values(room_id=self.room_id,
power_levels=self._power_levels_text))
class UserProfile(Base):
query = None # type: Query
db = None # type: Engine
t = None # type: Table
c = None # type: ImmutableColumnCollection
__tablename__ = "mx_user_profile"
room_id = Column(String, primary_key=True) # type: MatrixRoomID
@@ -196,6 +214,37 @@ class UserProfile(Base):
"avatar_url": self.avatar_url,
}
@classmethod
def get(cls, room_id: MatrixRoomID, user_id: MatrixUserID) -> Optional['UserProfile']:
rows = cls.db.execute(
cls.t.select().where(and_(cls.c.room_id == room_id, cls.c.user_id == user_id)))
try:
room_id, user_id, membership, displayname, avatar_url = next(rows)
return UserProfile(room_id=room_id, user_id=user_id, membership=membership,
displayname=displayname, avatar_url=avatar_url)
except StopIteration:
return None
@classmethod
def delete_all(cls, room_id: MatrixRoomID) -> None:
cls.db.execute(cls.t.delete().where(cls.c.room_id == room_id))
def update(self) -> None:
self.db.execute(self.t.update()
.where(and_(self.c.room_id == self.room_id, self.c.user_id == self.user_id))
.values(membership=self.membership, displayname=self.displayname,
avatar_url=self.avatar_url))
def delete(self) -> None:
self.db.execute(self.t.delete().where(and_(self.c.room_id == self.room_id,
self.c.user_id == self.user_id)))
def insert(self) -> None:
self.db.execute(self.t.insert().values(room_id=self.room_id, user_id=self.user_id,
membership=self.membership,
displayname=self.displayname,
avatar_url=self.avatar_url))
class Contact(Base):
query = None # type: Query
@@ -245,14 +294,10 @@ class TelegramFile(Base):
def init(db_session, db_engine) -> None:
Portal.query = db_session.query_property()
Message.db = db_engine
Message.t = Message.__table__
Message.c = Message.t.c
UserPortal.query = db_session.query_property()
User.query = db_session.query_property()
Puppet.query = db_session.query_property()
BotChat.query = db_session.query_property()
TelegramFile.query = db_session.query_property()
UserProfile.query = db_session.query_property()
RoomState.query = db_session.query_property()
query = db_session.query_property()
for table in (Portal, Message, UserPortal, User, Puppet, BotChat, TelegramFile, UserProfile,
RoomState):
table.query = query
table.db = db_engine
table.t = table.__table__
table.c = table.t.c