Add user+portal-specific lock for sending/receiving messages of authenticated users. Fixes #108

This commit is contained in:
Tulir Asokan
2018-04-28 21:20:21 +03:00
parent e231c3ec9a
commit 780edd7e57
2 changed files with 72 additions and 38 deletions
+66 -26
View File
@@ -76,6 +76,8 @@ class Portal:
self._dedup_mxid = {}
self._dedup_action = deque()
self._send_locks = {}
if tgid:
self.by_tgid[self.tgid_full] = self
if mxid:
@@ -634,11 +636,32 @@ class Portal:
message, entities = None, None
return message, entities
async def _handle_matrix_text(self, client, message, reply_to):
message, entities = await self._matrix_event_to_entities(client, message)
return await client.send_message(self.peer, message, entities=entities, reply_to=reply_to)
def require_send_lock(self, id):
if id is None:
return None
try:
return self._send_locks[id]
except KeyError:
self._send_locks[id] = asyncio.Lock()
return self._send_locks[id]
async def _handle_matrix_file(self, client, message, reply_to):
def optional_send_lock(self, id):
if id is None:
return None
try:
return self._send_locks[id]
except KeyError:
return None
async def _handle_matrix_text(self, sender_id, event_id, space, client, message, reply_to):
message, entities = await self._matrix_event_to_entities(client, message)
lock = self.require_send_lock(sender_id)
async with lock:
response = await client.send_message(self.peer, message, entities=entities, reply_to=reply_to)
self._add_telegram_message_to_db(event_id, space, response)
async def _handle_matrix_file(self, sender_id, event_id, space, client, message, reply_to):
file = await self.main_intent.download_file(message["url"])
info = message["info"]
@@ -651,11 +674,14 @@ class Portal:
attributes.append(DocumentAttributeImageSize(w=info["w"], h=info["h"]))
caption = message["body"] if message["body"] != file_name else None
return await client.send_file(self.peer, file, mime, caption=caption,
attributes=attributes, file_name=file_name,
reply_to=reply_to)
async def _handle_matrix_location(self, client, message, reply_to):
media = await client.upload_file(file, mime, attributes, file_name)
lock = self.require_send_lock(sender_id)
async with lock:
response = await client.send_media(self.peer, media, reply_to=reply_to, caption=caption)
self._add_telegram_message_to_db(event_id, space, response)
async def _handle_matrix_location(self, sender_id, event_id, space, client, message, reply_to):
try:
lat, long = message["geo_uri"][len("geo:"):].split(",")
lat, long = float(lat), float(long)
@@ -664,11 +690,26 @@ class Portal:
return None
message, entities = await self._matrix_event_to_entities(client, message)
media = MessageMediaGeo(geo=GeoPoint(lat, long))
return await client.send_media(self.peer, media, reply_to=reply_to, caption=message,
entities=entities)
lock = self.require_send_lock(sender_id)
async with lock:
response = await client.send_media(self.peer, media, reply_to=reply_to, caption=message,
entities=entities)
self._add_telegram_message_to_db(event_id, space, response)
def _add_telegram_message_to_db(self, event_id, space, response):
self.log.debug("Handled Matrix message: %s", response)
self.is_duplicate(response, (event_id, space))
self.db.add(DBMessage(
tgid=response.id,
tg_space=space,
mx_room=self.mxid,
mxid=event_id))
self.db.commit()
async def handle_matrix_message(self, sender, message, event_id):
client = sender.client if sender.logged_in else self.bot.client
sender_id = sender.tgid if sender.logged_in else self.bot.tgid
space = (self.tgid if self.peer_type == "channel" # Channels have their own ID space
else (sender.tgid if sender.logged_in else self.bot.tgid))
reply_to = formatter.matrix_reply_to_telegram(message, space, room_id=self.mxid)
@@ -678,26 +719,13 @@ class Portal:
type = message["msgtype"]
if type == "m.text" or (self.bridge_notices and type == "m.notice"):
response = await self._handle_matrix_text(client, message, reply_to)
await self._handle_matrix_text(sender_id, event_id, space, client, message, reply_to)
elif type == "m.location":
response = await self._handle_matrix_location(client, message, reply_to)
await self._handle_matrix_location(sender_id, event_id, space, client, message, reply_to)
elif type in ("m.image", "m.file", "m.audio", "m.video"):
response = await self._handle_matrix_file(client, message, reply_to)
await self._handle_matrix_file(sender_id, event_id, space, client, message, reply_to)
else:
self.log.debug("Unhandled Matrix event: %s", message)
response = None
if not response:
return
self.log.debug("Handled Matrix message: %s", response)
self.is_duplicate(response, (event_id, space))
self.db.add(DBMessage(
tgid=response.id,
tg_space=space,
mx_room=self.mxid,
mxid=event_id))
self.db.commit()
async def handle_matrix_pin(self, sender, pinned_message):
if self.peer_type != "channel":
@@ -1073,7 +1101,13 @@ class Portal:
self.log.debug("Edits as replies disabled, ignoring edit event...")
return
lock = self.optional_send_lock(sender.tgid if sender else None)
if lock:
async with lock:
pass
tg_space = self.tgid if self.peer_type == "channel" else source.tgid
temporary_identifier = f"${random.randint(1000000000000,9999999999999)}TGBRIDGEDITEMP"
duplicate_found = self.is_duplicate(evt, (temporary_identifier, tg_space), force_hash=True)
if duplicate_found:
@@ -1111,6 +1145,11 @@ class Portal:
if not self.mxid:
await self.create_matrix_room(source, invites=[source.mxid], update_if_exists=False)
lock = self.optional_send_lock(sender.tgid if sender else None)
if lock:
async with lock:
pass
tg_space = self.tgid if self.peer_type == "channel" else source.tgid
temporary_identifier = f"${random.randint(1000000000000,9999999999999)}TGBRIDGETEMP"
@@ -1122,6 +1161,7 @@ class Portal:
DBMessage(tgid=evt.id, mx_room=self.mxid, mxid=mxid, tg_space=tg_space))
self.db.commit()
return
allowed_media = (MessageMediaPhoto, MessageMediaDocument, MessageMediaGeo)
media = evt.media if hasattr(evt, "media") and isinstance(evt.media,
allowed_media) else None
+6 -12
View File
@@ -51,29 +51,23 @@ class MautrixTelegramClient(TelegramClient):
return self._get_response_message(request, result)
async def send_file(self, entity, file, mime_type=None, caption=None, entities=None,
attributes=None, file_name=None, reply_to=None, **kwargs):
entity = await self.get_input_entity(entity)
reply_to = self._get_message_id(reply_to)
file_handle = await self.upload_file(file, file_name=file_name, use_cache=False)
async def upload_file(self, file, mime_type=None, attributes=None, file_name=None):
file_handle = await super().upload_file(file, file_name=file_name, use_cache=False)
if mime_type == "image/png" or mime_type == "image/jpeg":
media = InputMediaUploadedPhoto(file_handle)
return InputMediaUploadedPhoto(file_handle)
else:
attributes = attributes or []
attr_dict = {type(attr): attr for attr in attributes}
media = InputMediaUploadedDocument(
return InputMediaUploadedDocument(
file=file_handle,
mime_type=mime_type or "application/octet-stream",
attributes=list(attr_dict.values()))
request = SendMediaRequest(entity, media, message=caption or "", entities=entities or [],
reply_to_msg_id=reply_to)
return self._get_response_message(request, await self(request))
async def send_media(self, entity, media, caption=None, entities=None, reply_to=None):
entity = await self.get_input_entity(entity)
reply_to = self._get_message_id(reply_to)
request = SendMediaRequest(entity, media, message=caption or "", entities=entities or [],
reply_to_msg_id=reply_to)
return self._get_response_message(request, await self(request))