move gotd fork into repo. (#111)
- update to latest telegram layer - remove some references to fields in tg.Entities that don't exist in the schema - originally added here: https://github.com/beeper/td/commit/820929062a2ba0104397bc01235ab58a9cff780e - referenced here - https://github.com/mautrix/telegramgo/commit/124f0967ed195b5a380c9bd02e170ada9710dde3 - https://github.com/mautrix/telegramgo/commit/4205047aab2e0639217148b5d125bfaab668bd8e
This commit is contained in:
@@ -0,0 +1,45 @@
|
||||
package rpc
|
||||
|
||||
import (
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// NotifyAcks notifies engine about received acknowledgements.
|
||||
func (e *Engine) NotifyAcks(ids []int64) {
|
||||
e.mux.Lock()
|
||||
defer e.mux.Unlock()
|
||||
|
||||
for _, id := range ids {
|
||||
ch, ok := e.ack[id]
|
||||
if !ok {
|
||||
e.log.Debug("Acknowledge callback not set", zap.Int64("msg_id", id))
|
||||
continue
|
||||
}
|
||||
|
||||
close(ch)
|
||||
delete(e.ack, id)
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Engine) waitAck(id int64) chan struct{} {
|
||||
e.mux.Lock()
|
||||
defer e.mux.Unlock()
|
||||
|
||||
log := e.log.With(zap.Int64("ack_id", id))
|
||||
if c, found := e.ack[id]; found {
|
||||
log.Warn("Ack already registered")
|
||||
return c
|
||||
}
|
||||
|
||||
log.Debug("Waiting for acknowledge")
|
||||
c := make(chan struct{})
|
||||
e.ack[id] = c
|
||||
return c
|
||||
}
|
||||
|
||||
func (e *Engine) removeAck(id int64) {
|
||||
e.mux.Lock()
|
||||
defer e.mux.Unlock()
|
||||
|
||||
delete(e.ack, id)
|
||||
}
|
||||
@@ -0,0 +1,2 @@
|
||||
// Package rpc implements rpc engine.
|
||||
package rpc
|
||||
@@ -0,0 +1,286 @@
|
||||
package rpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/clock"
|
||||
)
|
||||
|
||||
// Engine handles RPC requests.
|
||||
type Engine struct {
|
||||
send Send
|
||||
drop DropHandler
|
||||
|
||||
mux sync.Mutex
|
||||
rpc map[int64]func(*bin.Buffer, error) error
|
||||
ack map[int64]chan struct{}
|
||||
|
||||
clock clock.Clock
|
||||
log *zap.Logger
|
||||
retryInterval time.Duration
|
||||
maxRetries int
|
||||
|
||||
// Canceling pending requests in ForceClose.
|
||||
reqCtx context.Context
|
||||
reqCancel context.CancelFunc
|
||||
|
||||
wg sync.WaitGroup
|
||||
closed uint32
|
||||
|
||||
onError func(error)
|
||||
}
|
||||
|
||||
// New creates new rpc Engine.
|
||||
func New(send Send, cfg Options) *Engine {
|
||||
cfg.setDefaults()
|
||||
|
||||
cfg.Logger.Info("Initialized",
|
||||
zap.Duration("retry_interval", cfg.RetryInterval),
|
||||
zap.Int("max_retries", cfg.MaxRetries),
|
||||
)
|
||||
|
||||
reqCtx, reqCancel := context.WithCancel(context.Background())
|
||||
return &Engine{
|
||||
rpc: map[int64]func(*bin.Buffer, error) error{},
|
||||
ack: map[int64]chan struct{}{},
|
||||
|
||||
send: send,
|
||||
drop: cfg.DropHandler,
|
||||
|
||||
log: cfg.Logger,
|
||||
maxRetries: cfg.MaxRetries,
|
||||
retryInterval: cfg.RetryInterval,
|
||||
clock: cfg.Clock,
|
||||
|
||||
reqCtx: reqCtx,
|
||||
reqCancel: reqCancel,
|
||||
|
||||
onError: cfg.OnError,
|
||||
}
|
||||
}
|
||||
|
||||
// Request represents client RPC request.
|
||||
type Request struct {
|
||||
MsgID int64
|
||||
SeqNo int32
|
||||
Input bin.Encoder
|
||||
Output bin.Decoder
|
||||
}
|
||||
|
||||
// Do sends request to server and blocks until response is received, performing
|
||||
// multiple retries if needed.
|
||||
func (e *Engine) Do(ctx context.Context, req Request) error {
|
||||
if e.isClosed() {
|
||||
return ErrEngineClosed
|
||||
}
|
||||
|
||||
e.wg.Add(1)
|
||||
defer e.wg.Done()
|
||||
|
||||
retryCtx, retryClose := context.WithCancel(ctx)
|
||||
defer retryClose()
|
||||
|
||||
log := e.log.With(zap.Int64("msg_id", req.MsgID))
|
||||
log.Debug("Do called")
|
||||
|
||||
done := make(chan struct{})
|
||||
|
||||
var (
|
||||
// Handler result.
|
||||
resultErr error
|
||||
// Needed to prevent multiple handler calls.
|
||||
handlerCalled uint32
|
||||
)
|
||||
|
||||
handler := func(rpcBuff *bin.Buffer, rpcErr error) error {
|
||||
log.Debug("Handler called")
|
||||
|
||||
if ok := atomic.CompareAndSwapUint32(&handlerCalled, 0, 1); !ok {
|
||||
log.Warn("Handler already called")
|
||||
|
||||
return errors.New("handler already called")
|
||||
}
|
||||
|
||||
defer retryClose()
|
||||
defer close(done)
|
||||
|
||||
if rpcErr != nil {
|
||||
resultErr = rpcErr
|
||||
return nil
|
||||
}
|
||||
|
||||
resultErr = req.Output.Decode(rpcBuff)
|
||||
return resultErr
|
||||
}
|
||||
|
||||
// Setting callback that will be called if message is received.
|
||||
e.mux.Lock()
|
||||
e.rpc[req.MsgID] = handler
|
||||
e.mux.Unlock()
|
||||
|
||||
defer func() {
|
||||
// Ensuring that callback can't be called after function return.
|
||||
e.mux.Lock()
|
||||
delete(e.rpc, req.MsgID)
|
||||
e.mux.Unlock()
|
||||
}()
|
||||
|
||||
// Start retrying.
|
||||
sent, err := e.retryUntilAck(retryCtx, req)
|
||||
if err != nil && !errors.Is(err, retryCtx.Err()) {
|
||||
// If the retryCtx was canceled, then one of two things happened:
|
||||
// 1. User canceled the parent context.
|
||||
// 2. The RPC result came and callback canceled retryCtx.
|
||||
//
|
||||
// If this is not a Context’s error, most likely we did not receive ack
|
||||
// and exceeded the limit of attempts to send a request,
|
||||
// or could not write data to the connection, so we return an error.
|
||||
return errors.Wrap(err, "retryUntilAck")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if !sent {
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
// Set nop callback because server will respond with 'RpcDropAnswer' instead of expected result.
|
||||
//
|
||||
// NOTE(ccln): We can decode 'RpcDropAnswer' here but I see no reason to do this
|
||||
// because it will also come as a response to 'RPCDropAnswerRequest'.
|
||||
//
|
||||
// https://core.telegram.org/mtproto/service_messages#cancellation-of-an-rpc-query
|
||||
e.mux.Lock()
|
||||
e.rpc[req.MsgID] = func(b *bin.Buffer, e error) error { return nil }
|
||||
e.mux.Unlock()
|
||||
|
||||
if err := e.drop(req); err != nil {
|
||||
log.Info("Failed to drop request", zap.Error(err))
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
log.Debug("Request dropped")
|
||||
return ctx.Err()
|
||||
case <-e.reqCtx.Done():
|
||||
return errors.Wrap(e.reqCtx.Err(), "engine forcibly closed")
|
||||
case <-done:
|
||||
return resultErr
|
||||
}
|
||||
}
|
||||
|
||||
// retryUntilAck resends the request to the server until request is
|
||||
// acknowledged.
|
||||
//
|
||||
// Returns nil if acknowledge was received or error otherwise.
|
||||
func (e *Engine) retryUntilAck(ctx context.Context, req Request) (sent bool, err error) {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
var (
|
||||
ackChan = e.waitAck(req.MsgID)
|
||||
retries = 0
|
||||
log = e.log.Named("retry").With(zap.Int64("msg_id", req.MsgID))
|
||||
)
|
||||
|
||||
defer e.removeAck(req.MsgID)
|
||||
|
||||
// Encoding request.
|
||||
if err := e.send(ctx, req.MsgID, req.SeqNo, req.Input); err != nil {
|
||||
return false, errors.Wrap(err, "send")
|
||||
}
|
||||
|
||||
loop := func() error {
|
||||
timer := e.clock.Timer(e.retryInterval)
|
||||
defer clock.StopTimer(timer)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-e.reqCtx.Done():
|
||||
return errors.Wrap(e.reqCtx.Err(), "engine forcibly closed")
|
||||
case <-ackChan:
|
||||
log.Debug("Acknowledged")
|
||||
return nil
|
||||
case <-timer.C():
|
||||
timer.Reset(e.retryInterval)
|
||||
|
||||
log.Debug("Acknowledge timed out, performing retry")
|
||||
if err := e.send(ctx, req.MsgID, req.SeqNo, req.Input); err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Error("Retry failed", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
retries++
|
||||
if retries >= e.maxRetries {
|
||||
log.Error("Retry limit reached", zap.Int64("msg_id", req.MsgID))
|
||||
return &RetryLimitReachedErr{
|
||||
Retries: retries,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true, loop()
|
||||
}
|
||||
|
||||
// NotifyResult notifies engine about received RPC response.
|
||||
func (e *Engine) NotifyResult(msgID int64, b *bin.Buffer) error {
|
||||
e.mux.Lock()
|
||||
fn, ok := e.rpc[msgID]
|
||||
e.mux.Unlock()
|
||||
if !ok {
|
||||
e.log.Warn("rpc callback not set", zap.Int64("msg_id", msgID))
|
||||
return nil
|
||||
}
|
||||
|
||||
return fn(b, nil)
|
||||
}
|
||||
|
||||
// NotifyError notifies engine about received RPC error.
|
||||
func (e *Engine) NotifyError(msgID int64, rpcErr error) {
|
||||
e.onError(rpcErr)
|
||||
e.mux.Lock()
|
||||
fn, ok := e.rpc[msgID]
|
||||
e.mux.Unlock()
|
||||
if !ok {
|
||||
e.log.Warn("rpc callback not set", zap.Int64("msg_id", msgID))
|
||||
return
|
||||
}
|
||||
|
||||
// Callback with rpcError always return nil.
|
||||
_ = fn(nil, rpcErr)
|
||||
}
|
||||
|
||||
func (e *Engine) isClosed() bool {
|
||||
return atomic.LoadUint32(&e.closed) == 1
|
||||
}
|
||||
|
||||
// Close gracefully closes the engine.
|
||||
// All pending requests will be awaited.
|
||||
// All Do method calls of closed engine will return ErrEngineClosed error.
|
||||
func (e *Engine) Close() {
|
||||
atomic.StoreUint32(&e.closed, 1)
|
||||
e.log.Info("Close called")
|
||||
e.wg.Wait()
|
||||
}
|
||||
|
||||
// ForceClose forcibly closes the engine.
|
||||
// All pending requests will be canceled.
|
||||
// All Do method calls of closed engine will return ErrEngineClosed error.
|
||||
func (e *Engine) ForceClose() {
|
||||
e.reqCancel()
|
||||
e.Close()
|
||||
}
|
||||
@@ -0,0 +1,74 @@
|
||||
package rpc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
)
|
||||
|
||||
// mockObject implements bin.Object for testing.
|
||||
type mockObject struct {
|
||||
data []byte
|
||||
}
|
||||
|
||||
func (m mockObject) Decode(b *bin.Buffer) error {
|
||||
if !bytes.Equal(b.Buf, m.data) {
|
||||
return errors.New("mismatch")
|
||||
}
|
||||
b.Skip(len(b.Buf))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m mockObject) Encode(b *bin.Buffer) error {
|
||||
b.Put(m.data)
|
||||
return nil
|
||||
}
|
||||
|
||||
func BenchmarkEngine_Do(b *testing.B) {
|
||||
ids := make(chan int64, 100)
|
||||
defer close(ids)
|
||||
|
||||
e := New(func(ctx context.Context, msgID int64, seqNo int32, in bin.Encoder) error {
|
||||
ids <- msgID
|
||||
return nil
|
||||
}, Options{})
|
||||
|
||||
var id int64
|
||||
|
||||
ctx := context.Background()
|
||||
b.ReportAllocs()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
go func() {
|
||||
buf := &bin.Buffer{}
|
||||
// Fake handler.
|
||||
obj := mockObject{data: make([]byte, 100)}
|
||||
|
||||
for id := range ids {
|
||||
e.NotifyAcks([]int64{id})
|
||||
|
||||
buf.ResetTo(obj.data)
|
||||
if err := e.NotifyResult(id, buf); err != nil {
|
||||
b.Error(err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
obj := mockObject{data: make([]byte, 100)}
|
||||
|
||||
for pb.Next() {
|
||||
nextID := atomic.AddInt64(&id, 1)
|
||||
if err := e.Do(ctx, Request{
|
||||
MsgID: nextID,
|
||||
Input: obj,
|
||||
Output: obj,
|
||||
}); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,457 @@
|
||||
package rpc
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap/zaptest"
|
||||
"golang.org/x/net/context"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/gotd/neo"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/mt"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/testutil"
|
||||
)
|
||||
|
||||
type request struct {
|
||||
MsgID int64
|
||||
SeqNo int32
|
||||
Input bin.Encoder
|
||||
}
|
||||
|
||||
var defaultNow = testutil.Date()
|
||||
|
||||
const (
|
||||
msgID int64 = 1
|
||||
pingID int64 = 1337
|
||||
seqNo int32 = 1
|
||||
)
|
||||
|
||||
func TestRPCError(t *testing.T) {
|
||||
clock := neo.NewTime(defaultNow)
|
||||
observer := clock.Observe()
|
||||
expectedErr := errors.New("server side error")
|
||||
log := zaptest.NewLogger(t)
|
||||
|
||||
server := func(t *testing.T, e *Engine, incoming <-chan request) error {
|
||||
log := log.Named("server")
|
||||
|
||||
log.Info("Waiting ping request")
|
||||
require.Equal(t, request{
|
||||
MsgID: msgID,
|
||||
SeqNo: seqNo,
|
||||
Input: &mt.PingRequest{PingID: pingID},
|
||||
}, <-incoming)
|
||||
|
||||
log.Info("Got ping request")
|
||||
|
||||
// Make sure that client calls time.After
|
||||
// before time travel
|
||||
<-observer
|
||||
|
||||
log.Info("Traveling into the future for a second (simulate job)")
|
||||
clock.Travel(time.Second)
|
||||
|
||||
log.Info("Sending RPC error")
|
||||
e.NotifyError(msgID, expectedErr)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
client := func(t *testing.T, e *Engine) error {
|
||||
log := log.Named("client")
|
||||
|
||||
log.Info("Sending ping request")
|
||||
err := e.Do(context.TODO(), Request{
|
||||
MsgID: msgID,
|
||||
SeqNo: seqNo,
|
||||
Input: &mt.PingRequest{
|
||||
PingID: pingID,
|
||||
},
|
||||
})
|
||||
|
||||
log.Info("Got pong response")
|
||||
require.True(t, errors.Is(err, expectedErr), "expected error")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
runTest(t, Options{
|
||||
RetryInterval: time.Second * 3,
|
||||
MaxRetries: 2,
|
||||
Clock: clock,
|
||||
Logger: log.Named("rpc"),
|
||||
}, server, client)
|
||||
}
|
||||
|
||||
func TestRPCResult(t *testing.T) {
|
||||
clock := neo.NewTime(defaultNow)
|
||||
observer := clock.Observe()
|
||||
log := zaptest.NewLogger(t)
|
||||
|
||||
server := func(t *testing.T, e *Engine, incoming <-chan request) error {
|
||||
log := log.Named("server")
|
||||
|
||||
log.Info("Waiting ping request")
|
||||
require.Equal(t, request{
|
||||
MsgID: msgID,
|
||||
SeqNo: seqNo,
|
||||
Input: &mt.PingRequest{PingID: pingID},
|
||||
}, <-incoming)
|
||||
|
||||
log.Info("Got ping request")
|
||||
// Make sure that engine calls time.After
|
||||
// before time travel.
|
||||
<-observer
|
||||
|
||||
log.Info("Traveling into the future for 2 seconds (simulate job)")
|
||||
clock.Travel(time.Second * 2)
|
||||
|
||||
var b bin.Buffer
|
||||
if err := b.Encode(&mt.Pong{
|
||||
MsgID: msgID,
|
||||
PingID: pingID,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Info("Sending pong response")
|
||||
return e.NotifyResult(msgID, &b)
|
||||
}
|
||||
|
||||
client := func(t *testing.T, e *Engine) error {
|
||||
log := log.Named("client")
|
||||
|
||||
log.Info("Sending ping request")
|
||||
var out mt.Pong
|
||||
require.NoError(t, e.Do(context.TODO(), Request{
|
||||
MsgID: msgID,
|
||||
SeqNo: seqNo,
|
||||
Input: &mt.PingRequest{PingID: pingID},
|
||||
Output: &out,
|
||||
}))
|
||||
|
||||
log.Info("Got pong response")
|
||||
require.Equal(t, mt.Pong{
|
||||
MsgID: msgID,
|
||||
PingID: pingID,
|
||||
}, out)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
runTest(t, Options{
|
||||
RetryInterval: time.Second * 4,
|
||||
MaxRetries: 2,
|
||||
Clock: clock,
|
||||
Logger: log.Named("rpc"),
|
||||
}, server, client)
|
||||
}
|
||||
|
||||
func TestRPCAckThenResult(t *testing.T) {
|
||||
clock := neo.NewTime(defaultNow)
|
||||
observer := clock.Observe()
|
||||
log := zaptest.NewLogger(t)
|
||||
|
||||
server := func(t *testing.T, e *Engine, incoming <-chan request) error {
|
||||
log := log.Named("server")
|
||||
|
||||
log.Info("Waiting ping request")
|
||||
require.Equal(t, request{
|
||||
MsgID: msgID,
|
||||
SeqNo: seqNo,
|
||||
Input: &mt.PingRequest{PingID: pingID},
|
||||
}, <-incoming)
|
||||
|
||||
// Make sure that client calls time.After
|
||||
// before time travel.
|
||||
<-observer
|
||||
|
||||
log.Info("Traveling into the future for 2 seconds (simulate job)")
|
||||
clock.Travel(time.Second * 2)
|
||||
|
||||
log.Info("Sending ACK")
|
||||
e.NotifyAcks([]int64{msgID})
|
||||
|
||||
log.Info("Traveling into the future for 6 seconds (simulate request processing)")
|
||||
clock.Travel(time.Second * 6)
|
||||
|
||||
var b bin.Buffer
|
||||
if err := b.Encode(&mt.Pong{
|
||||
MsgID: msgID,
|
||||
PingID: pingID,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Info("Sending response")
|
||||
return e.NotifyResult(msgID, &b)
|
||||
}
|
||||
|
||||
client := func(t *testing.T, e *Engine) error {
|
||||
log := log.Named("client")
|
||||
|
||||
log.Info("Sending ping request")
|
||||
var out mt.Pong
|
||||
require.NoError(t, e.Do(context.TODO(), Request{
|
||||
MsgID: msgID,
|
||||
SeqNo: seqNo,
|
||||
Input: &mt.PingRequest{PingID: pingID},
|
||||
Output: &out,
|
||||
}))
|
||||
|
||||
log.Info("Got pong response")
|
||||
require.Equal(t, mt.Pong{
|
||||
MsgID: msgID,
|
||||
PingID: pingID,
|
||||
}, out)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
runTest(t, Options{
|
||||
RetryInterval: time.Second * 4,
|
||||
MaxRetries: 2,
|
||||
Clock: clock,
|
||||
Logger: log.Named("rpc"),
|
||||
}, server, client)
|
||||
}
|
||||
|
||||
func TestRPCWithRetryResult(t *testing.T) {
|
||||
clock := neo.NewTime(defaultNow)
|
||||
observer := clock.Observe()
|
||||
log := zaptest.NewLogger(t)
|
||||
|
||||
server := func(t *testing.T, e *Engine, incoming <-chan request) error {
|
||||
log := log.Named("server")
|
||||
|
||||
log.Info("Waiting ping request")
|
||||
require.Equal(t, request{
|
||||
MsgID: msgID,
|
||||
SeqNo: seqNo,
|
||||
Input: &mt.PingRequest{PingID: pingID},
|
||||
}, <-incoming)
|
||||
log.Info("Got ping request")
|
||||
|
||||
// Make sure that client calls time.After
|
||||
// before time travel.
|
||||
<-observer
|
||||
|
||||
log.Info("Traveling into the future for 6 seconds (simulate request loss)")
|
||||
clock.Travel(time.Second * 6)
|
||||
|
||||
log.Info("Waiting re-sending request")
|
||||
require.Equal(t, request{
|
||||
MsgID: msgID,
|
||||
SeqNo: seqNo,
|
||||
Input: &mt.PingRequest{PingID: pingID},
|
||||
}, <-incoming)
|
||||
log.Info("Got ping request")
|
||||
|
||||
var b bin.Buffer
|
||||
if err := b.Encode(&mt.Pong{
|
||||
MsgID: msgID,
|
||||
PingID: pingID,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Info("Send pong response")
|
||||
return e.NotifyResult(msgID, &b)
|
||||
}
|
||||
|
||||
client := func(t *testing.T, e *Engine) error {
|
||||
log := log.Named("client")
|
||||
|
||||
log.Info("Sending ping request")
|
||||
var out mt.Pong
|
||||
require.NoError(t, e.Do(context.TODO(), Request{
|
||||
MsgID: 1,
|
||||
SeqNo: seqNo,
|
||||
Input: &mt.PingRequest{PingID: pingID},
|
||||
Output: &out,
|
||||
}))
|
||||
|
||||
log.Info("Got pong response")
|
||||
require.Equal(t, mt.Pong{
|
||||
MsgID: msgID,
|
||||
PingID: pingID,
|
||||
}, out)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
runTest(t, Options{
|
||||
RetryInterval: time.Second * 4,
|
||||
MaxRetries: 5,
|
||||
Clock: clock,
|
||||
Logger: log.Named("rpc"),
|
||||
}, server, client)
|
||||
}
|
||||
|
||||
func TestEngineGracefulShutdown(t *testing.T) {
|
||||
var (
|
||||
log = zaptest.NewLogger(t)
|
||||
expectedErr = errors.New("server side error")
|
||||
requestsCount = 10
|
||||
serverRecv sync.WaitGroup
|
||||
canSendResponse sync.Mutex
|
||||
)
|
||||
|
||||
serverRecv.Add(requestsCount)
|
||||
canSendResponse.Lock()
|
||||
|
||||
server := func(t *testing.T, e *Engine, incoming <-chan request) error {
|
||||
log := log.Named("server")
|
||||
|
||||
var batch []request
|
||||
for i := 0; i < requestsCount; i++ {
|
||||
batch = append(batch, <-incoming)
|
||||
serverRecv.Done()
|
||||
}
|
||||
e.log.Info("Got all requests")
|
||||
|
||||
canSendResponse.Lock()
|
||||
e.log.Info("Sending responses")
|
||||
for _, req := range batch {
|
||||
log.Info("send response")
|
||||
e.NotifyError(req.MsgID, expectedErr)
|
||||
}
|
||||
canSendResponse.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
client := func(t *testing.T, e *Engine) error {
|
||||
var currMsgID int64
|
||||
|
||||
for i := 0; i < requestsCount; i++ {
|
||||
go func(t *testing.T, msgID int64) {
|
||||
var out mt.Pong
|
||||
require.Equal(t, e.Do(context.TODO(), Request{
|
||||
MsgID: msgID,
|
||||
SeqNo: seqNo,
|
||||
Input: &mt.PingRequest{PingID: pingID},
|
||||
Output: &out,
|
||||
}), expectedErr)
|
||||
}(t, currMsgID)
|
||||
|
||||
currMsgID++
|
||||
}
|
||||
|
||||
// wait until server receive all requests
|
||||
serverRecv.Wait()
|
||||
// allow server to send responses
|
||||
canSendResponse.Unlock()
|
||||
// close the engine
|
||||
e.Close()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
runTest(t, Options{
|
||||
RetryInterval: time.Second * 5,
|
||||
MaxRetries: 5,
|
||||
Logger: log.Named("rpc"),
|
||||
}, server, client)
|
||||
}
|
||||
|
||||
func TestDropRPC(t *testing.T) {
|
||||
clock := neo.NewTime(defaultNow)
|
||||
log := zaptest.NewLogger(t)
|
||||
serverRecvRequest := make(chan struct{})
|
||||
clientCancelledCtx := make(chan struct{})
|
||||
dropChan := make(chan Request)
|
||||
|
||||
server := func(t *testing.T, e *Engine, incoming <-chan request) error {
|
||||
log := log.Named("server")
|
||||
|
||||
log.Info("Waiting ping request")
|
||||
require.Equal(t, request{
|
||||
MsgID: msgID,
|
||||
SeqNo: seqNo,
|
||||
Input: &mt.PingRequest{PingID: pingID},
|
||||
}, <-incoming)
|
||||
|
||||
close(serverRecvRequest)
|
||||
<-clientCancelledCtx
|
||||
|
||||
log.Info("Waiting drop request")
|
||||
require.Equal(t, msgID, (<-dropChan).MsgID)
|
||||
return nil
|
||||
}
|
||||
|
||||
client := func(t *testing.T, e *Engine) error {
|
||||
log := log.Named("client")
|
||||
|
||||
log.Info("Sending ping request")
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
<-serverRecvRequest
|
||||
log.Info("Canceling request context")
|
||||
cancel()
|
||||
close(clientCancelledCtx)
|
||||
}()
|
||||
|
||||
require.ErrorIs(t, e.Do(ctx, Request{
|
||||
MsgID: msgID,
|
||||
SeqNo: seqNo,
|
||||
Input: &mt.PingRequest{PingID: pingID},
|
||||
Output: &mt.Pong{},
|
||||
}), context.Canceled)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
runTest(t, Options{
|
||||
RetryInterval: time.Second * 4,
|
||||
MaxRetries: 2,
|
||||
Clock: clock,
|
||||
Logger: log.Named("rpc"),
|
||||
DropHandler: func(req Request) error { dropChan <- req; return nil },
|
||||
}, server, client)
|
||||
}
|
||||
|
||||
func runTest(
|
||||
t *testing.T,
|
||||
cfg Options,
|
||||
server func(t *testing.T, e *Engine, incoming <-chan request) error,
|
||||
client func(t *testing.T, e *Engine) error,
|
||||
) {
|
||||
t.Helper()
|
||||
|
||||
// Channel of client requests sent to the server.
|
||||
requests := make(chan request)
|
||||
defer close(requests)
|
||||
|
||||
e := New(func(ctx context.Context, msgID int64, seqNo int32, in bin.Encoder) error {
|
||||
req := request{
|
||||
MsgID: msgID,
|
||||
SeqNo: seqNo,
|
||||
Input: in,
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case requests <- req:
|
||||
return nil
|
||||
}
|
||||
}, cfg)
|
||||
|
||||
var g errgroup.Group
|
||||
g.Go(func() error { return server(t, e, requests) })
|
||||
g.Go(func() error { return client(t, e) })
|
||||
|
||||
require.NoError(t, g.Wait())
|
||||
e.Close()
|
||||
require.NoError(t, cfg.Logger.Sync())
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
package rpc
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
)
|
||||
|
||||
// RetryLimitReachedErr means that server does not acknowledge request
|
||||
// after multiple retries.
|
||||
type RetryLimitReachedErr struct {
|
||||
Retries int
|
||||
}
|
||||
|
||||
func (r *RetryLimitReachedErr) Error() string {
|
||||
return fmt.Sprintf("retry limit reached after %d attempts", r.Retries)
|
||||
}
|
||||
|
||||
// Is reports whether err is RetryLimitReachedErr.
|
||||
func (r *RetryLimitReachedErr) Is(err error) bool {
|
||||
_, ok := err.(*RetryLimitReachedErr)
|
||||
return ok
|
||||
}
|
||||
|
||||
// ErrEngineClosed means that engine was closed.
|
||||
var ErrEngineClosed = errors.New("engine was closed")
|
||||
@@ -0,0 +1,23 @@
|
||||
package rpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
)
|
||||
|
||||
// Send is a function that sends requests to the server.
|
||||
type Send func(ctx context.Context, msgID int64, seqNo int32, in bin.Encoder) error
|
||||
|
||||
// NopSend does nothing.
|
||||
func NopSend(context.Context, int64, int32, bin.Encoder) error { return nil }
|
||||
|
||||
var _ Send = NopSend
|
||||
|
||||
// DropHandler handles drop rpc requests.
|
||||
type DropHandler func(req Request) error
|
||||
|
||||
// NopDrop does nothing.
|
||||
func NopDrop(Request) error { return nil }
|
||||
|
||||
var _ DropHandler = NopDrop
|
||||
@@ -0,0 +1,16 @@
|
||||
package rpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNopDrop(t *testing.T) {
|
||||
require.NoError(t, NopDrop(Request{}))
|
||||
}
|
||||
|
||||
func TestNopSend(t *testing.T) {
|
||||
require.NoError(t, NopSend(context.TODO(), 0, 0, nil))
|
||||
}
|
||||
@@ -0,0 +1,40 @@
|
||||
package rpc
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/clock"
|
||||
)
|
||||
|
||||
// Options of rpc engine.
|
||||
type Options struct {
|
||||
RetryInterval time.Duration
|
||||
MaxRetries int
|
||||
Logger *zap.Logger
|
||||
Clock clock.Clock
|
||||
DropHandler DropHandler
|
||||
OnError func(error)
|
||||
}
|
||||
|
||||
func (cfg *Options) setDefaults() {
|
||||
if cfg.RetryInterval == 0 {
|
||||
cfg.RetryInterval = time.Second * 10
|
||||
}
|
||||
if cfg.MaxRetries == 0 {
|
||||
cfg.MaxRetries = 5
|
||||
}
|
||||
if cfg.Logger == nil {
|
||||
cfg.Logger = zap.NewNop()
|
||||
}
|
||||
if cfg.Clock == nil {
|
||||
cfg.Clock = clock.System
|
||||
}
|
||||
if cfg.DropHandler == nil {
|
||||
cfg.DropHandler = NopDrop
|
||||
}
|
||||
if cfg.OnError == nil {
|
||||
cfg.OnError = func(err error) {}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user