Stop using SQLAlchemy ORM everywhere

This commit is contained in:
Tulir Asokan
2019-02-14 00:06:45 +02:00
parent 8ef82abe9d
commit 9174186442
5 changed files with 29 additions and 25 deletions
+1 -1
View File
@@ -112,7 +112,7 @@ if config["appservice.provisioning.enabled"]:
context.provisioning_api = provisioning_api
with appserv.run(config["appservice.hostname"], config["appservice.port"]) as start:
init_db(db_session, db_engine)
init_db(db_engine)
init_abstract_user(context)
context.bot = init_bot(context)
context.mx = MatrixHandler(context)
+6 -10
View File
@@ -56,7 +56,7 @@ class Bot(AbstractUser):
self.username = None # type: str
self.is_relaybot = True # type: bool
self.is_bot = True # type: bool
self.chats = {chat.id: chat.type for chat in BotChat.query.all()} # type: Dict[int, str]
self.chats = {chat.id: chat.type for chat in BotChat.all()} # type: Dict[int, str]
self.tg_whitelist = [] # type: List[int]
self.whitelist_group_admins = (config["bridge.relaybot.whitelist_group_admins"]
or False) # type: bool
@@ -114,23 +114,19 @@ class Bot(AbstractUser):
def unregister_portal(self, portal: po.Portal) -> None:
self.remove_chat(portal.tgid)
def add_chat(self, chat_id: int, chat_type: str) -> None:
def add_chat(self, chat_id: TelegramID, chat_type: str) -> None:
if chat_id not in self.chats:
self.chats[chat_id] = chat_type
self.db.add(BotChat(id=TelegramID(chat_id), type=chat_type))
self.db.commit()
BotChat(id=TelegramID(chat_id), type=chat_type).insert()
def remove_chat(self, chat_id: int) -> None:
def remove_chat(self, chat_id: TelegramID) -> None:
try:
del self.chats[chat_id]
except KeyError:
pass
existing_chat = BotChat.query.get(chat_id)
if existing_chat:
self.db.delete(existing_chat)
self.db.commit()
BotChat.delete(chat_id)
async def _can_use_commands(self, chat: TypePeer, tgid: int) -> bool:
async def _can_use_commands(self, chat: TypePeer, tgid: TelegramID) -> bool:
if tgid in self.tg_whitelist:
return True
+2 -3
View File
@@ -25,10 +25,9 @@ from .user import User, UserPortal, Contact
from .user_profile import UserProfile
def init(db_session, db_engine) -> None:
BotChat.query = db_session.query_property()
def init(db_engine) -> None:
for table in (Portal, Message, User, Contact, UserPortal, Puppet, TelegramFile, UserProfile,
RoomState):
RoomState, BotChat):
table.db = db_engine
table.t = table.__table__
table.c = table.t.c
+16 -2
View File
@@ -14,8 +14,9 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Iterable
from sqlalchemy import Column, Integer, String
from sqlalchemy.orm import Query
from ..types import TelegramID
from .base import Base
@@ -23,7 +24,20 @@ from .base import Base
# Fucking Telegram not telling bots what chats they are in 3:<
class BotChat(Base):
query = None # type: Query
__tablename__ = "bot_chat"
id = Column(Integer, primary_key=True) # type: TelegramID
type = Column(String, nullable=False)
@classmethod
def delete(cls, id: TelegramID) -> None:
cls.db.execute(cls.t.delete().where(cls.c.id == id))
@classmethod
def all(cls) -> Iterable['BotChat']:
rows = cls.db.execute(cls.t.select())
for row in rows:
id, type = row
yield cls(id=id, type=type)
def insert(self) -> None:
self.db.execute(self.t.insert().values(id=self.id, type=self.type))
+4 -9
View File
@@ -295,15 +295,10 @@ class Puppet:
db_instance=db_puppet)
def save(self) -> None:
self.db_instance.access_token = self.access_token
self.db_instance.custom_mxid = self.custom_mxid
self.db_instance.username = self.username
self.db_instance.displayname = self.displayname
self.db_instance.displayname_source = self.displayname_source
self.db_instance.photo_id = self.photo_id
self.db_instance.is_bot = self.is_bot
self.db_instance.matrix_registered = self.is_registered
self.db.commit()
self.db_instance.update(access_token=self.access_token, custom_mxid=self.custom_mxid,
username=self.username, displayname=self.displayname,
displayname_source=self.displayname_source, photo_id=self.photo_id,
is_bot=self.is_bot, matrix_registered=self.is_registered)
# endregion
# region Info updating