move gotd fork into repo. (#111)
- update to latest telegram layer - remove some references to fields in tg.Entities that don't exist in the schema - originally added here: https://github.com/beeper/td/commit/820929062a2ba0104397bc01235ab58a9cff780e - referenced here - https://github.com/mautrix/telegramgo/commit/124f0967ed195b5a380c9bd02e170ada9710dde3 - https://github.com/mautrix/telegramgo/commit/4205047aab2e0639217148b5d125bfaab668bd8e
This commit is contained in:
@@ -0,0 +1,87 @@
|
||||
package downloader
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
"go.uber.org/multierr"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
// Builder is a download builder.
|
||||
type Builder struct {
|
||||
downloader *Downloader
|
||||
|
||||
schema schema
|
||||
hashes []tg.FileHash
|
||||
verify bool
|
||||
threads int
|
||||
}
|
||||
|
||||
func newBuilder(downloader *Downloader, schema schema) *Builder {
|
||||
return &Builder{
|
||||
schema: schema,
|
||||
threads: 1,
|
||||
downloader: downloader,
|
||||
}
|
||||
}
|
||||
|
||||
// WithThreads sets downloading goroutines limit.
|
||||
func (b *Builder) WithThreads(threads int) *Builder {
|
||||
if threads > 0 {
|
||||
b.threads = threads
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// WithVerify sets verify parameter.
|
||||
// If verify is true, file hashes will be checked
|
||||
// Verify is true by default for CDN downloads.
|
||||
func (b *Builder) WithVerify(verify bool) *Builder {
|
||||
b.verify = verify
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *Builder) reader() *reader {
|
||||
if b.verify {
|
||||
return verifiedReader(b.schema, newVerifier(b.schema, b.hashes...))
|
||||
}
|
||||
|
||||
return plainReader(b.schema, b.downloader.partSize)
|
||||
}
|
||||
|
||||
// Stream downloads file to given io.Writer.
|
||||
// NB: in this mode download can't be parallel.
|
||||
func (b *Builder) Stream(ctx context.Context, output io.Writer) (tg.StorageFileTypeClass, error) {
|
||||
return b.downloader.stream(ctx, b.reader(), output)
|
||||
}
|
||||
|
||||
// StreamToReader streams a file to the returned [io.Reader].
|
||||
// NB: in this mode download can't be parallel.
|
||||
func (b *Builder) StreamToReader(ctx context.Context) (tg.StorageFileTypeClass, io.Reader, error) {
|
||||
var tgDC int
|
||||
ctx = context.WithValue(ctx, "tg_dc", &tgDC)
|
||||
return b.downloader.streamToReader(ctx, b.reader())
|
||||
}
|
||||
|
||||
// Parallel downloads file to given io.WriterAt.
|
||||
func (b *Builder) Parallel(ctx context.Context, output io.WriterAt) (tg.StorageFileTypeClass, error) {
|
||||
return b.downloader.parallel(ctx, b.reader(), b.threads, output)
|
||||
}
|
||||
|
||||
// ToPath downloads file to given path.
|
||||
func (b *Builder) ToPath(ctx context.Context, path string) (_ tg.StorageFileTypeClass, err error) {
|
||||
f, err := os.Create(filepath.Clean(path))
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "create output file")
|
||||
}
|
||||
defer func() {
|
||||
multierr.AppendInto(&err, f.Close())
|
||||
}()
|
||||
|
||||
return b.Parallel(ctx, f)
|
||||
}
|
||||
@@ -0,0 +1,98 @@
|
||||
package downloader
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"encoding/binary"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
// ExpiredTokenError error is returned when Downloader get expired file token for CDN.
|
||||
// See https://core.telegram.org/constructor/upload.fileCdnRedirect.
|
||||
type ExpiredTokenError struct {
|
||||
*tg.UploadCDNFileReuploadNeeded
|
||||
}
|
||||
|
||||
// Error implements error interface.
|
||||
func (r *ExpiredTokenError) Error() string {
|
||||
return "redirect to master DC for requesting new file token"
|
||||
}
|
||||
|
||||
// cdn is a CDN DC download schema.
|
||||
// See https://core.telegram.org/cdn#getting-files-from-a-cdn.
|
||||
type cdn struct {
|
||||
cdn CDN
|
||||
client Client
|
||||
pool *bin.Pool
|
||||
redirect *tg.UploadFileCDNRedirect
|
||||
}
|
||||
|
||||
var _ schema = cdn{}
|
||||
|
||||
// decrypt decrypts file chunk from Telegram CDN.
|
||||
// See https://core.telegram.org/cdn#decrypting-files.
|
||||
func (c cdn) decrypt(src []byte, offset int64) ([]byte, error) {
|
||||
block, err := aes.NewCipher(c.redirect.EncryptionKey)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "create cipher")
|
||||
}
|
||||
|
||||
if block.BlockSize() != len(c.redirect.EncryptionIv) {
|
||||
return nil, errors.Errorf(
|
||||
"invalid IV or key length, block size %d != IV %d",
|
||||
block.BlockSize(), len(c.redirect.EncryptionIv),
|
||||
)
|
||||
}
|
||||
|
||||
// Copy IV to buffer from Pool.
|
||||
iv := c.pool.GetSize(len(c.redirect.EncryptionIv))
|
||||
defer c.pool.Put(iv)
|
||||
copy(iv.Buf, c.redirect.EncryptionIv)
|
||||
|
||||
// For IV, it should use the value of encryption_iv, modified in the following manner:
|
||||
// for each offset replace the last 4 bytes of the encryption_iv with offset / 16 in big-endian.
|
||||
binary.BigEndian.PutUint32(iv.Buf[iv.Len()-4:], uint32(offset/16))
|
||||
|
||||
dst := make([]byte, len(src))
|
||||
cipher.NewCTR(block, iv.Buf).XORKeyStream(dst, src)
|
||||
return dst, nil
|
||||
}
|
||||
|
||||
func (c cdn) Chunk(ctx context.Context, offset int64, limit int) (chunk, error) {
|
||||
r, err := c.cdn.UploadGetCDNFile(ctx, &tg.UploadGetCDNFileRequest{
|
||||
Offset: offset,
|
||||
Limit: limit,
|
||||
FileToken: c.redirect.FileToken,
|
||||
})
|
||||
if err != nil {
|
||||
return chunk{}, err
|
||||
}
|
||||
|
||||
switch result := r.(type) {
|
||||
case *tg.UploadCDNFile:
|
||||
data, err := c.decrypt(result.Bytes, offset)
|
||||
if err != nil {
|
||||
return chunk{}, err
|
||||
}
|
||||
|
||||
return chunk{
|
||||
data: data,
|
||||
}, nil
|
||||
case *tg.UploadCDNFileReuploadNeeded:
|
||||
return chunk{}, &ExpiredTokenError{UploadCDNFileReuploadNeeded: result}
|
||||
default:
|
||||
return chunk{}, errors.Errorf("unexpected type %T", r)
|
||||
}
|
||||
}
|
||||
|
||||
func (c cdn) Hashes(ctx context.Context, offset int64) ([]tg.FileHash, error) {
|
||||
return c.client.UploadGetCDNFileHashes(ctx, &tg.UploadGetCDNFileHashesRequest{
|
||||
FileToken: c.redirect.FileToken,
|
||||
Offset: offset,
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,36 @@
|
||||
package downloader
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
func Test_cdn_decrypt(t *testing.T) {
|
||||
testdata := make([]byte, 32)
|
||||
tests := []struct {
|
||||
name string
|
||||
key, iv []byte
|
||||
err bool
|
||||
}{
|
||||
{"Bad key", []byte{10}, nil, true},
|
||||
{"Bad IV", make([]byte, 32), nil, true},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
c := &cdn{
|
||||
redirect: &tg.UploadFileCDNRedirect{
|
||||
EncryptionKey: test.key,
|
||||
EncryptionIv: test.iv,
|
||||
},
|
||||
}
|
||||
_, err := c.decrypt(testdata, 0)
|
||||
if test.err {
|
||||
require.Error(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,34 @@
|
||||
package downloader
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
// CDN represents Telegram RPC client to CDN server.
|
||||
type CDN interface {
|
||||
UploadGetCDNFile(ctx context.Context, request *tg.UploadGetCDNFileRequest) (tg.UploadCDNFileClass, error)
|
||||
}
|
||||
|
||||
// Client represents Telegram RPC client.
|
||||
type Client interface {
|
||||
UploadGetFile(ctx context.Context, request *tg.UploadGetFileRequest) (tg.UploadFileClass, error)
|
||||
UploadGetFileHashes(ctx context.Context, request *tg.UploadGetFileHashesRequest) ([]tg.FileHash, error)
|
||||
|
||||
UploadReuploadCDNFile(ctx context.Context, request *tg.UploadReuploadCDNFileRequest) ([]tg.FileHash, error)
|
||||
UploadGetCDNFileHashes(ctx context.Context, request *tg.UploadGetCDNFileHashesRequest) ([]tg.FileHash, error)
|
||||
|
||||
UploadGetWebFile(ctx context.Context, request *tg.UploadGetWebFileRequest) (*tg.UploadWebFile, error)
|
||||
}
|
||||
|
||||
type chunk struct {
|
||||
data []byte
|
||||
tag tg.StorageFileTypeClass
|
||||
}
|
||||
|
||||
// schema is simple interface for different download schemas.
|
||||
type schema interface {
|
||||
Chunk(ctx context.Context, offset int64, limit int) (chunk, error)
|
||||
Hashes(ctx context.Context, offset int64) ([]tg.FileHash, error)
|
||||
}
|
||||
@@ -0,0 +1,48 @@
|
||||
// Package downloader contains downloading files helpers.
|
||||
package downloader
|
||||
|
||||
import (
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
// Downloader is Telegram file downloader.
|
||||
type Downloader struct {
|
||||
partSize int
|
||||
pool *bin.Pool
|
||||
}
|
||||
|
||||
const defaultPartSize = 512 * 1024 // 512 kb
|
||||
|
||||
// NewDownloader creates new Downloader.
|
||||
func NewDownloader() *Downloader {
|
||||
return new(Downloader).WithPartSize(defaultPartSize)
|
||||
}
|
||||
|
||||
// WithPartSize sets chunk size.
|
||||
// Must be divisible by 4KB.
|
||||
//
|
||||
// See https://core.telegram.org/api/files#downloading-files.
|
||||
func (d *Downloader) WithPartSize(partSize int) *Downloader {
|
||||
d.partSize = partSize
|
||||
d.pool = bin.NewPool(partSize)
|
||||
return d
|
||||
}
|
||||
|
||||
// Download creates Builder for plain downloads.
|
||||
func (d *Downloader) Download(rpc Client, location tg.InputFileLocationClass) *Builder {
|
||||
return newBuilder(d, master{
|
||||
client: rpc,
|
||||
precise: true,
|
||||
allowCDN: false,
|
||||
location: location,
|
||||
})
|
||||
}
|
||||
|
||||
// Web creates Builder for web files downloads.
|
||||
func (d *Downloader) Web(rpc Client, location tg.InputWebFileLocationClass) *Builder {
|
||||
return newBuilder(d, web{
|
||||
client: rpc,
|
||||
location: location,
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,300 @@
|
||||
package downloader
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/syncio"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/testutil"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
type mock struct {
|
||||
data []byte
|
||||
hashes mockHashes
|
||||
migrate bool
|
||||
err bool
|
||||
hashesErr bool
|
||||
redirect *tg.UploadFileCDNRedirect
|
||||
}
|
||||
|
||||
var testErr = testutil.TestError()
|
||||
|
||||
func (m mock) getPart(offset int64, limit int) []byte {
|
||||
length := len(m.data)
|
||||
if offset >= int64(length) {
|
||||
return []byte{}
|
||||
}
|
||||
|
||||
size := length - int(offset)
|
||||
if size > limit {
|
||||
size = limit
|
||||
}
|
||||
|
||||
r := make([]byte, size)
|
||||
copy(r, m.data[offset:])
|
||||
return r
|
||||
}
|
||||
|
||||
func (m mock) UploadGetFile(ctx context.Context, request *tg.UploadGetFileRequest) (tg.UploadFileClass, error) {
|
||||
if m.err {
|
||||
return nil, testErr
|
||||
}
|
||||
|
||||
if m.migrate {
|
||||
return m.redirect, nil
|
||||
}
|
||||
|
||||
return &tg.UploadFile{
|
||||
Bytes: m.getPart(request.Offset, request.Limit),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m mock) UploadGetFileHashes(ctx context.Context, request *tg.UploadGetFileHashesRequest) ([]tg.FileHash, error) {
|
||||
if m.hashesErr {
|
||||
return nil, testErr
|
||||
}
|
||||
|
||||
return m.hashes.Hashes(ctx, request.Offset)
|
||||
}
|
||||
|
||||
func (m mock) UploadReuploadCDNFile(ctx context.Context, request *tg.UploadReuploadCDNFileRequest) ([]tg.FileHash, error) {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (m mock) UploadGetCDNFile(ctx context.Context, request *tg.UploadGetCDNFileRequest) (tg.UploadCDNFileClass, error) {
|
||||
if m.err {
|
||||
return nil, testErr
|
||||
}
|
||||
|
||||
if m.migrate {
|
||||
return &tg.UploadCDNFileReuploadNeeded{
|
||||
RequestToken: []byte{1, 2, 3},
|
||||
}, nil
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(m.redirect.EncryptionKey)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "CDN mock cipher creation")
|
||||
}
|
||||
|
||||
iv := make([]byte, len(m.redirect.EncryptionIv))
|
||||
copy(iv, m.redirect.EncryptionIv)
|
||||
binary.BigEndian.PutUint32(iv[len(iv)-4:], uint32(request.Offset/16))
|
||||
|
||||
part := m.getPart(request.Offset, request.Limit)
|
||||
r := make([]byte, len(part))
|
||||
cipher.NewCTR(block, iv).XORKeyStream(r, part)
|
||||
return &tg.UploadCDNFile{
|
||||
Bytes: r,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m mock) UploadGetCDNFileHashes(ctx context.Context, request *tg.UploadGetCDNFileHashesRequest) ([]tg.FileHash, error) {
|
||||
if m.hashesErr {
|
||||
return nil, testErr
|
||||
}
|
||||
|
||||
return m.hashes.Hashes(ctx, request.Offset)
|
||||
}
|
||||
|
||||
func (m mock) UploadGetWebFile(ctx context.Context, request *tg.UploadGetWebFileRequest) (*tg.UploadWebFile, error) {
|
||||
if m.err {
|
||||
return nil, testErr
|
||||
}
|
||||
|
||||
return &tg.UploadWebFile{
|
||||
Bytes: m.getPart(int64(request.Offset), request.Limit),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func countHashes(data []byte, partSize int) (r [][]tg.FileHash) {
|
||||
actions := data
|
||||
batchSize := partSize
|
||||
batches := make([][]byte, 0, (len(actions)+batchSize-1)/batchSize)
|
||||
|
||||
for batchSize < len(actions) {
|
||||
actions, batches = actions[batchSize:], append(batches, actions[0:batchSize:batchSize])
|
||||
}
|
||||
batches = append(batches, actions)
|
||||
|
||||
currentRange := make([]tg.FileHash, 0, 10)
|
||||
offset := 0
|
||||
for _, batch := range batches {
|
||||
if len(currentRange) >= 10 {
|
||||
r = append(r, currentRange)
|
||||
currentRange = make([]tg.FileHash, 0, 10)
|
||||
}
|
||||
currentRange = append(currentRange, tg.FileHash{
|
||||
Offset: int64(offset),
|
||||
Limit: partSize,
|
||||
Hash: crypto.SHA256(batch),
|
||||
})
|
||||
offset += len(batch)
|
||||
|
||||
if len(batch) < partSize {
|
||||
break
|
||||
}
|
||||
}
|
||||
r = append(r, currentRange)
|
||||
return
|
||||
}
|
||||
|
||||
func Test_countHashes(t *testing.T) {
|
||||
a := require.New(t)
|
||||
data := bytes.Repeat([]byte{1, 2, 3, 4, 5}, 10)
|
||||
hashes := countHashes(data, 4)
|
||||
|
||||
a.NotEmpty(hashes)
|
||||
for _, hashRange := range hashes {
|
||||
for _, hash := range hashRange {
|
||||
from := hash.Offset
|
||||
to := int(hash.Offset) + hash.Limit
|
||||
if to > len(data) {
|
||||
to = len(data)
|
||||
}
|
||||
a.Equal(crypto.SHA256(data[from:to]), hash.Hash)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDownloader(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
key := make([]byte, 32)
|
||||
iv := make([]byte, aes.BlockSize)
|
||||
if _, err := io.ReadFull(rand.Reader, key); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := io.ReadFull(rand.Reader, iv); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
redirect := &tg.UploadFileCDNRedirect{
|
||||
DCID: 1,
|
||||
FileToken: []byte{10},
|
||||
EncryptionKey: key,
|
||||
EncryptionIv: iv,
|
||||
}
|
||||
|
||||
testData := make([]byte, defaultPartSize*2)
|
||||
if _, err := io.ReadFull(rand.Reader, testData); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
data []byte
|
||||
migrate bool
|
||||
err bool
|
||||
hashesErr bool
|
||||
}{
|
||||
{"5b", []byte{1, 2, 3, 4, 5}, false, false, false},
|
||||
{strconv.Itoa(len(testData)) + "b", testData, false, false, false},
|
||||
{"Error", []byte{}, false, true, false},
|
||||
{"HashesError", []byte{}, false, true, true},
|
||||
{"Migrate", []byte{}, true, false, false},
|
||||
}
|
||||
schemas := []struct {
|
||||
name string
|
||||
creator func(c Client, cdn CDN) *Builder
|
||||
}{
|
||||
{"Master", func(c Client, cdn CDN) *Builder {
|
||||
return NewDownloader().Download(c, nil)
|
||||
}},
|
||||
{"Web", func(c Client, cdn CDN) *Builder {
|
||||
return NewDownloader().Web(c, nil)
|
||||
}},
|
||||
}
|
||||
ways := []struct {
|
||||
name string
|
||||
action func(b *Builder) ([]byte, error)
|
||||
}{
|
||||
{"Stream", func(b *Builder) ([]byte, error) {
|
||||
output := new(bytes.Buffer)
|
||||
_, err := b.Stream(ctx, output)
|
||||
return output.Bytes(), err
|
||||
}},
|
||||
{"Parallel", func(b *Builder) ([]byte, error) {
|
||||
output := new(syncio.BufWriterAt)
|
||||
_, err := b.WithThreads(runtime.GOMAXPROCS(0)).Parallel(ctx, output)
|
||||
return output.Bytes(), err
|
||||
}},
|
||||
{"Parallel-OneThread", func(b *Builder) ([]byte, error) {
|
||||
output := new(syncio.BufWriterAt)
|
||||
_, err := b.WithThreads(1).Parallel(ctx, output)
|
||||
return output.Bytes(), err
|
||||
}},
|
||||
}
|
||||
options := []struct {
|
||||
name string
|
||||
action func(b *Builder) *Builder
|
||||
}{
|
||||
{"NoVerify", func(b *Builder) *Builder {
|
||||
return b.WithVerify(false)
|
||||
}},
|
||||
{"Verify", func(b *Builder) *Builder {
|
||||
return b.WithVerify(true)
|
||||
}},
|
||||
}
|
||||
|
||||
for _, schema := range schemas {
|
||||
t.Run(schema.name, func(t *testing.T) {
|
||||
for _, test := range tests {
|
||||
// Telegram can't redirect web file downloads.
|
||||
if schema.name == "Web" && test.migrate {
|
||||
continue
|
||||
}
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
for _, option := range options {
|
||||
// Telegram can't return hashes for web files.
|
||||
if schema.name == "Web" && option.name == "Verify" {
|
||||
continue
|
||||
}
|
||||
|
||||
t.Run(option.name, func(t *testing.T) {
|
||||
for _, way := range ways {
|
||||
t.Run(way.name, func(t *testing.T) {
|
||||
a := require.New(t)
|
||||
client := &mock{
|
||||
data: test.data,
|
||||
hashes: mockHashes{
|
||||
ranges: countHashes(test.data, 128*1024),
|
||||
},
|
||||
migrate: test.migrate,
|
||||
err: test.err,
|
||||
redirect: redirect,
|
||||
}
|
||||
|
||||
b := schema.creator(client, client)
|
||||
b = option.action(b)
|
||||
data, err := way.action(b)
|
||||
switch {
|
||||
case test.migrate:
|
||||
a.Error(err)
|
||||
case test.err:
|
||||
a.Error(err)
|
||||
default:
|
||||
a.NoError(err)
|
||||
a.True(bytes.Equal(test.data, data))
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,64 @@
|
||||
package downloader
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
// RedirectError error is returned when Downloader get CDN redirect.
|
||||
// See https://core.telegram.org/constructor/upload.fileCdnRedirect.
|
||||
type RedirectError struct {
|
||||
Redirect *tg.UploadFileCDNRedirect
|
||||
}
|
||||
|
||||
// Error implements error interface.
|
||||
func (r *RedirectError) Error() string {
|
||||
return "redirect to CDN DC " + strconv.Itoa(r.Redirect.DCID)
|
||||
}
|
||||
|
||||
// master is a master DC download schema.
|
||||
// See https://core.telegram.org/api/files#downloading-files.
|
||||
type master struct {
|
||||
client Client
|
||||
|
||||
precise bool
|
||||
allowCDN bool
|
||||
location tg.InputFileLocationClass
|
||||
}
|
||||
|
||||
var _ schema = master{}
|
||||
|
||||
func (c master) Chunk(ctx context.Context, offset int64, limit int) (chunk, error) {
|
||||
req := &tg.UploadGetFileRequest{
|
||||
Offset: offset,
|
||||
Limit: limit,
|
||||
Location: c.location,
|
||||
}
|
||||
req.SetCDNSupported(c.allowCDN)
|
||||
req.SetPrecise(c.precise)
|
||||
|
||||
r, err := c.client.UploadGetFile(ctx, req)
|
||||
if err != nil {
|
||||
return chunk{}, err
|
||||
}
|
||||
|
||||
switch result := r.(type) {
|
||||
case *tg.UploadFile:
|
||||
return chunk{data: result.Bytes, tag: result.Type}, nil
|
||||
case *tg.UploadFileCDNRedirect:
|
||||
return chunk{}, &RedirectError{Redirect: result}
|
||||
default:
|
||||
return chunk{}, errors.Errorf("unexpected type %T", r)
|
||||
}
|
||||
}
|
||||
|
||||
func (c master) Hashes(ctx context.Context, offset int64) ([]tg.FileHash, error) {
|
||||
return c.client.UploadGetFileHashes(ctx, &tg.UploadGetFileHashesRequest{
|
||||
Location: c.location,
|
||||
Offset: offset,
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,83 @@
|
||||
package downloader
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"sync"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/syncio"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tdsync"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
// nolint:gocognit
|
||||
func (d *Downloader) parallel(
|
||||
ctx context.Context, r *reader,
|
||||
threads int, w io.WriterAt,
|
||||
) (tg.StorageFileTypeClass, error) {
|
||||
var typ tg.StorageFileTypeClass
|
||||
typOnce := &sync.Once{}
|
||||
|
||||
ready := tdsync.NewReady()
|
||||
g := tdsync.NewCancellableGroup(ctx)
|
||||
toWrite := make(chan block, threads)
|
||||
|
||||
stop := func(t tg.StorageFileTypeClass) {
|
||||
typOnce.Do(func() {
|
||||
typ = t
|
||||
})
|
||||
ready.Signal()
|
||||
}
|
||||
|
||||
// Download loop
|
||||
g.Go(func(ctx context.Context) error {
|
||||
downloads := tdsync.NewCancellableGroup(ctx)
|
||||
defer close(toWrite)
|
||||
|
||||
for i := 0; i < threads; i++ {
|
||||
downloads.Go(func(ctx context.Context) error {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-ready.Ready():
|
||||
return nil
|
||||
default:
|
||||
}
|
||||
|
||||
b, err := r.Next(ctx)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "get file")
|
||||
}
|
||||
|
||||
// If returned chunk is zero, that means we read all file.
|
||||
n := len(b.data)
|
||||
if n < 1 {
|
||||
stop(b.tag)
|
||||
return nil
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case toWrite <- b:
|
||||
}
|
||||
|
||||
if b.last() {
|
||||
stop(b.tag)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
return downloads.Wait()
|
||||
})
|
||||
|
||||
// Write loop
|
||||
g.Go(writeAtLoop(syncio.NewWriterAt(w), toWrite))
|
||||
|
||||
return typ, g.Wait()
|
||||
}
|
||||
@@ -0,0 +1,107 @@
|
||||
package downloader
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tgerr"
|
||||
)
|
||||
|
||||
type block struct {
|
||||
chunk
|
||||
offset int64
|
||||
partSize int
|
||||
}
|
||||
|
||||
// last compares partSize and chunk length to determine last part.
|
||||
func (b block) last() bool {
|
||||
// If returned chunk is smaller than requested part, it seems
|
||||
// it is last part.
|
||||
return len(b.data) < b.partSize
|
||||
}
|
||||
|
||||
type reader struct {
|
||||
sch schema // immutable
|
||||
verifier *verifier // immutable
|
||||
partSize int // immutable
|
||||
|
||||
offset int64
|
||||
offsetMux sync.Mutex
|
||||
}
|
||||
|
||||
func verifiedReader(sch schema, verifier *verifier) *reader {
|
||||
return &reader{
|
||||
sch: sch,
|
||||
verifier: verifier,
|
||||
}
|
||||
}
|
||||
|
||||
func plainReader(sch schema, partSize int) *reader {
|
||||
return &reader{
|
||||
sch: sch,
|
||||
partSize: partSize,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *reader) Next(ctx context.Context) (block, error) {
|
||||
if r.verifier != nil {
|
||||
return r.nextHashed(ctx)
|
||||
}
|
||||
|
||||
return r.nextPlain(ctx)
|
||||
}
|
||||
|
||||
func (r *reader) nextHashed(ctx context.Context) (block, error) {
|
||||
// Fetch next hashes.
|
||||
hash, ok, err := r.verifier.next(ctx)
|
||||
if err != nil {
|
||||
return block{}, err
|
||||
}
|
||||
if !ok {
|
||||
return block{}, nil
|
||||
}
|
||||
|
||||
// Get next chunk.
|
||||
b, err := r.next(ctx, hash.Offset, hash.Limit)
|
||||
if err != nil {
|
||||
return block{}, err
|
||||
}
|
||||
|
||||
// Verify chunk.
|
||||
if !r.verifier.verify(hash, b.data) {
|
||||
return block{}, ErrHashMismatch
|
||||
}
|
||||
|
||||
return b, nil
|
||||
}
|
||||
|
||||
func (r *reader) nextPlain(ctx context.Context) (block, error) {
|
||||
r.offsetMux.Lock()
|
||||
offset := r.offset
|
||||
r.offset += int64(r.partSize)
|
||||
r.offsetMux.Unlock()
|
||||
|
||||
return r.next(ctx, offset, r.partSize)
|
||||
}
|
||||
|
||||
func (r *reader) next(ctx context.Context, offset int64, limit int) (block, error) {
|
||||
for {
|
||||
ch, err := r.sch.Chunk(ctx, offset, limit)
|
||||
|
||||
if flood, err := tgerr.FloodWait(ctx, err); err != nil {
|
||||
if flood || tgerr.Is(err, tg.ErrTimeout) {
|
||||
continue
|
||||
}
|
||||
return block{}, errors.Wrap(err, "get next chunk")
|
||||
}
|
||||
|
||||
return block{
|
||||
chunk: ch,
|
||||
offset: offset,
|
||||
partSize: r.partSize,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,48 @@
|
||||
package downloader
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
)
|
||||
|
||||
func writeAtLoop(w io.WriterAt, toWrite <-chan block) func(context.Context) error {
|
||||
return func(ctx context.Context) error {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case part, ok := <-toWrite:
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err := w.WriteAt(part.data, part.offset)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "write output")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func writeLoop(w io.Writer, toWrite <-chan block) func(context.Context) error {
|
||||
return func(ctx context.Context) error {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case part, ok := <-toWrite:
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err := w.Write(part.data)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "write output")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,54 @@
|
||||
package downloader
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tdsync"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
func (d *Downloader) stream(ctx context.Context, r *reader, w io.Writer) (tg.StorageFileTypeClass, error) {
|
||||
var typ tg.StorageFileTypeClass
|
||||
|
||||
g := tdsync.NewCancellableGroup(ctx)
|
||||
toWrite := make(chan block, 1)
|
||||
|
||||
stop := func(t tg.StorageFileTypeClass) {
|
||||
typ = t
|
||||
close(toWrite)
|
||||
}
|
||||
// Download loop
|
||||
g.Go(func(ctx context.Context) error {
|
||||
for {
|
||||
b, err := r.Next(ctx)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "get file")
|
||||
}
|
||||
|
||||
n := len(b.data)
|
||||
if n < 1 {
|
||||
stop(b.tag)
|
||||
return nil
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case toWrite <- b:
|
||||
}
|
||||
|
||||
if b.last() {
|
||||
stop(b.tag)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// Write loop
|
||||
g.Go(writeLoop(w, toWrite))
|
||||
|
||||
return typ, g.Wait()
|
||||
}
|
||||
@@ -0,0 +1,49 @@
|
||||
package downloader
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
type streamReader struct {
|
||||
ctx context.Context
|
||||
reader *reader
|
||||
curBlock block
|
||||
last bool
|
||||
}
|
||||
|
||||
var _ io.Reader = (*streamReader)(nil)
|
||||
|
||||
func (s *streamReader) Read(p []byte) (n int, err error) {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
return 0, s.ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
if len(s.curBlock.data) == 0 {
|
||||
if s.last {
|
||||
return 0, io.EOF
|
||||
} else {
|
||||
s.curBlock, err = s.reader.Next(s.ctx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
s.last = s.curBlock.last()
|
||||
}
|
||||
}
|
||||
|
||||
n = copy(p, s.curBlock.data)
|
||||
s.curBlock.data = s.curBlock.data[n:]
|
||||
return
|
||||
}
|
||||
|
||||
func (d *Downloader) streamToReader(ctx context.Context, r *reader) (tg.StorageFileTypeClass, io.Reader, error) {
|
||||
first, err := r.Next(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return first.tag, &streamReader{ctx, r, first, first.last()}, nil
|
||||
}
|
||||
@@ -0,0 +1,104 @@
|
||||
package downloader
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"sort"
|
||||
"sync"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tgerr"
|
||||
)
|
||||
|
||||
// ErrHashMismatch means that download hash verification was failed.
|
||||
var ErrHashMismatch = errors.New("file hash mismatch")
|
||||
|
||||
type verifier struct {
|
||||
client schema
|
||||
|
||||
hashes []tg.FileHash
|
||||
offset int64
|
||||
mux sync.Mutex
|
||||
}
|
||||
|
||||
func newVerifier(client schema, hashes ...tg.FileHash) *verifier {
|
||||
r := make([]tg.FileHash, len(hashes))
|
||||
|
||||
copy(r, hashes)
|
||||
sort.SliceStable(r, func(i, j int) bool {
|
||||
return r[i].Offset < r[j].Offset
|
||||
})
|
||||
|
||||
return &verifier{client: client, hashes: r}
|
||||
}
|
||||
|
||||
func (v *verifier) pop() (tg.FileHash, bool) {
|
||||
if len(v.hashes) < 1 {
|
||||
return tg.FileHash{}, false
|
||||
}
|
||||
|
||||
// Pop and move.
|
||||
hash := v.hashes[0]
|
||||
copy(v.hashes, v.hashes[1:])
|
||||
v.hashes[len(v.hashes)-1] = tg.FileHash{}
|
||||
v.hashes = v.hashes[:len(v.hashes)-1]
|
||||
|
||||
return hash, true
|
||||
}
|
||||
|
||||
func (v *verifier) update(hashes ...tg.FileHash) (tg.FileHash, bool) {
|
||||
// If result is empty and queue is empty, so we can't return next hash.
|
||||
if len(hashes) < 1 {
|
||||
return tg.FileHash{}, false
|
||||
}
|
||||
|
||||
// Sort hashes by offset.
|
||||
// Usually Telegram server returns sorted parts, but...
|
||||
// you never known what can they do.
|
||||
sort.SliceStable(hashes, func(i, j int) bool {
|
||||
return hashes[i].Offset < hashes[j].Offset
|
||||
})
|
||||
|
||||
last := hashes[len(hashes)-1]
|
||||
// Check if we have reached the end.
|
||||
// If current state offset is equal the last offset + limit (right border)
|
||||
// then we got all hashes.
|
||||
if last.Offset == v.offset-int64(last.Limit) {
|
||||
return tg.FileHash{}, false
|
||||
}
|
||||
|
||||
// Otherwise, we update current offset and add hashes to the end of queue.
|
||||
v.offset = last.Offset + int64(last.Limit)
|
||||
v.hashes = append(v.hashes, hashes...)
|
||||
return v.pop()
|
||||
}
|
||||
|
||||
func (v *verifier) next(ctx context.Context) (tg.FileHash, bool, error) {
|
||||
v.mux.Lock()
|
||||
defer v.mux.Unlock()
|
||||
|
||||
hash, ok := v.pop()
|
||||
if ok {
|
||||
return hash, ok, nil
|
||||
}
|
||||
|
||||
for {
|
||||
hashes, err := v.client.Hashes(ctx, v.offset)
|
||||
if flood, err := tgerr.FloodWait(ctx, err); err != nil {
|
||||
if flood || tgerr.Is(err, tg.ErrTimeout) {
|
||||
continue
|
||||
}
|
||||
return tg.FileHash{}, false, errors.Wrap(err, "get hashes")
|
||||
}
|
||||
|
||||
hash, ok = v.update(hashes...)
|
||||
return hash, ok, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (v *verifier) verify(hash tg.FileHash, data []byte) bool {
|
||||
return bytes.Equal(crypto.SHA256(data), hash.Hash)
|
||||
}
|
||||
@@ -0,0 +1,104 @@
|
||||
package downloader
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
var hashRanges = [][]tg.FileHash{
|
||||
{
|
||||
tg.FileHash{Offset: 0, Limit: 131072},
|
||||
tg.FileHash{Offset: 131072, Limit: 131072},
|
||||
tg.FileHash{Offset: 262144, Limit: 131072},
|
||||
tg.FileHash{Offset: 393216, Limit: 131072},
|
||||
tg.FileHash{Offset: 524288, Limit: 131072},
|
||||
tg.FileHash{Offset: 655360, Limit: 131072},
|
||||
tg.FileHash{Offset: 786432, Limit: 131072},
|
||||
tg.FileHash{Offset: 917504, Limit: 131072},
|
||||
}, {
|
||||
tg.FileHash{Offset: 1048576, Limit: 131072},
|
||||
tg.FileHash{Offset: 1179648, Limit: 131072},
|
||||
tg.FileHash{Offset: 1310720, Limit: 131072},
|
||||
tg.FileHash{Offset: 1441792, Limit: 131072},
|
||||
tg.FileHash{Offset: 1572864, Limit: 131072},
|
||||
tg.FileHash{Offset: 1703936, Limit: 131072},
|
||||
tg.FileHash{Offset: 1835008, Limit: 131072},
|
||||
tg.FileHash{Offset: 1966080, Limit: 131072},
|
||||
}, {
|
||||
tg.FileHash{Offset: 2097152, Limit: 131072},
|
||||
tg.FileHash{Offset: 2228224, Limit: 131072},
|
||||
tg.FileHash{Offset: 2359296, Limit: 131072},
|
||||
tg.FileHash{Offset: 2490368, Limit: 131072},
|
||||
tg.FileHash{Offset: 2621440, Limit: 131072},
|
||||
tg.FileHash{Offset: 2752512, Limit: 131072},
|
||||
tg.FileHash{Offset: 2883584, Limit: 131072},
|
||||
tg.FileHash{Offset: 3014656, Limit: 131072},
|
||||
},
|
||||
}
|
||||
|
||||
type mockHashes struct {
|
||||
ranges [][]tg.FileHash
|
||||
}
|
||||
|
||||
func (m mockHashes) Chunk(ctx context.Context, offset int64, limit int) (chunk, error) {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (m mockHashes) Hashes(ctx context.Context, offset int64) ([]tg.FileHash, error) {
|
||||
for _, r := range m.ranges {
|
||||
last := r[len(r)-1]
|
||||
if last.Offset+int64(last.Limit) <= offset {
|
||||
continue
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
return m.ranges[len(m.ranges)-1], nil
|
||||
}
|
||||
|
||||
func TestVerifier(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
ranges [][]tg.FileHash
|
||||
// Hashes returned from CDN redirect, for example.
|
||||
predefined []tg.FileHash
|
||||
expected [][]tg.FileHash
|
||||
}{
|
||||
{"NoPredefined", hashRanges, nil, hashRanges},
|
||||
{"Predefined", hashRanges[1:], hashRanges[0], hashRanges},
|
||||
{"OnlyPredefined", hashRanges[:1], hashRanges[0], hashRanges[:1]},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
a := require.New(t)
|
||||
m := mockHashes{ranges: test.ranges}
|
||||
v := newVerifier(m, test.predefined...)
|
||||
|
||||
hashes := make([]tg.FileHash, 0, len(test.predefined))
|
||||
for {
|
||||
hash, ok, err := v.next(ctx)
|
||||
a.NoError(err)
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
|
||||
hashes = append(hashes, hash)
|
||||
}
|
||||
|
||||
i := 0
|
||||
for _, hashRange := range test.expected {
|
||||
for _, expected := range hashRange {
|
||||
a.Equal(expected, hashes[i])
|
||||
i++
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,38 @@
|
||||
package downloader
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
var errHashesNotSupported = errors.New("this schema does not support hashes fetch")
|
||||
|
||||
// web is a web file download schema.
|
||||
// See https://core.telegram.org/api/files#downloading-webfiles.
|
||||
type web struct {
|
||||
client Client
|
||||
|
||||
location tg.InputWebFileLocationClass
|
||||
}
|
||||
|
||||
var _ schema = web{}
|
||||
|
||||
func (w web) Chunk(ctx context.Context, offset int64, limit int) (chunk, error) {
|
||||
file, err := w.client.UploadGetWebFile(ctx, &tg.UploadGetWebFileRequest{
|
||||
Location: w.location,
|
||||
Offset: int(offset),
|
||||
Limit: limit,
|
||||
})
|
||||
if err != nil {
|
||||
return chunk{}, err
|
||||
}
|
||||
|
||||
return chunk{data: file.Bytes, tag: file.FileType}, nil
|
||||
}
|
||||
|
||||
func (w web) Hashes(ctx context.Context, offset int64) ([]tg.FileHash, error) {
|
||||
return nil, errHashesNotSupported
|
||||
}
|
||||
Reference in New Issue
Block a user