Compare commits

...

70 Commits

Author SHA1 Message Date
Tulir Asokan 1994ce38eb Bump version to 0.4.0 2018-11-28 02:10:37 +02:00
Tulir Asokan 9aad6de823 Bump version to 0.4.0rc2 2018-11-15 22:46:36 +02:00
Tulir Asokan 3d3afdb645 Fix bug in 82d7e78455 2018-11-15 22:45:48 +02:00
Tulir Asokan 983f5001ab Bump version to 0.4.0rc1 2018-11-15 22:27:25 +02:00
Tulir Asokan a80fdf0990 Fix bug in 720210ac08 2018-11-15 22:25:49 +02:00
Tulir Asokan 82d7e78455 Handle kicking puppets separately. Fixes #191 2018-11-15 11:57:02 +02:00
Tulir Asokan d514b929b3 Automatically log out when logging in with a user someone logged in with previously. Fixes #198 2018-11-15 11:45:46 +02:00
Tulir Asokan 720210ac08 Check if client is connected before checking if authorized. Fixes #215 2018-11-15 11:45:36 +02:00
Tulir Asokan 2dfc05db5f Fall back to get_dialogs if get_entity fails. Fixes #229 2018-11-15 11:20:43 +02:00
Tulir Asokan d551934ec1 Fix command suggestion when trying to bridge non-whitelisted chat 2018-11-01 01:55:54 +02:00
Tulir Asokan bac1e30cf0 Fix Matrix->Telegram code blocks without language. Fixes #240 2018-10-27 19:22:04 +03:00
Tulir Asokan 8fdb2c4e57 Merge pull request #239 from tulir/sqlalchemy-core
Port Message table to SQLAlchemy Core
2018-10-21 00:32:14 +03:00
Tulir Asokan 8da1fb78b8 Handle aiohttp errors in syncer. Fixes #210 2018-10-21 00:09:37 +03:00
Tulir Asokan cea8163366 Only match integers in puppet mxid regex. Fixes #234 2018-10-21 00:08:02 +03:00
Tulir Asokan 388e4f8601 Port Message table to SQLAlchemy Core 2018-10-20 23:11:10 +03:00
Tulir Asokan 2756873c53 Add SIGINT/SIGTERM handler 2018-10-20 21:21:26 +03:00
Tulir Asokan a770e1f67e Merge pull request #237 from turt2live/travis/fix-chat-id-request
Don't try permission checks on rooms that aren't bridged
2018-10-20 14:56:27 +03:00
Tulir Asokan f8c844c4c0 Add flag to enable alchemysession core mode 2018-10-20 14:46:26 +03:00
Travis Ralston 7f23d4cf68 Don't try permission checks on rooms that aren't bridged
This is the proper way to fix https://github.com/tulir/mautrix-telegram/pull/235
2018-10-19 19:31:58 -06:00
Tulir Asokan 247c75191b Merge pull request #226 from turt2live/travis/bridge-info
Add provisioning route for getting misc bridge info
2018-10-08 14:02:19 +03:00
Travis Ralston 4f3e1b4fe6 Fix errors in spec.yaml 2018-10-08 01:16:29 -06:00
Travis Ralston 6291e92ed7 Remove extraneous fstring 2018-10-08 01:15:49 -06:00
Tulir Asokan 5054afcbb5 Fix Python 3.5 compatibility 2018-10-02 14:51:54 +03:00
Tulir Asokan 980e0d6ef7 Send captions as second message by default. Fixes #233 2018-09-29 10:56:04 +03:00
Tulir Asokan 2f6147f325 Fix notice bridging exceptions 2018-09-29 01:35:30 +03:00
Tulir Asokan 56fb88b75e Use mxids instead of localparts as default displaynames and fix name add/remove message. Fixes #228 2018-09-29 00:59:02 +03:00
Tulir Asokan 24bdda8ca1 Reorganize formatter utils and add more blue text 2018-09-28 18:39:57 +03:00
Tulir Asokan c38e46fc2a Fix linebreaks in pre blocks 2018-09-28 17:15:57 +03:00
Tulir Asokan 916cc3746d Fix block tag newlines and allow <strike>. Fixes #232 2018-09-28 17:06:42 +03:00
Tulir Asokan a32bc2985a Show phone number when username doesn't exist. Fixes #213 2018-09-28 02:46:02 +03:00
Tulir Asokan 8d982b4615 Bump minimum mautrix-appservice version. Fixes #217 2018-09-28 02:22:54 +03:00
Tulir Asokan 10e77707d0 Fix HTML escaping in command reply markdown parser 2018-09-28 02:18:41 +03:00
Tulir Asokan b0fe208768 Add missing await to portal.set_typing 2018-09-28 01:18:39 +03:00
Tulir Asokan b44d6d2d90 Fix minor things and type hints 2018-09-28 01:02:09 +03:00
Tulir Asokan 828047e272 Split TelegramMessage helper to separate file 2018-09-28 00:49:37 +03:00
Tulir Asokan a9cb1bf518 Fix linebreak handling in lxml parser and add better bullets
Fixes #218
2018-09-28 00:45:37 +03:00
Tulir Asokan d71f421981 Use <pre> for multiline MessageEntityCode entities 2018-09-26 00:24:04 +03:00
Tulir Asokan 26e947992e Merge pull request #231 from tulir/room-specific-settings
Add room specific config
2018-09-25 00:47:44 +03:00
Tulir Asokan 78e4804774 Fix minor things and improve code style 2018-09-25 00:47:16 +03:00
Tulir Asokan 5ccd1bc2fe Fix bugs and switch to commonmark for command replies 2018-09-25 00:26:02 +03:00
Tulir Asokan f758884c75 Fix example config and add alembic migration 2018-09-24 23:41:18 +03:00
Tulir Asokan 9d2d34a25c Add command to update room-specific config 2018-09-24 17:44:00 +03:00
Tulir Asokan fc23461445 Add room specific settings. Probably broken 2018-09-24 16:01:16 +03:00
Tulir Asokan 5253504df9 Update setup.py classifiers 2018-09-24 01:26:02 +03:00
Tulir Asokan dd270b862e Fix handling capitalized file extensions. Fixes #156 2018-09-24 01:25:51 +03:00
Travis Ralston 5bc1362493 Add provisioning route for getting misc bridge info
Currently only the relay bot's username is exposed here.
2018-09-19 22:44:27 -06:00
Tulir Asokan 96a0c923c2 Merge pull request #225 from turt2live/travis/unbridge-info
Add a flag to indicate if the requesting user can unbridge the portal
2018-09-17 01:24:17 +03:00
Travis Ralston 23bb2871fd Add a flag to indicate if the requesting user can unbridge the portal 2018-09-16 16:16:33 -06:00
Tulir Asokan d4ea5f8b38 Improve type hints and set version to 0.4.0+dev 2018-09-10 01:14:12 +03:00
Tulir Asokan 4b2cdc3d39 Add missing command status clear 2018-09-10 00:14:35 +03:00
Tulir Asokan 4c54d9c9ea Fix previous commit (ref #219) and update catch_up config comment 2018-09-10 00:11:13 +03:00
Tulir Asokan 9541d5eceb Don't bridge messages from unbridged chats received by bot (ref #219) 2018-09-09 01:26:22 +03:00
Tulir Asokan c9c1023ece Merge pull request #223 from turt2live/patch-1
Allow negative numbers in /connect
2018-09-09 01:14:38 +03:00
Travis Ralston cb2073eb8b Allow negative numbers in /connect 2018-09-08 16:14:00 -06:00
Tulir Asokan d35104aea6 Fix incorrect type hint 2018-09-05 10:55:12 +03:00
Tulir Asokan ad342f2ca4 Ignore old log files too 2018-09-01 19:08:29 +03:00
Tulir Asokan 29541ff520 Pass logging a copy of the config to stop editing. Fixes #216 2018-09-01 14:07:44 +03:00
Tulir Asokan 6a1c160608 Await set_presence. Fixes #209 2018-09-01 14:03:13 +03:00
Tulir Asokan 731c802fcd Only import deque in type checking mode to fix 3.5 runtime support 2018-08-30 19:03:22 +03:00
Tulir Asokan b6f15934f2 Fix conversational command handling 2018-08-30 13:32:04 +03:00
Tulir Asokan 068449c59c Update ROADMAP.md 2018-08-24 09:47:54 +03:00
Tulir Asokan 4f36a2c7c1 Simplify displayname similarity calculation 2018-08-17 00:06:37 +03:00
Tulir Asokan bb04231880 Fix bugs in migrations 2018-08-17 00:06:02 +03:00
Tulir Asokan 1ef790ce31 Merge pull request #206 from V02460/master
Add type annotations
2018-08-15 10:18:39 +03:00
Kai A. Hiller 81531235bc Replace double quote type annotations with single quotes 2018-08-09 14:36:14 +02:00
Kai A. Hiller 66683151ec Make SearchResult a NewType and make its List explicit 2018-08-09 14:23:18 +02:00
Kai A. Hiller e751d140f2 Change case of new types 2018-08-09 14:11:41 +02:00
Kai A. Hiller 0f8009b1e9 Add missing type hints and fix most type errors except for Optionals. 2018-08-09 03:31:04 +02:00
Kai A. Hiller 01e153662e Replace star imports with literal values 2018-08-09 02:42:48 +02:00
Kai A. Hiller 08dd5b5b15 Add None return type to functions 2018-08-09 02:42:47 +02:00
52 changed files with 1862 additions and 1094 deletions
+1 -1
View File
@@ -7,5 +7,5 @@ __pycache__
config.yaml config.yaml
registration.yaml registration.yaml
*.log *.log*
*.db *.db
+4 -4
View File
@@ -4,9 +4,9 @@
* [x] Message content (text, formatting, files, etc..) * [x] Message content (text, formatting, files, etc..)
* [x] Message redactions * [x] Message redactions
* [ ] ‡ Message history * [ ] ‡ Message history
* [ ] Presence * [x] Presence
* [ ] Typing notifications * [x] Typing notifications
* [ ] Read receipts * [x] Read receipts
* [x] Pinning messages * [x] Pinning messages
* [x] Power level * [x] Power level
* [x] Normal chats * [x] Normal chats
@@ -46,7 +46,7 @@
* [x] When receiving invite or message * [x] When receiving invite or message
* [x] Private chat creation by inviting Matrix puppet of Telegram user to new room * [x] Private chat creation by inviting Matrix puppet of Telegram user to new room
* [x] Option to use bot to relay messages for unauthenticated Matrix users * [x] Option to use bot to relay messages for unauthenticated Matrix users
* [ ] Option to use own Matrix account for messages sent from other Telegram clients * [x] Option to use own Matrix account for messages sent from other Telegram clients
* [ ] ‡ Calls (hard, not yet supported by Telethon) * [ ] ‡ Calls (hard, not yet supported by Telethon)
† Information not automatically sent from source, i.e. implementation may not be possible † Information not automatically sent from source, i.e. implementation may not be possible
@@ -21,4 +21,5 @@ def upgrade():
def downgrade(): def downgrade():
op.drop_column('puppet', 'is_bot') with op.batch_alter_table("puppet") as batch_op:
batch_op.drop_column('is_bot')
@@ -20,4 +20,5 @@ def upgrade():
def downgrade(): def downgrade():
op.drop_column('portal', 'megagroup') with op.batch_alter_table("portal") as batch_op:
batch_op.drop_column('megagroup')
@@ -72,10 +72,17 @@ def upgrade():
sa.Column("avatar_url", sa.String(), nullable=True), sa.Column("avatar_url", sa.String(), nullable=True),
sa.PrimaryKeyConstraint("room_id", "user_id")) sa.PrimaryKeyConstraint("room_id", "user_id"))
try:
migrate_state_store()
except Exception as e:
print("Failed to migrate state store:", e)
print("Migrating the state store isn't required, but you can retry by alembic downgrading "
"to revision 2228d49c383f and upgrading again.")
def migrate_state_store():
conn = op.get_bind() conn = op.get_bind()
session = orm.sessionmaker(bind=conn) session = orm.sessionmaker(bind=conn)() # type: orm.Session
session = orm.scoping.scoped_session(session)
Puppet.query = session.query_property()
try: try:
with open("mx-state.json") as file: with open("mx-state.json") as file:
@@ -99,7 +106,7 @@ def upgrade():
if not match: if not match:
continue continue
puppet = Puppet.query.get(match.group(1)) puppet = session.query(Puppet).get(match.group(1))
if not puppet: if not puppet:
continue continue
@@ -22,4 +22,5 @@ def upgrade():
def downgrade(): def downgrade():
op.drop_column('telegram_file', 'timestamp') with op.batch_alter_table("telegram_file") as batch_op:
batch_op.drop_column('timestamp')
@@ -0,0 +1,25 @@
"""Add phone number field to users
Revision ID: a9119be92164
Revises: b54929c22c86
Create Date: 2018-09-28 02:38:40.626282
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "a9119be92164"
down_revision = "b54929c22c86"
branch_labels = None
depends_on = None
def upgrade():
op.add_column("user", sa.Column("tg_phone", sa.String(), nullable=True))
def downgrade():
with op.batch_alter_table("user") as batch_op:
batch_op.drop_column("tg_phone")
@@ -0,0 +1,25 @@
"""Add portal-specific config
Revision ID: b54929c22c86
Revises: d5f7b8b4b456
Create Date: 2018-09-24 23:40:33.528710
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "b54929c22c86"
down_revision = "d5f7b8b4b456"
branch_labels = None
depends_on = None
def upgrade():
op.add_column("portal", sa.Column("config", sa.Text(), nullable=True))
def downgrade():
with op.batch_alter_table("portal") as batch_op:
batch_op.drop_column("config")
@@ -20,4 +20,5 @@ def upgrade():
def downgrade(): def downgrade():
op.drop_column('puppet', 'displayname_source') with op.batch_alter_table("puppet") as batch_op:
batch_op.drop_column('displayname_source')
+27 -17
View File
@@ -27,6 +27,9 @@ appservice:
# SQLite: sqlite:///filename.db # SQLite: sqlite:///filename.db
# Postgres: postgres://username:password@hostname/dbname # Postgres: postgres://username:password@hostname/dbname
database: sqlite:///mautrix-telegram.db database: sqlite:///mautrix-telegram.db
# Whether or not to use SQLAlchemy Core for common database actions. Use if the bridge is
# being bottlenecked on ORM commits. Only supported with PostgreSQL.
sqlalchemy_core_mode: false
# Public part of web server for out-of-Matrix interaction with the bridge. # Public part of web server for out-of-Matrix interaction with the bridge.
# Used for things like login if the user wants to make sure the 2FA password isn't stored in # Used for things like login if the user wants to make sure the 2FA password isn't stored in
@@ -95,15 +98,6 @@ bridge:
- username - username
- phone number - phone number
# Show message editing as a reply to the original message.
# If this is false, message edits are not shown at all, as Matrix does not support editing yet.
edits_as_replies: false
# Highlight changed/added parts in edits. Requires lxml.
highlight_edits: false
# Whether or not Matrix bot messages (type m.notice) should be bridged.
bridge_notices: true
# Whether to bridge Telegram bot messages as m.notices or m.texts.
bot_messages_as_notices: true
# Maximum number of members to sync per portal when starting up. Other members will be # Maximum number of members to sync per portal when starting up. Other members will be
# synced when they send messages. The maximum is 10000, after which the Telegram server # synced when they send messages. The maximum is 10000, after which the Telegram server
# will not send any more members. # will not send any more members.
@@ -119,21 +113,16 @@ bridge:
# Allow logging in within Matrix. If false, the only way to log in is using the out-of-Matrix # Allow logging in within Matrix. If false, the only way to log in is using the out-of-Matrix
# login website (see appservice.public config section) # login website (see appservice.public config section)
allow_matrix_login: true allow_matrix_login: true
# Use inline images instead of m.image to make rich captions possible.
# N.B. Inline images are not supported on all clients (e.g. Riot iOS).
inline_images: false
# Whether or not to bridge plaintext highlights. # Whether or not to bridge plaintext highlights.
# Only enable this if your displayname_template has some static part that the bridge can use to # Only enable this if your displayname_template has some static part that the bridge can use to
# reliably identify what is a plaintext highlight. # reliably identify what is a plaintext highlight.
plaintext_highlights: false plaintext_highlights: false
# Highlight changed/added parts in edits. Requires lxml.
highlight_edits: false
# Whether or not to make portals of publicly joinable channels/supergroups publicly joinable on Matrix. # Whether or not to make portals of publicly joinable channels/supergroups publicly joinable on Matrix.
public_portals: true public_portals: true
# Whether to send stickers as the new native m.sticker type or normal m.images.
# Old versions of Riot don't support the new type at all.
# Remember that proper sticker support always requires Pillow to convert webp into png.
native_stickers: true
# Whether or not to fetch and handle Telegram updates at startup from the time the bridge was down. # Whether or not to fetch and handle Telegram updates at startup from the time the bridge was down.
# WARNING: Probably buggy, might get stuck in infinite loop. # Currently only works for private chats and normal groups.
catch_up: false catch_up: false
# Whether or not to use /sync to get presence, read receipts and typing notifications when using # Whether or not to use /sync to get presence, read receipts and typing notifications when using
# your own Matrix account as the Matrix puppet for your Telegram account. # your own Matrix account as the Matrix puppet for your Telegram account.
@@ -149,6 +138,27 @@ bridge:
# You might need to increase this on high-traffic bridge instances. # You might need to increase this on high-traffic bridge instances.
cache_queue_length: 20 cache_queue_length: 20
# Show message editing as a reply to the original message.
# If this is false, message edits are not shown at all, as Matrix does not support editing yet.
edits_as_replies: false
bridge_notices:
# Whether or not Matrix bot messages (type m.notice) should be bridged.
default: false
# List of user IDs for whom the previous flag is flipped.
# e.g. if bridge_notices.default is false, notices from other users will not be bridged, but
# notices from users listed here will be bridged.
exceptions:
- "@importantbot:example.com"
# Whether to bridge Telegram bot messages as m.notices or m.texts.
bot_messages_as_notices: true
# Use inline images instead of a separate message for the caption.
# N.B. Inline images are not supported on all clients (e.g. Riot iOS).
inline_images: false
# Whether to send stickers as the new native m.sticker type or normal m.images.
# Old versions of Riot don't support the new type at all.
# Remember that proper sticker support always requires Pillow to convert webp into png.
native_stickers: true
# The formats to use when sending messages to Telegram via the relay bot. # The formats to use when sending messages to Telegram via the relay bot.
# #
# Telegram doesn't have built-in emotes, so the m.emote format is also used for non-relaybot users. # Telegram doesn't have built-in emotes, so the m.emote format is also used for non-relaybot users.
+1 -1
View File
@@ -1,2 +1,2 @@
__version__ = "0.3.0" __version__ = "0.4.0"
__author__ = "Tulir Asokan <tulir@maunium.net>" __author__ = "Tulir Asokan <tulir@maunium.net>"
+15 -5
View File
@@ -14,11 +14,13 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional from typing import Coroutine, List
import argparse import argparse
import asyncio import asyncio
import logging.config import logging.config
import sys import sys
import copy
import signal
from sqlalchemy import orm from sqlalchemy import orm
import sqlalchemy as sql import sqlalchemy as sql
@@ -66,7 +68,7 @@ if args.generate_registration:
print(f"Registration generated and saved to {config.registration_path}") print(f"Registration generated and saved to {config.registration_path}")
sys.exit(0) sys.exit(0)
logging.config.dictConfig(config["logging"]) logging.config.dictConfig(copy.deepcopy(config["logging"]))
log = logging.getLogger("mau.init") # type: logging.Logger log = logging.getLogger("mau.init") # type: logging.Logger
log.debug(f"Initializing mautrix-telegram {__version__}") log.debug(f"Initializing mautrix-telegram {__version__}")
@@ -78,6 +80,11 @@ Base.metadata.bind = db_engine
session_container = AlchemySessionContainer(engine=db_engine, session=db_session, session_container = AlchemySessionContainer(engine=db_engine, session=db_session,
table_base=Base, table_prefix="telethon_", table_base=Base, table_prefix="telethon_",
manage_tables=False) manage_tables=False)
if config["appservice.sqlalchemy_core_mode"]:
try:
session_container.core_mode = True
except AttributeError:
log.error("Current version of teleton-session-sqlalchemy does not support core mode")
loop = asyncio.get_event_loop() # type: asyncio.AbstractEventLoop loop = asyncio.get_event_loop() # type: asyncio.AbstractEventLoop
@@ -106,7 +113,7 @@ if config["appservice.provisioning.enabled"]:
context.provisioning_api = provisioning_api context.provisioning_api = provisioning_api
with appserv.run(config["appservice.hostname"], config["appservice.port"]) as start: with appserv.run(config["appservice.hostname"], config["appservice.port"]) as start:
init_db(db_session) init_db(db_session, db_engine)
init_abstract_user(context) init_abstract_user(context)
context.bot = init_bot(context) context.bot = init_bot(context)
context.mx = MatrixHandler(context) context.mx = MatrixHandler(context)
@@ -115,18 +122,21 @@ with appserv.run(config["appservice.hostname"], config["appservice.port"]) as st
startup_actions = (init_puppet(context) + startup_actions = (init_puppet(context) +
init_user(context) + init_user(context) +
[start, [start,
context.mx.init_as_bot()]) context.mx.init_as_bot()]) # type: List[Coroutine]
if context.bot: if context.bot:
startup_actions.append(context.bot.start()) startup_actions.append(context.bot.start())
signal.signal(signal.SIGINT, signal.default_int_handler)
signal.signal(signal.SIGTERM, signal.default_int_handler)
try: try:
log.debug("Initialization complete, running startup actions") log.debug("Initialization complete, running startup actions")
loop.run_until_complete(asyncio.gather(*startup_actions, loop=loop)) loop.run_until_complete(asyncio.gather(*startup_actions, loop=loop))
log.debug("Startup actions complete, now running forever") log.debug("Startup actions complete, now running forever")
loop.run_forever() loop.run_forever()
except KeyboardInterrupt: except KeyboardInterrupt:
log.debug("Keyboard interrupt received, stopping clients") log.debug("Interrupt received, stopping clients")
loop.run_until_complete( loop.run_until_complete(
asyncio.gather(*[user.stop() for user in User.by_tgid.values()], loop=loop)) asyncio.gather(*[user.stop() for user in User.by_tgid.values()], loop=loop))
log.debug("Clients stopped, shutting down") log.debug("Clients stopped, shutting down")
+63 -48
View File
@@ -35,6 +35,7 @@ from alchemysession import AlchemySessionContainer
from . import portal as po, puppet as pu, __version__ from . import portal as po, puppet as pu, __version__
from .db import Message as DBMessage from .db import Message as DBMessage
from .types import TelegramID, MatrixUserID
from .tgclient import MautrixTelegramClient from .tgclient import MautrixTelegramClient
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -60,17 +61,18 @@ class AbstractUser(ABC):
bot = None # type: Bot bot = None # type: Bot
ignore_incoming_bot_events = True # type: bool ignore_incoming_bot_events = True # type: bool
def __init__(self): def __init__(self) -> None:
self.is_admin = False # type: bool self.is_admin = False # type: bool
self.matrix_puppet_whitelisted = False # type: bool self.matrix_puppet_whitelisted = False # type: bool
self.puppet_whitelisted = False # type: bool self.puppet_whitelisted = False # type: bool
self.whitelisted = False # type: bool self.whitelisted = False # type: bool
self.relaybot_whitelisted = False # type: bool self.relaybot_whitelisted = False # type: bool
self.client = None # type: MautrixTelegramClient self.client = None # type: MautrixTelegramClient
self.tgid = None # type: int self.tgid = None # type: TelegramID
self.mxid = None # type: str self.mxid = None # type: MatrixUserID
self.is_relaybot = False # type: bool self.is_relaybot = False # type: bool
self.is_bot = False # type: bool self.is_bot = False # type: bool
self.relaybot = None # type: Optional[Bot]
@property @property
def connected(self) -> bool: def connected(self) -> bool:
@@ -93,7 +95,7 @@ class AbstractUser(ABC):
config["telegram.proxy.rdns"], config["telegram.proxy.rdns"],
config["telegram.proxy.username"], config["telegram.proxy.password"]) config["telegram.proxy.username"], config["telegram.proxy.password"])
def _init_client(self): def _init_client(self) -> None:
self.log.debug(f"Initializing client for {self.name}") self.log.debug(f"Initializing client for {self.name}")
device = f"{platform.system()} {platform.release()}" device = f"{platform.system()} {platform.release()}"
sysversion = MautrixTelegramClient.__version__ sysversion = MautrixTelegramClient.__version__
@@ -114,18 +116,18 @@ class AbstractUser(ABC):
return False return False
@abstractmethod @abstractmethod
async def post_login(self): async def post_login(self) -> None:
raise NotImplementedError() raise NotImplementedError()
@abstractmethod @abstractmethod
def register_portal(self, portal: po.Portal): def register_portal(self, portal: po.Portal) -> None:
raise NotImplementedError() raise NotImplementedError()
@abstractmethod @abstractmethod
def unregister_portal(self, portal: po.Portal): def unregister_portal(self, portal: po.Portal) -> None:
raise NotImplementedError() raise NotImplementedError()
async def _update_catch(self, update: TypeUpdate): async def _update_catch(self, update: TypeUpdate) -> None:
try: try:
if not await self.update(update): if not await self.update(update):
await self._update(update) await self._update(update)
@@ -147,21 +149,21 @@ class AbstractUser(ABC):
raise NotImplementedError() raise NotImplementedError()
async def is_logged_in(self) -> bool: async def is_logged_in(self) -> bool:
return self.client and await self.client.is_user_authorized() return self.client and self.client.is_connected() and await self.client.is_user_authorized()
async def has_full_access(self, allow_bot: bool = False) -> bool: async def has_full_access(self, allow_bot: bool = False) -> bool:
return (self.puppet_whitelisted return (self.puppet_whitelisted
and (not self.is_bot or allow_bot) and (not self.is_bot or allow_bot)
and await self.is_logged_in()) and await self.is_logged_in())
async def start(self, delete_unless_authenticated: bool = False) -> "AbstractUser": async def start(self, delete_unless_authenticated: bool = False) -> 'AbstractUser':
if not self.client: if not self.client:
self._init_client() self._init_client()
await self.client.connect() await self.client.connect()
self.log.debug("%s connected: %s", self.mxid, self.connected) self.log.debug("%s connected: %s", self.mxid, self.connected)
return self return self
async def ensure_started(self, even_if_no_session=False) -> "AbstractUser": async def ensure_started(self, even_if_no_session=False) -> 'AbstractUser':
if not self.puppet_whitelisted: if not self.puppet_whitelisted:
return self return self
self.log.debug("ensure_started(%s, connected=%s, even_if_no_session=%s, session_count=%s)", self.log.debug("ensure_started(%s, connected=%s, even_if_no_session=%s, session_count=%s)",
@@ -175,13 +177,13 @@ class AbstractUser(ABC):
await self.start(delete_unless_authenticated=not even_if_no_session) await self.start(delete_unless_authenticated=not even_if_no_session)
return self return self
async def stop(self): async def stop(self) -> None:
await self.client.disconnect() await self.client.disconnect()
self.client = None self.client = None
# region Telegram update handling # region Telegram update handling
async def _update(self, update: TypeUpdate): async def _update(self, update: TypeUpdate) -> None:
if isinstance(update, (UpdateShortChatMessage, UpdateShortMessage, UpdateNewChannelMessage, if isinstance(update, (UpdateShortChatMessage, UpdateShortMessage, UpdateNewChannelMessage,
UpdateNewMessage, UpdateEditMessage, UpdateEditChannelMessage)): UpdateNewMessage, UpdateEditMessage, UpdateEditChannelMessage)):
await self.update_message(update) await self.update_message(update)
@@ -207,55 +209,63 @@ class AbstractUser(ABC):
self.log.debug("Unhandled update: %s", update) self.log.debug("Unhandled update: %s", update)
@staticmethod @staticmethod
async def update_pinned_messages(update: UpdateChannelPinnedMessage): async def update_pinned_messages(update: UpdateChannelPinnedMessage) -> None:
portal = po.Portal.get_by_tgid(update.channel_id) portal = po.Portal.get_by_tgid(TelegramID(update.channel_id))
if portal and portal.mxid: if portal and portal.mxid:
await portal.receive_telegram_pin_id(update.id) await portal.receive_telegram_pin_id(update.id)
@staticmethod @staticmethod
async def update_participants(update: UpdateChatParticipants): async def update_participants(update: UpdateChatParticipants) -> None:
portal = po.Portal.get_by_tgid(update.participants.chat_id) portal = po.Portal.get_by_tgid(TelegramID(update.participants.chat_id))
if portal and portal.mxid: if portal and portal.mxid:
await portal.update_telegram_participants(update.participants.participants) await portal.update_telegram_participants(update.participants.participants)
async def update_read_receipt(self, update: UpdateReadHistoryOutbox): async def update_read_receipt(self, update: UpdateReadHistoryOutbox) -> None:
if not isinstance(update.peer, PeerUser): if not isinstance(update.peer, PeerUser):
self.log.debug("Unexpected read receipt peer: %s", update.peer) self.log.debug("Unexpected read receipt peer: %s", update.peer)
return return
portal = po.Portal.get_by_tgid(update.peer.user_id, self.tgid) portal = po.Portal.get_by_tgid(TelegramID(update.peer.user_id), self.tgid)
if not portal or not portal.mxid: if not portal or not portal.mxid:
return return
# We check that these are user read receipts, so tg_space is always the user ID. # We check that these are user read receipts, so tg_space is always the user ID.
message = DBMessage.query.get((update.max_id, self.tgid)) message = DBMessage.get_by_tgid(update.max_id, self.tgid)
if not message: if not message:
return return
puppet = pu.Puppet.get(update.peer.user_id) puppet = pu.Puppet.get(TelegramID(update.peer.user_id))
await puppet.intent.mark_read(portal.mxid, message.mxid) await puppet.intent.mark_read(portal.mxid, message.mxid)
async def update_admin(self, update: Union[UpdateChatAdmins, UpdateChatParticipantAdmin]): async def update_admin(self,
update: Union[UpdateChatAdmins, UpdateChatParticipantAdmin]) -> None:
# TODO duplication not checked # TODO duplication not checked
portal = po.Portal.get_by_tgid(update.chat_id, peer_type="chat") portal = po.Portal.get_by_tgid(TelegramID(update.chat_id), peer_type="chat")
if not portal or not portal.mxid:
return
if isinstance(update, UpdateChatAdmins): if isinstance(update, UpdateChatAdmins):
await portal.set_telegram_admins_enabled(update.enabled) await portal.set_telegram_admins_enabled(update.enabled)
elif isinstance(update, UpdateChatParticipantAdmin): elif isinstance(update, UpdateChatParticipantAdmin):
await portal.set_telegram_admin(update.user_id) await portal.set_telegram_admin(TelegramID(update.user_id))
else: else:
self.log.warning("Unexpected admin status update: %s", update) self.log.warning("Unexpected admin status update: %s", update)
async def update_typing(self, update: Union[UpdateUserTyping, UpdateChatUserTyping]): async def update_typing(self, update: Union[UpdateUserTyping, UpdateChatUserTyping]) -> None:
if isinstance(update, UpdateUserTyping): if isinstance(update, UpdateUserTyping):
portal = po.Portal.get_by_tgid(update.user_id, self.tgid, "user") portal = po.Portal.get_by_tgid(TelegramID(update.user_id), self.tgid, "user")
else: else:
portal = po.Portal.get_by_tgid(update.chat_id, peer_type="chat") portal = po.Portal.get_by_tgid(TelegramID(update.chat_id), peer_type="chat")
sender = pu.Puppet.get(update.user_id)
if not portal or not portal.mxid:
return
sender = pu.Puppet.get(TelegramID(update.user_id))
await portal.handle_telegram_typing(sender, update) await portal.handle_telegram_typing(sender, update)
async def update_others_info(self, update: Union[UpdateUserName, UpdateUserPhoto]): async def update_others_info(self, update: Union[UpdateUserName, UpdateUserPhoto]) -> None:
# TODO duplication not checked # TODO duplication not checked
puppet = pu.Puppet.get(update.user_id) puppet = pu.Puppet.get(TelegramID(update.user_id))
if isinstance(update, UpdateUserName): if isinstance(update, UpdateUserName):
if await puppet.update_displayname(self, update): if await puppet.update_displayname(self, update):
puppet.save() puppet.save()
@@ -265,8 +275,8 @@ class AbstractUser(ABC):
else: else:
self.log.warning("Unexpected other user info update: %s", update) self.log.warning("Unexpected other user info update: %s", update)
async def update_status(self, update: UpdateUserStatus): async def update_status(self, update: UpdateUserStatus) -> None:
puppet = pu.Puppet.get(update.user_id) puppet = pu.Puppet.get(TelegramID(update.user_id))
if isinstance(update.status, UserStatusOnline): if isinstance(update.status, UserStatusOnline):
await puppet.default_mxid_intent.set_presence("online") await puppet.default_mxid_intent.set_presence("online")
elif isinstance(update.status, UserStatusOffline): elif isinstance(update.status, UserStatusOffline):
@@ -279,10 +289,10 @@ class AbstractUser(ABC):
Optional[pu.Puppet], Optional[pu.Puppet],
Optional[po.Portal]]: Optional[po.Portal]]:
if isinstance(update, UpdateShortChatMessage): if isinstance(update, UpdateShortChatMessage):
portal = po.Portal.get_by_tgid(update.chat_id, peer_type="chat") portal = po.Portal.get_by_tgid(TelegramID(update.chat_id), peer_type="chat")
sender = pu.Puppet.get(update.from_id) sender = pu.Puppet.get(TelegramID(update.from_id))
elif isinstance(update, UpdateShortMessage): elif isinstance(update, UpdateShortMessage):
portal = po.Portal.get_by_tgid(update.user_id, self.tgid, "user") portal = po.Portal.get_by_tgid(TelegramID(update.user_id), self.tgid, "user")
sender = pu.Puppet.get(self.tgid if update.out else update.user_id) sender = pu.Puppet.get(self.tgid if update.out else update.user_id)
elif isinstance(update, (UpdateNewMessage, UpdateNewChannelMessage, elif isinstance(update, (UpdateNewMessage, UpdateNewChannelMessage,
UpdateEditMessage, UpdateEditChannelMessage)): UpdateEditMessage, UpdateEditChannelMessage)):
@@ -300,7 +310,7 @@ class AbstractUser(ABC):
return update, sender, portal return update, sender, portal
@staticmethod @staticmethod
async def _try_redact(portal: po.Portal, message: DBMessage): async def _try_redact(portal: po.Portal, message: DBMessage) -> None:
if not portal: if not portal:
return return
try: try:
@@ -308,40 +318,45 @@ class AbstractUser(ABC):
except MatrixRequestError: except MatrixRequestError:
pass pass
async def delete_message(self, update: UpdateDeleteMessages): async def delete_message(self, update: UpdateDeleteMessages) -> None:
if len(update.messages) > MAX_DELETIONS: if len(update.messages) > MAX_DELETIONS:
return return
for message in update.messages: for message in update.messages:
message = DBMessage.query.get((message, self.tgid)) message = DBMessage.get_by_tgid(TelegramID(message), self.tgid)
if not message: if not message:
continue continue
self.db.delete(message) message.delete()
number_left = DBMessage.query.filter(DBMessage.mxid == message.mxid, number_left = DBMessage.count_spaces_by_mxid(message.mxid, message.mx_room)
DBMessage.mx_room == message.mx_room).count()
if number_left == 0: if number_left == 0:
portal = po.Portal.get_by_mxid(message.mx_room) portal = po.Portal.get_by_mxid(message.mx_room)
await self._try_redact(portal, message) await self._try_redact(portal, message)
self.db.commit() self.db.commit()
async def delete_channel_message(self, update: UpdateDeleteChannelMessages): async def delete_channel_message(self, update: UpdateDeleteChannelMessages) -> None:
if len(update.messages) > MAX_DELETIONS: if len(update.messages) > MAX_DELETIONS:
return return
portal = po.Portal.get_by_tgid(update.channel_id) portal = po.Portal.get_by_tgid(TelegramID(update.channel_id))
if not portal: if not portal:
return return
for message in update.messages: for message in update.messages:
message = DBMessage.query.get((message, portal.tgid)) message = DBMessage.get_by_tgid(TelegramID(message), portal.tgid)
if not message: if not message:
continue continue
self.db.delete(message) message.delete()
await self._try_redact(portal, message) await self._try_redact(portal, message)
self.db.commit() self.db.commit()
async def update_message(self, original_update: UpdateMessage): async def update_message(self, original_update: UpdateMessage) -> None:
update, sender, portal = self.get_message_details(original_update) update, sender, portal = self.get_message_details(original_update)
if self.is_bot and not portal.mxid:
self.log.debug(f"Ignoring message received by bot in unbridged chat %s",
portal.tgid_log)
return
if self.ignore_incoming_bot_events and self.bot and sender.id == self.bot.tgid: if self.ignore_incoming_bot_events and self.bot and sender.id == self.bot.tgid:
self.log.debug(f"Ignoring relaybot-sent message %s to %s", update, portal.tgid_log) self.log.debug(f"Ignoring relaybot-sent message %s to %s", update, portal.tgid_log)
return return
@@ -369,9 +384,9 @@ class AbstractUser(ABC):
# endregion # endregion
def init(context: "Context"): def init(context: "Context") -> None:
global config, MAX_DELETIONS global config, MAX_DELETIONS
AbstractUser.az, AbstractUser.db, config, AbstractUser.loop, AbstractUser.relaybot = context AbstractUser.az, AbstractUser.db, config, AbstractUser.loop, AbstractUser.relaybot = context.core
AbstractUser.ignore_incoming_bot_events = config["bridge.relaybot.ignore_own_incoming_events"] AbstractUser.ignore_incoming_bot_events = config["bridge.relaybot.ignore_own_incoming_events"]
AbstractUser.session_container = context.session_container AbstractUser.session_container = context.session_container
MAX_DELETIONS = config.get("bridge.max_telegram_delete", 10) MAX_DELETIONS = config.get("bridge.max_telegram_delete", 10)
+59 -47
View File
@@ -14,21 +14,27 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Awaitable, Callable, Pattern, Dict, TYPE_CHECKING from typing import Awaitable, Callable, Dict, List, Optional, Pattern, TYPE_CHECKING
import logging import logging
import re import re
from telethon.tl.types import * from telethon.tl.types import (
ChannelParticipantAdmin, ChannelParticipantCreator, ChatForbidden, ChatParticipantAdmin,
ChatParticipantCreator, InputChannel, InputUser, Message, MessageActionChatAddUser,
MessageActionChatDeleteUser, MessageEntityBotCommand, MessageService, PeerChannel, PeerChat,
TypePeer, UpdateNewChannelMessage, UpdateNewMessage)
from telethon.tl.functions.messages import GetChatsRequest, GetFullChatRequest from telethon.tl.functions.messages import GetChatsRequest, GetFullChatRequest
from telethon.tl.functions.channels import GetChannelsRequest, GetParticipantRequest from telethon.tl.functions.channels import GetChannelsRequest, GetParticipantRequest
from telethon.errors import ChannelInvalidError, ChannelPrivateError from telethon.errors import ChannelInvalidError, ChannelPrivateError
from .types import MatrixUserID
from .abstract_user import AbstractUser from .abstract_user import AbstractUser
from .db import BotChat from .db import BotChat
from . import puppet as pu, portal as po, user as u from . import puppet as pu, portal as po, user as u
if TYPE_CHECKING: if TYPE_CHECKING:
from .config import Config from .config import Config
from .context import Context
config = None # type: Config config = None # type: Config
@@ -39,7 +45,7 @@ class Bot(AbstractUser):
log = logging.getLogger("mau.bot") # type: logging.Logger log = logging.getLogger("mau.bot") # type: logging.Logger
mxid_regex = re.compile("@.+:.+") # type: Pattern mxid_regex = re.compile("@.+:.+") # type: Pattern
def __init__(self, token: str): def __init__(self, token: str) -> None:
super().__init__() super().__init__()
self.token = token # type: str self.token = token # type: str
self.puppet_whitelisted = True # type: bool self.puppet_whitelisted = True # type: bool
@@ -53,46 +59,46 @@ class Bot(AbstractUser):
self.whitelist_group_admins = (config["bridge.relaybot.whitelist_group_admins"] self.whitelist_group_admins = (config["bridge.relaybot.whitelist_group_admins"]
or False) # type: bool or False) # type: bool
async def init_permissions(self): async def init_permissions(self) -> None:
whitelist = config["bridge.relaybot.whitelist"] or [] whitelist = config["bridge.relaybot.whitelist"] or []
for id in whitelist: for user_id in whitelist:
if isinstance(id, str): if isinstance(user_id, str):
entity = await self.client.get_input_entity(id) entity = await self.client.get_input_entity(user_id)
if isinstance(entity, InputUser): if isinstance(entity, InputUser):
id = entity.user_id user_id = entity.user_id
else: else:
id = None user_id = None
if isinstance(id, int): if isinstance(user_id, int):
self.tg_whitelist.append(id) self.tg_whitelist.append(user_id)
async def start(self, delete_unless_authenticated: bool = False) -> "Bot": async def start(self, delete_unless_authenticated: bool = False) -> 'Bot':
await super().start(delete_unless_authenticated) await super().start(delete_unless_authenticated)
if not await self.is_logged_in(): if not await self.is_logged_in():
await self.client.sign_in(bot_token=self.token) await self.client.sign_in(bot_token=self.token)
await self.post_login() await self.post_login()
return self return self
async def post_login(self): async def post_login(self) -> None:
await self.init_permissions() await self.init_permissions()
info = await self.client.get_me() info = await self.client.get_me()
self.tgid = info.id self.tgid = info.id
self.username = info.username self.username = info.username
self.mxid = pu.Puppet.get_mxid_from_id(self.tgid) self.mxid = pu.Puppet.get_mxid_from_id(self.tgid)
chat_ids = [id for id, type in self.chats.items() if type == "chat"] chat_ids = [chat_id for chat_id, chat_type in self.chats.items() if chat_type == "chat"]
response = await self.client(GetChatsRequest(chat_ids)) response = await self.client(GetChatsRequest(chat_ids))
for chat in response.chats: for chat in response.chats:
if isinstance(chat, ChatForbidden) or chat.left or chat.deactivated: if isinstance(chat, ChatForbidden) or chat.left or chat.deactivated:
self.remove_chat(chat.id) self.remove_chat(chat.id)
channel_ids = [InputChannel(id, 0) channel_ids = [InputChannel(chat_id, 0)
for id, type in self.chats.items() for chat_id, chat_type in self.chats.items()
if type == "channel"] if chat_type == "channel"]
for id in channel_ids: for channel_id in channel_ids:
try: try:
await self.client(GetChannelsRequest([id])) await self.client(GetChannelsRequest([channel_id]))
except (ChannelPrivateError, ChannelInvalidError): except (ChannelPrivateError, ChannelInvalidError):
self.remove_chat(id.channel_id) self.remove_chat(channel_id.channel_id)
if config["bridge.catch_up"]: if config["bridge.catch_up"]:
try: try:
@@ -100,24 +106,24 @@ class Bot(AbstractUser):
except Exception: except Exception:
self.log.exception("Failed to run catch_up() for bot") self.log.exception("Failed to run catch_up() for bot")
def register_portal(self, portal: po.Portal): def register_portal(self, portal: po.Portal) -> None:
self.add_chat(portal.tgid, portal.peer_type) self.add_chat(portal.tgid, portal.peer_type)
def unregister_portal(self, portal: po.Portal): def unregister_portal(self, portal: po.Portal) -> None:
self.remove_chat(portal.tgid) self.remove_chat(portal.tgid)
def add_chat(self, id: int, type: str): def add_chat(self, chat_id: int, chat_type: str) -> None:
if id not in self.chats: if chat_id not in self.chats:
self.chats[id] = type self.chats[chat_id] = chat_type
self.db.add(BotChat(id=id, type=type)) self.db.add(BotChat(id=chat_id, type=chat_type))
self.db.commit() self.db.commit()
def remove_chat(self, id: int): def remove_chat(self, chat_id: int) -> None:
try: try:
del self.chats[id] del self.chats[chat_id]
except KeyError: except KeyError:
pass pass
existing_chat = BotChat.query.get(id) existing_chat = BotChat.query.get(chat_id)
if existing_chat: if existing_chat:
self.db.delete(existing_chat) self.db.delete(existing_chat)
self.db.commit() self.db.commit()
@@ -141,6 +147,7 @@ class Bot(AbstractUser):
for p in participants: for p in participants:
if p.user_id == tgid: if p.user_id == tgid:
return isinstance(p, (ChatParticipantCreator, ChatParticipantAdmin)) return isinstance(p, (ChatParticipantCreator, ChatParticipantAdmin))
return False
async def check_can_use_commands(self, event: Message, reply: ReplyFunc) -> bool: async def check_can_use_commands(self, event: Message, reply: ReplyFunc) -> bool:
if not await self._can_use_commands(event.to_id, event.from_id): if not await self._can_use_commands(event.to_id, event.from_id):
@@ -148,7 +155,7 @@ class Bot(AbstractUser):
return False return False
return True return True
async def handle_command_portal(self, portal: po.Portal, reply: ReplyFunc): async def handle_command_portal(self, portal: po.Portal, reply: ReplyFunc) -> None:
if not config["bridge.relaybot.authless_portals"]: if not config["bridge.relaybot.authless_portals"]:
return await reply("This bridge doesn't allow portal creation from Telegram.") return await reply("This bridge doesn't allow portal creation from Telegram.")
@@ -164,15 +171,16 @@ class Bot(AbstractUser):
return await reply( return await reply(
"Portal is not public. Use `/invite <mxid>` to get an invite.") "Portal is not public. Use `/invite <mxid>` to get an invite.")
async def handle_command_invite(self, portal: po.Portal, reply: ReplyFunc, mxid: str): async def handle_command_invite(self, portal: po.Portal, reply: ReplyFunc,
if len(mxid) == 0: mxid_input: MatrixUserID) -> Message:
if len(mxid_input) == 0:
return await reply("Usage: `/invite <mxid>`") return await reply("Usage: `/invite <mxid>`")
elif not portal.mxid: elif not portal.mxid:
return await reply("Portal does not have Matrix room. " return await reply("Portal does not have Matrix room. "
"Create one with /portal first.") "Create one with /portal first.")
if not self.mxid_regex.match(mxid): if not self.mxid_regex.match(mxid_input):
return await reply("That doesn't look like a Matrix ID.") return await reply("That doesn't look like a Matrix ID.")
user = await u.User.get_by_mxid(mxid).ensure_started() user = await u.User.get_by_mxid(MatrixUserID(mxid_input)).ensure_started()
if not user.relaybot_whitelisted: if not user.relaybot_whitelisted:
return await reply("That user is not whitelisted to use the bridge.") return await reply("That user is not whitelisted to use the bridge.")
elif await user.is_logged_in(): elif await user.is_logged_in():
@@ -183,7 +191,8 @@ class Bot(AbstractUser):
await portal.main_intent.invite(portal.mxid, user.mxid) await portal.main_intent.invite(portal.mxid, user.mxid)
return await reply(f"Invited `{user.mxid}` to the portal.") return await reply(f"Invited `{user.mxid}` to the portal.")
def handle_command_id(self, message: Message, reply: ReplyFunc): @staticmethod
def handle_command_id(message: Message, reply: ReplyFunc) -> Awaitable[Message]:
# Provide the prefixed ID to the user so that the user wouldn't need to specify whether the # Provide the prefixed ID to the user so that the user wouldn't need to specify whether the
# chat is a normal group or a supergroup/channel when using the ID. # chat is a normal group or a supergroup/channel when using the ID.
if isinstance(message.to_id, PeerChannel): if isinstance(message.to_id, PeerChannel):
@@ -205,8 +214,8 @@ class Bot(AbstractUser):
return False return False
async def handle_command(self, message: Message): async def handle_command(self, message: Message) -> None:
def reply(reply_text): def reply(reply_text: str) -> Awaitable[Message]:
return self.client.send_message(message.to_id, reply_text, reply_to=message.id) return self.client.send_message(message.to_id, reply_text, reply_to=message.id)
text = message.message text = message.message
@@ -227,36 +236,39 @@ class Bot(AbstractUser):
mxid = text[text.index(" ") + 1:] mxid = text[text.index(" ") + 1:]
except ValueError: except ValueError:
mxid = "" mxid = ""
await self.handle_command_invite(portal, reply, mxid=mxid) await self.handle_command_invite(portal, reply, mxid_input=mxid)
def handle_service_message(self, message: MessageService): def handle_service_message(self, message: MessageService) -> None:
to_id = message.to_id to_id = message.to_id
if isinstance(to_id, PeerChannel): if isinstance(to_id, PeerChannel):
to_id = to_id.channel_id to_id = to_id.channel_id
type = "channel" chat_type = "channel"
elif isinstance(to_id, PeerChat): elif isinstance(to_id, PeerChat):
to_id = to_id.chat_id to_id = to_id.chat_id
type = "chat" chat_type = "chat"
else: else:
return return
action = message.action action = message.action
if isinstance(action, MessageActionChatAddUser) and self.tgid in action.users: if isinstance(action, MessageActionChatAddUser) and self.tgid in action.users:
self.add_chat(to_id, type) self.add_chat(to_id, chat_type)
elif isinstance(action, MessageActionChatDeleteUser) and action.user_id == self.tgid: elif isinstance(action, MessageActionChatDeleteUser) and action.user_id == self.tgid:
self.remove_chat(to_id) self.remove_chat(to_id)
async def update(self, update): async def update(self, update) -> bool:
if not isinstance(update, (UpdateNewMessage, UpdateNewChannelMessage)): if not isinstance(update, (UpdateNewMessage, UpdateNewChannelMessage)):
return return False
if isinstance(update.message, MessageService): if isinstance(update.message, MessageService):
return self.handle_service_message(update.message) self.handle_service_message(update.message)
return False
is_command = (isinstance(update.message, Message) is_command = (isinstance(update.message, Message)
and update.message.entities and len(update.message.entities) > 0 and update.message.entities and len(update.message.entities) > 0
and isinstance(update.message.entities[0], MessageEntityBotCommand)) and isinstance(update.message.entities[0], MessageEntityBotCommand))
if is_command: if is_command:
return await self.handle_command(update.message) await self.handle_command(update.message)
return True
return False
def is_in_chat(self, peer_id) -> bool: def is_in_chat(self, peer_id) -> bool:
return peer_id in self.chats return peer_id in self.chats
@@ -266,7 +278,7 @@ class Bot(AbstractUser):
return "bot" return "bot"
def init(context) -> Optional[Bot]: def init(context: 'Context') -> Optional[Bot]:
global config global config
config = context.config config = context.config
token = config["telegram.bot_token"] token = config["telegram.bot_token"]
+39 -22
View File
@@ -14,20 +14,24 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict from typing import Any, Dict, Optional
import asyncio import asyncio
from telethon.errors import * from telethon.errors import (
AccessTokenExpiredError, AccessTokenInvalidError, FirstNameInvalidError, FloodWaitError,
PasswordHashInvalidError, PhoneCodeExpiredError, PhoneCodeInvalidError,
PhoneNumberAppSignupForbiddenError, PhoneNumberBannedError, PhoneNumberFloodError,
PhoneNumberOccupiedError, PhoneNumberUnoccupiedError, SessionPasswordNeededError)
from . import command_handler, CommandEvent, SECTION_AUTH from . import command_handler, CommandEvent, SECTION_AUTH
from .. import puppet as pu from .. import puppet as pu, user as u
from ..util import format_duration from ..util import format_duration
@command_handler(needs_auth=False, @command_handler(needs_auth=False,
help_section=SECTION_AUTH, help_section=SECTION_AUTH,
help_text="Check if you're logged into Telegram.") help_text="Check if you're logged into Telegram.")
async def ping(evt: CommandEvent): async def ping(evt: CommandEvent) -> Optional[Dict]:
me = await evt.sender.client.get_me() if await evt.sender.is_logged_in() else None me = await evt.sender.client.get_me() if await evt.sender.is_logged_in() else None
if me: if me:
return await evt.reply(f"You're logged in as @{me.username}") return await evt.reply(f"You're logged in as @{me.username}")
@@ -38,7 +42,7 @@ async def ping(evt: CommandEvent):
@command_handler(needs_auth=False, needs_puppeting=False, @command_handler(needs_auth=False, needs_puppeting=False,
help_section=SECTION_AUTH, help_section=SECTION_AUTH,
help_text="Get the info of the message relay Telegram bot.") help_text="Get the info of the message relay Telegram bot.")
async def ping_bot(evt: CommandEvent): async def ping_bot(evt: CommandEvent) -> Optional[Dict]:
if not evt.tgbot: if not evt.tgbot:
return await evt.reply("Telegram message relay bot not configured.") return await evt.reply("Telegram message relay bot not configured.")
bot_info = await evt.tgbot.client.get_me() bot_info = await evt.tgbot.client.get_me()
@@ -53,19 +57,19 @@ async def ping_bot(evt: CommandEvent):
help_section=SECTION_AUTH, help_section=SECTION_AUTH,
help_text="Revert your Telegram account's Matrix puppet to use the default Matrix " help_text="Revert your Telegram account's Matrix puppet to use the default Matrix "
"account.") "account.")
async def logout_matrix(evt: CommandEvent): async def logout_matrix(evt: CommandEvent) -> Optional[Dict]:
puppet = pu.Puppet.get(evt.sender.tgid) puppet = pu.Puppet.get(evt.sender.tgid)
if not puppet.is_real_user: if not puppet.is_real_user:
return await evt.reply("You are not logged in with your Matrix account.") return await evt.reply("You are not logged in with your Matrix account.")
await puppet.switch_mxid(None, None) await puppet.switch_mxid(None, None)
await evt.reply("Reverted your Telegram account's Matrix puppet back to the default.") return await evt.reply("Reverted your Telegram account's Matrix puppet back to the default.")
@command_handler(needs_auth=True, management_only=True, needs_matrix_puppeting=True, @command_handler(needs_auth=True, management_only=True, needs_matrix_puppeting=True,
help_section=SECTION_AUTH, help_section=SECTION_AUTH,
help_text="Replace your Telegram account's Matrix puppet with your own Matrix " help_text="Replace your Telegram account's Matrix puppet with your own Matrix "
"account") "account")
async def login_matrix(evt: CommandEvent): async def login_matrix(evt: CommandEvent) -> Optional[Dict]:
puppet = pu.Puppet.get(evt.sender.tgid) puppet = pu.Puppet.get(evt.sender.tgid)
if puppet.is_real_user: if puppet.is_real_user:
return await evt.reply("You have already logged in with your Matrix account. " return await evt.reply("You have already logged in with your Matrix account. "
@@ -96,7 +100,7 @@ async def login_matrix(evt: CommandEvent):
return await evt.reply("This bridge instance has been configured to not allow logging in.") return await evt.reply("This bridge instance has been configured to not allow logging in.")
async def enter_matrix_token(evt: CommandEvent): async def enter_matrix_token(evt: CommandEvent) -> Dict:
evt.sender.command_status = None evt.sender.command_status = None
puppet = pu.Puppet.get(evt.sender.tgid) puppet = pu.Puppet.get(evt.sender.tgid)
@@ -105,10 +109,11 @@ async def enter_matrix_token(evt: CommandEvent):
"Log out with `$cmdprefix+sp logout-matrix` first.") "Log out with `$cmdprefix+sp logout-matrix` first.")
resp = await puppet.switch_mxid(" ".join(evt.args), evt.sender.mxid) resp = await puppet.switch_mxid(" ".join(evt.args), evt.sender.mxid)
if resp == 2: if resp == pu.PuppetError.OnlyLoginSelf:
return await evt.reply("You can only log in as your own Matrix user.") return await evt.reply("You can only log in as your own Matrix user.")
elif resp == 1: elif resp == pu.PuppetError.InvalidAccessToken:
return await evt.reply("Failed to verify access token.") return await evt.reply("Failed to verify access token.")
assert resp == pu.PuppetError.Success, "Encountered an unhandled PuppetError."
return await evt.reply( return await evt.reply(
f"Replaced your Telegram account's Matrix puppet with {puppet.custom_mxid}.") f"Replaced your Telegram account's Matrix puppet with {puppet.custom_mxid}.")
@@ -117,7 +122,7 @@ async def enter_matrix_token(evt: CommandEvent):
help_section=SECTION_AUTH, help_section=SECTION_AUTH,
help_args="<_phone_> <_full name_>", help_args="<_phone_> <_full name_>",
help_text="Register to Telegram") help_text="Register to Telegram")
async def register(evt: CommandEvent): async def register(evt: CommandEvent) -> Optional[Dict]:
if await evt.sender.is_logged_in(): if await evt.sender.is_logged_in():
return await evt.reply("You are already logged in.") return await evt.reply("You are already logged in.")
elif len(evt.args) < 1: elif len(evt.args) < 1:
@@ -134,9 +139,10 @@ async def register(evt: CommandEvent):
"action": "Register", "action": "Register",
"full_name": full_name, "full_name": full_name,
}) })
return None
async def enter_code_register(evt: CommandEvent): async def enter_code_register(evt: CommandEvent) -> Dict:
if len(evt.args) == 0: if len(evt.args) == 0:
return await evt.reply("**Usage:** `$cmdprefix+sp <code>`") return await evt.reply("**Usage:** `$cmdprefix+sp <code>`")
try: try:
@@ -165,9 +171,9 @@ async def enter_code_register(evt: CommandEvent):
@command_handler(needs_auth=False, management_only=True, @command_handler(needs_auth=False, management_only=True,
help_section=SECTION_AUTH, help_section=SECTION_AUTH,
help_text="Get instructions on how to log in.") help_text="Get instructions on how to log in.")
async def login(evt: CommandEvent): async def login(evt: CommandEvent) -> Optional[Dict]:
if await evt.sender.is_logged_in(): if await evt.sender.is_logged_in():
return await evt.reply("You are already logged in.") return await evt.reply(f"You are already logged in as {evt.sender.human_tg_id}.")
allow_matrix_login = evt.config.get("bridge.allow_matrix_login", True) allow_matrix_login = evt.config.get("bridge.allow_matrix_login", True)
if allow_matrix_login: if allow_matrix_login:
@@ -196,7 +202,8 @@ async def login(evt: CommandEvent):
return await evt.reply("This bridge instance has been configured to not allow logging in.") return await evt.reply("This bridge instance has been configured to not allow logging in.")
async def request_code(evt: CommandEvent, phone_number: str, next_status: Dict[str, str]): async def request_code(evt: CommandEvent, phone_number: str, next_status: Dict[str, Any]
) -> Dict:
ok = False ok = False
try: try:
await evt.sender.ensure_started(even_if_no_session=True) await evt.sender.ensure_started(even_if_no_session=True)
@@ -228,7 +235,7 @@ async def request_code(evt: CommandEvent, phone_number: str, next_status: Dict[s
@command_handler(needs_auth=False) @command_handler(needs_auth=False)
async def enter_phone_or_token(evt: CommandEvent): async def enter_phone_or_token(evt: CommandEvent) -> Optional[Dict]:
if len(evt.args) == 0: if len(evt.args) == 0:
return await evt.reply("**Usage:** `$cmdprefix+sp enter-phone-or-token <phone-or-token>`") return await evt.reply("**Usage:** `$cmdprefix+sp enter-phone-or-token <phone-or-token>`")
elif not evt.config.get("bridge.allow_matrix_login", True): elif not evt.config.get("bridge.allow_matrix_login", True):
@@ -248,10 +255,11 @@ async def enter_phone_or_token(evt: CommandEvent):
"next": enter_code, "next": enter_code,
"action": "Login", "action": "Login",
}) })
return None
@command_handler(needs_auth=False) @command_handler(needs_auth=False)
async def enter_code(evt: CommandEvent): async def enter_code(evt: CommandEvent) -> Optional[Dict]:
if len(evt.args) == 0: if len(evt.args) == 0:
return await evt.reply("**Usage:** `$cmdprefix+sp enter-code <code>`") return await evt.reply("**Usage:** `$cmdprefix+sp enter-code <code>`")
elif not evt.config.get("bridge.allow_matrix_login", True): elif not evt.config.get("bridge.allow_matrix_login", True):
@@ -263,10 +271,11 @@ async def enter_code(evt: CommandEvent):
evt.log.exception("Error sending phone code") evt.log.exception("Error sending phone code")
return await evt.reply("Unhandled exception while sending code. " return await evt.reply("Unhandled exception while sending code. "
"Check console for more details.") "Check console for more details.")
return None
@command_handler(needs_auth=False) @command_handler(needs_auth=False)
async def enter_password(evt: CommandEvent): async def enter_password(evt: CommandEvent) -> Optional[Dict]:
if len(evt.args) == 0: if len(evt.args) == 0:
return await evt.reply("**Usage:** `$cmdprefix+sp enter-password <password>`") return await evt.reply("**Usage:** `$cmdprefix+sp enter-password <password>`")
elif not evt.config.get("bridge.allow_matrix_login", True): elif not evt.config.get("bridge.allow_matrix_login", True):
@@ -282,15 +291,23 @@ async def enter_password(evt: CommandEvent):
evt.log.exception("Error sending password") evt.log.exception("Error sending password")
return await evt.reply("Unhandled exception while sending password. " return await evt.reply("Unhandled exception while sending password. "
"Check console for more details.") "Check console for more details.")
return None
async def sign_in(evt: CommandEvent, **sign_in_info): async def sign_in(evt: CommandEvent, **sign_in_info) -> Dict:
try: try:
await evt.sender.ensure_started(even_if_no_session=True) await evt.sender.ensure_started(even_if_no_session=True)
user = await evt.sender.client.sign_in(**sign_in_info) user = await evt.sender.client.sign_in(**sign_in_info)
existing_user = u.User.get_by_tgid(user.id)
if existing_user and existing_user != evt.sender:
await existing_user.log_out()
await evt.reply(f"[{existing_user.displayname}]"
f"(https://matrix.to/#/{existing_user.mxid})"
" was logged out from the account.")
asyncio.ensure_future(evt.sender.post_login(user), loop=evt.loop) asyncio.ensure_future(evt.sender.post_login(user), loop=evt.loop)
evt.sender.command_status = None evt.sender.command_status = None
return await evt.reply(f"Successfully logged in as @{user.username}") name = f"@{user.username}" if user.username else f"+{user.phone}"
return await evt.reply(f"Successfully logged in as {name}")
except PhoneCodeExpiredError: except PhoneCodeExpiredError:
return await evt.reply("Phone code expired. Try again with `$cmdprefix+sp login`.") return await evt.reply("Phone code expired. Try again with `$cmdprefix+sp login`.")
except PhoneCodeInvalidError: except PhoneCodeInvalidError:
@@ -309,7 +326,7 @@ async def sign_in(evt: CommandEvent, **sign_in_info):
@command_handler(needs_auth=True, @command_handler(needs_auth=True,
help_section=SECTION_AUTH, help_section=SECTION_AUTH,
help_text="Log out from Telegram.") help_text="Log out from Telegram.")
async def logout(evt: CommandEvent): async def logout(evt: CommandEvent) -> Optional[Dict]:
if await evt.sender.log_out(): if await evt.sender.log_out():
return await evt.reply("Logged out successfully.") return await evt.reply("Logged out successfully.")
return await evt.reply("Failed to log out.") return await evt.reply("Failed to log out.")
+23 -21
View File
@@ -14,26 +14,27 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Tuple, List from typing import Dict, List, NewType, Optional, Tuple, Union
from mautrix_appservice import MatrixRequestError, IntentAPI from mautrix_appservice import MatrixRequestError, IntentAPI
from ..types import MatrixRoomID, MatrixUserID
from . import command_handler, CommandEvent, SECTION_ADMIN from . import command_handler, CommandEvent, SECTION_ADMIN
from .. import puppet as pu, portal as po from .. import puppet as pu, portal as po
ManagementRoomList = List[Tuple[str, str]] ManagementRoom = NewType('ManagementRoom', Tuple[MatrixRoomID, MatrixUserID])
RoomIDList = List[str]
async def _find_rooms(intent: IntentAPI) -> Tuple[ManagementRoomList, RoomIDList, async def _find_rooms(intent: IntentAPI) -> Tuple[List[ManagementRoom], List[MatrixRoomID],
List["po.Portal"], List["po.Portal"]]: List['po.Portal'], List['po.Portal']]:
management_rooms = [] # type: ManagementRoomList management_rooms = [] # type: List[ManagementRoom]
unidentified_rooms = [] # type: RoomIDList unidentified_rooms = [] # type: List[MatrixRoomID]
portals = [] # type: List[po.Portal] portals = [] # type: List[po.Portal]
empty_portals = [] # type: List[po.Portal] empty_portals = [] # type: List[po.Portal]
rooms = await intent.get_joined_rooms() rooms = await intent.get_joined_rooms()
for room in rooms: for room_str in rooms:
room = MatrixRoomID(room_str)
portal = po.Portal.get_by_mxid(room) portal = po.Portal.get_by_mxid(room)
if not portal: if not portal:
try: try:
@@ -41,11 +42,11 @@ async def _find_rooms(intent: IntentAPI) -> Tuple[ManagementRoomList, RoomIDList
except MatrixRequestError: except MatrixRequestError:
members = [] members = []
if len(members) == 2: if len(members) == 2:
other_member = members[0] if members[0] != intent.mxid else members[1] other_member = MatrixUserID(members[0] if members[0] != intent.mxid else members[1])
if pu.Puppet.get_id_from_mxid(other_member): if pu.Puppet.get_id_from_mxid(other_member):
unidentified_rooms.append(room) unidentified_rooms.append(room)
else: else:
management_rooms.append((room, other_member)) management_rooms.append(ManagementRoom((room, other_member)))
else: else:
unidentified_rooms.append(room) unidentified_rooms.append(room)
else: else:
@@ -61,7 +62,7 @@ async def _find_rooms(intent: IntentAPI) -> Tuple[ManagementRoomList, RoomIDList
@command_handler(needs_admin=True, needs_auth=False, management_only=True, name="clean-rooms", @command_handler(needs_admin=True, needs_auth=False, management_only=True, name="clean-rooms",
help_section=SECTION_ADMIN, help_section=SECTION_ADMIN,
help_text="Clean up unused portal/management rooms.") help_text="Clean up unused portal/management rooms.")
async def clean_rooms(evt: CommandEvent): async def clean_rooms(evt: CommandEvent) -> Optional[Dict]:
management_rooms, unidentified_rooms, portals, empty_portals = await _find_rooms(evt.az.intent) management_rooms, unidentified_rooms, portals, empty_portals = await _find_rooms(evt.az.intent)
reply = ["#### Management rooms (M)"] reply = ["#### Management rooms (M)"]
@@ -106,13 +107,14 @@ async def clean_rooms(evt: CommandEvent):
return await evt.reply("\n".join(reply)) return await evt.reply("\n".join(reply))
async def set_rooms_to_clean(evt, management_rooms: ManagementRoomList, async def set_rooms_to_clean(evt, management_rooms: List[ManagementRoom],
unidentified_rooms: RoomIDList, portals: List["po.Portal"], unidentified_rooms: List[MatrixRoomID], portals: List["po.Portal"],
empty_portals: List["po.Portal"]): empty_portals: List["po.Portal"]) -> None:
command = evt.args[0] command = evt.args[0]
rooms_to_clean = [] rooms_to_clean = [] # type: List[Union[po.Portal, MatrixRoomID]]
if command == "clean-recommended": if command == "clean-recommended":
rooms_to_clean = empty_portals + unidentified_rooms rooms_to_clean += empty_portals
rooms_to_clean += unidentified_rooms
elif command == "clean-groups": elif command == "clean-groups":
if len(evt.args) < 2: if len(evt.args) < 2:
return await evt.reply("**Usage:** `$cmdprefix+sp clean-groups [M][A][U][I]") return await evt.reply("**Usage:** `$cmdprefix+sp clean-groups [M][A][U][I]")
@@ -127,9 +129,9 @@ async def set_rooms_to_clean(evt, management_rooms: ManagementRoomList,
rooms_to_clean += empty_portals rooms_to_clean += empty_portals
elif command == "clean-range": elif command == "clean-range":
try: try:
range = evt.args[1] clean_range = evt.args[1]
group, range = range[0], range[1:] group, clean_range = clean_range[0], clean_range[1:]
start, end = range.split("-") start, end = clean_range.split("-")
start, end = int(start), int(end) start, end = int(start), int(end)
if group == "M": if group == "M":
group = [room_id for (room_id, user_id) in management_rooms] group = [room_id for (room_id, user_id) in management_rooms]
@@ -158,7 +160,7 @@ async def set_rooms_to_clean(evt, management_rooms: ManagementRoomList,
"`$cmdprefix+sp confirm-clean`.") "`$cmdprefix+sp confirm-clean`.")
async def execute_room_cleanup(evt, rooms_to_clean): async def execute_room_cleanup(evt, rooms_to_clean: List[Union[po.Portal, MatrixRoomID]]) -> None:
if len(evt.args) > 0 and evt.args[0] == "confirm-clean": if len(evt.args) > 0 and evt.args[0] == "confirm-clean":
await evt.reply(f"Cleaning {len(rooms_to_clean)} rooms. " await evt.reply(f"Cleaning {len(rooms_to_clean)} rooms. "
"This might take a while.") "This might take a while.")
@@ -167,7 +169,7 @@ async def execute_room_cleanup(evt, rooms_to_clean):
if isinstance(room, po.Portal): if isinstance(room, po.Portal):
await room.cleanup_and_delete() await room.cleanup_and_delete()
cleaned += 1 cleaned += 1
elif isinstance(room, str): elif isinstance(room, str): # str is aliased by MatrixRoomID
await po.Portal.cleanup_room(evt.az.intent, room, message="Room deleted") await po.Portal.cleanup_room(evt.az.intent, room, message="Room deleted")
cleaned += 1 cleaned += 1
evt.sender.command_status = None evt.sender.command_status = None
+57 -24
View File
@@ -14,19 +14,19 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import List, Dict, Callable, Optional from typing import Awaitable, Callable, Dict, List, NamedTuple, Optional
from collections import namedtuple import commonmark
import markdown
import logging import logging
from telethon.errors import FloodWaitError from telethon.errors import FloodWaitError
from ..types import MatrixRoomID
from ..util import format_duration from ..util import format_duration
from .. import user as u, context as c from .. import user as u, context as c
command_handlers = {} # type: Dict[str, CommandHandler] command_handlers = {} # type: Dict[str, CommandHandler]
HelpSection = namedtuple("HelpSection", "name order description") HelpSection = NamedTuple('HelpSection', [('name', str), ('order', int), ('description', str)])
SECTION_GENERAL = HelpSection("General", 0, "") SECTION_GENERAL = HelpSection("General", 0, "")
SECTION_AUTH = HelpSection("Authentication", 10, "") SECTION_AUTH = HelpSection("Authentication", 10, "")
@@ -36,9 +36,30 @@ SECTION_MISC = HelpSection("Miscellaneous", 40, "")
SECTION_ADMIN = HelpSection("Administration", 50, "") SECTION_ADMIN = HelpSection("Administration", 50, "")
class HtmlEscapingRenderer(commonmark.HtmlRenderer):
def __init__(self, allow_html: bool = False):
super().__init__()
self.allow_html = allow_html
def lit(self, s):
if self.allow_html:
return super().lit(s)
return super().lit(s.replace("<", "&lt;").replace(">", "&gt;"))
def image(self, node, entering):
prev = self.allow_html
self.allow_html = True
super().image(node, entering)
self.allow_html = prev
md_parser = commonmark.Parser()
md_renderer = HtmlEscapingRenderer()
class CommandEvent: class CommandEvent:
def __init__(self, processor: "CommandProcessor", room: str, sender: u.User, command: str, def __init__(self, processor: 'CommandProcessor', room: MatrixRoomID, sender: u.User,
args: List[str], is_management: bool, is_portal: bool): command: str, args: List[str], is_management: bool, is_portal: bool) -> None:
self.az = processor.az self.az = processor.az
self.log = processor.log self.log = processor.log
self.loop = processor.loop self.loop = processor.loop
@@ -53,23 +74,25 @@ class CommandEvent:
self.is_management = is_management self.is_management = is_management
self.is_portal = is_portal self.is_portal = is_portal
def reply(self, message: str, allow_html: bool = False, render_markdown: bool = True): def reply(self, message: str, allow_html: bool = False, render_markdown: bool = True
) -> Awaitable[Dict]:
message = message.replace("$cmdprefix+sp ", message = message.replace("$cmdprefix+sp ",
"" if self.is_management else f"{self.command_prefix} ") "" if self.is_management else f"{self.command_prefix} ")
message = message.replace("$cmdprefix", self.command_prefix) message = message.replace("$cmdprefix", self.command_prefix)
html = None html = None
if render_markdown: if render_markdown:
html = markdown.markdown(message, safe_mode="escape" if allow_html else False) md_renderer.allow_html = allow_html
html = md_renderer.render(md_parser.parse(message))
elif allow_html: elif allow_html:
html = message html = message
return self.az.intent.send_notice(self.room_id, message, html=html) return self.az.intent.send_notice(self.room_id, message, html=html)
class CommandHandler: class CommandHandler:
def __init__(self, handler: Callable[[CommandEvent], None], needs_auth: bool, def __init__(self, handler: Callable[[CommandEvent], Awaitable[Dict]], needs_auth: bool,
needs_puppeting: bool, needs_matrix_puppeting: bool, needs_admin: bool, needs_puppeting: bool, needs_matrix_puppeting: bool, needs_admin: bool,
management_only: bool, name: str, help_text: str, help_args: str, management_only: bool, name: str, help_text: str, help_args: str,
help_section: HelpSection): help_section: HelpSection) -> None:
self._handler = handler self._handler = handler
self.needs_auth = needs_auth self.needs_auth = needs_auth
self.needs_puppeting = needs_puppeting self.needs_puppeting = needs_puppeting
@@ -103,7 +126,8 @@ class CommandHandler:
(not self.needs_admin or is_admin) and (not self.needs_admin or is_admin) and
(not self.needs_auth or is_logged_in)) (not self.needs_auth or is_logged_in))
async def __call__(self, evt: CommandEvent): async def __call__(self, evt: CommandEvent
) -> Dict:
error = await self.get_permission_error(evt) error = await self.get_permission_error(evt)
if error is not None: if error is not None:
return await evt.reply(error) return await evt.reply(error)
@@ -118,13 +142,21 @@ class CommandHandler:
return f"**{self.name}** {self._help_args} - {self._help_text}" return f"**{self.name}** {self._help_args} - {self._help_text}"
def command_handler(_func: Optional[Callable[[CommandEvent], None]] = None, *, needs_auth=True, def command_handler(_func: Optional[Callable[[CommandEvent], Awaitable[Dict]]] = None, *,
needs_puppeting=True, needs_matrix_puppeting=False, needs_admin=False, needs_auth: bool = True,
management_only=False, name=None, help_text="", help_args="", needs_puppeting: bool = True,
help_section=None): needs_matrix_puppeting: bool = False,
needs_admin: bool = False,
management_only: bool = False,
name: Optional[str] = None,
help_text: str = "",
help_args: str = "",
help_section: HelpSection = None
) -> Callable[[Callable[[CommandEvent], Awaitable[Optional[Dict]]]],
CommandHandler]:
input_name = name input_name = name
def decorator(func: Callable[[CommandEvent], None]): def decorator(func: Callable[[CommandEvent], Awaitable[Optional[Dict]]]) -> CommandHandler:
name = input_name or func.__name__.replace("_", "-") name = input_name or func.__name__.replace("_", "-")
handler = CommandHandler(func, needs_auth, needs_puppeting, needs_matrix_puppeting, handler = CommandHandler(func, needs_auth, needs_puppeting, needs_matrix_puppeting,
needs_admin, management_only, name, help_text, help_args, needs_admin, management_only, name, help_text, help_args,
@@ -138,27 +170,27 @@ def command_handler(_func: Optional[Callable[[CommandEvent], None]] = None, *, n
class CommandProcessor: class CommandProcessor:
log = logging.getLogger("mau.commands") log = logging.getLogger("mau.commands")
def __init__(self, context: c.Context): def __init__(self, context: c.Context) -> None:
self.az, self.db, self.config, self.loop, self.tgbot = context self.az, self.db, self.config, self.loop, self.tgbot = context.core
self.public_website = context.public_website self.public_website = context.public_website
self.command_prefix = self.config["bridge.command_prefix"] self.command_prefix = self.config["bridge.command_prefix"]
async def handle(self, room: str, sender: u.User, command: str, args: List[str], async def handle(self, room: MatrixRoomID, sender: u.User, command: str, args: List[str],
is_management: bool, is_portal: bool): is_management: bool, is_portal: bool) -> Optional[Dict]:
evt = CommandEvent(self, room, sender, command, args, is_management, is_portal) evt = CommandEvent(self, room, sender, command, args, is_management, is_portal)
orig_command = command orig_command = command
command = command.lower() command = command.lower()
try: try:
command = command_handlers[command] handler = command_handlers[command]
except KeyError: except KeyError:
if sender.command_status and "next" in sender.command_status: if sender.command_status and "next" in sender.command_status:
args.insert(0, orig_command) args.insert(0, orig_command)
evt.command = "" evt.command = ""
command = sender.command_status["next"] handler = sender.command_status["next"]
else: else:
command = command_handlers["unknown-command"] handler = command_handlers["unknown-command"]
try: try:
await command(evt) await handler(evt)
except FloodWaitError as e: except FloodWaitError as e:
return await evt.reply(f"Flood error: Please wait {format_duration(e.seconds)}") return await evt.reply(f"Flood error: Please wait {format_duration(e.seconds)}")
except Exception: except Exception:
@@ -166,3 +198,4 @@ class CommandProcessor:
f"{evt.command} {' '.join(args)} from {sender.mxid}") f"{evt.command} {' '.join(args)} from {sender.mxid}")
return await evt.reply("Unhandled error while handling command. " return await evt.reply("Unhandled error while handling command. "
"Check logs for more details.") "Check logs for more details.")
return None
+17 -14
View File
@@ -14,46 +14,49 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict, List, Optional, Tuple
from . import command_handler, CommandEvent, _command_handlers, SECTION_GENERAL from . import command_handler, CommandEvent, _command_handlers, SECTION_GENERAL
from .handler import HelpSection
@command_handler(needs_auth=False, needs_puppeting=False, @command_handler(needs_auth=False, needs_puppeting=False,
help_section=SECTION_GENERAL, help_section=SECTION_GENERAL,
help_text="Cancel an ongoing action (such as login)") help_text="Cancel an ongoing action (such as login)")
def cancel(evt: CommandEvent): async def cancel(evt: CommandEvent) -> Optional[Dict]:
if evt.sender.command_status: if evt.sender.command_status:
action = evt.sender.command_status["action"] action = evt.sender.command_status["action"]
evt.sender.command_status = None evt.sender.command_status = None
return evt.reply(f"{action} cancelled.") return await evt.reply(f"{action} cancelled.")
else: else:
return evt.reply("No ongoing command.") return await evt.reply("No ongoing command.")
@command_handler(needs_auth=False, needs_puppeting=False) @command_handler(needs_auth=False, needs_puppeting=False)
def unknown_command(evt: CommandEvent): async def unknown_command(evt: CommandEvent) -> Optional[Dict]:
return evt.reply("Unknown command. Try `$cmdprefix+sp help` for help.") return await evt.reply("Unknown command. Try `$cmdprefix+sp help` for help.")
help_cache = {} help_cache = {} # type: Dict[Tuple[bool, bool, bool, bool, bool], str]
async def _get_help_text(evt: CommandEvent): async def _get_help_text(evt: CommandEvent) -> str:
cache_key = (evt.is_management, evt.sender.puppet_whitelisted, cache_key = (evt.is_management, evt.sender.puppet_whitelisted,
evt.sender.matrix_puppet_whitelisted, evt.sender.is_admin, evt.sender.matrix_puppet_whitelisted, evt.sender.is_admin,
await evt.sender.is_logged_in()) await evt.sender.is_logged_in())
if cache_key not in help_cache: if cache_key not in help_cache:
help = {} help_sections = {} # type: Dict[HelpSection, List[str]]
for handler in _command_handlers.values(): for handler in _command_handlers.values():
if handler.has_help and handler.has_permission(*cache_key): if handler.has_help and handler.has_permission(*cache_key):
help.setdefault(handler.help_section, []) help_sections.setdefault(handler.help_section, [])
help[handler.help_section].append(handler.help + " ") help_sections[handler.help_section].append(handler.help + " ")
help = sorted(help.items(), key=lambda item: item[0].order) help_sorted = sorted(help_sections.items(), key=lambda item: item[0].order)
help = ["#### {}\n{}\n".format(key.name, "\n".join(value)) for key, value in help] help = ["#### {}\n{}\n".format(key.name, "\n".join(value)) for key, value in help_sorted]
help_cache[cache_key] = "\n".join(help) help_cache[cache_key] = "\n".join(help)
return help_cache[cache_key] return help_cache[cache_key]
def _get_management_status(evt: CommandEvent): def _get_management_status(evt: CommandEvent) -> str:
if evt.is_management: if evt.is_management:
return "This is a management room: prefixing commands with `$cmdprefix` is not required." return "This is a management room: prefixing commands with `$cmdprefix` is not required."
elif evt.is_portal: elif evt.is_portal:
@@ -65,5 +68,5 @@ def _get_management_status(evt: CommandEvent):
@command_handler(needs_auth=False, needs_puppeting=False, @command_handler(needs_auth=False, needs_puppeting=False,
help_section=SECTION_GENERAL, help_section=SECTION_GENERAL,
help_text="Show this help message.") help_text="Show this help message.")
async def help(evt: CommandEvent): async def help(evt: CommandEvent) -> Optional[Dict]:
return await evt.reply(_get_management_status(evt) + "\n" + await _get_help_text(evt)) return await evt.reply(_get_management_status(evt) + "\n" + await _get_help_text(evt))
+169 -45
View File
@@ -14,14 +14,18 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional, Callable from typing import Dict, Callable, Optional, Tuple, Coroutine, Awaitable
from io import StringIO
import asyncio import asyncio
from telethon.errors import * from telethon.errors import (ChatAdminRequiredError, UsernameInvalidError,
UsernameNotModifiedError, UsernameOccupiedError)
from telethon.tl.types import ChatForbidden, ChannelForbidden from telethon.tl.types import ChatForbidden, ChannelForbidden
from mautrix_appservice import MatrixRequestError, IntentAPI from mautrix_appservice import MatrixRequestError, IntentAPI
from .. import portal as po, user as u from ..types import MatrixRoomID, TelegramID
from ..config import yaml
from .. import portal as po, user as u, util
from . import (command_handler, CommandEvent, from . import (command_handler, CommandEvent,
SECTION_ADMIN, SECTION_CREATING_PORTALS, SECTION_PORTAL_MANAGEMENT) SECTION_ADMIN, SECTION_CREATING_PORTALS, SECTION_PORTAL_MANAGEMENT)
@@ -30,7 +34,7 @@ from . import (command_handler, CommandEvent,
help_section=SECTION_ADMIN, help_section=SECTION_ADMIN,
help_args="<_level_> [_mxid_]", help_args="<_level_> [_mxid_]",
help_text="Set a temporary power level without affecting Telegram.") help_text="Set a temporary power level without affecting Telegram.")
async def set_power_level(evt: CommandEvent): async def set_power_level(evt: CommandEvent) -> Dict:
try: try:
level = int(evt.args[0]) level = int(evt.args[0])
except KeyError: except KeyError:
@@ -45,11 +49,12 @@ async def set_power_level(evt: CommandEvent):
except MatrixRequestError: except MatrixRequestError:
evt.log.exception("Failed to set power level.") evt.log.exception("Failed to set power level.")
return await evt.reply("Failed to set power level.") return await evt.reply("Failed to set power level.")
return {}
@command_handler(help_section=SECTION_PORTAL_MANAGEMENT, @command_handler(help_section=SECTION_PORTAL_MANAGEMENT,
help_text="Get a Telegram invite link to the current chat.") help_text="Get a Telegram invite link to the current chat.")
async def invite_link(evt: CommandEvent): async def invite_link(evt: CommandEvent) -> Dict:
portal = po.Portal.get_by_mxid(evt.room_id) portal = po.Portal.get_by_mxid(evt.room_id)
if not portal: if not portal:
return await evt.reply("This is not a portal room.") return await evt.reply("This is not a portal room.")
@@ -66,7 +71,8 @@ async def invite_link(evt: CommandEvent):
return await evt.reply("You don't have the permission to create an invite link.") return await evt.reply("You don't have the permission to create an invite link.")
async def user_has_power_level(room: str, intent, sender: u.User, event: str, default: int = 50): async def user_has_power_level(room: str, intent, sender: u.User, event: str, default: int = 50
) -> bool:
if sender.is_admin: if sender.is_admin:
return True return True
# Make sure the state store contains the power levels. # Make sure the state store contains the power levels.
@@ -80,23 +86,26 @@ async def user_has_power_level(room: str, intent, sender: u.User, event: str, de
async def _get_portal_and_check_permission(evt: CommandEvent, permission: str, async def _get_portal_and_check_permission(evt: CommandEvent, permission: str,
action: Optional[str] = None): action: Optional[str] = None
room_id = evt.args[0] if len(evt.args) > 0 else evt.room_id ) -> Optional[po.Portal]:
room_id = MatrixRoomID(evt.args[0]) if len(evt.args) > 0 else evt.room_id
portal = po.Portal.get_by_mxid(room_id) portal = po.Portal.get_by_mxid(room_id)
if not portal: if not portal:
that_this = "This" if room_id == evt.room_id else "That" that_this = "This" if room_id == evt.room_id else "That"
return await evt.reply(f"{that_this} is not a portal room."), False await evt.reply(f"{that_this} is not a portal room.")
return None
if not await user_has_power_level(portal.mxid, evt.az.intent, evt.sender, permission): if not await user_has_power_level(portal.mxid, evt.az.intent, evt.sender, permission):
action = action or f"{permission.replace('_', ' ')}s" action = action or f"{permission.replace('_', ' ')}s"
return await evt.reply(f"You do not have the permissions to {action} that portal."), False await evt.reply(f"You do not have the permissions to {action} that portal.")
return portal, True return None
return portal
def _get_portal_murder_function(action: str, room_id: str, function: Callable, command: str, def _get_portal_murder_function(action: str, room_id: str, function: Callable, command: str,
completed_message: str): completed_message: str) -> Dict:
async def post_confirm(confirm): async def post_confirm(confirm) -> Optional[Dict]:
confirm.sender.command_status = None confirm.sender.command_status = None
if len(confirm.args) > 0 and confirm.args[0] == f"confirm-{command}": if len(confirm.args) > 0 and confirm.args[0] == f"confirm-{command}":
await function() await function()
@@ -104,6 +113,7 @@ def _get_portal_murder_function(action: str, room_id: str, function: Callable, c
return await confirm.reply(completed_message) return await confirm.reply(completed_message)
else: else:
return await confirm.reply(f"{action} cancelled.") return await confirm.reply(f"{action} cancelled.")
return None
return { return {
"next": post_confirm, "next": post_confirm,
@@ -116,10 +126,10 @@ def _get_portal_murder_function(action: str, room_id: str, function: Callable, c
help_text="Remove all users from the current portal room and forget the portal. " help_text="Remove all users from the current portal room and forget the portal. "
"Only works for group chats; to delete a private chat portal, simply " "Only works for group chats; to delete a private chat portal, simply "
"leave the room.") "leave the room.")
async def delete_portal(evt: CommandEvent): async def delete_portal(evt: CommandEvent) -> Optional[Dict]:
portal, ok = await _get_portal_and_check_permission(evt, "unbridge") portal = await _get_portal_and_check_permission(evt, "unbridge")
if not ok: if not portal:
return return None
evt.sender.command_status = _get_portal_murder_function("Portal deletion", portal.mxid, evt.sender.command_status = _get_portal_murder_function("Portal deletion", portal.mxid,
portal.cleanup_and_delete, "delete", portal.cleanup_and_delete, "delete",
@@ -137,10 +147,10 @@ async def delete_portal(evt: CommandEvent):
@command_handler(needs_auth=False, needs_puppeting=False, @command_handler(needs_auth=False, needs_puppeting=False,
help_section=SECTION_PORTAL_MANAGEMENT, help_section=SECTION_PORTAL_MANAGEMENT,
help_text="Remove puppets from the current portal room and forget the portal.") help_text="Remove puppets from the current portal room and forget the portal.")
async def unbridge(evt: CommandEvent): async def unbridge(evt: CommandEvent) -> Optional[Dict]:
portal, ok = await _get_portal_and_check_permission(evt, "unbridge") portal = await _get_portal_and_check_permission(evt, "unbridge")
if not ok: if not portal:
return return None
evt.sender.command_status = _get_portal_murder_function("Room unbridging", portal.mxid, evt.sender.command_status = _get_portal_murder_function("Room unbridging", portal.mxid,
portal.unbridge, "unbridge", portal.unbridge, "unbridge",
@@ -156,11 +166,11 @@ async def unbridge(evt: CommandEvent):
help_text="Bridge the current Matrix room to the Telegram chat with the given " help_text="Bridge the current Matrix room to the Telegram chat with the given "
"ID. The ID must be the prefixed version that you get with the `/id` " "ID. The ID must be the prefixed version that you get with the `/id` "
"command of the Telegram-side bot.") "command of the Telegram-side bot.")
async def bridge(evt: CommandEvent): async def bridge(evt: CommandEvent) -> Dict:
if len(evt.args) == 0: if len(evt.args) == 0:
return await evt.reply("**Usage:** " return await evt.reply("**Usage:** "
"`$cmdprefix+sp bridge <Telegram chat ID> [Matrix room ID]`") "`$cmdprefix+sp bridge <Telegram chat ID> [Matrix room ID]`")
room_id = evt.args[1] if len(evt.args) > 1 else evt.room_id room_id = MatrixRoomID(evt.args[1]) if len(evt.args) > 1 else evt.room_id
that_this = "This" if room_id == evt.room_id else "That" that_this = "This" if room_id == evt.room_id else "That"
portal = po.Portal.get_by_mxid(room_id) portal = po.Portal.get_by_mxid(room_id)
@@ -171,12 +181,12 @@ async def bridge(evt: CommandEvent):
return await evt.reply(f"You do not have the permissions to bridge {that_this} room.") return await evt.reply(f"You do not have the permissions to bridge {that_this} room.")
# The /id bot command provides the prefixed ID, so we assume # The /id bot command provides the prefixed ID, so we assume
tgid = evt.args[0] tgid_str = evt.args[0]
if tgid.startswith("-100"): if tgid_str.startswith("-100"):
tgid = int(tgid[4:]) tgid = TelegramID(int(tgid_str[4:]))
peer_type = "channel" peer_type = "channel"
elif tgid.startswith("-"): elif tgid_str.startswith("-"):
tgid = -int(tgid) tgid = TelegramID(-int(tgid_str))
peer_type = "chat" peer_type = "chat"
else: else:
return await evt.reply("That doesn't seem like a prefixed Telegram chat ID.\n\n" return await evt.reply("That doesn't seem like a prefixed Telegram chat ID.\n\n"
@@ -188,7 +198,7 @@ async def bridge(evt: CommandEvent):
if not portal.allow_bridging(): if not portal.allow_bridging():
return await evt.reply("This bridge doesn't allow bridging that Telegram chat.\n" return await evt.reply("This bridge doesn't allow bridging that Telegram chat.\n"
"If you're the bridge admin, try " "If you're the bridge admin, try "
"`$cmdprefix+sp whitelist <Telegram chat ID>` first.") "`$cmdprefix+sp filter whitelist <Telegram chat ID>` first.")
if portal.mxid: if portal.mxid:
has_portal_message = ( has_portal_message = (
"That Telegram chat already has a portal at " "That Telegram chat already has a portal at "
@@ -222,7 +232,8 @@ async def bridge(evt: CommandEvent):
"chat to this room, use `$cmdprefix+sp continue`") "chat to this room, use `$cmdprefix+sp continue`")
async def cleanup_old_portal_while_bridging(evt: CommandEvent, portal: "po.Portal"): async def cleanup_old_portal_while_bridging(evt: CommandEvent, portal: "po.Portal"
) -> Tuple[bool, Optional[Coroutine[None, None, None]]]:
if not portal.mxid: if not portal.mxid:
await evt.reply("The portal seems to have lost its Matrix room between you" await evt.reply("The portal seems to have lost its Matrix room between you"
"calling `$cmdprefix+sp bridge` and this command.\n\n" "calling `$cmdprefix+sp bridge` and this command.\n\n"
@@ -245,7 +256,7 @@ async def cleanup_old_portal_while_bridging(evt: CommandEvent, portal: "po.Porta
return False, None return False, None
async def confirm_bridge(evt: CommandEvent): async def confirm_bridge(evt: CommandEvent) -> Optional[Dict]:
status = evt.sender.command_status status = evt.sender.command_status
try: try:
portal = po.Portal.get_by_tgid(status["tgid"], peer_type=status["peer_type"]) portal = po.Portal.get_by_tgid(status["tgid"], peer_type=status["peer_type"])
@@ -258,7 +269,7 @@ async def confirm_bridge(evt: CommandEvent):
if "mxid" in status: if "mxid" in status:
ok, coro = await cleanup_old_portal_while_bridging(evt, portal) ok, coro = await cleanup_old_portal_while_bridging(evt, portal)
if not ok: if not ok:
return return None
elif coro: elif coro:
asyncio.ensure_future(coro, loop=evt.loop) asyncio.ensure_future(coro, loop=evt.loop)
await evt.reply("Cleaning up previous portal room...") await evt.reply("Cleaning up previous portal room...")
@@ -271,6 +282,7 @@ async def confirm_bridge(evt: CommandEvent):
return await evt.reply("Please use `$cmdprefix+sp continue` to confirm the bridging or " return await evt.reply("Please use `$cmdprefix+sp continue` to confirm the bridging or "
"`$cmdprefix+sp cancel` to cancel.") "`$cmdprefix+sp cancel` to cancel.")
evt.sender.command_status = None
is_logged_in = await evt.sender.is_logged_in() is_logged_in = await evt.sender.is_logged_in()
user = evt.sender if is_logged_in else evt.tgbot user = evt.sender if is_logged_in else evt.tgbot
try: try:
@@ -302,7 +314,7 @@ async def confirm_bridge(evt: CommandEvent):
return await evt.reply("Bridging complete. Portal synchronization should begin momentarily.") return await evt.reply("Bridging complete. Portal synchronization should begin momentarily.")
async def get_initial_state(intent: IntentAPI, room_id: str): async def get_initial_state(intent: IntentAPI, room_id: str) -> Tuple[str, str, Dict]:
state = await intent.get_room_state(room_id) state = await intent.get_room_state(room_id)
title = None title = None
about = None about = None
@@ -328,7 +340,7 @@ async def get_initial_state(intent: IntentAPI, room_id: str):
help_text="Create a Telegram chat of the given type for the current Matrix room. " help_text="Create a Telegram chat of the given type for the current Matrix room. "
"The type is either `group`, `supergroup` or `channel` (defaults to " "The type is either `group`, `supergroup` or `channel` (defaults to "
"`group`).") "`group`).")
async def create(evt: CommandEvent): async def create(evt: CommandEvent) -> Dict:
type = evt.args[0] if len(evt.args) > 0 else "group" type = evt.args[0] if len(evt.args) > 0 else "group"
if type not in {"chat", "group", "supergroup", "channel"}: if type not in {"chat", "group", "supergroup", "channel"}:
return await evt.reply( return await evt.reply(
@@ -363,7 +375,7 @@ async def create(evt: CommandEvent):
@command_handler(help_section=SECTION_PORTAL_MANAGEMENT, @command_handler(help_section=SECTION_PORTAL_MANAGEMENT,
help_text="Upgrade a normal Telegram group to a supergroup.") help_text="Upgrade a normal Telegram group to a supergroup.")
async def upgrade(evt: CommandEvent): async def upgrade(evt: CommandEvent) -> Dict:
portal = po.Portal.get_by_mxid(evt.room_id) portal = po.Portal.get_by_mxid(evt.room_id)
if not portal: if not portal:
return await evt.reply("This is not a portal room.") return await evt.reply("This is not a portal room.")
@@ -381,11 +393,122 @@ async def upgrade(evt: CommandEvent):
return await evt.reply(e.args[0]) return await evt.reply(e.args[0])
@command_handler(help_section=SECTION_PORTAL_MANAGEMENT,
help_text="View or change per-portal settings.",
help_args="<`help`|_subcommand_> [...]")
async def config(evt: CommandEvent) -> None:
cmd = evt.args[0].lower() if len(evt.args) > 0 else "help"
if cmd not in ("view", "defaults", "set", "unset", "add", "del"):
await config_help(evt)
return
elif cmd == "defaults":
await config_defaults(evt)
return
portal = po.Portal.get_by_mxid(evt.room_id)
if not portal:
await evt.reply("This is not a portal room.")
return
elif cmd == "view":
await config_view(evt, portal)
return
key = evt.args[1] if len(evt.args) > 1 else None
value = yaml.load(" ".join(evt.args[2:])) if len(evt.args) > 2 else None
if cmd == "set":
await config_set(evt, portal, key, value)
elif cmd == "unset":
await config_unset(evt, portal, key)
elif cmd == "add" or cmd == "del":
await config_add_del(evt, portal, key, value, cmd)
else:
return
portal.save()
def config_help(evt: CommandEvent) -> Awaitable[Dict]:
return evt.reply("""**Usage:** `$cmdprefix config <subcommand> [...]`. Subcommands:
* **help** - View this help text.
* **view** - View the current config data.
* **defaults** - View the default config values.
* **set** <_key_> <_value_> - Set a config value.
* **unset** <_key_> - Remove a config value.
* **add** <_key_> <_value_> - Add a value to an array.
* **del** <_key_> <_value_> - Remove a value from an array.
""")
def config_view(evt: CommandEvent, portal: po.Portal) -> Awaitable[Dict]:
stream = StringIO()
yaml.dump(portal.local_config, stream)
return evt.reply(f"Room-specific config:\n\n```yaml\n{stream.getvalue()}```")
def config_defaults(evt: CommandEvent) -> Awaitable[Dict]:
stream = StringIO()
yaml.dump({
"edits_as_replies": evt.config["bridge.edits_as_replies"],
"bridge_notices": {
"default": evt.config["bridge.bridge_notices.default"],
"exceptions": evt.config["bridge.bridge_notices.exceptions"],
},
"bot_messages_as_notices": evt.config["bridge.bot_messages_as_notices"],
"inline_images": evt.config["bridge.inline_images"],
"native_stickers": evt.config["bridge.native_stickers"],
"message_formats": evt.config["bridge.message_formats"],
"state_event_formats": evt.config["bridge.state_event_formats"],
}, stream)
return evt.reply(f"Bridge instance wide config:\n\n```yaml\n{stream.getvalue()}```")
def config_set(evt: CommandEvent, portal: po.Portal, key: str, value: str) -> Awaitable[Dict]:
if not key or value is None:
return evt.reply(f"**Usage:** `$cmdprefix+sp config set <key> <value>`")
elif util.recursive_set(portal.local_config, key, value):
return evt.reply(f"Successfully set the value of `{key}` to `{value}`.")
else:
return evt.reply(f"Failed to set value of `{key}`. "
"Does the path contain non-map types?")
def config_unset(evt: CommandEvent, portal: po.Portal, key: str) -> Awaitable[Dict]:
if not key:
return evt.reply(f"**Usage:** `$cmdprefix+sp config unset <key>`")
elif util.recursive_del(portal.local_config, key):
return evt.reply(f"Successfully deleted `{key}` from config.")
else:
return evt.reply(f"`{key}` not found in config.")
def config_add_del(evt: CommandEvent, portal: po.Portal, key: str, value: str, cmd: str
) -> Awaitable[Dict]:
if not key or value is None:
return evt.reply(f"**Usage:** `$cmdprefix+sp config {cmd} <key> <value>`")
arr = util.recursive_get(portal.local_config, key)
if not arr:
return evt.reply(f"`{key}` not found in config. "
f"Maybe do `$cmdprefix+sp config set {key} []` first?")
elif not isinstance(arr, list):
return evt.reply("`{key}` does not seem to be an array.")
elif cmd == "add":
if value in arr:
return evt.reply(f"The array at `{key}` already contains `{value}`.")
arr.append(value)
return evt.reply(f"Successfully added `{value}` to the array at `{key}`")
else:
if value not in arr:
return evt.reply(f"The array at `{key}` does not contain `{value}`.")
arr.remove(value)
return evt.reply(f"Successfully removed `{value}` from the array at `{key}`")
@command_handler(help_section=SECTION_PORTAL_MANAGEMENT, @command_handler(help_section=SECTION_PORTAL_MANAGEMENT,
help_args="<_name_|`-`>", help_args="<_name_|`-`>",
help_text="Change the username of a supergroup/channel. " help_text="Change the username of a supergroup/channel. "
"To disable, use a dash (`-`) as the name.") "To disable, use a dash (`-`) as the name.")
async def group_name(evt: CommandEvent): async def group_name(evt: CommandEvent) -> Dict:
if len(evt.args) == 0: if len(evt.args) == 0:
return await evt.reply("**Usage:** `$cmdprefix+sp group-name <name/->`") return await evt.reply("**Usage:** `$cmdprefix+sp group-name <name/->`")
@@ -421,7 +544,7 @@ async def group_name(evt: CommandEvent):
help_args="<`whitelist`|`blacklist`>", help_args="<`whitelist`|`blacklist`>",
help_text="Change whether the bridge will allow or disallow bridging rooms by " help_text="Change whether the bridge will allow or disallow bridging rooms by "
"default.") "default.")
async def filter_mode(evt: CommandEvent): async def filter_mode(evt: CommandEvent) -> Dict:
try: try:
mode = evt.args[0] mode = evt.args[0]
if mode not in ("whitelist", "blacklist"): if mode not in ("whitelist", "blacklist"):
@@ -446,19 +569,19 @@ async def filter_mode(evt: CommandEvent):
help_section=SECTION_ADMIN, help_section=SECTION_ADMIN,
help_args="<`whitelist`|`blacklist`> <_chat ID_>", help_args="<`whitelist`|`blacklist`> <_chat ID_>",
help_text="Allow or disallow bridging a specific chat.") help_text="Allow or disallow bridging a specific chat.")
async def filter(evt: CommandEvent): async def filter(evt: CommandEvent) -> Optional[Dict]:
try: try:
action = evt.args[0] action = evt.args[0]
if action not in ("whitelist", "blacklist", "add", "remove"): if action not in ("whitelist", "blacklist", "add", "remove"):
raise ValueError() raise ValueError()
id = evt.args[1] id_str = evt.args[1]
if id.startswith("-100"): if id_str.startswith("-100"):
id = int(id[4:]) id = int(id_str[4:])
elif id.startswith("-"): elif id_str.startswith("-"):
id = int(id[1:]) id = int(id_str[1:])
else: else:
id = int(id) id = int(id_str)
except (IndexError, ValueError): except (IndexError, ValueError):
return await evt.reply("**Usage:** `$cmdprefix+sp filter <whitelist/blacklist> <chat ID>`") return await evt.reply("**Usage:** `$cmdprefix+sp filter <whitelist/blacklist> <chat ID>`")
@@ -471,7 +594,7 @@ async def filter(evt: CommandEvent):
if action in ("blacklist", "whitelist"): if action in ("blacklist", "whitelist"):
action = "add" if mode == action else "remove" action = "add" if mode == action else "remove"
def save(): def save() -> None:
evt.config["bridge.filter.list"] = list evt.config["bridge.filter.list"] = list
evt.config.save() evt.config.save()
po.Portal.filter_list = list po.Portal.filter_list = list
@@ -488,3 +611,4 @@ async def filter(evt: CommandEvent):
list.remove(id) list.remove(id)
save() save()
return await evt.reply(f"Chat ID removed from {mode}.") return await evt.reply(f"Chat ID removed from {mode}.")
return None
+14 -8
View File
@@ -14,8 +14,13 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from telethon.errors import * from typing import Awaitable, Dict, List, Optional, Tuple
import re
from telethon.errors import (
InviteHashInvalidError, InviteHashExpiredError, UserAlreadyParticipantError)
from telethon.tl.types import User as TLUser from telethon.tl.types import User as TLUser
from telethon.tl.types import TypeUpdates
from telethon.tl.functions.messages import ImportChatInviteRequest, CheckChatInviteRequest from telethon.tl.functions.messages import ImportChatInviteRequest, CheckChatInviteRequest
from telethon.tl.functions.channels import JoinChannelRequest from telethon.tl.functions.channels import JoinChannelRequest
@@ -26,7 +31,7 @@ from . import command_handler, CommandEvent, SECTION_MISC, SECTION_CREATING_PORT
@command_handler(help_section=SECTION_MISC, @command_handler(help_section=SECTION_MISC,
help_args="[_-r|--remote_] <_query_>", help_args="[_-r|--remote_] <_query_>",
help_text="Search your contacts or the Telegram servers for users.") help_text="Search your contacts or the Telegram servers for users.")
async def search(evt: CommandEvent): async def search(evt: CommandEvent) -> Optional[Dict]:
if len(evt.args) == 0: if len(evt.args) == 0:
return await evt.reply("**Usage:** `$cmdprefix+sp search [-r|--remote] <query>`") return await evt.reply("**Usage:** `$cmdprefix+sp search [-r|--remote] <query>`")
@@ -47,7 +52,7 @@ async def search(evt: CommandEvent):
"Minimum length of remote query is 5 characters.") "Minimum length of remote query is 5 characters.")
return await evt.reply("No results 3:") return await evt.reply("No results 3:")
reply = [] reply = [] # type: List[str]
if remote: if remote:
reply += ["**Results from Telegram server:**", ""] reply += ["**Results from Telegram server:**", ""]
else: else:
@@ -68,7 +73,7 @@ async def search(evt: CommandEvent):
"either the internal user ID, the username or the phone number. " "either the internal user ID, the username or the phone number. "
"**N.B.** The phone numbers you start chats with must already be in " "**N.B.** The phone numbers you start chats with must already be in "
"your contacts.") "your contacts.")
async def private_message(evt: CommandEvent): async def private_message(evt: CommandEvent) -> Optional[Dict]:
if len(evt.args) == 0: if len(evt.args) == 0:
return await evt.reply("**Usage:** `$cmdprefix+sp pm <user identifier>`") return await evt.reply("**Usage:** `$cmdprefix+sp pm <user identifier>`")
@@ -87,7 +92,7 @@ async def private_message(evt: CommandEvent):
f"{pu.Puppet.get_displayname(user, False)}") f"{pu.Puppet.get_displayname(user, False)}")
async def _join(evt: CommandEvent, arg: str): async def _join(evt: CommandEvent, arg: str) -> Tuple[Optional[TypeUpdates], Optional[Dict]]:
if arg.startswith("joinchat/"): if arg.startswith("joinchat/"):
invite_hash = arg[len("joinchat/"):] invite_hash = arg[len("joinchat/"):]
try: try:
@@ -110,7 +115,7 @@ async def _join(evt: CommandEvent, arg: str):
@command_handler(help_section=SECTION_CREATING_PORTALS, @command_handler(help_section=SECTION_CREATING_PORTALS,
help_args="<_link_>", help_args="<_link_>",
help_text="Join a chat with an invite link.") help_text="Join a chat with an invite link.")
async def join(evt: CommandEvent): async def join(evt: CommandEvent) -> Optional[Dict]:
if len(evt.args) == 0: if len(evt.args) == 0:
return await evt.reply("**Usage:** `$cmdprefix+sp join <invite link>`") return await evt.reply("**Usage:** `$cmdprefix+sp join <invite link>`")
@@ -121,7 +126,7 @@ async def join(evt: CommandEvent):
updates, _ = await _join(evt, arg.group(1)) updates, _ = await _join(evt, arg.group(1))
if not updates: if not updates:
return return None
for chat in updates.chats: for chat in updates.chats:
portal = po.Portal.get_by_entity(chat) portal = po.Portal.get_by_entity(chat)
@@ -132,12 +137,13 @@ async def join(evt: CommandEvent):
await evt.reply(f"Creating room for {chat.title}... This might take a while.") await evt.reply(f"Creating room for {chat.title}... This might take a while.")
await portal.create_matrix_room(evt.sender, chat, [evt.sender.mxid]) await portal.create_matrix_room(evt.sender, chat, [evt.sender.mxid])
return await evt.reply(f"Created room for {portal.title}") return await evt.reply(f"Created room for {portal.title}")
return None
@command_handler(help_section=SECTION_MISC, @command_handler(help_section=SECTION_MISC,
help_args="[`chats`|`contacts`|`me`]", help_args="[`chats`|`contacts`|`me`]",
help_text="Synchronize your chat portals, contacts and/or own info.") help_text="Synchronize your chat portals, contacts and/or own info.")
async def sync(evt: CommandEvent): async def sync(evt: CommandEvent) -> Optional[Dict]:
if len(evt.args) > 0: if len(evt.args) > 0:
sync_only = evt.args[0] sync_only = evt.args[0]
if sync_only not in ("chats", "contacts", "me"): if sync_only not in ("chats", "contacts", "me"):
+44 -28
View File
@@ -14,23 +14,34 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Tuple, Any, Optional from typing import Any, Dict, Optional, Tuple
from ruamel.yaml import YAML from ruamel.yaml import YAML
from ruamel.yaml.comments import CommentedMap from ruamel.yaml.comments import CommentedMap
import random import random
import string import string
yaml = YAML() yaml = YAML() # type: YAML
yaml.indent(4) yaml.indent(4)
class DictWithRecursion: class DictWithRecursion:
def __init__(self, data: CommentedMap = None): def __init__(self, data: Optional[CommentedMap] = None) -> None:
self._data = data or CommentedMap() # type: CommentedMap self._data = data or CommentedMap() # type: CommentedMap
@staticmethod
def _parse_key(key: str) -> Tuple[str, Optional[str]]:
if '.' not in key:
return key, None
key, next_key = key.split('.', 1)
if len(key) > 0 and key[0] == "[":
end_index = next_key.index("]")
key = key[1:] + "." + next_key[:end_index]
next_key = next_key[end_index + 2:] if len(next_key) > end_index + 1 else None
return key, next_key
def _recursive_get(self, data: CommentedMap, key: str, default_value: Any) -> Any: def _recursive_get(self, data: CommentedMap, key: str, default_value: Any) -> Any:
if '.' in key: key, next_key = self._parse_key(key)
key, next_key = key.split('.', 1) if next_key is not None:
next_data = data.get(key, CommentedMap()) next_data = data.get(key, CommentedMap())
return self._recursive_get(next_data, next_key, default_value) return self._recursive_get(next_data, next_key, default_value)
return data.get(key, default_value) return data.get(key, default_value)
@@ -46,40 +57,38 @@ class DictWithRecursion:
def __contains__(self, key: str) -> bool: def __contains__(self, key: str) -> bool:
return self[key] is not None return self[key] is not None
def _recursive_set(self, data: CommentedMap, key: str, value: Any): def _recursive_set(self, data: CommentedMap, key: str, value: Any) -> None:
if '.' in key: key, next_key = self._parse_key(key)
key, next_key = key.split('.', 1) if next_key is not None:
if key not in data: if key not in data:
data[key] = CommentedMap() data[key] = CommentedMap()
next_data = data.get(key, CommentedMap()) next_data = data.get(key, CommentedMap())
self._recursive_set(next_data, next_key, value) return self._recursive_set(next_data, next_key, value)
return
data[key] = value data[key] = value
def set(self, key: str, value: Any, allow_recursion: bool = True): def set(self, key: str, value: Any, allow_recursion: bool = True) -> None:
if allow_recursion and '.' in key: if allow_recursion and '.' in key:
self._recursive_set(self._data, key, value) self._recursive_set(self._data, key, value)
return return
self._data[key] = value self._data[key] = value
def __setitem__(self, key: str, value: Any): def __setitem__(self, key: str, value: Any) -> None:
self.set(key, value) self.set(key, value)
def _recursive_del(self, data: CommentedMap, key: str): def _recursive_del(self, data: CommentedMap, key: str) -> None:
if '.' in key: key, next_key = self._parse_key(key)
key, next_key = key.split('.', 1) if next_key is not None:
if key not in data: if key not in data:
return return
next_data = data[key] next_data = data[key]
self._recursive_del(next_data, next_key) return self._recursive_del(next_data, next_key)
return
try: try:
del data[key] del data[key]
del data.ca.items[key] del data.ca.items[key]
except KeyError: except KeyError:
pass pass
def delete(self, key: str, allow_recursion: bool = True): def delete(self, key: str, allow_recursion: bool = True) -> None:
if allow_recursion and '.' in key: if allow_recursion and '.' in key:
self._recursive_del(self._data, key) self._recursive_del(self._data, key)
return return
@@ -89,19 +98,19 @@ class DictWithRecursion:
except KeyError: except KeyError:
pass pass
def __delitem__(self, key: str): def __delitem__(self, key: str) -> None:
self.delete(key) self.delete(key)
class Config(DictWithRecursion): class Config(DictWithRecursion):
def __init__(self, path: str, registration_path: str, base_path: str): def __init__(self, path: str, registration_path: str, base_path: str) -> None:
super().__init__() super().__init__()
self.path = path # type: str self.path = path # type: str
self.registration_path = registration_path # type: str self.registration_path = registration_path # type: str
self.base_path = base_path # type: str self.base_path = base_path # type: str
self._registration = None # type: dict self._registration = None # type: Optional[Dict]
def load(self): def load(self) -> None:
with open(self.path, 'r') as stream: with open(self.path, 'r') as stream:
self._data = yaml.load(stream) self._data = yaml.load(stream)
@@ -113,7 +122,7 @@ class Config(DictWithRecursion):
pass pass
return None return None
def save(self): def save(self) -> None:
with open(self.path, 'w') as stream: with open(self.path, 'w') as stream:
yaml.dump(self._data, stream) yaml.dump(self._data, stream)
if self._registration and self.registration_path: if self._registration and self.registration_path:
@@ -124,16 +133,16 @@ class Config(DictWithRecursion):
def _new_token() -> str: def _new_token() -> str:
return "".join(random.choice(string.ascii_lowercase + string.digits) for _ in range(64)) return "".join(random.choice(string.ascii_lowercase + string.digits) for _ in range(64))
def update(self): def update(self) -> None:
base = self.load_base() base = self.load_base()
if not base: if not base:
return return
def copy(from_path, to_path=None): def copy(from_path, to_path=None) -> None:
if from_path in self: if from_path in self:
base[to_path or from_path] = self[from_path] base[to_path or from_path] = self[from_path]
def copy_dict(from_path, to_path=None, override_existing_map=True): def copy_dict(from_path, to_path=None, override_existing_map=True) -> None:
if from_path in self: if from_path in self:
to_path = to_path or from_path to_path = to_path or from_path
if override_existing_map or to_path not in base: if override_existing_map or to_path not in base:
@@ -156,6 +165,7 @@ class Config(DictWithRecursion):
copy("appservice.max_body_size") copy("appservice.max_body_size")
copy("appservice.database") copy("appservice.database")
copy("appservice.sqlalchemy_core_mode")
copy("appservice.public.enabled") copy("appservice.public.enabled")
copy("appservice.public.prefix") copy("appservice.public.prefix")
@@ -183,7 +193,13 @@ class Config(DictWithRecursion):
copy("bridge.edits_as_replies") copy("bridge.edits_as_replies")
copy("bridge.highlight_edits") copy("bridge.highlight_edits")
copy("bridge.bridge_notices") if isinstance(self["bridge.bridge_notices"], bool):
base["bridge.bridge_notices"] = {
"default": self["bridge.bridge_notices"],
"exceptions": ["@importantbot:example.com"],
}
else:
copy("bridge.bridge_notices")
copy("bridge.bot_messages_as_notices") copy("bridge.bot_messages_as_notices")
copy("bridge.max_initial_member_sync") copy("bridge.max_initial_member_sync")
copy("bridge.sync_channel_members") copy("bridge.sync_channel_members")
@@ -273,7 +289,7 @@ class Config(DictWithRecursion):
return self._get_permissions("*") return self._get_permissions("*")
def generate_registration(self): def generate_registration(self) -> None:
homeserver = self["homeserver.domain"] homeserver = self["homeserver.domain"]
username_format = self.get("bridge.username_template", "telegram_{userid}") \ username_format = self.get("bridge.username_template", "telegram_{userid}") \
+7 -8
View File
@@ -14,7 +14,7 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import TYPE_CHECKING, Optional from typing import Optional, Tuple, TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
import asyncio import asyncio
@@ -32,7 +32,8 @@ if TYPE_CHECKING:
class Context: class Context:
def __init__(self, az: "AppService", db: "scoped_session", config: "Config", def __init__(self, az: "AppService", db: "scoped_session", config: "Config",
loop: "asyncio.AbstractEventLoop", session_container: "AlchemySessionContainer"): loop: "asyncio.AbstractEventLoop", session_container: "AlchemySessionContainer"
) -> None:
self.az = az # type: AppService self.az = az # type: AppService
self.db = db # type: scoped_session self.db = db # type: scoped_session
self.config = config # type: Config self.config = config # type: Config
@@ -43,9 +44,7 @@ class Context:
self.public_website = None # type: PublicBridgeWebsite self.public_website = None # type: PublicBridgeWebsite
self.provisioning_api = None # type: ProvisioningAPI self.provisioning_api = None # type: ProvisioningAPI
def __iter__(self): @property
yield self.az def core(self) -> Tuple['AppService', 'scoped_session', 'Config',
yield self.db 'asyncio.AbstractEventLoop', Optional['Bot']]:
yield self.config return (self.az, self.db, self.config, self.loop, self.bot)
yield self.loop
yield self.bot
+107 -32
View File
@@ -15,11 +15,17 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from sqlalchemy import (Column, UniqueConstraint, ForeignKey, ForeignKeyConstraint, Integer, from sqlalchemy import (Column, UniqueConstraint, ForeignKey, ForeignKeyConstraint, Integer,
BigInteger, String, Boolean, Text) BigInteger, String, Boolean, Text, Table,
and_, func, select)
from sqlalchemy.engine import Engine, RowProxy
from sqlalchemy.sql import expression from sqlalchemy.sql import expression
from sqlalchemy.orm import relationship, Query from sqlalchemy.orm import relationship, Query
from sqlalchemy.sql.base import ImmutableColumnCollection
from typing import Dict, Optional, List
import json import json
from mautrix_telegram.types import MatrixUserID, MatrixRoomID, MatrixEventID
from .types import TelegramID
from .base import Base from .base import Base
@@ -28,13 +34,15 @@ class Portal(Base):
__tablename__ = "portal" __tablename__ = "portal"
# Telegram chat information # Telegram chat information
tgid = Column(Integer, primary_key=True) tgid = Column(Integer, primary_key=True) # type: TelegramID
tg_receiver = Column(Integer, primary_key=True) tg_receiver = Column(Integer, primary_key=True) # type: TelegramID
peer_type = Column(String, nullable=False) peer_type = Column(String, nullable=False)
megagroup = Column(Boolean) megagroup = Column(Boolean)
# Matrix portal information # Matrix portal information
mxid = Column(String, unique=True, nullable=True) mxid = Column(String, unique=True, nullable=True) # type: Optional[MatrixRoomID]
config = Column(Text, nullable=True)
# Telegram chat metadata # Telegram chat metadata
username = Column(String, nullable=True) username = Column(String, nullable=True)
@@ -44,25 +52,88 @@ class Portal(Base):
class Message(Base): class Message(Base):
query = None # type: Query db = None # type: Engine
t = None # type: Table
c = None # type: ImmutableColumnCollection
__tablename__ = "message" __tablename__ = "message"
mxid = Column(String) mxid = Column(String) # type: MatrixEventID
mx_room = Column(String) mx_room = Column(String) # type: MatrixRoomID
tgid = Column(Integer, primary_key=True) tgid = Column(Integer, primary_key=True) # type: TelegramID
tg_space = Column(Integer, primary_key=True) tg_space = Column(Integer, primary_key=True) # type: TelegramID
__table_args__ = (UniqueConstraint("mxid", "mx_room", "tg_space", name="_mx_id_room"),) __table_args__ = (UniqueConstraint("mxid", "mx_room", "tg_space", name="_mx_id_room"),)
@staticmethod
def _one_or_none(rows: RowProxy) -> Optional['Message']:
try:
mxid, mx_room, tgid, tg_space = next(rows)
return Message(mxid=mxid, mx_room=mx_room, tgid=tgid, tg_space=tg_space)
except StopIteration:
return None
@staticmethod
def _all(rows: RowProxy) -> List['Message']:
return [Message(mxid=row[0], mx_room=row[1], tgid=row[2], tg_space=row[3])
for row in rows]
@classmethod
def get_by_tgid(cls, tgid: TelegramID, tg_space: TelegramID) -> Optional['Message']:
rows = cls.db.execute(cls.t.select()
.where(and_(cls.c.tgid == tgid, cls.c.tg_space == tg_space)))
return cls._one_or_none(rows)
@classmethod
def count_spaces_by_mxid(cls, mxid: MatrixEventID, mx_room: MatrixRoomID) -> int:
rows = cls.db.execute(select([func.count(cls.c.tg_space)])
.where(and_(cls.c.mxid == mxid, cls.c.mx_room == mx_room)))
try:
count, = next(rows)
return count
except StopIteration:
return 0
@classmethod
def get_by_mxid(cls, mxid: MatrixEventID, mx_room: MatrixRoomID, tg_space: TelegramID
) -> Optional['Message']:
rows = cls.db.execute(cls.t.select().where(
and_(cls.c.mxid == mxid, cls.c.mx_room == mx_room, cls.c.tg_space == tg_space)))
return cls._one_or_none(rows)
@classmethod
def update_by_tgid(cls, s_tgid: TelegramID, s_tg_space: TelegramID, **values) -> None:
cls.db.execute(cls.t.update()
.where(and_(cls.c.tgid == s_tgid, cls.c.tg_space == s_tg_space))
.values(**values))
@classmethod
def update_by_mxid(cls, s_mxid: MatrixEventID, s_mx_room: MatrixRoomID, **values) -> None:
cls.db.execute(cls.t.update()
.where(and_(cls.c.mxid == s_mxid, cls.c.mx_room == s_mx_room))
.values(**values))
def update(self, **values) -> None:
for key, value in values.items():
setattr(self, key, value)
self.update_by_tgid(self.tgid, self.tg_space, **values)
def delete(self) -> None:
self.db.execute(self.t.delete().where(
and_(self.c.tgid == self.tgid, self.c.tg_space == self.tg_space)))
def insert(self) -> None:
self.db.execute(self.t.insert().values(mxid=self.mxid, mx_room=self.mx_room, tgid=self.tgid,
tg_space=self.tg_space))
class UserPortal(Base): class UserPortal(Base):
query = None # type: Query query = None # type: Query
__tablename__ = "user_portal" __tablename__ = "user_portal"
user = Column(Integer, ForeignKey("user.tgid", onupdate="CASCADE", ondelete="CASCADE"), user = Column(Integer, ForeignKey("user.tgid", onupdate="CASCADE", ondelete="CASCADE"),
primary_key=True) primary_key=True) # type: TelegramID
portal = Column(Integer, primary_key=True) portal = Column(Integer, primary_key=True) # type: TelegramID
portal_receiver = Column(Integer, primary_key=True) portal_receiver = Column(Integer, primary_key=True) # type: TelegramID
__table_args__ = (ForeignKeyConstraint(("portal", "portal_receiver"), __table_args__ = (ForeignKeyConstraint(("portal", "portal_receiver"),
("portal.tgid", "portal.tg_receiver"), ("portal.tgid", "portal.tg_receiver"),
@@ -73,12 +144,14 @@ class User(Base):
query = None # type: Query query = None # type: Query
__tablename__ = "user" __tablename__ = "user"
mxid = Column(String, primary_key=True) mxid = Column(String, primary_key=True) # type: MatrixUserID
tgid = Column(Integer, nullable=True, unique=True) tgid = Column(Integer, nullable=True, unique=True) # type: Optional[TelegramID]
tg_username = Column(String, nullable=True) tg_username = Column(String, nullable=True)
tg_phone = Column(String, nullable=True)
saved_contacts = Column(Integer, default=0, nullable=False) saved_contacts = Column(Integer, default=0, nullable=False)
contacts = relationship("Contact", uselist=True, contacts = relationship("Contact", uselist=True,
cascade="save-update, merge, delete, delete-orphan") cascade="save-update, merge, delete, delete-orphan"
) # type: List[Contact]
portals = relationship("Portal", secondary="user_portal") portals = relationship("Portal", secondary="user_portal")
@@ -86,22 +159,22 @@ class RoomState(Base):
query = None # type: Query query = None # type: Query
__tablename__ = "mx_room_state" __tablename__ = "mx_room_state"
room_id = Column(String, primary_key=True) room_id = Column(String, primary_key=True) # type: MatrixRoomID
_power_levels_text = Column("power_levels", Text, nullable=True) _power_levels_text = Column("power_levels", Text, nullable=True)
_power_levels_json = None _power_levels_json = {} # type: Dict
@property @property
def has_power_levels(self): def has_power_levels(self) -> bool:
return bool(self._power_levels_text) return bool(self._power_levels_text)
@property @property
def power_levels(self): def power_levels(self) -> Dict:
if not self._power_levels_json and self._power_levels_text: if not self._power_levels_json and self._power_levels_text:
self._power_levels_json = json.loads(self._power_levels_text) self._power_levels_json = json.loads(self._power_levels_text)
return self._power_levels_json or {} return self._power_levels_json
@power_levels.setter @power_levels.setter
def power_levels(self, val): def power_levels(self, val: Dict) -> None:
self._power_levels_json = val self._power_levels_json = val
self._power_levels_text = json.dumps(val) self._power_levels_text = json.dumps(val)
@@ -110,13 +183,13 @@ class UserProfile(Base):
query = None # type: Query query = None # type: Query
__tablename__ = "mx_user_profile" __tablename__ = "mx_user_profile"
room_id = Column(String, primary_key=True) room_id = Column(String, primary_key=True) # type: MatrixRoomID
user_id = Column(String, primary_key=True) user_id = Column(String, primary_key=True) # type: MatrixUserID
membership = Column(String, nullable=False, default="leave") membership = Column(String, nullable=False, default="leave")
displayname = Column(String, nullable=True) displayname = Column(String, nullable=True)
avatar_url = Column(String, nullable=True) avatar_url = Column(String, nullable=True)
def dict(self): def dict(self) -> Dict[str, str]:
return { return {
"membership": self.membership, "membership": self.membership,
"displayname": self.displayname, "displayname": self.displayname,
@@ -128,19 +201,19 @@ class Contact(Base):
query = None # type: Query query = None # type: Query
__tablename__ = "contact" __tablename__ = "contact"
user = Column(Integer, ForeignKey("user.tgid"), primary_key=True) user = Column(Integer, ForeignKey("user.tgid"), primary_key=True) # type: TelegramID
contact = Column(Integer, ForeignKey("puppet.id"), primary_key=True) contact = Column(Integer, ForeignKey("puppet.id"), primary_key=True) # type: TelegramID
class Puppet(Base): class Puppet(Base):
query = None # type: Query query = None # type: Query
__tablename__ = "puppet" __tablename__ = "puppet"
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True) # type: TelegramID
custom_mxid = Column(String, nullable=True) custom_mxid = Column(String, nullable=True) # type: Optional[MatrixUserID]
access_token = Column(String, nullable=True) access_token = Column(String, nullable=True)
displayname = Column(String, nullable=True) displayname = Column(String, nullable=True)
displayname_source = Column(Integer, nullable=True) displayname_source = Column(Integer, nullable=True) # type: Optional[TelegramID]
username = Column(String, nullable=True) username = Column(String, nullable=True)
photo_id = Column(String, nullable=True) photo_id = Column(String, nullable=True)
is_bot = Column(Boolean, nullable=True) is_bot = Column(Boolean, nullable=True)
@@ -151,7 +224,7 @@ class Puppet(Base):
class BotChat(Base): class BotChat(Base):
query = None # type: Query query = None # type: Query
__tablename__ = "bot_chat" __tablename__ = "bot_chat"
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True) # type: TelegramID
type = Column(String, nullable=False) type = Column(String, nullable=False)
@@ -171,9 +244,11 @@ class TelegramFile(Base):
thumbnail = relationship("TelegramFile", uselist=False) thumbnail = relationship("TelegramFile", uselist=False)
def init(db_session): def init(db_session, db_engine) -> None:
Portal.query = db_session.query_property() Portal.query = db_session.query_property()
Message.query = db_session.query_property() Message.db = db_engine
Message.t = Message.__table__
Message.c = Message.t.c
UserPortal.query = db_session.query_property() UserPortal.query = db_session.query_property()
User.query = db_session.query_property() User.query = db_session.query_property()
Puppet.query = db_session.query_property() Puppet.query = db_session.query_property()
+1 -1
View File
@@ -4,6 +4,6 @@ from .from_telegram import (telegram_reply_to_matrix, telegram_to_matrix, init_t
from .. import context as c from .. import context as c
def init(context: c.Context): def init(context: c.Context) -> None:
init_mx(context) init_mx(context)
init_tg(context) init_tg(context)
@@ -14,7 +14,7 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional, List, Tuple, Callable, Pattern, Match, TYPE_CHECKING from typing import Optional, List, Tuple, Callable, Pattern, Match, TYPE_CHECKING, Dict, Any
import re import re
import logging import logging
@@ -22,6 +22,7 @@ from telethon.tl.types import (MessageEntityMention, MessageEntityMentionName, M
TypeMessageEntity) TypeMessageEntity)
from ... import puppet as pu from ... import puppet as pu
from ...types import TelegramID, MatrixRoomID
from ...db import Message as DBMessage from ...db import Message as DBMessage
from ..util import (add_surrogates, remove_surrogates, trim_reply_fallback_html, from ..util import (add_surrogates, remove_surrogates, trim_reply_fallback_html,
trim_reply_fallback_text) trim_reply_fallback_text)
@@ -90,8 +91,8 @@ def matrix_to_telegram(html: str) -> ParsedMessage:
raise FormatError(f"Failed to convert Matrix format: {html}") from e raise FormatError(f"Failed to convert Matrix format: {html}") from e
def matrix_reply_to_telegram(content: dict, tg_space: int, room_id: Optional[str] = None def matrix_reply_to_telegram(content: Dict[str, Any], tg_space: TelegramID,
) -> Optional[int]: room_id: Optional[MatrixRoomID] = None) -> Optional[TelegramID]:
try: try:
reply = content["m.relates_to"]["m.in_reply_to"] reply = content["m.relates_to"]["m.in_reply_to"]
room_id = room_id or reply["room_id"] room_id = room_id or reply["room_id"]
@@ -104,9 +105,7 @@ def matrix_reply_to_telegram(content: dict, tg_space: int, room_id: Optional[str
pass pass
content["body"] = trim_reply_fallback_text(content["body"]) content["body"] = trim_reply_fallback_text(content["body"])
message = DBMessage.query.filter(DBMessage.mxid == event_id, message = DBMessage.get_by_mxid(event_id, room_id, tg_space)
DBMessage.tg_space == tg_space,
DBMessage.mx_room == room_id).one_or_none()
if message: if message:
return message.tgid return message.tgid
except KeyError: except KeyError:
@@ -147,7 +146,7 @@ def plain_mention_to_text() -> Tuple[List[TypeMessageEntity], Callable[[str], st
return entities, replacer return entities, replacer
def init_mx(context: "Context"): def init_mx(context: "Context") -> None:
global plain_mention_regex, should_bridge_plaintext_highlights global plain_mention_regex, should_bridge_plaintext_highlights
config = context.config config = context.config
dn_template = config.get("bridge.displayname_template", "{displayname} (Telegram)") dn_template = config.get("bridge.displayname_template", "{displayname} (Telegram)")
@@ -22,10 +22,15 @@ from telethon.tl.types import TypeMessageEntity
class MatrixParserCommon: class MatrixParserCommon:
mention_regex = re.compile("https://matrix.to/#/(@.+:.+)") # type: Pattern mention_regex = re.compile("https://matrix.to/#/(@.+:.+)") # type: Pattern
room_regex = re.compile("https://matrix.to/#/(#.+:.+)") # type: Pattern room_regex = re.compile("https://matrix.to/#/(#.+:.+)") # type: Pattern
block_tags = ("br", "p", "pre", "blockquote", block_tags = ("p", "pre", "blockquote",
"ol", "ul", "li", "ol", "ul", "li",
"h1", "h2", "h3", "h4", "h5", "h6", "h1", "h2", "h3", "h4", "h5", "h6",
"div", "hr", "table") # type: Tuple[str, ...] "div", "hr", "table") # type: Tuple[str, ...]
list_bullets = ("", "", "", "") # type: Tuple[str, ...]
@classmethod
def list_bullet(cls, depth: int) -> str:
return cls.list_bullets[(depth - 1) % len(cls.list_bullets)] + " "
ParsedMessage = Tuple[str, List[TypeMessageEntity]] ParsedMessage = Tuple[str, List[TypeMessageEntity]]
@@ -14,7 +14,7 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import (Optional, List, Tuple, Type, Dict, Any, Deque, Match) from typing import (Optional, List, Tuple, Type, Dict, Any, TYPE_CHECKING, Match)
from html import unescape from html import unescape
from html.parser import HTMLParser from html.parser import HTMLParser
from collections import deque from collections import deque
@@ -26,9 +26,13 @@ from telethon.tl.types import (MessageEntityMention, MessageEntityMentionName, M
MessageEntityBotCommand, TypeMessageEntity) MessageEntityBotCommand, TypeMessageEntity)
from ... import user as u, puppet as pu, portal as po from ... import user as u, puppet as pu, portal as po
from ...types import MatrixUserID
from ..util import html_to_unicode from ..util import html_to_unicode
from .parser_common import MatrixParserCommon, ParsedMessage from .parser_common import MatrixParserCommon, ParsedMessage
if TYPE_CHECKING:
from typing import Deque
def parse_html(html: str) -> ParsedMessage: def parse_html(html: str) -> ParsedMessage:
parser = MatrixParser() parser = MatrixParser()
@@ -52,7 +56,7 @@ class MatrixParser(HTMLParser, MatrixParserCommon):
) -> Tuple[Optional[Type[TypeMessageEntity]], Optional[str]]: ) -> Tuple[Optional[Type[TypeMessageEntity]], Optional[str]]:
mention = self.mention_regex.match(url) # type: Match mention = self.mention_regex.match(url) # type: Match
if mention: if mention:
mxid = mention.group(1) mxid = MatrixUserID(mention.group(1))
user = (pu.Puppet.get_by_mxid(mxid) user = (pu.Puppet.get_by_mxid(mxid)
or u.User.get_by_mxid(mxid, create=False)) or u.User.get_by_mxid(mxid, create=False))
if not user: if not user:
@@ -80,12 +84,12 @@ class MatrixParser(HTMLParser, MatrixParserCommon):
args["url"] = url args["url"] = url
return MessageEntityTextUrl, None return MessageEntityTextUrl, None
def handle_starttag(self, tag: str, attrs: List[Tuple[str, str]]): def handle_starttag(self, tag: str, attrs_list: List[Tuple[str, str]]):
self._open_tags.appendleft(tag) self._open_tags.appendleft(tag)
self._open_tags_meta.appendleft(0) self._open_tags_meta.appendleft(0)
attrs = dict(attrs) attrs = dict(attrs_list)
entity_type = None # type: type(TypeMessageEntity) entity_type = None # type: Optional[Type[TypeMessageEntity]]
args = {} # type: Dict[str, Any] args = {} # type: Dict[str, Any]
if tag in ("strong", "b"): if tag in ("strong", "b"):
entity_type = MessageEntityBold entity_type = MessageEntityBold
@@ -119,7 +123,7 @@ class MatrixParser(HTMLParser, MatrixParserCommon):
self._open_tags_meta.popleft() self._open_tags_meta.popleft()
self._open_tags_meta.appendleft(url) self._open_tags_meta.appendleft(url)
if tag in self.block_tags and ("blockquote" not in self._open_tags or tag == "br"): if (tag in self.block_tags and ("blockquote" not in self._open_tags)) or tag == "br":
self._newline() self._newline()
if entity_type and tag not in self._building_entities: if entity_type and tag not in self._building_entities:
@@ -198,7 +202,8 @@ class MatrixParser(HTMLParser, MatrixParserCommon):
else: else:
prefix = int(math.log(n, 10)) * 3 * " " + 4 * " " prefix = int(math.log(n, 10)) * 3 * " " + 4 * " "
else: else:
prefix = "* " if self._list_entry_is_new else 3 * " " prefix = (self.list_bullet(self._open_tags.count('ul'))
if self._list_entry_is_new else 3 * " ")
if not self._list_entry_is_new and not self._line_is_new: if not self._list_entry_is_new and not self._line_is_new:
prefix = "" prefix = ""
extra_offset += len(indent) + len(prefix) extra_offset += len(indent) + len(prefix)
@@ -14,166 +14,49 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional, List, Tuple, Union, Callable from typing import List, Tuple
from lxml import html from lxml import html
from telethon.tl.types import (MessageEntityMention as Mention, from telethon.tl.types import (MessageEntityMention as Mention, MessageEntityBotCommand as Command,
MessageEntityMentionName as MentionName, MessageEntityEmail as Email, MessageEntityMentionName as MentionName, MessageEntityEmail as Email,
MessageEntityUrl as URL, MessageEntityTextUrl as TextURL, MessageEntityUrl as URL, MessageEntityTextUrl as TextURL,
MessageEntityBold as Bold, MessageEntityItalic as Italic, MessageEntityBold as Bold, MessageEntityItalic as Italic,
MessageEntityCode as Code, MessageEntityPre as Pre, MessageEntityCode as Code, MessageEntityPre as Pre)
MessageEntityBotCommand as Command, TypeMessageEntity,
InputMessageEntityMentionName as InputMentionName)
from ... import user as u, puppet as pu, portal as po from ... import user as u, puppet as pu, portal as po
from ...types import MatrixUserID
from ..util import html_to_unicode from ..util import html_to_unicode
from .parser_common import MatrixParserCommon, ParsedMessage from .parser_common import MatrixParserCommon, ParsedMessage
from .telegram_message import TelegramMessage, Entity, offset_length_multiply
def parse_html(html: str) -> ParsedMessage: def parse_html(input_html: str) -> ParsedMessage:
return MatrixParser.parse(html) return MatrixParser.parse(input_html)
class Entity: class RecursionContext:
@staticmethod def __init__(self, strip_linebreaks: bool = True, ul_depth: int = 0):
def copy(entity: TypeMessageEntity) -> Optional[TypeMessageEntity]: self.strip_linebreaks = strip_linebreaks # type: bool
if not entity: self.ul_depth = ul_depth # type: int
return None self._inited = True # type: bool
kwargs = {
"offset": entity.offset,
"length": entity.length,
}
if isinstance(entity, Pre):
kwargs["language"] = entity.language
elif isinstance(entity, TextURL):
kwargs["url"] = entity.url
elif isinstance(entity, (MentionName, InputMentionName)):
kwargs["user_id"] = entity.user_id
return entity.__class__(**kwargs)
@classmethod def __setattr__(self, key, value):
def adjust(cls, entity: Union[TypeMessageEntity, List[TypeMessageEntity]], if getattr(self, "_inited", False) is True:
func: Callable[[TypeMessageEntity], None] raise TypeError("'RecursionContext' object is immutable")
) -> Union[Optional[TypeMessageEntity], List[TypeMessageEntity]]: super(RecursionContext, self).__setattr__(key, value)
if isinstance(entity, list):
return [Entity.adjust(element, func) for element in entity if entity]
elif not entity:
return None
entity = cls.copy(entity)
func(entity)
if entity.offset < 0:
entity.length += entity.offset
entity.offset = 0
return entity
def enter_list(self) -> 'RecursionContext':
return RecursionContext(strip_linebreaks=self.strip_linebreaks, ul_depth=self.ul_depth + 1)
def offset_diff(amount: int): def enter_code_block(self) -> 'RecursionContext':
def func(entity: TypeMessageEntity): return RecursionContext(strip_linebreaks=False, ul_depth=self.ul_depth)
entity.offset += amount
return func
def offset_length_multiply(amount: int):
def func(entity: TypeMessageEntity):
entity.offset *= amount
entity.length *= amount
return func
class TelegramMessage:
def __init__(self, text: str = "", entities: Optional[List[TypeMessageEntity]] = None):
self.text = text # type: str
self.entities = entities or [] # type: List[TypeMessageEntity]
def offset_entities(self, offset: int) -> "TelegramMessage":
def apply_offset(entity: TypeMessageEntity, inner_offset: int
) -> Optional[TypeMessageEntity]:
entity = Entity.copy(entity)
entity.offset += inner_offset
if entity.offset < 0:
entity.offset = 0
elif entity.offset > len(self.text):
return None
elif entity.offset + entity.length > len(self.text):
entity.length = len(self.text) - entity.offset
return entity
self.entities = [apply_offset(entity, offset) for entity in self.entities if entity]
self.entities = [x for x in self.entities if x is not None]
return self
def append(self, *args: Union[str, "TelegramMessage"]) -> "TelegramMessage":
for msg in args:
if isinstance(msg, str):
msg = TelegramMessage(text=msg)
self.entities += Entity.adjust(msg.entities, offset_diff(len(self.text)))
self.text += msg.text
return self
def prepend(self, *args: Union[str, "TelegramMessage"]) -> "TelegramMessage":
for msg in args:
if isinstance(msg, str):
msg = TelegramMessage(text=msg)
self.entities = msg.entities + Entity.adjust(self.entities, offset_diff(len(msg.text)))
self.text = msg.text + self.text
return self
def format(self, entity_type: type(TypeMessageEntity), offset: int = None, length: int = None,
**kwargs) -> "TelegramMessage":
self.entities.append(entity_type(offset=offset or 0,
length=length if length is not None else len(self.text),
**kwargs))
return self
def concat(self, *args: Union[str, "TelegramMessage"]) -> "TelegramMessage":
return TelegramMessage().append(self, *args)
def trim(self) -> "TelegramMessage":
orig_len = len(self.text)
self.text = self.text.lstrip()
diff = orig_len - len(self.text)
self.text = self.text.rstrip()
self.offset_entities(-diff)
return self
def split(self, separator, max_items: int = 0) -> List["TelegramMessage"]:
text_parts = self.text.split(separator, max_items - 1)
output = [] # type: List[TelegramMessage]
offset = 0
for part in text_parts:
msg = TelegramMessage(part)
for entity in self.entities:
start_in_range = len(part) > entity.offset - offset >= 0
end_in_range = len(part) >= entity.offset - offset + entity.length > 0
if start_in_range and end_in_range:
msg.entities.append(Entity.adjust(entity, offset_diff(-offset)))
output.append(msg)
offset += len(part)
offset += len(separator)
return output
@staticmethod
def join(items: List[Union[str, "TelegramMessage"]], separator: str = " ") -> "TelegramMessage":
main = TelegramMessage()
for msg in items:
if isinstance(msg, str):
msg = TelegramMessage(text=msg)
main.entities += Entity.adjust(msg.entities, offset_diff(len(main.text)))
main.text += msg.text + separator
main.text = main.text[:-len(separator)]
return main
class MatrixParser(MatrixParserCommon): class MatrixParser(MatrixParserCommon):
@classmethod @classmethod
def list_to_tmessage(cls, node: html.HtmlElement, strip_linebreaks) -> TelegramMessage: def list_to_tmessage(cls, node: html.HtmlElement, ctx: RecursionContext) -> TelegramMessage:
ordered = node.tag == "ol" ordered = node.tag == "ol"
tagged_children = cls.node_to_tagged_tmessages(node, strip_linebreaks) tagged_children = cls.node_to_tagged_tmessages(node, ctx)
counter = 1 counter = 1
indent_length = 0 indent_length = 0
if ordered: if ordered:
@@ -194,7 +77,7 @@ class MatrixParser(MatrixParserCommon):
prefix = f"{counter}. " prefix = f"{counter}. "
counter += 1 counter += 1
else: else:
prefix = "" prefix = cls.list_bullet(ctx.ul_depth)
child = child.prepend(prefix) child = child.prepend(prefix)
parts = child.split("\n") parts = child.split("\n")
parts = parts[:1] + [part.prepend(indent) for part in parts[1:]] parts = parts[:1] + [part.prepend(indent) for part in parts[1:]]
@@ -203,41 +86,43 @@ class MatrixParser(MatrixParserCommon):
return TelegramMessage.join(children, "\n") return TelegramMessage.join(children, "\n")
@classmethod @classmethod
def blockquote_to_tmessage(cls, node: html.HtmlElement, strip_linebreaks) -> TelegramMessage: def blockquote_to_tmessage(cls, node: html.HtmlElement, ctx: RecursionContext
msg = cls.tag_aware_parse_node(node, strip_linebreaks) ) -> TelegramMessage:
msg = cls.tag_aware_parse_node(node, ctx)
children = msg.trim().split("\n") children = msg.trim().split("\n")
children = [child.prepend("> ") for child in children] children = [child.prepend("> ") for child in children]
return TelegramMessage.join(children, "\n") return TelegramMessage.join(children, "\n")
@classmethod @classmethod
def header_to_tmessage(cls, node: html.HtmlElement, strip_linebreaks) -> TelegramMessage: def header_to_tmessage(cls, node: html.HtmlElement, ctx: RecursionContext) -> TelegramMessage:
children = cls.node_to_tmessages(node, strip_linebreaks) children = cls.node_to_tmessages(node, ctx)
length = int(node.tag[1]) length = int(node.tag[1])
prefix = "#" * length + " " prefix = "#" * length + " "
return TelegramMessage.join(children, "").prepend(prefix) return TelegramMessage.join(children, "").prepend(prefix).format(Bold)
@classmethod @classmethod
def basic_format_to_tmessage(cls, node: html.HtmlElement, strip_linebreaks) -> TelegramMessage: def basic_format_to_tmessage(cls, node: html.HtmlElement, ctx: RecursionContext
msg = cls.tag_aware_parse_node(node, strip_linebreaks) ) -> TelegramMessage:
msg = cls.tag_aware_parse_node(node, ctx)
if node.tag in ("b", "strong"): if node.tag in ("b", "strong"):
msg.format(Bold) msg.format(Bold)
elif node.tag in ("i", "em"): elif node.tag in ("i", "em"):
msg.format(Italic) msg.format(Italic)
elif node.tag == "command": elif node.tag == "command":
msg.format(Command) msg.format(Command)
elif node.tag in ("s", "del"): elif node.tag in ("s", "strike", "del"):
msg.text = html_to_unicode(msg.text, "\u0336") msg.text = html_to_unicode(msg.text, "\u0336")
elif node.tag in ("u", "ins"): elif node.tag in ("u", "ins"):
msg.text = html_to_unicode(msg.text, "\u0332") msg.text = html_to_unicode(msg.text, "\u0332")
if node.tag in ("s", "del", "u", "ins"): if node.tag in ("s", "strike", "del", "u", "ins"):
msg.entities = Entity.adjust(msg.entities, offset_length_multiply(2)) msg.entities = Entity.adjust(msg.entities, offset_length_multiply(2))
return msg return msg
@classmethod @classmethod
def link_to_tstring(cls, node: html.HtmlElement, strip_linebreaks) -> TelegramMessage: def link_to_tstring(cls, node: html.HtmlElement, ctx: RecursionContext) -> TelegramMessage:
msg = cls.tag_aware_parse_node(node, strip_linebreaks) msg = cls.tag_aware_parse_node(node, ctx)
href = node.attrib.get("href", "") href = node.attrib.get("href", "")
if not href: if not href:
return msg return msg
@@ -247,7 +132,7 @@ class MatrixParser(MatrixParserCommon):
mention = cls.mention_regex.match(href) mention = cls.mention_regex.match(href)
if mention: if mention:
mxid = mention.group(1) mxid = MatrixUserID(mention.group(1))
user = (pu.Puppet.get_by_mxid(mxid) user = (pu.Puppet.get_by_mxid(mxid)
or u.User.get_by_mxid(mxid, create=False)) or u.User.get_by_mxid(mxid, create=False))
if not user: if not user:
@@ -271,73 +156,81 @@ class MatrixParser(MatrixParserCommon):
else msg.format(TextURL, url=href)) else msg.format(TextURL, url=href))
@classmethod @classmethod
def node_to_tmessage(cls, node: html.HtmlElement, strip_linebreaks) -> TelegramMessage: def node_to_tmessage(cls, node: html.HtmlElement, ctx: RecursionContext) -> TelegramMessage:
if node.tag == "blockquote": if node.tag == "blockquote":
return cls.blockquote_to_tmessage(node, strip_linebreaks) return cls.blockquote_to_tmessage(node, ctx)
elif node.tag in ("ol", "ul"): elif node.tag == "ol":
return cls.list_to_tmessage(node, strip_linebreaks) return cls.list_to_tmessage(node, ctx)
elif node.tag == "ul":
return cls.list_to_tmessage(node, ctx.enter_list())
elif node.tag in ("h1", "h2", "h3", "h4", "h5", "h6"): elif node.tag in ("h1", "h2", "h3", "h4", "h5", "h6"):
return cls.header_to_tmessage(node, strip_linebreaks) return cls.header_to_tmessage(node, ctx)
elif node.tag == "br": elif node.tag == "br":
return TelegramMessage("\n") return TelegramMessage("\n")
elif node.tag in ("b", "strong", "i", "em", "s", "del", "u", "ins", "command"): elif node.tag in ("b", "strong", "i", "em", "s", "del", "u", "ins", "command"):
return cls.basic_format_to_tmessage(node, strip_linebreaks) return cls.basic_format_to_tmessage(node, ctx)
elif node.tag == "a": elif node.tag == "a":
return cls.link_to_tstring(node, strip_linebreaks) return cls.link_to_tstring(node, ctx)
elif node.tag == "p": elif node.tag == "p":
return cls.tag_aware_parse_node(node, strip_linebreaks).append("\n") return cls.tag_aware_parse_node(node, ctx).append("\n")
elif node.tag == "pre": elif node.tag == "pre":
lang = "" lang = ""
try: try:
if node[0].tag == "code": if node[0].tag == "code":
lang = node[0].attrib["class"][len("language-"):]
node = node[0] node = node[0]
lang = node.attrib["class"][len("language-"):]
except (IndexError, KeyError): except (IndexError, KeyError):
pass pass
return cls.parse_node(node, strip_linebreaks=False).format(Pre, language=lang) return cls.parse_node(node, ctx.enter_code_block()).format(Pre, language=lang)
elif node.tag == "code": elif node.tag == "code":
return cls.parse_node(node, strip_linebreaks=False).format(Code) return cls.parse_node(node, ctx.enter_code_block()).format(Code)
return cls.tag_aware_parse_node(node, strip_linebreaks) return cls.tag_aware_parse_node(node, ctx)
@staticmethod @staticmethod
def text_to_tmessage(text: str, strip_linebreaks: bool = True) -> TelegramMessage: def text_to_tmessage(text: str, ctx: RecursionContext) -> TelegramMessage:
if strip_linebreaks: if ctx.strip_linebreaks:
text = text.replace("\n", "") text = text.replace("\n", "")
return TelegramMessage(text) return TelegramMessage(text)
@classmethod @classmethod
def node_to_tagged_tmessages(cls, node: html.HtmlElement, strip_linebreaks: bool = True def node_to_tagged_tmessages(cls, node: html.HtmlElement, ctx: RecursionContext
) -> List[Tuple[TelegramMessage, str]]: ) -> List[Tuple[TelegramMessage, str]]:
output = [] output = []
if node.text: if node.text:
output.append((cls.text_to_tmessage(node.text, strip_linebreaks), "text")) output.append((cls.text_to_tmessage(node.text, ctx), "text"))
for child in node: for child in node:
output.append((cls.node_to_tmessage(child, strip_linebreaks), child.tag)) output.append((cls.node_to_tmessage(child, ctx), child.tag))
if child.tail: if child.tail:
output.append((cls.text_to_tmessage(child.tail, strip_linebreaks), "text")) output.append((cls.text_to_tmessage(child.tail, ctx), "text"))
return output return output
@classmethod @classmethod
def node_to_tmessages(cls, node: html.HtmlElement, strip_linebreaks) -> List[TelegramMessage]: def node_to_tmessages(cls, node: html.HtmlElement, ctx: RecursionContext
return [msg for (msg, tag) in cls.node_to_tagged_tmessages(node, strip_linebreaks)] ) -> List[TelegramMessage]:
return [msg for (msg, tag) in cls.node_to_tagged_tmessages(node, ctx)]
@classmethod @classmethod
def tag_aware_parse_node(cls, node: html.HtmlElement, strip_linebreaks) -> TelegramMessage: def tag_aware_parse_node(cls, node: html.HtmlElement, ctx: RecursionContext
msgs = cls.node_to_tagged_tmessages(node, strip_linebreaks) ) -> TelegramMessage:
msgs = cls.node_to_tagged_tmessages(node, ctx)
output = TelegramMessage() output = TelegramMessage()
prev_was_block = False
for msg, tag in msgs: for msg, tag in msgs:
if tag in cls.block_tags: if tag in cls.block_tags:
msg = msg.append("\n").prepend("\n") msg = msg.append("\n")
if not prev_was_block:
msg = msg.prepend("\n")
prev_was_block = True
output = output.append(msg) output = output.append(msg)
return output.trim() return output.trim()
@classmethod @classmethod
def parse_node(cls, node: html.HtmlElement, strip_linebreaks) -> TelegramMessage: def parse_node(cls, node: html.HtmlElement, ctx: RecursionContext) -> TelegramMessage:
return TelegramMessage.join(cls.node_to_tmessages(node, strip_linebreaks)) return TelegramMessage.join(cls.node_to_tmessages(node, ctx))
@classmethod @classmethod
def parse(cls, data: str) -> ParsedMessage: def parse(cls, data: str) -> ParsedMessage:
document = html.fromstring(f"<html>{data}</html>") document = html.fromstring(f"<html>{data}</html>")
msg = cls.parse_node(document, strip_linebreaks=True) msg = cls.parse_node(document, RecursionContext())
return msg.text, msg.entities return msg.text, msg.entities
@@ -0,0 +1,157 @@
# -*- coding: future_fstrings -*-
# mautrix-telegram - A Matrix-Telegram puppeting bridge
# Copyright (C) 2018 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Callable, List, Optional, Sequence, Type, Union
from telethon.tl.types import (MessageEntityMentionName as MentionName,
MessageEntityTextUrl as TextURL, MessageEntityPre as Pre,
TypeMessageEntity, InputMessageEntityMentionName as InputMentionName)
class Entity:
@staticmethod
def copy(entity: TypeMessageEntity) -> Optional[TypeMessageEntity]:
if not entity:
return None
kwargs = {
"offset": entity.offset,
"length": entity.length,
}
if isinstance(entity, Pre):
kwargs["language"] = entity.language
elif isinstance(entity, TextURL):
kwargs["url"] = entity.url
elif isinstance(entity, (MentionName, InputMentionName)):
kwargs["user_id"] = entity.user_id
return entity.__class__(**kwargs)
@classmethod
def adjust(cls, entity: Union[TypeMessageEntity, List[TypeMessageEntity]],
func: Callable[[TypeMessageEntity], None]
) -> Union[Optional[TypeMessageEntity], List[TypeMessageEntity]]:
if isinstance(entity, list):
return [Entity.adjust(element, func) for element in entity if entity]
elif not entity:
return None
entity = cls.copy(entity)
func(entity)
if entity.offset < 0:
entity.length += entity.offset
entity.offset = 0
return entity
def offset_diff(amount: int) -> Callable[[TypeMessageEntity], None]:
def func(entity: TypeMessageEntity) -> None:
entity.offset += amount
return func
def offset_length_multiply(amount: int) -> Callable[[TypeMessageEntity], None]:
def func(entity: TypeMessageEntity) -> None:
entity.offset *= amount
entity.length *= amount
return func
class TelegramMessage:
def __init__(self, text: str = "", entities: Optional[List[TypeMessageEntity]] = None) -> None:
self.text = text # type: str
self.entities = entities or [] # type: List[TypeMessageEntity]
def offset_entities(self, offset: int) -> 'TelegramMessage':
def apply_offset(entity: TypeMessageEntity, inner_offset: int
) -> Optional[TypeMessageEntity]:
entity = Entity.copy(entity)
entity.offset += inner_offset
if entity.offset < 0:
entity.offset = 0
elif entity.offset > len(self.text):
return None
elif entity.offset + entity.length > len(self.text):
entity.length = len(self.text) - entity.offset
return entity
self.entities = [apply_offset(entity, offset) for entity in self.entities if entity]
self.entities = [x for x in self.entities if x is not None]
return self
def append(self, *args: Union[str, 'TelegramMessage']) -> 'TelegramMessage':
for msg in args:
if isinstance(msg, str):
msg = TelegramMessage(text=msg)
self.entities += Entity.adjust(msg.entities, offset_diff(len(self.text)))
self.text += msg.text
return self
def prepend(self, *args: Union[str, 'TelegramMessage']) -> 'TelegramMessage':
for msg in args:
if isinstance(msg, str):
msg = TelegramMessage(text=msg)
self.entities = msg.entities + Entity.adjust(self.entities, offset_diff(len(msg.text)))
self.text = msg.text + self.text
return self
def format(self, entity_type: Type[TypeMessageEntity], offset: int = None, length: int = None,
**kwargs) -> 'TelegramMessage':
self.entities.append(entity_type(offset=offset or 0,
length=length if length is not None else len(self.text),
**kwargs))
return self
def concat(self, *args: Union[str, 'TelegramMessage']) -> 'TelegramMessage':
return TelegramMessage().append(self, *args)
def trim(self) -> 'TelegramMessage':
orig_len = len(self.text)
self.text = self.text.lstrip()
diff = orig_len - len(self.text)
self.text = self.text.rstrip()
self.offset_entities(-diff)
return self
def split(self, separator, max_items: int = 0) -> List['TelegramMessage']:
text_parts = self.text.split(separator, max_items - 1)
output = [] # type: List[TelegramMessage]
offset = 0
for part in text_parts:
msg = TelegramMessage(part)
for entity in self.entities:
start_in_range = len(part) > entity.offset - offset >= 0
end_in_range = len(part) >= entity.offset - offset + entity.length > 0
if start_in_range and end_in_range:
msg.entities.append(Entity.adjust(entity, offset_diff(-offset)))
output.append(msg)
offset += len(part)
offset += len(separator)
return output
@staticmethod
def join(items: Sequence[Union[str, 'TelegramMessage']],
separator: str = " ") -> 'TelegramMessage':
main = TelegramMessage()
for msg in items:
if isinstance(msg, str):
msg = TelegramMessage(text=msg)
main.entities += Entity.adjust(msg.entities, offset_diff(len(main.text)))
main.text += msg.text + separator
main.text = main.text[:-len(separator)]
return main
+28 -23
View File
@@ -14,21 +14,23 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional, List, Tuple, TYPE_CHECKING from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
from html import escape from html import escape
import logging import logging
import re import re
from telethon.tl.types import (MessageEntityMention, MessageEntityMentionName, from telethon.tl.types import (MessageEntityMention, MessageEntityMentionName, MessageEntityUrl,
MessageEntityEmail, MessageEntityUrl, MessageEntityTextUrl, MessageEntityEmail, MessageEntityTextUrl, MessageEntityBold,
MessageEntityBold, MessageEntityItalic, MessageEntityCode, MessageEntityItalic, MessageEntityCode, MessageEntityPre,
MessageEntityPre, MessageEntityBotCommand, Message, PeerChannel, MessageEntityBotCommand, MessageEntityHashtag, MessageEntityCashtag,
MessageEntityHashtag, TypeMessageEntity, MessageFwdHeader, PeerUser) MessageEntityPhone, TypeMessageEntity, Message, PeerChannel,
MessageFwdHeader, PeerUser)
from mautrix_appservice import MatrixRequestError from mautrix_appservice import MatrixRequestError
from mautrix_appservice.intent_api import IntentAPI from mautrix_appservice.intent_api import IntentAPI
from .. import user as u, puppet as pu, portal as po from .. import user as u, puppet as pu, portal as po
from ..types import TelegramID
from ..db import Message as DBMessage from ..db import Message as DBMessage
from .util import (add_surrogates, remove_surrogates, trim_reply_fallback_html, from .util import (add_surrogates, remove_surrogates, trim_reply_fallback_html,
trim_reply_fallback_text, unicode_to_html) trim_reply_fallback_text, unicode_to_html)
@@ -40,19 +42,19 @@ if TYPE_CHECKING:
try: try:
from lxml.html.diff import htmldiff from lxml.html.diff import htmldiff
except ImportError: except ImportError:
htmldiff = None # type: function htmldiff = None # type: ignore
log = logging.getLogger("mau.fmt.tg") # type: logging.Logger log = logging.getLogger("mau.fmt.tg") # type: logging.Logger
should_highlight_edits = False # type: bool should_highlight_edits = False # type: bool
def telegram_reply_to_matrix(evt: Message, source: "AbstractUser") -> dict: def telegram_reply_to_matrix(evt: Message, source: 'AbstractUser') -> Dict:
if evt.reply_to_msg_id: if evt.reply_to_msg_id:
space = (evt.to_id.channel_id space = (evt.to_id.channel_id
if isinstance(evt, Message) and isinstance(evt.to_id, PeerChannel) if isinstance(evt, Message) and isinstance(evt.to_id, PeerChannel)
else source.tgid) else source.tgid)
msg = DBMessage.query.get((evt.reply_to_msg_id, space)) msg = DBMessage.get_by_tgid(evt.reply_to_msg_id, space)
if msg: if msg:
return { return {
"m.in_reply_to": { "m.in_reply_to": {
@@ -75,7 +77,7 @@ async def _add_forward_header(source, text: str, html: Optional[str],
fwd_from_html = f"<a href='https://matrix.to/#/{user.mxid}'>{fwd_from_text}</a>" fwd_from_html = f"<a href='https://matrix.to/#/{user.mxid}'>{fwd_from_text}</a>"
if not fwd_from_text: if not fwd_from_text:
puppet = pu.Puppet.get(fwd_from.from_id, create=False) puppet = pu.Puppet.get(TelegramID(fwd_from.from_id), create=False)
if puppet and puppet.displayname: if puppet and puppet.displayname:
fwd_from_text = puppet.displayname or puppet.mxid fwd_from_text = puppet.displayname or puppet.mxid
fwd_from_html = f"<a href='https://matrix.to/#/{puppet.mxid}'>{fwd_from_text}</a>" fwd_from_html = f"<a href='https://matrix.to/#/{puppet.mxid}'>{fwd_from_text}</a>"
@@ -116,13 +118,13 @@ def highlight_edits(new_html: str, old_html: str) -> str:
async def _add_reply_header(source: "AbstractUser", text: str, html: str, evt: Message, async def _add_reply_header(source: "AbstractUser", text: str, html: str, evt: Message,
relates_to: dict, main_intent: IntentAPI, is_edit: bool relates_to: Dict, main_intent: IntentAPI, is_edit: bool
) -> Tuple[str, str]: ) -> Tuple[str, str]:
space = (evt.to_id.channel_id space = (evt.to_id.channel_id
if isinstance(evt, Message) and isinstance(evt.to_id, PeerChannel) if isinstance(evt, Message) and isinstance(evt.to_id, PeerChannel)
else source.tgid) else source.tgid)
msg = DBMessage.query.get((evt.reply_to_msg_id, space)) msg = DBMessage.get_by_tgid(evt.reply_to_msg_id, space)
if not msg: if not msg:
return text, html return text, html
@@ -177,10 +179,10 @@ async def _add_reply_header(source: "AbstractUser", text: str, html: str, evt: M
async def telegram_to_matrix(evt: Message, source: "AbstractUser", async def telegram_to_matrix(evt: Message, source: "AbstractUser",
main_intent: Optional[IntentAPI] = None, main_intent: Optional[IntentAPI] = None,
is_edit: bool = False, prefix_text: Optional[str] = None, is_edit: bool = False, prefix_text: Optional[str] = None,
prefix_html: Optional[str] = None) -> Tuple[str, str, dict]: prefix_html: Optional[str] = None) -> Tuple[str, str, Dict]:
text = add_surrogates(evt.message) text = add_surrogates(evt.message)
html = _telegram_entities_to_matrix_catch(text, evt.entities) if evt.entities else None html = _telegram_entities_to_matrix_catch(text, evt.entities) if evt.entities else None
relates_to = {} relates_to = {} # type: Dict
if prefix_html: if prefix_html:
html = prefix_html + (html or escape(text)) html = prefix_html + (html or escape(text))
@@ -217,6 +219,7 @@ def _telegram_entities_to_matrix_catch(text: str, entities: List[TypeMessageEnti
"message=%s\n" "message=%s\n"
"entities=%s", "entities=%s",
text, entities) text, entities)
return "[failed conversion in _telegram_entities_to_matrix]"
def _telegram_entities_to_matrix(text: str, entities: List[TypeMessageEntity]) -> str: def _telegram_entities_to_matrix(text: str, entities: List[TypeMessageEntity]) -> str:
@@ -239,21 +242,23 @@ def _telegram_entities_to_matrix(text: str, entities: List[TypeMessageEntity]) -
elif entity_type == MessageEntityItalic: elif entity_type == MessageEntityItalic:
html.append(f"<em>{entity_text}</em>") html.append(f"<em>{entity_text}</em>")
elif entity_type == MessageEntityCode: elif entity_type == MessageEntityCode:
html.append(f"<code>{entity_text}</code>") html.append(f"<pre><code>{entity_text}</code></pre>"
if "\n" in entity_text
else f"<code>{entity_text}</code>")
elif entity_type == MessageEntityPre: elif entity_type == MessageEntityPre:
skip_entity = _parse_pre(html, entity_text, entity.language) skip_entity = _parse_pre(html, entity_text, entity.language)
elif entity_type == MessageEntityMention: elif entity_type == MessageEntityMention:
skip_entity = _parse_mention(html, entity_text) skip_entity = _parse_mention(html, entity_text)
elif entity_type == MessageEntityMentionName: elif entity_type == MessageEntityMentionName:
skip_entity = _parse_name_mention(html, entity_text, entity.user_id) skip_entity = _parse_name_mention(html, entity_text, TelegramID(entity.user_id))
elif entity_type == MessageEntityEmail: elif entity_type == MessageEntityEmail:
html.append(f"<a href='mailto:{entity_text}'>{entity_text}</a>") html.append(f"<a href='mailto:{entity_text}'>{entity_text}</a>")
elif entity_type in {MessageEntityTextUrl, MessageEntityUrl}: elif entity_type in (MessageEntityTextUrl, MessageEntityUrl):
skip_entity = _parse_url(html, entity_text, skip_entity = _parse_url(html, entity_text,
entity.url if entity_type == MessageEntityTextUrl else None) entity.url if entity_type == MessageEntityTextUrl else None)
elif entity_type == MessageEntityBotCommand: elif entity_type == MessageEntityBotCommand:
html.append(f"<font color='blue'>!{entity_text[1:]}</font>") html.append(f"<font color='blue'>!{entity_text[1:]}</font>")
elif entity_type == MessageEntityHashtag: elif entity_type in (MessageEntityHashtag, MessageEntityCashtag, MessageEntityPhone):
html.append(f"<font color='blue'>{entity_text}</font>") html.append(f"<font color='blue'>{entity_text}</font>")
else: else:
skip_entity = True skip_entity = True
@@ -290,7 +295,7 @@ def _parse_mention(html: List[str], entity_text: str) -> bool:
return False return False
def _parse_name_mention(html: List[str], entity_text: str, user_id: int) -> bool: def _parse_name_mention(html: List[str], entity_text: str, user_id: TelegramID) -> bool:
user = u.User.get_by_tgid(user_id) user = u.User.get_by_tgid(user_id)
if user: if user:
mxid = user.mxid mxid = user.mxid
@@ -315,12 +320,12 @@ def _parse_url(html: List[str], entity_text: str, url: str) -> bool:
message_link_match = message_link_regex.match(url) message_link_match = message_link_regex.match(url)
if message_link_match: if message_link_match:
group, msgid = message_link_match.groups() group, msgid_str = message_link_match.groups()
msgid = int(msgid) msgid = int(msgid_str)
portal = po.Portal.find_by_username(group) portal = po.Portal.find_by_username(group)
if portal: if portal:
message = DBMessage.query.get((msgid, portal.tgid)) message = DBMessage.get_by_tgid(TelegramID(msgid), portal.tgid)
if message: if message:
url = f"https://matrix.to/#/{portal.mxid}/{message.mxid}" url = f"https://matrix.to/#/{portal.mxid}/{message.mxid}"
@@ -328,6 +333,6 @@ def _parse_url(html: List[str], entity_text: str, url: str) -> bool:
return False return False
def init_tg(context: "Context"): def init_tg(context: "Context") -> None:
global should_highlight_edits global should_highlight_edits
should_highlight_edits = htmldiff and context.config["bridge.highlight_edits"] should_highlight_edits = htmldiff and context.config["bridge.highlight_edits"]
+37 -34
View File
@@ -20,40 +20,6 @@ import struct
import re import re
# add_surrogates and remove_surrogates are unicode surrogate utility functions from Telethon.
# Licensed under the MIT license.
# https://github.com/LonamiWebs/Telethon/blob/master/telethon/extensions/markdown.py
def add_surrogates(text: Optional[str]) -> Optional[str]:
if text is None:
return None
return "".join("".join(chr(y) for y in struct.unpack("<HH", x.encode("utf-16-le")))
if (0x10000 <= ord(x) <= 0x10FFFF) else x for x in text)
def remove_surrogates(text: Optional[str]) -> Optional[str]:
if text is None:
return None
return text.encode("utf-16", "surrogatepass").decode("utf-16")
def trim_reply_fallback_text(text: str) -> str:
if not text.startswith("> ") or "\n" not in text:
return text
lines = text.split("\n")
while len(lines) > 0 and lines[0].startswith("> "):
lines.pop(0)
return "\n".join(lines)
html_reply_fallback_regex = re.compile("^<mx-reply>"
r"[\s\S]+?"
"</mx-reply>") # type: Pattern
def trim_reply_fallback_html(html: str) -> str:
return html_reply_fallback_regex.sub("", html)
def unicode_to_html(text: str, html: str, ctrl: str, tag: str) -> str: def unicode_to_html(text: str, html: str, ctrl: str, tag: str) -> str:
if ctrl not in text: if ctrl not in text:
return html return html
@@ -84,3 +50,40 @@ def unicode_to_html(text: str, html: str, ctrl: str, tag: str) -> str:
def html_to_unicode(text: str, ctrl: str) -> str: def html_to_unicode(text: str, ctrl: str) -> str:
return ctrl.join(text) + ctrl return ctrl.join(text) + ctrl
# add_surrogates and remove_surrogates are unicode surrogate utility functions from Telethon.
# Licensed under the MIT license.
# https://github.com/LonamiWebs/Telethon/blob/7cce7aa3e4c6c7019a55530391b1761d33e5a04e/telethon/helpers.py
def add_surrogates(text: Optional[str]) -> Optional[str]:
if text is None:
return None
return "".join("".join(chr(y) for y in struct.unpack("<HH", x.encode("utf-16-le")))
if (0x10000 <= ord(x) <= 0x10FFFF) else x for x in text)
def remove_surrogates(text: Optional[str]) -> Optional[str]:
if text is None:
return None
return text.encode("utf-16", "surrogatepass").decode("utf-16")
# trim_reply_fallback_text, html_reply_fallback_regex and trim_reply_fallback_html are Matrix
# reply fallback utility functions.
# You may copy and use them under any OSI-approved license.
def trim_reply_fallback_text(text: str) -> str:
if not text.startswith("> ") or "\n" not in text:
return text
lines = text.split("\n")
while len(lines) > 0 and lines[0].startswith("> "):
lines.pop(0)
return "\n".join(lines)
html_reply_fallback_regex = re.compile("^<mx-reply>"
r"[\s\S]+?"
"</mx-reply>") # type: Pattern
def trim_reply_fallback_html(html: str) -> str:
return html_reply_fallback_regex.sub("", html)
+61 -42
View File
@@ -14,27 +14,31 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import List, Dict, Tuple, Set, Match from typing import Dict, List, Match, Optional, Set, Tuple, TYPE_CHECKING
import logging import logging
import asyncio import asyncio
import re import re
from mautrix_appservice import MatrixRequestError, IntentError from mautrix_appservice import MatrixRequestError, IntentError
from .types import MatrixEvent, MatrixEventID, MatrixRoomID, MatrixUserID
from . import user as u, portal as po, puppet as pu, commands as com from . import user as u, portal as po, puppet as pu, commands as com
if TYPE_CHECKING:
from .context import Context
class MatrixHandler: class MatrixHandler:
log = logging.getLogger("mau.mx") # type: logging.Logger log = logging.getLogger("mau.mx") # type: logging.Logger
def __init__(self, context): def __init__(self, context: 'Context') -> None:
self.az, self.db, self.config, _, self.tgbot = context self.az, self.db, self.config, _, self.tgbot = context.core
self.commands = com.CommandProcessor(context) # type: com.CommandProcessor self.commands = com.CommandProcessor(context) # type: com.CommandProcessor
self.previously_typing = [] # type: List[str] self.previously_typing = [] # type: List[MatrixUserID]
self.az.matrix_event_handler(self.handle_event) self.az.matrix_event_handler(self.handle_event)
async def init_as_bot(self): async def init_as_bot(self) -> None:
displayname = self.config["appservice.bot_displayname"] displayname = self.config["appservice.bot_displayname"]
if displayname: if displayname:
try: try:
@@ -50,7 +54,8 @@ class MatrixHandler:
except asyncio.TimeoutError: except asyncio.TimeoutError:
self.log.exception("TimeoutError when trying to set avatar") self.log.exception("TimeoutError when trying to set avatar")
async def handle_puppet_invite(self, room_id, puppet: pu.Puppet, inviter: u.User): async def handle_puppet_invite(self, room_id: MatrixRoomID, puppet: pu.Puppet, inviter: u.User
) -> None:
intent = puppet.default_mxid_intent intent = puppet.default_mxid_intent
self.log.debug(f"{inviter} invited puppet for {puppet.tgid} to {room_id}") self.log.debug(f"{inviter} invited puppet for {puppet.tgid} to {room_id}")
if not await inviter.is_logged_in(): if not await inviter.is_logged_in():
@@ -80,6 +85,7 @@ class MatrixHandler:
await intent.join_room(room_id) await intent.join_room(room_id)
portal = po.Portal.get_by_tgid(puppet.tgid, inviter.tgid, "user") portal = po.Portal.get_by_tgid(puppet.tgid, inviter.tgid, "user")
# TODO: if portal is None:
if portal.mxid: if portal.mxid:
try: try:
await intent.invite(portal.mxid, inviter.mxid) await intent.invite(portal.mxid, inviter.mxid)
@@ -95,13 +101,13 @@ class MatrixHandler:
portal.mxid = room_id portal.mxid = room_id
portal.save() portal.save()
inviter.register_portal(portal) inviter.register_portal(portal)
await intent.send_notice(room_id, "po.Portal to private chat created.") await intent.send_notice(room_id, "Portal to private chat created.")
else: else:
await intent.join_room(room_id) await intent.join_room(room_id)
await intent.send_notice(room_id, "This puppet will remain inactive until a " await intent.send_notice(room_id, "This puppet will remain inactive until a "
"Telegram chat is created for this room.") "Telegram chat is created for this room.")
async def accept_bot_invite(self, room_id: str, inviter: u.User): async def accept_bot_invite(self, room_id: MatrixRoomID, inviter: u.User) -> None:
tries = 0 tries = 0
while tries < 5: while tries < 5:
try: try:
@@ -120,15 +126,19 @@ class MatrixHandler:
if not inviter.whitelisted: if not inviter.whitelisted:
await self.az.intent.send_notice( await self.az.intent.send_notice(
room_id, text=None, room_id, text="",
html="You are not whitelisted to use this bridge.<br/><br/>" html="You are not whitelisted to use this bridge.<br/><br/>"
"If you are the owner of this bridge, see the " "If you are the owner of this bridge, see the "
"<code>bridge.permissions</code> section in your config file.") "<code>bridge.permissions</code> section in your config file.")
await self.az.intent.leave_room(room_id) await self.az.intent.leave_room(room_id)
async def handle_invite(self, room_id: str, user_id: str, inviter_mxid: str): async def handle_invite(self, room_id: MatrixRoomID, user_id: MatrixUserID,
inviter_mxid: MatrixUserID) -> None:
self.log.debug(f"{inviter_mxid} invited {user_id} to {room_id}") self.log.debug(f"{inviter_mxid} invited {user_id} to {room_id}")
inviter = await u.User.get_by_mxid(inviter_mxid).ensure_started() inviter = u.User.get_by_mxid(inviter_mxid)
if inviter is None:
self.log.exception("Failed to find user with Matrix ID {inviter_mxid}")
await inviter.ensure_started()
if user_id == self.az.bot_mxid: if user_id == self.az.bot_mxid:
return await self.accept_bot_invite(room_id, inviter) return await self.accept_bot_invite(room_id, inviter)
elif not inviter.whitelisted: elif not inviter.whitelisted:
@@ -150,7 +160,8 @@ class MatrixHandler:
# The rest can probably be ignored # The rest can probably be ignored
async def handle_join(self, room_id: str, user_id: str, event_id: str): async def handle_join(self, room_id: MatrixRoomID, user_id: MatrixUserID,
event_id: MatrixEventID) -> None:
user = await u.User.get_by_mxid(user_id).ensure_started() user = await u.User.get_by_mxid(user_id).ensure_started()
portal = po.Portal.get_by_mxid(room_id) portal = po.Portal.get_by_mxid(room_id)
@@ -171,7 +182,8 @@ class MatrixHandler:
if await user.is_logged_in() or portal.has_bot: if await user.is_logged_in() or portal.has_bot:
await portal.join_matrix(user, event_id) await portal.join_matrix(user, event_id)
async def handle_part(self, room_id: str, user_id, sender_mxid: str, event_id: str): async def handle_part(self, room_id: MatrixRoomID, user_id: MatrixUserID,
sender_mxid: MatrixUserID, event_id: MatrixEventID) -> None:
self.log.debug(f"{user_id} left {room_id}") self.log.debug(f"{user_id} left {room_id}")
sender = u.User.get_by_mxid(sender_mxid, create=False) sender = u.User.get_by_mxid(sender_mxid, create=False)
@@ -184,8 +196,10 @@ class MatrixHandler:
return return
puppet = pu.Puppet.get_by_mxid(user_id) puppet = pu.Puppet.get_by_mxid(user_id)
if sender and puppet: if puppet:
await portal.leave_matrix(puppet, sender, event_id) if sender:
await portal.kick_matrix(puppet, sender)
return
user = u.User.get_by_mxid(user_id, create=False) user = u.User.get_by_mxid(user_id, create=False)
if not user: if not user:
@@ -194,7 +208,7 @@ class MatrixHandler:
if await user.is_logged_in() or portal.has_bot: if await user.is_logged_in() or portal.has_bot:
await portal.leave_matrix(user, sender, event_id) await portal.leave_matrix(user, sender, event_id)
def is_command(self, message: dict) -> Tuple[bool, str]: def is_command(self, message: Dict) -> Tuple[bool, str]:
text = message.get("body", "") text = message.get("body", "")
prefix = self.config["bridge.command_prefix"] prefix = self.config["bridge.command_prefix"]
is_command = text.startswith(prefix) is_command = text.startswith(prefix)
@@ -202,9 +216,10 @@ class MatrixHandler:
text = text[len(prefix) + 1:] text = text[len(prefix) + 1:]
return is_command, text return is_command, text
async def handle_message(self, room, sender, message, event_id): async def handle_message(self, room: MatrixRoomID, sender_id: MatrixUserID, message: Dict,
event_id: MatrixEventID) -> None:
is_command, text = self.is_command(message) is_command, text = self.is_command(message)
sender = await u.User.get_by_mxid(sender).ensure_started() sender = await u.User.get_by_mxid(sender_id).ensure_started()
if not sender.relaybot_whitelisted: if not sender.relaybot_whitelisted:
self.log.debug(f"Ignoring message \"{message}\" from {sender} to {room}:" self.log.debug(f"Ignoring message \"{message}\" from {sender} to {room}:"
" u.User is not whitelisted.") " u.User is not whitelisted.")
@@ -237,7 +252,8 @@ class MatrixHandler:
is_portal=portal is not None) is_portal=portal is not None)
@staticmethod @staticmethod
async def handle_redaction(room_id: str, sender_mxid: str, event_id: str): async def handle_redaction(room_id: MatrixRoomID, sender_mxid: MatrixUserID,
event_id: MatrixEventID) -> None:
sender = await u.User.get_by_mxid(sender_mxid).ensure_started() sender = await u.User.get_by_mxid(sender_mxid).ensure_started()
if not sender.relaybot_whitelisted: if not sender.relaybot_whitelisted:
return return
@@ -249,14 +265,16 @@ class MatrixHandler:
await portal.handle_matrix_deletion(sender, event_id) await portal.handle_matrix_deletion(sender, event_id)
@staticmethod @staticmethod
async def handle_power_levels(room_id: str, sender_mxid: str, new: dict, old: dict): async def handle_power_levels(room_id: MatrixRoomID, sender_mxid: MatrixUserID,
new: Dict, old: Dict) -> None:
portal = po.Portal.get_by_mxid(room_id) portal = po.Portal.get_by_mxid(room_id)
sender = await u.User.get_by_mxid(sender_mxid).ensure_started() sender = await u.User.get_by_mxid(sender_mxid).ensure_started()
if await sender.has_full_access(allow_bot=True) and portal: if await sender.has_full_access(allow_bot=True) and portal:
await portal.handle_matrix_power_levels(sender, new["users"], old["users"]) await portal.handle_matrix_power_levels(sender, new["users"], old["users"])
@staticmethod @staticmethod
async def handle_room_meta(evt_type: str, room_id: str, sender_mxid: str, content: dict): async def handle_room_meta(evt_type: str, room_id: MatrixRoomID, sender_mxid: MatrixUserID,
content: dict) -> None:
portal = po.Portal.get_by_mxid(room_id) portal = po.Portal.get_by_mxid(room_id)
sender = await u.User.get_by_mxid(sender_mxid).ensure_started() sender = await u.User.get_by_mxid(sender_mxid).ensure_started()
if await sender.has_full_access(allow_bot=True) and portal: if await sender.has_full_access(allow_bot=True) and portal:
@@ -270,8 +288,8 @@ class MatrixHandler:
await handler(sender, content[content_key]) await handler(sender, content[content_key])
@staticmethod @staticmethod
async def handle_room_pin(room_id: str, sender_mxid: str, new_events: Set[str], async def handle_room_pin(room_id: MatrixRoomID, sender_mxid: MatrixUserID,
old_events: Set[str]): new_events: Set[str], old_events: Set[str]) -> None:
portal = po.Portal.get_by_mxid(room_id) portal = po.Portal.get_by_mxid(room_id)
sender = await u.User.get_by_mxid(sender_mxid).ensure_started() sender = await u.User.get_by_mxid(sender_mxid).ensure_started()
if await sender.has_full_access(allow_bot=True) and portal: if await sender.has_full_access(allow_bot=True) and portal:
@@ -284,8 +302,8 @@ class MatrixHandler:
await portal.handle_matrix_pin(sender, None) await portal.handle_matrix_pin(sender, None)
@staticmethod @staticmethod
async def handle_name_change(room_id: str, user_id: str, displayname: str, async def handle_name_change(room_id: MatrixRoomID, user_id: MatrixUserID, displayname: str,
prev_displayname: str, event_id: str): prev_displayname: str, event_id: MatrixEventID) -> None:
portal = po.Portal.get_by_mxid(room_id) portal = po.Portal.get_by_mxid(room_id)
if not portal or not portal.has_bot: if not portal or not portal.has_bot:
return return
@@ -295,13 +313,14 @@ class MatrixHandler:
await portal.name_change_matrix(user, displayname, prev_displayname, event_id) await portal.name_change_matrix(user, displayname, prev_displayname, event_id)
@staticmethod @staticmethod
def parse_read_receipts(content: dict) -> Dict[str, str]: def parse_read_receipts(content: Dict) -> Dict[MatrixUserID, MatrixEventID]:
return {user_id: event_id return {user_id: event_id
for event_id, receipts in content.items() for event_id, receipts in content.items()
for user_id in receipts.get("m.read", {})} for user_id in receipts.get("m.read", {})}
@staticmethod @staticmethod
async def handle_read_receipts(room_id: str, receipts: Dict[str, str]): async def handle_read_receipts(room_id: MatrixRoomID,
receipts: Dict[MatrixUserID, MatrixEventID]) -> None:
portal = po.Portal.get_by_mxid(room_id) portal = po.Portal.get_by_mxid(room_id)
if not portal: if not portal:
return return
@@ -313,13 +332,13 @@ class MatrixHandler:
await portal.mark_read(user, event_id) await portal.mark_read(user, event_id)
@staticmethod @staticmethod
async def handle_presence(user_id: str, presence: str): async def handle_presence(user_id: MatrixUserID, presence: str) -> None:
user = await u.User.get_by_mxid(user_id).ensure_started() user = await u.User.get_by_mxid(user_id).ensure_started()
if not await user.is_logged_in(): if not await user.is_logged_in():
return return
await user.set_presence(presence == "online") await user.set_presence(presence == "online")
async def handle_typing(self, room_id: str, now_typing: List[str]): async def handle_typing(self, room_id: MatrixRoomID, now_typing: List[MatrixUserID]) -> None:
portal = po.Portal.get_by_mxid(room_id) portal = po.Portal.get_by_mxid(room_id)
if not portal: if not portal:
return return
@@ -338,38 +357,38 @@ class MatrixHandler:
self.previously_typing = now_typing self.previously_typing = now_typing
def filter_matrix_event(self, event: dict): def filter_matrix_event(self, event: MatrixEvent) -> bool:
sender = event.get("sender", None) sender = event.get("sender", None)
if not sender: if not sender:
return False return False
return (sender == self.az.bot_mxid return (sender == self.az.bot_mxid
or pu.Puppet.get_id_from_mxid(sender) is not None) or pu.Puppet.get_id_from_mxid(sender) is not None)
async def try_handle_event(self, evt: dict): async def try_handle_event(self, evt: MatrixEvent) -> None:
try: try:
await self.handle_event(evt) await self.handle_event(evt)
except Exception: except Exception:
self.log.exception("Error handling manually received Matrix event") self.log.exception("Error handling manually received Matrix event")
async def handle_event(self, evt: dict): async def handle_event(self, evt: MatrixEvent) -> None:
if self.filter_matrix_event(evt): if self.filter_matrix_event(evt):
return return
self.log.debug("Received event: %s", evt) self.log.debug("Received event: %s", evt)
evt_type = evt.get("type", "m.unknown") # type: str evt_type = evt.get("type", "m.unknown") # type: str
room_id = evt.get("room_id", None) # type: str room_id = evt.get("room_id", None) # type: Optional[MatrixRoomID]
event_id = evt.get("event_id", None) # type: str event_id = evt.get("event_id", None) # type: Optional[MatrixEventID]
sender = evt.get("sender", None) # type: str sender = evt.get("sender", None) # type: Optional[MatrixUserID]
content = evt.get("content", {}) # type: dict content = evt.get("content", {}) # type: Dict
if evt_type == "m.room.member": if evt_type == "m.room.member":
state_key = evt["state_key"] # type: str state_key = evt["state_key"] # type: MatrixUserID
prev_content = evt.get("unsigned", {}).get("prev_content", {}) # type: dict prev_content = evt.get("unsigned", {}).get("prev_content", {}) # type: Dict
membership = content.get("membership", "") # type: str membership = content.get("membership", "") # type: str
prev_membership = prev_content.get("membership", "leave") # type: str prev_membership = prev_content.get("membership", "leave") # type: str
if membership == prev_membership: if membership == prev_membership:
match = re.compile("@(.+):(.+)").match(state_key) # type: Match match = re.compile("@(.+):(.+)").match(state_key) # type: Match
localpart = match.group(1) # type: str mxid = match.group(0) # type: str
displayname = content.get("displayname", localpart) # type: str displayname = content.get("displayname", None) or mxid # type: str
prev_displayname = prev_content.get("displayname", localpart) # type: str prev_displayname = prev_content.get("displayname", None) or mxid # type: str
if displayname != prev_displayname: if displayname != prev_displayname:
await self.handle_name_change(room_id, state_key, displayname, await self.handle_name_change(room_id, state_key, displayname,
prev_displayname, event_id) prev_displayname, event_id)
@@ -386,7 +405,7 @@ class MatrixHandler:
elif evt_type == "m.room.redaction": elif evt_type == "m.room.redaction":
await self.handle_redaction(room_id, sender, evt["redacts"]) await self.handle_redaction(room_id, sender, evt["redacts"])
elif evt_type == "m.room.power_levels": elif evt_type == "m.room.power_levels":
prev_content = evt.get("unsigned", {}).get("prev_content", {}) # type: dict prev_content = evt.get("unsigned", {}).get("prev_content", {})
await self.handle_power_levels(room_id, sender, evt["content"], prev_content) await self.handle_power_levels(room_id, sender, evt["content"], prev_content)
elif evt_type in ("m.room.name", "m.room.avatar", "m.room.topic"): elif evt_type in ("m.room.name", "m.room.avatar", "m.room.topic"):
await self.handle_room_meta(evt_type, room_id, sender, evt["content"]) await self.handle_room_meta(evt_type, room_id, sender, evt["content"])
+316 -212
View File
File diff suppressed because it is too large Load Diff
+112 -86
View File
@@ -14,17 +14,20 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional, Awaitable, Pattern, Dict, List, TYPE_CHECKING from typing import Awaitable, Coroutine, Dict, List, Optional, Pattern, TYPE_CHECKING
from difflib import SequenceMatcher from difflib import SequenceMatcher
import re from enum import Enum
import logging from aiohttp import ServerDisconnectedError
import asyncio import asyncio
import logging
import re
from sqlalchemy import orm from sqlalchemy import orm
from telethon.tl.types import UserProfilePhoto from telethon.tl.types import UserProfilePhoto, User, FileLocation
from mautrix_appservice import AppService, IntentAPI, IntentError, MatrixRequestError from mautrix_appservice import AppService, IntentAPI, IntentError, MatrixRequestError
from .types import MatrixUserID, TelegramID
from .db import Puppet as DBPuppet from .db import Puppet as DBPuppet
from . import util from . import util
@@ -32,6 +35,10 @@ if TYPE_CHECKING:
from .matrix import MatrixHandler from .matrix import MatrixHandler
from .config import Config from .config import Config
from .context import Context from .context import Context
from . import user as u
from .abstract_user import AbstractUser
PuppetError = Enum('PuppetError', 'Success OnlyLoginSelf InvalidAccessToken')
config = None # type: Config config = None # type: Config
@@ -45,87 +52,101 @@ class Puppet:
mxid_regex = None # type: Pattern mxid_regex = None # type: Pattern
username_template = None # type: str username_template = None # type: str
hs_domain = None # type: str hs_domain = None # type: str
cache = {} # type: Dict[str, Puppet] cache = {} # type: Dict[TelegramID, Puppet]
by_custom_mxid = {} # type: Dict[str, Puppet] by_custom_mxid = {} # type: Dict[str, Puppet]
def __init__(self, id=None, access_token=None, custom_mxid=None, username=None, def __init__(self,
displayname=None, displayname_source=None, photo_id=None, is_bot=None, id: TelegramID,
is_registered=False, db_instance=None): access_token: Optional[str] = None,
self.id = id custom_mxid: Optional[MatrixUserID] = None,
self.access_token = access_token username: Optional[str] = None,
self.custom_mxid = custom_mxid displayname: Optional[str] = None,
self.is_real_user = self.custom_mxid and self.access_token displayname_source: Optional[TelegramID] = None,
self.default_mxid = self.get_mxid_from_id(self.id) photo_id: Optional[str] = None,
self.mxid = self.custom_mxid or self.default_mxid is_bot: bool = False,
is_registered: bool = False,
db_instance: Optional[DBPuppet] = None) -> None:
self.id = id # type: TelegramID
self.access_token = access_token # type: Optional[str]
self.custom_mxid = custom_mxid # type: Optional[MatrixUserID]
self.default_mxid = self.get_mxid_from_id(self.id) # type: MatrixUserID
self.username = username self.username = username # type: Optional[str]
self.displayname = displayname self.displayname = displayname # type: Optional[str]
self.displayname_source = displayname_source self.displayname_source = displayname_source # type: Optional[TelegramID]
self.photo_id = photo_id self.photo_id = photo_id # type: Optional[str]
self.is_bot = is_bot self.is_bot = is_bot # type: bool
self.is_registered = is_registered self.is_registered = is_registered # type: bool
self._db_instance = db_instance self._db_instance = db_instance # type: Optional[DBPuppet]
self.default_mxid_intent = self.az.intent.user(self.default_mxid) self.default_mxid_intent = self.az.intent.user(self.default_mxid)
self.intent = None # type: IntentAPI self.intent = self._fresh_intent() # type: IntentAPI
self.refresh_intents()
self.cache[id] = self self.cache[id] = self
if self.custom_mxid: if self.custom_mxid:
self.by_custom_mxid[self.custom_mxid] = self self.by_custom_mxid[self.custom_mxid] = self
@property @property
def tgid(self): def mxid(self) -> MatrixUserID:
return self.custom_mxid or self.default_mxid
@property
def tgid(self) -> TelegramID:
return self.id return self.id
@property
def is_real_user(self) -> bool:
""" Is True when the puppet is a real Matrix user. """
return bool(self.custom_mxid and self.access_token)
@staticmethod @staticmethod
async def is_logged_in(): async def is_logged_in() -> bool:
""" Is True if the puppet is logged in. """
return True return True
# region Custom puppet management # region Custom puppet management
def refresh_intents(self): def _fresh_intent(self) -> IntentAPI:
self.is_real_user = self.custom_mxid and self.access_token return (self.az.intent.user(self.custom_mxid, self.access_token)
self.intent = (self.az.intent.user(self.custom_mxid, self.access_token) if self.is_real_user else self.default_mxid_intent)
if self.is_real_user else self.default_mxid_intent)
async def switch_mxid(self, access_token, mxid): async def switch_mxid(self, access_token: Optional[str],
mxid: Optional[MatrixUserID]) -> PuppetError:
prev_mxid = self.custom_mxid prev_mxid = self.custom_mxid
self.custom_mxid = mxid self.custom_mxid = mxid
self.access_token = access_token self.access_token = access_token
self.refresh_intents() self.intent = self._fresh_intent()
err = await self.init_custom_mxid() err = await self.init_custom_mxid()
if err != 0: if err != PuppetError.Success:
return err return err
try: try:
del self.by_custom_mxid[prev_mxid] del self.by_custom_mxid[prev_mxid] # type: ignore
except KeyError: except KeyError:
pass pass
self.mxid = self.custom_mxid or self.default_mxid
if self.mxid != self.default_mxid: if self.mxid != self.default_mxid:
self.by_custom_mxid[self.mxid] = self self.by_custom_mxid[self.mxid] = self
await self.leave_rooms_with_default_user() await self.leave_rooms_with_default_user()
self.save() self.save()
return 0 return PuppetError.Success
async def init_custom_mxid(self): async def init_custom_mxid(self) -> PuppetError:
if not self.is_real_user: if not self.is_real_user:
return 0 return PuppetError.Success
mxid = await self.intent.whoami() mxid = await self.intent.whoami()
if not mxid or mxid != self.custom_mxid: if not mxid or mxid != self.custom_mxid:
self.custom_mxid = None self.custom_mxid = None
self.access_token = None self.access_token = None
self.refresh_intents() self.intent = self._fresh_intent()
if mxid != self.custom_mxid: if mxid != self.custom_mxid:
return 2 return PuppetError.OnlyLoginSelf
return 1 return PuppetError.InvalidAccessToken
if config["bridge.sync_with_custom_puppets"]: if config["bridge.sync_with_custom_puppets"]:
asyncio.ensure_future(self.sync(), loop=self.loop) asyncio.ensure_future(self.sync(), loop=self.loop)
return 0 return PuppetError.Success
async def leave_rooms_with_default_user(self): async def leave_rooms_with_default_user(self) -> None:
for room_id in await self.default_mxid_intent.get_joined_rooms(): for room_id in await self.default_mxid_intent.get_joined_rooms():
try: try:
await self.default_mxid_intent.leave_room(room_id) await self.default_mxid_intent.leave_room(room_id)
@@ -159,7 +180,7 @@ class Puppet:
}, },
}) })
def filter_events(self, events): def filter_events(self, events: List[Dict]) -> List:
new_events = [] new_events = []
for event in events: for event in events:
evt_type = event.get("type", None) evt_type = event.get("type", None)
@@ -186,28 +207,28 @@ class Puppet:
new_events.append(event) new_events.append(event)
return new_events return new_events
def handle_sync(self, presence, ephemeral): def handle_sync(self, presence: List, ephemeral: Dict) -> None:
presence = [self.mx.try_handle_event(event) for event in presence] presence_events = [self.mx.try_handle_event(event) for event in presence]
for room_id, events in ephemeral.items(): for room_id, events in ephemeral.items():
for event in events: for event in events:
event["room_id"] = room_id event["room_id"] = room_id
ephemeral = [self.mx.try_handle_event(event) ephemeral_events = [self.mx.try_handle_event(event)
for events in ephemeral.values() for events in ephemeral.values()
for event in self.filter_events(events)] for event in self.filter_events(events)]
events = ephemeral + presence events = ephemeral_events + presence_events # List[Callable[[int], Awaitable[None]]]
coro = asyncio.gather(*events, loop=self.loop) coro = asyncio.gather(*events, loop=self.loop)
asyncio.ensure_future(coro, loop=self.loop) asyncio.ensure_future(coro, loop=self.loop)
async def sync(self): async def sync(self) -> None:
try: try:
await self._sync() await self._sync()
except Exception: except Exception:
self.log.exception("Fatal error syncing") self.log.exception("Fatal error syncing")
async def _sync(self): async def _sync(self) -> None:
if not self.is_real_user: if not self.is_real_user:
self.log.warning("Called sync() for non-custom puppet.") self.log.warning("Called sync() for non-custom puppet.")
return return
@@ -220,16 +241,17 @@ class Puppet:
while access_token_at_start == self.access_token: while access_token_at_start == self.access_token:
try: try:
sync_resp = await self.intent.client.sync(filter=filter_id, since=next_batch, sync_resp = await self.intent.client.sync(filter=filter_id, since=next_batch,
set_presence="offline") set_presence="offline") # type: Dict
errors = 0 errors = 0
if next_batch is not None: if next_batch is not None:
presence = sync_resp.get("presence", {}).get("events", []) presence = sync_resp.get("presence", {}).get("events", []) # type: List
ephemeral = {room: data.get("ephemeral", {}).get("events", []) ephemeral = {room: data.get("ephemeral", {}).get("events", [])
for room, data for room, data
in sync_resp.get("rooms", {}).get("join", {}).items()} in sync_resp.get("rooms", {}).get("join", {}).items()
} # type: Dict
self.handle_sync(presence, ephemeral) self.handle_sync(presence, ephemeral)
next_batch = sync_resp.get("next_batch", None) next_batch = sync_resp.get("next_batch", None)
except MatrixRequestError as e: except (MatrixRequestError, ServerDisconnectedError) as e:
wait = min(errors, 11) ** 2 wait = min(errors, 11) ** 2
self.log.warning(f"Syncer for {custom_mxid} errored: {e}. " self.log.warning(f"Syncer for {custom_mxid} errored: {e}. "
f"Waiting for {wait} seconds...") f"Waiting for {wait} seconds...")
@@ -241,25 +263,25 @@ class Puppet:
# region DB conversion # region DB conversion
@property @property
def db_instance(self): def db_instance(self) -> DBPuppet:
if not self._db_instance: if not self._db_instance:
self._db_instance = self.new_db_instance() self._db_instance = self.new_db_instance()
return self._db_instance return self._db_instance
def new_db_instance(self): def new_db_instance(self) -> DBPuppet:
return DBPuppet(id=self.id, access_token=self.access_token, custom_mxid=self.custom_mxid, return DBPuppet(id=self.id, access_token=self.access_token, custom_mxid=self.custom_mxid,
username=self.username, displayname=self.displayname, username=self.username, displayname=self.displayname,
displayname_source=self.displayname_source, photo_id=self.photo_id, displayname_source=self.displayname_source, photo_id=self.photo_id,
is_bot=self.is_bot, matrix_registered=self.is_registered) is_bot=self.is_bot, matrix_registered=self.is_registered)
@classmethod @classmethod
def from_db(cls, db_puppet): def from_db(cls, db_puppet: DBPuppet) -> 'Puppet':
return Puppet(db_puppet.id, db_puppet.access_token, db_puppet.custom_mxid, return Puppet(db_puppet.id, db_puppet.access_token, db_puppet.custom_mxid,
db_puppet.username, db_puppet.displayname, db_puppet.displayname_source, db_puppet.username, db_puppet.displayname, db_puppet.displayname_source,
db_puppet.photo_id, db_puppet.is_bot, db_puppet.matrix_registered, db_puppet.photo_id, db_puppet.is_bot, db_puppet.matrix_registered,
db_instance=db_puppet) db_instance=db_puppet)
def save(self): def save(self) -> None:
self.db_instance.access_token = self.access_token self.db_instance.access_token = self.access_token
self.db_instance.custom_mxid = self.custom_mxid self.db_instance.custom_mxid = self.custom_mxid
self.db_instance.username = self.username self.db_instance.username = self.username
@@ -272,16 +294,16 @@ class Puppet:
# endregion # endregion
# region Info updating # region Info updating
def similarity(self, query): def similarity(self, query: str) -> int:
username_similarity = (SequenceMatcher(None, self.username, query).ratio() username_similarity = (SequenceMatcher(None, self.username, query).ratio()
if self.username else 0) if self.username else 0)
displayname_similarity = (SequenceMatcher(None, self.displayname, query).ratio() displayname_similarity = (SequenceMatcher(None, self.displayname, query).ratio()
if self.displayname else 0) if self.displayname else 0)
similarity = max(username_similarity, displayname_similarity) similarity = max(username_similarity, displayname_similarity)
return round(similarity * 1000) / 10 return int(round(similarity * 100))
@staticmethod @staticmethod
def get_displayname(info, enable_format=True): def get_displayname(info: User, enable_format: bool = True) -> str:
data = { data = {
"phone number": info.phone if hasattr(info, "phone") else None, "phone number": info.phone if hasattr(info, "phone") else None,
"username": info.username, "username": info.username,
@@ -308,7 +330,7 @@ class Puppet:
return config.get("bridge.displayname_template", "{displayname} (Telegram)").format( return config.get("bridge.displayname_template", "{displayname} (Telegram)").format(
displayname=name) displayname=name)
async def update_info(self, source, info): async def update_info(self, source: 'AbstractUser', info: User) -> None:
changed = False changed = False
if self.username != info.username: if self.username != info.username:
self.username = info.username self.username = info.username
@@ -323,12 +345,12 @@ class Puppet:
if changed: if changed:
self.save() self.save()
async def update_displayname(self, source, info): async def update_displayname(self, source: 'AbstractUser', info: User) -> bool:
ignore_source = (not source.is_relaybot ignore_source = (not source.is_relaybot
and self.displayname_source is not None and self.displayname_source is not None
and self.displayname_source != source.tgid) and self.displayname_source != source.tgid)
if ignore_source: if ignore_source:
return return False
displayname = self.get_displayname(info) displayname = self.get_displayname(info)
if displayname != self.displayname: if displayname != self.displayname:
@@ -339,8 +361,9 @@ class Puppet:
elif source.is_relaybot or self.displayname_source is None: elif source.is_relaybot or self.displayname_source is None:
self.displayname_source = source.tgid self.displayname_source = source.tgid
return True return True
return False
async def update_avatar(self, source, photo): async def update_avatar(self, source: 'AbstractUser', photo: FileLocation) -> bool:
photo_id = f"{photo.volume_id}-{photo.local_id}" photo_id = f"{photo.volume_id}-{photo.local_id}"
if self.photo_id != photo_id: if self.photo_id != photo_id:
file = await util.transfer_file_to_matrix(self.db, source.client, file = await util.transfer_file_to_matrix(self.db, source.client,
@@ -355,7 +378,7 @@ class Puppet:
# region Getters # region Getters
@classmethod @classmethod
def get(cls, tgid, create=True) -> "Optional[Puppet]": def get(cls, tgid: TelegramID, create: bool = True) -> Optional['Puppet']:
try: try:
return cls.cache[tgid] return cls.cache[tgid]
except KeyError: except KeyError:
@@ -374,12 +397,15 @@ class Puppet:
return None return None
@classmethod @classmethod
def get_by_mxid(cls, mxid, create=True) -> "Optional[Puppet]": def get_by_mxid(cls, mxid: MatrixUserID, create: bool = True) -> Optional['Puppet']:
tgid = cls.get_id_from_mxid(mxid) tgid = cls.get_id_from_mxid(mxid)
return cls.get(tgid, create) if tgid else None if tgid:
return cls.get(tgid, create)
return None
@classmethod @classmethod
def get_by_custom_mxid(cls, mxid): def get_by_custom_mxid(cls, mxid: MatrixUserID) -> Optional['Puppet']:
if not mxid: if not mxid:
raise ValueError("Matrix ID can't be empty") raise ValueError("Matrix ID can't be empty")
@@ -396,25 +422,25 @@ class Puppet:
return None return None
@classmethod @classmethod
def get_all_with_custom_mxid(cls): def get_all_with_custom_mxid(cls) -> List['Puppet']:
return [cls.by_custom_mxid[puppet.mxid] return [cls.by_custom_mxid[puppet.mxid]
if puppet.custom_mxid in cls.by_custom_mxid if puppet.custom_mxid in cls.by_custom_mxid
else cls.from_db(puppet) else cls.from_db(puppet)
for puppet in DBPuppet.query.filter(DBPuppet.custom_mxid is not None).all()] for puppet in DBPuppet.query.filter(DBPuppet.custom_mxid is not None).all()]
@classmethod @classmethod
def get_id_from_mxid(cls, mxid): def get_id_from_mxid(cls, mxid: MatrixUserID) -> Optional[TelegramID]:
match = cls.mxid_regex.match(mxid) match = cls.mxid_regex.match(mxid)
if match: if match:
return int(match.group(1)) return TelegramID(int(match.group(1)))
return None return None
@classmethod @classmethod
def get_mxid_from_id(cls, tgid): def get_mxid_from_id(cls, tgid: TelegramID) -> MatrixUserID:
return f"@{cls.username_template.format(userid=tgid)}:{cls.hs_domain}" return MatrixUserID(f"@{cls.username_template.format(userid=tgid)}:{cls.hs_domain}")
@classmethod @classmethod
def find_by_username(cls, username) -> "Optional[Puppet]": def find_by_username(cls, username: str) -> Optional['Puppet']:
if not username: if not username:
return None return None
@@ -422,14 +448,14 @@ class Puppet:
if puppet.username and puppet.username.lower() == username.lower(): if puppet.username and puppet.username.lower() == username.lower():
return puppet return puppet
puppet = DBPuppet.query.filter(DBPuppet.username == username).one_or_none() dbpuppet = DBPuppet.query.filter(DBPuppet.username == username).one_or_none()
if puppet: if dbpuppet:
return cls.from_db(puppet) return cls.from_db(dbpuppet)
return None return None
@classmethod @classmethod
def find_by_displayname(cls, displayname) -> "Optional[Puppet]": def find_by_displayname(cls, displayname: str) -> Optional['Puppet']:
if not displayname: if not displayname:
return None return None
@@ -437,20 +463,20 @@ class Puppet:
if puppet.displayname and puppet.displayname == displayname: if puppet.displayname and puppet.displayname == displayname:
return puppet return puppet
puppet = DBPuppet.query.filter(DBPuppet.displayname == displayname).one_or_none() dbpuppet = DBPuppet.query.filter(DBPuppet.displayname == displayname).one_or_none()
if puppet: if dbpuppet:
return cls.from_db(puppet) return cls.from_db(dbpuppet)
return None return None
# endregion # endregion
def init(context: "Context") -> List[Awaitable[int]]: def init(context: 'Context') -> List[Coroutine]: # [None, None, PuppetError]
global config global config
Puppet.az, Puppet.db, config, Puppet.loop, _ = context Puppet.az, Puppet.db, config, Puppet.loop, _ = context.core
Puppet.mx = context.mx Puppet.mx = context.mx
Puppet.username_template = config.get("bridge.username_template", "telegram_{userid}") Puppet.username_template = config.get("bridge.username_template", "telegram_{userid}")
Puppet.hs_domain = config["homeserver"]["domain"] Puppet.hs_domain = config["homeserver"]["domain"]
Puppet.mxid_regex = re.compile( Puppet.mxid_regex = re.compile(
f"@{Puppet.username_template.format(userid='(.+)')}:{Puppet.hs_domain}") f"@{Puppet.username_template.format(userid='([0-9]+)')}:{Puppet.hs_domain}")
return [puppet.init_custom_mxid() for puppet in Puppet.get_all_with_custom_mxid()] return [puppet.init_custom_mxid() for puppet in Puppet.get_all_with_custom_mxid()]
@@ -32,15 +32,15 @@ TelematrixBase.metadata.bind = telematrix_db_engine
chat_links = telematrix.query(ChatLink).all() chat_links = telematrix.query(ChatLink).all()
tg_users = telematrix.query(TgUser).all() tg_users = telematrix.query(TgUser).all()
mx_users = telematrix.query(MatrixUser).all() mx_users = telematrix.query(MatrixUser).all()
messages = telematrix.query(TMMessage).all() tm_messages = telematrix.query(TMMessage).all()
telematrix.close() telematrix.close()
telematrix_db_engine.dispose() telematrix_db_engine.dispose()
portals = {} portals = {} # Dict[int, Portal]
chats = {} chats = {} # Dict[int, BotChat]
messages = {} messages = {} # Dict[str, Message]
puppets = {} puppets = {} # Dict[int, Puppet]
for chat_link in chat_links: for chat_link in chat_links:
if type(chat_link.tg_room) is str: if type(chat_link.tg_room) is str:
@@ -65,11 +65,12 @@ for chat_link in chat_links:
portals[chat_link.tg_room] = portal portals[chat_link.tg_room] = portal
chats[tgid] = bot_chat chats[tgid] = bot_chat
for tm_msg in messages: for tm_msg in tm_messages:
try: try:
portal = portals[tm_msg.tg_group_id] portal = portals[tm_msg.tg_group_id]
except KeyError: except KeyError:
print("Found message entry %d in unlinked chat %d, ignoring..." % (tm_msg.tg_message_id, tm_msg.tg_group_id)) print("Found message entry %d in unlinked chat %d, ignoring..." % (tm_msg.tg_message_id,
tm_msg.tg_group_id))
continue continue
tg_space = portal.tgid if portal.peer_type == "channel" else args.bot_id tg_space = portal.tgid if portal.peer_type == "channel" else args.bot_id
message = Message(mxid=tm_msg.matrix_event_id, mx_room=tm_msg.matrix_room_id, message = Message(mxid=tm_msg.matrix_event_id, mx_room=tm_msg.matrix_room_id,
@@ -77,7 +78,8 @@ for tm_msg in messages:
messages[tm_msg.matrix_event_id] = message messages[tm_msg.matrix_event_id] = message
for user in tg_users: for user in tg_users:
puppets[user.tg_id] = Puppet(id=user.tg_id, displayname=user.name, displayname_source=args.bot_id) puppets[user.tg_id] = Puppet(id=user.tg_id, displayname=user.name,
displayname_source=args.bot_id)
for k, v in portals.items(): for k, v in portals.items():
mxtg.add(v) mxtg.add(v)
@@ -5,7 +5,7 @@ Base = declarative_base()
class ChatLink(Base): class ChatLink(Base):
__tablename__ = 'chat_link' __tablename__ = "chat_link"
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column(sa.Integer, primary_key=True)
matrix_room = sa.Column(sa.String) matrix_room = sa.Column(sa.String)
@@ -14,7 +14,7 @@ class ChatLink(Base):
class TgUser(Base): class TgUser(Base):
__tablename__ = 'tg_user' __tablename__ = "tg_user"
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column(sa.Integer, primary_key=True)
tg_id = sa.Column(sa.BigInteger) tg_id = sa.Column(sa.BigInteger)
@@ -23,7 +23,7 @@ class TgUser(Base):
class MatrixUser(Base): class MatrixUser(Base):
__tablename__ = 'matrix_user' __tablename__ = "matrix_user"
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column(sa.Integer, primary_key=True)
matrix_id = sa.Column(sa.String) matrix_id = sa.Column(sa.String)
+15 -13
View File
@@ -20,37 +20,39 @@ from sqlalchemy import orm
from mautrix_appservice import StateStore from mautrix_appservice import StateStore
from .types import MatrixUserID, MatrixRoomID
from . import puppet as pu from . import puppet as pu
from .db import RoomState, UserProfile from .db import RoomState, UserProfile
class SQLStateStore(StateStore): class SQLStateStore(StateStore):
def __init__(self, db): def __init__(self, db: orm.Session) -> None:
super().__init__() super().__init__()
self.db = db # type: orm.Session self.db = db # type: orm.Session
self.profile_cache = {} # type: Dict[Tuple[str, str], UserProfile] self.profile_cache = {} # type: Dict[Tuple[str, str], UserProfile]
self.room_state_cache = {} # type: Dict[str, RoomState] self.room_state_cache = {} # type: Dict[str, RoomState]
@staticmethod @staticmethod
def is_registered(user: str) -> bool: def is_registered(user: MatrixUserID) -> bool:
puppet = pu.Puppet.get_by_mxid(user) puppet = pu.Puppet.get_by_mxid(user)
return puppet.is_registered if puppet else False return puppet.is_registered if puppet else False
@staticmethod @staticmethod
def registered(user: str): def registered(user: MatrixUserID) -> None:
puppet = pu.Puppet.get_by_mxid(user) puppet = pu.Puppet.get_by_mxid(user)
if puppet: if puppet:
puppet.is_registered = True puppet.is_registered = True
puppet.save() puppet.save()
def update_state(self, event: dict): def update_state(self, event: Dict) -> None:
event_type = event["type"] event_type = event["type"]
if event_type == "m.room.power_levels": if event_type == "m.room.power_levels":
self.set_power_levels(event["room_id"], event["content"]) self.set_power_levels(event["room_id"], event["content"])
elif event_type == "m.room.member": elif event_type == "m.room.member":
self.set_member(event["room_id"], event["state_key"], event["content"]) self.set_member(event["room_id"], event["state_key"], event["content"])
def _get_user_profile(self, room_id: str, user_id: str, create: bool = True) -> UserProfile: def _get_user_profile(self, room_id: MatrixRoomID, user_id: MatrixUserID, create: bool = True
) -> UserProfile:
key = (room_id, user_id) key = (room_id, user_id)
try: try:
return self.profile_cache[key] return self.profile_cache[key]
@@ -67,22 +69,22 @@ class SQLStateStore(StateStore):
self.profile_cache[key] = profile self.profile_cache[key] = profile
return profile return profile
def get_member(self, room: str, user: str) -> dict: def get_member(self, room: MatrixRoomID, user: MatrixUserID) -> Dict:
return self._get_user_profile(room, user).dict() return self._get_user_profile(room, user).dict()
def set_member(self, room: str, user: str, member: dict): def set_member(self, room: MatrixRoomID, user: MatrixUserID, member: Dict) -> None:
profile = self._get_user_profile(room, user) profile = self._get_user_profile(room, user)
profile.membership = member.get("membership", profile.membership or "leave") profile.membership = member.get("membership", profile.membership or "leave")
profile.displayname = member.get("displayname", profile.displayname) profile.displayname = member.get("displayname", profile.displayname)
profile.avatar_url = member.get("avatar_url", profile.avatar_url) profile.avatar_url = member.get("avatar_url", profile.avatar_url)
self.db.commit() self.db.commit()
def set_membership(self, room: str, user: str, membership: str): def set_membership(self, room: MatrixRoomID, user: MatrixUserID, membership: str) -> None:
self.set_member(room, user, { self.set_member(room, user, {
"membership": membership, "membership": membership,
}) })
def _get_room_state(self, room_id: str, create: bool = True) -> RoomState: def _get_room_state(self, room_id: MatrixRoomID, create: bool = True) -> RoomState:
try: try:
return self.room_state_cache[room_id] return self.room_state_cache[room_id]
except KeyError: except KeyError:
@@ -96,13 +98,13 @@ class SQLStateStore(StateStore):
self.room_state_cache[room_id] = room self.room_state_cache[room_id] = room
return room return room
def has_power_levels(self, room: str) -> bool: def has_power_levels(self, room: MatrixRoomID) -> bool:
return self._get_room_state(room).has_power_levels return self._get_room_state(room).has_power_levels
def get_power_levels(self, room: str) -> dict: def get_power_levels(self, room: MatrixRoomID) -> Dict:
return self._get_room_state(room).power_levels return self._get_room_state(room).power_levels
def set_power_level(self, room: str, user: str, level: int): def set_power_level(self, room: MatrixRoomID, user: MatrixUserID, level: int) -> None:
room_state = self._get_room_state(room) room_state = self._get_room_state(room)
power_levels = room_state.power_levels power_levels = room_state.power_levels
if not power_levels: if not power_levels:
@@ -114,7 +116,7 @@ class SQLStateStore(StateStore):
room_state.power_levels = power_levels room_state.power_levels = power_levels
self.db.commit() self.db.commit()
def set_power_levels(self, room: str, content: dict): def set_power_levels(self, room: MatrixRoomID, content: Dict) -> None:
state = self._get_room_state(room) state = self._get_room_state(room)
state.power_levels = content state.power_levels = content
self.db.commit() self.db.commit()
+5 -1
View File
@@ -14,9 +14,13 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import List, Union, Optional
from telethon import TelegramClient, utils from telethon import TelegramClient, utils
from telethon.tl.functions.messages import SendMediaRequest from telethon.tl.functions.messages import SendMediaRequest
from telethon.tl.types import * from telethon.tl.types import (
InputMediaUploadedDocument, InputMediaUploadedPhoto, TypeDocumentAttribute, TypeInputMedia,
TypeInputPeer, TypeMessageEntity, TypeMessageMedia, TypePeer)
from telethon.tl import custom from telethon.tl import custom
+9
View File
@@ -0,0 +1,9 @@
from typing import Dict, NewType
MatrixUserID = NewType('MatrixUserID', str)
MatrixRoomID = NewType('MatrixRoomID', str)
MatrixEventID = NewType('MatrixEventID', str)
MatrixEvent = NewType('MatrixEvent', Dict)
TelegramID = NewType('TelegramID', int)
+76 -56
View File
@@ -14,18 +14,21 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict, Awaitable, Optional, Match, Tuple, TYPE_CHECKING from typing import Coroutine, Dict, List, Match, NewType, Optional, Tuple, cast, TYPE_CHECKING
import logging import logging
import asyncio import asyncio
import re import re
from telethon.tl.types import * from telethon.tl.types import (
TypeUpdate, UpdateNewMessage, UpdateNewChannelMessage, PeerUser,
UpdateShortChatMessage, UpdateShortMessage)
from telethon.tl.types import User as TLUser from telethon.tl.types import User as TLUser
from telethon.tl.types.contacts import ContactsNotModified from telethon.tl.types.contacts import ContactsNotModified
from telethon.tl.functions.contacts import GetContactsRequest, SearchRequest from telethon.tl.functions.contacts import GetContactsRequest, SearchRequest
from telethon.tl.functions.account import UpdateStatusRequest from telethon.tl.functions.account import UpdateStatusRequest
from mautrix_appservice import MatrixRequestError from mautrix_appservice import MatrixRequestError
from .types import MatrixUserID, TelegramID
from .db import User as DBUser, Contact as DBContact, Portal as DBPortal from .db import User as DBUser, Contact as DBContact, Portal as DBPortal
from .abstract_user import AbstractUser from .abstract_user import AbstractUser
from . import portal as po, puppet as pu from . import portal as po, puppet as pu
@@ -36,7 +39,7 @@ if TYPE_CHECKING:
config = None # type: Config config = None # type: Config
SearchResults = List[Tuple["pu.Puppet", int]] SearchResult = NewType('SearchResult', Tuple['pu.Puppet', int])
class User(AbstractUser): class User(AbstractUser):
@@ -44,23 +47,26 @@ class User(AbstractUser):
by_mxid = {} # type: Dict[str, User] by_mxid = {} # type: Dict[str, User]
by_tgid = {} # type: Dict[int, User] by_tgid = {} # type: Dict[int, User]
def __init__(self, mxid: str, tgid: Optional[int] = None, username: Optional[str] = None, def __init__(self, mxid: MatrixUserID, tgid: Optional[TelegramID] = None,
db_contacts: Optional[List[DBContact]] = None, saved_contacts: int = 0, username: Optional[str] = None, phone: Optional[str] = None,
is_bot: bool = False, db_portals: Optional[List[DBPortal]] = None, db_contacts: Optional[List[DBContact]] = None,
db_instance: Optional[DBUser] = None): saved_contacts: int = 0, is_bot: bool = False,
db_portals: Optional[List[DBPortal]] = None,
db_instance: Optional[DBUser] = None) -> None:
super().__init__() super().__init__()
self.mxid = mxid # type: str self.mxid = mxid # type: MatrixUserID
self.tgid = tgid # type: int self.tgid = tgid # type: TelegramID
self.is_bot = is_bot # type: bool self.is_bot = is_bot # type: bool
self.username = username # type: str self.username = username # type: str
self.phone = phone # type: str
self.contacts = [] # type: List[pu.Puppet] self.contacts = [] # type: List[pu.Puppet]
self.saved_contacts = saved_contacts # type: int self.saved_contacts = saved_contacts # type: int
self.db_contacts = db_contacts # type: List[DBContact] self.db_contacts = db_contacts # type: List[DBContact]
self.portals = {} # type: Dict[Tuple[int, int], po.Portal] self.portals = {} # type: Dict[Tuple[int, int], po.Portal]
self.db_portals = db_portals # type: List[DBPortal] self.db_portals = db_portals or [] # type: List[DBPortal]
self._db_instance = db_instance # type: DBUser self._db_instance = db_instance # type: Optional[DBUser]
self.command_status = None # type: dict self.command_status = None # type: Dict
(self.relaybot_whitelisted, (self.relaybot_whitelisted,
self.whitelisted, self.whitelisted,
@@ -82,6 +88,10 @@ class User(AbstractUser):
match = re.compile("@(.+):(.+)").match(self.mxid) # type: Match match = re.compile("@(.+):(.+)").match(self.mxid) # type: Match
return match.group(1) return match.group(1)
@property
def human_tg_id(self) -> str:
return f"@{self.username}" if self.username else f"+{self.phone}" or None
# TODO replace with proper displayname getting everywhere # TODO replace with proper displayname getting everywhere
@property @property
def displayname(self) -> str: def displayname(self) -> str:
@@ -93,7 +103,7 @@ class User(AbstractUser):
for puppet in self.contacts] for puppet in self.contacts]
@db_contacts.setter @db_contacts.setter
def db_contacts(self, contacts: List[DBContact]): def db_contacts(self, contacts: List[DBContact]) -> None:
self.contacts = [pu.Puppet.get(entry.contact) for entry in contacts] if contacts else [] self.contacts = [pu.Puppet.get(entry.contact) for entry in contacts] if contacts else []
@property @property
@@ -101,10 +111,12 @@ class User(AbstractUser):
return [portal.db_instance for portal in self.portals.values() if not portal.deleted] return [portal.db_instance for portal in self.portals.values() if not portal.deleted]
@db_portals.setter @db_portals.setter
def db_portals(self, portals: List[DBPortal]): def db_portals(self, portals: List[DBPortal]) -> None:
self.portals = {(portal.tgid, portal.tg_receiver): self.portals = {
po.Portal.get_by_tgid(portal.tgid, portal.tg_receiver) (portal.tgid, portal.tg_receiver): po.Portal.get_by_tgid(portal.tgid,
for portal in portals} if portals else {} portal.tg_receiver)
for portal in portals
} if portals else {}
# region Database conversion # region Database conversion
@@ -116,18 +128,19 @@ class User(AbstractUser):
def new_db_instance(self) -> DBUser: def new_db_instance(self) -> DBUser:
return DBUser(mxid=self.mxid, tgid=self.tgid, tg_username=self.username, return DBUser(mxid=self.mxid, tgid=self.tgid, tg_username=self.username,
contacts=self.db_contacts, saved_contacts=self.saved_contacts or 0, contacts=self.db_contacts, saved_contacts=self.saved_contacts,
portals=self.db_portals) portals=self.db_portals)
def save(self): def save(self) -> None:
self.db_instance.tgid = self.tgid self.db_instance.tgid = self.tgid
self.db_instance.username = self.username self.db_instance.tg_username = self.username
self.db_instance.tg_phone = self.phone
self.db_instance.contacts = self.db_contacts self.db_instance.contacts = self.db_contacts
self.db_instance.saved_contacts = self.saved_contacts or 0 self.db_instance.saved_contacts = self.saved_contacts
self.db_instance.portals = self.db_portals self.db_instance.portals = self.db_portals
self.db.commit() self.db.commit()
def delete(self): def delete(self) -> None:
try: try:
del self.by_mxid[self.mxid] del self.by_mxid[self.mxid]
del self.by_tgid[self.tgid] del self.by_tgid[self.tgid]
@@ -138,14 +151,15 @@ class User(AbstractUser):
self.db.commit() self.db.commit()
@classmethod @classmethod
def from_db(cls, db_user: DBUser) -> "User": def from_db(cls, db_user: DBUser) -> 'User':
return User(db_user.mxid, db_user.tgid, db_user.tg_username, db_user.contacts, return User(db_user.mxid, db_user.tgid, db_user.tg_username, db_user.tg_phone,
False, db_user.saved_contacts, db_user.portals, db_instance=db_user) db_user.contacts, db_user.saved_contacts, False, db_user.portals,
db_instance=db_user)
# endregion # endregion
# region Telegram connection management # region Telegram connection management
async def start(self, delete_unless_authenticated: bool = False) -> "User": async def start(self, delete_unless_authenticated: bool = False) -> 'User':
await super().start() await super().start()
if await self.is_logged_in(): if await self.is_logged_in():
self.log.debug(f"Ensuring post_login() for {self.name}") self.log.debug(f"Ensuring post_login() for {self.name}")
@@ -156,7 +170,7 @@ class User(AbstractUser):
self.client.session.delete() self.client.session.delete()
return self return self
async def post_login(self, info: TLUser = None): async def post_login(self, info: TLUser = None) -> None:
try: try:
await self.update_info(info) await self.update_info(info)
if not self.is_bot: if not self.is_bot:
@@ -167,9 +181,9 @@ class User(AbstractUser):
except Exception: except Exception:
self.log.exception("Failed to run post-login functions for %s", self.mxid) self.log.exception("Failed to run post-login functions for %s", self.mxid)
async def update(self, update: TypeUpdate): async def update(self, update: TypeUpdate) -> bool:
if not self.is_bot: if not self.is_bot:
return return False
if isinstance(update, (UpdateNewMessage, UpdateNewChannelMessage)): if isinstance(update, (UpdateNewMessage, UpdateNewChannelMessage)):
message = update.message message = update.message
@@ -179,26 +193,28 @@ class User(AbstractUser):
else: else:
portal = po.Portal.get_by_entity(message.to_id, receiver_id=self.tgid) portal = po.Portal.get_by_entity(message.to_id, receiver_id=self.tgid)
elif isinstance(update, UpdateShortChatMessage): elif isinstance(update, UpdateShortChatMessage):
portal = po.Portal.get_by_tgid(update.chat_id, peer_type="chat") portal = po.Portal.get_by_tgid(TelegramID(update.chat_id), peer_type="chat")
elif isinstance(update, UpdateShortMessage): elif isinstance(update, UpdateShortMessage):
portal = po.Portal.get_by_tgid(update.user_id, self.tgid, "user") portal = po.Portal.get_by_tgid(TelegramID(update.user_id), self.tgid, "user")
else: else:
return return False
self.register_portal(portal) if portal:
self.register_portal(portal)
return True
# endregion # endregion
# region Telegram actions that need custom methods # region Telegram actions that need custom methods
def ensure_started(self, even_if_no_session: bool = False) -> "Awaitable[User]": def ensure_started(self, even_if_no_session: bool = False) -> Coroutine[None, None, 'User']:
return super().ensure_started(even_if_no_session) return cast(Coroutine[None, None, 'User'], super().ensure_started(even_if_no_session))
def set_presence(self, online: bool = True): async def set_presence(self, online: bool = True) -> None:
if self.is_bot: if not self.is_bot:
return await self.client(UpdateStatusRequest(offline=not online))
return self.client(UpdateStatusRequest(offline=not online))
async def update_info(self, info: TLUser = None): async def update_info(self, info: TLUser = None) -> None:
info = info or await self.client.get_me() info = info or await self.client.get_me()
changed = False changed = False
if self.is_bot != info.bot: if self.is_bot != info.bot:
@@ -207,13 +223,16 @@ class User(AbstractUser):
if self.username != info.username: if self.username != info.username:
self.username = info.username self.username = info.username
changed = True changed = True
if self.phone != info.phone:
self.phone = info.phone
changed = True
if self.tgid != info.id: if self.tgid != info.id:
self.tgid = info.id self.tgid = info.id
self.by_tgid[self.tgid] = self self.by_tgid[self.tgid] = self
if changed: if changed:
self.save() self.save()
async def log_out(self): async def log_out(self) -> bool:
puppet = pu.Puppet.get(self.tgid) puppet = pu.Puppet.get(self.tgid)
if puppet.is_real_user: if puppet.is_real_user:
await puppet.switch_mxid(None, None) await puppet.switch_mxid(None, None)
@@ -241,28 +260,29 @@ class User(AbstractUser):
return True return True
def _search_local(self, query: str, max_results: int = 5, min_similarity: int = 45 def _search_local(self, query: str, max_results: int = 5, min_similarity: int = 45
) -> SearchResults: ) -> List[SearchResult]:
results = [] # type: SearchResults results = [] # type: List[SearchResult]
for contact in self.contacts: for contact in self.contacts:
similarity = contact.similarity(query) similarity = contact.similarity(query)
if similarity >= min_similarity: if similarity >= min_similarity:
results.append((contact, similarity)) results.append(SearchResult((contact, similarity)))
results.sort(key=lambda tup: tup[1], reverse=True) results.sort(key=lambda tup: tup[1], reverse=True)
return results[0:max_results] return results[0:max_results]
async def _search_remote(self, query: str, max_results: int = 5) -> SearchResults: async def _search_remote(self, query: str, max_results: int = 5) -> List[SearchResult]:
if len(query) < 5: if len(query) < 5:
return [] return []
server_results = await self.client(SearchRequest(q=query, limit=max_results)) server_results = await self.client(SearchRequest(q=query, limit=max_results))
results = [] # type: SearchResults results = [] # type: List[SearchResult]
for user in server_results.users: for user in server_results.users:
puppet = pu.Puppet.get(user.id) puppet = pu.Puppet.get(user.id)
await puppet.update_info(self, user) await puppet.update_info(self, user)
results.append((puppet, puppet.similarity(query))) results.append(SearchResult((puppet, puppet.similarity(query))))
results.sort(key=lambda tup: tup[1], reverse=True) results.sort(key=lambda tup: tup[1], reverse=True)
return results[0:max_results] return results[0:max_results]
async def search(self, query: str, force_remote: bool = False) -> Tuple[SearchResults, bool]: async def search(self, query: str, force_remote: bool = False
) -> Tuple[List[SearchResult], bool]:
if force_remote: if force_remote:
return await self._search_remote(query), True return await self._search_remote(query), True
@@ -272,7 +292,7 @@ class User(AbstractUser):
return await self._search_remote(query), True return await self._search_remote(query), True
async def sync_dialogs(self, synchronous_create: bool = False): async def sync_dialogs(self, synchronous_create: bool = False) -> None:
creators = [] creators = []
for entity in await self.get_dialogs(limit=30): for entity in await self.get_dialogs(limit=30):
portal = po.Portal.get_by_entity(entity) portal = po.Portal.get_by_entity(entity)
@@ -283,7 +303,7 @@ class User(AbstractUser):
self.save() self.save()
await asyncio.gather(*creators, loop=self.loop) await asyncio.gather(*creators, loop=self.loop)
def register_portal(self, portal: po.Portal): def register_portal(self, portal: po.Portal) -> None:
try: try:
if self.portals[portal.tgid_full] == portal: if self.portals[portal.tgid_full] == portal:
return return
@@ -292,7 +312,7 @@ class User(AbstractUser):
self.portals[portal.tgid_full] = portal self.portals[portal.tgid_full] = portal
self.save() self.save()
def unregister_portal(self, portal: po.Portal): def unregister_portal(self, portal: po.Portal) -> None:
try: try:
del self.portals[portal.tgid_full] del self.portals[portal.tgid_full]
self.save() self.save()
@@ -309,7 +329,7 @@ class User(AbstractUser):
acc = (acc * 20261 + id) & 0xffffffff acc = (acc * 20261 + id) & 0xffffffff
return acc & 0x7fffffff return acc & 0x7fffffff
async def sync_contacts(self): async def sync_contacts(self) -> None:
response = await self.client(GetContactsRequest(hash=self._hash_contacts())) response = await self.client(GetContactsRequest(hash=self._hash_contacts()))
if isinstance(response, ContactsNotModified): if isinstance(response, ContactsNotModified):
return return
@@ -326,7 +346,7 @@ class User(AbstractUser):
# region Class instance lookup # region Class instance lookup
@classmethod @classmethod
def get_by_mxid(cls, mxid: str, create: bool=True) -> "Optional[User]": def get_by_mxid(cls, mxid: MatrixUserID, create: bool = True) -> Optional['User']:
if not mxid: if not mxid:
raise ValueError("Matrix ID can't be empty") raise ValueError("Matrix ID can't be empty")
@@ -349,7 +369,7 @@ class User(AbstractUser):
return None return None
@classmethod @classmethod
def get_by_tgid(cls, tgid: int) -> "Optional[User]": def get_by_tgid(cls, tgid: int) -> Optional['User']:
try: try:
return cls.by_tgid[tgid] return cls.by_tgid[tgid]
except KeyError: except KeyError:
@@ -363,7 +383,7 @@ class User(AbstractUser):
return None return None
@classmethod @classmethod
def find_by_username(cls, username: str) -> "Optional[User]": def find_by_username(cls, username: str) -> Optional['User']:
if not username: if not username:
return None return None
@@ -379,7 +399,7 @@ class User(AbstractUser):
# endregion # endregion
def init(context: "Context") -> List[Awaitable[User]]: def init(context: 'Context') -> List[Coroutine]: # [None, None, AbstractUser]
global config global config
config = context.config config = context.config
+1
View File
@@ -1,3 +1,4 @@
from .file_transfer import transfer_file_to_matrix, convert_image from .file_transfer import transfer_file_to_matrix, convert_image
from .format_duration import format_duration from .format_duration import format_duration
from .signed_token import sign_token, verify_token from .signed_token import sign_token, verify_token
from .recursive_dict import recursive_del, recursive_set, recursive_get
+2 -1
View File
@@ -27,7 +27,8 @@ from sqlalchemy.orm.exc import FlushError
from telethon.tl.types import (Document, FileLocation, InputFileLocation, from telethon.tl.types import (Document, FileLocation, InputFileLocation,
InputDocumentFileLocation, PhotoSize, PhotoCachedSize) InputDocumentFileLocation, PhotoSize, PhotoCachedSize)
from telethon.errors import * from telethon.errors import (AuthBytesInvalidError, AuthKeyInvalidError, LocationInvalidError,
SecurityError)
from mautrix_appservice import IntentAPI from mautrix_appservice import IntentAPI
from ..tgclient import MautrixTelegramClient from ..tgclient import MautrixTelegramClient
+2 -2
View File
@@ -17,10 +17,10 @@
def format_duration(seconds: int) -> str: def format_duration(seconds: int) -> str:
def pluralize(count, singular): def pluralize(count: int, singular: str) -> str:
return singular if count == 1 else singular + "s" return singular if count == 1 else singular + "s"
def include(count, word): def include(count: int, word: str) -> str:
return f"{count} {pluralize(count, word)}" if count > 0 else "" return f"{count} {pluralize(count, word)}" if count > 0 else ""
minutes, seconds = divmod(seconds, 60) minutes, seconds = divmod(seconds, 60)
+54
View File
@@ -0,0 +1,54 @@
# -*- coding: future_fstrings -*-
# mautrix-telegram - A Matrix-Telegram puppeting bridge
# Copyright (C) 2018 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict, Any
from ..config import DictWithRecursion
def recursive_set(data: Dict[str, Any], key: str, value: Any) -> bool:
key, next_key = DictWithRecursion._parse_key(key)
if next_key is not None:
if key not in data:
data[key] = {}
next_data = data.get(key, {})
if not isinstance(next_data, dict):
return False
return recursive_set(next_data, next_key, value)
data[key] = value
return True
def recursive_get(data: Dict[str, Any], key: str) -> Any:
key, next_key = DictWithRecursion._parse_key(key)
if next_key is not None:
next_data = data.get(key, None)
if not next_data:
return None
return recursive_get(next_data, next_key)
return data.get(key, None)
def recursive_del(data: Dict[str, any], key: str) -> bool:
key, next_key = DictWithRecursion._parse_key(key)
if next_key is not None:
if key not in data:
return False
next_data = data.get(key, {})
return recursive_del(next_data, next_key)
if key in data:
del data[key]
return True
return False
+6 -6
View File
@@ -14,7 +14,7 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional from typing import Dict, Optional
import json import json
import base64 import base64
import hashlib import hashlib
@@ -28,13 +28,13 @@ def _get_checksum(key: str, payload: bytes) -> str:
return checksum return checksum
def sign_token(key: str, payload: dict) -> str: def sign_token(key: str, payload: Dict) -> str:
payload = base64.urlsafe_b64encode(json.dumps(payload).encode("utf-8")) payload_b64 = base64.urlsafe_b64encode(json.dumps(payload).encode("utf-8"))
checksum = _get_checksum(key, payload) checksum = _get_checksum(key, payload_b64)
return f"{checksum}:{payload.decode('utf-8')}" return f"{checksum}:{payload_b64.decode('utf-8')}"
def verify_token(key: str, data: str) -> Optional[dict]: def verify_token(key: str, data: str) -> Optional[Dict]:
if not data: if not data:
return None return None
+43 -27
View File
@@ -15,6 +15,9 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from abc import abstractmethod from abc import abstractmethod
from typing import Optional
from aiohttp import web
import abc import abc
import asyncio import asyncio
import logging import logging
@@ -23,27 +26,30 @@ from telethon.errors import *
from ...commands.auth import enter_password from ...commands.auth import enter_password
from ...util import format_duration from ...util import format_duration
from ...puppet import Puppet from ...puppet import Puppet, PuppetError
from ...user import User from ...user import User
class AuthAPI(abc.ABC): class AuthAPI(abc.ABC):
log = logging.getLogger("mau.web.auth") log = logging.getLogger("mau.web.auth") # type: logging.Logger
def __init__(self, loop): def __init__(self, loop: asyncio.AbstractEventLoop):
self.loop = loop # type: asyncio.AbstractEventLoop self.loop = loop # type: asyncio.AbstractEventLoop
@abstractmethod @abstractmethod
def get_login_response(self, status=200, state="", username="", mxid="", message="", error="", def get_login_response(self, status: int = 200, state: str = "", username: str = "",
errcode=""): phone: str = "", human_tg_id: str = "", mxid: str = "",
message: str = "", error: str = "", errcode: str = "") -> web.Response:
raise NotImplementedError() raise NotImplementedError()
@abstractmethod @abstractmethod
def get_mx_login_response(self, status=200, state="", username="", mxid="", message="", def get_mx_login_response(self, status: int = 200, state: str = "", username: str = "",
error="", errcode=""): phone: str = "", human_tg_id: str = "", mxid: str = "",
message: str = "", error: str = "", errcode: str = ""
) -> web.Response:
raise NotImplementedError() raise NotImplementedError()
async def post_matrix_token(self, user: User, token): async def post_matrix_token(self, user: User, token: str) -> web.Response:
puppet = Puppet.get(user.tgid) puppet = Puppet.get(user.tgid)
if puppet.is_real_user: if puppet.is_real_user:
return self.get_mx_login_response(state="already-logged-in", status=409, return self.get_mx_login_response(state="already-logged-in", status=409,
@@ -51,20 +57,21 @@ class AuthAPI(abc.ABC):
"account.", errcode="already-logged-in") "account.", errcode="already-logged-in")
resp = await puppet.switch_mxid(token, user.mxid) resp = await puppet.switch_mxid(token, user.mxid)
if resp == 2: if resp == PuppetError.OnlyLoginSelf:
return self.get_mx_login_response(status=403, errcode="only-login-self", return self.get_mx_login_response(status=403, errcode="only-login-self",
error="You can only log in as your own Matrix user.") error="You can only log in as your own Matrix user.")
elif resp == 1: elif resp == PuppetError.InvalidAccessToken:
return self.get_mx_login_response(status=401, errcode="invalid-access-token", return self.get_mx_login_response(status=401, errcode="invalid-access-token",
error="Failed to verify access token.") error="Failed to verify access token.")
assert resp == PuppetError.Success, "Encountered an unhandled PuppetError."
return self.get_mx_login_response(mxid=user.mxid, status=200, state="logged-in") return self.get_mx_login_response(mxid=user.mxid, status=200, state="logged-in")
async def post_matrix_password(self, user, password): async def post_matrix_password(self, user: User, password: str) -> web.Response:
return self.get_mx_login_response(mxid=user.mxid, status=501, error="Not yet implemented", return self.get_mx_login_response(mxid=user.mxid, status=501, error="Not yet implemented",
errcode="not-yet-implemented") errcode="not-yet-implemented")
async def post_login_phone(self, user, phone): async def post_login_phone(self, user: User, phone: str) -> web.Response:
try: try:
await user.client.sign_in(phone or "+123") await user.client.sign_in(phone or "+123")
return self.get_login_response(mxid=user.mxid, state="code", status=200, return self.get_login_response(mxid=user.mxid, state="code", status=200,
@@ -101,14 +108,22 @@ class AuthAPI(abc.ABC):
errcode="unknown_error", errcode="unknown_error",
error="Internal server error while requesting code.") error="Internal server error while requesting code.")
async def post_login_token(self, user, token): async def postprocess_login(self, user: User, user_info) -> None:
existing_user = User.get_by_tgid(user_info.id)
if existing_user and existing_user != user:
await existing_user.log_out()
asyncio.ensure_future(user.post_login(user_info), loop=self.loop)
if user.command_status and user.command_status["action"] == "Login":
user.command_status = None
async def post_login_token(self, user: User, token: str) -> web.Response:
try: try:
user_info = await user.client.sign_in(bot_token=token) user_info = await user.client.sign_in(bot_token=token)
asyncio.ensure_future(user.post_login(user_info), loop=self.loop) await self.postprocess_login(user, user_info)
if user.command_status and user.command_status["action"] == "Login":
user.command_status = None
return self.get_login_response(mxid=user.mxid, state="logged-in", status=200, return self.get_login_response(mxid=user.mxid, state="logged-in", status=200,
username=user_info.username) username=user_info.username, phone=None,
human_tg_id=f"@{user_info.username}")
except AccessTokenInvalidError: except AccessTokenInvalidError:
return self.get_login_response(mxid=user.mxid, state="token", status=401, return self.get_login_response(mxid=user.mxid, state="token", status=401,
errcode="bot_token_invalid", errcode="bot_token_invalid",
@@ -122,14 +137,15 @@ class AuthAPI(abc.ABC):
return self.get_login_response(mxid=user.mxid, state="token", status=500, return self.get_login_response(mxid=user.mxid, state="token", status=500,
error="Internal server error while sending token.") error="Internal server error while sending token.")
async def post_login_code(self, user, code, password_in_data): async def post_login_code(self, user: User, code: int, password_in_data: bool
) -> Optional[web.Response]:
try: try:
user_info = await user.client.sign_in(code=code) user_info = await user.client.sign_in(code=code)
asyncio.ensure_future(user.post_login(user_info), loop=self.loop) await self.postprocess_login(user, user_info)
if user.command_status and user.command_status["action"] == "Login": human_tg_id = f"@{user_info.username}" if user_info.username else f"+{user_info.phone}"
user.command_status = None
return self.get_login_response(mxid=user.mxid, state="logged-in", status=200, return self.get_login_response(mxid=user.mxid, state="logged-in", status=200,
username=user_info.username) username=user_info.username, phone=user_info.phone,
human_tg_id=human_tg_id)
except PhoneCodeInvalidError: except PhoneCodeInvalidError:
return self.get_login_response(mxid=user.mxid, state="code", status=401, return self.get_login_response(mxid=user.mxid, state="code", status=401,
errcode="phone_code_invalid", errcode="phone_code_invalid",
@@ -155,14 +171,14 @@ class AuthAPI(abc.ABC):
errcode="unknown_error", errcode="unknown_error",
error="Internal server error while sending code.") error="Internal server error while sending code.")
async def post_login_password(self, user, password): async def post_login_password(self, user: User, password: str) -> web.Response:
try: try:
user_info = await user.client.sign_in(password=password) user_info = await user.client.sign_in(password=password)
asyncio.ensure_future(user.post_login(user_info), loop=self.loop) await self.postprocess_login(user, user_info)
if user.command_status and user.command_status["action"] == "Login (password entry)": human_tg_id = f"@{user_info.username}" if user_info.username else f"+{user_info.phone}"
user.command_status = None
return self.get_login_response(mxid=user.mxid, state="logged-in", status=200, return self.get_login_response(mxid=user.mxid, state="logged-in", status=200,
username=user_info.username) username=user_info.username, phone=user_info.phone,
human_tg_id=human_tg_id)
except PasswordEmptyError: except PasswordEmptyError:
return self.get_login_response(mxid=user.mxid, state="password", status=400, return self.get_login_response(mxid=user.mxid, state="password", status=400,
errcode="password_empty", errcode="password_empty",
+39 -20
View File
@@ -15,7 +15,7 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from aiohttp import web from aiohttp import web
from typing import Tuple, Optional, Callable, Awaitable, TYPE_CHECKING from typing import Awaitable, Callable, Dict, Optional, Tuple, TYPE_CHECKING
import asyncio import asyncio
import logging import logging
import json import json
@@ -24,6 +24,7 @@ from telethon.utils import get_peer_id, resolve_id
from telethon.tl.types import ChatForbidden, ChannelForbidden, TypeChat from telethon.tl.types import ChatForbidden, ChannelForbidden, TypeChat
from mautrix_appservice import AppService, MatrixRequestError, IntentError from mautrix_appservice import AppService, MatrixRequestError, IntentError
from ...types import MatrixUserID, TelegramID
from ...user import User from ...user import User
from ...portal import Portal from ...portal import Portal
from ...commands.portal import user_has_power_level, get_initial_state from ...commands.portal import user_has_power_level, get_initial_state
@@ -34,20 +35,21 @@ if TYPE_CHECKING:
class ProvisioningAPI(AuthAPI): class ProvisioningAPI(AuthAPI):
log = logging.getLogger("mau.web.provisioning") log = logging.getLogger("mau.web.provisioning") # type: logging.Logger
def __init__(self, context: "Context"): def __init__(self, context: "Context") -> None:
super().__init__(context.loop) super().__init__(context.loop)
self.secret = context.config["appservice.provisioning.shared_secret"] self.secret = context.config["appservice.provisioning.shared_secret"] # type: str
self.az = context.az # type: AppService self.az = context.az # type: AppService
self.context = context # type: Context self.context = context # type: Context
self.app = web.Application(loop=context.loop, middlewares=[self.error_middleware]) self.app = web.Application(loop=context.loop, middlewares=[self.error_middleware]
) # type: web.Application
portal_prefix = "/portal/{mxid:![^/]+}" portal_prefix = "/portal/{mxid:![^/]+}"
self.app.router.add_route("GET", f"{portal_prefix}", self.get_portal_by_mxid) self.app.router.add_route("GET", f"{portal_prefix}", self.get_portal_by_mxid)
self.app.router.add_route("GET", "/portal/{tgid:-[0-9]+}", self.get_portal_by_tgid) self.app.router.add_route("GET", "/portal/{tgid:-[0-9]+}", self.get_portal_by_tgid)
self.app.router.add_route("POST", portal_prefix + "/connect/{chat_id:[0-9]+}", self.app.router.add_route("POST", portal_prefix + "/connect/{chat_id:-[0-9]+}",
self.connect_chat) self.connect_chat)
self.app.router.add_route("POST", f"{portal_prefix}/create", self.create_chat) self.app.router.add_route("POST", f"{portal_prefix}/create", self.create_chat)
self.app.router.add_route("POST", f"{portal_prefix}/disconnect", self.disconnect_chat) self.app.router.add_route("POST", f"{portal_prefix}/disconnect", self.disconnect_chat)
@@ -62,6 +64,8 @@ class ProvisioningAPI(AuthAPI):
self.app.router.add_route("POST", f"{user_prefix}/login/send_code", self.send_code) self.app.router.add_route("POST", f"{user_prefix}/login/send_code", self.send_code)
self.app.router.add_route("POST", f"{user_prefix}/login/send_password", self.send_password) self.app.router.add_route("POST", f"{user_prefix}/login/send_password", self.send_password)
self.app.router.add_route("GET", "/bridge", self.bridge_info)
async def get_portal_by_mxid(self, request: web.Request) -> web.Response: async def get_portal_by_mxid(self, request: web.Request) -> web.Response:
err = self.check_authorization(request) err = self.check_authorization(request)
if err is not None: if err is not None:
@@ -72,6 +76,8 @@ class ProvisioningAPI(AuthAPI):
if not portal: if not portal:
return self.get_error_response(404, "portal_not_found", return self.get_error_response(404, "portal_not_found",
"Portal with given Matrix ID not found.") "Portal with given Matrix ID not found.")
user, _ = await self.get_user(request.query.get("user_id", None), expect_logged_in=None,
require_puppeting=False)
return web.json_response({ return web.json_response({
"mxid": portal.mxid, "mxid": portal.mxid,
"chat_id": get_peer_id(portal.peer), "chat_id": get_peer_id(portal.peer),
@@ -80,6 +86,7 @@ class ProvisioningAPI(AuthAPI):
"about": portal.about, "about": portal.about,
"username": portal.username, "username": portal.username,
"megagroup": portal.megagroup, "megagroup": portal.megagroup,
"can_unbridge": (await portal.can_user_perform(user, "unbridge")) if user else False,
}) })
async def get_portal_by_tgid(self, request: web.Request) -> web.Response: async def get_portal_by_tgid(self, request: web.Request) -> web.Response:
@@ -96,6 +103,8 @@ class ProvisioningAPI(AuthAPI):
if not portal: if not portal:
return self.get_error_response(404, "portal_not_found", return self.get_error_response(404, "portal_not_found",
"Portal to given Telegram chat not found.") "Portal to given Telegram chat not found.")
user, _ = await self.get_user(request.query.get("user_id", None), expect_logged_in=None,
require_puppeting=False)
return web.json_response({ return web.json_response({
"mxid": portal.mxid, "mxid": portal.mxid,
"chat_id": get_peer_id(portal.peer), "chat_id": get_peer_id(portal.peer),
@@ -104,6 +113,7 @@ class ProvisioningAPI(AuthAPI):
"about": portal.about, "about": portal.about,
"username": portal.username, "username": portal.username,
"megagroup": portal.megagroup, "megagroup": portal.megagroup,
"can_unbridge": (await portal.can_user_perform(user, "unbridge")) if user else False,
}) })
async def connect_chat(self, request: web.Request) -> web.Response: async def connect_chat(self, request: web.Request) -> web.Response:
@@ -118,10 +128,10 @@ class ProvisioningAPI(AuthAPI):
chat_id = request.match_info["chat_id"] chat_id = request.match_info["chat_id"]
if chat_id.startswith("-100"): if chat_id.startswith("-100"):
tgid = int(chat_id[4:]) tgid = TelegramID(int(chat_id[4:]))
peer_type = "channel" peer_type = "channel"
elif chat_id.startswith("-"): elif chat_id.startswith("-"):
tgid = -int(chat_id) tgid = TelegramID(-int(chat_id))
peer_type = "chat" peer_type = "chat"
else: else:
return self.get_error_response(400, "tgid_invalid", "Invalid Telegram chat ID.") return self.get_error_response(400, "tgid_invalid", "Invalid Telegram chat ID.")
@@ -153,14 +163,14 @@ class ProvisioningAPI(AuthAPI):
"Matrix room.") "Matrix room.")
is_logged_in = user is not None and await user.is_logged_in() is_logged_in = user is not None and await user.is_logged_in()
user = user if is_logged_in else self.context.bot acting_user = user if is_logged_in else self.context.bot
if not user: if not acting_user:
return self.get_login_response(status=403, errcode="not_logged_in", return self.get_login_response(status=403, errcode="not_logged_in",
error="You are not logged in and there is no relay bot.") error="You are not logged in and there is no relay bot.")
entity = None # type: Optional[TypeChat] entity = None # type: Optional[TypeChat]
try: try:
entity = await user.client.get_entity(portal.peer) entity = await acting_user.client.get_entity(portal.peer)
except Exception: except Exception:
self.log.exception("Failed to get_entity(%s) for manual bridging.", portal.peer) self.log.exception("Failed to get_entity(%s) for manual bridging.", portal.peer)
@@ -351,8 +361,14 @@ class ProvisioningAPI(AuthAPI):
return err return err
await user.log_out() await user.log_out()
async def bridge_info(self, request: web.Request) -> web.Response:
return web.json_response({
"relaybot_username": self.context.bot.username,
}, status=200)
@staticmethod @staticmethod
async def error_middleware(_, handler) -> Callable[[web.Request], Awaitable[web.Response]]: async def error_middleware(_, handler: Callable[[web.Request], Awaitable[web.Response]]
) -> Callable[[web.Request], Awaitable[web.Response]]:
async def middleware_handler(request: web.Request) -> web.Response: async def middleware_handler(request: web.Request) -> web.Response:
try: try:
return await handler(request) return await handler(request)
@@ -371,16 +387,18 @@ class ProvisioningAPI(AuthAPI):
"errcode": errcode, "errcode": errcode,
}, status=status) }, status=status)
def get_mx_login_response(self, status=200, state="", username="", mxid="", message="", def get_mx_login_response(self, status=200, state="", username="", phone="", human_tg_id="",
error="", errcode=""): mxid="", message="", error="", errcode=""):
raise NotImplementedError() raise NotImplementedError()
def get_login_response(self, status=200, state="", username="", mxid="", message="", error="", def get_login_response(self, status=200, state="", username="", phone: str = "",
errcode="") -> web.Response: human_tg_id: str = "", mxid="", message="", error="", errcode=""
if username: ) -> web.Response:
if username or phone:
resp = { resp = {
"state": "logged-in", "state": "logged-in",
"username": username, "username": username,
"phone": phone,
} }
elif message: elif message:
resp = { resp = {
@@ -411,7 +429,7 @@ class ProvisioningAPI(AuthAPI):
except json.JSONDecodeError: except json.JSONDecodeError:
return None return None
async def get_user(self, mxid: str, expect_logged_in: Optional[bool] = False, async def get_user(self, mxid: MatrixUserID, expect_logged_in: Optional[bool] = False,
require_puppeting: bool = True, require_user: bool = True require_puppeting: bool = True, require_user: bool = True
) -> Tuple[Optional[User], Optional[web.Response]]: ) -> Tuple[Optional[User], Optional[web.Response]]:
if not mxid: if not mxid:
@@ -427,7 +445,8 @@ class ProvisioningAPI(AuthAPI):
if expect_logged_in is not None: if expect_logged_in is not None:
logged_in = await user.is_logged_in() logged_in = await user.is_logged_in()
if not expect_logged_in and logged_in: if not expect_logged_in and logged_in:
return user, self.get_login_response(username=user.username, status=409, return user, self.get_login_response(username=user.username, phone=user.phone,
status=409,
error="You are already logged in.", error="You are already logged in.",
errcode="already_logged_in") errcode="already_logged_in")
elif expect_logged_in and not logged_in: elif expect_logged_in and not logged_in:
@@ -439,7 +458,7 @@ class ProvisioningAPI(AuthAPI):
expect_logged_in: Optional[bool] = False, expect_logged_in: Optional[bool] = False,
require_puppeting: bool = False, require_puppeting: bool = False,
want_data: bool = True, want_data: bool = True,
) -> (Tuple[Optional[dict], ) -> (Tuple[Optional[Dict],
Optional[User], Optional[User],
Optional[web.Response]]): Optional[web.Response]]):
err = self.check_authorization(request) err = self.check_authorization(request)
+35 -1
View File
@@ -22,8 +22,23 @@ tags:
- name: User info - name: User info
- name: Authentication - name: Authentication
- name: Bridging - name: Bridging
- name: Misc
paths: paths:
/bridge:
get:
operationId: get_bridge
summary: Get the bridge's information
tags: [Misc]
responses:
200:
description: The bridge information
schema:
type: object
properties:
relaybot_username:
type: string
description: The relay bot's username on Telegram
/portal/{room_id}: /portal/{room_id}:
get: get:
operationId: get_portal operationId: get_portal
@@ -57,6 +72,11 @@ paths:
required: true required: true
type: string type: string
pattern: "![^/]+" pattern: "![^/]+"
- name: user_id
in: query
description: Optional Matrix user ID to check if the user has permissions to do bridging.
required: false
type: string
/portal/{chat_id}: /portal/{chat_id}:
get: get:
operationId: get_portal_by_tgid operationId: get_portal_by_tgid
@@ -102,6 +122,11 @@ paths:
required: true required: true
type: integer type: integer
pattern: "-[0-9]+" pattern: "-[0-9]+"
- name: user_id
in: query
description: Optional Matrix user ID to check if the user has permissions to do bridging.
required: false
type: string
/portal/{room_id}/connect/{chat_id}: /portal/{room_id}/connect/{chat_id}:
post: post:
operationId: connect_portal operationId: connect_portal
@@ -706,6 +731,9 @@ responses:
username: username:
type: string type: string
description: The Telegram username the user is logged in as. description: The Telegram username the user is logged in as.
phone:
type: string
description: The phone number of the account the user is logged into.
BadRequest: BadRequest:
description: Invalid JSON. description: Invalid JSON.
schema: schema:
@@ -790,7 +818,7 @@ definitions:
example: A. example: A.
phone: phone:
type: string type: string
example: +123456789 example: 123456789
is_bot: is_bot:
type: boolean type: boolean
example: false example: false
@@ -829,6 +857,9 @@ definitions:
type: string type: string
about: about:
type: string type: string
can_unbridge:
type: boolean
description: If a user ID was provided with the request, this will indicate whether or not the user can unbridge the room.
AuthSuccess: AuthSuccess:
type: object type: object
@@ -845,6 +876,9 @@ definitions:
username: username:
type: string type: string
description: The Telegram username the user is logged in as. Only applicable if state=logged-in description: The Telegram username the user is logged in as. Only applicable if state=logged-in
phone:
type: string
description: The phone number of the account the user logged into. Only applicable if state=logged-in
HumanReadableError: HumanReadableError:
type: string type: string
+31 -25
View File
@@ -14,14 +14,17 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional
from aiohttp import web from aiohttp import web
from mako.template import Template from mako.template import Template
import pkg_resources import pkg_resources
import asyncio
import logging import logging
import random import random
import string import string
import time import time
from ...types import MatrixUserID
from ...util import sign_token, verify_token from ...util import sign_token, verify_token
from ...user import User from ...user import User
from ...puppet import Puppet from ...puppet import Puppet
@@ -29,20 +32,20 @@ from ..common import AuthAPI
class PublicBridgeWebsite(AuthAPI): class PublicBridgeWebsite(AuthAPI):
log = logging.getLogger("mau.web.public") log = logging.getLogger("mau.web.public") # type: logging.Logger
def __init__(self, loop): def __init__(self, loop: asyncio.AbstractEventLoop):
super().__init__(loop) super().__init__(loop)
self.secret_key = "".join( self.secret_key = "".join(
random.choice(string.ascii_lowercase + string.digits) for _ in range(64)) random.choice(string.ascii_lowercase + string.digits) for _ in range(64)) # type: str
self.login = Template( self.login = Template(pkg_resources.resource_string(
pkg_resources.resource_string("mautrix_telegram", "web/public/login.html.mako")) "mautrix_telegram", "web/public/login.html.mako")) # type: Template
self.mx_login = Template( self.mx_login = Template(pkg_resources.resource_string(
pkg_resources.resource_string("mautrix_telegram", "web/public/matrix-login.html.mako")) "mautrix_telegram", "web/public/matrix-login.html.mako")) # type: Template
self.app = web.Application(loop=loop) self.app = web.Application(loop=loop) # type: web.Application
self.app.router.add_route("GET", "/login", self.get_login) self.app.router.add_route("GET", "/login", self.get_login)
self.app.router.add_route("POST", "/login", self.post_login) self.app.router.add_route("POST", "/login", self.post_login)
self.app.router.add_route("GET", "/matrix-login", self.get_matrix_login) self.app.router.add_route("GET", "/matrix-login", self.get_matrix_login)
@@ -50,21 +53,21 @@ class PublicBridgeWebsite(AuthAPI):
self.app.router.add_static("/", pkg_resources.resource_filename("mautrix_telegram", self.app.router.add_static("/", pkg_resources.resource_filename("mautrix_telegram",
"web/public/")) "web/public/"))
def make_token(self, mxid, endpoint="/login", expires_in=900): def make_token(self, mxid: str, endpoint: str = "/login", expires_in: int = 900) -> str:
return sign_token(self.secret_key, { return sign_token(self.secret_key, {
"mxid": mxid, "mxid": mxid,
"endpoint": endpoint, "endpoint": endpoint,
"expiry": int(time.time()) + expires_in, "expiry": int(time.time()) + expires_in,
}) })
def verify_token(self, token, endpoint="/login"): def verify_token(self, token: str, endpoint: str = "/login") -> Optional[MatrixUserID]:
token = verify_token(self.secret_key, token) token = verify_token(self.secret_key, token)
if token and (token.get("expiry", 0) > int(time.time()) and if token and (token.get("expiry", 0) > int(time.time()) and
token.get("endpoint", None) == endpoint): token.get("endpoint", None) == endpoint):
return token.get("mxid", None) return MatrixUserID(token.get("mxid", None))
return None return None
async def get_login(self, request): async def get_login(self, request: web.Request) -> web.Response:
state = "bot_token" if request.rel_url.query.get("mode", "") == "bot" else "request" state = "bot_token" if request.rel_url.query.get("mode", "") == "bot" else "request"
mxid = self.verify_token(request.rel_url.query.get("token", None), endpoint="/login") mxid = self.verify_token(request.rel_url.query.get("token", None), endpoint="/login")
@@ -81,9 +84,9 @@ class PublicBridgeWebsite(AuthAPI):
if not await user.is_logged_in(): if not await user.is_logged_in():
return self.get_login_response(mxid=user.mxid, state=state) return self.get_login_response(mxid=user.mxid, state=state)
return self.get_login_response(mxid=user.mxid, username=user.username) return self.get_login_response(mxid=user.mxid, human_tg_id=user.human_tg_id)
async def get_matrix_login(self, request): async def get_matrix_login(self, request: web.Request) -> web.Response:
mxid = self.verify_token(request.rel_url.query.get("token", None), endpoint="/matrix-login") mxid = self.verify_token(request.rel_url.query.get("token", None), endpoint="/matrix-login")
if not mxid: if not mxid:
return self.get_mx_login_response(status=401, state="invalid-token") return self.get_mx_login_response(status=401, state="invalid-token")
@@ -105,19 +108,22 @@ class PublicBridgeWebsite(AuthAPI):
return self.get_mx_login_response(mxid=user.mxid) return self.get_mx_login_response(mxid=user.mxid)
def get_login_response(self, status=200, state="", username="", mxid="", message="", error="", def get_login_response(self, status: int = 200, state: str = "", username: str = "",
errcode=""): phone: str = "", human_tg_id: str = "", mxid: str = "",
message: str = "", error: str = "", errcode: str = "") -> web.Response:
return web.Response(status=status, content_type="text/html", return web.Response(status=status, content_type="text/html",
text=self.login.render(username=username, state=state, error=error, text=self.login.render(human_tg_id=human_tg_id, state=state,
message=message, mxid=mxid)) error=error, message=message, mxid=mxid))
def get_mx_login_response(self, status=200, state="", username="", mxid="", message="", def get_mx_login_response(self, status: int = 200, state: str = "", username: str = "",
error="", errcode=""): phone: str = "", human_tg_id: str = "", mxid: str = "",
message: str = "", error: str = "", errcode: str = ""
) -> web.Response:
return web.Response(status=status, content_type="text/html", return web.Response(status=status, content_type="text/html",
text=self.mx_login.render(username=username, state=state, error=error, text=self.mx_login.render(human_tg_id=human_tg_id, state=state,
message=message, mxid=mxid)) error=error, message=message, mxid=mxid))
async def post_matrix_login(self, request): async def post_matrix_login(self, request: web.Request) -> web.Response:
mxid = self.verify_token(request.rel_url.query.get("token", None), endpoint="/matrix-login") mxid = self.verify_token(request.rel_url.query.get("token", None), endpoint="/matrix-login")
if not mxid: if not mxid:
return self.get_mx_login_response(status=401, state="invalid-token") return self.get_mx_login_response(status=401, state="invalid-token")
@@ -140,7 +146,7 @@ class PublicBridgeWebsite(AuthAPI):
error="You must provide an access token or " error="You must provide an access token or "
"password.") "password.")
async def post_login(self, request): async def post_login(self, request: web.Request) -> web.Response:
mxid = self.verify_token(request.rel_url.query.get("token", None), endpoint="/login") mxid = self.verify_token(request.rel_url.query.get("token", None), endpoint="/login")
if not mxid: if not mxid:
return self.get_login_response(status=401, state="invalid-token") return self.get_login_response(status=401, state="invalid-token")
@@ -152,7 +158,7 @@ class PublicBridgeWebsite(AuthAPI):
return self.get_login_response(mxid=user.mxid, error="You are not whitelisted.", return self.get_login_response(mxid=user.mxid, error="You are not whitelisted.",
status=403) status=403)
elif await user.is_logged_in(): elif await user.is_logged_in():
return self.get_login_response(mxid=user.mxid, username=user.username) return self.get_login_response(mxid=user.mxid, human_tg_id=user.human_tg_id)
await user.ensure_started(even_if_no_session=True) await user.ensure_started(even_if_no_session=True)
+4 -4
View File
@@ -51,25 +51,25 @@ along with this program. If not, see <https://www.gnu.org/licenses/>.
</head> </head>
<body> <body>
<main class="container"> <main class="container">
% if username: % if human_tg_id:
% if state == "logged-in": % if state == "logged-in":
<h1>Logged in successfully!</h1> <h1>Logged in successfully!</h1>
<p> <p>
Logged in as @${username}. Logged in as ${human_tg_id}.
You can now close this page. You can now close this page.
You should be invited to Telegram portals on Matrix momentarily. You should be invited to Telegram portals on Matrix momentarily.
</p> </p>
% elif state == "bot-logged-in": % elif state == "bot-logged-in":
<h1>Logged in successfully!</h1> <h1>Logged in successfully!</h1>
<p> <p>
Logged in as @${username}. Logged in as ${human_tg_id}.
You can now close this page. You can now close this page.
You should be invited to Telegram portals on Matrix momentarily. You should be invited to Telegram portals on Matrix momentarily.
</p> </p>
% else: % else:
<h1>You're already logged in!</h1> <h1>You're already logged in!</h1>
<p> <p>
You're logged in as @${username}. You're logged in as ${human_tg_id}.
</p> </p>
<p> <p>
If you want to log in with another account, log out using the <code>logout</code> If you want to log in with another account, log out using the <code>logout</code>
+1 -1
View File
@@ -4,7 +4,7 @@ ruamel.yaml
python-magic python-magic
SQLAlchemy SQLAlchemy
alembic alembic
Markdown commonmark
future-fstrings future-fstrings
telethon telethon
telethon-session-sqlalchemy telethon-session-sqlalchemy
+3 -2
View File
@@ -27,10 +27,10 @@ setuptools.setup(
install_requires=[ install_requires=[
"aiohttp>=3.0.1,<4", "aiohttp>=3.0.1,<4",
"mautrix-appservice>=0.3.6,<0.4.0", "mautrix-appservice>=0.3.7,<0.4.0",
"SQLAlchemy>=1.2.3,<2", "SQLAlchemy>=1.2.3,<2",
"alembic>=1.0.0,<2", "alembic>=1.0.0,<2",
"Markdown>=2.6.11,<3", "commonmark>=0.8.1,<1",
"ruamel.yaml>=0.15.35,<0.16", "ruamel.yaml>=0.15.35,<0.16",
"future-fstrings>=0.4.2", "future-fstrings>=0.4.2",
"python-magic>=0.4.15,<0.5", "python-magic>=0.4.15,<0.5",
@@ -43,6 +43,7 @@ setuptools.setup(
"Development Status :: 4 - Beta", "Development Status :: 4 - Beta",
"License :: OSI Approved :: GNU Affero General Public License v3 or later (AGPLv3+)", "License :: OSI Approved :: GNU Affero General Public License v3 or later (AGPLv3+)",
"Topic :: Communications :: Chat", "Topic :: Communications :: Chat",
"Framework :: AsyncIO",
"Programming Language :: Python", "Programming Language :: Python",
"Programming Language :: Python :: 3", "Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.5", "Programming Language :: Python :: 3.5",