move gotd fork into repo. (#111)

- update to latest telegram layer
- remove some references to fields in tg.Entities that don't exist in
the schema
- originally added here:
https://github.com/beeper/td/commit/820929062a2ba0104397bc01235ab58a9cff780e
  - referenced here
-
https://github.com/mautrix/telegramgo/commit/124f0967ed195b5a380c9bd02e170ada9710dde3
-
https://github.com/mautrix/telegramgo/commit/4205047aab2e0639217148b5d125bfaab668bd8e
This commit is contained in:
Adam Van Ymeren
2025-06-27 20:03:37 -07:00
committed by GitHub
parent 0952df0244
commit 7a04f298d2
19264 changed files with 1539697 additions and 84 deletions
+45
View File
@@ -0,0 +1,45 @@
package rpc
import (
"go.uber.org/zap"
)
// NotifyAcks notifies engine about received acknowledgements.
func (e *Engine) NotifyAcks(ids []int64) {
e.mux.Lock()
defer e.mux.Unlock()
for _, id := range ids {
ch, ok := e.ack[id]
if !ok {
e.log.Debug("Acknowledge callback not set", zap.Int64("msg_id", id))
continue
}
close(ch)
delete(e.ack, id)
}
}
func (e *Engine) waitAck(id int64) chan struct{} {
e.mux.Lock()
defer e.mux.Unlock()
log := e.log.With(zap.Int64("ack_id", id))
if c, found := e.ack[id]; found {
log.Warn("Ack already registered")
return c
}
log.Debug("Waiting for acknowledge")
c := make(chan struct{})
e.ack[id] = c
return c
}
func (e *Engine) removeAck(id int64) {
e.mux.Lock()
defer e.mux.Unlock()
delete(e.ack, id)
}
+2
View File
@@ -0,0 +1,2 @@
// Package rpc implements rpc engine.
package rpc
+286
View File
@@ -0,0 +1,286 @@
package rpc
import (
"context"
"sync"
"sync/atomic"
"time"
"github.com/go-faster/errors"
"go.uber.org/zap"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
"go.mau.fi/mautrix-telegram/pkg/gotd/clock"
)
// Engine handles RPC requests.
type Engine struct {
send Send
drop DropHandler
mux sync.Mutex
rpc map[int64]func(*bin.Buffer, error) error
ack map[int64]chan struct{}
clock clock.Clock
log *zap.Logger
retryInterval time.Duration
maxRetries int
// Canceling pending requests in ForceClose.
reqCtx context.Context
reqCancel context.CancelFunc
wg sync.WaitGroup
closed uint32
onError func(error)
}
// New creates new rpc Engine.
func New(send Send, cfg Options) *Engine {
cfg.setDefaults()
cfg.Logger.Info("Initialized",
zap.Duration("retry_interval", cfg.RetryInterval),
zap.Int("max_retries", cfg.MaxRetries),
)
reqCtx, reqCancel := context.WithCancel(context.Background())
return &Engine{
rpc: map[int64]func(*bin.Buffer, error) error{},
ack: map[int64]chan struct{}{},
send: send,
drop: cfg.DropHandler,
log: cfg.Logger,
maxRetries: cfg.MaxRetries,
retryInterval: cfg.RetryInterval,
clock: cfg.Clock,
reqCtx: reqCtx,
reqCancel: reqCancel,
onError: cfg.OnError,
}
}
// Request represents client RPC request.
type Request struct {
MsgID int64
SeqNo int32
Input bin.Encoder
Output bin.Decoder
}
// Do sends request to server and blocks until response is received, performing
// multiple retries if needed.
func (e *Engine) Do(ctx context.Context, req Request) error {
if e.isClosed() {
return ErrEngineClosed
}
e.wg.Add(1)
defer e.wg.Done()
retryCtx, retryClose := context.WithCancel(ctx)
defer retryClose()
log := e.log.With(zap.Int64("msg_id", req.MsgID))
log.Debug("Do called")
done := make(chan struct{})
var (
// Handler result.
resultErr error
// Needed to prevent multiple handler calls.
handlerCalled uint32
)
handler := func(rpcBuff *bin.Buffer, rpcErr error) error {
log.Debug("Handler called")
if ok := atomic.CompareAndSwapUint32(&handlerCalled, 0, 1); !ok {
log.Warn("Handler already called")
return errors.New("handler already called")
}
defer retryClose()
defer close(done)
if rpcErr != nil {
resultErr = rpcErr
return nil
}
resultErr = req.Output.Decode(rpcBuff)
return resultErr
}
// Setting callback that will be called if message is received.
e.mux.Lock()
e.rpc[req.MsgID] = handler
e.mux.Unlock()
defer func() {
// Ensuring that callback can't be called after function return.
e.mux.Lock()
delete(e.rpc, req.MsgID)
e.mux.Unlock()
}()
// Start retrying.
sent, err := e.retryUntilAck(retryCtx, req)
if err != nil && !errors.Is(err, retryCtx.Err()) {
// If the retryCtx was canceled, then one of two things happened:
// 1. User canceled the parent context.
// 2. The RPC result came and callback canceled retryCtx.
//
// If this is not a Contexts error, most likely we did not receive ack
// and exceeded the limit of attempts to send a request,
// or could not write data to the connection, so we return an error.
return errors.Wrap(err, "retryUntilAck")
}
select {
case <-ctx.Done():
if !sent {
return ctx.Err()
}
// Set nop callback because server will respond with 'RpcDropAnswer' instead of expected result.
//
// NOTE(ccln): We can decode 'RpcDropAnswer' here but I see no reason to do this
// because it will also come as a response to 'RPCDropAnswerRequest'.
//
// https://core.telegram.org/mtproto/service_messages#cancellation-of-an-rpc-query
e.mux.Lock()
e.rpc[req.MsgID] = func(b *bin.Buffer, e error) error { return nil }
e.mux.Unlock()
if err := e.drop(req); err != nil {
log.Info("Failed to drop request", zap.Error(err))
return ctx.Err()
}
log.Debug("Request dropped")
return ctx.Err()
case <-e.reqCtx.Done():
return errors.Wrap(e.reqCtx.Err(), "engine forcibly closed")
case <-done:
return resultErr
}
}
// retryUntilAck resends the request to the server until request is
// acknowledged.
//
// Returns nil if acknowledge was received or error otherwise.
func (e *Engine) retryUntilAck(ctx context.Context, req Request) (sent bool, err error) {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
var (
ackChan = e.waitAck(req.MsgID)
retries = 0
log = e.log.Named("retry").With(zap.Int64("msg_id", req.MsgID))
)
defer e.removeAck(req.MsgID)
// Encoding request.
if err := e.send(ctx, req.MsgID, req.SeqNo, req.Input); err != nil {
return false, errors.Wrap(err, "send")
}
loop := func() error {
timer := e.clock.Timer(e.retryInterval)
defer clock.StopTimer(timer)
for {
select {
case <-ctx.Done():
return ctx.Err()
case <-e.reqCtx.Done():
return errors.Wrap(e.reqCtx.Err(), "engine forcibly closed")
case <-ackChan:
log.Debug("Acknowledged")
return nil
case <-timer.C():
timer.Reset(e.retryInterval)
log.Debug("Acknowledge timed out, performing retry")
if err := e.send(ctx, req.MsgID, req.SeqNo, req.Input); err != nil {
if errors.Is(err, context.Canceled) {
return nil
}
log.Error("Retry failed", zap.Error(err))
return err
}
retries++
if retries >= e.maxRetries {
log.Error("Retry limit reached", zap.Int64("msg_id", req.MsgID))
return &RetryLimitReachedErr{
Retries: retries,
}
}
}
}
}
return true, loop()
}
// NotifyResult notifies engine about received RPC response.
func (e *Engine) NotifyResult(msgID int64, b *bin.Buffer) error {
e.mux.Lock()
fn, ok := e.rpc[msgID]
e.mux.Unlock()
if !ok {
e.log.Warn("rpc callback not set", zap.Int64("msg_id", msgID))
return nil
}
return fn(b, nil)
}
// NotifyError notifies engine about received RPC error.
func (e *Engine) NotifyError(msgID int64, rpcErr error) {
e.onError(rpcErr)
e.mux.Lock()
fn, ok := e.rpc[msgID]
e.mux.Unlock()
if !ok {
e.log.Warn("rpc callback not set", zap.Int64("msg_id", msgID))
return
}
// Callback with rpcError always return nil.
_ = fn(nil, rpcErr)
}
func (e *Engine) isClosed() bool {
return atomic.LoadUint32(&e.closed) == 1
}
// Close gracefully closes the engine.
// All pending requests will be awaited.
// All Do method calls of closed engine will return ErrEngineClosed error.
func (e *Engine) Close() {
atomic.StoreUint32(&e.closed, 1)
e.log.Info("Close called")
e.wg.Wait()
}
// ForceClose forcibly closes the engine.
// All pending requests will be canceled.
// All Do method calls of closed engine will return ErrEngineClosed error.
func (e *Engine) ForceClose() {
e.reqCancel()
e.Close()
}
+74
View File
@@ -0,0 +1,74 @@
package rpc
import (
"bytes"
"context"
"sync/atomic"
"testing"
"github.com/go-faster/errors"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
)
// mockObject implements bin.Object for testing.
type mockObject struct {
data []byte
}
func (m mockObject) Decode(b *bin.Buffer) error {
if !bytes.Equal(b.Buf, m.data) {
return errors.New("mismatch")
}
b.Skip(len(b.Buf))
return nil
}
func (m mockObject) Encode(b *bin.Buffer) error {
b.Put(m.data)
return nil
}
func BenchmarkEngine_Do(b *testing.B) {
ids := make(chan int64, 100)
defer close(ids)
e := New(func(ctx context.Context, msgID int64, seqNo int32, in bin.Encoder) error {
ids <- msgID
return nil
}, Options{})
var id int64
ctx := context.Background()
b.ReportAllocs()
b.RunParallel(func(pb *testing.PB) {
go func() {
buf := &bin.Buffer{}
// Fake handler.
obj := mockObject{data: make([]byte, 100)}
for id := range ids {
e.NotifyAcks([]int64{id})
buf.ResetTo(obj.data)
if err := e.NotifyResult(id, buf); err != nil {
b.Error(err)
}
}
}()
obj := mockObject{data: make([]byte, 100)}
for pb.Next() {
nextID := atomic.AddInt64(&id, 1)
if err := e.Do(ctx, Request{
MsgID: nextID,
Input: obj,
Output: obj,
}); err != nil {
b.Fatal(err)
}
}
})
}
+457
View File
@@ -0,0 +1,457 @@
package rpc
import (
"sync"
"testing"
"time"
"github.com/go-faster/errors"
"github.com/stretchr/testify/require"
"go.uber.org/zap/zaptest"
"golang.org/x/net/context"
"golang.org/x/sync/errgroup"
"github.com/gotd/neo"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
"go.mau.fi/mautrix-telegram/pkg/gotd/mt"
"go.mau.fi/mautrix-telegram/pkg/gotd/testutil"
)
type request struct {
MsgID int64
SeqNo int32
Input bin.Encoder
}
var defaultNow = testutil.Date()
const (
msgID int64 = 1
pingID int64 = 1337
seqNo int32 = 1
)
func TestRPCError(t *testing.T) {
clock := neo.NewTime(defaultNow)
observer := clock.Observe()
expectedErr := errors.New("server side error")
log := zaptest.NewLogger(t)
server := func(t *testing.T, e *Engine, incoming <-chan request) error {
log := log.Named("server")
log.Info("Waiting ping request")
require.Equal(t, request{
MsgID: msgID,
SeqNo: seqNo,
Input: &mt.PingRequest{PingID: pingID},
}, <-incoming)
log.Info("Got ping request")
// Make sure that client calls time.After
// before time travel
<-observer
log.Info("Traveling into the future for a second (simulate job)")
clock.Travel(time.Second)
log.Info("Sending RPC error")
e.NotifyError(msgID, expectedErr)
return nil
}
client := func(t *testing.T, e *Engine) error {
log := log.Named("client")
log.Info("Sending ping request")
err := e.Do(context.TODO(), Request{
MsgID: msgID,
SeqNo: seqNo,
Input: &mt.PingRequest{
PingID: pingID,
},
})
log.Info("Got pong response")
require.True(t, errors.Is(err, expectedErr), "expected error")
return nil
}
runTest(t, Options{
RetryInterval: time.Second * 3,
MaxRetries: 2,
Clock: clock,
Logger: log.Named("rpc"),
}, server, client)
}
func TestRPCResult(t *testing.T) {
clock := neo.NewTime(defaultNow)
observer := clock.Observe()
log := zaptest.NewLogger(t)
server := func(t *testing.T, e *Engine, incoming <-chan request) error {
log := log.Named("server")
log.Info("Waiting ping request")
require.Equal(t, request{
MsgID: msgID,
SeqNo: seqNo,
Input: &mt.PingRequest{PingID: pingID},
}, <-incoming)
log.Info("Got ping request")
// Make sure that engine calls time.After
// before time travel.
<-observer
log.Info("Traveling into the future for 2 seconds (simulate job)")
clock.Travel(time.Second * 2)
var b bin.Buffer
if err := b.Encode(&mt.Pong{
MsgID: msgID,
PingID: pingID,
}); err != nil {
return err
}
log.Info("Sending pong response")
return e.NotifyResult(msgID, &b)
}
client := func(t *testing.T, e *Engine) error {
log := log.Named("client")
log.Info("Sending ping request")
var out mt.Pong
require.NoError(t, e.Do(context.TODO(), Request{
MsgID: msgID,
SeqNo: seqNo,
Input: &mt.PingRequest{PingID: pingID},
Output: &out,
}))
log.Info("Got pong response")
require.Equal(t, mt.Pong{
MsgID: msgID,
PingID: pingID,
}, out)
return nil
}
runTest(t, Options{
RetryInterval: time.Second * 4,
MaxRetries: 2,
Clock: clock,
Logger: log.Named("rpc"),
}, server, client)
}
func TestRPCAckThenResult(t *testing.T) {
clock := neo.NewTime(defaultNow)
observer := clock.Observe()
log := zaptest.NewLogger(t)
server := func(t *testing.T, e *Engine, incoming <-chan request) error {
log := log.Named("server")
log.Info("Waiting ping request")
require.Equal(t, request{
MsgID: msgID,
SeqNo: seqNo,
Input: &mt.PingRequest{PingID: pingID},
}, <-incoming)
// Make sure that client calls time.After
// before time travel.
<-observer
log.Info("Traveling into the future for 2 seconds (simulate job)")
clock.Travel(time.Second * 2)
log.Info("Sending ACK")
e.NotifyAcks([]int64{msgID})
log.Info("Traveling into the future for 6 seconds (simulate request processing)")
clock.Travel(time.Second * 6)
var b bin.Buffer
if err := b.Encode(&mt.Pong{
MsgID: msgID,
PingID: pingID,
}); err != nil {
return err
}
log.Info("Sending response")
return e.NotifyResult(msgID, &b)
}
client := func(t *testing.T, e *Engine) error {
log := log.Named("client")
log.Info("Sending ping request")
var out mt.Pong
require.NoError(t, e.Do(context.TODO(), Request{
MsgID: msgID,
SeqNo: seqNo,
Input: &mt.PingRequest{PingID: pingID},
Output: &out,
}))
log.Info("Got pong response")
require.Equal(t, mt.Pong{
MsgID: msgID,
PingID: pingID,
}, out)
return nil
}
runTest(t, Options{
RetryInterval: time.Second * 4,
MaxRetries: 2,
Clock: clock,
Logger: log.Named("rpc"),
}, server, client)
}
func TestRPCWithRetryResult(t *testing.T) {
clock := neo.NewTime(defaultNow)
observer := clock.Observe()
log := zaptest.NewLogger(t)
server := func(t *testing.T, e *Engine, incoming <-chan request) error {
log := log.Named("server")
log.Info("Waiting ping request")
require.Equal(t, request{
MsgID: msgID,
SeqNo: seqNo,
Input: &mt.PingRequest{PingID: pingID},
}, <-incoming)
log.Info("Got ping request")
// Make sure that client calls time.After
// before time travel.
<-observer
log.Info("Traveling into the future for 6 seconds (simulate request loss)")
clock.Travel(time.Second * 6)
log.Info("Waiting re-sending request")
require.Equal(t, request{
MsgID: msgID,
SeqNo: seqNo,
Input: &mt.PingRequest{PingID: pingID},
}, <-incoming)
log.Info("Got ping request")
var b bin.Buffer
if err := b.Encode(&mt.Pong{
MsgID: msgID,
PingID: pingID,
}); err != nil {
return err
}
log.Info("Send pong response")
return e.NotifyResult(msgID, &b)
}
client := func(t *testing.T, e *Engine) error {
log := log.Named("client")
log.Info("Sending ping request")
var out mt.Pong
require.NoError(t, e.Do(context.TODO(), Request{
MsgID: 1,
SeqNo: seqNo,
Input: &mt.PingRequest{PingID: pingID},
Output: &out,
}))
log.Info("Got pong response")
require.Equal(t, mt.Pong{
MsgID: msgID,
PingID: pingID,
}, out)
return nil
}
runTest(t, Options{
RetryInterval: time.Second * 4,
MaxRetries: 5,
Clock: clock,
Logger: log.Named("rpc"),
}, server, client)
}
func TestEngineGracefulShutdown(t *testing.T) {
var (
log = zaptest.NewLogger(t)
expectedErr = errors.New("server side error")
requestsCount = 10
serverRecv sync.WaitGroup
canSendResponse sync.Mutex
)
serverRecv.Add(requestsCount)
canSendResponse.Lock()
server := func(t *testing.T, e *Engine, incoming <-chan request) error {
log := log.Named("server")
var batch []request
for i := 0; i < requestsCount; i++ {
batch = append(batch, <-incoming)
serverRecv.Done()
}
e.log.Info("Got all requests")
canSendResponse.Lock()
e.log.Info("Sending responses")
for _, req := range batch {
log.Info("send response")
e.NotifyError(req.MsgID, expectedErr)
}
canSendResponse.Unlock()
return nil
}
client := func(t *testing.T, e *Engine) error {
var currMsgID int64
for i := 0; i < requestsCount; i++ {
go func(t *testing.T, msgID int64) {
var out mt.Pong
require.Equal(t, e.Do(context.TODO(), Request{
MsgID: msgID,
SeqNo: seqNo,
Input: &mt.PingRequest{PingID: pingID},
Output: &out,
}), expectedErr)
}(t, currMsgID)
currMsgID++
}
// wait until server receive all requests
serverRecv.Wait()
// allow server to send responses
canSendResponse.Unlock()
// close the engine
e.Close()
return nil
}
runTest(t, Options{
RetryInterval: time.Second * 5,
MaxRetries: 5,
Logger: log.Named("rpc"),
}, server, client)
}
func TestDropRPC(t *testing.T) {
clock := neo.NewTime(defaultNow)
log := zaptest.NewLogger(t)
serverRecvRequest := make(chan struct{})
clientCancelledCtx := make(chan struct{})
dropChan := make(chan Request)
server := func(t *testing.T, e *Engine, incoming <-chan request) error {
log := log.Named("server")
log.Info("Waiting ping request")
require.Equal(t, request{
MsgID: msgID,
SeqNo: seqNo,
Input: &mt.PingRequest{PingID: pingID},
}, <-incoming)
close(serverRecvRequest)
<-clientCancelledCtx
log.Info("Waiting drop request")
require.Equal(t, msgID, (<-dropChan).MsgID)
return nil
}
client := func(t *testing.T, e *Engine) error {
log := log.Named("client")
log.Info("Sending ping request")
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() {
<-serverRecvRequest
log.Info("Canceling request context")
cancel()
close(clientCancelledCtx)
}()
require.ErrorIs(t, e.Do(ctx, Request{
MsgID: msgID,
SeqNo: seqNo,
Input: &mt.PingRequest{PingID: pingID},
Output: &mt.Pong{},
}), context.Canceled)
return nil
}
runTest(t, Options{
RetryInterval: time.Second * 4,
MaxRetries: 2,
Clock: clock,
Logger: log.Named("rpc"),
DropHandler: func(req Request) error { dropChan <- req; return nil },
}, server, client)
}
func runTest(
t *testing.T,
cfg Options,
server func(t *testing.T, e *Engine, incoming <-chan request) error,
client func(t *testing.T, e *Engine) error,
) {
t.Helper()
// Channel of client requests sent to the server.
requests := make(chan request)
defer close(requests)
e := New(func(ctx context.Context, msgID int64, seqNo int32, in bin.Encoder) error {
req := request{
MsgID: msgID,
SeqNo: seqNo,
Input: in,
}
select {
case <-ctx.Done():
return ctx.Err()
case requests <- req:
return nil
}
}, cfg)
var g errgroup.Group
g.Go(func() error { return server(t, e, requests) })
g.Go(func() error { return client(t, e) })
require.NoError(t, g.Wait())
e.Close()
require.NoError(t, cfg.Logger.Sync())
}
+26
View File
@@ -0,0 +1,26 @@
package rpc
import (
"fmt"
"github.com/go-faster/errors"
)
// RetryLimitReachedErr means that server does not acknowledge request
// after multiple retries.
type RetryLimitReachedErr struct {
Retries int
}
func (r *RetryLimitReachedErr) Error() string {
return fmt.Sprintf("retry limit reached after %d attempts", r.Retries)
}
// Is reports whether err is RetryLimitReachedErr.
func (r *RetryLimitReachedErr) Is(err error) bool {
_, ok := err.(*RetryLimitReachedErr)
return ok
}
// ErrEngineClosed means that engine was closed.
var ErrEngineClosed = errors.New("engine was closed")
+23
View File
@@ -0,0 +1,23 @@
package rpc
import (
"context"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
)
// Send is a function that sends requests to the server.
type Send func(ctx context.Context, msgID int64, seqNo int32, in bin.Encoder) error
// NopSend does nothing.
func NopSend(context.Context, int64, int32, bin.Encoder) error { return nil }
var _ Send = NopSend
// DropHandler handles drop rpc requests.
type DropHandler func(req Request) error
// NopDrop does nothing.
func NopDrop(Request) error { return nil }
var _ DropHandler = NopDrop
+16
View File
@@ -0,0 +1,16 @@
package rpc
import (
"context"
"testing"
"github.com/stretchr/testify/require"
)
func TestNopDrop(t *testing.T) {
require.NoError(t, NopDrop(Request{}))
}
func TestNopSend(t *testing.T) {
require.NoError(t, NopSend(context.TODO(), 0, 0, nil))
}
+40
View File
@@ -0,0 +1,40 @@
package rpc
import (
"time"
"go.uber.org/zap"
"go.mau.fi/mautrix-telegram/pkg/gotd/clock"
)
// Options of rpc engine.
type Options struct {
RetryInterval time.Duration
MaxRetries int
Logger *zap.Logger
Clock clock.Clock
DropHandler DropHandler
OnError func(error)
}
func (cfg *Options) setDefaults() {
if cfg.RetryInterval == 0 {
cfg.RetryInterval = time.Second * 10
}
if cfg.MaxRetries == 0 {
cfg.MaxRetries = 5
}
if cfg.Logger == nil {
cfg.Logger = zap.NewNop()
}
if cfg.Clock == nil {
cfg.Clock = clock.System
}
if cfg.DropHandler == nil {
cfg.DropHandler = NopDrop
}
if cfg.OnError == nil {
cfg.OnError = func(err error) {}
}
}