From 73a6ad2cf267dc112aa86390e6b72251d95bd62c Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 27 Oct 2019 02:43:29 +0300 Subject: [PATCH] Add parallel file upload too --- mautrix_telegram/portal/matrix.py | 41 ++-- mautrix_telegram/util/__init__.py | 1 + .../util/parallel_file_transfer.py | 196 ++++++++++++++---- 3 files changed, 189 insertions(+), 49 deletions(-) diff --git a/mautrix_telegram/portal/matrix.py b/mautrix_telegram/portal/matrix.py index a7479472..f59d2222 100644 --- a/mautrix_telegram/portal/matrix.py +++ b/mautrix_telegram/portal/matrix.py @@ -17,7 +17,6 @@ from typing import Awaitable, Dict, List, Optional, Tuple, Union, Any, TYPE_CHEC from html import escape as escape_html from string import Template from abc import ABC -import mimetypes import magic @@ -32,7 +31,7 @@ from telethon.tl.types import ( DocumentAttributeFilename, DocumentAttributeImageSize, GeoPoint, InputChatUploadedPhoto, MessageActionChatEditPhoto, MessageMediaGeo, SendMessageCancelAction, SendMessageTypingAction, TypeInputPeer, TypeMessageEntity, - UpdateNewMessage, InputMediaUploadedDocument) + UpdateNewMessage, InputMediaUploadedDocument, InputMediaUploadedPhoto) from mautrix.types import (EventID, RoomID, UserID, ContentURI, MessageType, MessageEventContent, TextMessageEventContent, MediaMessageEventContent, Format, @@ -41,7 +40,7 @@ from mautrix.bridge import BasePortal as MautrixBasePortal from ..types import TelegramID from ..db import Message as DBMessage -from ..util import sane_mimetypes +from ..util import sane_mimetypes, parallel_transfer_to_telegram from ..context import Context from .. import puppet as p, user as u, formatter, util from .base import BasePortal @@ -250,28 +249,42 @@ class PortalMatrix(BasePortal, MautrixBasePortal, ABC): async def _handle_matrix_file(self, sender_id: TelegramID, event_id: EventID, space: TelegramID, client: 'MautrixTelegramClient', content: MediaMessageEventContent, reply_to: TelegramID) -> None: - file = await self.main_intent.download_media(content.url) - mime = content.info.mimetype w, h = content.info.width, content.info.height file_name = content["net.maunium.telegram.internal.filename"] + max_image_size = config["bridge.image_as_file_size"] * 1000 ** 2 - if content.msgtype == MessageType.STICKER: - if mime != "image/gif": - mime, file, w, h = util.convert_image(file, source_mime=mime, target_type="webp") - else: - # Remove sticker description - file_name = "sticker.gif" + if config["bridge.parallel_file_transfer"]: + file_handle, file_size = await parallel_transfer_to_telegram(client, self.main_intent, + content.url, 0) + else: + file = await self.main_intent.download_media(content.url) + + if content.msgtype == MessageType.STICKER: + if mime != "image/gif": + mime, file, w, h = util.convert_image(file, source_mime=mime, + target_type="webp") + else: + # Remove sticker description + file_name = "sticker.gif" + + file_handle = await client.upload_file(file) + file_size = len(file) + + file_handle.name = file_name attributes = [DocumentAttributeFilename(file_name=file_name)] if w and h: attributes.append(DocumentAttributeImageSize(w, h)) + if (mime == "image/png" or mime == "image/jpeg") and file_size < max_image_size: + media = InputMediaUploadedPhoto(file_handle) + else: + media = InputMediaUploadedDocument(file=file_handle, attributes=attributes, + mime_type=mime or "application/octet-stream") + caption = content.body if content.body != file_name else None - media = await client.upload_file_direct( - file, mime, attributes, file_name, - max_image_size=config["bridge.image_as_file_size"] * 1000 ** 2) async with self.send_lock(sender_id): if await self._matrix_document_edit(client, content, space, caption, media, event_id): return diff --git a/mautrix_telegram/util/__init__.py b/mautrix_telegram/util/__init__.py index 727224bb..b2bfa88e 100644 --- a/mautrix_telegram/util/__init__.py +++ b/mautrix_telegram/util/__init__.py @@ -1,4 +1,5 @@ from .file_transfer import transfer_file_to_matrix, convert_image +from .parallel_file_transfer import parallel_transfer_to_telegram from .format_duration import format_duration from .recursive_dict import recursive_del, recursive_set, recursive_get from .color_log import ColorFormatter diff --git a/mautrix_telegram/util/parallel_file_transfer.py b/mautrix_telegram/util/parallel_file_transfer.py index 447eda3c..a2258972 100644 --- a/mautrix_telegram/util/parallel_file_transfer.py +++ b/mautrix_telegram/util/parallel_file_transfer.py @@ -13,22 +13,28 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Optional, List, AsyncGenerator, Union, Awaitable, DefaultDict +from typing import Optional, List, AsyncGenerator, Union, Awaitable, DefaultDict, Tuple from collections import defaultdict +import hashlib import asyncio import logging import time import math +from aiohttp import ClientResponse + from telethon.tl.types import (Document, InputFileLocation, InputDocumentFileLocation, - InputPhotoFileLocation, InputPeerPhotoFileLocation) + InputPhotoFileLocation, InputPeerPhotoFileLocation, TypeInputFile, + InputFileBig, InputFile) from telethon.tl.functions.auth import ExportAuthorizationRequest, ImportAuthorizationRequest -from telethon.tl.functions.upload import GetFileRequest +from telethon.tl.functions.upload import (GetFileRequest, SaveFilePartRequest, + SaveBigFilePartRequest) from telethon.network import MTProtoSender from telethon.crypto import AuthKey -from telethon import utils +from telethon import utils, helpers from mautrix.appservice import IntentAPI +from mautrix.types import ContentURI from ..tgclient import MautrixTelegramClient from ..db import TelegramFile as DBTelegramFile @@ -39,7 +45,7 @@ TypeLocation = Union[Document, InputDocumentFileLocation, InputPeerPhotoFileLoca InputFileLocation, InputPhotoFileLocation] -class Sender: +class DownloadSender: sender: MTProtoSender request: GetFileRequest remaining: int @@ -47,7 +53,7 @@ class Sender: def __init__(self, sender: MTProtoSender, file: TypeLocation, offset: int, limit: int, stride: int, count: int) -> None: - log.debug(f"Creating sender with {offset=} {limit=} {stride=} {count=}") + log.debug(f"Creating download sender with {offset=} {limit=} {stride=} {count=}") self.sender = sender self.request = GetFileRequest(file, offset=offset, limit=limit) self.stride = stride @@ -56,7 +62,6 @@ class Sender: async def next(self) -> Optional[bytes]: if not self.remaining: return None - log.debug(f"Sending {self.request!s}") result = await self.sender.send(self.request) self.remaining -= 1 self.request.offset += self.stride @@ -66,23 +71,76 @@ class Sender: return self.sender.disconnect() -class ParallelDownloader: +class UploadSender: + sender: MTProtoSender + request: Union[SaveFilePartRequest, SaveBigFilePartRequest] + part_count: int + stride: int + previous: Optional[asyncio.Task] + loop: asyncio.AbstractEventLoop + + def __init__(self, sender: MTProtoSender, file_id: int, part_count: int, big: bool, index: int, + stride: int, loop: asyncio.AbstractEventLoop) -> None: + log.debug( + f"Creating upload sender with {file_id=} {part_count=} {big=} {index=} {stride=}") + self.sender = sender + self.part_count = part_count + if big: + self.request = SaveBigFilePartRequest(file_id, index, part_count, b"") + else: + self.request = SaveFilePartRequest(file_id, index, b"") + self.stride = stride + self.previous = None + self.loop = loop + + async def next(self, data: bytes) -> None: + if self.previous: + await self.previous + self.previous = self.loop.create_task(self._next(data)) + + async def _next(self, data: bytes) -> None: + self.request.bytes = data + log.debug(f"Sending file part {self.request.file_part}/{self.part_count}" + f" with {len(data)} bytes") + await self.sender.send(self.request) + self.request.file_part += self.stride + + async def disconnect(self) -> None: + if self.previous: + await self.previous + return await self.sender.disconnect() + + +class ParallelTransferrer: client: MautrixTelegramClient loop: asyncio.AbstractEventLoop dc_id: int - senders: Optional[List[Sender]] + senders: Optional[List[Union[DownloadSender, UploadSender]]] auth_key: AuthKey + upload_ticker: int - def __init__(self, client: MautrixTelegramClient, dc_id: int) -> None: + def __init__(self, client: MautrixTelegramClient, dc_id: Optional[int] = None) -> None: self.client = client self.loop = self.client.loop - self.dc_id = dc_id - self.exported = dc_id and self.client.session.dc_id != dc_id - self.auth_key = self.client.session.auth_key if not self.exported else None + self.dc_id = dc_id or self.client.session.dc_id + self.auth_key = (None if dc_id and self.client.session.dc_id != dc_id + else self.client.session.auth_key) + self.senders = None + self.upload_ticker = 0 + + async def _cleanup(self) -> None: + await asyncio.gather(*[sender.disconnect() for sender in self.senders]) self.senders = None - async def _init(self, connections: int, file: TypeLocation, part_count: int, part_size: int - ) -> None: + @staticmethod + def _get_connection_count(file_size: int, max_count: int = 20, + full_size: int = 100 * 1024 * 1024) -> int: + if file_size > full_size: + return max_count + return math.ceil((file_size / full_size) * max_count) + + async def _init_download(self, connections: int, file: TypeLocation, part_count: int, + part_size: int) -> None: minimum, remainder = divmod(part_count, connections) def get_part_count() -> int: @@ -92,21 +150,38 @@ class ParallelDownloader: return minimum + 1 return minimum + # The first cross-DC sender will export+import the authorization, so we always create it + # before creating any other senders. self.senders = [ - await self._create_sender(file, 0, part_size, connections * part_size, - get_part_count()), - *await asyncio.gather(*[ - self._create_sender(file, i, part_size, connections * part_size, get_part_count()) - for i in range(1, connections) - ]) + await self._create_download_sender(file, 0, part_size, connections * part_size, + get_part_count()), + *await asyncio.gather( + *[self._create_download_sender(file, i, part_size, connections * part_size, + get_part_count()) + for i in range(1, connections)]) ] - async def _cleanup(self) -> None: - await asyncio.gather(*[sender.disconnect() for sender in self.senders]) - self.senders = None + async def _create_download_sender(self, file: TypeLocation, index: int, part_size: int, + stride: int, + part_count: int) -> DownloadSender: + return DownloadSender(await self._create_sender(), file, index * part_size, part_size, + stride, part_count) - async def _create_sender(self, file: TypeLocation, index: int, part_size: int, stride: int, - part_count: int) -> Sender: + async def _init_upload(self, connections: int, file_id: int, part_count: int, big: bool + ) -> None: + self.senders = [ + await self._create_upload_sender(file_id, part_count, big, 0, connections), + *await asyncio.gather( + *[self._create_upload_sender(file_id, part_count, big, i, connections) + for i in range(1, connections)]) + ] + + async def _create_upload_sender(self, file_id: int, part_count: int, big: bool, index: int, + stride: int) -> UploadSender: + return UploadSender(await self._create_sender(), file_id, part_count, big, index, stride, + loop=self.loop) + + async def _create_sender(self) -> MTProtoSender: dc = await self.client._get_dc(self.dc_id) sender = MTProtoSender(self.auth_key, self.loop, loggers=self.client._log) await sender.connect(self.client._connection(dc.ip_address, dc.port, dc.id, @@ -120,14 +195,23 @@ class ParallelDownloader: )) await sender.send(req) self.auth_key = sender.auth_key - return Sender(sender, file, index * part_size, part_size, stride, part_count) + return sender - @staticmethod - def _get_connection_count(file_size: int, max_count: int = 20, - full_size: int = 100 * 1024 * 1024) -> int: - if file_size > full_size: - return max_count - return math.ceil((file_size / full_size) * max_count) + async def init_upload(self, file_id: int, file_size: int, part_size_kb: Optional[float] = None, + connection_count: Optional[int] = None) -> Tuple[int, int, bool]: + connection_count = connection_count or self._get_connection_count(file_size) + part_size = (part_size_kb or utils.get_appropriated_part_size(file_size)) * 1024 + part_count = (file_size + part_size - 1) // part_size + is_large = file_size > 10 * 1024 * 1024 + await self._init_upload(connection_count, file_id, part_count, is_large) + return part_size, part_count, is_large + + async def upload(self, part: bytes) -> None: + await self.senders[self.upload_ticker].next(part) + self.upload_ticker = (self.upload_ticker + 1) % len(self.senders) + + async def finish_upload(self) -> None: + await self._cleanup() async def download(self, file: TypeLocation, file_size: int, part_size_kb: Optional[float] = None, @@ -137,7 +221,7 @@ class ParallelDownloader: part_count = math.ceil(file_size / part_size) log.debug("Starting parallel download: " f"{connection_count} {part_size} {part_count} {file!s}") - await self._init(connection_count, file, part_count, part_size) + await self._init_download(connection_count, file, part_count, part_size) part = 0 while part < part_count: @@ -167,9 +251,51 @@ async def parallel_transfer_to_matrix(client: MautrixTelegramClient, intent: Int dc_id, location = utils.get_input_location(location) # We lock the transfers because telegram has connection count limits async with parallel_transfer_locks[parallel_id]: - downloader = ParallelDownloader(client, dc_id) + downloader = ParallelTransferrer(client, dc_id) content_uri = await intent.upload_media(downloader.download(location, size), mime_type=mime_type, filename=filename, size=size) return DBTelegramFile(id=loc_id, mxc=content_uri, mime_type=mime_type, was_converted=False, timestamp=int(time.time()), size=size, width=None, height=None) + + +async def _internal_transfer_to_telegram(client: MautrixTelegramClient, response: ClientResponse + ) -> Tuple[TypeInputFile, int]: + file_id = helpers.generate_random_long() + file_size = response.content_length + + hash_md5 = hashlib.md5() + uploader = ParallelTransferrer(client) + part_size, part_count, is_large = await uploader.init_upload(file_id, file_size) + buffer = bytearray() + async for data in response.content: + if not is_large: + hash_md5.update(data) + if len(buffer) == 0 and len(data) == part_size: + await uploader.upload(data) + continue + new_len = len(buffer) + len(data) + if new_len >= part_size: + cutoff = part_size - len(buffer) + buffer.extend(data[:cutoff]) + await uploader.upload(bytes(buffer)) + buffer.clear() + buffer.extend(data[cutoff:]) + else: + buffer.extend(data) + if len(buffer) > 0: + await uploader.upload(bytes(buffer)) + await uploader.finish_upload() + if is_large: + return InputFileBig(file_id, part_count, "upload"), file_size + else: + return InputFile(file_id, part_count, "upload", hash_md5.hexdigest()), file_size + + +async def parallel_transfer_to_telegram(client: MautrixTelegramClient, intent: IntentAPI, + uri: ContentURI, parallel_id: int + ) -> Tuple[TypeInputFile, int]: + url = intent.api.get_download_url(uri) + async with parallel_transfer_locks[parallel_id]: + async with intent.api.session.get(url) as response: + return await _internal_transfer_to_telegram(client, response)