diff --git a/example-config.yaml b/example-config.yaml index 5f7e9753..61a444e1 100644 --- a/example-config.yaml +++ b/example-config.yaml @@ -163,6 +163,10 @@ bridge: image_as_file_size: 10 # Maximum size of Telegram documents in megabytes to bridge. max_document_size: 100 + # Enable experimental parallel file transfer, which makes uploads/downloads much faster by + # streaming from/to Matrix and using many connections for Telegram. + # Note that generating HQ thumbnails for videos is not possible with streamed transfers. + parallel_file_transfer: false # Whether or not created rooms should have federation enabled. # If false, created portal rooms will never be federated. federate_rooms: true @@ -207,20 +211,19 @@ bridge: # Text msgtypes (m.text, m.notice and m.emote) support HTML, media msgtypes don't. # # Available variables: - # $sender_displayname - The display name of the sender (e.g. Example User) - # $sender_username - The username (Matrix ID localpart) of the sender (e.g. exampleuser) - # $sender_mxid - The Matrix ID of the sender (e.g. @exampleuser:example.com) - # $body - The plaintext body (file name for media msgtypes) - # $formatted_body - The message content as HTML (for text msgtypes) + # $sender_displayname - The display name of the sender (e.g. Example User) + # $sender_username - The username (Matrix ID localpart) of the sender (e.g. exampleuser) + # $sender_mxid - The Matrix ID of the sender (e.g. @exampleuser:example.com) + # $message - The message content message_formats: - m.text: "$sender_displayname: $formatted_body" - m.notice: "$sender_displayname: $formatted_body" - m.emote: "* $sender_displayname $formatted_body" - m.file: "$sender_displayname sent a file: $body" - m.image: "$sender_displayname sent an image: $body" - m.audio: "$sender_displayname sent an audio file: $body" - m.video: "$sender_displayname sent a video: $body" - m.location: "$sender_displayname sent a location: $body" + m.text: "$sender_displayname: $message" + m.notice: "$sender_displayname: $message" + m.emote: "* $sender_displayname $message" + m.file: "$sender_displayname sent a file: $message" + m.image: "$sender_displayname sent an image: $message" + m.audio: "$sender_displayname sent an audio file: $message" + m.video: "$sender_displayname sent a video: $message" + m.location: "$sender_displayname sent a location: $message" # Telegram doesn't have built-in emotes, this field specifies how m.emote's from authenticated # users are sent to telegram. All fields in message_formats are supported. Additionally, the # Telegram user info is available in the following variables: diff --git a/mautrix_telegram/abstract_user.py b/mautrix_telegram/abstract_user.py index a40badc9..c4cfd7e7 100644 --- a/mautrix_telegram/abstract_user.py +++ b/mautrix_telegram/abstract_user.py @@ -295,7 +295,7 @@ class AbstractUser(ABC): async def update_admin(self, update: UpdateChatParticipantAdmin) -> None: # TODO duplication not checked - portal = po.Portal.get_by_tgid(TelegramID(update.chat_id), peer_type="chat") + portal = po.Portal.get_by_tgid(TelegramID(update.chat_id)) if not portal or not portal.mxid: return @@ -305,7 +305,7 @@ class AbstractUser(ABC): if isinstance(update, UpdateUserTyping): portal = po.Portal.get_by_tgid(TelegramID(update.user_id), self.tgid, "user") else: - portal = po.Portal.get_by_tgid(TelegramID(update.chat_id), peer_type="chat") + portal = po.Portal.get_by_tgid(TelegramID(update.chat_id)) if not portal or not portal.mxid: return @@ -350,7 +350,7 @@ class AbstractUser(ABC): Optional[pu.Puppet], Optional[po.Portal]]: if isinstance(update, UpdateShortChatMessage): - portal = po.Portal.get_by_tgid(TelegramID(update.chat_id), peer_type="chat") + portal = po.Portal.get_by_tgid(TelegramID(update.chat_id)) sender = pu.Puppet.get(TelegramID(update.from_id)) elif isinstance(update, UpdateShortMessage): portal = po.Portal.get_by_tgid(TelegramID(update.user_id), self.tgid, "user") @@ -410,9 +410,10 @@ class AbstractUser(ABC): if not config["bridge.relaybot.private_chat.invite"]: self.log.debug(f"Ignoring private message to bot from {sender.id}") return - elif not portal.mxid: + elif not portal or not portal.mxid: + tgid_log = portal.tgid_log if portal else original_update.chat_id self.log.debug( - f"Ignoring message received by bot in unbridged chat {portal.tgid_log}") + f"Ignoring message received by bot in unbridged chat {tgid_log}") return if self.ignore_incoming_bot_events and self.relaybot and sender.id == self.relaybot.tgid: diff --git a/mautrix_telegram/commands/clean_rooms.py b/mautrix_telegram/commands/clean_rooms.py index 8182f746..93e3e649 100644 --- a/mautrix_telegram/commands/clean_rooms.py +++ b/mautrix_telegram/commands/clean_rooms.py @@ -155,7 +155,7 @@ async def set_rooms_to_clean(evt, management_rooms: List[ManagementRoom], "next": lambda confirm: execute_room_cleanup(confirm, rooms_to_clean), "action": "Room cleaning", } - await evt.reply(f"To confirm cleaning up {len(rooms_to_clean)} rooms, type" + await evt.reply(f"To confirm cleaning up {len(rooms_to_clean)} rooms, type " "`$cmdprefix+sp confirm-clean`.") diff --git a/mautrix_telegram/commands/handler.py b/mautrix_telegram/commands/handler.py index f50e59e5..cc728be6 100644 --- a/mautrix_telegram/commands/handler.py +++ b/mautrix_telegram/commands/handler.py @@ -18,7 +18,7 @@ from typing import Awaitable, Callable, List, Optional, NamedTuple, Any from telethon.errors import FloodWaitError -from mautrix.types import RoomID, EventID +from mautrix.types import RoomID, EventID, MessageEventContent from mautrix.bridge.commands import (HelpSection, CommandEvent as BaseCommandEvent, CommandHandler as BaseCommandHandler, CommandProcessor as BaseCommandProcessor, @@ -42,10 +42,10 @@ class CommandEvent(BaseCommandEvent): sender: u.User def __init__(self, processor: 'CommandProcessor', room_id: RoomID, event_id: EventID, - sender: u.User, command: str, args: List[str], is_management: bool, - is_portal: bool) -> None: - super().__init__(processor, room_id, event_id, sender, command, args, is_management, - is_portal) + sender: u.User, command: str, args: List[str], content: MessageEventContent, + is_management: bool, is_portal: bool) -> None: + super().__init__(processor, room_id, event_id, sender, command, args, content, + is_management, is_portal) self.bridge = processor.bridge self.tgbot = processor.tgbot self.config = processor.config @@ -69,7 +69,7 @@ class CommandHandler(BaseCommandHandler): def __init__(self, handler: Callable[[CommandEvent], Awaitable[EventID]], management_only: bool, name: str, help_text: str, help_args: str, help_section: HelpSection, needs_auth: bool, needs_puppeting: bool, - needs_matrix_puppeting: bool, needs_admin: bool,) -> None: + needs_matrix_puppeting: bool, needs_admin: bool) -> None: super().__init__(handler, management_only, name, help_text, help_args, help_section, needs_auth=needs_auth, needs_puppeting=needs_puppeting, needs_matrix_puppeting=needs_matrix_puppeting, needs_admin=needs_admin) diff --git a/mautrix_telegram/commands/telegram/misc.py b/mautrix_telegram/commands/telegram/misc.py index 60b12181..6d9fde77 100644 --- a/mautrix_telegram/commands/telegram/misc.py +++ b/mautrix_telegram/commands/telegram/misc.py @@ -13,7 +13,7 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, cast import logging import codecs import base64 @@ -23,13 +23,13 @@ from telethon.errors import (InviteHashInvalidError, InviteHashExpiredError, Opt UserAlreadyParticipantError, ChatIdInvalidError) from telethon.tl.patched import Message from telethon.tl.types import (User as TLUser, TypeUpdates, MessageMediaGame, MessageMediaPoll, - TypePeer) + TypeInputPeer) from telethon.tl.types.messages import BotCallbackAnswer from telethon.tl.functions.messages import (ImportChatInviteRequest, CheckChatInviteRequest, GetBotCallbackAnswerRequest, SendVoteRequest) from telethon.tl.functions.channels import JoinChannelRequest -from mautrix.types import EventID +from mautrix.types import EventID, Format from ... import puppet as pu, portal as po from ...abstract_user import AbstractUser @@ -38,6 +38,22 @@ from ...types import TelegramID from ...commands import command_handler, CommandEvent, SECTION_MISC, SECTION_CREATING_PORTALS +@command_handler(needs_auth=False, + help_section=SECTION_MISC, help_args="<_caption_>", + help_text="Set a caption for the next image you send") +async def caption(evt: CommandEvent) -> EventID: + if len(evt.args) == 0: + return await evt.reply("**Usage:** `$cmdprefix+sp caption `") + + prefix = f"{evt.command_prefix} caption " + if evt.content.format == Format.HTML: + evt.content.formatted_body = evt.content.formatted_body.replace(prefix, "", 1) + evt.content.body = evt.content.body.replace(prefix, "", 1) + evt.sender.command_status = {"caption": evt.content} + return await evt.reply("Your next image or file will be sent with that caption. " + "Use `$cmdprefix+sp cancel` to cancel the caption.") + + @command_handler(help_section=SECTION_MISC, help_args="[_-r|--remote_] <_query_>", help_text="Search your contacts or the Telegram servers for users.") @@ -76,8 +92,7 @@ async def search(evt: CommandEvent) -> EventID: return await evt.reply("\n".join(reply)) -@command_handler(help_section=SECTION_CREATING_PORTALS, - help_args="<_identifier_>", +@command_handler(help_section=SECTION_CREATING_PORTALS, help_args="<_identifier_>", help_text="Open a private chat with the given Telegram user. The identifier is " "either the internal user ID, the username or the phone number. " "**N.B.** The phone numbers you start chats with must already be in " @@ -183,7 +198,7 @@ class MessageIDError(ValueError): async def _parse_encoded_msgid(user: AbstractUser, enc_id: str, type_name: str - ) -> Tuple[TypePeer, Message]: + ) -> Tuple[TypeInputPeer, Message]: try: enc_id += (4 - len(enc_id) % 4) * "=" enc_id = base64.b64decode(enc_id) @@ -212,7 +227,7 @@ async def _parse_encoded_msgid(user: AbstractUser, enc_id: str, type_name: str msg = await user.client.get_messages(entity=peer, ids=msg_id) if not msg: raise MessageIDError(f"Invalid {type_name} ID (message not found)") - return peer, msg + return peer, cast(Message, msg) @command_handler(help_section=SECTION_MISC, @@ -234,12 +249,13 @@ async def play(evt: CommandEvent) -> EventID: if not isinstance(msg.media, MessageMediaGame): return await evt.reply("Invalid play ID (message doesn't look like a game)") - game = await evt.sender.client(GetBotCallbackAnswerRequest(peer=peer, msg_id=msg.id, game=True)) + game = await evt.sender.client( + GetBotCallbackAnswerRequest(peer=peer, msg_id=msg.id, game=True)) if not isinstance(game, BotCallbackAnswer): return await evt.reply("Game request response invalid") return await evt.reply(f"Click [here]({game.url}) to play {msg.media.game.title}:\n\n" - f"{msg.media.game.description}") + f"{msg.media.game.description}") @command_handler(help_section=SECTION_MISC, diff --git a/mautrix_telegram/config.py b/mautrix_telegram/config.py index 9d3cd37c..e25c3d32 100644 --- a/mautrix_telegram/config.py +++ b/mautrix_telegram/config.py @@ -101,6 +101,7 @@ class Config(BaseBridgeConfig): copy("bridge.inline_images") copy("bridge.image_as_file_size") copy("bridge.max_document_size") + copy("bridge.parallel_file_transfer") copy("bridge.federate_rooms") copy("bridge.animated_sticker.target") copy("bridge.animated_sticker.args") diff --git a/mautrix_telegram/db/telegram_file.py b/mautrix_telegram/db/telegram_file.py index 4ac05293..efed516f 100644 --- a/mautrix_telegram/db/telegram_file.py +++ b/mautrix_telegram/db/telegram_file.py @@ -30,9 +30,9 @@ class TelegramFile(Base): mime_type: str = Column(String) was_converted: bool = Column(Boolean) timestamp: int = Column(BigInteger) - size: int = Column(Integer, nullable=True) - width: int = Column(Integer, nullable=True) - height: int = Column(Integer, nullable=True) + size: Optional[int] = Column(Integer, nullable=True) + width: Optional[int] = Column(Integer, nullable=True) + height: Optional[int] = Column(Integer, nullable=True) thumbnail_id: str = Column("thumbnail", String, ForeignKey("telegram_file.id"), nullable=True) thumbnail: Optional['TelegramFile'] = None diff --git a/mautrix_telegram/portal/matrix.py b/mautrix_telegram/portal/matrix.py index 32730da3..95c509e6 100644 --- a/mautrix_telegram/portal/matrix.py +++ b/mautrix_telegram/portal/matrix.py @@ -17,7 +17,6 @@ from typing import Awaitable, Dict, List, Optional, Tuple, Union, Any, TYPE_CHEC from html import escape as escape_html from string import Template from abc import ABC -import mimetypes import magic @@ -32,7 +31,7 @@ from telethon.tl.types import ( DocumentAttributeFilename, DocumentAttributeImageSize, GeoPoint, InputChatUploadedPhoto, MessageActionChatEditPhoto, MessageMediaGeo, SendMessageCancelAction, SendMessageTypingAction, TypeInputPeer, TypeMessageEntity, - UpdateNewMessage, InputMediaUploadedDocument) + UpdateNewMessage, InputMediaUploadedDocument, InputMediaUploadedPhoto) from mautrix.types import (EventID, RoomID, UserID, ContentURI, MessageType, MessageEventContent, TextMessageEventContent, MediaMessageEventContent, Format, @@ -41,7 +40,7 @@ from mautrix.bridge import BasePortal as MautrixBasePortal from ..types import TelegramID from ..db import Message as DBMessage -from ..util import sane_mimetypes +from ..util import sane_mimetypes, parallel_transfer_to_telegram from ..context import Context from .. import puppet as p, user as u, formatter, util from .base import BasePortal @@ -57,19 +56,6 @@ config: Optional['Config'] = None class PortalMatrix(BasePortal, MautrixBasePortal, ABC): - @staticmethod - def _get_file_meta(body: str, mime: str) -> str: - try: - current_extension = body[body.rindex("."):].lower() - body = body[:body.rindex(".")] - if mimetypes.types_map[current_extension] == mime: - return body + current_extension - except (ValueError, KeyError): - pass - if mime: - return f"matrix_upload{sane_mimetypes.guess_extension(mime)}" - return "" - async def _get_state_change_message(self, event: str, user: 'u.User', **kwargs: Any ) -> Optional[str]: tpl = self.get_config(f"state_event_formats.{event}") @@ -183,7 +169,7 @@ class PortalMatrix(BasePortal, MautrixBasePortal, ABC): async def _apply_msg_format(self, sender: 'u.User', content: MessageEventContent ) -> None: - if isinstance(content, TextMessageEventContent) and content.format != Format.HTML: + if not isinstance(content, TextMessageEventContent) or content.format != Format.HTML: content.format = Format.HTML content.formatted_body = escape_html(content.body).replace("\n", "
") @@ -193,14 +179,9 @@ class PortalMatrix(BasePortal, MautrixBasePortal, ABC): tpl_args = dict(sender_mxid=sender.mxid, sender_username=sender.mxid_localpart, sender_displayname=escape_html(displayname), - body=content.body) - if isinstance(content, TextMessageEventContent): - tpl_args["formatted_body"] = content.formatted_body - tpl_args["message"] = content.formatted_body - content.formatted_body = Template(tpl).safe_substitute(tpl_args) - else: - tpl_args["message"] = content.body - content.body = Template(tpl).safe_substitute(tpl_args) + message=content.formatted_body, + body=content.body, formatted_body=content.formatted_body) + content.formatted_body = Template(tpl).safe_substitute(tpl_args) async def _apply_emote_format(self, sender: 'u.User', content: TextMessageEventContent) -> None: @@ -262,42 +243,55 @@ class PortalMatrix(BasePortal, MautrixBasePortal, ABC): async def _handle_matrix_file(self, sender_id: TelegramID, event_id: EventID, space: TelegramID, client: 'MautrixTelegramClient', - content: MediaMessageEventContent, reply_to: TelegramID) -> None: - file = await self.main_intent.download_media(content.url) - + content: MediaMessageEventContent, reply_to: TelegramID, + caption: TextMessageEventContent = None) -> None: mime = content.info.mimetype - w, h = content.info.width, content.info.height + file_name = content["net.maunium.telegram.internal.filename"] + max_image_size = config["bridge.image_as_file_size"] * 1000 ** 2 - if content.msgtype == MessageType.STICKER: - if mime != "image/gif": - mime, file, w, h = util.convert_image(file, source_mime=mime, target_type="webp") - else: - # Remove sticker description - content["net.maunium.telegram.internal.filename"] = "sticker.gif" - content.body = "" + if config["bridge.parallel_file_transfer"]: + file_handle, file_size = await parallel_transfer_to_telegram(client, self.main_intent, + content.url, sender_id) + else: + file = await self.main_intent.download_media(content.url) + + if content.msgtype == MessageType.STICKER: + if mime != "image/gif": + mime, file, w, h = util.convert_image(file, source_mime=mime, + target_type="webp") + else: + # Remove sticker description + file_name = "sticker.gif" + + file_handle = await client.upload_file(file) + file_size = len(file) + + file_handle.name = file_name - file_name = self._get_file_meta(content["net.maunium.telegram.internal.filename"], mime) attributes = [DocumentAttributeFilename(file_name=file_name)] if w and h: attributes.append(DocumentAttributeImageSize(w, h)) - caption = content.body if content.body.lower() != file_name.lower() else None + if (mime == "image/png" or mime == "image/jpeg") and file_size < max_image_size: + media = InputMediaUploadedPhoto(file_handle) + else: + media = InputMediaUploadedDocument(file=file_handle, attributes=attributes, + mime_type=mime or "application/octet-stream") + + caption, entities = self._matrix_event_to_entities(caption) if caption else (None, None) - media = await client.upload_file_direct( - file, mime, attributes, file_name, - max_image_size=config["bridge.image_as_file_size"] * 1000 ** 2) async with self.send_lock(sender_id): if await self._matrix_document_edit(client, content, space, caption, media, event_id): return try: response = await client.send_media(self.peer, media, reply_to=reply_to, - caption=caption) + caption=caption, entities=entities) except (PhotoInvalidDimensionsError, PhotoSaveFileInvalidError, PhotoExtInvalidError): media = InputMediaUploadedDocument(file=media.file, mime_type=mime, attributes=attributes) response = await client.send_media(self.peer, media, reply_to=reply_to, - caption=caption) + caption=caption, entities=entities) self._add_telegram_message_to_db(event_id, space, 0, response) async def _matrix_document_edit(self, client: 'MautrixTelegramClient', @@ -364,8 +358,20 @@ class PortalMatrix(BasePortal, MautrixBasePortal, ABC): else (sender.tgid if logged_in else self.bot.tgid)) reply_to = formatter.matrix_reply_to_telegram(content, space, room_id=self.mxid) - content["net.maunium.telegram.internal.filename"] = content.body - await self._pre_process_matrix_message(sender, not logged_in, content) + media = (MessageType.STICKER, MessageType.IMAGE, MessageType.FILE, MessageType.AUDIO, + MessageType.VIDEO) + caption_content = None + if content.msgtype in media: + content["net.maunium.telegram.internal.filename"] = content.body + try: + caption_content: MessageEventContent = sender.command_status["caption"] + caption_content.msgtype = content.msgtype + reply_to = reply_to or formatter.matrix_reply_to_telegram(caption_content, space, + room_id=self.mxid) + sender.command_status = None + except (KeyError, TypeError): + pass + await self._pre_process_matrix_message(sender, not logged_in, caption_content or content) if content.msgtype == MessageType.NOTICE: bridge_notices = self.get_config("bridge_notices.default") @@ -378,9 +384,9 @@ class PortalMatrix(BasePortal, MautrixBasePortal, ABC): elif content.msgtype == MessageType.LOCATION: await self._handle_matrix_location(sender_id, event_id, space, client, content, reply_to) - elif content.msgtype in (MessageType.STICKER, MessageType.IMAGE, MessageType.FILE, - MessageType.AUDIO, MessageType.VIDEO): - await self._handle_matrix_file(sender_id, event_id, space, client, content, reply_to) + elif content.msgtype in media: + await self._handle_matrix_file(sender_id, event_id, space, client, content, reply_to, + caption_content) else: self.log.debug(f"Unhandled Matrix event: {content}") diff --git a/mautrix_telegram/portal/telegram.py b/mautrix_telegram/portal/telegram.py index 3aa97b3a..f88d1380 100644 --- a/mautrix_telegram/portal/telegram.py +++ b/mautrix_telegram/portal/telegram.py @@ -181,9 +181,11 @@ class PortalTelegram(BasePortal, ABC): self.log.debug(f"Unsupported thumbnail type {type(thumb_size)}") thumb_loc = None thumb_size = None + parallel_id = source.tgid if config["bridge.parallel_file_transfer"] else None file = await util.transfer_file_to_matrix(source.client, intent, document, thumb_loc, is_sticker=attrs.is_sticker, - tgs_convert=config["bridge.animated_sticker"]) + tgs_convert=config["bridge.animated_sticker"], + filename=attrs.name, parallel_id=parallel_id) if not file: return None diff --git a/mautrix_telegram/user.py b/mautrix_telegram/user.py index df36fc5d..9c365656 100644 --- a/mautrix_telegram/user.py +++ b/mautrix_telegram/user.py @@ -219,7 +219,7 @@ class User(AbstractUser, BaseUser): else: portal = po.Portal.get_by_entity(message.to_id, receiver_id=self.tgid) elif isinstance(update, UpdateShortChatMessage): - portal = po.Portal.get_by_tgid(TelegramID(update.chat_id), peer_type="chat") + portal = po.Portal.get_by_tgid(TelegramID(update.chat_id)) elif isinstance(update, UpdateShortMessage): portal = po.Portal.get_by_tgid(TelegramID(update.user_id), self.tgid, "user") else: diff --git a/mautrix_telegram/util/__init__.py b/mautrix_telegram/util/__init__.py index 727224bb..b2bfa88e 100644 --- a/mautrix_telegram/util/__init__.py +++ b/mautrix_telegram/util/__init__.py @@ -1,4 +1,5 @@ from .file_transfer import transfer_file_to_matrix, convert_image +from .parallel_file_transfer import parallel_transfer_to_telegram from .format_duration import format_duration from .recursive_dict import recursive_del, recursive_set, recursive_get from .color_log import ColorFormatter diff --git a/mautrix_telegram/util/file_transfer.py b/mautrix_telegram/util/file_transfer.py index e89c3465..9028d9ce 100644 --- a/mautrix_telegram/util/file_transfer.py +++ b/mautrix_telegram/util/file_transfer.py @@ -34,6 +34,7 @@ from mautrix.appservice import IntentAPI from ..tgclient import MautrixTelegramClient from ..db import TelegramFile as DBTelegramFile from ..util import sane_mimetypes +from .parallel_file_transfer import parallel_transfer_to_matrix try: from PIL import Image @@ -129,7 +130,7 @@ async def transfer_thumbnail_to_matrix(client: MautrixTelegramClient, intent: In return db_file video_ext = sane_mimetypes.guess_extension(mime) - if VideoFileClip and video_ext: + if VideoFileClip and video_ext and video: try: file, width, height = _read_video_thumbnail(video, video_ext, frame_ext="png") except OSError: @@ -161,7 +162,8 @@ TypeThumbnail = Optional[Union[TypeLocation, TypePhotoSize]] async def transfer_file_to_matrix(client: MautrixTelegramClient, intent: IntentAPI, location: TypeLocation, thumbnail: TypeThumbnail = None, - is_sticker: bool = False, tgs_convert: Optional[dict] = None + is_sticker: bool = False, tgs_convert: Optional[dict] = None, + filename: Optional[str] = None, parallel_id: Optional[int] = None ) -> Optional[DBTelegramFile]: location_id = _location_to_id(location) if not location_id: @@ -178,53 +180,61 @@ async def transfer_file_to_matrix(client: MautrixTelegramClient, intent: IntentA transfer_locks[location_id] = lock async with lock: return await _unlocked_transfer_file_to_matrix(client, intent, location_id, location, - thumbnail, is_sticker, tgs_convert) + thumbnail, is_sticker, tgs_convert, + filename, parallel_id) async def _unlocked_transfer_file_to_matrix(client: MautrixTelegramClient, intent: IntentAPI, loc_id: str, location: TypeLocation, thumbnail: TypeThumbnail, is_sticker: bool, - tgs_convert: Optional[dict] + tgs_convert: Optional[dict], filename: Optional[str], + parallel_id: Optional[int] ) -> Optional[DBTelegramFile]: db_file = DBTelegramFile.get(loc_id) if db_file: return db_file - try: - file = await client.download_file(location) - except (LocationInvalidError, FileIdInvalidError): - return None - except (AuthBytesInvalidError, AuthKeyInvalidError, SecurityError) as e: - log.exception(f"{e.__class__.__name__} while downloading a file.") - return None + if parallel_id and isinstance(location, Document) and (not is_sticker or not tgs_convert): + db_file = await parallel_transfer_to_matrix(client, intent, loc_id, location, filename, + parallel_id) + mime_type = location.mime_type + file = None + else: + try: + file = await client.download_file(location) + except (LocationInvalidError, FileIdInvalidError): + return None + except (AuthBytesInvalidError, AuthKeyInvalidError, SecurityError) as e: + log.exception(f"{e.__class__.__name__} while downloading a file.") + return None - width, height = None, None - mime_type = magic.from_buffer(file, mime=True) + width, height = None, None + mime_type = magic.from_buffer(file, mime=True) - image_converted = False - # A weird bug in alpine/magic makes it return application/octet-stream for gzips... - if is_sticker and tgs_convert and (mime_type == "application/gzip" or ( - mime_type == "application/octet-stream" - and magic.from_buffer(file).startswith("gzip"))): - mime_type, file, width, height = await convert_tgs_to( - file, tgs_convert["target"], **tgs_convert["args"]) - thumbnail = None - image_converted = mime_type != "application/gzip" + image_converted = False + # A weird bug in alpine/magic makes it return application/octet-stream for gzips... + if is_sticker and tgs_convert and (mime_type == "application/gzip" or ( + mime_type == "application/octet-stream" + and magic.from_buffer(file).startswith("gzip"))): + mime_type, file, width, height = await convert_tgs_to( + file, tgs_convert["target"], **tgs_convert["args"]) + thumbnail = None + image_converted = mime_type != "application/gzip" - if mime_type == "image/webp": - new_mime_type, file, width, height = convert_image( - file, source_mime="image/webp", target_type="png", - thumbnail_to=(256, 256) if is_sticker else None) - image_converted = new_mime_type != mime_type - mime_type = new_mime_type - thumbnail = None + if mime_type == "image/webp": + new_mime_type, file, width, height = convert_image( + file, source_mime="image/webp", target_type="png", + thumbnail_to=(256, 256) if is_sticker else None) + image_converted = new_mime_type != mime_type + mime_type = new_mime_type + thumbnail = None - content_uri = await intent.upload_media(file, mime_type) + content_uri = await intent.upload_media(file, mime_type) - db_file = DBTelegramFile(id=loc_id, mxc=content_uri, - mime_type=mime_type, was_converted=image_converted, - timestamp=int(time.time()), size=len(file), - width=width, height=height) + db_file = DBTelegramFile(id=loc_id, mxc=content_uri, + mime_type=mime_type, was_converted=image_converted, + timestamp=int(time.time()), size=len(file), + width=width, height=height) if thumbnail and (mime_type.startswith("video/") or mime_type == "image/gif"): if isinstance(thumbnail, (PhotoSize, PhotoCachedSize)): thumbnail = thumbnail.location diff --git a/mautrix_telegram/util/parallel_file_transfer.py b/mautrix_telegram/util/parallel_file_transfer.py new file mode 100644 index 00000000..507646aa --- /dev/null +++ b/mautrix_telegram/util/parallel_file_transfer.py @@ -0,0 +1,298 @@ +# mautrix-telegram - A Matrix-Telegram puppeting bridge +# Copyright (C) 2019 Tulir Asokan +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +from typing import Optional, List, AsyncGenerator, Union, Awaitable, DefaultDict, Tuple +from collections import defaultdict +import hashlib +import asyncio +import logging +import time +import math + +from aiohttp import ClientResponse + +from telethon.tl.types import (Document, InputFileLocation, InputDocumentFileLocation, + InputPhotoFileLocation, InputPeerPhotoFileLocation, TypeInputFile, + InputFileBig, InputFile) +from telethon.tl.functions.auth import ExportAuthorizationRequest, ImportAuthorizationRequest +from telethon.tl.functions.upload import (GetFileRequest, SaveFilePartRequest, + SaveBigFilePartRequest) +from telethon.network import MTProtoSender +from telethon.crypto import AuthKey +from telethon import utils, helpers + +from mautrix.appservice import IntentAPI +from mautrix.types import ContentURI + +from ..tgclient import MautrixTelegramClient +from ..db import TelegramFile as DBTelegramFile + +log: logging.Logger = logging.getLogger("mau.util") + +TypeLocation = Union[Document, InputDocumentFileLocation, InputPeerPhotoFileLocation, + InputFileLocation, InputPhotoFileLocation] + + +class DownloadSender: + sender: MTProtoSender + request: GetFileRequest + remaining: int + stride: int + + def __init__(self, sender: MTProtoSender, file: TypeLocation, offset: int, limit: int, + stride: int, count: int) -> None: + self.sender = sender + self.request = GetFileRequest(file, offset=offset, limit=limit) + self.stride = stride + self.remaining = count + + async def next(self) -> Optional[bytes]: + if not self.remaining: + return None + result = await self.sender.send(self.request) + self.remaining -= 1 + self.request.offset += self.stride + return result.bytes + + def disconnect(self) -> Awaitable[None]: + return self.sender.disconnect() + + +class UploadSender: + sender: MTProtoSender + request: Union[SaveFilePartRequest, SaveBigFilePartRequest] + part_count: int + stride: int + previous: Optional[asyncio.Task] + loop: asyncio.AbstractEventLoop + + def __init__(self, sender: MTProtoSender, file_id: int, part_count: int, big: bool, index: int, + stride: int, loop: asyncio.AbstractEventLoop) -> None: + self.sender = sender + self.part_count = part_count + if big: + self.request = SaveBigFilePartRequest(file_id, index, part_count, b"") + else: + self.request = SaveFilePartRequest(file_id, index, b"") + self.stride = stride + self.previous = None + self.loop = loop + + async def next(self, data: bytes) -> None: + if self.previous: + await self.previous + self.previous = self.loop.create_task(self._next(data)) + + async def _next(self, data: bytes) -> None: + self.request.bytes = data + log.debug(f"Sending file part {self.request.file_part}/{self.part_count}" + f" with {len(data)} bytes") + await self.sender.send(self.request) + self.request.file_part += self.stride + + async def disconnect(self) -> None: + if self.previous: + await self.previous + return await self.sender.disconnect() + + +class ParallelTransferrer: + client: MautrixTelegramClient + loop: asyncio.AbstractEventLoop + dc_id: int + senders: Optional[List[Union[DownloadSender, UploadSender]]] + auth_key: AuthKey + upload_ticker: int + + def __init__(self, client: MautrixTelegramClient, dc_id: Optional[int] = None) -> None: + self.client = client + self.loop = self.client.loop + self.dc_id = dc_id or self.client.session.dc_id + self.auth_key = (None if dc_id and self.client.session.dc_id != dc_id + else self.client.session.auth_key) + self.senders = None + self.upload_ticker = 0 + + async def _cleanup(self) -> None: + await asyncio.gather(*[sender.disconnect() for sender in self.senders]) + self.senders = None + + @staticmethod + def _get_connection_count(file_size: int, max_count: int = 20, + full_size: int = 100 * 1024 * 1024) -> int: + if file_size > full_size: + return max_count + return math.ceil((file_size / full_size) * max_count) + + async def _init_download(self, connections: int, file: TypeLocation, part_count: int, + part_size: int) -> None: + minimum, remainder = divmod(part_count, connections) + + def get_part_count() -> int: + nonlocal remainder + if remainder > 0: + remainder -= 1 + return minimum + 1 + return minimum + + # The first cross-DC sender will export+import the authorization, so we always create it + # before creating any other senders. + self.senders = [ + await self._create_download_sender(file, 0, part_size, connections * part_size, + get_part_count()), + *await asyncio.gather( + *[self._create_download_sender(file, i, part_size, connections * part_size, + get_part_count()) + for i in range(1, connections)]) + ] + + async def _create_download_sender(self, file: TypeLocation, index: int, part_size: int, + stride: int, + part_count: int) -> DownloadSender: + return DownloadSender(await self._create_sender(), file, index * part_size, part_size, + stride, part_count) + + async def _init_upload(self, connections: int, file_id: int, part_count: int, big: bool + ) -> None: + self.senders = [ + await self._create_upload_sender(file_id, part_count, big, 0, connections), + *await asyncio.gather( + *[self._create_upload_sender(file_id, part_count, big, i, connections) + for i in range(1, connections)]) + ] + + async def _create_upload_sender(self, file_id: int, part_count: int, big: bool, index: int, + stride: int) -> UploadSender: + return UploadSender(await self._create_sender(), file_id, part_count, big, index, stride, + loop=self.loop) + + async def _create_sender(self) -> MTProtoSender: + dc = await self.client._get_dc(self.dc_id) + sender = MTProtoSender(self.auth_key, self.loop, loggers=self.client._log) + await sender.connect(self.client._connection(dc.ip_address, dc.port, dc.id, + loop=self.loop, loggers=self.client._log, + proxy=self.client._proxy)) + if not self.auth_key: + log.debug(f"Exporting auth to DC {self.dc_id}") + auth = await self.client(ExportAuthorizationRequest(self.dc_id)) + req = self.client._init_with(ImportAuthorizationRequest( + id=auth.id, bytes=auth.bytes + )) + await sender.send(req) + self.auth_key = sender.auth_key + return sender + + async def init_upload(self, file_id: int, file_size: int, part_size_kb: Optional[float] = None, + connection_count: Optional[int] = None) -> Tuple[int, int, bool]: + connection_count = connection_count or self._get_connection_count(file_size) + part_size = (part_size_kb or utils.get_appropriated_part_size(file_size)) * 1024 + part_count = (file_size + part_size - 1) // part_size + is_large = file_size > 10 * 1024 * 1024 + await self._init_upload(connection_count, file_id, part_count, is_large) + return part_size, part_count, is_large + + async def upload(self, part: bytes) -> None: + await self.senders[self.upload_ticker].next(part) + self.upload_ticker = (self.upload_ticker + 1) % len(self.senders) + + async def finish_upload(self) -> None: + await self._cleanup() + + async def download(self, file: TypeLocation, file_size: int, + part_size_kb: Optional[float] = None, + connection_count: Optional[int] = None) -> AsyncGenerator[bytes, None]: + connection_count = connection_count or self._get_connection_count(file_size) + part_size = (part_size_kb or utils.get_appropriated_part_size(file_size)) * 1024 + part_count = math.ceil(file_size / part_size) + log.debug("Starting parallel download: " + f"{connection_count} {part_size} {part_count} {file!s}") + await self._init_download(connection_count, file, part_count, part_size) + + part = 0 + while part < part_count: + tasks = [] + for sender in self.senders: + tasks.append(self.loop.create_task(sender.next())) + for task in tasks: + data = await task + if not data: + break + yield data + part += 1 + log.debug(f"Part {part} downloaded") + + log.debug("Parallel download finished, cleaning up connections") + await self._cleanup() + + +parallel_transfer_locks: DefaultDict[int, asyncio.Lock] = defaultdict(lambda: asyncio.Lock()) + + +async def parallel_transfer_to_matrix(client: MautrixTelegramClient, intent: IntentAPI, + loc_id: str, location: TypeLocation, filename: str, + parallel_id: int) -> DBTelegramFile: + size = location.size + mime_type = location.mime_type + dc_id, location = utils.get_input_location(location) + # We lock the transfers because telegram has connection count limits + async with parallel_transfer_locks[parallel_id]: + downloader = ParallelTransferrer(client, dc_id) + content_uri = await intent.upload_media(downloader.download(location, size), + mime_type=mime_type, filename=filename, size=size) + return DBTelegramFile(id=loc_id, mxc=content_uri, mime_type=mime_type, + was_converted=False, timestamp=int(time.time()), size=size, + width=None, height=None) + + +async def _internal_transfer_to_telegram(client: MautrixTelegramClient, response: ClientResponse + ) -> Tuple[TypeInputFile, int]: + file_id = helpers.generate_random_long() + file_size = response.content_length + + hash_md5 = hashlib.md5() + uploader = ParallelTransferrer(client) + part_size, part_count, is_large = await uploader.init_upload(file_id, file_size) + buffer = bytearray() + async for data in response.content: + if not is_large: + hash_md5.update(data) + if len(buffer) == 0 and len(data) == part_size: + await uploader.upload(data) + continue + new_len = len(buffer) + len(data) + if new_len >= part_size: + cutoff = part_size - len(buffer) + buffer.extend(data[:cutoff]) + await uploader.upload(bytes(buffer)) + buffer.clear() + buffer.extend(data[cutoff:]) + else: + buffer.extend(data) + if len(buffer) > 0: + await uploader.upload(bytes(buffer)) + await uploader.finish_upload() + if is_large: + return InputFileBig(file_id, part_count, "upload"), file_size + else: + return InputFile(file_id, part_count, "upload", hash_md5.hexdigest()), file_size + + +async def parallel_transfer_to_telegram(client: MautrixTelegramClient, intent: IntentAPI, + uri: ContentURI, parallel_id: int + ) -> Tuple[TypeInputFile, int]: + url = intent.api.get_download_url(uri) + async with parallel_transfer_locks[parallel_id]: + async with intent.api.session.get(url) as response: + return await _internal_transfer_to_telegram(client, response) diff --git a/setup.py b/setup.py index 51ee2ae8..1e21ba78 100644 --- a/setup.py +++ b/setup.py @@ -32,7 +32,7 @@ setuptools.setup( install_requires=[ "aiohttp>=3.0.1,<4", - "mautrix>=0.4.0.dev71,<0.5", + "mautrix>=0.4.0.dev75,<0.5", "SQLAlchemy>=1.2.3,<2", "alembic>=1.0.0,<2", "commonmark>=0.8.1,<0.10",