Files
2025-06-27 20:03:37 -07:00

162 lines
3.6 KiB
Go

package mtproto
import (
"fmt"
"os"
"path/filepath"
"testing"
"time"
"github.com/go-faster/errors"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
"github.com/gotd/neo"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
"go.mau.fi/mautrix-telegram/pkg/gotd/mt"
"go.mau.fi/mautrix-telegram/pkg/gotd/rpc"
"go.mau.fi/mautrix-telegram/pkg/gotd/tdsync"
"go.mau.fi/mautrix-telegram/pkg/gotd/testutil"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
"go.mau.fi/mautrix-telegram/pkg/gotd/tmap"
)
type testUpdateHandler struct {
types *tmap.Constructor
}
func (h testUpdateHandler) OnMessage(b *bin.Buffer) error {
id, err := b.PeekID()
if err != nil {
return err
}
v := h.types.New(id)
if v == nil {
return errors.New("not found")
}
if err := v.Decode(b); err != nil {
return errors.Wrap(err, "decode")
}
return nil
}
func (testUpdateHandler) OnSession(session Session) error { return nil }
func newTestHandler() Handler {
return &testUpdateHandler{
types: tmap.NewConstructor(
tg.TypesConstructorMap(),
mt.TypesConstructorMap(),
),
}
}
func TestConnHandleMessage(t *testing.T) {
c := &Conn{
rand: Zero{},
log: zap.NewNop(),
handler: newTestHandler(),
}
for i, input := range []string{
"\xdc\xf8\xf1stewa\x00O\x03expired c" +
"ertificate\x02\x00\x00\x00\xef",
"\x01m\\\xf300000000\x19\xcaD!0000" +
"\x1100000000000000000",
"\x01m\\\xf300000000\x19\xcaD!0000" +
"\xfe0",
"@B\xaet\x15ĵ\x1c0000,\x8f\xf8B0000" +
"00000000\x15ĵ\x1c0000\xff000" +
"00000000000000000000" +
"00000000000000000000" +
"00000000000000000000" +
"00000000000000000000" +
"00000000000000000000" +
"00000000000000000000" +
"00000000000000000000" +
"00000000000000000000" +
"00000000000000000000" +
"00000000000000000000" +
"00000000000000000000" +
"00000000000000000000" +
"000000000000",
} {
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
if err := c.handleMessage(0, &bin.Buffer{Buf: []byte(input)}); err == nil {
t.Fatal("error expected")
}
})
}
}
func TestConnHandleMessageCorpus(t *testing.T) {
if testutil.Race {
t.Skip("Skipped")
}
c := &Conn{
handler: newTestHandler(),
rpc: rpc.New(rpc.NopSend, rpc.Options{}),
clock: neo.NewTime(time.Now()),
rand: Zero{},
log: zap.NewNop(),
gotSession: tdsync.NewReady(),
}
corpusDir := filepath.Join("..", "_fuzz", "handle_message", "corpus")
corpus, err := os.ReadDir(corpusDir)
if err != nil {
t.Fatal(err)
}
b := &bin.Buffer{}
types := tmap.New(
tg.TypesMap(),
mt.TypesMap(),
)
for _, f := range corpus {
t.Run(f.Name(), func(t *testing.T) {
data, err := os.ReadFile(filepath.Join(corpusDir, f.Name()))
require.NoError(t, err)
// Default to 128 bytes per invocation.
allocThreshold := 128
// Adjusting threshold for specific types.
//
// Probably there should be better way to do this, but
// manually ensuring allocation distribution by type is
// pretty ok.
b.ResetTo(data)
if id, err := b.PeekID(); err == nil {
t.Logf("Type: 0x%x %s", id, types.Get(id))
switch id {
case tg.UpdatesTypeID,
tg.TextFixedTypeID,
tg.InputPeerChannelFromMessageTypeID,
tg.PageBlockRelatedArticlesTypeID:
allocThreshold = 512
case tg.TextBoldTypeID,
tg.TextItalicTypeID,
tg.TextMarkedTypeID,
tg.MessageTypeID,
tg.PageBlockCoverTypeID,
tg.InputMediaUploadedDocumentTypeID,
tg.SecureRequiredTypeOneOfTypeID:
allocThreshold = 256
}
}
testutil.MaxAlloc(t, allocThreshold, func() {
b.ResetTo(data)
_ = c.handleMessage(0, b)
})
})
}
}