", help="the path to save the generated registration to")
-args = parser.parse_args()
-config = Config(args.config, args.registration, args.base_config, os.environ)
-config.load()
-config.update()
+class TelegramBridge(Bridge):
+ name = "mautrix-telegram"
+ command = "python -m mautrix-telegram"
+ description = "A Matrix-Telegram puppeting bridge."
+ real_user_content_key = "net.maunium.telegram.puppet"
+ version = __version__
+ config_class = Config
+ matrix_class = MatrixHandler
+ state_store_class = SQLStateStore
-if args.generate_registration:
- config.generate_registration()
- config.save()
- print(f"Registration generated and saved to {config.registration_path}")
- sys.exit(0)
+ config: Config
+ session_container: AlchemySessionContainer
+ bot: Bot
-logging.config.dictConfig(copy.deepcopy(config["logging"]))
-log: logging.Logger = logging.getLogger("mau.init")
-log.debug(f"Initializing mautrix-telegram {__version__}")
+ def prepare_db(self) -> None:
+ super().prepare_db()
+ init_db(self.db)
+ self.session_container = AlchemySessionContainer(
+ engine=self.db, table_base=Base, session=False,
+ table_prefix="telethon_", manage_tables=False)
-db_engine = sql.create_engine(config["appservice.database"] or "sqlite:///mautrix-telegram.db")
-Base.metadata.bind = db_engine
+ def prepare_bridge(self) -> None:
+ self.bot = init_bot(self.config)
+ context = Context(self.az, self.config, self.loop, self.session_container, self.bot)
-session_container = AlchemySessionContainer(engine=db_engine, table_base=Base, session=False,
- table_prefix="telethon_", manage_tables=False)
-session_container.core_mode = True
+ if self.config["appservice.public.enabled"]:
+ public_website = PublicBridgeWebsite(self.loop)
+ self.az.app.add_subapp(self.config["appservice.public.prefix"], public_website.app)
+ context.public_website = public_website
-try:
- import uvloop
+ if self.config["appservice.provisioning.enabled"]:
+ provisioning_api = ProvisioningAPI(context)
+ self.az.app.add_subapp(self.config["appservice.provisioning.prefix"],
+ provisioning_api.app)
+ context.provisioning_api = provisioning_api
- asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
- log.debug("Using uvloop for asyncio")
-except ImportError:
- pass
+ self.matrix = context.mx = MatrixHandler(context)
-loop: asyncio.AbstractEventLoop = asyncio.get_event_loop()
+ if self.config["metrics.enabled"]:
+ if prometheus:
+ prometheus.start_http_server(self.config["metrics.listen_port"])
+ else:
+ self.log.warn("Metrics are enabled in the config, "
+ "but prometheus_client is not installed.")
-state_store = SQLStateStore()
-mebibyte = 1024 ** 2
-appserv = AppService(config["homeserver.address"], config["homeserver.domain"],
- config["appservice.as_token"], config["appservice.hs_token"],
- config["appservice.bot_username"], log="mau.as", loop=loop,
- verify_ssl=config["homeserver.verify_ssl"], state_store=state_store,
- real_user_content_key="net.maunium.telegram.puppet",
- aiohttp_params={
- "client_max_size": config["appservice.max_body_size"] * mebibyte
- })
-bot = init_bot(config)
-context = Context(appserv, config, loop, session_container, bot)
+ init_abstract_user(context)
+ init_formatter(context)
+ init_portal(context)
+ puppet_startup = init_puppet(context)
+ user_startup = init_user(context)
+ self.startup_actions = chain(puppet_startup, user_startup,
+ [self.bot.start] if self.bot else [])
-if config["appservice.public.enabled"]:
- public_website = PublicBridgeWebsite(loop)
- appserv.app.add_subapp(config["appservice.public.prefix"] or "/public", public_website.app)
- context.public_website = public_website
-
-if config["appservice.provisioning.enabled"]:
- provisioning_api = ProvisioningAPI(context)
- appserv.app.add_subapp(config["appservice.provisioning.prefix"] or "/_matrix/provisioning",
- provisioning_api.app)
- context.provisioning_api = provisioning_api
-
-context.mx = MatrixHandler(context)
-
-if config["metrics.enabled"]:
- if prometheus:
- prometheus.start_http_server(config["metrics.listen_port"])
- else:
- log.warn("Metrics are enabled in the config, but prometheus_client is not installed.")
-
-with appserv.run(config["appservice.hostname"], config["appservice.port"]) as start:
- start_ts = time()
- init_db(db_engine)
- init_abstract_user(context)
- init_formatter(context)
- init_portal(context)
- startup_actions: List[Awaitable[Any]] = (init_puppet(context) +
- init_user(context) +
- [start, context.mx.init_as_bot()])
-
- if context.bot:
- startup_actions.append(context.bot.start())
-
- signal.signal(signal.SIGINT, signal.default_int_handler)
- signal.signal(signal.SIGTERM, signal.default_int_handler)
-
- end_ts = time()
- try:
- log.debug(f"Initialization complete in {round(end_ts - start_ts, 2)} seconds,"
- " running startup actions")
- start_ts = time()
- loop.run_until_complete(asyncio.gather(*startup_actions, loop=loop))
- end_ts = time()
- log.debug(f"Startup actions complete in {round(end_ts - start_ts, 2)} seconds,"
- " now running forever")
- loop.run_forever()
- except KeyboardInterrupt:
- log.debug("Interrupt received, stopping clients")
- loop.run_until_complete(
- asyncio.gather(*[user.stop() for user in User.by_tgid.values()], loop=loop))
- log.debug("Clients stopped, shutting down")
- sys.exit(0)
- except Exception as e:
- log.exception("Unexpected error")
- sys.exit(1)
+ async def stop(self) -> None:
+ self.shutdown_actions = [user.stop() for user in User.by_tgid.values()]
+ await super().stop()
diff --git a/mautrix_telegram/abstract_user.py b/mautrix_telegram/abstract_user.py
index cb60c07a..2a7f849a 100644
--- a/mautrix_telegram/abstract_user.py
+++ b/mautrix_telegram/abstract_user.py
@@ -65,7 +65,7 @@ class AbstractUser(ABC):
loop: asyncio.AbstractEventLoop = None
log: logging.Logger
az: AppService
- bot: 'Bot'
+ relaybot: Optional['Bot']
ignore_incoming_bot_events: bool = True
client: Optional[MautrixTelegramClient]
@@ -76,7 +76,6 @@ class AbstractUser(ABC):
is_bot: bool
is_relaybot: bool
- relaybot: Optional['Bot']
puppet_whitelisted: bool
whitelisted: bool
@@ -404,7 +403,7 @@ class AbstractUser(ABC):
portal.tgid_log)
return
- if self.ignore_incoming_bot_events and self.bot and sender.id == self.bot.tgid:
+ if self.ignore_incoming_bot_events and self.relaybot and sender.id == self.relaybot.tgid:
self.log.debug(f"Ignoring relaybot-sent message %s to %s", update, portal.tgid_log)
return
diff --git a/mautrix_telegram/config.py b/mautrix_telegram/config.py
index f712fffa..276e4e6a 100644
--- a/mautrix_telegram/config.py
+++ b/mautrix_telegram/config.py
@@ -13,157 +13,33 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-from typing import Any, Dict, Optional, Tuple
-from ruamel.yaml import YAML
+from typing import Any, Dict, List, NamedTuple
from ruamel.yaml.comments import CommentedMap
import random
import string
+import os
-yaml: YAML = YAML()
-yaml.indent(4)
+from mautrix.types import UserID
+from mautrix.client import Client
+from mautrix.bridge.config import BaseBridgeConfig, ConfigUpdateHelper
+
+Permissions = NamedTuple("Permissions", relaybot=bool, user=bool, puppeting=bool,
+ matrix_puppeting=bool, admin=bool, level=str)
-class DictWithRecursion:
- _data: CommentedMap
-
- def __init__(self, data: Optional[CommentedMap] = None) -> None:
- self._data = data or CommentedMap()
-
- @staticmethod
- def _parse_key(key: str) -> Tuple[str, Optional[str]]:
- if '.' not in key:
- return key, None
- key, next_key = key.split('.', 1)
- if len(key) > 0 and key[0] == "[":
- end_index = next_key.index("]")
- key = key[1:] + "." + next_key[:end_index]
- next_key = next_key[end_index + 2:] if len(next_key) > end_index + 1 else None
- return key, next_key
-
- def _recursive_get(self, data: CommentedMap, key: str, default_value: Any) -> Any:
- key, next_key = self._parse_key(key)
- if next_key is not None:
- next_data = data.get(key, CommentedMap())
- return self._recursive_get(next_data, next_key, default_value)
- return data.get(key, default_value)
-
- def get(self, key: str, default_value: Any, allow_recursion: bool = True) -> Any:
- if allow_recursion and '.' in key:
- return self._recursive_get(self._data, key, default_value)
- return self._data.get(key, default_value)
-
- def __getitem__(self, key: str) -> Any:
- return self.get(key, None)
-
- def __contains__(self, key: str) -> bool:
- return self[key] is not None
-
- def _recursive_set(self, data: CommentedMap, key: str, value: Any) -> None:
- key, next_key = self._parse_key(key)
- if next_key is not None:
- if key not in data:
- data[key] = CommentedMap()
- next_data = data.get(key, CommentedMap())
- return self._recursive_set(next_data, next_key, value)
- data[key] = value
-
- def set(self, key: str, value: Any, allow_recursion: bool = True) -> None:
- if allow_recursion and '.' in key:
- self._recursive_set(self._data, key, value)
- return
- self._data[key] = value
-
- def __setitem__(self, key: str, value: Any) -> None:
- self.set(key, value)
-
- def _recursive_del(self, data: CommentedMap, key: str) -> None:
- key, next_key = self._parse_key(key)
- if next_key is not None:
- if key not in data:
- return
- next_data = data[key]
- return self._recursive_del(next_data, next_key)
- try:
- del data[key]
- del data.ca.items[key]
- except KeyError:
- pass
-
- def delete(self, key: str, allow_recursion: bool = True) -> None:
- if allow_recursion and '.' in key:
- self._recursive_del(self._data, key)
- return
- try:
- del self._data[key]
- del self._data.ca.items[key]
- except KeyError:
- pass
-
- def __delitem__(self, key: str) -> None:
- self.delete(key)
-
-
-class Config(DictWithRecursion):
- path: str
- registration_path: str
- base_path: str
- _registration: Optional[Dict[str, Any]]
- _overrides: Dict[str, Any]
-
- def __init__(self, path: str, registration_path: str, base_path: str,
- overrides: Dict[str, Any] = None) -> None:
- super().__init__()
- self.path = path
- self.registration_path = registration_path
- self.base_path = base_path
- self._registration = None
- self._overrides = overrides or {}
-
+class Config(BaseBridgeConfig):
def __getitem__(self, key: str) -> Any:
try:
- return self._overrides[f"MAUTRIX_TELEGRAM_{key.replace('.', '_').upper()}"]
+ return os.environ[f"MAUTRIX_TELEGRAM_{key.replace('.', '_').upper()}"]
except KeyError:
return super().__getitem__(key)
- def load(self) -> None:
- with open(self.path, 'r') as stream:
- self._data = yaml.load(stream)
-
- def load_base(self) -> Optional[DictWithRecursion]:
- try:
- with open(self.base_path, 'r') as stream:
- return DictWithRecursion(yaml.load(stream))
- except OSError:
- pass
- return None
-
- def save(self) -> None:
- with open(self.path, 'w') as stream:
- yaml.dump(self._data, stream)
- if self._registration and self.registration_path:
- with open(self.registration_path, 'w') as stream:
- yaml.dump(self._registration, stream)
-
@staticmethod
def _new_token() -> str:
return "".join(random.choice(string.ascii_lowercase + string.digits) for _ in range(64))
- def update(self) -> None:
- base = self.load_base()
- if not base:
- return
-
- def copy(from_path, to_path=None) -> None:
- if from_path in self:
- base[to_path or from_path] = self[from_path]
-
- def copy_dict(from_path, to_path=None, override_existing_map=True) -> None:
- if from_path in self:
- to_path = to_path or from_path
- if override_existing_map or to_path not in base:
- base[to_path] = CommentedMap()
- for key, value in self[from_path].items():
- base[to_path][key] = value
+ def do_update(self, helper: ConfigUpdateHelper) -> None:
+ copy, copy_dict, base = helper
copy("homeserver.address")
copy("homeserver.domain")
@@ -309,58 +185,43 @@ class Config(DictWithRecursion):
else:
copy("logging")
- self._data = base._data
- self.save()
-
- def _get_permissions(self, key: str) -> Tuple[bool, bool, bool, bool, bool, bool]:
+ def _get_permissions(self, key: str) -> Permissions:
level = self["bridge.permissions"].get(key, "")
admin = level == "admin"
matrix_puppeting = level == "full" or admin
puppeting = level == "puppeting" or matrix_puppeting
user = level == "user" or puppeting
relaybot = level == "relaybot" or user
- return relaybot, user, puppeting, matrix_puppeting, admin, level
+ return Permissions(relaybot, user, puppeting, matrix_puppeting, admin, level)
- def get_permissions(self, mxid: str) -> Tuple[bool, bool, bool, bool, bool, bool]:
- permissions = self["bridge.permissions"] or {}
+ def get_permissions(self, mxid: UserID) -> Permissions:
+ permissions = self["bridge.permissions"]
if mxid in permissions:
return self._get_permissions(mxid)
- homeserver = mxid[mxid.index(":") + 1:]
+ _, homeserver = Client.parse_user_id(mxid)
if homeserver in permissions:
return self._get_permissions(homeserver)
return self._get_permissions("*")
- def generate_registration(self) -> None:
+ @property
+ def namespaces(self) -> Dict[str, List[Dict[str, Any]]]:
homeserver = self["homeserver.domain"]
- username_format = self.get("bridge.username_template",
- "telegram_{userid}").format(userid=".+")
- alias_format = self.get("bridge.alias_template",
- "telegram_{groupname}").format(groupname=".+")
+ username_format = self["bridge.username_template"].format(userid=".+")
+ alias_format = self["bridge.alias_template"].format(groupname=".+")
+ group_id = ({"group_id": self["appservice.community_id"]}
+ if self["appservice.community_id"] else {})
- self.set("appservice.as_token", self._new_token())
- self.set("appservice.hs_token", self._new_token())
-
- self._registration = {
- "id": self["appservice.id"] or "telegram",
- "as_token": self["appservice.as_token"],
- "hs_token": self["appservice.hs_token"],
- "namespaces": {
- "users": [{
- "exclusive": True,
- "regex": f"@{username_format}:{homeserver}"
- }],
- "aliases": [{
- "exclusive": True,
- "regex": f"#{alias_format}:{homeserver}"
- }]
- },
- "url": self["appservice.address"],
- "sender_localpart": self["appservice.bot_username"],
- "rate_limited": False
+ return {
+ "users": [{
+ "exclusive": True,
+ "regex": f"@{username_format}:{homeserver}",
+ **group_id,
+ }],
+ "aliases": [{
+ "exclusive": True,
+ "regex": f"#{alias_format}:{homeserver}",
+ }]
}
- if self["appservice.community_id"]:
- self._registration["namespaces"]["users"][0]["group_id"] = self[
- "appservice.community_id"]
diff --git a/mautrix_telegram/context.py b/mautrix_telegram/context.py
index 1735d322..e1a1de2f 100644
--- a/mautrix_telegram/context.py
+++ b/mautrix_telegram/context.py
@@ -15,11 +15,12 @@
# along with this program. If not, see .
from typing import Optional, Tuple, TYPE_CHECKING
-if TYPE_CHECKING:
- import asyncio
+import asyncio
- from alchemysession import AlchemySessionContainer
- from mautrix_appservice import AppService
+from alchemysession import AlchemySessionContainer
+from mautrix_appservice import AppService
+
+if TYPE_CHECKING:
from .web import PublicBridgeWebsite, ProvisioningAPI
from .config import Config
@@ -28,17 +29,17 @@ if TYPE_CHECKING:
class Context:
- az: 'AppService'
+ az: AppService
config: 'Config'
- loop: 'asyncio.AbstractEventLoop'
+ loop: asyncio.AbstractEventLoop
bot: Optional['Bot']
mx: Optional['MatrixHandler']
- session_container: 'AlchemySessionContainer'
+ session_container: AlchemySessionContainer
public_website: Optional['PublicBridgeWebsite']
provisioning_api: Optional['ProvisioningAPI']
- def __init__(self, az: 'AppService', config: 'Config', loop: 'asyncio.AbstractEventLoop',
- session_container: 'AlchemySessionContainer', bot: Optional['Bot']) -> None:
+ def __init__(self, az: AppService, config: 'Config', loop: asyncio.AbstractEventLoop,
+ session_container: AlchemySessionContainer, bot: Optional['Bot']) -> None:
self.az = az
self.config = config
self.loop = loop
@@ -49,5 +50,5 @@ class Context:
self.provisioning_api = None
@property
- def core(self) -> Tuple['AppService', 'Config', 'asyncio.AbstractEventLoop', Optional['Bot']]:
+ def core(self) -> Tuple[AppService, 'Config', asyncio.AbstractEventLoop, Optional['Bot']]:
return self.az, self.config, self.loop, self.bot
diff --git a/mautrix_telegram/db/room_state.py b/mautrix_telegram/db/room_state.py
deleted file mode 100644
index e61f0bd2..00000000
--- a/mautrix_telegram/db/room_state.py
+++ /dev/null
@@ -1,61 +0,0 @@
-# mautrix-telegram - A Matrix-Telegram puppeting bridge
-# Copyright (C) 2019 Tulir Asokan
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Affero General Public License as published by
-# the Free Software Foundation, either version 3 of the License, or
-# (at your option) any later version.
-#
-# This program is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-# GNU Affero General Public License for more details.
-#
-# You should have received a copy of the GNU Affero General Public License
-# along with this program. If not, see .
-from sqlalchemy import Column, String, Text
-from typing import Dict, Optional
-import json
-
-from ..types import MatrixRoomID
-from .base import Base
-
-
-class RoomState(Base):
- __tablename__ = "mx_room_state"
-
- room_id = Column(String, primary_key=True) # type: MatrixRoomID
- power_levels = Column("power_levels", Text, nullable=True) # type: Optional[Dict]
-
- @property
- def _power_levels_text(self) -> Optional[str]:
- return json.dumps(self.power_levels) if self.power_levels else None
-
- @property
- def has_power_levels(self) -> bool:
- return bool(self.power_levels)
-
- @classmethod
- def get(cls, room_id: MatrixRoomID) -> Optional['RoomState']:
- rows = cls.db.execute(cls.t.select().where(cls.c.room_id == room_id))
- try:
- room_id, power_levels_text = next(rows)
- return cls(room_id=room_id, power_levels=(json.loads(power_levels_text)
- if power_levels_text else None))
- except StopIteration:
- return None
-
- def update(self) -> None:
- with self.db.begin() as conn:
- conn.execute(self.t.update()
- .where(self.c.room_id == self.room_id)
- .values(power_levels=self._power_levels_text))
-
- @property
- def _edit_identity(self):
- return self.c.room_id == self.room_id
-
- def insert(self) -> None:
- with self.db.begin() as conn:
- conn.execute(self.t.insert().values(room_id=self.room_id,
- power_levels=self._power_levels_text))
diff --git a/mautrix_telegram/db/user_profile.py b/mautrix_telegram/db/user_profile.py
deleted file mode 100644
index c99bb186..00000000
--- a/mautrix_telegram/db/user_profile.py
+++ /dev/null
@@ -1,68 +0,0 @@
-# mautrix-telegram - A Matrix-Telegram puppeting bridge
-# Copyright (C) 2019 Tulir Asokan
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Affero General Public License as published by
-# the Free Software Foundation, either version 3 of the License, or
-# (at your option) any later version.
-#
-# This program is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-# GNU Affero General Public License for more details.
-#
-# You should have received a copy of the GNU Affero General Public License
-# along with this program. If not, see .
-from sqlalchemy import Column, String, and_
-from typing import Dict, Optional
-
-from ..types import MatrixUserID, MatrixRoomID
-from .base import Base
-
-
-class UserProfile(Base):
- __tablename__ = "mx_user_profile"
-
- room_id = Column(String, primary_key=True) # type: MatrixRoomID
- user_id = Column(String, primary_key=True) # type: MatrixUserID
- membership = Column(String, nullable=False, default="leave")
- displayname = Column(String, nullable=True)
- avatar_url = Column(String, nullable=True)
-
- def dict(self) -> Dict[str, str]:
- return {
- "membership": self.membership,
- "displayname": self.displayname,
- "avatar_url": self.avatar_url,
- }
-
- @classmethod
- def get(cls, room_id: MatrixRoomID, user_id: MatrixUserID) -> Optional['UserProfile']:
- rows = cls.db.execute(
- cls.t.select().where(and_(cls.c.room_id == room_id, cls.c.user_id == user_id)))
- try:
- room_id, user_id, membership, displayname, avatar_url = next(rows)
- return cls(room_id=room_id, user_id=user_id, membership=membership,
- displayname=displayname, avatar_url=avatar_url)
- except StopIteration:
- return None
-
- @classmethod
- def delete_all(cls, room_id: MatrixRoomID) -> None:
- with cls.db.begin() as conn:
- conn.execute(cls.t.delete().where(cls.c.room_id == room_id))
-
- def update(self) -> None:
- super().update(membership=self.membership, displayname=self.displayname,
- avatar_url=self.avatar_url)
-
- @property
- def _edit_identity(self):
- return and_(self.c.room_id == self.room_id, self.c.user_id == self.user_id)
-
- def insert(self) -> None:
- with self.db.begin() as conn:
- conn.execute(self.t.insert().values(room_id=self.room_id, user_id=self.user_id,
- membership=self.membership,
- displayname=self.displayname,
- avatar_url=self.avatar_url))
diff --git a/mautrix_telegram/matrix.py b/mautrix_telegram/matrix.py
index 661da6cc..68319f3b 100644
--- a/mautrix_telegram/matrix.py
+++ b/mautrix_telegram/matrix.py
@@ -13,66 +13,59 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-from typing import Dict, List, Match, Optional, Set, Tuple, TYPE_CHECKING
-import logging
-import asyncio
+from typing import Dict, Match, Optional, Set, Tuple, Union, Iterable, TYPE_CHECKING
import time
import re
-from mautrix_appservice import MatrixRequestError, IntentError
+from mautrix.bridge import BaseMatrixHandler
+from mautrix.types import (Event, EventType, RoomID, UserID, EventID, ReceiptEvent, ReceiptType,
+ ReceiptEventContent, PresenceEvent, PresenceState, TypingEvent,
+ MessageEvent, StateEvent, RedactionEvent, RoomNameStateEventContent,
+ RoomAvatarStateEventContent, RoomTopicStateEventContent,
+ MemberStateEventContent)
+from mautrix.errors import MatrixError
-from .types import MatrixEvent, MatrixEventID, MatrixRoomID, MatrixUserID
from . import user as u, portal as po, puppet as pu, commands as com
if TYPE_CHECKING:
from .context import Context
- from .config import Config
from .bot import Bot
- from mautrix_appservice import AppService
try:
from prometheus_client import Histogram
- EVENT_TIME = Histogram("matrix_event", "Time spent processing Matrix events",
- ["event_type"])
+ EVENT_TIME = Histogram("matrix_event", "Time spent processing Matrix events", ["event_type"])
except ImportError:
Histogram = None
EVENT_TIME = None
-class MatrixHandler:
- log: logging.Logger = logging.getLogger("mau.mx")
- az: 'AppService'
- config: 'Config'
+RoomMetaStateEventContent = Union[RoomNameStateEventContent, RoomAvatarStateEventContent,
+ RoomTopicStateEventContent]
+
+
+class MatrixHandler(BaseMatrixHandler):
bot: 'Bot'
commands: 'com.CommandProcessor'
- previously_typing: Dict[MatrixRoomID, Set[MatrixUserID]]
+ previously_typing: Dict[RoomID, Set[UserID]]
def __init__(self, context: 'Context') -> None:
- self.az, self.config, _, self.tgbot = context.core
- self.commands = com.CommandProcessor(context)
+ super(MatrixHandler, self).__init__(context.az, context.config, loop=context.loop,
+ command_processor=com.CommandProcessor(context))
+ self.bot = context.bot
self.previously_typing = {}
- self.az.matrix_event_handler(self.handle_event)
+ async def get_user(self, user_id: UserID) -> 'u.User':
+ return await u.User.get_by_mxid(user_id).ensure_started()
- async def init_as_bot(self) -> None:
- displayname = self.config["appservice.bot_displayname"]
- if displayname:
- try:
- await self.az.intent.set_display_name(
- displayname if displayname != "remove" else "")
- except asyncio.TimeoutError:
- self.log.exception("TimeoutError when trying to set displayname")
+ async def get_portal(self, room_id: RoomID) -> 'po.Portal':
+ return po.Portal.get_by_mxid(room_id)
- avatar = self.config["appservice.bot_avatar"]
- if avatar:
- try:
- await self.az.intent.set_avatar(avatar if avatar != "remove" else "")
- except asyncio.TimeoutError:
- self.log.exception("TimeoutError when trying to set avatar")
+ async def get_puppet(self, user_id: UserID) -> 'pu.Puppet':
+ return pu.Puppet.get_by_mxid(user_id)
- async def handle_puppet_invite(self, room_id: MatrixRoomID, puppet: pu.Puppet, inviter: u.User
- ) -> None:
+ async def handle_puppet_invite(self, room_id: RoomID, puppet: pu.Puppet, inviter: u.User,
+ event_id: EventID) -> None:
intent = puppet.default_mxid_intent
self.log.debug(f"{inviter} invited puppet for {puppet.tgid} to {room_id}")
if not await inviter.is_logged_in():
@@ -90,7 +83,7 @@ class MatrixHandler:
return
try:
members = await self.az.intent.get_room_members(room_id)
- except MatrixRequestError:
+ except MatrixError:
members = []
if self.az.bot_mxid not in members:
if len(members) > 1:
@@ -113,7 +106,7 @@ class MatrixHandler:
""))
await intent.leave_room(room_id)
return
- except MatrixRequestError:
+ except MatrixError:
pass
portal.mxid = room_id
portal.save()
@@ -124,67 +117,25 @@ class MatrixHandler:
await intent.send_notice(room_id, "This puppet will remain inactive until a "
"Telegram chat is created for this room.")
- async def accept_bot_invite(self, room_id: MatrixRoomID, inviter: u.User) -> None:
- tries = 0
- while tries < 5:
- try:
- await self.az.intent.join_room(room_id)
- break
- except (IntentError, MatrixRequestError):
- tries += 1
- wait_for_seconds = (tries + 1) * 10
- if tries < 5:
- self.log.exception(f"Failed to join room {room_id} with bridge bot, "
- f"retrying in {wait_for_seconds} seconds...")
- await asyncio.sleep(wait_for_seconds)
- else:
- self.log.exception("Failed to join room {room}, giving up.")
- return
-
- if not inviter.whitelisted:
- await self.az.intent.send_notice(
- room_id,
- text="You are not whitelisted to use this bridge.\n\n"
- "If you are the owner of this bridge, see the "
- "`bridge.permissions` section in your config file.",
- html="You are not whitelisted to use this bridge.
"
- "If you are the owner of this bridge, see the "
- "bridge.permissions section in your config file.
")
- await self.az.intent.leave_room(room_id)
-
+ async def send_welcome_message(self, room_id: RoomID, inviter: 'u.User', event_id: EventID
+ ) -> None:
try:
is_management = len(await self.az.intent.get_room_members(room_id)) == 2
- except MatrixRequestError:
- is_management = False
+ except MatrixError:
+ # The AS bot is not in the room.
+ return
cmd_prefix = self.commands.command_prefix
text = html = "Hello, I'm a Telegram bridge bot. "
if is_management and inviter.puppet_whitelisted and not await inviter.is_logged_in():
text += f"Use `{cmd_prefix} help` for help or `{cmd_prefix} login` to log in."
html += (f"Use {cmd_prefix} help for help"
f" or {cmd_prefix} login to log in.")
- pass
else:
text += f"Use `{cmd_prefix} help` for help."
html += f"Use {cmd_prefix} help for help."
await self.az.intent.send_notice(room_id, text=text, html=html)
- async def handle_invite(self, room_id: MatrixRoomID, user_id: MatrixUserID,
- inviter_mxid: MatrixUserID) -> None:
- self.log.debug(f"{inviter_mxid} invited {user_id} to {room_id}")
- inviter = u.User.get_by_mxid(inviter_mxid)
- if inviter is None:
- self.log.exception("Failed to find user with Matrix ID {inviter_mxid}")
- await inviter.ensure_started()
- if user_id == self.az.bot_mxid:
- return await self.accept_bot_invite(room_id, inviter)
- elif not inviter.whitelisted:
- return
-
- puppet = pu.Puppet.get_by_mxid(user_id)
- if puppet:
- await self.handle_puppet_invite(room_id, puppet, inviter)
- return
-
+ async def handle_invite(self, room_id: RoomID, user_id: UserID, inviter: 'u.User') -> None:
user = u.User.get_by_mxid(user_id, create=False)
if not user:
return
@@ -194,10 +145,8 @@ class MatrixHandler:
await portal.invite_telegram(inviter, user)
return
- # The rest can probably be ignored
-
- async def handle_join(self, room_id: MatrixRoomID, user_id: MatrixUserID,
- event_id: MatrixEventID) -> None:
+ async def handle_join(self, room_id: RoomID, user_id: UserID,
+ event_id: EventID) -> None:
user = await u.User.get_by_mxid(user_id).ensure_started()
portal = po.Portal.get_by_mxid(room_id)
@@ -218,11 +167,11 @@ class MatrixHandler:
if await user.is_logged_in() or portal.has_bot:
await portal.join_matrix(user, event_id)
- async def handle_part(self, room_id: MatrixRoomID, user_id: MatrixUserID,
- sender_mxid: MatrixUserID, event_id: MatrixEventID) -> None:
+ async def handle_raw_leave(self, room_id: RoomID, user_id: UserID, sender_id: UserID,
+ reason: str, event_id: EventID) -> None:
self.log.debug(f"{user_id} left {room_id}")
- sender = u.User.get_by_mxid(sender_mxid, create=False)
+ sender = u.User.get_by_mxid(sender_id, create=False)
if not sender:
return
await sender.ensure_started()
@@ -233,98 +182,67 @@ class MatrixHandler:
puppet = pu.Puppet.get_by_mxid(user_id)
if puppet:
- if sender:
- await portal.kick_matrix(puppet, sender)
+ await portal.kick_matrix(puppet, sender)
return
user = u.User.get_by_mxid(user_id, create=False)
if not user:
return
await user.ensure_started()
- if await user.is_logged_in() or portal.has_bot:
- await portal.leave_matrix(user, sender, event_id)
-
- def is_command(self, message: Dict) -> Tuple[bool, str]:
- text = message.get("body", "")
- prefix = self.config["bridge.command_prefix"]
- is_command = text.startswith(prefix)
- if is_command:
- text = text[len(prefix) + 1:].lstrip()
- return is_command, text
-
- async def handle_message(self, room: MatrixRoomID, sender_id: MatrixUserID, message: Dict,
- event_id: MatrixEventID) -> None:
- is_command, text = self.is_command(message)
- sender = await u.User.get_by_mxid(sender_id).ensure_started()
- if not sender.relaybot_whitelisted:
- self.log.debug(f"Ignoring message \"{message}\" from {sender} to {room}:"
- " User is not whitelisted.")
- return
- self.log.debug(f"Received Matrix event \"{message}\" from {sender} in {room}")
-
- portal = po.Portal.get_by_mxid(room)
- if not is_command and portal and (await sender.is_logged_in() or portal.has_bot):
- await portal.handle_matrix_message(sender, message, event_id)
- return
-
- if not sender.whitelisted or message.get("msgtype", "m.unknown") != "m.text":
- return
-
- try:
- is_management = len(await self.az.intent.get_room_members(room)) == 2
- except MatrixRequestError:
- # The AS bot is not in the room.
- return
-
- if is_command or is_management:
- try:
- command, arguments = text.split(" ", 1)
- args = arguments.split(" ")
- except ValueError:
- # Not enough values to unpack, i.e. no arguments
- command = text
- args = []
- await self.commands.handle(room, event_id, sender, command, args, is_management,
- is_portal=portal is not None)
+ if sender_id != user_id:
+ await portal.kick_matrix(user, sender)
+ else:
+ await portal.leave_matrix(user, event_id)
@staticmethod
- async def handle_redaction(room_id: MatrixRoomID, sender_mxid: MatrixUserID,
- event_id: MatrixEventID) -> None:
- sender = await u.User.get_by_mxid(sender_mxid).ensure_started()
+ async def allow_message(user: 'u.User') -> bool:
+ return user.relaybot_whitelisted
+
+ @staticmethod
+ async def allow_command(user: 'u.User') -> bool:
+ return user.whitelisted
+
+ @staticmethod
+ async def allow_bridging_message(user: 'u.User', portal: 'po.Portal') -> bool:
+ return await user.is_logged_in() or portal.has_bot
+
+ @staticmethod
+ async def handle_redaction(evt: RedactionEvent) -> None:
+ sender = await u.User.get_by_mxid(evt.sender).ensure_started()
if not sender.relaybot_whitelisted:
return
- portal = po.Portal.get_by_mxid(room_id)
+ portal = po.Portal.get_by_mxid(evt.room_id)
if not portal:
return
- await portal.handle_matrix_deletion(sender, event_id)
+ await portal.handle_matrix_deletion(sender, evt.redacts)
@staticmethod
- async def handle_power_levels(room_id: MatrixRoomID, sender_mxid: MatrixUserID,
- new: Dict, old: Dict) -> None:
- portal = po.Portal.get_by_mxid(room_id)
- sender = await u.User.get_by_mxid(sender_mxid).ensure_started()
+ async def handle_power_levels(evt: StateEvent) -> None:
+ portal = po.Portal.get_by_mxid(evt.event_id)
+ sender = await u.User.get_by_mxid(evt.sender).ensure_started()
if await sender.has_full_access(allow_bot=True) and portal:
- await portal.handle_matrix_power_levels(sender, new["users"], old["users"])
+ await portal.handle_matrix_power_levels(sender, evt.content.users,
+ evt.unsigned.prev_content.users)
@staticmethod
- async def handle_room_meta(evt_type: str, room_id: MatrixRoomID, sender_mxid: MatrixUserID,
- content: dict) -> None:
+ async def handle_room_meta(evt_type: EventType, room_id: RoomID, sender_mxid: UserID,
+ content: RoomMetaStateEventContent) -> None:
portal = po.Portal.get_by_mxid(room_id)
sender = await u.User.get_by_mxid(sender_mxid).ensure_started()
if await sender.has_full_access(allow_bot=True) and portal:
handler, content_key = {
- "m.room.name": (portal.handle_matrix_title, "name"),
- "m.room.topic": (portal.handle_matrix_about, "topic"),
- "m.room.avatar": (portal.handle_matrix_avatar, "url"),
+ EventType.ROOM_NAME: (portal.handle_matrix_title, "name"),
+ EventType.ROOM_TOPIC: (portal.handle_matrix_about, "topic"),
+ EventType.ROOM_AVATAR: (portal.handle_matrix_avatar, "url"),
}[evt_type]
if content_key not in content:
return
await handler(sender, content[content_key])
@staticmethod
- async def handle_room_pin(room_id: MatrixRoomID, sender_mxid: MatrixUserID,
+ async def handle_room_pin(room_id: RoomID, sender_mxid: UserID,
new_events: Set[str], old_events: Set[str]) -> None:
portal = po.Portal.get_by_mxid(room_id)
sender = await u.User.get_by_mxid(sender_mxid).ensure_started()
@@ -332,55 +250,61 @@ class MatrixHandler:
events = new_events - old_events
if len(events) > 0:
# New event pinned, set that as pinned in Telegram.
- await portal.handle_matrix_pin(sender, MatrixEventID(events.pop()))
+ await portal.handle_matrix_pin(sender, EventID(events.pop()))
elif len(new_events) == 0:
# All pinned events removed, remove pinned event in Telegram.
await portal.handle_matrix_pin(sender, None)
@staticmethod
- async def handle_room_upgrade(room_id: MatrixRoomID, new_room_id: MatrixRoomID) -> None:
+ async def handle_room_upgrade(room_id: RoomID, new_room_id: RoomID) -> None:
portal = po.Portal.get_by_mxid(room_id)
if portal:
await portal.handle_matrix_upgrade(new_room_id)
@staticmethod
- async def handle_name_change(room_id: MatrixRoomID, user_id: MatrixUserID, displayname: str,
- prev_displayname: str, event_id: MatrixEventID) -> None:
+ async def handle_member_info_change(room_id: RoomID, user_id: UserID,
+ profile: MemberStateEventContent,
+ prev_profile: MemberStateEventContent,
+ event_id: EventID) -> None:
+ if profile.displayname == prev_profile.displayname:
+ return
+
portal = po.Portal.get_by_mxid(room_id)
if not portal or not portal.has_bot:
return
user = await u.User.get_by_mxid(user_id).ensure_started()
if await user.needs_relaybot(portal):
- await portal.name_change_matrix(user, displayname, prev_displayname, event_id)
+ await portal.name_change_matrix(user, profile.displayname, prev_profile.displayname,
+ event_id)
@staticmethod
- def parse_read_receipts(content: Dict) -> Dict[MatrixUserID, MatrixEventID]:
- return {user_id: event_id
+ def parse_read_receipts(content: ReceiptEventContent) -> Iterable[Tuple[UserID, EventID]]:
+ return ((user_id, event_id)
for event_id, receipts in content.items()
- for user_id in receipts.get("m.read", {})}
+ for user_id in receipts.get(ReceiptType.READ, {}))
@staticmethod
- async def handle_read_receipts(room_id: MatrixRoomID,
- receipts: Dict[MatrixUserID, MatrixEventID]) -> None:
+ async def handle_read_receipts(room_id: RoomID, receipts: Iterable[Tuple[UserID, EventID]]
+ ) -> None:
portal = po.Portal.get_by_mxid(room_id)
if not portal:
return
- for user_id, event_id in receipts.items():
+ for user_id, event_id in receipts:
user = await u.User.get_by_mxid(user_id).ensure_started()
if not await user.is_logged_in():
continue
await portal.mark_read(user, event_id)
@staticmethod
- async def handle_presence(user_id: MatrixUserID, presence: str) -> None:
+ async def handle_presence(user_id: UserID, presence: PresenceState) -> None:
user = await u.User.get_by_mxid(user_id).ensure_started()
if not await user.is_logged_in():
return
- await user.set_presence(presence == "online")
+ await user.set_presence(presence == PresenceState.ONLINE)
- async def handle_typing(self, room_id: MatrixRoomID, now_typing: Set[MatrixUserID]) -> None:
+ async def handle_typing(self, room_id: RoomID, now_typing: Set[UserID]) -> None:
portal = po.Portal.get_by_mxid(room_id)
if not portal:
return
@@ -401,86 +325,44 @@ class MatrixHandler:
self.previously_typing[room_id] = now_typing
- def filter_matrix_event(self, event: MatrixEvent) -> bool:
- sender = event.get("sender", None)
- if not sender:
- return False
- return (sender == self.az.bot_mxid
- or pu.Puppet.get_id_from_mxid(sender) is not None)
+ def filter_matrix_event(self, evt: Event) -> bool:
+ if not isinstance(evt, (MessageEvent, StateEvent)):
+ return True
+ return evt.sender and (evt.sender == self.az.bot_mxid
+ or pu.Puppet.get_id_from_mxid(evt.sender) is not None)
- async def try_handle_ephemeral_event(self, evt: MatrixEvent) -> None:
- try:
- await self.handle_ephemeral_event(evt)
- except Exception:
- self.log.exception("Error handling manually received Matrix event")
+ async def handle_ephemeral_event(self, evt: Union[ReceiptEvent, PresenceEvent, TypingEvent]
+ ) -> None:
+ if evt.type == EventType.RECEIPT:
+ await self.handle_read_receipts(evt.room_id, self.parse_read_receipts(evt.content))
+ elif evt.type == EventType.PRESENCE:
+ await self.handle_presence(evt.sender, evt.content.presence)
+ elif evt.type == EventType.TYPING:
+ await self.handle_typing(evt.room_id, set(evt.content.user_ids))
- async def handle_ephemeral_event(self, evt: MatrixEvent) -> None:
- evt_type: str = evt.get("type", "m.unknown")
- room_id: Optional[MatrixRoomID] = evt.get("room_id", None)
- sender: Optional[MatrixUserID] = evt.get("sender", None)
- content: Dict = evt.get("content", {})
- if evt_type == "m.receipt":
- await self.handle_read_receipts(room_id, self.parse_read_receipts(content))
- elif evt_type == "m.presence":
- await self.handle_presence(sender, content.get("presence", "offline"))
- elif evt_type == "m.typing":
- await self.handle_typing(room_id, set(content.get("user_ids", [])))
+ async def handle_event(self, evt: Event) -> None:
+ if evt.type == EventType.ROOM_REDACTION:
+ await self.handle_redaction(evt)
- async def handle_event(self, evt: MatrixEvent) -> None:
- if self.filter_matrix_event(evt):
- return
- start_time = time.time()
- self.log.debug("Received event: %s", evt)
- evt_type: str = evt.get("type", "m.unknown")
- room_id: Optional[MatrixRoomID] = evt.get("room_id", None)
- event_id: Optional[MatrixEventID] = evt.get("event_id", None)
- sender: Optional[MatrixUserID] = evt.get("sender", None)
- state_key = evt.get("state_key", None)
- content: Dict = evt.get("content", {})
- if state_key is not None:
- if evt_type == "m.room.member":
- prev_content: Dict = evt.get("unsigned", {}).get("prev_content", {})
- membership: str = content.get("membership", "")
- prev_membership: str = prev_content.get("membership", "leave")
- if membership == prev_membership:
- match: Match = re.compile("@(.+):(.+)").match(state_key)
- mxid: str = match.group(0)
- displayname: str = content.get("displayname", None) or mxid
- prev_displayname: str = prev_content.get("displayname", None) or mxid
- if displayname != prev_displayname:
- await self.handle_name_change(room_id, state_key, displayname,
- prev_displayname, event_id)
- elif membership == "invite":
- await self.handle_invite(room_id, state_key, sender)
- elif prev_membership == "join" and membership == "leave":
- await self.handle_part(room_id, state_key, sender, event_id)
- elif membership == "join":
- await self.handle_join(room_id, state_key, event_id)
- elif evt_type == "m.room.power_levels":
- prev_content = evt.get("unsigned", {}).get("prev_content", {})
- await self.handle_power_levels(room_id, sender, evt["content"], prev_content)
- elif evt_type in ("m.room.name", "m.room.avatar", "m.room.topic"):
- await self.handle_room_meta(evt_type, room_id, sender, evt["content"])
- elif evt_type == "m.room.pinned_events":
- new_events = set(evt["content"]["pinned"])
- try:
- old_events = set(evt["unsigned"]["prev_content"]["pinned"])
- except KeyError:
- old_events = set()
- await self.handle_room_pin(room_id, sender, new_events, old_events)
- elif evt_type == "m.room.tombstone":
- await self.handle_room_upgrade(room_id, evt["content"]["replacement_room"])
- else:
- return
- else:
- if evt_type in ("m.room.message", "m.sticker"):
- if evt_type != "m.room.message":
- content["msgtype"] = evt_type
- await self.handle_message(room_id, sender, content, event_id)
- elif evt_type == "m.room.redaction":
- await self.handle_redaction(room_id, sender, evt["redacts"])
- else:
- return
+ async def handle_state_event(self, evt: StateEvent) -> None:
+ if evt.type == EventType.ROOM_POWER_LEVELS:
+ await self.handle_power_levels(evt)
+ elif evt.type in (EventType.ROOM_NAME, EventType.ROOM_AVATAR, EventType.ROOM_TOPIC):
+ await self.handle_room_meta(evt.type, evt.room_id, evt.sender, evt.content)
+ elif evt.type == EventType.ROOM_PINNED_EVENTS:
+ new_events = set(evt.content.pinned)
+ try:
+ old_events = set(evt.unsigned.prev_content.pinned)
+ except (KeyError, ValueError, TypeError, AttributeError):
+ old_events = set()
+ await self.handle_room_pin(evt.room_id, evt.sender, new_events, old_events)
+ elif evt.type == EventType.ROOM_TOMBSTONE:
+ await self.handle_room_upgrade(evt.room_id, evt.content.replacement_room)
- if EVENT_TIME:
- EVENT_TIME.labels(event_type=evt_type).observe(time.time() - start_time)
+ # async def handle_event(self, evt: MatrixEvent) -> None:
+ # if self.filter_matrix_event(evt):
+ # return
+ # start_time = time.time()
+ #
+ # if EVENT_TIME:
+ # EVENT_TIME.labels(event_type=evt_type).observe(time.time() - start_time)
diff --git a/mautrix_telegram/portal.py b/mautrix_telegram/portal.py
index d62c2602..cf2342b4 100644
--- a/mautrix_telegram/portal.py
+++ b/mautrix_telegram/portal.py
@@ -864,23 +864,32 @@ class Portal:
else:
await user.client(ReadMessageHistoryRequest(peer=self.peer, max_id=message.tgid))
- async def kick_matrix(self, user: Union['u.User', 'p.Puppet'], source: 'u.User') -> None:
+ async def kick_matrix(self, user: Union['u.User', 'p.Puppet'], source: 'u.User',
+ ban: bool = False) -> None:
if user.tgid == source.tgid:
return
+ if isinstance(user, u.User) and await user.needs_relaybot(self):
+ if not self.bot:
+ return
+ # TODO kick and ban message
+ return
if await source.needs_relaybot(self):
+ if not self.has_bot:
+ return
source = self.bot
+ target = await user.get_input_entity(source)
if self.peer_type == "chat":
- await source.client(DeleteChatUserRequest(chat_id=self.tgid, user_id=user.tgid))
+ await source.client(DeleteChatUserRequest(chat_id=self.tgid, user_id=target))
elif self.peer_type == "channel":
channel = await self.get_input_entity(source)
- rights = ChatBannedRights(datetime.fromtimestamp(0), True)
- await source.client(EditBannedRequest(channel=channel,
- user_id=user.tgid,
- banned_rights=rights))
+ await source.client.edit_permissions(channel, target, view_messages=False)
+ if not ban:
+ await source.client.edit_permissions(channel, target, view_messages=True)
- async def leave_matrix(self, user: 'u.User', source: 'u.User',
- event_id: MatrixEventID) -> None:
+ async def leave_matrix(self, user: 'u.User', event_id: MatrixEventID) -> None:
if await user.needs_relaybot(self):
+ if not self.has_bot:
+ return
async with self.require_send_lock(self.bot.tgid):
message = await self._get_state_change_message("leave", user)
if not message:
@@ -900,8 +909,6 @@ class Portal:
del self.by_mxid[self.mxid]
except KeyError:
pass
- elif source and source.tgid != user.tgid:
- await self.kick_matrix(user, source)
elif self.peer_type == "chat":
await user.client(DeleteChatUserRequest(chat_id=self.tgid, user_id=InputUserSelf()))
elif self.peer_type == "channel":
diff --git a/mautrix_telegram/puppet.py b/mautrix_telegram/puppet.py
index 7d232436..85d7f326 100644
--- a/mautrix_telegram/puppet.py
+++ b/mautrix_telegram/puppet.py
@@ -521,7 +521,7 @@ class Puppet:
# endregion
-def init(context: 'Context') -> List[Awaitable[Any]]: # [None, None, PuppetError]
+def init(context: 'Context') -> Iterable[Awaitable[Any]]:
global config
Puppet.az, config, Puppet.loop, _ = context.core
Puppet.mx = context.mx
@@ -529,4 +529,4 @@ def init(context: 'Context') -> List[Awaitable[Any]]: # [None, None, PuppetErro
Puppet.hs_domain = config["homeserver"]["domain"]
Puppet.mxid_regex = re.compile(
f"@{Puppet.username_template.format(userid='([0-9]+)')}:{Puppet.hs_domain}")
- return [puppet.init_custom_mxid() for puppet in Puppet.all_with_custom_mxid()]
+ return (puppet.init_custom_mxid() for puppet in Puppet.all_with_custom_mxid())
diff --git a/mautrix_telegram/sqlstatestore.py b/mautrix_telegram/sqlstatestore.py
index cd7e4ab0..3b8c91ad 100644
--- a/mautrix_telegram/sqlstatestore.py
+++ b/mautrix_telegram/sqlstatestore.py
@@ -13,109 +13,26 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-from typing import Dict, Tuple
+from mautrix.types import UserID
+from mautrix.bridge.db import SQLStateStore as BaseSQLStateStore
-from mautrix_appservice import StateStore
-
-from .types import MatrixUserID, MatrixRoomID
from . import puppet as pu
-from .db import RoomState, UserProfile
-class SQLStateStore(StateStore):
- profile_cache: Dict[Tuple[str, str], UserProfile]
- room_state_cache: Dict[str, RoomState]
+class SQLStateStore(BaseSQLStateStore):
+ def is_registered(self, user_id: UserID) -> bool:
+ puppet = pu.Puppet.get_by_mxid(user_id, create=False)
+ if puppet:
+ return puppet.is_registered
+ custom_puppet = pu.Puppet.get_by_custom_mxid(user_id)
+ if custom_puppet:
+ return True
+ return super().is_registered(user_id)
- def __init__(self) -> None:
- super().__init__()
- self.profile_cache = {}
- self.room_state_cache = {}
-
- @staticmethod
- def is_registered(user: MatrixUserID) -> bool:
- puppet = pu.Puppet.get_by_mxid(user)
- return puppet.is_registered if puppet else False
-
- @staticmethod
- def registered(user: MatrixUserID) -> None:
- puppet = pu.Puppet.get_by_mxid(user)
+ def registered(self, user_id: UserID) -> None:
+ puppet = pu.Puppet.get_by_mxid(user_id, create=True)
if puppet:
puppet.is_registered = True
puppet.save()
-
- def update_state(self, event: Dict) -> None:
- event_type = event["type"]
- if event_type == "m.room.power_levels":
- self.set_power_levels(event["room_id"], event["content"])
- elif event_type == "m.room.member":
- self.set_member(event["room_id"], event["state_key"], event["content"])
-
- def _get_user_profile(self, room_id: MatrixRoomID, user_id: MatrixUserID, create: bool = True
- ) -> UserProfile:
- key = (room_id, user_id)
- try:
- return self.profile_cache[key]
- except KeyError:
- pass
-
- profile = UserProfile.get(*key)
- if profile:
- self.profile_cache[key] = profile
- elif create:
- profile = UserProfile(room_id=room_id, user_id=user_id, membership="leave")
- profile.insert()
- self.profile_cache[key] = profile
- return profile
-
- def get_member(self, room: MatrixRoomID, user: MatrixUserID) -> Dict:
- return self._get_user_profile(room, user).dict()
-
- def set_member(self, room: MatrixRoomID, user: MatrixUserID, member: Dict) -> None:
- profile = self._get_user_profile(room, user)
- profile.membership = member.get("membership", profile.membership or "leave")
- profile.displayname = member.get("displayname", profile.displayname)
- profile.avatar_url = member.get("avatar_url", profile.avatar_url)
- profile.update()
-
- def set_membership(self, room: MatrixRoomID, user: MatrixUserID, membership: str) -> None:
- self.set_member(room, user, {
- "membership": membership,
- })
-
- def _get_room_state(self, room_id: MatrixRoomID, create: bool = True) -> RoomState:
- try:
- return self.room_state_cache[room_id]
- except KeyError:
- pass
-
- room = RoomState.get(room_id)
- if room:
- self.room_state_cache[room_id] = room
- elif create:
- room = RoomState(room_id=room_id)
- room.insert()
- self.room_state_cache[room_id] = room
- return room
-
- def has_power_levels(self, room: MatrixRoomID) -> bool:
- return bool(self._get_room_state(room).power_levels)
-
- def get_power_levels(self, room: MatrixRoomID) -> Dict:
- return self._get_room_state(room).power_levels
-
- def set_power_level(self, room: MatrixRoomID, user: MatrixUserID, level: int) -> None:
- room_state = self._get_room_state(room)
- power_levels = room_state.power_levels
- if not power_levels:
- power_levels = {
- "users": {},
- "events": {},
- }
- power_levels[room]["users"][user] = level
- room_state.power_levels = power_levels
- room_state.update()
-
- def set_power_levels(self, room: MatrixRoomID, content: Dict) -> None:
- state = self._get_room_state(room)
- state.power_levels = content
- state.update()
+ else:
+ super().registered(user_id)
diff --git a/mautrix_telegram/types.py b/mautrix_telegram/types.py
index 15cc7094..a1871b48 100644
--- a/mautrix_telegram/types.py
+++ b/mautrix_telegram/types.py
@@ -1,9 +1,3 @@
from typing import Dict, NewType
-MatrixUserID = NewType('MatrixUserID', str)
-MatrixRoomID = NewType('MatrixRoomID', str)
-MatrixEventID = NewType('MatrixEventID', str)
-
-MatrixEvent = NewType('MatrixEvent', Dict)
-
TelegramID = NewType('TelegramID', int)
diff --git a/mautrix_telegram/user.py b/mautrix_telegram/user.py
index cfab854e..eacea399 100644
--- a/mautrix_telegram/user.py
+++ b/mautrix_telegram/user.py
@@ -331,7 +331,7 @@ class User(AbstractUser):
async def needs_relaybot(self, portal: po.Portal) -> bool:
return not await self.is_logged_in() or (
- (portal.has_bot or self.bot) and portal.tgid_full not in self.portals)
+ (portal.has_bot or self.is_bot) and portal.tgid_full not in self.portals)
def _hash_contacts(self) -> int:
acc = 0
@@ -408,9 +408,8 @@ class User(AbstractUser):
# endregion
-def init(context: 'Context') -> List[Awaitable['User']]:
+def init(context: 'Context') -> Iterable[Awaitable['User']]:
global config
config = context.config
- users = [User.from_db(user) for user in DBUser.all()]
- return [user.ensure_started() for user in users if user.tgid]
+ return (User.from_db(db_user).ensure_started() for db_user in DBUser.all() if db_user.tgid)
diff --git a/requirements.txt b/requirements.txt
index e9e5ccc2..6cd3af4d 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,5 +1,5 @@
aiohttp
-mautrix-appservice
+mautrix
ruamel.yaml
python-magic
SQLAlchemy
diff --git a/setup.py b/setup.py
index c3e20fe6..82e26bcb 100644
--- a/setup.py
+++ b/setup.py
@@ -31,7 +31,7 @@ setuptools.setup(
install_requires=[
"aiohttp>=3.0.1,<4",
- "mautrix-appservice>=0.3.11,<0.4.0",
+ "mautrix>=0.4.0.dev46,<0.5",
"SQLAlchemy>=1.2.3,<2",
"alembic>=1.0.0,<2",
"commonmark>=0.8.1,<1",