gotd: ensure user is member of channels before starting getDifference loop

This commit is contained in:
Tulir Asokan
2025-12-12 15:45:39 +02:00
parent c1d92ce051
commit ba4dd48d5a
8 changed files with 88 additions and 31 deletions
+48 -12
View File
@@ -2,6 +2,7 @@ package updates
import (
"context"
"fmt"
"sync"
"github.com/go-faster/errors"
@@ -11,6 +12,7 @@ import (
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
"go.mau.fi/mautrix-telegram/pkg/gotd/tgerr"
)
var _ telegram.UpdateHandler = (*Manager)(nil)
@@ -74,6 +76,47 @@ type AuthOptions struct {
OnStart func(ctx context.Context)
}
type PtsAccessHashTuple struct {
Pts int
AccessHash int64
}
func (m *Manager) checkParticipant(ctx context.Context, api API, userID, channelID, hash int64) error {
lg := m.lg.With(zap.Int64("channel_id", channelID))
lg.Info("Ensuring user is still in channel")
pcp, err := api.ChannelsGetParticipant(ctx, &tg.ChannelsGetParticipantRequest{
Channel: &tg.InputChannel{
ChannelID: channelID,
AccessHash: hash,
},
Participant: &tg.InputPeerSelf{},
})
if err != nil {
if tgerr.Is(err, tg.ErrChannelInvalid, tg.ErrChannelPrivate, tg.ErrUserNotParticipant) {
lg.Warn("Removing update state for channel after error", zap.Error(err))
} else {
lg.Error("channels.getParticipant failed", zap.Error(err))
// TODO fatal error?
return nil
}
} else {
switch pcp.Participant.(type) {
case *tg.ChannelParticipantLeft, *tg.ChannelParticipantBanned:
lg.Warn("Removing update state for channel as user is left or banned")
default:
lg.Debug("Membership confirmed", zap.Any("participant", pcp.Participant))
return nil
}
}
if err := m.cfg.Storage.SetChannelPts(ctx, userID, channelID, -1); err != nil {
return fmt.Errorf("failed to clear pts: %w", err)
} else if err = m.cfg.OnNotChannelMember(ctx, channelID); err != nil {
return fmt.Errorf("OnNotChannelMember callback failed: %w", err)
}
return nil
}
// Run notifies manager about user authentication on the telegram server.
//
// If forget is true, local internalState (if exist) will be overwritten
@@ -101,10 +144,7 @@ func (m *Manager) Run(ctx context.Context, api API, userID int64, opt AuthOption
if err != nil {
return errors.Wrap(err, "load internalState")
}
channels := make(map[int64]struct {
Pts int
AccessHash int64
})
channels := make(map[int64]PtsAccessHashTuple)
if err := m.cfg.Storage.ForEachChannels(ctx, userID, func(ctx context.Context, channelID int64, pts int) error {
if pts == -1 {
return nil
@@ -112,16 +152,12 @@ func (m *Manager) Run(ctx context.Context, api API, userID int64, opt AuthOption
hash, found, err := m.cfg.AccessHasher.GetChannelAccessHash(ctx, userID, channelID)
if err != nil {
return errors.Wrap(err, "get channel access hash")
}
if !found {
} else if !found {
return nil
} else if err = m.checkParticipant(ctx, api, userID, channelID, hash); err != nil {
return fmt.Errorf("failed to check if user is participant of channel %d: %w", channelID, err)
}
channels[channelID] = struct {
Pts int
AccessHash int64
}{Pts: pts, AccessHash: hash}
channels[channelID] = PtsAccessHashTuple{Pts: pts, AccessHash: hash}
return nil
}); err != nil {
return errors.Wrap(err, "iterate channels")