Add support for multiple pins
This commit is contained in:
@@ -25,8 +25,8 @@ from telethon.network import (ConnectionTcpMTProxyRandomizedIntermediate, Connec
|
||||
Connection)
|
||||
from telethon.tl.patched import MessageService, Message
|
||||
from telethon.tl.types import (
|
||||
Channel, Chat, MessageActionChannelMigrateFrom, PeerUser, TypeUpdate, UpdateChatPinnedMessage,
|
||||
UpdateChannelPinnedMessage, UpdateChatParticipantAdmin, UpdateChatParticipants, PeerChat,
|
||||
Channel, Chat, MessageActionChannelMigrateFrom, PeerUser, TypeUpdate, UpdatePinnedMessages,
|
||||
UpdatePinnedChannelMessages, UpdateChatParticipantAdmin, UpdateChatParticipants, PeerChat,
|
||||
UpdateChatUserTyping, UpdateDeleteChannelMessages, UpdateNewMessage, UpdateDeleteMessages,
|
||||
UpdateEditChannelMessage, UpdateEditMessage, UpdateNewChannelMessage, UpdateReadHistoryOutbox,
|
||||
UpdateShortChatMessage, UpdateShortMessage, UpdateUserName, UpdateUserPhoto, UpdateUserStatus,
|
||||
@@ -252,7 +252,7 @@ class AbstractUser(ABC):
|
||||
await self.update_admin(update)
|
||||
elif isinstance(update, UpdateChatParticipants):
|
||||
await self.update_participants(update)
|
||||
elif isinstance(update, (UpdateChannelPinnedMessage, UpdateChatPinnedMessage)):
|
||||
elif isinstance(update, (UpdatePinnedMessages, UpdatePinnedChannelMessages)):
|
||||
await self.update_pinned_messages(update)
|
||||
elif isinstance(update, (UpdateUserName, UpdateUserPhoto)):
|
||||
await self.update_others_info(update)
|
||||
@@ -263,14 +263,15 @@ class AbstractUser(ABC):
|
||||
else:
|
||||
self.log.trace("Unhandled update: %s", update)
|
||||
|
||||
async def update_pinned_messages(self, update: Union[UpdateChannelPinnedMessage,
|
||||
UpdateChatPinnedMessage]) -> None:
|
||||
if isinstance(update, UpdateChatPinnedMessage):
|
||||
portal = po.Portal.get_by_tgid(TelegramID(update.chat_id))
|
||||
async def update_pinned_messages(self, update: Union[UpdatePinnedMessages,
|
||||
UpdatePinnedChannelMessages]) -> None:
|
||||
if isinstance(update, UpdatePinnedMessages):
|
||||
portal = po.Portal.get_by_entity(update.peer, receiver_id=self.tgid)
|
||||
else:
|
||||
portal = po.Portal.get_by_tgid(TelegramID(update.channel_id))
|
||||
if portal and portal.mxid:
|
||||
await portal.receive_telegram_pin_id(update.id, self.tgid)
|
||||
await portal.receive_telegram_pin_ids(update.messages, self.tgid,
|
||||
remove=not update.pinned)
|
||||
|
||||
@staticmethod
|
||||
async def update_participants(update: UpdateChatParticipants) -> None:
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Optional, Iterator
|
||||
from typing import Optional, Iterator, List
|
||||
|
||||
from sqlalchemy import Column, UniqueConstraint, Integer, String, and_, func, desc, select
|
||||
|
||||
@@ -51,6 +51,12 @@ class Message(Base):
|
||||
return cls._select_one_or_none(cls.c.tgid == tgid, cls.c.tg_space == tg_space,
|
||||
cls.c.edit_index == edit_index)
|
||||
|
||||
@classmethod
|
||||
def get_first_by_tgids(cls, tgids: List[TelegramID], tg_space: TelegramID
|
||||
) -> Iterator['Message']:
|
||||
return cls._select_all(cls.c.tgid.in_(tgids), cls.c.tg_space == tg_space,
|
||||
cls.c.edit_index == 0)
|
||||
|
||||
@classmethod
|
||||
def count_spaces_by_mxid(cls, mxid: EventID, mx_room: RoomID) -> int:
|
||||
rows = cls.db.execute(select([func.count(cls.c.tg_space)])
|
||||
@@ -77,6 +83,12 @@ class Message(Base):
|
||||
return cls._select_one_or_none(cls.c.mxid == mxid, cls.c.mx_room == mx_room,
|
||||
cls.c.tg_space == tg_space)
|
||||
|
||||
@classmethod
|
||||
def get_by_mxids(cls, mxids: List[EventID], mx_room: RoomID, tg_space: TelegramID
|
||||
) -> Iterator['Message']:
|
||||
return cls._select_all(cls.c.mxid.in_(mxids), cls.c.mx_room == mx_room,
|
||||
cls.c.tg_space == tg_space)
|
||||
|
||||
@classmethod
|
||||
def update_by_tgid(cls, s_tgid: TelegramID, s_tg_space: TelegramID, s_edit_index: int,
|
||||
**values) -> None:
|
||||
|
||||
@@ -283,13 +283,12 @@ class MatrixHandler(BaseMatrixHandler):
|
||||
portal = po.Portal.get_by_mxid(room_id)
|
||||
sender = await u.User.get_by_mxid(sender_mxid).ensure_started()
|
||||
if await sender.has_full_access(allow_bot=True) and portal:
|
||||
events = new_events - old_events
|
||||
if len(events) > 0:
|
||||
# New event pinned, set that as pinned in Telegram.
|
||||
await portal.handle_matrix_pin(sender, EventID(events.pop()), event_id)
|
||||
elif len(new_events) == 0:
|
||||
# All pinned events removed, remove pinned event in Telegram.
|
||||
await portal.handle_matrix_pin(sender, None, event_id)
|
||||
if not new_events:
|
||||
await portal.handle_matrix_unpin_all(sender)
|
||||
else:
|
||||
changes = {event_id: event_id in new_events
|
||||
for event_id in new_events ^ old_events}
|
||||
await portal.handle_matrix_pin(sender, changes, event_id)
|
||||
|
||||
@staticmethod
|
||||
async def handle_room_upgrade(room_id: RoomID, sender: UserID, new_room_id: RoomID,
|
||||
|
||||
@@ -104,6 +104,7 @@ class BasePortal(MautrixBasePortal, ABC):
|
||||
|
||||
dedup: PortalDedup
|
||||
send_lock: PortalSendLock
|
||||
_pin_lock: asyncio.Lock
|
||||
|
||||
_db_instance: DBPortal
|
||||
_main_intent: Optional[IntentAPI]
|
||||
@@ -138,6 +139,7 @@ class BasePortal(MautrixBasePortal, ABC):
|
||||
|
||||
self.dedup = PortalDedup(self)
|
||||
self.send_lock = PortalSendLock()
|
||||
self._pin_lock = asyncio.Lock()
|
||||
|
||||
if tgid:
|
||||
self.by_tgid[self.tgid_full] = self
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Awaitable, Dict, List, Optional, Tuple, Union, Any, TYPE_CHECKING
|
||||
from typing import Awaitable, Dict, List, Optional, Tuple, Union, Any, Set, TYPE_CHECKING
|
||||
from html import escape as escape_html
|
||||
from string import Template
|
||||
from abc import ABC
|
||||
@@ -22,11 +22,10 @@ import magic
|
||||
|
||||
from telethon.tl.functions.messages import (EditChatPhotoRequest, EditChatTitleRequest,
|
||||
UpdatePinnedMessageRequest, SetTypingRequest,
|
||||
EditChatAboutRequest)
|
||||
EditChatAboutRequest, UnpinAllMessagesRequest)
|
||||
from telethon.tl.functions.channels import EditPhotoRequest, EditTitleRequest, JoinChannelRequest
|
||||
from telethon.errors import (ChatNotModifiedError, PhotoExtInvalidError,
|
||||
PhotoInvalidDimensionsError, PhotoSaveFileInvalidError,
|
||||
RPCError)
|
||||
from telethon.errors import (ChatNotModifiedError, PhotoExtInvalidError, MessageIdInvalidError,
|
||||
PhotoInvalidDimensionsError, PhotoSaveFileInvalidError, RPCError)
|
||||
from telethon.tl.patched import Message, MessageService
|
||||
from telethon.tl.types import (
|
||||
DocumentAttributeFilename, DocumentAttributeImageSize, GeoPoint,
|
||||
@@ -432,23 +431,23 @@ class PortalMatrix(BasePortal, ABC):
|
||||
else:
|
||||
self.log.trace("Unhandled Matrix event: %s", content)
|
||||
|
||||
async def handle_matrix_pin(self, sender: 'u.User', pinned_message: Optional[EventID],
|
||||
async def handle_matrix_unpin_all(self, sender: 'u.User', pin_event_id: EventID) -> None:
|
||||
await sender.client(UnpinAllMessagesRequest(peer=self.peer))
|
||||
await self._send_delivery_receipt(pin_event_id)
|
||||
|
||||
async def handle_matrix_pin(self, sender: 'u.User', changes: Dict[EventID, bool],
|
||||
pin_event_id: EventID) -> None:
|
||||
if self.peer_type != "chat" and self.peer_type != "channel":
|
||||
return
|
||||
try:
|
||||
if not pinned_message:
|
||||
await sender.client(UpdatePinnedMessageRequest(peer=self.peer, id=0))
|
||||
else:
|
||||
tg_space = self.tgid if self.peer_type == "channel" else sender.tgid
|
||||
message = DBMessage.get_by_mxid(pinned_message, self.mxid, tg_space)
|
||||
if message is None:
|
||||
self.log.warning(f"Could not find pinned {pinned_message} in {self.mxid}")
|
||||
return
|
||||
await sender.client(UpdatePinnedMessageRequest(peer=self.peer, id=message.tgid))
|
||||
await self._send_delivery_receipt(pin_event_id)
|
||||
except ChatNotModifiedError:
|
||||
pass
|
||||
tg_space = self.tgid if self.peer_type == "channel" else sender.tgid
|
||||
ids = {msg.mxid: msg.tgid
|
||||
for msg in DBMessage.get_by_mxids(list(changes.keys()),
|
||||
mx_room=self.mxid, tg_space=tg_space)}
|
||||
for event_id, pinned in changes.items():
|
||||
try:
|
||||
await sender.client(UpdatePinnedMessageRequest(peer=self.peer, id=ids[event_id],
|
||||
unpin=not pinned))
|
||||
except (ChatNotModifiedError, MessageIdInvalidError, KeyError):
|
||||
pass
|
||||
await self._send_delivery_receipt(pin_event_id)
|
||||
|
||||
async def handle_matrix_deletion(self, deleter: 'u.User', event_id: EventID,
|
||||
redaction_event_id: EventID) -> None:
|
||||
|
||||
@@ -694,13 +694,20 @@ class PortalTelegram(BasePortal, ABC):
|
||||
levels.users[puppet.mxid] = 50
|
||||
await self.main_intent.set_power_levels(self.mxid, levels)
|
||||
|
||||
async def receive_telegram_pin_id(self, msg_id: TelegramID, receiver: TelegramID) -> None:
|
||||
tg_space = receiver if self.peer_type != "channel" else self.tgid
|
||||
message = DBMessage.get_one_by_tgid(msg_id, tg_space) if msg_id != 0 else None
|
||||
if message:
|
||||
await self.main_intent.set_pinned_messages(self.mxid, [message.mxid])
|
||||
else:
|
||||
await self.main_intent.set_pinned_messages(self.mxid, [])
|
||||
async def receive_telegram_pin_ids(self, msg_ids: List[TelegramID], receiver: TelegramID,
|
||||
remove: bool) -> None:
|
||||
async with self._pin_lock:
|
||||
tg_space = receiver if self.peer_type != "channel" else self.tgid
|
||||
previously_pinned = await self.main_intent.get_pinned_messages(self.mxid)
|
||||
currently_pinned_dict = {event_id: True for event_id in previously_pinned}
|
||||
for message in DBMessage.get_first_by_tgids(msg_ids, tg_space):
|
||||
if remove:
|
||||
currently_pinned_dict.pop(message.mxid, None)
|
||||
else:
|
||||
currently_pinned_dict[message.mxid] = True
|
||||
currently_pinned = list(currently_pinned_dict.keys())
|
||||
if currently_pinned != previously_pinned:
|
||||
await self.main_intent.set_pinned_messages(self.mxid, currently_pinned)
|
||||
|
||||
async def set_telegram_admins_enabled(self, enabled: bool) -> None:
|
||||
level = 50 if enabled else 10
|
||||
|
||||
Reference in New Issue
Block a user