Add parallel file upload too
This commit is contained in:
@@ -17,7 +17,6 @@ from typing import Awaitable, Dict, List, Optional, Tuple, Union, Any, TYPE_CHEC
|
||||
from html import escape as escape_html
|
||||
from string import Template
|
||||
from abc import ABC
|
||||
import mimetypes
|
||||
|
||||
import magic
|
||||
|
||||
@@ -32,7 +31,7 @@ from telethon.tl.types import (
|
||||
DocumentAttributeFilename, DocumentAttributeImageSize, GeoPoint,
|
||||
InputChatUploadedPhoto, MessageActionChatEditPhoto, MessageMediaGeo,
|
||||
SendMessageCancelAction, SendMessageTypingAction, TypeInputPeer, TypeMessageEntity,
|
||||
UpdateNewMessage, InputMediaUploadedDocument)
|
||||
UpdateNewMessage, InputMediaUploadedDocument, InputMediaUploadedPhoto)
|
||||
|
||||
from mautrix.types import (EventID, RoomID, UserID, ContentURI, MessageType, MessageEventContent,
|
||||
TextMessageEventContent, MediaMessageEventContent, Format,
|
||||
@@ -41,7 +40,7 @@ from mautrix.bridge import BasePortal as MautrixBasePortal
|
||||
|
||||
from ..types import TelegramID
|
||||
from ..db import Message as DBMessage
|
||||
from ..util import sane_mimetypes
|
||||
from ..util import sane_mimetypes, parallel_transfer_to_telegram
|
||||
from ..context import Context
|
||||
from .. import puppet as p, user as u, formatter, util
|
||||
from .base import BasePortal
|
||||
@@ -250,28 +249,42 @@ class PortalMatrix(BasePortal, MautrixBasePortal, ABC):
|
||||
async def _handle_matrix_file(self, sender_id: TelegramID, event_id: EventID,
|
||||
space: TelegramID, client: 'MautrixTelegramClient',
|
||||
content: MediaMessageEventContent, reply_to: TelegramID) -> None:
|
||||
file = await self.main_intent.download_media(content.url)
|
||||
|
||||
mime = content.info.mimetype
|
||||
w, h = content.info.width, content.info.height
|
||||
file_name = content["net.maunium.telegram.internal.filename"]
|
||||
max_image_size = config["bridge.image_as_file_size"] * 1000 ** 2
|
||||
|
||||
if content.msgtype == MessageType.STICKER:
|
||||
if mime != "image/gif":
|
||||
mime, file, w, h = util.convert_image(file, source_mime=mime, target_type="webp")
|
||||
else:
|
||||
# Remove sticker description
|
||||
file_name = "sticker.gif"
|
||||
if config["bridge.parallel_file_transfer"]:
|
||||
file_handle, file_size = await parallel_transfer_to_telegram(client, self.main_intent,
|
||||
content.url, 0)
|
||||
else:
|
||||
file = await self.main_intent.download_media(content.url)
|
||||
|
||||
if content.msgtype == MessageType.STICKER:
|
||||
if mime != "image/gif":
|
||||
mime, file, w, h = util.convert_image(file, source_mime=mime,
|
||||
target_type="webp")
|
||||
else:
|
||||
# Remove sticker description
|
||||
file_name = "sticker.gif"
|
||||
|
||||
file_handle = await client.upload_file(file)
|
||||
file_size = len(file)
|
||||
|
||||
file_handle.name = file_name
|
||||
|
||||
attributes = [DocumentAttributeFilename(file_name=file_name)]
|
||||
if w and h:
|
||||
attributes.append(DocumentAttributeImageSize(w, h))
|
||||
|
||||
if (mime == "image/png" or mime == "image/jpeg") and file_size < max_image_size:
|
||||
media = InputMediaUploadedPhoto(file_handle)
|
||||
else:
|
||||
media = InputMediaUploadedDocument(file=file_handle, attributes=attributes,
|
||||
mime_type=mime or "application/octet-stream")
|
||||
|
||||
caption = content.body if content.body != file_name else None
|
||||
|
||||
media = await client.upload_file_direct(
|
||||
file, mime, attributes, file_name,
|
||||
max_image_size=config["bridge.image_as_file_size"] * 1000 ** 2)
|
||||
async with self.send_lock(sender_id):
|
||||
if await self._matrix_document_edit(client, content, space, caption, media, event_id):
|
||||
return
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from .file_transfer import transfer_file_to_matrix, convert_image
|
||||
from .parallel_file_transfer import parallel_transfer_to_telegram
|
||||
from .format_duration import format_duration
|
||||
from .recursive_dict import recursive_del, recursive_set, recursive_get
|
||||
from .color_log import ColorFormatter
|
||||
|
||||
@@ -13,22 +13,28 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Optional, List, AsyncGenerator, Union, Awaitable, DefaultDict
|
||||
from typing import Optional, List, AsyncGenerator, Union, Awaitable, DefaultDict, Tuple
|
||||
from collections import defaultdict
|
||||
import hashlib
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
import math
|
||||
|
||||
from aiohttp import ClientResponse
|
||||
|
||||
from telethon.tl.types import (Document, InputFileLocation, InputDocumentFileLocation,
|
||||
InputPhotoFileLocation, InputPeerPhotoFileLocation)
|
||||
InputPhotoFileLocation, InputPeerPhotoFileLocation, TypeInputFile,
|
||||
InputFileBig, InputFile)
|
||||
from telethon.tl.functions.auth import ExportAuthorizationRequest, ImportAuthorizationRequest
|
||||
from telethon.tl.functions.upload import GetFileRequest
|
||||
from telethon.tl.functions.upload import (GetFileRequest, SaveFilePartRequest,
|
||||
SaveBigFilePartRequest)
|
||||
from telethon.network import MTProtoSender
|
||||
from telethon.crypto import AuthKey
|
||||
from telethon import utils
|
||||
from telethon import utils, helpers
|
||||
|
||||
from mautrix.appservice import IntentAPI
|
||||
from mautrix.types import ContentURI
|
||||
|
||||
from ..tgclient import MautrixTelegramClient
|
||||
from ..db import TelegramFile as DBTelegramFile
|
||||
@@ -39,7 +45,7 @@ TypeLocation = Union[Document, InputDocumentFileLocation, InputPeerPhotoFileLoca
|
||||
InputFileLocation, InputPhotoFileLocation]
|
||||
|
||||
|
||||
class Sender:
|
||||
class DownloadSender:
|
||||
sender: MTProtoSender
|
||||
request: GetFileRequest
|
||||
remaining: int
|
||||
@@ -47,7 +53,7 @@ class Sender:
|
||||
|
||||
def __init__(self, sender: MTProtoSender, file: TypeLocation, offset: int, limit: int,
|
||||
stride: int, count: int) -> None:
|
||||
log.debug(f"Creating sender with {offset=} {limit=} {stride=} {count=}")
|
||||
log.debug(f"Creating download sender with {offset=} {limit=} {stride=} {count=}")
|
||||
self.sender = sender
|
||||
self.request = GetFileRequest(file, offset=offset, limit=limit)
|
||||
self.stride = stride
|
||||
@@ -56,7 +62,6 @@ class Sender:
|
||||
async def next(self) -> Optional[bytes]:
|
||||
if not self.remaining:
|
||||
return None
|
||||
log.debug(f"Sending {self.request!s}")
|
||||
result = await self.sender.send(self.request)
|
||||
self.remaining -= 1
|
||||
self.request.offset += self.stride
|
||||
@@ -66,23 +71,76 @@ class Sender:
|
||||
return self.sender.disconnect()
|
||||
|
||||
|
||||
class ParallelDownloader:
|
||||
class UploadSender:
|
||||
sender: MTProtoSender
|
||||
request: Union[SaveFilePartRequest, SaveBigFilePartRequest]
|
||||
part_count: int
|
||||
stride: int
|
||||
previous: Optional[asyncio.Task]
|
||||
loop: asyncio.AbstractEventLoop
|
||||
|
||||
def __init__(self, sender: MTProtoSender, file_id: int, part_count: int, big: bool, index: int,
|
||||
stride: int, loop: asyncio.AbstractEventLoop) -> None:
|
||||
log.debug(
|
||||
f"Creating upload sender with {file_id=} {part_count=} {big=} {index=} {stride=}")
|
||||
self.sender = sender
|
||||
self.part_count = part_count
|
||||
if big:
|
||||
self.request = SaveBigFilePartRequest(file_id, index, part_count, b"")
|
||||
else:
|
||||
self.request = SaveFilePartRequest(file_id, index, b"")
|
||||
self.stride = stride
|
||||
self.previous = None
|
||||
self.loop = loop
|
||||
|
||||
async def next(self, data: bytes) -> None:
|
||||
if self.previous:
|
||||
await self.previous
|
||||
self.previous = self.loop.create_task(self._next(data))
|
||||
|
||||
async def _next(self, data: bytes) -> None:
|
||||
self.request.bytes = data
|
||||
log.debug(f"Sending file part {self.request.file_part}/{self.part_count}"
|
||||
f" with {len(data)} bytes")
|
||||
await self.sender.send(self.request)
|
||||
self.request.file_part += self.stride
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
if self.previous:
|
||||
await self.previous
|
||||
return await self.sender.disconnect()
|
||||
|
||||
|
||||
class ParallelTransferrer:
|
||||
client: MautrixTelegramClient
|
||||
loop: asyncio.AbstractEventLoop
|
||||
dc_id: int
|
||||
senders: Optional[List[Sender]]
|
||||
senders: Optional[List[Union[DownloadSender, UploadSender]]]
|
||||
auth_key: AuthKey
|
||||
upload_ticker: int
|
||||
|
||||
def __init__(self, client: MautrixTelegramClient, dc_id: int) -> None:
|
||||
def __init__(self, client: MautrixTelegramClient, dc_id: Optional[int] = None) -> None:
|
||||
self.client = client
|
||||
self.loop = self.client.loop
|
||||
self.dc_id = dc_id
|
||||
self.exported = dc_id and self.client.session.dc_id != dc_id
|
||||
self.auth_key = self.client.session.auth_key if not self.exported else None
|
||||
self.dc_id = dc_id or self.client.session.dc_id
|
||||
self.auth_key = (None if dc_id and self.client.session.dc_id != dc_id
|
||||
else self.client.session.auth_key)
|
||||
self.senders = None
|
||||
self.upload_ticker = 0
|
||||
|
||||
async def _cleanup(self) -> None:
|
||||
await asyncio.gather(*[sender.disconnect() for sender in self.senders])
|
||||
self.senders = None
|
||||
|
||||
async def _init(self, connections: int, file: TypeLocation, part_count: int, part_size: int
|
||||
) -> None:
|
||||
@staticmethod
|
||||
def _get_connection_count(file_size: int, max_count: int = 20,
|
||||
full_size: int = 100 * 1024 * 1024) -> int:
|
||||
if file_size > full_size:
|
||||
return max_count
|
||||
return math.ceil((file_size / full_size) * max_count)
|
||||
|
||||
async def _init_download(self, connections: int, file: TypeLocation, part_count: int,
|
||||
part_size: int) -> None:
|
||||
minimum, remainder = divmod(part_count, connections)
|
||||
|
||||
def get_part_count() -> int:
|
||||
@@ -92,21 +150,38 @@ class ParallelDownloader:
|
||||
return minimum + 1
|
||||
return minimum
|
||||
|
||||
# The first cross-DC sender will export+import the authorization, so we always create it
|
||||
# before creating any other senders.
|
||||
self.senders = [
|
||||
await self._create_sender(file, 0, part_size, connections * part_size,
|
||||
get_part_count()),
|
||||
*await asyncio.gather(*[
|
||||
self._create_sender(file, i, part_size, connections * part_size, get_part_count())
|
||||
for i in range(1, connections)
|
||||
])
|
||||
await self._create_download_sender(file, 0, part_size, connections * part_size,
|
||||
get_part_count()),
|
||||
*await asyncio.gather(
|
||||
*[self._create_download_sender(file, i, part_size, connections * part_size,
|
||||
get_part_count())
|
||||
for i in range(1, connections)])
|
||||
]
|
||||
|
||||
async def _cleanup(self) -> None:
|
||||
await asyncio.gather(*[sender.disconnect() for sender in self.senders])
|
||||
self.senders = None
|
||||
async def _create_download_sender(self, file: TypeLocation, index: int, part_size: int,
|
||||
stride: int,
|
||||
part_count: int) -> DownloadSender:
|
||||
return DownloadSender(await self._create_sender(), file, index * part_size, part_size,
|
||||
stride, part_count)
|
||||
|
||||
async def _create_sender(self, file: TypeLocation, index: int, part_size: int, stride: int,
|
||||
part_count: int) -> Sender:
|
||||
async def _init_upload(self, connections: int, file_id: int, part_count: int, big: bool
|
||||
) -> None:
|
||||
self.senders = [
|
||||
await self._create_upload_sender(file_id, part_count, big, 0, connections),
|
||||
*await asyncio.gather(
|
||||
*[self._create_upload_sender(file_id, part_count, big, i, connections)
|
||||
for i in range(1, connections)])
|
||||
]
|
||||
|
||||
async def _create_upload_sender(self, file_id: int, part_count: int, big: bool, index: int,
|
||||
stride: int) -> UploadSender:
|
||||
return UploadSender(await self._create_sender(), file_id, part_count, big, index, stride,
|
||||
loop=self.loop)
|
||||
|
||||
async def _create_sender(self) -> MTProtoSender:
|
||||
dc = await self.client._get_dc(self.dc_id)
|
||||
sender = MTProtoSender(self.auth_key, self.loop, loggers=self.client._log)
|
||||
await sender.connect(self.client._connection(dc.ip_address, dc.port, dc.id,
|
||||
@@ -120,14 +195,23 @@ class ParallelDownloader:
|
||||
))
|
||||
await sender.send(req)
|
||||
self.auth_key = sender.auth_key
|
||||
return Sender(sender, file, index * part_size, part_size, stride, part_count)
|
||||
return sender
|
||||
|
||||
@staticmethod
|
||||
def _get_connection_count(file_size: int, max_count: int = 20,
|
||||
full_size: int = 100 * 1024 * 1024) -> int:
|
||||
if file_size > full_size:
|
||||
return max_count
|
||||
return math.ceil((file_size / full_size) * max_count)
|
||||
async def init_upload(self, file_id: int, file_size: int, part_size_kb: Optional[float] = None,
|
||||
connection_count: Optional[int] = None) -> Tuple[int, int, bool]:
|
||||
connection_count = connection_count or self._get_connection_count(file_size)
|
||||
part_size = (part_size_kb or utils.get_appropriated_part_size(file_size)) * 1024
|
||||
part_count = (file_size + part_size - 1) // part_size
|
||||
is_large = file_size > 10 * 1024 * 1024
|
||||
await self._init_upload(connection_count, file_id, part_count, is_large)
|
||||
return part_size, part_count, is_large
|
||||
|
||||
async def upload(self, part: bytes) -> None:
|
||||
await self.senders[self.upload_ticker].next(part)
|
||||
self.upload_ticker = (self.upload_ticker + 1) % len(self.senders)
|
||||
|
||||
async def finish_upload(self) -> None:
|
||||
await self._cleanup()
|
||||
|
||||
async def download(self, file: TypeLocation, file_size: int,
|
||||
part_size_kb: Optional[float] = None,
|
||||
@@ -137,7 +221,7 @@ class ParallelDownloader:
|
||||
part_count = math.ceil(file_size / part_size)
|
||||
log.debug("Starting parallel download: "
|
||||
f"{connection_count} {part_size} {part_count} {file!s}")
|
||||
await self._init(connection_count, file, part_count, part_size)
|
||||
await self._init_download(connection_count, file, part_count, part_size)
|
||||
|
||||
part = 0
|
||||
while part < part_count:
|
||||
@@ -167,9 +251,51 @@ async def parallel_transfer_to_matrix(client: MautrixTelegramClient, intent: Int
|
||||
dc_id, location = utils.get_input_location(location)
|
||||
# We lock the transfers because telegram has connection count limits
|
||||
async with parallel_transfer_locks[parallel_id]:
|
||||
downloader = ParallelDownloader(client, dc_id)
|
||||
downloader = ParallelTransferrer(client, dc_id)
|
||||
content_uri = await intent.upload_media(downloader.download(location, size),
|
||||
mime_type=mime_type, filename=filename, size=size)
|
||||
return DBTelegramFile(id=loc_id, mxc=content_uri, mime_type=mime_type,
|
||||
was_converted=False, timestamp=int(time.time()), size=size,
|
||||
width=None, height=None)
|
||||
|
||||
|
||||
async def _internal_transfer_to_telegram(client: MautrixTelegramClient, response: ClientResponse
|
||||
) -> Tuple[TypeInputFile, int]:
|
||||
file_id = helpers.generate_random_long()
|
||||
file_size = response.content_length
|
||||
|
||||
hash_md5 = hashlib.md5()
|
||||
uploader = ParallelTransferrer(client)
|
||||
part_size, part_count, is_large = await uploader.init_upload(file_id, file_size)
|
||||
buffer = bytearray()
|
||||
async for data in response.content:
|
||||
if not is_large:
|
||||
hash_md5.update(data)
|
||||
if len(buffer) == 0 and len(data) == part_size:
|
||||
await uploader.upload(data)
|
||||
continue
|
||||
new_len = len(buffer) + len(data)
|
||||
if new_len >= part_size:
|
||||
cutoff = part_size - len(buffer)
|
||||
buffer.extend(data[:cutoff])
|
||||
await uploader.upload(bytes(buffer))
|
||||
buffer.clear()
|
||||
buffer.extend(data[cutoff:])
|
||||
else:
|
||||
buffer.extend(data)
|
||||
if len(buffer) > 0:
|
||||
await uploader.upload(bytes(buffer))
|
||||
await uploader.finish_upload()
|
||||
if is_large:
|
||||
return InputFileBig(file_id, part_count, "upload"), file_size
|
||||
else:
|
||||
return InputFile(file_id, part_count, "upload", hash_md5.hexdigest()), file_size
|
||||
|
||||
|
||||
async def parallel_transfer_to_telegram(client: MautrixTelegramClient, intent: IntentAPI,
|
||||
uri: ContentURI, parallel_id: int
|
||||
) -> Tuple[TypeInputFile, int]:
|
||||
url = intent.api.get_download_url(uri)
|
||||
async with parallel_transfer_locks[parallel_id]:
|
||||
async with intent.api.session.get(url) as response:
|
||||
return await _internal_transfer_to_telegram(client, response)
|
||||
|
||||
Reference in New Issue
Block a user