Blacken and isort code

This commit is contained in:
Tulir Asokan
2021-12-21 01:36:24 +02:00
parent f2af17d359
commit 6d25e9687e
55 changed files with 3752 additions and 2018 deletions
+5 -5
View File
@@ -1,7 +1,7 @@
from .file_transfer import transfer_file_to_matrix, convert_image
from .parallel_file_transfer import parallel_transfer_to_telegram
from .recursive_dict import recursive_del, recursive_set, recursive_get
from .color_log import ColorFormatter
from .send_lock import PortalSendLock
from .deduplication import PortalDedup
from .media_fallback import make_dice_event_content, make_contact_event_content
from .file_transfer import convert_image, transfer_file_to_matrix
from .media_fallback import make_contact_event_content, make_dice_event_content
from .parallel_file_transfer import parallel_transfer_to_telegram
from .recursive_dict import recursive_del, recursive_get, recursive_set
from .send_lock import PortalSendLock
+12 -6
View File
@@ -1,5 +1,5 @@
# mautrix-telegram - A Matrix-Telegram puppeting bridge
# Copyright (C) 2019 Tulir Asokan
# Copyright (C) 2021 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@@ -13,8 +13,12 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from mautrix.util.logging.color import (ColorFormatter as BaseColorFormatter,
PREFIX, MXID_COLOR, RESET)
from mautrix.util.logging.color import (
MXID_COLOR,
PREFIX,
RESET,
ColorFormatter as BaseColorFormatter,
)
TELETHON_COLOR = PREFIX + "35;1m" # magenta
TELETHON_MODULE_COLOR = PREFIX + "35m"
@@ -24,7 +28,9 @@ class ColorFormatter(BaseColorFormatter):
def _color_name(self, module: str) -> str:
if module.startswith("telethon"):
prefix, user_id, module = module.split(".", 2)
return (f"{TELETHON_COLOR}{prefix}{RESET}."
f"{MXID_COLOR}{user_id}{RESET}."
f"{TELETHON_MODULE_COLOR}{module}{RESET}")
return (
f"{TELETHON_COLOR}{prefix}{RESET}."
f"{MXID_COLOR}{user_id}{RESET}."
f"{TELETHON_MODULE_COLOR}{module}{RESET}"
)
return super()._color_name(module)
+35 -27
View File
@@ -13,22 +13,29 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional, Deque, Dict, Tuple, TYPE_CHECKING
from __future__ import annotations
from typing import Tuple
from collections import deque
import hashlib
from telethon.tl.patched import Message, MessageService
from telethon.tl.types import (MessageMediaContact, MessageMediaDocument, MessageMediaGeo,
MessageMediaPhoto, TypeMessage, TypeUpdates, UpdateNewMessage,
UpdateNewChannelMessage)
from telethon.tl.types import (
MessageMediaContact,
MessageMediaDocument,
MessageMediaGeo,
MessageMediaPhoto,
TypeMessage,
TypeUpdates,
UpdateNewChannelMessage,
UpdateNewMessage,
)
from mautrix.types import EventID
from .. import portal as po
from ..types import TelegramID
if TYPE_CHECKING:
from ..portal import Portal
DedupMXID = Tuple[EventID, TelegramID]
@@ -36,12 +43,12 @@ class PortalDedup:
pre_db_check: bool = False
cache_queue_length: int = 20
_dedup: Deque[str]
_dedup_mxid: Dict[str, DedupMXID]
_dedup_action: Deque[str]
_portal: 'Portal'
_dedup: deque[str]
_dedup_mxid: dict[str, DedupMXID]
_dedup_action: deque[str]
_portal: po.Portal
def __init__(self, portal: 'Portal') -> None:
def __init__(self, portal: po.Portal) -> None:
self._dedup = deque()
self._dedup_mxid = {}
self._dedup_action = deque()
@@ -49,7 +56,7 @@ class PortalDedup:
@property
def _always_force_hash(self) -> bool:
return self._portal.peer_type == 'chat'
return self._portal.peer_type == "chat"
@staticmethod
def _hash_event(event: TypeMessage) -> str:
@@ -73,10 +80,7 @@ class PortalDedup:
}[type(event.media)](event.media)
except KeyError:
pass
return hashlib.md5("-"
.join(str(a) for a in hash_content)
.encode("utf-8")
).hexdigest()
return hashlib.md5("-".join(str(a) for a in hash_content).encode("utf-8")).hexdigest()
def check_action(self, event: TypeMessage) -> bool:
evt_hash = self._hash_event(event) if self._always_force_hash else event.id
@@ -89,9 +93,13 @@ class PortalDedup:
self._dedup_action.popleft()
return False
def update(self, event: TypeMessage, mxid: DedupMXID = None,
expected_mxid: Optional[DedupMXID] = None, force_hash: bool = False
) -> Optional[DedupMXID]:
def update(
self,
event: TypeMessage,
mxid: DedupMXID = None,
expected_mxid: DedupMXID | None = None,
force_hash: bool = False,
) -> DedupMXID | None:
evt_hash = self._hash_event(event) if self._always_force_hash or force_hash else event.id
try:
found_mxid = self._dedup_mxid[evt_hash]
@@ -103,11 +111,10 @@ class PortalDedup:
self._dedup_mxid[evt_hash] = mxid
return None
def check(self, event: TypeMessage, mxid: DedupMXID = None, force_hash: bool = False
) -> Optional[DedupMXID]:
evt_hash = (self._hash_event(event)
if self._always_force_hash or force_hash
else event.id)
def check(
self, event: TypeMessage, mxid: DedupMXID = None, force_hash: bool = False
) -> DedupMXID | None:
evt_hash = self._hash_event(event) if self._always_force_hash or force_hash else event.id
if evt_hash in self._dedup:
return self._dedup_mxid[evt_hash]
@@ -120,7 +127,8 @@ class PortalDedup:
def register_outgoing_actions(self, response: TypeUpdates) -> None:
for update in response.updates:
check_dedup = (isinstance(update, (UpdateNewMessage, UpdateNewChannelMessage))
and isinstance(update.message, MessageService))
check_dedup = isinstance(
update, (UpdateNewMessage, UpdateNewChannelMessage)
) and isinstance(update.message, MessageService)
if check_dedup:
self.check(update.message)
+149 -68
View File
@@ -1,5 +1,5 @@
# mautrix-telegram - A Matrix-Telegram puppeting bridge
# Copyright (C) 2019 Tulir Asokan
# Copyright (C) 2021 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@@ -13,27 +13,40 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional, Tuple, Union, Dict
from __future__ import annotations
from typing import Optional, Union
from io import BytesIO
import time
import logging
import asyncio
import tempfile
import magic
from asyncpg import UniqueViolationError
from sqlite3 import IntegrityError
import asyncio
import logging
import tempfile
import time
from telethon.tl.types import (Document, InputFileLocation, InputDocumentFileLocation,
TypePhotoSize, PhotoSize, PhotoCachedSize, InputPhotoFileLocation,
InputPeerPhotoFileLocation)
from telethon.errors import (AuthBytesInvalidError, AuthKeyInvalidError, LocationInvalidError,
SecurityError, FileIdInvalidError)
from asyncpg import UniqueViolationError
from telethon.errors import (
AuthBytesInvalidError,
AuthKeyInvalidError,
FileIdInvalidError,
LocationInvalidError,
SecurityError,
)
from telethon.tl.types import (
Document,
InputDocumentFileLocation,
InputFileLocation,
InputPeerPhotoFileLocation,
InputPhotoFileLocation,
PhotoCachedSize,
PhotoSize,
TypePhotoSize,
)
import magic
from mautrix.appservice import IntentAPI
from ..tgclient import MautrixTelegramClient
from ..db import TelegramFile as DBTelegramFile
from ..tgclient import MautrixTelegramClient
from ..util import sane_mimetypes
from .parallel_file_transfer import parallel_transfer_to_matrix
from .tgs_converter import convert_tgs_to
@@ -55,13 +68,21 @@ except ImportError:
log: logging.Logger = logging.getLogger("mau.util")
TypeLocation = Union[Document, InputDocumentFileLocation, InputPeerPhotoFileLocation,
InputFileLocation, InputPhotoFileLocation]
TypeLocation = Union[
Document,
InputDocumentFileLocation,
InputPeerPhotoFileLocation,
InputFileLocation,
InputPhotoFileLocation,
]
def convert_image(file: bytes, source_mime: str = "image/webp", target_type: str = "png",
thumbnail_to: Optional[Tuple[int, int]] = None
) -> Tuple[str, bytes, Optional[int], Optional[int]]:
def convert_image(
file: bytes,
source_mime: str = "image/webp",
target_type: str = "png",
thumbnail_to: tuple[int, int] | None = None,
) -> tuple[str, bytes, int | None, int | None]:
if not Image:
return source_mime, file, None, None
try:
@@ -77,8 +98,12 @@ def convert_image(file: bytes, source_mime: str = "image/webp", target_type: str
return source_mime, file, None, None
def _read_video_thumbnail(data: bytes, video_ext: str = "mp4", frame_ext: str = "png",
max_size: Tuple[int, int] = (1024, 720)) -> Tuple[bytes, int, int]:
def _read_video_thumbnail(
data: bytes,
video_ext: str = "mp4",
frame_ext: str = "png",
max_size: tuple[int, int] = (1024, 720),
) -> tuple[bytes, int, int]:
with tempfile.NamedTemporaryFile(prefix="mxtg_video_", suffix=f".{video_ext}") as file:
# We don't have any way to read the video from memory, so save it to disk.
file.write(data)
@@ -109,11 +134,17 @@ def _location_to_id(location: TypeLocation) -> str:
return str(location.photo_id)
async def transfer_thumbnail_to_matrix(client: MautrixTelegramClient, intent: IntentAPI,
thumbnail_loc: TypeLocation, mime_type: str, encrypt: bool,
video: Optional[bytes], custom_data: Optional[bytes] = None,
width: Optional[int] = None, height: [int] = None
) -> Optional[DBTelegramFile]:
async def transfer_thumbnail_to_matrix(
client: MautrixTelegramClient,
intent: IntentAPI,
thumbnail_loc: TypeLocation,
mime_type: str,
encrypt: bool,
video: bytes | None,
custom_data: bytes | None = None,
width: int | None = None,
height: int | None = None,
) -> DBTelegramFile | None:
if not Image or not VideoFileClip:
return None
@@ -151,28 +182,45 @@ async def transfer_thumbnail_to_matrix(client: MautrixTelegramClient, intent: In
if decryption_info:
decryption_info.url = content_uri
db_file = DBTelegramFile(id=loc_id, mxc=content_uri, mime_type=mime_type,
was_converted=False, timestamp=int(time.time()), size=len(file),
width=width, height=height, decryption_info=decryption_info)
db_file = DBTelegramFile(
id=loc_id,
mxc=content_uri,
mime_type=mime_type,
was_converted=False,
timestamp=int(time.time()),
size=len(file),
width=width,
height=height,
decryption_info=decryption_info,
)
try:
await db_file.insert()
except (UniqueViolationError, IntegrityError) as e:
log.exception(f"{e.__class__.__name__} while saving transferred file thumbnail data. "
"This was probably caused by two simultaneous transfers of the same file, "
"and might (but probably won't) cause problems with thumbnails or something.")
log.exception(
f"{e.__class__.__name__} while saving transferred file thumbnail data. "
"This was probably caused by two simultaneous transfers of the same file, "
"and might (but probably won't) cause problems with thumbnails or something."
)
return db_file
transfer_locks: Dict[str, asyncio.Lock] = {}
transfer_locks: dict[str, asyncio.Lock] = {}
TypeThumbnail = Optional[Union[TypeLocation, TypePhotoSize]]
async def transfer_file_to_matrix(client: MautrixTelegramClient, intent: IntentAPI,
location: TypeLocation, thumbnail: TypeThumbnail = None, *,
is_sticker: bool = False, tgs_convert: Optional[dict] = None,
filename: Optional[str] = None, encrypt: bool = False,
parallel_id: Optional[int] = None) -> Optional[DBTelegramFile]:
async def transfer_file_to_matrix(
client: MautrixTelegramClient,
intent: IntentAPI,
location: TypeLocation,
thumbnail: TypeThumbnail = None,
*,
is_sticker: bool = False,
tgs_convert: dict | None = None,
filename: str | None = None,
encrypt: bool = False,
parallel_id: int | None = None,
) -> DBTelegramFile | None:
location_id = _location_to_id(location)
if not location_id:
return None
@@ -187,17 +235,32 @@ async def transfer_file_to_matrix(client: MautrixTelegramClient, intent: IntentA
lock = asyncio.Lock()
transfer_locks[location_id] = lock
async with lock:
return await _unlocked_transfer_file_to_matrix(client, intent, location_id, location,
thumbnail, is_sticker, tgs_convert,
filename, encrypt, parallel_id)
return await _unlocked_transfer_file_to_matrix(
client,
intent,
location_id,
location,
thumbnail,
is_sticker,
tgs_convert,
filename,
encrypt,
parallel_id,
)
async def _unlocked_transfer_file_to_matrix(client: MautrixTelegramClient, intent: IntentAPI,
loc_id: str, location: TypeLocation,
thumbnail: TypeThumbnail, is_sticker: bool,
tgs_convert: Optional[dict], filename: Optional[str],
encrypt: bool, parallel_id: Optional[int]
) -> Optional[DBTelegramFile]:
async def _unlocked_transfer_file_to_matrix(
client: MautrixTelegramClient,
intent: IntentAPI,
loc_id: str,
location: TypeLocation,
thumbnail: TypeThumbnail,
is_sticker: bool,
tgs_convert: dict | None,
filename: str | None,
encrypt: bool,
parallel_id: int | None,
) -> DBTelegramFile | None:
db_file = await DBTelegramFile.get(loc_id)
if db_file:
return db_file
@@ -205,8 +268,9 @@ async def _unlocked_transfer_file_to_matrix(client: MautrixTelegramClient, inten
converted_anim = None
if parallel_id and isinstance(location, Document) and (not is_sticker or not tgs_convert):
db_file = await parallel_transfer_to_matrix(client, intent, loc_id, location, filename,
encrypt, parallel_id)
db_file = await parallel_transfer_to_matrix(
client, intent, loc_id, location, filename, encrypt, parallel_id
)
mime_type = location.mime_type
file = None
else:
@@ -223,12 +287,13 @@ async def _unlocked_transfer_file_to_matrix(client: MautrixTelegramClient, inten
image_converted = False
# A weird bug in alpine/magic makes it return application/octet-stream for gzips...
is_tgs = (mime_type == "application/gzip"
or (mime_type == "application/octet-stream"
and magic.from_buffer(file).startswith("gzip")))
is_tgs = mime_type == "application/gzip" or (
mime_type == "application/octet-stream" and magic.from_buffer(file).startswith("gzip")
)
if is_sticker and tgs_convert and is_tgs:
converted_anim = await convert_tgs_to(file, tgs_convert["target"],
**tgs_convert["args"])
converted_anim = await convert_tgs_to(
file, tgs_convert["target"], **tgs_convert["args"]
)
mime_type = converted_anim.mime
file = converted_anim.data
width, height = converted_anim.width, converted_anim.height
@@ -244,29 +309,45 @@ async def _unlocked_transfer_file_to_matrix(client: MautrixTelegramClient, inten
if decryption_info:
decryption_info.url = content_uri
db_file = DBTelegramFile(id=loc_id, mxc=content_uri, decryption_info=decryption_info,
mime_type=mime_type, was_converted=image_converted,
timestamp=int(time.time()), size=len(file),
width=width, height=height)
db_file = DBTelegramFile(
id=loc_id,
mxc=content_uri,
decryption_info=decryption_info,
mime_type=mime_type,
was_converted=image_converted,
timestamp=int(time.time()),
size=len(file),
width=width,
height=height,
)
if thumbnail and (mime_type.startswith("video/") or mime_type == "image/gif"):
if isinstance(thumbnail, (PhotoSize, PhotoCachedSize)):
thumbnail = thumbnail.location
try:
db_file.thumbnail = await transfer_thumbnail_to_matrix(client, intent, thumbnail,
video=file, mime_type=mime_type,
encrypt=encrypt)
db_file.thumbnail = await transfer_thumbnail_to_matrix(
client, intent, thumbnail, video=file, mime_type=mime_type, encrypt=encrypt
)
except FileIdInvalidError:
log.warning(f"Failed to transfer thumbnail for {thumbnail!s}", exc_info=True)
elif converted_anim and converted_anim.thumbnail_data:
db_file.thumbnail = await transfer_thumbnail_to_matrix(
client, intent, location, video=None, encrypt=encrypt,
custom_data=converted_anim.thumbnail_data, mime_type=converted_anim.thumbnail_mime,
width=converted_anim.width, height=converted_anim.height)
client,
intent,
location,
video=None,
encrypt=encrypt,
custom_data=converted_anim.thumbnail_data,
mime_type=converted_anim.thumbnail_mime,
width=converted_anim.width,
height=converted_anim.height,
)
try:
await db_file.insert()
except (UniqueViolationError, IntegrityError) as e:
log.exception(f"{e.__class__.__name__} while saving transferred file data. "
"This was probably caused by two simultaneous transfers of the same file, "
"and should not cause any problems.")
log.exception(
f"{e.__class__.__name__} while saving transferred file data. "
"This was probably caused by two simultaneous transfers of the same file, "
"and should not cause any problems."
)
return db_file
+8 -7
View File
@@ -17,11 +17,11 @@ from __future__ import annotations
import html
from telethon.tl.types import MessageMediaDice, MessageMediaContact, PeerUser
from telethon.tl.types import MessageMediaContact, MessageMediaDice, PeerUser
from mautrix.types import TextMessageEventContent, MessageType, Format
from mautrix.types import Format, MessageType, TextMessageEventContent
from .. import puppet as pu, abstract_user as au
from .. import abstract_user as au, puppet as pu
from ..types import TelegramID
try:
@@ -36,7 +36,7 @@ def _format_dice(roll: MessageMediaDice) -> str:
0: "\U0001F36B", # "🍫",
1: "\U0001F352", # "🍒",
2: "\U0001F34B", # "🍋",
3: "7\ufe0f\u20e3" # "7️⃣",
3: "7\ufe0f\u20e3", # "7️⃣",
}
res = roll.value - 1
slot1, slot2, slot3 = emojis[res % 4], emojis[res // 4 % 4], emojis[res // 16]
@@ -82,11 +82,12 @@ def make_dice_event_content(roll: MessageMediaDice) -> TextMessageEventContent:
"\U0001F3C0": " Basketball throw",
"\U0001F3B0": " Slot machine",
"\U0001F3B3": " Bowling",
"\u26BD": " Football kick"
"\u26BD": " Football kick",
}
text = f"{roll.emoticon}{emoji_text.get(roll.emoticon, '')} result: {_format_dice(roll)}"
content = TextMessageEventContent(msgtype=MessageType.TEXT, format=Format.HTML, body=text,
formatted_body=f"<h4>{text}</h4>")
content = TextMessageEventContent(
msgtype=MessageType.TEXT, format=Format.HTML, body=text, formatted_body=f"<h4>{text}</h4>"
)
content["net.maunium.telegram.dice"] = {"emoticon": roll.emoticon, "value": roll.value}
return content
+156 -78
View File
@@ -13,34 +13,45 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional, List, AsyncGenerator, Union, Awaitable, DefaultDict, Tuple, cast
from __future__ import annotations
from typing import AsyncGenerator, Awaitable, Union, cast
from collections import defaultdict
import hashlib
import asyncio
import hashlib
import logging
import time
import math
import time
from aiohttp import ClientResponse
from telethon.tl.types import (Document, InputFileLocation, InputDocumentFileLocation,
InputPhotoFileLocation, InputPeerPhotoFileLocation, TypeInputFile,
InputFileBig, InputFile)
from telethon.tl.functions.auth import ExportAuthorizationRequest, ImportAuthorizationRequest
from telethon.tl.functions import InvokeWithLayerRequest
from telethon.tl.functions.upload import (GetFileRequest, SaveFilePartRequest,
SaveBigFilePartRequest)
from telethon.tl.alltlobjects import LAYER
from telethon.network import MTProtoSender
from telethon import helpers, utils
from telethon.crypto import AuthKey
from telethon import utils, helpers
from telethon.network import MTProtoSender
from telethon.tl.alltlobjects import LAYER
from telethon.tl.functions import InvokeWithLayerRequest
from telethon.tl.functions.auth import ExportAuthorizationRequest, ImportAuthorizationRequest
from telethon.tl.functions.upload import (
GetFileRequest,
SaveBigFilePartRequest,
SaveFilePartRequest,
)
from telethon.tl.types import (
Document,
InputDocumentFileLocation,
InputFile,
InputFileBig,
InputFileLocation,
InputPeerPhotoFileLocation,
InputPhotoFileLocation,
TypeInputFile,
)
from mautrix.appservice import IntentAPI
from mautrix.types import ContentURI, EncryptedFile
from mautrix.util.logging import TraceLogger
from ..tgclient import MautrixTelegramClient
from ..db import TelegramFile as DBTelegramFile
from ..tgclient import MautrixTelegramClient
try:
from mautrix.crypto.attachments import async_encrypt_attachment
@@ -49,8 +60,13 @@ except ImportError:
log: TraceLogger = cast(TraceLogger, logging.getLogger("mau.util"))
TypeLocation = Union[Document, InputDocumentFileLocation, InputPeerPhotoFileLocation,
InputFileLocation, InputPhotoFileLocation]
TypeLocation = Union[
Document,
InputDocumentFileLocation,
InputPeerPhotoFileLocation,
InputFileLocation,
InputPhotoFileLocation,
]
class DownloadSender:
@@ -59,14 +75,21 @@ class DownloadSender:
remaining: int
stride: int
def __init__(self, sender: MTProtoSender, file: TypeLocation, offset: int, limit: int,
stride: int, count: int) -> None:
def __init__(
self,
sender: MTProtoSender,
file: TypeLocation,
offset: int,
limit: int,
stride: int,
count: int,
) -> None:
self.sender = sender
self.request = GetFileRequest(file, offset=offset, limit=limit)
self.stride = stride
self.remaining = count
async def next(self) -> Optional[bytes]:
async def next(self) -> bytes | None:
if not self.remaining:
return None
result = await self.sender.send(self.request)
@@ -80,14 +103,22 @@ class DownloadSender:
class UploadSender:
sender: MTProtoSender
request: Union[SaveFilePartRequest, SaveBigFilePartRequest]
request: SaveFilePartRequest < SaveBigFilePartRequest
part_count: int
stride: int
previous: Optional[asyncio.Task]
previous: asyncio.Task | None
loop: asyncio.AbstractEventLoop
def __init__(self, sender: MTProtoSender, file_id: int, part_count: int, big: bool, index: int,
stride: int, loop: asyncio.AbstractEventLoop) -> None:
def __init__(
self,
sender: MTProtoSender,
file_id: int,
part_count: int,
big: bool,
index: int,
stride: int,
loop: asyncio.AbstractEventLoop,
) -> None:
self.sender = sender
self.part_count = part_count
if big:
@@ -105,8 +136,10 @@ class UploadSender:
async def _next(self, data: bytes) -> None:
self.request.bytes = data
log.trace(f"Sending file part {self.request.file_part}/{self.part_count}"
f" with {len(data)} bytes")
log.trace(
f"Sending file part {self.request.file_part}/{self.part_count}"
f" with {len(data)} bytes"
)
await self.sender.send(self.request)
self.request.file_part += self.stride
@@ -120,16 +153,17 @@ class ParallelTransferrer:
client: MautrixTelegramClient
loop: asyncio.AbstractEventLoop
dc_id: int
senders: Optional[List[Union[DownloadSender, UploadSender]]]
senders: list[DownloadSender | UploadSender] | None
auth_key: AuthKey
upload_ticker: int
def __init__(self, client: MautrixTelegramClient, dc_id: Optional[int] = None) -> None:
def __init__(self, client: MautrixTelegramClient, dc_id: int | None = None) -> None:
self.client = client
self.loop = self.client.loop
self.dc_id = dc_id or self.client.session.dc_id
self.auth_key = (None if dc_id and self.client.session.dc_id != dc_id
else self.client.session.auth_key)
self.auth_key = (
None if dc_id and self.client.session.dc_id != dc_id else self.client.session.auth_key
)
self.senders = None
self.upload_ticker = 0
@@ -138,14 +172,16 @@ class ParallelTransferrer:
self.senders = None
@staticmethod
def _get_connection_count(file_size: int, max_count: int = 20,
full_size: int = 100 * 1024 * 1024) -> int:
def _get_connection_count(
file_size: int, max_count: int = 20, full_size: int = 100 * 1024 * 1024
) -> int:
if file_size > full_size:
return max_count
return math.ceil((file_size / full_size) * max_count)
async def _init_download(self, connections: int, file: TypeLocation, part_count: int,
part_size: int) -> None:
async def _init_download(
self, connections: int, file: TypeLocation, part_count: int, part_size: int
) -> None:
minimum, remainder = divmod(part_count, connections)
def get_part_count() -> int:
@@ -158,52 +194,72 @@ class ParallelTransferrer:
# The first cross-DC sender will export+import the authorization, so we always create it
# before creating any other senders.
self.senders = [
await self._create_download_sender(file, 0, part_size, connections * part_size,
get_part_count()),
await self._create_download_sender(
file, 0, part_size, connections * part_size, get_part_count()
),
*await asyncio.gather(
*(self._create_download_sender(file, i, part_size, connections * part_size,
get_part_count())
for i in range(1, connections)))
*(
self._create_download_sender(
file, i, part_size, connections * part_size, get_part_count()
)
for i in range(1, connections)
)
),
]
async def _create_download_sender(self, file: TypeLocation, index: int, part_size: int,
stride: int,
part_count: int) -> DownloadSender:
return DownloadSender(await self._create_sender(), file, index * part_size, part_size,
stride, part_count)
async def _create_download_sender(
self, file: TypeLocation, index: int, part_size: int, stride: int, part_count: int
) -> DownloadSender:
return DownloadSender(
await self._create_sender(), file, index * part_size, part_size, stride, part_count
)
async def _init_upload(self, connections: int, file_id: int, part_count: int, big: bool
) -> None:
async def _init_upload(
self, connections: int, file_id: int, part_count: int, big: bool
) -> None:
self.senders = [
await self._create_upload_sender(file_id, part_count, big, 0, connections),
*await asyncio.gather(
*(self._create_upload_sender(file_id, part_count, big, i, connections)
for i in range(1, connections)))
*(
self._create_upload_sender(file_id, part_count, big, i, connections)
for i in range(1, connections)
)
),
]
async def _create_upload_sender(self, file_id: int, part_count: int, big: bool, index: int,
stride: int) -> UploadSender:
return UploadSender(await self._create_sender(), file_id, part_count, big, index, stride,
loop=self.loop)
async def _create_upload_sender(
self, file_id: int, part_count: int, big: bool, index: int, stride: int
) -> UploadSender:
return UploadSender(
await self._create_sender(), file_id, part_count, big, index, stride, loop=self.loop
)
async def _create_sender(self) -> MTProtoSender:
dc = await self.client._get_dc(self.dc_id)
sender = MTProtoSender(self.auth_key, loggers=self.client._log)
await sender.connect(self.client._connection(dc.ip_address, dc.port, dc.id,
loggers=self.client._log,
proxy=self.client._proxy))
await sender.connect(
self.client._connection(
dc.ip_address, dc.port, dc.id, loggers=self.client._log, proxy=self.client._proxy
)
)
if not self.auth_key:
log.debug(f"Exporting auth to DC {self.dc_id}")
auth = await self.client(ExportAuthorizationRequest(self.dc_id))
self.client._init_request.query = ImportAuthorizationRequest(id=auth.id,
bytes=auth.bytes)
self.client._init_request.query = ImportAuthorizationRequest(
id=auth.id, bytes=auth.bytes
)
req = InvokeWithLayerRequest(LAYER, self.client._init_request)
await sender.send(req)
self.auth_key = sender.auth_key
return sender
async def init_upload(self, file_id: int, file_size: int, part_size_kb: Optional[float] = None,
connection_count: Optional[int] = None) -> Tuple[int, int, bool]:
async def init_upload(
self,
file_id: int,
file_size: int,
part_size_kb: float | None = None,
connection_count: int | None = None,
) -> tuple[int, int, bool]:
connection_count = connection_count or self._get_connection_count(file_size)
part_size = (part_size_kb or utils.get_appropriated_part_size(file_size)) * 1024
part_count = (file_size + part_size - 1) // part_size
@@ -218,14 +274,19 @@ class ParallelTransferrer:
async def finish_upload(self) -> None:
await self._cleanup()
async def download(self, file: TypeLocation, file_size: int,
part_size_kb: Optional[float] = None,
connection_count: Optional[int] = None) -> AsyncGenerator[bytes, None]:
async def download(
self,
file: TypeLocation,
file_size: int,
part_size_kb: float | None = None,
connection_count: int | None = None,
) -> AsyncGenerator[bytes, None]:
connection_count = connection_count or self._get_connection_count(file_size)
part_size = (part_size_kb or utils.get_appropriated_part_size(file_size)) * 1024
part_count = math.ceil(file_size / part_size)
log.debug("Starting parallel download: "
f"{connection_count} {part_size} {part_count} {file!s}")
log.debug(
f"Starting parallel download: {connection_count} {part_size} {part_count} {file!s}"
)
await self._init_download(connection_count, file, part_count, part_size)
part = 0
@@ -245,12 +306,18 @@ class ParallelTransferrer:
await self._cleanup()
parallel_transfer_locks: DefaultDict[int, asyncio.Lock] = defaultdict(lambda: asyncio.Lock())
parallel_transfer_locks: defaultdict[int, asyncio.Lock] = defaultdict(lambda: asyncio.Lock())
async def parallel_transfer_to_matrix(client: MautrixTelegramClient, intent: IntentAPI,
loc_id: str, location: TypeLocation, filename: str,
encrypt: bool, parallel_id: int) -> DBTelegramFile:
async def parallel_transfer_to_matrix(
client: MautrixTelegramClient,
intent: IntentAPI,
loc_id: str,
location: TypeLocation,
filename: str,
encrypt: bool,
parallel_id: int,
) -> DBTelegramFile:
size = location.size
mime_type = location.mime_type
dc_id, location = utils.get_input_location(location)
@@ -261,6 +328,7 @@ async def parallel_transfer_to_matrix(client: MautrixTelegramClient, intent: Int
decryption_info = None
up_mime_type = mime_type
if encrypt and async_encrypt_attachment:
async def encrypted(stream):
nonlocal decryption_info
async for chunk in async_encrypt_attachment(stream):
@@ -271,17 +339,27 @@ async def parallel_transfer_to_matrix(client: MautrixTelegramClient, intent: Int
data = encrypted(data)
up_mime_type = "application/octet-stream"
content_uri = await intent.upload_media(data, mime_type=up_mime_type, filename=filename,
size=size if not encrypt else None)
content_uri = await intent.upload_media(
data, mime_type=up_mime_type, filename=filename, size=size if not encrypt else None
)
if decryption_info:
decryption_info.url = content_uri
return DBTelegramFile(id=loc_id, mxc=content_uri, mime_type=mime_type,
was_converted=False, timestamp=int(time.time()), size=size,
width=None, height=None, decryption_info=decryption_info)
return DBTelegramFile(
id=loc_id,
mxc=content_uri,
mime_type=mime_type,
was_converted=False,
timestamp=int(time.time()),
size=size,
width=None,
height=None,
decryption_info=decryption_info,
)
async def _internal_transfer_to_telegram(client: MautrixTelegramClient, response: ClientResponse
) -> Tuple[TypeInputFile, int]:
async def _internal_transfer_to_telegram(
client: MautrixTelegramClient, response: ClientResponse
) -> tuple[TypeInputFile, int]:
file_id = helpers.generate_random_long()
file_size = response.content_length
@@ -313,9 +391,9 @@ async def _internal_transfer_to_telegram(client: MautrixTelegramClient, response
return InputFile(file_id, part_count, "upload", hash_md5.hexdigest()), file_size
async def parallel_transfer_to_telegram(client: MautrixTelegramClient, intent: IntentAPI,
uri: ContentURI, parallel_id: int
) -> Tuple[TypeInputFile, int]:
async def parallel_transfer_to_telegram(
client: MautrixTelegramClient, intent: IntentAPI, uri: ContentURI, parallel_id: int
) -> tuple[TypeInputFile, int]:
url = intent.api.get_download_url(uri)
async with parallel_transfer_locks[parallel_id]:
async with intent.api.session.get(url) as response:
+7 -5
View File
@@ -1,5 +1,5 @@
# mautrix-telegram - A Matrix-Telegram puppeting bridge
# Copyright (C) 2019 Tulir Asokan
# Copyright (C) 2021 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@@ -13,12 +13,14 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict, Any
from __future__ import annotations
from typing import Any
from mautrix.util.config import RecursiveDict
def recursive_set(data: Dict[str, Any], key: str, value: Any) -> bool:
def recursive_set(data: dict[str, Any], key: str, value: Any) -> bool:
key, next_key = RecursiveDict.parse_key(key)
if next_key is not None:
if key not in data:
@@ -31,7 +33,7 @@ def recursive_set(data: Dict[str, Any], key: str, value: Any) -> bool:
return True
def recursive_get(data: Dict[str, Any], key: str) -> Any:
def recursive_get(data: dict[str, Any], key: str) -> Any:
key, next_key = RecursiveDict.parse_key(key)
if next_key is not None:
next_data = data.get(key, None)
@@ -41,7 +43,7 @@ def recursive_get(data: Dict[str, Any], key: str) -> Any:
return data.get(key, None)
def recursive_del(data: Dict[str, any], key: str) -> bool:
def recursive_del(data: dict[str, any], key: str) -> bool:
key, next_key = RecursiveDict.parse_key(key)
if next_key is not None:
if key not in data:
+4 -4
View File
@@ -13,7 +13,8 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict
from __future__ import annotations
from asyncio import Lock
from ..types import TelegramID
@@ -28,7 +29,7 @@ class FakeLock:
class PortalSendLock:
_send_locks: Dict[int, Lock]
_send_locks: dict[int, Lock]
_noop_lock: Lock = FakeLock()
def __init__(self) -> None:
@@ -40,5 +41,4 @@ class PortalSendLock:
try:
return self._send_locks[user_id]
except KeyError:
return (self._send_locks.setdefault(user_id, Lock())
if required else self._noop_lock)
return self._send_locks.setdefault(user_id, Lock()) if required else self._noop_lock
+102 -43
View File
@@ -14,11 +14,13 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict, Callable, Awaitable, Optional, Tuple, Any
from __future__ import annotations
from typing import Any, Awaitable, Callable
import asyncio.subprocess
import logging
import shutil
import os.path
import shutil
import tempfile
from attr import dataclass
@@ -30,17 +32,17 @@ log: logging.Logger = logging.getLogger("mau.util.tgs")
class ConvertedSticker:
mime: str
data: bytes
thumbnail_mime: Optional[str] = None
thumbnail_data: Optional[bytes] = None
thumbnail_mime: str | None = None
thumbnail_data: bytes | None = None
width: int = 0
height: int = 0
Converter = Callable[[bytes, int, int, Any], Awaitable[ConvertedSticker]]
converters: Dict[str, Converter] = {}
converters: dict[str, Converter] = {}
def abswhich(program: Optional[str]) -> Optional[str]:
def abswhich(program: str | None) -> str | None:
path = shutil.which(program)
return os.path.abspath(path) if path else None
@@ -49,77 +51,134 @@ lottieconverter = abswhich("lottieconverter")
ffmpeg = abswhich("ffmpeg")
if lottieconverter:
async def tgs_to_png(file: bytes, width: int, height: int, **_: Any) -> ConvertedSticker:
frame = 1
proc = await asyncio.create_subprocess_exec(lottieconverter, "-", "-", "png",
f"{width}x{height}", str(frame),
stdout=asyncio.subprocess.PIPE,
stdin=asyncio.subprocess.PIPE)
proc = await asyncio.create_subprocess_exec(
lottieconverter,
"-",
"-",
"png",
f"{width}x{height}",
str(frame),
stdout=asyncio.subprocess.PIPE,
stdin=asyncio.subprocess.PIPE,
)
stdout, stderr = await proc.communicate(file)
if proc.returncode == 0:
return ConvertedSticker("image/png", stdout)
else:
log.error("lottieconverter error: " + (stderr.decode("utf-8") if stderr is not None
else f"unknown ({proc.returncode})"))
log.error(
"lottieconverter error: "
+ (
stderr.decode("utf-8")
if stderr is not None
else f"unknown ({proc.returncode})"
)
)
return ConvertedSticker("application/gzip", file)
async def tgs_to_gif(file: bytes, width: int, height: int, fps: int = 25,
**_: Any) -> ConvertedSticker:
proc = await asyncio.create_subprocess_exec(lottieconverter, "-", "-", "gif",
f"{width}x{height}", str(fps),
stdout=asyncio.subprocess.PIPE,
stdin=asyncio.subprocess.PIPE)
async def tgs_to_gif(
file: bytes, width: int, height: int, fps: int = 25, **_: Any
) -> ConvertedSticker:
proc = await asyncio.create_subprocess_exec(
lottieconverter,
"-",
"-",
"gif",
f"{width}x{height}",
str(fps),
stdout=asyncio.subprocess.PIPE,
stdin=asyncio.subprocess.PIPE,
)
stdout, stderr = await proc.communicate(file)
if proc.returncode == 0:
return ConvertedSticker("image/gif", stdout)
else:
log.error("lottieconverter error: " + (stderr.decode("utf-8") if stderr is not None
else f"unknown ({proc.returncode})"))
log.error(
"lottieconverter error: "
+ (
stderr.decode("utf-8")
if stderr is not None
else f"unknown ({proc.returncode})"
)
)
return ConvertedSticker("application/gzip", file)
converters["png"] = tgs_to_png
converters["gif"] = tgs_to_gif
if lottieconverter and ffmpeg:
async def tgs_to_webm(file: bytes, width: int, height: int, fps: int = 30,
**_: Any) -> ConvertedSticker:
async def tgs_to_webm(
file: bytes, width: int, height: int, fps: int = 30, **_: Any
) -> ConvertedSticker:
with tempfile.TemporaryDirectory(prefix="tgs_") as tmpdir:
file_template = tmpdir + "/out_"
proc = await asyncio.create_subprocess_exec(lottieconverter, "-", file_template,
"pngs", f"{width}x{height}", str(fps),
stdout=asyncio.subprocess.PIPE,
stdin=asyncio.subprocess.PIPE)
proc = await asyncio.create_subprocess_exec(
lottieconverter,
"-",
file_template,
"pngs",
f"{width}x{height}",
str(fps),
stdout=asyncio.subprocess.PIPE,
stdin=asyncio.subprocess.PIPE,
)
_, stderr = await proc.communicate(file)
if proc.returncode == 0:
with open(f"{file_template}00.png", "rb") as first_frame_file:
first_frame_data = first_frame_file.read()
proc = await asyncio.create_subprocess_exec(ffmpeg, "-hide_banner", "-loglevel",
"error", "-framerate", str(fps),
"-pattern_type", "glob", "-i",
file_template + "*.png",
"-c:v", "libvpx-vp9", "-pix_fmt",
"yuva420p", "-f", "webm", "-",
stdout=asyncio.subprocess.PIPE,
stdin=asyncio.subprocess.PIPE)
proc = await asyncio.create_subprocess_exec(
ffmpeg,
"-hide_banner",
"-loglevel",
"error",
"-framerate",
str(fps),
"-pattern_type",
"glob",
"-i",
file_template + "*.png",
"-c:v",
"libvpx-vp9",
"-pix_fmt",
"yuva420p",
"-f",
"webm",
"-",
stdout=asyncio.subprocess.PIPE,
stdin=asyncio.subprocess.PIPE,
)
stdout, stderr = await proc.communicate()
if proc.returncode == 0:
return ConvertedSticker("video/webm", stdout, "image/png", first_frame_data)
else:
log.error("ffmpeg error: " + (stderr.decode("utf-8") if stderr is not None
else f"unknown ({proc.returncode})"))
log.error(
"ffmpeg error: "
+ (
stderr.decode("utf-8")
if stderr is not None
else f"unknown ({proc.returncode})"
)
)
else:
log.error("lottieconverter error: " + (stderr.decode("utf-8") if stderr is not None
else f"unknown ({proc.returncode})"))
log.error(
"lottieconverter error: "
+ (
stderr.decode("utf-8")
if stderr is not None
else f"unknown ({proc.returncode})"
)
)
return ConvertedSticker("application/gzip", file)
converters["webm"] = tgs_to_webm
async def convert_tgs_to(file: bytes, convert_to: str, width: int, height: int, **kwargs: Any
) -> ConvertedSticker:
async def convert_tgs_to(
file: bytes, convert_to: str, width: int, height: int, **kwargs: Any
) -> ConvertedSticker:
if convert_to in converters:
converter = converters[convert_to]
converted = await converter(file, width, height, **kwargs)