Merge branch 'master' into rlottie

This commit is contained in:
Tulir Asokan
2019-10-27 15:37:33 +02:00
14 changed files with 460 additions and 122 deletions
+16 -13
View File
@@ -163,6 +163,10 @@ bridge:
image_as_file_size: 10
# Maximum size of Telegram documents in megabytes to bridge.
max_document_size: 100
# Enable experimental parallel file transfer, which makes uploads/downloads much faster by
# streaming from/to Matrix and using many connections for Telegram.
# Note that generating HQ thumbnails for videos is not possible with streamed transfers.
parallel_file_transfer: false
# Whether or not created rooms should have federation enabled.
# If false, created portal rooms will never be federated.
federate_rooms: true
@@ -207,20 +211,19 @@ bridge:
# Text msgtypes (m.text, m.notice and m.emote) support HTML, media msgtypes don't.
#
# Available variables:
# $sender_displayname - The display name of the sender (e.g. Example User)
# $sender_username - The username (Matrix ID localpart) of the sender (e.g. exampleuser)
# $sender_mxid - The Matrix ID of the sender (e.g. @exampleuser:example.com)
# $body - The plaintext body (file name for media msgtypes)
# $formatted_body - The message content as HTML (for text msgtypes)
# $sender_displayname - The display name of the sender (e.g. Example User)
# $sender_username - The username (Matrix ID localpart) of the sender (e.g. exampleuser)
# $sender_mxid - The Matrix ID of the sender (e.g. @exampleuser:example.com)
# $message - The message content
message_formats:
m.text: "<b>$sender_displayname</b>: $formatted_body"
m.notice: "<b>$sender_displayname</b>: $formatted_body"
m.emote: "* <b>$sender_displayname</b> $formatted_body"
m.file: "$sender_displayname sent a file: $body"
m.image: "$sender_displayname sent an image: $body"
m.audio: "$sender_displayname sent an audio file: $body"
m.video: "$sender_displayname sent a video: $body"
m.location: "$sender_displayname sent a location: $body"
m.text: "<b>$sender_displayname</b>: $message"
m.notice: "<b>$sender_displayname</b>: $message"
m.emote: "* <b>$sender_displayname</b> $message"
m.file: "<b>$sender_displayname</b> sent a file: $message"
m.image: "<b>$sender_displayname</b> sent an image: $message"
m.audio: "<b>$sender_displayname</b> sent an audio file: $message"
m.video: "<b>$sender_displayname</b> sent a video: $message"
m.location: "<b>$sender_displayname</b> sent a location: $message"
# Telegram doesn't have built-in emotes, this field specifies how m.emote's from authenticated
# users are sent to telegram. All fields in message_formats are supported. Additionally, the
# Telegram user info is available in the following variables:
+6 -5
View File
@@ -295,7 +295,7 @@ class AbstractUser(ABC):
async def update_admin(self, update: UpdateChatParticipantAdmin) -> None:
# TODO duplication not checked
portal = po.Portal.get_by_tgid(TelegramID(update.chat_id), peer_type="chat")
portal = po.Portal.get_by_tgid(TelegramID(update.chat_id))
if not portal or not portal.mxid:
return
@@ -305,7 +305,7 @@ class AbstractUser(ABC):
if isinstance(update, UpdateUserTyping):
portal = po.Portal.get_by_tgid(TelegramID(update.user_id), self.tgid, "user")
else:
portal = po.Portal.get_by_tgid(TelegramID(update.chat_id), peer_type="chat")
portal = po.Portal.get_by_tgid(TelegramID(update.chat_id))
if not portal or not portal.mxid:
return
@@ -350,7 +350,7 @@ class AbstractUser(ABC):
Optional[pu.Puppet],
Optional[po.Portal]]:
if isinstance(update, UpdateShortChatMessage):
portal = po.Portal.get_by_tgid(TelegramID(update.chat_id), peer_type="chat")
portal = po.Portal.get_by_tgid(TelegramID(update.chat_id))
sender = pu.Puppet.get(TelegramID(update.from_id))
elif isinstance(update, UpdateShortMessage):
portal = po.Portal.get_by_tgid(TelegramID(update.user_id), self.tgid, "user")
@@ -410,9 +410,10 @@ class AbstractUser(ABC):
if not config["bridge.relaybot.private_chat.invite"]:
self.log.debug(f"Ignoring private message to bot from {sender.id}")
return
elif not portal.mxid:
elif not portal or not portal.mxid:
tgid_log = portal.tgid_log if portal else original_update.chat_id
self.log.debug(
f"Ignoring message received by bot in unbridged chat {portal.tgid_log}")
f"Ignoring message received by bot in unbridged chat {tgid_log}")
return
if self.ignore_incoming_bot_events and self.relaybot and sender.id == self.relaybot.tgid:
+1 -1
View File
@@ -155,7 +155,7 @@ async def set_rooms_to_clean(evt, management_rooms: List[ManagementRoom],
"next": lambda confirm: execute_room_cleanup(confirm, rooms_to_clean),
"action": "Room cleaning",
}
await evt.reply(f"To confirm cleaning up {len(rooms_to_clean)} rooms, type"
await evt.reply(f"To confirm cleaning up {len(rooms_to_clean)} rooms, type "
"`$cmdprefix+sp confirm-clean`.")
+6 -6
View File
@@ -18,7 +18,7 @@ from typing import Awaitable, Callable, List, Optional, NamedTuple, Any
from telethon.errors import FloodWaitError
from mautrix.types import RoomID, EventID
from mautrix.types import RoomID, EventID, MessageEventContent
from mautrix.bridge.commands import (HelpSection, CommandEvent as BaseCommandEvent,
CommandHandler as BaseCommandHandler,
CommandProcessor as BaseCommandProcessor,
@@ -42,10 +42,10 @@ class CommandEvent(BaseCommandEvent):
sender: u.User
def __init__(self, processor: 'CommandProcessor', room_id: RoomID, event_id: EventID,
sender: u.User, command: str, args: List[str], is_management: bool,
is_portal: bool) -> None:
super().__init__(processor, room_id, event_id, sender, command, args, is_management,
is_portal)
sender: u.User, command: str, args: List[str], content: MessageEventContent,
is_management: bool, is_portal: bool) -> None:
super().__init__(processor, room_id, event_id, sender, command, args, content,
is_management, is_portal)
self.bridge = processor.bridge
self.tgbot = processor.tgbot
self.config = processor.config
@@ -69,7 +69,7 @@ class CommandHandler(BaseCommandHandler):
def __init__(self, handler: Callable[[CommandEvent], Awaitable[EventID]],
management_only: bool, name: str, help_text: str, help_args: str,
help_section: HelpSection, needs_auth: bool, needs_puppeting: bool,
needs_matrix_puppeting: bool, needs_admin: bool,) -> None:
needs_matrix_puppeting: bool, needs_admin: bool) -> None:
super().__init__(handler, management_only, name, help_text, help_args, help_section,
needs_auth=needs_auth, needs_puppeting=needs_puppeting,
needs_matrix_puppeting=needs_matrix_puppeting, needs_admin=needs_admin)
+25 -9
View File
@@ -13,7 +13,7 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, cast
import logging
import codecs
import base64
@@ -23,13 +23,13 @@ from telethon.errors import (InviteHashInvalidError, InviteHashExpiredError, Opt
UserAlreadyParticipantError, ChatIdInvalidError)
from telethon.tl.patched import Message
from telethon.tl.types import (User as TLUser, TypeUpdates, MessageMediaGame, MessageMediaPoll,
TypePeer)
TypeInputPeer)
from telethon.tl.types.messages import BotCallbackAnswer
from telethon.tl.functions.messages import (ImportChatInviteRequest, CheckChatInviteRequest,
GetBotCallbackAnswerRequest, SendVoteRequest)
from telethon.tl.functions.channels import JoinChannelRequest
from mautrix.types import EventID
from mautrix.types import EventID, Format
from ... import puppet as pu, portal as po
from ...abstract_user import AbstractUser
@@ -38,6 +38,22 @@ from ...types import TelegramID
from ...commands import command_handler, CommandEvent, SECTION_MISC, SECTION_CREATING_PORTALS
@command_handler(needs_auth=False,
help_section=SECTION_MISC, help_args="<_caption_>",
help_text="Set a caption for the next image you send")
async def caption(evt: CommandEvent) -> EventID:
if len(evt.args) == 0:
return await evt.reply("**Usage:** `$cmdprefix+sp caption <caption>`")
prefix = f"{evt.command_prefix} caption "
if evt.content.format == Format.HTML:
evt.content.formatted_body = evt.content.formatted_body.replace(prefix, "", 1)
evt.content.body = evt.content.body.replace(prefix, "", 1)
evt.sender.command_status = {"caption": evt.content}
return await evt.reply("Your next image or file will be sent with that caption. "
"Use `$cmdprefix+sp cancel` to cancel the caption.")
@command_handler(help_section=SECTION_MISC,
help_args="[_-r|--remote_] <_query_>",
help_text="Search your contacts or the Telegram servers for users.")
@@ -76,8 +92,7 @@ async def search(evt: CommandEvent) -> EventID:
return await evt.reply("\n".join(reply))
@command_handler(help_section=SECTION_CREATING_PORTALS,
help_args="<_identifier_>",
@command_handler(help_section=SECTION_CREATING_PORTALS, help_args="<_identifier_>",
help_text="Open a private chat with the given Telegram user. The identifier is "
"either the internal user ID, the username or the phone number. "
"**N.B.** The phone numbers you start chats with must already be in "
@@ -183,7 +198,7 @@ class MessageIDError(ValueError):
async def _parse_encoded_msgid(user: AbstractUser, enc_id: str, type_name: str
) -> Tuple[TypePeer, Message]:
) -> Tuple[TypeInputPeer, Message]:
try:
enc_id += (4 - len(enc_id) % 4) * "="
enc_id = base64.b64decode(enc_id)
@@ -212,7 +227,7 @@ async def _parse_encoded_msgid(user: AbstractUser, enc_id: str, type_name: str
msg = await user.client.get_messages(entity=peer, ids=msg_id)
if not msg:
raise MessageIDError(f"Invalid {type_name} ID (message not found)")
return peer, msg
return peer, cast(Message, msg)
@command_handler(help_section=SECTION_MISC,
@@ -234,12 +249,13 @@ async def play(evt: CommandEvent) -> EventID:
if not isinstance(msg.media, MessageMediaGame):
return await evt.reply("Invalid play ID (message doesn't look like a game)")
game = await evt.sender.client(GetBotCallbackAnswerRequest(peer=peer, msg_id=msg.id, game=True))
game = await evt.sender.client(
GetBotCallbackAnswerRequest(peer=peer, msg_id=msg.id, game=True))
if not isinstance(game, BotCallbackAnswer):
return await evt.reply("Game request response invalid")
return await evt.reply(f"Click [here]({game.url}) to play {msg.media.game.title}:\n\n"
f"{msg.media.game.description}")
f"{msg.media.game.description}")
@command_handler(help_section=SECTION_MISC,
+1
View File
@@ -101,6 +101,7 @@ class Config(BaseBridgeConfig):
copy("bridge.inline_images")
copy("bridge.image_as_file_size")
copy("bridge.max_document_size")
copy("bridge.parallel_file_transfer")
copy("bridge.federate_rooms")
copy("bridge.animated_sticker.target")
copy("bridge.animated_sticker.args")
+3 -3
View File
@@ -30,9 +30,9 @@ class TelegramFile(Base):
mime_type: str = Column(String)
was_converted: bool = Column(Boolean)
timestamp: int = Column(BigInteger)
size: int = Column(Integer, nullable=True)
width: int = Column(Integer, nullable=True)
height: int = Column(Integer, nullable=True)
size: Optional[int] = Column(Integer, nullable=True)
width: Optional[int] = Column(Integer, nullable=True)
height: Optional[int] = Column(Integer, nullable=True)
thumbnail_id: str = Column("thumbnail", String, ForeignKey("telegram_file.id"), nullable=True)
thumbnail: Optional['TelegramFile'] = None
+54 -48
View File
@@ -17,7 +17,6 @@ from typing import Awaitable, Dict, List, Optional, Tuple, Union, Any, TYPE_CHEC
from html import escape as escape_html
from string import Template
from abc import ABC
import mimetypes
import magic
@@ -32,7 +31,7 @@ from telethon.tl.types import (
DocumentAttributeFilename, DocumentAttributeImageSize, GeoPoint,
InputChatUploadedPhoto, MessageActionChatEditPhoto, MessageMediaGeo,
SendMessageCancelAction, SendMessageTypingAction, TypeInputPeer, TypeMessageEntity,
UpdateNewMessage, InputMediaUploadedDocument)
UpdateNewMessage, InputMediaUploadedDocument, InputMediaUploadedPhoto)
from mautrix.types import (EventID, RoomID, UserID, ContentURI, MessageType, MessageEventContent,
TextMessageEventContent, MediaMessageEventContent, Format,
@@ -41,7 +40,7 @@ from mautrix.bridge import BasePortal as MautrixBasePortal
from ..types import TelegramID
from ..db import Message as DBMessage
from ..util import sane_mimetypes
from ..util import sane_mimetypes, parallel_transfer_to_telegram
from ..context import Context
from .. import puppet as p, user as u, formatter, util
from .base import BasePortal
@@ -57,19 +56,6 @@ config: Optional['Config'] = None
class PortalMatrix(BasePortal, MautrixBasePortal, ABC):
@staticmethod
def _get_file_meta(body: str, mime: str) -> str:
try:
current_extension = body[body.rindex("."):].lower()
body = body[:body.rindex(".")]
if mimetypes.types_map[current_extension] == mime:
return body + current_extension
except (ValueError, KeyError):
pass
if mime:
return f"matrix_upload{sane_mimetypes.guess_extension(mime)}"
return ""
async def _get_state_change_message(self, event: str, user: 'u.User', **kwargs: Any
) -> Optional[str]:
tpl = self.get_config(f"state_event_formats.{event}")
@@ -183,7 +169,7 @@ class PortalMatrix(BasePortal, MautrixBasePortal, ABC):
async def _apply_msg_format(self, sender: 'u.User', content: MessageEventContent
) -> None:
if isinstance(content, TextMessageEventContent) and content.format != Format.HTML:
if not isinstance(content, TextMessageEventContent) or content.format != Format.HTML:
content.format = Format.HTML
content.formatted_body = escape_html(content.body).replace("\n", "<br/>")
@@ -193,14 +179,9 @@ class PortalMatrix(BasePortal, MautrixBasePortal, ABC):
tpl_args = dict(sender_mxid=sender.mxid,
sender_username=sender.mxid_localpart,
sender_displayname=escape_html(displayname),
body=content.body)
if isinstance(content, TextMessageEventContent):
tpl_args["formatted_body"] = content.formatted_body
tpl_args["message"] = content.formatted_body
content.formatted_body = Template(tpl).safe_substitute(tpl_args)
else:
tpl_args["message"] = content.body
content.body = Template(tpl).safe_substitute(tpl_args)
message=content.formatted_body,
body=content.body, formatted_body=content.formatted_body)
content.formatted_body = Template(tpl).safe_substitute(tpl_args)
async def _apply_emote_format(self, sender: 'u.User',
content: TextMessageEventContent) -> None:
@@ -262,42 +243,55 @@ class PortalMatrix(BasePortal, MautrixBasePortal, ABC):
async def _handle_matrix_file(self, sender_id: TelegramID, event_id: EventID,
space: TelegramID, client: 'MautrixTelegramClient',
content: MediaMessageEventContent, reply_to: TelegramID) -> None:
file = await self.main_intent.download_media(content.url)
content: MediaMessageEventContent, reply_to: TelegramID,
caption: TextMessageEventContent = None) -> None:
mime = content.info.mimetype
w, h = content.info.width, content.info.height
file_name = content["net.maunium.telegram.internal.filename"]
max_image_size = config["bridge.image_as_file_size"] * 1000 ** 2
if content.msgtype == MessageType.STICKER:
if mime != "image/gif":
mime, file, w, h = util.convert_image(file, source_mime=mime, target_type="webp")
else:
# Remove sticker description
content["net.maunium.telegram.internal.filename"] = "sticker.gif"
content.body = ""
if config["bridge.parallel_file_transfer"]:
file_handle, file_size = await parallel_transfer_to_telegram(client, self.main_intent,
content.url, sender_id)
else:
file = await self.main_intent.download_media(content.url)
if content.msgtype == MessageType.STICKER:
if mime != "image/gif":
mime, file, w, h = util.convert_image(file, source_mime=mime,
target_type="webp")
else:
# Remove sticker description
file_name = "sticker.gif"
file_handle = await client.upload_file(file)
file_size = len(file)
file_handle.name = file_name
file_name = self._get_file_meta(content["net.maunium.telegram.internal.filename"], mime)
attributes = [DocumentAttributeFilename(file_name=file_name)]
if w and h:
attributes.append(DocumentAttributeImageSize(w, h))
caption = content.body if content.body.lower() != file_name.lower() else None
if (mime == "image/png" or mime == "image/jpeg") and file_size < max_image_size:
media = InputMediaUploadedPhoto(file_handle)
else:
media = InputMediaUploadedDocument(file=file_handle, attributes=attributes,
mime_type=mime or "application/octet-stream")
caption, entities = self._matrix_event_to_entities(caption) if caption else (None, None)
media = await client.upload_file_direct(
file, mime, attributes, file_name,
max_image_size=config["bridge.image_as_file_size"] * 1000 ** 2)
async with self.send_lock(sender_id):
if await self._matrix_document_edit(client, content, space, caption, media, event_id):
return
try:
response = await client.send_media(self.peer, media, reply_to=reply_to,
caption=caption)
caption=caption, entities=entities)
except (PhotoInvalidDimensionsError, PhotoSaveFileInvalidError, PhotoExtInvalidError):
media = InputMediaUploadedDocument(file=media.file, mime_type=mime,
attributes=attributes)
response = await client.send_media(self.peer, media, reply_to=reply_to,
caption=caption)
caption=caption, entities=entities)
self._add_telegram_message_to_db(event_id, space, 0, response)
async def _matrix_document_edit(self, client: 'MautrixTelegramClient',
@@ -364,8 +358,20 @@ class PortalMatrix(BasePortal, MautrixBasePortal, ABC):
else (sender.tgid if logged_in else self.bot.tgid))
reply_to = formatter.matrix_reply_to_telegram(content, space, room_id=self.mxid)
content["net.maunium.telegram.internal.filename"] = content.body
await self._pre_process_matrix_message(sender, not logged_in, content)
media = (MessageType.STICKER, MessageType.IMAGE, MessageType.FILE, MessageType.AUDIO,
MessageType.VIDEO)
caption_content = None
if content.msgtype in media:
content["net.maunium.telegram.internal.filename"] = content.body
try:
caption_content: MessageEventContent = sender.command_status["caption"]
caption_content.msgtype = content.msgtype
reply_to = reply_to or formatter.matrix_reply_to_telegram(caption_content, space,
room_id=self.mxid)
sender.command_status = None
except (KeyError, TypeError):
pass
await self._pre_process_matrix_message(sender, not logged_in, caption_content or content)
if content.msgtype == MessageType.NOTICE:
bridge_notices = self.get_config("bridge_notices.default")
@@ -378,9 +384,9 @@ class PortalMatrix(BasePortal, MautrixBasePortal, ABC):
elif content.msgtype == MessageType.LOCATION:
await self._handle_matrix_location(sender_id, event_id, space, client, content,
reply_to)
elif content.msgtype in (MessageType.STICKER, MessageType.IMAGE, MessageType.FILE,
MessageType.AUDIO, MessageType.VIDEO):
await self._handle_matrix_file(sender_id, event_id, space, client, content, reply_to)
elif content.msgtype in media:
await self._handle_matrix_file(sender_id, event_id, space, client, content, reply_to,
caption_content)
else:
self.log.debug(f"Unhandled Matrix event: {content}")
+3 -1
View File
@@ -181,9 +181,11 @@ class PortalTelegram(BasePortal, ABC):
self.log.debug(f"Unsupported thumbnail type {type(thumb_size)}")
thumb_loc = None
thumb_size = None
parallel_id = source.tgid if config["bridge.parallel_file_transfer"] else None
file = await util.transfer_file_to_matrix(source.client, intent, document, thumb_loc,
is_sticker=attrs.is_sticker,
tgs_convert=config["bridge.animated_sticker"])
tgs_convert=config["bridge.animated_sticker"],
filename=attrs.name, parallel_id=parallel_id)
if not file:
return None
+1 -1
View File
@@ -219,7 +219,7 @@ class User(AbstractUser, BaseUser):
else:
portal = po.Portal.get_by_entity(message.to_id, receiver_id=self.tgid)
elif isinstance(update, UpdateShortChatMessage):
portal = po.Portal.get_by_tgid(TelegramID(update.chat_id), peer_type="chat")
portal = po.Portal.get_by_tgid(TelegramID(update.chat_id))
elif isinstance(update, UpdateShortMessage):
portal = po.Portal.get_by_tgid(TelegramID(update.user_id), self.tgid, "user")
else:
+1
View File
@@ -1,4 +1,5 @@
from .file_transfer import transfer_file_to_matrix, convert_image
from .parallel_file_transfer import parallel_transfer_to_telegram
from .format_duration import format_duration
from .recursive_dict import recursive_del, recursive_set, recursive_get
from .color_log import ColorFormatter
+44 -34
View File
@@ -34,6 +34,7 @@ from mautrix.appservice import IntentAPI
from ..tgclient import MautrixTelegramClient
from ..db import TelegramFile as DBTelegramFile
from ..util import sane_mimetypes
from .parallel_file_transfer import parallel_transfer_to_matrix
try:
from PIL import Image
@@ -129,7 +130,7 @@ async def transfer_thumbnail_to_matrix(client: MautrixTelegramClient, intent: In
return db_file
video_ext = sane_mimetypes.guess_extension(mime)
if VideoFileClip and video_ext:
if VideoFileClip and video_ext and video:
try:
file, width, height = _read_video_thumbnail(video, video_ext, frame_ext="png")
except OSError:
@@ -161,7 +162,8 @@ TypeThumbnail = Optional[Union[TypeLocation, TypePhotoSize]]
async def transfer_file_to_matrix(client: MautrixTelegramClient, intent: IntentAPI,
location: TypeLocation, thumbnail: TypeThumbnail = None,
is_sticker: bool = False, tgs_convert: Optional[dict] = None
is_sticker: bool = False, tgs_convert: Optional[dict] = None,
filename: Optional[str] = None, parallel_id: Optional[int] = None
) -> Optional[DBTelegramFile]:
location_id = _location_to_id(location)
if not location_id:
@@ -178,53 +180,61 @@ async def transfer_file_to_matrix(client: MautrixTelegramClient, intent: IntentA
transfer_locks[location_id] = lock
async with lock:
return await _unlocked_transfer_file_to_matrix(client, intent, location_id, location,
thumbnail, is_sticker, tgs_convert)
thumbnail, is_sticker, tgs_convert,
filename, parallel_id)
async def _unlocked_transfer_file_to_matrix(client: MautrixTelegramClient, intent: IntentAPI,
loc_id: str, location: TypeLocation,
thumbnail: TypeThumbnail, is_sticker: bool,
tgs_convert: Optional[dict]
tgs_convert: Optional[dict], filename: Optional[str],
parallel_id: Optional[int]
) -> Optional[DBTelegramFile]:
db_file = DBTelegramFile.get(loc_id)
if db_file:
return db_file
try:
file = await client.download_file(location)
except (LocationInvalidError, FileIdInvalidError):
return None
except (AuthBytesInvalidError, AuthKeyInvalidError, SecurityError) as e:
log.exception(f"{e.__class__.__name__} while downloading a file.")
return None
if parallel_id and isinstance(location, Document) and (not is_sticker or not tgs_convert):
db_file = await parallel_transfer_to_matrix(client, intent, loc_id, location, filename,
parallel_id)
mime_type = location.mime_type
file = None
else:
try:
file = await client.download_file(location)
except (LocationInvalidError, FileIdInvalidError):
return None
except (AuthBytesInvalidError, AuthKeyInvalidError, SecurityError) as e:
log.exception(f"{e.__class__.__name__} while downloading a file.")
return None
width, height = None, None
mime_type = magic.from_buffer(file, mime=True)
width, height = None, None
mime_type = magic.from_buffer(file, mime=True)
image_converted = False
# A weird bug in alpine/magic makes it return application/octet-stream for gzips...
if is_sticker and tgs_convert and (mime_type == "application/gzip" or (
mime_type == "application/octet-stream"
and magic.from_buffer(file).startswith("gzip"))):
mime_type, file, width, height = await convert_tgs_to(
file, tgs_convert["target"], **tgs_convert["args"])
thumbnail = None
image_converted = mime_type != "application/gzip"
image_converted = False
# A weird bug in alpine/magic makes it return application/octet-stream for gzips...
if is_sticker and tgs_convert and (mime_type == "application/gzip" or (
mime_type == "application/octet-stream"
and magic.from_buffer(file).startswith("gzip"))):
mime_type, file, width, height = await convert_tgs_to(
file, tgs_convert["target"], **tgs_convert["args"])
thumbnail = None
image_converted = mime_type != "application/gzip"
if mime_type == "image/webp":
new_mime_type, file, width, height = convert_image(
file, source_mime="image/webp", target_type="png",
thumbnail_to=(256, 256) if is_sticker else None)
image_converted = new_mime_type != mime_type
mime_type = new_mime_type
thumbnail = None
if mime_type == "image/webp":
new_mime_type, file, width, height = convert_image(
file, source_mime="image/webp", target_type="png",
thumbnail_to=(256, 256) if is_sticker else None)
image_converted = new_mime_type != mime_type
mime_type = new_mime_type
thumbnail = None
content_uri = await intent.upload_media(file, mime_type)
content_uri = await intent.upload_media(file, mime_type)
db_file = DBTelegramFile(id=loc_id, mxc=content_uri,
mime_type=mime_type, was_converted=image_converted,
timestamp=int(time.time()), size=len(file),
width=width, height=height)
db_file = DBTelegramFile(id=loc_id, mxc=content_uri,
mime_type=mime_type, was_converted=image_converted,
timestamp=int(time.time()), size=len(file),
width=width, height=height)
if thumbnail and (mime_type.startswith("video/") or mime_type == "image/gif"):
if isinstance(thumbnail, (PhotoSize, PhotoCachedSize)):
thumbnail = thumbnail.location
@@ -0,0 +1,298 @@
# mautrix-telegram - A Matrix-Telegram puppeting bridge
# Copyright (C) 2019 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional, List, AsyncGenerator, Union, Awaitable, DefaultDict, Tuple
from collections import defaultdict
import hashlib
import asyncio
import logging
import time
import math
from aiohttp import ClientResponse
from telethon.tl.types import (Document, InputFileLocation, InputDocumentFileLocation,
InputPhotoFileLocation, InputPeerPhotoFileLocation, TypeInputFile,
InputFileBig, InputFile)
from telethon.tl.functions.auth import ExportAuthorizationRequest, ImportAuthorizationRequest
from telethon.tl.functions.upload import (GetFileRequest, SaveFilePartRequest,
SaveBigFilePartRequest)
from telethon.network import MTProtoSender
from telethon.crypto import AuthKey
from telethon import utils, helpers
from mautrix.appservice import IntentAPI
from mautrix.types import ContentURI
from ..tgclient import MautrixTelegramClient
from ..db import TelegramFile as DBTelegramFile
log: logging.Logger = logging.getLogger("mau.util")
TypeLocation = Union[Document, InputDocumentFileLocation, InputPeerPhotoFileLocation,
InputFileLocation, InputPhotoFileLocation]
class DownloadSender:
sender: MTProtoSender
request: GetFileRequest
remaining: int
stride: int
def __init__(self, sender: MTProtoSender, file: TypeLocation, offset: int, limit: int,
stride: int, count: int) -> None:
self.sender = sender
self.request = GetFileRequest(file, offset=offset, limit=limit)
self.stride = stride
self.remaining = count
async def next(self) -> Optional[bytes]:
if not self.remaining:
return None
result = await self.sender.send(self.request)
self.remaining -= 1
self.request.offset += self.stride
return result.bytes
def disconnect(self) -> Awaitable[None]:
return self.sender.disconnect()
class UploadSender:
sender: MTProtoSender
request: Union[SaveFilePartRequest, SaveBigFilePartRequest]
part_count: int
stride: int
previous: Optional[asyncio.Task]
loop: asyncio.AbstractEventLoop
def __init__(self, sender: MTProtoSender, file_id: int, part_count: int, big: bool, index: int,
stride: int, loop: asyncio.AbstractEventLoop) -> None:
self.sender = sender
self.part_count = part_count
if big:
self.request = SaveBigFilePartRequest(file_id, index, part_count, b"")
else:
self.request = SaveFilePartRequest(file_id, index, b"")
self.stride = stride
self.previous = None
self.loop = loop
async def next(self, data: bytes) -> None:
if self.previous:
await self.previous
self.previous = self.loop.create_task(self._next(data))
async def _next(self, data: bytes) -> None:
self.request.bytes = data
log.debug(f"Sending file part {self.request.file_part}/{self.part_count}"
f" with {len(data)} bytes")
await self.sender.send(self.request)
self.request.file_part += self.stride
async def disconnect(self) -> None:
if self.previous:
await self.previous
return await self.sender.disconnect()
class ParallelTransferrer:
client: MautrixTelegramClient
loop: asyncio.AbstractEventLoop
dc_id: int
senders: Optional[List[Union[DownloadSender, UploadSender]]]
auth_key: AuthKey
upload_ticker: int
def __init__(self, client: MautrixTelegramClient, dc_id: Optional[int] = None) -> None:
self.client = client
self.loop = self.client.loop
self.dc_id = dc_id or self.client.session.dc_id
self.auth_key = (None if dc_id and self.client.session.dc_id != dc_id
else self.client.session.auth_key)
self.senders = None
self.upload_ticker = 0
async def _cleanup(self) -> None:
await asyncio.gather(*[sender.disconnect() for sender in self.senders])
self.senders = None
@staticmethod
def _get_connection_count(file_size: int, max_count: int = 20,
full_size: int = 100 * 1024 * 1024) -> int:
if file_size > full_size:
return max_count
return math.ceil((file_size / full_size) * max_count)
async def _init_download(self, connections: int, file: TypeLocation, part_count: int,
part_size: int) -> None:
minimum, remainder = divmod(part_count, connections)
def get_part_count() -> int:
nonlocal remainder
if remainder > 0:
remainder -= 1
return minimum + 1
return minimum
# The first cross-DC sender will export+import the authorization, so we always create it
# before creating any other senders.
self.senders = [
await self._create_download_sender(file, 0, part_size, connections * part_size,
get_part_count()),
*await asyncio.gather(
*[self._create_download_sender(file, i, part_size, connections * part_size,
get_part_count())
for i in range(1, connections)])
]
async def _create_download_sender(self, file: TypeLocation, index: int, part_size: int,
stride: int,
part_count: int) -> DownloadSender:
return DownloadSender(await self._create_sender(), file, index * part_size, part_size,
stride, part_count)
async def _init_upload(self, connections: int, file_id: int, part_count: int, big: bool
) -> None:
self.senders = [
await self._create_upload_sender(file_id, part_count, big, 0, connections),
*await asyncio.gather(
*[self._create_upload_sender(file_id, part_count, big, i, connections)
for i in range(1, connections)])
]
async def _create_upload_sender(self, file_id: int, part_count: int, big: bool, index: int,
stride: int) -> UploadSender:
return UploadSender(await self._create_sender(), file_id, part_count, big, index, stride,
loop=self.loop)
async def _create_sender(self) -> MTProtoSender:
dc = await self.client._get_dc(self.dc_id)
sender = MTProtoSender(self.auth_key, self.loop, loggers=self.client._log)
await sender.connect(self.client._connection(dc.ip_address, dc.port, dc.id,
loop=self.loop, loggers=self.client._log,
proxy=self.client._proxy))
if not self.auth_key:
log.debug(f"Exporting auth to DC {self.dc_id}")
auth = await self.client(ExportAuthorizationRequest(self.dc_id))
req = self.client._init_with(ImportAuthorizationRequest(
id=auth.id, bytes=auth.bytes
))
await sender.send(req)
self.auth_key = sender.auth_key
return sender
async def init_upload(self, file_id: int, file_size: int, part_size_kb: Optional[float] = None,
connection_count: Optional[int] = None) -> Tuple[int, int, bool]:
connection_count = connection_count or self._get_connection_count(file_size)
part_size = (part_size_kb or utils.get_appropriated_part_size(file_size)) * 1024
part_count = (file_size + part_size - 1) // part_size
is_large = file_size > 10 * 1024 * 1024
await self._init_upload(connection_count, file_id, part_count, is_large)
return part_size, part_count, is_large
async def upload(self, part: bytes) -> None:
await self.senders[self.upload_ticker].next(part)
self.upload_ticker = (self.upload_ticker + 1) % len(self.senders)
async def finish_upload(self) -> None:
await self._cleanup()
async def download(self, file: TypeLocation, file_size: int,
part_size_kb: Optional[float] = None,
connection_count: Optional[int] = None) -> AsyncGenerator[bytes, None]:
connection_count = connection_count or self._get_connection_count(file_size)
part_size = (part_size_kb or utils.get_appropriated_part_size(file_size)) * 1024
part_count = math.ceil(file_size / part_size)
log.debug("Starting parallel download: "
f"{connection_count} {part_size} {part_count} {file!s}")
await self._init_download(connection_count, file, part_count, part_size)
part = 0
while part < part_count:
tasks = []
for sender in self.senders:
tasks.append(self.loop.create_task(sender.next()))
for task in tasks:
data = await task
if not data:
break
yield data
part += 1
log.debug(f"Part {part} downloaded")
log.debug("Parallel download finished, cleaning up connections")
await self._cleanup()
parallel_transfer_locks: DefaultDict[int, asyncio.Lock] = defaultdict(lambda: asyncio.Lock())
async def parallel_transfer_to_matrix(client: MautrixTelegramClient, intent: IntentAPI,
loc_id: str, location: TypeLocation, filename: str,
parallel_id: int) -> DBTelegramFile:
size = location.size
mime_type = location.mime_type
dc_id, location = utils.get_input_location(location)
# We lock the transfers because telegram has connection count limits
async with parallel_transfer_locks[parallel_id]:
downloader = ParallelTransferrer(client, dc_id)
content_uri = await intent.upload_media(downloader.download(location, size),
mime_type=mime_type, filename=filename, size=size)
return DBTelegramFile(id=loc_id, mxc=content_uri, mime_type=mime_type,
was_converted=False, timestamp=int(time.time()), size=size,
width=None, height=None)
async def _internal_transfer_to_telegram(client: MautrixTelegramClient, response: ClientResponse
) -> Tuple[TypeInputFile, int]:
file_id = helpers.generate_random_long()
file_size = response.content_length
hash_md5 = hashlib.md5()
uploader = ParallelTransferrer(client)
part_size, part_count, is_large = await uploader.init_upload(file_id, file_size)
buffer = bytearray()
async for data in response.content:
if not is_large:
hash_md5.update(data)
if len(buffer) == 0 and len(data) == part_size:
await uploader.upload(data)
continue
new_len = len(buffer) + len(data)
if new_len >= part_size:
cutoff = part_size - len(buffer)
buffer.extend(data[:cutoff])
await uploader.upload(bytes(buffer))
buffer.clear()
buffer.extend(data[cutoff:])
else:
buffer.extend(data)
if len(buffer) > 0:
await uploader.upload(bytes(buffer))
await uploader.finish_upload()
if is_large:
return InputFileBig(file_id, part_count, "upload"), file_size
else:
return InputFile(file_id, part_count, "upload", hash_md5.hexdigest()), file_size
async def parallel_transfer_to_telegram(client: MautrixTelegramClient, intent: IntentAPI,
uri: ContentURI, parallel_id: int
) -> Tuple[TypeInputFile, int]:
url = intent.api.get_download_url(uri)
async with parallel_transfer_locks[parallel_id]:
async with intent.api.session.get(url) as response:
return await _internal_transfer_to_telegram(client, response)
+1 -1
View File
@@ -32,7 +32,7 @@ setuptools.setup(
install_requires=[
"aiohttp>=3.0.1,<4",
"mautrix>=0.4.0.dev71,<0.5",
"mautrix>=0.4.0.dev75,<0.5",
"SQLAlchemy>=1.2.3,<2",
"alembic>=1.0.0,<2",
"commonmark>=0.8.1,<0.10",