Add more type hints
This commit is contained in:
@@ -14,15 +14,25 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Optional, Tuple, Union, Dict
|
||||
from io import BytesIO
|
||||
import time
|
||||
import logging
|
||||
import asyncio
|
||||
|
||||
import magic
|
||||
from sqlalchemy import orm
|
||||
from sqlalchemy.exc import IntegrityError, InvalidRequestError
|
||||
from sqlalchemy.orm.exc import FlushError
|
||||
|
||||
from telethon.tl.types import (Document, FileLocation, InputFileLocation,
|
||||
InputDocumentFileLocation, PhotoSize, PhotoCachedSize)
|
||||
from telethon.errors import *
|
||||
from mautrix_appservice import IntentAPI
|
||||
|
||||
from ..tgclient import MautrixTelegramClient
|
||||
from ..db import TelegramFile as DBTelegramFile
|
||||
|
||||
try:
|
||||
from PIL import Image
|
||||
except ImportError:
|
||||
@@ -36,20 +46,18 @@ try:
|
||||
except ImportError:
|
||||
VideoFileClip = random = string = os = mimetypes = None
|
||||
|
||||
from telethon.tl.types import (Document, FileLocation, InputFileLocation,
|
||||
InputDocumentFileLocation, PhotoSize, PhotoCachedSize)
|
||||
from telethon.errors import *
|
||||
log = logging.getLogger("mau.util") # type: logging.Logger
|
||||
|
||||
from ..db import TelegramFile as DBTelegramFile
|
||||
|
||||
log = logging.getLogger("mau.util")
|
||||
TypeLocation = Union[Document, InputDocumentFileLocation, FileLocation, InputFileLocation]
|
||||
|
||||
|
||||
def convert_image(file, source_mime="image/webp", target_type="png", thumbnail_to=None):
|
||||
def convert_image(file: bytes, source_mime: str = "image/webp", target_type: str = "png",
|
||||
thumbnail_to: Optional[Tuple[int, int]] = None
|
||||
) -> Tuple[str, bytes, Optional[int], Optional[int]]:
|
||||
if not Image:
|
||||
return source_mime, file, None, None
|
||||
try:
|
||||
image = Image.open(BytesIO(file)).convert("RGBA")
|
||||
image = Image.open(BytesIO(file)).convert("RGBA") # type: Image.Image
|
||||
if thumbnail_to:
|
||||
image.thumbnail(thumbnail_to, Image.ANTIALIAS)
|
||||
new_file = BytesIO()
|
||||
@@ -61,13 +69,14 @@ def convert_image(file, source_mime="image/webp", target_type="png", thumbnail_t
|
||||
return source_mime, file, None, None
|
||||
|
||||
|
||||
def _temp_file_name(ext):
|
||||
def _temp_file_name(ext: str) -> str:
|
||||
return ("/tmp/mxtg-video-"
|
||||
+ "".join(random.choice(string.ascii_uppercase + string.digits) for _ in range(10))
|
||||
+ ext)
|
||||
|
||||
|
||||
def _read_video_thumbnail(data, video_ext="mp4", frame_ext="png", max_size=(1024, 720)):
|
||||
def _read_video_thumbnail(data: bytes, video_ext: str = "mp4", frame_ext: str = "png",
|
||||
max_size: Tuple[int, int] = (1024, 720)) -> Tuple[bytes, int, int]:
|
||||
# We don't have any way to read the video from memory, so save it to disk.
|
||||
temp_file = _temp_file_name(video_ext)
|
||||
with open(temp_file, "wb") as file:
|
||||
@@ -90,21 +99,21 @@ def _read_video_thumbnail(data, video_ext="mp4", frame_ext="png", max_size=(1024
|
||||
return thumbnail_file.getvalue(), w, h
|
||||
|
||||
|
||||
def _location_to_id(location):
|
||||
def _location_to_id(location: TypeLocation) -> str:
|
||||
if isinstance(location, (Document, InputDocumentFileLocation)):
|
||||
return f"{location.id}-{location.version}"
|
||||
elif isinstance(location, (FileLocation, InputFileLocation)):
|
||||
return f"{location.volume_id}-{location.local_id}"
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
async def transfer_thumbnail_to_matrix(client, intent, thumbnail_loc, video, mime):
|
||||
async def transfer_thumbnail_to_matrix(client: MautrixTelegramClient, intent: IntentAPI,
|
||||
thumbnail_loc: TypeLocation, video: bytes,
|
||||
mime: str) -> Optional[DBTelegramFile]:
|
||||
if not Image or not VideoFileClip:
|
||||
return None
|
||||
|
||||
id = _location_to_id(thumbnail_loc)
|
||||
if not id:
|
||||
loc_id = _location_to_id(thumbnail_loc)
|
||||
if not loc_id:
|
||||
return None
|
||||
|
||||
video_ext = mimetypes.guess_extension(mime)
|
||||
@@ -121,36 +130,40 @@ async def transfer_thumbnail_to_matrix(client, intent, thumbnail_loc, video, mim
|
||||
|
||||
content_uri = await intent.upload_file(file, mime_type)
|
||||
|
||||
return DBTelegramFile(id=id, mxc=content_uri, mime_type=mime_type,
|
||||
return DBTelegramFile(id=loc_id, mxc=content_uri, mime_type=mime_type,
|
||||
was_converted=False, timestamp=int(time.time()), size=len(file),
|
||||
width=width, height=height)
|
||||
|
||||
|
||||
transfer_locks = {}
|
||||
transfer_locks_lock = asyncio.Lock()
|
||||
transfer_locks = {} # type: Dict[str, asyncio.Lock]
|
||||
|
||||
|
||||
async def transfer_file_to_matrix(db, client, intent, location, thumbnail=None, is_sticker=False):
|
||||
id = _location_to_id(location)
|
||||
if not id:
|
||||
async def transfer_file_to_matrix(db: orm.Session, client: MautrixTelegramClient, intent: IntentAPI,
|
||||
location: TypeLocation, thumbnail: Optional[TypeLocation] = None,
|
||||
is_sticker: bool = False) -> Optional[DBTelegramFile]:
|
||||
location_id = _location_to_id(location)
|
||||
if not location_id:
|
||||
return None
|
||||
|
||||
db_file = DBTelegramFile.query.get(id)
|
||||
db_file = DBTelegramFile.query.get(location_id)
|
||||
if db_file:
|
||||
return db_file
|
||||
|
||||
async with transfer_locks_lock:
|
||||
try:
|
||||
lock = transfer_locks[id]
|
||||
except KeyError:
|
||||
lock = asyncio.Lock()
|
||||
transfer_locks[id] = lock
|
||||
try:
|
||||
lock = transfer_locks[location_id]
|
||||
except KeyError:
|
||||
lock = asyncio.Lock()
|
||||
transfer_locks[location_id] = lock
|
||||
async with lock:
|
||||
return await _unlocked_transfer_file_to_matrix(db, client, intent, id, location, thumbnail, is_sticker)
|
||||
return await _unlocked_transfer_file_to_matrix(db, client, intent, location_id, location,
|
||||
thumbnail, is_sticker)
|
||||
|
||||
|
||||
async def _unlocked_transfer_file_to_matrix(db, client, intent, id, location, thumbnail, is_sticker):
|
||||
db_file = DBTelegramFile.query.get(id)
|
||||
async def _unlocked_transfer_file_to_matrix(db: orm.Session, client: MautrixTelegramClient,
|
||||
intent: IntentAPI, loc_id: str, location: TypeLocation,
|
||||
thumbnail: Optional[TypeLocation],
|
||||
is_sticker: bool) -> Optional[DBTelegramFile]:
|
||||
db_file = DBTelegramFile.query.get(loc_id)
|
||||
if db_file:
|
||||
return db_file
|
||||
|
||||
@@ -167,15 +180,16 @@ async def _unlocked_transfer_file_to_matrix(db, client, intent, id, location, th
|
||||
|
||||
image_converted = False
|
||||
if mime_type == "image/webp":
|
||||
new_mime_type, file, width, height = convert_image(file, source_mime="image/webp", target_type="png", thumbnail_to=(
|
||||
256, 256) if is_sticker else None)
|
||||
new_mime_type, file, width, height = convert_image(
|
||||
file, source_mime="image/webp", target_type="png",
|
||||
thumbnail_to=(256, 256) if is_sticker else None)
|
||||
image_converted = new_mime_type != mime_type
|
||||
mime_type = new_mime_type
|
||||
thumbnail = None
|
||||
|
||||
content_uri = await intent.upload_file(file, mime_type)
|
||||
|
||||
db_file = DBTelegramFile(id=id, mxc=content_uri,
|
||||
db_file = DBTelegramFile(id=loc_id, mxc=content_uri,
|
||||
mime_type=mime_type, was_converted=image_converted,
|
||||
timestamp=int(time.time()), size=len(file),
|
||||
width=width, height=height)
|
||||
|
||||
Reference in New Issue
Block a user