Move Matrix state cache to main database. Fixes #159
This commit is contained in:
@@ -0,0 +1,126 @@
|
||||
"""Move state store to main database
|
||||
|
||||
Revision ID: 6ca3d74d51e4
|
||||
Revises: 2228d49c383f
|
||||
Create Date: 2018-06-26 21:31:26.911307
|
||||
|
||||
"""
|
||||
from alembic import context, op
|
||||
import sqlalchemy.orm as orm
|
||||
import sqlalchemy as sa
|
||||
import json
|
||||
import re
|
||||
|
||||
from mautrix_telegram.config import Config
|
||||
from mautrix_telegram.base import Base
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "6ca3d74d51e4"
|
||||
down_revision = "2228d49c383f"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
class RoomState(Base):
|
||||
query = None
|
||||
__tablename__ = "mx_room_state"
|
||||
__table_args__ = {"extend_existing": True}
|
||||
|
||||
room_id = sa.Column(sa.String, primary_key=True)
|
||||
power_levels = sa.Column("power_levels", sa.Text, nullable=True)
|
||||
|
||||
|
||||
class UserProfile(Base):
|
||||
query = None
|
||||
__tablename__ = "mx_user_profile"
|
||||
__table_args__ = {"extend_existing": True}
|
||||
|
||||
room_id = sa.Column(sa.String, primary_key=True)
|
||||
user_id = sa.Column(sa.String, primary_key=True)
|
||||
membership = sa.Column(sa.String, nullable=False, default="leave")
|
||||
displayname = sa.Column(sa.String, nullable=True)
|
||||
avatar_url = sa.Column(sa.String, nullable=True)
|
||||
|
||||
|
||||
class Puppet(Base):
|
||||
query = None
|
||||
__tablename__ = "puppet"
|
||||
__table_args__ = {"extend_existing": True}
|
||||
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
displayname = sa.Column(sa.String, nullable=True)
|
||||
displayname_source = sa.Column(sa.Integer, nullable=True)
|
||||
username = sa.Column(sa.String, nullable=True)
|
||||
photo_id = sa.Column(sa.String, nullable=True)
|
||||
is_bot = sa.Column(sa.Boolean, nullable=True)
|
||||
matrix_registered = sa.Column(sa.Boolean, nullable=False, default=False)
|
||||
|
||||
|
||||
def upgrade():
|
||||
op.add_column("puppet", sa.Column("matrix_registered", sa.Boolean(), nullable=False,
|
||||
server_default=sa.sql.expression.false()))
|
||||
op.create_table("mx_room_state",
|
||||
sa.Column("room_id", sa.String(), nullable=False),
|
||||
sa.Column("power_levels", sa.Text(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("room_id"))
|
||||
op.create_table("mx_user_profile",
|
||||
sa.Column("room_id", sa.String(), nullable=False),
|
||||
sa.Column("user_id", sa.String(), nullable=False),
|
||||
sa.Column("membership", sa.String(), nullable=False,
|
||||
default="leave"),
|
||||
sa.Column("displayname", sa.String(), nullable=True),
|
||||
sa.Column("avatar_url", sa.String(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("room_id", "user_id"))
|
||||
|
||||
conn = op.get_bind()
|
||||
session = orm.sessionmaker(bind=conn)
|
||||
session = orm.scoping.scoped_session(session)
|
||||
Puppet.query = session.query_property()
|
||||
|
||||
with open("mx-state.json") as file:
|
||||
data = json.load(file)
|
||||
if not data:
|
||||
return
|
||||
registrations = data.get("registrations", [])
|
||||
|
||||
mxtg_config_path = context.get_x_argument(as_dictionary=True).get("config", "config.yaml")
|
||||
mxtg_config = Config(mxtg_config_path, None, None)
|
||||
mxtg_config.load()
|
||||
|
||||
username_template = mxtg_config.get("bridge.username_template", "telegram_{userid}")
|
||||
hs_domain = mxtg_config["homeserver.domain"]
|
||||
localpart = username_template.format(userid="(.+)")
|
||||
mxid_regex = re.compile(f"@{localpart}:{hs_domain}")
|
||||
for user in registrations:
|
||||
match = mxid_regex.match(user)
|
||||
if not match:
|
||||
continue
|
||||
|
||||
puppet = Puppet.query.get(match.group(1))
|
||||
if not puppet:
|
||||
continue
|
||||
|
||||
puppet.matrix_registered = True
|
||||
session.merge(puppet)
|
||||
session.commit()
|
||||
|
||||
user_profiles = [UserProfile(room_id=room, user_id=user,
|
||||
membership=member.get("membership", "leave"),
|
||||
displayname=member.get("displayname", None),
|
||||
avatar_url=member.get("avatar_url", None))
|
||||
for room, members in data.get("members", {}).items()
|
||||
for user, member in members.items()]
|
||||
session.add_all(user_profiles)
|
||||
session.commit()
|
||||
|
||||
room_state = [RoomState(room_id=room, power_levels=json.dumps(levels))
|
||||
for room, levels in data.get("power_levels", {}).items()]
|
||||
session.add_all(room_state)
|
||||
session.commit()
|
||||
|
||||
|
||||
def downgrade():
|
||||
op.drop_table("mx_user_profile")
|
||||
op.drop_table("mx_room_state")
|
||||
with op.batch_alter_table("puppet") as batch_op:
|
||||
batch_op.drop_column("matrix_registered")
|
||||
@@ -38,6 +38,7 @@ from .puppet import init as init_puppet
|
||||
from .formatter import init as init_formatter
|
||||
from .public import PublicBridgeWebsite
|
||||
from .context import Context
|
||||
from .sqlstatestore import SQLStateStore
|
||||
|
||||
log = logging.getLogger("mau")
|
||||
time_formatter = logging.Formatter("[%(asctime)s] [%(levelname)s@%(name)s] %(message)s")
|
||||
@@ -87,10 +88,11 @@ telethon_session_container = AlchemySessionContainer(engine=db_engine, session=d
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
state_store = SQLStateStore(db_session)
|
||||
appserv = AppService(config["homeserver.address"], config["homeserver.domain"],
|
||||
config["appservice.as_token"], config["appservice.hs_token"],
|
||||
config["appservice.bot_username"], log="mau.as", loop=loop,
|
||||
verify_ssl=config["homeserver.verify_ssl"])
|
||||
verify_ssl=config["homeserver.verify_ssl"], state_store=state_store)
|
||||
|
||||
context = Context(appserv, db_session, config, loop, None, None, telethon_session_container)
|
||||
|
||||
|
||||
+48
-1
@@ -15,8 +15,10 @@
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from sqlalchemy import (Column, UniqueConstraint, ForeignKey, ForeignKeyConstraint, Integer,
|
||||
BigInteger, String, Boolean)
|
||||
BigInteger, String, Boolean, Text)
|
||||
from sqlalchemy.sql import expression
|
||||
from sqlalchemy.orm import relationship
|
||||
import json
|
||||
|
||||
from .base import Base
|
||||
|
||||
@@ -80,6 +82,48 @@ class User(Base):
|
||||
portals = relationship("Portal", secondary="user_portal")
|
||||
|
||||
|
||||
class RoomState(Base):
|
||||
query = None
|
||||
__tablename__ = "mx_room_state"
|
||||
|
||||
room_id = Column(String, primary_key=True)
|
||||
_power_levels_text = Column("power_levels", Text, nullable=True)
|
||||
_power_levels_json = None
|
||||
|
||||
# def __init__(self, *args, **kwargs):
|
||||
# super().__init__(*args, **kwargs)
|
||||
# self._power_levels_json = None
|
||||
|
||||
@property
|
||||
def power_levels(self):
|
||||
if not self._power_levels_json and self._power_levels_text:
|
||||
self._power_levels_json = json.loads(self._power_levels_text)
|
||||
return self._power_levels_json or {}
|
||||
|
||||
@power_levels.setter
|
||||
def power_levels(self, val):
|
||||
self._power_levels_json = val
|
||||
self._power_levels_text = json.dumps(val)
|
||||
|
||||
|
||||
class UserProfile(Base):
|
||||
query = None
|
||||
__tablename__ = "mx_user_profile"
|
||||
|
||||
room_id = Column(String, primary_key=True)
|
||||
user_id = Column(String, primary_key=True)
|
||||
membership = Column(String, nullable=False, default="leave")
|
||||
displayname = Column(String, nullable=True)
|
||||
avatar_url = Column(String, nullable=True)
|
||||
|
||||
def dict(self):
|
||||
return {
|
||||
"membership": self.membership,
|
||||
"displayname": self.displayname,
|
||||
"avatar_url": self.avatar_url,
|
||||
}
|
||||
|
||||
|
||||
class Contact(Base):
|
||||
query = None
|
||||
__tablename__ = "contact"
|
||||
@@ -98,6 +142,7 @@ class Puppet(Base):
|
||||
username = Column(String, nullable=True)
|
||||
photo_id = Column(String, nullable=True)
|
||||
is_bot = Column(Boolean, nullable=True)
|
||||
matrix_registered = Column(Boolean, nullable=False, server_default=expression.false())
|
||||
|
||||
|
||||
# Fucking Telegram not telling bots what chats they are in 3:<
|
||||
@@ -132,3 +177,5 @@ def init(db_session):
|
||||
Puppet.query = db_session.query_property()
|
||||
BotChat.query = db_session.query_property()
|
||||
TelegramFile.query = db_session.query_property()
|
||||
UserProfile.query = db_session.query_property()
|
||||
RoomState.query = db_session.query_property()
|
||||
|
||||
@@ -36,7 +36,7 @@ class Puppet:
|
||||
cache = {}
|
||||
|
||||
def __init__(self, id=None, username=None, displayname=None, displayname_source=None,
|
||||
photo_id=None, is_bot=None, db_instance=None):
|
||||
photo_id=None, is_bot=None, is_registered=False, db_instance=None):
|
||||
self.id = id
|
||||
self.mxid = self.get_mxid_from_id(self.id)
|
||||
|
||||
@@ -45,6 +45,7 @@ class Puppet:
|
||||
self.displayname_source = displayname_source
|
||||
self.photo_id = photo_id
|
||||
self.is_bot = is_bot
|
||||
self.is_registered = is_registered
|
||||
self._db_instance = db_instance
|
||||
|
||||
self.intent = self.az.intent.user(self.mxid)
|
||||
@@ -67,13 +68,13 @@ class Puppet:
|
||||
def new_db_instance(self):
|
||||
return DBPuppet(id=self.id, username=self.username, displayname=self.displayname,
|
||||
displayname_source=self.displayname_source, photo_id=self.photo_id,
|
||||
is_bot=self.is_bot)
|
||||
is_bot=self.is_bot, matrix_registered=self.is_registered)
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, db_puppet):
|
||||
return Puppet(db_puppet.id, db_puppet.username, db_puppet.displayname,
|
||||
db_puppet.displayname_source, db_puppet.photo_id, db_puppet.is_bot,
|
||||
db_instance=db_puppet)
|
||||
db_puppet.matrix_registered, db_instance=db_puppet)
|
||||
|
||||
def save(self):
|
||||
self.db_instance.username = self.username
|
||||
@@ -81,6 +82,7 @@ class Puppet:
|
||||
self.db_instance.displayname_source = self.displayname_source
|
||||
self.db_instance.photo_id = self.photo_id
|
||||
self.db_instance.is_bot = self.is_bot
|
||||
self.db_instance.matrix_registered = self.is_registered
|
||||
self.db.commit()
|
||||
|
||||
def similarity(self, query):
|
||||
|
||||
@@ -0,0 +1,99 @@
|
||||
# -*- coding: future_fstrings -*-
|
||||
# mautrix-telegram - A Matrix-Telegram puppeting bridge
|
||||
# Copyright (C) 2018 Tulir Asokan
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
import json
|
||||
|
||||
from mautrix_appservice import StateStore
|
||||
|
||||
from . import puppet as pu
|
||||
from .db import RoomState, UserProfile
|
||||
|
||||
|
||||
class SQLStateStore(StateStore):
|
||||
def __init__(self, db):
|
||||
super().__init__()
|
||||
self.db = db
|
||||
|
||||
def is_registered(self, user: str) -> bool:
|
||||
puppet = pu.Puppet.get_by_mxid(user)
|
||||
return puppet.is_registered if puppet else False
|
||||
|
||||
def registered(self, user: str):
|
||||
puppet = pu.Puppet.get_by_mxid(user)
|
||||
if puppet:
|
||||
puppet.is_registered = True
|
||||
puppet.save()
|
||||
|
||||
def update_state(self, event: dict):
|
||||
event_type = event["type"]
|
||||
if event_type == "m.room.power_levels":
|
||||
self.set_power_levels(event["room_id"], event["content"])
|
||||
elif event_type == "m.room.member":
|
||||
self.set_member(event["room_id"], event["state_key"], event["content"])
|
||||
|
||||
def get_member(self, room: str, user: str) -> dict:
|
||||
profile = UserProfile.query.get((room, user))
|
||||
if profile:
|
||||
return profile.dict()
|
||||
return {}
|
||||
|
||||
def set_member(self, room: str, user: str, member: dict):
|
||||
profile = UserProfile(room_id=room, user_id=user,
|
||||
membership=member.get("membership", "leave"),
|
||||
displayname=member.get("displayname", None),
|
||||
avatar_url=member.get("avatar_url", None))
|
||||
self.db.merge(profile)
|
||||
self.db.commit()
|
||||
|
||||
def set_membership(self, room: str, user: str, membership: str):
|
||||
profile = UserProfile.query.get((room, user))
|
||||
if not profile:
|
||||
profile = UserProfile(room_id=room, user_id=user, membership=membership)
|
||||
self.db.add(profile)
|
||||
else:
|
||||
profile.membership = membership
|
||||
self.db.commit()
|
||||
|
||||
def has_power_levels(self, room: str) -> bool:
|
||||
room = RoomState.query.get(room)
|
||||
return room and room._power_levels_text
|
||||
|
||||
def get_power_levels(self, room: str) -> dict:
|
||||
return RoomState.query.get(room).power_levels
|
||||
|
||||
def set_power_level(self, room: str, user: str, level: int):
|
||||
room_state = RoomState.query.get(room)
|
||||
if not room_state:
|
||||
room_state = RoomState(room)
|
||||
self.db.add(room_state)
|
||||
|
||||
power_levels = room_state.power_levels
|
||||
if not power_levels:
|
||||
power_levels = {
|
||||
"users": {},
|
||||
"events": {},
|
||||
}
|
||||
power_levels[room]["users"][user] = level
|
||||
room_state.power_levels = power_levels
|
||||
self.db.commit()
|
||||
|
||||
def set_power_levels(self, room: str, content: dict):
|
||||
state = RoomState.query.get(room)
|
||||
if not state:
|
||||
state = RoomState(room_id=room)
|
||||
self.db.add(state)
|
||||
state.power_levels = content
|
||||
self.db.commit()
|
||||
Reference in New Issue
Block a user