157 lines
3.5 KiB
Go
157 lines
3.5 KiB
Go
package telegram
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"math/rand"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/go-faster/errors"
|
|
"github.com/stretchr/testify/require"
|
|
"go.uber.org/zap/zaptest"
|
|
|
|
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
|
"go.mau.fi/mautrix-telegram/pkg/gotd/clock"
|
|
"go.mau.fi/mautrix-telegram/pkg/gotd/mtproto"
|
|
"go.mau.fi/mautrix-telegram/pkg/gotd/pool"
|
|
"go.mau.fi/mautrix-telegram/pkg/gotd/rpc"
|
|
"go.mau.fi/mautrix-telegram/pkg/gotd/tdsync"
|
|
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/internal/manager"
|
|
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
|
"go.mau.fi/mautrix-telegram/pkg/gotd/tgerr"
|
|
)
|
|
|
|
type migrationTestHandler func(id int64, dc int, body bin.Encoder) (bin.Encoder, error)
|
|
|
|
type migrateTestConn struct {
|
|
testConn
|
|
dc int
|
|
cfg tg.Config
|
|
opts mtproto.Options
|
|
client *Client
|
|
}
|
|
|
|
func (c *migrateTestConn) Ping(ctx context.Context) error {
|
|
return errors.New("not implemented")
|
|
}
|
|
|
|
func (c *migrateTestConn) Run(ctx context.Context) error {
|
|
cfg := c.cfg
|
|
cfg.ThisDC = c.dc
|
|
if err := c.client.onSession(cfg, mtproto.Session{
|
|
ID: 10,
|
|
Key: c.opts.Key,
|
|
Salt: 10,
|
|
}); err != nil {
|
|
return err
|
|
}
|
|
|
|
<-ctx.Done()
|
|
return ctx.Err()
|
|
}
|
|
|
|
func newMigrationClient(t *testing.T, h migrationTestHandler) *Client {
|
|
cfg := tg.Config{
|
|
ThisDC: 2,
|
|
DCOptions: []tg.DCOption{
|
|
{
|
|
ID: 10,
|
|
IPAddress: "10",
|
|
},
|
|
{
|
|
ID: 2,
|
|
IPAddress: "2",
|
|
},
|
|
},
|
|
}
|
|
|
|
var client *Client
|
|
creator := func(
|
|
create mtproto.Dialer,
|
|
mode manager.ConnMode,
|
|
appID int,
|
|
opts mtproto.Options,
|
|
connOpts manager.ConnOptions,
|
|
) pool.Conn {
|
|
var engine *rpc.Engine
|
|
|
|
ready := tdsync.NewReady()
|
|
ready.Signal()
|
|
engine = rpc.New(func(ctx context.Context, msgID int64, seqNo int32, in bin.Encoder) error {
|
|
if response, err := h(msgID, connOpts.DC, in); err != nil {
|
|
engine.NotifyError(msgID, err)
|
|
} else {
|
|
var b bin.Buffer
|
|
if err := b.Encode(response); err != nil {
|
|
return err
|
|
}
|
|
return engine.NotifyResult(msgID, &b)
|
|
}
|
|
return nil
|
|
}, rpc.Options{})
|
|
|
|
return &migrateTestConn{
|
|
testConn: testConn{engine: engine, ready: ready},
|
|
dc: connOpts.DC,
|
|
cfg: cfg,
|
|
opts: opts,
|
|
client: client,
|
|
}
|
|
}
|
|
|
|
client = &Client{
|
|
log: zaptest.NewLogger(t),
|
|
rand: rand.New(rand.NewSource(1)),
|
|
appID: TestAppID,
|
|
appHash: TestAppHash,
|
|
create: creator,
|
|
clock: clock.System,
|
|
session: pool.NewSyncSession(pool.Session{
|
|
DC: 2,
|
|
}),
|
|
newConnBackoff: defaultBackoff(clock.System),
|
|
ctx: context.Background(),
|
|
cancel: func(error) {},
|
|
migrationTimeout: 10 * time.Second,
|
|
}
|
|
client.init()
|
|
client.conn = client.createConn(0, manager.ConnModeUpdates, nil, nil, nil)
|
|
client.cfg.Store(cfg)
|
|
return client
|
|
}
|
|
|
|
func TestMigration(t *testing.T) {
|
|
codes := []int{303, 400}
|
|
|
|
for _, code := range codes {
|
|
t.Run(fmt.Sprintf("Code%d", code), func(t *testing.T) {
|
|
ctx := context.Background()
|
|
expected := &tg.BoolTrue{}
|
|
a := require.New(t)
|
|
|
|
client := newMigrationClient(t, func(id int64, dc int, body bin.Encoder) (bin.Encoder, error) {
|
|
switch body.(type) {
|
|
case *tg.UsersGetUsersRequest:
|
|
return nil, tgerr.New(401, "AUTH_KEY_UNREGISTERED")
|
|
case *tg.AuthLogOutRequest:
|
|
if dc == 2 {
|
|
return nil, tgerr.New(code, "USER_MIGRATE_10")
|
|
}
|
|
|
|
a.Equal(10, dc)
|
|
return expected, nil
|
|
default:
|
|
return nil, errors.Errorf("unexpected body %T", body)
|
|
}
|
|
})
|
|
|
|
err := client.Run(ctx, func(ctx context.Context) error {
|
|
var result tg.BoolTrue
|
|
return client.Invoke(ctx, &tg.AuthLogOutRequest{}, &result)
|
|
})
|
|
a.NoError(err)
|
|
})
|
|
}
|
|
}
|