From 088900aee1610fce7f1864f6ead537964d274803 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Thu, 29 Aug 2024 11:57:21 -0600 Subject: [PATCH] connector: save channel access hashes in more places Signed-off-by: Sumner Evans --- pkg/connector/api.go | 35 +++++++++++++++++++++++------ pkg/connector/startnewchat.go | 2 +- pkg/connector/store/scoped_store.go | 8 +++---- pkg/connector/telegram.go | 9 +++++++- 4 files changed, 41 insertions(+), 13 deletions(-) diff --git a/pkg/connector/api.go b/pkg/connector/api.go index 4f3cbde9..bf2868cf 100644 --- a/pkg/connector/api.go +++ b/pkg/connector/api.go @@ -7,26 +7,47 @@ import ( "github.com/gotd/td/tg" ) -type hasUpdates interface { +type hasUserUpdates interface { GetUsers() []tg.UserClass } -// Wrapper for API calls that return a response with updates. -func APICallWithUpdates[U hasUpdates](ctx context.Context, t *TelegramClient, fn func() (U, error)) (U, error) { +type hasUpdates interface { + hasUserUpdates + GetChats() []tg.ChatClass +} + +func APICallWithOnlyUserUpdates[U hasUserUpdates](ctx context.Context, t *TelegramClient, fn func() (U, error)) (U, error) { resp, err := fn() if err != nil { - return resp, err + return *new(U), err } - // TODO do we also need to expand this to chats and messages? for _, user := range resp.GetUsers() { user, ok := user.(*tg.User) if !ok { - return resp, fmt.Errorf("user is %T not *tg.User", user) + return *new(U), fmt.Errorf("user is %T not *tg.User", user) } _, err := t.updateGhost(ctx, user.ID, user) if err != nil { - return resp, err + return *new(U), err + } + } + + return resp, nil +} + +// Wrapper for API calls that return a response with updates. +func APICallWithUpdates[U hasUpdates](ctx context.Context, t *TelegramClient, fn func() (U, error)) (U, error) { + resp, err := APICallWithOnlyUserUpdates(ctx, t, fn) + if err != nil { + return *new(U), err + } + + for _, c := range resp.GetChats() { + if channel, ok := c.(*tg.Channel); ok { + if err := t.ScopedStore.SetAccessHash(ctx, channel.ID, channel.AccessHash); err != nil { + return *new(U), err + } } } diff --git a/pkg/connector/startnewchat.go b/pkg/connector/startnewchat.go index c0b256fa..13d09850 100644 --- a/pkg/connector/startnewchat.go +++ b/pkg/connector/startnewchat.go @@ -155,7 +155,7 @@ func (t *TelegramClient) SearchUsers(ctx context.Context, query string) (resp [] } func (t *TelegramClient) GetContactList(ctx context.Context) (resp []*bridgev2.ResolveIdentifierResponse, err error) { - contacts, err := APICallWithUpdates(ctx, t, func() (*tg.ContactsContacts, error) { + contacts, err := APICallWithOnlyUserUpdates(ctx, t, func() (*tg.ContactsContacts, error) { c, err := t.client.API().ContactsGetContacts(ctx, t.cachedContactsHash) if err != nil { return nil, err diff --git a/pkg/connector/store/scoped_store.go b/pkg/connector/store/scoped_store.go index 59126f00..e5e83742 100644 --- a/pkg/connector/store/scoped_store.go +++ b/pkg/connector/store/scoped_store.go @@ -167,16 +167,16 @@ func (s *ScopedStore) SetChannelAccessHash(ctx context.Context, userID, channelI var ErrNoAccessHash = errors.New("access hash not found") -func (s *ScopedStore) GetAccessHash(ctx context.Context, userID int64) (accessHash int64, err error) { - err = s.db.QueryRow(ctx, getAccessHashQuery, s.telegramUserID, userID).Scan(&accessHash) +func (s *ScopedStore) GetAccessHash(ctx context.Context, entityID int64) (accessHash int64, err error) { + err = s.db.QueryRow(ctx, getAccessHashQuery, s.telegramUserID, entityID).Scan(&accessHash) if errors.Is(err, sql.ErrNoRows) { err = ErrNoAccessHash } return } -func (s *ScopedStore) SetAccessHash(ctx context.Context, userID, accessHash int64) (err error) { - _, err = s.db.Exec(ctx, setAccessHashQuery, s.telegramUserID, userID, accessHash) +func (s *ScopedStore) SetAccessHash(ctx context.Context, entityID, accessHash int64) (err error) { + _, err = s.db.Exec(ctx, setAccessHashQuery, s.telegramUserID, entityID, accessHash) return } diff --git a/pkg/connector/telegram.go b/pkg/connector/telegram.go index 994b180f..7a48ad84 100644 --- a/pkg/connector/telegram.go +++ b/pkg/connector/telegram.go @@ -320,7 +320,14 @@ func (t *TelegramClient) updateGhost(ctx context.Context, userID int64, user *tg func (t *TelegramClient) onEntityUpdate(ctx context.Context, e tg.Entities) error { for userID, user := range e.Users { - t.updateGhost(ctx, userID, user) + if _, err := t.updateGhost(ctx, userID, user); err != nil { + return err + } + } + for channelID, channel := range e.Channels { + if err := t.ScopedStore.SetAccessHash(ctx, channelID, channel.AccessHash); err != nil { + return err + } } return nil }