move gotd fork into repo. (#111)

- update to latest telegram layer
- remove some references to fields in tg.Entities that don't exist in
the schema
- originally added here:
https://github.com/beeper/td/commit/820929062a2ba0104397bc01235ab58a9cff780e
  - referenced here
-
https://github.com/mautrix/telegramgo/commit/124f0967ed195b5a380c9bd02e170ada9710dde3
-
https://github.com/mautrix/telegramgo/commit/4205047aab2e0639217148b5d125bfaab668bd8e
This commit is contained in:
Adam Van Ymeren
2025-06-27 20:03:37 -07:00
committed by GitHub
parent 0952df0244
commit 7a04f298d2
19264 changed files with 1539697 additions and 84 deletions
@@ -0,0 +1,134 @@
package updates
import (
"fmt"
"go.uber.org/zap"
"golang.org/x/net/context"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
)
func (s *internalState) saveChannelHashes(ctx context.Context, chats []tg.ChatClass) {
ctx, span := s.tracer.Start(ctx, "updates.saveChannelHashes")
defer span.End()
for _, c := range chats {
switch c := c.(type) {
case *tg.Channel:
if c.Min {
continue
}
if hash, ok := c.GetAccessHash(); ok {
if _, ok = s.channels[c.ID]; ok {
continue
}
s.log.Debug("New channel access hash",
zap.Int64("channel_id", c.ID),
zap.String("title", c.Title),
)
if err := s.hasher.SetChannelAccessHash(ctx, s.selfID, c.ID, hash); err != nil {
s.log.Error("SetChannelAccessHash error", zap.Error(err))
}
}
case *tg.ChannelForbidden:
if _, ok := s.channels[c.ID]; ok {
continue
}
s.log.Debug("New channel access hash",
zap.Int64("channel_id", c.ID),
zap.String("title", c.Title),
)
if err := s.hasher.SetChannelAccessHash(ctx, s.selfID, c.ID, c.AccessHash); err != nil {
s.log.Error("SetChannelAccessHash error", zap.Error(err))
}
}
}
}
func (s *internalState) saveUserHashes(ctx context.Context, chats []tg.UserClass) {
ctx, span := s.tracer.Start(ctx, "updates.saveChannelHashes")
defer span.End()
for _, u := range chats {
if user, ok := u.(*tg.User); !ok {
continue
} else if hash, ok := user.GetAccessHash(); !ok {
continue
} else if user.Min {
s.log.Debug("User is min, not saving access hash")
continue
} else {
s.log.Debug("New user access hash", zap.Int64("user_id", user.ID))
if err := s.hasher.SetUserAccessHash(ctx, s.selfID, user.ID, hash); err != nil {
s.log.Error("SetUserAccessHash error", zap.Error(err))
}
}
}
}
func (s *internalState) handleDifference(ctx context.Context, date int) (chats []tg.ChatClass, users []tg.UserClass, err error) {
ctx, span := s.tracer.Start(ctx, "updates.handleDifference")
defer span.End()
diff, err := s.client.UpdatesGetDifference(ctx, &tg.UpdatesGetDifferenceRequest{
Pts: s.pts.State(),
Qts: s.qts.State(),
Date: date,
})
if err != nil {
s.log.Error("UpdatesGetDifference error", zap.Error(err))
return nil, nil, fmt.Errorf("get difference: %w", err)
}
switch diff := diff.(type) {
case *tg.UpdatesDifference:
chats = diff.Chats
users = diff.Users
case *tg.UpdatesDifferenceSlice:
chats = diff.Chats
users = diff.Users
}
s.saveChannelHashes(ctx, chats)
s.saveUserHashes(ctx, users)
return chats, users, nil
}
func (s *internalState) restoreChannelAccessHash(ctx context.Context, channelID int64, date int) (accessHash int64, ok bool) {
ctx, span := s.tracer.Start(ctx, "updates.restoreAccessHash")
defer span.End()
chats, _, err := s.handleDifference(ctx, date)
if err != nil {
s.log.Error("getDifference error", zap.Error(err))
return 0, false
}
for _, c := range chats {
switch c := c.(type) {
case *tg.Channel:
if c.Min {
continue
}
if c.ID != channelID {
continue
}
if hash, ok := c.GetAccessHash(); ok {
return hash, true
}
case *tg.ChannelForbidden:
if c.ID != channelID {
continue
}
return c.AccessHash, true
}
}
return 0, false
}
+62
View File
@@ -0,0 +1,62 @@
package updates
import (
"context"
"go.opentelemetry.io/otel/trace"
"go.uber.org/zap"
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
)
// API is the interface which contains
// Telegram RPC methods used by manager for internalState synchronization.
type API interface {
UpdatesGetState(ctx context.Context) (*tg.UpdatesState, error)
UpdatesGetDifference(ctx context.Context, request *tg.UpdatesGetDifferenceRequest) (tg.UpdatesDifferenceClass, error)
UpdatesGetChannelDifference(ctx context.Context, request *tg.UpdatesGetChannelDifferenceRequest) (tg.UpdatesChannelDifferenceClass, error)
}
// Config of the manager.
type Config struct {
// Handler where updates will be passed.
Handler telegram.UpdateHandler
// Callback called if manager cannot
// recover channel gap (optional).
OnChannelTooLong func(channelID int64) error
// State storage.
// In-mem used if not provided.
Storage StateStorage
// Channel access hash storage.
// In-mem used if not provided.
AccessHasher AccessHasher
// Logger (optional).
Logger *zap.Logger
// TracerProvider (optional).
TracerProvider trace.TracerProvider
}
func (cfg *Config) setDefaults() {
if cfg.Handler == nil {
panic("Handler is nil")
}
if cfg.AccessHasher == nil {
cfg.AccessHasher = newMemAccessHasher()
}
if cfg.Logger == nil {
cfg.Logger = zap.NewNop()
}
if cfg.TracerProvider == nil {
cfg.TracerProvider = trace.NewNoopTracerProvider()
}
if cfg.Storage == nil {
cfg.Storage = newMemStorage()
}
if cfg.OnChannelTooLong == nil {
cfg.OnChannelTooLong = func(channelID int64) error {
cfg.Logger.Error("Difference too long", zap.Int64("channel_id", channelID))
return nil
}
}
}
+129
View File
@@ -0,0 +1,129 @@
package updates
import "go.mau.fi/mautrix-telegram/pkg/gotd/tg"
func convertOptional(msg *tg.Message, i tg.UpdatesClass) {
if u, ok := i.(interface {
GetFwdFrom() (tg.MessageFwdHeader, bool)
}); ok {
if v, ok := u.GetFwdFrom(); ok {
msg.SetFwdFrom(v)
}
}
if u, ok := i.(interface{ GetViaBotID() (int64, bool) }); ok {
if v, ok := u.GetViaBotID(); ok {
msg.SetViaBotID(v)
}
}
if u, ok := i.(interface {
GetReplyTo() (tg.MessageReplyHeaderClass, bool)
}); ok {
if v, ok := u.GetReplyTo(); ok {
msg.SetReplyTo(v)
}
}
if u, ok := i.(interface {
GetReplyTo() (tg.MessageReplyHeader, bool)
}); ok {
if v, ok := u.GetReplyTo(); ok {
msg.SetReplyTo(&v)
}
}
if u, ok := i.(interface {
GetEntities() ([]tg.MessageEntityClass, bool)
}); ok {
if v, ok := u.GetEntities(); ok {
msg.SetEntities(v)
}
}
if u, ok := i.(interface {
GetMedia() (tg.MessageMediaClass, bool)
}); ok {
if v, ok := u.GetMedia(); ok {
msg.SetMedia(v)
}
}
if u, ok := i.(interface {
GetTTLPeriod() (int, bool)
}); ok {
if v, ok := u.GetTTLPeriod(); ok {
msg.SetTTLPeriod(v)
}
}
}
func (s *internalState) convertShortMessage(u *tg.UpdateShortMessage) *tg.UpdateShort {
msg := &tg.Message{
ID: u.ID,
PeerID: &tg.PeerUser{UserID: u.UserID},
Message: u.Message,
Date: u.Date,
}
// Optional fields should set by SetXXX(), so GetXXX and Flags.Has()
// can return the right values even we hav't call .Encode()
msg.SetOut(u.Out)
msg.SetMentioned(u.Mentioned)
msg.SetMediaUnread(u.MediaUnread)
msg.SetSilent(u.Silent)
msg.SetFromID(&tg.PeerUser{UserID: s.selfID})
if !u.Out {
msg.SetFromID(&tg.PeerUser{UserID: u.UserID})
}
convertOptional(msg, u)
return &tg.UpdateShort{
Update: &tg.UpdateNewMessage{
Message: msg,
Pts: u.Pts,
PtsCount: u.PtsCount,
},
Date: u.Date,
}
}
func (s *internalState) convertShortChatMessage(u *tg.UpdateShortChatMessage) *tg.UpdateShort {
msg := &tg.Message{
ID: u.ID,
PeerID: &tg.PeerChat{ChatID: u.ChatID},
Message: u.Message,
Date: u.Date,
}
msg.SetOut(u.Out)
msg.SetMentioned(u.Mentioned)
msg.SetMediaUnread(u.MediaUnread)
msg.SetSilent(u.Silent)
msg.SetFromScheduled(msg.FromScheduled)
msg.SetFromID(&tg.PeerUser{UserID: u.FromID})
convertOptional(msg, u)
return &tg.UpdateShort{
Update: &tg.UpdateNewMessage{
Message: msg,
Pts: u.Pts,
PtsCount: u.PtsCount,
},
Date: u.Date,
}
}
func (s *internalState) convertShortSentMessage(u *tg.UpdateShortSentMessage) *tg.UpdateShort {
// This update should be converted by the one who called the method
// that returned this update, because we do not have any context about
// it (message text, sender/recipient, etc.)
//
// In theory, this update should come only as a response to an RPC call,
// and we get it here because of the update hook.
// We use it to make sure there are no pts gaps.
return &tg.UpdateShort{
Update: &tg.UpdateNewMessage{
Message: &tg.MessageEmpty{
ID: u.ID,
},
Pts: u.Pts,
PtsCount: u.PtsCount,
},
Date: u.Date,
}
}
+20
View File
@@ -0,0 +1,20 @@
// Package updates provides a Telegram's internalState synchronization manager.
//
// It guarantees that all internalState-sensitive updates will be performed
// in correct order.
//
// Limitations:
//
// 1. Manager cannot verify stateless types of updates
// (tg.UpdatesClass without Seq, or tg.UpdateClass without Pts or Qts).
//
// 2. Due to the fact that updates.getDifference and updates.getChannelDifference
// do not return event sequences, manager cannot guarantee the correctness
// of these operations. We rely on the server here.
//
// 3. Manager cannot recover the channel gap if there is a ChannelDifferenceTooLong error.
// Restoring the internalState in such situation is not the prerogative of this manager.
// See: https://core.telegram.org/constructor/updates.channelDifferenceTooLong
//
// TODO: Write implementation details.
package updates
+54
View File
@@ -0,0 +1,54 @@
package updates
import "go.uber.org/zap/zapcore"
type gap struct {
from, to int
}
type gapBuffer struct {
gaps []gap
}
func (b gapBuffer) Has() bool { return len(b.gaps) > 0 }
func (b *gapBuffer) Clear() { b.gaps = make([]gap, 0, 1) }
func (b *gapBuffer) Enable(from, to int) {
if len(b.gaps) > 0 {
panic("unreachable")
}
b.gaps = append(b.gaps, gap{from, to})
}
func (b *gapBuffer) Consume(u update) (accepted bool) {
for i, g := range b.gaps {
if g.from <= u.start() && g.to >= u.end() {
if g.from < u.start() {
b.gaps = append(b.gaps, gap{from: g.from, to: u.start()})
}
if g.to > u.end() {
b.gaps = append(b.gaps, gap{from: u.end(), to: g.to})
}
b.gaps = append(b.gaps[:i], b.gaps[i+1:]...)
return true
}
}
return false
}
func (b gapBuffer) MarshalLogArray(e zapcore.ArrayEncoder) error {
for _, g := range b.gaps {
if err := e.AppendObject(zapcore.ObjectMarshalerFunc(func(e zapcore.ObjectEncoder) error {
e.AddInt("from", g.from)
e.AddInt("to", g.to)
return nil
})); err != nil {
return err
}
}
return nil
}
@@ -0,0 +1,24 @@
package updates
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestGapBuffer(t *testing.T) {
buf := new(gapBuffer)
buf.Enable(1, 7)
require.True(t, buf.Consume(update{State: 4, Count: 3}))
require.Equal(t, []gap{{4, 7}}, buf.gaps)
require.False(t, buf.Consume(update{State: 1, Count: 1}))
require.Equal(t, []gap{{4, 7}}, buf.gaps)
require.False(t, buf.Consume(update{State: 8, Count: 1}))
require.Equal(t, []gap{{4, 7}}, buf.gaps)
require.True(t, buf.Consume(update{State: 7, Count: 3}))
require.Empty(t, buf.gaps)
}
+27
View File
@@ -0,0 +1,27 @@
package updates
type gapCheckResult byte
const (
_ gapCheckResult = iota
gapApply
gapIgnore
gapRefetch
)
func checkGap(localState, remoteState, count int) gapCheckResult {
// Temporary fix for handling qts updates gaps.
if remoteState == 0 {
return gapApply
}
if localState+count == remoteState {
return gapApply
}
if localState+count > remoteState {
return gapIgnore
}
return gapRefetch
}
+34
View File
@@ -0,0 +1,34 @@
// Package hook contains telegram update hook middleware.
package hook
import (
"context"
"github.com/go-faster/errors"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
)
// UpdateHook middleware is called on each tg.UpdatesClass method result.
//
// Function is called before invoker return. Returned error will be wrapped
// and returned as InvokeRaw result.
type UpdateHook func(ctx context.Context, u tg.UpdatesClass) error
// Handle implements telegram.Middleware.
func (h UpdateHook) Handle(next tg.Invoker) telegram.InvokeFunc {
return func(ctx context.Context, input bin.Encoder, output bin.Decoder) error {
if err := next.Invoke(ctx, input, output); err != nil {
return err
}
if u, ok := output.(*tg.UpdatesBox); ok {
if err := h(ctx, u.Updates); err != nil {
return errors.Wrap(err, "hook")
}
}
return nil
}
}
@@ -0,0 +1,90 @@
package hook
import (
"context"
"testing"
"github.com/go-faster/errors"
"github.com/stretchr/testify/assert"
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
)
func TestUpdateHook_InvokeRaw(t *testing.T) {
t.Run("Success", func(t *testing.T) {
var invokerCalled, hookCalled bool
assert.NoError(t, UpdateHook(func(ctx context.Context, u tg.UpdatesClass) error {
assert.NotNil(t, u)
hookCalled = true
return nil
}).Handle(telegram.InvokeFunc(func(ctx context.Context, input bin.Encoder, output bin.Decoder) error {
invokerCalled = true
return nil
})).Invoke(context.TODO(), nil, &tg.UpdatesBox{
Updates: &tg.UpdateShortMessage{
ID: 100,
},
}))
assert.True(t, invokerCalled, "invoker should be called")
assert.True(t, hookCalled, "hook should be called")
})
t.Run("Error", func(t *testing.T) {
t.Run("Handler", func(t *testing.T) {
var invokerCalled, hookCalled bool
err := errors.New("failure")
assert.ErrorIs(t, UpdateHook(func(ctx context.Context, u tg.UpdatesClass) error {
assert.NotNil(t, u)
hookCalled = true
return nil
}).Handle(telegram.InvokeFunc(func(ctx context.Context, input bin.Encoder, output bin.Decoder) error {
invokerCalled = true
return err
})).Invoke(context.TODO(), nil, &tg.UpdatesBox{
Updates: &tg.UpdateShortMessage{
ID: 100,
},
}), err)
assert.True(t, invokerCalled, "invoker should be called")
assert.False(t, hookCalled, "hook should not be called")
})
t.Run("Hook", func(t *testing.T) {
var invokerCalled, hookCalled bool
err := errors.New("failure")
assert.ErrorIs(t, UpdateHook(func(ctx context.Context, u tg.UpdatesClass) error {
assert.NotNil(t, u)
hookCalled = true
return err
}).Handle(telegram.InvokeFunc(func(ctx context.Context, input bin.Encoder, output bin.Decoder) error {
invokerCalled = true
return nil
})).Invoke(context.TODO(), nil, &tg.UpdatesBox{
Updates: &tg.UpdateShortMessage{
ID: 100,
},
}), err)
assert.True(t, invokerCalled, "invoker should be called")
assert.True(t, hookCalled, "hook should be called")
})
})
t.Run("Not update", func(t *testing.T) {
var invokerCalled, hookCalled bool
assert.NoError(t, UpdateHook(func(ctx context.Context, u tg.UpdatesClass) error {
assert.NotNil(t, u)
hookCalled = true
return nil
}).Handle(telegram.InvokeFunc(func(ctx context.Context, input bin.Encoder, output bin.Decoder) error {
invokerCalled = true
return nil
})).Invoke(context.TODO(), nil, &tg.User{}))
assert.True(t, invokerCalled, "invoker should be called")
assert.False(t, hookCalled, "hook should not be called")
})
}
@@ -0,0 +1,81 @@
package e2e
import "go.mau.fi/mautrix-telegram/pkg/gotd/tg"
// Entities contains update entities.
type Entities struct {
Users map[int64]*tg.User
Chats map[int64]*tg.Chat
Channels map[int64]*tg.Channel
ChannelsForbidden map[int64]*tg.ChannelForbidden
}
// NewEntities creates new Entities.
func NewEntities() *Entities {
return &Entities{
Users: map[int64]*tg.User{},
Chats: map[int64]*tg.Chat{},
Channels: map[int64]*tg.Channel{},
ChannelsForbidden: map[int64]*tg.ChannelForbidden{},
}
}
// Merge merges entities.
func (e *Entities) Merge(from *Entities) {
if from == nil {
return
}
for userID, user := range from.Users {
e.Users[userID] = user
}
for chanID, chat := range from.Chats {
e.Chats[chanID] = chat
}
for channelID, channel := range from.Channels {
e.Channels[channelID] = channel
}
for channelID, channel := range from.ChannelsForbidden {
e.ChannelsForbidden[channelID] = channel
}
}
// FromUpdates method.
func (e *Entities) FromUpdates(u interface {
tg.UpdatesClass
MapUsers() tg.UserClassArray
MapChats() tg.ChatClassArray
}) *Entities {
u.MapChats().FillChatMap(e.Chats)
u.MapChats().FillChannelMap(e.Channels)
u.MapChats().FillChannelForbiddenMap(e.ChannelsForbidden)
u.MapUsers().FillUserMap(e.Users)
return e
}
// AsUsers returns users as tg.UserClass slice.
func (e *Entities) AsUsers() []tg.UserClass {
var users []tg.UserClass
for _, u := range e.Users {
users = append(users, u)
}
return users
}
// AsChats returns chats as tg.ChatClass slice.
func (e *Entities) AsChats() []tg.ChatClass {
var chats []tg.ChatClass
for _, c := range e.Chats {
chats = append(chats, c)
}
for _, c := range e.Channels {
chats = append(chats, c)
}
for _, c := range e.ChannelsForbidden {
chats = append(chats, c)
}
return chats
}
@@ -0,0 +1,60 @@
package e2e
import (
"context"
"sync"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
)
// Handler handles updates.
type handler struct {
messages *messageDatabase
ents *Entities
mux sync.Mutex
}
func newHandler() *handler {
return &handler{
messages: &messageDatabase{
channels: make(map[int64][]tg.MessageClass),
},
ents: NewEntities(),
}
}
func (h *handler) Handle(ctx context.Context, u tg.UpdatesClass) error {
switch u := u.(type) {
case *tg.Updates:
return h.handleUpdates(NewEntities().FromUpdates(u), u.Updates)
case *tg.UpdatesCombined:
return h.handleUpdates(NewEntities().FromUpdates(u), u.Updates)
default:
panic(u)
}
}
// HandleUpdates handler.
func (h *handler) handleUpdates(ents *Entities, upds []tg.UpdateClass) error {
h.mux.Lock()
defer h.mux.Unlock()
h.ents.Merge(ents)
for _, u := range upds {
switch u := u.(type) {
case *tg.UpdateNewMessage:
h.messages.common = append(h.messages.common, u.Message)
case *tg.UpdateNewEncryptedMessage:
h.messages.secret = append(h.messages.secret, u.Message)
case *tg.UpdateNewChannelMessage:
channelID := u.Message.(*tg.Message).PeerID.(*tg.PeerChannel).ChannelID
msgs := h.messages.channels[channelID]
msgs = append(msgs, u.Message)
h.messages.channels[channelID] = msgs
default:
panic("unexpected update type")
}
}
return nil
}
@@ -0,0 +1,211 @@
//go:build linux
package e2e
import (
"context"
"fmt"
"math/rand"
"sync"
"testing"
"github.com/stretchr/testify/require"
"go.uber.org/zap/zaptest"
"golang.org/x/sync/errgroup"
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/updates"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
)
func TestE2E(t *testing.T) {
testManager(t, func(s *server, storage updates.StateStorage) chan *tg.Updates {
t.Helper()
c := make(chan *tg.Updates, 10)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
var (
biba = s.peers.createUser("biba")
boba = s.peers.createUser("boba")
chat = s.peers.createChat("chat")
)
var channels []*tg.PeerChannel
require.NoError(t, storage.ForEachChannels(ctx, 123, func(ctx context.Context, channelID int64, pts int) error {
channels = append(channels, &tg.PeerChannel{
ChannelID: channelID,
})
return nil
}))
var wg sync.WaitGroup
wg.Add(2)
// Biba.
go func() {
defer wg.Done()
for i := 0; i < 3; i++ {
c <- s.CreateEvent(func(ev *EventBuilder) {
ev.SendMessage(biba, chat, fmt.Sprintf("biba-%d", i))
for mi, c := range channels {
ev.SendMessage(biba, c, fmt.Sprintf("biba-channel-%d-%d", i, mi))
}
})
}
}()
// Boba.
go func() {
defer wg.Done()
for i := 0; i < 3; i++ {
c <- s.CreateEvent(func(ev *EventBuilder) {
ev.SendMessage(boba, chat, fmt.Sprintf("boba-%d", i))
for _, c := range channels {
ev.SendMessage(boba, c, fmt.Sprintf("boba-channel-%d", i))
}
})
}
}()
go func() {
wg.Wait()
close(c)
}()
return c
})
}
func testManager(t *testing.T, f func(s *server, storage updates.StateStorage) chan *tg.Updates) {
t.Helper()
ctx, cancel := context.WithCancel(context.Background())
var (
log = zaptest.NewLogger(t)
s = newServer()
h = newHandler()
storage = newMemStorage()
hasher = newMemAccessHasher()
)
const uid = 123
require.NoError(t, storage.SetState(ctx, uid, updates.State{
Pts: 0,
Qts: 0,
Date: 0,
Seq: 0,
}))
for i := 0; i < 2; i++ {
c := s.peers.createChannel(fmt.Sprintf("channel-%d", i))
require.NoError(t, storage.SetChannelPts(ctx, uid, c.ChannelID, 0))
require.NoError(t, hasher.SetChannelAccessHash(ctx, uid, c.ChannelID, c.ChannelID*2))
}
e := updates.New(updates.Config{
Handler: h,
Logger: log.Named("gaps"),
Storage: storage,
AccessHasher: hasher,
})
uchan := loss(f(s, storage))
g, ctx := errgroup.WithContext(ctx)
ready := make(chan struct{})
opts := updates.AuthOptions{
OnStart: func(ctx context.Context) {
t.Log("OnStart")
close(ready)
},
}
g.Go(func() error {
t.Log("Starting manager")
defer t.Log("Manager stopped")
return e.Run(ctx, s, uid, opts)
})
g.Go(func() error {
t.Log("Starting updates generator")
defer t.Log("Updates generator stopped")
defer cancel()
select {
case <-ready:
t.Log("Ready")
case <-ctx.Done():
return ctx.Err()
}
var g errgroup.Group
for i := 0; i < 2; i++ {
g.Go(func() error {
for {
select {
case <-ctx.Done():
return ctx.Err()
case u, ok := <-uchan:
if !ok {
return nil
}
if err := e.Handle(ctx, u); err != nil {
return err
}
}
}
})
}
t.Log("Waiting")
if err := g.Wait(); err != nil {
return err
}
t.Log("Sending pts changed")
ups := []tg.UpdateClass{&tg.UpdatePtsChanged{}}
if err := storage.ForEachChannels(ctx, uid, func(ctx context.Context, channelID int64, pts int) error {
ups = append(ups, &tg.UpdateChannelTooLong{ChannelID: channelID})
return nil
}); err != nil {
return err
}
t.Log("Handle")
return e.Handle(ctx, &tg.Updates{
Updates: ups,
})
})
t.Log("Waiting for shutdown")
require.ErrorIs(t, g.Wait(), context.Canceled)
t.Log("Checking")
require.Equal(t, s.messages, h.messages)
require.Equal(t, s.peers.channels, h.ents.Channels)
require.Equal(t, s.peers.chats, h.ents.Chats)
require.Equal(t, s.peers.users, h.ents.Users)
}
func loss(in chan *tg.Updates) chan *tg.Updates {
out := make(chan *tg.Updates)
go func() {
defer close(out)
for u := range in {
if rand.Intn(5) == 1 {
continue
}
out <- u
}
}()
return out
}
@@ -0,0 +1,48 @@
package e2e
import "go.mau.fi/mautrix-telegram/pkg/gotd/tg"
type messageDatabase struct {
common []tg.MessageClass
secret []tg.EncryptedMessageClass
channels map[int64][]tg.MessageClass
}
type peerDatabase struct {
users map[int64]*tg.User
chats map[int64]*tg.Chat
channels map[int64]*tg.Channel
id int64
}
func (p *peerDatabase) createUser(username string) *tg.PeerUser {
p.users[p.id] = &tg.User{
ID: p.id,
Username: username,
}
defer func() { p.id++ }()
return &tg.PeerUser{UserID: p.id}
}
func (p *peerDatabase) createChat(title string) *tg.PeerChat {
p.chats[p.id] = &tg.Chat{
ID: p.id,
Title: title,
}
defer func() { p.id++ }()
return &tg.PeerChat{ChatID: p.id}
}
func (p *peerDatabase) createChannel(username string) *tg.PeerChannel {
p.channels[p.id] = &tg.Channel{
ID: p.id,
Username: username,
}
p.channels[p.id].SetAccessHash(p.id * 2)
defer func() { p.id++ }()
return &tg.PeerChannel{ChannelID: p.id}
}
@@ -0,0 +1,190 @@
// Package e2e contains end-to-end updates processing test.
package e2e
import (
"context"
"sync"
"github.com/go-faster/errors"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
)
// Server for testing gaps.
type server struct {
date int
peers *peerDatabase
messages *messageDatabase
mux sync.Mutex
}
// NewServer creates new test server.
func newServer() *server {
return &server{
date: 1,
peers: &peerDatabase{
users: make(map[int64]*tg.User),
chats: make(map[int64]*tg.Chat),
channels: make(map[int64]*tg.Channel),
},
messages: &messageDatabase{
channels: make(map[int64][]tg.MessageClass),
},
}
}
// UpdatesGetState returns current remote state.
func (s *server) UpdatesGetState(ctx context.Context) (*tg.UpdatesState, error) {
s.mux.Lock()
defer s.mux.Unlock()
return &tg.UpdatesState{
Pts: len(s.messages.common),
Qts: len(s.messages.secret),
Date: s.date,
Seq: 0,
}, nil
}
// UpdatesGetDifference returns difference between local and remote states.
func (s *server) UpdatesGetDifference(ctx context.Context, request *tg.UpdatesGetDifferenceRequest) (tg.UpdatesDifferenceClass, error) {
s.mux.Lock()
defer s.mux.Unlock()
ents := NewEntities()
var common []tg.MessageClass
for i := request.Pts + 1; i <= len(s.messages.common); i++ {
common = append(common, s.messages.common[i-1])
s.fillMessageEnts(s.messages.common[i-1], ents)
}
var secret []tg.EncryptedMessageClass
for i := request.Qts + 1; i <= len(s.messages.secret); i++ {
secret = append(secret, s.messages.secret[i-1])
}
var others []tg.UpdateClass
for _, msgs := range s.messages.channels {
for i, msg := range msgs {
if msg.(*tg.Message).Date > request.Date {
others = append(others, &tg.UpdateNewChannelMessage{
Message: msg,
Pts: i + 1,
PtsCount: 1,
})
s.fillMessageEnts(msg, ents)
}
}
}
if len(common) == 0 && len(secret) == 0 && len(others) == 0 {
return &tg.UpdatesDifferenceEmpty{
Date: s.date,
Seq: 0,
}, nil
}
return &tg.UpdatesDifference{
NewMessages: common,
NewEncryptedMessages: secret,
OtherUpdates: others,
Users: ents.AsUsers(),
Chats: ents.AsChats(),
State: tg.UpdatesState{
Pts: len(s.messages.common),
Qts: len(s.messages.secret),
Date: s.date,
Seq: 0,
},
}, nil
}
// UpdatesGetChannelDifference returns difference between local and remote channel states.
func (s *server) UpdatesGetChannelDifference(
ctx context.Context, request *tg.UpdatesGetChannelDifferenceRequest,
) (tg.UpdatesChannelDifferenceClass, error) {
s.mux.Lock()
defer s.mux.Unlock()
channel, ok := request.Channel.(*tg.InputChannel)
if !ok {
return nil, errors.Errorf("bad InputChannelClass type: %T", request.Channel)
}
if peer, ok := s.peers.channels[channel.ChannelID]; true {
if !ok {
return nil, errors.Errorf("channel %d not found", channel.ChannelID)
}
if peer.AccessHash != channel.AccessHash {
return nil, errors.New("invalid access hash")
}
}
var (
channelMsgs = s.messages.channels[channel.ChannelID]
ents = NewEntities()
prepared []tg.MessageClass
)
for i := request.Pts + 1; i <= len(channelMsgs); i++ {
prepared = append(prepared, channelMsgs[i-1])
s.fillMessageEnts(channelMsgs[i-1], ents)
}
if len(prepared) == 0 {
return &tg.UpdatesChannelDifferenceEmpty{
Pts: len(channelMsgs),
Final: true,
}, nil
}
return &tg.UpdatesChannelDifference{
NewMessages: prepared,
Users: ents.AsUsers(),
Chats: ents.AsChats(),
Pts: len(channelMsgs),
Final: true,
}, nil
}
func (s *server) fillMessageEnts(msg tg.MessageClass, ents *Entities) {
switch peer := msg.(*tg.Message).PeerID.(type) {
case *tg.PeerUser:
user, ok := s.peers.users[peer.UserID]
if !ok {
panic("bad user")
}
ents.Users[user.ID] = user
case *tg.PeerChat:
chat, ok := s.peers.chats[peer.ChatID]
if !ok {
panic("bad chat")
}
ents.Chats[chat.ID] = chat
case *tg.PeerChannel:
channel, ok := s.peers.channels[peer.ChannelID]
if !ok {
panic("bad channel")
}
ents.Channels[channel.ID] = channel
default:
panic("unexpected peer type")
}
peerUser, ok := msg.(*tg.Message).FromID.(*tg.PeerUser)
if !ok {
panic("bad fromID")
}
user, ok := s.peers.users[peerUser.UserID]
if !ok {
panic("bad user")
}
ents.Users[user.ID] = user
}
@@ -0,0 +1,96 @@
package e2e
import (
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
)
// EventBuilder struct.
type EventBuilder struct {
updates []tg.UpdateClass
ents *Entities
s *server
date int
}
// SendMessage send a new message.
func (e *EventBuilder) SendMessage(from *tg.PeerUser, peer tg.PeerClass, text string) {
msg := &tg.Message{
Message: text,
PeerID: peer,
FromID: from,
Date: e.date,
}
fromUser, ok := e.s.peers.users[from.UserID]
if !ok {
panic("bad fromID")
}
e.ents.Users[from.UserID] = fromUser
switch peer := peer.(type) {
case *tg.PeerUser:
user, ok := e.s.peers.users[peer.UserID]
if !ok {
panic("peer not found")
}
e.ents.Users[user.ID] = user
e.s.messages.common = append(e.s.messages.common, msg)
e.updates = append(e.updates, &tg.UpdateNewMessage{
Message: msg,
Pts: len(e.s.messages.common),
PtsCount: 1,
})
case *tg.PeerChat:
chat, ok := e.s.peers.chats[peer.ChatID]
if !ok {
panic("peer not found")
}
e.ents.Chats[chat.ID] = chat
e.s.messages.common = append(e.s.messages.common, msg)
e.updates = append(e.updates, &tg.UpdateNewMessage{
Message: msg,
Pts: len(e.s.messages.common),
PtsCount: 1,
})
case *tg.PeerChannel:
channel, ok := e.s.peers.channels[peer.ChannelID]
if !ok {
panic("peer not found")
}
e.ents.Channels[channel.ID] = channel
msgs := append(e.s.messages.channels[peer.ChannelID], msg)
e.s.messages.channels[peer.ChannelID] = msgs
e.updates = append(e.updates, &tg.UpdateNewChannelMessage{
Message: msg,
Pts: len(msgs),
PtsCount: 1,
})
default:
panic("unexpected peer type")
}
}
// CreateEvent creates new event.
func (s *server) CreateEvent(f func(ev *EventBuilder)) *tg.Updates {
s.mux.Lock()
defer s.mux.Unlock()
s.date++
ev := &EventBuilder{
ents: NewEntities(),
s: s,
date: s.date,
}
f(ev)
return &tg.Updates{
Updates: ev.updates,
Users: ev.ents.AsUsers(),
Chats: ev.ents.AsChats(),
Date: s.date,
Seq: 0,
}
}
@@ -0,0 +1,226 @@
package e2e
import (
"context"
"sync"
"github.com/go-faster/errors"
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/updates"
)
var _ updates.StateStorage = (*memStorage)(nil)
type memStorage struct {
states map[int64]updates.State
channels map[int64]map[int64]int
mux sync.Mutex
}
func newMemStorage() *memStorage {
return &memStorage{
states: map[int64]updates.State{},
channels: map[int64]map[int64]int{},
}
}
func (s *memStorage) GetState(ctx context.Context, userID int64) (state updates.State, found bool, err error) {
s.mux.Lock()
defer s.mux.Unlock()
state, found = s.states[userID]
return
}
func (s *memStorage) SetState(ctx context.Context, userID int64, state updates.State) error {
s.mux.Lock()
defer s.mux.Unlock()
s.states[userID] = state
s.channels[userID] = map[int64]int{}
return nil
}
func (s *memStorage) SetPts(ctx context.Context, userID int64, pts int) error {
s.mux.Lock()
defer s.mux.Unlock()
state, ok := s.states[userID]
if !ok {
return errors.New("state not found")
}
state.Pts = pts
s.states[userID] = state
return nil
}
func (s *memStorage) SetQts(ctx context.Context, userID int64, qts int) error {
s.mux.Lock()
defer s.mux.Unlock()
state, ok := s.states[userID]
if !ok {
return errors.New("state not found")
}
state.Qts = qts
s.states[userID] = state
return nil
}
func (s *memStorage) SetDate(ctx context.Context, userID int64, date int) error {
s.mux.Lock()
defer s.mux.Unlock()
state, ok := s.states[userID]
if !ok {
return errors.New("state not found")
}
state.Date = date
s.states[userID] = state
return nil
}
func (s *memStorage) SetSeq(ctx context.Context, userID int64, seq int) error {
s.mux.Lock()
defer s.mux.Unlock()
state, ok := s.states[userID]
if !ok {
return errors.New("state not found")
}
state.Seq = seq
s.states[userID] = state
return nil
}
func (s *memStorage) SetDateSeq(ctx context.Context, userID int64, date, seq int) error {
s.mux.Lock()
defer s.mux.Unlock()
state, ok := s.states[userID]
if !ok {
return errors.New("state not found")
}
state.Date = date
state.Seq = seq
s.states[userID] = state
return nil
}
func (s *memStorage) SetChannelPts(ctx context.Context, userID, channelID int64, pts int) error {
s.mux.Lock()
defer s.mux.Unlock()
channels, ok := s.channels[userID]
if !ok {
return errors.New("user state does not exist")
}
channels[channelID] = pts
return nil
}
func (s *memStorage) GetChannelPts(ctx context.Context, userID, channelID int64) (pts int, found bool, err error) {
s.mux.Lock()
defer s.mux.Unlock()
channels, ok := s.channels[userID]
if !ok {
return 0, false, nil
}
pts, found = channels[channelID]
return
}
func (s *memStorage) ForEachChannels(ctx context.Context, userID int64, f func(ctx context.Context, channelID int64, pts int) error) error {
s.mux.Lock()
defer s.mux.Unlock()
cmap, ok := s.channels[userID]
if !ok {
return errors.New("channels map does not exist")
}
for id, pts := range cmap {
if err := f(ctx, id, pts); err != nil {
return err
}
}
return nil
}
var _ updates.AccessHasher = (*memAccessHasher)(nil)
type memAccessHasher struct {
channelHashes map[int64]map[int64]int64
userHashes map[int64]map[int64]int64
mux sync.Mutex
}
func newMemAccessHasher() *memAccessHasher {
return &memAccessHasher{
channelHashes: map[int64]map[int64]int64{},
userHashes: map[int64]map[int64]int64{},
}
}
func (m *memAccessHasher) GetChannelAccessHash(ctx context.Context, forUserID, channelID int64) (accessHash int64, found bool, err error) {
m.mux.Lock()
defer m.mux.Unlock()
accessHashes, ok := m.channelHashes[forUserID]
if !ok {
return 0, false, nil
}
accessHash, found = accessHashes[channelID]
return
}
func (m *memAccessHasher) SetChannelAccessHash(ctx context.Context, forUserID, channelID, accessHash int64) error {
m.mux.Lock()
defer m.mux.Unlock()
accessHashes, ok := m.channelHashes[forUserID]
if !ok {
accessHashes = map[int64]int64{}
m.channelHashes[forUserID] = accessHashes
}
accessHashes[channelID] = accessHash
return nil
}
func (m *memAccessHasher) GetUserAccessHash(ctx context.Context, forUserID, userID int64) (accessHash int64, found bool, err error) {
m.mux.Lock()
defer m.mux.Unlock()
accessHashes, ok := m.userHashes[forUserID]
if !ok {
return 0, false, nil
}
accessHash, found = accessHashes[userID]
return
}
func (m *memAccessHasher) SetUserAccessHash(ctx context.Context, forUserID, userID, accessHash int64) error {
m.mux.Lock()
defer m.mux.Unlock()
accessHashes, ok := m.userHashes[forUserID]
if !ok {
accessHashes = map[int64]int64{}
m.channelHashes[forUserID] = accessHashes
}
accessHashes[userID] = accessHash
return nil
}
+199
View File
@@ -0,0 +1,199 @@
package updates
import (
"context"
"sync"
"github.com/go-faster/errors"
"go.opentelemetry.io/otel/trace"
"go.uber.org/zap"
"golang.org/x/sync/errgroup"
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
)
var _ telegram.UpdateHandler = (*Manager)(nil)
// Manager deals with gaps.
//
// Important:
// Updates produced by this manager may contain
// negative Pts/Qts/Seq values in tg.UpdateClass/tg.UpdatesClass
// (does not affects to the tg.MessageClass).
//
// This is because telegram server does not return these sequences
// for getDifference/getChannelDifference results.
// You SHOULD NOT use them in update handlers at all.
type Manager struct {
state *internalState
mux sync.Mutex
// immutable:
cfg Config
lg *zap.Logger
tracer trace.Tracer
}
// New creates new manager.
func New(cfg Config) *Manager {
cfg.setDefaults()
return &Manager{
cfg: cfg,
lg: cfg.Logger,
tracer: cfg.TracerProvider.Tracer(""),
}
}
// Handle handles updates.
//
// Important:
// If Run method not called, all updates will be passed
// to the provided handler as-is without any order verification
// or short updates transformation.
func (m *Manager) Handle(ctx context.Context, u tg.UpdatesClass) error {
ctx, span := m.tracer.Start(ctx, "updates.Manager.Handle")
defer span.End()
m.lg.Debug("Handle")
defer m.lg.Debug("Handled")
m.mux.Lock()
state := m.state
m.mux.Unlock()
if state == nil {
m.lg.Debug("Handle (no internalState)")
return m.cfg.Handler.Handle(ctx, u)
}
return state.Push(ctx, u)
}
type AuthOptions struct {
IsBot bool
Forget bool
OnStart func(ctx context.Context)
}
// Run notifies manager about user authentication on the telegram server.
//
// If forget is true, local internalState (if exist) will be overwritten
// with remote internalState.
func (m *Manager) Run(ctx context.Context, api API, userID int64, opt AuthOptions) error {
lg := m.lg.With(
zap.Int64("user_id", userID),
zap.Bool("is_bot", opt.IsBot),
zap.Bool("forget", opt.Forget),
)
lg.Debug("Run")
defer lg.Debug("Done")
wg, ctx := errgroup.WithContext(ctx)
if err := func() error {
m.mux.Lock()
defer m.mux.Unlock()
if m.state != nil {
return errors.Errorf("already authorized (userID: %d)", m.state.selfID)
}
state, err := m.loadState(ctx, api, userID, opt.Forget)
if err != nil {
return errors.Wrap(err, "load internalState")
}
channels := make(map[int64]struct {
Pts int
AccessHash int64
})
if err := m.cfg.Storage.ForEachChannels(ctx, userID, func(ctx context.Context, channelID int64, pts int) error {
hash, found, err := m.cfg.AccessHasher.GetChannelAccessHash(ctx, userID, channelID)
if err != nil {
return errors.Wrap(err, "get channel access hash")
}
if !found {
return nil
}
channels[channelID] = struct {
Pts int
AccessHash int64
}{Pts: pts, AccessHash: hash}
return nil
}); err != nil {
return errors.Wrap(err, "iterate channels")
}
diffLim := diffLimitUser
if opt.IsBot {
diffLim = diffLimitBot
}
m.state = newState(ctx, stateConfig{
State: state,
Channels: channels,
RawClient: api,
Tracer: m.tracer,
Logger: m.cfg.Logger,
Handler: m.cfg.Handler,
OnChannelTooLong: m.cfg.OnChannelTooLong,
Storage: m.cfg.Storage,
Hasher: m.cfg.AccessHasher,
SelfID: userID,
DiffLimit: diffLim,
WorkGroup: wg,
})
return nil
}(); err != nil {
return errors.Wrap(err, "setup")
}
if opt.OnStart != nil {
opt.OnStart(ctx)
}
wg.Go(func() error {
return m.state.Run(ctx)
})
lg.Debug("Wait")
return wg.Wait()
}
func (m *Manager) loadState(ctx context.Context, api API, userID int64, forget bool) (State, error) {
onNotFound:
var state State
if forget {
remote, err := api.UpdatesGetState(ctx)
if err != nil {
return State{}, errors.Wrap(err, "get remote internalState")
}
state = state.fromRemote(remote)
if err := m.cfg.Storage.SetState(ctx, userID, state); err != nil {
return State{}, err
}
return state, nil
}
state, found, err := m.cfg.Storage.GetState(ctx, userID)
if err != nil {
return State{}, errors.Wrap(err, "restore local internalState")
}
if !found {
forget = true
goto onNotFound
}
return state, nil
}
// Reset notifies manager about user logout.
func (m *Manager) Reset() {
m.mux.Lock()
defer m.mux.Unlock()
m.state = nil
}
+189
View File
@@ -0,0 +1,189 @@
package updates
import (
"context"
"sort"
"time"
"go.opentelemetry.io/otel/trace"
"go.uber.org/zap"
)
type sequenceBox struct {
state int
gaps gapBuffer
gapTimeout *time.Timer
pending []update
apply func(ctx context.Context, state int, updates []update) error
log *zap.Logger
tracer trace.Tracer
}
type sequenceConfig struct {
InitialState int
Apply func(ctx context.Context, state int, updates []update) error
Logger *zap.Logger
Tracer trace.Tracer
}
func newSequenceBox(cfg sequenceConfig) *sequenceBox {
if cfg.Apply == nil {
panic("Apply func nil")
}
if cfg.Logger == nil {
cfg.Logger = zap.NewNop()
}
if cfg.Tracer == nil {
cfg.Tracer = trace.NewNoopTracerProvider().Tracer("")
}
cfg.Logger.Debug("Initialized", zap.Int("internalState", cfg.InitialState))
t := time.NewTimer(fastgapTimeout)
_ = t.Stop()
return &sequenceBox{
state: cfg.InitialState,
gapTimeout: t,
apply: cfg.Apply,
log: cfg.Logger,
tracer: cfg.Tracer,
}
}
func (s *sequenceBox) Handle(ctx context.Context, u update) error {
ctx, span := s.tracer.Start(ctx, "sequenceBox.Handle")
defer span.End()
log := s.log.With(zap.Int("upd_from", u.start()), zap.Int("upd_to", u.end()))
if checkGap(s.state, u.State, u.Count) == gapIgnore {
log.Debug("Outdated update, skipping", zap.Int("internalState", s.state))
return nil
}
if s.gaps.Has() {
s.pending = append(s.pending, u)
if accepted := s.gaps.Consume(u); !accepted {
log.Debug("Out of gap range, postponed", zap.Array("gaps", s.gaps))
return nil
}
log.Debug("Gap accepted", zap.Array("gaps", s.gaps))
if !s.gaps.Has() {
_ = s.gapTimeout.Stop()
s.log.Debug("Gap was resolved by waiting")
return s.applyPending(ctx)
}
return nil
}
switch checkGap(s.state, u.State, u.Count) {
case gapApply:
if len(s.pending) > 0 {
s.pending = append(s.pending, u)
return s.applyPending(ctx)
}
if err := s.apply(ctx, u.State, []update{u}); err != nil {
return err
}
log.Debug("Accepted")
s.setState(u.State, "update")
return nil
case gapRefetch:
s.pending = append(s.pending, u)
s.gaps.Enable(s.state, u.start())
// Check if we already have acceptable updates in buffer.
for _, u := range s.pending {
_ = s.gaps.Consume(u)
}
if !s.gaps.Has() {
log.Debug("Gap was resolved by pending updates")
return s.applyPending(ctx)
}
_ = s.gapTimeout.Reset(fastgapTimeout)
s.log.Debug("Gap detected", zap.Array("gap", s.gaps))
return nil
default:
panic("unreachable")
}
}
func (s *sequenceBox) applyPending(ctx context.Context) error {
ctx, span := s.tracer.Start(ctx, "sequenceBox.applyPending")
defer span.End()
sort.SliceStable(s.pending, func(i, j int) bool {
return s.pending[i].start() < s.pending[j].start()
})
var (
cursor = 0
state = s.state
accepted []update
)
loop:
for i, update := range s.pending {
switch checkGap(state, update.State, update.Count) {
case gapApply:
accepted = append(accepted, update)
state = update.State
cursor = i + 1
continue
case gapIgnore:
cursor = i + 1
continue
case gapRefetch:
break loop
}
}
// Trim processed updates. Setting zero values for the rest
// of the slice lets GC collect referenced objects.
end := len(s.pending)
trim := end - cursor
copy(s.pending, s.pending[cursor:])
for i := trim; i < end; i++ {
s.pending[i] = update{}
}
s.pending = s.pending[:trim]
if len(accepted) == 0 {
s.log.Warn("Empty buffer", zap.Any("pending", s.pending), zap.Int("internalState", s.state))
return nil
}
if err := s.apply(ctx, state, accepted); err != nil {
return err
}
s.log.Debug("Pending updates applied",
zap.Int("prev_state", s.state),
zap.Int("new_state", state),
zap.Int("accepted_count", len(accepted)),
)
s.setState(state, "pending updates")
return nil
}
func (s *sequenceBox) State() int { return s.state }
func (s *sequenceBox) SetState(state int, reason string) {
s.setState(state, reason)
}
func (s *sequenceBox) setState(state int, reason string) {
old := s.state
s.state = state
s.log.Debug("State changed",
zap.Int("old", old),
zap.Int("new", state),
zap.String("reason", reason),
)
}
@@ -0,0 +1,168 @@
package updates
import (
"context"
"testing"
"github.com/stretchr/testify/require"
"go.uber.org/zap/zaptest"
)
func TestSequenceBox(t *testing.T) {
var (
state int
updates []update
)
ctx := context.Background()
box := newSequenceBox(sequenceConfig{
InitialState: 3,
Apply: func(ctx context.Context, s int, u []update) error {
state = s
updates = append(updates, u...)
return nil
},
Logger: zaptest.NewLogger(t),
})
require.Nil(t, box.Handle(ctx, update{
Value: 1,
State: 2,
Count: 1,
}))
require.Zero(t, state)
require.Empty(t, updates)
require.Empty(t, box.pending)
require.Nil(t, box.Handle(ctx, update{
Value: 1,
State: 3,
Count: 1,
}))
require.Zero(t, state)
require.Empty(t, updates)
require.Empty(t, box.pending)
require.Nil(t, box.Handle(ctx, update{
Value: 1,
State: 4,
Count: 1,
}))
require.Equal(t, 4, state)
require.Equal(t, []update{{Value: 1, State: 4, Count: 1}}, updates)
require.Empty(t, box.pending)
updates = nil
require.Nil(t, box.Handle(ctx, update{
Value: 1,
State: 6,
Count: 1,
}))
require.Equal(t, 4, state)
require.Empty(t, updates)
require.Equal(t, []update{{Value: 1, State: 6, Count: 1}}, box.pending)
require.Nil(t, box.Handle(ctx, update{
Value: 2,
State: 5,
Count: 1,
}))
require.Equal(t, 6, state)
require.Equal(t, []update{{Value: 2, State: 5, Count: 1}, {Value: 1, State: 6, Count: 1}}, updates)
require.Empty(t, box.pending)
updates = nil
require.Nil(t, box.Handle(ctx, update{
Value: 3,
State: 8,
Count: 1,
}))
require.Equal(t, 6, state)
require.Empty(t, updates)
require.Equal(t, []update{{Value: 3, State: 8, Count: 1}}, box.pending)
<-box.gapTimeout.C
require.Equal(t, []gap{{from: 6, to: 7}}, box.gaps.gaps)
box.gaps.Clear()
require.False(t, box.gaps.Has())
}
func TestSequenceBoxApplyPending(t *testing.T) {
tests := []struct {
InitialState int
Pending []update
PendingAfter []update
Applied []update
}{
{
InitialState: 5,
Pending: []update{
{Value: 1, State: 3, Count: 1},
{Value: 1, State: 4, Count: 1},
{Value: 1, State: 1, Count: 1},
},
PendingAfter: []update{},
Applied: []update{},
},
{
InitialState: 5,
Pending: []update{
{Value: 1, State: 3, Count: 1},
{Value: 1, State: 8, Count: 1},
{Value: 1, State: 7, Count: 1},
{Value: 1, State: 4, Count: 1},
{Value: 1, State: 1, Count: 1},
},
PendingAfter: []update{
{1, 7, 1, entities{}},
{1, 8, 1, entities{}},
},
Applied: []update{},
},
{
InitialState: 5,
Pending: []update{
{Value: 1, State: 8, Count: 1},
{Value: 1, State: 7, Count: 1},
},
PendingAfter: []update{
{1, 7, 1, entities{}},
{Value: 1, State: 8, Count: 1},
},
Applied: []update{},
},
{
InitialState: 5,
Pending: []update{
{Value: 1, State: 3, Count: 1},
{Value: 1, State: 6, Count: 1},
{Value: 1, State: 8, Count: 1},
{Value: 1, State: 4, Count: 1},
{Value: 1, State: 1, Count: 1},
},
PendingAfter: []update{
{Value: 1, State: 8, Count: 1},
},
Applied: []update{
{Value: 1, State: 6, Count: 1},
},
},
}
for _, test := range tests {
applied := make([]update, 0)
box := newSequenceBox(sequenceConfig{
InitialState: test.InitialState,
Apply: func(_ context.Context, s int, u []update) error {
applied = append(applied, u...)
return nil
},
Logger: zaptest.NewLogger(t),
})
box.pending = test.Pending
require.NoError(t, box.applyPending(context.TODO()))
require.Equal(t, test.PendingAfter, box.pending)
require.Equal(t, test.Applied, applied)
}
}
+82
View File
@@ -0,0 +1,82 @@
package updates
import (
"sort"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
)
func sortUpdatesByPts(u []tg.UpdateClass) {
sort.Stable(ptsSorter(u))
}
type ptsSorter []tg.UpdateClass
func (p ptsSorter) Len() int {
return len(p)
}
func (p ptsSorter) Less(i, j int) bool {
type (
ptsType int
compare struct {
typ ptsType
channelID int64
ptsDiff int
}
)
// Sorting order
//
// 0) Common updates without PTS.
// 1) Common PTS updates
// 2) Common QTS updates
// 3) Channel PTS updates (by channelID and by pts)
const (
other ptsType = iota
commonPts
commonQts
channelPts
)
getType := func(u tg.UpdateClass) compare {
if pts, ptsCount, ok := tg.IsPtsUpdate(u); ok {
return compare{typ: commonPts, ptsDiff: pts - ptsCount}
}
if channelID, pts, ptsCount, ok, _ := tg.IsChannelPtsUpdate(u); ok {
return compare{typ: channelPts, channelID: channelID, ptsDiff: pts - ptsCount}
}
if qts, ok := tg.IsQtsUpdate(u); ok {
return compare{typ: commonQts, ptsDiff: qts}
}
return compare{typ: other}
}
a, b := getType(p[i]), getType(p[j])
switch {
case a.typ < b.typ:
return true
case a.typ > b.typ:
return false
}
// a.typ == b.typ
switch a.typ {
case other:
// Keep original order
case commonPts, commonQts:
return a.ptsDiff < b.ptsDiff
case channelPts:
if a.channelID < b.channelID {
return true
}
if a.channelID > b.channelID {
return false
}
return a.ptsDiff < b.ptsDiff
}
return false
}
func (p ptsSorter) Swap(i, j int) {
p[i], p[j] = p[j], p[i]
}
@@ -0,0 +1,96 @@
package updates
import (
"fmt"
"testing"
"github.com/stretchr/testify/require"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
)
func Test_sortUpdatesByPts(t *testing.T) {
channelNewMessage := func(pts int, id int64) *tg.UpdateNewChannelMessage {
return &tg.UpdateNewChannelMessage{
Message: &tg.Message{
PeerID: &tg.PeerChannel{ChannelID: id},
},
Pts: pts - 1,
PtsCount: 1,
}
}
newMessage := func(pts int) *tg.UpdateNewMessage {
return &tg.UpdateNewMessage{
Message: &tg.Message{
PeerID: &tg.PeerUser{UserID: 10},
},
Pts: pts - 1,
PtsCount: 1,
}
}
encryptedNewMessage := func(pts int) *tg.UpdateNewEncryptedMessage {
return &tg.UpdateNewEncryptedMessage{
Qts: pts,
}
}
channelReadInbox := func(pts int, id int64) *tg.UpdateReadChannelInbox {
return &tg.UpdateReadChannelInbox{ChannelID: id, MaxID: 25, Pts: pts}
}
tests := []struct {
input []tg.UpdateClass
result []tg.UpdateClass
}{
{
[]tg.UpdateClass{
channelReadInbox(26, 1),
channelNewMessage(25, 1),
},
[]tg.UpdateClass{
channelNewMessage(25, 1),
channelReadInbox(26, 1),
},
},
{
[]tg.UpdateClass{
channelReadInbox(26, 1),
channelNewMessage(25, 2),
newMessage(26),
encryptedNewMessage(26),
newMessage(25),
encryptedNewMessage(25),
encryptedNewMessage(27),
channelNewMessage(25, 1),
},
[]tg.UpdateClass{
newMessage(25),
newMessage(26),
encryptedNewMessage(25),
encryptedNewMessage(26),
encryptedNewMessage(27),
channelNewMessage(25, 1),
channelReadInbox(26, 1),
channelNewMessage(25, 2),
},
},
{
[]tg.UpdateClass{
channelReadInbox(26, 1),
&tg.UpdateConfig{},
channelNewMessage(25, 1),
},
[]tg.UpdateClass{
&tg.UpdateConfig{},
channelNewMessage(25, 1),
channelReadInbox(26, 1),
},
},
}
for i, tt := range tests {
tt := tt
t.Run(fmt.Sprintf("Test%d", i+1), func(t *testing.T) {
sortUpdatesByPts(tt.input)
require.Equal(t, tt.result, tt.input)
})
}
}
+528
View File
@@ -0,0 +1,528 @@
package updates
import (
"context"
"fmt"
"time"
"github.com/go-faster/errors"
"go.opentelemetry.io/otel/trace"
"go.uber.org/zap"
"golang.org/x/sync/errgroup"
"go.mau.fi/mautrix-telegram/pkg/gotd/exchange"
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram"
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/auth"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
"go.mau.fi/mautrix-telegram/pkg/gotd/tgerr"
)
const (
idleTimeout = time.Minute * 15
fastgapTimeout = time.Millisecond * 500
diffLimitUser = 100
diffLimitBot = 100000
)
type tracedUpdate struct {
update tg.UpdatesClass
span trace.SpanContext
}
type internalState struct {
// Updates channel.
externalQueue chan tracedUpdate
// Updates from channel states
// during updates.getChannelDifference.
internalQueue chan tracedUpdate
// Common internalState.
pts, qts, seq *sequenceBox
date int
idleTimeout *time.Timer
// Channel states.
channels map[int64]*channelState
// Immutable fields.
client API
log *zap.Logger
handler telegram.UpdateHandler
onTooLong func(channelID int64) error
storage StateStorage
hasher AccessHasher
selfID int64
diffLim int
wg *errgroup.Group
tracer trace.Tracer
}
type stateConfig struct {
State State
Channels map[int64]struct {
Pts int
AccessHash int64
}
RawClient API
Logger *zap.Logger
Tracer trace.Tracer
Handler telegram.UpdateHandler
OnChannelTooLong func(channelID int64) error
Storage StateStorage
Hasher AccessHasher
SelfID int64
DiffLimit int
WorkGroup *errgroup.Group
}
func newState(ctx context.Context, cfg stateConfig) *internalState {
s := &internalState{
externalQueue: make(chan tracedUpdate, 10),
internalQueue: make(chan tracedUpdate, 10),
date: cfg.State.Date,
idleTimeout: time.NewTimer(idleTimeout),
channels: make(map[int64]*channelState),
client: cfg.RawClient,
log: cfg.Logger,
handler: cfg.Handler,
onTooLong: cfg.OnChannelTooLong,
storage: cfg.Storage,
hasher: cfg.Hasher,
selfID: cfg.SelfID,
diffLim: cfg.DiffLimit,
wg: cfg.WorkGroup,
tracer: cfg.Tracer,
}
s.pts = newSequenceBox(sequenceConfig{
InitialState: cfg.State.Pts,
Apply: s.applyPts,
Logger: s.log.Named("pts"),
Tracer: s.tracer,
})
s.qts = newSequenceBox(sequenceConfig{
InitialState: cfg.State.Qts,
Apply: s.applyQts,
Logger: s.log.Named("qts"),
Tracer: s.tracer,
})
s.seq = newSequenceBox(sequenceConfig{
InitialState: cfg.State.Seq,
Apply: s.applySeq,
Logger: s.log.Named("seq"),
})
for id, info := range cfg.Channels {
state := s.newChannelState(id, info.AccessHash, info.Pts)
s.channels[id] = state
s.wg.Go(func() error {
return state.Run(ctx)
})
}
return s
}
func (s *internalState) Push(ctx context.Context, u tg.UpdatesClass) error {
tu := tracedUpdate{
update: u,
span: trace.SpanContextFromContext(ctx),
}
select {
case s.externalQueue <- tu:
return nil
case <-ctx.Done():
return ctx.Err()
}
}
// isFatalError returns true if error is fatal so we should stop updates handler.
func isFatalError(err error) bool {
// See https://github.com/gotd/td/issues/1458.
if errors.Is(err, exchange.ErrKeyFingerprintNotFound) {
return true
}
if tgerr.Is(err, "AUTH_KEY_UNREGISTERED", "SESSION_EXPIRED") {
return true
}
if auth.IsUnauthorized(err) {
return true
}
return false
}
func (s *internalState) Run(ctx context.Context) error {
if s == nil {
return errors.New("invalid: nil internalState")
}
if s.log == nil {
return errors.New("invalid: nil logger")
}
s.log.Debug("Starting updates handler")
defer s.log.Debug("Updates handler stopped")
s.getDifferenceLogger(ctx)
for {
select {
case <-ctx.Done():
if len(s.pts.pending) > 0 || len(s.qts.pending) > 0 || len(s.seq.pending) > 0 {
s.getDifferenceLogger(ctx)
}
return ctx.Err()
case u := <-s.externalQueue:
ctx := trace.ContextWithSpanContext(ctx, u.span)
if err := s.handleUpdates(ctx, u.update); err != nil {
return err
}
case u := <-s.internalQueue:
ctx := trace.ContextWithSpanContext(ctx, u.span)
if err := s.handleUpdates(ctx, u.update); err != nil {
return err
}
case <-s.pts.gapTimeout.C:
s.log.Debug("Pts gap timeout")
s.getDifferenceLogger(ctx)
case <-s.qts.gapTimeout.C:
s.log.Debug("Qts gap timeout")
s.getDifferenceLogger(ctx)
case <-s.seq.gapTimeout.C:
s.log.Debug("Seq gap timeout")
s.getDifferenceLogger(ctx)
case <-s.idleTimeout.C:
s.log.Debug("Idle timeout")
s.getDifferenceLogger(ctx)
}
}
}
func (s *internalState) handleUpdates(ctx context.Context, u tg.UpdatesClass) error {
ctx, span := s.tracer.Start(ctx, "handleUpdates")
defer span.End()
s.resetIdleTimer()
switch u := u.(type) {
case *tg.Updates:
s.saveUserHashes(ctx, u.Users)
s.saveChannelHashes(ctx, u.Chats)
return s.handleSeq(ctx, &tg.UpdatesCombined{
Updates: u.Updates,
Users: u.Users,
Chats: u.Chats,
Date: u.Date,
Seq: u.Seq,
SeqStart: u.Seq,
})
case *tg.UpdatesCombined:
s.saveUserHashes(ctx, u.Users)
s.saveChannelHashes(ctx, u.Chats)
return s.handleSeq(ctx, u)
case *tg.UpdateShort:
return s.handleUpdates(ctx, &tg.UpdatesCombined{
Updates: []tg.UpdateClass{u.Update},
Date: u.Date,
})
case *tg.UpdateShortMessage:
updateShort := s.convertShortMessage(u)
if _, found, err := s.hasher.GetUserAccessHash(ctx, s.selfID, u.UserID); err != nil {
return err
} else if !found {
chats, users, err := s.handleDifference(ctx, u.Date)
if err != nil {
return err
}
return s.handleUpdates(ctx, &tg.UpdatesCombined{
Updates: []tg.UpdateClass{updateShort.Update},
Users: users,
Chats: chats,
Date: u.Date,
})
} else {
return s.handleUpdates(ctx, updateShort)
}
case *tg.UpdateShortChatMessage:
updateShort := s.convertShortChatMessage(u)
if _, found, err := s.hasher.GetUserAccessHash(ctx, s.selfID, u.FromID); err != nil {
return err
} else if !found {
chats, users, err := s.handleDifference(ctx, u.Date)
if err != nil {
return err
}
return s.handleUpdates(ctx, &tg.UpdatesCombined{
Updates: []tg.UpdateClass{updateShort.Update},
Users: users,
Chats: chats,
Date: u.Date,
})
} else {
return s.handleUpdates(ctx, updateShort)
}
case *tg.UpdateShortSentMessage:
return s.handleUpdates(ctx, s.convertShortSentMessage(u))
case *tg.UpdatesTooLong:
return s.getDifference(ctx)
default:
panic(fmt.Sprintf("unexpected update type: %T", u))
}
}
func (s *internalState) handleSeq(ctx context.Context, u *tg.UpdatesCombined) error {
ctx, span := s.tracer.Start(ctx, "handleSeq")
defer span.End()
if err := validateSeq(u.Seq, u.SeqStart); err != nil {
s.log.Error("Seq validation failed", zap.Error(err), zap.Any("update", u))
return nil
}
// Special case.
if u.Seq == 0 {
ptsChanged, err := s.applyCombined(ctx, u)
if err != nil {
return err
}
if ptsChanged {
return s.getDifference(ctx)
}
return nil
}
return s.seq.Handle(ctx, update{
Value: u,
State: u.Seq,
Count: u.Seq - u.SeqStart + 1,
})
}
func (s *internalState) handlePts(ctx context.Context, pts, ptsCount int, u tg.UpdateClass, ents entities) error {
if err := validatePts(pts, ptsCount); err != nil {
s.log.Error("Pts validation failed", zap.Error(err), zap.Any("update", u))
return nil
}
return s.pts.Handle(ctx, update{
Value: u,
State: pts,
Count: ptsCount,
Entities: ents,
})
}
func (s *internalState) handleQts(ctx context.Context, qts int, u tg.UpdateClass, ents entities) error {
if err := validateQts(qts); err != nil {
s.log.Error("Qts validation failed", zap.Error(err), zap.Any("update", u))
return nil
}
return s.qts.Handle(ctx, update{
Value: u,
State: qts,
Count: 1,
Entities: ents,
})
}
func (s *internalState) handleChannel(ctx context.Context, channelID int64, date, pts, ptsCount int, cu channelUpdate) error {
if err := validatePts(pts, ptsCount); err != nil {
s.log.Error("Pts validation failed", zap.Error(err), zap.Any("update", cu.update))
return nil
}
state, ok := s.channels[channelID]
if !ok {
accessHash, found, err := s.hasher.GetChannelAccessHash(context.Background(), s.selfID, channelID)
if err != nil {
s.log.Error("GetAccessHash error", zap.Error(err))
}
if !found {
if date == 0 {
// Received update has no date field.
date = s.date - 30
} else {
date-- // 1 sec back
}
// Try to get access hash from updates.getDifference.
accessHash, found = s.restoreChannelAccessHash(ctx, channelID, date)
if !found {
s.log.Debug("Failed to recover missing access hash, update ignored",
zap.Int64("channel_id", channelID),
// zap.Any("update", cu.update),
)
return nil
}
}
localPts, found, err := s.storage.GetChannelPts(ctx, s.selfID, channelID)
if err != nil {
localPts = pts - ptsCount
s.log.Error("GetChannelPts error", zap.Error(err))
}
if !found {
localPts = pts - ptsCount
if err := s.storage.SetChannelPts(ctx, s.selfID, channelID, localPts); err != nil {
s.log.Error("SetChannelPts error", zap.Error(err))
}
}
state = s.newChannelState(channelID, accessHash, localPts)
s.channels[channelID] = state
s.wg.Go(func() error {
return state.Run(ctx)
})
}
return state.Push(ctx, cu)
}
func (s *internalState) newChannelState(channelID, accessHash int64, initialPts int) *channelState {
return newChannelState(channelStateConfig{
Out: s.internalQueue,
InitialPts: initialPts,
ChannelID: channelID,
AccessHash: accessHash,
SelfID: s.selfID,
Storage: s.storage,
DiffLimit: s.diffLim,
RawClient: s.client,
Handler: s.handler,
OnChannelTooLong: s.onTooLong,
Logger: s.log.Named("channel").With(zap.Int64("channel_id", channelID)),
Tracer: s.tracer,
})
}
func (s *internalState) getDifference(ctx context.Context) error {
ctx, span := s.tracer.Start(ctx, "getDifference")
defer span.End()
s.resetIdleTimer()
s.pts.gaps.Clear()
s.qts.gaps.Clear()
s.seq.gaps.Clear()
s.log.Debug("Getting difference")
setState := func(state tg.UpdatesState, reason string) {
if err := s.storage.SetState(ctx, s.selfID, State{}.fromRemote(&state)); err != nil {
s.log.Warn("SetState error", zap.Error(err))
}
s.pts.SetState(state.Pts, reason)
s.qts.SetState(state.Qts, reason)
s.seq.SetState(state.Seq, reason)
s.date = state.Date
}
diff, err := s.client.UpdatesGetDifference(ctx, &tg.UpdatesGetDifferenceRequest{
Pts: s.pts.State(),
Qts: s.qts.State(),
Date: s.date,
})
if err != nil {
return errors.Wrap(err, "get difference")
}
s.log.Debug("Difference received", zap.String("diff", fmt.Sprintf("%T", diff)))
switch diff := diff.(type) {
case *tg.UpdatesDifference:
if len(diff.OtherUpdates) > 0 {
if err := s.handleUpdates(ctx, &tg.UpdatesCombined{
Updates: diff.OtherUpdates,
Users: diff.Users,
Chats: diff.Chats,
}); err != nil {
return errors.Wrap(err, "handle diff.OtherUpdates")
}
}
if len(diff.NewMessages) > 0 || len(diff.NewEncryptedMessages) > 0 {
if err := s.handler.Handle(ctx, &tg.Updates{
Updates: append(
msgsToUpdates(diff.NewMessages, false),
encryptedMsgsToUpdates(diff.NewEncryptedMessages)...,
),
Users: diff.Users,
Chats: diff.Chats,
}); err != nil {
return err
}
}
setState(diff.State, "updates.Difference")
return nil
// No events.
case *tg.UpdatesDifferenceEmpty:
if err := s.storage.SetDateSeq(ctx, s.selfID, diff.Date, diff.Seq); err != nil {
s.log.Warn("SetDateSeq error", zap.Error(err))
}
s.date = diff.Date
s.seq.SetState(diff.Seq, "updates.differenceEmpty")
return nil
// Incomplete list of occurred events.
case *tg.UpdatesDifferenceSlice:
if len(diff.OtherUpdates) > 0 {
if err := s.handleUpdates(ctx, &tg.UpdatesCombined{
Updates: diff.OtherUpdates,
Users: diff.Users,
Chats: diff.Chats,
Date: diff.IntermediateState.Date,
}); err != nil {
return err
}
}
if len(diff.NewMessages) > 0 || len(diff.NewEncryptedMessages) > 0 {
if err := s.handler.Handle(ctx, &tg.Updates{
Updates: append(
msgsToUpdates(diff.NewMessages, false),
encryptedMsgsToUpdates(diff.NewEncryptedMessages)...,
),
Users: diff.Users,
Chats: diff.Chats,
}); err != nil {
return err
}
}
setState(diff.IntermediateState, "updates.differenceSlice")
return s.getDifference(ctx)
// The difference is too long, and the specified internalState must be used to refetch updates.
case *tg.UpdatesDifferenceTooLong:
if err := s.storage.SetPts(ctx, s.selfID, diff.Pts); err != nil {
s.log.Error("SetPts error", zap.Error(err))
}
s.pts.SetState(diff.Pts, "updates.differenceTooLong")
return s.getDifference(ctx)
default:
return errors.Errorf("unexpected diff type: %T", diff)
}
}
func (s *internalState) getDifferenceLogger(ctx context.Context) {
if err := s.getDifference(ctx); err != nil {
s.log.Error("get difference error", zap.Error(err))
}
}
func (s *internalState) resetIdleTimer() {
if len(s.idleTimeout.C) > 0 {
<-s.idleTimeout.C
}
_ = s.idleTimeout.Reset(idleTimeout)
}
+189
View File
@@ -0,0 +1,189 @@
package updates
import (
"context"
"go.opentelemetry.io/otel/trace"
"go.uber.org/zap"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
)
func (s *internalState) applySeq(ctx context.Context, state int, updates []update) error {
recoverState := false
for _, u := range updates {
ptsChanged, err := s.applyCombined(ctx, u.Value.(*tg.UpdatesCombined))
if err != nil {
return err
}
if ptsChanged {
recoverState = true
}
}
if err := s.storage.SetSeq(ctx, s.selfID, state); err != nil {
return err
}
if recoverState {
return s.getDifference(ctx)
}
return nil
}
func (s *internalState) applyCombined(ctx context.Context, comb *tg.UpdatesCombined) (ptsChanged bool, err error) {
ctx, span := s.tracer.Start(ctx, "internalState.applyCombined")
defer span.End()
var (
ents = entities{
Users: comb.Users,
Chats: comb.Chats,
}
)
sortUpdatesByPts(comb.Updates)
for _, u := range comb.Updates {
switch u := u.(type) {
case *tg.UpdatePtsChanged:
ptsChanged = true
continue
case *tg.UpdateChannelTooLong:
st, ok := s.channels[u.ChannelID]
if !ok {
s.log.Debug("ChannelTooLong for channel that is not in the internalState, update ignored", zap.Int64("channel_id", u.ChannelID))
continue
}
if err := st.Push(ctx, channelUpdate{
update: u,
entities: ents,
span: trace.SpanContextFromContext(ctx),
}); err != nil {
return false, err
}
continue
}
if pts, ptsCount, ok := tg.IsPtsUpdate(u); ok {
if err := s.handlePts(ctx, pts, ptsCount, u, ents); err != nil {
return false, err
}
}
if channelID, pts, ptsCount, ok, err := tg.IsChannelPtsUpdate(u); ok {
if err != nil {
s.log.Debug("Invalid channel update", zap.Error(err)) //, zap.Any("update", u))
continue
}
if err := s.handleChannel(ctx, channelID, comb.Date, pts, ptsCount, channelUpdate{
update: u,
entities: ents,
span: trace.SpanContextFromContext(ctx),
}); err != nil {
return false, err
}
}
if qts, ok := tg.IsQtsUpdate(u); ok {
if err := s.handleQts(ctx, qts, u, ents); err != nil {
return false, err
}
}
}
if err := s.handler.Handle(ctx, &tg.Updates{
Updates: comb.Updates,
Users: ents.Users,
Chats: ents.Chats,
}); err != nil {
return false, err
}
setDate, setSeq := comb.Date > s.date, comb.Seq > 0
switch {
case setDate && setSeq:
if err := s.storage.SetDateSeq(ctx, s.selfID, comb.Date, comb.Seq); err != nil {
return false, err
}
s.date = comb.Date
s.seq.SetState(comb.Seq, "seq update")
case setDate:
if err := s.storage.SetDate(ctx, s.selfID, comb.Date); err != nil {
return false, err
}
s.date = comb.Date
case setSeq:
if err := s.storage.SetSeq(ctx, s.selfID, comb.Seq); err != nil {
return false, err
}
s.seq.SetState(comb.Seq, "seq update")
}
return ptsChanged, nil
}
func (s *internalState) applyPts(ctx context.Context, state int, updates []update) error {
ctx, span := s.tracer.Start(ctx, "internalState.applyPts")
defer span.End()
var (
converted []tg.UpdateClass
ents entities
)
for _, update := range updates {
converted = append(converted, update.Value.(tg.UpdateClass))
ents.Merge(update.Entities)
}
if err := s.handler.Handle(ctx, &tg.Updates{
Updates: converted,
Users: ents.Users,
Chats: ents.Chats,
}); err != nil {
return err
}
if err := s.storage.SetPts(ctx, s.selfID, state); err != nil {
return err
}
return nil
}
func (s *internalState) applyQts(ctx context.Context, state int, updates []update) error {
ctx, span := s.tracer.Start(ctx, "internalState.applyQts")
defer span.End()
var (
converted []tg.UpdateClass
ents entities
)
for _, update := range updates {
converted = append(converted, update.Value.(tg.UpdateClass))
ents.Merge(update.Entities)
}
if err := s.handler.Handle(ctx, &tg.Updates{
Updates: converted,
Users: ents.Users,
Chats: ents.Chats,
}); err != nil {
return err
}
// Don't set qts if it's 0, because it means that we are apllying gaps updates
if state == 0 {
return nil
}
if err := s.storage.SetQts(ctx, s.selfID, state); err != nil {
return err
}
return nil
}
+344
View File
@@ -0,0 +1,344 @@
package updates
import (
"context"
"time"
"github.com/go-faster/errors"
"go.opentelemetry.io/otel/trace"
"go.uber.org/zap"
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
)
type channelUpdate struct {
update tg.UpdateClass
entities entities
span trace.SpanContext
}
type channelState struct {
// Updates from *internalState.
updates chan channelUpdate
// Channel to pass diff.OtherUpdates into *internalState.
out chan<- tracedUpdate
// Channel internalState.
pts *sequenceBox
idleTimeout *time.Timer
diffTimeout time.Time
// Immutable fields.
channelID int64
accessHash int64
selfID int64
diffLim int
client API
storage StateStorage
log *zap.Logger
tracer trace.Tracer
handler telegram.UpdateHandler
onTooLong func(channelID int64) error
}
type channelStateConfig struct {
Out chan tracedUpdate
InitialPts int
ChannelID int64
AccessHash int64
SelfID int64
DiffLimit int
RawClient API
Storage StateStorage
Handler telegram.UpdateHandler
OnChannelTooLong func(channelID int64) error
Logger *zap.Logger
Tracer trace.Tracer
}
func newChannelState(cfg channelStateConfig) *channelState {
state := &channelState{
updates: make(chan channelUpdate, 10),
out: cfg.Out,
idleTimeout: time.NewTimer(idleTimeout),
channelID: cfg.ChannelID,
accessHash: cfg.AccessHash,
selfID: cfg.SelfID,
diffLim: cfg.DiffLimit,
client: cfg.RawClient,
storage: cfg.Storage,
log: cfg.Logger,
handler: cfg.Handler,
onTooLong: cfg.OnChannelTooLong,
tracer: cfg.Tracer,
}
state.pts = newSequenceBox(sequenceConfig{
InitialState: cfg.InitialPts,
Apply: state.applyPts,
Logger: cfg.Logger.Named("pts"),
Tracer: cfg.Tracer,
})
return state
}
func (s *channelState) Push(ctx context.Context, u channelUpdate) error {
select {
case <-ctx.Done():
return ctx.Err()
case s.updates <- u:
return nil
}
}
func (s *channelState) Run(ctx context.Context) error {
// Subscribe to channel updates.
if err := s.getDifference(ctx); err != nil {
s.log.Error("Failed to subscribe to channel updates", zap.Error(err))
}
for {
select {
case u := <-s.updates:
ctx := trace.ContextWithSpanContext(ctx, u.span)
if err := s.handleUpdate(ctx, u.update, u.entities); err != nil {
s.log.Error("Handle update error", zap.Error(err))
}
case <-s.pts.gapTimeout.C:
s.log.Debug("Gap timeout")
s.getDifferenceLogger(ctx)
case <-ctx.Done():
if len(s.pts.pending) > 0 {
// This will probably fail.
s.getDifferenceLogger(ctx)
}
return ctx.Err()
case <-s.idleTimeout.C:
s.log.Debug("Idle timeout")
s.resetIdleTimer()
s.getDifferenceLogger(ctx)
}
}
}
func (s *channelState) handleUpdate(ctx context.Context, u tg.UpdateClass, ents entities) error {
ctx, span := s.tracer.Start(ctx, "channelState.handleUpdate")
defer span.End()
s.resetIdleTimer()
if long, ok := u.(*tg.UpdateChannelTooLong); ok {
return s.handleTooLong(ctx, long)
}
channelID, pts, ptsCount, ok, err := tg.IsChannelPtsUpdate(u)
if err != nil {
return errors.Wrap(err, "invalid update")
}
if !ok {
return errors.Errorf("expected channel update, got: %T", u)
}
if channelID != s.channelID {
return errors.Errorf("update for wrong channel (channelID: %d)", channelID)
}
return s.pts.Handle(ctx, update{
Value: u,
State: pts,
Count: ptsCount,
Entities: ents,
})
}
func (s *channelState) handleTooLong(ctx context.Context, long *tg.UpdateChannelTooLong) error {
ctx, span := s.tracer.Start(ctx, "channelState.handleTooLong")
defer span.End()
remotePts, ok := long.GetPts()
if !ok {
s.log.Warn("Got UpdateChannelTooLong without pts field")
return s.getDifference(ctx)
}
// Note: we still can fetch latest diffLim updates.
// Should we do?
if remotePts-s.pts.State() > s.diffLim {
return s.onTooLong(s.channelID)
}
return s.getDifference(ctx)
}
func (s *channelState) applyPts(ctx context.Context, state int, updates []update) error {
ctx, span := s.tracer.Start(ctx, "channelState.applyPts")
defer span.End()
var (
converted []tg.UpdateClass
ents entities
)
for _, update := range updates {
converted = append(converted, update.Value.(tg.UpdateClass))
ents.Merge(update.Entities)
}
if err := s.handler.Handle(ctx, &tg.Updates{
Updates: converted,
Users: ents.Users,
Chats: ents.Chats,
}); err != nil {
s.log.Error("Handle update error", zap.Error(err))
return nil
}
if err := s.storage.SetChannelPts(ctx, s.selfID, s.channelID, state); err != nil {
s.log.Error("SetChannelPts error", zap.Error(err))
}
return nil
}
func (s *channelState) getDifference(ctx context.Context) error {
ctx, span := s.tracer.Start(ctx, "channelState.getDifference")
defer span.End()
s.pts.gaps.Clear()
s.log.Debug("Getting difference")
if now := time.Now(); now.Before(s.diffTimeout) {
dur := s.diffTimeout.Sub(now)
s.log.Debug("GetChannelDifference timeout", zap.Duration("duration", dur))
if err := func() error {
afterC := time.After(dur)
for {
select {
case <-afterC:
return nil
case _, ok := <-s.updates:
if !ok {
continue
}
// Ignoring updates to prevent *internalState worker from blocking.
// All ignored updates should be restored by future getChannelDifference call.
// At least I hope so...
s.log.Debug("Ignoring update due to getChannelDifference timeout") // , zap.Any("update", u.update))
case <-ctx.Done():
return ctx.Err()
}
}
}(); err != nil {
return err
}
}
diff, err := s.client.UpdatesGetChannelDifference(ctx, &tg.UpdatesGetChannelDifferenceRequest{
Channel: &tg.InputChannel{
ChannelID: s.channelID,
AccessHash: s.accessHash,
},
Filter: &tg.ChannelMessagesFilterEmpty{},
Pts: s.pts.State(),
Limit: s.diffLim,
})
if err != nil {
return errors.Wrap(err, "get channel difference")
}
switch diff := diff.(type) {
case *tg.UpdatesChannelDifference:
if len(diff.OtherUpdates) > 0 {
select {
case s.out <- tracedUpdate{
span: trace.SpanContextFromContext(ctx),
update: &tg.Updates{
Updates: diff.OtherUpdates,
Users: diff.Users,
Chats: diff.Chats,
},
}:
case <-ctx.Done():
return ctx.Err()
}
}
if len(diff.NewMessages) > 0 {
if err := s.handler.Handle(ctx, &tg.Updates{
Updates: msgsToUpdates(diff.NewMessages, true),
Users: diff.Users,
Chats: diff.Chats,
}); err != nil {
return err
}
}
if err := s.storage.SetChannelPts(ctx, s.selfID, s.channelID, diff.Pts); err != nil {
s.log.Warn("SetChannelPts error", zap.Error(err))
}
s.pts.SetState(diff.Pts, "updates.channelDifference")
if seconds, ok := diff.GetTimeout(); ok {
s.diffTimeout = time.Now().Add(time.Second * time.Duration(seconds))
}
if !diff.Final {
return s.getDifference(ctx)
}
return nil
case *tg.UpdatesChannelDifferenceEmpty:
if err := s.storage.SetChannelPts(ctx, s.selfID, s.channelID, diff.Pts); err != nil {
s.log.Warn("SetChannelPts error", zap.Error(err))
}
s.pts.SetState(diff.Pts, "updates.channelDifferenceEmpty")
if seconds, ok := diff.GetTimeout(); ok {
s.diffTimeout = time.Now().Add(time.Second * time.Duration(seconds))
}
return nil
case *tg.UpdatesChannelDifferenceTooLong:
if seconds, ok := diff.GetTimeout(); ok {
s.diffTimeout = time.Now().Add(time.Second * time.Duration(seconds))
}
remotePts, err := getDialogPts(diff.Dialog)
if err != nil {
s.log.Warn("UpdatesChannelDifferenceTooLong invalid Dialog", zap.Error(err))
} else {
if err := s.storage.SetChannelPts(ctx, s.selfID, s.channelID, remotePts); err != nil {
s.log.Warn("SetChannelPts error", zap.Error(err))
}
s.pts.SetState(remotePts, "updates.channelDifferenceTooLong dialog new pts")
}
return s.onTooLong(s.channelID)
default:
return errors.Errorf("unexpected channel diff type: %T", diff)
}
}
func (s *channelState) getDifferenceLogger(ctx context.Context) {
if err := s.getDifference(ctx); err != nil {
s.log.Error("get channel difference error", zap.Error(err))
}
}
func (s *channelState) resetIdleTimer() {
if len(s.idleTimeout.C) > 0 {
<-s.idleTimeout.C
}
_ = s.idleTimeout.Reset(idleTimeout)
}
+47
View File
@@ -0,0 +1,47 @@
package updates
import (
"context"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
)
// State is the user internalState.
type State struct {
Pts, Qts, Date, Seq int
}
func (s State) fromRemote(remote *tg.UpdatesState) State {
return State{
Pts: remote.Pts,
Qts: remote.Qts,
Date: remote.Date,
Seq: remote.Seq,
}
}
// StateStorage is the users internalState storage.
//
// Note:
// SetPts, SetQts, SetDate, SetSeq, SetDateSeq
// should return error if user internalState does not exist.
type StateStorage interface {
GetState(ctx context.Context, userID int64) (state State, found bool, err error)
SetState(ctx context.Context, userID int64, state State) error
SetPts(ctx context.Context, userID int64, pts int) error
SetQts(ctx context.Context, userID int64, qts int) error
SetDate(ctx context.Context, userID int64, date int) error
SetSeq(ctx context.Context, userID int64, seq int) error
SetDateSeq(ctx context.Context, userID int64, date, seq int) error
GetChannelPts(ctx context.Context, userID, channelID int64) (pts int, found bool, err error)
SetChannelPts(ctx context.Context, userID, channelID int64, pts int) error
ForEachChannels(ctx context.Context, userID int64, f func(ctx context.Context, channelID int64, pts int) error) error
}
// AccessHasher stores user and channel access hashes for a user.
type AccessHasher interface {
SetChannelAccessHash(ctx context.Context, forUserID, channelID, accessHash int64) error
GetChannelAccessHash(ctx context.Context, forUserID, channelID int64) (accessHash int64, found bool, err error)
SetUserAccessHash(ctx context.Context, forUserID, userID, accessHash int64) error
GetUserAccessHash(ctx context.Context, forUserID, userID int64) (accessHash int64, found bool, err error)
}
+224
View File
@@ -0,0 +1,224 @@
package updates
import (
"context"
"sync"
"github.com/go-faster/errors"
)
var _ StateStorage = (*memStorage)(nil)
type memStorage struct {
states map[int64]State
channels map[int64]map[int64]int
mux sync.Mutex
}
func newMemStorage() *memStorage {
return &memStorage{
states: map[int64]State{},
channels: map[int64]map[int64]int{},
}
}
func (s *memStorage) GetState(ctx context.Context, userID int64) (state State, found bool, err error) {
s.mux.Lock()
defer s.mux.Unlock()
state, found = s.states[userID]
return
}
func (s *memStorage) SetState(ctx context.Context, userID int64, state State) error {
s.mux.Lock()
defer s.mux.Unlock()
s.states[userID] = state
s.channels[userID] = map[int64]int{}
return nil
}
func (s *memStorage) SetPts(ctx context.Context, userID int64, pts int) error {
s.mux.Lock()
defer s.mux.Unlock()
state, ok := s.states[userID]
if !ok {
return errors.New("internalState not found")
}
state.Pts = pts
s.states[userID] = state
return nil
}
func (s *memStorage) SetQts(ctx context.Context, userID int64, qts int) error {
s.mux.Lock()
defer s.mux.Unlock()
state, ok := s.states[userID]
if !ok {
return errors.New("internalState not found")
}
state.Qts = qts
s.states[userID] = state
return nil
}
func (s *memStorage) SetDate(ctx context.Context, userID int64, date int) error {
s.mux.Lock()
defer s.mux.Unlock()
state, ok := s.states[userID]
if !ok {
return errors.New("internalState not found")
}
state.Date = date
s.states[userID] = state
return nil
}
func (s *memStorage) SetSeq(ctx context.Context, userID int64, seq int) error {
s.mux.Lock()
defer s.mux.Unlock()
state, ok := s.states[userID]
if !ok {
return errors.New("internalState not found")
}
state.Seq = seq
s.states[userID] = state
return nil
}
func (s *memStorage) SetDateSeq(ctx context.Context, userID int64, date, seq int) error {
s.mux.Lock()
defer s.mux.Unlock()
state, ok := s.states[userID]
if !ok {
return errors.New("internalState not found")
}
state.Date = date
state.Seq = seq
s.states[userID] = state
return nil
}
func (s *memStorage) SetChannelPts(ctx context.Context, userID, channelID int64, pts int) error {
s.mux.Lock()
defer s.mux.Unlock()
channels, ok := s.channels[userID]
if !ok {
return errors.New("user internalState does not exist")
}
channels[channelID] = pts
return nil
}
func (s *memStorage) GetChannelPts(ctx context.Context, userID, channelID int64) (pts int, found bool, err error) {
s.mux.Lock()
defer s.mux.Unlock()
channels, ok := s.channels[userID]
if !ok {
return 0, false, nil
}
pts, found = channels[channelID]
return
}
func (s *memStorage) ForEachChannels(ctx context.Context, userID int64, f func(ctx context.Context, channelID int64, pts int) error) error {
s.mux.Lock()
defer s.mux.Unlock()
cmap, ok := s.channels[userID]
if !ok {
return errors.New("channels map does not exist")
}
for id, pts := range cmap {
if err := f(ctx, id, pts); err != nil {
return err
}
}
return nil
}
var _ AccessHasher = (*memAccessHasher)(nil)
type memAccessHasher struct {
channelHashes map[int64]map[int64]int64
userHashes map[int64]map[int64]int64
mux sync.Mutex
}
func newMemAccessHasher() *memAccessHasher {
return &memAccessHasher{
channelHashes: map[int64]map[int64]int64{},
userHashes: map[int64]map[int64]int64{},
}
}
func (m *memAccessHasher) GetChannelAccessHash(ctx context.Context, forUserID, channelID int64) (accessHash int64, found bool, err error) {
m.mux.Lock()
defer m.mux.Unlock()
accessHashes, ok := m.channelHashes[forUserID]
if !ok {
return 0, false, nil
}
accessHash, found = accessHashes[channelID]
return
}
func (m *memAccessHasher) SetChannelAccessHash(ctx context.Context, forUserID, channelID, accessHash int64) error {
m.mux.Lock()
defer m.mux.Unlock()
accessHashes, ok := m.channelHashes[forUserID]
if !ok {
accessHashes = map[int64]int64{}
m.channelHashes[forUserID] = accessHashes
}
accessHashes[channelID] = accessHash
return nil
}
func (m *memAccessHasher) GetUserAccessHash(ctx context.Context, forUserID, userID int64) (accessHash int64, found bool, err error) {
m.mux.Lock()
defer m.mux.Unlock()
accessHashes, ok := m.userHashes[forUserID]
if !ok {
return 0, false, nil
}
accessHash, found = accessHashes[userID]
return
}
func (m *memAccessHasher) SetUserAccessHash(ctx context.Context, forUserID, userID, accessHash int64) error {
m.mux.Lock()
defer m.mux.Unlock()
accessHashes, ok := m.userHashes[forUserID]
if !ok {
accessHashes = map[int64]int64{}
m.channelHashes[forUserID] = accessHashes
}
accessHashes[userID] = accessHash
return nil
}
+53
View File
@@ -0,0 +1,53 @@
package updates
import (
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
)
type update struct {
Value any
State int
Count int
Entities entities
}
func (u update) start() int { return u.State - u.Count }
func (u update) end() int { return u.State }
// Entities contains update entities.
type entities struct {
Users []tg.UserClass
Chats []tg.ChatClass
}
// Merge merges entities.
func (e *entities) Merge(from entities) {
for _, candidate := range from.Users {
merge := true
for _, exist := range e.Users {
if exist.GetID() == candidate.GetID() {
merge = false
break
}
}
if merge {
e.Users = append(e.Users, candidate)
}
}
for _, candidate := range from.Chats {
merge := true
for _, exist := range e.Chats {
if exist.GetID() == candidate.GetID() {
merge = false
break
}
}
if merge {
e.Chats = append(e.Chats, candidate)
}
}
}
+87
View File
@@ -0,0 +1,87 @@
package updates
import (
"github.com/go-faster/errors"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
)
func validatePts(pts, ptsCount int) error {
if pts < 0 {
return errors.Errorf("invalid pts value: %d", pts)
}
if ptsCount < 0 {
return errors.Errorf("invalid ptsCount value: %d", ptsCount)
}
return nil
}
func validateQts(qts int) error {
if qts < 0 {
return errors.Errorf("invalid qts value: %d", qts)
}
return nil
}
func validateSeq(seq, seqStart int) error {
if seq < 0 {
return errors.Errorf("invalid seq value: %d", seq)
}
if seqStart < 0 {
return errors.Errorf("invalid seqStart value: %d", seq)
}
return nil
}
func getDialogPts(dialog tg.DialogClass) (int, error) {
d, ok := dialog.(*tg.Dialog)
if !ok {
return 0, errors.Errorf("unexpected dialog type: %T", dialog)
}
pts, ok := d.GetPts()
if !ok {
return 0, errors.New("dialog has no pts field")
}
return pts, nil
}
func msgsToUpdates(msgs []tg.MessageClass, channel bool) []tg.UpdateClass {
updates := make([]tg.UpdateClass, 0, len(msgs))
for _, msg := range msgs {
if channel {
updates = append(updates, &tg.UpdateNewChannelMessage{
Message: msg,
Pts: -1,
PtsCount: -1,
})
continue
}
updates = append(updates, &tg.UpdateNewMessage{
Message: msg,
Pts: -1,
PtsCount: -1,
})
}
return updates
}
func encryptedMsgsToUpdates(msgs []tg.EncryptedMessageClass) []tg.UpdateClass {
updates := make([]tg.UpdateClass, 0, len(msgs))
for _, msg := range msgs {
updates = append(updates, &tg.UpdateNewEncryptedMessage{
Message: msg,
Qts: -1,
})
}
return updates
}