Files
mautrix-telegram/pkg/gotd/telegram/downloader/downloader_test.go
T
2025-06-27 20:03:37 -07:00

301 lines
7.0 KiB
Go

package downloader
import (
"bytes"
"context"
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/binary"
"io"
"runtime"
"strconv"
"testing"
"github.com/go-faster/errors"
"github.com/stretchr/testify/require"
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
"go.mau.fi/mautrix-telegram/pkg/gotd/syncio"
"go.mau.fi/mautrix-telegram/pkg/gotd/testutil"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
)
type mock struct {
data []byte
hashes mockHashes
migrate bool
err bool
hashesErr bool
redirect *tg.UploadFileCDNRedirect
}
var testErr = testutil.TestError()
func (m mock) getPart(offset int64, limit int) []byte {
length := len(m.data)
if offset >= int64(length) {
return []byte{}
}
size := length - int(offset)
if size > limit {
size = limit
}
r := make([]byte, size)
copy(r, m.data[offset:])
return r
}
func (m mock) UploadGetFile(ctx context.Context, request *tg.UploadGetFileRequest) (tg.UploadFileClass, error) {
if m.err {
return nil, testErr
}
if m.migrate {
return m.redirect, nil
}
return &tg.UploadFile{
Bytes: m.getPart(request.Offset, request.Limit),
}, nil
}
func (m mock) UploadGetFileHashes(ctx context.Context, request *tg.UploadGetFileHashesRequest) ([]tg.FileHash, error) {
if m.hashesErr {
return nil, testErr
}
return m.hashes.Hashes(ctx, request.Offset)
}
func (m mock) UploadReuploadCDNFile(ctx context.Context, request *tg.UploadReuploadCDNFileRequest) ([]tg.FileHash, error) {
panic("implement me")
}
func (m mock) UploadGetCDNFile(ctx context.Context, request *tg.UploadGetCDNFileRequest) (tg.UploadCDNFileClass, error) {
if m.err {
return nil, testErr
}
if m.migrate {
return &tg.UploadCDNFileReuploadNeeded{
RequestToken: []byte{1, 2, 3},
}, nil
}
block, err := aes.NewCipher(m.redirect.EncryptionKey)
if err != nil {
return nil, errors.Wrap(err, "CDN mock cipher creation")
}
iv := make([]byte, len(m.redirect.EncryptionIv))
copy(iv, m.redirect.EncryptionIv)
binary.BigEndian.PutUint32(iv[len(iv)-4:], uint32(request.Offset/16))
part := m.getPart(request.Offset, request.Limit)
r := make([]byte, len(part))
cipher.NewCTR(block, iv).XORKeyStream(r, part)
return &tg.UploadCDNFile{
Bytes: r,
}, nil
}
func (m mock) UploadGetCDNFileHashes(ctx context.Context, request *tg.UploadGetCDNFileHashesRequest) ([]tg.FileHash, error) {
if m.hashesErr {
return nil, testErr
}
return m.hashes.Hashes(ctx, request.Offset)
}
func (m mock) UploadGetWebFile(ctx context.Context, request *tg.UploadGetWebFileRequest) (*tg.UploadWebFile, error) {
if m.err {
return nil, testErr
}
return &tg.UploadWebFile{
Bytes: m.getPart(int64(request.Offset), request.Limit),
}, nil
}
func countHashes(data []byte, partSize int) (r [][]tg.FileHash) {
actions := data
batchSize := partSize
batches := make([][]byte, 0, (len(actions)+batchSize-1)/batchSize)
for batchSize < len(actions) {
actions, batches = actions[batchSize:], append(batches, actions[0:batchSize:batchSize])
}
batches = append(batches, actions)
currentRange := make([]tg.FileHash, 0, 10)
offset := 0
for _, batch := range batches {
if len(currentRange) >= 10 {
r = append(r, currentRange)
currentRange = make([]tg.FileHash, 0, 10)
}
currentRange = append(currentRange, tg.FileHash{
Offset: int64(offset),
Limit: partSize,
Hash: crypto.SHA256(batch),
})
offset += len(batch)
if len(batch) < partSize {
break
}
}
r = append(r, currentRange)
return
}
func Test_countHashes(t *testing.T) {
a := require.New(t)
data := bytes.Repeat([]byte{1, 2, 3, 4, 5}, 10)
hashes := countHashes(data, 4)
a.NotEmpty(hashes)
for _, hashRange := range hashes {
for _, hash := range hashRange {
from := hash.Offset
to := int(hash.Offset) + hash.Limit
if to > len(data) {
to = len(data)
}
a.Equal(crypto.SHA256(data[from:to]), hash.Hash)
}
}
}
func TestDownloader(t *testing.T) {
ctx := context.Background()
key := make([]byte, 32)
iv := make([]byte, aes.BlockSize)
if _, err := io.ReadFull(rand.Reader, key); err != nil {
t.Fatal(err)
}
if _, err := io.ReadFull(rand.Reader, iv); err != nil {
t.Fatal(err)
}
redirect := &tg.UploadFileCDNRedirect{
DCID: 1,
FileToken: []byte{10},
EncryptionKey: key,
EncryptionIv: iv,
}
testData := make([]byte, defaultPartSize*2)
if _, err := io.ReadFull(rand.Reader, testData); err != nil {
t.Fatal(err)
}
tests := []struct {
name string
data []byte
migrate bool
err bool
hashesErr bool
}{
{"5b", []byte{1, 2, 3, 4, 5}, false, false, false},
{strconv.Itoa(len(testData)) + "b", testData, false, false, false},
{"Error", []byte{}, false, true, false},
{"HashesError", []byte{}, false, true, true},
{"Migrate", []byte{}, true, false, false},
}
schemas := []struct {
name string
creator func(c Client, cdn CDN) *Builder
}{
{"Master", func(c Client, cdn CDN) *Builder {
return NewDownloader().Download(c, nil)
}},
{"Web", func(c Client, cdn CDN) *Builder {
return NewDownloader().Web(c, nil)
}},
}
ways := []struct {
name string
action func(b *Builder) ([]byte, error)
}{
{"Stream", func(b *Builder) ([]byte, error) {
output := new(bytes.Buffer)
_, err := b.Stream(ctx, output)
return output.Bytes(), err
}},
{"Parallel", func(b *Builder) ([]byte, error) {
output := new(syncio.BufWriterAt)
_, err := b.WithThreads(runtime.GOMAXPROCS(0)).Parallel(ctx, output)
return output.Bytes(), err
}},
{"Parallel-OneThread", func(b *Builder) ([]byte, error) {
output := new(syncio.BufWriterAt)
_, err := b.WithThreads(1).Parallel(ctx, output)
return output.Bytes(), err
}},
}
options := []struct {
name string
action func(b *Builder) *Builder
}{
{"NoVerify", func(b *Builder) *Builder {
return b.WithVerify(false)
}},
{"Verify", func(b *Builder) *Builder {
return b.WithVerify(true)
}},
}
for _, schema := range schemas {
t.Run(schema.name, func(t *testing.T) {
for _, test := range tests {
// Telegram can't redirect web file downloads.
if schema.name == "Web" && test.migrate {
continue
}
t.Run(test.name, func(t *testing.T) {
for _, option := range options {
// Telegram can't return hashes for web files.
if schema.name == "Web" && option.name == "Verify" {
continue
}
t.Run(option.name, func(t *testing.T) {
for _, way := range ways {
t.Run(way.name, func(t *testing.T) {
a := require.New(t)
client := &mock{
data: test.data,
hashes: mockHashes{
ranges: countHashes(test.data, 128*1024),
},
migrate: test.migrate,
err: test.err,
redirect: redirect,
}
b := schema.creator(client, client)
b = option.action(b)
data, err := way.action(b)
switch {
case test.migrate:
a.Error(err)
case test.err:
a.Error(err)
default:
a.NoError(err)
a.True(bytes.Equal(test.data, data))
}
})
}
})
}
})
}
})
}
}