diff --git a/mautrix_telegram/abstract_user.py b/mautrix_telegram/abstract_user.py
index dba567ec..cd363041 100644
--- a/mautrix_telegram/abstract_user.py
+++ b/mautrix_telegram/abstract_user.py
@@ -25,8 +25,8 @@ from telethon.network import (ConnectionTcpMTProxyRandomizedIntermediate, Connec
Connection)
from telethon.tl.patched import MessageService, Message
from telethon.tl.types import (
- Channel, Chat, MessageActionChannelMigrateFrom, PeerUser, TypeUpdate, UpdateChatPinnedMessage,
- UpdateChannelPinnedMessage, UpdateChatParticipantAdmin, UpdateChatParticipants, PeerChat,
+ Channel, Chat, MessageActionChannelMigrateFrom, PeerUser, TypeUpdate, UpdatePinnedMessages,
+ UpdatePinnedChannelMessages, UpdateChatParticipantAdmin, UpdateChatParticipants, PeerChat,
UpdateChatUserTyping, UpdateDeleteChannelMessages, UpdateNewMessage, UpdateDeleteMessages,
UpdateEditChannelMessage, UpdateEditMessage, UpdateNewChannelMessage, UpdateReadHistoryOutbox,
UpdateShortChatMessage, UpdateShortMessage, UpdateUserName, UpdateUserPhoto, UpdateUserStatus,
@@ -252,7 +252,7 @@ class AbstractUser(ABC):
await self.update_admin(update)
elif isinstance(update, UpdateChatParticipants):
await self.update_participants(update)
- elif isinstance(update, (UpdateChannelPinnedMessage, UpdateChatPinnedMessage)):
+ elif isinstance(update, (UpdatePinnedMessages, UpdatePinnedChannelMessages)):
await self.update_pinned_messages(update)
elif isinstance(update, (UpdateUserName, UpdateUserPhoto)):
await self.update_others_info(update)
@@ -263,14 +263,15 @@ class AbstractUser(ABC):
else:
self.log.trace("Unhandled update: %s", update)
- async def update_pinned_messages(self, update: Union[UpdateChannelPinnedMessage,
- UpdateChatPinnedMessage]) -> None:
- if isinstance(update, UpdateChatPinnedMessage):
- portal = po.Portal.get_by_tgid(TelegramID(update.chat_id))
+ async def update_pinned_messages(self, update: Union[UpdatePinnedMessages,
+ UpdatePinnedChannelMessages]) -> None:
+ if isinstance(update, UpdatePinnedMessages):
+ portal = po.Portal.get_by_entity(update.peer, receiver_id=self.tgid)
else:
portal = po.Portal.get_by_tgid(TelegramID(update.channel_id))
if portal and portal.mxid:
- await portal.receive_telegram_pin_id(update.id, self.tgid)
+ await portal.receive_telegram_pin_ids(update.messages, self.tgid,
+ remove=not update.pinned)
@staticmethod
async def update_participants(update: UpdateChatParticipants) -> None:
diff --git a/mautrix_telegram/db/message.py b/mautrix_telegram/db/message.py
index 8d6f5bd5..e39f2a72 100644
--- a/mautrix_telegram/db/message.py
+++ b/mautrix_telegram/db/message.py
@@ -13,7 +13,7 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-from typing import Optional, Iterator
+from typing import Optional, Iterator, List
from sqlalchemy import Column, UniqueConstraint, Integer, String, and_, func, desc, select
@@ -51,6 +51,12 @@ class Message(Base):
return cls._select_one_or_none(cls.c.tgid == tgid, cls.c.tg_space == tg_space,
cls.c.edit_index == edit_index)
+ @classmethod
+ def get_first_by_tgids(cls, tgids: List[TelegramID], tg_space: TelegramID
+ ) -> Iterator['Message']:
+ return cls._select_all(cls.c.tgid.in_(tgids), cls.c.tg_space == tg_space,
+ cls.c.edit_index == 0)
+
@classmethod
def count_spaces_by_mxid(cls, mxid: EventID, mx_room: RoomID) -> int:
rows = cls.db.execute(select([func.count(cls.c.tg_space)])
@@ -77,6 +83,12 @@ class Message(Base):
return cls._select_one_or_none(cls.c.mxid == mxid, cls.c.mx_room == mx_room,
cls.c.tg_space == tg_space)
+ @classmethod
+ def get_by_mxids(cls, mxids: List[EventID], mx_room: RoomID, tg_space: TelegramID
+ ) -> Iterator['Message']:
+ return cls._select_all(cls.c.mxid.in_(mxids), cls.c.mx_room == mx_room,
+ cls.c.tg_space == tg_space)
+
@classmethod
def update_by_tgid(cls, s_tgid: TelegramID, s_tg_space: TelegramID, s_edit_index: int,
**values) -> None:
diff --git a/mautrix_telegram/matrix.py b/mautrix_telegram/matrix.py
index 062cbe00..379e559e 100644
--- a/mautrix_telegram/matrix.py
+++ b/mautrix_telegram/matrix.py
@@ -283,13 +283,12 @@ class MatrixHandler(BaseMatrixHandler):
portal = po.Portal.get_by_mxid(room_id)
sender = await u.User.get_by_mxid(sender_mxid).ensure_started()
if await sender.has_full_access(allow_bot=True) and portal:
- events = new_events - old_events
- if len(events) > 0:
- # New event pinned, set that as pinned in Telegram.
- await portal.handle_matrix_pin(sender, EventID(events.pop()), event_id)
- elif len(new_events) == 0:
- # All pinned events removed, remove pinned event in Telegram.
- await portal.handle_matrix_pin(sender, None, event_id)
+ if not new_events:
+ await portal.handle_matrix_unpin_all(sender)
+ else:
+ changes = {event_id: event_id in new_events
+ for event_id in new_events ^ old_events}
+ await portal.handle_matrix_pin(sender, changes, event_id)
@staticmethod
async def handle_room_upgrade(room_id: RoomID, sender: UserID, new_room_id: RoomID,
diff --git a/mautrix_telegram/portal/base.py b/mautrix_telegram/portal/base.py
index 3f3d5106..a0ba7962 100644
--- a/mautrix_telegram/portal/base.py
+++ b/mautrix_telegram/portal/base.py
@@ -104,6 +104,7 @@ class BasePortal(MautrixBasePortal, ABC):
dedup: PortalDedup
send_lock: PortalSendLock
+ _pin_lock: asyncio.Lock
_db_instance: DBPortal
_main_intent: Optional[IntentAPI]
@@ -138,6 +139,7 @@ class BasePortal(MautrixBasePortal, ABC):
self.dedup = PortalDedup(self)
self.send_lock = PortalSendLock()
+ self._pin_lock = asyncio.Lock()
if tgid:
self.by_tgid[self.tgid_full] = self
diff --git a/mautrix_telegram/portal/matrix.py b/mautrix_telegram/portal/matrix.py
index e6adac82..3d601f25 100644
--- a/mautrix_telegram/portal/matrix.py
+++ b/mautrix_telegram/portal/matrix.py
@@ -13,7 +13,7 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-from typing import Awaitable, Dict, List, Optional, Tuple, Union, Any, TYPE_CHECKING
+from typing import Awaitable, Dict, List, Optional, Tuple, Union, Any, Set, TYPE_CHECKING
from html import escape as escape_html
from string import Template
from abc import ABC
@@ -22,11 +22,10 @@ import magic
from telethon.tl.functions.messages import (EditChatPhotoRequest, EditChatTitleRequest,
UpdatePinnedMessageRequest, SetTypingRequest,
- EditChatAboutRequest)
+ EditChatAboutRequest, UnpinAllMessagesRequest)
from telethon.tl.functions.channels import EditPhotoRequest, EditTitleRequest, JoinChannelRequest
-from telethon.errors import (ChatNotModifiedError, PhotoExtInvalidError,
- PhotoInvalidDimensionsError, PhotoSaveFileInvalidError,
- RPCError)
+from telethon.errors import (ChatNotModifiedError, PhotoExtInvalidError, MessageIdInvalidError,
+ PhotoInvalidDimensionsError, PhotoSaveFileInvalidError, RPCError)
from telethon.tl.patched import Message, MessageService
from telethon.tl.types import (
DocumentAttributeFilename, DocumentAttributeImageSize, GeoPoint,
@@ -432,23 +431,23 @@ class PortalMatrix(BasePortal, ABC):
else:
self.log.trace("Unhandled Matrix event: %s", content)
- async def handle_matrix_pin(self, sender: 'u.User', pinned_message: Optional[EventID],
+ async def handle_matrix_unpin_all(self, sender: 'u.User', pin_event_id: EventID) -> None:
+ await sender.client(UnpinAllMessagesRequest(peer=self.peer))
+ await self._send_delivery_receipt(pin_event_id)
+
+ async def handle_matrix_pin(self, sender: 'u.User', changes: Dict[EventID, bool],
pin_event_id: EventID) -> None:
- if self.peer_type != "chat" and self.peer_type != "channel":
- return
- try:
- if not pinned_message:
- await sender.client(UpdatePinnedMessageRequest(peer=self.peer, id=0))
- else:
- tg_space = self.tgid if self.peer_type == "channel" else sender.tgid
- message = DBMessage.get_by_mxid(pinned_message, self.mxid, tg_space)
- if message is None:
- self.log.warning(f"Could not find pinned {pinned_message} in {self.mxid}")
- return
- await sender.client(UpdatePinnedMessageRequest(peer=self.peer, id=message.tgid))
- await self._send_delivery_receipt(pin_event_id)
- except ChatNotModifiedError:
- pass
+ tg_space = self.tgid if self.peer_type == "channel" else sender.tgid
+ ids = {msg.mxid: msg.tgid
+ for msg in DBMessage.get_by_mxids(list(changes.keys()),
+ mx_room=self.mxid, tg_space=tg_space)}
+ for event_id, pinned in changes.items():
+ try:
+ await sender.client(UpdatePinnedMessageRequest(peer=self.peer, id=ids[event_id],
+ unpin=not pinned))
+ except (ChatNotModifiedError, MessageIdInvalidError, KeyError):
+ pass
+ await self._send_delivery_receipt(pin_event_id)
async def handle_matrix_deletion(self, deleter: 'u.User', event_id: EventID,
redaction_event_id: EventID) -> None:
diff --git a/mautrix_telegram/portal/telegram.py b/mautrix_telegram/portal/telegram.py
index f74eb52e..0e3cfb07 100644
--- a/mautrix_telegram/portal/telegram.py
+++ b/mautrix_telegram/portal/telegram.py
@@ -694,13 +694,20 @@ class PortalTelegram(BasePortal, ABC):
levels.users[puppet.mxid] = 50
await self.main_intent.set_power_levels(self.mxid, levels)
- async def receive_telegram_pin_id(self, msg_id: TelegramID, receiver: TelegramID) -> None:
- tg_space = receiver if self.peer_type != "channel" else self.tgid
- message = DBMessage.get_one_by_tgid(msg_id, tg_space) if msg_id != 0 else None
- if message:
- await self.main_intent.set_pinned_messages(self.mxid, [message.mxid])
- else:
- await self.main_intent.set_pinned_messages(self.mxid, [])
+ async def receive_telegram_pin_ids(self, msg_ids: List[TelegramID], receiver: TelegramID,
+ remove: bool) -> None:
+ async with self._pin_lock:
+ tg_space = receiver if self.peer_type != "channel" else self.tgid
+ previously_pinned = await self.main_intent.get_pinned_messages(self.mxid)
+ currently_pinned_dict = {event_id: True for event_id in previously_pinned}
+ for message in DBMessage.get_first_by_tgids(msg_ids, tg_space):
+ if remove:
+ currently_pinned_dict.pop(message.mxid, None)
+ else:
+ currently_pinned_dict[message.mxid] = True
+ currently_pinned = list(currently_pinned_dict.keys())
+ if currently_pinned != previously_pinned:
+ await self.main_intent.set_pinned_messages(self.mxid, currently_pinned)
async def set_telegram_admins_enabled(self, enabled: bool) -> None:
level = 50 if enabled else 10