Stop using SQLAlchemy ORM everywhere
This commit is contained in:
@@ -112,7 +112,7 @@ if config["appservice.provisioning.enabled"]:
|
||||
context.provisioning_api = provisioning_api
|
||||
|
||||
with appserv.run(config["appservice.hostname"], config["appservice.port"]) as start:
|
||||
init_db(db_session, db_engine)
|
||||
init_db(db_engine)
|
||||
init_abstract_user(context)
|
||||
context.bot = init_bot(context)
|
||||
context.mx = MatrixHandler(context)
|
||||
|
||||
+6
-10
@@ -56,7 +56,7 @@ class Bot(AbstractUser):
|
||||
self.username = None # type: str
|
||||
self.is_relaybot = True # type: bool
|
||||
self.is_bot = True # type: bool
|
||||
self.chats = {chat.id: chat.type for chat in BotChat.query.all()} # type: Dict[int, str]
|
||||
self.chats = {chat.id: chat.type for chat in BotChat.all()} # type: Dict[int, str]
|
||||
self.tg_whitelist = [] # type: List[int]
|
||||
self.whitelist_group_admins = (config["bridge.relaybot.whitelist_group_admins"]
|
||||
or False) # type: bool
|
||||
@@ -114,23 +114,19 @@ class Bot(AbstractUser):
|
||||
def unregister_portal(self, portal: po.Portal) -> None:
|
||||
self.remove_chat(portal.tgid)
|
||||
|
||||
def add_chat(self, chat_id: int, chat_type: str) -> None:
|
||||
def add_chat(self, chat_id: TelegramID, chat_type: str) -> None:
|
||||
if chat_id not in self.chats:
|
||||
self.chats[chat_id] = chat_type
|
||||
self.db.add(BotChat(id=TelegramID(chat_id), type=chat_type))
|
||||
self.db.commit()
|
||||
BotChat(id=TelegramID(chat_id), type=chat_type).insert()
|
||||
|
||||
def remove_chat(self, chat_id: int) -> None:
|
||||
def remove_chat(self, chat_id: TelegramID) -> None:
|
||||
try:
|
||||
del self.chats[chat_id]
|
||||
except KeyError:
|
||||
pass
|
||||
existing_chat = BotChat.query.get(chat_id)
|
||||
if existing_chat:
|
||||
self.db.delete(existing_chat)
|
||||
self.db.commit()
|
||||
BotChat.delete(chat_id)
|
||||
|
||||
async def _can_use_commands(self, chat: TypePeer, tgid: int) -> bool:
|
||||
async def _can_use_commands(self, chat: TypePeer, tgid: TelegramID) -> bool:
|
||||
if tgid in self.tg_whitelist:
|
||||
return True
|
||||
|
||||
|
||||
@@ -25,10 +25,9 @@ from .user import User, UserPortal, Contact
|
||||
from .user_profile import UserProfile
|
||||
|
||||
|
||||
def init(db_session, db_engine) -> None:
|
||||
BotChat.query = db_session.query_property()
|
||||
def init(db_engine) -> None:
|
||||
for table in (Portal, Message, User, Contact, UserPortal, Puppet, TelegramFile, UserProfile,
|
||||
RoomState):
|
||||
RoomState, BotChat):
|
||||
table.db = db_engine
|
||||
table.t = table.__table__
|
||||
table.c = table.t.c
|
||||
|
||||
@@ -14,8 +14,9 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Iterable
|
||||
|
||||
from sqlalchemy import Column, Integer, String
|
||||
from sqlalchemy.orm import Query
|
||||
|
||||
from ..types import TelegramID
|
||||
from .base import Base
|
||||
@@ -23,7 +24,20 @@ from .base import Base
|
||||
|
||||
# Fucking Telegram not telling bots what chats they are in 3:<
|
||||
class BotChat(Base):
|
||||
query = None # type: Query
|
||||
__tablename__ = "bot_chat"
|
||||
id = Column(Integer, primary_key=True) # type: TelegramID
|
||||
type = Column(String, nullable=False)
|
||||
|
||||
@classmethod
|
||||
def delete(cls, id: TelegramID) -> None:
|
||||
cls.db.execute(cls.t.delete().where(cls.c.id == id))
|
||||
|
||||
@classmethod
|
||||
def all(cls) -> Iterable['BotChat']:
|
||||
rows = cls.db.execute(cls.t.select())
|
||||
for row in rows:
|
||||
id, type = row
|
||||
yield cls(id=id, type=type)
|
||||
|
||||
def insert(self) -> None:
|
||||
self.db.execute(self.t.insert().values(id=self.id, type=self.type))
|
||||
|
||||
@@ -295,15 +295,10 @@ class Puppet:
|
||||
db_instance=db_puppet)
|
||||
|
||||
def save(self) -> None:
|
||||
self.db_instance.access_token = self.access_token
|
||||
self.db_instance.custom_mxid = self.custom_mxid
|
||||
self.db_instance.username = self.username
|
||||
self.db_instance.displayname = self.displayname
|
||||
self.db_instance.displayname_source = self.displayname_source
|
||||
self.db_instance.photo_id = self.photo_id
|
||||
self.db_instance.is_bot = self.is_bot
|
||||
self.db_instance.matrix_registered = self.is_registered
|
||||
self.db.commit()
|
||||
self.db_instance.update(access_token=self.access_token, custom_mxid=self.custom_mxid,
|
||||
username=self.username, displayname=self.displayname,
|
||||
displayname_source=self.displayname_source, photo_id=self.photo_id,
|
||||
is_bot=self.is_bot, matrix_registered=self.is_registered)
|
||||
|
||||
# endregion
|
||||
# region Info updating
|
||||
|
||||
Reference in New Issue
Block a user