Files
mautrix-telegram/pkg/gotd/telegram/migrate_to_dc_test.go
T
2025-06-27 20:03:37 -07:00

157 lines
3.5 KiB
Go

package telegram
import (
"context"
"fmt"
"math/rand"
"testing"
"time"
"github.com/go-faster/errors"
"github.com/stretchr/testify/require"
"go.uber.org/zap/zaptest"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
"go.mau.fi/mautrix-telegram/pkg/gotd/clock"
"go.mau.fi/mautrix-telegram/pkg/gotd/mtproto"
"go.mau.fi/mautrix-telegram/pkg/gotd/pool"
"go.mau.fi/mautrix-telegram/pkg/gotd/rpc"
"go.mau.fi/mautrix-telegram/pkg/gotd/tdsync"
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/internal/manager"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
"go.mau.fi/mautrix-telegram/pkg/gotd/tgerr"
)
type migrationTestHandler func(id int64, dc int, body bin.Encoder) (bin.Encoder, error)
type migrateTestConn struct {
testConn
dc int
cfg tg.Config
opts mtproto.Options
client *Client
}
func (c *migrateTestConn) Ping(ctx context.Context) error {
return errors.New("not implemented")
}
func (c *migrateTestConn) Run(ctx context.Context) error {
cfg := c.cfg
cfg.ThisDC = c.dc
if err := c.client.onSession(cfg, mtproto.Session{
ID: 10,
Key: c.opts.Key,
Salt: 10,
}); err != nil {
return err
}
<-ctx.Done()
return ctx.Err()
}
func newMigrationClient(t *testing.T, h migrationTestHandler) *Client {
cfg := tg.Config{
ThisDC: 2,
DCOptions: []tg.DCOption{
{
ID: 10,
IPAddress: "10",
},
{
ID: 2,
IPAddress: "2",
},
},
}
var client *Client
creator := func(
create mtproto.Dialer,
mode manager.ConnMode,
appID int,
opts mtproto.Options,
connOpts manager.ConnOptions,
) pool.Conn {
var engine *rpc.Engine
ready := tdsync.NewReady()
ready.Signal()
engine = rpc.New(func(ctx context.Context, msgID int64, seqNo int32, in bin.Encoder) error {
if response, err := h(msgID, connOpts.DC, in); err != nil {
engine.NotifyError(msgID, err)
} else {
var b bin.Buffer
if err := b.Encode(response); err != nil {
return err
}
return engine.NotifyResult(msgID, &b)
}
return nil
}, rpc.Options{})
return &migrateTestConn{
testConn: testConn{engine: engine, ready: ready},
dc: connOpts.DC,
cfg: cfg,
opts: opts,
client: client,
}
}
client = &Client{
log: zaptest.NewLogger(t),
rand: rand.New(rand.NewSource(1)),
appID: TestAppID,
appHash: TestAppHash,
create: creator,
clock: clock.System,
session: pool.NewSyncSession(pool.Session{
DC: 2,
}),
newConnBackoff: defaultBackoff(clock.System),
ctx: context.Background(),
cancel: func() {},
migrationTimeout: 10 * time.Second,
}
client.init()
client.conn = client.createConn(0, manager.ConnModeUpdates, nil, nil, nil)
client.cfg.Store(cfg)
return client
}
func TestMigration(t *testing.T) {
codes := []int{303, 400}
for _, code := range codes {
t.Run(fmt.Sprintf("Code%d", code), func(t *testing.T) {
ctx := context.Background()
expected := &tg.BoolTrue{}
a := require.New(t)
client := newMigrationClient(t, func(id int64, dc int, body bin.Encoder) (bin.Encoder, error) {
switch body.(type) {
case *tg.UsersGetUsersRequest:
return nil, tgerr.New(401, "AUTH_KEY_UNREGISTERED")
case *tg.AuthLogOutRequest:
if dc == 2 {
return nil, tgerr.New(code, "USER_MIGRATE_10")
}
a.Equal(10, dc)
return expected, nil
default:
return nil, errors.Errorf("unexpected body %T", body)
}
})
err := client.Run(ctx, func(ctx context.Context) error {
var result tg.BoolTrue
return client.Invoke(ctx, &tg.AuthLogOutRequest{}, &result)
})
a.NoError(err)
})
}
}