Files
mautrix-telegram/pkg/gotd/telegram/updates/state.go
T

556 lines
14 KiB
Go

package updates
import (
"context"
"fmt"
"sync"
"time"
"github.com/go-faster/errors"
"go.opentelemetry.io/otel/trace"
"go.uber.org/zap"
"golang.org/x/sync/errgroup"
"go.mau.fi/mautrix-telegram/pkg/gotd/exchange"
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram"
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/auth"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
"go.mau.fi/mautrix-telegram/pkg/gotd/tgerr"
)
const (
idleTimeout = time.Minute * 15
fastgapTimeout = time.Millisecond * 500
diffLimitUser = 100
diffLimitBot = 100000
)
type tracedUpdate struct {
update tg.UpdatesClass
span trace.SpanContext
}
type internalState struct {
// Updates channel.
externalQueue chan tracedUpdate
// Updates from channel states
// during updates.getChannelDifference.
internalQueue chan tracedUpdate
// Common internalState.
pts, qts, seq *sequenceBox
date int
idleTimeout *time.Timer
// Channel states.
channels map[int64]*channelState
channelsLock sync.Mutex
// Immutable fields.
client API
log *zap.Logger
handler telegram.UpdateHandler
onTooLong func(channelID int64) error
storage StateStorage
hasher AccessHasher
selfID int64
diffLim int
wg *errgroup.Group
tracer trace.Tracer
}
type stateConfig struct {
State State
Channels map[int64]struct {
Pts int
AccessHash int64
}
RawClient API
Logger *zap.Logger
Tracer trace.Tracer
Handler telegram.UpdateHandler
OnChannelTooLong func(channelID int64) error
Storage StateStorage
Hasher AccessHasher
SelfID int64
DiffLimit int
WorkGroup *errgroup.Group
}
func newState(ctx context.Context, cfg stateConfig) *internalState {
s := &internalState{
externalQueue: make(chan tracedUpdate, 10),
internalQueue: make(chan tracedUpdate, 10),
date: cfg.State.Date,
idleTimeout: time.NewTimer(idleTimeout),
channels: make(map[int64]*channelState),
client: cfg.RawClient,
log: cfg.Logger,
handler: cfg.Handler,
onTooLong: cfg.OnChannelTooLong,
storage: cfg.Storage,
hasher: cfg.Hasher,
selfID: cfg.SelfID,
diffLim: cfg.DiffLimit,
wg: cfg.WorkGroup,
tracer: cfg.Tracer,
}
s.pts = newSequenceBox(sequenceConfig{
InitialState: cfg.State.Pts,
Apply: s.applyPts,
Logger: s.log.Named("pts"),
Tracer: s.tracer,
})
s.qts = newSequenceBox(sequenceConfig{
InitialState: cfg.State.Qts,
Apply: s.applyQts,
Logger: s.log.Named("qts"),
Tracer: s.tracer,
})
s.seq = newSequenceBox(sequenceConfig{
InitialState: cfg.State.Seq,
Apply: s.applySeq,
Logger: s.log.Named("seq"),
})
for id, info := range cfg.Channels {
if info.Pts == -1 {
continue
}
s.createAndRunChannelState(ctx, id, info.AccessHash, info.Pts)
}
return s
}
func (s *internalState) Push(ctx context.Context, u tg.UpdatesClass) error {
tu := tracedUpdate{
update: u,
span: trace.SpanContextFromContext(ctx),
}
select {
case s.externalQueue <- tu:
return nil
case <-ctx.Done():
return ctx.Err()
}
}
// isFatalError returns true if error is fatal so we should stop updates handler.
func isFatalError(err error) bool {
// See https://github.com/gotd/td/issues/1458.
if errors.Is(err, exchange.ErrKeyFingerprintNotFound) {
return true
}
if tgerr.Is(err, "AUTH_KEY_UNREGISTERED", "SESSION_EXPIRED") {
return true
}
if auth.IsUnauthorized(err) {
return true
}
return false
}
func (s *internalState) Run(ctx context.Context) error {
if s == nil {
return errors.New("invalid: nil internalState")
}
if s.log == nil {
return errors.New("invalid: nil logger")
}
s.log.Info("Starting updates handler")
defer s.log.Info("Updates handler stopped")
s.getDifferenceLogger(ctx)
for {
select {
case <-ctx.Done():
return ctx.Err()
case u := <-s.externalQueue:
ctx := trace.ContextWithSpanContext(ctx, u.span)
if err := s.handleUpdates(ctx, u.update); err != nil {
return err
}
case u := <-s.internalQueue:
ctx := trace.ContextWithSpanContext(ctx, u.span)
if err := s.handleUpdates(ctx, u.update); err != nil {
return err
}
case <-s.pts.gapTimeout.C:
s.log.Debug("Pts gap timeout")
s.getDifferenceLogger(ctx)
case <-s.qts.gapTimeout.C:
s.log.Debug("Qts gap timeout")
s.getDifferenceLogger(ctx)
case <-s.seq.gapTimeout.C:
s.log.Debug("Seq gap timeout")
s.getDifferenceLogger(ctx)
case <-s.idleTimeout.C:
s.log.Debug("Idle timeout")
s.getDifferenceLogger(ctx)
}
}
}
func (s *internalState) handleUpdates(ctx context.Context, u tg.UpdatesClass) error {
ctx, span := s.tracer.Start(ctx, "handleUpdates")
defer span.End()
s.resetIdleTimer()
switch u := u.(type) {
case *tg.Updates:
s.saveUserHashes(ctx, u.Users)
s.saveChannelHashes(ctx, u.Chats)
return s.handleSeq(ctx, &tg.UpdatesCombined{
Updates: u.Updates,
Users: u.Users,
Chats: u.Chats,
Date: u.Date,
Seq: u.Seq,
SeqStart: u.Seq,
})
case *tg.UpdatesCombined:
s.saveUserHashes(ctx, u.Users)
s.saveChannelHashes(ctx, u.Chats)
return s.handleSeq(ctx, u)
case *tg.UpdateShort:
return s.handleUpdates(ctx, &tg.UpdatesCombined{
Updates: []tg.UpdateClass{u.Update},
Date: u.Date,
})
case *tg.UpdateShortMessage:
updateShort := s.convertShortMessage(u)
if _, found, err := s.hasher.GetUserAccessHash(ctx, s.selfID, u.UserID); err != nil {
return err
} else if !found {
chats, users, err := s.handleDifference(ctx, u.Date)
if err != nil {
return err
}
return s.handleUpdates(ctx, &tg.UpdatesCombined{
Updates: []tg.UpdateClass{updateShort.Update},
Users: users,
Chats: chats,
Date: u.Date,
})
} else {
return s.handleUpdates(ctx, updateShort)
}
case *tg.UpdateShortChatMessage:
updateShort := s.convertShortChatMessage(u)
if _, found, err := s.hasher.GetUserAccessHash(ctx, s.selfID, u.FromID); err != nil {
return err
} else if !found {
chats, users, err := s.handleDifference(ctx, u.Date)
if err != nil {
return err
}
return s.handleUpdates(ctx, &tg.UpdatesCombined{
Updates: []tg.UpdateClass{updateShort.Update},
Users: users,
Chats: chats,
Date: u.Date,
})
} else {
return s.handleUpdates(ctx, updateShort)
}
case *tg.UpdateShortSentMessage:
return s.handleUpdates(ctx, s.convertShortSentMessage(u))
case *tg.UpdatesTooLong:
return s.getDifference(ctx)
default:
panic(fmt.Sprintf("unexpected update type: %T", u))
}
}
func (s *internalState) handleSeq(ctx context.Context, u *tg.UpdatesCombined) error {
ctx, span := s.tracer.Start(ctx, "handleSeq")
defer span.End()
if err := validateSeq(u.Seq, u.SeqStart); err != nil {
s.log.Error("Seq validation failed", zap.Error(err), zap.Any("update", u))
return nil
}
// Special case.
if u.Seq == 0 {
ptsChanged, err := s.applyCombined(ctx, u)
if err != nil {
return err
}
if ptsChanged {
return s.getDifference(ctx)
}
return nil
}
return s.seq.Handle(ctx, update{
Value: u,
State: u.Seq,
Count: u.Seq - u.SeqStart + 1,
})
}
func (s *internalState) handlePts(ctx context.Context, pts, ptsCount int, u tg.UpdateClass, ents entities) error {
if err := validatePts(pts, ptsCount); err != nil {
s.log.Error("Pts validation failed", zap.Error(err), zap.Any("update", u))
return nil
}
return s.pts.Handle(ctx, update{
Value: u,
State: pts,
Count: ptsCount,
Entities: ents,
})
}
func (s *internalState) handleQts(ctx context.Context, qts int, u tg.UpdateClass, ents entities) error {
if err := validateQts(qts); err != nil {
s.log.Error("Qts validation failed", zap.Error(err), zap.Any("update", u))
return nil
}
return s.qts.Handle(ctx, update{
Value: u,
State: qts,
Count: 1,
Entities: ents,
})
}
func (s *internalState) handleChannel(ctx context.Context, channelID int64, date, pts, ptsCount int, cu channelUpdate) error {
if err := validatePts(pts, ptsCount); err != nil {
s.log.Error("Pts validation failed", zap.Error(err), zap.Any("update", cu.update))
return nil
}
s.channelsLock.Lock()
state, ok := s.channels[channelID]
s.channelsLock.Unlock()
if !ok {
accessHash, found, err := s.hasher.GetChannelAccessHash(context.Background(), s.selfID, channelID)
if err != nil {
s.log.Error("GetAccessHash error", zap.Error(err))
}
if !found {
if date == 0 {
// Received update has no date field.
date = s.date - 30
} else {
date-- // 1 sec back
}
// Try to get access hash from updates.getDifference.
accessHash, found = s.restoreChannelAccessHash(ctx, channelID, date)
if !found {
s.log.Debug("Failed to recover missing access hash, update ignored",
zap.Int64("channel_id", channelID),
// zap.Any("update", cu.update),
)
return nil
}
}
localPts, found, err := s.storage.GetChannelPts(ctx, s.selfID, channelID)
if localPts == -1 {
found = false
}
if err != nil {
localPts = pts - ptsCount
s.log.Error("GetChannelPts error", zap.Error(err))
}
if !found {
localPts = pts - ptsCount
if err := s.storage.SetChannelPts(ctx, s.selfID, channelID, localPts); err != nil {
s.log.Error("SetChannelPts error", zap.Error(err))
}
}
state = s.createAndRunChannelState(ctx, channelID, accessHash, localPts)
}
return state.Push(ctx, cu)
}
func (s *internalState) RemoveChannel(channelID int64, reason error) {
if s == nil {
return
}
s.channelsLock.Lock()
state, ok := s.channels[channelID]
s.channelsLock.Unlock()
if !ok {
return
}
state.stop(fmt.Errorf("%w: %w", ErrRemoveChannelState, reason))
}
func (s *internalState) createAndRunChannelState(ctx context.Context, channelID, accessHash int64, initialPts int) (state *channelState) {
state = newChannelState(ctx, channelStateConfig{
Out: s.internalQueue,
InitialPts: initialPts,
ChannelID: channelID,
AccessHash: accessHash,
SelfID: s.selfID,
Storage: s.storage,
DiffLimit: s.diffLim,
RawClient: s.client,
Handler: s.handler,
OnChannelTooLong: s.onTooLong,
Logger: s.log.Named("channel").With(zap.Int64("channel_id", channelID)),
Tracer: s.tracer,
})
s.channelsLock.Lock()
s.channels[channelID] = state
s.channelsLock.Unlock()
s.wg.Go(func() error {
err := state.Run(ctx)
if errors.Is(err, ErrRemoveChannelState) {
s.channelsLock.Lock()
delete(s.channels, channelID)
s.channelsLock.Unlock()
s.log.Info("Removed channel state due to error", zap.Int64("channel_id", channelID), zap.Error(err))
return nil
}
return err
})
return state
}
func (s *internalState) getDifference(ctx context.Context) error {
ctx, span := s.tracer.Start(ctx, "getDifference")
defer span.End()
s.resetIdleTimer()
s.pts.gaps.Clear()
s.qts.gaps.Clear()
s.seq.gaps.Clear()
s.log.Debug("Getting difference")
setState := func(state tg.UpdatesState, reason string) {
if err := s.storage.SetState(ctx, s.selfID, State{}.fromRemote(&state)); err != nil {
s.log.Warn("SetState error", zap.Error(err))
}
s.pts.SetState(state.Pts, reason)
s.qts.SetState(state.Qts, reason)
s.seq.SetState(state.Seq, reason)
s.date = state.Date
}
diff, err := s.client.UpdatesGetDifference(ctx, &tg.UpdatesGetDifferenceRequest{
Pts: s.pts.State(),
Qts: s.qts.State(),
Date: s.date,
})
if err != nil {
return errors.Wrap(err, "get difference")
}
s.log.Debug("Difference received", zap.String("diff", fmt.Sprintf("%T", diff)))
switch diff := diff.(type) {
case *tg.UpdatesDifference:
if len(diff.OtherUpdates) > 0 {
if err := s.handleUpdates(ctx, &tg.UpdatesCombined{
Updates: diff.OtherUpdates,
Users: diff.Users,
Chats: diff.Chats,
}); err != nil {
return errors.Wrap(err, "handle diff.OtherUpdates")
}
}
if len(diff.NewMessages) > 0 || len(diff.NewEncryptedMessages) > 0 {
if err := s.handler.Handle(ctx, &tg.Updates{
Updates: append(
msgsToUpdates(diff.NewMessages, false),
encryptedMsgsToUpdates(diff.NewEncryptedMessages)...,
),
Users: diff.Users,
Chats: diff.Chats,
}); err != nil {
return err
}
}
setState(diff.State, "updates.Difference")
return nil
// No events.
case *tg.UpdatesDifferenceEmpty:
if err := s.storage.SetDateSeq(ctx, s.selfID, diff.Date, diff.Seq); err != nil {
s.log.Warn("SetDateSeq error", zap.Error(err))
}
s.date = diff.Date
s.seq.SetState(diff.Seq, "updates.differenceEmpty")
return nil
// Incomplete list of occurred events.
case *tg.UpdatesDifferenceSlice:
if len(diff.OtherUpdates) > 0 {
if err := s.handleUpdates(ctx, &tg.UpdatesCombined{
Updates: diff.OtherUpdates,
Users: diff.Users,
Chats: diff.Chats,
Date: diff.IntermediateState.Date,
}); err != nil {
return err
}
}
if len(diff.NewMessages) > 0 || len(diff.NewEncryptedMessages) > 0 {
if err := s.handler.Handle(ctx, &tg.Updates{
Updates: append(
msgsToUpdates(diff.NewMessages, false),
encryptedMsgsToUpdates(diff.NewEncryptedMessages)...,
),
Users: diff.Users,
Chats: diff.Chats,
}); err != nil {
return err
}
}
setState(diff.IntermediateState, "updates.differenceSlice")
return s.getDifference(ctx)
// The difference is too long, and the specified internalState must be used to refetch updates.
case *tg.UpdatesDifferenceTooLong:
if err := s.storage.SetPts(ctx, s.selfID, diff.Pts); err != nil {
s.log.Error("SetPts error", zap.Error(err))
}
s.pts.SetState(diff.Pts, "updates.differenceTooLong")
return s.getDifference(ctx)
default:
return errors.Errorf("unexpected diff type: %T", diff)
}
}
func (s *internalState) getDifferenceLogger(ctx context.Context) {
if err := s.getDifference(ctx); err != nil {
s.log.Error("get difference error", zap.Error(err))
}
}
func (s *internalState) resetIdleTimer() {
if len(s.idleTimeout.C) > 0 {
<-s.idleTimeout.C
}
_ = s.idleTimeout.Reset(idleTimeout)
}