Compare commits

...

70 Commits

Author SHA1 Message Date
Tulir Asokan 1994ce38eb Bump version to 0.4.0 2018-11-28 02:10:37 +02:00
Tulir Asokan 9aad6de823 Bump version to 0.4.0rc2 2018-11-15 22:46:36 +02:00
Tulir Asokan 3d3afdb645 Fix bug in 82d7e78455 2018-11-15 22:45:48 +02:00
Tulir Asokan 983f5001ab Bump version to 0.4.0rc1 2018-11-15 22:27:25 +02:00
Tulir Asokan a80fdf0990 Fix bug in 720210ac08 2018-11-15 22:25:49 +02:00
Tulir Asokan 82d7e78455 Handle kicking puppets separately. Fixes #191 2018-11-15 11:57:02 +02:00
Tulir Asokan d514b929b3 Automatically log out when logging in with a user someone logged in with previously. Fixes #198 2018-11-15 11:45:46 +02:00
Tulir Asokan 720210ac08 Check if client is connected before checking if authorized. Fixes #215 2018-11-15 11:45:36 +02:00
Tulir Asokan 2dfc05db5f Fall back to get_dialogs if get_entity fails. Fixes #229 2018-11-15 11:20:43 +02:00
Tulir Asokan d551934ec1 Fix command suggestion when trying to bridge non-whitelisted chat 2018-11-01 01:55:54 +02:00
Tulir Asokan bac1e30cf0 Fix Matrix->Telegram code blocks without language. Fixes #240 2018-10-27 19:22:04 +03:00
Tulir Asokan 8fdb2c4e57 Merge pull request #239 from tulir/sqlalchemy-core
Port Message table to SQLAlchemy Core
2018-10-21 00:32:14 +03:00
Tulir Asokan 8da1fb78b8 Handle aiohttp errors in syncer. Fixes #210 2018-10-21 00:09:37 +03:00
Tulir Asokan cea8163366 Only match integers in puppet mxid regex. Fixes #234 2018-10-21 00:08:02 +03:00
Tulir Asokan 388e4f8601 Port Message table to SQLAlchemy Core 2018-10-20 23:11:10 +03:00
Tulir Asokan 2756873c53 Add SIGINT/SIGTERM handler 2018-10-20 21:21:26 +03:00
Tulir Asokan a770e1f67e Merge pull request #237 from turt2live/travis/fix-chat-id-request
Don't try permission checks on rooms that aren't bridged
2018-10-20 14:56:27 +03:00
Tulir Asokan f8c844c4c0 Add flag to enable alchemysession core mode 2018-10-20 14:46:26 +03:00
Travis Ralston 7f23d4cf68 Don't try permission checks on rooms that aren't bridged
This is the proper way to fix https://github.com/tulir/mautrix-telegram/pull/235
2018-10-19 19:31:58 -06:00
Tulir Asokan 247c75191b Merge pull request #226 from turt2live/travis/bridge-info
Add provisioning route for getting misc bridge info
2018-10-08 14:02:19 +03:00
Travis Ralston 4f3e1b4fe6 Fix errors in spec.yaml 2018-10-08 01:16:29 -06:00
Travis Ralston 6291e92ed7 Remove extraneous fstring 2018-10-08 01:15:49 -06:00
Tulir Asokan 5054afcbb5 Fix Python 3.5 compatibility 2018-10-02 14:51:54 +03:00
Tulir Asokan 980e0d6ef7 Send captions as second message by default. Fixes #233 2018-09-29 10:56:04 +03:00
Tulir Asokan 2f6147f325 Fix notice bridging exceptions 2018-09-29 01:35:30 +03:00
Tulir Asokan 56fb88b75e Use mxids instead of localparts as default displaynames and fix name add/remove message. Fixes #228 2018-09-29 00:59:02 +03:00
Tulir Asokan 24bdda8ca1 Reorganize formatter utils and add more blue text 2018-09-28 18:39:57 +03:00
Tulir Asokan c38e46fc2a Fix linebreaks in pre blocks 2018-09-28 17:15:57 +03:00
Tulir Asokan 916cc3746d Fix block tag newlines and allow <strike>. Fixes #232 2018-09-28 17:06:42 +03:00
Tulir Asokan a32bc2985a Show phone number when username doesn't exist. Fixes #213 2018-09-28 02:46:02 +03:00
Tulir Asokan 8d982b4615 Bump minimum mautrix-appservice version. Fixes #217 2018-09-28 02:22:54 +03:00
Tulir Asokan 10e77707d0 Fix HTML escaping in command reply markdown parser 2018-09-28 02:18:41 +03:00
Tulir Asokan b0fe208768 Add missing await to portal.set_typing 2018-09-28 01:18:39 +03:00
Tulir Asokan b44d6d2d90 Fix minor things and type hints 2018-09-28 01:02:09 +03:00
Tulir Asokan 828047e272 Split TelegramMessage helper to separate file 2018-09-28 00:49:37 +03:00
Tulir Asokan a9cb1bf518 Fix linebreak handling in lxml parser and add better bullets
Fixes #218
2018-09-28 00:45:37 +03:00
Tulir Asokan d71f421981 Use <pre> for multiline MessageEntityCode entities 2018-09-26 00:24:04 +03:00
Tulir Asokan 26e947992e Merge pull request #231 from tulir/room-specific-settings
Add room specific config
2018-09-25 00:47:44 +03:00
Tulir Asokan 78e4804774 Fix minor things and improve code style 2018-09-25 00:47:16 +03:00
Tulir Asokan 5ccd1bc2fe Fix bugs and switch to commonmark for command replies 2018-09-25 00:26:02 +03:00
Tulir Asokan f758884c75 Fix example config and add alembic migration 2018-09-24 23:41:18 +03:00
Tulir Asokan 9d2d34a25c Add command to update room-specific config 2018-09-24 17:44:00 +03:00
Tulir Asokan fc23461445 Add room specific settings. Probably broken 2018-09-24 16:01:16 +03:00
Tulir Asokan 5253504df9 Update setup.py classifiers 2018-09-24 01:26:02 +03:00
Tulir Asokan dd270b862e Fix handling capitalized file extensions. Fixes #156 2018-09-24 01:25:51 +03:00
Travis Ralston 5bc1362493 Add provisioning route for getting misc bridge info
Currently only the relay bot's username is exposed here.
2018-09-19 22:44:27 -06:00
Tulir Asokan 96a0c923c2 Merge pull request #225 from turt2live/travis/unbridge-info
Add a flag to indicate if the requesting user can unbridge the portal
2018-09-17 01:24:17 +03:00
Travis Ralston 23bb2871fd Add a flag to indicate if the requesting user can unbridge the portal 2018-09-16 16:16:33 -06:00
Tulir Asokan d4ea5f8b38 Improve type hints and set version to 0.4.0+dev 2018-09-10 01:14:12 +03:00
Tulir Asokan 4b2cdc3d39 Add missing command status clear 2018-09-10 00:14:35 +03:00
Tulir Asokan 4c54d9c9ea Fix previous commit (ref #219) and update catch_up config comment 2018-09-10 00:11:13 +03:00
Tulir Asokan 9541d5eceb Don't bridge messages from unbridged chats received by bot (ref #219) 2018-09-09 01:26:22 +03:00
Tulir Asokan c9c1023ece Merge pull request #223 from turt2live/patch-1
Allow negative numbers in /connect
2018-09-09 01:14:38 +03:00
Travis Ralston cb2073eb8b Allow negative numbers in /connect 2018-09-08 16:14:00 -06:00
Tulir Asokan d35104aea6 Fix incorrect type hint 2018-09-05 10:55:12 +03:00
Tulir Asokan ad342f2ca4 Ignore old log files too 2018-09-01 19:08:29 +03:00
Tulir Asokan 29541ff520 Pass logging a copy of the config to stop editing. Fixes #216 2018-09-01 14:07:44 +03:00
Tulir Asokan 6a1c160608 Await set_presence. Fixes #209 2018-09-01 14:03:13 +03:00
Tulir Asokan 731c802fcd Only import deque in type checking mode to fix 3.5 runtime support 2018-08-30 19:03:22 +03:00
Tulir Asokan b6f15934f2 Fix conversational command handling 2018-08-30 13:32:04 +03:00
Tulir Asokan 068449c59c Update ROADMAP.md 2018-08-24 09:47:54 +03:00
Tulir Asokan 4f36a2c7c1 Simplify displayname similarity calculation 2018-08-17 00:06:37 +03:00
Tulir Asokan bb04231880 Fix bugs in migrations 2018-08-17 00:06:02 +03:00
Tulir Asokan 1ef790ce31 Merge pull request #206 from V02460/master
Add type annotations
2018-08-15 10:18:39 +03:00
Kai A. Hiller 81531235bc Replace double quote type annotations with single quotes 2018-08-09 14:36:14 +02:00
Kai A. Hiller 66683151ec Make SearchResult a NewType and make its List explicit 2018-08-09 14:23:18 +02:00
Kai A. Hiller e751d140f2 Change case of new types 2018-08-09 14:11:41 +02:00
Kai A. Hiller 0f8009b1e9 Add missing type hints and fix most type errors except for Optionals. 2018-08-09 03:31:04 +02:00
Kai A. Hiller 01e153662e Replace star imports with literal values 2018-08-09 02:42:48 +02:00
Kai A. Hiller 08dd5b5b15 Add None return type to functions 2018-08-09 02:42:47 +02:00
52 changed files with 1862 additions and 1094 deletions
+1 -1
View File
@@ -7,5 +7,5 @@ __pycache__
config.yaml
registration.yaml
*.log
*.log*
*.db
+4 -4
View File
@@ -4,9 +4,9 @@
* [x] Message content (text, formatting, files, etc..)
* [x] Message redactions
* [ ] ‡ Message history
* [ ] Presence
* [ ] Typing notifications
* [ ] Read receipts
* [x] Presence
* [x] Typing notifications
* [x] Read receipts
* [x] Pinning messages
* [x] Power level
* [x] Normal chats
@@ -46,7 +46,7 @@
* [x] When receiving invite or message
* [x] Private chat creation by inviting Matrix puppet of Telegram user to new room
* [x] Option to use bot to relay messages for unauthenticated Matrix users
* [ ] Option to use own Matrix account for messages sent from other Telegram clients
* [x] Option to use own Matrix account for messages sent from other Telegram clients
* [ ] ‡ Calls (hard, not yet supported by Telethon)
† Information not automatically sent from source, i.e. implementation may not be possible
@@ -21,4 +21,5 @@ def upgrade():
def downgrade():
op.drop_column('puppet', 'is_bot')
with op.batch_alter_table("puppet") as batch_op:
batch_op.drop_column('is_bot')
@@ -20,4 +20,5 @@ def upgrade():
def downgrade():
op.drop_column('portal', 'megagroup')
with op.batch_alter_table("portal") as batch_op:
batch_op.drop_column('megagroup')
@@ -72,10 +72,17 @@ def upgrade():
sa.Column("avatar_url", sa.String(), nullable=True),
sa.PrimaryKeyConstraint("room_id", "user_id"))
try:
migrate_state_store()
except Exception as e:
print("Failed to migrate state store:", e)
print("Migrating the state store isn't required, but you can retry by alembic downgrading "
"to revision 2228d49c383f and upgrading again.")
def migrate_state_store():
conn = op.get_bind()
session = orm.sessionmaker(bind=conn)
session = orm.scoping.scoped_session(session)
Puppet.query = session.query_property()
session = orm.sessionmaker(bind=conn)() # type: orm.Session
try:
with open("mx-state.json") as file:
@@ -99,7 +106,7 @@ def upgrade():
if not match:
continue
puppet = Puppet.query.get(match.group(1))
puppet = session.query(Puppet).get(match.group(1))
if not puppet:
continue
@@ -22,4 +22,5 @@ def upgrade():
def downgrade():
op.drop_column('telegram_file', 'timestamp')
with op.batch_alter_table("telegram_file") as batch_op:
batch_op.drop_column('timestamp')
@@ -0,0 +1,25 @@
"""Add phone number field to users
Revision ID: a9119be92164
Revises: b54929c22c86
Create Date: 2018-09-28 02:38:40.626282
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "a9119be92164"
down_revision = "b54929c22c86"
branch_labels = None
depends_on = None
def upgrade():
op.add_column("user", sa.Column("tg_phone", sa.String(), nullable=True))
def downgrade():
with op.batch_alter_table("user") as batch_op:
batch_op.drop_column("tg_phone")
@@ -0,0 +1,25 @@
"""Add portal-specific config
Revision ID: b54929c22c86
Revises: d5f7b8b4b456
Create Date: 2018-09-24 23:40:33.528710
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "b54929c22c86"
down_revision = "d5f7b8b4b456"
branch_labels = None
depends_on = None
def upgrade():
op.add_column("portal", sa.Column("config", sa.Text(), nullable=True))
def downgrade():
with op.batch_alter_table("portal") as batch_op:
batch_op.drop_column("config")
@@ -20,4 +20,5 @@ def upgrade():
def downgrade():
op.drop_column('puppet', 'displayname_source')
with op.batch_alter_table("puppet") as batch_op:
batch_op.drop_column('displayname_source')
+27 -17
View File
@@ -27,6 +27,9 @@ appservice:
# SQLite: sqlite:///filename.db
# Postgres: postgres://username:password@hostname/dbname
database: sqlite:///mautrix-telegram.db
# Whether or not to use SQLAlchemy Core for common database actions. Use if the bridge is
# being bottlenecked on ORM commits. Only supported with PostgreSQL.
sqlalchemy_core_mode: false
# Public part of web server for out-of-Matrix interaction with the bridge.
# Used for things like login if the user wants to make sure the 2FA password isn't stored in
@@ -95,15 +98,6 @@ bridge:
- username
- phone number
# Show message editing as a reply to the original message.
# If this is false, message edits are not shown at all, as Matrix does not support editing yet.
edits_as_replies: false
# Highlight changed/added parts in edits. Requires lxml.
highlight_edits: false
# Whether or not Matrix bot messages (type m.notice) should be bridged.
bridge_notices: true
# Whether to bridge Telegram bot messages as m.notices or m.texts.
bot_messages_as_notices: true
# Maximum number of members to sync per portal when starting up. Other members will be
# synced when they send messages. The maximum is 10000, after which the Telegram server
# will not send any more members.
@@ -119,21 +113,16 @@ bridge:
# Allow logging in within Matrix. If false, the only way to log in is using the out-of-Matrix
# login website (see appservice.public config section)
allow_matrix_login: true
# Use inline images instead of m.image to make rich captions possible.
# N.B. Inline images are not supported on all clients (e.g. Riot iOS).
inline_images: false
# Whether or not to bridge plaintext highlights.
# Only enable this if your displayname_template has some static part that the bridge can use to
# reliably identify what is a plaintext highlight.
plaintext_highlights: false
# Highlight changed/added parts in edits. Requires lxml.
highlight_edits: false
# Whether or not to make portals of publicly joinable channels/supergroups publicly joinable on Matrix.
public_portals: true
# Whether to send stickers as the new native m.sticker type or normal m.images.
# Old versions of Riot don't support the new type at all.
# Remember that proper sticker support always requires Pillow to convert webp into png.
native_stickers: true
# Whether or not to fetch and handle Telegram updates at startup from the time the bridge was down.
# WARNING: Probably buggy, might get stuck in infinite loop.
# Currently only works for private chats and normal groups.
catch_up: false
# Whether or not to use /sync to get presence, read receipts and typing notifications when using
# your own Matrix account as the Matrix puppet for your Telegram account.
@@ -149,6 +138,27 @@ bridge:
# You might need to increase this on high-traffic bridge instances.
cache_queue_length: 20
# Show message editing as a reply to the original message.
# If this is false, message edits are not shown at all, as Matrix does not support editing yet.
edits_as_replies: false
bridge_notices:
# Whether or not Matrix bot messages (type m.notice) should be bridged.
default: false
# List of user IDs for whom the previous flag is flipped.
# e.g. if bridge_notices.default is false, notices from other users will not be bridged, but
# notices from users listed here will be bridged.
exceptions:
- "@importantbot:example.com"
# Whether to bridge Telegram bot messages as m.notices or m.texts.
bot_messages_as_notices: true
# Use inline images instead of a separate message for the caption.
# N.B. Inline images are not supported on all clients (e.g. Riot iOS).
inline_images: false
# Whether to send stickers as the new native m.sticker type or normal m.images.
# Old versions of Riot don't support the new type at all.
# Remember that proper sticker support always requires Pillow to convert webp into png.
native_stickers: true
# The formats to use when sending messages to Telegram via the relay bot.
#
# Telegram doesn't have built-in emotes, so the m.emote format is also used for non-relaybot users.
+1 -1
View File
@@ -1,2 +1,2 @@
__version__ = "0.3.0"
__version__ = "0.4.0"
__author__ = "Tulir Asokan <tulir@maunium.net>"
+15 -5
View File
@@ -14,11 +14,13 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional
from typing import Coroutine, List
import argparse
import asyncio
import logging.config
import sys
import copy
import signal
from sqlalchemy import orm
import sqlalchemy as sql
@@ -66,7 +68,7 @@ if args.generate_registration:
print(f"Registration generated and saved to {config.registration_path}")
sys.exit(0)
logging.config.dictConfig(config["logging"])
logging.config.dictConfig(copy.deepcopy(config["logging"]))
log = logging.getLogger("mau.init") # type: logging.Logger
log.debug(f"Initializing mautrix-telegram {__version__}")
@@ -78,6 +80,11 @@ Base.metadata.bind = db_engine
session_container = AlchemySessionContainer(engine=db_engine, session=db_session,
table_base=Base, table_prefix="telethon_",
manage_tables=False)
if config["appservice.sqlalchemy_core_mode"]:
try:
session_container.core_mode = True
except AttributeError:
log.error("Current version of teleton-session-sqlalchemy does not support core mode")
loop = asyncio.get_event_loop() # type: asyncio.AbstractEventLoop
@@ -106,7 +113,7 @@ if config["appservice.provisioning.enabled"]:
context.provisioning_api = provisioning_api
with appserv.run(config["appservice.hostname"], config["appservice.port"]) as start:
init_db(db_session)
init_db(db_session, db_engine)
init_abstract_user(context)
context.bot = init_bot(context)
context.mx = MatrixHandler(context)
@@ -115,18 +122,21 @@ with appserv.run(config["appservice.hostname"], config["appservice.port"]) as st
startup_actions = (init_puppet(context) +
init_user(context) +
[start,
context.mx.init_as_bot()])
context.mx.init_as_bot()]) # type: List[Coroutine]
if context.bot:
startup_actions.append(context.bot.start())
signal.signal(signal.SIGINT, signal.default_int_handler)
signal.signal(signal.SIGTERM, signal.default_int_handler)
try:
log.debug("Initialization complete, running startup actions")
loop.run_until_complete(asyncio.gather(*startup_actions, loop=loop))
log.debug("Startup actions complete, now running forever")
loop.run_forever()
except KeyboardInterrupt:
log.debug("Keyboard interrupt received, stopping clients")
log.debug("Interrupt received, stopping clients")
loop.run_until_complete(
asyncio.gather(*[user.stop() for user in User.by_tgid.values()], loop=loop))
log.debug("Clients stopped, shutting down")
+63 -48
View File
@@ -35,6 +35,7 @@ from alchemysession import AlchemySessionContainer
from . import portal as po, puppet as pu, __version__
from .db import Message as DBMessage
from .types import TelegramID, MatrixUserID
from .tgclient import MautrixTelegramClient
if TYPE_CHECKING:
@@ -60,17 +61,18 @@ class AbstractUser(ABC):
bot = None # type: Bot
ignore_incoming_bot_events = True # type: bool
def __init__(self):
def __init__(self) -> None:
self.is_admin = False # type: bool
self.matrix_puppet_whitelisted = False # type: bool
self.puppet_whitelisted = False # type: bool
self.whitelisted = False # type: bool
self.relaybot_whitelisted = False # type: bool
self.client = None # type: MautrixTelegramClient
self.tgid = None # type: int
self.mxid = None # type: str
self.tgid = None # type: TelegramID
self.mxid = None # type: MatrixUserID
self.is_relaybot = False # type: bool
self.is_bot = False # type: bool
self.relaybot = None # type: Optional[Bot]
@property
def connected(self) -> bool:
@@ -93,7 +95,7 @@ class AbstractUser(ABC):
config["telegram.proxy.rdns"],
config["telegram.proxy.username"], config["telegram.proxy.password"])
def _init_client(self):
def _init_client(self) -> None:
self.log.debug(f"Initializing client for {self.name}")
device = f"{platform.system()} {platform.release()}"
sysversion = MautrixTelegramClient.__version__
@@ -114,18 +116,18 @@ class AbstractUser(ABC):
return False
@abstractmethod
async def post_login(self):
async def post_login(self) -> None:
raise NotImplementedError()
@abstractmethod
def register_portal(self, portal: po.Portal):
def register_portal(self, portal: po.Portal) -> None:
raise NotImplementedError()
@abstractmethod
def unregister_portal(self, portal: po.Portal):
def unregister_portal(self, portal: po.Portal) -> None:
raise NotImplementedError()
async def _update_catch(self, update: TypeUpdate):
async def _update_catch(self, update: TypeUpdate) -> None:
try:
if not await self.update(update):
await self._update(update)
@@ -147,21 +149,21 @@ class AbstractUser(ABC):
raise NotImplementedError()
async def is_logged_in(self) -> bool:
return self.client and await self.client.is_user_authorized()
return self.client and self.client.is_connected() and await self.client.is_user_authorized()
async def has_full_access(self, allow_bot: bool = False) -> bool:
return (self.puppet_whitelisted
and (not self.is_bot or allow_bot)
and await self.is_logged_in())
async def start(self, delete_unless_authenticated: bool = False) -> "AbstractUser":
async def start(self, delete_unless_authenticated: bool = False) -> 'AbstractUser':
if not self.client:
self._init_client()
await self.client.connect()
self.log.debug("%s connected: %s", self.mxid, self.connected)
return self
async def ensure_started(self, even_if_no_session=False) -> "AbstractUser":
async def ensure_started(self, even_if_no_session=False) -> 'AbstractUser':
if not self.puppet_whitelisted:
return self
self.log.debug("ensure_started(%s, connected=%s, even_if_no_session=%s, session_count=%s)",
@@ -175,13 +177,13 @@ class AbstractUser(ABC):
await self.start(delete_unless_authenticated=not even_if_no_session)
return self
async def stop(self):
async def stop(self) -> None:
await self.client.disconnect()
self.client = None
# region Telegram update handling
async def _update(self, update: TypeUpdate):
async def _update(self, update: TypeUpdate) -> None:
if isinstance(update, (UpdateShortChatMessage, UpdateShortMessage, UpdateNewChannelMessage,
UpdateNewMessage, UpdateEditMessage, UpdateEditChannelMessage)):
await self.update_message(update)
@@ -207,55 +209,63 @@ class AbstractUser(ABC):
self.log.debug("Unhandled update: %s", update)
@staticmethod
async def update_pinned_messages(update: UpdateChannelPinnedMessage):
portal = po.Portal.get_by_tgid(update.channel_id)
async def update_pinned_messages(update: UpdateChannelPinnedMessage) -> None:
portal = po.Portal.get_by_tgid(TelegramID(update.channel_id))
if portal and portal.mxid:
await portal.receive_telegram_pin_id(update.id)
@staticmethod
async def update_participants(update: UpdateChatParticipants):
portal = po.Portal.get_by_tgid(update.participants.chat_id)
async def update_participants(update: UpdateChatParticipants) -> None:
portal = po.Portal.get_by_tgid(TelegramID(update.participants.chat_id))
if portal and portal.mxid:
await portal.update_telegram_participants(update.participants.participants)
async def update_read_receipt(self, update: UpdateReadHistoryOutbox):
async def update_read_receipt(self, update: UpdateReadHistoryOutbox) -> None:
if not isinstance(update.peer, PeerUser):
self.log.debug("Unexpected read receipt peer: %s", update.peer)
return
portal = po.Portal.get_by_tgid(update.peer.user_id, self.tgid)
portal = po.Portal.get_by_tgid(TelegramID(update.peer.user_id), self.tgid)
if not portal or not portal.mxid:
return
# We check that these are user read receipts, so tg_space is always the user ID.
message = DBMessage.query.get((update.max_id, self.tgid))
message = DBMessage.get_by_tgid(update.max_id, self.tgid)
if not message:
return
puppet = pu.Puppet.get(update.peer.user_id)
puppet = pu.Puppet.get(TelegramID(update.peer.user_id))
await puppet.intent.mark_read(portal.mxid, message.mxid)
async def update_admin(self, update: Union[UpdateChatAdmins, UpdateChatParticipantAdmin]):
async def update_admin(self,
update: Union[UpdateChatAdmins, UpdateChatParticipantAdmin]) -> None:
# TODO duplication not checked
portal = po.Portal.get_by_tgid(update.chat_id, peer_type="chat")
portal = po.Portal.get_by_tgid(TelegramID(update.chat_id), peer_type="chat")
if not portal or not portal.mxid:
return
if isinstance(update, UpdateChatAdmins):
await portal.set_telegram_admins_enabled(update.enabled)
elif isinstance(update, UpdateChatParticipantAdmin):
await portal.set_telegram_admin(update.user_id)
await portal.set_telegram_admin(TelegramID(update.user_id))
else:
self.log.warning("Unexpected admin status update: %s", update)
async def update_typing(self, update: Union[UpdateUserTyping, UpdateChatUserTyping]):
async def update_typing(self, update: Union[UpdateUserTyping, UpdateChatUserTyping]) -> None:
if isinstance(update, UpdateUserTyping):
portal = po.Portal.get_by_tgid(update.user_id, self.tgid, "user")
portal = po.Portal.get_by_tgid(TelegramID(update.user_id), self.tgid, "user")
else:
portal = po.Portal.get_by_tgid(update.chat_id, peer_type="chat")
sender = pu.Puppet.get(update.user_id)
portal = po.Portal.get_by_tgid(TelegramID(update.chat_id), peer_type="chat")
if not portal or not portal.mxid:
return
sender = pu.Puppet.get(TelegramID(update.user_id))
await portal.handle_telegram_typing(sender, update)
async def update_others_info(self, update: Union[UpdateUserName, UpdateUserPhoto]):
async def update_others_info(self, update: Union[UpdateUserName, UpdateUserPhoto]) -> None:
# TODO duplication not checked
puppet = pu.Puppet.get(update.user_id)
puppet = pu.Puppet.get(TelegramID(update.user_id))
if isinstance(update, UpdateUserName):
if await puppet.update_displayname(self, update):
puppet.save()
@@ -265,8 +275,8 @@ class AbstractUser(ABC):
else:
self.log.warning("Unexpected other user info update: %s", update)
async def update_status(self, update: UpdateUserStatus):
puppet = pu.Puppet.get(update.user_id)
async def update_status(self, update: UpdateUserStatus) -> None:
puppet = pu.Puppet.get(TelegramID(update.user_id))
if isinstance(update.status, UserStatusOnline):
await puppet.default_mxid_intent.set_presence("online")
elif isinstance(update.status, UserStatusOffline):
@@ -279,10 +289,10 @@ class AbstractUser(ABC):
Optional[pu.Puppet],
Optional[po.Portal]]:
if isinstance(update, UpdateShortChatMessage):
portal = po.Portal.get_by_tgid(update.chat_id, peer_type="chat")
sender = pu.Puppet.get(update.from_id)
portal = po.Portal.get_by_tgid(TelegramID(update.chat_id), peer_type="chat")
sender = pu.Puppet.get(TelegramID(update.from_id))
elif isinstance(update, UpdateShortMessage):
portal = po.Portal.get_by_tgid(update.user_id, self.tgid, "user")
portal = po.Portal.get_by_tgid(TelegramID(update.user_id), self.tgid, "user")
sender = pu.Puppet.get(self.tgid if update.out else update.user_id)
elif isinstance(update, (UpdateNewMessage, UpdateNewChannelMessage,
UpdateEditMessage, UpdateEditChannelMessage)):
@@ -300,7 +310,7 @@ class AbstractUser(ABC):
return update, sender, portal
@staticmethod
async def _try_redact(portal: po.Portal, message: DBMessage):
async def _try_redact(portal: po.Portal, message: DBMessage) -> None:
if not portal:
return
try:
@@ -308,40 +318,45 @@ class AbstractUser(ABC):
except MatrixRequestError:
pass
async def delete_message(self, update: UpdateDeleteMessages):
async def delete_message(self, update: UpdateDeleteMessages) -> None:
if len(update.messages) > MAX_DELETIONS:
return
for message in update.messages:
message = DBMessage.query.get((message, self.tgid))
message = DBMessage.get_by_tgid(TelegramID(message), self.tgid)
if not message:
continue
self.db.delete(message)
number_left = DBMessage.query.filter(DBMessage.mxid == message.mxid,
DBMessage.mx_room == message.mx_room).count()
message.delete()
number_left = DBMessage.count_spaces_by_mxid(message.mxid, message.mx_room)
if number_left == 0:
portal = po.Portal.get_by_mxid(message.mx_room)
await self._try_redact(portal, message)
self.db.commit()
async def delete_channel_message(self, update: UpdateDeleteChannelMessages):
async def delete_channel_message(self, update: UpdateDeleteChannelMessages) -> None:
if len(update.messages) > MAX_DELETIONS:
return
portal = po.Portal.get_by_tgid(update.channel_id)
portal = po.Portal.get_by_tgid(TelegramID(update.channel_id))
if not portal:
return
for message in update.messages:
message = DBMessage.query.get((message, portal.tgid))
message = DBMessage.get_by_tgid(TelegramID(message), portal.tgid)
if not message:
continue
self.db.delete(message)
message.delete()
await self._try_redact(portal, message)
self.db.commit()
async def update_message(self, original_update: UpdateMessage):
async def update_message(self, original_update: UpdateMessage) -> None:
update, sender, portal = self.get_message_details(original_update)
if self.is_bot and not portal.mxid:
self.log.debug(f"Ignoring message received by bot in unbridged chat %s",
portal.tgid_log)
return
if self.ignore_incoming_bot_events and self.bot and sender.id == self.bot.tgid:
self.log.debug(f"Ignoring relaybot-sent message %s to %s", update, portal.tgid_log)
return
@@ -369,9 +384,9 @@ class AbstractUser(ABC):
# endregion
def init(context: "Context"):
def init(context: "Context") -> None:
global config, MAX_DELETIONS
AbstractUser.az, AbstractUser.db, config, AbstractUser.loop, AbstractUser.relaybot = context
AbstractUser.az, AbstractUser.db, config, AbstractUser.loop, AbstractUser.relaybot = context.core
AbstractUser.ignore_incoming_bot_events = config["bridge.relaybot.ignore_own_incoming_events"]
AbstractUser.session_container = context.session_container
MAX_DELETIONS = config.get("bridge.max_telegram_delete", 10)
+59 -47
View File
@@ -14,21 +14,27 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Awaitable, Callable, Pattern, Dict, TYPE_CHECKING
from typing import Awaitable, Callable, Dict, List, Optional, Pattern, TYPE_CHECKING
import logging
import re
from telethon.tl.types import *
from telethon.tl.types import (
ChannelParticipantAdmin, ChannelParticipantCreator, ChatForbidden, ChatParticipantAdmin,
ChatParticipantCreator, InputChannel, InputUser, Message, MessageActionChatAddUser,
MessageActionChatDeleteUser, MessageEntityBotCommand, MessageService, PeerChannel, PeerChat,
TypePeer, UpdateNewChannelMessage, UpdateNewMessage)
from telethon.tl.functions.messages import GetChatsRequest, GetFullChatRequest
from telethon.tl.functions.channels import GetChannelsRequest, GetParticipantRequest
from telethon.errors import ChannelInvalidError, ChannelPrivateError
from .types import MatrixUserID
from .abstract_user import AbstractUser
from .db import BotChat
from . import puppet as pu, portal as po, user as u
if TYPE_CHECKING:
from .config import Config
from .context import Context
config = None # type: Config
@@ -39,7 +45,7 @@ class Bot(AbstractUser):
log = logging.getLogger("mau.bot") # type: logging.Logger
mxid_regex = re.compile("@.+:.+") # type: Pattern
def __init__(self, token: str):
def __init__(self, token: str) -> None:
super().__init__()
self.token = token # type: str
self.puppet_whitelisted = True # type: bool
@@ -53,46 +59,46 @@ class Bot(AbstractUser):
self.whitelist_group_admins = (config["bridge.relaybot.whitelist_group_admins"]
or False) # type: bool
async def init_permissions(self):
async def init_permissions(self) -> None:
whitelist = config["bridge.relaybot.whitelist"] or []
for id in whitelist:
if isinstance(id, str):
entity = await self.client.get_input_entity(id)
for user_id in whitelist:
if isinstance(user_id, str):
entity = await self.client.get_input_entity(user_id)
if isinstance(entity, InputUser):
id = entity.user_id
user_id = entity.user_id
else:
id = None
if isinstance(id, int):
self.tg_whitelist.append(id)
user_id = None
if isinstance(user_id, int):
self.tg_whitelist.append(user_id)
async def start(self, delete_unless_authenticated: bool = False) -> "Bot":
async def start(self, delete_unless_authenticated: bool = False) -> 'Bot':
await super().start(delete_unless_authenticated)
if not await self.is_logged_in():
await self.client.sign_in(bot_token=self.token)
await self.post_login()
return self
async def post_login(self):
async def post_login(self) -> None:
await self.init_permissions()
info = await self.client.get_me()
self.tgid = info.id
self.username = info.username
self.mxid = pu.Puppet.get_mxid_from_id(self.tgid)
chat_ids = [id for id, type in self.chats.items() if type == "chat"]
chat_ids = [chat_id for chat_id, chat_type in self.chats.items() if chat_type == "chat"]
response = await self.client(GetChatsRequest(chat_ids))
for chat in response.chats:
if isinstance(chat, ChatForbidden) or chat.left or chat.deactivated:
self.remove_chat(chat.id)
channel_ids = [InputChannel(id, 0)
for id, type in self.chats.items()
if type == "channel"]
for id in channel_ids:
channel_ids = [InputChannel(chat_id, 0)
for chat_id, chat_type in self.chats.items()
if chat_type == "channel"]
for channel_id in channel_ids:
try:
await self.client(GetChannelsRequest([id]))
await self.client(GetChannelsRequest([channel_id]))
except (ChannelPrivateError, ChannelInvalidError):
self.remove_chat(id.channel_id)
self.remove_chat(channel_id.channel_id)
if config["bridge.catch_up"]:
try:
@@ -100,24 +106,24 @@ class Bot(AbstractUser):
except Exception:
self.log.exception("Failed to run catch_up() for bot")
def register_portal(self, portal: po.Portal):
def register_portal(self, portal: po.Portal) -> None:
self.add_chat(portal.tgid, portal.peer_type)
def unregister_portal(self, portal: po.Portal):
def unregister_portal(self, portal: po.Portal) -> None:
self.remove_chat(portal.tgid)
def add_chat(self, id: int, type: str):
if id not in self.chats:
self.chats[id] = type
self.db.add(BotChat(id=id, type=type))
def add_chat(self, chat_id: int, chat_type: str) -> None:
if chat_id not in self.chats:
self.chats[chat_id] = chat_type
self.db.add(BotChat(id=chat_id, type=chat_type))
self.db.commit()
def remove_chat(self, id: int):
def remove_chat(self, chat_id: int) -> None:
try:
del self.chats[id]
del self.chats[chat_id]
except KeyError:
pass
existing_chat = BotChat.query.get(id)
existing_chat = BotChat.query.get(chat_id)
if existing_chat:
self.db.delete(existing_chat)
self.db.commit()
@@ -141,6 +147,7 @@ class Bot(AbstractUser):
for p in participants:
if p.user_id == tgid:
return isinstance(p, (ChatParticipantCreator, ChatParticipantAdmin))
return False
async def check_can_use_commands(self, event: Message, reply: ReplyFunc) -> bool:
if not await self._can_use_commands(event.to_id, event.from_id):
@@ -148,7 +155,7 @@ class Bot(AbstractUser):
return False
return True
async def handle_command_portal(self, portal: po.Portal, reply: ReplyFunc):
async def handle_command_portal(self, portal: po.Portal, reply: ReplyFunc) -> None:
if not config["bridge.relaybot.authless_portals"]:
return await reply("This bridge doesn't allow portal creation from Telegram.")
@@ -164,15 +171,16 @@ class Bot(AbstractUser):
return await reply(
"Portal is not public. Use `/invite <mxid>` to get an invite.")
async def handle_command_invite(self, portal: po.Portal, reply: ReplyFunc, mxid: str):
if len(mxid) == 0:
async def handle_command_invite(self, portal: po.Portal, reply: ReplyFunc,
mxid_input: MatrixUserID) -> Message:
if len(mxid_input) == 0:
return await reply("Usage: `/invite <mxid>`")
elif not portal.mxid:
return await reply("Portal does not have Matrix room. "
"Create one with /portal first.")
if not self.mxid_regex.match(mxid):
if not self.mxid_regex.match(mxid_input):
return await reply("That doesn't look like a Matrix ID.")
user = await u.User.get_by_mxid(mxid).ensure_started()
user = await u.User.get_by_mxid(MatrixUserID(mxid_input)).ensure_started()
if not user.relaybot_whitelisted:
return await reply("That user is not whitelisted to use the bridge.")
elif await user.is_logged_in():
@@ -183,7 +191,8 @@ class Bot(AbstractUser):
await portal.main_intent.invite(portal.mxid, user.mxid)
return await reply(f"Invited `{user.mxid}` to the portal.")
def handle_command_id(self, message: Message, reply: ReplyFunc):
@staticmethod
def handle_command_id(message: Message, reply: ReplyFunc) -> Awaitable[Message]:
# Provide the prefixed ID to the user so that the user wouldn't need to specify whether the
# chat is a normal group or a supergroup/channel when using the ID.
if isinstance(message.to_id, PeerChannel):
@@ -205,8 +214,8 @@ class Bot(AbstractUser):
return False
async def handle_command(self, message: Message):
def reply(reply_text):
async def handle_command(self, message: Message) -> None:
def reply(reply_text: str) -> Awaitable[Message]:
return self.client.send_message(message.to_id, reply_text, reply_to=message.id)
text = message.message
@@ -227,36 +236,39 @@ class Bot(AbstractUser):
mxid = text[text.index(" ") + 1:]
except ValueError:
mxid = ""
await self.handle_command_invite(portal, reply, mxid=mxid)
await self.handle_command_invite(portal, reply, mxid_input=mxid)
def handle_service_message(self, message: MessageService):
def handle_service_message(self, message: MessageService) -> None:
to_id = message.to_id
if isinstance(to_id, PeerChannel):
to_id = to_id.channel_id
type = "channel"
chat_type = "channel"
elif isinstance(to_id, PeerChat):
to_id = to_id.chat_id
type = "chat"
chat_type = "chat"
else:
return
action = message.action
if isinstance(action, MessageActionChatAddUser) and self.tgid in action.users:
self.add_chat(to_id, type)
self.add_chat(to_id, chat_type)
elif isinstance(action, MessageActionChatDeleteUser) and action.user_id == self.tgid:
self.remove_chat(to_id)
async def update(self, update):
async def update(self, update) -> bool:
if not isinstance(update, (UpdateNewMessage, UpdateNewChannelMessage)):
return
return False
if isinstance(update.message, MessageService):
return self.handle_service_message(update.message)
self.handle_service_message(update.message)
return False
is_command = (isinstance(update.message, Message)
and update.message.entities and len(update.message.entities) > 0
and isinstance(update.message.entities[0], MessageEntityBotCommand))
if is_command:
return await self.handle_command(update.message)
await self.handle_command(update.message)
return True
return False
def is_in_chat(self, peer_id) -> bool:
return peer_id in self.chats
@@ -266,7 +278,7 @@ class Bot(AbstractUser):
return "bot"
def init(context) -> Optional[Bot]:
def init(context: 'Context') -> Optional[Bot]:
global config
config = context.config
token = config["telegram.bot_token"]
+39 -22
View File
@@ -14,20 +14,24 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict
from typing import Any, Dict, Optional
import asyncio
from telethon.errors import *
from telethon.errors import (
AccessTokenExpiredError, AccessTokenInvalidError, FirstNameInvalidError, FloodWaitError,
PasswordHashInvalidError, PhoneCodeExpiredError, PhoneCodeInvalidError,
PhoneNumberAppSignupForbiddenError, PhoneNumberBannedError, PhoneNumberFloodError,
PhoneNumberOccupiedError, PhoneNumberUnoccupiedError, SessionPasswordNeededError)
from . import command_handler, CommandEvent, SECTION_AUTH
from .. import puppet as pu
from .. import puppet as pu, user as u
from ..util import format_duration
@command_handler(needs_auth=False,
help_section=SECTION_AUTH,
help_text="Check if you're logged into Telegram.")
async def ping(evt: CommandEvent):
async def ping(evt: CommandEvent) -> Optional[Dict]:
me = await evt.sender.client.get_me() if await evt.sender.is_logged_in() else None
if me:
return await evt.reply(f"You're logged in as @{me.username}")
@@ -38,7 +42,7 @@ async def ping(evt: CommandEvent):
@command_handler(needs_auth=False, needs_puppeting=False,
help_section=SECTION_AUTH,
help_text="Get the info of the message relay Telegram bot.")
async def ping_bot(evt: CommandEvent):
async def ping_bot(evt: CommandEvent) -> Optional[Dict]:
if not evt.tgbot:
return await evt.reply("Telegram message relay bot not configured.")
bot_info = await evt.tgbot.client.get_me()
@@ -53,19 +57,19 @@ async def ping_bot(evt: CommandEvent):
help_section=SECTION_AUTH,
help_text="Revert your Telegram account's Matrix puppet to use the default Matrix "
"account.")
async def logout_matrix(evt: CommandEvent):
async def logout_matrix(evt: CommandEvent) -> Optional[Dict]:
puppet = pu.Puppet.get(evt.sender.tgid)
if not puppet.is_real_user:
return await evt.reply("You are not logged in with your Matrix account.")
await puppet.switch_mxid(None, None)
await evt.reply("Reverted your Telegram account's Matrix puppet back to the default.")
return await evt.reply("Reverted your Telegram account's Matrix puppet back to the default.")
@command_handler(needs_auth=True, management_only=True, needs_matrix_puppeting=True,
help_section=SECTION_AUTH,
help_text="Replace your Telegram account's Matrix puppet with your own Matrix "
"account")
async def login_matrix(evt: CommandEvent):
async def login_matrix(evt: CommandEvent) -> Optional[Dict]:
puppet = pu.Puppet.get(evt.sender.tgid)
if puppet.is_real_user:
return await evt.reply("You have already logged in with your Matrix account. "
@@ -96,7 +100,7 @@ async def login_matrix(evt: CommandEvent):
return await evt.reply("This bridge instance has been configured to not allow logging in.")
async def enter_matrix_token(evt: CommandEvent):
async def enter_matrix_token(evt: CommandEvent) -> Dict:
evt.sender.command_status = None
puppet = pu.Puppet.get(evt.sender.tgid)
@@ -105,10 +109,11 @@ async def enter_matrix_token(evt: CommandEvent):
"Log out with `$cmdprefix+sp logout-matrix` first.")
resp = await puppet.switch_mxid(" ".join(evt.args), evt.sender.mxid)
if resp == 2:
if resp == pu.PuppetError.OnlyLoginSelf:
return await evt.reply("You can only log in as your own Matrix user.")
elif resp == 1:
elif resp == pu.PuppetError.InvalidAccessToken:
return await evt.reply("Failed to verify access token.")
assert resp == pu.PuppetError.Success, "Encountered an unhandled PuppetError."
return await evt.reply(
f"Replaced your Telegram account's Matrix puppet with {puppet.custom_mxid}.")
@@ -117,7 +122,7 @@ async def enter_matrix_token(evt: CommandEvent):
help_section=SECTION_AUTH,
help_args="<_phone_> <_full name_>",
help_text="Register to Telegram")
async def register(evt: CommandEvent):
async def register(evt: CommandEvent) -> Optional[Dict]:
if await evt.sender.is_logged_in():
return await evt.reply("You are already logged in.")
elif len(evt.args) < 1:
@@ -134,9 +139,10 @@ async def register(evt: CommandEvent):
"action": "Register",
"full_name": full_name,
})
return None
async def enter_code_register(evt: CommandEvent):
async def enter_code_register(evt: CommandEvent) -> Dict:
if len(evt.args) == 0:
return await evt.reply("**Usage:** `$cmdprefix+sp <code>`")
try:
@@ -165,9 +171,9 @@ async def enter_code_register(evt: CommandEvent):
@command_handler(needs_auth=False, management_only=True,
help_section=SECTION_AUTH,
help_text="Get instructions on how to log in.")
async def login(evt: CommandEvent):
async def login(evt: CommandEvent) -> Optional[Dict]:
if await evt.sender.is_logged_in():
return await evt.reply("You are already logged in.")
return await evt.reply(f"You are already logged in as {evt.sender.human_tg_id}.")
allow_matrix_login = evt.config.get("bridge.allow_matrix_login", True)
if allow_matrix_login:
@@ -196,7 +202,8 @@ async def login(evt: CommandEvent):
return await evt.reply("This bridge instance has been configured to not allow logging in.")
async def request_code(evt: CommandEvent, phone_number: str, next_status: Dict[str, str]):
async def request_code(evt: CommandEvent, phone_number: str, next_status: Dict[str, Any]
) -> Dict:
ok = False
try:
await evt.sender.ensure_started(even_if_no_session=True)
@@ -228,7 +235,7 @@ async def request_code(evt: CommandEvent, phone_number: str, next_status: Dict[s
@command_handler(needs_auth=False)
async def enter_phone_or_token(evt: CommandEvent):
async def enter_phone_or_token(evt: CommandEvent) -> Optional[Dict]:
if len(evt.args) == 0:
return await evt.reply("**Usage:** `$cmdprefix+sp enter-phone-or-token <phone-or-token>`")
elif not evt.config.get("bridge.allow_matrix_login", True):
@@ -248,10 +255,11 @@ async def enter_phone_or_token(evt: CommandEvent):
"next": enter_code,
"action": "Login",
})
return None
@command_handler(needs_auth=False)
async def enter_code(evt: CommandEvent):
async def enter_code(evt: CommandEvent) -> Optional[Dict]:
if len(evt.args) == 0:
return await evt.reply("**Usage:** `$cmdprefix+sp enter-code <code>`")
elif not evt.config.get("bridge.allow_matrix_login", True):
@@ -263,10 +271,11 @@ async def enter_code(evt: CommandEvent):
evt.log.exception("Error sending phone code")
return await evt.reply("Unhandled exception while sending code. "
"Check console for more details.")
return None
@command_handler(needs_auth=False)
async def enter_password(evt: CommandEvent):
async def enter_password(evt: CommandEvent) -> Optional[Dict]:
if len(evt.args) == 0:
return await evt.reply("**Usage:** `$cmdprefix+sp enter-password <password>`")
elif not evt.config.get("bridge.allow_matrix_login", True):
@@ -282,15 +291,23 @@ async def enter_password(evt: CommandEvent):
evt.log.exception("Error sending password")
return await evt.reply("Unhandled exception while sending password. "
"Check console for more details.")
return None
async def sign_in(evt: CommandEvent, **sign_in_info):
async def sign_in(evt: CommandEvent, **sign_in_info) -> Dict:
try:
await evt.sender.ensure_started(even_if_no_session=True)
user = await evt.sender.client.sign_in(**sign_in_info)
existing_user = u.User.get_by_tgid(user.id)
if existing_user and existing_user != evt.sender:
await existing_user.log_out()
await evt.reply(f"[{existing_user.displayname}]"
f"(https://matrix.to/#/{existing_user.mxid})"
" was logged out from the account.")
asyncio.ensure_future(evt.sender.post_login(user), loop=evt.loop)
evt.sender.command_status = None
return await evt.reply(f"Successfully logged in as @{user.username}")
name = f"@{user.username}" if user.username else f"+{user.phone}"
return await evt.reply(f"Successfully logged in as {name}")
except PhoneCodeExpiredError:
return await evt.reply("Phone code expired. Try again with `$cmdprefix+sp login`.")
except PhoneCodeInvalidError:
@@ -309,7 +326,7 @@ async def sign_in(evt: CommandEvent, **sign_in_info):
@command_handler(needs_auth=True,
help_section=SECTION_AUTH,
help_text="Log out from Telegram.")
async def logout(evt: CommandEvent):
async def logout(evt: CommandEvent) -> Optional[Dict]:
if await evt.sender.log_out():
return await evt.reply("Logged out successfully.")
return await evt.reply("Failed to log out.")
+23 -21
View File
@@ -14,26 +14,27 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Tuple, List
from typing import Dict, List, NewType, Optional, Tuple, Union
from mautrix_appservice import MatrixRequestError, IntentAPI
from ..types import MatrixRoomID, MatrixUserID
from . import command_handler, CommandEvent, SECTION_ADMIN
from .. import puppet as pu, portal as po
ManagementRoomList = List[Tuple[str, str]]
RoomIDList = List[str]
ManagementRoom = NewType('ManagementRoom', Tuple[MatrixRoomID, MatrixUserID])
async def _find_rooms(intent: IntentAPI) -> Tuple[ManagementRoomList, RoomIDList,
List["po.Portal"], List["po.Portal"]]:
management_rooms = [] # type: ManagementRoomList
unidentified_rooms = [] # type: RoomIDList
async def _find_rooms(intent: IntentAPI) -> Tuple[List[ManagementRoom], List[MatrixRoomID],
List['po.Portal'], List['po.Portal']]:
management_rooms = [] # type: List[ManagementRoom]
unidentified_rooms = [] # type: List[MatrixRoomID]
portals = [] # type: List[po.Portal]
empty_portals = [] # type: List[po.Portal]
rooms = await intent.get_joined_rooms()
for room in rooms:
for room_str in rooms:
room = MatrixRoomID(room_str)
portal = po.Portal.get_by_mxid(room)
if not portal:
try:
@@ -41,11 +42,11 @@ async def _find_rooms(intent: IntentAPI) -> Tuple[ManagementRoomList, RoomIDList
except MatrixRequestError:
members = []
if len(members) == 2:
other_member = members[0] if members[0] != intent.mxid else members[1]
other_member = MatrixUserID(members[0] if members[0] != intent.mxid else members[1])
if pu.Puppet.get_id_from_mxid(other_member):
unidentified_rooms.append(room)
else:
management_rooms.append((room, other_member))
management_rooms.append(ManagementRoom((room, other_member)))
else:
unidentified_rooms.append(room)
else:
@@ -61,7 +62,7 @@ async def _find_rooms(intent: IntentAPI) -> Tuple[ManagementRoomList, RoomIDList
@command_handler(needs_admin=True, needs_auth=False, management_only=True, name="clean-rooms",
help_section=SECTION_ADMIN,
help_text="Clean up unused portal/management rooms.")
async def clean_rooms(evt: CommandEvent):
async def clean_rooms(evt: CommandEvent) -> Optional[Dict]:
management_rooms, unidentified_rooms, portals, empty_portals = await _find_rooms(evt.az.intent)
reply = ["#### Management rooms (M)"]
@@ -106,13 +107,14 @@ async def clean_rooms(evt: CommandEvent):
return await evt.reply("\n".join(reply))
async def set_rooms_to_clean(evt, management_rooms: ManagementRoomList,
unidentified_rooms: RoomIDList, portals: List["po.Portal"],
empty_portals: List["po.Portal"]):
async def set_rooms_to_clean(evt, management_rooms: List[ManagementRoom],
unidentified_rooms: List[MatrixRoomID], portals: List["po.Portal"],
empty_portals: List["po.Portal"]) -> None:
command = evt.args[0]
rooms_to_clean = []
rooms_to_clean = [] # type: List[Union[po.Portal, MatrixRoomID]]
if command == "clean-recommended":
rooms_to_clean = empty_portals + unidentified_rooms
rooms_to_clean += empty_portals
rooms_to_clean += unidentified_rooms
elif command == "clean-groups":
if len(evt.args) < 2:
return await evt.reply("**Usage:** `$cmdprefix+sp clean-groups [M][A][U][I]")
@@ -127,9 +129,9 @@ async def set_rooms_to_clean(evt, management_rooms: ManagementRoomList,
rooms_to_clean += empty_portals
elif command == "clean-range":
try:
range = evt.args[1]
group, range = range[0], range[1:]
start, end = range.split("-")
clean_range = evt.args[1]
group, clean_range = clean_range[0], clean_range[1:]
start, end = clean_range.split("-")
start, end = int(start), int(end)
if group == "M":
group = [room_id for (room_id, user_id) in management_rooms]
@@ -158,7 +160,7 @@ async def set_rooms_to_clean(evt, management_rooms: ManagementRoomList,
"`$cmdprefix+sp confirm-clean`.")
async def execute_room_cleanup(evt, rooms_to_clean):
async def execute_room_cleanup(evt, rooms_to_clean: List[Union[po.Portal, MatrixRoomID]]) -> None:
if len(evt.args) > 0 and evt.args[0] == "confirm-clean":
await evt.reply(f"Cleaning {len(rooms_to_clean)} rooms. "
"This might take a while.")
@@ -167,7 +169,7 @@ async def execute_room_cleanup(evt, rooms_to_clean):
if isinstance(room, po.Portal):
await room.cleanup_and_delete()
cleaned += 1
elif isinstance(room, str):
elif isinstance(room, str): # str is aliased by MatrixRoomID
await po.Portal.cleanup_room(evt.az.intent, room, message="Room deleted")
cleaned += 1
evt.sender.command_status = None
+57 -24
View File
@@ -14,19 +14,19 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import List, Dict, Callable, Optional
from collections import namedtuple
import markdown
from typing import Awaitable, Callable, Dict, List, NamedTuple, Optional
import commonmark
import logging
from telethon.errors import FloodWaitError
from ..types import MatrixRoomID
from ..util import format_duration
from .. import user as u, context as c
command_handlers = {} # type: Dict[str, CommandHandler]
HelpSection = namedtuple("HelpSection", "name order description")
HelpSection = NamedTuple('HelpSection', [('name', str), ('order', int), ('description', str)])
SECTION_GENERAL = HelpSection("General", 0, "")
SECTION_AUTH = HelpSection("Authentication", 10, "")
@@ -36,9 +36,30 @@ SECTION_MISC = HelpSection("Miscellaneous", 40, "")
SECTION_ADMIN = HelpSection("Administration", 50, "")
class HtmlEscapingRenderer(commonmark.HtmlRenderer):
def __init__(self, allow_html: bool = False):
super().__init__()
self.allow_html = allow_html
def lit(self, s):
if self.allow_html:
return super().lit(s)
return super().lit(s.replace("<", "&lt;").replace(">", "&gt;"))
def image(self, node, entering):
prev = self.allow_html
self.allow_html = True
super().image(node, entering)
self.allow_html = prev
md_parser = commonmark.Parser()
md_renderer = HtmlEscapingRenderer()
class CommandEvent:
def __init__(self, processor: "CommandProcessor", room: str, sender: u.User, command: str,
args: List[str], is_management: bool, is_portal: bool):
def __init__(self, processor: 'CommandProcessor', room: MatrixRoomID, sender: u.User,
command: str, args: List[str], is_management: bool, is_portal: bool) -> None:
self.az = processor.az
self.log = processor.log
self.loop = processor.loop
@@ -53,23 +74,25 @@ class CommandEvent:
self.is_management = is_management
self.is_portal = is_portal
def reply(self, message: str, allow_html: bool = False, render_markdown: bool = True):
def reply(self, message: str, allow_html: bool = False, render_markdown: bool = True
) -> Awaitable[Dict]:
message = message.replace("$cmdprefix+sp ",
"" if self.is_management else f"{self.command_prefix} ")
message = message.replace("$cmdprefix", self.command_prefix)
html = None
if render_markdown:
html = markdown.markdown(message, safe_mode="escape" if allow_html else False)
md_renderer.allow_html = allow_html
html = md_renderer.render(md_parser.parse(message))
elif allow_html:
html = message
return self.az.intent.send_notice(self.room_id, message, html=html)
class CommandHandler:
def __init__(self, handler: Callable[[CommandEvent], None], needs_auth: bool,
def __init__(self, handler: Callable[[CommandEvent], Awaitable[Dict]], needs_auth: bool,
needs_puppeting: bool, needs_matrix_puppeting: bool, needs_admin: bool,
management_only: bool, name: str, help_text: str, help_args: str,
help_section: HelpSection):
help_section: HelpSection) -> None:
self._handler = handler
self.needs_auth = needs_auth
self.needs_puppeting = needs_puppeting
@@ -103,7 +126,8 @@ class CommandHandler:
(not self.needs_admin or is_admin) and
(not self.needs_auth or is_logged_in))
async def __call__(self, evt: CommandEvent):
async def __call__(self, evt: CommandEvent
) -> Dict:
error = await self.get_permission_error(evt)
if error is not None:
return await evt.reply(error)
@@ -118,13 +142,21 @@ class CommandHandler:
return f"**{self.name}** {self._help_args} - {self._help_text}"
def command_handler(_func: Optional[Callable[[CommandEvent], None]] = None, *, needs_auth=True,
needs_puppeting=True, needs_matrix_puppeting=False, needs_admin=False,
management_only=False, name=None, help_text="", help_args="",
help_section=None):
def command_handler(_func: Optional[Callable[[CommandEvent], Awaitable[Dict]]] = None, *,
needs_auth: bool = True,
needs_puppeting: bool = True,
needs_matrix_puppeting: bool = False,
needs_admin: bool = False,
management_only: bool = False,
name: Optional[str] = None,
help_text: str = "",
help_args: str = "",
help_section: HelpSection = None
) -> Callable[[Callable[[CommandEvent], Awaitable[Optional[Dict]]]],
CommandHandler]:
input_name = name
def decorator(func: Callable[[CommandEvent], None]):
def decorator(func: Callable[[CommandEvent], Awaitable[Optional[Dict]]]) -> CommandHandler:
name = input_name or func.__name__.replace("_", "-")
handler = CommandHandler(func, needs_auth, needs_puppeting, needs_matrix_puppeting,
needs_admin, management_only, name, help_text, help_args,
@@ -138,27 +170,27 @@ def command_handler(_func: Optional[Callable[[CommandEvent], None]] = None, *, n
class CommandProcessor:
log = logging.getLogger("mau.commands")
def __init__(self, context: c.Context):
self.az, self.db, self.config, self.loop, self.tgbot = context
def __init__(self, context: c.Context) -> None:
self.az, self.db, self.config, self.loop, self.tgbot = context.core
self.public_website = context.public_website
self.command_prefix = self.config["bridge.command_prefix"]
async def handle(self, room: str, sender: u.User, command: str, args: List[str],
is_management: bool, is_portal: bool):
async def handle(self, room: MatrixRoomID, sender: u.User, command: str, args: List[str],
is_management: bool, is_portal: bool) -> Optional[Dict]:
evt = CommandEvent(self, room, sender, command, args, is_management, is_portal)
orig_command = command
command = command.lower()
try:
command = command_handlers[command]
handler = command_handlers[command]
except KeyError:
if sender.command_status and "next" in sender.command_status:
args.insert(0, orig_command)
evt.command = ""
command = sender.command_status["next"]
handler = sender.command_status["next"]
else:
command = command_handlers["unknown-command"]
handler = command_handlers["unknown-command"]
try:
await command(evt)
await handler(evt)
except FloodWaitError as e:
return await evt.reply(f"Flood error: Please wait {format_duration(e.seconds)}")
except Exception:
@@ -166,3 +198,4 @@ class CommandProcessor:
f"{evt.command} {' '.join(args)} from {sender.mxid}")
return await evt.reply("Unhandled error while handling command. "
"Check logs for more details.")
return None
+17 -14
View File
@@ -14,46 +14,49 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict, List, Optional, Tuple
from . import command_handler, CommandEvent, _command_handlers, SECTION_GENERAL
from .handler import HelpSection
@command_handler(needs_auth=False, needs_puppeting=False,
help_section=SECTION_GENERAL,
help_text="Cancel an ongoing action (such as login)")
def cancel(evt: CommandEvent):
async def cancel(evt: CommandEvent) -> Optional[Dict]:
if evt.sender.command_status:
action = evt.sender.command_status["action"]
evt.sender.command_status = None
return evt.reply(f"{action} cancelled.")
return await evt.reply(f"{action} cancelled.")
else:
return evt.reply("No ongoing command.")
return await evt.reply("No ongoing command.")
@command_handler(needs_auth=False, needs_puppeting=False)
def unknown_command(evt: CommandEvent):
return evt.reply("Unknown command. Try `$cmdprefix+sp help` for help.")
async def unknown_command(evt: CommandEvent) -> Optional[Dict]:
return await evt.reply("Unknown command. Try `$cmdprefix+sp help` for help.")
help_cache = {}
help_cache = {} # type: Dict[Tuple[bool, bool, bool, bool, bool], str]
async def _get_help_text(evt: CommandEvent):
async def _get_help_text(evt: CommandEvent) -> str:
cache_key = (evt.is_management, evt.sender.puppet_whitelisted,
evt.sender.matrix_puppet_whitelisted, evt.sender.is_admin,
await evt.sender.is_logged_in())
if cache_key not in help_cache:
help = {}
help_sections = {} # type: Dict[HelpSection, List[str]]
for handler in _command_handlers.values():
if handler.has_help and handler.has_permission(*cache_key):
help.setdefault(handler.help_section, [])
help[handler.help_section].append(handler.help + " ")
help = sorted(help.items(), key=lambda item: item[0].order)
help = ["#### {}\n{}\n".format(key.name, "\n".join(value)) for key, value in help]
help_sections.setdefault(handler.help_section, [])
help_sections[handler.help_section].append(handler.help + " ")
help_sorted = sorted(help_sections.items(), key=lambda item: item[0].order)
help = ["#### {}\n{}\n".format(key.name, "\n".join(value)) for key, value in help_sorted]
help_cache[cache_key] = "\n".join(help)
return help_cache[cache_key]
def _get_management_status(evt: CommandEvent):
def _get_management_status(evt: CommandEvent) -> str:
if evt.is_management:
return "This is a management room: prefixing commands with `$cmdprefix` is not required."
elif evt.is_portal:
@@ -65,5 +68,5 @@ def _get_management_status(evt: CommandEvent):
@command_handler(needs_auth=False, needs_puppeting=False,
help_section=SECTION_GENERAL,
help_text="Show this help message.")
async def help(evt: CommandEvent):
async def help(evt: CommandEvent) -> Optional[Dict]:
return await evt.reply(_get_management_status(evt) + "\n" + await _get_help_text(evt))
+169 -45
View File
@@ -14,14 +14,18 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional, Callable
from typing import Dict, Callable, Optional, Tuple, Coroutine, Awaitable
from io import StringIO
import asyncio
from telethon.errors import *
from telethon.errors import (ChatAdminRequiredError, UsernameInvalidError,
UsernameNotModifiedError, UsernameOccupiedError)
from telethon.tl.types import ChatForbidden, ChannelForbidden
from mautrix_appservice import MatrixRequestError, IntentAPI
from .. import portal as po, user as u
from ..types import MatrixRoomID, TelegramID
from ..config import yaml
from .. import portal as po, user as u, util
from . import (command_handler, CommandEvent,
SECTION_ADMIN, SECTION_CREATING_PORTALS, SECTION_PORTAL_MANAGEMENT)
@@ -30,7 +34,7 @@ from . import (command_handler, CommandEvent,
help_section=SECTION_ADMIN,
help_args="<_level_> [_mxid_]",
help_text="Set a temporary power level without affecting Telegram.")
async def set_power_level(evt: CommandEvent):
async def set_power_level(evt: CommandEvent) -> Dict:
try:
level = int(evt.args[0])
except KeyError:
@@ -45,11 +49,12 @@ async def set_power_level(evt: CommandEvent):
except MatrixRequestError:
evt.log.exception("Failed to set power level.")
return await evt.reply("Failed to set power level.")
return {}
@command_handler(help_section=SECTION_PORTAL_MANAGEMENT,
help_text="Get a Telegram invite link to the current chat.")
async def invite_link(evt: CommandEvent):
async def invite_link(evt: CommandEvent) -> Dict:
portal = po.Portal.get_by_mxid(evt.room_id)
if not portal:
return await evt.reply("This is not a portal room.")
@@ -66,7 +71,8 @@ async def invite_link(evt: CommandEvent):
return await evt.reply("You don't have the permission to create an invite link.")
async def user_has_power_level(room: str, intent, sender: u.User, event: str, default: int = 50):
async def user_has_power_level(room: str, intent, sender: u.User, event: str, default: int = 50
) -> bool:
if sender.is_admin:
return True
# Make sure the state store contains the power levels.
@@ -80,23 +86,26 @@ async def user_has_power_level(room: str, intent, sender: u.User, event: str, de
async def _get_portal_and_check_permission(evt: CommandEvent, permission: str,
action: Optional[str] = None):
room_id = evt.args[0] if len(evt.args) > 0 else evt.room_id
action: Optional[str] = None
) -> Optional[po.Portal]:
room_id = MatrixRoomID(evt.args[0]) if len(evt.args) > 0 else evt.room_id
portal = po.Portal.get_by_mxid(room_id)
if not portal:
that_this = "This" if room_id == evt.room_id else "That"
return await evt.reply(f"{that_this} is not a portal room."), False
await evt.reply(f"{that_this} is not a portal room.")
return None
if not await user_has_power_level(portal.mxid, evt.az.intent, evt.sender, permission):
action = action or f"{permission.replace('_', ' ')}s"
return await evt.reply(f"You do not have the permissions to {action} that portal."), False
return portal, True
await evt.reply(f"You do not have the permissions to {action} that portal.")
return None
return portal
def _get_portal_murder_function(action: str, room_id: str, function: Callable, command: str,
completed_message: str):
async def post_confirm(confirm):
completed_message: str) -> Dict:
async def post_confirm(confirm) -> Optional[Dict]:
confirm.sender.command_status = None
if len(confirm.args) > 0 and confirm.args[0] == f"confirm-{command}":
await function()
@@ -104,6 +113,7 @@ def _get_portal_murder_function(action: str, room_id: str, function: Callable, c
return await confirm.reply(completed_message)
else:
return await confirm.reply(f"{action} cancelled.")
return None
return {
"next": post_confirm,
@@ -116,10 +126,10 @@ def _get_portal_murder_function(action: str, room_id: str, function: Callable, c
help_text="Remove all users from the current portal room and forget the portal. "
"Only works for group chats; to delete a private chat portal, simply "
"leave the room.")
async def delete_portal(evt: CommandEvent):
portal, ok = await _get_portal_and_check_permission(evt, "unbridge")
if not ok:
return
async def delete_portal(evt: CommandEvent) -> Optional[Dict]:
portal = await _get_portal_and_check_permission(evt, "unbridge")
if not portal:
return None
evt.sender.command_status = _get_portal_murder_function("Portal deletion", portal.mxid,
portal.cleanup_and_delete, "delete",
@@ -137,10 +147,10 @@ async def delete_portal(evt: CommandEvent):
@command_handler(needs_auth=False, needs_puppeting=False,
help_section=SECTION_PORTAL_MANAGEMENT,
help_text="Remove puppets from the current portal room and forget the portal.")
async def unbridge(evt: CommandEvent):
portal, ok = await _get_portal_and_check_permission(evt, "unbridge")
if not ok:
return
async def unbridge(evt: CommandEvent) -> Optional[Dict]:
portal = await _get_portal_and_check_permission(evt, "unbridge")
if not portal:
return None
evt.sender.command_status = _get_portal_murder_function("Room unbridging", portal.mxid,
portal.unbridge, "unbridge",
@@ -156,11 +166,11 @@ async def unbridge(evt: CommandEvent):
help_text="Bridge the current Matrix room to the Telegram chat with the given "
"ID. The ID must be the prefixed version that you get with the `/id` "
"command of the Telegram-side bot.")
async def bridge(evt: CommandEvent):
async def bridge(evt: CommandEvent) -> Dict:
if len(evt.args) == 0:
return await evt.reply("**Usage:** "
"`$cmdprefix+sp bridge <Telegram chat ID> [Matrix room ID]`")
room_id = evt.args[1] if len(evt.args) > 1 else evt.room_id
room_id = MatrixRoomID(evt.args[1]) if len(evt.args) > 1 else evt.room_id
that_this = "This" if room_id == evt.room_id else "That"
portal = po.Portal.get_by_mxid(room_id)
@@ -171,12 +181,12 @@ async def bridge(evt: CommandEvent):
return await evt.reply(f"You do not have the permissions to bridge {that_this} room.")
# The /id bot command provides the prefixed ID, so we assume
tgid = evt.args[0]
if tgid.startswith("-100"):
tgid = int(tgid[4:])
tgid_str = evt.args[0]
if tgid_str.startswith("-100"):
tgid = TelegramID(int(tgid_str[4:]))
peer_type = "channel"
elif tgid.startswith("-"):
tgid = -int(tgid)
elif tgid_str.startswith("-"):
tgid = TelegramID(-int(tgid_str))
peer_type = "chat"
else:
return await evt.reply("That doesn't seem like a prefixed Telegram chat ID.\n\n"
@@ -188,7 +198,7 @@ async def bridge(evt: CommandEvent):
if not portal.allow_bridging():
return await evt.reply("This bridge doesn't allow bridging that Telegram chat.\n"
"If you're the bridge admin, try "
"`$cmdprefix+sp whitelist <Telegram chat ID>` first.")
"`$cmdprefix+sp filter whitelist <Telegram chat ID>` first.")
if portal.mxid:
has_portal_message = (
"That Telegram chat already has a portal at "
@@ -222,7 +232,8 @@ async def bridge(evt: CommandEvent):
"chat to this room, use `$cmdprefix+sp continue`")
async def cleanup_old_portal_while_bridging(evt: CommandEvent, portal: "po.Portal"):
async def cleanup_old_portal_while_bridging(evt: CommandEvent, portal: "po.Portal"
) -> Tuple[bool, Optional[Coroutine[None, None, None]]]:
if not portal.mxid:
await evt.reply("The portal seems to have lost its Matrix room between you"
"calling `$cmdprefix+sp bridge` and this command.\n\n"
@@ -245,7 +256,7 @@ async def cleanup_old_portal_while_bridging(evt: CommandEvent, portal: "po.Porta
return False, None
async def confirm_bridge(evt: CommandEvent):
async def confirm_bridge(evt: CommandEvent) -> Optional[Dict]:
status = evt.sender.command_status
try:
portal = po.Portal.get_by_tgid(status["tgid"], peer_type=status["peer_type"])
@@ -258,7 +269,7 @@ async def confirm_bridge(evt: CommandEvent):
if "mxid" in status:
ok, coro = await cleanup_old_portal_while_bridging(evt, portal)
if not ok:
return
return None
elif coro:
asyncio.ensure_future(coro, loop=evt.loop)
await evt.reply("Cleaning up previous portal room...")
@@ -271,6 +282,7 @@ async def confirm_bridge(evt: CommandEvent):
return await evt.reply("Please use `$cmdprefix+sp continue` to confirm the bridging or "
"`$cmdprefix+sp cancel` to cancel.")
evt.sender.command_status = None
is_logged_in = await evt.sender.is_logged_in()
user = evt.sender if is_logged_in else evt.tgbot
try:
@@ -302,7 +314,7 @@ async def confirm_bridge(evt: CommandEvent):
return await evt.reply("Bridging complete. Portal synchronization should begin momentarily.")
async def get_initial_state(intent: IntentAPI, room_id: str):
async def get_initial_state(intent: IntentAPI, room_id: str) -> Tuple[str, str, Dict]:
state = await intent.get_room_state(room_id)
title = None
about = None
@@ -328,7 +340,7 @@ async def get_initial_state(intent: IntentAPI, room_id: str):
help_text="Create a Telegram chat of the given type for the current Matrix room. "
"The type is either `group`, `supergroup` or `channel` (defaults to "
"`group`).")
async def create(evt: CommandEvent):
async def create(evt: CommandEvent) -> Dict:
type = evt.args[0] if len(evt.args) > 0 else "group"
if type not in {"chat", "group", "supergroup", "channel"}:
return await evt.reply(
@@ -363,7 +375,7 @@ async def create(evt: CommandEvent):
@command_handler(help_section=SECTION_PORTAL_MANAGEMENT,
help_text="Upgrade a normal Telegram group to a supergroup.")
async def upgrade(evt: CommandEvent):
async def upgrade(evt: CommandEvent) -> Dict:
portal = po.Portal.get_by_mxid(evt.room_id)
if not portal:
return await evt.reply("This is not a portal room.")
@@ -381,11 +393,122 @@ async def upgrade(evt: CommandEvent):
return await evt.reply(e.args[0])
@command_handler(help_section=SECTION_PORTAL_MANAGEMENT,
help_text="View or change per-portal settings.",
help_args="<`help`|_subcommand_> [...]")
async def config(evt: CommandEvent) -> None:
cmd = evt.args[0].lower() if len(evt.args) > 0 else "help"
if cmd not in ("view", "defaults", "set", "unset", "add", "del"):
await config_help(evt)
return
elif cmd == "defaults":
await config_defaults(evt)
return
portal = po.Portal.get_by_mxid(evt.room_id)
if not portal:
await evt.reply("This is not a portal room.")
return
elif cmd == "view":
await config_view(evt, portal)
return
key = evt.args[1] if len(evt.args) > 1 else None
value = yaml.load(" ".join(evt.args[2:])) if len(evt.args) > 2 else None
if cmd == "set":
await config_set(evt, portal, key, value)
elif cmd == "unset":
await config_unset(evt, portal, key)
elif cmd == "add" or cmd == "del":
await config_add_del(evt, portal, key, value, cmd)
else:
return
portal.save()
def config_help(evt: CommandEvent) -> Awaitable[Dict]:
return evt.reply("""**Usage:** `$cmdprefix config <subcommand> [...]`. Subcommands:
* **help** - View this help text.
* **view** - View the current config data.
* **defaults** - View the default config values.
* **set** <_key_> <_value_> - Set a config value.
* **unset** <_key_> - Remove a config value.
* **add** <_key_> <_value_> - Add a value to an array.
* **del** <_key_> <_value_> - Remove a value from an array.
""")
def config_view(evt: CommandEvent, portal: po.Portal) -> Awaitable[Dict]:
stream = StringIO()
yaml.dump(portal.local_config, stream)
return evt.reply(f"Room-specific config:\n\n```yaml\n{stream.getvalue()}```")
def config_defaults(evt: CommandEvent) -> Awaitable[Dict]:
stream = StringIO()
yaml.dump({
"edits_as_replies": evt.config["bridge.edits_as_replies"],
"bridge_notices": {
"default": evt.config["bridge.bridge_notices.default"],
"exceptions": evt.config["bridge.bridge_notices.exceptions"],
},
"bot_messages_as_notices": evt.config["bridge.bot_messages_as_notices"],
"inline_images": evt.config["bridge.inline_images"],
"native_stickers": evt.config["bridge.native_stickers"],
"message_formats": evt.config["bridge.message_formats"],
"state_event_formats": evt.config["bridge.state_event_formats"],
}, stream)
return evt.reply(f"Bridge instance wide config:\n\n```yaml\n{stream.getvalue()}```")
def config_set(evt: CommandEvent, portal: po.Portal, key: str, value: str) -> Awaitable[Dict]:
if not key or value is None:
return evt.reply(f"**Usage:** `$cmdprefix+sp config set <key> <value>`")
elif util.recursive_set(portal.local_config, key, value):
return evt.reply(f"Successfully set the value of `{key}` to `{value}`.")
else:
return evt.reply(f"Failed to set value of `{key}`. "
"Does the path contain non-map types?")
def config_unset(evt: CommandEvent, portal: po.Portal, key: str) -> Awaitable[Dict]:
if not key:
return evt.reply(f"**Usage:** `$cmdprefix+sp config unset <key>`")
elif util.recursive_del(portal.local_config, key):
return evt.reply(f"Successfully deleted `{key}` from config.")
else:
return evt.reply(f"`{key}` not found in config.")
def config_add_del(evt: CommandEvent, portal: po.Portal, key: str, value: str, cmd: str
) -> Awaitable[Dict]:
if not key or value is None:
return evt.reply(f"**Usage:** `$cmdprefix+sp config {cmd} <key> <value>`")
arr = util.recursive_get(portal.local_config, key)
if not arr:
return evt.reply(f"`{key}` not found in config. "
f"Maybe do `$cmdprefix+sp config set {key} []` first?")
elif not isinstance(arr, list):
return evt.reply("`{key}` does not seem to be an array.")
elif cmd == "add":
if value in arr:
return evt.reply(f"The array at `{key}` already contains `{value}`.")
arr.append(value)
return evt.reply(f"Successfully added `{value}` to the array at `{key}`")
else:
if value not in arr:
return evt.reply(f"The array at `{key}` does not contain `{value}`.")
arr.remove(value)
return evt.reply(f"Successfully removed `{value}` from the array at `{key}`")
@command_handler(help_section=SECTION_PORTAL_MANAGEMENT,
help_args="<_name_|`-`>",
help_text="Change the username of a supergroup/channel. "
"To disable, use a dash (`-`) as the name.")
async def group_name(evt: CommandEvent):
async def group_name(evt: CommandEvent) -> Dict:
if len(evt.args) == 0:
return await evt.reply("**Usage:** `$cmdprefix+sp group-name <name/->`")
@@ -421,7 +544,7 @@ async def group_name(evt: CommandEvent):
help_args="<`whitelist`|`blacklist`>",
help_text="Change whether the bridge will allow or disallow bridging rooms by "
"default.")
async def filter_mode(evt: CommandEvent):
async def filter_mode(evt: CommandEvent) -> Dict:
try:
mode = evt.args[0]
if mode not in ("whitelist", "blacklist"):
@@ -446,19 +569,19 @@ async def filter_mode(evt: CommandEvent):
help_section=SECTION_ADMIN,
help_args="<`whitelist`|`blacklist`> <_chat ID_>",
help_text="Allow or disallow bridging a specific chat.")
async def filter(evt: CommandEvent):
async def filter(evt: CommandEvent) -> Optional[Dict]:
try:
action = evt.args[0]
if action not in ("whitelist", "blacklist", "add", "remove"):
raise ValueError()
id = evt.args[1]
if id.startswith("-100"):
id = int(id[4:])
elif id.startswith("-"):
id = int(id[1:])
id_str = evt.args[1]
if id_str.startswith("-100"):
id = int(id_str[4:])
elif id_str.startswith("-"):
id = int(id_str[1:])
else:
id = int(id)
id = int(id_str)
except (IndexError, ValueError):
return await evt.reply("**Usage:** `$cmdprefix+sp filter <whitelist/blacklist> <chat ID>`")
@@ -471,7 +594,7 @@ async def filter(evt: CommandEvent):
if action in ("blacklist", "whitelist"):
action = "add" if mode == action else "remove"
def save():
def save() -> None:
evt.config["bridge.filter.list"] = list
evt.config.save()
po.Portal.filter_list = list
@@ -488,3 +611,4 @@ async def filter(evt: CommandEvent):
list.remove(id)
save()
return await evt.reply(f"Chat ID removed from {mode}.")
return None
+14 -8
View File
@@ -14,8 +14,13 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from telethon.errors import *
from typing import Awaitable, Dict, List, Optional, Tuple
import re
from telethon.errors import (
InviteHashInvalidError, InviteHashExpiredError, UserAlreadyParticipantError)
from telethon.tl.types import User as TLUser
from telethon.tl.types import TypeUpdates
from telethon.tl.functions.messages import ImportChatInviteRequest, CheckChatInviteRequest
from telethon.tl.functions.channels import JoinChannelRequest
@@ -26,7 +31,7 @@ from . import command_handler, CommandEvent, SECTION_MISC, SECTION_CREATING_PORT
@command_handler(help_section=SECTION_MISC,
help_args="[_-r|--remote_] <_query_>",
help_text="Search your contacts or the Telegram servers for users.")
async def search(evt: CommandEvent):
async def search(evt: CommandEvent) -> Optional[Dict]:
if len(evt.args) == 0:
return await evt.reply("**Usage:** `$cmdprefix+sp search [-r|--remote] <query>`")
@@ -47,7 +52,7 @@ async def search(evt: CommandEvent):
"Minimum length of remote query is 5 characters.")
return await evt.reply("No results 3:")
reply = []
reply = [] # type: List[str]
if remote:
reply += ["**Results from Telegram server:**", ""]
else:
@@ -68,7 +73,7 @@ async def search(evt: CommandEvent):
"either the internal user ID, the username or the phone number. "
"**N.B.** The phone numbers you start chats with must already be in "
"your contacts.")
async def private_message(evt: CommandEvent):
async def private_message(evt: CommandEvent) -> Optional[Dict]:
if len(evt.args) == 0:
return await evt.reply("**Usage:** `$cmdprefix+sp pm <user identifier>`")
@@ -87,7 +92,7 @@ async def private_message(evt: CommandEvent):
f"{pu.Puppet.get_displayname(user, False)}")
async def _join(evt: CommandEvent, arg: str):
async def _join(evt: CommandEvent, arg: str) -> Tuple[Optional[TypeUpdates], Optional[Dict]]:
if arg.startswith("joinchat/"):
invite_hash = arg[len("joinchat/"):]
try:
@@ -110,7 +115,7 @@ async def _join(evt: CommandEvent, arg: str):
@command_handler(help_section=SECTION_CREATING_PORTALS,
help_args="<_link_>",
help_text="Join a chat with an invite link.")
async def join(evt: CommandEvent):
async def join(evt: CommandEvent) -> Optional[Dict]:
if len(evt.args) == 0:
return await evt.reply("**Usage:** `$cmdprefix+sp join <invite link>`")
@@ -121,7 +126,7 @@ async def join(evt: CommandEvent):
updates, _ = await _join(evt, arg.group(1))
if not updates:
return
return None
for chat in updates.chats:
portal = po.Portal.get_by_entity(chat)
@@ -132,12 +137,13 @@ async def join(evt: CommandEvent):
await evt.reply(f"Creating room for {chat.title}... This might take a while.")
await portal.create_matrix_room(evt.sender, chat, [evt.sender.mxid])
return await evt.reply(f"Created room for {portal.title}")
return None
@command_handler(help_section=SECTION_MISC,
help_args="[`chats`|`contacts`|`me`]",
help_text="Synchronize your chat portals, contacts and/or own info.")
async def sync(evt: CommandEvent):
async def sync(evt: CommandEvent) -> Optional[Dict]:
if len(evt.args) > 0:
sync_only = evt.args[0]
if sync_only not in ("chats", "contacts", "me"):
+44 -28
View File
@@ -14,23 +14,34 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Tuple, Any, Optional
from typing import Any, Dict, Optional, Tuple
from ruamel.yaml import YAML
from ruamel.yaml.comments import CommentedMap
import random
import string
yaml = YAML()
yaml = YAML() # type: YAML
yaml.indent(4)
class DictWithRecursion:
def __init__(self, data: CommentedMap = None):
def __init__(self, data: Optional[CommentedMap] = None) -> None:
self._data = data or CommentedMap() # type: CommentedMap
@staticmethod
def _parse_key(key: str) -> Tuple[str, Optional[str]]:
if '.' not in key:
return key, None
key, next_key = key.split('.', 1)
if len(key) > 0 and key[0] == "[":
end_index = next_key.index("]")
key = key[1:] + "." + next_key[:end_index]
next_key = next_key[end_index + 2:] if len(next_key) > end_index + 1 else None
return key, next_key
def _recursive_get(self, data: CommentedMap, key: str, default_value: Any) -> Any:
if '.' in key:
key, next_key = key.split('.', 1)
key, next_key = self._parse_key(key)
if next_key is not None:
next_data = data.get(key, CommentedMap())
return self._recursive_get(next_data, next_key, default_value)
return data.get(key, default_value)
@@ -46,40 +57,38 @@ class DictWithRecursion:
def __contains__(self, key: str) -> bool:
return self[key] is not None
def _recursive_set(self, data: CommentedMap, key: str, value: Any):
if '.' in key:
key, next_key = key.split('.', 1)
def _recursive_set(self, data: CommentedMap, key: str, value: Any) -> None:
key, next_key = self._parse_key(key)
if next_key is not None:
if key not in data:
data[key] = CommentedMap()
next_data = data.get(key, CommentedMap())
self._recursive_set(next_data, next_key, value)
return
return self._recursive_set(next_data, next_key, value)
data[key] = value
def set(self, key: str, value: Any, allow_recursion: bool = True):
def set(self, key: str, value: Any, allow_recursion: bool = True) -> None:
if allow_recursion and '.' in key:
self._recursive_set(self._data, key, value)
return
self._data[key] = value
def __setitem__(self, key: str, value: Any):
def __setitem__(self, key: str, value: Any) -> None:
self.set(key, value)
def _recursive_del(self, data: CommentedMap, key: str):
if '.' in key:
key, next_key = key.split('.', 1)
def _recursive_del(self, data: CommentedMap, key: str) -> None:
key, next_key = self._parse_key(key)
if next_key is not None:
if key not in data:
return
next_data = data[key]
self._recursive_del(next_data, next_key)
return
return self._recursive_del(next_data, next_key)
try:
del data[key]
del data.ca.items[key]
except KeyError:
pass
def delete(self, key: str, allow_recursion: bool = True):
def delete(self, key: str, allow_recursion: bool = True) -> None:
if allow_recursion and '.' in key:
self._recursive_del(self._data, key)
return
@@ -89,19 +98,19 @@ class DictWithRecursion:
except KeyError:
pass
def __delitem__(self, key: str):
def __delitem__(self, key: str) -> None:
self.delete(key)
class Config(DictWithRecursion):
def __init__(self, path: str, registration_path: str, base_path: str):
def __init__(self, path: str, registration_path: str, base_path: str) -> None:
super().__init__()
self.path = path # type: str
self.registration_path = registration_path # type: str
self.base_path = base_path # type: str
self._registration = None # type: dict
self._registration = None # type: Optional[Dict]
def load(self):
def load(self) -> None:
with open(self.path, 'r') as stream:
self._data = yaml.load(stream)
@@ -113,7 +122,7 @@ class Config(DictWithRecursion):
pass
return None
def save(self):
def save(self) -> None:
with open(self.path, 'w') as stream:
yaml.dump(self._data, stream)
if self._registration and self.registration_path:
@@ -124,16 +133,16 @@ class Config(DictWithRecursion):
def _new_token() -> str:
return "".join(random.choice(string.ascii_lowercase + string.digits) for _ in range(64))
def update(self):
def update(self) -> None:
base = self.load_base()
if not base:
return
def copy(from_path, to_path=None):
def copy(from_path, to_path=None) -> None:
if from_path in self:
base[to_path or from_path] = self[from_path]
def copy_dict(from_path, to_path=None, override_existing_map=True):
def copy_dict(from_path, to_path=None, override_existing_map=True) -> None:
if from_path in self:
to_path = to_path or from_path
if override_existing_map or to_path not in base:
@@ -156,6 +165,7 @@ class Config(DictWithRecursion):
copy("appservice.max_body_size")
copy("appservice.database")
copy("appservice.sqlalchemy_core_mode")
copy("appservice.public.enabled")
copy("appservice.public.prefix")
@@ -183,7 +193,13 @@ class Config(DictWithRecursion):
copy("bridge.edits_as_replies")
copy("bridge.highlight_edits")
copy("bridge.bridge_notices")
if isinstance(self["bridge.bridge_notices"], bool):
base["bridge.bridge_notices"] = {
"default": self["bridge.bridge_notices"],
"exceptions": ["@importantbot:example.com"],
}
else:
copy("bridge.bridge_notices")
copy("bridge.bot_messages_as_notices")
copy("bridge.max_initial_member_sync")
copy("bridge.sync_channel_members")
@@ -273,7 +289,7 @@ class Config(DictWithRecursion):
return self._get_permissions("*")
def generate_registration(self):
def generate_registration(self) -> None:
homeserver = self["homeserver.domain"]
username_format = self.get("bridge.username_template", "telegram_{userid}") \
+7 -8
View File
@@ -14,7 +14,7 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import TYPE_CHECKING, Optional
from typing import Optional, Tuple, TYPE_CHECKING
if TYPE_CHECKING:
import asyncio
@@ -32,7 +32,8 @@ if TYPE_CHECKING:
class Context:
def __init__(self, az: "AppService", db: "scoped_session", config: "Config",
loop: "asyncio.AbstractEventLoop", session_container: "AlchemySessionContainer"):
loop: "asyncio.AbstractEventLoop", session_container: "AlchemySessionContainer"
) -> None:
self.az = az # type: AppService
self.db = db # type: scoped_session
self.config = config # type: Config
@@ -43,9 +44,7 @@ class Context:
self.public_website = None # type: PublicBridgeWebsite
self.provisioning_api = None # type: ProvisioningAPI
def __iter__(self):
yield self.az
yield self.db
yield self.config
yield self.loop
yield self.bot
@property
def core(self) -> Tuple['AppService', 'scoped_session', 'Config',
'asyncio.AbstractEventLoop', Optional['Bot']]:
return (self.az, self.db, self.config, self.loop, self.bot)
+107 -32
View File
@@ -15,11 +15,17 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from sqlalchemy import (Column, UniqueConstraint, ForeignKey, ForeignKeyConstraint, Integer,
BigInteger, String, Boolean, Text)
BigInteger, String, Boolean, Text, Table,
and_, func, select)
from sqlalchemy.engine import Engine, RowProxy
from sqlalchemy.sql import expression
from sqlalchemy.orm import relationship, Query
from sqlalchemy.sql.base import ImmutableColumnCollection
from typing import Dict, Optional, List
import json
from mautrix_telegram.types import MatrixUserID, MatrixRoomID, MatrixEventID
from .types import TelegramID
from .base import Base
@@ -28,13 +34,15 @@ class Portal(Base):
__tablename__ = "portal"
# Telegram chat information
tgid = Column(Integer, primary_key=True)
tg_receiver = Column(Integer, primary_key=True)
tgid = Column(Integer, primary_key=True) # type: TelegramID
tg_receiver = Column(Integer, primary_key=True) # type: TelegramID
peer_type = Column(String, nullable=False)
megagroup = Column(Boolean)
# Matrix portal information
mxid = Column(String, unique=True, nullable=True)
mxid = Column(String, unique=True, nullable=True) # type: Optional[MatrixRoomID]
config = Column(Text, nullable=True)
# Telegram chat metadata
username = Column(String, nullable=True)
@@ -44,25 +52,88 @@ class Portal(Base):
class Message(Base):
query = None # type: Query
db = None # type: Engine
t = None # type: Table
c = None # type: ImmutableColumnCollection
__tablename__ = "message"
mxid = Column(String)
mx_room = Column(String)
tgid = Column(Integer, primary_key=True)
tg_space = Column(Integer, primary_key=True)
mxid = Column(String) # type: MatrixEventID
mx_room = Column(String) # type: MatrixRoomID
tgid = Column(Integer, primary_key=True) # type: TelegramID
tg_space = Column(Integer, primary_key=True) # type: TelegramID
__table_args__ = (UniqueConstraint("mxid", "mx_room", "tg_space", name="_mx_id_room"),)
@staticmethod
def _one_or_none(rows: RowProxy) -> Optional['Message']:
try:
mxid, mx_room, tgid, tg_space = next(rows)
return Message(mxid=mxid, mx_room=mx_room, tgid=tgid, tg_space=tg_space)
except StopIteration:
return None
@staticmethod
def _all(rows: RowProxy) -> List['Message']:
return [Message(mxid=row[0], mx_room=row[1], tgid=row[2], tg_space=row[3])
for row in rows]
@classmethod
def get_by_tgid(cls, tgid: TelegramID, tg_space: TelegramID) -> Optional['Message']:
rows = cls.db.execute(cls.t.select()
.where(and_(cls.c.tgid == tgid, cls.c.tg_space == tg_space)))
return cls._one_or_none(rows)
@classmethod
def count_spaces_by_mxid(cls, mxid: MatrixEventID, mx_room: MatrixRoomID) -> int:
rows = cls.db.execute(select([func.count(cls.c.tg_space)])
.where(and_(cls.c.mxid == mxid, cls.c.mx_room == mx_room)))
try:
count, = next(rows)
return count
except StopIteration:
return 0
@classmethod
def get_by_mxid(cls, mxid: MatrixEventID, mx_room: MatrixRoomID, tg_space: TelegramID
) -> Optional['Message']:
rows = cls.db.execute(cls.t.select().where(
and_(cls.c.mxid == mxid, cls.c.mx_room == mx_room, cls.c.tg_space == tg_space)))
return cls._one_or_none(rows)
@classmethod
def update_by_tgid(cls, s_tgid: TelegramID, s_tg_space: TelegramID, **values) -> None:
cls.db.execute(cls.t.update()
.where(and_(cls.c.tgid == s_tgid, cls.c.tg_space == s_tg_space))
.values(**values))
@classmethod
def update_by_mxid(cls, s_mxid: MatrixEventID, s_mx_room: MatrixRoomID, **values) -> None:
cls.db.execute(cls.t.update()
.where(and_(cls.c.mxid == s_mxid, cls.c.mx_room == s_mx_room))
.values(**values))
def update(self, **values) -> None:
for key, value in values.items():
setattr(self, key, value)
self.update_by_tgid(self.tgid, self.tg_space, **values)
def delete(self) -> None:
self.db.execute(self.t.delete().where(
and_(self.c.tgid == self.tgid, self.c.tg_space == self.tg_space)))
def insert(self) -> None:
self.db.execute(self.t.insert().values(mxid=self.mxid, mx_room=self.mx_room, tgid=self.tgid,
tg_space=self.tg_space))
class UserPortal(Base):
query = None # type: Query
__tablename__ = "user_portal"
user = Column(Integer, ForeignKey("user.tgid", onupdate="CASCADE", ondelete="CASCADE"),
primary_key=True)
portal = Column(Integer, primary_key=True)
portal_receiver = Column(Integer, primary_key=True)
primary_key=True) # type: TelegramID
portal = Column(Integer, primary_key=True) # type: TelegramID
portal_receiver = Column(Integer, primary_key=True) # type: TelegramID
__table_args__ = (ForeignKeyConstraint(("portal", "portal_receiver"),
("portal.tgid", "portal.tg_receiver"),
@@ -73,12 +144,14 @@ class User(Base):
query = None # type: Query
__tablename__ = "user"
mxid = Column(String, primary_key=True)
tgid = Column(Integer, nullable=True, unique=True)
mxid = Column(String, primary_key=True) # type: MatrixUserID
tgid = Column(Integer, nullable=True, unique=True) # type: Optional[TelegramID]
tg_username = Column(String, nullable=True)
tg_phone = Column(String, nullable=True)
saved_contacts = Column(Integer, default=0, nullable=False)
contacts = relationship("Contact", uselist=True,
cascade="save-update, merge, delete, delete-orphan")
cascade="save-update, merge, delete, delete-orphan"
) # type: List[Contact]
portals = relationship("Portal", secondary="user_portal")
@@ -86,22 +159,22 @@ class RoomState(Base):
query = None # type: Query
__tablename__ = "mx_room_state"
room_id = Column(String, primary_key=True)
room_id = Column(String, primary_key=True) # type: MatrixRoomID
_power_levels_text = Column("power_levels", Text, nullable=True)
_power_levels_json = None
_power_levels_json = {} # type: Dict
@property
def has_power_levels(self):
def has_power_levels(self) -> bool:
return bool(self._power_levels_text)
@property
def power_levels(self):
def power_levels(self) -> Dict:
if not self._power_levels_json and self._power_levels_text:
self._power_levels_json = json.loads(self._power_levels_text)
return self._power_levels_json or {}
return self._power_levels_json
@power_levels.setter
def power_levels(self, val):
def power_levels(self, val: Dict) -> None:
self._power_levels_json = val
self._power_levels_text = json.dumps(val)
@@ -110,13 +183,13 @@ class UserProfile(Base):
query = None # type: Query
__tablename__ = "mx_user_profile"
room_id = Column(String, primary_key=True)
user_id = Column(String, primary_key=True)
room_id = Column(String, primary_key=True) # type: MatrixRoomID
user_id = Column(String, primary_key=True) # type: MatrixUserID
membership = Column(String, nullable=False, default="leave")
displayname = Column(String, nullable=True)
avatar_url = Column(String, nullable=True)
def dict(self):
def dict(self) -> Dict[str, str]:
return {
"membership": self.membership,
"displayname": self.displayname,
@@ -128,19 +201,19 @@ class Contact(Base):
query = None # type: Query
__tablename__ = "contact"
user = Column(Integer, ForeignKey("user.tgid"), primary_key=True)
contact = Column(Integer, ForeignKey("puppet.id"), primary_key=True)
user = Column(Integer, ForeignKey("user.tgid"), primary_key=True) # type: TelegramID
contact = Column(Integer, ForeignKey("puppet.id"), primary_key=True) # type: TelegramID
class Puppet(Base):
query = None # type: Query
__tablename__ = "puppet"
id = Column(Integer, primary_key=True)
custom_mxid = Column(String, nullable=True)
id = Column(Integer, primary_key=True) # type: TelegramID
custom_mxid = Column(String, nullable=True) # type: Optional[MatrixUserID]
access_token = Column(String, nullable=True)
displayname = Column(String, nullable=True)
displayname_source = Column(Integer, nullable=True)
displayname_source = Column(Integer, nullable=True) # type: Optional[TelegramID]
username = Column(String, nullable=True)
photo_id = Column(String, nullable=True)
is_bot = Column(Boolean, nullable=True)
@@ -151,7 +224,7 @@ class Puppet(Base):
class BotChat(Base):
query = None # type: Query
__tablename__ = "bot_chat"
id = Column(Integer, primary_key=True)
id = Column(Integer, primary_key=True) # type: TelegramID
type = Column(String, nullable=False)
@@ -171,9 +244,11 @@ class TelegramFile(Base):
thumbnail = relationship("TelegramFile", uselist=False)
def init(db_session):
def init(db_session, db_engine) -> None:
Portal.query = db_session.query_property()
Message.query = db_session.query_property()
Message.db = db_engine
Message.t = Message.__table__
Message.c = Message.t.c
UserPortal.query = db_session.query_property()
User.query = db_session.query_property()
Puppet.query = db_session.query_property()
+1 -1
View File
@@ -4,6 +4,6 @@ from .from_telegram import (telegram_reply_to_matrix, telegram_to_matrix, init_t
from .. import context as c
def init(context: c.Context):
def init(context: c.Context) -> None:
init_mx(context)
init_tg(context)
@@ -14,7 +14,7 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional, List, Tuple, Callable, Pattern, Match, TYPE_CHECKING
from typing import Optional, List, Tuple, Callable, Pattern, Match, TYPE_CHECKING, Dict, Any
import re
import logging
@@ -22,6 +22,7 @@ from telethon.tl.types import (MessageEntityMention, MessageEntityMentionName, M
TypeMessageEntity)
from ... import puppet as pu
from ...types import TelegramID, MatrixRoomID
from ...db import Message as DBMessage
from ..util import (add_surrogates, remove_surrogates, trim_reply_fallback_html,
trim_reply_fallback_text)
@@ -90,8 +91,8 @@ def matrix_to_telegram(html: str) -> ParsedMessage:
raise FormatError(f"Failed to convert Matrix format: {html}") from e
def matrix_reply_to_telegram(content: dict, tg_space: int, room_id: Optional[str] = None
) -> Optional[int]:
def matrix_reply_to_telegram(content: Dict[str, Any], tg_space: TelegramID,
room_id: Optional[MatrixRoomID] = None) -> Optional[TelegramID]:
try:
reply = content["m.relates_to"]["m.in_reply_to"]
room_id = room_id or reply["room_id"]
@@ -104,9 +105,7 @@ def matrix_reply_to_telegram(content: dict, tg_space: int, room_id: Optional[str
pass
content["body"] = trim_reply_fallback_text(content["body"])
message = DBMessage.query.filter(DBMessage.mxid == event_id,
DBMessage.tg_space == tg_space,
DBMessage.mx_room == room_id).one_or_none()
message = DBMessage.get_by_mxid(event_id, room_id, tg_space)
if message:
return message.tgid
except KeyError:
@@ -147,7 +146,7 @@ def plain_mention_to_text() -> Tuple[List[TypeMessageEntity], Callable[[str], st
return entities, replacer
def init_mx(context: "Context"):
def init_mx(context: "Context") -> None:
global plain_mention_regex, should_bridge_plaintext_highlights
config = context.config
dn_template = config.get("bridge.displayname_template", "{displayname} (Telegram)")
@@ -22,10 +22,15 @@ from telethon.tl.types import TypeMessageEntity
class MatrixParserCommon:
mention_regex = re.compile("https://matrix.to/#/(@.+:.+)") # type: Pattern
room_regex = re.compile("https://matrix.to/#/(#.+:.+)") # type: Pattern
block_tags = ("br", "p", "pre", "blockquote",
block_tags = ("p", "pre", "blockquote",
"ol", "ul", "li",
"h1", "h2", "h3", "h4", "h5", "h6",
"div", "hr", "table") # type: Tuple[str, ...]
list_bullets = ("", "", "", "") # type: Tuple[str, ...]
@classmethod
def list_bullet(cls, depth: int) -> str:
return cls.list_bullets[(depth - 1) % len(cls.list_bullets)] + " "
ParsedMessage = Tuple[str, List[TypeMessageEntity]]
@@ -14,7 +14,7 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import (Optional, List, Tuple, Type, Dict, Any, Deque, Match)
from typing import (Optional, List, Tuple, Type, Dict, Any, TYPE_CHECKING, Match)
from html import unescape
from html.parser import HTMLParser
from collections import deque
@@ -26,9 +26,13 @@ from telethon.tl.types import (MessageEntityMention, MessageEntityMentionName, M
MessageEntityBotCommand, TypeMessageEntity)
from ... import user as u, puppet as pu, portal as po
from ...types import MatrixUserID
from ..util import html_to_unicode
from .parser_common import MatrixParserCommon, ParsedMessage
if TYPE_CHECKING:
from typing import Deque
def parse_html(html: str) -> ParsedMessage:
parser = MatrixParser()
@@ -52,7 +56,7 @@ class MatrixParser(HTMLParser, MatrixParserCommon):
) -> Tuple[Optional[Type[TypeMessageEntity]], Optional[str]]:
mention = self.mention_regex.match(url) # type: Match
if mention:
mxid = mention.group(1)
mxid = MatrixUserID(mention.group(1))
user = (pu.Puppet.get_by_mxid(mxid)
or u.User.get_by_mxid(mxid, create=False))
if not user:
@@ -80,12 +84,12 @@ class MatrixParser(HTMLParser, MatrixParserCommon):
args["url"] = url
return MessageEntityTextUrl, None
def handle_starttag(self, tag: str, attrs: List[Tuple[str, str]]):
def handle_starttag(self, tag: str, attrs_list: List[Tuple[str, str]]):
self._open_tags.appendleft(tag)
self._open_tags_meta.appendleft(0)
attrs = dict(attrs)
entity_type = None # type: type(TypeMessageEntity)
attrs = dict(attrs_list)
entity_type = None # type: Optional[Type[TypeMessageEntity]]
args = {} # type: Dict[str, Any]
if tag in ("strong", "b"):
entity_type = MessageEntityBold
@@ -119,7 +123,7 @@ class MatrixParser(HTMLParser, MatrixParserCommon):
self._open_tags_meta.popleft()
self._open_tags_meta.appendleft(url)
if tag in self.block_tags and ("blockquote" not in self._open_tags or tag == "br"):
if (tag in self.block_tags and ("blockquote" not in self._open_tags)) or tag == "br":
self._newline()
if entity_type and tag not in self._building_entities:
@@ -198,7 +202,8 @@ class MatrixParser(HTMLParser, MatrixParserCommon):
else:
prefix = int(math.log(n, 10)) * 3 * " " + 4 * " "
else:
prefix = "* " if self._list_entry_is_new else 3 * " "
prefix = (self.list_bullet(self._open_tags.count('ul'))
if self._list_entry_is_new else 3 * " ")
if not self._list_entry_is_new and not self._line_is_new:
prefix = ""
extra_offset += len(indent) + len(prefix)
@@ -14,166 +14,49 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional, List, Tuple, Union, Callable
from typing import List, Tuple
from lxml import html
from telethon.tl.types import (MessageEntityMention as Mention,
from telethon.tl.types import (MessageEntityMention as Mention, MessageEntityBotCommand as Command,
MessageEntityMentionName as MentionName, MessageEntityEmail as Email,
MessageEntityUrl as URL, MessageEntityTextUrl as TextURL,
MessageEntityBold as Bold, MessageEntityItalic as Italic,
MessageEntityCode as Code, MessageEntityPre as Pre,
MessageEntityBotCommand as Command, TypeMessageEntity,
InputMessageEntityMentionName as InputMentionName)
MessageEntityCode as Code, MessageEntityPre as Pre)
from ... import user as u, puppet as pu, portal as po
from ...types import MatrixUserID
from ..util import html_to_unicode
from .parser_common import MatrixParserCommon, ParsedMessage
from .telegram_message import TelegramMessage, Entity, offset_length_multiply
def parse_html(html: str) -> ParsedMessage:
return MatrixParser.parse(html)
def parse_html(input_html: str) -> ParsedMessage:
return MatrixParser.parse(input_html)
class Entity:
@staticmethod
def copy(entity: TypeMessageEntity) -> Optional[TypeMessageEntity]:
if not entity:
return None
kwargs = {
"offset": entity.offset,
"length": entity.length,
}
if isinstance(entity, Pre):
kwargs["language"] = entity.language
elif isinstance(entity, TextURL):
kwargs["url"] = entity.url
elif isinstance(entity, (MentionName, InputMentionName)):
kwargs["user_id"] = entity.user_id
return entity.__class__(**kwargs)
class RecursionContext:
def __init__(self, strip_linebreaks: bool = True, ul_depth: int = 0):
self.strip_linebreaks = strip_linebreaks # type: bool
self.ul_depth = ul_depth # type: int
self._inited = True # type: bool
@classmethod
def adjust(cls, entity: Union[TypeMessageEntity, List[TypeMessageEntity]],
func: Callable[[TypeMessageEntity], None]
) -> Union[Optional[TypeMessageEntity], List[TypeMessageEntity]]:
if isinstance(entity, list):
return [Entity.adjust(element, func) for element in entity if entity]
elif not entity:
return None
entity = cls.copy(entity)
func(entity)
if entity.offset < 0:
entity.length += entity.offset
entity.offset = 0
return entity
def __setattr__(self, key, value):
if getattr(self, "_inited", False) is True:
raise TypeError("'RecursionContext' object is immutable")
super(RecursionContext, self).__setattr__(key, value)
def enter_list(self) -> 'RecursionContext':
return RecursionContext(strip_linebreaks=self.strip_linebreaks, ul_depth=self.ul_depth + 1)
def offset_diff(amount: int):
def func(entity: TypeMessageEntity):
entity.offset += amount
return func
def offset_length_multiply(amount: int):
def func(entity: TypeMessageEntity):
entity.offset *= amount
entity.length *= amount
return func
class TelegramMessage:
def __init__(self, text: str = "", entities: Optional[List[TypeMessageEntity]] = None):
self.text = text # type: str
self.entities = entities or [] # type: List[TypeMessageEntity]
def offset_entities(self, offset: int) -> "TelegramMessage":
def apply_offset(entity: TypeMessageEntity, inner_offset: int
) -> Optional[TypeMessageEntity]:
entity = Entity.copy(entity)
entity.offset += inner_offset
if entity.offset < 0:
entity.offset = 0
elif entity.offset > len(self.text):
return None
elif entity.offset + entity.length > len(self.text):
entity.length = len(self.text) - entity.offset
return entity
self.entities = [apply_offset(entity, offset) for entity in self.entities if entity]
self.entities = [x for x in self.entities if x is not None]
return self
def append(self, *args: Union[str, "TelegramMessage"]) -> "TelegramMessage":
for msg in args:
if isinstance(msg, str):
msg = TelegramMessage(text=msg)
self.entities += Entity.adjust(msg.entities, offset_diff(len(self.text)))
self.text += msg.text
return self
def prepend(self, *args: Union[str, "TelegramMessage"]) -> "TelegramMessage":
for msg in args:
if isinstance(msg, str):
msg = TelegramMessage(text=msg)
self.entities = msg.entities + Entity.adjust(self.entities, offset_diff(len(msg.text)))
self.text = msg.text + self.text
return self
def format(self, entity_type: type(TypeMessageEntity), offset: int = None, length: int = None,
**kwargs) -> "TelegramMessage":
self.entities.append(entity_type(offset=offset or 0,
length=length if length is not None else len(self.text),
**kwargs))
return self
def concat(self, *args: Union[str, "TelegramMessage"]) -> "TelegramMessage":
return TelegramMessage().append(self, *args)
def trim(self) -> "TelegramMessage":
orig_len = len(self.text)
self.text = self.text.lstrip()
diff = orig_len - len(self.text)
self.text = self.text.rstrip()
self.offset_entities(-diff)
return self
def split(self, separator, max_items: int = 0) -> List["TelegramMessage"]:
text_parts = self.text.split(separator, max_items - 1)
output = [] # type: List[TelegramMessage]
offset = 0
for part in text_parts:
msg = TelegramMessage(part)
for entity in self.entities:
start_in_range = len(part) > entity.offset - offset >= 0
end_in_range = len(part) >= entity.offset - offset + entity.length > 0
if start_in_range and end_in_range:
msg.entities.append(Entity.adjust(entity, offset_diff(-offset)))
output.append(msg)
offset += len(part)
offset += len(separator)
return output
@staticmethod
def join(items: List[Union[str, "TelegramMessage"]], separator: str = " ") -> "TelegramMessage":
main = TelegramMessage()
for msg in items:
if isinstance(msg, str):
msg = TelegramMessage(text=msg)
main.entities += Entity.adjust(msg.entities, offset_diff(len(main.text)))
main.text += msg.text + separator
main.text = main.text[:-len(separator)]
return main
def enter_code_block(self) -> 'RecursionContext':
return RecursionContext(strip_linebreaks=False, ul_depth=self.ul_depth)
class MatrixParser(MatrixParserCommon):
@classmethod
def list_to_tmessage(cls, node: html.HtmlElement, strip_linebreaks) -> TelegramMessage:
def list_to_tmessage(cls, node: html.HtmlElement, ctx: RecursionContext) -> TelegramMessage:
ordered = node.tag == "ol"
tagged_children = cls.node_to_tagged_tmessages(node, strip_linebreaks)
tagged_children = cls.node_to_tagged_tmessages(node, ctx)
counter = 1
indent_length = 0
if ordered:
@@ -194,7 +77,7 @@ class MatrixParser(MatrixParserCommon):
prefix = f"{counter}. "
counter += 1
else:
prefix = ""
prefix = cls.list_bullet(ctx.ul_depth)
child = child.prepend(prefix)
parts = child.split("\n")
parts = parts[:1] + [part.prepend(indent) for part in parts[1:]]
@@ -203,41 +86,43 @@ class MatrixParser(MatrixParserCommon):
return TelegramMessage.join(children, "\n")
@classmethod
def blockquote_to_tmessage(cls, node: html.HtmlElement, strip_linebreaks) -> TelegramMessage:
msg = cls.tag_aware_parse_node(node, strip_linebreaks)
def blockquote_to_tmessage(cls, node: html.HtmlElement, ctx: RecursionContext
) -> TelegramMessage:
msg = cls.tag_aware_parse_node(node, ctx)
children = msg.trim().split("\n")
children = [child.prepend("> ") for child in children]
return TelegramMessage.join(children, "\n")
@classmethod
def header_to_tmessage(cls, node: html.HtmlElement, strip_linebreaks) -> TelegramMessage:
children = cls.node_to_tmessages(node, strip_linebreaks)
def header_to_tmessage(cls, node: html.HtmlElement, ctx: RecursionContext) -> TelegramMessage:
children = cls.node_to_tmessages(node, ctx)
length = int(node.tag[1])
prefix = "#" * length + " "
return TelegramMessage.join(children, "").prepend(prefix)
return TelegramMessage.join(children, "").prepend(prefix).format(Bold)
@classmethod
def basic_format_to_tmessage(cls, node: html.HtmlElement, strip_linebreaks) -> TelegramMessage:
msg = cls.tag_aware_parse_node(node, strip_linebreaks)
def basic_format_to_tmessage(cls, node: html.HtmlElement, ctx: RecursionContext
) -> TelegramMessage:
msg = cls.tag_aware_parse_node(node, ctx)
if node.tag in ("b", "strong"):
msg.format(Bold)
elif node.tag in ("i", "em"):
msg.format(Italic)
elif node.tag == "command":
msg.format(Command)
elif node.tag in ("s", "del"):
elif node.tag in ("s", "strike", "del"):
msg.text = html_to_unicode(msg.text, "\u0336")
elif node.tag in ("u", "ins"):
msg.text = html_to_unicode(msg.text, "\u0332")
if node.tag in ("s", "del", "u", "ins"):
if node.tag in ("s", "strike", "del", "u", "ins"):
msg.entities = Entity.adjust(msg.entities, offset_length_multiply(2))
return msg
@classmethod
def link_to_tstring(cls, node: html.HtmlElement, strip_linebreaks) -> TelegramMessage:
msg = cls.tag_aware_parse_node(node, strip_linebreaks)
def link_to_tstring(cls, node: html.HtmlElement, ctx: RecursionContext) -> TelegramMessage:
msg = cls.tag_aware_parse_node(node, ctx)
href = node.attrib.get("href", "")
if not href:
return msg
@@ -247,7 +132,7 @@ class MatrixParser(MatrixParserCommon):
mention = cls.mention_regex.match(href)
if mention:
mxid = mention.group(1)
mxid = MatrixUserID(mention.group(1))
user = (pu.Puppet.get_by_mxid(mxid)
or u.User.get_by_mxid(mxid, create=False))
if not user:
@@ -271,73 +156,81 @@ class MatrixParser(MatrixParserCommon):
else msg.format(TextURL, url=href))
@classmethod
def node_to_tmessage(cls, node: html.HtmlElement, strip_linebreaks) -> TelegramMessage:
def node_to_tmessage(cls, node: html.HtmlElement, ctx: RecursionContext) -> TelegramMessage:
if node.tag == "blockquote":
return cls.blockquote_to_tmessage(node, strip_linebreaks)
elif node.tag in ("ol", "ul"):
return cls.list_to_tmessage(node, strip_linebreaks)
return cls.blockquote_to_tmessage(node, ctx)
elif node.tag == "ol":
return cls.list_to_tmessage(node, ctx)
elif node.tag == "ul":
return cls.list_to_tmessage(node, ctx.enter_list())
elif node.tag in ("h1", "h2", "h3", "h4", "h5", "h6"):
return cls.header_to_tmessage(node, strip_linebreaks)
return cls.header_to_tmessage(node, ctx)
elif node.tag == "br":
return TelegramMessage("\n")
elif node.tag in ("b", "strong", "i", "em", "s", "del", "u", "ins", "command"):
return cls.basic_format_to_tmessage(node, strip_linebreaks)
return cls.basic_format_to_tmessage(node, ctx)
elif node.tag == "a":
return cls.link_to_tstring(node, strip_linebreaks)
return cls.link_to_tstring(node, ctx)
elif node.tag == "p":
return cls.tag_aware_parse_node(node, strip_linebreaks).append("\n")
return cls.tag_aware_parse_node(node, ctx).append("\n")
elif node.tag == "pre":
lang = ""
try:
if node[0].tag == "code":
lang = node[0].attrib["class"][len("language-"):]
node = node[0]
lang = node.attrib["class"][len("language-"):]
except (IndexError, KeyError):
pass
return cls.parse_node(node, strip_linebreaks=False).format(Pre, language=lang)
return cls.parse_node(node, ctx.enter_code_block()).format(Pre, language=lang)
elif node.tag == "code":
return cls.parse_node(node, strip_linebreaks=False).format(Code)
return cls.tag_aware_parse_node(node, strip_linebreaks)
return cls.parse_node(node, ctx.enter_code_block()).format(Code)
return cls.tag_aware_parse_node(node, ctx)
@staticmethod
def text_to_tmessage(text: str, strip_linebreaks: bool = True) -> TelegramMessage:
if strip_linebreaks:
def text_to_tmessage(text: str, ctx: RecursionContext) -> TelegramMessage:
if ctx.strip_linebreaks:
text = text.replace("\n", "")
return TelegramMessage(text)
@classmethod
def node_to_tagged_tmessages(cls, node: html.HtmlElement, strip_linebreaks: bool = True
def node_to_tagged_tmessages(cls, node: html.HtmlElement, ctx: RecursionContext
) -> List[Tuple[TelegramMessage, str]]:
output = []
if node.text:
output.append((cls.text_to_tmessage(node.text, strip_linebreaks), "text"))
output.append((cls.text_to_tmessage(node.text, ctx), "text"))
for child in node:
output.append((cls.node_to_tmessage(child, strip_linebreaks), child.tag))
output.append((cls.node_to_tmessage(child, ctx), child.tag))
if child.tail:
output.append((cls.text_to_tmessage(child.tail, strip_linebreaks), "text"))
output.append((cls.text_to_tmessage(child.tail, ctx), "text"))
return output
@classmethod
def node_to_tmessages(cls, node: html.HtmlElement, strip_linebreaks) -> List[TelegramMessage]:
return [msg for (msg, tag) in cls.node_to_tagged_tmessages(node, strip_linebreaks)]
def node_to_tmessages(cls, node: html.HtmlElement, ctx: RecursionContext
) -> List[TelegramMessage]:
return [msg for (msg, tag) in cls.node_to_tagged_tmessages(node, ctx)]
@classmethod
def tag_aware_parse_node(cls, node: html.HtmlElement, strip_linebreaks) -> TelegramMessage:
msgs = cls.node_to_tagged_tmessages(node, strip_linebreaks)
def tag_aware_parse_node(cls, node: html.HtmlElement, ctx: RecursionContext
) -> TelegramMessage:
msgs = cls.node_to_tagged_tmessages(node, ctx)
output = TelegramMessage()
prev_was_block = False
for msg, tag in msgs:
if tag in cls.block_tags:
msg = msg.append("\n").prepend("\n")
msg = msg.append("\n")
if not prev_was_block:
msg = msg.prepend("\n")
prev_was_block = True
output = output.append(msg)
return output.trim()
@classmethod
def parse_node(cls, node: html.HtmlElement, strip_linebreaks) -> TelegramMessage:
return TelegramMessage.join(cls.node_to_tmessages(node, strip_linebreaks))
def parse_node(cls, node: html.HtmlElement, ctx: RecursionContext) -> TelegramMessage:
return TelegramMessage.join(cls.node_to_tmessages(node, ctx))
@classmethod
def parse(cls, data: str) -> ParsedMessage:
document = html.fromstring(f"<html>{data}</html>")
msg = cls.parse_node(document, strip_linebreaks=True)
msg = cls.parse_node(document, RecursionContext())
return msg.text, msg.entities
@@ -0,0 +1,157 @@
# -*- coding: future_fstrings -*-
# mautrix-telegram - A Matrix-Telegram puppeting bridge
# Copyright (C) 2018 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Callable, List, Optional, Sequence, Type, Union
from telethon.tl.types import (MessageEntityMentionName as MentionName,
MessageEntityTextUrl as TextURL, MessageEntityPre as Pre,
TypeMessageEntity, InputMessageEntityMentionName as InputMentionName)
class Entity:
@staticmethod
def copy(entity: TypeMessageEntity) -> Optional[TypeMessageEntity]:
if not entity:
return None
kwargs = {
"offset": entity.offset,
"length": entity.length,
}
if isinstance(entity, Pre):
kwargs["language"] = entity.language
elif isinstance(entity, TextURL):
kwargs["url"] = entity.url
elif isinstance(entity, (MentionName, InputMentionName)):
kwargs["user_id"] = entity.user_id
return entity.__class__(**kwargs)
@classmethod
def adjust(cls, entity: Union[TypeMessageEntity, List[TypeMessageEntity]],
func: Callable[[TypeMessageEntity], None]
) -> Union[Optional[TypeMessageEntity], List[TypeMessageEntity]]:
if isinstance(entity, list):
return [Entity.adjust(element, func) for element in entity if entity]
elif not entity:
return None
entity = cls.copy(entity)
func(entity)
if entity.offset < 0:
entity.length += entity.offset
entity.offset = 0
return entity
def offset_diff(amount: int) -> Callable[[TypeMessageEntity], None]:
def func(entity: TypeMessageEntity) -> None:
entity.offset += amount
return func
def offset_length_multiply(amount: int) -> Callable[[TypeMessageEntity], None]:
def func(entity: TypeMessageEntity) -> None:
entity.offset *= amount
entity.length *= amount
return func
class TelegramMessage:
def __init__(self, text: str = "", entities: Optional[List[TypeMessageEntity]] = None) -> None:
self.text = text # type: str
self.entities = entities or [] # type: List[TypeMessageEntity]
def offset_entities(self, offset: int) -> 'TelegramMessage':
def apply_offset(entity: TypeMessageEntity, inner_offset: int
) -> Optional[TypeMessageEntity]:
entity = Entity.copy(entity)
entity.offset += inner_offset
if entity.offset < 0:
entity.offset = 0
elif entity.offset > len(self.text):
return None
elif entity.offset + entity.length > len(self.text):
entity.length = len(self.text) - entity.offset
return entity
self.entities = [apply_offset(entity, offset) for entity in self.entities if entity]
self.entities = [x for x in self.entities if x is not None]
return self
def append(self, *args: Union[str, 'TelegramMessage']) -> 'TelegramMessage':
for msg in args:
if isinstance(msg, str):
msg = TelegramMessage(text=msg)
self.entities += Entity.adjust(msg.entities, offset_diff(len(self.text)))
self.text += msg.text
return self
def prepend(self, *args: Union[str, 'TelegramMessage']) -> 'TelegramMessage':
for msg in args:
if isinstance(msg, str):
msg = TelegramMessage(text=msg)
self.entities = msg.entities + Entity.adjust(self.entities, offset_diff(len(msg.text)))
self.text = msg.text + self.text
return self
def format(self, entity_type: Type[TypeMessageEntity], offset: int = None, length: int = None,
**kwargs) -> 'TelegramMessage':
self.entities.append(entity_type(offset=offset or 0,
length=length if length is not None else len(self.text),
**kwargs))
return self
def concat(self, *args: Union[str, 'TelegramMessage']) -> 'TelegramMessage':
return TelegramMessage().append(self, *args)
def trim(self) -> 'TelegramMessage':
orig_len = len(self.text)
self.text = self.text.lstrip()
diff = orig_len - len(self.text)
self.text = self.text.rstrip()
self.offset_entities(-diff)
return self
def split(self, separator, max_items: int = 0) -> List['TelegramMessage']:
text_parts = self.text.split(separator, max_items - 1)
output = [] # type: List[TelegramMessage]
offset = 0
for part in text_parts:
msg = TelegramMessage(part)
for entity in self.entities:
start_in_range = len(part) > entity.offset - offset >= 0
end_in_range = len(part) >= entity.offset - offset + entity.length > 0
if start_in_range and end_in_range:
msg.entities.append(Entity.adjust(entity, offset_diff(-offset)))
output.append(msg)
offset += len(part)
offset += len(separator)
return output
@staticmethod
def join(items: Sequence[Union[str, 'TelegramMessage']],
separator: str = " ") -> 'TelegramMessage':
main = TelegramMessage()
for msg in items:
if isinstance(msg, str):
msg = TelegramMessage(text=msg)
main.entities += Entity.adjust(msg.entities, offset_diff(len(main.text)))
main.text += msg.text + separator
main.text = main.text[:-len(separator)]
return main
+28 -23
View File
@@ -14,21 +14,23 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional, List, Tuple, TYPE_CHECKING
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
from html import escape
import logging
import re
from telethon.tl.types import (MessageEntityMention, MessageEntityMentionName,
MessageEntityEmail, MessageEntityUrl, MessageEntityTextUrl,
MessageEntityBold, MessageEntityItalic, MessageEntityCode,
MessageEntityPre, MessageEntityBotCommand, Message, PeerChannel,
MessageEntityHashtag, TypeMessageEntity, MessageFwdHeader, PeerUser)
from telethon.tl.types import (MessageEntityMention, MessageEntityMentionName, MessageEntityUrl,
MessageEntityEmail, MessageEntityTextUrl, MessageEntityBold,
MessageEntityItalic, MessageEntityCode, MessageEntityPre,
MessageEntityBotCommand, MessageEntityHashtag, MessageEntityCashtag,
MessageEntityPhone, TypeMessageEntity, Message, PeerChannel,
MessageFwdHeader, PeerUser)
from mautrix_appservice import MatrixRequestError
from mautrix_appservice.intent_api import IntentAPI
from .. import user as u, puppet as pu, portal as po
from ..types import TelegramID
from ..db import Message as DBMessage
from .util import (add_surrogates, remove_surrogates, trim_reply_fallback_html,
trim_reply_fallback_text, unicode_to_html)
@@ -40,19 +42,19 @@ if TYPE_CHECKING:
try:
from lxml.html.diff import htmldiff
except ImportError:
htmldiff = None # type: function
htmldiff = None # type: ignore
log = logging.getLogger("mau.fmt.tg") # type: logging.Logger
should_highlight_edits = False # type: bool
def telegram_reply_to_matrix(evt: Message, source: "AbstractUser") -> dict:
def telegram_reply_to_matrix(evt: Message, source: 'AbstractUser') -> Dict:
if evt.reply_to_msg_id:
space = (evt.to_id.channel_id
if isinstance(evt, Message) and isinstance(evt.to_id, PeerChannel)
else source.tgid)
msg = DBMessage.query.get((evt.reply_to_msg_id, space))
msg = DBMessage.get_by_tgid(evt.reply_to_msg_id, space)
if msg:
return {
"m.in_reply_to": {
@@ -75,7 +77,7 @@ async def _add_forward_header(source, text: str, html: Optional[str],
fwd_from_html = f"<a href='https://matrix.to/#/{user.mxid}'>{fwd_from_text}</a>"
if not fwd_from_text:
puppet = pu.Puppet.get(fwd_from.from_id, create=False)
puppet = pu.Puppet.get(TelegramID(fwd_from.from_id), create=False)
if puppet and puppet.displayname:
fwd_from_text = puppet.displayname or puppet.mxid
fwd_from_html = f"<a href='https://matrix.to/#/{puppet.mxid}'>{fwd_from_text}</a>"
@@ -116,13 +118,13 @@ def highlight_edits(new_html: str, old_html: str) -> str:
async def _add_reply_header(source: "AbstractUser", text: str, html: str, evt: Message,
relates_to: dict, main_intent: IntentAPI, is_edit: bool
relates_to: Dict, main_intent: IntentAPI, is_edit: bool
) -> Tuple[str, str]:
space = (evt.to_id.channel_id
if isinstance(evt, Message) and isinstance(evt.to_id, PeerChannel)
else source.tgid)
msg = DBMessage.query.get((evt.reply_to_msg_id, space))
msg = DBMessage.get_by_tgid(evt.reply_to_msg_id, space)
if not msg:
return text, html
@@ -177,10 +179,10 @@ async def _add_reply_header(source: "AbstractUser", text: str, html: str, evt: M
async def telegram_to_matrix(evt: Message, source: "AbstractUser",
main_intent: Optional[IntentAPI] = None,
is_edit: bool = False, prefix_text: Optional[str] = None,
prefix_html: Optional[str] = None) -> Tuple[str, str, dict]:
prefix_html: Optional[str] = None) -> Tuple[str, str, Dict]:
text = add_surrogates(evt.message)
html = _telegram_entities_to_matrix_catch(text, evt.entities) if evt.entities else None
relates_to = {}
relates_to = {} # type: Dict
if prefix_html:
html = prefix_html + (html or escape(text))
@@ -217,6 +219,7 @@ def _telegram_entities_to_matrix_catch(text: str, entities: List[TypeMessageEnti
"message=%s\n"
"entities=%s",
text, entities)
return "[failed conversion in _telegram_entities_to_matrix]"
def _telegram_entities_to_matrix(text: str, entities: List[TypeMessageEntity]) -> str:
@@ -239,21 +242,23 @@ def _telegram_entities_to_matrix(text: str, entities: List[TypeMessageEntity]) -
elif entity_type == MessageEntityItalic:
html.append(f"<em>{entity_text}</em>")
elif entity_type == MessageEntityCode:
html.append(f"<code>{entity_text}</code>")
html.append(f"<pre><code>{entity_text}</code></pre>"
if "\n" in entity_text
else f"<code>{entity_text}</code>")
elif entity_type == MessageEntityPre:
skip_entity = _parse_pre(html, entity_text, entity.language)
elif entity_type == MessageEntityMention:
skip_entity = _parse_mention(html, entity_text)
elif entity_type == MessageEntityMentionName:
skip_entity = _parse_name_mention(html, entity_text, entity.user_id)
skip_entity = _parse_name_mention(html, entity_text, TelegramID(entity.user_id))
elif entity_type == MessageEntityEmail:
html.append(f"<a href='mailto:{entity_text}'>{entity_text}</a>")
elif entity_type in {MessageEntityTextUrl, MessageEntityUrl}:
elif entity_type in (MessageEntityTextUrl, MessageEntityUrl):
skip_entity = _parse_url(html, entity_text,
entity.url if entity_type == MessageEntityTextUrl else None)
elif entity_type == MessageEntityBotCommand:
html.append(f"<font color='blue'>!{entity_text[1:]}</font>")
elif entity_type == MessageEntityHashtag:
elif entity_type in (MessageEntityHashtag, MessageEntityCashtag, MessageEntityPhone):
html.append(f"<font color='blue'>{entity_text}</font>")
else:
skip_entity = True
@@ -290,7 +295,7 @@ def _parse_mention(html: List[str], entity_text: str) -> bool:
return False
def _parse_name_mention(html: List[str], entity_text: str, user_id: int) -> bool:
def _parse_name_mention(html: List[str], entity_text: str, user_id: TelegramID) -> bool:
user = u.User.get_by_tgid(user_id)
if user:
mxid = user.mxid
@@ -315,12 +320,12 @@ def _parse_url(html: List[str], entity_text: str, url: str) -> bool:
message_link_match = message_link_regex.match(url)
if message_link_match:
group, msgid = message_link_match.groups()
msgid = int(msgid)
group, msgid_str = message_link_match.groups()
msgid = int(msgid_str)
portal = po.Portal.find_by_username(group)
if portal:
message = DBMessage.query.get((msgid, portal.tgid))
message = DBMessage.get_by_tgid(TelegramID(msgid), portal.tgid)
if message:
url = f"https://matrix.to/#/{portal.mxid}/{message.mxid}"
@@ -328,6 +333,6 @@ def _parse_url(html: List[str], entity_text: str, url: str) -> bool:
return False
def init_tg(context: "Context"):
def init_tg(context: "Context") -> None:
global should_highlight_edits
should_highlight_edits = htmldiff and context.config["bridge.highlight_edits"]
+37 -34
View File
@@ -20,40 +20,6 @@ import struct
import re
# add_surrogates and remove_surrogates are unicode surrogate utility functions from Telethon.
# Licensed under the MIT license.
# https://github.com/LonamiWebs/Telethon/blob/master/telethon/extensions/markdown.py
def add_surrogates(text: Optional[str]) -> Optional[str]:
if text is None:
return None
return "".join("".join(chr(y) for y in struct.unpack("<HH", x.encode("utf-16-le")))
if (0x10000 <= ord(x) <= 0x10FFFF) else x for x in text)
def remove_surrogates(text: Optional[str]) -> Optional[str]:
if text is None:
return None
return text.encode("utf-16", "surrogatepass").decode("utf-16")
def trim_reply_fallback_text(text: str) -> str:
if not text.startswith("> ") or "\n" not in text:
return text
lines = text.split("\n")
while len(lines) > 0 and lines[0].startswith("> "):
lines.pop(0)
return "\n".join(lines)
html_reply_fallback_regex = re.compile("^<mx-reply>"
r"[\s\S]+?"
"</mx-reply>") # type: Pattern
def trim_reply_fallback_html(html: str) -> str:
return html_reply_fallback_regex.sub("", html)
def unicode_to_html(text: str, html: str, ctrl: str, tag: str) -> str:
if ctrl not in text:
return html
@@ -84,3 +50,40 @@ def unicode_to_html(text: str, html: str, ctrl: str, tag: str) -> str:
def html_to_unicode(text: str, ctrl: str) -> str:
return ctrl.join(text) + ctrl
# add_surrogates and remove_surrogates are unicode surrogate utility functions from Telethon.
# Licensed under the MIT license.
# https://github.com/LonamiWebs/Telethon/blob/7cce7aa3e4c6c7019a55530391b1761d33e5a04e/telethon/helpers.py
def add_surrogates(text: Optional[str]) -> Optional[str]:
if text is None:
return None
return "".join("".join(chr(y) for y in struct.unpack("<HH", x.encode("utf-16-le")))
if (0x10000 <= ord(x) <= 0x10FFFF) else x for x in text)
def remove_surrogates(text: Optional[str]) -> Optional[str]:
if text is None:
return None
return text.encode("utf-16", "surrogatepass").decode("utf-16")
# trim_reply_fallback_text, html_reply_fallback_regex and trim_reply_fallback_html are Matrix
# reply fallback utility functions.
# You may copy and use them under any OSI-approved license.
def trim_reply_fallback_text(text: str) -> str:
if not text.startswith("> ") or "\n" not in text:
return text
lines = text.split("\n")
while len(lines) > 0 and lines[0].startswith("> "):
lines.pop(0)
return "\n".join(lines)
html_reply_fallback_regex = re.compile("^<mx-reply>"
r"[\s\S]+?"
"</mx-reply>") # type: Pattern
def trim_reply_fallback_html(html: str) -> str:
return html_reply_fallback_regex.sub("", html)
+61 -42
View File
@@ -14,27 +14,31 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import List, Dict, Tuple, Set, Match
from typing import Dict, List, Match, Optional, Set, Tuple, TYPE_CHECKING
import logging
import asyncio
import re
from mautrix_appservice import MatrixRequestError, IntentError
from .types import MatrixEvent, MatrixEventID, MatrixRoomID, MatrixUserID
from . import user as u, portal as po, puppet as pu, commands as com
if TYPE_CHECKING:
from .context import Context
class MatrixHandler:
log = logging.getLogger("mau.mx") # type: logging.Logger
def __init__(self, context):
self.az, self.db, self.config, _, self.tgbot = context
def __init__(self, context: 'Context') -> None:
self.az, self.db, self.config, _, self.tgbot = context.core
self.commands = com.CommandProcessor(context) # type: com.CommandProcessor
self.previously_typing = [] # type: List[str]
self.previously_typing = [] # type: List[MatrixUserID]
self.az.matrix_event_handler(self.handle_event)
async def init_as_bot(self):
async def init_as_bot(self) -> None:
displayname = self.config["appservice.bot_displayname"]
if displayname:
try:
@@ -50,7 +54,8 @@ class MatrixHandler:
except asyncio.TimeoutError:
self.log.exception("TimeoutError when trying to set avatar")
async def handle_puppet_invite(self, room_id, puppet: pu.Puppet, inviter: u.User):
async def handle_puppet_invite(self, room_id: MatrixRoomID, puppet: pu.Puppet, inviter: u.User
) -> None:
intent = puppet.default_mxid_intent
self.log.debug(f"{inviter} invited puppet for {puppet.tgid} to {room_id}")
if not await inviter.is_logged_in():
@@ -80,6 +85,7 @@ class MatrixHandler:
await intent.join_room(room_id)
portal = po.Portal.get_by_tgid(puppet.tgid, inviter.tgid, "user")
# TODO: if portal is None:
if portal.mxid:
try:
await intent.invite(portal.mxid, inviter.mxid)
@@ -95,13 +101,13 @@ class MatrixHandler:
portal.mxid = room_id
portal.save()
inviter.register_portal(portal)
await intent.send_notice(room_id, "po.Portal to private chat created.")
await intent.send_notice(room_id, "Portal to private chat created.")
else:
await intent.join_room(room_id)
await intent.send_notice(room_id, "This puppet will remain inactive until a "
"Telegram chat is created for this room.")
async def accept_bot_invite(self, room_id: str, inviter: u.User):
async def accept_bot_invite(self, room_id: MatrixRoomID, inviter: u.User) -> None:
tries = 0
while tries < 5:
try:
@@ -120,15 +126,19 @@ class MatrixHandler:
if not inviter.whitelisted:
await self.az.intent.send_notice(
room_id, text=None,
room_id, text="",
html="You are not whitelisted to use this bridge.<br/><br/>"
"If you are the owner of this bridge, see the "
"<code>bridge.permissions</code> section in your config file.")
await self.az.intent.leave_room(room_id)
async def handle_invite(self, room_id: str, user_id: str, inviter_mxid: str):
async def handle_invite(self, room_id: MatrixRoomID, user_id: MatrixUserID,
inviter_mxid: MatrixUserID) -> None:
self.log.debug(f"{inviter_mxid} invited {user_id} to {room_id}")
inviter = await u.User.get_by_mxid(inviter_mxid).ensure_started()
inviter = u.User.get_by_mxid(inviter_mxid)
if inviter is None:
self.log.exception("Failed to find user with Matrix ID {inviter_mxid}")
await inviter.ensure_started()
if user_id == self.az.bot_mxid:
return await self.accept_bot_invite(room_id, inviter)
elif not inviter.whitelisted:
@@ -150,7 +160,8 @@ class MatrixHandler:
# The rest can probably be ignored
async def handle_join(self, room_id: str, user_id: str, event_id: str):
async def handle_join(self, room_id: MatrixRoomID, user_id: MatrixUserID,
event_id: MatrixEventID) -> None:
user = await u.User.get_by_mxid(user_id).ensure_started()
portal = po.Portal.get_by_mxid(room_id)
@@ -171,7 +182,8 @@ class MatrixHandler:
if await user.is_logged_in() or portal.has_bot:
await portal.join_matrix(user, event_id)
async def handle_part(self, room_id: str, user_id, sender_mxid: str, event_id: str):
async def handle_part(self, room_id: MatrixRoomID, user_id: MatrixUserID,
sender_mxid: MatrixUserID, event_id: MatrixEventID) -> None:
self.log.debug(f"{user_id} left {room_id}")
sender = u.User.get_by_mxid(sender_mxid, create=False)
@@ -184,8 +196,10 @@ class MatrixHandler:
return
puppet = pu.Puppet.get_by_mxid(user_id)
if sender and puppet:
await portal.leave_matrix(puppet, sender, event_id)
if puppet:
if sender:
await portal.kick_matrix(puppet, sender)
return
user = u.User.get_by_mxid(user_id, create=False)
if not user:
@@ -194,7 +208,7 @@ class MatrixHandler:
if await user.is_logged_in() or portal.has_bot:
await portal.leave_matrix(user, sender, event_id)
def is_command(self, message: dict) -> Tuple[bool, str]:
def is_command(self, message: Dict) -> Tuple[bool, str]:
text = message.get("body", "")
prefix = self.config["bridge.command_prefix"]
is_command = text.startswith(prefix)
@@ -202,9 +216,10 @@ class MatrixHandler:
text = text[len(prefix) + 1:]
return is_command, text
async def handle_message(self, room, sender, message, event_id):
async def handle_message(self, room: MatrixRoomID, sender_id: MatrixUserID, message: Dict,
event_id: MatrixEventID) -> None:
is_command, text = self.is_command(message)
sender = await u.User.get_by_mxid(sender).ensure_started()
sender = await u.User.get_by_mxid(sender_id).ensure_started()
if not sender.relaybot_whitelisted:
self.log.debug(f"Ignoring message \"{message}\" from {sender} to {room}:"
" u.User is not whitelisted.")
@@ -237,7 +252,8 @@ class MatrixHandler:
is_portal=portal is not None)
@staticmethod
async def handle_redaction(room_id: str, sender_mxid: str, event_id: str):
async def handle_redaction(room_id: MatrixRoomID, sender_mxid: MatrixUserID,
event_id: MatrixEventID) -> None:
sender = await u.User.get_by_mxid(sender_mxid).ensure_started()
if not sender.relaybot_whitelisted:
return
@@ -249,14 +265,16 @@ class MatrixHandler:
await portal.handle_matrix_deletion(sender, event_id)
@staticmethod
async def handle_power_levels(room_id: str, sender_mxid: str, new: dict, old: dict):
async def handle_power_levels(room_id: MatrixRoomID, sender_mxid: MatrixUserID,
new: Dict, old: Dict) -> None:
portal = po.Portal.get_by_mxid(room_id)
sender = await u.User.get_by_mxid(sender_mxid).ensure_started()
if await sender.has_full_access(allow_bot=True) and portal:
await portal.handle_matrix_power_levels(sender, new["users"], old["users"])
@staticmethod
async def handle_room_meta(evt_type: str, room_id: str, sender_mxid: str, content: dict):
async def handle_room_meta(evt_type: str, room_id: MatrixRoomID, sender_mxid: MatrixUserID,
content: dict) -> None:
portal = po.Portal.get_by_mxid(room_id)
sender = await u.User.get_by_mxid(sender_mxid).ensure_started()
if await sender.has_full_access(allow_bot=True) and portal:
@@ -270,8 +288,8 @@ class MatrixHandler:
await handler(sender, content[content_key])
@staticmethod
async def handle_room_pin(room_id: str, sender_mxid: str, new_events: Set[str],
old_events: Set[str]):
async def handle_room_pin(room_id: MatrixRoomID, sender_mxid: MatrixUserID,
new_events: Set[str], old_events: Set[str]) -> None:
portal = po.Portal.get_by_mxid(room_id)
sender = await u.User.get_by_mxid(sender_mxid).ensure_started()
if await sender.has_full_access(allow_bot=True) and portal:
@@ -284,8 +302,8 @@ class MatrixHandler:
await portal.handle_matrix_pin(sender, None)
@staticmethod
async def handle_name_change(room_id: str, user_id: str, displayname: str,
prev_displayname: str, event_id: str):
async def handle_name_change(room_id: MatrixRoomID, user_id: MatrixUserID, displayname: str,
prev_displayname: str, event_id: MatrixEventID) -> None:
portal = po.Portal.get_by_mxid(room_id)
if not portal or not portal.has_bot:
return
@@ -295,13 +313,14 @@ class MatrixHandler:
await portal.name_change_matrix(user, displayname, prev_displayname, event_id)
@staticmethod
def parse_read_receipts(content: dict) -> Dict[str, str]:
def parse_read_receipts(content: Dict) -> Dict[MatrixUserID, MatrixEventID]:
return {user_id: event_id
for event_id, receipts in content.items()
for user_id in receipts.get("m.read", {})}
@staticmethod
async def handle_read_receipts(room_id: str, receipts: Dict[str, str]):
async def handle_read_receipts(room_id: MatrixRoomID,
receipts: Dict[MatrixUserID, MatrixEventID]) -> None:
portal = po.Portal.get_by_mxid(room_id)
if not portal:
return
@@ -313,13 +332,13 @@ class MatrixHandler:
await portal.mark_read(user, event_id)
@staticmethod
async def handle_presence(user_id: str, presence: str):
async def handle_presence(user_id: MatrixUserID, presence: str) -> None:
user = await u.User.get_by_mxid(user_id).ensure_started()
if not await user.is_logged_in():
return
await user.set_presence(presence == "online")
async def handle_typing(self, room_id: str, now_typing: List[str]):
async def handle_typing(self, room_id: MatrixRoomID, now_typing: List[MatrixUserID]) -> None:
portal = po.Portal.get_by_mxid(room_id)
if not portal:
return
@@ -338,38 +357,38 @@ class MatrixHandler:
self.previously_typing = now_typing
def filter_matrix_event(self, event: dict):
def filter_matrix_event(self, event: MatrixEvent) -> bool:
sender = event.get("sender", None)
if not sender:
return False
return (sender == self.az.bot_mxid
or pu.Puppet.get_id_from_mxid(sender) is not None)
async def try_handle_event(self, evt: dict):
async def try_handle_event(self, evt: MatrixEvent) -> None:
try:
await self.handle_event(evt)
except Exception:
self.log.exception("Error handling manually received Matrix event")
async def handle_event(self, evt: dict):
async def handle_event(self, evt: MatrixEvent) -> None:
if self.filter_matrix_event(evt):
return
self.log.debug("Received event: %s", evt)
evt_type = evt.get("type", "m.unknown") # type: str
room_id = evt.get("room_id", None) # type: str
event_id = evt.get("event_id", None) # type: str
sender = evt.get("sender", None) # type: str
content = evt.get("content", {}) # type: dict
room_id = evt.get("room_id", None) # type: Optional[MatrixRoomID]
event_id = evt.get("event_id", None) # type: Optional[MatrixEventID]
sender = evt.get("sender", None) # type: Optional[MatrixUserID]
content = evt.get("content", {}) # type: Dict
if evt_type == "m.room.member":
state_key = evt["state_key"] # type: str
prev_content = evt.get("unsigned", {}).get("prev_content", {}) # type: dict
state_key = evt["state_key"] # type: MatrixUserID
prev_content = evt.get("unsigned", {}).get("prev_content", {}) # type: Dict
membership = content.get("membership", "") # type: str
prev_membership = prev_content.get("membership", "leave") # type: str
if membership == prev_membership:
match = re.compile("@(.+):(.+)").match(state_key) # type: Match
localpart = match.group(1) # type: str
displayname = content.get("displayname", localpart) # type: str
prev_displayname = prev_content.get("displayname", localpart) # type: str
mxid = match.group(0) # type: str
displayname = content.get("displayname", None) or mxid # type: str
prev_displayname = prev_content.get("displayname", None) or mxid # type: str
if displayname != prev_displayname:
await self.handle_name_change(room_id, state_key, displayname,
prev_displayname, event_id)
@@ -386,7 +405,7 @@ class MatrixHandler:
elif evt_type == "m.room.redaction":
await self.handle_redaction(room_id, sender, evt["redacts"])
elif evt_type == "m.room.power_levels":
prev_content = evt.get("unsigned", {}).get("prev_content", {}) # type: dict
prev_content = evt.get("unsigned", {}).get("prev_content", {})
await self.handle_power_levels(room_id, sender, evt["content"], prev_content)
elif evt_type in ("m.room.name", "m.room.avatar", "m.room.topic"):
await self.handle_room_meta(evt_type, room_id, sender, evt["content"])
+316 -212
View File
File diff suppressed because it is too large Load Diff
+112 -86
View File
@@ -14,17 +14,20 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional, Awaitable, Pattern, Dict, List, TYPE_CHECKING
from typing import Awaitable, Coroutine, Dict, List, Optional, Pattern, TYPE_CHECKING
from difflib import SequenceMatcher
import re
import logging
from enum import Enum
from aiohttp import ServerDisconnectedError
import asyncio
import logging
import re
from sqlalchemy import orm
from telethon.tl.types import UserProfilePhoto
from telethon.tl.types import UserProfilePhoto, User, FileLocation
from mautrix_appservice import AppService, IntentAPI, IntentError, MatrixRequestError
from .types import MatrixUserID, TelegramID
from .db import Puppet as DBPuppet
from . import util
@@ -32,6 +35,10 @@ if TYPE_CHECKING:
from .matrix import MatrixHandler
from .config import Config
from .context import Context
from . import user as u
from .abstract_user import AbstractUser
PuppetError = Enum('PuppetError', 'Success OnlyLoginSelf InvalidAccessToken')
config = None # type: Config
@@ -45,87 +52,101 @@ class Puppet:
mxid_regex = None # type: Pattern
username_template = None # type: str
hs_domain = None # type: str
cache = {} # type: Dict[str, Puppet]
cache = {} # type: Dict[TelegramID, Puppet]
by_custom_mxid = {} # type: Dict[str, Puppet]
def __init__(self, id=None, access_token=None, custom_mxid=None, username=None,
displayname=None, displayname_source=None, photo_id=None, is_bot=None,
is_registered=False, db_instance=None):
self.id = id
self.access_token = access_token
self.custom_mxid = custom_mxid
self.is_real_user = self.custom_mxid and self.access_token
self.default_mxid = self.get_mxid_from_id(self.id)
self.mxid = self.custom_mxid or self.default_mxid
def __init__(self,
id: TelegramID,
access_token: Optional[str] = None,
custom_mxid: Optional[MatrixUserID] = None,
username: Optional[str] = None,
displayname: Optional[str] = None,
displayname_source: Optional[TelegramID] = None,
photo_id: Optional[str] = None,
is_bot: bool = False,
is_registered: bool = False,
db_instance: Optional[DBPuppet] = None) -> None:
self.id = id # type: TelegramID
self.access_token = access_token # type: Optional[str]
self.custom_mxid = custom_mxid # type: Optional[MatrixUserID]
self.default_mxid = self.get_mxid_from_id(self.id) # type: MatrixUserID
self.username = username
self.displayname = displayname
self.displayname_source = displayname_source
self.photo_id = photo_id
self.is_bot = is_bot
self.is_registered = is_registered
self._db_instance = db_instance
self.username = username # type: Optional[str]
self.displayname = displayname # type: Optional[str]
self.displayname_source = displayname_source # type: Optional[TelegramID]
self.photo_id = photo_id # type: Optional[str]
self.is_bot = is_bot # type: bool
self.is_registered = is_registered # type: bool
self._db_instance = db_instance # type: Optional[DBPuppet]
self.default_mxid_intent = self.az.intent.user(self.default_mxid)
self.intent = None # type: IntentAPI
self.refresh_intents()
self.intent = self._fresh_intent() # type: IntentAPI
self.cache[id] = self
if self.custom_mxid:
self.by_custom_mxid[self.custom_mxid] = self
@property
def tgid(self):
def mxid(self) -> MatrixUserID:
return self.custom_mxid or self.default_mxid
@property
def tgid(self) -> TelegramID:
return self.id
@property
def is_real_user(self) -> bool:
""" Is True when the puppet is a real Matrix user. """
return bool(self.custom_mxid and self.access_token)
@staticmethod
async def is_logged_in():
async def is_logged_in() -> bool:
""" Is True if the puppet is logged in. """
return True
# region Custom puppet management
def refresh_intents(self):
self.is_real_user = self.custom_mxid and self.access_token
self.intent = (self.az.intent.user(self.custom_mxid, self.access_token)
if self.is_real_user else self.default_mxid_intent)
def _fresh_intent(self) -> IntentAPI:
return (self.az.intent.user(self.custom_mxid, self.access_token)
if self.is_real_user else self.default_mxid_intent)
async def switch_mxid(self, access_token, mxid):
async def switch_mxid(self, access_token: Optional[str],
mxid: Optional[MatrixUserID]) -> PuppetError:
prev_mxid = self.custom_mxid
self.custom_mxid = mxid
self.access_token = access_token
self.refresh_intents()
self.intent = self._fresh_intent()
err = await self.init_custom_mxid()
if err != 0:
if err != PuppetError.Success:
return err
try:
del self.by_custom_mxid[prev_mxid]
del self.by_custom_mxid[prev_mxid] # type: ignore
except KeyError:
pass
self.mxid = self.custom_mxid or self.default_mxid
if self.mxid != self.default_mxid:
self.by_custom_mxid[self.mxid] = self
await self.leave_rooms_with_default_user()
self.save()
return 0
return PuppetError.Success
async def init_custom_mxid(self):
async def init_custom_mxid(self) -> PuppetError:
if not self.is_real_user:
return 0
return PuppetError.Success
mxid = await self.intent.whoami()
if not mxid or mxid != self.custom_mxid:
self.custom_mxid = None
self.access_token = None
self.refresh_intents()
self.intent = self._fresh_intent()
if mxid != self.custom_mxid:
return 2
return 1
return PuppetError.OnlyLoginSelf
return PuppetError.InvalidAccessToken
if config["bridge.sync_with_custom_puppets"]:
asyncio.ensure_future(self.sync(), loop=self.loop)
return 0
return PuppetError.Success
async def leave_rooms_with_default_user(self):
async def leave_rooms_with_default_user(self) -> None:
for room_id in await self.default_mxid_intent.get_joined_rooms():
try:
await self.default_mxid_intent.leave_room(room_id)
@@ -159,7 +180,7 @@ class Puppet:
},
})
def filter_events(self, events):
def filter_events(self, events: List[Dict]) -> List:
new_events = []
for event in events:
evt_type = event.get("type", None)
@@ -186,28 +207,28 @@ class Puppet:
new_events.append(event)
return new_events
def handle_sync(self, presence, ephemeral):
presence = [self.mx.try_handle_event(event) for event in presence]
def handle_sync(self, presence: List, ephemeral: Dict) -> None:
presence_events = [self.mx.try_handle_event(event) for event in presence]
for room_id, events in ephemeral.items():
for event in events:
event["room_id"] = room_id
ephemeral = [self.mx.try_handle_event(event)
for events in ephemeral.values()
for event in self.filter_events(events)]
ephemeral_events = [self.mx.try_handle_event(event)
for events in ephemeral.values()
for event in self.filter_events(events)]
events = ephemeral + presence
events = ephemeral_events + presence_events # List[Callable[[int], Awaitable[None]]]
coro = asyncio.gather(*events, loop=self.loop)
asyncio.ensure_future(coro, loop=self.loop)
async def sync(self):
async def sync(self) -> None:
try:
await self._sync()
except Exception:
self.log.exception("Fatal error syncing")
async def _sync(self):
async def _sync(self) -> None:
if not self.is_real_user:
self.log.warning("Called sync() for non-custom puppet.")
return
@@ -220,16 +241,17 @@ class Puppet:
while access_token_at_start == self.access_token:
try:
sync_resp = await self.intent.client.sync(filter=filter_id, since=next_batch,
set_presence="offline")
set_presence="offline") # type: Dict
errors = 0
if next_batch is not None:
presence = sync_resp.get("presence", {}).get("events", [])
presence = sync_resp.get("presence", {}).get("events", []) # type: List
ephemeral = {room: data.get("ephemeral", {}).get("events", [])
for room, data
in sync_resp.get("rooms", {}).get("join", {}).items()}
in sync_resp.get("rooms", {}).get("join", {}).items()
} # type: Dict
self.handle_sync(presence, ephemeral)
next_batch = sync_resp.get("next_batch", None)
except MatrixRequestError as e:
except (MatrixRequestError, ServerDisconnectedError) as e:
wait = min(errors, 11) ** 2
self.log.warning(f"Syncer for {custom_mxid} errored: {e}. "
f"Waiting for {wait} seconds...")
@@ -241,25 +263,25 @@ class Puppet:
# region DB conversion
@property
def db_instance(self):
def db_instance(self) -> DBPuppet:
if not self._db_instance:
self._db_instance = self.new_db_instance()
return self._db_instance
def new_db_instance(self):
def new_db_instance(self) -> DBPuppet:
return DBPuppet(id=self.id, access_token=self.access_token, custom_mxid=self.custom_mxid,
username=self.username, displayname=self.displayname,
displayname_source=self.displayname_source, photo_id=self.photo_id,
is_bot=self.is_bot, matrix_registered=self.is_registered)
@classmethod
def from_db(cls, db_puppet):
def from_db(cls, db_puppet: DBPuppet) -> 'Puppet':
return Puppet(db_puppet.id, db_puppet.access_token, db_puppet.custom_mxid,
db_puppet.username, db_puppet.displayname, db_puppet.displayname_source,
db_puppet.photo_id, db_puppet.is_bot, db_puppet.matrix_registered,
db_instance=db_puppet)
def save(self):
def save(self) -> None:
self.db_instance.access_token = self.access_token
self.db_instance.custom_mxid = self.custom_mxid
self.db_instance.username = self.username
@@ -272,16 +294,16 @@ class Puppet:
# endregion
# region Info updating
def similarity(self, query):
def similarity(self, query: str) -> int:
username_similarity = (SequenceMatcher(None, self.username, query).ratio()
if self.username else 0)
displayname_similarity = (SequenceMatcher(None, self.displayname, query).ratio()
if self.displayname else 0)
similarity = max(username_similarity, displayname_similarity)
return round(similarity * 1000) / 10
return int(round(similarity * 100))
@staticmethod
def get_displayname(info, enable_format=True):
def get_displayname(info: User, enable_format: bool = True) -> str:
data = {
"phone number": info.phone if hasattr(info, "phone") else None,
"username": info.username,
@@ -308,7 +330,7 @@ class Puppet:
return config.get("bridge.displayname_template", "{displayname} (Telegram)").format(
displayname=name)
async def update_info(self, source, info):
async def update_info(self, source: 'AbstractUser', info: User) -> None:
changed = False
if self.username != info.username:
self.username = info.username
@@ -323,12 +345,12 @@ class Puppet:
if changed:
self.save()
async def update_displayname(self, source, info):
async def update_displayname(self, source: 'AbstractUser', info: User) -> bool:
ignore_source = (not source.is_relaybot
and self.displayname_source is not None
and self.displayname_source != source.tgid)
if ignore_source:
return
return False
displayname = self.get_displayname(info)
if displayname != self.displayname:
@@ -339,8 +361,9 @@ class Puppet:
elif source.is_relaybot or self.displayname_source is None:
self.displayname_source = source.tgid
return True
return False
async def update_avatar(self, source, photo):
async def update_avatar(self, source: 'AbstractUser', photo: FileLocation) -> bool:
photo_id = f"{photo.volume_id}-{photo.local_id}"
if self.photo_id != photo_id:
file = await util.transfer_file_to_matrix(self.db, source.client,
@@ -355,7 +378,7 @@ class Puppet:
# region Getters
@classmethod
def get(cls, tgid, create=True) -> "Optional[Puppet]":
def get(cls, tgid: TelegramID, create: bool = True) -> Optional['Puppet']:
try:
return cls.cache[tgid]
except KeyError:
@@ -374,12 +397,15 @@ class Puppet:
return None
@classmethod
def get_by_mxid(cls, mxid, create=True) -> "Optional[Puppet]":
def get_by_mxid(cls, mxid: MatrixUserID, create: bool = True) -> Optional['Puppet']:
tgid = cls.get_id_from_mxid(mxid)
return cls.get(tgid, create) if tgid else None
if tgid:
return cls.get(tgid, create)
return None
@classmethod
def get_by_custom_mxid(cls, mxid):
def get_by_custom_mxid(cls, mxid: MatrixUserID) -> Optional['Puppet']:
if not mxid:
raise ValueError("Matrix ID can't be empty")
@@ -396,25 +422,25 @@ class Puppet:
return None
@classmethod
def get_all_with_custom_mxid(cls):
def get_all_with_custom_mxid(cls) -> List['Puppet']:
return [cls.by_custom_mxid[puppet.mxid]
if puppet.custom_mxid in cls.by_custom_mxid
else cls.from_db(puppet)
for puppet in DBPuppet.query.filter(DBPuppet.custom_mxid is not None).all()]
@classmethod
def get_id_from_mxid(cls, mxid):
def get_id_from_mxid(cls, mxid: MatrixUserID) -> Optional[TelegramID]:
match = cls.mxid_regex.match(mxid)
if match:
return int(match.group(1))
return TelegramID(int(match.group(1)))
return None
@classmethod
def get_mxid_from_id(cls, tgid):
return f"@{cls.username_template.format(userid=tgid)}:{cls.hs_domain}"
def get_mxid_from_id(cls, tgid: TelegramID) -> MatrixUserID:
return MatrixUserID(f"@{cls.username_template.format(userid=tgid)}:{cls.hs_domain}")
@classmethod
def find_by_username(cls, username) -> "Optional[Puppet]":
def find_by_username(cls, username: str) -> Optional['Puppet']:
if not username:
return None
@@ -422,14 +448,14 @@ class Puppet:
if puppet.username and puppet.username.lower() == username.lower():
return puppet
puppet = DBPuppet.query.filter(DBPuppet.username == username).one_or_none()
if puppet:
return cls.from_db(puppet)
dbpuppet = DBPuppet.query.filter(DBPuppet.username == username).one_or_none()
if dbpuppet:
return cls.from_db(dbpuppet)
return None
@classmethod
def find_by_displayname(cls, displayname) -> "Optional[Puppet]":
def find_by_displayname(cls, displayname: str) -> Optional['Puppet']:
if not displayname:
return None
@@ -437,20 +463,20 @@ class Puppet:
if puppet.displayname and puppet.displayname == displayname:
return puppet
puppet = DBPuppet.query.filter(DBPuppet.displayname == displayname).one_or_none()
if puppet:
return cls.from_db(puppet)
dbpuppet = DBPuppet.query.filter(DBPuppet.displayname == displayname).one_or_none()
if dbpuppet:
return cls.from_db(dbpuppet)
return None
# endregion
def init(context: "Context") -> List[Awaitable[int]]:
def init(context: 'Context') -> List[Coroutine]: # [None, None, PuppetError]
global config
Puppet.az, Puppet.db, config, Puppet.loop, _ = context
Puppet.az, Puppet.db, config, Puppet.loop, _ = context.core
Puppet.mx = context.mx
Puppet.username_template = config.get("bridge.username_template", "telegram_{userid}")
Puppet.hs_domain = config["homeserver"]["domain"]
Puppet.mxid_regex = re.compile(
f"@{Puppet.username_template.format(userid='(.+)')}:{Puppet.hs_domain}")
f"@{Puppet.username_template.format(userid='([0-9]+)')}:{Puppet.hs_domain}")
return [puppet.init_custom_mxid() for puppet in Puppet.get_all_with_custom_mxid()]
@@ -32,15 +32,15 @@ TelematrixBase.metadata.bind = telematrix_db_engine
chat_links = telematrix.query(ChatLink).all()
tg_users = telematrix.query(TgUser).all()
mx_users = telematrix.query(MatrixUser).all()
messages = telematrix.query(TMMessage).all()
tm_messages = telematrix.query(TMMessage).all()
telematrix.close()
telematrix_db_engine.dispose()
portals = {}
chats = {}
messages = {}
puppets = {}
portals = {} # Dict[int, Portal]
chats = {} # Dict[int, BotChat]
messages = {} # Dict[str, Message]
puppets = {} # Dict[int, Puppet]
for chat_link in chat_links:
if type(chat_link.tg_room) is str:
@@ -65,11 +65,12 @@ for chat_link in chat_links:
portals[chat_link.tg_room] = portal
chats[tgid] = bot_chat
for tm_msg in messages:
for tm_msg in tm_messages:
try:
portal = portals[tm_msg.tg_group_id]
except KeyError:
print("Found message entry %d in unlinked chat %d, ignoring..." % (tm_msg.tg_message_id, tm_msg.tg_group_id))
print("Found message entry %d in unlinked chat %d, ignoring..." % (tm_msg.tg_message_id,
tm_msg.tg_group_id))
continue
tg_space = portal.tgid if portal.peer_type == "channel" else args.bot_id
message = Message(mxid=tm_msg.matrix_event_id, mx_room=tm_msg.matrix_room_id,
@@ -77,7 +78,8 @@ for tm_msg in messages:
messages[tm_msg.matrix_event_id] = message
for user in tg_users:
puppets[user.tg_id] = Puppet(id=user.tg_id, displayname=user.name, displayname_source=args.bot_id)
puppets[user.tg_id] = Puppet(id=user.tg_id, displayname=user.name,
displayname_source=args.bot_id)
for k, v in portals.items():
mxtg.add(v)
@@ -5,7 +5,7 @@ Base = declarative_base()
class ChatLink(Base):
__tablename__ = 'chat_link'
__tablename__ = "chat_link"
id = sa.Column(sa.Integer, primary_key=True)
matrix_room = sa.Column(sa.String)
@@ -14,7 +14,7 @@ class ChatLink(Base):
class TgUser(Base):
__tablename__ = 'tg_user'
__tablename__ = "tg_user"
id = sa.Column(sa.Integer, primary_key=True)
tg_id = sa.Column(sa.BigInteger)
@@ -23,7 +23,7 @@ class TgUser(Base):
class MatrixUser(Base):
__tablename__ = 'matrix_user'
__tablename__ = "matrix_user"
id = sa.Column(sa.Integer, primary_key=True)
matrix_id = sa.Column(sa.String)
+15 -13
View File
@@ -20,37 +20,39 @@ from sqlalchemy import orm
from mautrix_appservice import StateStore
from .types import MatrixUserID, MatrixRoomID
from . import puppet as pu
from .db import RoomState, UserProfile
class SQLStateStore(StateStore):
def __init__(self, db):
def __init__(self, db: orm.Session) -> None:
super().__init__()
self.db = db # type: orm.Session
self.profile_cache = {} # type: Dict[Tuple[str, str], UserProfile]
self.room_state_cache = {} # type: Dict[str, RoomState]
@staticmethod
def is_registered(user: str) -> bool:
def is_registered(user: MatrixUserID) -> bool:
puppet = pu.Puppet.get_by_mxid(user)
return puppet.is_registered if puppet else False
@staticmethod
def registered(user: str):
def registered(user: MatrixUserID) -> None:
puppet = pu.Puppet.get_by_mxid(user)
if puppet:
puppet.is_registered = True
puppet.save()
def update_state(self, event: dict):
def update_state(self, event: Dict) -> None:
event_type = event["type"]
if event_type == "m.room.power_levels":
self.set_power_levels(event["room_id"], event["content"])
elif event_type == "m.room.member":
self.set_member(event["room_id"], event["state_key"], event["content"])
def _get_user_profile(self, room_id: str, user_id: str, create: bool = True) -> UserProfile:
def _get_user_profile(self, room_id: MatrixRoomID, user_id: MatrixUserID, create: bool = True
) -> UserProfile:
key = (room_id, user_id)
try:
return self.profile_cache[key]
@@ -67,22 +69,22 @@ class SQLStateStore(StateStore):
self.profile_cache[key] = profile
return profile
def get_member(self, room: str, user: str) -> dict:
def get_member(self, room: MatrixRoomID, user: MatrixUserID) -> Dict:
return self._get_user_profile(room, user).dict()
def set_member(self, room: str, user: str, member: dict):
def set_member(self, room: MatrixRoomID, user: MatrixUserID, member: Dict) -> None:
profile = self._get_user_profile(room, user)
profile.membership = member.get("membership", profile.membership or "leave")
profile.displayname = member.get("displayname", profile.displayname)
profile.avatar_url = member.get("avatar_url", profile.avatar_url)
self.db.commit()
def set_membership(self, room: str, user: str, membership: str):
def set_membership(self, room: MatrixRoomID, user: MatrixUserID, membership: str) -> None:
self.set_member(room, user, {
"membership": membership,
})
def _get_room_state(self, room_id: str, create: bool = True) -> RoomState:
def _get_room_state(self, room_id: MatrixRoomID, create: bool = True) -> RoomState:
try:
return self.room_state_cache[room_id]
except KeyError:
@@ -96,13 +98,13 @@ class SQLStateStore(StateStore):
self.room_state_cache[room_id] = room
return room
def has_power_levels(self, room: str) -> bool:
def has_power_levels(self, room: MatrixRoomID) -> bool:
return self._get_room_state(room).has_power_levels
def get_power_levels(self, room: str) -> dict:
def get_power_levels(self, room: MatrixRoomID) -> Dict:
return self._get_room_state(room).power_levels
def set_power_level(self, room: str, user: str, level: int):
def set_power_level(self, room: MatrixRoomID, user: MatrixUserID, level: int) -> None:
room_state = self._get_room_state(room)
power_levels = room_state.power_levels
if not power_levels:
@@ -114,7 +116,7 @@ class SQLStateStore(StateStore):
room_state.power_levels = power_levels
self.db.commit()
def set_power_levels(self, room: str, content: dict):
def set_power_levels(self, room: MatrixRoomID, content: Dict) -> None:
state = self._get_room_state(room)
state.power_levels = content
self.db.commit()
+5 -1
View File
@@ -14,9 +14,13 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import List, Union, Optional
from telethon import TelegramClient, utils
from telethon.tl.functions.messages import SendMediaRequest
from telethon.tl.types import *
from telethon.tl.types import (
InputMediaUploadedDocument, InputMediaUploadedPhoto, TypeDocumentAttribute, TypeInputMedia,
TypeInputPeer, TypeMessageEntity, TypeMessageMedia, TypePeer)
from telethon.tl import custom
+9
View File
@@ -0,0 +1,9 @@
from typing import Dict, NewType
MatrixUserID = NewType('MatrixUserID', str)
MatrixRoomID = NewType('MatrixRoomID', str)
MatrixEventID = NewType('MatrixEventID', str)
MatrixEvent = NewType('MatrixEvent', Dict)
TelegramID = NewType('TelegramID', int)
+76 -56
View File
@@ -14,18 +14,21 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict, Awaitable, Optional, Match, Tuple, TYPE_CHECKING
from typing import Coroutine, Dict, List, Match, NewType, Optional, Tuple, cast, TYPE_CHECKING
import logging
import asyncio
import re
from telethon.tl.types import *
from telethon.tl.types import (
TypeUpdate, UpdateNewMessage, UpdateNewChannelMessage, PeerUser,
UpdateShortChatMessage, UpdateShortMessage)
from telethon.tl.types import User as TLUser
from telethon.tl.types.contacts import ContactsNotModified
from telethon.tl.functions.contacts import GetContactsRequest, SearchRequest
from telethon.tl.functions.account import UpdateStatusRequest
from mautrix_appservice import MatrixRequestError
from .types import MatrixUserID, TelegramID
from .db import User as DBUser, Contact as DBContact, Portal as DBPortal
from .abstract_user import AbstractUser
from . import portal as po, puppet as pu
@@ -36,7 +39,7 @@ if TYPE_CHECKING:
config = None # type: Config
SearchResults = List[Tuple["pu.Puppet", int]]
SearchResult = NewType('SearchResult', Tuple['pu.Puppet', int])
class User(AbstractUser):
@@ -44,23 +47,26 @@ class User(AbstractUser):
by_mxid = {} # type: Dict[str, User]
by_tgid = {} # type: Dict[int, User]
def __init__(self, mxid: str, tgid: Optional[int] = None, username: Optional[str] = None,
db_contacts: Optional[List[DBContact]] = None, saved_contacts: int = 0,
is_bot: bool = False, db_portals: Optional[List[DBPortal]] = None,
db_instance: Optional[DBUser] = None):
def __init__(self, mxid: MatrixUserID, tgid: Optional[TelegramID] = None,
username: Optional[str] = None, phone: Optional[str] = None,
db_contacts: Optional[List[DBContact]] = None,
saved_contacts: int = 0, is_bot: bool = False,
db_portals: Optional[List[DBPortal]] = None,
db_instance: Optional[DBUser] = None) -> None:
super().__init__()
self.mxid = mxid # type: str
self.tgid = tgid # type: int
self.mxid = mxid # type: MatrixUserID
self.tgid = tgid # type: TelegramID
self.is_bot = is_bot # type: bool
self.username = username # type: str
self.phone = phone # type: str
self.contacts = [] # type: List[pu.Puppet]
self.saved_contacts = saved_contacts # type: int
self.db_contacts = db_contacts # type: List[DBContact]
self.portals = {} # type: Dict[Tuple[int, int], po.Portal]
self.db_portals = db_portals # type: List[DBPortal]
self._db_instance = db_instance # type: DBUser
self.db_portals = db_portals or [] # type: List[DBPortal]
self._db_instance = db_instance # type: Optional[DBUser]
self.command_status = None # type: dict
self.command_status = None # type: Dict
(self.relaybot_whitelisted,
self.whitelisted,
@@ -82,6 +88,10 @@ class User(AbstractUser):
match = re.compile("@(.+):(.+)").match(self.mxid) # type: Match
return match.group(1)
@property
def human_tg_id(self) -> str:
return f"@{self.username}" if self.username else f"+{self.phone}" or None
# TODO replace with proper displayname getting everywhere
@property
def displayname(self) -> str:
@@ -93,7 +103,7 @@ class User(AbstractUser):
for puppet in self.contacts]
@db_contacts.setter
def db_contacts(self, contacts: List[DBContact]):
def db_contacts(self, contacts: List[DBContact]) -> None:
self.contacts = [pu.Puppet.get(entry.contact) for entry in contacts] if contacts else []
@property
@@ -101,10 +111,12 @@ class User(AbstractUser):
return [portal.db_instance for portal in self.portals.values() if not portal.deleted]
@db_portals.setter
def db_portals(self, portals: List[DBPortal]):
self.portals = {(portal.tgid, portal.tg_receiver):
po.Portal.get_by_tgid(portal.tgid, portal.tg_receiver)
for portal in portals} if portals else {}
def db_portals(self, portals: List[DBPortal]) -> None:
self.portals = {
(portal.tgid, portal.tg_receiver): po.Portal.get_by_tgid(portal.tgid,
portal.tg_receiver)
for portal in portals
} if portals else {}
# region Database conversion
@@ -116,18 +128,19 @@ class User(AbstractUser):
def new_db_instance(self) -> DBUser:
return DBUser(mxid=self.mxid, tgid=self.tgid, tg_username=self.username,
contacts=self.db_contacts, saved_contacts=self.saved_contacts or 0,
contacts=self.db_contacts, saved_contacts=self.saved_contacts,
portals=self.db_portals)
def save(self):
def save(self) -> None:
self.db_instance.tgid = self.tgid
self.db_instance.username = self.username
self.db_instance.tg_username = self.username
self.db_instance.tg_phone = self.phone
self.db_instance.contacts = self.db_contacts
self.db_instance.saved_contacts = self.saved_contacts or 0
self.db_instance.saved_contacts = self.saved_contacts
self.db_instance.portals = self.db_portals
self.db.commit()
def delete(self):
def delete(self) -> None:
try:
del self.by_mxid[self.mxid]
del self.by_tgid[self.tgid]
@@ -138,14 +151,15 @@ class User(AbstractUser):
self.db.commit()
@classmethod
def from_db(cls, db_user: DBUser) -> "User":
return User(db_user.mxid, db_user.tgid, db_user.tg_username, db_user.contacts,
False, db_user.saved_contacts, db_user.portals, db_instance=db_user)
def from_db(cls, db_user: DBUser) -> 'User':
return User(db_user.mxid, db_user.tgid, db_user.tg_username, db_user.tg_phone,
db_user.contacts, db_user.saved_contacts, False, db_user.portals,
db_instance=db_user)
# endregion
# region Telegram connection management
async def start(self, delete_unless_authenticated: bool = False) -> "User":
async def start(self, delete_unless_authenticated: bool = False) -> 'User':
await super().start()
if await self.is_logged_in():
self.log.debug(f"Ensuring post_login() for {self.name}")
@@ -156,7 +170,7 @@ class User(AbstractUser):
self.client.session.delete()
return self
async def post_login(self, info: TLUser = None):
async def post_login(self, info: TLUser = None) -> None:
try:
await self.update_info(info)
if not self.is_bot:
@@ -167,9 +181,9 @@ class User(AbstractUser):
except Exception:
self.log.exception("Failed to run post-login functions for %s", self.mxid)
async def update(self, update: TypeUpdate):
async def update(self, update: TypeUpdate) -> bool:
if not self.is_bot:
return
return False
if isinstance(update, (UpdateNewMessage, UpdateNewChannelMessage)):
message = update.message
@@ -179,26 +193,28 @@ class User(AbstractUser):
else:
portal = po.Portal.get_by_entity(message.to_id, receiver_id=self.tgid)
elif isinstance(update, UpdateShortChatMessage):
portal = po.Portal.get_by_tgid(update.chat_id, peer_type="chat")
portal = po.Portal.get_by_tgid(TelegramID(update.chat_id), peer_type="chat")
elif isinstance(update, UpdateShortMessage):
portal = po.Portal.get_by_tgid(update.user_id, self.tgid, "user")
portal = po.Portal.get_by_tgid(TelegramID(update.user_id), self.tgid, "user")
else:
return
return False
self.register_portal(portal)
if portal:
self.register_portal(portal)
return True
# endregion
# region Telegram actions that need custom methods
def ensure_started(self, even_if_no_session: bool = False) -> "Awaitable[User]":
return super().ensure_started(even_if_no_session)
def ensure_started(self, even_if_no_session: bool = False) -> Coroutine[None, None, 'User']:
return cast(Coroutine[None, None, 'User'], super().ensure_started(even_if_no_session))
def set_presence(self, online: bool = True):
if self.is_bot:
return
return self.client(UpdateStatusRequest(offline=not online))
async def set_presence(self, online: bool = True) -> None:
if not self.is_bot:
await self.client(UpdateStatusRequest(offline=not online))
async def update_info(self, info: TLUser = None):
async def update_info(self, info: TLUser = None) -> None:
info = info or await self.client.get_me()
changed = False
if self.is_bot != info.bot:
@@ -207,13 +223,16 @@ class User(AbstractUser):
if self.username != info.username:
self.username = info.username
changed = True
if self.phone != info.phone:
self.phone = info.phone
changed = True
if self.tgid != info.id:
self.tgid = info.id
self.by_tgid[self.tgid] = self
if changed:
self.save()
async def log_out(self):
async def log_out(self) -> bool:
puppet = pu.Puppet.get(self.tgid)
if puppet.is_real_user:
await puppet.switch_mxid(None, None)
@@ -241,28 +260,29 @@ class User(AbstractUser):
return True
def _search_local(self, query: str, max_results: int = 5, min_similarity: int = 45
) -> SearchResults:
results = [] # type: SearchResults
) -> List[SearchResult]:
results = [] # type: List[SearchResult]
for contact in self.contacts:
similarity = contact.similarity(query)
if similarity >= min_similarity:
results.append((contact, similarity))
results.append(SearchResult((contact, similarity)))
results.sort(key=lambda tup: tup[1], reverse=True)
return results[0:max_results]
async def _search_remote(self, query: str, max_results: int = 5) -> SearchResults:
async def _search_remote(self, query: str, max_results: int = 5) -> List[SearchResult]:
if len(query) < 5:
return []
server_results = await self.client(SearchRequest(q=query, limit=max_results))
results = [] # type: SearchResults
results = [] # type: List[SearchResult]
for user in server_results.users:
puppet = pu.Puppet.get(user.id)
await puppet.update_info(self, user)
results.append((puppet, puppet.similarity(query)))
results.append(SearchResult((puppet, puppet.similarity(query))))
results.sort(key=lambda tup: tup[1], reverse=True)
return results[0:max_results]
async def search(self, query: str, force_remote: bool = False) -> Tuple[SearchResults, bool]:
async def search(self, query: str, force_remote: bool = False
) -> Tuple[List[SearchResult], bool]:
if force_remote:
return await self._search_remote(query), True
@@ -272,7 +292,7 @@ class User(AbstractUser):
return await self._search_remote(query), True
async def sync_dialogs(self, synchronous_create: bool = False):
async def sync_dialogs(self, synchronous_create: bool = False) -> None:
creators = []
for entity in await self.get_dialogs(limit=30):
portal = po.Portal.get_by_entity(entity)
@@ -283,7 +303,7 @@ class User(AbstractUser):
self.save()
await asyncio.gather(*creators, loop=self.loop)
def register_portal(self, portal: po.Portal):
def register_portal(self, portal: po.Portal) -> None:
try:
if self.portals[portal.tgid_full] == portal:
return
@@ -292,7 +312,7 @@ class User(AbstractUser):
self.portals[portal.tgid_full] = portal
self.save()
def unregister_portal(self, portal: po.Portal):
def unregister_portal(self, portal: po.Portal) -> None:
try:
del self.portals[portal.tgid_full]
self.save()
@@ -309,7 +329,7 @@ class User(AbstractUser):
acc = (acc * 20261 + id) & 0xffffffff
return acc & 0x7fffffff
async def sync_contacts(self):
async def sync_contacts(self) -> None:
response = await self.client(GetContactsRequest(hash=self._hash_contacts()))
if isinstance(response, ContactsNotModified):
return
@@ -326,7 +346,7 @@ class User(AbstractUser):
# region Class instance lookup
@classmethod
def get_by_mxid(cls, mxid: str, create: bool=True) -> "Optional[User]":
def get_by_mxid(cls, mxid: MatrixUserID, create: bool = True) -> Optional['User']:
if not mxid:
raise ValueError("Matrix ID can't be empty")
@@ -349,7 +369,7 @@ class User(AbstractUser):
return None
@classmethod
def get_by_tgid(cls, tgid: int) -> "Optional[User]":
def get_by_tgid(cls, tgid: int) -> Optional['User']:
try:
return cls.by_tgid[tgid]
except KeyError:
@@ -363,7 +383,7 @@ class User(AbstractUser):
return None
@classmethod
def find_by_username(cls, username: str) -> "Optional[User]":
def find_by_username(cls, username: str) -> Optional['User']:
if not username:
return None
@@ -379,7 +399,7 @@ class User(AbstractUser):
# endregion
def init(context: "Context") -> List[Awaitable[User]]:
def init(context: 'Context') -> List[Coroutine]: # [None, None, AbstractUser]
global config
config = context.config
+1
View File
@@ -1,3 +1,4 @@
from .file_transfer import transfer_file_to_matrix, convert_image
from .format_duration import format_duration
from .signed_token import sign_token, verify_token
from .recursive_dict import recursive_del, recursive_set, recursive_get
+2 -1
View File
@@ -27,7 +27,8 @@ from sqlalchemy.orm.exc import FlushError
from telethon.tl.types import (Document, FileLocation, InputFileLocation,
InputDocumentFileLocation, PhotoSize, PhotoCachedSize)
from telethon.errors import *
from telethon.errors import (AuthBytesInvalidError, AuthKeyInvalidError, LocationInvalidError,
SecurityError)
from mautrix_appservice import IntentAPI
from ..tgclient import MautrixTelegramClient
+2 -2
View File
@@ -17,10 +17,10 @@
def format_duration(seconds: int) -> str:
def pluralize(count, singular):
def pluralize(count: int, singular: str) -> str:
return singular if count == 1 else singular + "s"
def include(count, word):
def include(count: int, word: str) -> str:
return f"{count} {pluralize(count, word)}" if count > 0 else ""
minutes, seconds = divmod(seconds, 60)
+54
View File
@@ -0,0 +1,54 @@
# -*- coding: future_fstrings -*-
# mautrix-telegram - A Matrix-Telegram puppeting bridge
# Copyright (C) 2018 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict, Any
from ..config import DictWithRecursion
def recursive_set(data: Dict[str, Any], key: str, value: Any) -> bool:
key, next_key = DictWithRecursion._parse_key(key)
if next_key is not None:
if key not in data:
data[key] = {}
next_data = data.get(key, {})
if not isinstance(next_data, dict):
return False
return recursive_set(next_data, next_key, value)
data[key] = value
return True
def recursive_get(data: Dict[str, Any], key: str) -> Any:
key, next_key = DictWithRecursion._parse_key(key)
if next_key is not None:
next_data = data.get(key, None)
if not next_data:
return None
return recursive_get(next_data, next_key)
return data.get(key, None)
def recursive_del(data: Dict[str, any], key: str) -> bool:
key, next_key = DictWithRecursion._parse_key(key)
if next_key is not None:
if key not in data:
return False
next_data = data.get(key, {})
return recursive_del(next_data, next_key)
if key in data:
del data[key]
return True
return False
+6 -6
View File
@@ -14,7 +14,7 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional
from typing import Dict, Optional
import json
import base64
import hashlib
@@ -28,13 +28,13 @@ def _get_checksum(key: str, payload: bytes) -> str:
return checksum
def sign_token(key: str, payload: dict) -> str:
payload = base64.urlsafe_b64encode(json.dumps(payload).encode("utf-8"))
checksum = _get_checksum(key, payload)
return f"{checksum}:{payload.decode('utf-8')}"
def sign_token(key: str, payload: Dict) -> str:
payload_b64 = base64.urlsafe_b64encode(json.dumps(payload).encode("utf-8"))
checksum = _get_checksum(key, payload_b64)
return f"{checksum}:{payload_b64.decode('utf-8')}"
def verify_token(key: str, data: str) -> Optional[dict]:
def verify_token(key: str, data: str) -> Optional[Dict]:
if not data:
return None
+43 -27
View File
@@ -15,6 +15,9 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from abc import abstractmethod
from typing import Optional
from aiohttp import web
import abc
import asyncio
import logging
@@ -23,27 +26,30 @@ from telethon.errors import *
from ...commands.auth import enter_password
from ...util import format_duration
from ...puppet import Puppet
from ...puppet import Puppet, PuppetError
from ...user import User
class AuthAPI(abc.ABC):
log = logging.getLogger("mau.web.auth")
log = logging.getLogger("mau.web.auth") # type: logging.Logger
def __init__(self, loop):
def __init__(self, loop: asyncio.AbstractEventLoop):
self.loop = loop # type: asyncio.AbstractEventLoop
@abstractmethod
def get_login_response(self, status=200, state="", username="", mxid="", message="", error="",
errcode=""):
def get_login_response(self, status: int = 200, state: str = "", username: str = "",
phone: str = "", human_tg_id: str = "", mxid: str = "",
message: str = "", error: str = "", errcode: str = "") -> web.Response:
raise NotImplementedError()
@abstractmethod
def get_mx_login_response(self, status=200, state="", username="", mxid="", message="",
error="", errcode=""):
def get_mx_login_response(self, status: int = 200, state: str = "", username: str = "",
phone: str = "", human_tg_id: str = "", mxid: str = "",
message: str = "", error: str = "", errcode: str = ""
) -> web.Response:
raise NotImplementedError()
async def post_matrix_token(self, user: User, token):
async def post_matrix_token(self, user: User, token: str) -> web.Response:
puppet = Puppet.get(user.tgid)
if puppet.is_real_user:
return self.get_mx_login_response(state="already-logged-in", status=409,
@@ -51,20 +57,21 @@ class AuthAPI(abc.ABC):
"account.", errcode="already-logged-in")
resp = await puppet.switch_mxid(token, user.mxid)
if resp == 2:
if resp == PuppetError.OnlyLoginSelf:
return self.get_mx_login_response(status=403, errcode="only-login-self",
error="You can only log in as your own Matrix user.")
elif resp == 1:
elif resp == PuppetError.InvalidAccessToken:
return self.get_mx_login_response(status=401, errcode="invalid-access-token",
error="Failed to verify access token.")
assert resp == PuppetError.Success, "Encountered an unhandled PuppetError."
return self.get_mx_login_response(mxid=user.mxid, status=200, state="logged-in")
async def post_matrix_password(self, user, password):
async def post_matrix_password(self, user: User, password: str) -> web.Response:
return self.get_mx_login_response(mxid=user.mxid, status=501, error="Not yet implemented",
errcode="not-yet-implemented")
async def post_login_phone(self, user, phone):
async def post_login_phone(self, user: User, phone: str) -> web.Response:
try:
await user.client.sign_in(phone or "+123")
return self.get_login_response(mxid=user.mxid, state="code", status=200,
@@ -101,14 +108,22 @@ class AuthAPI(abc.ABC):
errcode="unknown_error",
error="Internal server error while requesting code.")
async def post_login_token(self, user, token):
async def postprocess_login(self, user: User, user_info) -> None:
existing_user = User.get_by_tgid(user_info.id)
if existing_user and existing_user != user:
await existing_user.log_out()
asyncio.ensure_future(user.post_login(user_info), loop=self.loop)
if user.command_status and user.command_status["action"] == "Login":
user.command_status = None
async def post_login_token(self, user: User, token: str) -> web.Response:
try:
user_info = await user.client.sign_in(bot_token=token)
asyncio.ensure_future(user.post_login(user_info), loop=self.loop)
if user.command_status and user.command_status["action"] == "Login":
user.command_status = None
await self.postprocess_login(user, user_info)
return self.get_login_response(mxid=user.mxid, state="logged-in", status=200,
username=user_info.username)
username=user_info.username, phone=None,
human_tg_id=f"@{user_info.username}")
except AccessTokenInvalidError:
return self.get_login_response(mxid=user.mxid, state="token", status=401,
errcode="bot_token_invalid",
@@ -122,14 +137,15 @@ class AuthAPI(abc.ABC):
return self.get_login_response(mxid=user.mxid, state="token", status=500,
error="Internal server error while sending token.")
async def post_login_code(self, user, code, password_in_data):
async def post_login_code(self, user: User, code: int, password_in_data: bool
) -> Optional[web.Response]:
try:
user_info = await user.client.sign_in(code=code)
asyncio.ensure_future(user.post_login(user_info), loop=self.loop)
if user.command_status and user.command_status["action"] == "Login":
user.command_status = None
await self.postprocess_login(user, user_info)
human_tg_id = f"@{user_info.username}" if user_info.username else f"+{user_info.phone}"
return self.get_login_response(mxid=user.mxid, state="logged-in", status=200,
username=user_info.username)
username=user_info.username, phone=user_info.phone,
human_tg_id=human_tg_id)
except PhoneCodeInvalidError:
return self.get_login_response(mxid=user.mxid, state="code", status=401,
errcode="phone_code_invalid",
@@ -155,14 +171,14 @@ class AuthAPI(abc.ABC):
errcode="unknown_error",
error="Internal server error while sending code.")
async def post_login_password(self, user, password):
async def post_login_password(self, user: User, password: str) -> web.Response:
try:
user_info = await user.client.sign_in(password=password)
asyncio.ensure_future(user.post_login(user_info), loop=self.loop)
if user.command_status and user.command_status["action"] == "Login (password entry)":
user.command_status = None
await self.postprocess_login(user, user_info)
human_tg_id = f"@{user_info.username}" if user_info.username else f"+{user_info.phone}"
return self.get_login_response(mxid=user.mxid, state="logged-in", status=200,
username=user_info.username)
username=user_info.username, phone=user_info.phone,
human_tg_id=human_tg_id)
except PasswordEmptyError:
return self.get_login_response(mxid=user.mxid, state="password", status=400,
errcode="password_empty",
+39 -20
View File
@@ -15,7 +15,7 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from aiohttp import web
from typing import Tuple, Optional, Callable, Awaitable, TYPE_CHECKING
from typing import Awaitable, Callable, Dict, Optional, Tuple, TYPE_CHECKING
import asyncio
import logging
import json
@@ -24,6 +24,7 @@ from telethon.utils import get_peer_id, resolve_id
from telethon.tl.types import ChatForbidden, ChannelForbidden, TypeChat
from mautrix_appservice import AppService, MatrixRequestError, IntentError
from ...types import MatrixUserID, TelegramID
from ...user import User
from ...portal import Portal
from ...commands.portal import user_has_power_level, get_initial_state
@@ -34,20 +35,21 @@ if TYPE_CHECKING:
class ProvisioningAPI(AuthAPI):
log = logging.getLogger("mau.web.provisioning")
log = logging.getLogger("mau.web.provisioning") # type: logging.Logger
def __init__(self, context: "Context"):
def __init__(self, context: "Context") -> None:
super().__init__(context.loop)
self.secret = context.config["appservice.provisioning.shared_secret"]
self.secret = context.config["appservice.provisioning.shared_secret"] # type: str
self.az = context.az # type: AppService
self.context = context # type: Context
self.app = web.Application(loop=context.loop, middlewares=[self.error_middleware])
self.app = web.Application(loop=context.loop, middlewares=[self.error_middleware]
) # type: web.Application
portal_prefix = "/portal/{mxid:![^/]+}"
self.app.router.add_route("GET", f"{portal_prefix}", self.get_portal_by_mxid)
self.app.router.add_route("GET", "/portal/{tgid:-[0-9]+}", self.get_portal_by_tgid)
self.app.router.add_route("POST", portal_prefix + "/connect/{chat_id:[0-9]+}",
self.app.router.add_route("POST", portal_prefix + "/connect/{chat_id:-[0-9]+}",
self.connect_chat)
self.app.router.add_route("POST", f"{portal_prefix}/create", self.create_chat)
self.app.router.add_route("POST", f"{portal_prefix}/disconnect", self.disconnect_chat)
@@ -62,6 +64,8 @@ class ProvisioningAPI(AuthAPI):
self.app.router.add_route("POST", f"{user_prefix}/login/send_code", self.send_code)
self.app.router.add_route("POST", f"{user_prefix}/login/send_password", self.send_password)
self.app.router.add_route("GET", "/bridge", self.bridge_info)
async def get_portal_by_mxid(self, request: web.Request) -> web.Response:
err = self.check_authorization(request)
if err is not None:
@@ -72,6 +76,8 @@ class ProvisioningAPI(AuthAPI):
if not portal:
return self.get_error_response(404, "portal_not_found",
"Portal with given Matrix ID not found.")
user, _ = await self.get_user(request.query.get("user_id", None), expect_logged_in=None,
require_puppeting=False)
return web.json_response({
"mxid": portal.mxid,
"chat_id": get_peer_id(portal.peer),
@@ -80,6 +86,7 @@ class ProvisioningAPI(AuthAPI):
"about": portal.about,
"username": portal.username,
"megagroup": portal.megagroup,
"can_unbridge": (await portal.can_user_perform(user, "unbridge")) if user else False,
})
async def get_portal_by_tgid(self, request: web.Request) -> web.Response:
@@ -96,6 +103,8 @@ class ProvisioningAPI(AuthAPI):
if not portal:
return self.get_error_response(404, "portal_not_found",
"Portal to given Telegram chat not found.")
user, _ = await self.get_user(request.query.get("user_id", None), expect_logged_in=None,
require_puppeting=False)
return web.json_response({
"mxid": portal.mxid,
"chat_id": get_peer_id(portal.peer),
@@ -104,6 +113,7 @@ class ProvisioningAPI(AuthAPI):
"about": portal.about,
"username": portal.username,
"megagroup": portal.megagroup,
"can_unbridge": (await portal.can_user_perform(user, "unbridge")) if user else False,
})
async def connect_chat(self, request: web.Request) -> web.Response:
@@ -118,10 +128,10 @@ class ProvisioningAPI(AuthAPI):
chat_id = request.match_info["chat_id"]
if chat_id.startswith("-100"):
tgid = int(chat_id[4:])
tgid = TelegramID(int(chat_id[4:]))
peer_type = "channel"
elif chat_id.startswith("-"):
tgid = -int(chat_id)
tgid = TelegramID(-int(chat_id))
peer_type = "chat"
else:
return self.get_error_response(400, "tgid_invalid", "Invalid Telegram chat ID.")
@@ -153,14 +163,14 @@ class ProvisioningAPI(AuthAPI):
"Matrix room.")
is_logged_in = user is not None and await user.is_logged_in()
user = user if is_logged_in else self.context.bot
if not user:
acting_user = user if is_logged_in else self.context.bot
if not acting_user:
return self.get_login_response(status=403, errcode="not_logged_in",
error="You are not logged in and there is no relay bot.")
entity = None # type: Optional[TypeChat]
try:
entity = await user.client.get_entity(portal.peer)
entity = await acting_user.client.get_entity(portal.peer)
except Exception:
self.log.exception("Failed to get_entity(%s) for manual bridging.", portal.peer)
@@ -351,8 +361,14 @@ class ProvisioningAPI(AuthAPI):
return err
await user.log_out()
async def bridge_info(self, request: web.Request) -> web.Response:
return web.json_response({
"relaybot_username": self.context.bot.username,
}, status=200)
@staticmethod
async def error_middleware(_, handler) -> Callable[[web.Request], Awaitable[web.Response]]:
async def error_middleware(_, handler: Callable[[web.Request], Awaitable[web.Response]]
) -> Callable[[web.Request], Awaitable[web.Response]]:
async def middleware_handler(request: web.Request) -> web.Response:
try:
return await handler(request)
@@ -371,16 +387,18 @@ class ProvisioningAPI(AuthAPI):
"errcode": errcode,
}, status=status)
def get_mx_login_response(self, status=200, state="", username="", mxid="", message="",
error="", errcode=""):
def get_mx_login_response(self, status=200, state="", username="", phone="", human_tg_id="",
mxid="", message="", error="", errcode=""):
raise NotImplementedError()
def get_login_response(self, status=200, state="", username="", mxid="", message="", error="",
errcode="") -> web.Response:
if username:
def get_login_response(self, status=200, state="", username="", phone: str = "",
human_tg_id: str = "", mxid="", message="", error="", errcode=""
) -> web.Response:
if username or phone:
resp = {
"state": "logged-in",
"username": username,
"phone": phone,
}
elif message:
resp = {
@@ -411,7 +429,7 @@ class ProvisioningAPI(AuthAPI):
except json.JSONDecodeError:
return None
async def get_user(self, mxid: str, expect_logged_in: Optional[bool] = False,
async def get_user(self, mxid: MatrixUserID, expect_logged_in: Optional[bool] = False,
require_puppeting: bool = True, require_user: bool = True
) -> Tuple[Optional[User], Optional[web.Response]]:
if not mxid:
@@ -427,7 +445,8 @@ class ProvisioningAPI(AuthAPI):
if expect_logged_in is not None:
logged_in = await user.is_logged_in()
if not expect_logged_in and logged_in:
return user, self.get_login_response(username=user.username, status=409,
return user, self.get_login_response(username=user.username, phone=user.phone,
status=409,
error="You are already logged in.",
errcode="already_logged_in")
elif expect_logged_in and not logged_in:
@@ -439,7 +458,7 @@ class ProvisioningAPI(AuthAPI):
expect_logged_in: Optional[bool] = False,
require_puppeting: bool = False,
want_data: bool = True,
) -> (Tuple[Optional[dict],
) -> (Tuple[Optional[Dict],
Optional[User],
Optional[web.Response]]):
err = self.check_authorization(request)
+35 -1
View File
@@ -22,8 +22,23 @@ tags:
- name: User info
- name: Authentication
- name: Bridging
- name: Misc
paths:
/bridge:
get:
operationId: get_bridge
summary: Get the bridge's information
tags: [Misc]
responses:
200:
description: The bridge information
schema:
type: object
properties:
relaybot_username:
type: string
description: The relay bot's username on Telegram
/portal/{room_id}:
get:
operationId: get_portal
@@ -57,6 +72,11 @@ paths:
required: true
type: string
pattern: "![^/]+"
- name: user_id
in: query
description: Optional Matrix user ID to check if the user has permissions to do bridging.
required: false
type: string
/portal/{chat_id}:
get:
operationId: get_portal_by_tgid
@@ -102,6 +122,11 @@ paths:
required: true
type: integer
pattern: "-[0-9]+"
- name: user_id
in: query
description: Optional Matrix user ID to check if the user has permissions to do bridging.
required: false
type: string
/portal/{room_id}/connect/{chat_id}:
post:
operationId: connect_portal
@@ -706,6 +731,9 @@ responses:
username:
type: string
description: The Telegram username the user is logged in as.
phone:
type: string
description: The phone number of the account the user is logged into.
BadRequest:
description: Invalid JSON.
schema:
@@ -790,7 +818,7 @@ definitions:
example: A.
phone:
type: string
example: +123456789
example: 123456789
is_bot:
type: boolean
example: false
@@ -829,6 +857,9 @@ definitions:
type: string
about:
type: string
can_unbridge:
type: boolean
description: If a user ID was provided with the request, this will indicate whether or not the user can unbridge the room.
AuthSuccess:
type: object
@@ -845,6 +876,9 @@ definitions:
username:
type: string
description: The Telegram username the user is logged in as. Only applicable if state=logged-in
phone:
type: string
description: The phone number of the account the user logged into. Only applicable if state=logged-in
HumanReadableError:
type: string
+31 -25
View File
@@ -14,14 +14,17 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional
from aiohttp import web
from mako.template import Template
import pkg_resources
import asyncio
import logging
import random
import string
import time
from ...types import MatrixUserID
from ...util import sign_token, verify_token
from ...user import User
from ...puppet import Puppet
@@ -29,20 +32,20 @@ from ..common import AuthAPI
class PublicBridgeWebsite(AuthAPI):
log = logging.getLogger("mau.web.public")
log = logging.getLogger("mau.web.public") # type: logging.Logger
def __init__(self, loop):
def __init__(self, loop: asyncio.AbstractEventLoop):
super().__init__(loop)
self.secret_key = "".join(
random.choice(string.ascii_lowercase + string.digits) for _ in range(64))
random.choice(string.ascii_lowercase + string.digits) for _ in range(64)) # type: str
self.login = Template(
pkg_resources.resource_string("mautrix_telegram", "web/public/login.html.mako"))
self.login = Template(pkg_resources.resource_string(
"mautrix_telegram", "web/public/login.html.mako")) # type: Template
self.mx_login = Template(
pkg_resources.resource_string("mautrix_telegram", "web/public/matrix-login.html.mako"))
self.mx_login = Template(pkg_resources.resource_string(
"mautrix_telegram", "web/public/matrix-login.html.mako")) # type: Template
self.app = web.Application(loop=loop)
self.app = web.Application(loop=loop) # type: web.Application
self.app.router.add_route("GET", "/login", self.get_login)
self.app.router.add_route("POST", "/login", self.post_login)
self.app.router.add_route("GET", "/matrix-login", self.get_matrix_login)
@@ -50,21 +53,21 @@ class PublicBridgeWebsite(AuthAPI):
self.app.router.add_static("/", pkg_resources.resource_filename("mautrix_telegram",
"web/public/"))
def make_token(self, mxid, endpoint="/login", expires_in=900):
def make_token(self, mxid: str, endpoint: str = "/login", expires_in: int = 900) -> str:
return sign_token(self.secret_key, {
"mxid": mxid,
"endpoint": endpoint,
"expiry": int(time.time()) + expires_in,
})
def verify_token(self, token, endpoint="/login"):
def verify_token(self, token: str, endpoint: str = "/login") -> Optional[MatrixUserID]:
token = verify_token(self.secret_key, token)
if token and (token.get("expiry", 0) > int(time.time()) and
token.get("endpoint", None) == endpoint):
return token.get("mxid", None)
return MatrixUserID(token.get("mxid", None))
return None
async def get_login(self, request):
async def get_login(self, request: web.Request) -> web.Response:
state = "bot_token" if request.rel_url.query.get("mode", "") == "bot" else "request"
mxid = self.verify_token(request.rel_url.query.get("token", None), endpoint="/login")
@@ -81,9 +84,9 @@ class PublicBridgeWebsite(AuthAPI):
if not await user.is_logged_in():
return self.get_login_response(mxid=user.mxid, state=state)
return self.get_login_response(mxid=user.mxid, username=user.username)
return self.get_login_response(mxid=user.mxid, human_tg_id=user.human_tg_id)
async def get_matrix_login(self, request):
async def get_matrix_login(self, request: web.Request) -> web.Response:
mxid = self.verify_token(request.rel_url.query.get("token", None), endpoint="/matrix-login")
if not mxid:
return self.get_mx_login_response(status=401, state="invalid-token")
@@ -105,19 +108,22 @@ class PublicBridgeWebsite(AuthAPI):
return self.get_mx_login_response(mxid=user.mxid)
def get_login_response(self, status=200, state="", username="", mxid="", message="", error="",
errcode=""):
def get_login_response(self, status: int = 200, state: str = "", username: str = "",
phone: str = "", human_tg_id: str = "", mxid: str = "",
message: str = "", error: str = "", errcode: str = "") -> web.Response:
return web.Response(status=status, content_type="text/html",
text=self.login.render(username=username, state=state, error=error,
message=message, mxid=mxid))
text=self.login.render(human_tg_id=human_tg_id, state=state,
error=error, message=message, mxid=mxid))
def get_mx_login_response(self, status=200, state="", username="", mxid="", message="",
error="", errcode=""):
def get_mx_login_response(self, status: int = 200, state: str = "", username: str = "",
phone: str = "", human_tg_id: str = "", mxid: str = "",
message: str = "", error: str = "", errcode: str = ""
) -> web.Response:
return web.Response(status=status, content_type="text/html",
text=self.mx_login.render(username=username, state=state, error=error,
message=message, mxid=mxid))
text=self.mx_login.render(human_tg_id=human_tg_id, state=state,
error=error, message=message, mxid=mxid))
async def post_matrix_login(self, request):
async def post_matrix_login(self, request: web.Request) -> web.Response:
mxid = self.verify_token(request.rel_url.query.get("token", None), endpoint="/matrix-login")
if not mxid:
return self.get_mx_login_response(status=401, state="invalid-token")
@@ -140,7 +146,7 @@ class PublicBridgeWebsite(AuthAPI):
error="You must provide an access token or "
"password.")
async def post_login(self, request):
async def post_login(self, request: web.Request) -> web.Response:
mxid = self.verify_token(request.rel_url.query.get("token", None), endpoint="/login")
if not mxid:
return self.get_login_response(status=401, state="invalid-token")
@@ -152,7 +158,7 @@ class PublicBridgeWebsite(AuthAPI):
return self.get_login_response(mxid=user.mxid, error="You are not whitelisted.",
status=403)
elif await user.is_logged_in():
return self.get_login_response(mxid=user.mxid, username=user.username)
return self.get_login_response(mxid=user.mxid, human_tg_id=user.human_tg_id)
await user.ensure_started(even_if_no_session=True)
+4 -4
View File
@@ -51,25 +51,25 @@ along with this program. If not, see <https://www.gnu.org/licenses/>.
</head>
<body>
<main class="container">
% if username:
% if human_tg_id:
% if state == "logged-in":
<h1>Logged in successfully!</h1>
<p>
Logged in as @${username}.
Logged in as ${human_tg_id}.
You can now close this page.
You should be invited to Telegram portals on Matrix momentarily.
</p>
% elif state == "bot-logged-in":
<h1>Logged in successfully!</h1>
<p>
Logged in as @${username}.
Logged in as ${human_tg_id}.
You can now close this page.
You should be invited to Telegram portals on Matrix momentarily.
</p>
% else:
<h1>You're already logged in!</h1>
<p>
You're logged in as @${username}.
You're logged in as ${human_tg_id}.
</p>
<p>
If you want to log in with another account, log out using the <code>logout</code>
+1 -1
View File
@@ -4,7 +4,7 @@ ruamel.yaml
python-magic
SQLAlchemy
alembic
Markdown
commonmark
future-fstrings
telethon
telethon-session-sqlalchemy
+3 -2
View File
@@ -27,10 +27,10 @@ setuptools.setup(
install_requires=[
"aiohttp>=3.0.1,<4",
"mautrix-appservice>=0.3.6,<0.4.0",
"mautrix-appservice>=0.3.7,<0.4.0",
"SQLAlchemy>=1.2.3,<2",
"alembic>=1.0.0,<2",
"Markdown>=2.6.11,<3",
"commonmark>=0.8.1,<1",
"ruamel.yaml>=0.15.35,<0.16",
"future-fstrings>=0.4.2",
"python-magic>=0.4.15,<0.5",
@@ -43,6 +43,7 @@ setuptools.setup(
"Development Status :: 4 - Beta",
"License :: OSI Approved :: GNU Affero General Public License v3 or later (AGPLv3+)",
"Topic :: Communications :: Chat",
"Framework :: AsyncIO",
"Programming Language :: Python",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.5",