Blacken and isort code
This commit is contained in:
@@ -13,34 +13,45 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Optional, List, AsyncGenerator, Union, Awaitable, DefaultDict, Tuple, cast
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import AsyncGenerator, Awaitable, Union, cast
|
||||
from collections import defaultdict
|
||||
import hashlib
|
||||
import asyncio
|
||||
import hashlib
|
||||
import logging
|
||||
import time
|
||||
import math
|
||||
import time
|
||||
|
||||
from aiohttp import ClientResponse
|
||||
|
||||
from telethon.tl.types import (Document, InputFileLocation, InputDocumentFileLocation,
|
||||
InputPhotoFileLocation, InputPeerPhotoFileLocation, TypeInputFile,
|
||||
InputFileBig, InputFile)
|
||||
from telethon.tl.functions.auth import ExportAuthorizationRequest, ImportAuthorizationRequest
|
||||
from telethon.tl.functions import InvokeWithLayerRequest
|
||||
from telethon.tl.functions.upload import (GetFileRequest, SaveFilePartRequest,
|
||||
SaveBigFilePartRequest)
|
||||
from telethon.tl.alltlobjects import LAYER
|
||||
from telethon.network import MTProtoSender
|
||||
from telethon import helpers, utils
|
||||
from telethon.crypto import AuthKey
|
||||
from telethon import utils, helpers
|
||||
from telethon.network import MTProtoSender
|
||||
from telethon.tl.alltlobjects import LAYER
|
||||
from telethon.tl.functions import InvokeWithLayerRequest
|
||||
from telethon.tl.functions.auth import ExportAuthorizationRequest, ImportAuthorizationRequest
|
||||
from telethon.tl.functions.upload import (
|
||||
GetFileRequest,
|
||||
SaveBigFilePartRequest,
|
||||
SaveFilePartRequest,
|
||||
)
|
||||
from telethon.tl.types import (
|
||||
Document,
|
||||
InputDocumentFileLocation,
|
||||
InputFile,
|
||||
InputFileBig,
|
||||
InputFileLocation,
|
||||
InputPeerPhotoFileLocation,
|
||||
InputPhotoFileLocation,
|
||||
TypeInputFile,
|
||||
)
|
||||
|
||||
from mautrix.appservice import IntentAPI
|
||||
from mautrix.types import ContentURI, EncryptedFile
|
||||
from mautrix.util.logging import TraceLogger
|
||||
|
||||
from ..tgclient import MautrixTelegramClient
|
||||
from ..db import TelegramFile as DBTelegramFile
|
||||
from ..tgclient import MautrixTelegramClient
|
||||
|
||||
try:
|
||||
from mautrix.crypto.attachments import async_encrypt_attachment
|
||||
@@ -49,8 +60,13 @@ except ImportError:
|
||||
|
||||
log: TraceLogger = cast(TraceLogger, logging.getLogger("mau.util"))
|
||||
|
||||
TypeLocation = Union[Document, InputDocumentFileLocation, InputPeerPhotoFileLocation,
|
||||
InputFileLocation, InputPhotoFileLocation]
|
||||
TypeLocation = Union[
|
||||
Document,
|
||||
InputDocumentFileLocation,
|
||||
InputPeerPhotoFileLocation,
|
||||
InputFileLocation,
|
||||
InputPhotoFileLocation,
|
||||
]
|
||||
|
||||
|
||||
class DownloadSender:
|
||||
@@ -59,14 +75,21 @@ class DownloadSender:
|
||||
remaining: int
|
||||
stride: int
|
||||
|
||||
def __init__(self, sender: MTProtoSender, file: TypeLocation, offset: int, limit: int,
|
||||
stride: int, count: int) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
sender: MTProtoSender,
|
||||
file: TypeLocation,
|
||||
offset: int,
|
||||
limit: int,
|
||||
stride: int,
|
||||
count: int,
|
||||
) -> None:
|
||||
self.sender = sender
|
||||
self.request = GetFileRequest(file, offset=offset, limit=limit)
|
||||
self.stride = stride
|
||||
self.remaining = count
|
||||
|
||||
async def next(self) -> Optional[bytes]:
|
||||
async def next(self) -> bytes | None:
|
||||
if not self.remaining:
|
||||
return None
|
||||
result = await self.sender.send(self.request)
|
||||
@@ -80,14 +103,22 @@ class DownloadSender:
|
||||
|
||||
class UploadSender:
|
||||
sender: MTProtoSender
|
||||
request: Union[SaveFilePartRequest, SaveBigFilePartRequest]
|
||||
request: SaveFilePartRequest < SaveBigFilePartRequest
|
||||
part_count: int
|
||||
stride: int
|
||||
previous: Optional[asyncio.Task]
|
||||
previous: asyncio.Task | None
|
||||
loop: asyncio.AbstractEventLoop
|
||||
|
||||
def __init__(self, sender: MTProtoSender, file_id: int, part_count: int, big: bool, index: int,
|
||||
stride: int, loop: asyncio.AbstractEventLoop) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
sender: MTProtoSender,
|
||||
file_id: int,
|
||||
part_count: int,
|
||||
big: bool,
|
||||
index: int,
|
||||
stride: int,
|
||||
loop: asyncio.AbstractEventLoop,
|
||||
) -> None:
|
||||
self.sender = sender
|
||||
self.part_count = part_count
|
||||
if big:
|
||||
@@ -105,8 +136,10 @@ class UploadSender:
|
||||
|
||||
async def _next(self, data: bytes) -> None:
|
||||
self.request.bytes = data
|
||||
log.trace(f"Sending file part {self.request.file_part}/{self.part_count}"
|
||||
f" with {len(data)} bytes")
|
||||
log.trace(
|
||||
f"Sending file part {self.request.file_part}/{self.part_count}"
|
||||
f" with {len(data)} bytes"
|
||||
)
|
||||
await self.sender.send(self.request)
|
||||
self.request.file_part += self.stride
|
||||
|
||||
@@ -120,16 +153,17 @@ class ParallelTransferrer:
|
||||
client: MautrixTelegramClient
|
||||
loop: asyncio.AbstractEventLoop
|
||||
dc_id: int
|
||||
senders: Optional[List[Union[DownloadSender, UploadSender]]]
|
||||
senders: list[DownloadSender | UploadSender] | None
|
||||
auth_key: AuthKey
|
||||
upload_ticker: int
|
||||
|
||||
def __init__(self, client: MautrixTelegramClient, dc_id: Optional[int] = None) -> None:
|
||||
def __init__(self, client: MautrixTelegramClient, dc_id: int | None = None) -> None:
|
||||
self.client = client
|
||||
self.loop = self.client.loop
|
||||
self.dc_id = dc_id or self.client.session.dc_id
|
||||
self.auth_key = (None if dc_id and self.client.session.dc_id != dc_id
|
||||
else self.client.session.auth_key)
|
||||
self.auth_key = (
|
||||
None if dc_id and self.client.session.dc_id != dc_id else self.client.session.auth_key
|
||||
)
|
||||
self.senders = None
|
||||
self.upload_ticker = 0
|
||||
|
||||
@@ -138,14 +172,16 @@ class ParallelTransferrer:
|
||||
self.senders = None
|
||||
|
||||
@staticmethod
|
||||
def _get_connection_count(file_size: int, max_count: int = 20,
|
||||
full_size: int = 100 * 1024 * 1024) -> int:
|
||||
def _get_connection_count(
|
||||
file_size: int, max_count: int = 20, full_size: int = 100 * 1024 * 1024
|
||||
) -> int:
|
||||
if file_size > full_size:
|
||||
return max_count
|
||||
return math.ceil((file_size / full_size) * max_count)
|
||||
|
||||
async def _init_download(self, connections: int, file: TypeLocation, part_count: int,
|
||||
part_size: int) -> None:
|
||||
async def _init_download(
|
||||
self, connections: int, file: TypeLocation, part_count: int, part_size: int
|
||||
) -> None:
|
||||
minimum, remainder = divmod(part_count, connections)
|
||||
|
||||
def get_part_count() -> int:
|
||||
@@ -158,52 +194,72 @@ class ParallelTransferrer:
|
||||
# The first cross-DC sender will export+import the authorization, so we always create it
|
||||
# before creating any other senders.
|
||||
self.senders = [
|
||||
await self._create_download_sender(file, 0, part_size, connections * part_size,
|
||||
get_part_count()),
|
||||
await self._create_download_sender(
|
||||
file, 0, part_size, connections * part_size, get_part_count()
|
||||
),
|
||||
*await asyncio.gather(
|
||||
*(self._create_download_sender(file, i, part_size, connections * part_size,
|
||||
get_part_count())
|
||||
for i in range(1, connections)))
|
||||
*(
|
||||
self._create_download_sender(
|
||||
file, i, part_size, connections * part_size, get_part_count()
|
||||
)
|
||||
for i in range(1, connections)
|
||||
)
|
||||
),
|
||||
]
|
||||
|
||||
async def _create_download_sender(self, file: TypeLocation, index: int, part_size: int,
|
||||
stride: int,
|
||||
part_count: int) -> DownloadSender:
|
||||
return DownloadSender(await self._create_sender(), file, index * part_size, part_size,
|
||||
stride, part_count)
|
||||
async def _create_download_sender(
|
||||
self, file: TypeLocation, index: int, part_size: int, stride: int, part_count: int
|
||||
) -> DownloadSender:
|
||||
return DownloadSender(
|
||||
await self._create_sender(), file, index * part_size, part_size, stride, part_count
|
||||
)
|
||||
|
||||
async def _init_upload(self, connections: int, file_id: int, part_count: int, big: bool
|
||||
) -> None:
|
||||
async def _init_upload(
|
||||
self, connections: int, file_id: int, part_count: int, big: bool
|
||||
) -> None:
|
||||
self.senders = [
|
||||
await self._create_upload_sender(file_id, part_count, big, 0, connections),
|
||||
*await asyncio.gather(
|
||||
*(self._create_upload_sender(file_id, part_count, big, i, connections)
|
||||
for i in range(1, connections)))
|
||||
*(
|
||||
self._create_upload_sender(file_id, part_count, big, i, connections)
|
||||
for i in range(1, connections)
|
||||
)
|
||||
),
|
||||
]
|
||||
|
||||
async def _create_upload_sender(self, file_id: int, part_count: int, big: bool, index: int,
|
||||
stride: int) -> UploadSender:
|
||||
return UploadSender(await self._create_sender(), file_id, part_count, big, index, stride,
|
||||
loop=self.loop)
|
||||
async def _create_upload_sender(
|
||||
self, file_id: int, part_count: int, big: bool, index: int, stride: int
|
||||
) -> UploadSender:
|
||||
return UploadSender(
|
||||
await self._create_sender(), file_id, part_count, big, index, stride, loop=self.loop
|
||||
)
|
||||
|
||||
async def _create_sender(self) -> MTProtoSender:
|
||||
dc = await self.client._get_dc(self.dc_id)
|
||||
sender = MTProtoSender(self.auth_key, loggers=self.client._log)
|
||||
await sender.connect(self.client._connection(dc.ip_address, dc.port, dc.id,
|
||||
loggers=self.client._log,
|
||||
proxy=self.client._proxy))
|
||||
await sender.connect(
|
||||
self.client._connection(
|
||||
dc.ip_address, dc.port, dc.id, loggers=self.client._log, proxy=self.client._proxy
|
||||
)
|
||||
)
|
||||
if not self.auth_key:
|
||||
log.debug(f"Exporting auth to DC {self.dc_id}")
|
||||
auth = await self.client(ExportAuthorizationRequest(self.dc_id))
|
||||
self.client._init_request.query = ImportAuthorizationRequest(id=auth.id,
|
||||
bytes=auth.bytes)
|
||||
self.client._init_request.query = ImportAuthorizationRequest(
|
||||
id=auth.id, bytes=auth.bytes
|
||||
)
|
||||
req = InvokeWithLayerRequest(LAYER, self.client._init_request)
|
||||
await sender.send(req)
|
||||
self.auth_key = sender.auth_key
|
||||
return sender
|
||||
|
||||
async def init_upload(self, file_id: int, file_size: int, part_size_kb: Optional[float] = None,
|
||||
connection_count: Optional[int] = None) -> Tuple[int, int, bool]:
|
||||
async def init_upload(
|
||||
self,
|
||||
file_id: int,
|
||||
file_size: int,
|
||||
part_size_kb: float | None = None,
|
||||
connection_count: int | None = None,
|
||||
) -> tuple[int, int, bool]:
|
||||
connection_count = connection_count or self._get_connection_count(file_size)
|
||||
part_size = (part_size_kb or utils.get_appropriated_part_size(file_size)) * 1024
|
||||
part_count = (file_size + part_size - 1) // part_size
|
||||
@@ -218,14 +274,19 @@ class ParallelTransferrer:
|
||||
async def finish_upload(self) -> None:
|
||||
await self._cleanup()
|
||||
|
||||
async def download(self, file: TypeLocation, file_size: int,
|
||||
part_size_kb: Optional[float] = None,
|
||||
connection_count: Optional[int] = None) -> AsyncGenerator[bytes, None]:
|
||||
async def download(
|
||||
self,
|
||||
file: TypeLocation,
|
||||
file_size: int,
|
||||
part_size_kb: float | None = None,
|
||||
connection_count: int | None = None,
|
||||
) -> AsyncGenerator[bytes, None]:
|
||||
connection_count = connection_count or self._get_connection_count(file_size)
|
||||
part_size = (part_size_kb or utils.get_appropriated_part_size(file_size)) * 1024
|
||||
part_count = math.ceil(file_size / part_size)
|
||||
log.debug("Starting parallel download: "
|
||||
f"{connection_count} {part_size} {part_count} {file!s}")
|
||||
log.debug(
|
||||
f"Starting parallel download: {connection_count} {part_size} {part_count} {file!s}"
|
||||
)
|
||||
await self._init_download(connection_count, file, part_count, part_size)
|
||||
|
||||
part = 0
|
||||
@@ -245,12 +306,18 @@ class ParallelTransferrer:
|
||||
await self._cleanup()
|
||||
|
||||
|
||||
parallel_transfer_locks: DefaultDict[int, asyncio.Lock] = defaultdict(lambda: asyncio.Lock())
|
||||
parallel_transfer_locks: defaultdict[int, asyncio.Lock] = defaultdict(lambda: asyncio.Lock())
|
||||
|
||||
|
||||
async def parallel_transfer_to_matrix(client: MautrixTelegramClient, intent: IntentAPI,
|
||||
loc_id: str, location: TypeLocation, filename: str,
|
||||
encrypt: bool, parallel_id: int) -> DBTelegramFile:
|
||||
async def parallel_transfer_to_matrix(
|
||||
client: MautrixTelegramClient,
|
||||
intent: IntentAPI,
|
||||
loc_id: str,
|
||||
location: TypeLocation,
|
||||
filename: str,
|
||||
encrypt: bool,
|
||||
parallel_id: int,
|
||||
) -> DBTelegramFile:
|
||||
size = location.size
|
||||
mime_type = location.mime_type
|
||||
dc_id, location = utils.get_input_location(location)
|
||||
@@ -261,6 +328,7 @@ async def parallel_transfer_to_matrix(client: MautrixTelegramClient, intent: Int
|
||||
decryption_info = None
|
||||
up_mime_type = mime_type
|
||||
if encrypt and async_encrypt_attachment:
|
||||
|
||||
async def encrypted(stream):
|
||||
nonlocal decryption_info
|
||||
async for chunk in async_encrypt_attachment(stream):
|
||||
@@ -271,17 +339,27 @@ async def parallel_transfer_to_matrix(client: MautrixTelegramClient, intent: Int
|
||||
|
||||
data = encrypted(data)
|
||||
up_mime_type = "application/octet-stream"
|
||||
content_uri = await intent.upload_media(data, mime_type=up_mime_type, filename=filename,
|
||||
size=size if not encrypt else None)
|
||||
content_uri = await intent.upload_media(
|
||||
data, mime_type=up_mime_type, filename=filename, size=size if not encrypt else None
|
||||
)
|
||||
if decryption_info:
|
||||
decryption_info.url = content_uri
|
||||
return DBTelegramFile(id=loc_id, mxc=content_uri, mime_type=mime_type,
|
||||
was_converted=False, timestamp=int(time.time()), size=size,
|
||||
width=None, height=None, decryption_info=decryption_info)
|
||||
return DBTelegramFile(
|
||||
id=loc_id,
|
||||
mxc=content_uri,
|
||||
mime_type=mime_type,
|
||||
was_converted=False,
|
||||
timestamp=int(time.time()),
|
||||
size=size,
|
||||
width=None,
|
||||
height=None,
|
||||
decryption_info=decryption_info,
|
||||
)
|
||||
|
||||
|
||||
async def _internal_transfer_to_telegram(client: MautrixTelegramClient, response: ClientResponse
|
||||
) -> Tuple[TypeInputFile, int]:
|
||||
async def _internal_transfer_to_telegram(
|
||||
client: MautrixTelegramClient, response: ClientResponse
|
||||
) -> tuple[TypeInputFile, int]:
|
||||
file_id = helpers.generate_random_long()
|
||||
file_size = response.content_length
|
||||
|
||||
@@ -313,9 +391,9 @@ async def _internal_transfer_to_telegram(client: MautrixTelegramClient, response
|
||||
return InputFile(file_id, part_count, "upload", hash_md5.hexdigest()), file_size
|
||||
|
||||
|
||||
async def parallel_transfer_to_telegram(client: MautrixTelegramClient, intent: IntentAPI,
|
||||
uri: ContentURI, parallel_id: int
|
||||
) -> Tuple[TypeInputFile, int]:
|
||||
async def parallel_transfer_to_telegram(
|
||||
client: MautrixTelegramClient, intent: IntentAPI, uri: ContentURI, parallel_id: int
|
||||
) -> tuple[TypeInputFile, int]:
|
||||
url = intent.api.get_download_url(uri)
|
||||
async with parallel_transfer_locks[parallel_id]:
|
||||
async with intent.api.session.get(url) as response:
|
||||
|
||||
Reference in New Issue
Block a user