move gotd fork into repo. (#111)

- update to latest telegram layer
- remove some references to fields in tg.Entities that don't exist in
the schema
- originally added here:
https://github.com/beeper/td/commit/820929062a2ba0104397bc01235ab58a9cff780e
  - referenced here
-
https://github.com/mautrix/telegramgo/commit/124f0967ed195b5a380c9bd02e170ada9710dde3
-
https://github.com/mautrix/telegramgo/commit/4205047aab2e0639217148b5d125bfaab668bd8e
This commit is contained in:
Adam Van Ymeren
2025-06-27 20:03:37 -07:00
committed by GitHub
parent 0952df0244
commit 7a04f298d2
19264 changed files with 1539697 additions and 84 deletions
+87
View File
@@ -0,0 +1,87 @@
package downloader
import (
"context"
"io"
"os"
"path/filepath"
"github.com/go-faster/errors"
"go.uber.org/multierr"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
)
// Builder is a download builder.
type Builder struct {
downloader *Downloader
schema schema
hashes []tg.FileHash
verify bool
threads int
}
func newBuilder(downloader *Downloader, schema schema) *Builder {
return &Builder{
schema: schema,
threads: 1,
downloader: downloader,
}
}
// WithThreads sets downloading goroutines limit.
func (b *Builder) WithThreads(threads int) *Builder {
if threads > 0 {
b.threads = threads
}
return b
}
// WithVerify sets verify parameter.
// If verify is true, file hashes will be checked
// Verify is true by default for CDN downloads.
func (b *Builder) WithVerify(verify bool) *Builder {
b.verify = verify
return b
}
func (b *Builder) reader() *reader {
if b.verify {
return verifiedReader(b.schema, newVerifier(b.schema, b.hashes...))
}
return plainReader(b.schema, b.downloader.partSize)
}
// Stream downloads file to given io.Writer.
// NB: in this mode download can't be parallel.
func (b *Builder) Stream(ctx context.Context, output io.Writer) (tg.StorageFileTypeClass, error) {
return b.downloader.stream(ctx, b.reader(), output)
}
// StreamToReader streams a file to the returned [io.Reader].
// NB: in this mode download can't be parallel.
func (b *Builder) StreamToReader(ctx context.Context) (tg.StorageFileTypeClass, io.Reader, error) {
var tgDC int
ctx = context.WithValue(ctx, "tg_dc", &tgDC)
return b.downloader.streamToReader(ctx, b.reader())
}
// Parallel downloads file to given io.WriterAt.
func (b *Builder) Parallel(ctx context.Context, output io.WriterAt) (tg.StorageFileTypeClass, error) {
return b.downloader.parallel(ctx, b.reader(), b.threads, output)
}
// ToPath downloads file to given path.
func (b *Builder) ToPath(ctx context.Context, path string) (_ tg.StorageFileTypeClass, err error) {
f, err := os.Create(filepath.Clean(path))
if err != nil {
return nil, errors.Wrap(err, "create output file")
}
defer func() {
multierr.AppendInto(&err, f.Close())
}()
return b.Parallel(ctx, f)
}
+98
View File
@@ -0,0 +1,98 @@
package downloader
import (
"context"
"crypto/aes"
"crypto/cipher"
"encoding/binary"
"github.com/go-faster/errors"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
)
// ExpiredTokenError error is returned when Downloader get expired file token for CDN.
// See https://core.telegram.org/constructor/upload.fileCdnRedirect.
type ExpiredTokenError struct {
*tg.UploadCDNFileReuploadNeeded
}
// Error implements error interface.
func (r *ExpiredTokenError) Error() string {
return "redirect to master DC for requesting new file token"
}
// cdn is a CDN DC download schema.
// See https://core.telegram.org/cdn#getting-files-from-a-cdn.
type cdn struct {
cdn CDN
client Client
pool *bin.Pool
redirect *tg.UploadFileCDNRedirect
}
var _ schema = cdn{}
// decrypt decrypts file chunk from Telegram CDN.
// See https://core.telegram.org/cdn#decrypting-files.
func (c cdn) decrypt(src []byte, offset int64) ([]byte, error) {
block, err := aes.NewCipher(c.redirect.EncryptionKey)
if err != nil {
return nil, errors.Wrap(err, "create cipher")
}
if block.BlockSize() != len(c.redirect.EncryptionIv) {
return nil, errors.Errorf(
"invalid IV or key length, block size %d != IV %d",
block.BlockSize(), len(c.redirect.EncryptionIv),
)
}
// Copy IV to buffer from Pool.
iv := c.pool.GetSize(len(c.redirect.EncryptionIv))
defer c.pool.Put(iv)
copy(iv.Buf, c.redirect.EncryptionIv)
// For IV, it should use the value of encryption_iv, modified in the following manner:
// for each offset replace the last 4 bytes of the encryption_iv with offset / 16 in big-endian.
binary.BigEndian.PutUint32(iv.Buf[iv.Len()-4:], uint32(offset/16))
dst := make([]byte, len(src))
cipher.NewCTR(block, iv.Buf).XORKeyStream(dst, src)
return dst, nil
}
func (c cdn) Chunk(ctx context.Context, offset int64, limit int) (chunk, error) {
r, err := c.cdn.UploadGetCDNFile(ctx, &tg.UploadGetCDNFileRequest{
Offset: offset,
Limit: limit,
FileToken: c.redirect.FileToken,
})
if err != nil {
return chunk{}, err
}
switch result := r.(type) {
case *tg.UploadCDNFile:
data, err := c.decrypt(result.Bytes, offset)
if err != nil {
return chunk{}, err
}
return chunk{
data: data,
}, nil
case *tg.UploadCDNFileReuploadNeeded:
return chunk{}, &ExpiredTokenError{UploadCDNFileReuploadNeeded: result}
default:
return chunk{}, errors.Errorf("unexpected type %T", r)
}
}
func (c cdn) Hashes(ctx context.Context, offset int64) ([]tg.FileHash, error) {
return c.client.UploadGetCDNFileHashes(ctx, &tg.UploadGetCDNFileHashesRequest{
FileToken: c.redirect.FileToken,
Offset: offset,
})
}
+36
View File
@@ -0,0 +1,36 @@
package downloader
import (
"testing"
"github.com/stretchr/testify/require"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
)
func Test_cdn_decrypt(t *testing.T) {
testdata := make([]byte, 32)
tests := []struct {
name string
key, iv []byte
err bool
}{
{"Bad key", []byte{10}, nil, true},
{"Bad IV", make([]byte, 32), nil, true},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
c := &cdn{
redirect: &tg.UploadFileCDNRedirect{
EncryptionKey: test.key,
EncryptionIv: test.iv,
},
}
_, err := c.decrypt(testdata, 0)
if test.err {
require.Error(t, err)
}
})
}
}
+34
View File
@@ -0,0 +1,34 @@
package downloader
import (
"context"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
)
// CDN represents Telegram RPC client to CDN server.
type CDN interface {
UploadGetCDNFile(ctx context.Context, request *tg.UploadGetCDNFileRequest) (tg.UploadCDNFileClass, error)
}
// Client represents Telegram RPC client.
type Client interface {
UploadGetFile(ctx context.Context, request *tg.UploadGetFileRequest) (tg.UploadFileClass, error)
UploadGetFileHashes(ctx context.Context, request *tg.UploadGetFileHashesRequest) ([]tg.FileHash, error)
UploadReuploadCDNFile(ctx context.Context, request *tg.UploadReuploadCDNFileRequest) ([]tg.FileHash, error)
UploadGetCDNFileHashes(ctx context.Context, request *tg.UploadGetCDNFileHashesRequest) ([]tg.FileHash, error)
UploadGetWebFile(ctx context.Context, request *tg.UploadGetWebFileRequest) (*tg.UploadWebFile, error)
}
type chunk struct {
data []byte
tag tg.StorageFileTypeClass
}
// schema is simple interface for different download schemas.
type schema interface {
Chunk(ctx context.Context, offset int64, limit int) (chunk, error)
Hashes(ctx context.Context, offset int64) ([]tg.FileHash, error)
}
@@ -0,0 +1,48 @@
// Package downloader contains downloading files helpers.
package downloader
import (
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
)
// Downloader is Telegram file downloader.
type Downloader struct {
partSize int
pool *bin.Pool
}
const defaultPartSize = 512 * 1024 // 512 kb
// NewDownloader creates new Downloader.
func NewDownloader() *Downloader {
return new(Downloader).WithPartSize(defaultPartSize)
}
// WithPartSize sets chunk size.
// Must be divisible by 4KB.
//
// See https://core.telegram.org/api/files#downloading-files.
func (d *Downloader) WithPartSize(partSize int) *Downloader {
d.partSize = partSize
d.pool = bin.NewPool(partSize)
return d
}
// Download creates Builder for plain downloads.
func (d *Downloader) Download(rpc Client, location tg.InputFileLocationClass) *Builder {
return newBuilder(d, master{
client: rpc,
precise: true,
allowCDN: false,
location: location,
})
}
// Web creates Builder for web files downloads.
func (d *Downloader) Web(rpc Client, location tg.InputWebFileLocationClass) *Builder {
return newBuilder(d, web{
client: rpc,
location: location,
})
}
@@ -0,0 +1,300 @@
package downloader
import (
"bytes"
"context"
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/binary"
"io"
"runtime"
"strconv"
"testing"
"github.com/go-faster/errors"
"github.com/stretchr/testify/require"
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
"go.mau.fi/mautrix-telegram/pkg/gotd/syncio"
"go.mau.fi/mautrix-telegram/pkg/gotd/testutil"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
)
type mock struct {
data []byte
hashes mockHashes
migrate bool
err bool
hashesErr bool
redirect *tg.UploadFileCDNRedirect
}
var testErr = testutil.TestError()
func (m mock) getPart(offset int64, limit int) []byte {
length := len(m.data)
if offset >= int64(length) {
return []byte{}
}
size := length - int(offset)
if size > limit {
size = limit
}
r := make([]byte, size)
copy(r, m.data[offset:])
return r
}
func (m mock) UploadGetFile(ctx context.Context, request *tg.UploadGetFileRequest) (tg.UploadFileClass, error) {
if m.err {
return nil, testErr
}
if m.migrate {
return m.redirect, nil
}
return &tg.UploadFile{
Bytes: m.getPart(request.Offset, request.Limit),
}, nil
}
func (m mock) UploadGetFileHashes(ctx context.Context, request *tg.UploadGetFileHashesRequest) ([]tg.FileHash, error) {
if m.hashesErr {
return nil, testErr
}
return m.hashes.Hashes(ctx, request.Offset)
}
func (m mock) UploadReuploadCDNFile(ctx context.Context, request *tg.UploadReuploadCDNFileRequest) ([]tg.FileHash, error) {
panic("implement me")
}
func (m mock) UploadGetCDNFile(ctx context.Context, request *tg.UploadGetCDNFileRequest) (tg.UploadCDNFileClass, error) {
if m.err {
return nil, testErr
}
if m.migrate {
return &tg.UploadCDNFileReuploadNeeded{
RequestToken: []byte{1, 2, 3},
}, nil
}
block, err := aes.NewCipher(m.redirect.EncryptionKey)
if err != nil {
return nil, errors.Wrap(err, "CDN mock cipher creation")
}
iv := make([]byte, len(m.redirect.EncryptionIv))
copy(iv, m.redirect.EncryptionIv)
binary.BigEndian.PutUint32(iv[len(iv)-4:], uint32(request.Offset/16))
part := m.getPart(request.Offset, request.Limit)
r := make([]byte, len(part))
cipher.NewCTR(block, iv).XORKeyStream(r, part)
return &tg.UploadCDNFile{
Bytes: r,
}, nil
}
func (m mock) UploadGetCDNFileHashes(ctx context.Context, request *tg.UploadGetCDNFileHashesRequest) ([]tg.FileHash, error) {
if m.hashesErr {
return nil, testErr
}
return m.hashes.Hashes(ctx, request.Offset)
}
func (m mock) UploadGetWebFile(ctx context.Context, request *tg.UploadGetWebFileRequest) (*tg.UploadWebFile, error) {
if m.err {
return nil, testErr
}
return &tg.UploadWebFile{
Bytes: m.getPart(int64(request.Offset), request.Limit),
}, nil
}
func countHashes(data []byte, partSize int) (r [][]tg.FileHash) {
actions := data
batchSize := partSize
batches := make([][]byte, 0, (len(actions)+batchSize-1)/batchSize)
for batchSize < len(actions) {
actions, batches = actions[batchSize:], append(batches, actions[0:batchSize:batchSize])
}
batches = append(batches, actions)
currentRange := make([]tg.FileHash, 0, 10)
offset := 0
for _, batch := range batches {
if len(currentRange) >= 10 {
r = append(r, currentRange)
currentRange = make([]tg.FileHash, 0, 10)
}
currentRange = append(currentRange, tg.FileHash{
Offset: int64(offset),
Limit: partSize,
Hash: crypto.SHA256(batch),
})
offset += len(batch)
if len(batch) < partSize {
break
}
}
r = append(r, currentRange)
return
}
func Test_countHashes(t *testing.T) {
a := require.New(t)
data := bytes.Repeat([]byte{1, 2, 3, 4, 5}, 10)
hashes := countHashes(data, 4)
a.NotEmpty(hashes)
for _, hashRange := range hashes {
for _, hash := range hashRange {
from := hash.Offset
to := int(hash.Offset) + hash.Limit
if to > len(data) {
to = len(data)
}
a.Equal(crypto.SHA256(data[from:to]), hash.Hash)
}
}
}
func TestDownloader(t *testing.T) {
ctx := context.Background()
key := make([]byte, 32)
iv := make([]byte, aes.BlockSize)
if _, err := io.ReadFull(rand.Reader, key); err != nil {
t.Fatal(err)
}
if _, err := io.ReadFull(rand.Reader, iv); err != nil {
t.Fatal(err)
}
redirect := &tg.UploadFileCDNRedirect{
DCID: 1,
FileToken: []byte{10},
EncryptionKey: key,
EncryptionIv: iv,
}
testData := make([]byte, defaultPartSize*2)
if _, err := io.ReadFull(rand.Reader, testData); err != nil {
t.Fatal(err)
}
tests := []struct {
name string
data []byte
migrate bool
err bool
hashesErr bool
}{
{"5b", []byte{1, 2, 3, 4, 5}, false, false, false},
{strconv.Itoa(len(testData)) + "b", testData, false, false, false},
{"Error", []byte{}, false, true, false},
{"HashesError", []byte{}, false, true, true},
{"Migrate", []byte{}, true, false, false},
}
schemas := []struct {
name string
creator func(c Client, cdn CDN) *Builder
}{
{"Master", func(c Client, cdn CDN) *Builder {
return NewDownloader().Download(c, nil)
}},
{"Web", func(c Client, cdn CDN) *Builder {
return NewDownloader().Web(c, nil)
}},
}
ways := []struct {
name string
action func(b *Builder) ([]byte, error)
}{
{"Stream", func(b *Builder) ([]byte, error) {
output := new(bytes.Buffer)
_, err := b.Stream(ctx, output)
return output.Bytes(), err
}},
{"Parallel", func(b *Builder) ([]byte, error) {
output := new(syncio.BufWriterAt)
_, err := b.WithThreads(runtime.GOMAXPROCS(0)).Parallel(ctx, output)
return output.Bytes(), err
}},
{"Parallel-OneThread", func(b *Builder) ([]byte, error) {
output := new(syncio.BufWriterAt)
_, err := b.WithThreads(1).Parallel(ctx, output)
return output.Bytes(), err
}},
}
options := []struct {
name string
action func(b *Builder) *Builder
}{
{"NoVerify", func(b *Builder) *Builder {
return b.WithVerify(false)
}},
{"Verify", func(b *Builder) *Builder {
return b.WithVerify(true)
}},
}
for _, schema := range schemas {
t.Run(schema.name, func(t *testing.T) {
for _, test := range tests {
// Telegram can't redirect web file downloads.
if schema.name == "Web" && test.migrate {
continue
}
t.Run(test.name, func(t *testing.T) {
for _, option := range options {
// Telegram can't return hashes for web files.
if schema.name == "Web" && option.name == "Verify" {
continue
}
t.Run(option.name, func(t *testing.T) {
for _, way := range ways {
t.Run(way.name, func(t *testing.T) {
a := require.New(t)
client := &mock{
data: test.data,
hashes: mockHashes{
ranges: countHashes(test.data, 128*1024),
},
migrate: test.migrate,
err: test.err,
redirect: redirect,
}
b := schema.creator(client, client)
b = option.action(b)
data, err := way.action(b)
switch {
case test.migrate:
a.Error(err)
case test.err:
a.Error(err)
default:
a.NoError(err)
a.True(bytes.Equal(test.data, data))
}
})
}
})
}
})
}
})
}
}
+64
View File
@@ -0,0 +1,64 @@
package downloader
import (
"context"
"strconv"
"github.com/go-faster/errors"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
)
// RedirectError error is returned when Downloader get CDN redirect.
// See https://core.telegram.org/constructor/upload.fileCdnRedirect.
type RedirectError struct {
Redirect *tg.UploadFileCDNRedirect
}
// Error implements error interface.
func (r *RedirectError) Error() string {
return "redirect to CDN DC " + strconv.Itoa(r.Redirect.DCID)
}
// master is a master DC download schema.
// See https://core.telegram.org/api/files#downloading-files.
type master struct {
client Client
precise bool
allowCDN bool
location tg.InputFileLocationClass
}
var _ schema = master{}
func (c master) Chunk(ctx context.Context, offset int64, limit int) (chunk, error) {
req := &tg.UploadGetFileRequest{
Offset: offset,
Limit: limit,
Location: c.location,
}
req.SetCDNSupported(c.allowCDN)
req.SetPrecise(c.precise)
r, err := c.client.UploadGetFile(ctx, req)
if err != nil {
return chunk{}, err
}
switch result := r.(type) {
case *tg.UploadFile:
return chunk{data: result.Bytes, tag: result.Type}, nil
case *tg.UploadFileCDNRedirect:
return chunk{}, &RedirectError{Redirect: result}
default:
return chunk{}, errors.Errorf("unexpected type %T", r)
}
}
func (c master) Hashes(ctx context.Context, offset int64) ([]tg.FileHash, error) {
return c.client.UploadGetFileHashes(ctx, &tg.UploadGetFileHashesRequest{
Location: c.location,
Offset: offset,
})
}
+83
View File
@@ -0,0 +1,83 @@
package downloader
import (
"context"
"io"
"sync"
"github.com/go-faster/errors"
"go.mau.fi/mautrix-telegram/pkg/gotd/syncio"
"go.mau.fi/mautrix-telegram/pkg/gotd/tdsync"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
)
// nolint:gocognit
func (d *Downloader) parallel(
ctx context.Context, r *reader,
threads int, w io.WriterAt,
) (tg.StorageFileTypeClass, error) {
var typ tg.StorageFileTypeClass
typOnce := &sync.Once{}
ready := tdsync.NewReady()
g := tdsync.NewCancellableGroup(ctx)
toWrite := make(chan block, threads)
stop := func(t tg.StorageFileTypeClass) {
typOnce.Do(func() {
typ = t
})
ready.Signal()
}
// Download loop
g.Go(func(ctx context.Context) error {
downloads := tdsync.NewCancellableGroup(ctx)
defer close(toWrite)
for i := 0; i < threads; i++ {
downloads.Go(func(ctx context.Context) error {
for {
select {
case <-ctx.Done():
return ctx.Err()
case <-ready.Ready():
return nil
default:
}
b, err := r.Next(ctx)
if err != nil {
return errors.Wrap(err, "get file")
}
// If returned chunk is zero, that means we read all file.
n := len(b.data)
if n < 1 {
stop(b.tag)
return nil
}
select {
case <-ctx.Done():
return ctx.Err()
case toWrite <- b:
}
if b.last() {
stop(b.tag)
return nil
}
}
})
}
return downloads.Wait()
})
// Write loop
g.Go(writeAtLoop(syncio.NewWriterAt(w), toWrite))
return typ, g.Wait()
}
+107
View File
@@ -0,0 +1,107 @@
package downloader
import (
"context"
"sync"
"github.com/go-faster/errors"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
"go.mau.fi/mautrix-telegram/pkg/gotd/tgerr"
)
type block struct {
chunk
offset int64
partSize int
}
// last compares partSize and chunk length to determine last part.
func (b block) last() bool {
// If returned chunk is smaller than requested part, it seems
// it is last part.
return len(b.data) < b.partSize
}
type reader struct {
sch schema // immutable
verifier *verifier // immutable
partSize int // immutable
offset int64
offsetMux sync.Mutex
}
func verifiedReader(sch schema, verifier *verifier) *reader {
return &reader{
sch: sch,
verifier: verifier,
}
}
func plainReader(sch schema, partSize int) *reader {
return &reader{
sch: sch,
partSize: partSize,
}
}
func (r *reader) Next(ctx context.Context) (block, error) {
if r.verifier != nil {
return r.nextHashed(ctx)
}
return r.nextPlain(ctx)
}
func (r *reader) nextHashed(ctx context.Context) (block, error) {
// Fetch next hashes.
hash, ok, err := r.verifier.next(ctx)
if err != nil {
return block{}, err
}
if !ok {
return block{}, nil
}
// Get next chunk.
b, err := r.next(ctx, hash.Offset, hash.Limit)
if err != nil {
return block{}, err
}
// Verify chunk.
if !r.verifier.verify(hash, b.data) {
return block{}, ErrHashMismatch
}
return b, nil
}
func (r *reader) nextPlain(ctx context.Context) (block, error) {
r.offsetMux.Lock()
offset := r.offset
r.offset += int64(r.partSize)
r.offsetMux.Unlock()
return r.next(ctx, offset, r.partSize)
}
func (r *reader) next(ctx context.Context, offset int64, limit int) (block, error) {
for {
ch, err := r.sch.Chunk(ctx, offset, limit)
if flood, err := tgerr.FloodWait(ctx, err); err != nil {
if flood || tgerr.Is(err, tg.ErrTimeout) {
continue
}
return block{}, errors.Wrap(err, "get next chunk")
}
return block{
chunk: ch,
offset: offset,
partSize: r.partSize,
}, nil
}
}
+48
View File
@@ -0,0 +1,48 @@
package downloader
import (
"context"
"io"
"github.com/go-faster/errors"
)
func writeAtLoop(w io.WriterAt, toWrite <-chan block) func(context.Context) error {
return func(ctx context.Context) error {
for {
select {
case <-ctx.Done():
return ctx.Err()
case part, ok := <-toWrite:
if !ok {
return nil
}
_, err := w.WriteAt(part.data, part.offset)
if err != nil {
return errors.Wrap(err, "write output")
}
}
}
}
}
func writeLoop(w io.Writer, toWrite <-chan block) func(context.Context) error {
return func(ctx context.Context) error {
for {
select {
case <-ctx.Done():
return ctx.Err()
case part, ok := <-toWrite:
if !ok {
return nil
}
_, err := w.Write(part.data)
if err != nil {
return errors.Wrap(err, "write output")
}
}
}
}
}
+54
View File
@@ -0,0 +1,54 @@
package downloader
import (
"context"
"io"
"github.com/go-faster/errors"
"go.mau.fi/mautrix-telegram/pkg/gotd/tdsync"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
)
func (d *Downloader) stream(ctx context.Context, r *reader, w io.Writer) (tg.StorageFileTypeClass, error) {
var typ tg.StorageFileTypeClass
g := tdsync.NewCancellableGroup(ctx)
toWrite := make(chan block, 1)
stop := func(t tg.StorageFileTypeClass) {
typ = t
close(toWrite)
}
// Download loop
g.Go(func(ctx context.Context) error {
for {
b, err := r.Next(ctx)
if err != nil {
return errors.Wrap(err, "get file")
}
n := len(b.data)
if n < 1 {
stop(b.tag)
return nil
}
select {
case <-ctx.Done():
return ctx.Err()
case toWrite <- b:
}
if b.last() {
stop(b.tag)
return nil
}
}
})
// Write loop
g.Go(writeLoop(w, toWrite))
return typ, g.Wait()
}
@@ -0,0 +1,49 @@
package downloader
import (
"context"
"io"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
)
type streamReader struct {
ctx context.Context
reader *reader
curBlock block
last bool
}
var _ io.Reader = (*streamReader)(nil)
func (s *streamReader) Read(p []byte) (n int, err error) {
select {
case <-s.ctx.Done():
return 0, s.ctx.Err()
default:
}
if len(s.curBlock.data) == 0 {
if s.last {
return 0, io.EOF
} else {
s.curBlock, err = s.reader.Next(s.ctx)
if err != nil {
return 0, err
}
s.last = s.curBlock.last()
}
}
n = copy(p, s.curBlock.data)
s.curBlock.data = s.curBlock.data[n:]
return
}
func (d *Downloader) streamToReader(ctx context.Context, r *reader) (tg.StorageFileTypeClass, io.Reader, error) {
first, err := r.Next(ctx)
if err != nil {
return nil, nil, err
}
return first.tag, &streamReader{ctx, r, first, first.last()}, nil
}
+104
View File
@@ -0,0 +1,104 @@
package downloader
import (
"bytes"
"context"
"sort"
"sync"
"github.com/go-faster/errors"
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
"go.mau.fi/mautrix-telegram/pkg/gotd/tgerr"
)
// ErrHashMismatch means that download hash verification was failed.
var ErrHashMismatch = errors.New("file hash mismatch")
type verifier struct {
client schema
hashes []tg.FileHash
offset int64
mux sync.Mutex
}
func newVerifier(client schema, hashes ...tg.FileHash) *verifier {
r := make([]tg.FileHash, len(hashes))
copy(r, hashes)
sort.SliceStable(r, func(i, j int) bool {
return r[i].Offset < r[j].Offset
})
return &verifier{client: client, hashes: r}
}
func (v *verifier) pop() (tg.FileHash, bool) {
if len(v.hashes) < 1 {
return tg.FileHash{}, false
}
// Pop and move.
hash := v.hashes[0]
copy(v.hashes, v.hashes[1:])
v.hashes[len(v.hashes)-1] = tg.FileHash{}
v.hashes = v.hashes[:len(v.hashes)-1]
return hash, true
}
func (v *verifier) update(hashes ...tg.FileHash) (tg.FileHash, bool) {
// If result is empty and queue is empty, so we can't return next hash.
if len(hashes) < 1 {
return tg.FileHash{}, false
}
// Sort hashes by offset.
// Usually Telegram server returns sorted parts, but...
// you never known what can they do.
sort.SliceStable(hashes, func(i, j int) bool {
return hashes[i].Offset < hashes[j].Offset
})
last := hashes[len(hashes)-1]
// Check if we have reached the end.
// If current state offset is equal the last offset + limit (right border)
// then we got all hashes.
if last.Offset == v.offset-int64(last.Limit) {
return tg.FileHash{}, false
}
// Otherwise, we update current offset and add hashes to the end of queue.
v.offset = last.Offset + int64(last.Limit)
v.hashes = append(v.hashes, hashes...)
return v.pop()
}
func (v *verifier) next(ctx context.Context) (tg.FileHash, bool, error) {
v.mux.Lock()
defer v.mux.Unlock()
hash, ok := v.pop()
if ok {
return hash, ok, nil
}
for {
hashes, err := v.client.Hashes(ctx, v.offset)
if flood, err := tgerr.FloodWait(ctx, err); err != nil {
if flood || tgerr.Is(err, tg.ErrTimeout) {
continue
}
return tg.FileHash{}, false, errors.Wrap(err, "get hashes")
}
hash, ok = v.update(hashes...)
return hash, ok, nil
}
}
func (v *verifier) verify(hash tg.FileHash, data []byte) bool {
return bytes.Equal(crypto.SHA256(data), hash.Hash)
}
@@ -0,0 +1,104 @@
package downloader
import (
"context"
"testing"
"github.com/stretchr/testify/require"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
)
var hashRanges = [][]tg.FileHash{
{
tg.FileHash{Offset: 0, Limit: 131072},
tg.FileHash{Offset: 131072, Limit: 131072},
tg.FileHash{Offset: 262144, Limit: 131072},
tg.FileHash{Offset: 393216, Limit: 131072},
tg.FileHash{Offset: 524288, Limit: 131072},
tg.FileHash{Offset: 655360, Limit: 131072},
tg.FileHash{Offset: 786432, Limit: 131072},
tg.FileHash{Offset: 917504, Limit: 131072},
}, {
tg.FileHash{Offset: 1048576, Limit: 131072},
tg.FileHash{Offset: 1179648, Limit: 131072},
tg.FileHash{Offset: 1310720, Limit: 131072},
tg.FileHash{Offset: 1441792, Limit: 131072},
tg.FileHash{Offset: 1572864, Limit: 131072},
tg.FileHash{Offset: 1703936, Limit: 131072},
tg.FileHash{Offset: 1835008, Limit: 131072},
tg.FileHash{Offset: 1966080, Limit: 131072},
}, {
tg.FileHash{Offset: 2097152, Limit: 131072},
tg.FileHash{Offset: 2228224, Limit: 131072},
tg.FileHash{Offset: 2359296, Limit: 131072},
tg.FileHash{Offset: 2490368, Limit: 131072},
tg.FileHash{Offset: 2621440, Limit: 131072},
tg.FileHash{Offset: 2752512, Limit: 131072},
tg.FileHash{Offset: 2883584, Limit: 131072},
tg.FileHash{Offset: 3014656, Limit: 131072},
},
}
type mockHashes struct {
ranges [][]tg.FileHash
}
func (m mockHashes) Chunk(ctx context.Context, offset int64, limit int) (chunk, error) {
panic("implement me")
}
func (m mockHashes) Hashes(ctx context.Context, offset int64) ([]tg.FileHash, error) {
for _, r := range m.ranges {
last := r[len(r)-1]
if last.Offset+int64(last.Limit) <= offset {
continue
}
return r, nil
}
return m.ranges[len(m.ranges)-1], nil
}
func TestVerifier(t *testing.T) {
ctx := context.Background()
tests := []struct {
name string
ranges [][]tg.FileHash
// Hashes returned from CDN redirect, for example.
predefined []tg.FileHash
expected [][]tg.FileHash
}{
{"NoPredefined", hashRanges, nil, hashRanges},
{"Predefined", hashRanges[1:], hashRanges[0], hashRanges},
{"OnlyPredefined", hashRanges[:1], hashRanges[0], hashRanges[:1]},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
a := require.New(t)
m := mockHashes{ranges: test.ranges}
v := newVerifier(m, test.predefined...)
hashes := make([]tg.FileHash, 0, len(test.predefined))
for {
hash, ok, err := v.next(ctx)
a.NoError(err)
if !ok {
break
}
hashes = append(hashes, hash)
}
i := 0
for _, hashRange := range test.expected {
for _, expected := range hashRange {
a.Equal(expected, hashes[i])
i++
}
}
})
}
}
+38
View File
@@ -0,0 +1,38 @@
package downloader
import (
"context"
"github.com/go-faster/errors"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
)
var errHashesNotSupported = errors.New("this schema does not support hashes fetch")
// web is a web file download schema.
// See https://core.telegram.org/api/files#downloading-webfiles.
type web struct {
client Client
location tg.InputWebFileLocationClass
}
var _ schema = web{}
func (w web) Chunk(ctx context.Context, offset int64, limit int) (chunk, error) {
file, err := w.client.UploadGetWebFile(ctx, &tg.UploadGetWebFileRequest{
Location: w.location,
Offset: int(offset),
Limit: limit,
})
if err != nil {
return chunk{}, err
}
return chunk{data: file.Bytes, tag: file.FileType}, nil
}
func (w web) Hashes(ctx context.Context, offset int64) ([]tg.FileHash, error) {
return nil, errHashesNotSupported
}