gotd: ensure user is member of channels before starting getDifference loop
This commit is contained in:
@@ -226,6 +226,7 @@ func NewTelegramClient(ctx context.Context, tc *TelegramConnector, login *bridge
|
|||||||
dispatcher.OnPhoneCall(client.onPhoneCall)
|
dispatcher.OnPhoneCall(client.onPhoneCall)
|
||||||
|
|
||||||
client.updatesManager = updates.New(updates.Config{
|
client.updatesManager = updates.New(updates.Config{
|
||||||
|
OnNotChannelMember: client.onNotChannelMember,
|
||||||
OnChannelTooLong: func(channelID int64) error {
|
OnChannelTooLong: func(channelID int64) error {
|
||||||
// TODO resync topics?
|
// TODO resync topics?
|
||||||
res := tc.Bridge.QueueRemoteEvent(login, &simplevent.ChatResync{
|
res := tc.Bridge.QueueRemoteEvent(login, &simplevent.ChatResync{
|
||||||
|
|||||||
@@ -107,6 +107,10 @@ func (t *TelegramClient) selfLeaveChat(ctx context.Context, portalKey networkid.
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *TelegramClient) onNotChannelMember(ctx context.Context, channelID int64) error {
|
||||||
|
return t.selfLeaveChat(ctx, t.makePortalKeyFromID(ids.PeerTypeChannel, channelID, 0), fmt.Errorf("startup channel member check failed"))
|
||||||
|
}
|
||||||
|
|
||||||
func (t *TelegramClient) onUpdateChannel(ctx context.Context, e tg.Entities, update *tg.UpdateChannel) error {
|
func (t *TelegramClient) onUpdateChannel(ctx context.Context, e tg.Entities, update *tg.UpdateChannel) error {
|
||||||
log := zerolog.Ctx(ctx).With().
|
log := zerolog.Ctx(ctx).With().
|
||||||
Str("handler", "on_update_channel").
|
Str("handler", "on_update_channel").
|
||||||
|
|||||||
@@ -74,19 +74,26 @@ const (
|
|||||||
|
|
||||||
var _ updates.StateStorage = (*ScopedStore)(nil)
|
var _ updates.StateStorage = (*ScopedStore)(nil)
|
||||||
|
|
||||||
|
type channelIDPtsTuple struct {
|
||||||
|
ChannelID int64
|
||||||
|
Pts int
|
||||||
|
}
|
||||||
|
|
||||||
|
var ciptScanner = dbutil.ConvertRowFn[channelIDPtsTuple](func(row dbutil.Scannable) (cipt channelIDPtsTuple, err error) {
|
||||||
|
err = row.Scan(&cipt.ChannelID, &cipt.Pts)
|
||||||
|
return
|
||||||
|
})
|
||||||
|
|
||||||
func (s *ScopedStore) ForEachChannels(ctx context.Context, userID int64, f func(ctx context.Context, channelID int64, pts int) error) error {
|
func (s *ScopedStore) ForEachChannels(ctx context.Context, userID int64, f func(ctx context.Context, channelID int64, pts int) error) error {
|
||||||
s.assertUserIDMatches(userID)
|
s.assertUserIDMatches(userID)
|
||||||
rows, err := s.db.Query(ctx, allChannelsQuery, userID)
|
items, err := ciptScanner.NewRowIter(s.db.Query(ctx, allChannelsQuery, userID)).AsList()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
var channelID int64
|
for _, item := range items {
|
||||||
var pts int
|
err = f(ctx, item.ChannelID, item.Pts)
|
||||||
for rows.Next() {
|
if err != nil {
|
||||||
if err = rows.Scan(&channelID, &pts); err != nil {
|
return fmt.Errorf("iteration error for channel %d: %w", item.ChannelID, err)
|
||||||
return err
|
|
||||||
} else if err = f(ctx, channelID, pts); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ type API interface {
|
|||||||
UpdatesGetState(ctx context.Context) (*tg.UpdatesState, error)
|
UpdatesGetState(ctx context.Context) (*tg.UpdatesState, error)
|
||||||
UpdatesGetDifference(ctx context.Context, request *tg.UpdatesGetDifferenceRequest) (tg.UpdatesDifferenceClass, error)
|
UpdatesGetDifference(ctx context.Context, request *tg.UpdatesGetDifferenceRequest) (tg.UpdatesDifferenceClass, error)
|
||||||
UpdatesGetChannelDifference(ctx context.Context, request *tg.UpdatesGetChannelDifferenceRequest) (tg.UpdatesChannelDifferenceClass, error)
|
UpdatesGetChannelDifference(ctx context.Context, request *tg.UpdatesGetChannelDifferenceRequest) (tg.UpdatesChannelDifferenceClass, error)
|
||||||
|
ChannelsGetParticipant(ctx context.Context, request *tg.ChannelsGetParticipantRequest) (*tg.ChannelsChannelParticipant, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Config of the manager.
|
// Config of the manager.
|
||||||
@@ -24,7 +25,8 @@ type Config struct {
|
|||||||
Handler telegram.UpdateHandler
|
Handler telegram.UpdateHandler
|
||||||
// Callback called if manager cannot
|
// Callback called if manager cannot
|
||||||
// recover channel gap (optional).
|
// recover channel gap (optional).
|
||||||
OnChannelTooLong func(channelID int64) error
|
OnChannelTooLong func(channelID int64) error
|
||||||
|
OnNotChannelMember func(ctx context.Context, channelID int64) error
|
||||||
// State storage.
|
// State storage.
|
||||||
// In-mem used if not provided.
|
// In-mem used if not provided.
|
||||||
Storage StateStorage
|
Storage StateStorage
|
||||||
|
|||||||
@@ -34,6 +34,12 @@ func newServer() *server {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *server) ChannelsGetParticipant(ctx context.Context, request *tg.ChannelsGetParticipantRequest) (*tg.ChannelsChannelParticipant, error) {
|
||||||
|
return &tg.ChannelsChannelParticipant{
|
||||||
|
Participant: &tg.ChannelParticipantSelf{},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
// UpdatesGetState returns current remote state.
|
// UpdatesGetState returns current remote state.
|
||||||
func (s *server) UpdatesGetState(ctx context.Context) (*tg.UpdatesState, error) {
|
func (s *server) UpdatesGetState(ctx context.Context) (*tg.UpdatesState, error) {
|
||||||
s.mux.Lock()
|
s.mux.Lock()
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package updates
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/go-faster/errors"
|
"github.com/go-faster/errors"
|
||||||
@@ -11,6 +12,7 @@ import (
|
|||||||
|
|
||||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram"
|
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram"
|
||||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||||
|
"go.mau.fi/mautrix-telegram/pkg/gotd/tgerr"
|
||||||
)
|
)
|
||||||
|
|
||||||
var _ telegram.UpdateHandler = (*Manager)(nil)
|
var _ telegram.UpdateHandler = (*Manager)(nil)
|
||||||
@@ -74,6 +76,47 @@ type AuthOptions struct {
|
|||||||
OnStart func(ctx context.Context)
|
OnStart func(ctx context.Context)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type PtsAccessHashTuple struct {
|
||||||
|
Pts int
|
||||||
|
AccessHash int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) checkParticipant(ctx context.Context, api API, userID, channelID, hash int64) error {
|
||||||
|
lg := m.lg.With(zap.Int64("channel_id", channelID))
|
||||||
|
lg.Info("Ensuring user is still in channel")
|
||||||
|
pcp, err := api.ChannelsGetParticipant(ctx, &tg.ChannelsGetParticipantRequest{
|
||||||
|
Channel: &tg.InputChannel{
|
||||||
|
ChannelID: channelID,
|
||||||
|
AccessHash: hash,
|
||||||
|
},
|
||||||
|
Participant: &tg.InputPeerSelf{},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
if tgerr.Is(err, tg.ErrChannelInvalid, tg.ErrChannelPrivate, tg.ErrUserNotParticipant) {
|
||||||
|
lg.Warn("Removing update state for channel after error", zap.Error(err))
|
||||||
|
} else {
|
||||||
|
lg.Error("channels.getParticipant failed", zap.Error(err))
|
||||||
|
// TODO fatal error?
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
switch pcp.Participant.(type) {
|
||||||
|
case *tg.ChannelParticipantLeft, *tg.ChannelParticipantBanned:
|
||||||
|
lg.Warn("Removing update state for channel as user is left or banned")
|
||||||
|
default:
|
||||||
|
lg.Debug("Membership confirmed", zap.Any("participant", pcp.Participant))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.cfg.Storage.SetChannelPts(ctx, userID, channelID, -1); err != nil {
|
||||||
|
return fmt.Errorf("failed to clear pts: %w", err)
|
||||||
|
} else if err = m.cfg.OnNotChannelMember(ctx, channelID); err != nil {
|
||||||
|
return fmt.Errorf("OnNotChannelMember callback failed: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Run notifies manager about user authentication on the telegram server.
|
// Run notifies manager about user authentication on the telegram server.
|
||||||
//
|
//
|
||||||
// If forget is true, local internalState (if exist) will be overwritten
|
// If forget is true, local internalState (if exist) will be overwritten
|
||||||
@@ -101,10 +144,7 @@ func (m *Manager) Run(ctx context.Context, api API, userID int64, opt AuthOption
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "load internalState")
|
return errors.Wrap(err, "load internalState")
|
||||||
}
|
}
|
||||||
channels := make(map[int64]struct {
|
channels := make(map[int64]PtsAccessHashTuple)
|
||||||
Pts int
|
|
||||||
AccessHash int64
|
|
||||||
})
|
|
||||||
if err := m.cfg.Storage.ForEachChannels(ctx, userID, func(ctx context.Context, channelID int64, pts int) error {
|
if err := m.cfg.Storage.ForEachChannels(ctx, userID, func(ctx context.Context, channelID int64, pts int) error {
|
||||||
if pts == -1 {
|
if pts == -1 {
|
||||||
return nil
|
return nil
|
||||||
@@ -112,16 +152,12 @@ func (m *Manager) Run(ctx context.Context, api API, userID int64, opt AuthOption
|
|||||||
hash, found, err := m.cfg.AccessHasher.GetChannelAccessHash(ctx, userID, channelID)
|
hash, found, err := m.cfg.AccessHasher.GetChannelAccessHash(ctx, userID, channelID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "get channel access hash")
|
return errors.Wrap(err, "get channel access hash")
|
||||||
}
|
} else if !found {
|
||||||
|
|
||||||
if !found {
|
|
||||||
return nil
|
return nil
|
||||||
|
} else if err = m.checkParticipant(ctx, api, userID, channelID, hash); err != nil {
|
||||||
|
return fmt.Errorf("failed to check if user is participant of channel %d: %w", channelID, err)
|
||||||
}
|
}
|
||||||
|
channels[channelID] = PtsAccessHashTuple{Pts: pts, AccessHash: hash}
|
||||||
channels[channelID] = struct {
|
|
||||||
Pts int
|
|
||||||
AccessHash int64
|
|
||||||
}{Pts: pts, AccessHash: hash}
|
|
||||||
return nil
|
return nil
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
return errors.Wrap(err, "iterate channels")
|
return errors.Wrap(err, "iterate channels")
|
||||||
|
|||||||
@@ -62,11 +62,8 @@ type internalState struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type stateConfig struct {
|
type stateConfig struct {
|
||||||
State State
|
State State
|
||||||
Channels map[int64]struct {
|
Channels map[int64]PtsAccessHashTuple
|
||||||
Pts int
|
|
||||||
AccessHash int64
|
|
||||||
}
|
|
||||||
RawClient API
|
RawClient API
|
||||||
Logger *zap.Logger
|
Logger *zap.Logger
|
||||||
Tracer trace.Tracer
|
Tracer trace.Tracer
|
||||||
@@ -422,6 +419,10 @@ func (s *internalState) createAndRunChannelState(ctx context.Context, channelID,
|
|||||||
s.channelsLock.Unlock()
|
s.channelsLock.Unlock()
|
||||||
s.log.Info("Removed channel state due to error", zap.Int64("channel_id", channelID), zap.Error(err))
|
s.log.Info("Removed channel state due to error", zap.Int64("channel_id", channelID), zap.Error(err))
|
||||||
return nil
|
return nil
|
||||||
|
} else if ctx.Err() == nil {
|
||||||
|
s.log.Error("Channel state stopped with unexpected error, new messages may stop arriving",
|
||||||
|
zap.Int64("channel_id", channelID), zap.Error(err))
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -205,12 +205,12 @@ func (s *channelState) applyPts(ctx context.Context, state int, updates []update
|
|||||||
Users: ents.Users,
|
Users: ents.Users,
|
||||||
Chats: ents.Chats,
|
Chats: ents.Chats,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
s.log.Error("Handle update error", zap.Error(err))
|
s.log.Error("Handle update error (applyPts)", zap.Error(err))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.storage.SetChannelPts(ctx, s.selfID, s.channelID, state); err != nil {
|
if err := s.storage.SetChannelPts(ctx, s.selfID, s.channelID, state); err != nil {
|
||||||
s.log.Error("SetChannelPts error", zap.Error(err))
|
s.log.Error("SetChannelPts error (applyPts)", zap.Error(err))
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -297,7 +297,7 @@ func (s *channelState) getDifference(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := s.storage.SetChannelPts(ctx, s.selfID, s.channelID, diff.Pts); err != nil {
|
if err := s.storage.SetChannelPts(ctx, s.selfID, s.channelID, diff.Pts); err != nil {
|
||||||
s.log.Warn("SetChannelPts error", zap.Error(err))
|
s.log.Warn("SetChannelPts error (getDifference)", zap.Error(err))
|
||||||
}
|
}
|
||||||
|
|
||||||
s.pts.SetState(diff.Pts, "updates.channelDifference")
|
s.pts.SetState(diff.Pts, "updates.channelDifference")
|
||||||
@@ -313,7 +313,7 @@ func (s *channelState) getDifference(ctx context.Context) error {
|
|||||||
|
|
||||||
case *tg.UpdatesChannelDifferenceEmpty:
|
case *tg.UpdatesChannelDifferenceEmpty:
|
||||||
if err := s.storage.SetChannelPts(ctx, s.selfID, s.channelID, diff.Pts); err != nil {
|
if err := s.storage.SetChannelPts(ctx, s.selfID, s.channelID, diff.Pts); err != nil {
|
||||||
s.log.Warn("SetChannelPts error", zap.Error(err))
|
s.log.Warn("SetChannelPts error (getDifference empty)", zap.Error(err))
|
||||||
}
|
}
|
||||||
|
|
||||||
s.pts.SetState(diff.Pts, "updates.channelDifferenceEmpty")
|
s.pts.SetState(diff.Pts, "updates.channelDifferenceEmpty")
|
||||||
@@ -333,7 +333,7 @@ func (s *channelState) getDifference(ctx context.Context) error {
|
|||||||
s.log.Warn("UpdatesChannelDifferenceTooLong invalid Dialog", zap.Error(err))
|
s.log.Warn("UpdatesChannelDifferenceTooLong invalid Dialog", zap.Error(err))
|
||||||
} else {
|
} else {
|
||||||
if err := s.storage.SetChannelPts(ctx, s.selfID, s.channelID, remotePts); err != nil {
|
if err := s.storage.SetChannelPts(ctx, s.selfID, s.channelID, remotePts); err != nil {
|
||||||
s.log.Warn("SetChannelPts error", zap.Error(err))
|
s.log.Warn("SetChannelPts error (getDifference too long)", zap.Error(err))
|
||||||
}
|
}
|
||||||
|
|
||||||
s.pts.SetState(remotePts, "updates.channelDifferenceTooLong dialog new pts")
|
s.pts.SetState(remotePts, "updates.channelDifferenceTooLong dialog new pts")
|
||||||
|
|||||||
Reference in New Issue
Block a user