diff --git a/example-config.yaml b/example-config.yaml
index 5f7e9753..61a444e1 100644
--- a/example-config.yaml
+++ b/example-config.yaml
@@ -163,6 +163,10 @@ bridge:
image_as_file_size: 10
# Maximum size of Telegram documents in megabytes to bridge.
max_document_size: 100
+ # Enable experimental parallel file transfer, which makes uploads/downloads much faster by
+ # streaming from/to Matrix and using many connections for Telegram.
+ # Note that generating HQ thumbnails for videos is not possible with streamed transfers.
+ parallel_file_transfer: false
# Whether or not created rooms should have federation enabled.
# If false, created portal rooms will never be federated.
federate_rooms: true
@@ -207,20 +211,19 @@ bridge:
# Text msgtypes (m.text, m.notice and m.emote) support HTML, media msgtypes don't.
#
# Available variables:
- # $sender_displayname - The display name of the sender (e.g. Example User)
- # $sender_username - The username (Matrix ID localpart) of the sender (e.g. exampleuser)
- # $sender_mxid - The Matrix ID of the sender (e.g. @exampleuser:example.com)
- # $body - The plaintext body (file name for media msgtypes)
- # $formatted_body - The message content as HTML (for text msgtypes)
+ # $sender_displayname - The display name of the sender (e.g. Example User)
+ # $sender_username - The username (Matrix ID localpart) of the sender (e.g. exampleuser)
+ # $sender_mxid - The Matrix ID of the sender (e.g. @exampleuser:example.com)
+ # $message - The message content
message_formats:
- m.text: "$sender_displayname: $formatted_body"
- m.notice: "$sender_displayname: $formatted_body"
- m.emote: "* $sender_displayname $formatted_body"
- m.file: "$sender_displayname sent a file: $body"
- m.image: "$sender_displayname sent an image: $body"
- m.audio: "$sender_displayname sent an audio file: $body"
- m.video: "$sender_displayname sent a video: $body"
- m.location: "$sender_displayname sent a location: $body"
+ m.text: "$sender_displayname: $message"
+ m.notice: "$sender_displayname: $message"
+ m.emote: "* $sender_displayname $message"
+ m.file: "$sender_displayname sent a file: $message"
+ m.image: "$sender_displayname sent an image: $message"
+ m.audio: "$sender_displayname sent an audio file: $message"
+ m.video: "$sender_displayname sent a video: $message"
+ m.location: "$sender_displayname sent a location: $message"
# Telegram doesn't have built-in emotes, this field specifies how m.emote's from authenticated
# users are sent to telegram. All fields in message_formats are supported. Additionally, the
# Telegram user info is available in the following variables:
diff --git a/mautrix_telegram/abstract_user.py b/mautrix_telegram/abstract_user.py
index a40badc9..c4cfd7e7 100644
--- a/mautrix_telegram/abstract_user.py
+++ b/mautrix_telegram/abstract_user.py
@@ -295,7 +295,7 @@ class AbstractUser(ABC):
async def update_admin(self, update: UpdateChatParticipantAdmin) -> None:
# TODO duplication not checked
- portal = po.Portal.get_by_tgid(TelegramID(update.chat_id), peer_type="chat")
+ portal = po.Portal.get_by_tgid(TelegramID(update.chat_id))
if not portal or not portal.mxid:
return
@@ -305,7 +305,7 @@ class AbstractUser(ABC):
if isinstance(update, UpdateUserTyping):
portal = po.Portal.get_by_tgid(TelegramID(update.user_id), self.tgid, "user")
else:
- portal = po.Portal.get_by_tgid(TelegramID(update.chat_id), peer_type="chat")
+ portal = po.Portal.get_by_tgid(TelegramID(update.chat_id))
if not portal or not portal.mxid:
return
@@ -350,7 +350,7 @@ class AbstractUser(ABC):
Optional[pu.Puppet],
Optional[po.Portal]]:
if isinstance(update, UpdateShortChatMessage):
- portal = po.Portal.get_by_tgid(TelegramID(update.chat_id), peer_type="chat")
+ portal = po.Portal.get_by_tgid(TelegramID(update.chat_id))
sender = pu.Puppet.get(TelegramID(update.from_id))
elif isinstance(update, UpdateShortMessage):
portal = po.Portal.get_by_tgid(TelegramID(update.user_id), self.tgid, "user")
@@ -410,9 +410,10 @@ class AbstractUser(ABC):
if not config["bridge.relaybot.private_chat.invite"]:
self.log.debug(f"Ignoring private message to bot from {sender.id}")
return
- elif not portal.mxid:
+ elif not portal or not portal.mxid:
+ tgid_log = portal.tgid_log if portal else original_update.chat_id
self.log.debug(
- f"Ignoring message received by bot in unbridged chat {portal.tgid_log}")
+ f"Ignoring message received by bot in unbridged chat {tgid_log}")
return
if self.ignore_incoming_bot_events and self.relaybot and sender.id == self.relaybot.tgid:
diff --git a/mautrix_telegram/commands/clean_rooms.py b/mautrix_telegram/commands/clean_rooms.py
index 8182f746..93e3e649 100644
--- a/mautrix_telegram/commands/clean_rooms.py
+++ b/mautrix_telegram/commands/clean_rooms.py
@@ -155,7 +155,7 @@ async def set_rooms_to_clean(evt, management_rooms: List[ManagementRoom],
"next": lambda confirm: execute_room_cleanup(confirm, rooms_to_clean),
"action": "Room cleaning",
}
- await evt.reply(f"To confirm cleaning up {len(rooms_to_clean)} rooms, type"
+ await evt.reply(f"To confirm cleaning up {len(rooms_to_clean)} rooms, type "
"`$cmdprefix+sp confirm-clean`.")
diff --git a/mautrix_telegram/commands/handler.py b/mautrix_telegram/commands/handler.py
index f50e59e5..cc728be6 100644
--- a/mautrix_telegram/commands/handler.py
+++ b/mautrix_telegram/commands/handler.py
@@ -18,7 +18,7 @@ from typing import Awaitable, Callable, List, Optional, NamedTuple, Any
from telethon.errors import FloodWaitError
-from mautrix.types import RoomID, EventID
+from mautrix.types import RoomID, EventID, MessageEventContent
from mautrix.bridge.commands import (HelpSection, CommandEvent as BaseCommandEvent,
CommandHandler as BaseCommandHandler,
CommandProcessor as BaseCommandProcessor,
@@ -42,10 +42,10 @@ class CommandEvent(BaseCommandEvent):
sender: u.User
def __init__(self, processor: 'CommandProcessor', room_id: RoomID, event_id: EventID,
- sender: u.User, command: str, args: List[str], is_management: bool,
- is_portal: bool) -> None:
- super().__init__(processor, room_id, event_id, sender, command, args, is_management,
- is_portal)
+ sender: u.User, command: str, args: List[str], content: MessageEventContent,
+ is_management: bool, is_portal: bool) -> None:
+ super().__init__(processor, room_id, event_id, sender, command, args, content,
+ is_management, is_portal)
self.bridge = processor.bridge
self.tgbot = processor.tgbot
self.config = processor.config
@@ -69,7 +69,7 @@ class CommandHandler(BaseCommandHandler):
def __init__(self, handler: Callable[[CommandEvent], Awaitable[EventID]],
management_only: bool, name: str, help_text: str, help_args: str,
help_section: HelpSection, needs_auth: bool, needs_puppeting: bool,
- needs_matrix_puppeting: bool, needs_admin: bool,) -> None:
+ needs_matrix_puppeting: bool, needs_admin: bool) -> None:
super().__init__(handler, management_only, name, help_text, help_args, help_section,
needs_auth=needs_auth, needs_puppeting=needs_puppeting,
needs_matrix_puppeting=needs_matrix_puppeting, needs_admin=needs_admin)
diff --git a/mautrix_telegram/commands/telegram/misc.py b/mautrix_telegram/commands/telegram/misc.py
index 60b12181..6d9fde77 100644
--- a/mautrix_telegram/commands/telegram/misc.py
+++ b/mautrix_telegram/commands/telegram/misc.py
@@ -13,7 +13,7 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-from typing import List, Optional, Tuple
+from typing import List, Optional, Tuple, cast
import logging
import codecs
import base64
@@ -23,13 +23,13 @@ from telethon.errors import (InviteHashInvalidError, InviteHashExpiredError, Opt
UserAlreadyParticipantError, ChatIdInvalidError)
from telethon.tl.patched import Message
from telethon.tl.types import (User as TLUser, TypeUpdates, MessageMediaGame, MessageMediaPoll,
- TypePeer)
+ TypeInputPeer)
from telethon.tl.types.messages import BotCallbackAnswer
from telethon.tl.functions.messages import (ImportChatInviteRequest, CheckChatInviteRequest,
GetBotCallbackAnswerRequest, SendVoteRequest)
from telethon.tl.functions.channels import JoinChannelRequest
-from mautrix.types import EventID
+from mautrix.types import EventID, Format
from ... import puppet as pu, portal as po
from ...abstract_user import AbstractUser
@@ -38,6 +38,22 @@ from ...types import TelegramID
from ...commands import command_handler, CommandEvent, SECTION_MISC, SECTION_CREATING_PORTALS
+@command_handler(needs_auth=False,
+ help_section=SECTION_MISC, help_args="<_caption_>",
+ help_text="Set a caption for the next image you send")
+async def caption(evt: CommandEvent) -> EventID:
+ if len(evt.args) == 0:
+ return await evt.reply("**Usage:** `$cmdprefix+sp caption
`")
+
+ prefix = f"{evt.command_prefix} caption "
+ if evt.content.format == Format.HTML:
+ evt.content.formatted_body = evt.content.formatted_body.replace(prefix, "", 1)
+ evt.content.body = evt.content.body.replace(prefix, "", 1)
+ evt.sender.command_status = {"caption": evt.content}
+ return await evt.reply("Your next image or file will be sent with that caption. "
+ "Use `$cmdprefix+sp cancel` to cancel the caption.")
+
+
@command_handler(help_section=SECTION_MISC,
help_args="[_-r|--remote_] <_query_>",
help_text="Search your contacts or the Telegram servers for users.")
@@ -76,8 +92,7 @@ async def search(evt: CommandEvent) -> EventID:
return await evt.reply("\n".join(reply))
-@command_handler(help_section=SECTION_CREATING_PORTALS,
- help_args="<_identifier_>",
+@command_handler(help_section=SECTION_CREATING_PORTALS, help_args="<_identifier_>",
help_text="Open a private chat with the given Telegram user. The identifier is "
"either the internal user ID, the username or the phone number. "
"**N.B.** The phone numbers you start chats with must already be in "
@@ -183,7 +198,7 @@ class MessageIDError(ValueError):
async def _parse_encoded_msgid(user: AbstractUser, enc_id: str, type_name: str
- ) -> Tuple[TypePeer, Message]:
+ ) -> Tuple[TypeInputPeer, Message]:
try:
enc_id += (4 - len(enc_id) % 4) * "="
enc_id = base64.b64decode(enc_id)
@@ -212,7 +227,7 @@ async def _parse_encoded_msgid(user: AbstractUser, enc_id: str, type_name: str
msg = await user.client.get_messages(entity=peer, ids=msg_id)
if not msg:
raise MessageIDError(f"Invalid {type_name} ID (message not found)")
- return peer, msg
+ return peer, cast(Message, msg)
@command_handler(help_section=SECTION_MISC,
@@ -234,12 +249,13 @@ async def play(evt: CommandEvent) -> EventID:
if not isinstance(msg.media, MessageMediaGame):
return await evt.reply("Invalid play ID (message doesn't look like a game)")
- game = await evt.sender.client(GetBotCallbackAnswerRequest(peer=peer, msg_id=msg.id, game=True))
+ game = await evt.sender.client(
+ GetBotCallbackAnswerRequest(peer=peer, msg_id=msg.id, game=True))
if not isinstance(game, BotCallbackAnswer):
return await evt.reply("Game request response invalid")
return await evt.reply(f"Click [here]({game.url}) to play {msg.media.game.title}:\n\n"
- f"{msg.media.game.description}")
+ f"{msg.media.game.description}")
@command_handler(help_section=SECTION_MISC,
diff --git a/mautrix_telegram/config.py b/mautrix_telegram/config.py
index 9d3cd37c..e25c3d32 100644
--- a/mautrix_telegram/config.py
+++ b/mautrix_telegram/config.py
@@ -101,6 +101,7 @@ class Config(BaseBridgeConfig):
copy("bridge.inline_images")
copy("bridge.image_as_file_size")
copy("bridge.max_document_size")
+ copy("bridge.parallel_file_transfer")
copy("bridge.federate_rooms")
copy("bridge.animated_sticker.target")
copy("bridge.animated_sticker.args")
diff --git a/mautrix_telegram/db/telegram_file.py b/mautrix_telegram/db/telegram_file.py
index 4ac05293..efed516f 100644
--- a/mautrix_telegram/db/telegram_file.py
+++ b/mautrix_telegram/db/telegram_file.py
@@ -30,9 +30,9 @@ class TelegramFile(Base):
mime_type: str = Column(String)
was_converted: bool = Column(Boolean)
timestamp: int = Column(BigInteger)
- size: int = Column(Integer, nullable=True)
- width: int = Column(Integer, nullable=True)
- height: int = Column(Integer, nullable=True)
+ size: Optional[int] = Column(Integer, nullable=True)
+ width: Optional[int] = Column(Integer, nullable=True)
+ height: Optional[int] = Column(Integer, nullable=True)
thumbnail_id: str = Column("thumbnail", String, ForeignKey("telegram_file.id"), nullable=True)
thumbnail: Optional['TelegramFile'] = None
diff --git a/mautrix_telegram/portal/matrix.py b/mautrix_telegram/portal/matrix.py
index 32730da3..95c509e6 100644
--- a/mautrix_telegram/portal/matrix.py
+++ b/mautrix_telegram/portal/matrix.py
@@ -17,7 +17,6 @@ from typing import Awaitable, Dict, List, Optional, Tuple, Union, Any, TYPE_CHEC
from html import escape as escape_html
from string import Template
from abc import ABC
-import mimetypes
import magic
@@ -32,7 +31,7 @@ from telethon.tl.types import (
DocumentAttributeFilename, DocumentAttributeImageSize, GeoPoint,
InputChatUploadedPhoto, MessageActionChatEditPhoto, MessageMediaGeo,
SendMessageCancelAction, SendMessageTypingAction, TypeInputPeer, TypeMessageEntity,
- UpdateNewMessage, InputMediaUploadedDocument)
+ UpdateNewMessage, InputMediaUploadedDocument, InputMediaUploadedPhoto)
from mautrix.types import (EventID, RoomID, UserID, ContentURI, MessageType, MessageEventContent,
TextMessageEventContent, MediaMessageEventContent, Format,
@@ -41,7 +40,7 @@ from mautrix.bridge import BasePortal as MautrixBasePortal
from ..types import TelegramID
from ..db import Message as DBMessage
-from ..util import sane_mimetypes
+from ..util import sane_mimetypes, parallel_transfer_to_telegram
from ..context import Context
from .. import puppet as p, user as u, formatter, util
from .base import BasePortal
@@ -57,19 +56,6 @@ config: Optional['Config'] = None
class PortalMatrix(BasePortal, MautrixBasePortal, ABC):
- @staticmethod
- def _get_file_meta(body: str, mime: str) -> str:
- try:
- current_extension = body[body.rindex("."):].lower()
- body = body[:body.rindex(".")]
- if mimetypes.types_map[current_extension] == mime:
- return body + current_extension
- except (ValueError, KeyError):
- pass
- if mime:
- return f"matrix_upload{sane_mimetypes.guess_extension(mime)}"
- return ""
-
async def _get_state_change_message(self, event: str, user: 'u.User', **kwargs: Any
) -> Optional[str]:
tpl = self.get_config(f"state_event_formats.{event}")
@@ -183,7 +169,7 @@ class PortalMatrix(BasePortal, MautrixBasePortal, ABC):
async def _apply_msg_format(self, sender: 'u.User', content: MessageEventContent
) -> None:
- if isinstance(content, TextMessageEventContent) and content.format != Format.HTML:
+ if not isinstance(content, TextMessageEventContent) or content.format != Format.HTML:
content.format = Format.HTML
content.formatted_body = escape_html(content.body).replace("\n", "
")
@@ -193,14 +179,9 @@ class PortalMatrix(BasePortal, MautrixBasePortal, ABC):
tpl_args = dict(sender_mxid=sender.mxid,
sender_username=sender.mxid_localpart,
sender_displayname=escape_html(displayname),
- body=content.body)
- if isinstance(content, TextMessageEventContent):
- tpl_args["formatted_body"] = content.formatted_body
- tpl_args["message"] = content.formatted_body
- content.formatted_body = Template(tpl).safe_substitute(tpl_args)
- else:
- tpl_args["message"] = content.body
- content.body = Template(tpl).safe_substitute(tpl_args)
+ message=content.formatted_body,
+ body=content.body, formatted_body=content.formatted_body)
+ content.formatted_body = Template(tpl).safe_substitute(tpl_args)
async def _apply_emote_format(self, sender: 'u.User',
content: TextMessageEventContent) -> None:
@@ -262,42 +243,55 @@ class PortalMatrix(BasePortal, MautrixBasePortal, ABC):
async def _handle_matrix_file(self, sender_id: TelegramID, event_id: EventID,
space: TelegramID, client: 'MautrixTelegramClient',
- content: MediaMessageEventContent, reply_to: TelegramID) -> None:
- file = await self.main_intent.download_media(content.url)
-
+ content: MediaMessageEventContent, reply_to: TelegramID,
+ caption: TextMessageEventContent = None) -> None:
mime = content.info.mimetype
-
w, h = content.info.width, content.info.height
+ file_name = content["net.maunium.telegram.internal.filename"]
+ max_image_size = config["bridge.image_as_file_size"] * 1000 ** 2
- if content.msgtype == MessageType.STICKER:
- if mime != "image/gif":
- mime, file, w, h = util.convert_image(file, source_mime=mime, target_type="webp")
- else:
- # Remove sticker description
- content["net.maunium.telegram.internal.filename"] = "sticker.gif"
- content.body = ""
+ if config["bridge.parallel_file_transfer"]:
+ file_handle, file_size = await parallel_transfer_to_telegram(client, self.main_intent,
+ content.url, sender_id)
+ else:
+ file = await self.main_intent.download_media(content.url)
+
+ if content.msgtype == MessageType.STICKER:
+ if mime != "image/gif":
+ mime, file, w, h = util.convert_image(file, source_mime=mime,
+ target_type="webp")
+ else:
+ # Remove sticker description
+ file_name = "sticker.gif"
+
+ file_handle = await client.upload_file(file)
+ file_size = len(file)
+
+ file_handle.name = file_name
- file_name = self._get_file_meta(content["net.maunium.telegram.internal.filename"], mime)
attributes = [DocumentAttributeFilename(file_name=file_name)]
if w and h:
attributes.append(DocumentAttributeImageSize(w, h))
- caption = content.body if content.body.lower() != file_name.lower() else None
+ if (mime == "image/png" or mime == "image/jpeg") and file_size < max_image_size:
+ media = InputMediaUploadedPhoto(file_handle)
+ else:
+ media = InputMediaUploadedDocument(file=file_handle, attributes=attributes,
+ mime_type=mime or "application/octet-stream")
+
+ caption, entities = self._matrix_event_to_entities(caption) if caption else (None, None)
- media = await client.upload_file_direct(
- file, mime, attributes, file_name,
- max_image_size=config["bridge.image_as_file_size"] * 1000 ** 2)
async with self.send_lock(sender_id):
if await self._matrix_document_edit(client, content, space, caption, media, event_id):
return
try:
response = await client.send_media(self.peer, media, reply_to=reply_to,
- caption=caption)
+ caption=caption, entities=entities)
except (PhotoInvalidDimensionsError, PhotoSaveFileInvalidError, PhotoExtInvalidError):
media = InputMediaUploadedDocument(file=media.file, mime_type=mime,
attributes=attributes)
response = await client.send_media(self.peer, media, reply_to=reply_to,
- caption=caption)
+ caption=caption, entities=entities)
self._add_telegram_message_to_db(event_id, space, 0, response)
async def _matrix_document_edit(self, client: 'MautrixTelegramClient',
@@ -364,8 +358,20 @@ class PortalMatrix(BasePortal, MautrixBasePortal, ABC):
else (sender.tgid if logged_in else self.bot.tgid))
reply_to = formatter.matrix_reply_to_telegram(content, space, room_id=self.mxid)
- content["net.maunium.telegram.internal.filename"] = content.body
- await self._pre_process_matrix_message(sender, not logged_in, content)
+ media = (MessageType.STICKER, MessageType.IMAGE, MessageType.FILE, MessageType.AUDIO,
+ MessageType.VIDEO)
+ caption_content = None
+ if content.msgtype in media:
+ content["net.maunium.telegram.internal.filename"] = content.body
+ try:
+ caption_content: MessageEventContent = sender.command_status["caption"]
+ caption_content.msgtype = content.msgtype
+ reply_to = reply_to or formatter.matrix_reply_to_telegram(caption_content, space,
+ room_id=self.mxid)
+ sender.command_status = None
+ except (KeyError, TypeError):
+ pass
+ await self._pre_process_matrix_message(sender, not logged_in, caption_content or content)
if content.msgtype == MessageType.NOTICE:
bridge_notices = self.get_config("bridge_notices.default")
@@ -378,9 +384,9 @@ class PortalMatrix(BasePortal, MautrixBasePortal, ABC):
elif content.msgtype == MessageType.LOCATION:
await self._handle_matrix_location(sender_id, event_id, space, client, content,
reply_to)
- elif content.msgtype in (MessageType.STICKER, MessageType.IMAGE, MessageType.FILE,
- MessageType.AUDIO, MessageType.VIDEO):
- await self._handle_matrix_file(sender_id, event_id, space, client, content, reply_to)
+ elif content.msgtype in media:
+ await self._handle_matrix_file(sender_id, event_id, space, client, content, reply_to,
+ caption_content)
else:
self.log.debug(f"Unhandled Matrix event: {content}")
diff --git a/mautrix_telegram/portal/telegram.py b/mautrix_telegram/portal/telegram.py
index 3aa97b3a..f88d1380 100644
--- a/mautrix_telegram/portal/telegram.py
+++ b/mautrix_telegram/portal/telegram.py
@@ -181,9 +181,11 @@ class PortalTelegram(BasePortal, ABC):
self.log.debug(f"Unsupported thumbnail type {type(thumb_size)}")
thumb_loc = None
thumb_size = None
+ parallel_id = source.tgid if config["bridge.parallel_file_transfer"] else None
file = await util.transfer_file_to_matrix(source.client, intent, document, thumb_loc,
is_sticker=attrs.is_sticker,
- tgs_convert=config["bridge.animated_sticker"])
+ tgs_convert=config["bridge.animated_sticker"],
+ filename=attrs.name, parallel_id=parallel_id)
if not file:
return None
diff --git a/mautrix_telegram/user.py b/mautrix_telegram/user.py
index df36fc5d..9c365656 100644
--- a/mautrix_telegram/user.py
+++ b/mautrix_telegram/user.py
@@ -219,7 +219,7 @@ class User(AbstractUser, BaseUser):
else:
portal = po.Portal.get_by_entity(message.to_id, receiver_id=self.tgid)
elif isinstance(update, UpdateShortChatMessage):
- portal = po.Portal.get_by_tgid(TelegramID(update.chat_id), peer_type="chat")
+ portal = po.Portal.get_by_tgid(TelegramID(update.chat_id))
elif isinstance(update, UpdateShortMessage):
portal = po.Portal.get_by_tgid(TelegramID(update.user_id), self.tgid, "user")
else:
diff --git a/mautrix_telegram/util/__init__.py b/mautrix_telegram/util/__init__.py
index 727224bb..b2bfa88e 100644
--- a/mautrix_telegram/util/__init__.py
+++ b/mautrix_telegram/util/__init__.py
@@ -1,4 +1,5 @@
from .file_transfer import transfer_file_to_matrix, convert_image
+from .parallel_file_transfer import parallel_transfer_to_telegram
from .format_duration import format_duration
from .recursive_dict import recursive_del, recursive_set, recursive_get
from .color_log import ColorFormatter
diff --git a/mautrix_telegram/util/file_transfer.py b/mautrix_telegram/util/file_transfer.py
index e89c3465..9028d9ce 100644
--- a/mautrix_telegram/util/file_transfer.py
+++ b/mautrix_telegram/util/file_transfer.py
@@ -34,6 +34,7 @@ from mautrix.appservice import IntentAPI
from ..tgclient import MautrixTelegramClient
from ..db import TelegramFile as DBTelegramFile
from ..util import sane_mimetypes
+from .parallel_file_transfer import parallel_transfer_to_matrix
try:
from PIL import Image
@@ -129,7 +130,7 @@ async def transfer_thumbnail_to_matrix(client: MautrixTelegramClient, intent: In
return db_file
video_ext = sane_mimetypes.guess_extension(mime)
- if VideoFileClip and video_ext:
+ if VideoFileClip and video_ext and video:
try:
file, width, height = _read_video_thumbnail(video, video_ext, frame_ext="png")
except OSError:
@@ -161,7 +162,8 @@ TypeThumbnail = Optional[Union[TypeLocation, TypePhotoSize]]
async def transfer_file_to_matrix(client: MautrixTelegramClient, intent: IntentAPI,
location: TypeLocation, thumbnail: TypeThumbnail = None,
- is_sticker: bool = False, tgs_convert: Optional[dict] = None
+ is_sticker: bool = False, tgs_convert: Optional[dict] = None,
+ filename: Optional[str] = None, parallel_id: Optional[int] = None
) -> Optional[DBTelegramFile]:
location_id = _location_to_id(location)
if not location_id:
@@ -178,53 +180,61 @@ async def transfer_file_to_matrix(client: MautrixTelegramClient, intent: IntentA
transfer_locks[location_id] = lock
async with lock:
return await _unlocked_transfer_file_to_matrix(client, intent, location_id, location,
- thumbnail, is_sticker, tgs_convert)
+ thumbnail, is_sticker, tgs_convert,
+ filename, parallel_id)
async def _unlocked_transfer_file_to_matrix(client: MautrixTelegramClient, intent: IntentAPI,
loc_id: str, location: TypeLocation,
thumbnail: TypeThumbnail, is_sticker: bool,
- tgs_convert: Optional[dict]
+ tgs_convert: Optional[dict], filename: Optional[str],
+ parallel_id: Optional[int]
) -> Optional[DBTelegramFile]:
db_file = DBTelegramFile.get(loc_id)
if db_file:
return db_file
- try:
- file = await client.download_file(location)
- except (LocationInvalidError, FileIdInvalidError):
- return None
- except (AuthBytesInvalidError, AuthKeyInvalidError, SecurityError) as e:
- log.exception(f"{e.__class__.__name__} while downloading a file.")
- return None
+ if parallel_id and isinstance(location, Document) and (not is_sticker or not tgs_convert):
+ db_file = await parallel_transfer_to_matrix(client, intent, loc_id, location, filename,
+ parallel_id)
+ mime_type = location.mime_type
+ file = None
+ else:
+ try:
+ file = await client.download_file(location)
+ except (LocationInvalidError, FileIdInvalidError):
+ return None
+ except (AuthBytesInvalidError, AuthKeyInvalidError, SecurityError) as e:
+ log.exception(f"{e.__class__.__name__} while downloading a file.")
+ return None
- width, height = None, None
- mime_type = magic.from_buffer(file, mime=True)
+ width, height = None, None
+ mime_type = magic.from_buffer(file, mime=True)
- image_converted = False
- # A weird bug in alpine/magic makes it return application/octet-stream for gzips...
- if is_sticker and tgs_convert and (mime_type == "application/gzip" or (
- mime_type == "application/octet-stream"
- and magic.from_buffer(file).startswith("gzip"))):
- mime_type, file, width, height = await convert_tgs_to(
- file, tgs_convert["target"], **tgs_convert["args"])
- thumbnail = None
- image_converted = mime_type != "application/gzip"
+ image_converted = False
+ # A weird bug in alpine/magic makes it return application/octet-stream for gzips...
+ if is_sticker and tgs_convert and (mime_type == "application/gzip" or (
+ mime_type == "application/octet-stream"
+ and magic.from_buffer(file).startswith("gzip"))):
+ mime_type, file, width, height = await convert_tgs_to(
+ file, tgs_convert["target"], **tgs_convert["args"])
+ thumbnail = None
+ image_converted = mime_type != "application/gzip"
- if mime_type == "image/webp":
- new_mime_type, file, width, height = convert_image(
- file, source_mime="image/webp", target_type="png",
- thumbnail_to=(256, 256) if is_sticker else None)
- image_converted = new_mime_type != mime_type
- mime_type = new_mime_type
- thumbnail = None
+ if mime_type == "image/webp":
+ new_mime_type, file, width, height = convert_image(
+ file, source_mime="image/webp", target_type="png",
+ thumbnail_to=(256, 256) if is_sticker else None)
+ image_converted = new_mime_type != mime_type
+ mime_type = new_mime_type
+ thumbnail = None
- content_uri = await intent.upload_media(file, mime_type)
+ content_uri = await intent.upload_media(file, mime_type)
- db_file = DBTelegramFile(id=loc_id, mxc=content_uri,
- mime_type=mime_type, was_converted=image_converted,
- timestamp=int(time.time()), size=len(file),
- width=width, height=height)
+ db_file = DBTelegramFile(id=loc_id, mxc=content_uri,
+ mime_type=mime_type, was_converted=image_converted,
+ timestamp=int(time.time()), size=len(file),
+ width=width, height=height)
if thumbnail and (mime_type.startswith("video/") or mime_type == "image/gif"):
if isinstance(thumbnail, (PhotoSize, PhotoCachedSize)):
thumbnail = thumbnail.location
diff --git a/mautrix_telegram/util/parallel_file_transfer.py b/mautrix_telegram/util/parallel_file_transfer.py
new file mode 100644
index 00000000..507646aa
--- /dev/null
+++ b/mautrix_telegram/util/parallel_file_transfer.py
@@ -0,0 +1,298 @@
+# mautrix-telegram - A Matrix-Telegram puppeting bridge
+# Copyright (C) 2019 Tulir Asokan
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+from typing import Optional, List, AsyncGenerator, Union, Awaitable, DefaultDict, Tuple
+from collections import defaultdict
+import hashlib
+import asyncio
+import logging
+import time
+import math
+
+from aiohttp import ClientResponse
+
+from telethon.tl.types import (Document, InputFileLocation, InputDocumentFileLocation,
+ InputPhotoFileLocation, InputPeerPhotoFileLocation, TypeInputFile,
+ InputFileBig, InputFile)
+from telethon.tl.functions.auth import ExportAuthorizationRequest, ImportAuthorizationRequest
+from telethon.tl.functions.upload import (GetFileRequest, SaveFilePartRequest,
+ SaveBigFilePartRequest)
+from telethon.network import MTProtoSender
+from telethon.crypto import AuthKey
+from telethon import utils, helpers
+
+from mautrix.appservice import IntentAPI
+from mautrix.types import ContentURI
+
+from ..tgclient import MautrixTelegramClient
+from ..db import TelegramFile as DBTelegramFile
+
+log: logging.Logger = logging.getLogger("mau.util")
+
+TypeLocation = Union[Document, InputDocumentFileLocation, InputPeerPhotoFileLocation,
+ InputFileLocation, InputPhotoFileLocation]
+
+
+class DownloadSender:
+ sender: MTProtoSender
+ request: GetFileRequest
+ remaining: int
+ stride: int
+
+ def __init__(self, sender: MTProtoSender, file: TypeLocation, offset: int, limit: int,
+ stride: int, count: int) -> None:
+ self.sender = sender
+ self.request = GetFileRequest(file, offset=offset, limit=limit)
+ self.stride = stride
+ self.remaining = count
+
+ async def next(self) -> Optional[bytes]:
+ if not self.remaining:
+ return None
+ result = await self.sender.send(self.request)
+ self.remaining -= 1
+ self.request.offset += self.stride
+ return result.bytes
+
+ def disconnect(self) -> Awaitable[None]:
+ return self.sender.disconnect()
+
+
+class UploadSender:
+ sender: MTProtoSender
+ request: Union[SaveFilePartRequest, SaveBigFilePartRequest]
+ part_count: int
+ stride: int
+ previous: Optional[asyncio.Task]
+ loop: asyncio.AbstractEventLoop
+
+ def __init__(self, sender: MTProtoSender, file_id: int, part_count: int, big: bool, index: int,
+ stride: int, loop: asyncio.AbstractEventLoop) -> None:
+ self.sender = sender
+ self.part_count = part_count
+ if big:
+ self.request = SaveBigFilePartRequest(file_id, index, part_count, b"")
+ else:
+ self.request = SaveFilePartRequest(file_id, index, b"")
+ self.stride = stride
+ self.previous = None
+ self.loop = loop
+
+ async def next(self, data: bytes) -> None:
+ if self.previous:
+ await self.previous
+ self.previous = self.loop.create_task(self._next(data))
+
+ async def _next(self, data: bytes) -> None:
+ self.request.bytes = data
+ log.debug(f"Sending file part {self.request.file_part}/{self.part_count}"
+ f" with {len(data)} bytes")
+ await self.sender.send(self.request)
+ self.request.file_part += self.stride
+
+ async def disconnect(self) -> None:
+ if self.previous:
+ await self.previous
+ return await self.sender.disconnect()
+
+
+class ParallelTransferrer:
+ client: MautrixTelegramClient
+ loop: asyncio.AbstractEventLoop
+ dc_id: int
+ senders: Optional[List[Union[DownloadSender, UploadSender]]]
+ auth_key: AuthKey
+ upload_ticker: int
+
+ def __init__(self, client: MautrixTelegramClient, dc_id: Optional[int] = None) -> None:
+ self.client = client
+ self.loop = self.client.loop
+ self.dc_id = dc_id or self.client.session.dc_id
+ self.auth_key = (None if dc_id and self.client.session.dc_id != dc_id
+ else self.client.session.auth_key)
+ self.senders = None
+ self.upload_ticker = 0
+
+ async def _cleanup(self) -> None:
+ await asyncio.gather(*[sender.disconnect() for sender in self.senders])
+ self.senders = None
+
+ @staticmethod
+ def _get_connection_count(file_size: int, max_count: int = 20,
+ full_size: int = 100 * 1024 * 1024) -> int:
+ if file_size > full_size:
+ return max_count
+ return math.ceil((file_size / full_size) * max_count)
+
+ async def _init_download(self, connections: int, file: TypeLocation, part_count: int,
+ part_size: int) -> None:
+ minimum, remainder = divmod(part_count, connections)
+
+ def get_part_count() -> int:
+ nonlocal remainder
+ if remainder > 0:
+ remainder -= 1
+ return minimum + 1
+ return minimum
+
+ # The first cross-DC sender will export+import the authorization, so we always create it
+ # before creating any other senders.
+ self.senders = [
+ await self._create_download_sender(file, 0, part_size, connections * part_size,
+ get_part_count()),
+ *await asyncio.gather(
+ *[self._create_download_sender(file, i, part_size, connections * part_size,
+ get_part_count())
+ for i in range(1, connections)])
+ ]
+
+ async def _create_download_sender(self, file: TypeLocation, index: int, part_size: int,
+ stride: int,
+ part_count: int) -> DownloadSender:
+ return DownloadSender(await self._create_sender(), file, index * part_size, part_size,
+ stride, part_count)
+
+ async def _init_upload(self, connections: int, file_id: int, part_count: int, big: bool
+ ) -> None:
+ self.senders = [
+ await self._create_upload_sender(file_id, part_count, big, 0, connections),
+ *await asyncio.gather(
+ *[self._create_upload_sender(file_id, part_count, big, i, connections)
+ for i in range(1, connections)])
+ ]
+
+ async def _create_upload_sender(self, file_id: int, part_count: int, big: bool, index: int,
+ stride: int) -> UploadSender:
+ return UploadSender(await self._create_sender(), file_id, part_count, big, index, stride,
+ loop=self.loop)
+
+ async def _create_sender(self) -> MTProtoSender:
+ dc = await self.client._get_dc(self.dc_id)
+ sender = MTProtoSender(self.auth_key, self.loop, loggers=self.client._log)
+ await sender.connect(self.client._connection(dc.ip_address, dc.port, dc.id,
+ loop=self.loop, loggers=self.client._log,
+ proxy=self.client._proxy))
+ if not self.auth_key:
+ log.debug(f"Exporting auth to DC {self.dc_id}")
+ auth = await self.client(ExportAuthorizationRequest(self.dc_id))
+ req = self.client._init_with(ImportAuthorizationRequest(
+ id=auth.id, bytes=auth.bytes
+ ))
+ await sender.send(req)
+ self.auth_key = sender.auth_key
+ return sender
+
+ async def init_upload(self, file_id: int, file_size: int, part_size_kb: Optional[float] = None,
+ connection_count: Optional[int] = None) -> Tuple[int, int, bool]:
+ connection_count = connection_count or self._get_connection_count(file_size)
+ part_size = (part_size_kb or utils.get_appropriated_part_size(file_size)) * 1024
+ part_count = (file_size + part_size - 1) // part_size
+ is_large = file_size > 10 * 1024 * 1024
+ await self._init_upload(connection_count, file_id, part_count, is_large)
+ return part_size, part_count, is_large
+
+ async def upload(self, part: bytes) -> None:
+ await self.senders[self.upload_ticker].next(part)
+ self.upload_ticker = (self.upload_ticker + 1) % len(self.senders)
+
+ async def finish_upload(self) -> None:
+ await self._cleanup()
+
+ async def download(self, file: TypeLocation, file_size: int,
+ part_size_kb: Optional[float] = None,
+ connection_count: Optional[int] = None) -> AsyncGenerator[bytes, None]:
+ connection_count = connection_count or self._get_connection_count(file_size)
+ part_size = (part_size_kb or utils.get_appropriated_part_size(file_size)) * 1024
+ part_count = math.ceil(file_size / part_size)
+ log.debug("Starting parallel download: "
+ f"{connection_count} {part_size} {part_count} {file!s}")
+ await self._init_download(connection_count, file, part_count, part_size)
+
+ part = 0
+ while part < part_count:
+ tasks = []
+ for sender in self.senders:
+ tasks.append(self.loop.create_task(sender.next()))
+ for task in tasks:
+ data = await task
+ if not data:
+ break
+ yield data
+ part += 1
+ log.debug(f"Part {part} downloaded")
+
+ log.debug("Parallel download finished, cleaning up connections")
+ await self._cleanup()
+
+
+parallel_transfer_locks: DefaultDict[int, asyncio.Lock] = defaultdict(lambda: asyncio.Lock())
+
+
+async def parallel_transfer_to_matrix(client: MautrixTelegramClient, intent: IntentAPI,
+ loc_id: str, location: TypeLocation, filename: str,
+ parallel_id: int) -> DBTelegramFile:
+ size = location.size
+ mime_type = location.mime_type
+ dc_id, location = utils.get_input_location(location)
+ # We lock the transfers because telegram has connection count limits
+ async with parallel_transfer_locks[parallel_id]:
+ downloader = ParallelTransferrer(client, dc_id)
+ content_uri = await intent.upload_media(downloader.download(location, size),
+ mime_type=mime_type, filename=filename, size=size)
+ return DBTelegramFile(id=loc_id, mxc=content_uri, mime_type=mime_type,
+ was_converted=False, timestamp=int(time.time()), size=size,
+ width=None, height=None)
+
+
+async def _internal_transfer_to_telegram(client: MautrixTelegramClient, response: ClientResponse
+ ) -> Tuple[TypeInputFile, int]:
+ file_id = helpers.generate_random_long()
+ file_size = response.content_length
+
+ hash_md5 = hashlib.md5()
+ uploader = ParallelTransferrer(client)
+ part_size, part_count, is_large = await uploader.init_upload(file_id, file_size)
+ buffer = bytearray()
+ async for data in response.content:
+ if not is_large:
+ hash_md5.update(data)
+ if len(buffer) == 0 and len(data) == part_size:
+ await uploader.upload(data)
+ continue
+ new_len = len(buffer) + len(data)
+ if new_len >= part_size:
+ cutoff = part_size - len(buffer)
+ buffer.extend(data[:cutoff])
+ await uploader.upload(bytes(buffer))
+ buffer.clear()
+ buffer.extend(data[cutoff:])
+ else:
+ buffer.extend(data)
+ if len(buffer) > 0:
+ await uploader.upload(bytes(buffer))
+ await uploader.finish_upload()
+ if is_large:
+ return InputFileBig(file_id, part_count, "upload"), file_size
+ else:
+ return InputFile(file_id, part_count, "upload", hash_md5.hexdigest()), file_size
+
+
+async def parallel_transfer_to_telegram(client: MautrixTelegramClient, intent: IntentAPI,
+ uri: ContentURI, parallel_id: int
+ ) -> Tuple[TypeInputFile, int]:
+ url = intent.api.get_download_url(uri)
+ async with parallel_transfer_locks[parallel_id]:
+ async with intent.api.session.get(url) as response:
+ return await _internal_transfer_to_telegram(client, response)
diff --git a/setup.py b/setup.py
index 51ee2ae8..1e21ba78 100644
--- a/setup.py
+++ b/setup.py
@@ -32,7 +32,7 @@ setuptools.setup(
install_requires=[
"aiohttp>=3.0.1,<4",
- "mautrix>=0.4.0.dev71,<0.5",
+ "mautrix>=0.4.0.dev75,<0.5",
"SQLAlchemy>=1.2.3,<2",
"alembic>=1.0.0,<2",
"commonmark>=0.8.1,<0.10",