move gotd fork into repo. (#111)
- update to latest telegram layer - remove some references to fields in tg.Entities that don't exist in the schema - originally added here: https://github.com/beeper/td/commit/820929062a2ba0104397bc01235ab58a9cff780e - referenced here - https://github.com/mautrix/telegramgo/commit/124f0967ed195b5a380c9bd02e170ada9710dde3 - https://github.com/mautrix/telegramgo/commit/4205047aab2e0639217148b5d125bfaab668bd8e
This commit is contained in:
@@ -0,0 +1,134 @@
|
||||
package updates
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"golang.org/x/net/context"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
func (s *internalState) saveChannelHashes(ctx context.Context, chats []tg.ChatClass) {
|
||||
ctx, span := s.tracer.Start(ctx, "updates.saveChannelHashes")
|
||||
defer span.End()
|
||||
|
||||
for _, c := range chats {
|
||||
switch c := c.(type) {
|
||||
case *tg.Channel:
|
||||
if c.Min {
|
||||
continue
|
||||
}
|
||||
|
||||
if hash, ok := c.GetAccessHash(); ok {
|
||||
if _, ok = s.channels[c.ID]; ok {
|
||||
continue
|
||||
}
|
||||
s.log.Debug("New channel access hash",
|
||||
zap.Int64("channel_id", c.ID),
|
||||
zap.String("title", c.Title),
|
||||
)
|
||||
if err := s.hasher.SetChannelAccessHash(ctx, s.selfID, c.ID, hash); err != nil {
|
||||
s.log.Error("SetChannelAccessHash error", zap.Error(err))
|
||||
}
|
||||
}
|
||||
case *tg.ChannelForbidden:
|
||||
if _, ok := s.channels[c.ID]; ok {
|
||||
continue
|
||||
}
|
||||
s.log.Debug("New channel access hash",
|
||||
zap.Int64("channel_id", c.ID),
|
||||
zap.String("title", c.Title),
|
||||
)
|
||||
if err := s.hasher.SetChannelAccessHash(ctx, s.selfID, c.ID, c.AccessHash); err != nil {
|
||||
s.log.Error("SetChannelAccessHash error", zap.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *internalState) saveUserHashes(ctx context.Context, chats []tg.UserClass) {
|
||||
ctx, span := s.tracer.Start(ctx, "updates.saveChannelHashes")
|
||||
defer span.End()
|
||||
|
||||
for _, u := range chats {
|
||||
if user, ok := u.(*tg.User); !ok {
|
||||
continue
|
||||
} else if hash, ok := user.GetAccessHash(); !ok {
|
||||
continue
|
||||
} else if user.Min {
|
||||
s.log.Debug("User is min, not saving access hash")
|
||||
continue
|
||||
} else {
|
||||
s.log.Debug("New user access hash", zap.Int64("user_id", user.ID))
|
||||
if err := s.hasher.SetUserAccessHash(ctx, s.selfID, user.ID, hash); err != nil {
|
||||
s.log.Error("SetUserAccessHash error", zap.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *internalState) handleDifference(ctx context.Context, date int) (chats []tg.ChatClass, users []tg.UserClass, err error) {
|
||||
ctx, span := s.tracer.Start(ctx, "updates.handleDifference")
|
||||
defer span.End()
|
||||
|
||||
diff, err := s.client.UpdatesGetDifference(ctx, &tg.UpdatesGetDifferenceRequest{
|
||||
Pts: s.pts.State(),
|
||||
Qts: s.qts.State(),
|
||||
Date: date,
|
||||
})
|
||||
if err != nil {
|
||||
s.log.Error("UpdatesGetDifference error", zap.Error(err))
|
||||
return nil, nil, fmt.Errorf("get difference: %w", err)
|
||||
}
|
||||
|
||||
switch diff := diff.(type) {
|
||||
case *tg.UpdatesDifference:
|
||||
chats = diff.Chats
|
||||
users = diff.Users
|
||||
case *tg.UpdatesDifferenceSlice:
|
||||
chats = diff.Chats
|
||||
users = diff.Users
|
||||
}
|
||||
|
||||
s.saveChannelHashes(ctx, chats)
|
||||
s.saveUserHashes(ctx, users)
|
||||
return chats, users, nil
|
||||
}
|
||||
|
||||
func (s *internalState) restoreChannelAccessHash(ctx context.Context, channelID int64, date int) (accessHash int64, ok bool) {
|
||||
ctx, span := s.tracer.Start(ctx, "updates.restoreAccessHash")
|
||||
defer span.End()
|
||||
|
||||
chats, _, err := s.handleDifference(ctx, date)
|
||||
if err != nil {
|
||||
s.log.Error("getDifference error", zap.Error(err))
|
||||
return 0, false
|
||||
}
|
||||
|
||||
for _, c := range chats {
|
||||
switch c := c.(type) {
|
||||
case *tg.Channel:
|
||||
if c.Min {
|
||||
continue
|
||||
}
|
||||
|
||||
if c.ID != channelID {
|
||||
continue
|
||||
}
|
||||
|
||||
if hash, ok := c.GetAccessHash(); ok {
|
||||
return hash, true
|
||||
}
|
||||
|
||||
case *tg.ChannelForbidden:
|
||||
if c.ID != channelID {
|
||||
continue
|
||||
}
|
||||
|
||||
return c.AccessHash, true
|
||||
}
|
||||
}
|
||||
|
||||
return 0, false
|
||||
}
|
||||
@@ -0,0 +1,62 @@
|
||||
package updates
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
// API is the interface which contains
|
||||
// Telegram RPC methods used by manager for internalState synchronization.
|
||||
type API interface {
|
||||
UpdatesGetState(ctx context.Context) (*tg.UpdatesState, error)
|
||||
UpdatesGetDifference(ctx context.Context, request *tg.UpdatesGetDifferenceRequest) (tg.UpdatesDifferenceClass, error)
|
||||
UpdatesGetChannelDifference(ctx context.Context, request *tg.UpdatesGetChannelDifferenceRequest) (tg.UpdatesChannelDifferenceClass, error)
|
||||
}
|
||||
|
||||
// Config of the manager.
|
||||
type Config struct {
|
||||
// Handler where updates will be passed.
|
||||
Handler telegram.UpdateHandler
|
||||
// Callback called if manager cannot
|
||||
// recover channel gap (optional).
|
||||
OnChannelTooLong func(channelID int64) error
|
||||
// State storage.
|
||||
// In-mem used if not provided.
|
||||
Storage StateStorage
|
||||
// Channel access hash storage.
|
||||
// In-mem used if not provided.
|
||||
AccessHasher AccessHasher
|
||||
// Logger (optional).
|
||||
Logger *zap.Logger
|
||||
// TracerProvider (optional).
|
||||
TracerProvider trace.TracerProvider
|
||||
}
|
||||
|
||||
func (cfg *Config) setDefaults() {
|
||||
if cfg.Handler == nil {
|
||||
panic("Handler is nil")
|
||||
}
|
||||
if cfg.AccessHasher == nil {
|
||||
cfg.AccessHasher = newMemAccessHasher()
|
||||
}
|
||||
if cfg.Logger == nil {
|
||||
cfg.Logger = zap.NewNop()
|
||||
}
|
||||
if cfg.TracerProvider == nil {
|
||||
cfg.TracerProvider = trace.NewNoopTracerProvider()
|
||||
}
|
||||
if cfg.Storage == nil {
|
||||
cfg.Storage = newMemStorage()
|
||||
}
|
||||
if cfg.OnChannelTooLong == nil {
|
||||
cfg.OnChannelTooLong = func(channelID int64) error {
|
||||
cfg.Logger.Error("Difference too long", zap.Int64("channel_id", channelID))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,129 @@
|
||||
package updates
|
||||
|
||||
import "go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
|
||||
func convertOptional(msg *tg.Message, i tg.UpdatesClass) {
|
||||
if u, ok := i.(interface {
|
||||
GetFwdFrom() (tg.MessageFwdHeader, bool)
|
||||
}); ok {
|
||||
if v, ok := u.GetFwdFrom(); ok {
|
||||
msg.SetFwdFrom(v)
|
||||
}
|
||||
}
|
||||
if u, ok := i.(interface{ GetViaBotID() (int64, bool) }); ok {
|
||||
if v, ok := u.GetViaBotID(); ok {
|
||||
msg.SetViaBotID(v)
|
||||
}
|
||||
}
|
||||
if u, ok := i.(interface {
|
||||
GetReplyTo() (tg.MessageReplyHeaderClass, bool)
|
||||
}); ok {
|
||||
if v, ok := u.GetReplyTo(); ok {
|
||||
msg.SetReplyTo(v)
|
||||
}
|
||||
}
|
||||
if u, ok := i.(interface {
|
||||
GetReplyTo() (tg.MessageReplyHeader, bool)
|
||||
}); ok {
|
||||
if v, ok := u.GetReplyTo(); ok {
|
||||
msg.SetReplyTo(&v)
|
||||
}
|
||||
}
|
||||
if u, ok := i.(interface {
|
||||
GetEntities() ([]tg.MessageEntityClass, bool)
|
||||
}); ok {
|
||||
if v, ok := u.GetEntities(); ok {
|
||||
msg.SetEntities(v)
|
||||
}
|
||||
}
|
||||
if u, ok := i.(interface {
|
||||
GetMedia() (tg.MessageMediaClass, bool)
|
||||
}); ok {
|
||||
if v, ok := u.GetMedia(); ok {
|
||||
msg.SetMedia(v)
|
||||
}
|
||||
}
|
||||
if u, ok := i.(interface {
|
||||
GetTTLPeriod() (int, bool)
|
||||
}); ok {
|
||||
if v, ok := u.GetTTLPeriod(); ok {
|
||||
msg.SetTTLPeriod(v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *internalState) convertShortMessage(u *tg.UpdateShortMessage) *tg.UpdateShort {
|
||||
msg := &tg.Message{
|
||||
ID: u.ID,
|
||||
PeerID: &tg.PeerUser{UserID: u.UserID},
|
||||
Message: u.Message,
|
||||
Date: u.Date,
|
||||
}
|
||||
// Optional fields should set by SetXXX(), so GetXXX and Flags.Has()
|
||||
// can return the right values even we hav't call .Encode()
|
||||
msg.SetOut(u.Out)
|
||||
msg.SetMentioned(u.Mentioned)
|
||||
msg.SetMediaUnread(u.MediaUnread)
|
||||
msg.SetSilent(u.Silent)
|
||||
|
||||
msg.SetFromID(&tg.PeerUser{UserID: s.selfID})
|
||||
if !u.Out {
|
||||
msg.SetFromID(&tg.PeerUser{UserID: u.UserID})
|
||||
}
|
||||
convertOptional(msg, u)
|
||||
|
||||
return &tg.UpdateShort{
|
||||
Update: &tg.UpdateNewMessage{
|
||||
Message: msg,
|
||||
Pts: u.Pts,
|
||||
PtsCount: u.PtsCount,
|
||||
},
|
||||
Date: u.Date,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *internalState) convertShortChatMessage(u *tg.UpdateShortChatMessage) *tg.UpdateShort {
|
||||
msg := &tg.Message{
|
||||
ID: u.ID,
|
||||
PeerID: &tg.PeerChat{ChatID: u.ChatID},
|
||||
Message: u.Message,
|
||||
Date: u.Date,
|
||||
}
|
||||
|
||||
msg.SetOut(u.Out)
|
||||
msg.SetMentioned(u.Mentioned)
|
||||
msg.SetMediaUnread(u.MediaUnread)
|
||||
msg.SetSilent(u.Silent)
|
||||
msg.SetFromScheduled(msg.FromScheduled)
|
||||
msg.SetFromID(&tg.PeerUser{UserID: u.FromID})
|
||||
convertOptional(msg, u)
|
||||
|
||||
return &tg.UpdateShort{
|
||||
Update: &tg.UpdateNewMessage{
|
||||
Message: msg,
|
||||
Pts: u.Pts,
|
||||
PtsCount: u.PtsCount,
|
||||
},
|
||||
Date: u.Date,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *internalState) convertShortSentMessage(u *tg.UpdateShortSentMessage) *tg.UpdateShort {
|
||||
// This update should be converted by the one who called the method
|
||||
// that returned this update, because we do not have any context about
|
||||
// it (message text, sender/recipient, etc.)
|
||||
//
|
||||
// In theory, this update should come only as a response to an RPC call,
|
||||
// and we get it here because of the update hook.
|
||||
// We use it to make sure there are no pts gaps.
|
||||
return &tg.UpdateShort{
|
||||
Update: &tg.UpdateNewMessage{
|
||||
Message: &tg.MessageEmpty{
|
||||
ID: u.ID,
|
||||
},
|
||||
Pts: u.Pts,
|
||||
PtsCount: u.PtsCount,
|
||||
},
|
||||
Date: u.Date,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,20 @@
|
||||
// Package updates provides a Telegram's internalState synchronization manager.
|
||||
//
|
||||
// It guarantees that all internalState-sensitive updates will be performed
|
||||
// in correct order.
|
||||
//
|
||||
// Limitations:
|
||||
//
|
||||
// 1. Manager cannot verify stateless types of updates
|
||||
// (tg.UpdatesClass without Seq, or tg.UpdateClass without Pts or Qts).
|
||||
//
|
||||
// 2. Due to the fact that updates.getDifference and updates.getChannelDifference
|
||||
// do not return event sequences, manager cannot guarantee the correctness
|
||||
// of these operations. We rely on the server here.
|
||||
//
|
||||
// 3. Manager cannot recover the channel gap if there is a ChannelDifferenceTooLong error.
|
||||
// Restoring the internalState in such situation is not the prerogative of this manager.
|
||||
// See: https://core.telegram.org/constructor/updates.channelDifferenceTooLong
|
||||
//
|
||||
// TODO: Write implementation details.
|
||||
package updates
|
||||
@@ -0,0 +1,54 @@
|
||||
package updates
|
||||
|
||||
import "go.uber.org/zap/zapcore"
|
||||
|
||||
type gap struct {
|
||||
from, to int
|
||||
}
|
||||
|
||||
type gapBuffer struct {
|
||||
gaps []gap
|
||||
}
|
||||
|
||||
func (b gapBuffer) Has() bool { return len(b.gaps) > 0 }
|
||||
|
||||
func (b *gapBuffer) Clear() { b.gaps = make([]gap, 0, 1) }
|
||||
|
||||
func (b *gapBuffer) Enable(from, to int) {
|
||||
if len(b.gaps) > 0 {
|
||||
panic("unreachable")
|
||||
}
|
||||
|
||||
b.gaps = append(b.gaps, gap{from, to})
|
||||
}
|
||||
|
||||
func (b *gapBuffer) Consume(u update) (accepted bool) {
|
||||
for i, g := range b.gaps {
|
||||
if g.from <= u.start() && g.to >= u.end() {
|
||||
if g.from < u.start() {
|
||||
b.gaps = append(b.gaps, gap{from: g.from, to: u.start()})
|
||||
}
|
||||
if g.to > u.end() {
|
||||
b.gaps = append(b.gaps, gap{from: u.end(), to: g.to})
|
||||
}
|
||||
|
||||
b.gaps = append(b.gaps[:i], b.gaps[i+1:]...)
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (b gapBuffer) MarshalLogArray(e zapcore.ArrayEncoder) error {
|
||||
for _, g := range b.gaps {
|
||||
if err := e.AppendObject(zapcore.ObjectMarshalerFunc(func(e zapcore.ObjectEncoder) error {
|
||||
e.AddInt("from", g.from)
|
||||
e.AddInt("to", g.to)
|
||||
return nil
|
||||
})); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,24 @@
|
||||
package updates
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGapBuffer(t *testing.T) {
|
||||
buf := new(gapBuffer)
|
||||
buf.Enable(1, 7)
|
||||
|
||||
require.True(t, buf.Consume(update{State: 4, Count: 3}))
|
||||
require.Equal(t, []gap{{4, 7}}, buf.gaps)
|
||||
|
||||
require.False(t, buf.Consume(update{State: 1, Count: 1}))
|
||||
require.Equal(t, []gap{{4, 7}}, buf.gaps)
|
||||
|
||||
require.False(t, buf.Consume(update{State: 8, Count: 1}))
|
||||
require.Equal(t, []gap{{4, 7}}, buf.gaps)
|
||||
|
||||
require.True(t, buf.Consume(update{State: 7, Count: 3}))
|
||||
require.Empty(t, buf.gaps)
|
||||
}
|
||||
@@ -0,0 +1,27 @@
|
||||
package updates
|
||||
|
||||
type gapCheckResult byte
|
||||
|
||||
const (
|
||||
_ gapCheckResult = iota
|
||||
gapApply
|
||||
gapIgnore
|
||||
gapRefetch
|
||||
)
|
||||
|
||||
func checkGap(localState, remoteState, count int) gapCheckResult {
|
||||
// Temporary fix for handling qts updates gaps.
|
||||
if remoteState == 0 {
|
||||
return gapApply
|
||||
}
|
||||
|
||||
if localState+count == remoteState {
|
||||
return gapApply
|
||||
}
|
||||
|
||||
if localState+count > remoteState {
|
||||
return gapIgnore
|
||||
}
|
||||
|
||||
return gapRefetch
|
||||
}
|
||||
@@ -0,0 +1,34 @@
|
||||
// Package hook contains telegram update hook middleware.
|
||||
package hook
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
// UpdateHook middleware is called on each tg.UpdatesClass method result.
|
||||
//
|
||||
// Function is called before invoker return. Returned error will be wrapped
|
||||
// and returned as InvokeRaw result.
|
||||
type UpdateHook func(ctx context.Context, u tg.UpdatesClass) error
|
||||
|
||||
// Handle implements telegram.Middleware.
|
||||
func (h UpdateHook) Handle(next tg.Invoker) telegram.InvokeFunc {
|
||||
return func(ctx context.Context, input bin.Encoder, output bin.Decoder) error {
|
||||
if err := next.Invoke(ctx, input, output); err != nil {
|
||||
return err
|
||||
}
|
||||
if u, ok := output.(*tg.UpdatesBox); ok {
|
||||
if err := h(ctx, u.Updates); err != nil {
|
||||
return errors.Wrap(err, "hook")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,90 @@
|
||||
package hook
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
func TestUpdateHook_InvokeRaw(t *testing.T) {
|
||||
t.Run("Success", func(t *testing.T) {
|
||||
var invokerCalled, hookCalled bool
|
||||
assert.NoError(t, UpdateHook(func(ctx context.Context, u tg.UpdatesClass) error {
|
||||
assert.NotNil(t, u)
|
||||
hookCalled = true
|
||||
return nil
|
||||
}).Handle(telegram.InvokeFunc(func(ctx context.Context, input bin.Encoder, output bin.Decoder) error {
|
||||
invokerCalled = true
|
||||
return nil
|
||||
})).Invoke(context.TODO(), nil, &tg.UpdatesBox{
|
||||
Updates: &tg.UpdateShortMessage{
|
||||
ID: 100,
|
||||
},
|
||||
}))
|
||||
|
||||
assert.True(t, invokerCalled, "invoker should be called")
|
||||
assert.True(t, hookCalled, "hook should be called")
|
||||
})
|
||||
t.Run("Error", func(t *testing.T) {
|
||||
t.Run("Handler", func(t *testing.T) {
|
||||
var invokerCalled, hookCalled bool
|
||||
err := errors.New("failure")
|
||||
assert.ErrorIs(t, UpdateHook(func(ctx context.Context, u tg.UpdatesClass) error {
|
||||
assert.NotNil(t, u)
|
||||
hookCalled = true
|
||||
return nil
|
||||
}).Handle(telegram.InvokeFunc(func(ctx context.Context, input bin.Encoder, output bin.Decoder) error {
|
||||
invokerCalled = true
|
||||
return err
|
||||
})).Invoke(context.TODO(), nil, &tg.UpdatesBox{
|
||||
Updates: &tg.UpdateShortMessage{
|
||||
ID: 100,
|
||||
},
|
||||
}), err)
|
||||
|
||||
assert.True(t, invokerCalled, "invoker should be called")
|
||||
assert.False(t, hookCalled, "hook should not be called")
|
||||
})
|
||||
t.Run("Hook", func(t *testing.T) {
|
||||
var invokerCalled, hookCalled bool
|
||||
err := errors.New("failure")
|
||||
assert.ErrorIs(t, UpdateHook(func(ctx context.Context, u tg.UpdatesClass) error {
|
||||
assert.NotNil(t, u)
|
||||
hookCalled = true
|
||||
return err
|
||||
}).Handle(telegram.InvokeFunc(func(ctx context.Context, input bin.Encoder, output bin.Decoder) error {
|
||||
invokerCalled = true
|
||||
return nil
|
||||
})).Invoke(context.TODO(), nil, &tg.UpdatesBox{
|
||||
Updates: &tg.UpdateShortMessage{
|
||||
ID: 100,
|
||||
},
|
||||
}), err)
|
||||
|
||||
assert.True(t, invokerCalled, "invoker should be called")
|
||||
assert.True(t, hookCalled, "hook should be called")
|
||||
})
|
||||
})
|
||||
t.Run("Not update", func(t *testing.T) {
|
||||
var invokerCalled, hookCalled bool
|
||||
assert.NoError(t, UpdateHook(func(ctx context.Context, u tg.UpdatesClass) error {
|
||||
assert.NotNil(t, u)
|
||||
hookCalled = true
|
||||
return nil
|
||||
}).Handle(telegram.InvokeFunc(func(ctx context.Context, input bin.Encoder, output bin.Decoder) error {
|
||||
invokerCalled = true
|
||||
return nil
|
||||
})).Invoke(context.TODO(), nil, &tg.User{}))
|
||||
|
||||
assert.True(t, invokerCalled, "invoker should be called")
|
||||
assert.False(t, hookCalled, "hook should not be called")
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,81 @@
|
||||
package e2e
|
||||
|
||||
import "go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
|
||||
// Entities contains update entities.
|
||||
type Entities struct {
|
||||
Users map[int64]*tg.User
|
||||
Chats map[int64]*tg.Chat
|
||||
Channels map[int64]*tg.Channel
|
||||
ChannelsForbidden map[int64]*tg.ChannelForbidden
|
||||
}
|
||||
|
||||
// NewEntities creates new Entities.
|
||||
func NewEntities() *Entities {
|
||||
return &Entities{
|
||||
Users: map[int64]*tg.User{},
|
||||
Chats: map[int64]*tg.Chat{},
|
||||
Channels: map[int64]*tg.Channel{},
|
||||
ChannelsForbidden: map[int64]*tg.ChannelForbidden{},
|
||||
}
|
||||
}
|
||||
|
||||
// Merge merges entities.
|
||||
func (e *Entities) Merge(from *Entities) {
|
||||
if from == nil {
|
||||
return
|
||||
}
|
||||
|
||||
for userID, user := range from.Users {
|
||||
e.Users[userID] = user
|
||||
}
|
||||
|
||||
for chanID, chat := range from.Chats {
|
||||
e.Chats[chanID] = chat
|
||||
}
|
||||
|
||||
for channelID, channel := range from.Channels {
|
||||
e.Channels[channelID] = channel
|
||||
}
|
||||
|
||||
for channelID, channel := range from.ChannelsForbidden {
|
||||
e.ChannelsForbidden[channelID] = channel
|
||||
}
|
||||
}
|
||||
|
||||
// FromUpdates method.
|
||||
func (e *Entities) FromUpdates(u interface {
|
||||
tg.UpdatesClass
|
||||
MapUsers() tg.UserClassArray
|
||||
MapChats() tg.ChatClassArray
|
||||
}) *Entities {
|
||||
u.MapChats().FillChatMap(e.Chats)
|
||||
u.MapChats().FillChannelMap(e.Channels)
|
||||
u.MapChats().FillChannelForbiddenMap(e.ChannelsForbidden)
|
||||
u.MapUsers().FillUserMap(e.Users)
|
||||
return e
|
||||
}
|
||||
|
||||
// AsUsers returns users as tg.UserClass slice.
|
||||
func (e *Entities) AsUsers() []tg.UserClass {
|
||||
var users []tg.UserClass
|
||||
for _, u := range e.Users {
|
||||
users = append(users, u)
|
||||
}
|
||||
return users
|
||||
}
|
||||
|
||||
// AsChats returns chats as tg.ChatClass slice.
|
||||
func (e *Entities) AsChats() []tg.ChatClass {
|
||||
var chats []tg.ChatClass
|
||||
for _, c := range e.Chats {
|
||||
chats = append(chats, c)
|
||||
}
|
||||
for _, c := range e.Channels {
|
||||
chats = append(chats, c)
|
||||
}
|
||||
for _, c := range e.ChannelsForbidden {
|
||||
chats = append(chats, c)
|
||||
}
|
||||
return chats
|
||||
}
|
||||
@@ -0,0 +1,60 @@
|
||||
package e2e
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
// Handler handles updates.
|
||||
type handler struct {
|
||||
messages *messageDatabase
|
||||
ents *Entities
|
||||
mux sync.Mutex
|
||||
}
|
||||
|
||||
func newHandler() *handler {
|
||||
return &handler{
|
||||
messages: &messageDatabase{
|
||||
channels: make(map[int64][]tg.MessageClass),
|
||||
},
|
||||
ents: NewEntities(),
|
||||
}
|
||||
}
|
||||
|
||||
func (h *handler) Handle(ctx context.Context, u tg.UpdatesClass) error {
|
||||
switch u := u.(type) {
|
||||
case *tg.Updates:
|
||||
return h.handleUpdates(NewEntities().FromUpdates(u), u.Updates)
|
||||
case *tg.UpdatesCombined:
|
||||
return h.handleUpdates(NewEntities().FromUpdates(u), u.Updates)
|
||||
default:
|
||||
panic(u)
|
||||
}
|
||||
}
|
||||
|
||||
// HandleUpdates handler.
|
||||
func (h *handler) handleUpdates(ents *Entities, upds []tg.UpdateClass) error {
|
||||
h.mux.Lock()
|
||||
defer h.mux.Unlock()
|
||||
|
||||
h.ents.Merge(ents)
|
||||
for _, u := range upds {
|
||||
switch u := u.(type) {
|
||||
case *tg.UpdateNewMessage:
|
||||
h.messages.common = append(h.messages.common, u.Message)
|
||||
case *tg.UpdateNewEncryptedMessage:
|
||||
h.messages.secret = append(h.messages.secret, u.Message)
|
||||
case *tg.UpdateNewChannelMessage:
|
||||
channelID := u.Message.(*tg.Message).PeerID.(*tg.PeerChannel).ChannelID
|
||||
msgs := h.messages.channels[channelID]
|
||||
msgs = append(msgs, u.Message)
|
||||
h.messages.channels[channelID] = msgs
|
||||
default:
|
||||
panic("unexpected update type")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,211 @@
|
||||
//go:build linux
|
||||
|
||||
package e2e
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap/zaptest"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/updates"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
func TestE2E(t *testing.T) {
|
||||
testManager(t, func(s *server, storage updates.StateStorage) chan *tg.Updates {
|
||||
t.Helper()
|
||||
|
||||
c := make(chan *tg.Updates, 10)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
var (
|
||||
biba = s.peers.createUser("biba")
|
||||
boba = s.peers.createUser("boba")
|
||||
chat = s.peers.createChat("chat")
|
||||
)
|
||||
|
||||
var channels []*tg.PeerChannel
|
||||
require.NoError(t, storage.ForEachChannels(ctx, 123, func(ctx context.Context, channelID int64, pts int) error {
|
||||
channels = append(channels, &tg.PeerChannel{
|
||||
ChannelID: channelID,
|
||||
})
|
||||
return nil
|
||||
}))
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
// Biba.
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for i := 0; i < 3; i++ {
|
||||
c <- s.CreateEvent(func(ev *EventBuilder) {
|
||||
ev.SendMessage(biba, chat, fmt.Sprintf("biba-%d", i))
|
||||
|
||||
for mi, c := range channels {
|
||||
ev.SendMessage(biba, c, fmt.Sprintf("biba-channel-%d-%d", i, mi))
|
||||
}
|
||||
})
|
||||
}
|
||||
}()
|
||||
|
||||
// Boba.
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for i := 0; i < 3; i++ {
|
||||
c <- s.CreateEvent(func(ev *EventBuilder) {
|
||||
ev.SendMessage(boba, chat, fmt.Sprintf("boba-%d", i))
|
||||
|
||||
for _, c := range channels {
|
||||
ev.SendMessage(boba, c, fmt.Sprintf("boba-channel-%d", i))
|
||||
}
|
||||
})
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(c)
|
||||
}()
|
||||
return c
|
||||
})
|
||||
}
|
||||
|
||||
func testManager(t *testing.T, f func(s *server, storage updates.StateStorage) chan *tg.Updates) {
|
||||
t.Helper()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
var (
|
||||
log = zaptest.NewLogger(t)
|
||||
s = newServer()
|
||||
h = newHandler()
|
||||
storage = newMemStorage()
|
||||
hasher = newMemAccessHasher()
|
||||
)
|
||||
|
||||
const uid = 123
|
||||
|
||||
require.NoError(t, storage.SetState(ctx, uid, updates.State{
|
||||
Pts: 0,
|
||||
Qts: 0,
|
||||
Date: 0,
|
||||
Seq: 0,
|
||||
}))
|
||||
|
||||
for i := 0; i < 2; i++ {
|
||||
c := s.peers.createChannel(fmt.Sprintf("channel-%d", i))
|
||||
require.NoError(t, storage.SetChannelPts(ctx, uid, c.ChannelID, 0))
|
||||
require.NoError(t, hasher.SetChannelAccessHash(ctx, uid, c.ChannelID, c.ChannelID*2))
|
||||
}
|
||||
|
||||
e := updates.New(updates.Config{
|
||||
Handler: h,
|
||||
Logger: log.Named("gaps"),
|
||||
Storage: storage,
|
||||
AccessHasher: hasher,
|
||||
})
|
||||
|
||||
uchan := loss(f(s, storage))
|
||||
g, ctx := errgroup.WithContext(ctx)
|
||||
ready := make(chan struct{})
|
||||
opts := updates.AuthOptions{
|
||||
OnStart: func(ctx context.Context) {
|
||||
t.Log("OnStart")
|
||||
close(ready)
|
||||
},
|
||||
}
|
||||
g.Go(func() error {
|
||||
t.Log("Starting manager")
|
||||
defer t.Log("Manager stopped")
|
||||
return e.Run(ctx, s, uid, opts)
|
||||
})
|
||||
g.Go(func() error {
|
||||
t.Log("Starting updates generator")
|
||||
defer t.Log("Updates generator stopped")
|
||||
|
||||
defer cancel()
|
||||
|
||||
select {
|
||||
case <-ready:
|
||||
t.Log("Ready")
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
var g errgroup.Group
|
||||
for i := 0; i < 2; i++ {
|
||||
g.Go(func() error {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case u, ok := <-uchan:
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := e.Handle(ctx, u); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
t.Log("Waiting")
|
||||
if err := g.Wait(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
t.Log("Sending pts changed")
|
||||
|
||||
ups := []tg.UpdateClass{&tg.UpdatePtsChanged{}}
|
||||
if err := storage.ForEachChannels(ctx, uid, func(ctx context.Context, channelID int64, pts int) error {
|
||||
ups = append(ups, &tg.UpdateChannelTooLong{ChannelID: channelID})
|
||||
return nil
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
t.Log("Handle")
|
||||
|
||||
return e.Handle(ctx, &tg.Updates{
|
||||
Updates: ups,
|
||||
})
|
||||
})
|
||||
|
||||
t.Log("Waiting for shutdown")
|
||||
require.ErrorIs(t, g.Wait(), context.Canceled)
|
||||
|
||||
t.Log("Checking")
|
||||
require.Equal(t, s.messages, h.messages)
|
||||
require.Equal(t, s.peers.channels, h.ents.Channels)
|
||||
require.Equal(t, s.peers.chats, h.ents.Chats)
|
||||
require.Equal(t, s.peers.users, h.ents.Users)
|
||||
}
|
||||
|
||||
func loss(in chan *tg.Updates) chan *tg.Updates {
|
||||
out := make(chan *tg.Updates)
|
||||
|
||||
go func() {
|
||||
defer close(out)
|
||||
|
||||
for u := range in {
|
||||
if rand.Intn(5) == 1 {
|
||||
continue
|
||||
}
|
||||
|
||||
out <- u
|
||||
}
|
||||
}()
|
||||
|
||||
return out
|
||||
}
|
||||
@@ -0,0 +1,48 @@
|
||||
package e2e
|
||||
|
||||
import "go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
|
||||
type messageDatabase struct {
|
||||
common []tg.MessageClass
|
||||
secret []tg.EncryptedMessageClass
|
||||
channels map[int64][]tg.MessageClass
|
||||
}
|
||||
|
||||
type peerDatabase struct {
|
||||
users map[int64]*tg.User
|
||||
chats map[int64]*tg.Chat
|
||||
channels map[int64]*tg.Channel
|
||||
|
||||
id int64
|
||||
}
|
||||
|
||||
func (p *peerDatabase) createUser(username string) *tg.PeerUser {
|
||||
p.users[p.id] = &tg.User{
|
||||
ID: p.id,
|
||||
Username: username,
|
||||
}
|
||||
|
||||
defer func() { p.id++ }()
|
||||
return &tg.PeerUser{UserID: p.id}
|
||||
}
|
||||
|
||||
func (p *peerDatabase) createChat(title string) *tg.PeerChat {
|
||||
p.chats[p.id] = &tg.Chat{
|
||||
ID: p.id,
|
||||
Title: title,
|
||||
}
|
||||
|
||||
defer func() { p.id++ }()
|
||||
return &tg.PeerChat{ChatID: p.id}
|
||||
}
|
||||
|
||||
func (p *peerDatabase) createChannel(username string) *tg.PeerChannel {
|
||||
p.channels[p.id] = &tg.Channel{
|
||||
ID: p.id,
|
||||
Username: username,
|
||||
}
|
||||
p.channels[p.id].SetAccessHash(p.id * 2)
|
||||
|
||||
defer func() { p.id++ }()
|
||||
return &tg.PeerChannel{ChannelID: p.id}
|
||||
}
|
||||
@@ -0,0 +1,190 @@
|
||||
// Package e2e contains end-to-end updates processing test.
|
||||
package e2e
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
// Server for testing gaps.
|
||||
type server struct {
|
||||
date int
|
||||
peers *peerDatabase
|
||||
messages *messageDatabase
|
||||
|
||||
mux sync.Mutex
|
||||
}
|
||||
|
||||
// NewServer creates new test server.
|
||||
func newServer() *server {
|
||||
return &server{
|
||||
date: 1,
|
||||
peers: &peerDatabase{
|
||||
users: make(map[int64]*tg.User),
|
||||
chats: make(map[int64]*tg.Chat),
|
||||
channels: make(map[int64]*tg.Channel),
|
||||
},
|
||||
messages: &messageDatabase{
|
||||
channels: make(map[int64][]tg.MessageClass),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// UpdatesGetState returns current remote state.
|
||||
func (s *server) UpdatesGetState(ctx context.Context) (*tg.UpdatesState, error) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
return &tg.UpdatesState{
|
||||
Pts: len(s.messages.common),
|
||||
Qts: len(s.messages.secret),
|
||||
Date: s.date,
|
||||
Seq: 0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// UpdatesGetDifference returns difference between local and remote states.
|
||||
func (s *server) UpdatesGetDifference(ctx context.Context, request *tg.UpdatesGetDifferenceRequest) (tg.UpdatesDifferenceClass, error) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
ents := NewEntities()
|
||||
var common []tg.MessageClass
|
||||
for i := request.Pts + 1; i <= len(s.messages.common); i++ {
|
||||
common = append(common, s.messages.common[i-1])
|
||||
s.fillMessageEnts(s.messages.common[i-1], ents)
|
||||
}
|
||||
|
||||
var secret []tg.EncryptedMessageClass
|
||||
for i := request.Qts + 1; i <= len(s.messages.secret); i++ {
|
||||
secret = append(secret, s.messages.secret[i-1])
|
||||
}
|
||||
|
||||
var others []tg.UpdateClass
|
||||
for _, msgs := range s.messages.channels {
|
||||
for i, msg := range msgs {
|
||||
if msg.(*tg.Message).Date > request.Date {
|
||||
others = append(others, &tg.UpdateNewChannelMessage{
|
||||
Message: msg,
|
||||
Pts: i + 1,
|
||||
PtsCount: 1,
|
||||
})
|
||||
s.fillMessageEnts(msg, ents)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(common) == 0 && len(secret) == 0 && len(others) == 0 {
|
||||
return &tg.UpdatesDifferenceEmpty{
|
||||
Date: s.date,
|
||||
Seq: 0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &tg.UpdatesDifference{
|
||||
NewMessages: common,
|
||||
NewEncryptedMessages: secret,
|
||||
OtherUpdates: others,
|
||||
Users: ents.AsUsers(),
|
||||
Chats: ents.AsChats(),
|
||||
State: tg.UpdatesState{
|
||||
Pts: len(s.messages.common),
|
||||
Qts: len(s.messages.secret),
|
||||
Date: s.date,
|
||||
Seq: 0,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// UpdatesGetChannelDifference returns difference between local and remote channel states.
|
||||
func (s *server) UpdatesGetChannelDifference(
|
||||
ctx context.Context, request *tg.UpdatesGetChannelDifferenceRequest,
|
||||
) (tg.UpdatesChannelDifferenceClass, error) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
channel, ok := request.Channel.(*tg.InputChannel)
|
||||
if !ok {
|
||||
return nil, errors.Errorf("bad InputChannelClass type: %T", request.Channel)
|
||||
}
|
||||
|
||||
if peer, ok := s.peers.channels[channel.ChannelID]; true {
|
||||
if !ok {
|
||||
return nil, errors.Errorf("channel %d not found", channel.ChannelID)
|
||||
}
|
||||
|
||||
if peer.AccessHash != channel.AccessHash {
|
||||
return nil, errors.New("invalid access hash")
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
channelMsgs = s.messages.channels[channel.ChannelID]
|
||||
ents = NewEntities()
|
||||
prepared []tg.MessageClass
|
||||
)
|
||||
|
||||
for i := request.Pts + 1; i <= len(channelMsgs); i++ {
|
||||
prepared = append(prepared, channelMsgs[i-1])
|
||||
s.fillMessageEnts(channelMsgs[i-1], ents)
|
||||
}
|
||||
|
||||
if len(prepared) == 0 {
|
||||
return &tg.UpdatesChannelDifferenceEmpty{
|
||||
Pts: len(channelMsgs),
|
||||
Final: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &tg.UpdatesChannelDifference{
|
||||
NewMessages: prepared,
|
||||
Users: ents.AsUsers(),
|
||||
Chats: ents.AsChats(),
|
||||
Pts: len(channelMsgs),
|
||||
Final: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *server) fillMessageEnts(msg tg.MessageClass, ents *Entities) {
|
||||
switch peer := msg.(*tg.Message).PeerID.(type) {
|
||||
case *tg.PeerUser:
|
||||
user, ok := s.peers.users[peer.UserID]
|
||||
if !ok {
|
||||
panic("bad user")
|
||||
}
|
||||
|
||||
ents.Users[user.ID] = user
|
||||
case *tg.PeerChat:
|
||||
chat, ok := s.peers.chats[peer.ChatID]
|
||||
if !ok {
|
||||
panic("bad chat")
|
||||
}
|
||||
|
||||
ents.Chats[chat.ID] = chat
|
||||
case *tg.PeerChannel:
|
||||
channel, ok := s.peers.channels[peer.ChannelID]
|
||||
if !ok {
|
||||
panic("bad channel")
|
||||
}
|
||||
|
||||
ents.Channels[channel.ID] = channel
|
||||
default:
|
||||
panic("unexpected peer type")
|
||||
}
|
||||
|
||||
peerUser, ok := msg.(*tg.Message).FromID.(*tg.PeerUser)
|
||||
if !ok {
|
||||
panic("bad fromID")
|
||||
}
|
||||
|
||||
user, ok := s.peers.users[peerUser.UserID]
|
||||
if !ok {
|
||||
panic("bad user")
|
||||
}
|
||||
|
||||
ents.Users[user.ID] = user
|
||||
}
|
||||
@@ -0,0 +1,96 @@
|
||||
package e2e
|
||||
|
||||
import (
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
// EventBuilder struct.
|
||||
type EventBuilder struct {
|
||||
updates []tg.UpdateClass
|
||||
ents *Entities
|
||||
s *server
|
||||
date int
|
||||
}
|
||||
|
||||
// SendMessage send a new message.
|
||||
func (e *EventBuilder) SendMessage(from *tg.PeerUser, peer tg.PeerClass, text string) {
|
||||
msg := &tg.Message{
|
||||
Message: text,
|
||||
PeerID: peer,
|
||||
FromID: from,
|
||||
Date: e.date,
|
||||
}
|
||||
|
||||
fromUser, ok := e.s.peers.users[from.UserID]
|
||||
if !ok {
|
||||
panic("bad fromID")
|
||||
}
|
||||
e.ents.Users[from.UserID] = fromUser
|
||||
|
||||
switch peer := peer.(type) {
|
||||
case *tg.PeerUser:
|
||||
user, ok := e.s.peers.users[peer.UserID]
|
||||
if !ok {
|
||||
panic("peer not found")
|
||||
}
|
||||
|
||||
e.ents.Users[user.ID] = user
|
||||
e.s.messages.common = append(e.s.messages.common, msg)
|
||||
e.updates = append(e.updates, &tg.UpdateNewMessage{
|
||||
Message: msg,
|
||||
Pts: len(e.s.messages.common),
|
||||
PtsCount: 1,
|
||||
})
|
||||
case *tg.PeerChat:
|
||||
chat, ok := e.s.peers.chats[peer.ChatID]
|
||||
if !ok {
|
||||
panic("peer not found")
|
||||
}
|
||||
|
||||
e.ents.Chats[chat.ID] = chat
|
||||
e.s.messages.common = append(e.s.messages.common, msg)
|
||||
e.updates = append(e.updates, &tg.UpdateNewMessage{
|
||||
Message: msg,
|
||||
Pts: len(e.s.messages.common),
|
||||
PtsCount: 1,
|
||||
})
|
||||
case *tg.PeerChannel:
|
||||
channel, ok := e.s.peers.channels[peer.ChannelID]
|
||||
if !ok {
|
||||
panic("peer not found")
|
||||
}
|
||||
|
||||
e.ents.Channels[channel.ID] = channel
|
||||
msgs := append(e.s.messages.channels[peer.ChannelID], msg)
|
||||
e.s.messages.channels[peer.ChannelID] = msgs
|
||||
e.updates = append(e.updates, &tg.UpdateNewChannelMessage{
|
||||
Message: msg,
|
||||
Pts: len(msgs),
|
||||
PtsCount: 1,
|
||||
})
|
||||
default:
|
||||
panic("unexpected peer type")
|
||||
}
|
||||
}
|
||||
|
||||
// CreateEvent creates new event.
|
||||
func (s *server) CreateEvent(f func(ev *EventBuilder)) *tg.Updates {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
s.date++
|
||||
ev := &EventBuilder{
|
||||
ents: NewEntities(),
|
||||
s: s,
|
||||
date: s.date,
|
||||
}
|
||||
f(ev)
|
||||
|
||||
return &tg.Updates{
|
||||
Updates: ev.updates,
|
||||
Users: ev.ents.AsUsers(),
|
||||
Chats: ev.ents.AsChats(),
|
||||
Date: s.date,
|
||||
Seq: 0,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,226 @@
|
||||
package e2e
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/updates"
|
||||
)
|
||||
|
||||
var _ updates.StateStorage = (*memStorage)(nil)
|
||||
|
||||
type memStorage struct {
|
||||
states map[int64]updates.State
|
||||
channels map[int64]map[int64]int
|
||||
mux sync.Mutex
|
||||
}
|
||||
|
||||
func newMemStorage() *memStorage {
|
||||
return &memStorage{
|
||||
states: map[int64]updates.State{},
|
||||
channels: map[int64]map[int64]int{},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *memStorage) GetState(ctx context.Context, userID int64) (state updates.State, found bool, err error) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
state, found = s.states[userID]
|
||||
return
|
||||
}
|
||||
|
||||
func (s *memStorage) SetState(ctx context.Context, userID int64, state updates.State) error {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
s.states[userID] = state
|
||||
s.channels[userID] = map[int64]int{}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memStorage) SetPts(ctx context.Context, userID int64, pts int) error {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
state, ok := s.states[userID]
|
||||
if !ok {
|
||||
return errors.New("state not found")
|
||||
}
|
||||
|
||||
state.Pts = pts
|
||||
s.states[userID] = state
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memStorage) SetQts(ctx context.Context, userID int64, qts int) error {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
state, ok := s.states[userID]
|
||||
if !ok {
|
||||
return errors.New("state not found")
|
||||
}
|
||||
|
||||
state.Qts = qts
|
||||
s.states[userID] = state
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memStorage) SetDate(ctx context.Context, userID int64, date int) error {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
state, ok := s.states[userID]
|
||||
if !ok {
|
||||
return errors.New("state not found")
|
||||
}
|
||||
|
||||
state.Date = date
|
||||
s.states[userID] = state
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memStorage) SetSeq(ctx context.Context, userID int64, seq int) error {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
state, ok := s.states[userID]
|
||||
if !ok {
|
||||
return errors.New("state not found")
|
||||
}
|
||||
|
||||
state.Seq = seq
|
||||
s.states[userID] = state
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memStorage) SetDateSeq(ctx context.Context, userID int64, date, seq int) error {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
state, ok := s.states[userID]
|
||||
if !ok {
|
||||
return errors.New("state not found")
|
||||
}
|
||||
|
||||
state.Date = date
|
||||
state.Seq = seq
|
||||
s.states[userID] = state
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memStorage) SetChannelPts(ctx context.Context, userID, channelID int64, pts int) error {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
channels, ok := s.channels[userID]
|
||||
if !ok {
|
||||
return errors.New("user state does not exist")
|
||||
}
|
||||
|
||||
channels[channelID] = pts
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memStorage) GetChannelPts(ctx context.Context, userID, channelID int64) (pts int, found bool, err error) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
channels, ok := s.channels[userID]
|
||||
if !ok {
|
||||
return 0, false, nil
|
||||
}
|
||||
|
||||
pts, found = channels[channelID]
|
||||
return
|
||||
}
|
||||
|
||||
func (s *memStorage) ForEachChannels(ctx context.Context, userID int64, f func(ctx context.Context, channelID int64, pts int) error) error {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
cmap, ok := s.channels[userID]
|
||||
if !ok {
|
||||
return errors.New("channels map does not exist")
|
||||
}
|
||||
|
||||
for id, pts := range cmap {
|
||||
if err := f(ctx, id, pts); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var _ updates.AccessHasher = (*memAccessHasher)(nil)
|
||||
|
||||
type memAccessHasher struct {
|
||||
channelHashes map[int64]map[int64]int64
|
||||
userHashes map[int64]map[int64]int64
|
||||
mux sync.Mutex
|
||||
}
|
||||
|
||||
func newMemAccessHasher() *memAccessHasher {
|
||||
return &memAccessHasher{
|
||||
channelHashes: map[int64]map[int64]int64{},
|
||||
userHashes: map[int64]map[int64]int64{},
|
||||
}
|
||||
}
|
||||
|
||||
func (m *memAccessHasher) GetChannelAccessHash(ctx context.Context, forUserID, channelID int64) (accessHash int64, found bool, err error) {
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
|
||||
accessHashes, ok := m.channelHashes[forUserID]
|
||||
if !ok {
|
||||
return 0, false, nil
|
||||
}
|
||||
|
||||
accessHash, found = accessHashes[channelID]
|
||||
return
|
||||
}
|
||||
|
||||
func (m *memAccessHasher) SetChannelAccessHash(ctx context.Context, forUserID, channelID, accessHash int64) error {
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
|
||||
accessHashes, ok := m.channelHashes[forUserID]
|
||||
if !ok {
|
||||
accessHashes = map[int64]int64{}
|
||||
m.channelHashes[forUserID] = accessHashes
|
||||
}
|
||||
|
||||
accessHashes[channelID] = accessHash
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *memAccessHasher) GetUserAccessHash(ctx context.Context, forUserID, userID int64) (accessHash int64, found bool, err error) {
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
|
||||
accessHashes, ok := m.userHashes[forUserID]
|
||||
if !ok {
|
||||
return 0, false, nil
|
||||
}
|
||||
|
||||
accessHash, found = accessHashes[userID]
|
||||
return
|
||||
}
|
||||
|
||||
func (m *memAccessHasher) SetUserAccessHash(ctx context.Context, forUserID, userID, accessHash int64) error {
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
|
||||
accessHashes, ok := m.userHashes[forUserID]
|
||||
if !ok {
|
||||
accessHashes = map[int64]int64{}
|
||||
m.channelHashes[forUserID] = accessHashes
|
||||
}
|
||||
|
||||
accessHashes[userID] = accessHash
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,199 @@
|
||||
package updates
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"go.uber.org/zap"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
var _ telegram.UpdateHandler = (*Manager)(nil)
|
||||
|
||||
// Manager deals with gaps.
|
||||
//
|
||||
// Important:
|
||||
// Updates produced by this manager may contain
|
||||
// negative Pts/Qts/Seq values in tg.UpdateClass/tg.UpdatesClass
|
||||
// (does not affects to the tg.MessageClass).
|
||||
//
|
||||
// This is because telegram server does not return these sequences
|
||||
// for getDifference/getChannelDifference results.
|
||||
// You SHOULD NOT use them in update handlers at all.
|
||||
type Manager struct {
|
||||
state *internalState
|
||||
mux sync.Mutex
|
||||
|
||||
// immutable:
|
||||
|
||||
cfg Config
|
||||
lg *zap.Logger
|
||||
tracer trace.Tracer
|
||||
}
|
||||
|
||||
// New creates new manager.
|
||||
func New(cfg Config) *Manager {
|
||||
cfg.setDefaults()
|
||||
return &Manager{
|
||||
cfg: cfg,
|
||||
lg: cfg.Logger,
|
||||
tracer: cfg.TracerProvider.Tracer(""),
|
||||
}
|
||||
}
|
||||
|
||||
// Handle handles updates.
|
||||
//
|
||||
// Important:
|
||||
// If Run method not called, all updates will be passed
|
||||
// to the provided handler as-is without any order verification
|
||||
// or short updates transformation.
|
||||
func (m *Manager) Handle(ctx context.Context, u tg.UpdatesClass) error {
|
||||
ctx, span := m.tracer.Start(ctx, "updates.Manager.Handle")
|
||||
defer span.End()
|
||||
|
||||
m.lg.Debug("Handle")
|
||||
defer m.lg.Debug("Handled")
|
||||
|
||||
m.mux.Lock()
|
||||
state := m.state
|
||||
m.mux.Unlock()
|
||||
|
||||
if state == nil {
|
||||
m.lg.Debug("Handle (no internalState)")
|
||||
return m.cfg.Handler.Handle(ctx, u)
|
||||
}
|
||||
|
||||
return state.Push(ctx, u)
|
||||
}
|
||||
|
||||
type AuthOptions struct {
|
||||
IsBot bool
|
||||
Forget bool
|
||||
OnStart func(ctx context.Context)
|
||||
}
|
||||
|
||||
// Run notifies manager about user authentication on the telegram server.
|
||||
//
|
||||
// If forget is true, local internalState (if exist) will be overwritten
|
||||
// with remote internalState.
|
||||
func (m *Manager) Run(ctx context.Context, api API, userID int64, opt AuthOptions) error {
|
||||
lg := m.lg.With(
|
||||
zap.Int64("user_id", userID),
|
||||
zap.Bool("is_bot", opt.IsBot),
|
||||
zap.Bool("forget", opt.Forget),
|
||||
)
|
||||
lg.Debug("Run")
|
||||
defer lg.Debug("Done")
|
||||
|
||||
wg, ctx := errgroup.WithContext(ctx)
|
||||
|
||||
if err := func() error {
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
|
||||
if m.state != nil {
|
||||
return errors.Errorf("already authorized (userID: %d)", m.state.selfID)
|
||||
}
|
||||
|
||||
state, err := m.loadState(ctx, api, userID, opt.Forget)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "load internalState")
|
||||
}
|
||||
channels := make(map[int64]struct {
|
||||
Pts int
|
||||
AccessHash int64
|
||||
})
|
||||
if err := m.cfg.Storage.ForEachChannels(ctx, userID, func(ctx context.Context, channelID int64, pts int) error {
|
||||
hash, found, err := m.cfg.AccessHasher.GetChannelAccessHash(ctx, userID, channelID)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "get channel access hash")
|
||||
}
|
||||
|
||||
if !found {
|
||||
return nil
|
||||
}
|
||||
|
||||
channels[channelID] = struct {
|
||||
Pts int
|
||||
AccessHash int64
|
||||
}{Pts: pts, AccessHash: hash}
|
||||
return nil
|
||||
}); err != nil {
|
||||
return errors.Wrap(err, "iterate channels")
|
||||
}
|
||||
|
||||
diffLim := diffLimitUser
|
||||
if opt.IsBot {
|
||||
diffLim = diffLimitBot
|
||||
}
|
||||
|
||||
m.state = newState(ctx, stateConfig{
|
||||
State: state,
|
||||
Channels: channels,
|
||||
RawClient: api,
|
||||
Tracer: m.tracer,
|
||||
Logger: m.cfg.Logger,
|
||||
Handler: m.cfg.Handler,
|
||||
OnChannelTooLong: m.cfg.OnChannelTooLong,
|
||||
Storage: m.cfg.Storage,
|
||||
Hasher: m.cfg.AccessHasher,
|
||||
SelfID: userID,
|
||||
DiffLimit: diffLim,
|
||||
WorkGroup: wg,
|
||||
})
|
||||
|
||||
return nil
|
||||
}(); err != nil {
|
||||
return errors.Wrap(err, "setup")
|
||||
}
|
||||
if opt.OnStart != nil {
|
||||
opt.OnStart(ctx)
|
||||
}
|
||||
wg.Go(func() error {
|
||||
return m.state.Run(ctx)
|
||||
})
|
||||
lg.Debug("Wait")
|
||||
return wg.Wait()
|
||||
}
|
||||
|
||||
func (m *Manager) loadState(ctx context.Context, api API, userID int64, forget bool) (State, error) {
|
||||
onNotFound:
|
||||
var state State
|
||||
if forget {
|
||||
remote, err := api.UpdatesGetState(ctx)
|
||||
if err != nil {
|
||||
return State{}, errors.Wrap(err, "get remote internalState")
|
||||
}
|
||||
|
||||
state = state.fromRemote(remote)
|
||||
if err := m.cfg.Storage.SetState(ctx, userID, state); err != nil {
|
||||
return State{}, err
|
||||
}
|
||||
|
||||
return state, nil
|
||||
}
|
||||
|
||||
state, found, err := m.cfg.Storage.GetState(ctx, userID)
|
||||
if err != nil {
|
||||
return State{}, errors.Wrap(err, "restore local internalState")
|
||||
}
|
||||
|
||||
if !found {
|
||||
forget = true
|
||||
goto onNotFound
|
||||
}
|
||||
|
||||
return state, nil
|
||||
}
|
||||
|
||||
// Reset notifies manager about user logout.
|
||||
func (m *Manager) Reset() {
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
m.state = nil
|
||||
}
|
||||
@@ -0,0 +1,189 @@
|
||||
package updates
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type sequenceBox struct {
|
||||
state int
|
||||
gaps gapBuffer
|
||||
gapTimeout *time.Timer
|
||||
pending []update
|
||||
|
||||
apply func(ctx context.Context, state int, updates []update) error
|
||||
log *zap.Logger
|
||||
tracer trace.Tracer
|
||||
}
|
||||
|
||||
type sequenceConfig struct {
|
||||
InitialState int
|
||||
Apply func(ctx context.Context, state int, updates []update) error
|
||||
Logger *zap.Logger
|
||||
Tracer trace.Tracer
|
||||
}
|
||||
|
||||
func newSequenceBox(cfg sequenceConfig) *sequenceBox {
|
||||
if cfg.Apply == nil {
|
||||
panic("Apply func nil")
|
||||
}
|
||||
if cfg.Logger == nil {
|
||||
cfg.Logger = zap.NewNop()
|
||||
}
|
||||
if cfg.Tracer == nil {
|
||||
cfg.Tracer = trace.NewNoopTracerProvider().Tracer("")
|
||||
}
|
||||
|
||||
cfg.Logger.Debug("Initialized", zap.Int("internalState", cfg.InitialState))
|
||||
|
||||
t := time.NewTimer(fastgapTimeout)
|
||||
_ = t.Stop()
|
||||
return &sequenceBox{
|
||||
state: cfg.InitialState,
|
||||
gapTimeout: t,
|
||||
apply: cfg.Apply,
|
||||
log: cfg.Logger,
|
||||
tracer: cfg.Tracer,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *sequenceBox) Handle(ctx context.Context, u update) error {
|
||||
ctx, span := s.tracer.Start(ctx, "sequenceBox.Handle")
|
||||
defer span.End()
|
||||
|
||||
log := s.log.With(zap.Int("upd_from", u.start()), zap.Int("upd_to", u.end()))
|
||||
if checkGap(s.state, u.State, u.Count) == gapIgnore {
|
||||
log.Debug("Outdated update, skipping", zap.Int("internalState", s.state))
|
||||
return nil
|
||||
}
|
||||
|
||||
if s.gaps.Has() {
|
||||
s.pending = append(s.pending, u)
|
||||
if accepted := s.gaps.Consume(u); !accepted {
|
||||
log.Debug("Out of gap range, postponed", zap.Array("gaps", s.gaps))
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Debug("Gap accepted", zap.Array("gaps", s.gaps))
|
||||
if !s.gaps.Has() {
|
||||
_ = s.gapTimeout.Stop()
|
||||
s.log.Debug("Gap was resolved by waiting")
|
||||
return s.applyPending(ctx)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
switch checkGap(s.state, u.State, u.Count) {
|
||||
case gapApply:
|
||||
if len(s.pending) > 0 {
|
||||
s.pending = append(s.pending, u)
|
||||
return s.applyPending(ctx)
|
||||
}
|
||||
|
||||
if err := s.apply(ctx, u.State, []update{u}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debug("Accepted")
|
||||
s.setState(u.State, "update")
|
||||
return nil
|
||||
case gapRefetch:
|
||||
s.pending = append(s.pending, u)
|
||||
s.gaps.Enable(s.state, u.start())
|
||||
|
||||
// Check if we already have acceptable updates in buffer.
|
||||
for _, u := range s.pending {
|
||||
_ = s.gaps.Consume(u)
|
||||
}
|
||||
|
||||
if !s.gaps.Has() {
|
||||
log.Debug("Gap was resolved by pending updates")
|
||||
return s.applyPending(ctx)
|
||||
}
|
||||
|
||||
_ = s.gapTimeout.Reset(fastgapTimeout)
|
||||
s.log.Debug("Gap detected", zap.Array("gap", s.gaps))
|
||||
return nil
|
||||
default:
|
||||
panic("unreachable")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *sequenceBox) applyPending(ctx context.Context) error {
|
||||
ctx, span := s.tracer.Start(ctx, "sequenceBox.applyPending")
|
||||
defer span.End()
|
||||
|
||||
sort.SliceStable(s.pending, func(i, j int) bool {
|
||||
return s.pending[i].start() < s.pending[j].start()
|
||||
})
|
||||
|
||||
var (
|
||||
cursor = 0
|
||||
state = s.state
|
||||
accepted []update
|
||||
)
|
||||
|
||||
loop:
|
||||
for i, update := range s.pending {
|
||||
switch checkGap(state, update.State, update.Count) {
|
||||
case gapApply:
|
||||
accepted = append(accepted, update)
|
||||
state = update.State
|
||||
cursor = i + 1
|
||||
continue
|
||||
|
||||
case gapIgnore:
|
||||
cursor = i + 1
|
||||
continue
|
||||
|
||||
case gapRefetch:
|
||||
break loop
|
||||
}
|
||||
}
|
||||
|
||||
// Trim processed updates. Setting zero values for the rest
|
||||
// of the slice lets GC collect referenced objects.
|
||||
end := len(s.pending)
|
||||
trim := end - cursor
|
||||
copy(s.pending, s.pending[cursor:])
|
||||
for i := trim; i < end; i++ {
|
||||
s.pending[i] = update{}
|
||||
}
|
||||
s.pending = s.pending[:trim]
|
||||
if len(accepted) == 0 {
|
||||
s.log.Warn("Empty buffer", zap.Any("pending", s.pending), zap.Int("internalState", s.state))
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := s.apply(ctx, state, accepted); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.log.Debug("Pending updates applied",
|
||||
zap.Int("prev_state", s.state),
|
||||
zap.Int("new_state", state),
|
||||
zap.Int("accepted_count", len(accepted)),
|
||||
)
|
||||
|
||||
s.setState(state, "pending updates")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *sequenceBox) State() int { return s.state }
|
||||
|
||||
func (s *sequenceBox) SetState(state int, reason string) {
|
||||
s.setState(state, reason)
|
||||
}
|
||||
|
||||
func (s *sequenceBox) setState(state int, reason string) {
|
||||
old := s.state
|
||||
s.state = state
|
||||
s.log.Debug("State changed",
|
||||
zap.Int("old", old),
|
||||
zap.Int("new", state),
|
||||
zap.String("reason", reason),
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,168 @@
|
||||
package updates
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap/zaptest"
|
||||
)
|
||||
|
||||
func TestSequenceBox(t *testing.T) {
|
||||
var (
|
||||
state int
|
||||
updates []update
|
||||
)
|
||||
|
||||
ctx := context.Background()
|
||||
box := newSequenceBox(sequenceConfig{
|
||||
InitialState: 3,
|
||||
Apply: func(ctx context.Context, s int, u []update) error {
|
||||
state = s
|
||||
updates = append(updates, u...)
|
||||
return nil
|
||||
},
|
||||
Logger: zaptest.NewLogger(t),
|
||||
})
|
||||
|
||||
require.Nil(t, box.Handle(ctx, update{
|
||||
Value: 1,
|
||||
State: 2,
|
||||
Count: 1,
|
||||
}))
|
||||
require.Zero(t, state)
|
||||
require.Empty(t, updates)
|
||||
require.Empty(t, box.pending)
|
||||
|
||||
require.Nil(t, box.Handle(ctx, update{
|
||||
Value: 1,
|
||||
State: 3,
|
||||
Count: 1,
|
||||
}))
|
||||
require.Zero(t, state)
|
||||
require.Empty(t, updates)
|
||||
require.Empty(t, box.pending)
|
||||
|
||||
require.Nil(t, box.Handle(ctx, update{
|
||||
Value: 1,
|
||||
State: 4,
|
||||
Count: 1,
|
||||
}))
|
||||
require.Equal(t, 4, state)
|
||||
require.Equal(t, []update{{Value: 1, State: 4, Count: 1}}, updates)
|
||||
require.Empty(t, box.pending)
|
||||
updates = nil
|
||||
|
||||
require.Nil(t, box.Handle(ctx, update{
|
||||
Value: 1,
|
||||
State: 6,
|
||||
Count: 1,
|
||||
}))
|
||||
require.Equal(t, 4, state)
|
||||
require.Empty(t, updates)
|
||||
require.Equal(t, []update{{Value: 1, State: 6, Count: 1}}, box.pending)
|
||||
|
||||
require.Nil(t, box.Handle(ctx, update{
|
||||
Value: 2,
|
||||
State: 5,
|
||||
Count: 1,
|
||||
}))
|
||||
require.Equal(t, 6, state)
|
||||
require.Equal(t, []update{{Value: 2, State: 5, Count: 1}, {Value: 1, State: 6, Count: 1}}, updates)
|
||||
require.Empty(t, box.pending)
|
||||
updates = nil
|
||||
|
||||
require.Nil(t, box.Handle(ctx, update{
|
||||
Value: 3,
|
||||
State: 8,
|
||||
Count: 1,
|
||||
}))
|
||||
require.Equal(t, 6, state)
|
||||
require.Empty(t, updates)
|
||||
require.Equal(t, []update{{Value: 3, State: 8, Count: 1}}, box.pending)
|
||||
<-box.gapTimeout.C
|
||||
|
||||
require.Equal(t, []gap{{from: 6, to: 7}}, box.gaps.gaps)
|
||||
box.gaps.Clear()
|
||||
require.False(t, box.gaps.Has())
|
||||
}
|
||||
|
||||
func TestSequenceBoxApplyPending(t *testing.T) {
|
||||
tests := []struct {
|
||||
InitialState int
|
||||
Pending []update
|
||||
PendingAfter []update
|
||||
Applied []update
|
||||
}{
|
||||
{
|
||||
InitialState: 5,
|
||||
Pending: []update{
|
||||
{Value: 1, State: 3, Count: 1},
|
||||
{Value: 1, State: 4, Count: 1},
|
||||
{Value: 1, State: 1, Count: 1},
|
||||
},
|
||||
PendingAfter: []update{},
|
||||
Applied: []update{},
|
||||
},
|
||||
{
|
||||
InitialState: 5,
|
||||
Pending: []update{
|
||||
{Value: 1, State: 3, Count: 1},
|
||||
{Value: 1, State: 8, Count: 1},
|
||||
{Value: 1, State: 7, Count: 1},
|
||||
{Value: 1, State: 4, Count: 1},
|
||||
{Value: 1, State: 1, Count: 1},
|
||||
},
|
||||
PendingAfter: []update{
|
||||
{1, 7, 1, entities{}},
|
||||
{1, 8, 1, entities{}},
|
||||
},
|
||||
Applied: []update{},
|
||||
},
|
||||
{
|
||||
InitialState: 5,
|
||||
Pending: []update{
|
||||
{Value: 1, State: 8, Count: 1},
|
||||
{Value: 1, State: 7, Count: 1},
|
||||
},
|
||||
PendingAfter: []update{
|
||||
{1, 7, 1, entities{}},
|
||||
{Value: 1, State: 8, Count: 1},
|
||||
},
|
||||
Applied: []update{},
|
||||
},
|
||||
{
|
||||
InitialState: 5,
|
||||
Pending: []update{
|
||||
{Value: 1, State: 3, Count: 1},
|
||||
{Value: 1, State: 6, Count: 1},
|
||||
{Value: 1, State: 8, Count: 1},
|
||||
{Value: 1, State: 4, Count: 1},
|
||||
{Value: 1, State: 1, Count: 1},
|
||||
},
|
||||
PendingAfter: []update{
|
||||
{Value: 1, State: 8, Count: 1},
|
||||
},
|
||||
Applied: []update{
|
||||
{Value: 1, State: 6, Count: 1},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
applied := make([]update, 0)
|
||||
box := newSequenceBox(sequenceConfig{
|
||||
InitialState: test.InitialState,
|
||||
Apply: func(_ context.Context, s int, u []update) error {
|
||||
applied = append(applied, u...)
|
||||
return nil
|
||||
},
|
||||
Logger: zaptest.NewLogger(t),
|
||||
})
|
||||
|
||||
box.pending = test.Pending
|
||||
require.NoError(t, box.applyPending(context.TODO()))
|
||||
require.Equal(t, test.PendingAfter, box.pending)
|
||||
require.Equal(t, test.Applied, applied)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,82 @@
|
||||
package updates
|
||||
|
||||
import (
|
||||
"sort"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
func sortUpdatesByPts(u []tg.UpdateClass) {
|
||||
sort.Stable(ptsSorter(u))
|
||||
}
|
||||
|
||||
type ptsSorter []tg.UpdateClass
|
||||
|
||||
func (p ptsSorter) Len() int {
|
||||
return len(p)
|
||||
}
|
||||
|
||||
func (p ptsSorter) Less(i, j int) bool {
|
||||
type (
|
||||
ptsType int
|
||||
compare struct {
|
||||
typ ptsType
|
||||
channelID int64
|
||||
ptsDiff int
|
||||
}
|
||||
)
|
||||
// Sorting order
|
||||
//
|
||||
// 0) Common updates without PTS.
|
||||
// 1) Common PTS updates
|
||||
// 2) Common QTS updates
|
||||
// 3) Channel PTS updates (by channelID and by pts)
|
||||
const (
|
||||
other ptsType = iota
|
||||
commonPts
|
||||
commonQts
|
||||
channelPts
|
||||
)
|
||||
getType := func(u tg.UpdateClass) compare {
|
||||
if pts, ptsCount, ok := tg.IsPtsUpdate(u); ok {
|
||||
return compare{typ: commonPts, ptsDiff: pts - ptsCount}
|
||||
}
|
||||
if channelID, pts, ptsCount, ok, _ := tg.IsChannelPtsUpdate(u); ok {
|
||||
return compare{typ: channelPts, channelID: channelID, ptsDiff: pts - ptsCount}
|
||||
}
|
||||
if qts, ok := tg.IsQtsUpdate(u); ok {
|
||||
return compare{typ: commonQts, ptsDiff: qts}
|
||||
}
|
||||
return compare{typ: other}
|
||||
}
|
||||
|
||||
a, b := getType(p[i]), getType(p[j])
|
||||
switch {
|
||||
case a.typ < b.typ:
|
||||
return true
|
||||
case a.typ > b.typ:
|
||||
return false
|
||||
}
|
||||
|
||||
// a.typ == b.typ
|
||||
switch a.typ {
|
||||
case other:
|
||||
// Keep original order
|
||||
case commonPts, commonQts:
|
||||
return a.ptsDiff < b.ptsDiff
|
||||
case channelPts:
|
||||
if a.channelID < b.channelID {
|
||||
return true
|
||||
}
|
||||
if a.channelID > b.channelID {
|
||||
return false
|
||||
}
|
||||
return a.ptsDiff < b.ptsDiff
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (p ptsSorter) Swap(i, j int) {
|
||||
p[i], p[j] = p[j], p[i]
|
||||
}
|
||||
@@ -0,0 +1,96 @@
|
||||
package updates
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
func Test_sortUpdatesByPts(t *testing.T) {
|
||||
channelNewMessage := func(pts int, id int64) *tg.UpdateNewChannelMessage {
|
||||
return &tg.UpdateNewChannelMessage{
|
||||
Message: &tg.Message{
|
||||
PeerID: &tg.PeerChannel{ChannelID: id},
|
||||
},
|
||||
Pts: pts - 1,
|
||||
PtsCount: 1,
|
||||
}
|
||||
}
|
||||
newMessage := func(pts int) *tg.UpdateNewMessage {
|
||||
return &tg.UpdateNewMessage{
|
||||
Message: &tg.Message{
|
||||
PeerID: &tg.PeerUser{UserID: 10},
|
||||
},
|
||||
Pts: pts - 1,
|
||||
PtsCount: 1,
|
||||
}
|
||||
}
|
||||
encryptedNewMessage := func(pts int) *tg.UpdateNewEncryptedMessage {
|
||||
return &tg.UpdateNewEncryptedMessage{
|
||||
Qts: pts,
|
||||
}
|
||||
}
|
||||
channelReadInbox := func(pts int, id int64) *tg.UpdateReadChannelInbox {
|
||||
return &tg.UpdateReadChannelInbox{ChannelID: id, MaxID: 25, Pts: pts}
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
input []tg.UpdateClass
|
||||
result []tg.UpdateClass
|
||||
}{
|
||||
{
|
||||
[]tg.UpdateClass{
|
||||
channelReadInbox(26, 1),
|
||||
channelNewMessage(25, 1),
|
||||
},
|
||||
[]tg.UpdateClass{
|
||||
channelNewMessage(25, 1),
|
||||
channelReadInbox(26, 1),
|
||||
},
|
||||
},
|
||||
{
|
||||
[]tg.UpdateClass{
|
||||
channelReadInbox(26, 1),
|
||||
channelNewMessage(25, 2),
|
||||
newMessage(26),
|
||||
encryptedNewMessage(26),
|
||||
newMessage(25),
|
||||
encryptedNewMessage(25),
|
||||
encryptedNewMessage(27),
|
||||
channelNewMessage(25, 1),
|
||||
},
|
||||
[]tg.UpdateClass{
|
||||
newMessage(25),
|
||||
newMessage(26),
|
||||
encryptedNewMessage(25),
|
||||
encryptedNewMessage(26),
|
||||
encryptedNewMessage(27),
|
||||
channelNewMessage(25, 1),
|
||||
channelReadInbox(26, 1),
|
||||
channelNewMessage(25, 2),
|
||||
},
|
||||
},
|
||||
{
|
||||
[]tg.UpdateClass{
|
||||
channelReadInbox(26, 1),
|
||||
&tg.UpdateConfig{},
|
||||
channelNewMessage(25, 1),
|
||||
},
|
||||
[]tg.UpdateClass{
|
||||
&tg.UpdateConfig{},
|
||||
channelNewMessage(25, 1),
|
||||
channelReadInbox(26, 1),
|
||||
},
|
||||
},
|
||||
}
|
||||
for i, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(fmt.Sprintf("Test%d", i+1), func(t *testing.T) {
|
||||
sortUpdatesByPts(tt.input)
|
||||
require.Equal(t, tt.result, tt.input)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,528 @@
|
||||
package updates
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"go.uber.org/zap"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/exchange"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/auth"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tgerr"
|
||||
)
|
||||
|
||||
const (
|
||||
idleTimeout = time.Minute * 15
|
||||
fastgapTimeout = time.Millisecond * 500
|
||||
|
||||
diffLimitUser = 100
|
||||
diffLimitBot = 100000
|
||||
)
|
||||
|
||||
type tracedUpdate struct {
|
||||
update tg.UpdatesClass
|
||||
span trace.SpanContext
|
||||
}
|
||||
|
||||
type internalState struct {
|
||||
// Updates channel.
|
||||
externalQueue chan tracedUpdate
|
||||
|
||||
// Updates from channel states
|
||||
// during updates.getChannelDifference.
|
||||
internalQueue chan tracedUpdate
|
||||
|
||||
// Common internalState.
|
||||
pts, qts, seq *sequenceBox
|
||||
date int
|
||||
idleTimeout *time.Timer
|
||||
|
||||
// Channel states.
|
||||
channels map[int64]*channelState
|
||||
|
||||
// Immutable fields.
|
||||
client API
|
||||
log *zap.Logger
|
||||
handler telegram.UpdateHandler
|
||||
onTooLong func(channelID int64) error
|
||||
storage StateStorage
|
||||
hasher AccessHasher
|
||||
selfID int64
|
||||
diffLim int
|
||||
wg *errgroup.Group
|
||||
tracer trace.Tracer
|
||||
}
|
||||
|
||||
type stateConfig struct {
|
||||
State State
|
||||
Channels map[int64]struct {
|
||||
Pts int
|
||||
AccessHash int64
|
||||
}
|
||||
RawClient API
|
||||
Logger *zap.Logger
|
||||
Tracer trace.Tracer
|
||||
Handler telegram.UpdateHandler
|
||||
OnChannelTooLong func(channelID int64) error
|
||||
Storage StateStorage
|
||||
Hasher AccessHasher
|
||||
SelfID int64
|
||||
DiffLimit int
|
||||
WorkGroup *errgroup.Group
|
||||
}
|
||||
|
||||
func newState(ctx context.Context, cfg stateConfig) *internalState {
|
||||
s := &internalState{
|
||||
externalQueue: make(chan tracedUpdate, 10),
|
||||
internalQueue: make(chan tracedUpdate, 10),
|
||||
|
||||
date: cfg.State.Date,
|
||||
idleTimeout: time.NewTimer(idleTimeout),
|
||||
|
||||
channels: make(map[int64]*channelState),
|
||||
|
||||
client: cfg.RawClient,
|
||||
log: cfg.Logger,
|
||||
handler: cfg.Handler,
|
||||
onTooLong: cfg.OnChannelTooLong,
|
||||
storage: cfg.Storage,
|
||||
hasher: cfg.Hasher,
|
||||
selfID: cfg.SelfID,
|
||||
diffLim: cfg.DiffLimit,
|
||||
wg: cfg.WorkGroup,
|
||||
tracer: cfg.Tracer,
|
||||
}
|
||||
s.pts = newSequenceBox(sequenceConfig{
|
||||
InitialState: cfg.State.Pts,
|
||||
Apply: s.applyPts,
|
||||
Logger: s.log.Named("pts"),
|
||||
Tracer: s.tracer,
|
||||
})
|
||||
s.qts = newSequenceBox(sequenceConfig{
|
||||
InitialState: cfg.State.Qts,
|
||||
Apply: s.applyQts,
|
||||
Logger: s.log.Named("qts"),
|
||||
Tracer: s.tracer,
|
||||
})
|
||||
s.seq = newSequenceBox(sequenceConfig{
|
||||
InitialState: cfg.State.Seq,
|
||||
Apply: s.applySeq,
|
||||
Logger: s.log.Named("seq"),
|
||||
})
|
||||
|
||||
for id, info := range cfg.Channels {
|
||||
state := s.newChannelState(id, info.AccessHash, info.Pts)
|
||||
s.channels[id] = state
|
||||
s.wg.Go(func() error {
|
||||
return state.Run(ctx)
|
||||
})
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *internalState) Push(ctx context.Context, u tg.UpdatesClass) error {
|
||||
tu := tracedUpdate{
|
||||
update: u,
|
||||
span: trace.SpanContextFromContext(ctx),
|
||||
}
|
||||
select {
|
||||
case s.externalQueue <- tu:
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// isFatalError returns true if error is fatal so we should stop updates handler.
|
||||
func isFatalError(err error) bool {
|
||||
// See https://github.com/gotd/td/issues/1458.
|
||||
if errors.Is(err, exchange.ErrKeyFingerprintNotFound) {
|
||||
return true
|
||||
}
|
||||
if tgerr.Is(err, "AUTH_KEY_UNREGISTERED", "SESSION_EXPIRED") {
|
||||
return true
|
||||
}
|
||||
if auth.IsUnauthorized(err) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *internalState) Run(ctx context.Context) error {
|
||||
if s == nil {
|
||||
return errors.New("invalid: nil internalState")
|
||||
}
|
||||
if s.log == nil {
|
||||
return errors.New("invalid: nil logger")
|
||||
}
|
||||
s.log.Debug("Starting updates handler")
|
||||
defer s.log.Debug("Updates handler stopped")
|
||||
s.getDifferenceLogger(ctx)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if len(s.pts.pending) > 0 || len(s.qts.pending) > 0 || len(s.seq.pending) > 0 {
|
||||
s.getDifferenceLogger(ctx)
|
||||
}
|
||||
return ctx.Err()
|
||||
case u := <-s.externalQueue:
|
||||
ctx := trace.ContextWithSpanContext(ctx, u.span)
|
||||
if err := s.handleUpdates(ctx, u.update); err != nil {
|
||||
return err
|
||||
}
|
||||
case u := <-s.internalQueue:
|
||||
ctx := trace.ContextWithSpanContext(ctx, u.span)
|
||||
if err := s.handleUpdates(ctx, u.update); err != nil {
|
||||
return err
|
||||
}
|
||||
case <-s.pts.gapTimeout.C:
|
||||
s.log.Debug("Pts gap timeout")
|
||||
s.getDifferenceLogger(ctx)
|
||||
case <-s.qts.gapTimeout.C:
|
||||
s.log.Debug("Qts gap timeout")
|
||||
s.getDifferenceLogger(ctx)
|
||||
case <-s.seq.gapTimeout.C:
|
||||
s.log.Debug("Seq gap timeout")
|
||||
s.getDifferenceLogger(ctx)
|
||||
case <-s.idleTimeout.C:
|
||||
s.log.Debug("Idle timeout")
|
||||
s.getDifferenceLogger(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *internalState) handleUpdates(ctx context.Context, u tg.UpdatesClass) error {
|
||||
ctx, span := s.tracer.Start(ctx, "handleUpdates")
|
||||
defer span.End()
|
||||
|
||||
s.resetIdleTimer()
|
||||
|
||||
switch u := u.(type) {
|
||||
case *tg.Updates:
|
||||
s.saveUserHashes(ctx, u.Users)
|
||||
s.saveChannelHashes(ctx, u.Chats)
|
||||
return s.handleSeq(ctx, &tg.UpdatesCombined{
|
||||
Updates: u.Updates,
|
||||
Users: u.Users,
|
||||
Chats: u.Chats,
|
||||
Date: u.Date,
|
||||
Seq: u.Seq,
|
||||
SeqStart: u.Seq,
|
||||
})
|
||||
case *tg.UpdatesCombined:
|
||||
s.saveUserHashes(ctx, u.Users)
|
||||
s.saveChannelHashes(ctx, u.Chats)
|
||||
return s.handleSeq(ctx, u)
|
||||
case *tg.UpdateShort:
|
||||
return s.handleUpdates(ctx, &tg.UpdatesCombined{
|
||||
Updates: []tg.UpdateClass{u.Update},
|
||||
Date: u.Date,
|
||||
})
|
||||
case *tg.UpdateShortMessage:
|
||||
updateShort := s.convertShortMessage(u)
|
||||
if _, found, err := s.hasher.GetUserAccessHash(ctx, s.selfID, u.UserID); err != nil {
|
||||
return err
|
||||
} else if !found {
|
||||
chats, users, err := s.handleDifference(ctx, u.Date)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return s.handleUpdates(ctx, &tg.UpdatesCombined{
|
||||
Updates: []tg.UpdateClass{updateShort.Update},
|
||||
Users: users,
|
||||
Chats: chats,
|
||||
Date: u.Date,
|
||||
})
|
||||
} else {
|
||||
return s.handleUpdates(ctx, updateShort)
|
||||
}
|
||||
case *tg.UpdateShortChatMessage:
|
||||
updateShort := s.convertShortChatMessage(u)
|
||||
if _, found, err := s.hasher.GetUserAccessHash(ctx, s.selfID, u.FromID); err != nil {
|
||||
return err
|
||||
} else if !found {
|
||||
chats, users, err := s.handleDifference(ctx, u.Date)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return s.handleUpdates(ctx, &tg.UpdatesCombined{
|
||||
Updates: []tg.UpdateClass{updateShort.Update},
|
||||
Users: users,
|
||||
Chats: chats,
|
||||
Date: u.Date,
|
||||
})
|
||||
} else {
|
||||
return s.handleUpdates(ctx, updateShort)
|
||||
}
|
||||
case *tg.UpdateShortSentMessage:
|
||||
return s.handleUpdates(ctx, s.convertShortSentMessage(u))
|
||||
case *tg.UpdatesTooLong:
|
||||
return s.getDifference(ctx)
|
||||
default:
|
||||
panic(fmt.Sprintf("unexpected update type: %T", u))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *internalState) handleSeq(ctx context.Context, u *tg.UpdatesCombined) error {
|
||||
ctx, span := s.tracer.Start(ctx, "handleSeq")
|
||||
defer span.End()
|
||||
|
||||
if err := validateSeq(u.Seq, u.SeqStart); err != nil {
|
||||
s.log.Error("Seq validation failed", zap.Error(err), zap.Any("update", u))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Special case.
|
||||
if u.Seq == 0 {
|
||||
ptsChanged, err := s.applyCombined(ctx, u)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if ptsChanged {
|
||||
return s.getDifference(ctx)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return s.seq.Handle(ctx, update{
|
||||
Value: u,
|
||||
State: u.Seq,
|
||||
Count: u.Seq - u.SeqStart + 1,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *internalState) handlePts(ctx context.Context, pts, ptsCount int, u tg.UpdateClass, ents entities) error {
|
||||
if err := validatePts(pts, ptsCount); err != nil {
|
||||
s.log.Error("Pts validation failed", zap.Error(err), zap.Any("update", u))
|
||||
return nil
|
||||
}
|
||||
|
||||
return s.pts.Handle(ctx, update{
|
||||
Value: u,
|
||||
State: pts,
|
||||
Count: ptsCount,
|
||||
Entities: ents,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *internalState) handleQts(ctx context.Context, qts int, u tg.UpdateClass, ents entities) error {
|
||||
if err := validateQts(qts); err != nil {
|
||||
s.log.Error("Qts validation failed", zap.Error(err), zap.Any("update", u))
|
||||
return nil
|
||||
}
|
||||
|
||||
return s.qts.Handle(ctx, update{
|
||||
Value: u,
|
||||
State: qts,
|
||||
Count: 1,
|
||||
Entities: ents,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *internalState) handleChannel(ctx context.Context, channelID int64, date, pts, ptsCount int, cu channelUpdate) error {
|
||||
if err := validatePts(pts, ptsCount); err != nil {
|
||||
s.log.Error("Pts validation failed", zap.Error(err), zap.Any("update", cu.update))
|
||||
return nil
|
||||
}
|
||||
|
||||
state, ok := s.channels[channelID]
|
||||
if !ok {
|
||||
accessHash, found, err := s.hasher.GetChannelAccessHash(context.Background(), s.selfID, channelID)
|
||||
if err != nil {
|
||||
s.log.Error("GetAccessHash error", zap.Error(err))
|
||||
}
|
||||
|
||||
if !found {
|
||||
if date == 0 {
|
||||
// Received update has no date field.
|
||||
date = s.date - 30
|
||||
} else {
|
||||
date-- // 1 sec back
|
||||
}
|
||||
|
||||
// Try to get access hash from updates.getDifference.
|
||||
accessHash, found = s.restoreChannelAccessHash(ctx, channelID, date)
|
||||
if !found {
|
||||
s.log.Debug("Failed to recover missing access hash, update ignored",
|
||||
zap.Int64("channel_id", channelID),
|
||||
// zap.Any("update", cu.update),
|
||||
)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
localPts, found, err := s.storage.GetChannelPts(ctx, s.selfID, channelID)
|
||||
if err != nil {
|
||||
localPts = pts - ptsCount
|
||||
s.log.Error("GetChannelPts error", zap.Error(err))
|
||||
}
|
||||
|
||||
if !found {
|
||||
localPts = pts - ptsCount
|
||||
if err := s.storage.SetChannelPts(ctx, s.selfID, channelID, localPts); err != nil {
|
||||
s.log.Error("SetChannelPts error", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
state = s.newChannelState(channelID, accessHash, localPts)
|
||||
s.channels[channelID] = state
|
||||
s.wg.Go(func() error {
|
||||
return state.Run(ctx)
|
||||
})
|
||||
}
|
||||
|
||||
return state.Push(ctx, cu)
|
||||
}
|
||||
|
||||
func (s *internalState) newChannelState(channelID, accessHash int64, initialPts int) *channelState {
|
||||
return newChannelState(channelStateConfig{
|
||||
Out: s.internalQueue,
|
||||
InitialPts: initialPts,
|
||||
ChannelID: channelID,
|
||||
AccessHash: accessHash,
|
||||
SelfID: s.selfID,
|
||||
Storage: s.storage,
|
||||
DiffLimit: s.diffLim,
|
||||
RawClient: s.client,
|
||||
Handler: s.handler,
|
||||
OnChannelTooLong: s.onTooLong,
|
||||
Logger: s.log.Named("channel").With(zap.Int64("channel_id", channelID)),
|
||||
Tracer: s.tracer,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *internalState) getDifference(ctx context.Context) error {
|
||||
ctx, span := s.tracer.Start(ctx, "getDifference")
|
||||
defer span.End()
|
||||
|
||||
s.resetIdleTimer()
|
||||
s.pts.gaps.Clear()
|
||||
s.qts.gaps.Clear()
|
||||
s.seq.gaps.Clear()
|
||||
|
||||
s.log.Debug("Getting difference")
|
||||
|
||||
setState := func(state tg.UpdatesState, reason string) {
|
||||
if err := s.storage.SetState(ctx, s.selfID, State{}.fromRemote(&state)); err != nil {
|
||||
s.log.Warn("SetState error", zap.Error(err))
|
||||
}
|
||||
|
||||
s.pts.SetState(state.Pts, reason)
|
||||
s.qts.SetState(state.Qts, reason)
|
||||
s.seq.SetState(state.Seq, reason)
|
||||
s.date = state.Date
|
||||
}
|
||||
|
||||
diff, err := s.client.UpdatesGetDifference(ctx, &tg.UpdatesGetDifferenceRequest{
|
||||
Pts: s.pts.State(),
|
||||
Qts: s.qts.State(),
|
||||
Date: s.date,
|
||||
})
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "get difference")
|
||||
}
|
||||
|
||||
s.log.Debug("Difference received", zap.String("diff", fmt.Sprintf("%T", diff)))
|
||||
|
||||
switch diff := diff.(type) {
|
||||
case *tg.UpdatesDifference:
|
||||
if len(diff.OtherUpdates) > 0 {
|
||||
if err := s.handleUpdates(ctx, &tg.UpdatesCombined{
|
||||
Updates: diff.OtherUpdates,
|
||||
Users: diff.Users,
|
||||
Chats: diff.Chats,
|
||||
}); err != nil {
|
||||
return errors.Wrap(err, "handle diff.OtherUpdates")
|
||||
}
|
||||
}
|
||||
|
||||
if len(diff.NewMessages) > 0 || len(diff.NewEncryptedMessages) > 0 {
|
||||
if err := s.handler.Handle(ctx, &tg.Updates{
|
||||
Updates: append(
|
||||
msgsToUpdates(diff.NewMessages, false),
|
||||
encryptedMsgsToUpdates(diff.NewEncryptedMessages)...,
|
||||
),
|
||||
Users: diff.Users,
|
||||
Chats: diff.Chats,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
setState(diff.State, "updates.Difference")
|
||||
return nil
|
||||
|
||||
// No events.
|
||||
case *tg.UpdatesDifferenceEmpty:
|
||||
if err := s.storage.SetDateSeq(ctx, s.selfID, diff.Date, diff.Seq); err != nil {
|
||||
s.log.Warn("SetDateSeq error", zap.Error(err))
|
||||
}
|
||||
|
||||
s.date = diff.Date
|
||||
s.seq.SetState(diff.Seq, "updates.differenceEmpty")
|
||||
return nil
|
||||
|
||||
// Incomplete list of occurred events.
|
||||
case *tg.UpdatesDifferenceSlice:
|
||||
if len(diff.OtherUpdates) > 0 {
|
||||
if err := s.handleUpdates(ctx, &tg.UpdatesCombined{
|
||||
Updates: diff.OtherUpdates,
|
||||
Users: diff.Users,
|
||||
Chats: diff.Chats,
|
||||
Date: diff.IntermediateState.Date,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if len(diff.NewMessages) > 0 || len(diff.NewEncryptedMessages) > 0 {
|
||||
if err := s.handler.Handle(ctx, &tg.Updates{
|
||||
Updates: append(
|
||||
msgsToUpdates(diff.NewMessages, false),
|
||||
encryptedMsgsToUpdates(diff.NewEncryptedMessages)...,
|
||||
),
|
||||
Users: diff.Users,
|
||||
Chats: diff.Chats,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
setState(diff.IntermediateState, "updates.differenceSlice")
|
||||
return s.getDifference(ctx)
|
||||
|
||||
// The difference is too long, and the specified internalState must be used to refetch updates.
|
||||
case *tg.UpdatesDifferenceTooLong:
|
||||
if err := s.storage.SetPts(ctx, s.selfID, diff.Pts); err != nil {
|
||||
s.log.Error("SetPts error", zap.Error(err))
|
||||
}
|
||||
s.pts.SetState(diff.Pts, "updates.differenceTooLong")
|
||||
return s.getDifference(ctx)
|
||||
|
||||
default:
|
||||
return errors.Errorf("unexpected diff type: %T", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *internalState) getDifferenceLogger(ctx context.Context) {
|
||||
if err := s.getDifference(ctx); err != nil {
|
||||
s.log.Error("get difference error", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *internalState) resetIdleTimer() {
|
||||
if len(s.idleTimeout.C) > 0 {
|
||||
<-s.idleTimeout.C
|
||||
}
|
||||
_ = s.idleTimeout.Reset(idleTimeout)
|
||||
}
|
||||
@@ -0,0 +1,189 @@
|
||||
package updates
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
func (s *internalState) applySeq(ctx context.Context, state int, updates []update) error {
|
||||
recoverState := false
|
||||
for _, u := range updates {
|
||||
ptsChanged, err := s.applyCombined(ctx, u.Value.(*tg.UpdatesCombined))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if ptsChanged {
|
||||
recoverState = true
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.storage.SetSeq(ctx, s.selfID, state); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if recoverState {
|
||||
return s.getDifference(ctx)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *internalState) applyCombined(ctx context.Context, comb *tg.UpdatesCombined) (ptsChanged bool, err error) {
|
||||
ctx, span := s.tracer.Start(ctx, "internalState.applyCombined")
|
||||
defer span.End()
|
||||
|
||||
var (
|
||||
ents = entities{
|
||||
Users: comb.Users,
|
||||
Chats: comb.Chats,
|
||||
}
|
||||
)
|
||||
sortUpdatesByPts(comb.Updates)
|
||||
|
||||
for _, u := range comb.Updates {
|
||||
switch u := u.(type) {
|
||||
case *tg.UpdatePtsChanged:
|
||||
ptsChanged = true
|
||||
continue
|
||||
case *tg.UpdateChannelTooLong:
|
||||
st, ok := s.channels[u.ChannelID]
|
||||
if !ok {
|
||||
s.log.Debug("ChannelTooLong for channel that is not in the internalState, update ignored", zap.Int64("channel_id", u.ChannelID))
|
||||
continue
|
||||
}
|
||||
if err := st.Push(ctx, channelUpdate{
|
||||
update: u,
|
||||
entities: ents,
|
||||
span: trace.SpanContextFromContext(ctx),
|
||||
}); err != nil {
|
||||
return false, err
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if pts, ptsCount, ok := tg.IsPtsUpdate(u); ok {
|
||||
if err := s.handlePts(ctx, pts, ptsCount, u, ents); err != nil {
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
|
||||
if channelID, pts, ptsCount, ok, err := tg.IsChannelPtsUpdate(u); ok {
|
||||
if err != nil {
|
||||
s.log.Debug("Invalid channel update", zap.Error(err)) //, zap.Any("update", u))
|
||||
continue
|
||||
}
|
||||
if err := s.handleChannel(ctx, channelID, comb.Date, pts, ptsCount, channelUpdate{
|
||||
update: u,
|
||||
entities: ents,
|
||||
span: trace.SpanContextFromContext(ctx),
|
||||
}); err != nil {
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
|
||||
if qts, ok := tg.IsQtsUpdate(u); ok {
|
||||
if err := s.handleQts(ctx, qts, u, ents); err != nil {
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.handler.Handle(ctx, &tg.Updates{
|
||||
Updates: comb.Updates,
|
||||
Users: ents.Users,
|
||||
Chats: ents.Chats,
|
||||
}); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
setDate, setSeq := comb.Date > s.date, comb.Seq > 0
|
||||
switch {
|
||||
case setDate && setSeq:
|
||||
if err := s.storage.SetDateSeq(ctx, s.selfID, comb.Date, comb.Seq); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
s.date = comb.Date
|
||||
s.seq.SetState(comb.Seq, "seq update")
|
||||
case setDate:
|
||||
if err := s.storage.SetDate(ctx, s.selfID, comb.Date); err != nil {
|
||||
return false, err
|
||||
}
|
||||
s.date = comb.Date
|
||||
case setSeq:
|
||||
if err := s.storage.SetSeq(ctx, s.selfID, comb.Seq); err != nil {
|
||||
return false, err
|
||||
}
|
||||
s.seq.SetState(comb.Seq, "seq update")
|
||||
}
|
||||
|
||||
return ptsChanged, nil
|
||||
}
|
||||
|
||||
func (s *internalState) applyPts(ctx context.Context, state int, updates []update) error {
|
||||
ctx, span := s.tracer.Start(ctx, "internalState.applyPts")
|
||||
defer span.End()
|
||||
|
||||
var (
|
||||
converted []tg.UpdateClass
|
||||
ents entities
|
||||
)
|
||||
|
||||
for _, update := range updates {
|
||||
converted = append(converted, update.Value.(tg.UpdateClass))
|
||||
ents.Merge(update.Entities)
|
||||
}
|
||||
|
||||
if err := s.handler.Handle(ctx, &tg.Updates{
|
||||
Updates: converted,
|
||||
Users: ents.Users,
|
||||
Chats: ents.Chats,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.storage.SetPts(ctx, s.selfID, state); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *internalState) applyQts(ctx context.Context, state int, updates []update) error {
|
||||
ctx, span := s.tracer.Start(ctx, "internalState.applyQts")
|
||||
defer span.End()
|
||||
|
||||
var (
|
||||
converted []tg.UpdateClass
|
||||
ents entities
|
||||
)
|
||||
|
||||
for _, update := range updates {
|
||||
converted = append(converted, update.Value.(tg.UpdateClass))
|
||||
ents.Merge(update.Entities)
|
||||
}
|
||||
|
||||
if err := s.handler.Handle(ctx, &tg.Updates{
|
||||
Updates: converted,
|
||||
Users: ents.Users,
|
||||
Chats: ents.Chats,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Don't set qts if it's 0, because it means that we are apllying gaps updates
|
||||
if state == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := s.storage.SetQts(ctx, s.selfID, state); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,344 @@
|
||||
package updates
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
type channelUpdate struct {
|
||||
update tg.UpdateClass
|
||||
entities entities
|
||||
span trace.SpanContext
|
||||
}
|
||||
|
||||
type channelState struct {
|
||||
// Updates from *internalState.
|
||||
updates chan channelUpdate
|
||||
// Channel to pass diff.OtherUpdates into *internalState.
|
||||
out chan<- tracedUpdate
|
||||
|
||||
// Channel internalState.
|
||||
pts *sequenceBox
|
||||
idleTimeout *time.Timer
|
||||
diffTimeout time.Time
|
||||
|
||||
// Immutable fields.
|
||||
channelID int64
|
||||
accessHash int64
|
||||
selfID int64
|
||||
diffLim int
|
||||
client API
|
||||
storage StateStorage
|
||||
log *zap.Logger
|
||||
tracer trace.Tracer
|
||||
handler telegram.UpdateHandler
|
||||
onTooLong func(channelID int64) error
|
||||
}
|
||||
|
||||
type channelStateConfig struct {
|
||||
Out chan tracedUpdate
|
||||
InitialPts int
|
||||
ChannelID int64
|
||||
AccessHash int64
|
||||
SelfID int64
|
||||
DiffLimit int
|
||||
RawClient API
|
||||
Storage StateStorage
|
||||
Handler telegram.UpdateHandler
|
||||
OnChannelTooLong func(channelID int64) error
|
||||
Logger *zap.Logger
|
||||
Tracer trace.Tracer
|
||||
}
|
||||
|
||||
func newChannelState(cfg channelStateConfig) *channelState {
|
||||
state := &channelState{
|
||||
updates: make(chan channelUpdate, 10),
|
||||
out: cfg.Out,
|
||||
|
||||
idleTimeout: time.NewTimer(idleTimeout),
|
||||
|
||||
channelID: cfg.ChannelID,
|
||||
accessHash: cfg.AccessHash,
|
||||
selfID: cfg.SelfID,
|
||||
diffLim: cfg.DiffLimit,
|
||||
client: cfg.RawClient,
|
||||
storage: cfg.Storage,
|
||||
log: cfg.Logger,
|
||||
handler: cfg.Handler,
|
||||
onTooLong: cfg.OnChannelTooLong,
|
||||
tracer: cfg.Tracer,
|
||||
}
|
||||
|
||||
state.pts = newSequenceBox(sequenceConfig{
|
||||
InitialState: cfg.InitialPts,
|
||||
Apply: state.applyPts,
|
||||
Logger: cfg.Logger.Named("pts"),
|
||||
Tracer: cfg.Tracer,
|
||||
})
|
||||
|
||||
return state
|
||||
}
|
||||
|
||||
func (s *channelState) Push(ctx context.Context, u channelUpdate) error {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case s.updates <- u:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *channelState) Run(ctx context.Context) error {
|
||||
// Subscribe to channel updates.
|
||||
if err := s.getDifference(ctx); err != nil {
|
||||
s.log.Error("Failed to subscribe to channel updates", zap.Error(err))
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case u := <-s.updates:
|
||||
ctx := trace.ContextWithSpanContext(ctx, u.span)
|
||||
if err := s.handleUpdate(ctx, u.update, u.entities); err != nil {
|
||||
s.log.Error("Handle update error", zap.Error(err))
|
||||
}
|
||||
case <-s.pts.gapTimeout.C:
|
||||
s.log.Debug("Gap timeout")
|
||||
s.getDifferenceLogger(ctx)
|
||||
case <-ctx.Done():
|
||||
if len(s.pts.pending) > 0 {
|
||||
// This will probably fail.
|
||||
s.getDifferenceLogger(ctx)
|
||||
}
|
||||
return ctx.Err()
|
||||
case <-s.idleTimeout.C:
|
||||
s.log.Debug("Idle timeout")
|
||||
s.resetIdleTimer()
|
||||
s.getDifferenceLogger(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *channelState) handleUpdate(ctx context.Context, u tg.UpdateClass, ents entities) error {
|
||||
ctx, span := s.tracer.Start(ctx, "channelState.handleUpdate")
|
||||
defer span.End()
|
||||
|
||||
s.resetIdleTimer()
|
||||
|
||||
if long, ok := u.(*tg.UpdateChannelTooLong); ok {
|
||||
return s.handleTooLong(ctx, long)
|
||||
}
|
||||
|
||||
channelID, pts, ptsCount, ok, err := tg.IsChannelPtsUpdate(u)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "invalid update")
|
||||
}
|
||||
|
||||
if !ok {
|
||||
return errors.Errorf("expected channel update, got: %T", u)
|
||||
}
|
||||
|
||||
if channelID != s.channelID {
|
||||
return errors.Errorf("update for wrong channel (channelID: %d)", channelID)
|
||||
}
|
||||
|
||||
return s.pts.Handle(ctx, update{
|
||||
Value: u,
|
||||
State: pts,
|
||||
Count: ptsCount,
|
||||
Entities: ents,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *channelState) handleTooLong(ctx context.Context, long *tg.UpdateChannelTooLong) error {
|
||||
ctx, span := s.tracer.Start(ctx, "channelState.handleTooLong")
|
||||
defer span.End()
|
||||
|
||||
remotePts, ok := long.GetPts()
|
||||
if !ok {
|
||||
s.log.Warn("Got UpdateChannelTooLong without pts field")
|
||||
return s.getDifference(ctx)
|
||||
}
|
||||
|
||||
// Note: we still can fetch latest diffLim updates.
|
||||
// Should we do?
|
||||
if remotePts-s.pts.State() > s.diffLim {
|
||||
return s.onTooLong(s.channelID)
|
||||
}
|
||||
|
||||
return s.getDifference(ctx)
|
||||
}
|
||||
|
||||
func (s *channelState) applyPts(ctx context.Context, state int, updates []update) error {
|
||||
ctx, span := s.tracer.Start(ctx, "channelState.applyPts")
|
||||
defer span.End()
|
||||
|
||||
var (
|
||||
converted []tg.UpdateClass
|
||||
ents entities
|
||||
)
|
||||
|
||||
for _, update := range updates {
|
||||
converted = append(converted, update.Value.(tg.UpdateClass))
|
||||
ents.Merge(update.Entities)
|
||||
}
|
||||
|
||||
if err := s.handler.Handle(ctx, &tg.Updates{
|
||||
Updates: converted,
|
||||
Users: ents.Users,
|
||||
Chats: ents.Chats,
|
||||
}); err != nil {
|
||||
s.log.Error("Handle update error", zap.Error(err))
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := s.storage.SetChannelPts(ctx, s.selfID, s.channelID, state); err != nil {
|
||||
s.log.Error("SetChannelPts error", zap.Error(err))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *channelState) getDifference(ctx context.Context) error {
|
||||
ctx, span := s.tracer.Start(ctx, "channelState.getDifference")
|
||||
defer span.End()
|
||||
s.pts.gaps.Clear()
|
||||
|
||||
s.log.Debug("Getting difference")
|
||||
|
||||
if now := time.Now(); now.Before(s.diffTimeout) {
|
||||
dur := s.diffTimeout.Sub(now)
|
||||
s.log.Debug("GetChannelDifference timeout", zap.Duration("duration", dur))
|
||||
if err := func() error {
|
||||
afterC := time.After(dur)
|
||||
for {
|
||||
select {
|
||||
case <-afterC:
|
||||
return nil
|
||||
case _, ok := <-s.updates:
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
// Ignoring updates to prevent *internalState worker from blocking.
|
||||
// All ignored updates should be restored by future getChannelDifference call.
|
||||
// At least I hope so...
|
||||
s.log.Debug("Ignoring update due to getChannelDifference timeout") // , zap.Any("update", u.update))
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
}(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
diff, err := s.client.UpdatesGetChannelDifference(ctx, &tg.UpdatesGetChannelDifferenceRequest{
|
||||
Channel: &tg.InputChannel{
|
||||
ChannelID: s.channelID,
|
||||
AccessHash: s.accessHash,
|
||||
},
|
||||
Filter: &tg.ChannelMessagesFilterEmpty{},
|
||||
Pts: s.pts.State(),
|
||||
Limit: s.diffLim,
|
||||
})
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "get channel difference")
|
||||
}
|
||||
|
||||
switch diff := diff.(type) {
|
||||
case *tg.UpdatesChannelDifference:
|
||||
if len(diff.OtherUpdates) > 0 {
|
||||
select {
|
||||
case s.out <- tracedUpdate{
|
||||
span: trace.SpanContextFromContext(ctx),
|
||||
update: &tg.Updates{
|
||||
Updates: diff.OtherUpdates,
|
||||
Users: diff.Users,
|
||||
Chats: diff.Chats,
|
||||
},
|
||||
}:
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
if len(diff.NewMessages) > 0 {
|
||||
if err := s.handler.Handle(ctx, &tg.Updates{
|
||||
Updates: msgsToUpdates(diff.NewMessages, true),
|
||||
Users: diff.Users,
|
||||
Chats: diff.Chats,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.storage.SetChannelPts(ctx, s.selfID, s.channelID, diff.Pts); err != nil {
|
||||
s.log.Warn("SetChannelPts error", zap.Error(err))
|
||||
}
|
||||
|
||||
s.pts.SetState(diff.Pts, "updates.channelDifference")
|
||||
if seconds, ok := diff.GetTimeout(); ok {
|
||||
s.diffTimeout = time.Now().Add(time.Second * time.Duration(seconds))
|
||||
}
|
||||
|
||||
if !diff.Final {
|
||||
return s.getDifference(ctx)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
case *tg.UpdatesChannelDifferenceEmpty:
|
||||
if err := s.storage.SetChannelPts(ctx, s.selfID, s.channelID, diff.Pts); err != nil {
|
||||
s.log.Warn("SetChannelPts error", zap.Error(err))
|
||||
}
|
||||
|
||||
s.pts.SetState(diff.Pts, "updates.channelDifferenceEmpty")
|
||||
if seconds, ok := diff.GetTimeout(); ok {
|
||||
s.diffTimeout = time.Now().Add(time.Second * time.Duration(seconds))
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
case *tg.UpdatesChannelDifferenceTooLong:
|
||||
if seconds, ok := diff.GetTimeout(); ok {
|
||||
s.diffTimeout = time.Now().Add(time.Second * time.Duration(seconds))
|
||||
}
|
||||
|
||||
remotePts, err := getDialogPts(diff.Dialog)
|
||||
if err != nil {
|
||||
s.log.Warn("UpdatesChannelDifferenceTooLong invalid Dialog", zap.Error(err))
|
||||
} else {
|
||||
if err := s.storage.SetChannelPts(ctx, s.selfID, s.channelID, remotePts); err != nil {
|
||||
s.log.Warn("SetChannelPts error", zap.Error(err))
|
||||
}
|
||||
|
||||
s.pts.SetState(remotePts, "updates.channelDifferenceTooLong dialog new pts")
|
||||
}
|
||||
|
||||
return s.onTooLong(s.channelID)
|
||||
|
||||
default:
|
||||
return errors.Errorf("unexpected channel diff type: %T", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *channelState) getDifferenceLogger(ctx context.Context) {
|
||||
if err := s.getDifference(ctx); err != nil {
|
||||
s.log.Error("get channel difference error", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *channelState) resetIdleTimer() {
|
||||
if len(s.idleTimeout.C) > 0 {
|
||||
<-s.idleTimeout.C
|
||||
}
|
||||
|
||||
_ = s.idleTimeout.Reset(idleTimeout)
|
||||
}
|
||||
@@ -0,0 +1,47 @@
|
||||
package updates
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
// State is the user internalState.
|
||||
type State struct {
|
||||
Pts, Qts, Date, Seq int
|
||||
}
|
||||
|
||||
func (s State) fromRemote(remote *tg.UpdatesState) State {
|
||||
return State{
|
||||
Pts: remote.Pts,
|
||||
Qts: remote.Qts,
|
||||
Date: remote.Date,
|
||||
Seq: remote.Seq,
|
||||
}
|
||||
}
|
||||
|
||||
// StateStorage is the users internalState storage.
|
||||
//
|
||||
// Note:
|
||||
// SetPts, SetQts, SetDate, SetSeq, SetDateSeq
|
||||
// should return error if user internalState does not exist.
|
||||
type StateStorage interface {
|
||||
GetState(ctx context.Context, userID int64) (state State, found bool, err error)
|
||||
SetState(ctx context.Context, userID int64, state State) error
|
||||
SetPts(ctx context.Context, userID int64, pts int) error
|
||||
SetQts(ctx context.Context, userID int64, qts int) error
|
||||
SetDate(ctx context.Context, userID int64, date int) error
|
||||
SetSeq(ctx context.Context, userID int64, seq int) error
|
||||
SetDateSeq(ctx context.Context, userID int64, date, seq int) error
|
||||
GetChannelPts(ctx context.Context, userID, channelID int64) (pts int, found bool, err error)
|
||||
SetChannelPts(ctx context.Context, userID, channelID int64, pts int) error
|
||||
ForEachChannels(ctx context.Context, userID int64, f func(ctx context.Context, channelID int64, pts int) error) error
|
||||
}
|
||||
|
||||
// AccessHasher stores user and channel access hashes for a user.
|
||||
type AccessHasher interface {
|
||||
SetChannelAccessHash(ctx context.Context, forUserID, channelID, accessHash int64) error
|
||||
GetChannelAccessHash(ctx context.Context, forUserID, channelID int64) (accessHash int64, found bool, err error)
|
||||
SetUserAccessHash(ctx context.Context, forUserID, userID, accessHash int64) error
|
||||
GetUserAccessHash(ctx context.Context, forUserID, userID int64) (accessHash int64, found bool, err error)
|
||||
}
|
||||
@@ -0,0 +1,224 @@
|
||||
package updates
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
)
|
||||
|
||||
var _ StateStorage = (*memStorage)(nil)
|
||||
|
||||
type memStorage struct {
|
||||
states map[int64]State
|
||||
channels map[int64]map[int64]int
|
||||
mux sync.Mutex
|
||||
}
|
||||
|
||||
func newMemStorage() *memStorage {
|
||||
return &memStorage{
|
||||
states: map[int64]State{},
|
||||
channels: map[int64]map[int64]int{},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *memStorage) GetState(ctx context.Context, userID int64) (state State, found bool, err error) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
state, found = s.states[userID]
|
||||
return
|
||||
}
|
||||
|
||||
func (s *memStorage) SetState(ctx context.Context, userID int64, state State) error {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
s.states[userID] = state
|
||||
s.channels[userID] = map[int64]int{}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memStorage) SetPts(ctx context.Context, userID int64, pts int) error {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
state, ok := s.states[userID]
|
||||
if !ok {
|
||||
return errors.New("internalState not found")
|
||||
}
|
||||
|
||||
state.Pts = pts
|
||||
s.states[userID] = state
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memStorage) SetQts(ctx context.Context, userID int64, qts int) error {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
state, ok := s.states[userID]
|
||||
if !ok {
|
||||
return errors.New("internalState not found")
|
||||
}
|
||||
|
||||
state.Qts = qts
|
||||
s.states[userID] = state
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memStorage) SetDate(ctx context.Context, userID int64, date int) error {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
state, ok := s.states[userID]
|
||||
if !ok {
|
||||
return errors.New("internalState not found")
|
||||
}
|
||||
|
||||
state.Date = date
|
||||
s.states[userID] = state
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memStorage) SetSeq(ctx context.Context, userID int64, seq int) error {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
state, ok := s.states[userID]
|
||||
if !ok {
|
||||
return errors.New("internalState not found")
|
||||
}
|
||||
|
||||
state.Seq = seq
|
||||
s.states[userID] = state
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memStorage) SetDateSeq(ctx context.Context, userID int64, date, seq int) error {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
state, ok := s.states[userID]
|
||||
if !ok {
|
||||
return errors.New("internalState not found")
|
||||
}
|
||||
|
||||
state.Date = date
|
||||
state.Seq = seq
|
||||
s.states[userID] = state
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memStorage) SetChannelPts(ctx context.Context, userID, channelID int64, pts int) error {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
channels, ok := s.channels[userID]
|
||||
if !ok {
|
||||
return errors.New("user internalState does not exist")
|
||||
}
|
||||
|
||||
channels[channelID] = pts
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memStorage) GetChannelPts(ctx context.Context, userID, channelID int64) (pts int, found bool, err error) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
channels, ok := s.channels[userID]
|
||||
if !ok {
|
||||
return 0, false, nil
|
||||
}
|
||||
|
||||
pts, found = channels[channelID]
|
||||
return
|
||||
}
|
||||
|
||||
func (s *memStorage) ForEachChannels(ctx context.Context, userID int64, f func(ctx context.Context, channelID int64, pts int) error) error {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
cmap, ok := s.channels[userID]
|
||||
if !ok {
|
||||
return errors.New("channels map does not exist")
|
||||
}
|
||||
|
||||
for id, pts := range cmap {
|
||||
if err := f(ctx, id, pts); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var _ AccessHasher = (*memAccessHasher)(nil)
|
||||
|
||||
type memAccessHasher struct {
|
||||
channelHashes map[int64]map[int64]int64
|
||||
userHashes map[int64]map[int64]int64
|
||||
mux sync.Mutex
|
||||
}
|
||||
|
||||
func newMemAccessHasher() *memAccessHasher {
|
||||
return &memAccessHasher{
|
||||
channelHashes: map[int64]map[int64]int64{},
|
||||
userHashes: map[int64]map[int64]int64{},
|
||||
}
|
||||
}
|
||||
|
||||
func (m *memAccessHasher) GetChannelAccessHash(ctx context.Context, forUserID, channelID int64) (accessHash int64, found bool, err error) {
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
|
||||
accessHashes, ok := m.channelHashes[forUserID]
|
||||
if !ok {
|
||||
return 0, false, nil
|
||||
}
|
||||
|
||||
accessHash, found = accessHashes[channelID]
|
||||
return
|
||||
}
|
||||
|
||||
func (m *memAccessHasher) SetChannelAccessHash(ctx context.Context, forUserID, channelID, accessHash int64) error {
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
|
||||
accessHashes, ok := m.channelHashes[forUserID]
|
||||
if !ok {
|
||||
accessHashes = map[int64]int64{}
|
||||
m.channelHashes[forUserID] = accessHashes
|
||||
}
|
||||
|
||||
accessHashes[channelID] = accessHash
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *memAccessHasher) GetUserAccessHash(ctx context.Context, forUserID, userID int64) (accessHash int64, found bool, err error) {
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
|
||||
accessHashes, ok := m.userHashes[forUserID]
|
||||
if !ok {
|
||||
return 0, false, nil
|
||||
}
|
||||
|
||||
accessHash, found = accessHashes[userID]
|
||||
return
|
||||
}
|
||||
|
||||
func (m *memAccessHasher) SetUserAccessHash(ctx context.Context, forUserID, userID, accessHash int64) error {
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
|
||||
accessHashes, ok := m.userHashes[forUserID]
|
||||
if !ok {
|
||||
accessHashes = map[int64]int64{}
|
||||
m.channelHashes[forUserID] = accessHashes
|
||||
}
|
||||
|
||||
accessHashes[userID] = accessHash
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,53 @@
|
||||
package updates
|
||||
|
||||
import (
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
type update struct {
|
||||
Value any
|
||||
State int
|
||||
Count int
|
||||
Entities entities
|
||||
}
|
||||
|
||||
func (u update) start() int { return u.State - u.Count }
|
||||
|
||||
func (u update) end() int { return u.State }
|
||||
|
||||
// Entities contains update entities.
|
||||
type entities struct {
|
||||
Users []tg.UserClass
|
||||
Chats []tg.ChatClass
|
||||
}
|
||||
|
||||
// Merge merges entities.
|
||||
func (e *entities) Merge(from entities) {
|
||||
for _, candidate := range from.Users {
|
||||
merge := true
|
||||
for _, exist := range e.Users {
|
||||
if exist.GetID() == candidate.GetID() {
|
||||
merge = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if merge {
|
||||
e.Users = append(e.Users, candidate)
|
||||
}
|
||||
}
|
||||
|
||||
for _, candidate := range from.Chats {
|
||||
merge := true
|
||||
for _, exist := range e.Chats {
|
||||
if exist.GetID() == candidate.GetID() {
|
||||
merge = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if merge {
|
||||
e.Chats = append(e.Chats, candidate)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,87 @@
|
||||
package updates
|
||||
|
||||
import (
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
func validatePts(pts, ptsCount int) error {
|
||||
if pts < 0 {
|
||||
return errors.Errorf("invalid pts value: %d", pts)
|
||||
}
|
||||
|
||||
if ptsCount < 0 {
|
||||
return errors.Errorf("invalid ptsCount value: %d", ptsCount)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateQts(qts int) error {
|
||||
if qts < 0 {
|
||||
return errors.Errorf("invalid qts value: %d", qts)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateSeq(seq, seqStart int) error {
|
||||
if seq < 0 {
|
||||
return errors.Errorf("invalid seq value: %d", seq)
|
||||
}
|
||||
|
||||
if seqStart < 0 {
|
||||
return errors.Errorf("invalid seqStart value: %d", seq)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getDialogPts(dialog tg.DialogClass) (int, error) {
|
||||
d, ok := dialog.(*tg.Dialog)
|
||||
if !ok {
|
||||
return 0, errors.Errorf("unexpected dialog type: %T", dialog)
|
||||
}
|
||||
|
||||
pts, ok := d.GetPts()
|
||||
if !ok {
|
||||
return 0, errors.New("dialog has no pts field")
|
||||
}
|
||||
|
||||
return pts, nil
|
||||
}
|
||||
|
||||
func msgsToUpdates(msgs []tg.MessageClass, channel bool) []tg.UpdateClass {
|
||||
updates := make([]tg.UpdateClass, 0, len(msgs))
|
||||
for _, msg := range msgs {
|
||||
if channel {
|
||||
updates = append(updates, &tg.UpdateNewChannelMessage{
|
||||
Message: msg,
|
||||
Pts: -1,
|
||||
PtsCount: -1,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
updates = append(updates, &tg.UpdateNewMessage{
|
||||
Message: msg,
|
||||
Pts: -1,
|
||||
PtsCount: -1,
|
||||
})
|
||||
}
|
||||
|
||||
return updates
|
||||
}
|
||||
|
||||
func encryptedMsgsToUpdates(msgs []tg.EncryptedMessageClass) []tg.UpdateClass {
|
||||
updates := make([]tg.UpdateClass, 0, len(msgs))
|
||||
for _, msg := range msgs {
|
||||
updates = append(updates, &tg.UpdateNewEncryptedMessage{
|
||||
Message: msg,
|
||||
Qts: -1,
|
||||
})
|
||||
}
|
||||
|
||||
return updates
|
||||
}
|
||||
Reference in New Issue
Block a user