Add parallel file upload too

This commit is contained in:
Tulir Asokan
2019-10-27 02:43:29 +03:00
parent 574312d7c5
commit 73a6ad2cf2
3 changed files with 189 additions and 49 deletions
+27 -14
View File
@@ -17,7 +17,6 @@ from typing import Awaitable, Dict, List, Optional, Tuple, Union, Any, TYPE_CHEC
from html import escape as escape_html
from string import Template
from abc import ABC
import mimetypes
import magic
@@ -32,7 +31,7 @@ from telethon.tl.types import (
DocumentAttributeFilename, DocumentAttributeImageSize, GeoPoint,
InputChatUploadedPhoto, MessageActionChatEditPhoto, MessageMediaGeo,
SendMessageCancelAction, SendMessageTypingAction, TypeInputPeer, TypeMessageEntity,
UpdateNewMessage, InputMediaUploadedDocument)
UpdateNewMessage, InputMediaUploadedDocument, InputMediaUploadedPhoto)
from mautrix.types import (EventID, RoomID, UserID, ContentURI, MessageType, MessageEventContent,
TextMessageEventContent, MediaMessageEventContent, Format,
@@ -41,7 +40,7 @@ from mautrix.bridge import BasePortal as MautrixBasePortal
from ..types import TelegramID
from ..db import Message as DBMessage
from ..util import sane_mimetypes
from ..util import sane_mimetypes, parallel_transfer_to_telegram
from ..context import Context
from .. import puppet as p, user as u, formatter, util
from .base import BasePortal
@@ -250,28 +249,42 @@ class PortalMatrix(BasePortal, MautrixBasePortal, ABC):
async def _handle_matrix_file(self, sender_id: TelegramID, event_id: EventID,
space: TelegramID, client: 'MautrixTelegramClient',
content: MediaMessageEventContent, reply_to: TelegramID) -> None:
file = await self.main_intent.download_media(content.url)
mime = content.info.mimetype
w, h = content.info.width, content.info.height
file_name = content["net.maunium.telegram.internal.filename"]
max_image_size = config["bridge.image_as_file_size"] * 1000 ** 2
if content.msgtype == MessageType.STICKER:
if mime != "image/gif":
mime, file, w, h = util.convert_image(file, source_mime=mime, target_type="webp")
else:
# Remove sticker description
file_name = "sticker.gif"
if config["bridge.parallel_file_transfer"]:
file_handle, file_size = await parallel_transfer_to_telegram(client, self.main_intent,
content.url, 0)
else:
file = await self.main_intent.download_media(content.url)
if content.msgtype == MessageType.STICKER:
if mime != "image/gif":
mime, file, w, h = util.convert_image(file, source_mime=mime,
target_type="webp")
else:
# Remove sticker description
file_name = "sticker.gif"
file_handle = await client.upload_file(file)
file_size = len(file)
file_handle.name = file_name
attributes = [DocumentAttributeFilename(file_name=file_name)]
if w and h:
attributes.append(DocumentAttributeImageSize(w, h))
if (mime == "image/png" or mime == "image/jpeg") and file_size < max_image_size:
media = InputMediaUploadedPhoto(file_handle)
else:
media = InputMediaUploadedDocument(file=file_handle, attributes=attributes,
mime_type=mime or "application/octet-stream")
caption = content.body if content.body != file_name else None
media = await client.upload_file_direct(
file, mime, attributes, file_name,
max_image_size=config["bridge.image_as_file_size"] * 1000 ** 2)
async with self.send_lock(sender_id):
if await self._matrix_document_edit(client, content, space, caption, media, event_id):
return
+1
View File
@@ -1,4 +1,5 @@
from .file_transfer import transfer_file_to_matrix, convert_image
from .parallel_file_transfer import parallel_transfer_to_telegram
from .format_duration import format_duration
from .recursive_dict import recursive_del, recursive_set, recursive_get
from .color_log import ColorFormatter
+161 -35
View File
@@ -13,22 +13,28 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional, List, AsyncGenerator, Union, Awaitable, DefaultDict
from typing import Optional, List, AsyncGenerator, Union, Awaitable, DefaultDict, Tuple
from collections import defaultdict
import hashlib
import asyncio
import logging
import time
import math
from aiohttp import ClientResponse
from telethon.tl.types import (Document, InputFileLocation, InputDocumentFileLocation,
InputPhotoFileLocation, InputPeerPhotoFileLocation)
InputPhotoFileLocation, InputPeerPhotoFileLocation, TypeInputFile,
InputFileBig, InputFile)
from telethon.tl.functions.auth import ExportAuthorizationRequest, ImportAuthorizationRequest
from telethon.tl.functions.upload import GetFileRequest
from telethon.tl.functions.upload import (GetFileRequest, SaveFilePartRequest,
SaveBigFilePartRequest)
from telethon.network import MTProtoSender
from telethon.crypto import AuthKey
from telethon import utils
from telethon import utils, helpers
from mautrix.appservice import IntentAPI
from mautrix.types import ContentURI
from ..tgclient import MautrixTelegramClient
from ..db import TelegramFile as DBTelegramFile
@@ -39,7 +45,7 @@ TypeLocation = Union[Document, InputDocumentFileLocation, InputPeerPhotoFileLoca
InputFileLocation, InputPhotoFileLocation]
class Sender:
class DownloadSender:
sender: MTProtoSender
request: GetFileRequest
remaining: int
@@ -47,7 +53,7 @@ class Sender:
def __init__(self, sender: MTProtoSender, file: TypeLocation, offset: int, limit: int,
stride: int, count: int) -> None:
log.debug(f"Creating sender with {offset=} {limit=} {stride=} {count=}")
log.debug(f"Creating download sender with {offset=} {limit=} {stride=} {count=}")
self.sender = sender
self.request = GetFileRequest(file, offset=offset, limit=limit)
self.stride = stride
@@ -56,7 +62,6 @@ class Sender:
async def next(self) -> Optional[bytes]:
if not self.remaining:
return None
log.debug(f"Sending {self.request!s}")
result = await self.sender.send(self.request)
self.remaining -= 1
self.request.offset += self.stride
@@ -66,23 +71,76 @@ class Sender:
return self.sender.disconnect()
class ParallelDownloader:
class UploadSender:
sender: MTProtoSender
request: Union[SaveFilePartRequest, SaveBigFilePartRequest]
part_count: int
stride: int
previous: Optional[asyncio.Task]
loop: asyncio.AbstractEventLoop
def __init__(self, sender: MTProtoSender, file_id: int, part_count: int, big: bool, index: int,
stride: int, loop: asyncio.AbstractEventLoop) -> None:
log.debug(
f"Creating upload sender with {file_id=} {part_count=} {big=} {index=} {stride=}")
self.sender = sender
self.part_count = part_count
if big:
self.request = SaveBigFilePartRequest(file_id, index, part_count, b"")
else:
self.request = SaveFilePartRequest(file_id, index, b"")
self.stride = stride
self.previous = None
self.loop = loop
async def next(self, data: bytes) -> None:
if self.previous:
await self.previous
self.previous = self.loop.create_task(self._next(data))
async def _next(self, data: bytes) -> None:
self.request.bytes = data
log.debug(f"Sending file part {self.request.file_part}/{self.part_count}"
f" with {len(data)} bytes")
await self.sender.send(self.request)
self.request.file_part += self.stride
async def disconnect(self) -> None:
if self.previous:
await self.previous
return await self.sender.disconnect()
class ParallelTransferrer:
client: MautrixTelegramClient
loop: asyncio.AbstractEventLoop
dc_id: int
senders: Optional[List[Sender]]
senders: Optional[List[Union[DownloadSender, UploadSender]]]
auth_key: AuthKey
upload_ticker: int
def __init__(self, client: MautrixTelegramClient, dc_id: int) -> None:
def __init__(self, client: MautrixTelegramClient, dc_id: Optional[int] = None) -> None:
self.client = client
self.loop = self.client.loop
self.dc_id = dc_id
self.exported = dc_id and self.client.session.dc_id != dc_id
self.auth_key = self.client.session.auth_key if not self.exported else None
self.dc_id = dc_id or self.client.session.dc_id
self.auth_key = (None if dc_id and self.client.session.dc_id != dc_id
else self.client.session.auth_key)
self.senders = None
self.upload_ticker = 0
async def _cleanup(self) -> None:
await asyncio.gather(*[sender.disconnect() for sender in self.senders])
self.senders = None
async def _init(self, connections: int, file: TypeLocation, part_count: int, part_size: int
) -> None:
@staticmethod
def _get_connection_count(file_size: int, max_count: int = 20,
full_size: int = 100 * 1024 * 1024) -> int:
if file_size > full_size:
return max_count
return math.ceil((file_size / full_size) * max_count)
async def _init_download(self, connections: int, file: TypeLocation, part_count: int,
part_size: int) -> None:
minimum, remainder = divmod(part_count, connections)
def get_part_count() -> int:
@@ -92,21 +150,38 @@ class ParallelDownloader:
return minimum + 1
return minimum
# The first cross-DC sender will export+import the authorization, so we always create it
# before creating any other senders.
self.senders = [
await self._create_sender(file, 0, part_size, connections * part_size,
get_part_count()),
*await asyncio.gather(*[
self._create_sender(file, i, part_size, connections * part_size, get_part_count())
for i in range(1, connections)
])
await self._create_download_sender(file, 0, part_size, connections * part_size,
get_part_count()),
*await asyncio.gather(
*[self._create_download_sender(file, i, part_size, connections * part_size,
get_part_count())
for i in range(1, connections)])
]
async def _cleanup(self) -> None:
await asyncio.gather(*[sender.disconnect() for sender in self.senders])
self.senders = None
async def _create_download_sender(self, file: TypeLocation, index: int, part_size: int,
stride: int,
part_count: int) -> DownloadSender:
return DownloadSender(await self._create_sender(), file, index * part_size, part_size,
stride, part_count)
async def _create_sender(self, file: TypeLocation, index: int, part_size: int, stride: int,
part_count: int) -> Sender:
async def _init_upload(self, connections: int, file_id: int, part_count: int, big: bool
) -> None:
self.senders = [
await self._create_upload_sender(file_id, part_count, big, 0, connections),
*await asyncio.gather(
*[self._create_upload_sender(file_id, part_count, big, i, connections)
for i in range(1, connections)])
]
async def _create_upload_sender(self, file_id: int, part_count: int, big: bool, index: int,
stride: int) -> UploadSender:
return UploadSender(await self._create_sender(), file_id, part_count, big, index, stride,
loop=self.loop)
async def _create_sender(self) -> MTProtoSender:
dc = await self.client._get_dc(self.dc_id)
sender = MTProtoSender(self.auth_key, self.loop, loggers=self.client._log)
await sender.connect(self.client._connection(dc.ip_address, dc.port, dc.id,
@@ -120,14 +195,23 @@ class ParallelDownloader:
))
await sender.send(req)
self.auth_key = sender.auth_key
return Sender(sender, file, index * part_size, part_size, stride, part_count)
return sender
@staticmethod
def _get_connection_count(file_size: int, max_count: int = 20,
full_size: int = 100 * 1024 * 1024) -> int:
if file_size > full_size:
return max_count
return math.ceil((file_size / full_size) * max_count)
async def init_upload(self, file_id: int, file_size: int, part_size_kb: Optional[float] = None,
connection_count: Optional[int] = None) -> Tuple[int, int, bool]:
connection_count = connection_count or self._get_connection_count(file_size)
part_size = (part_size_kb or utils.get_appropriated_part_size(file_size)) * 1024
part_count = (file_size + part_size - 1) // part_size
is_large = file_size > 10 * 1024 * 1024
await self._init_upload(connection_count, file_id, part_count, is_large)
return part_size, part_count, is_large
async def upload(self, part: bytes) -> None:
await self.senders[self.upload_ticker].next(part)
self.upload_ticker = (self.upload_ticker + 1) % len(self.senders)
async def finish_upload(self) -> None:
await self._cleanup()
async def download(self, file: TypeLocation, file_size: int,
part_size_kb: Optional[float] = None,
@@ -137,7 +221,7 @@ class ParallelDownloader:
part_count = math.ceil(file_size / part_size)
log.debug("Starting parallel download: "
f"{connection_count} {part_size} {part_count} {file!s}")
await self._init(connection_count, file, part_count, part_size)
await self._init_download(connection_count, file, part_count, part_size)
part = 0
while part < part_count:
@@ -167,9 +251,51 @@ async def parallel_transfer_to_matrix(client: MautrixTelegramClient, intent: Int
dc_id, location = utils.get_input_location(location)
# We lock the transfers because telegram has connection count limits
async with parallel_transfer_locks[parallel_id]:
downloader = ParallelDownloader(client, dc_id)
downloader = ParallelTransferrer(client, dc_id)
content_uri = await intent.upload_media(downloader.download(location, size),
mime_type=mime_type, filename=filename, size=size)
return DBTelegramFile(id=loc_id, mxc=content_uri, mime_type=mime_type,
was_converted=False, timestamp=int(time.time()), size=size,
width=None, height=None)
async def _internal_transfer_to_telegram(client: MautrixTelegramClient, response: ClientResponse
) -> Tuple[TypeInputFile, int]:
file_id = helpers.generate_random_long()
file_size = response.content_length
hash_md5 = hashlib.md5()
uploader = ParallelTransferrer(client)
part_size, part_count, is_large = await uploader.init_upload(file_id, file_size)
buffer = bytearray()
async for data in response.content:
if not is_large:
hash_md5.update(data)
if len(buffer) == 0 and len(data) == part_size:
await uploader.upload(data)
continue
new_len = len(buffer) + len(data)
if new_len >= part_size:
cutoff = part_size - len(buffer)
buffer.extend(data[:cutoff])
await uploader.upload(bytes(buffer))
buffer.clear()
buffer.extend(data[cutoff:])
else:
buffer.extend(data)
if len(buffer) > 0:
await uploader.upload(bytes(buffer))
await uploader.finish_upload()
if is_large:
return InputFileBig(file_id, part_count, "upload"), file_size
else:
return InputFile(file_id, part_count, "upload", hash_md5.hexdigest()), file_size
async def parallel_transfer_to_telegram(client: MautrixTelegramClient, intent: IntentAPI,
uri: ContentURI, parallel_id: int
) -> Tuple[TypeInputFile, int]:
url = intent.api.get_download_url(uri)
async with parallel_transfer_locks[parallel_id]:
async with intent.api.session.get(url) as response:
return await _internal_transfer_to_telegram(client, response)