Files
mautrix-telegram/pkg/gotd/telegram/updates/state_channel.go
T
2025-12-10 19:39:56 +02:00

361 lines
8.9 KiB
Go

package updates
import (
"context"
"fmt"
"time"
"github.com/go-faster/errors"
"go.opentelemetry.io/otel/trace"
"go.uber.org/zap"
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
"go.mau.fi/mautrix-telegram/pkg/gotd/tgerr"
)
type channelUpdate struct {
update tg.UpdateClass
entities entities
span trace.SpanContext
}
type channelState struct {
// Updates from *internalState.
updates chan channelUpdate
// Channel to pass diff.OtherUpdates into *internalState.
out chan<- tracedUpdate
// Channel internalState.
pts *sequenceBox
idleTimeout *time.Timer
diffTimeout time.Time
// Immutable fields.
channelID int64
accessHash int64
selfID int64
diffLim int
client API
storage StateStorage
log *zap.Logger
tracer trace.Tracer
handler telegram.UpdateHandler
onTooLong func(channelID int64) error
runCtx context.Context
stop context.CancelCauseFunc
}
type channelStateConfig struct {
Out chan tracedUpdate
InitialPts int
ChannelID int64
AccessHash int64
SelfID int64
DiffLimit int
RawClient API
Storage StateStorage
Handler telegram.UpdateHandler
OnChannelTooLong func(channelID int64) error
Logger *zap.Logger
Tracer trace.Tracer
}
func newChannelState(cfg channelStateConfig) *channelState {
state := &channelState{
updates: make(chan channelUpdate, 10),
out: cfg.Out,
idleTimeout: time.NewTimer(idleTimeout),
channelID: cfg.ChannelID,
accessHash: cfg.AccessHash,
selfID: cfg.SelfID,
diffLim: cfg.DiffLimit,
client: cfg.RawClient,
storage: cfg.Storage,
log: cfg.Logger,
handler: cfg.Handler,
onTooLong: cfg.OnChannelTooLong,
tracer: cfg.Tracer,
}
state.pts = newSequenceBox(sequenceConfig{
InitialState: cfg.InitialPts,
Apply: state.applyPts,
Logger: cfg.Logger.Named("pts"),
Tracer: cfg.Tracer,
})
return state
}
func (s *channelState) Push(ctx context.Context, u channelUpdate) error {
select {
case <-ctx.Done():
return ctx.Err()
case <-s.runCtx.Done():
return s.runCtx.Err()
case s.updates <- u:
return nil
}
}
var ErrRemoveChannelState = errors.New("remove channel state")
func (s *channelState) Run(ctx context.Context) error {
s.runCtx, s.stop = context.WithCancelCause(ctx)
defer s.stop(nil)
// Subscribe to channel updates.
if err := s.getDifference(s.runCtx); err != nil {
s.log.Error("Failed to subscribe to channel updates", zap.Error(err))
}
for {
select {
case u := <-s.updates:
ctx := trace.ContextWithSpanContext(s.runCtx, u.span)
if err := s.handleUpdate(ctx, u.update, u.entities); err != nil {
s.log.Error("Handle update error", zap.Error(err))
}
case <-s.pts.gapTimeout.C:
s.log.Debug("Gap timeout")
s.getDifferenceLogger(s.runCtx)
case <-s.runCtx.Done():
if cause := context.Cause(s.runCtx); cause != nil && ctx.Err() == nil {
return cause
}
return s.runCtx.Err()
case <-s.idleTimeout.C:
s.log.Debug("Idle timeout")
s.resetIdleTimer()
s.getDifferenceLogger(s.runCtx)
}
}
}
func (s *channelState) handleUpdate(ctx context.Context, u tg.UpdateClass, ents entities) error {
ctx, span := s.tracer.Start(ctx, "channelState.handleUpdate")
defer span.End()
s.resetIdleTimer()
if long, ok := u.(*tg.UpdateChannelTooLong); ok {
return s.handleTooLong(ctx, long)
}
channelID, pts, ptsCount, ok, err := tg.IsChannelPtsUpdate(u)
if err != nil {
return errors.Wrap(err, "invalid update")
}
if !ok {
return errors.Errorf("expected channel update, got: %T", u)
}
if channelID != s.channelID {
return errors.Errorf("update for wrong channel (channelID: %d)", channelID)
}
return s.pts.Handle(ctx, update{
Value: u,
State: pts,
Count: ptsCount,
Entities: ents,
})
}
func (s *channelState) handleTooLong(ctx context.Context, long *tg.UpdateChannelTooLong) error {
ctx, span := s.tracer.Start(ctx, "channelState.handleTooLong")
defer span.End()
remotePts, ok := long.GetPts()
if !ok {
s.log.Warn("Got UpdateChannelTooLong without pts field")
return s.getDifference(ctx)
}
// Note: we still can fetch latest diffLim updates.
// Should we do?
if remotePts-s.pts.State() > s.diffLim {
return s.onTooLong(s.channelID)
}
return s.getDifference(ctx)
}
func (s *channelState) applyPts(ctx context.Context, state int, updates []update) error {
ctx, span := s.tracer.Start(ctx, "channelState.applyPts")
defer span.End()
var (
converted []tg.UpdateClass
ents entities
)
for _, update := range updates {
converted = append(converted, update.Value.(tg.UpdateClass))
ents.Merge(update.Entities)
}
if err := s.handler.Handle(ctx, &tg.Updates{
Updates: converted,
Users: ents.Users,
Chats: ents.Chats,
}); err != nil {
s.log.Error("Handle update error", zap.Error(err))
return nil
}
if err := s.storage.SetChannelPts(ctx, s.selfID, s.channelID, state); err != nil {
s.log.Error("SetChannelPts error", zap.Error(err))
}
return nil
}
func (s *channelState) getDifference(ctx context.Context) error {
ctx, span := s.tracer.Start(ctx, "channelState.getDifference")
defer span.End()
s.pts.gaps.Clear()
s.log.Debug("Getting difference")
if now := time.Now(); now.Before(s.diffTimeout) {
dur := s.diffTimeout.Sub(now)
s.log.Debug("GetChannelDifference timeout", zap.Duration("duration", dur))
if err := func() error {
afterC := time.After(dur)
for {
select {
case <-afterC:
return nil
case _, ok := <-s.updates:
if !ok {
continue
}
// Ignoring updates to prevent *internalState worker from blocking.
// All ignored updates should be restored by future getChannelDifference call.
// At least I hope so...
s.log.Debug("Ignoring update due to getChannelDifference timeout") // , zap.Any("update", u.update))
case <-ctx.Done():
return ctx.Err()
}
}
}(); err != nil {
return err
}
}
diff, err := s.client.UpdatesGetChannelDifference(ctx, &tg.UpdatesGetChannelDifferenceRequest{
Channel: &tg.InputChannel{
ChannelID: s.channelID,
AccessHash: s.accessHash,
},
Filter: &tg.ChannelMessagesFilterEmpty{},
Pts: s.pts.State(),
Limit: s.diffLim,
})
if err != nil {
if tgerr.Is(err, tg.ErrChannelPrivate) || tgerr.Is(err, tg.ErrChannelInvalid) {
if err := s.storage.SetChannelPts(ctx, s.selfID, s.channelID, -1); err != nil {
s.log.Error("SetChannelPts error (clear)", zap.Error(err))
}
s.stop(fmt.Errorf("%w: %w", ErrRemoveChannelState, err))
}
return errors.Wrap(err, "get channel difference")
}
switch diff := diff.(type) {
case *tg.UpdatesChannelDifference:
if len(diff.OtherUpdates) > 0 {
select {
case s.out <- tracedUpdate{
span: trace.SpanContextFromContext(ctx),
update: &tg.Updates{
Updates: diff.OtherUpdates,
Users: diff.Users,
Chats: diff.Chats,
},
}:
case <-ctx.Done():
return ctx.Err()
}
}
if len(diff.NewMessages) > 0 {
if err := s.handler.Handle(ctx, &tg.Updates{
Updates: msgsToUpdates(diff.NewMessages, true),
Users: diff.Users,
Chats: diff.Chats,
}); err != nil {
return err
}
}
if err := s.storage.SetChannelPts(ctx, s.selfID, s.channelID, diff.Pts); err != nil {
s.log.Warn("SetChannelPts error", zap.Error(err))
}
s.pts.SetState(diff.Pts, "updates.channelDifference")
if seconds, ok := diff.GetTimeout(); ok {
s.diffTimeout = time.Now().Add(time.Second * time.Duration(seconds))
}
if !diff.Final {
return s.getDifference(ctx)
}
return nil
case *tg.UpdatesChannelDifferenceEmpty:
if err := s.storage.SetChannelPts(ctx, s.selfID, s.channelID, diff.Pts); err != nil {
s.log.Warn("SetChannelPts error", zap.Error(err))
}
s.pts.SetState(diff.Pts, "updates.channelDifferenceEmpty")
if seconds, ok := diff.GetTimeout(); ok {
s.diffTimeout = time.Now().Add(time.Second * time.Duration(seconds))
}
return nil
case *tg.UpdatesChannelDifferenceTooLong:
if seconds, ok := diff.GetTimeout(); ok {
s.diffTimeout = time.Now().Add(time.Second * time.Duration(seconds))
}
remotePts, err := getDialogPts(diff.Dialog)
if err != nil {
s.log.Warn("UpdatesChannelDifferenceTooLong invalid Dialog", zap.Error(err))
} else {
if err := s.storage.SetChannelPts(ctx, s.selfID, s.channelID, remotePts); err != nil {
s.log.Warn("SetChannelPts error", zap.Error(err))
}
s.pts.SetState(remotePts, "updates.channelDifferenceTooLong dialog new pts")
}
return s.onTooLong(s.channelID)
default:
return errors.Errorf("unexpected channel diff type: %T", diff)
}
}
func (s *channelState) getDifferenceLogger(ctx context.Context) {
if err := s.getDifference(ctx); err != nil {
s.log.Error("get channel difference error", zap.Error(err))
}
}
func (s *channelState) resetIdleTimer() {
if len(s.idleTimeout.C) > 0 {
<-s.idleTimeout.C
}
_ = s.idleTimeout.Reset(idleTimeout)
}