move gotd fork into repo. (#111)
- update to latest telegram layer - remove some references to fields in tg.Entities that don't exist in the schema - originally added here: https://github.com/beeper/td/commit/820929062a2ba0104397bc01235ab58a9cff780e - referenced here - https://github.com/mautrix/telegramgo/commit/124f0967ed195b5a380c9bd02e170ada9710dde3 - https://github.com/mautrix/telegramgo/commit/4205047aab2e0639217148b5d125bfaab668bd8e
This commit is contained in:
@@ -0,0 +1,457 @@
|
||||
package rpc
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap/zaptest"
|
||||
"golang.org/x/net/context"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/gotd/neo"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/mt"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/testutil"
|
||||
)
|
||||
|
||||
type request struct {
|
||||
MsgID int64
|
||||
SeqNo int32
|
||||
Input bin.Encoder
|
||||
}
|
||||
|
||||
var defaultNow = testutil.Date()
|
||||
|
||||
const (
|
||||
msgID int64 = 1
|
||||
pingID int64 = 1337
|
||||
seqNo int32 = 1
|
||||
)
|
||||
|
||||
func TestRPCError(t *testing.T) {
|
||||
clock := neo.NewTime(defaultNow)
|
||||
observer := clock.Observe()
|
||||
expectedErr := errors.New("server side error")
|
||||
log := zaptest.NewLogger(t)
|
||||
|
||||
server := func(t *testing.T, e *Engine, incoming <-chan request) error {
|
||||
log := log.Named("server")
|
||||
|
||||
log.Info("Waiting ping request")
|
||||
require.Equal(t, request{
|
||||
MsgID: msgID,
|
||||
SeqNo: seqNo,
|
||||
Input: &mt.PingRequest{PingID: pingID},
|
||||
}, <-incoming)
|
||||
|
||||
log.Info("Got ping request")
|
||||
|
||||
// Make sure that client calls time.After
|
||||
// before time travel
|
||||
<-observer
|
||||
|
||||
log.Info("Traveling into the future for a second (simulate job)")
|
||||
clock.Travel(time.Second)
|
||||
|
||||
log.Info("Sending RPC error")
|
||||
e.NotifyError(msgID, expectedErr)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
client := func(t *testing.T, e *Engine) error {
|
||||
log := log.Named("client")
|
||||
|
||||
log.Info("Sending ping request")
|
||||
err := e.Do(context.TODO(), Request{
|
||||
MsgID: msgID,
|
||||
SeqNo: seqNo,
|
||||
Input: &mt.PingRequest{
|
||||
PingID: pingID,
|
||||
},
|
||||
})
|
||||
|
||||
log.Info("Got pong response")
|
||||
require.True(t, errors.Is(err, expectedErr), "expected error")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
runTest(t, Options{
|
||||
RetryInterval: time.Second * 3,
|
||||
MaxRetries: 2,
|
||||
Clock: clock,
|
||||
Logger: log.Named("rpc"),
|
||||
}, server, client)
|
||||
}
|
||||
|
||||
func TestRPCResult(t *testing.T) {
|
||||
clock := neo.NewTime(defaultNow)
|
||||
observer := clock.Observe()
|
||||
log := zaptest.NewLogger(t)
|
||||
|
||||
server := func(t *testing.T, e *Engine, incoming <-chan request) error {
|
||||
log := log.Named("server")
|
||||
|
||||
log.Info("Waiting ping request")
|
||||
require.Equal(t, request{
|
||||
MsgID: msgID,
|
||||
SeqNo: seqNo,
|
||||
Input: &mt.PingRequest{PingID: pingID},
|
||||
}, <-incoming)
|
||||
|
||||
log.Info("Got ping request")
|
||||
// Make sure that engine calls time.After
|
||||
// before time travel.
|
||||
<-observer
|
||||
|
||||
log.Info("Traveling into the future for 2 seconds (simulate job)")
|
||||
clock.Travel(time.Second * 2)
|
||||
|
||||
var b bin.Buffer
|
||||
if err := b.Encode(&mt.Pong{
|
||||
MsgID: msgID,
|
||||
PingID: pingID,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Info("Sending pong response")
|
||||
return e.NotifyResult(msgID, &b)
|
||||
}
|
||||
|
||||
client := func(t *testing.T, e *Engine) error {
|
||||
log := log.Named("client")
|
||||
|
||||
log.Info("Sending ping request")
|
||||
var out mt.Pong
|
||||
require.NoError(t, e.Do(context.TODO(), Request{
|
||||
MsgID: msgID,
|
||||
SeqNo: seqNo,
|
||||
Input: &mt.PingRequest{PingID: pingID},
|
||||
Output: &out,
|
||||
}))
|
||||
|
||||
log.Info("Got pong response")
|
||||
require.Equal(t, mt.Pong{
|
||||
MsgID: msgID,
|
||||
PingID: pingID,
|
||||
}, out)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
runTest(t, Options{
|
||||
RetryInterval: time.Second * 4,
|
||||
MaxRetries: 2,
|
||||
Clock: clock,
|
||||
Logger: log.Named("rpc"),
|
||||
}, server, client)
|
||||
}
|
||||
|
||||
func TestRPCAckThenResult(t *testing.T) {
|
||||
clock := neo.NewTime(defaultNow)
|
||||
observer := clock.Observe()
|
||||
log := zaptest.NewLogger(t)
|
||||
|
||||
server := func(t *testing.T, e *Engine, incoming <-chan request) error {
|
||||
log := log.Named("server")
|
||||
|
||||
log.Info("Waiting ping request")
|
||||
require.Equal(t, request{
|
||||
MsgID: msgID,
|
||||
SeqNo: seqNo,
|
||||
Input: &mt.PingRequest{PingID: pingID},
|
||||
}, <-incoming)
|
||||
|
||||
// Make sure that client calls time.After
|
||||
// before time travel.
|
||||
<-observer
|
||||
|
||||
log.Info("Traveling into the future for 2 seconds (simulate job)")
|
||||
clock.Travel(time.Second * 2)
|
||||
|
||||
log.Info("Sending ACK")
|
||||
e.NotifyAcks([]int64{msgID})
|
||||
|
||||
log.Info("Traveling into the future for 6 seconds (simulate request processing)")
|
||||
clock.Travel(time.Second * 6)
|
||||
|
||||
var b bin.Buffer
|
||||
if err := b.Encode(&mt.Pong{
|
||||
MsgID: msgID,
|
||||
PingID: pingID,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Info("Sending response")
|
||||
return e.NotifyResult(msgID, &b)
|
||||
}
|
||||
|
||||
client := func(t *testing.T, e *Engine) error {
|
||||
log := log.Named("client")
|
||||
|
||||
log.Info("Sending ping request")
|
||||
var out mt.Pong
|
||||
require.NoError(t, e.Do(context.TODO(), Request{
|
||||
MsgID: msgID,
|
||||
SeqNo: seqNo,
|
||||
Input: &mt.PingRequest{PingID: pingID},
|
||||
Output: &out,
|
||||
}))
|
||||
|
||||
log.Info("Got pong response")
|
||||
require.Equal(t, mt.Pong{
|
||||
MsgID: msgID,
|
||||
PingID: pingID,
|
||||
}, out)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
runTest(t, Options{
|
||||
RetryInterval: time.Second * 4,
|
||||
MaxRetries: 2,
|
||||
Clock: clock,
|
||||
Logger: log.Named("rpc"),
|
||||
}, server, client)
|
||||
}
|
||||
|
||||
func TestRPCWithRetryResult(t *testing.T) {
|
||||
clock := neo.NewTime(defaultNow)
|
||||
observer := clock.Observe()
|
||||
log := zaptest.NewLogger(t)
|
||||
|
||||
server := func(t *testing.T, e *Engine, incoming <-chan request) error {
|
||||
log := log.Named("server")
|
||||
|
||||
log.Info("Waiting ping request")
|
||||
require.Equal(t, request{
|
||||
MsgID: msgID,
|
||||
SeqNo: seqNo,
|
||||
Input: &mt.PingRequest{PingID: pingID},
|
||||
}, <-incoming)
|
||||
log.Info("Got ping request")
|
||||
|
||||
// Make sure that client calls time.After
|
||||
// before time travel.
|
||||
<-observer
|
||||
|
||||
log.Info("Traveling into the future for 6 seconds (simulate request loss)")
|
||||
clock.Travel(time.Second * 6)
|
||||
|
||||
log.Info("Waiting re-sending request")
|
||||
require.Equal(t, request{
|
||||
MsgID: msgID,
|
||||
SeqNo: seqNo,
|
||||
Input: &mt.PingRequest{PingID: pingID},
|
||||
}, <-incoming)
|
||||
log.Info("Got ping request")
|
||||
|
||||
var b bin.Buffer
|
||||
if err := b.Encode(&mt.Pong{
|
||||
MsgID: msgID,
|
||||
PingID: pingID,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Info("Send pong response")
|
||||
return e.NotifyResult(msgID, &b)
|
||||
}
|
||||
|
||||
client := func(t *testing.T, e *Engine) error {
|
||||
log := log.Named("client")
|
||||
|
||||
log.Info("Sending ping request")
|
||||
var out mt.Pong
|
||||
require.NoError(t, e.Do(context.TODO(), Request{
|
||||
MsgID: 1,
|
||||
SeqNo: seqNo,
|
||||
Input: &mt.PingRequest{PingID: pingID},
|
||||
Output: &out,
|
||||
}))
|
||||
|
||||
log.Info("Got pong response")
|
||||
require.Equal(t, mt.Pong{
|
||||
MsgID: msgID,
|
||||
PingID: pingID,
|
||||
}, out)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
runTest(t, Options{
|
||||
RetryInterval: time.Second * 4,
|
||||
MaxRetries: 5,
|
||||
Clock: clock,
|
||||
Logger: log.Named("rpc"),
|
||||
}, server, client)
|
||||
}
|
||||
|
||||
func TestEngineGracefulShutdown(t *testing.T) {
|
||||
var (
|
||||
log = zaptest.NewLogger(t)
|
||||
expectedErr = errors.New("server side error")
|
||||
requestsCount = 10
|
||||
serverRecv sync.WaitGroup
|
||||
canSendResponse sync.Mutex
|
||||
)
|
||||
|
||||
serverRecv.Add(requestsCount)
|
||||
canSendResponse.Lock()
|
||||
|
||||
server := func(t *testing.T, e *Engine, incoming <-chan request) error {
|
||||
log := log.Named("server")
|
||||
|
||||
var batch []request
|
||||
for i := 0; i < requestsCount; i++ {
|
||||
batch = append(batch, <-incoming)
|
||||
serverRecv.Done()
|
||||
}
|
||||
e.log.Info("Got all requests")
|
||||
|
||||
canSendResponse.Lock()
|
||||
e.log.Info("Sending responses")
|
||||
for _, req := range batch {
|
||||
log.Info("send response")
|
||||
e.NotifyError(req.MsgID, expectedErr)
|
||||
}
|
||||
canSendResponse.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
client := func(t *testing.T, e *Engine) error {
|
||||
var currMsgID int64
|
||||
|
||||
for i := 0; i < requestsCount; i++ {
|
||||
go func(t *testing.T, msgID int64) {
|
||||
var out mt.Pong
|
||||
require.Equal(t, e.Do(context.TODO(), Request{
|
||||
MsgID: msgID,
|
||||
SeqNo: seqNo,
|
||||
Input: &mt.PingRequest{PingID: pingID},
|
||||
Output: &out,
|
||||
}), expectedErr)
|
||||
}(t, currMsgID)
|
||||
|
||||
currMsgID++
|
||||
}
|
||||
|
||||
// wait until server receive all requests
|
||||
serverRecv.Wait()
|
||||
// allow server to send responses
|
||||
canSendResponse.Unlock()
|
||||
// close the engine
|
||||
e.Close()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
runTest(t, Options{
|
||||
RetryInterval: time.Second * 5,
|
||||
MaxRetries: 5,
|
||||
Logger: log.Named("rpc"),
|
||||
}, server, client)
|
||||
}
|
||||
|
||||
func TestDropRPC(t *testing.T) {
|
||||
clock := neo.NewTime(defaultNow)
|
||||
log := zaptest.NewLogger(t)
|
||||
serverRecvRequest := make(chan struct{})
|
||||
clientCancelledCtx := make(chan struct{})
|
||||
dropChan := make(chan Request)
|
||||
|
||||
server := func(t *testing.T, e *Engine, incoming <-chan request) error {
|
||||
log := log.Named("server")
|
||||
|
||||
log.Info("Waiting ping request")
|
||||
require.Equal(t, request{
|
||||
MsgID: msgID,
|
||||
SeqNo: seqNo,
|
||||
Input: &mt.PingRequest{PingID: pingID},
|
||||
}, <-incoming)
|
||||
|
||||
close(serverRecvRequest)
|
||||
<-clientCancelledCtx
|
||||
|
||||
log.Info("Waiting drop request")
|
||||
require.Equal(t, msgID, (<-dropChan).MsgID)
|
||||
return nil
|
||||
}
|
||||
|
||||
client := func(t *testing.T, e *Engine) error {
|
||||
log := log.Named("client")
|
||||
|
||||
log.Info("Sending ping request")
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
<-serverRecvRequest
|
||||
log.Info("Canceling request context")
|
||||
cancel()
|
||||
close(clientCancelledCtx)
|
||||
}()
|
||||
|
||||
require.ErrorIs(t, e.Do(ctx, Request{
|
||||
MsgID: msgID,
|
||||
SeqNo: seqNo,
|
||||
Input: &mt.PingRequest{PingID: pingID},
|
||||
Output: &mt.Pong{},
|
||||
}), context.Canceled)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
runTest(t, Options{
|
||||
RetryInterval: time.Second * 4,
|
||||
MaxRetries: 2,
|
||||
Clock: clock,
|
||||
Logger: log.Named("rpc"),
|
||||
DropHandler: func(req Request) error { dropChan <- req; return nil },
|
||||
}, server, client)
|
||||
}
|
||||
|
||||
func runTest(
|
||||
t *testing.T,
|
||||
cfg Options,
|
||||
server func(t *testing.T, e *Engine, incoming <-chan request) error,
|
||||
client func(t *testing.T, e *Engine) error,
|
||||
) {
|
||||
t.Helper()
|
||||
|
||||
// Channel of client requests sent to the server.
|
||||
requests := make(chan request)
|
||||
defer close(requests)
|
||||
|
||||
e := New(func(ctx context.Context, msgID int64, seqNo int32, in bin.Encoder) error {
|
||||
req := request{
|
||||
MsgID: msgID,
|
||||
SeqNo: seqNo,
|
||||
Input: in,
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case requests <- req:
|
||||
return nil
|
||||
}
|
||||
}, cfg)
|
||||
|
||||
var g errgroup.Group
|
||||
g.Go(func() error { return server(t, e, requests) })
|
||||
g.Go(func() error { return client(t, e) })
|
||||
|
||||
require.NoError(t, g.Wait())
|
||||
e.Close()
|
||||
require.NoError(t, cfg.Logger.Sync())
|
||||
}
|
||||
Reference in New Issue
Block a user