Start moving portals and users to SQLAlchemy Core

This commit is contained in:
Tulir Asokan
2019-02-12 01:19:12 +02:00
parent c028e1befc
commit 53489e7356
5 changed files with 216 additions and 127 deletions
+15 -23
View File
@@ -14,7 +14,7 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Awaitable, Dict, List, Match, NewType, Optional, Tuple, TYPE_CHECKING
from typing import Awaitable, Dict, List, Iterable, Match, NewType, Optional, Tuple, TYPE_CHECKING
import logging
import asyncio
import re
@@ -101,20 +101,19 @@ class User(AbstractUser):
return self.displayname
@property
def db_contacts(self) -> List[DBContact]:
return [self.db.merge(DBContact(user=self.tgid, contact=puppet.id))
for puppet in self.contacts]
def db_contacts(self) -> Iterable[DBContact]:
return (DBContact(user=self.tgid, contact=puppet.id) for puppet in self.contacts)
@db_contacts.setter
def db_contacts(self, contacts: List[DBContact]) -> None:
def db_contacts(self, contacts: Iterable[DBContact]) -> None:
self.contacts = [pu.Puppet.get(entry.contact) for entry in contacts] if contacts else []
@property
def db_portals(self) -> List[DBPortal]:
return [portal.db_instance for portal in self.portals.values() if not portal.deleted]
def db_portals(self) -> Iterable[DBPortal]:
return (portal.db_instance for portal in self.portals.values() if not portal.deleted)
@db_portals.setter
def db_portals(self, portals: List[DBPortal]) -> None:
def db_portals(self, portals: Iterable[DBPortal]) -> None:
self.portals = {
(portal.tgid, portal.tg_receiver): po.Portal.get_by_tgid(portal.tgid,
portal.tg_receiver)
@@ -135,13 +134,8 @@ class User(AbstractUser):
portals=self.db_portals)
def save(self) -> None:
self.db_instance.tgid = self.tgid
self.db_instance.tg_username = self.username
self.db_instance.tg_phone = self.phone
self.db_instance.contacts = self.db_contacts
self.db_instance.saved_contacts = self.saved_contacts
self.db_instance.portals = self.db_portals
self.db.commit()
self.db_instance.update(tgid=self.tgid, tg_username=self.username, tg_phone=self.phone,
saved_contacts=self.saved_contacts)
def delete(self) -> None:
try:
@@ -150,8 +144,7 @@ class User(AbstractUser):
except KeyError:
pass
if self._db_instance:
self.db.delete(self._db_instance)
self.db.commit()
self._db_instance.delete()
@classmethod
def from_db(cls, db_user: DBUser) -> 'User':
@@ -358,15 +351,14 @@ class User(AbstractUser):
except KeyError:
pass
user = DBUser.query.get(mxid)
user = DBUser.get_by_mxid(mxid)
if user:
user = cls.from_db(user)
return user
if create:
user = cls(mxid)
cls.db.add(user.db_instance)
cls.db.commit()
user.db_instance.insert()
return user
return None
@@ -378,7 +370,7 @@ class User(AbstractUser):
except KeyError:
pass
user = DBUser.query.filter(DBUser.tgid == tgid).one_or_none()
user = DBUser.get_by_tgid(tgid)
if user:
user = cls.from_db(user)
return user
@@ -394,7 +386,7 @@ class User(AbstractUser):
if user.username and user.username.lower() == username.lower():
return user
puppet = DBUser.query.filter(DBUser.tg_username == username).one_or_none()
puppet = DBUser.get_by_username(username)
if puppet:
return cls.from_db(puppet)
@@ -406,5 +398,5 @@ def init(context: 'Context') -> List[Awaitable['User']]:
global config
config = context.config
users = [User.from_db(user) for user in DBUser.query.all()]
users = [User.from_db(user) for user in DBUser.get_all()]
return [user.ensure_started() for user in users if user.tgid]