Files
mautrix-telegram/pkg/gotd/telegram/uploader/uploader_test.go
T
2025-06-27 20:03:37 -07:00

225 lines
4.9 KiB
Go

package uploader
import (
"bytes"
"context"
"crypto/rand"
"io"
"net/url"
"runtime"
"strconv"
"sync"
"testing"
"testing/fstest"
"github.com/go-faster/errors"
"github.com/stretchr/testify/require"
"go.uber.org/atomic"
"go.mau.fi/mautrix-telegram/pkg/gotd/syncio"
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/uploader/source"
"go.mau.fi/mautrix-telegram/pkg/gotd/testutil"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
)
type mockClient struct {
err bool
// Upload state.
buf *syncio.BufWriterAt
parts []atomic.Int64
partSize int
partSizeMux sync.Mutex
}
func newMockClient(err bool) *mockClient {
return &mockClient{
err: err,
buf: &syncio.BufWriterAt{},
parts: make([]atomic.Int64, partsLimit+1),
}
}
var testErr = testutil.TestError()
func (m *mockClient) write(part int, data []byte) error {
m.partSizeMux.Lock()
if m.partSize == 0 {
m.partSize = len(data)
} else if m.partSize != len(data) {
m.partSizeMux.Unlock()
return errors.Errorf(
"invalid part size, expected %d, got %d",
m.partSize, len(data),
)
}
partSize := m.partSize
m.partSizeMux.Unlock()
// Every part have ID which is offset in partSize from start of file.
// But maximal ID is 3999, so part ID for big files can overflow.
// We use parts array to count received parts by ID to compute the offset.
rangeOffset := int(m.parts[part].Inc() - 1)
// If rangeOffset is zero, so offset will be zero, part ID received first time.
// Otherwise, we count next range offset.
offset := rangeOffset * partsLimit * partSize
_, err := m.buf.WriteAt(data, int64(part*partSize+offset))
return err
}
func (m *mockClient) UploadSaveFilePart(ctx context.Context, request *tg.UploadSaveFilePartRequest) (bool, error) {
if m.err {
return false, testErr
}
if err := m.write(request.FilePart, request.Bytes); err != nil {
return false, err
}
return true, nil
}
func (m *mockClient) UploadSaveBigFilePart(ctx context.Context, request *tg.UploadSaveBigFilePartRequest) (bool, error) {
if m.err {
return false, testErr
}
if err := m.write(request.FilePart, request.Bytes); err != nil {
return false, err
}
return true, nil
}
type mockSource struct {
name string
data *bytes.Reader
}
func (m mockSource) Open(ctx context.Context, u *url.URL) (source.RemoteFile, error) {
return m, nil
}
func (m mockSource) Read(p []byte) (n int, err error) {
return m.data.Read(p)
}
func (m mockSource) Close() error {
return nil
}
func (m mockSource) Name() string {
return m.name
}
func (m mockSource) Size() int64 {
return m.data.Size()
}
func TestUploader(t *testing.T) {
ctx := context.Background()
testData := make([]byte, 15*1024*1024)
if _, err := io.ReadFull(rand.Reader, testData); err != nil {
t.Fatal(err)
}
tests := []struct {
name string
data []byte
err bool
}{
{"5b", []byte{1, 2, 3, 4, 5}, false},
{strconv.Itoa(defaultPartSize) + "b", bytes.Repeat([]byte{1}, defaultPartSize), false},
{strconv.Itoa(len(testData)) + "b", testData, false},
{"Error", []byte{1, 2, 3, 4, 5}, true},
}
ways := []struct {
name string
action func(b *Uploader, data []byte) error
}{
{"FromReader", func(b *Uploader, data []byte) error {
if len(data) == len(testData) {
b = b.WithPartSize(16384)
}
_, err := b.FromReader(ctx, "10.jpg", bytes.NewReader(data))
return err
}},
{"FromBytes", func(b *Uploader, data []byte) error {
if len(data) == len(testData) {
b = b.WithPartSize(MaximumPartSize)
}
_, err := b.FromBytes(ctx, "10.jpg", data)
return err
}},
{"FromFS", func(b *Uploader, data []byte) error {
if len(data) == len(testData) {
b = b.WithPartSize(MaximumPartSize)
}
_, err := b.FromFS(ctx, fstest.MapFS{
"10.jpg": &fstest.MapFile{
Data: data,
},
}, "10.jpg", "")
return err
}},
{"FromURL", func(b *Uploader, data []byte) error {
if len(data) == len(testData) {
b = b.WithPartSize(MaximumPartSize)
}
b = b.WithSource(mockSource{
name: "img.jpg",
data: bytes.NewReader(data),
})
_, err := b.FromURL(ctx, "http://example.com")
return err
}},
}
options := []struct {
name string
action func(b *Uploader) *Uploader
}{
{"OneThread", func(b *Uploader) *Uploader {
return b.WithThreads(1)
}},
{"ManyThread", func(b *Uploader) *Uploader {
return b.WithThreads(runtime.GOMAXPROCS(0))
}},
}
for _, way := range ways {
t.Run(way.name, func(t *testing.T) {
for _, option := range options {
t.Run(option.name, func(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
client := newMockClient(test.err)
u := NewUploader(client)
err := way.action(option.action(u), test.data)
if test.err {
require.Error(t, err)
return
}
require.NoError(t, err)
require.Truef(
t, bytes.Equal(test.data, client.buf.Bytes()),
"expected uploaded and given equal",
)
})
}
})
}
})
}
}