Blacken and isort code

This commit is contained in:
Tulir Asokan
2021-12-21 01:36:24 +02:00
parent f2af17d359
commit 6d25e9687e
55 changed files with 3752 additions and 2018 deletions
+156 -78
View File
@@ -13,34 +13,45 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional, List, AsyncGenerator, Union, Awaitable, DefaultDict, Tuple, cast
from __future__ import annotations
from typing import AsyncGenerator, Awaitable, Union, cast
from collections import defaultdict
import hashlib
import asyncio
import hashlib
import logging
import time
import math
import time
from aiohttp import ClientResponse
from telethon.tl.types import (Document, InputFileLocation, InputDocumentFileLocation,
InputPhotoFileLocation, InputPeerPhotoFileLocation, TypeInputFile,
InputFileBig, InputFile)
from telethon.tl.functions.auth import ExportAuthorizationRequest, ImportAuthorizationRequest
from telethon.tl.functions import InvokeWithLayerRequest
from telethon.tl.functions.upload import (GetFileRequest, SaveFilePartRequest,
SaveBigFilePartRequest)
from telethon.tl.alltlobjects import LAYER
from telethon.network import MTProtoSender
from telethon import helpers, utils
from telethon.crypto import AuthKey
from telethon import utils, helpers
from telethon.network import MTProtoSender
from telethon.tl.alltlobjects import LAYER
from telethon.tl.functions import InvokeWithLayerRequest
from telethon.tl.functions.auth import ExportAuthorizationRequest, ImportAuthorizationRequest
from telethon.tl.functions.upload import (
GetFileRequest,
SaveBigFilePartRequest,
SaveFilePartRequest,
)
from telethon.tl.types import (
Document,
InputDocumentFileLocation,
InputFile,
InputFileBig,
InputFileLocation,
InputPeerPhotoFileLocation,
InputPhotoFileLocation,
TypeInputFile,
)
from mautrix.appservice import IntentAPI
from mautrix.types import ContentURI, EncryptedFile
from mautrix.util.logging import TraceLogger
from ..tgclient import MautrixTelegramClient
from ..db import TelegramFile as DBTelegramFile
from ..tgclient import MautrixTelegramClient
try:
from mautrix.crypto.attachments import async_encrypt_attachment
@@ -49,8 +60,13 @@ except ImportError:
log: TraceLogger = cast(TraceLogger, logging.getLogger("mau.util"))
TypeLocation = Union[Document, InputDocumentFileLocation, InputPeerPhotoFileLocation,
InputFileLocation, InputPhotoFileLocation]
TypeLocation = Union[
Document,
InputDocumentFileLocation,
InputPeerPhotoFileLocation,
InputFileLocation,
InputPhotoFileLocation,
]
class DownloadSender:
@@ -59,14 +75,21 @@ class DownloadSender:
remaining: int
stride: int
def __init__(self, sender: MTProtoSender, file: TypeLocation, offset: int, limit: int,
stride: int, count: int) -> None:
def __init__(
self,
sender: MTProtoSender,
file: TypeLocation,
offset: int,
limit: int,
stride: int,
count: int,
) -> None:
self.sender = sender
self.request = GetFileRequest(file, offset=offset, limit=limit)
self.stride = stride
self.remaining = count
async def next(self) -> Optional[bytes]:
async def next(self) -> bytes | None:
if not self.remaining:
return None
result = await self.sender.send(self.request)
@@ -80,14 +103,22 @@ class DownloadSender:
class UploadSender:
sender: MTProtoSender
request: Union[SaveFilePartRequest, SaveBigFilePartRequest]
request: SaveFilePartRequest < SaveBigFilePartRequest
part_count: int
stride: int
previous: Optional[asyncio.Task]
previous: asyncio.Task | None
loop: asyncio.AbstractEventLoop
def __init__(self, sender: MTProtoSender, file_id: int, part_count: int, big: bool, index: int,
stride: int, loop: asyncio.AbstractEventLoop) -> None:
def __init__(
self,
sender: MTProtoSender,
file_id: int,
part_count: int,
big: bool,
index: int,
stride: int,
loop: asyncio.AbstractEventLoop,
) -> None:
self.sender = sender
self.part_count = part_count
if big:
@@ -105,8 +136,10 @@ class UploadSender:
async def _next(self, data: bytes) -> None:
self.request.bytes = data
log.trace(f"Sending file part {self.request.file_part}/{self.part_count}"
f" with {len(data)} bytes")
log.trace(
f"Sending file part {self.request.file_part}/{self.part_count}"
f" with {len(data)} bytes"
)
await self.sender.send(self.request)
self.request.file_part += self.stride
@@ -120,16 +153,17 @@ class ParallelTransferrer:
client: MautrixTelegramClient
loop: asyncio.AbstractEventLoop
dc_id: int
senders: Optional[List[Union[DownloadSender, UploadSender]]]
senders: list[DownloadSender | UploadSender] | None
auth_key: AuthKey
upload_ticker: int
def __init__(self, client: MautrixTelegramClient, dc_id: Optional[int] = None) -> None:
def __init__(self, client: MautrixTelegramClient, dc_id: int | None = None) -> None:
self.client = client
self.loop = self.client.loop
self.dc_id = dc_id or self.client.session.dc_id
self.auth_key = (None if dc_id and self.client.session.dc_id != dc_id
else self.client.session.auth_key)
self.auth_key = (
None if dc_id and self.client.session.dc_id != dc_id else self.client.session.auth_key
)
self.senders = None
self.upload_ticker = 0
@@ -138,14 +172,16 @@ class ParallelTransferrer:
self.senders = None
@staticmethod
def _get_connection_count(file_size: int, max_count: int = 20,
full_size: int = 100 * 1024 * 1024) -> int:
def _get_connection_count(
file_size: int, max_count: int = 20, full_size: int = 100 * 1024 * 1024
) -> int:
if file_size > full_size:
return max_count
return math.ceil((file_size / full_size) * max_count)
async def _init_download(self, connections: int, file: TypeLocation, part_count: int,
part_size: int) -> None:
async def _init_download(
self, connections: int, file: TypeLocation, part_count: int, part_size: int
) -> None:
minimum, remainder = divmod(part_count, connections)
def get_part_count() -> int:
@@ -158,52 +194,72 @@ class ParallelTransferrer:
# The first cross-DC sender will export+import the authorization, so we always create it
# before creating any other senders.
self.senders = [
await self._create_download_sender(file, 0, part_size, connections * part_size,
get_part_count()),
await self._create_download_sender(
file, 0, part_size, connections * part_size, get_part_count()
),
*await asyncio.gather(
*(self._create_download_sender(file, i, part_size, connections * part_size,
get_part_count())
for i in range(1, connections)))
*(
self._create_download_sender(
file, i, part_size, connections * part_size, get_part_count()
)
for i in range(1, connections)
)
),
]
async def _create_download_sender(self, file: TypeLocation, index: int, part_size: int,
stride: int,
part_count: int) -> DownloadSender:
return DownloadSender(await self._create_sender(), file, index * part_size, part_size,
stride, part_count)
async def _create_download_sender(
self, file: TypeLocation, index: int, part_size: int, stride: int, part_count: int
) -> DownloadSender:
return DownloadSender(
await self._create_sender(), file, index * part_size, part_size, stride, part_count
)
async def _init_upload(self, connections: int, file_id: int, part_count: int, big: bool
) -> None:
async def _init_upload(
self, connections: int, file_id: int, part_count: int, big: bool
) -> None:
self.senders = [
await self._create_upload_sender(file_id, part_count, big, 0, connections),
*await asyncio.gather(
*(self._create_upload_sender(file_id, part_count, big, i, connections)
for i in range(1, connections)))
*(
self._create_upload_sender(file_id, part_count, big, i, connections)
for i in range(1, connections)
)
),
]
async def _create_upload_sender(self, file_id: int, part_count: int, big: bool, index: int,
stride: int) -> UploadSender:
return UploadSender(await self._create_sender(), file_id, part_count, big, index, stride,
loop=self.loop)
async def _create_upload_sender(
self, file_id: int, part_count: int, big: bool, index: int, stride: int
) -> UploadSender:
return UploadSender(
await self._create_sender(), file_id, part_count, big, index, stride, loop=self.loop
)
async def _create_sender(self) -> MTProtoSender:
dc = await self.client._get_dc(self.dc_id)
sender = MTProtoSender(self.auth_key, loggers=self.client._log)
await sender.connect(self.client._connection(dc.ip_address, dc.port, dc.id,
loggers=self.client._log,
proxy=self.client._proxy))
await sender.connect(
self.client._connection(
dc.ip_address, dc.port, dc.id, loggers=self.client._log, proxy=self.client._proxy
)
)
if not self.auth_key:
log.debug(f"Exporting auth to DC {self.dc_id}")
auth = await self.client(ExportAuthorizationRequest(self.dc_id))
self.client._init_request.query = ImportAuthorizationRequest(id=auth.id,
bytes=auth.bytes)
self.client._init_request.query = ImportAuthorizationRequest(
id=auth.id, bytes=auth.bytes
)
req = InvokeWithLayerRequest(LAYER, self.client._init_request)
await sender.send(req)
self.auth_key = sender.auth_key
return sender
async def init_upload(self, file_id: int, file_size: int, part_size_kb: Optional[float] = None,
connection_count: Optional[int] = None) -> Tuple[int, int, bool]:
async def init_upload(
self,
file_id: int,
file_size: int,
part_size_kb: float | None = None,
connection_count: int | None = None,
) -> tuple[int, int, bool]:
connection_count = connection_count or self._get_connection_count(file_size)
part_size = (part_size_kb or utils.get_appropriated_part_size(file_size)) * 1024
part_count = (file_size + part_size - 1) // part_size
@@ -218,14 +274,19 @@ class ParallelTransferrer:
async def finish_upload(self) -> None:
await self._cleanup()
async def download(self, file: TypeLocation, file_size: int,
part_size_kb: Optional[float] = None,
connection_count: Optional[int] = None) -> AsyncGenerator[bytes, None]:
async def download(
self,
file: TypeLocation,
file_size: int,
part_size_kb: float | None = None,
connection_count: int | None = None,
) -> AsyncGenerator[bytes, None]:
connection_count = connection_count or self._get_connection_count(file_size)
part_size = (part_size_kb or utils.get_appropriated_part_size(file_size)) * 1024
part_count = math.ceil(file_size / part_size)
log.debug("Starting parallel download: "
f"{connection_count} {part_size} {part_count} {file!s}")
log.debug(
f"Starting parallel download: {connection_count} {part_size} {part_count} {file!s}"
)
await self._init_download(connection_count, file, part_count, part_size)
part = 0
@@ -245,12 +306,18 @@ class ParallelTransferrer:
await self._cleanup()
parallel_transfer_locks: DefaultDict[int, asyncio.Lock] = defaultdict(lambda: asyncio.Lock())
parallel_transfer_locks: defaultdict[int, asyncio.Lock] = defaultdict(lambda: asyncio.Lock())
async def parallel_transfer_to_matrix(client: MautrixTelegramClient, intent: IntentAPI,
loc_id: str, location: TypeLocation, filename: str,
encrypt: bool, parallel_id: int) -> DBTelegramFile:
async def parallel_transfer_to_matrix(
client: MautrixTelegramClient,
intent: IntentAPI,
loc_id: str,
location: TypeLocation,
filename: str,
encrypt: bool,
parallel_id: int,
) -> DBTelegramFile:
size = location.size
mime_type = location.mime_type
dc_id, location = utils.get_input_location(location)
@@ -261,6 +328,7 @@ async def parallel_transfer_to_matrix(client: MautrixTelegramClient, intent: Int
decryption_info = None
up_mime_type = mime_type
if encrypt and async_encrypt_attachment:
async def encrypted(stream):
nonlocal decryption_info
async for chunk in async_encrypt_attachment(stream):
@@ -271,17 +339,27 @@ async def parallel_transfer_to_matrix(client: MautrixTelegramClient, intent: Int
data = encrypted(data)
up_mime_type = "application/octet-stream"
content_uri = await intent.upload_media(data, mime_type=up_mime_type, filename=filename,
size=size if not encrypt else None)
content_uri = await intent.upload_media(
data, mime_type=up_mime_type, filename=filename, size=size if not encrypt else None
)
if decryption_info:
decryption_info.url = content_uri
return DBTelegramFile(id=loc_id, mxc=content_uri, mime_type=mime_type,
was_converted=False, timestamp=int(time.time()), size=size,
width=None, height=None, decryption_info=decryption_info)
return DBTelegramFile(
id=loc_id,
mxc=content_uri,
mime_type=mime_type,
was_converted=False,
timestamp=int(time.time()),
size=size,
width=None,
height=None,
decryption_info=decryption_info,
)
async def _internal_transfer_to_telegram(client: MautrixTelegramClient, response: ClientResponse
) -> Tuple[TypeInputFile, int]:
async def _internal_transfer_to_telegram(
client: MautrixTelegramClient, response: ClientResponse
) -> tuple[TypeInputFile, int]:
file_id = helpers.generate_random_long()
file_size = response.content_length
@@ -313,9 +391,9 @@ async def _internal_transfer_to_telegram(client: MautrixTelegramClient, response
return InputFile(file_id, part_count, "upload", hash_md5.hexdigest()), file_size
async def parallel_transfer_to_telegram(client: MautrixTelegramClient, intent: IntentAPI,
uri: ContentURI, parallel_id: int
) -> Tuple[TypeInputFile, int]:
async def parallel_transfer_to_telegram(
client: MautrixTelegramClient, intent: IntentAPI, uri: ContentURI, parallel_id: int
) -> tuple[TypeInputFile, int]:
url = intent.api.get_download_url(uri)
async with parallel_transfer_locks[parallel_id]:
async with intent.api.session.get(url) as response: