diff --git a/example-config.yaml b/example-config.yaml index ca4517ac..5cfed8ea 100644 --- a/example-config.yaml +++ b/example-config.yaml @@ -163,6 +163,10 @@ bridge: image_as_file_size: 10 # Maximum size of Telegram documents in megabytes to bridge. max_document_size: 100 + # Enable experimental parallel file transfer, which makes uploads/downloads much faster by + # streaming from/to Matrix and using many connections for Telegram. + # Note that generating HQ thumbnails for videos is not possible with streamed transfers. + parallel_file_transfer: false # Whether or not created rooms should have federation enabled. # If false, created portal rooms will never be federated. federate_rooms: true diff --git a/mautrix_telegram/config.py b/mautrix_telegram/config.py index 3c3e8929..cfdc7d7b 100644 --- a/mautrix_telegram/config.py +++ b/mautrix_telegram/config.py @@ -101,6 +101,7 @@ class Config(BaseBridgeConfig): copy("bridge.inline_images") copy("bridge.image_as_file_size") copy("bridge.max_document_size") + copy("bridge.parallel_file_transfer") copy("bridge.federate_rooms") copy("bridge.bot_messages_as_notices") diff --git a/mautrix_telegram/db/telegram_file.py b/mautrix_telegram/db/telegram_file.py index 4ac05293..efed516f 100644 --- a/mautrix_telegram/db/telegram_file.py +++ b/mautrix_telegram/db/telegram_file.py @@ -30,9 +30,9 @@ class TelegramFile(Base): mime_type: str = Column(String) was_converted: bool = Column(Boolean) timestamp: int = Column(BigInteger) - size: int = Column(Integer, nullable=True) - width: int = Column(Integer, nullable=True) - height: int = Column(Integer, nullable=True) + size: Optional[int] = Column(Integer, nullable=True) + width: Optional[int] = Column(Integer, nullable=True) + height: Optional[int] = Column(Integer, nullable=True) thumbnail_id: str = Column("thumbnail", String, ForeignKey("telegram_file.id"), nullable=True) thumbnail: Optional['TelegramFile'] = None diff --git a/mautrix_telegram/portal/telegram.py b/mautrix_telegram/portal/telegram.py index 6ccf62cd..8ccd75c3 100644 --- a/mautrix_telegram/portal/telegram.py +++ b/mautrix_telegram/portal/telegram.py @@ -181,8 +181,10 @@ class PortalTelegram(BasePortal, ABC): self.log.debug(f"Unsupported thumbnail type {type(thumb_size)}") thumb_loc = None thumb_size = None + parallel_id = source.tgid if config["bridge.parallel_file_transfer"] else None file = await util.transfer_file_to_matrix(source.client, intent, document, thumb_loc, - is_sticker=attrs.is_sticker) + is_sticker=attrs.is_sticker, filename=attrs.name, + parallel_id=parallel_id) if not file: return None diff --git a/mautrix_telegram/util/file_transfer.py b/mautrix_telegram/util/file_transfer.py index 02a3e7ca..1f0402a6 100644 --- a/mautrix_telegram/util/file_transfer.py +++ b/mautrix_telegram/util/file_transfer.py @@ -33,6 +33,7 @@ from mautrix.appservice import IntentAPI from ..tgclient import MautrixTelegramClient from ..db import TelegramFile as DBTelegramFile from ..util import sane_mimetypes +from .parallel_file_transfer import parallel_transfer_to_matrix try: from PIL import Image @@ -126,7 +127,7 @@ async def transfer_thumbnail_to_matrix(client: MautrixTelegramClient, intent: In return db_file video_ext = sane_mimetypes.guess_extension(mime) - if VideoFileClip and video_ext: + if VideoFileClip and video_ext and video: try: file, width, height = _read_video_thumbnail(video, video_ext, frame_ext="png") except OSError: @@ -158,7 +159,8 @@ TypeThumbnail = Optional[Union[TypeLocation, TypePhotoSize]] async def transfer_file_to_matrix(client: MautrixTelegramClient, intent: IntentAPI, location: TypeLocation, thumbnail: TypeThumbnail = None, - is_sticker: bool = False) -> Optional[DBTelegramFile]: + is_sticker: bool = False, filename: Optional[str] = None, + parallel_id: Optional[int] = None) -> Optional[DBTelegramFile]: location_id = _location_to_id(location) if not location_id: return None @@ -174,43 +176,52 @@ async def transfer_file_to_matrix(client: MautrixTelegramClient, intent: IntentA transfer_locks[location_id] = lock async with lock: return await _unlocked_transfer_file_to_matrix(client, intent, location_id, location, - thumbnail, is_sticker) + thumbnail, is_sticker, filename, + parallel_id) async def _unlocked_transfer_file_to_matrix(client: MautrixTelegramClient, intent: IntentAPI, loc_id: str, location: TypeLocation, - thumbnail: TypeThumbnail, is_sticker: bool + thumbnail: TypeThumbnail, is_sticker: bool, + filename: Optional[str], + parallel_id: Optional[int] = None ) -> Optional[DBTelegramFile]: db_file = DBTelegramFile.get(loc_id) if db_file: return db_file - try: - file = await client.download_file(location) - except (LocationInvalidError, FileIdInvalidError): - return None - except (AuthBytesInvalidError, AuthKeyInvalidError, SecurityError) as e: - log.exception(f"{e.__class__.__name__} while downloading a file.") - return None + if parallel_id and isinstance(location, Document): + db_file = await parallel_transfer_to_matrix(client, intent, loc_id, location, filename, + parallel_id) + mime_type = location.mime_type + file = None + else: + try: + file = await client.download_file(location) + except (LocationInvalidError, FileIdInvalidError): + return None + except (AuthBytesInvalidError, AuthKeyInvalidError, SecurityError) as e: + log.exception(f"{e.__class__.__name__} while downloading a file.") + return None - width, height = None, None - mime_type = magic.from_buffer(file, mime=True) + width, height = None, None + mime_type = magic.from_buffer(file, mime=True) - image_converted = False - if mime_type == "image/webp": - new_mime_type, file, width, height = convert_image( - file, source_mime="image/webp", target_type="png", - thumbnail_to=(256, 256) if is_sticker else None) - image_converted = new_mime_type != mime_type - mime_type = new_mime_type - thumbnail = None + image_converted = False + if mime_type == "image/webp": + new_mime_type, file, width, height = convert_image( + file, source_mime="image/webp", target_type="png", + thumbnail_to=(256, 256) if is_sticker else None) + image_converted = new_mime_type != mime_type + mime_type = new_mime_type + thumbnail = None - content_uri = await intent.upload_media(file, mime_type) + content_uri = await intent.upload_media(file, mime_type) - db_file = DBTelegramFile(id=loc_id, mxc=content_uri, - mime_type=mime_type, was_converted=image_converted, - timestamp=int(time.time()), size=len(file), - width=width, height=height) + db_file = DBTelegramFile(id=loc_id, mxc=content_uri, + mime_type=mime_type, was_converted=image_converted, + timestamp=int(time.time()), size=len(file), + width=width, height=height) if thumbnail and (mime_type.startswith("video/") or mime_type == "image/gif"): if isinstance(thumbnail, (PhotoSize, PhotoCachedSize)): thumbnail = thumbnail.location diff --git a/mautrix_telegram/util/parallel_file_transfer.py b/mautrix_telegram/util/parallel_file_transfer.py new file mode 100644 index 00000000..447eda3c --- /dev/null +++ b/mautrix_telegram/util/parallel_file_transfer.py @@ -0,0 +1,175 @@ +# mautrix-telegram - A Matrix-Telegram puppeting bridge +# Copyright (C) 2019 Tulir Asokan +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +from typing import Optional, List, AsyncGenerator, Union, Awaitable, DefaultDict +from collections import defaultdict +import asyncio +import logging +import time +import math + +from telethon.tl.types import (Document, InputFileLocation, InputDocumentFileLocation, + InputPhotoFileLocation, InputPeerPhotoFileLocation) +from telethon.tl.functions.auth import ExportAuthorizationRequest, ImportAuthorizationRequest +from telethon.tl.functions.upload import GetFileRequest +from telethon.network import MTProtoSender +from telethon.crypto import AuthKey +from telethon import utils + +from mautrix.appservice import IntentAPI + +from ..tgclient import MautrixTelegramClient +from ..db import TelegramFile as DBTelegramFile + +log: logging.Logger = logging.getLogger("mau.util") + +TypeLocation = Union[Document, InputDocumentFileLocation, InputPeerPhotoFileLocation, + InputFileLocation, InputPhotoFileLocation] + + +class Sender: + sender: MTProtoSender + request: GetFileRequest + remaining: int + stride: int + + def __init__(self, sender: MTProtoSender, file: TypeLocation, offset: int, limit: int, + stride: int, count: int) -> None: + log.debug(f"Creating sender with {offset=} {limit=} {stride=} {count=}") + self.sender = sender + self.request = GetFileRequest(file, offset=offset, limit=limit) + self.stride = stride + self.remaining = count + + async def next(self) -> Optional[bytes]: + if not self.remaining: + return None + log.debug(f"Sending {self.request!s}") + result = await self.sender.send(self.request) + self.remaining -= 1 + self.request.offset += self.stride + return result.bytes + + def disconnect(self) -> Awaitable[None]: + return self.sender.disconnect() + + +class ParallelDownloader: + client: MautrixTelegramClient + loop: asyncio.AbstractEventLoop + dc_id: int + senders: Optional[List[Sender]] + auth_key: AuthKey + + def __init__(self, client: MautrixTelegramClient, dc_id: int) -> None: + self.client = client + self.loop = self.client.loop + self.dc_id = dc_id + self.exported = dc_id and self.client.session.dc_id != dc_id + self.auth_key = self.client.session.auth_key if not self.exported else None + self.senders = None + + async def _init(self, connections: int, file: TypeLocation, part_count: int, part_size: int + ) -> None: + minimum, remainder = divmod(part_count, connections) + + def get_part_count() -> int: + nonlocal remainder + if remainder > 0: + remainder -= 1 + return minimum + 1 + return minimum + + self.senders = [ + await self._create_sender(file, 0, part_size, connections * part_size, + get_part_count()), + *await asyncio.gather(*[ + self._create_sender(file, i, part_size, connections * part_size, get_part_count()) + for i in range(1, connections) + ]) + ] + + async def _cleanup(self) -> None: + await asyncio.gather(*[sender.disconnect() for sender in self.senders]) + self.senders = None + + async def _create_sender(self, file: TypeLocation, index: int, part_size: int, stride: int, + part_count: int) -> Sender: + dc = await self.client._get_dc(self.dc_id) + sender = MTProtoSender(self.auth_key, self.loop, loggers=self.client._log) + await sender.connect(self.client._connection(dc.ip_address, dc.port, dc.id, + loop=self.loop, loggers=self.client._log, + proxy=self.client._proxy)) + if not self.auth_key: + log.debug(f"Exporting auth to DC {self.dc_id}") + auth = await self.client(ExportAuthorizationRequest(self.dc_id)) + req = self.client._init_with(ImportAuthorizationRequest( + id=auth.id, bytes=auth.bytes + )) + await sender.send(req) + self.auth_key = sender.auth_key + return Sender(sender, file, index * part_size, part_size, stride, part_count) + + @staticmethod + def _get_connection_count(file_size: int, max_count: int = 20, + full_size: int = 100 * 1024 * 1024) -> int: + if file_size > full_size: + return max_count + return math.ceil((file_size / full_size) * max_count) + + async def download(self, file: TypeLocation, file_size: int, + part_size_kb: Optional[float] = None, + connection_count: Optional[int] = None) -> AsyncGenerator[bytes, None]: + connection_count = connection_count or self._get_connection_count(file_size) + part_size = (part_size_kb or utils.get_appropriated_part_size(file_size)) * 1024 + part_count = math.ceil(file_size / part_size) + log.debug("Starting parallel download: " + f"{connection_count} {part_size} {part_count} {file!s}") + await self._init(connection_count, file, part_count, part_size) + + part = 0 + while part < part_count: + tasks = [] + for sender in self.senders: + tasks.append(self.loop.create_task(sender.next())) + for task in tasks: + data = await task + if not data: + break + yield data + part += 1 + log.debug(f"Part {part} downloaded") + + log.debug("Parallel download finished, cleaning up connections") + await self._cleanup() + + +parallel_transfer_locks: DefaultDict[int, asyncio.Lock] = defaultdict(lambda: asyncio.Lock()) + + +async def parallel_transfer_to_matrix(client: MautrixTelegramClient, intent: IntentAPI, + loc_id: str, location: TypeLocation, filename: str, + parallel_id: int) -> DBTelegramFile: + size = location.size + mime_type = location.mime_type + dc_id, location = utils.get_input_location(location) + # We lock the transfers because telegram has connection count limits + async with parallel_transfer_locks[parallel_id]: + downloader = ParallelDownloader(client, dc_id) + content_uri = await intent.upload_media(downloader.download(location, size), + mime_type=mime_type, filename=filename, size=size) + return DBTelegramFile(id=loc_id, mxc=content_uri, mime_type=mime_type, + was_converted=False, timestamp=int(time.time()), size=size, + width=None, height=None) diff --git a/setup.py b/setup.py index 51ee2ae8..0ce9daba 100644 --- a/setup.py +++ b/setup.py @@ -32,7 +32,7 @@ setuptools.setup( install_requires=[ "aiohttp>=3.0.1,<4", - "mautrix>=0.4.0.dev71,<0.5", + "mautrix>=0.4.0.dev74,<0.5", "SQLAlchemy>=1.2.3,<2", "alembic>=1.0.0,<2", "commonmark>=0.8.1,<0.10",