From 48a624bd07ea4aa00c6ec63e8a2d58ced12af104 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 19 Aug 2020 00:11:44 +0300 Subject: [PATCH] Re-add custom get_users method to avoid expensive API calls --- mautrix_telegram/portal/metadata.py | 67 ++++++++++++++++++++++++----- 1 file changed, 56 insertions(+), 11 deletions(-) diff --git a/mautrix_telegram/portal/metadata.py b/mautrix_telegram/portal/metadata.py index 0cbe914c..019d8090 100644 --- a/mautrix_telegram/portal/metadata.py +++ b/mautrix_telegram/portal/metadata.py @@ -13,7 +13,7 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import List, Optional, Tuple, Union, Dict, Any, TYPE_CHECKING +from typing import List, Optional, Iterable, Union, Dict, Any, TYPE_CHECKING from abc import ABC import asyncio @@ -766,20 +766,65 @@ class PortalMetadata(BasePortal, ABC): return True return False + @staticmethod + def _filter_participants(users: List[TypeUser], participants: List[TypeParticipant] + ) -> Iterable[TypeUser]: + participant_map = {part.user_id: part for part in participants} + for user in users: + try: + user.participant = participant_map[user.id] + except KeyError: + pass + else: + yield user + + async def _get_channel_users(self, user: 'AbstractUser', entity: InputChannel, limit: int + ) -> List[TypeUser]: + if 0 < limit <= 200: + response = await user.client(GetParticipantsRequest( + entity, ChannelParticipantsRecent(), offset=0, limit=limit, hash=0)) + return list(self._filter_participants(response.users, response.participants)) + elif limit > 200 or limit == -1: + users: List[TypeUser] = [] + offset = 0 + remaining_quota = limit if limit > 0 else 1000000 + query = (ChannelParticipantsSearch("") if limit == -1 + else ChannelParticipantsRecent()) + while True: + if remaining_quota <= 0: + break + response = await user.client(GetParticipantsRequest( + entity, query, offset=offset, limit=min(remaining_quota, 200), hash=0)) + if not response.users: + break + users += self._filter_participants(response.users, response.participants) + offset += len(response.participants) + remaining_quota -= len(response.participants) + return users + async def _get_users(self, user: 'AbstractUser', entity: Union[TypeInputPeer, InputUser, TypeChat, TypeUser, InputChannel] ) -> List[TypeUser]: - if self.peer_type == "user": + if self.peer_type == "chat": + chat = await user.client(GetFullChatRequest(chat_id=self.tgid)) + return list(self._filter_participants(chat.users, + chat.full_chat.participants.participants)) + elif self.peer_type == "channel": + if not self.megagroup and not self.sync_channel_members: + return [] + + limit = self.max_initial_member_sync + if limit == 0: + return [] + + try: + return await self._get_channel_users(user, entity, limit) + except ChatAdminRequiredError: + return [] + elif self.peer_type == "user": return [entity] - limit = self.max_initial_member_sync - if limit == 0: - return [] - elif limit < 0: - limit = None - try: - return await user.client.get_participants(entity, limit=limit) - except ChatAdminRequiredError: - return [] + else: + raise RuntimeError(f"Unexpected peer type {self.peer_type}") # endregion