7a04f298d2
- update to latest telegram layer - remove some references to fields in tg.Entities that don't exist in the schema - originally added here: https://github.com/beeper/td/commit/820929062a2ba0104397bc01235ab58a9cff780e - referenced here - https://github.com/mautrix/telegramgo/commit/124f0967ed195b5a380c9bd02e170ada9710dde3 - https://github.com/mautrix/telegramgo/commit/4205047aab2e0639217148b5d125bfaab668bd8e
458 lines
9.6 KiB
Go
458 lines
9.6 KiB
Go
package rpc
|
|
|
|
import (
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/go-faster/errors"
|
|
"github.com/stretchr/testify/require"
|
|
"go.uber.org/zap/zaptest"
|
|
"golang.org/x/net/context"
|
|
"golang.org/x/sync/errgroup"
|
|
|
|
"github.com/gotd/neo"
|
|
|
|
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
|
"go.mau.fi/mautrix-telegram/pkg/gotd/mt"
|
|
"go.mau.fi/mautrix-telegram/pkg/gotd/testutil"
|
|
)
|
|
|
|
type request struct {
|
|
MsgID int64
|
|
SeqNo int32
|
|
Input bin.Encoder
|
|
}
|
|
|
|
var defaultNow = testutil.Date()
|
|
|
|
const (
|
|
msgID int64 = 1
|
|
pingID int64 = 1337
|
|
seqNo int32 = 1
|
|
)
|
|
|
|
func TestRPCError(t *testing.T) {
|
|
clock := neo.NewTime(defaultNow)
|
|
observer := clock.Observe()
|
|
expectedErr := errors.New("server side error")
|
|
log := zaptest.NewLogger(t)
|
|
|
|
server := func(t *testing.T, e *Engine, incoming <-chan request) error {
|
|
log := log.Named("server")
|
|
|
|
log.Info("Waiting ping request")
|
|
require.Equal(t, request{
|
|
MsgID: msgID,
|
|
SeqNo: seqNo,
|
|
Input: &mt.PingRequest{PingID: pingID},
|
|
}, <-incoming)
|
|
|
|
log.Info("Got ping request")
|
|
|
|
// Make sure that client calls time.After
|
|
// before time travel
|
|
<-observer
|
|
|
|
log.Info("Traveling into the future for a second (simulate job)")
|
|
clock.Travel(time.Second)
|
|
|
|
log.Info("Sending RPC error")
|
|
e.NotifyError(msgID, expectedErr)
|
|
|
|
return nil
|
|
}
|
|
|
|
client := func(t *testing.T, e *Engine) error {
|
|
log := log.Named("client")
|
|
|
|
log.Info("Sending ping request")
|
|
err := e.Do(context.TODO(), Request{
|
|
MsgID: msgID,
|
|
SeqNo: seqNo,
|
|
Input: &mt.PingRequest{
|
|
PingID: pingID,
|
|
},
|
|
})
|
|
|
|
log.Info("Got pong response")
|
|
require.True(t, errors.Is(err, expectedErr), "expected error")
|
|
|
|
return nil
|
|
}
|
|
|
|
runTest(t, Options{
|
|
RetryInterval: time.Second * 3,
|
|
MaxRetries: 2,
|
|
Clock: clock,
|
|
Logger: log.Named("rpc"),
|
|
}, server, client)
|
|
}
|
|
|
|
func TestRPCResult(t *testing.T) {
|
|
clock := neo.NewTime(defaultNow)
|
|
observer := clock.Observe()
|
|
log := zaptest.NewLogger(t)
|
|
|
|
server := func(t *testing.T, e *Engine, incoming <-chan request) error {
|
|
log := log.Named("server")
|
|
|
|
log.Info("Waiting ping request")
|
|
require.Equal(t, request{
|
|
MsgID: msgID,
|
|
SeqNo: seqNo,
|
|
Input: &mt.PingRequest{PingID: pingID},
|
|
}, <-incoming)
|
|
|
|
log.Info("Got ping request")
|
|
// Make sure that engine calls time.After
|
|
// before time travel.
|
|
<-observer
|
|
|
|
log.Info("Traveling into the future for 2 seconds (simulate job)")
|
|
clock.Travel(time.Second * 2)
|
|
|
|
var b bin.Buffer
|
|
if err := b.Encode(&mt.Pong{
|
|
MsgID: msgID,
|
|
PingID: pingID,
|
|
}); err != nil {
|
|
return err
|
|
}
|
|
|
|
log.Info("Sending pong response")
|
|
return e.NotifyResult(msgID, &b)
|
|
}
|
|
|
|
client := func(t *testing.T, e *Engine) error {
|
|
log := log.Named("client")
|
|
|
|
log.Info("Sending ping request")
|
|
var out mt.Pong
|
|
require.NoError(t, e.Do(context.TODO(), Request{
|
|
MsgID: msgID,
|
|
SeqNo: seqNo,
|
|
Input: &mt.PingRequest{PingID: pingID},
|
|
Output: &out,
|
|
}))
|
|
|
|
log.Info("Got pong response")
|
|
require.Equal(t, mt.Pong{
|
|
MsgID: msgID,
|
|
PingID: pingID,
|
|
}, out)
|
|
|
|
return nil
|
|
}
|
|
|
|
runTest(t, Options{
|
|
RetryInterval: time.Second * 4,
|
|
MaxRetries: 2,
|
|
Clock: clock,
|
|
Logger: log.Named("rpc"),
|
|
}, server, client)
|
|
}
|
|
|
|
func TestRPCAckThenResult(t *testing.T) {
|
|
clock := neo.NewTime(defaultNow)
|
|
observer := clock.Observe()
|
|
log := zaptest.NewLogger(t)
|
|
|
|
server := func(t *testing.T, e *Engine, incoming <-chan request) error {
|
|
log := log.Named("server")
|
|
|
|
log.Info("Waiting ping request")
|
|
require.Equal(t, request{
|
|
MsgID: msgID,
|
|
SeqNo: seqNo,
|
|
Input: &mt.PingRequest{PingID: pingID},
|
|
}, <-incoming)
|
|
|
|
// Make sure that client calls time.After
|
|
// before time travel.
|
|
<-observer
|
|
|
|
log.Info("Traveling into the future for 2 seconds (simulate job)")
|
|
clock.Travel(time.Second * 2)
|
|
|
|
log.Info("Sending ACK")
|
|
e.NotifyAcks([]int64{msgID})
|
|
|
|
log.Info("Traveling into the future for 6 seconds (simulate request processing)")
|
|
clock.Travel(time.Second * 6)
|
|
|
|
var b bin.Buffer
|
|
if err := b.Encode(&mt.Pong{
|
|
MsgID: msgID,
|
|
PingID: pingID,
|
|
}); err != nil {
|
|
return err
|
|
}
|
|
|
|
log.Info("Sending response")
|
|
return e.NotifyResult(msgID, &b)
|
|
}
|
|
|
|
client := func(t *testing.T, e *Engine) error {
|
|
log := log.Named("client")
|
|
|
|
log.Info("Sending ping request")
|
|
var out mt.Pong
|
|
require.NoError(t, e.Do(context.TODO(), Request{
|
|
MsgID: msgID,
|
|
SeqNo: seqNo,
|
|
Input: &mt.PingRequest{PingID: pingID},
|
|
Output: &out,
|
|
}))
|
|
|
|
log.Info("Got pong response")
|
|
require.Equal(t, mt.Pong{
|
|
MsgID: msgID,
|
|
PingID: pingID,
|
|
}, out)
|
|
|
|
return nil
|
|
}
|
|
|
|
runTest(t, Options{
|
|
RetryInterval: time.Second * 4,
|
|
MaxRetries: 2,
|
|
Clock: clock,
|
|
Logger: log.Named("rpc"),
|
|
}, server, client)
|
|
}
|
|
|
|
func TestRPCWithRetryResult(t *testing.T) {
|
|
clock := neo.NewTime(defaultNow)
|
|
observer := clock.Observe()
|
|
log := zaptest.NewLogger(t)
|
|
|
|
server := func(t *testing.T, e *Engine, incoming <-chan request) error {
|
|
log := log.Named("server")
|
|
|
|
log.Info("Waiting ping request")
|
|
require.Equal(t, request{
|
|
MsgID: msgID,
|
|
SeqNo: seqNo,
|
|
Input: &mt.PingRequest{PingID: pingID},
|
|
}, <-incoming)
|
|
log.Info("Got ping request")
|
|
|
|
// Make sure that client calls time.After
|
|
// before time travel.
|
|
<-observer
|
|
|
|
log.Info("Traveling into the future for 6 seconds (simulate request loss)")
|
|
clock.Travel(time.Second * 6)
|
|
|
|
log.Info("Waiting re-sending request")
|
|
require.Equal(t, request{
|
|
MsgID: msgID,
|
|
SeqNo: seqNo,
|
|
Input: &mt.PingRequest{PingID: pingID},
|
|
}, <-incoming)
|
|
log.Info("Got ping request")
|
|
|
|
var b bin.Buffer
|
|
if err := b.Encode(&mt.Pong{
|
|
MsgID: msgID,
|
|
PingID: pingID,
|
|
}); err != nil {
|
|
return err
|
|
}
|
|
|
|
log.Info("Send pong response")
|
|
return e.NotifyResult(msgID, &b)
|
|
}
|
|
|
|
client := func(t *testing.T, e *Engine) error {
|
|
log := log.Named("client")
|
|
|
|
log.Info("Sending ping request")
|
|
var out mt.Pong
|
|
require.NoError(t, e.Do(context.TODO(), Request{
|
|
MsgID: 1,
|
|
SeqNo: seqNo,
|
|
Input: &mt.PingRequest{PingID: pingID},
|
|
Output: &out,
|
|
}))
|
|
|
|
log.Info("Got pong response")
|
|
require.Equal(t, mt.Pong{
|
|
MsgID: msgID,
|
|
PingID: pingID,
|
|
}, out)
|
|
|
|
return nil
|
|
}
|
|
|
|
runTest(t, Options{
|
|
RetryInterval: time.Second * 4,
|
|
MaxRetries: 5,
|
|
Clock: clock,
|
|
Logger: log.Named("rpc"),
|
|
}, server, client)
|
|
}
|
|
|
|
func TestEngineGracefulShutdown(t *testing.T) {
|
|
var (
|
|
log = zaptest.NewLogger(t)
|
|
expectedErr = errors.New("server side error")
|
|
requestsCount = 10
|
|
serverRecv sync.WaitGroup
|
|
canSendResponse sync.Mutex
|
|
)
|
|
|
|
serverRecv.Add(requestsCount)
|
|
canSendResponse.Lock()
|
|
|
|
server := func(t *testing.T, e *Engine, incoming <-chan request) error {
|
|
log := log.Named("server")
|
|
|
|
var batch []request
|
|
for i := 0; i < requestsCount; i++ {
|
|
batch = append(batch, <-incoming)
|
|
serverRecv.Done()
|
|
}
|
|
e.log.Info("Got all requests")
|
|
|
|
canSendResponse.Lock()
|
|
e.log.Info("Sending responses")
|
|
for _, req := range batch {
|
|
log.Info("send response")
|
|
e.NotifyError(req.MsgID, expectedErr)
|
|
}
|
|
canSendResponse.Unlock()
|
|
|
|
return nil
|
|
}
|
|
|
|
client := func(t *testing.T, e *Engine) error {
|
|
var currMsgID int64
|
|
|
|
for i := 0; i < requestsCount; i++ {
|
|
go func(t *testing.T, msgID int64) {
|
|
var out mt.Pong
|
|
require.Equal(t, e.Do(context.TODO(), Request{
|
|
MsgID: msgID,
|
|
SeqNo: seqNo,
|
|
Input: &mt.PingRequest{PingID: pingID},
|
|
Output: &out,
|
|
}), expectedErr)
|
|
}(t, currMsgID)
|
|
|
|
currMsgID++
|
|
}
|
|
|
|
// wait until server receive all requests
|
|
serverRecv.Wait()
|
|
// allow server to send responses
|
|
canSendResponse.Unlock()
|
|
// close the engine
|
|
e.Close()
|
|
|
|
return nil
|
|
}
|
|
|
|
runTest(t, Options{
|
|
RetryInterval: time.Second * 5,
|
|
MaxRetries: 5,
|
|
Logger: log.Named("rpc"),
|
|
}, server, client)
|
|
}
|
|
|
|
func TestDropRPC(t *testing.T) {
|
|
clock := neo.NewTime(defaultNow)
|
|
log := zaptest.NewLogger(t)
|
|
serverRecvRequest := make(chan struct{})
|
|
clientCancelledCtx := make(chan struct{})
|
|
dropChan := make(chan Request)
|
|
|
|
server := func(t *testing.T, e *Engine, incoming <-chan request) error {
|
|
log := log.Named("server")
|
|
|
|
log.Info("Waiting ping request")
|
|
require.Equal(t, request{
|
|
MsgID: msgID,
|
|
SeqNo: seqNo,
|
|
Input: &mt.PingRequest{PingID: pingID},
|
|
}, <-incoming)
|
|
|
|
close(serverRecvRequest)
|
|
<-clientCancelledCtx
|
|
|
|
log.Info("Waiting drop request")
|
|
require.Equal(t, msgID, (<-dropChan).MsgID)
|
|
return nil
|
|
}
|
|
|
|
client := func(t *testing.T, e *Engine) error {
|
|
log := log.Named("client")
|
|
|
|
log.Info("Sending ping request")
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
|
|
go func() {
|
|
<-serverRecvRequest
|
|
log.Info("Canceling request context")
|
|
cancel()
|
|
close(clientCancelledCtx)
|
|
}()
|
|
|
|
require.ErrorIs(t, e.Do(ctx, Request{
|
|
MsgID: msgID,
|
|
SeqNo: seqNo,
|
|
Input: &mt.PingRequest{PingID: pingID},
|
|
Output: &mt.Pong{},
|
|
}), context.Canceled)
|
|
|
|
return nil
|
|
}
|
|
|
|
runTest(t, Options{
|
|
RetryInterval: time.Second * 4,
|
|
MaxRetries: 2,
|
|
Clock: clock,
|
|
Logger: log.Named("rpc"),
|
|
DropHandler: func(req Request) error { dropChan <- req; return nil },
|
|
}, server, client)
|
|
}
|
|
|
|
func runTest(
|
|
t *testing.T,
|
|
cfg Options,
|
|
server func(t *testing.T, e *Engine, incoming <-chan request) error,
|
|
client func(t *testing.T, e *Engine) error,
|
|
) {
|
|
t.Helper()
|
|
|
|
// Channel of client requests sent to the server.
|
|
requests := make(chan request)
|
|
defer close(requests)
|
|
|
|
e := New(func(ctx context.Context, msgID int64, seqNo int32, in bin.Encoder) error {
|
|
req := request{
|
|
MsgID: msgID,
|
|
SeqNo: seqNo,
|
|
Input: in,
|
|
}
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
case requests <- req:
|
|
return nil
|
|
}
|
|
}, cfg)
|
|
|
|
var g errgroup.Group
|
|
g.Go(func() error { return server(t, e, requests) })
|
|
g.Go(func() error { return client(t, e) })
|
|
|
|
require.NoError(t, g.Wait())
|
|
e.Close()
|
|
require.NoError(t, cfg.Logger.Sync())
|
|
}
|