From 5930b2e3bb0b07907d629598295786b60e66c439 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 19 Feb 2018 20:35:34 +0200 Subject: [PATCH] Stop using db.merge() in most places --- mautrix_telegram/portal.py | 43 ++++++++++++++++++++++---------------- mautrix_telegram/puppet.py | 26 ++++++++++++++++------- mautrix_telegram/user.py | 32 ++++++++++++++++++---------- 3 files changed, 64 insertions(+), 37 deletions(-) diff --git a/mautrix_telegram/portal.py b/mautrix_telegram/portal.py index 7cc88c6d..88ead3c8 100644 --- a/mautrix_telegram/portal.py +++ b/mautrix_telegram/portal.py @@ -50,7 +50,7 @@ class Portal: by_tgid = {} def __init__(self, tgid, peer_type, tg_receiver=None, mxid=None, username=None, title=None, - about=None, photo_id=None, save_to_cache=True): + about=None, photo_id=None, db_instance=None): self.mxid = mxid self.tgid = tgid self.tg_receiver = tg_receiver or tgid @@ -59,6 +59,8 @@ class Portal: self.title = title self.about = about self.photo_id = photo_id + self._db_instance = db_instance + self._main_intent = None self._room_create_lock = asyncio.Lock() @@ -66,11 +68,10 @@ class Portal: self._dedup_mxid = {} self._dedup_action = deque() - if save_to_cache: - if tgid: - self.by_tgid[self.tgid_full] = self - if mxid: - self.by_mxid[mxid] = self + if tgid: + self.by_tgid[self.tgid_full] = self + if mxid: + self.by_mxid[mxid] = self @property def tgid_full(self): @@ -1047,13 +1048,16 @@ class Portal: # endregion # region Database conversion - def to_db(self, merge=True): - portal = DBPortal(tgid=self.tgid, tg_receiver=self.tg_receiver, peer_type=self.peer_type, + @property + def db_instance(self): + if not self._db_instance: + self._db_instance = self.new_db_instance() + return self._db_instance + + def new_db_instance(self): + return DBPortal(tgid=self.tgid, tg_receiver=self.tg_receiver, peer_type=self.peer_type, mxid=self.mxid, username=self.username, title=self.title, about=self.about, photo_id=self.photo_id) - if merge: - return self.db.merge(portal) - return portal def migrate_and_save(self, new_id): existing = DBPortal.query.get(self.tgid_full) @@ -1069,7 +1073,11 @@ class Portal: self.save() def save(self): - self.to_db() + self.db_instance.mxid = self.mxid + self.db_instance.username = self.username + self.db_instance.title = self.title + self.db_instance.about = self.about + self.db_instance.photo_id = self.photo_id self.db.commit() def delete(self): @@ -1078,7 +1086,7 @@ class Portal: del self.by_mxid[self.mxid] except KeyError: pass - self.db.delete(self.to_db()) + self.db.delete(self.db_instance) self.db.commit() @classmethod @@ -1086,7 +1094,8 @@ class Portal: return Portal(tgid=db_portal.tgid, tg_receiver=db_portal.tg_receiver, peer_type=db_portal.peer_type, mxid=db_portal.mxid, username=db_portal.username, title=db_portal.title, - about=db_portal.about, photo_id=db_portal.photo_id) + about=db_portal.about, photo_id=db_portal.photo_id, + db_instance=db_portal) # endregion # region Class instance lookup @@ -1118,11 +1127,9 @@ class Portal: return cls.from_db(portal) if peer_type: - portal = Portal(tgid, peer_type=peer_type, tg_receiver=tg_receiver, - save_to_cache=False) - cls.db.add(portal.to_db(merge=False)) + portal = Portal(tgid, peer_type=peer_type, tg_receiver=tg_receiver) + cls.db.add(portal.db_instance) cls.db.commit() - cls.by_tgid[portal.tgid_full] = portal return portal return None diff --git a/mautrix_telegram/puppet.py b/mautrix_telegram/puppet.py index bfb9c28c..d41c9e82 100644 --- a/mautrix_telegram/puppet.py +++ b/mautrix_telegram/puppet.py @@ -35,13 +35,15 @@ class Puppet: hs_domain = None cache = {} - def __init__(self, id=None, username=None, displayname=None, photo_id=None): + def __init__(self, id=None, username=None, displayname=None, photo_id=None, db_instance=None): self.id = id self.mxid = self.get_mxid_from_id(self.id) self.username = username self.displayname = displayname self.photo_id = photo_id + self._db_instance = db_instance + self.intent = self.az.intent.user(self.mxid) self.cache[id] = self @@ -50,17 +52,25 @@ class Puppet: def tgid(self): return self.id - def to_db(self): - return self.db.merge( - DBPuppet(id=self.id, username=self.username, displayname=self.displayname, - photo_id=self.photo_id)) + @property + def db_instance(self): + if not self._db_instance: + self._db_instance = self.new_db_instance() + return self._db_instance + + def new_db_instance(self): + return DBPuppet(id=self.id, username=self.username, displayname=self.displayname, + photo_id=self.photo_id) @classmethod def from_db(cls, db_puppet): - return Puppet(db_puppet.id, db_puppet.username, db_puppet.displayname, db_puppet.photo_id) + return Puppet(db_puppet.id, db_puppet.username, db_puppet.displayname, db_puppet.photo_id, + db_instance=db_puppet) def save(self): - self.to_db() + self.db_instance.username = self.username + self.db_instance.displayname = self.displayname + self.db_instance.photo_id = self.photo_id self.db.commit() def similarity(self, query): @@ -142,7 +152,7 @@ class Puppet: if create: puppet = cls(id) - cls.db.add(puppet.to_db()) + cls.db.add(puppet.db_instance) cls.db.commit() return puppet diff --git a/mautrix_telegram/user.py b/mautrix_telegram/user.py index 4ecd64f6..dd77893b 100644 --- a/mautrix_telegram/user.py +++ b/mautrix_telegram/user.py @@ -36,7 +36,7 @@ class User(AbstractUser): by_tgid = {} def __init__(self, mxid, tgid=None, username=None, db_contacts=None, saved_contacts=0, - db_portals=None): + db_portals=None, db_instance=None): super().__init__() self.mxid = mxid self.tgid = tgid @@ -46,6 +46,7 @@ class User(AbstractUser): self.db_contacts = db_contacts self.portals = {} self.db_portals = db_portals + self._db_instance = db_instance self.command_status = None @@ -87,7 +88,7 @@ class User(AbstractUser): @property def db_portals(self): - return [portal.to_db(merge=False) for _, portal in self.portals.items()] + return [portal.db_instance for portal in self.portals.values()] @db_portals.setter def db_portals(self, portals): @@ -100,14 +101,23 @@ class User(AbstractUser): # region Database conversion - def to_db(self): - return self.db.merge( - DBUser(mxid=self.mxid, tgid=self.tgid, tg_username=self.username, - contacts=self.db_contacts, saved_contacts=self.saved_contacts, - portals=self.db_portals)) + @property + def db_instance(self): + if not self._db_instance: + self._db_instance = self.new_db_instance() + return self._db_instance + + def new_db_instance(self): + return DBUser(mxid=self.mxid, tgid=self.tgid, tg_username=self.username, + contacts=self.db_contacts, saved_contacts=self.saved_contacts, + portals=self.db_portals) def save(self): - self.to_db() + self.db_instance.tgid = self.tgid + self.db_instance.username = self.username + self.db_instance.contacts = self.db_contacts + self.db_instance.saved_contacts = self.saved_contacts + self.db_instance.portals = self.db_portals self.db.commit() def delete(self): @@ -116,13 +126,13 @@ class User(AbstractUser): del self.by_tgid[self.tgid] except KeyError: pass - self.db.delete(self.to_db()) + self.db.delete(self.db_instance) self.db.commit() @classmethod def from_db(cls, db_user): return User(db_user.mxid, db_user.tgid, db_user.tg_username, db_user.contacts, - db_user.saved_contacts, db_user.portals) + db_user.saved_contacts, db_user.portals, db_instance=db_user) # endregion # region Telegram connection management @@ -277,7 +287,7 @@ class User(AbstractUser): if create: user = cls(mxid) - cls.db.add(user.to_db()) + cls.db.add(user.db_instance) cls.db.commit() return user