From 91741864421d7f11bae38cdd71d3650fbfd54e30 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 14 Feb 2019 00:06:45 +0200 Subject: [PATCH] Stop using SQLAlchemy ORM everywhere --- mautrix_telegram/__main__.py | 2 +- mautrix_telegram/bot.py | 16 ++++++---------- mautrix_telegram/db/__init__.py | 5 ++--- mautrix_telegram/db/bot_chat.py | 18 ++++++++++++++++-- mautrix_telegram/puppet.py | 13 ++++--------- 5 files changed, 29 insertions(+), 25 deletions(-) diff --git a/mautrix_telegram/__main__.py b/mautrix_telegram/__main__.py index 10ad60ea..e70f6cfe 100644 --- a/mautrix_telegram/__main__.py +++ b/mautrix_telegram/__main__.py @@ -112,7 +112,7 @@ if config["appservice.provisioning.enabled"]: context.provisioning_api = provisioning_api with appserv.run(config["appservice.hostname"], config["appservice.port"]) as start: - init_db(db_session, db_engine) + init_db(db_engine) init_abstract_user(context) context.bot = init_bot(context) context.mx = MatrixHandler(context) diff --git a/mautrix_telegram/bot.py b/mautrix_telegram/bot.py index c0399797..8718fcad 100644 --- a/mautrix_telegram/bot.py +++ b/mautrix_telegram/bot.py @@ -56,7 +56,7 @@ class Bot(AbstractUser): self.username = None # type: str self.is_relaybot = True # type: bool self.is_bot = True # type: bool - self.chats = {chat.id: chat.type for chat in BotChat.query.all()} # type: Dict[int, str] + self.chats = {chat.id: chat.type for chat in BotChat.all()} # type: Dict[int, str] self.tg_whitelist = [] # type: List[int] self.whitelist_group_admins = (config["bridge.relaybot.whitelist_group_admins"] or False) # type: bool @@ -114,23 +114,19 @@ class Bot(AbstractUser): def unregister_portal(self, portal: po.Portal) -> None: self.remove_chat(portal.tgid) - def add_chat(self, chat_id: int, chat_type: str) -> None: + def add_chat(self, chat_id: TelegramID, chat_type: str) -> None: if chat_id not in self.chats: self.chats[chat_id] = chat_type - self.db.add(BotChat(id=TelegramID(chat_id), type=chat_type)) - self.db.commit() + BotChat(id=TelegramID(chat_id), type=chat_type).insert() - def remove_chat(self, chat_id: int) -> None: + def remove_chat(self, chat_id: TelegramID) -> None: try: del self.chats[chat_id] except KeyError: pass - existing_chat = BotChat.query.get(chat_id) - if existing_chat: - self.db.delete(existing_chat) - self.db.commit() + BotChat.delete(chat_id) - async def _can_use_commands(self, chat: TypePeer, tgid: int) -> bool: + async def _can_use_commands(self, chat: TypePeer, tgid: TelegramID) -> bool: if tgid in self.tg_whitelist: return True diff --git a/mautrix_telegram/db/__init__.py b/mautrix_telegram/db/__init__.py index af81d44c..053f7fa4 100644 --- a/mautrix_telegram/db/__init__.py +++ b/mautrix_telegram/db/__init__.py @@ -25,10 +25,9 @@ from .user import User, UserPortal, Contact from .user_profile import UserProfile -def init(db_session, db_engine) -> None: - BotChat.query = db_session.query_property() +def init(db_engine) -> None: for table in (Portal, Message, User, Contact, UserPortal, Puppet, TelegramFile, UserProfile, - RoomState): + RoomState, BotChat): table.db = db_engine table.t = table.__table__ table.c = table.t.c diff --git a/mautrix_telegram/db/bot_chat.py b/mautrix_telegram/db/bot_chat.py index f675fed5..7667c0c0 100644 --- a/mautrix_telegram/db/bot_chat.py +++ b/mautrix_telegram/db/bot_chat.py @@ -14,8 +14,9 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . +from typing import Iterable + from sqlalchemy import Column, Integer, String -from sqlalchemy.orm import Query from ..types import TelegramID from .base import Base @@ -23,7 +24,20 @@ from .base import Base # Fucking Telegram not telling bots what chats they are in 3:< class BotChat(Base): - query = None # type: Query __tablename__ = "bot_chat" id = Column(Integer, primary_key=True) # type: TelegramID type = Column(String, nullable=False) + + @classmethod + def delete(cls, id: TelegramID) -> None: + cls.db.execute(cls.t.delete().where(cls.c.id == id)) + + @classmethod + def all(cls) -> Iterable['BotChat']: + rows = cls.db.execute(cls.t.select()) + for row in rows: + id, type = row + yield cls(id=id, type=type) + + def insert(self) -> None: + self.db.execute(self.t.insert().values(id=self.id, type=self.type)) diff --git a/mautrix_telegram/puppet.py b/mautrix_telegram/puppet.py index e92b3cd9..b22d9210 100644 --- a/mautrix_telegram/puppet.py +++ b/mautrix_telegram/puppet.py @@ -295,15 +295,10 @@ class Puppet: db_instance=db_puppet) def save(self) -> None: - self.db_instance.access_token = self.access_token - self.db_instance.custom_mxid = self.custom_mxid - self.db_instance.username = self.username - self.db_instance.displayname = self.displayname - self.db_instance.displayname_source = self.displayname_source - self.db_instance.photo_id = self.photo_id - self.db_instance.is_bot = self.is_bot - self.db_instance.matrix_registered = self.is_registered - self.db.commit() + self.db_instance.update(access_token=self.access_token, custom_mxid=self.custom_mxid, + username=self.username, displayname=self.displayname, + displayname_source=self.displayname_source, photo_id=self.photo_id, + is_bot=self.is_bot, matrix_registered=self.is_registered) # endregion # region Info updating