Start moving portals and users to SQLAlchemy Core
This commit is contained in:
+15
-23
@@ -14,7 +14,7 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Awaitable, Dict, List, Match, NewType, Optional, Tuple, TYPE_CHECKING
|
||||
from typing import Awaitable, Dict, List, Iterable, Match, NewType, Optional, Tuple, TYPE_CHECKING
|
||||
import logging
|
||||
import asyncio
|
||||
import re
|
||||
@@ -101,20 +101,19 @@ class User(AbstractUser):
|
||||
return self.displayname
|
||||
|
||||
@property
|
||||
def db_contacts(self) -> List[DBContact]:
|
||||
return [self.db.merge(DBContact(user=self.tgid, contact=puppet.id))
|
||||
for puppet in self.contacts]
|
||||
def db_contacts(self) -> Iterable[DBContact]:
|
||||
return (DBContact(user=self.tgid, contact=puppet.id) for puppet in self.contacts)
|
||||
|
||||
@db_contacts.setter
|
||||
def db_contacts(self, contacts: List[DBContact]) -> None:
|
||||
def db_contacts(self, contacts: Iterable[DBContact]) -> None:
|
||||
self.contacts = [pu.Puppet.get(entry.contact) for entry in contacts] if contacts else []
|
||||
|
||||
@property
|
||||
def db_portals(self) -> List[DBPortal]:
|
||||
return [portal.db_instance for portal in self.portals.values() if not portal.deleted]
|
||||
def db_portals(self) -> Iterable[DBPortal]:
|
||||
return (portal.db_instance for portal in self.portals.values() if not portal.deleted)
|
||||
|
||||
@db_portals.setter
|
||||
def db_portals(self, portals: List[DBPortal]) -> None:
|
||||
def db_portals(self, portals: Iterable[DBPortal]) -> None:
|
||||
self.portals = {
|
||||
(portal.tgid, portal.tg_receiver): po.Portal.get_by_tgid(portal.tgid,
|
||||
portal.tg_receiver)
|
||||
@@ -135,13 +134,8 @@ class User(AbstractUser):
|
||||
portals=self.db_portals)
|
||||
|
||||
def save(self) -> None:
|
||||
self.db_instance.tgid = self.tgid
|
||||
self.db_instance.tg_username = self.username
|
||||
self.db_instance.tg_phone = self.phone
|
||||
self.db_instance.contacts = self.db_contacts
|
||||
self.db_instance.saved_contacts = self.saved_contacts
|
||||
self.db_instance.portals = self.db_portals
|
||||
self.db.commit()
|
||||
self.db_instance.update(tgid=self.tgid, tg_username=self.username, tg_phone=self.phone,
|
||||
saved_contacts=self.saved_contacts)
|
||||
|
||||
def delete(self) -> None:
|
||||
try:
|
||||
@@ -150,8 +144,7 @@ class User(AbstractUser):
|
||||
except KeyError:
|
||||
pass
|
||||
if self._db_instance:
|
||||
self.db.delete(self._db_instance)
|
||||
self.db.commit()
|
||||
self._db_instance.delete()
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, db_user: DBUser) -> 'User':
|
||||
@@ -358,15 +351,14 @@ class User(AbstractUser):
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
user = DBUser.query.get(mxid)
|
||||
user = DBUser.get_by_mxid(mxid)
|
||||
if user:
|
||||
user = cls.from_db(user)
|
||||
return user
|
||||
|
||||
if create:
|
||||
user = cls(mxid)
|
||||
cls.db.add(user.db_instance)
|
||||
cls.db.commit()
|
||||
user.db_instance.insert()
|
||||
return user
|
||||
|
||||
return None
|
||||
@@ -378,7 +370,7 @@ class User(AbstractUser):
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
user = DBUser.query.filter(DBUser.tgid == tgid).one_or_none()
|
||||
user = DBUser.get_by_tgid(tgid)
|
||||
if user:
|
||||
user = cls.from_db(user)
|
||||
return user
|
||||
@@ -394,7 +386,7 @@ class User(AbstractUser):
|
||||
if user.username and user.username.lower() == username.lower():
|
||||
return user
|
||||
|
||||
puppet = DBUser.query.filter(DBUser.tg_username == username).one_or_none()
|
||||
puppet = DBUser.get_by_username(username)
|
||||
if puppet:
|
||||
return cls.from_db(puppet)
|
||||
|
||||
@@ -406,5 +398,5 @@ def init(context: 'Context') -> List[Awaitable['User']]:
|
||||
global config
|
||||
config = context.config
|
||||
|
||||
users = [User.from_db(user) for user in DBUser.query.all()]
|
||||
users = [User.from_db(user) for user in DBUser.get_all()]
|
||||
return [user.ensure_started() for user in users if user.tgid]
|
||||
|
||||
Reference in New Issue
Block a user