Files
2026-03-03 14:34:02 +02:00

156 lines
3.7 KiB
Go

package telegram
import (
"context"
"crypto/md5"
"fmt"
"math/rand"
"net"
"os"
"path/filepath"
"strings"
"testing"
"time"
"github.com/go-faster/errors"
"github.com/cenkalti/backoff/v4"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/atomic"
"go.uber.org/zap"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
"go.mau.fi/mautrix-telegram/pkg/gotd/mt"
"go.mau.fi/mautrix-telegram/pkg/gotd/proto"
"go.mau.fi/mautrix-telegram/pkg/gotd/rpc"
"go.mau.fi/mautrix-telegram/pkg/gotd/tdsync"
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/dcs"
"go.mau.fi/mautrix-telegram/pkg/gotd/testutil"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
"go.mau.fi/mautrix-telegram/pkg/gotd/tgmock"
"go.mau.fi/mautrix-telegram/pkg/gotd/tmap"
)
type testHandler func(id int64, body bin.Encoder) (bin.Encoder, error)
type testConn struct {
id atomic.Int64
engine *rpc.Engine
ready *tdsync.Ready
}
func (t *testConn) Ping(ctx context.Context) error {
return errors.New("not implemented")
}
func (t *testConn) Ready() <-chan struct{} {
return t.ready.Ready()
}
func (t *testConn) Invoke(ctx context.Context, input bin.Encoder, output bin.Decoder) error {
id := t.id.Inc() - 1
return t.engine.Do(ctx, rpc.Request{
Input: input,
Output: output,
MsgID: id,
})
}
func (testConn) Run(ctx context.Context) error { return nil }
func newTestClient(h testHandler) *Client {
var engine *rpc.Engine
engine = rpc.New(func(ctx context.Context, msgID int64, seqNo int32, in bin.Encoder) error {
if response, err := h(msgID, in); err != nil {
engine.NotifyError(msgID, err)
} else {
var b bin.Buffer
if err := b.Encode(response); err != nil {
return err
}
return engine.NotifyResult(msgID, &b)
}
return nil
}, rpc.Options{})
ready := tdsync.NewReady()
ready.Signal()
client := &Client{
log: zap.NewNop(),
rand: rand.New(rand.NewSource(1)),
appID: TestAppID,
appHash: TestAppHash,
conn: &testConn{engine: engine, ready: ready},
ctx: context.Background(),
cancel: func(error) {},
updateHandler: UpdateHandlerFunc(func(ctx context.Context, u tg.UpdatesClass) error { return nil }),
onTransfer: noopOnTransfer,
}
client.init()
return client
}
func mockClient(cb func(mock *tgmock.Mock, client *Client)) func(t *testing.T) {
return func(t *testing.T) {
t.Helper()
mock := tgmock.NewRequire(t)
client := newTestClient(testHandler(mock.Handler()))
cb(mock, client)
}
}
func TestEnsureErrorIfCantConnect(t *testing.T) {
testErr := testutil.TestError()
dialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
return nil, testErr
}
opts := Options{
Resolver: dcs.Plain(dcs.PlainOptions{Dial: dialer}),
ReconnectionBackoff: func() backoff.BackOff {
return backoff.WithMaxRetries(backoff.NewConstantBackOff(time.Nanosecond), 2)
},
}
err := NewClient(1, "hash", opts).Run(context.Background(),
func(ctx context.Context) error {
return nil
})
require.ErrorIs(t, err, testErr)
}
// newCorpusTracer will save incoming messages to corpus folder.
//
// Usage:
//
// client.trace.OnMessage = newCorpusTracer(t)
//
// nolint: deadcode,unused // optional
func newCorpusTracer(t testing.TB) func(b *bin.Buffer) {
types := tmap.New(
mt.TypesMap(),
tg.TypesMap(),
proto.TypesMap(),
)
dir := filepath.Join("..", "_fuzz", "handle_message", "corpus")
return func(b *bin.Buffer) {
id, _ := b.PeekID()
h := md5.Sum(b.Buf)
name := types.Get(id)
if name == "" {
name = "unknown"
}
if idx := strings.Index(name, "#"); idx > 0 {
// Removing type id from name.
name = name[:idx]
}
base := fmt.Sprintf("trace_%x_%s_%x",
id, name, h,
)
assert.NoError(t, os.WriteFile(filepath.Join(dir, base), b.Buf, 0600))
}
}