From 0670c2b2bcc4edcd95ec3fdf03fc1f735edc16e9 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Wed, 21 Aug 2024 13:45:45 -0600 Subject: [PATCH] updates: add wrapper for API calls to update users Signed-off-by: Sumner Evans --- pkg/connector/api.go | 34 ++++++++++++++++ pkg/connector/backfill.go | 16 +++++--- pkg/connector/chatinfo.go | 70 +++++++++++++++++++-------------- pkg/connector/client.go | 15 ------- pkg/connector/directdownload.go | 56 ++++++++++++++++---------- pkg/connector/reactions.go | 6 ++- 6 files changed, 124 insertions(+), 73 deletions(-) create mode 100644 pkg/connector/api.go diff --git a/pkg/connector/api.go b/pkg/connector/api.go new file mode 100644 index 00000000..514aecef --- /dev/null +++ b/pkg/connector/api.go @@ -0,0 +1,34 @@ +package connector + +import ( + "context" + "fmt" + + "github.com/gotd/td/tg" +) + +type hasUpdates interface { + GetUsers() []tg.UserClass +} + +// Wrapper for API calls that return a response with updates. +func APICallWithUpdates[U hasUpdates](ctx context.Context, t *TelegramClient, fn func() (U, error)) (U, error) { + resp, err := fn() + if err != nil { + return resp, err + } + + // TODO do we also need to expand this to chats and messages? + for _, user := range resp.GetUsers() { + user, ok := user.(*tg.User) + if !ok { + return resp, fmt.Errorf("user is %T not *tg.User", user) + } + err := t.updateGhost(ctx, user.ID, user) + if err != nil { + return resp, err + } + } + + return resp, nil +} diff --git a/pkg/connector/backfill.go b/pkg/connector/backfill.go index 8c30d5d5..9ea94718 100644 --- a/pkg/connector/backfill.go +++ b/pkg/connector/backfill.go @@ -36,14 +36,20 @@ func (t *TelegramClient) FetchMessages(ctx context.Context, fetchParams bridgev2 return nil, err } } - rawMsgs, err := t.client.API().MessagesGetHistory(ctx, &req) + msgs, err := APICallWithUpdates(ctx, t, func() (tg.ModifiedMessagesMessages, error) { + rawMsgs, err := t.client.API().MessagesGetHistory(ctx, &req) + if err != nil { + return nil, err + } + msgs, ok := rawMsgs.(tg.ModifiedMessagesMessages) + if !ok { + return nil, fmt.Errorf("unsupported messages type %T", rawMsgs) + } + return msgs, nil + }) if err != nil { return nil, err } - msgs, ok := rawMsgs.(interface{ GetMessages() []tg.MessageClass }) - if !ok { - return nil, fmt.Errorf("unsupported messages type %T", rawMsgs) - } var markRead bool // TODO implement messages := msgs.GetMessages() diff --git a/pkg/connector/chatinfo.go b/pkg/connector/chatinfo.go index 30cb24b6..0509def3 100644 --- a/pkg/connector/chatinfo.go +++ b/pkg/connector/chatinfo.go @@ -53,11 +53,7 @@ func (t *TelegramClient) getDMChatInfo(ctx context.Context, userID int64) (*brid return &chatInfo, nil } -func (t *TelegramClient) getGroupChatInfo(ctx context.Context, fullChat *tg.MessagesChatFull, chatID int64) (*bridgev2.ChatInfo, bool, error) { - if err := t.updateUsersFromResponse(ctx, fullChat); err != nil { - return nil, false, err - } - +func (t *TelegramClient) getGroupChatInfo(fullChat *tg.MessagesChatFull, chatID int64) (*bridgev2.ChatInfo, bool, error) { var name *string var isBroadcastChannel, isMegagroup bool for _, c := range fullChat.GetChats() { @@ -140,7 +136,6 @@ func (t *TelegramClient) filterChannelParticipants(chatParticipants []tg.Channel } func (t *TelegramClient) GetChatInfo(ctx context.Context, portal *bridgev2.Portal) (*bridgev2.ChatInfo, error) { - // fmt.Printf("get chat info %+v\n", portal) peerType, id, err := ids.ParsePortalID(portal.ID) if err != nil { return nil, err @@ -150,11 +145,13 @@ func (t *TelegramClient) GetChatInfo(ctx context.Context, portal *bridgev2.Porta case ids.PeerTypeUser: return t.getDMChatInfo(ctx, id) case ids.PeerTypeChat: - fullChat, err := t.client.API().MessagesGetFullChat(ctx, id) + fullChat, err := APICallWithUpdates(ctx, t, func() (*tg.MessagesChatFull, error) { + return t.client.API().MessagesGetFullChat(ctx, id) + }) if err != nil { return nil, err } - chatInfo, _, err := t.getGroupChatInfo(ctx, fullChat, id) + chatInfo, _, err := t.getGroupChatInfo(fullChat, id) if err != nil { return nil, err } @@ -202,12 +199,14 @@ func (t *TelegramClient) GetChatInfo(ctx context.Context, portal *bridgev2.Porta return nil, fmt.Errorf("channel access hash not found for %d", id) } inputChannel := &tg.InputChannel{ChannelID: id, AccessHash: accessHash} - fullChat, err := t.client.API().ChannelsGetFullChannel(ctx, inputChannel) + fullChat, err := APICallWithUpdates(ctx, t, func() (*tg.MessagesChatFull, error) { + return t.client.API().ChannelsGetFullChannel(ctx, inputChannel) + }) if err != nil { return nil, err } - chatInfo, isBroadcastChannel, err := t.getGroupChatInfo(ctx, fullChat, id) + chatInfo, isBroadcastChannel, err := t.getGroupChatInfo(fullChat, id) if err != nil { return nil, err } @@ -245,32 +244,47 @@ func (t *TelegramClient) GetChatInfo(ctx context.Context, portal *bridgev2.Porta limit := t.main.Config.MemberList.NormalizedMaxInitialSync() if limit <= 200 { - p, err := t.client.API().ChannelsGetParticipants(ctx, &tg.ChannelsGetParticipantsRequest{ - Channel: inputChannel, - Filter: &tg.ChannelParticipantsRecent{}, - Limit: limit, + participants, err := APICallWithUpdates(ctx, t, func() (*tg.ChannelsChannelParticipants, error) { + p, err := t.client.API().ChannelsGetParticipants(ctx, &tg.ChannelsGetParticipantsRequest{ + Channel: inputChannel, + Filter: &tg.ChannelParticipantsRecent{}, + Limit: limit, + }) + if err != nil { + return nil, err + } + participants, ok := p.(*tg.ChannelsChannelParticipants) + if !ok { + return nil, fmt.Errorf("returned participants is %T not *tg.ChannelsChannelParticipants", p) + } else { + return participants, nil + } }) if err != nil { return nil, err } - participants, ok := p.(*tg.ChannelsChannelParticipants) - if !ok { - return nil, fmt.Errorf("returned participants is %T not *tg.ChannelsChannelParticipants", p) - } chatInfo.Members.IsFull = len(participants.Participants) < limit - if err := t.updateUsersFromResponse(ctx, participants); err != nil { - return nil, err - } chatInfo.Members.Members = append(chatInfo.Members.Members, t.filterChannelParticipants(participants.Participants, limit)...) } else { remaining := t.main.Config.MemberList.NormalizedMaxInitialSync() var offset int for remaining > 0 { - p, err := t.client.API().ChannelsGetParticipants(ctx, &tg.ChannelsGetParticipantsRequest{ - Channel: inputChannel, - Filter: &tg.ChannelParticipantsSearch{}, - Limit: min(remaining, 200), - Offset: offset, + participants, err := APICallWithUpdates(ctx, t, func() (*tg.ChannelsChannelParticipants, error) { + p, err := t.client.API().ChannelsGetParticipants(ctx, &tg.ChannelsGetParticipantsRequest{ + Channel: inputChannel, + Filter: &tg.ChannelParticipantsSearch{}, + Limit: min(remaining, 200), + Offset: offset, + }) + if err != nil { + return nil, err + } + participants, ok := p.(*tg.ChannelsChannelParticipants) + if !ok { + return nil, fmt.Errorf("returned participants is %T not *tg.ChannelsChannelParticipants", p) + } else { + return participants, nil + } }) if err != nil { return nil, err @@ -283,10 +297,6 @@ func (t *TelegramClient) GetChatInfo(ctx context.Context, portal *bridgev2.Porta chatInfo.Members.IsFull = true break } - - if err := t.updateUsersFromResponse(ctx, participants); err != nil { - return nil, err - } chatInfo.Members.Members = append(chatInfo.Members.Members, t.filterChannelParticipants(participants.Participants, limit)...) offset += len(participants.Participants) diff --git a/pkg/connector/client.go b/pkg/connector/client.go index e173c979..63ed6663 100644 --- a/pkg/connector/client.go +++ b/pkg/connector/client.go @@ -383,21 +383,6 @@ func (t *TelegramClient) Disconnect() { t.clientCancel() } -func (t *TelegramClient) updateUsersFromResponse(ctx context.Context, resp interface{ GetUsers() []tg.UserClass }) error { - // TODO table for the access hashes? - for _, user := range resp.GetUsers() { - user, ok := user.(*tg.User) - if !ok { - return fmt.Errorf("user is %T not *tg.User", user) - } - err := t.updateGhost(ctx, user.ID, user) - if err != nil { - return err - } - } - return nil -} - func (t *TelegramClient) GetUserInfo(ctx context.Context, ghost *bridgev2.Ghost) (*bridgev2.UserInfo, error) { id, err := ids.ParseUserID(ghost.ID) if err != nil { diff --git a/pkg/connector/directdownload.go b/pkg/connector/directdownload.go index 820ccdf2..9a1fa5c5 100644 --- a/pkg/connector/directdownload.go +++ b/pkg/connector/directdownload.go @@ -50,11 +50,20 @@ func (tc *TelegramConnector) Download(ctx context.Context, mediaID networkid.Med } client := userLogin.Client.(*TelegramClient) - var messages tg.MessagesMessagesClass + var messages tg.ModifiedMessagesMessages switch info.PeerType { case ids.PeerTypeUser, ids.PeerTypeChat: - messages, err = client.client.API().MessagesGetMessages(ctx, []tg.InputMessageClass{ - &tg.InputMessageID{ID: int(info.MessageID)}, + messages, err = APICallWithUpdates(ctx, client, func() (tg.ModifiedMessagesMessages, error) { + m, err := client.client.API().MessagesGetMessages(ctx, []tg.InputMessageClass{ + &tg.InputMessageID{ID: int(info.MessageID)}, + }) + if err != nil { + return nil, err + } else if messages, ok := m.(tg.ModifiedMessagesMessages); !ok { + return nil, fmt.Errorf("unsupported messages type %T", messages) + } else { + return messages, nil + } }) case ids.PeerTypeChannel: var accessHash int64 @@ -65,11 +74,20 @@ func (tc *TelegramConnector) Download(ctx context.Context, mediaID networkid.Med } else if !found { return nil, fmt.Errorf("channel access hash not found for %d", info.ChatID) } else { - messages, err = client.client.API().ChannelsGetMessages(ctx, &tg.ChannelsGetMessagesRequest{ - Channel: &tg.InputChannel{ChannelID: info.ChatID, AccessHash: accessHash}, - ID: []tg.InputMessageClass{ - &tg.InputMessageID{ID: int(info.MessageID)}, - }, + messages, err = APICallWithUpdates(ctx, client, func() (tg.ModifiedMessagesMessages, error) { + m, err := client.client.API().ChannelsGetMessages(ctx, &tg.ChannelsGetMessagesRequest{ + Channel: &tg.InputChannel{ChannelID: info.ChatID, AccessHash: accessHash}, + ID: []tg.InputMessageClass{ + &tg.InputMessageID{ID: int(info.MessageID)}, + }, + }) + if err != nil { + return nil, err + } else if messages, ok := m.(tg.ModifiedMessagesMessages); !ok { + return nil, fmt.Errorf("unsupported messages type %T", messages) + } else { + return messages, nil + } }) } default: @@ -80,21 +98,17 @@ func (tc *TelegramConnector) Download(ctx context.Context, mediaID networkid.Med } var msgMedia tg.MessageMediaClass - if m, ok := messages.(getMessages); !ok { - return nil, fmt.Errorf("unknown message type %T", messages) - } else { - var found bool - for _, message := range m.GetMessages() { - if msg, ok := message.(*tg.Message); ok && msg.ID == int(info.MessageID) { - msgMedia = msg.Media - found = true - break - } - } - if !found { - return nil, fmt.Errorf("no media found with ID %d", info.MessageID) + var found bool + for _, message := range messages.GetMessages() { + if msg, ok := message.(*tg.Message); ok && msg.ID == int(info.MessageID) { + msgMedia = msg.Media + found = true + break } } + if !found { + return nil, fmt.Errorf("no media found with ID %d", info.MessageID) + } transferer := media.NewTransferer(client.client.API()) var readyTransferer *media.ReadyTransferer diff --git a/pkg/connector/reactions.go b/pkg/connector/reactions.go index 73d87ef8..10c23375 100644 --- a/pkg/connector/reactions.go +++ b/pkg/connector/reactions.go @@ -51,8 +51,10 @@ func (t *TelegramClient) computeReactionsList(ctx context.Context, msg *tg.Messa } else if peer, err := t.inputPeerForPortalID(ctx, ids.MakePortalKey(msg.PeerID, t.loginID).ID); err != nil { return nil, false, nil, fmt.Errorf("failed to get input peer: %w", err) } else { - reactions, err := t.client.API().MessagesGetMessageReactionsList(ctx, &tg.MessagesGetMessageReactionsListRequest{ - Peer: peer, ID: msg.ID, Limit: 100, + reactions, err := APICallWithUpdates(ctx, t, func() (*tg.MessagesMessageReactionsList, error) { + return t.client.API().MessagesGetMessageReactionsList(ctx, &tg.MessagesGetMessageReactionsListRequest{ + Peer: peer, ID: msg.ID, Limit: 100, + }) }) if err != nil { return nil, false, nil, fmt.Errorf("failed to get reactions list: %w", err)