Files
mautrix-telegram/pkg/gotd/proto/codec/codec_test.go
T
2025-06-27 20:03:37 -07:00

198 lines
4.4 KiB
Go

package codec
import (
"bytes"
"encoding/binary"
"io"
"strings"
"testing"
"testing/iotest"
"github.com/stretchr/testify/require"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
"go.mau.fi/mautrix-telegram/pkg/gotd/testutil"
)
type codecTest struct {
name string
align int
create func() Codec
}
func codecs() []codecTest {
return []codecTest{
{"Abridged", 4, func() Codec {
return Abridged{}
}},
{"Intermediate", 4, func() Codec {
return Intermediate{}
}},
{"PaddedIntermediate", 4, func() Codec {
return PaddedIntermediate{}
}},
{"Full", 0, func() Codec {
return &Full{}
}},
{"NoHeaderIntermediate", 0, func() Codec {
return NoHeader{Codec: Intermediate{}}
}},
}
}
type payload struct {
name string
testData string
mustFail bool
readTestOnly bool
}
func payloads() []payload {
var code [4]byte
binary.LittleEndian.PutUint32(code[:], CodeTransportFlood)
return []payload{
{"Empty", "", true, false},
{"Protocol error", string(code[:]), true, true},
{"Small 8b", "abcdabcd", false, false},
{"Medium 1kb", strings.Repeat("a", 1024), false, false},
}
}
func testGood(c codecTest, p payload) func(t *testing.T) {
return func(t *testing.T) {
t.Run("One message", func(t *testing.T) {
a := require.New(t)
codec := c.create()
buf := bytes.NewBuffer(nil)
payload := &bin.Buffer{Buf: []byte(p.testData)}
l := payload.Len()
// Encode
a.NoError(codec.Write(buf, payload))
a.Equal(l, payload.Len(), "Codec must not change buffer length")
// Decode
payload.Reset()
a.NoError(codec.Read(buf, payload))
a.Equal(p.testData, string(payload.Buf))
})
t.Run("Two messages", func(t *testing.T) {
a := require.New(t)
codec := c.create()
buf := bytes.NewBuffer(nil)
payload := &bin.Buffer{Buf: []byte(p.testData)}
l := payload.Len()
// Encode twice
a.NoError(codec.Write(buf, payload))
payload.ResetTo([]byte(p.testData))
a.NoError(codec.Write(buf, payload))
a.Equal(l, payload.Len(), "Codec must not change buffer length")
// Decode twice
payload.Reset()
a.NoError(codec.Read(buf, payload))
a.Equal(p.testData, string(payload.Buf))
payload.Reset()
a.NoError(codec.Read(buf, payload))
a.Equal(p.testData, string(payload.Buf))
})
}
}
func testBad(c codecTest, p payload) func(t *testing.T) {
return func(t *testing.T) {
if !p.readTestOnly {
t.Run("Write", func(t *testing.T) {
payload := &bin.Buffer{Buf: []byte(p.testData)}
require.Error(t, c.create().Write(io.Discard, payload))
})
}
t.Run("Read", func(t *testing.T) {
reader := bytes.NewBufferString(p.testData)
require.Error(t, c.create().Read(reader, &bin.Buffer{}))
})
}
}
func testHeaderTag(c codecTest) func(t *testing.T) {
e := io.ErrClosedPipe
return func(t *testing.T) {
t.Run("GoodTag", func(t *testing.T) {
a := require.New(t)
buf := bytes.NewBuffer(nil)
a.NoError(c.create().WriteHeader(buf))
a.NoError(c.create().ReadHeader(buf))
})
if tagged, ok := c.create().(TaggedCodec); ok {
t.Run("ReadError", func(t *testing.T) {
r := iotest.ErrReader(e)
require.ErrorIs(t, c.create().ReadHeader(r), e)
})
t.Run("WriteError", func(t *testing.T) {
w := testutil.ErrWriter(e)
require.ErrorIs(t, c.create().WriteHeader(w), e)
})
t.Run("BadTag", func(t *testing.T) {
tag := tagged.ObfuscatedTag()
buf := bytes.NewBuffer(tag[:])
tag[0] = 0
require.ErrorIs(t, c.create().ReadHeader(buf), ErrProtocolHeaderMismatch)
})
}
}
}
func testCodec(c codecTest) func(t *testing.T) {
e := io.ErrClosedPipe
return func(t *testing.T) {
t.Run("ReadError", func(t *testing.T) {
r := iotest.ErrReader(e)
require.ErrorIs(t, c.create().Read(r, &bin.Buffer{}), e)
})
t.Run("WriteError", func(t *testing.T) {
w := testutil.ErrWriter(e)
require.ErrorIs(t,
c.create().Write(
w,
&bin.Buffer{Buf: make([]byte, 16)},
),
e,
)
})
if c.align != 0 {
t.Run("AlignError", func(t *testing.T) {
require.Error(t,
c.create().Write(
io.Discard,
&bin.Buffer{Buf: make([]byte, c.align-1)},
),
)
})
}
}
}
func TestCodecs(t *testing.T) {
for _, c := range codecs() {
t.Run(c.name, func(t *testing.T) {
for _, p := range payloads() {
if p.mustFail {
t.Run(p.name, testBad(c, p))
} else {
t.Run(p.name, testGood(c, p))
}
}
t.Run("Header", testHeaderTag(c))
testCodec(c)(t)
})
}
}