diff --git a/mautrix_telegram/portal.py b/mautrix_telegram/portal.py index 86f4fe6c..8a64207c 100644 --- a/mautrix_telegram/portal.py +++ b/mautrix_telegram/portal.py @@ -76,6 +76,8 @@ class Portal: self._dedup_mxid = {} self._dedup_action = deque() + self._send_locks = {} + if tgid: self.by_tgid[self.tgid_full] = self if mxid: @@ -634,11 +636,32 @@ class Portal: message, entities = None, None return message, entities - async def _handle_matrix_text(self, client, message, reply_to): - message, entities = await self._matrix_event_to_entities(client, message) - return await client.send_message(self.peer, message, entities=entities, reply_to=reply_to) + def require_send_lock(self, id): + if id is None: + return None + try: + return self._send_locks[id] + except KeyError: + self._send_locks[id] = asyncio.Lock() + return self._send_locks[id] - async def _handle_matrix_file(self, client, message, reply_to): + def optional_send_lock(self, id): + if id is None: + return None + try: + return self._send_locks[id] + except KeyError: + return None + + async def _handle_matrix_text(self, sender_id, event_id, space, client, message, reply_to): + message, entities = await self._matrix_event_to_entities(client, message) + + lock = self.require_send_lock(sender_id) + async with lock: + response = await client.send_message(self.peer, message, entities=entities, reply_to=reply_to) + self._add_telegram_message_to_db(event_id, space, response) + + async def _handle_matrix_file(self, sender_id, event_id, space, client, message, reply_to): file = await self.main_intent.download_file(message["url"]) info = message["info"] @@ -651,11 +674,14 @@ class Portal: attributes.append(DocumentAttributeImageSize(w=info["w"], h=info["h"])) caption = message["body"] if message["body"] != file_name else None - return await client.send_file(self.peer, file, mime, caption=caption, - attributes=attributes, file_name=file_name, - reply_to=reply_to) - async def _handle_matrix_location(self, client, message, reply_to): + media = await client.upload_file(file, mime, attributes, file_name) + lock = self.require_send_lock(sender_id) + async with lock: + response = await client.send_media(self.peer, media, reply_to=reply_to, caption=caption) + self._add_telegram_message_to_db(event_id, space, response) + + async def _handle_matrix_location(self, sender_id, event_id, space, client, message, reply_to): try: lat, long = message["geo_uri"][len("geo:"):].split(",") lat, long = float(lat), float(long) @@ -664,11 +690,26 @@ class Portal: return None message, entities = await self._matrix_event_to_entities(client, message) media = MessageMediaGeo(geo=GeoPoint(lat, long)) - return await client.send_media(self.peer, media, reply_to=reply_to, caption=message, - entities=entities) + + lock = self.require_send_lock(sender_id) + async with lock: + response = await client.send_media(self.peer, media, reply_to=reply_to, caption=message, + entities=entities) + self._add_telegram_message_to_db(event_id, space, response) + + def _add_telegram_message_to_db(self, event_id, space, response): + self.log.debug("Handled Matrix message: %s", response) + self.is_duplicate(response, (event_id, space)) + self.db.add(DBMessage( + tgid=response.id, + tg_space=space, + mx_room=self.mxid, + mxid=event_id)) + self.db.commit() async def handle_matrix_message(self, sender, message, event_id): client = sender.client if sender.logged_in else self.bot.client + sender_id = sender.tgid if sender.logged_in else self.bot.tgid space = (self.tgid if self.peer_type == "channel" # Channels have their own ID space else (sender.tgid if sender.logged_in else self.bot.tgid)) reply_to = formatter.matrix_reply_to_telegram(message, space, room_id=self.mxid) @@ -678,26 +719,13 @@ class Portal: type = message["msgtype"] if type == "m.text" or (self.bridge_notices and type == "m.notice"): - response = await self._handle_matrix_text(client, message, reply_to) + await self._handle_matrix_text(sender_id, event_id, space, client, message, reply_to) elif type == "m.location": - response = await self._handle_matrix_location(client, message, reply_to) + await self._handle_matrix_location(sender_id, event_id, space, client, message, reply_to) elif type in ("m.image", "m.file", "m.audio", "m.video"): - response = await self._handle_matrix_file(client, message, reply_to) + await self._handle_matrix_file(sender_id, event_id, space, client, message, reply_to) else: self.log.debug("Unhandled Matrix event: %s", message) - response = None - - if not response: - return - - self.log.debug("Handled Matrix message: %s", response) - self.is_duplicate(response, (event_id, space)) - self.db.add(DBMessage( - tgid=response.id, - tg_space=space, - mx_room=self.mxid, - mxid=event_id)) - self.db.commit() async def handle_matrix_pin(self, sender, pinned_message): if self.peer_type != "channel": @@ -1073,7 +1101,13 @@ class Portal: self.log.debug("Edits as replies disabled, ignoring edit event...") return + lock = self.optional_send_lock(sender.tgid if sender else None) + if lock: + async with lock: + pass + tg_space = self.tgid if self.peer_type == "channel" else source.tgid + temporary_identifier = f"${random.randint(1000000000000,9999999999999)}TGBRIDGEDITEMP" duplicate_found = self.is_duplicate(evt, (temporary_identifier, tg_space), force_hash=True) if duplicate_found: @@ -1111,6 +1145,11 @@ class Portal: if not self.mxid: await self.create_matrix_room(source, invites=[source.mxid], update_if_exists=False) + lock = self.optional_send_lock(sender.tgid if sender else None) + if lock: + async with lock: + pass + tg_space = self.tgid if self.peer_type == "channel" else source.tgid temporary_identifier = f"${random.randint(1000000000000,9999999999999)}TGBRIDGETEMP" @@ -1122,6 +1161,7 @@ class Portal: DBMessage(tgid=evt.id, mx_room=self.mxid, mxid=mxid, tg_space=tg_space)) self.db.commit() return + allowed_media = (MessageMediaPhoto, MessageMediaDocument, MessageMediaGeo) media = evt.media if hasattr(evt, "media") and isinstance(evt.media, allowed_media) else None diff --git a/mautrix_telegram/tgclient.py b/mautrix_telegram/tgclient.py index 8c1cc660..406bd559 100644 --- a/mautrix_telegram/tgclient.py +++ b/mautrix_telegram/tgclient.py @@ -51,29 +51,23 @@ class MautrixTelegramClient(TelegramClient): return self._get_response_message(request, result) - async def send_file(self, entity, file, mime_type=None, caption=None, entities=None, - attributes=None, file_name=None, reply_to=None, **kwargs): - entity = await self.get_input_entity(entity) - reply_to = self._get_message_id(reply_to) - - file_handle = await self.upload_file(file, file_name=file_name, use_cache=False) + async def upload_file(self, file, mime_type=None, attributes=None, file_name=None): + file_handle = await super().upload_file(file, file_name=file_name, use_cache=False) if mime_type == "image/png" or mime_type == "image/jpeg": - media = InputMediaUploadedPhoto(file_handle) + return InputMediaUploadedPhoto(file_handle) else: attributes = attributes or [] attr_dict = {type(attr): attr for attr in attributes} - media = InputMediaUploadedDocument( + return InputMediaUploadedDocument( file=file_handle, mime_type=mime_type or "application/octet-stream", attributes=list(attr_dict.values())) - request = SendMediaRequest(entity, media, message=caption or "", entities=entities or [], - reply_to_msg_id=reply_to) - return self._get_response_message(request, await self(request)) - async def send_media(self, entity, media, caption=None, entities=None, reply_to=None): + entity = await self.get_input_entity(entity) + reply_to = self._get_message_id(reply_to) request = SendMediaRequest(entity, media, message=caption or "", entities=entities or [], reply_to_msg_id=reply_to) return self._get_response_message(request, await self(request))