7a04f298d2
- update to latest telegram layer - remove some references to fields in tg.Entities that don't exist in the schema - originally added here: https://github.com/beeper/td/commit/820929062a2ba0104397bc01235ab58a9cff780e - referenced here - https://github.com/mautrix/telegramgo/commit/124f0967ed195b5a380c9bd02e170ada9710dde3 - https://github.com/mautrix/telegramgo/commit/4205047aab2e0639217148b5d125bfaab668bd8e
301 lines
7.0 KiB
Go
301 lines
7.0 KiB
Go
package downloader
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/aes"
|
|
"crypto/cipher"
|
|
"crypto/rand"
|
|
"encoding/binary"
|
|
"io"
|
|
"runtime"
|
|
"strconv"
|
|
"testing"
|
|
|
|
"github.com/go-faster/errors"
|
|
"github.com/stretchr/testify/require"
|
|
|
|
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
|
|
"go.mau.fi/mautrix-telegram/pkg/gotd/syncio"
|
|
"go.mau.fi/mautrix-telegram/pkg/gotd/testutil"
|
|
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
|
)
|
|
|
|
type mock struct {
|
|
data []byte
|
|
hashes mockHashes
|
|
migrate bool
|
|
err bool
|
|
hashesErr bool
|
|
redirect *tg.UploadFileCDNRedirect
|
|
}
|
|
|
|
var testErr = testutil.TestError()
|
|
|
|
func (m mock) getPart(offset int64, limit int) []byte {
|
|
length := len(m.data)
|
|
if offset >= int64(length) {
|
|
return []byte{}
|
|
}
|
|
|
|
size := length - int(offset)
|
|
if size > limit {
|
|
size = limit
|
|
}
|
|
|
|
r := make([]byte, size)
|
|
copy(r, m.data[offset:])
|
|
return r
|
|
}
|
|
|
|
func (m mock) UploadGetFile(ctx context.Context, request *tg.UploadGetFileRequest) (tg.UploadFileClass, error) {
|
|
if m.err {
|
|
return nil, testErr
|
|
}
|
|
|
|
if m.migrate {
|
|
return m.redirect, nil
|
|
}
|
|
|
|
return &tg.UploadFile{
|
|
Bytes: m.getPart(request.Offset, request.Limit),
|
|
}, nil
|
|
}
|
|
|
|
func (m mock) UploadGetFileHashes(ctx context.Context, request *tg.UploadGetFileHashesRequest) ([]tg.FileHash, error) {
|
|
if m.hashesErr {
|
|
return nil, testErr
|
|
}
|
|
|
|
return m.hashes.Hashes(ctx, request.Offset)
|
|
}
|
|
|
|
func (m mock) UploadReuploadCDNFile(ctx context.Context, request *tg.UploadReuploadCDNFileRequest) ([]tg.FileHash, error) {
|
|
panic("implement me")
|
|
}
|
|
|
|
func (m mock) UploadGetCDNFile(ctx context.Context, request *tg.UploadGetCDNFileRequest) (tg.UploadCDNFileClass, error) {
|
|
if m.err {
|
|
return nil, testErr
|
|
}
|
|
|
|
if m.migrate {
|
|
return &tg.UploadCDNFileReuploadNeeded{
|
|
RequestToken: []byte{1, 2, 3},
|
|
}, nil
|
|
}
|
|
|
|
block, err := aes.NewCipher(m.redirect.EncryptionKey)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "CDN mock cipher creation")
|
|
}
|
|
|
|
iv := make([]byte, len(m.redirect.EncryptionIv))
|
|
copy(iv, m.redirect.EncryptionIv)
|
|
binary.BigEndian.PutUint32(iv[len(iv)-4:], uint32(request.Offset/16))
|
|
|
|
part := m.getPart(request.Offset, request.Limit)
|
|
r := make([]byte, len(part))
|
|
cipher.NewCTR(block, iv).XORKeyStream(r, part)
|
|
return &tg.UploadCDNFile{
|
|
Bytes: r,
|
|
}, nil
|
|
}
|
|
|
|
func (m mock) UploadGetCDNFileHashes(ctx context.Context, request *tg.UploadGetCDNFileHashesRequest) ([]tg.FileHash, error) {
|
|
if m.hashesErr {
|
|
return nil, testErr
|
|
}
|
|
|
|
return m.hashes.Hashes(ctx, request.Offset)
|
|
}
|
|
|
|
func (m mock) UploadGetWebFile(ctx context.Context, request *tg.UploadGetWebFileRequest) (*tg.UploadWebFile, error) {
|
|
if m.err {
|
|
return nil, testErr
|
|
}
|
|
|
|
return &tg.UploadWebFile{
|
|
Bytes: m.getPart(int64(request.Offset), request.Limit),
|
|
}, nil
|
|
}
|
|
|
|
func countHashes(data []byte, partSize int) (r [][]tg.FileHash) {
|
|
actions := data
|
|
batchSize := partSize
|
|
batches := make([][]byte, 0, (len(actions)+batchSize-1)/batchSize)
|
|
|
|
for batchSize < len(actions) {
|
|
actions, batches = actions[batchSize:], append(batches, actions[0:batchSize:batchSize])
|
|
}
|
|
batches = append(batches, actions)
|
|
|
|
currentRange := make([]tg.FileHash, 0, 10)
|
|
offset := 0
|
|
for _, batch := range batches {
|
|
if len(currentRange) >= 10 {
|
|
r = append(r, currentRange)
|
|
currentRange = make([]tg.FileHash, 0, 10)
|
|
}
|
|
currentRange = append(currentRange, tg.FileHash{
|
|
Offset: int64(offset),
|
|
Limit: partSize,
|
|
Hash: crypto.SHA256(batch),
|
|
})
|
|
offset += len(batch)
|
|
|
|
if len(batch) < partSize {
|
|
break
|
|
}
|
|
}
|
|
r = append(r, currentRange)
|
|
return
|
|
}
|
|
|
|
func Test_countHashes(t *testing.T) {
|
|
a := require.New(t)
|
|
data := bytes.Repeat([]byte{1, 2, 3, 4, 5}, 10)
|
|
hashes := countHashes(data, 4)
|
|
|
|
a.NotEmpty(hashes)
|
|
for _, hashRange := range hashes {
|
|
for _, hash := range hashRange {
|
|
from := hash.Offset
|
|
to := int(hash.Offset) + hash.Limit
|
|
if to > len(data) {
|
|
to = len(data)
|
|
}
|
|
a.Equal(crypto.SHA256(data[from:to]), hash.Hash)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestDownloader(t *testing.T) {
|
|
ctx := context.Background()
|
|
|
|
key := make([]byte, 32)
|
|
iv := make([]byte, aes.BlockSize)
|
|
if _, err := io.ReadFull(rand.Reader, key); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if _, err := io.ReadFull(rand.Reader, iv); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
redirect := &tg.UploadFileCDNRedirect{
|
|
DCID: 1,
|
|
FileToken: []byte{10},
|
|
EncryptionKey: key,
|
|
EncryptionIv: iv,
|
|
}
|
|
|
|
testData := make([]byte, defaultPartSize*2)
|
|
if _, err := io.ReadFull(rand.Reader, testData); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
tests := []struct {
|
|
name string
|
|
data []byte
|
|
migrate bool
|
|
err bool
|
|
hashesErr bool
|
|
}{
|
|
{"5b", []byte{1, 2, 3, 4, 5}, false, false, false},
|
|
{strconv.Itoa(len(testData)) + "b", testData, false, false, false},
|
|
{"Error", []byte{}, false, true, false},
|
|
{"HashesError", []byte{}, false, true, true},
|
|
{"Migrate", []byte{}, true, false, false},
|
|
}
|
|
schemas := []struct {
|
|
name string
|
|
creator func(c Client, cdn CDN) *Builder
|
|
}{
|
|
{"Master", func(c Client, cdn CDN) *Builder {
|
|
return NewDownloader().Download(c, nil)
|
|
}},
|
|
{"Web", func(c Client, cdn CDN) *Builder {
|
|
return NewDownloader().Web(c, nil)
|
|
}},
|
|
}
|
|
ways := []struct {
|
|
name string
|
|
action func(b *Builder) ([]byte, error)
|
|
}{
|
|
{"Stream", func(b *Builder) ([]byte, error) {
|
|
output := new(bytes.Buffer)
|
|
_, err := b.Stream(ctx, output)
|
|
return output.Bytes(), err
|
|
}},
|
|
{"Parallel", func(b *Builder) ([]byte, error) {
|
|
output := new(syncio.BufWriterAt)
|
|
_, err := b.WithThreads(runtime.GOMAXPROCS(0)).Parallel(ctx, output)
|
|
return output.Bytes(), err
|
|
}},
|
|
{"Parallel-OneThread", func(b *Builder) ([]byte, error) {
|
|
output := new(syncio.BufWriterAt)
|
|
_, err := b.WithThreads(1).Parallel(ctx, output)
|
|
return output.Bytes(), err
|
|
}},
|
|
}
|
|
options := []struct {
|
|
name string
|
|
action func(b *Builder) *Builder
|
|
}{
|
|
{"NoVerify", func(b *Builder) *Builder {
|
|
return b.WithVerify(false)
|
|
}},
|
|
{"Verify", func(b *Builder) *Builder {
|
|
return b.WithVerify(true)
|
|
}},
|
|
}
|
|
|
|
for _, schema := range schemas {
|
|
t.Run(schema.name, func(t *testing.T) {
|
|
for _, test := range tests {
|
|
// Telegram can't redirect web file downloads.
|
|
if schema.name == "Web" && test.migrate {
|
|
continue
|
|
}
|
|
t.Run(test.name, func(t *testing.T) {
|
|
for _, option := range options {
|
|
// Telegram can't return hashes for web files.
|
|
if schema.name == "Web" && option.name == "Verify" {
|
|
continue
|
|
}
|
|
|
|
t.Run(option.name, func(t *testing.T) {
|
|
for _, way := range ways {
|
|
t.Run(way.name, func(t *testing.T) {
|
|
a := require.New(t)
|
|
client := &mock{
|
|
data: test.data,
|
|
hashes: mockHashes{
|
|
ranges: countHashes(test.data, 128*1024),
|
|
},
|
|
migrate: test.migrate,
|
|
err: test.err,
|
|
redirect: redirect,
|
|
}
|
|
|
|
b := schema.creator(client, client)
|
|
b = option.action(b)
|
|
data, err := way.action(b)
|
|
switch {
|
|
case test.migrate:
|
|
a.Error(err)
|
|
case test.err:
|
|
a.Error(err)
|
|
default:
|
|
a.NoError(err)
|
|
a.True(bytes.Equal(test.data, data))
|
|
}
|
|
})
|
|
}
|
|
})
|
|
}
|
|
})
|
|
}
|
|
})
|
|
}
|
|
}
|