Move some utility methods from portal to separate files

This commit is contained in:
Tulir Asokan
2021-12-21 15:27:10 +02:00
parent 7595b9c015
commit 2615e11e34
8 changed files with 296 additions and 226 deletions
+30 -223
View File
@@ -21,7 +21,6 @@ from typing import (
AsyncGenerator,
Awaitable,
Callable,
Iterable,
List,
NamedTuple,
Union,
@@ -52,7 +51,6 @@ from telethon.tl.functions.channels import (
CreateChannelRequest,
EditPhotoRequest,
EditTitleRequest,
GetParticipantsRequest,
InviteToChannelRequest,
JoinChannelRequest,
UpdateUsernameRequest,
@@ -74,16 +72,8 @@ from telethon.tl.patched import Message, MessageService
from telethon.tl.types import (
Channel,
ChannelFull,
ChannelParticipantAdmin,
ChannelParticipantBanned,
ChannelParticipantCreator,
ChannelParticipantsRecent,
ChannelParticipantsSearch,
Chat,
ChatBannedRights,
ChatFull,
ChatParticipantAdmin,
ChatParticipantCreator,
ChatPhoto,
ChatPhotoEmpty,
Document,
@@ -190,12 +180,12 @@ from mautrix.util.message_send_checkpoint import MessageSendCheckpointStatus
from mautrix.util.simple_lock import SimpleLock
from mautrix.util.simple_template import SimpleTemplate
from . import abstract_user as au, formatter, puppet as p, user as u, util
from . import abstract_user as au, formatter, portal_util as putil, puppet as p, user as u, util
from .config import Config
from .db import Message as DBMessage, Portal as DBPortal, TelegramFile as DBTelegramFile
from .tgclient import MautrixTelegramClient
from .types import TelegramID
from .util import parallel_transfer_to_telegram, sane_mimetypes
from .util import sane_mimetypes
try:
from mautrix.crypto.attachments import decrypt_attachment
@@ -210,7 +200,6 @@ StateBridge = EventType.find("m.bridge", EventType.Class.STATE)
StateHalfShotBridge = EventType.find("uk.half-shot.bridge", EventType.Class.STATE)
InviteList = Union[UserID, List[UserID]]
TypeParticipant = Union[TypeChatParticipant, TypeChannelParticipant]
UpdateTyping = Union[UpdateUserTyping, UpdateChatUserTyping, UpdateChannelUserTyping]
TypeChatPhoto = Union[ChatPhoto, ChatPhotoEmpty, Photo, PhotoEmpty]
MediaHandler = Callable[["au.AbstractUser", IntentAPI, Message, RelatesTo], Awaitable[EventID]]
@@ -256,8 +245,8 @@ class Portal(DBPortal, BasePortal):
alias: RoomAlias | None
dedup: util.PortalDedup
send_lock: util.PortalSendLock
dedup: putil.PortalDedup
send_lock: putil.PortalSendLock
_pin_lock: asyncio.Lock
_main_intent: IntentAPI | None
@@ -301,8 +290,8 @@ class Portal(DBPortal, BasePortal):
self.backfill_method_lock = asyncio.Lock()
self.backfill_leave = None
self.dedup = util.PortalDedup(self)
self.send_lock = util.PortalSendLock()
self.dedup = putil.PortalDedup(self)
self.send_lock = putil.PortalSendLock()
self._pin_lock = asyncio.Lock()
self._room_create_lock = asyncio.Lock()
@@ -370,8 +359,6 @@ class Portal(DBPortal, BasePortal):
return self.tgid not in self.filter_list
return True
# endregion
@classmethod
def init_cls(cls, bridge: "TelegramBridge") -> None:
BasePortal.bridge = bridge
@@ -397,11 +384,12 @@ class Portal(DBPortal, BasePortal):
)
NotificationDisabler.puppet_cls = p.Puppet
NotificationDisabler.config_enabled = cls.config["bridge.backfill.disable_notifications"]
util.PortalDedup.dedup_pre_db_check = cls.config["bridge.deduplication.pre_db_check"]
util.PortalDedup.dedup_cache_queue_length = cls.config[
putil.PortalDedup.dedup_pre_db_check = cls.config["bridge.deduplication.pre_db_check"]
putil.PortalDedup.dedup_cache_queue_length = cls.config[
"bridge.deduplication.cache_queue_length"
]
# endregion
# region Matrix -> Telegram metadata
async def get_telegram_users_in_matrix_room(
@@ -522,7 +510,7 @@ class Portal(DBPortal, BasePortal):
levels = await self.main_intent.get_power_levels(self.mxid)
if levels.get_user_level(self.main_intent.mxid) == 100:
levels = self._get_base_power_levels(levels, entity)
levels = putil.get_base_power_levels(self, levels, entity)
await self.main_intent.set_power_levels(self.mxid, levels)
await self.handle_matrix_power_levels(source, levels.users, {}, None)
await self.update_bridge_info()
@@ -740,7 +728,7 @@ class Portal(DBPortal, BasePortal):
# TODO? properly handle existing room aliases
await self.main_intent.remove_room_alias(alias)
power_levels = self._get_base_power_levels(entity=entity)
power_levels = putil.get_base_power_levels(self, entity=entity)
users = None
if not direct:
users = await self._get_users(user, entity)
@@ -749,7 +737,7 @@ class Portal(DBPortal, BasePortal):
invites += extra_invites
for invite in extra_invites:
power_levels.users.setdefault(invite, 100)
await self._participants_to_power_levels(users, power_levels)
await putil.participants_to_power_levels(self, users, power_levels)
elif self.bot and self.tg_receiver == self.bot.tgid:
invites = self.config["bridge.relaybot.private_chat.invite"]
for invite in invites:
@@ -846,133 +834,26 @@ class Portal(DBPortal, BasePortal):
return self.mxid
def _get_base_power_levels(
self, levels: PowerLevelStateEventContent = None, entity: TypeChat = None
) -> PowerLevelStateEventContent:
levels = levels or PowerLevelStateEventContent()
if self.peer_type == "user":
overrides = self.config["bridge.initial_power_level_overrides.user"]
levels.ban = overrides.get("ban", 100)
levels.kick = overrides.get("kick", 100)
levels.invite = overrides.get("invite", 100)
levels.redact = overrides.get("redact", 0)
levels.events[EventType.ROOM_NAME] = 0
levels.events[EventType.ROOM_AVATAR] = 0
levels.events[EventType.ROOM_TOPIC] = 0
levels.state_default = overrides.get("state_default", 0)
levels.users_default = overrides.get("users_default", 0)
levels.events_default = overrides.get("events_default", 0)
else:
overrides = self.config["bridge.initial_power_level_overrides.group"]
dbr = entity.default_banned_rights
if not dbr:
self.log.debug(f"default_banned_rights is None in {entity}")
dbr = ChatBannedRights(
invite_users=True,
change_info=True,
pin_messages=True,
send_stickers=False,
send_messages=False,
until_date=None,
)
levels.ban = overrides.get("ban", 50)
levels.kick = overrides.get("kick", 50)
levels.redact = overrides.get("redact", 50)
levels.invite = overrides.get("invite", 50 if dbr.invite_users else 0)
levels.events[EventType.ROOM_ENCRYPTION] = 50 if self.matrix.e2ee else 99
levels.events[EventType.ROOM_TOMBSTONE] = 99
levels.events[EventType.ROOM_NAME] = 50 if dbr.change_info else 0
levels.events[EventType.ROOM_AVATAR] = 50 if dbr.change_info else 0
levels.events[EventType.ROOM_TOPIC] = 50 if dbr.change_info else 0
levels.events[EventType.ROOM_PINNED_EVENTS] = 50 if dbr.pin_messages else 0
levels.events[EventType.ROOM_POWER_LEVELS] = 75
levels.events[EventType.ROOM_HISTORY_VISIBILITY] = 75
levels.events[EventType.STICKER] = 50 if dbr.send_stickers else levels.events_default
levels.state_default = overrides.get("state_default", 50)
levels.users_default = overrides.get("users_default", 0)
levels.events_default = overrides.get(
"events_default",
50
if (
self.peer_type == "channel"
and not entity.megagroup
or entity.default_banned_rights.send_messages
)
else 0,
)
for evt_type, value in overrides.get("events", {}).items():
levels.events[EventType.find(evt_type)] = value
levels.users = overrides.get("users", {})
if self.main_intent.mxid not in levels.users:
levels.users[self.main_intent.mxid] = 100
return levels
@classmethod
def _get_level_from_participant(
cls, participant: TypeParticipant, levels: PowerLevelStateEventContent
) -> int:
# TODO use the power level requirements to get better precision in channels
if isinstance(participant, (ChatParticipantAdmin, ChannelParticipantAdmin)):
return levels.state_default or 50
elif isinstance(participant, (ChatParticipantCreator, ChannelParticipantCreator)):
return levels.get_user_level(cls.az.bot_mxid) - 5
return levels.users_default or 0
@staticmethod
def _participant_to_power_levels(
levels: PowerLevelStateEventContent,
user: u.User | p.Puppet,
new_level: int,
bot_level: int,
) -> bool:
new_level = min(new_level, bot_level)
user_level = levels.get_user_level(user.mxid)
if user_level != new_level and user_level < bot_level:
levels.users[user.mxid] = new_level
return True
return False
async def _participants_to_power_levels(
self, users: list[TypeUser | TypeParticipant], levels: PowerLevelStateEventContent
) -> bool:
bot_level = levels.get_user_level(self.main_intent.mxid)
if bot_level < levels.get_event_level(EventType.ROOM_POWER_LEVELS):
return False
changed = False
admin_power_level = min(75 if self.peer_type == "channel" else 50, bot_level)
if levels.get_event_level(EventType.ROOM_POWER_LEVELS) != admin_power_level:
changed = True
levels.events[EventType.ROOM_POWER_LEVELS] = admin_power_level
for user in users:
# The User objects we get from TelegramClient.get_participants have a custom
# participant property
participant = getattr(user, "participant", user)
puppet = await p.Puppet.get_by_tgid(TelegramID(participant.user_id))
user = await u.User.get_by_tgid(TelegramID(participant.user_id))
new_level = self._get_level_from_participant(participant, levels)
if user:
await user.register_portal(self)
changed = (
self._participant_to_power_levels(levels, user, new_level, bot_level)
or changed
)
if puppet:
changed = (
self._participant_to_power_levels(levels, puppet, new_level, bot_level)
or changed
)
return changed
async def _get_users(
self,
user: au.AbstractUser,
entity: TypeInputPeer | InputUser | TypeChat | TypeUser | InputChannel,
) -> list[TypeUser]:
if self.peer_type == "channel" and not self.megagroup and not self.sync_channel_members:
return []
limit = self.max_initial_member_sync
if limit == 0:
return []
return await putil.get_users(user.client, self.tgid, entity, limit, self.peer_type)
async def update_power_levels(
self, users: list[TypeUser | TypeParticipant], levels: PowerLevelStateEventContent = None
self,
users: list[TypeUser | TypeChatParticipant | TypeChannelParticipant],
levels: PowerLevelStateEventContent = None
) -> None:
if not levels:
levels = await self.main_intent.get_power_levels(self.mxid)
if await self._participants_to_power_levels(users, levels):
if await putil.participants_to_power_levels(self, users, levels):
await self.main_intent.set_power_levels(self.mxid, levels)
async def _add_bot_chat(self, bot: User) -> None:
@@ -1235,80 +1116,6 @@ class Portal(DBPortal, BasePortal):
return True
return False
@staticmethod
def _filter_participants(
users: list[TypeUser], participants: list[TypeParticipant]
) -> Iterable[TypeUser]:
participant_map = {
part.user_id: part
for part in participants
if not isinstance(part, ChannelParticipantBanned)
}
for user in users:
try:
user.participant = participant_map[user.id]
except KeyError:
pass
else:
yield user
async def _get_channel_users(
self, user: au.AbstractUser, entity: InputChannel, limit: int
) -> list[TypeUser]:
if 0 < limit <= 200:
response = await user.client(
GetParticipantsRequest(
entity, ChannelParticipantsRecent(), offset=0, limit=limit, hash=0
)
)
return list(self._filter_participants(response.users, response.participants))
elif limit > 200 or limit == -1:
users: list[TypeUser] = []
offset = 0
remaining_quota = limit if limit > 0 else 1000000
query = ChannelParticipantsSearch("") if limit == -1 else ChannelParticipantsRecent()
while True:
if remaining_quota <= 0:
break
response = await user.client(
GetParticipantsRequest(
entity, query, offset=offset, limit=min(remaining_quota, 200), hash=0
)
)
if not response.users:
break
users += self._filter_participants(response.users, response.participants)
offset += len(response.participants)
remaining_quota -= len(response.participants)
return users
async def _get_users(
self,
user: au.AbstractUser,
entity: TypeInputPeer | InputUser | TypeChat | TypeUser | InputChannel,
) -> list[TypeUser]:
limit = self.max_initial_member_sync
if self.peer_type == "chat":
chat = await user.client(GetFullChatRequest(chat_id=self.tgid))
return list(
self._filter_participants(chat.users, chat.full_chat.participants.participants)
)[:limit]
elif self.peer_type == "channel":
if not self.megagroup and not self.sync_channel_members:
return []
if limit == 0:
return []
try:
return await self._get_channel_users(user, entity, limit)
except ChatAdminRequiredError:
return []
elif self.peer_type == "user":
return [entity]
else:
raise RuntimeError(f"Unexpected peer type {self.peer_type}")
# endregion
# region Matrix -> Telegram bridging
@@ -1589,7 +1396,7 @@ class Portal(DBPortal, BasePortal):
max_image_size = self.config["bridge.image_as_file_size"] * 1000 ** 2
if self.config["bridge.parallel_file_transfer"] and content.url:
file_handle, file_size = await parallel_transfer_to_telegram(
file_handle, file_size = await util.parallel_transfer_to_telegram(
client, self.main_intent, content.url, sender_id
)
else:
@@ -2405,7 +2212,7 @@ class Portal(DBPortal, BasePortal):
async def _handle_telegram_dice(
self, _: au.AbstractUser, intent: IntentAPI, evt: Message, relates_to: RelatesTo
) -> EventID:
content = util.make_dice_event_content(evt.media)
content = putil.make_dice_event_content(evt.media)
content.relates_to = relates_to
content.external_url = self._get_external_url(evt)
await intent.set_typing(self.mxid, is_typing=False)
@@ -2461,7 +2268,7 @@ class Portal(DBPortal, BasePortal):
async def _handle_telegram_contact(
self, source: au.AbstractUser, intent: IntentAPI, evt: Message, relates_to: RelatesTo
) -> EventID:
content = await util.make_contact_event_content(source, evt.media)
content = await putil.make_contact_event_content(source, evt.media)
content.relates_to = relates_to
content.external_url = self._get_external_url(evt)
+5
View File
@@ -0,0 +1,5 @@
from .deduplication import PortalDedup
from .media_fallback import make_contact_event_content, make_dice_event_content
from .participants import get_users
from .power_levels import get_base_power_levels, participants_to_power_levels
from .send_lock import PortalSendLock
@@ -0,0 +1,107 @@
# mautrix-telegram - A Matrix-Telegram puppeting bridge
# Copyright (C) 2021 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
from typing import Iterable
from telethon.errors import ChatAdminRequiredError
from telethon.tl.functions.channels import GetParticipantsRequest
from telethon.tl.functions.messages import GetFullChatRequest
from telethon.tl.types import (
ChannelParticipantBanned,
ChannelParticipantsRecent,
ChannelParticipantsSearch,
InputChannel,
InputUser,
TypeChannelParticipant,
TypeChat,
TypeChatParticipant,
TypeInputPeer,
TypeUser,
)
from ..tgclient import MautrixTelegramClient
def _filter_participants(
users: list[TypeUser], participants: list[TypeChatParticipant | TypeChannelParticipant]
) -> Iterable[TypeUser]:
participant_map = {
part.user_id: part
for part in participants
if not isinstance(part, ChannelParticipantBanned)
}
for user in users:
try:
user.participant = participant_map[user.id]
except KeyError:
pass
else:
yield user
async def _get_channel_users(
client: MautrixTelegramClient, entity: InputChannel, limit: int
) -> list[TypeUser]:
if 0 < limit <= 200:
response = await client(
GetParticipantsRequest(
entity, ChannelParticipantsRecent(), offset=0, limit=limit, hash=0
)
)
return list(_filter_participants(response.users, response.participants))
elif limit > 200 or limit == -1:
users: list[TypeUser] = []
offset = 0
remaining_quota = limit if limit > 0 else 1000000
query = ChannelParticipantsSearch("") if limit == -1 else ChannelParticipantsRecent()
while True:
if remaining_quota <= 0:
break
response = await client(
GetParticipantsRequest(
entity, query, offset=offset, limit=min(remaining_quota, 200), hash=0
)
)
if not response.users:
break
users += _filter_participants(response.users, response.participants)
offset += len(response.participants)
remaining_quota -= len(response.participants)
return users
async def get_users(
client: MautrixTelegramClient,
tgid: int,
entity: TypeInputPeer | InputUser | TypeChat | TypeUser | InputChannel,
limit: int,
peer_type: str,
) -> list[TypeUser]:
if peer_type == "chat":
chat = await client(GetFullChatRequest(chat_id=tgid))
return list(_filter_participants(chat.users, chat.full_chat.participants.participants))[
:limit
]
elif peer_type == "channel":
try:
return await _get_channel_users(client, entity, limit)
except ChatAdminRequiredError:
return []
elif peer_type == "user":
return [entity]
else:
raise RuntimeError(f"Unexpected peer type {peer_type}")
@@ -0,0 +1,154 @@
# mautrix-telegram - A Matrix-Telegram puppeting bridge
# Copyright (C) 2021 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
from telethon.tl.types import (
ChannelParticipantAdmin,
ChannelParticipantCreator,
ChatBannedRights,
ChatParticipantAdmin,
ChatParticipantCreator,
TypeChannelParticipant,
TypeChat,
TypeChatParticipant,
TypeUser,
)
from mautrix.types import EventType, PowerLevelStateEventContent as PowerLevelContent, UserID
from .. import portal as po, puppet as pu, user as u
from ..types import TelegramID
def get_base_power_levels(
portal: po.Portal, levels: PowerLevelContent = None, entity: TypeChat = None
) -> PowerLevelContent:
levels = levels or PowerLevelContent()
if portal.peer_type == "user":
overrides = portal.config["bridge.initial_power_level_overrides.user"]
levels.ban = overrides.get("ban", 100)
levels.kick = overrides.get("kick", 100)
levels.invite = overrides.get("invite", 100)
levels.redact = overrides.get("redact", 0)
levels.events[EventType.ROOM_NAME] = 0
levels.events[EventType.ROOM_AVATAR] = 0
levels.events[EventType.ROOM_TOPIC] = 0
levels.state_default = overrides.get("state_default", 0)
levels.users_default = overrides.get("users_default", 0)
levels.events_default = overrides.get("events_default", 0)
else:
overrides = portal.config["bridge.initial_power_level_overrides.group"]
dbr = entity.default_banned_rights
if not dbr:
portal.log.debug(f"default_banned_rights is None in {entity}")
dbr = ChatBannedRights(
invite_users=True,
change_info=True,
pin_messages=True,
send_stickers=False,
send_messages=False,
until_date=None,
)
levels.ban = overrides.get("ban", 50)
levels.kick = overrides.get("kick", 50)
levels.redact = overrides.get("redact", 50)
levels.invite = overrides.get("invite", 50 if dbr.invite_users else 0)
levels.events[EventType.ROOM_ENCRYPTION] = 50 if portal.matrix.e2ee else 99
levels.events[EventType.ROOM_TOMBSTONE] = 99
levels.events[EventType.ROOM_NAME] = 50 if dbr.change_info else 0
levels.events[EventType.ROOM_AVATAR] = 50 if dbr.change_info else 0
levels.events[EventType.ROOM_TOPIC] = 50 if dbr.change_info else 0
levels.events[EventType.ROOM_PINNED_EVENTS] = 50 if dbr.pin_messages else 0
levels.events[EventType.ROOM_POWER_LEVELS] = 75
levels.events[EventType.ROOM_HISTORY_VISIBILITY] = 75
levels.events[EventType.STICKER] = 50 if dbr.send_stickers else levels.events_default
levels.state_default = overrides.get("state_default", 50)
levels.users_default = overrides.get("users_default", 0)
levels.events_default = overrides.get(
"events_default",
50
if (
portal.peer_type == "channel"
and not entity.megagroup
or entity.default_banned_rights.send_messages
)
else 0,
)
for evt_type, value in overrides.get("events", {}).items():
levels.events[EventType.find(evt_type)] = value
levels.users = overrides.get("users", {})
if portal.main_intent.mxid not in levels.users:
levels.users[portal.main_intent.mxid] = 100
return levels
async def participants_to_power_levels(
portal: po.Portal,
users: list[TypeUser | TypeChatParticipant | TypeChannelParticipant],
levels: PowerLevelContent,
) -> bool:
bot_level = levels.get_user_level(portal.main_intent.mxid)
if bot_level < levels.get_event_level(EventType.ROOM_POWER_LEVELS):
return False
changed = False
admin_power_level = min(75 if portal.peer_type == "channel" else 50, bot_level)
if levels.get_event_level(EventType.ROOM_POWER_LEVELS) != admin_power_level:
changed = True
levels.events[EventType.ROOM_POWER_LEVELS] = admin_power_level
for user in users:
# The User objects we get from TelegramClient.get_participants have a custom
# participant property
participant = getattr(user, "participant", user)
puppet = await pu.Puppet.get_by_tgid(TelegramID(participant.user_id))
user = await u.User.get_by_tgid(TelegramID(participant.user_id))
new_level = _get_level_from_participant(portal.az.bot_mxid, participant, levels)
if user:
await user.register_portal(portal)
changed = _participant_to_power_levels(levels, user, new_level, bot_level) or changed
if puppet:
changed = _participant_to_power_levels(levels, puppet, new_level, bot_level) or changed
return changed
def _get_level_from_participant(
bot_mxid: UserID,
participant: TypeUser | TypeChatParticipant | TypeChannelParticipant,
levels: PowerLevelContent,
) -> int:
# TODO use the power level requirements to get better precision in channels
if isinstance(participant, (ChatParticipantAdmin, ChannelParticipantAdmin)):
return levels.state_default or 50
elif isinstance(participant, (ChatParticipantCreator, ChannelParticipantCreator)):
return levels.get_user_level(bot_mxid) - 5
return levels.users_default or 0
def _participant_to_power_levels(
levels: PowerLevelContent,
user: u.User | pu.Puppet,
new_level: int,
bot_level: int,
) -> bool:
new_level = min(new_level, bot_level)
user_level = levels.get_user_level(user.mxid)
if user_level != new_level and user_level < bot_level:
levels.users[user.mxid] = new_level
return True
return False
-3
View File
@@ -1,7 +1,4 @@
from .color_log import ColorFormatter
from .deduplication import PortalDedup
from .file_transfer import convert_image, transfer_file_to_matrix
from .media_fallback import make_contact_event_content, make_dice_event_content
from .parallel_file_transfer import parallel_transfer_to_telegram
from .recursive_dict import recursive_del, recursive_get, recursive_set
from .send_lock import PortalSendLock