move gotd fork into repo. (#111)

- update to latest telegram layer
- remove some references to fields in tg.Entities that don't exist in
the schema
- originally added here:
https://github.com/beeper/td/commit/820929062a2ba0104397bc01235ab58a9cff780e
  - referenced here
-
https://github.com/mautrix/telegramgo/commit/124f0967ed195b5a380c9bd02e170ada9710dde3
-
https://github.com/mautrix/telegramgo/commit/4205047aab2e0639217148b5d125bfaab668bd8e
This commit is contained in:
Adam Van Ymeren
2025-06-27 20:03:37 -07:00
committed by GitHub
parent 0952df0244
commit 7a04f298d2
19264 changed files with 1539697 additions and 84 deletions
+129
View File
@@ -0,0 +1,129 @@
package uploader
import (
"context"
"io"
"github.com/go-faster/errors"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
"go.mau.fi/mautrix-telegram/pkg/gotd/syncio"
"go.mau.fi/mautrix-telegram/pkg/gotd/tdsync"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
"go.mau.fi/mautrix-telegram/pkg/gotd/tgerr"
)
type part struct {
id int
buf *bin.Buffer
upload *Upload
}
func (u *Uploader) uploadBigFilePart(ctx context.Context, p part) (int, error) {
defer u.pool.Put(p.buf)
// Upload loop.
for {
r, err := u.rpc.UploadSaveBigFilePart(ctx, &tg.UploadSaveBigFilePartRequest{
FileID: p.upload.id,
FilePart: p.id,
FileTotalParts: p.upload.totalParts,
Bytes: p.buf.Buf,
})
if flood, err := tgerr.FloodWait(ctx, err); err != nil {
if flood {
continue
}
return 0, errors.Wrapf(err, "send upload part %d RPC", p.id)
}
// If Telegram returned false, it seems save is not successful, so we retry to send.
if r {
return p.buf.Len(), nil
}
}
}
func (u *Uploader) bigLoop(ctx context.Context, threads int, upload *Upload) error { // nolint:gocognit
g := tdsync.NewCancellableGroup(ctx)
toSend := make(chan part, threads)
// Run read loop
r := syncio.NewReader(upload.from)
g.Go(func(ctx context.Context) error {
last := false
totalStreamSize := 0
for {
buf := u.pool.GetSize(u.partSize)
n, err := io.ReadFull(r, buf.Buf)
if n > 0 {
totalStreamSize += n
}
switch {
case errors.Is(err, io.ErrUnexpectedEOF):
last = true
if upload.totalParts == -1 {
totalParts := (totalStreamSize + u.partSize - 1) / u.partSize
upload.totalParts = int(totalParts)
}
case errors.Is(err, io.EOF):
u.pool.Put(buf)
close(toSend)
return nil
case err != nil:
u.pool.Put(buf)
return errors.Wrap(err, "read source")
}
buf.Buf = buf.Buf[:n]
nextPart := part{
id: int(upload.sentParts.Load()),
buf: buf,
upload: upload,
}
select {
case toSend <- nextPart:
upload.sentParts.Inc()
if last {
close(toSend)
return nil
}
case <-ctx.Done():
u.pool.Put(buf)
return ctx.Err()
}
}
})
for i := 0; i < threads; i++ {
g.Go(func(ctx context.Context) error {
for {
select {
case <-ctx.Done():
return ctx.Err()
case part, ok := <-toSend:
if !ok {
return nil
}
n, err := u.uploadBigFilePart(ctx, part)
if err != nil {
return errors.Wrap(err, "upload part")
}
if err := u.callback(ctx, upload.confirm(part.id, n)); err != nil {
return errors.Wrap(err, "progress callback")
}
}
}
})
}
return g.Wait()
}
+13
View File
@@ -0,0 +1,13 @@
package uploader
import (
"context"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
)
// Client represents Telegram RPC client.
type Client interface {
UploadSaveFilePart(ctx context.Context, request *tg.UploadSaveFilePartRequest) (bool, error)
UploadSaveBigFilePart(ctx context.Context, request *tg.UploadSaveBigFilePartRequest) (bool, error)
}
+109
View File
@@ -0,0 +1,109 @@
package uploader
import (
"bytes"
"context"
"io"
"io/fs"
"net/url"
"os"
"path/filepath"
"github.com/go-faster/errors"
"go.uber.org/multierr"
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/uploader/source"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
)
// File is file abstraction.
type File interface {
Stat() (os.FileInfo, error)
io.Reader
}
// FromFile uploads given File.
// NB: FromFile does not close given file.
func (u *Uploader) FromFile(ctx context.Context, f File, name string) (tg.InputFileClass, error) {
info, err := f.Stat()
if err != nil {
return nil, errors.Wrap(err, "stat")
}
if name == "" {
name = info.Name()
}
return u.Upload(ctx, NewUpload(name, f, info.Size()))
}
// FromPath uploads file from given path.
func (u *Uploader) FromPath(ctx context.Context, path, name string) (tg.InputFileClass, error) {
return u.FromFS(ctx, osFS{}, path, name)
}
type osFS struct{}
func (o osFS) Open(name string) (fs.File, error) {
return os.Open(filepath.Clean(name))
}
// FromFS uploads file from fs using given path.
func (u *Uploader) FromFS(ctx context.Context, filesystem fs.FS, path, name string) (_ tg.InputFileClass, err error) {
f, err := filesystem.Open(path)
if err != nil {
return nil, errors.Wrap(err, "open")
}
defer func() {
multierr.AppendInto(&err, f.Close())
}()
return u.FromFile(ctx, f, name)
}
// FromReader uploads file from given io.Reader.
// NB: totally stream should not exceed the limit for
// small files (10 MB as docs says, may be a bit bigger).
// Support For Big Files
// https://core.telegram.org/api/files#streamed-uploads
func (u *Uploader) FromReader(ctx context.Context, name string, f io.Reader) (tg.InputFileClass, error) {
return u.Upload(ctx, NewUpload(name, f, -1))
}
// FromBytes uploads file from given byte slice.
func (u *Uploader) FromBytes(ctx context.Context, name string, b []byte) (tg.InputFileClass, error) {
return u.Upload(ctx, NewUpload(name, bytes.NewReader(b), int64(len(b))))
}
// FromURL uses given source to upload to Telegram.
func (u *Uploader) FromURL(ctx context.Context, rawURL string) (_ tg.InputFileClass, rerr error) {
return u.FromSource(ctx, u.src, rawURL)
}
// FromSource uses given source and URL to fetch data and upload it to Telegram.
func (u *Uploader) FromSource(ctx context.Context, src source.Source, rawURL string) (_ tg.InputFileClass, rerr error) {
parsed, err := url.Parse(rawURL)
if err != nil {
return nil, errors.Wrapf(err, "parse url %q", rawURL)
}
f, err := src.Open(ctx, parsed)
if err != nil {
return nil, errors.Wrapf(err, "open %q", rawURL)
}
defer func() {
multierr.AppendInto(&rerr, f.Close())
}()
name := f.Name()
if name == "" {
return nil, errors.Errorf("invalid name %q got from %q", name, rawURL)
}
size := f.Size()
if size < 0 {
size = -1
}
return u.Upload(ctx, NewUpload(f.Name(), f, size))
}
+82
View File
@@ -0,0 +1,82 @@
package uploader
import (
"github.com/go-faster/errors"
"go.mau.fi/mautrix-telegram/pkg/gotd/constant"
)
// https://core.telegram.org/api/files#uploading-files
const (
// Use upload.saveBigFilePart in case the full size of the file is more than 10 MB
// and upload.saveFilePart for smaller files.
bigFileLimit = constant.UploadMaxSmallSize
// Each part should have a sequence number, file_part, with a value ranging from 0 to 3,999.
partsLimit = constant.UploadMaxParts
defaultPartSize = 128 * 1024 // 128 KB
// The files binary content is then split into parts. All parts must have the same size (part_size)
// and the following conditions must be met:
// `part_size % 1024 = 0` (divisible by 1KB)
paddingPartSize = constant.UploadPadding
// MaximumPartSize is maximum size of single part.
MaximumPartSize = constant.UploadMaxPartSize
)
func checkPartSize(partSize int) error {
switch {
case partSize == 0:
return errors.New("is equal to zero")
case partSize%paddingPartSize != 0:
return errors.Errorf("%d is not divisible by %d", partSize, paddingPartSize)
case MaximumPartSize%partSize != 0:
return errors.Errorf("%d is not divisible by %d", MaximumPartSize, partSize)
}
return nil
}
func computeParts(partSize, total int) int {
if total <= 0 {
return 0
}
parts := total / partSize
if total%partSize != 0 {
parts++
}
return parts
}
func (u *Uploader) initUpload(upload *Upload) error {
big := upload.totalBytes > bigFileLimit
totalParts := computeParts(u.partSize, int(upload.totalBytes))
if !big && totalParts > partsLimit {
return errors.Errorf(
"part size is too small: total size = %d, part size = %d, %d / %d > %d",
upload.totalBytes, u.partSize, upload.totalBytes, u.partSize, partsLimit,
)
}
if upload.id == 0 {
id, err := u.id()
if err != nil {
return errors.Wrap(err, "id generation")
}
upload.id = id
upload.partSize = u.partSize
} else if upload.partSize != u.partSize {
return errors.Errorf(
"previous upload has part size %d, but uploader size is %d",
upload.partSize, u.partSize,
)
}
upload.big = big
upload.totalParts = totalParts
return nil
}
+48
View File
@@ -0,0 +1,48 @@
package uploader
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestUploader_checkPartSize(t *testing.T) {
tests := []struct {
name string
partSize int
err bool
}{
{"Zero", 0, true},
{"Not divisible by 1024", 1023, true},
{"Max not divisible by part", MaximumPartSize + 1024, true},
{"Default", defaultPartSize, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := checkPartSize(tt.partSize)
if tt.err {
require.Error(t, err)
} else {
require.NoError(t, err)
}
})
}
}
func Test_computeParts(t *testing.T) {
tests := []struct {
name string
partSize int
total int
want int
}{
{"Exact part", 1024, 1024, 1},
{"Bit more than part", 1024, 1024 + 1, 2},
{"Stream", 1024, -1, 0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
require.Equal(t, tt.want, computeParts(tt.partSize, tt.total))
})
}
}
+25
View File
@@ -0,0 +1,25 @@
package uploader
import "context"
// ProgressState represents upload state change.
type ProgressState struct {
// ID of upload.
ID int64
// Name of uploading file.
Name string
// Part is an ID of uploaded part.
Part int
// PartSize is a size of uploaded part.
PartSize int
// Uploaded is a total sum of uploaded bytes.
Uploaded int64
// Total is a total size of uploading file.
// May be equal to -1, in case when Upload created without size (stream upload).
Total int64
}
// Progress is interface of upload process tracker.
type Progress interface {
Chunk(ctx context.Context, state ProgressState) error
}
+66
View File
@@ -0,0 +1,66 @@
package uploader
import (
"context"
"io"
"github.com/go-faster/errors"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
"go.mau.fi/mautrix-telegram/pkg/gotd/tgerr"
)
func (u *Uploader) smallLoop(ctx context.Context, h io.Writer, upload *Upload) error {
buf := u.pool.GetSize(u.partSize)
defer u.pool.Put(buf)
last := false
r := io.TeeReader(upload.from, h)
for {
n, err := io.ReadFull(r, buf.Buf)
switch {
case errors.Is(err, io.ErrUnexpectedEOF):
last = true
case errors.Is(err, io.EOF):
return nil
case err != nil:
return errors.Wrap(err, "read source")
}
read := buf.Buf[:n]
// Upload loop.
for {
r, err := u.rpc.UploadSaveFilePart(ctx, &tg.UploadSaveFilePartRequest{
FileID: upload.id,
FilePart: int(upload.sentParts.Load()) % partsLimit,
Bytes: read,
})
if flood, err := tgerr.FloodWait(ctx, err); err != nil {
if flood {
continue
}
return errors.Wrap(err, "send upload RPC")
}
// If Telegram returned false, it seems save is not successful, so we retry to send.
if !r {
continue
}
break
}
upload.sentParts.Inc()
if err := u.callback(ctx, upload.confirmSmall(n)); err != nil {
return errors.Wrap(err, "progress callback")
}
if last {
break
}
}
return nil
}
+2
View File
@@ -0,0 +1,2 @@
// Package source contains remote source interface and implementations for uploader.
package source
+82
View File
@@ -0,0 +1,82 @@
package source
import (
"context"
"io"
"net/http"
"net/url"
"path"
"github.com/go-faster/errors"
"go.uber.org/multierr"
)
// HTTPSource is HTTP source.
type HTTPSource struct {
client *http.Client
}
// NewHTTPSource creates new HTTPSource.
func NewHTTPSource() *HTTPSource {
return &HTTPSource{client: http.DefaultClient}
}
// WithClient sets HTTP client to use.
func (s *HTTPSource) WithClient(client *http.Client) *HTTPSource {
s.client = client
return s
}
type httpFile struct {
body io.ReadCloser
name string
size int64
}
func (h httpFile) Read(p []byte) (n int, err error) {
return h.body.Read(p)
}
func (h httpFile) Close() error {
return h.body.Close()
}
func (h httpFile) Name() string {
return h.name
}
func (h httpFile) Size() int64 {
return h.size
}
// Open implements Source.
func (s *HTTPSource) Open(ctx context.Context, u *url.URL) (_ RemoteFile, rerr error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
if err != nil {
return nil, errors.Wrap(err, "create request")
}
resp, err := s.client.Do(req)
if err != nil {
return nil, errors.Wrap(err, "get")
}
defer func() {
if rerr != nil {
multierr.AppendInto(&rerr, resp.Body.Close())
}
}()
if resp.StatusCode >= 400 {
return nil, errors.Errorf("bad code %d", resp.StatusCode)
}
lastURL := u
if resp.Request.URL != nil {
lastURL = resp.Request.URL
}
return httpFile{
body: resp.Body,
name: path.Base(lastURL.Path),
size: resp.ContentLength,
}, nil
}
@@ -0,0 +1,68 @@
package source
import (
"bytes"
"context"
"io"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestHTTPSource(t *testing.T) {
t.Run("OK", func(t *testing.T) {
a := require.New(t)
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
data := bytes.Repeat([]byte{1}, 10)
var h http.HandlerFunc = func(w http.ResponseWriter, req *http.Request) {
_, err := w.Write(data)
a.NoError(err)
}
s := httptest.NewServer(h)
defer s.Close()
src := new(HTTPSource).WithClient(s.Client())
f, err := src.Open(ctx, &url.URL{
Scheme: "http",
Host: s.Listener.Addr().String(),
Path: "img.jpg",
})
a.NoError(err)
a.Len(data, int(f.Size()))
a.Equal("img.jpg", f.Name())
r, err := io.ReadAll(f)
a.NoError(err)
a.Equal(data, r)
a.NoError(f.Close())
})
t.Run("NotFound", func(t *testing.T) {
a := require.New(t)
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
var h http.HandlerFunc = func(w http.ResponseWriter, req *http.Request) {
w.WriteHeader(http.StatusNotFound)
}
s := httptest.NewServer(h)
defer s.Close()
src := new(HTTPSource).WithClient(s.Client())
_, err := src.Open(ctx, &url.URL{
Scheme: "http",
Host: s.Listener.Addr().String(),
Path: "img.jpg",
})
a.Error(err)
})
}
@@ -0,0 +1,21 @@
package source
import (
"context"
"io"
"net/url"
)
// RemoteFile is abstraction for remote file.
type RemoteFile interface {
io.ReadCloser
// Name returns filename. Should not be empty.
Name() string
// Size returns size of file. If size is unknown, -1 should be returned.
Size() int64
}
// Source is abstraction for remote upload source.
type Source interface {
Open(ctx context.Context, u *url.URL) (RemoteFile, error)
}
Binary file not shown.
+77
View File
@@ -0,0 +1,77 @@
// Package uploader contains uploading files helpers.
package uploader
import (
"io"
"sync"
"go.uber.org/atomic"
)
// NewUpload creates new Upload struct using given
// name and reader.
func NewUpload(name string, from io.Reader, total int64) *Upload {
return &Upload{
name: name,
totalBytes: total,
from: from,
partSize: -1,
}
}
// Upload represents Telegram file upload.
type Upload struct {
// Fields which will be set by Uploader.
// File ID for Telegram.
id int64
// Sent parts (in partSize).
sentParts atomic.Int64
// Confirmed uploaded parts.
confirmedParts int
// Confirmed uploaded bytes.
confirmedBytes int64
confirmedMux sync.Mutex
// Total parts.
totalParts int
// Part size of uploader.
partSize int
// Flag to determine class of size of file.
big bool
// Total size (in bytes) of upload.
totalBytes int64 // immutable
// Name of file.
name string // immutable
// Reader of data.
from io.Reader // immutable
}
func (u *Upload) confirmSmall(bytes int) ProgressState {
u.confirmedMux.Lock()
defer u.confirmedMux.Unlock()
u.confirmedParts++
return u.confirmLocked(u.confirmedParts, bytes)
}
func (u *Upload) confirm(part, bytes int) ProgressState {
u.confirmedMux.Lock()
defer u.confirmedMux.Unlock()
return u.confirmLocked(part, bytes)
}
func (u *Upload) confirmLocked(part, bytes int) ProgressState {
u.confirmedBytes += int64(bytes)
return ProgressState{
ID: u.id,
Name: u.name,
Part: part,
PartSize: u.partSize,
Uploaded: u.confirmedBytes,
Total: u.totalBytes,
}
}
+129
View File
@@ -0,0 +1,129 @@
package uploader
import (
"context"
"crypto/md5" // #nosec G501
"encoding/hex"
"github.com/go-faster/errors"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/uploader/source"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
)
// Uploader is Telegram file uploader.
type Uploader struct {
rpc Client
id func() (int64, error)
partSize int
pool *bin.Pool
threads int
progress Progress
src source.Source
}
// NewUploader creates new Uploader.
func NewUploader(rpc Client) *Uploader {
return (&Uploader{
rpc: rpc,
id: func() (int64, error) {
return crypto.RandInt64(crypto.DefaultRand())
},
src: source.NewHTTPSource(),
threads: 1,
}).WithPartSize(defaultPartSize)
}
// WithProgress sets progress callback.
func (u *Uploader) WithProgress(progress Progress) *Uploader {
u.progress = progress
return u
}
// WithSource sets URL resolver to use.
func (u *Uploader) WithSource(src source.Source) *Uploader {
u.src = src
return u
}
// WithThreads sets uploading goroutines limit per upload.
func (u *Uploader) WithThreads(threads int) *Uploader {
if threads > 0 {
u.threads = threads
}
return u
}
// WithIDGenerator sets id generator.
func (u *Uploader) WithIDGenerator(cb func() (int64, error)) *Uploader {
u.id = cb
return u
}
// WithPartSize sets part size.
// Should be divisible by 1024.
// 524288 should be divisible by partSize.
//
// See https://core.telegram.org/api/files#uploading-files.
func (u *Uploader) WithPartSize(partSize int) *Uploader {
u.partSize = partSize
u.pool = bin.NewPool(partSize)
return u
}
// Upload uploads data from Upload object.
func (u *Uploader) Upload(ctx context.Context, upload *Upload) (tg.InputFileClass, error) {
if err := checkPartSize(u.partSize); err != nil {
return nil, errors.Wrap(err, "invalid part size")
}
if err := u.initUpload(upload); err != nil {
return nil, err
}
if upload.totalBytes == -1 {
upload.big = true
upload.totalParts = -1
}
if !upload.big {
return u.uploadSmall(ctx, upload)
}
return u.uploadBig(ctx, upload)
}
func (u *Uploader) uploadSmall(ctx context.Context, upload *Upload) (tg.InputFileClass, error) {
h := md5.New() // #nosec G401
if err := u.smallLoop(ctx, h, upload); err != nil {
return nil, err
}
return &tg.InputFile{
ID: upload.id,
Parts: int(upload.sentParts.Load()),
Name: upload.name,
MD5Checksum: hex.EncodeToString(h.Sum(nil)),
}, nil
}
func (u *Uploader) uploadBig(ctx context.Context, upload *Upload) (tg.InputFileClass, error) {
if err := u.bigLoop(ctx, u.threads, upload); err != nil {
return nil, err
}
return &tg.InputFileBig{
ID: upload.id,
Parts: int(upload.sentParts.Load()),
Name: upload.name,
}, nil
}
func (u *Uploader) callback(ctx context.Context, state ProgressState) error {
if u.progress != nil {
return u.progress.Chunk(ctx, state)
}
return nil
}
+224
View File
@@ -0,0 +1,224 @@
package uploader
import (
"bytes"
"context"
"crypto/rand"
"io"
"net/url"
"runtime"
"strconv"
"sync"
"testing"
"testing/fstest"
"github.com/go-faster/errors"
"github.com/stretchr/testify/require"
"go.uber.org/atomic"
"go.mau.fi/mautrix-telegram/pkg/gotd/syncio"
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/uploader/source"
"go.mau.fi/mautrix-telegram/pkg/gotd/testutil"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
)
type mockClient struct {
err bool
// Upload state.
buf *syncio.BufWriterAt
parts []atomic.Int64
partSize int
partSizeMux sync.Mutex
}
func newMockClient(err bool) *mockClient {
return &mockClient{
err: err,
buf: &syncio.BufWriterAt{},
parts: make([]atomic.Int64, partsLimit+1),
}
}
var testErr = testutil.TestError()
func (m *mockClient) write(part int, data []byte) error {
m.partSizeMux.Lock()
if m.partSize == 0 {
m.partSize = len(data)
} else if m.partSize != len(data) {
m.partSizeMux.Unlock()
return errors.Errorf(
"invalid part size, expected %d, got %d",
m.partSize, len(data),
)
}
partSize := m.partSize
m.partSizeMux.Unlock()
// Every part have ID which is offset in partSize from start of file.
// But maximal ID is 3999, so part ID for big files can overflow.
// We use parts array to count received parts by ID to compute the offset.
rangeOffset := int(m.parts[part].Inc() - 1)
// If rangeOffset is zero, so offset will be zero, part ID received first time.
// Otherwise, we count next range offset.
offset := rangeOffset * partsLimit * partSize
_, err := m.buf.WriteAt(data, int64(part*partSize+offset))
return err
}
func (m *mockClient) UploadSaveFilePart(ctx context.Context, request *tg.UploadSaveFilePartRequest) (bool, error) {
if m.err {
return false, testErr
}
if err := m.write(request.FilePart, request.Bytes); err != nil {
return false, err
}
return true, nil
}
func (m *mockClient) UploadSaveBigFilePart(ctx context.Context, request *tg.UploadSaveBigFilePartRequest) (bool, error) {
if m.err {
return false, testErr
}
if err := m.write(request.FilePart, request.Bytes); err != nil {
return false, err
}
return true, nil
}
type mockSource struct {
name string
data *bytes.Reader
}
func (m mockSource) Open(ctx context.Context, u *url.URL) (source.RemoteFile, error) {
return m, nil
}
func (m mockSource) Read(p []byte) (n int, err error) {
return m.data.Read(p)
}
func (m mockSource) Close() error {
return nil
}
func (m mockSource) Name() string {
return m.name
}
func (m mockSource) Size() int64 {
return m.data.Size()
}
func TestUploader(t *testing.T) {
ctx := context.Background()
testData := make([]byte, 15*1024*1024)
if _, err := io.ReadFull(rand.Reader, testData); err != nil {
t.Fatal(err)
}
tests := []struct {
name string
data []byte
err bool
}{
{"5b", []byte{1, 2, 3, 4, 5}, false},
{strconv.Itoa(defaultPartSize) + "b", bytes.Repeat([]byte{1}, defaultPartSize), false},
{strconv.Itoa(len(testData)) + "b", testData, false},
{"Error", []byte{1, 2, 3, 4, 5}, true},
}
ways := []struct {
name string
action func(b *Uploader, data []byte) error
}{
{"FromReader", func(b *Uploader, data []byte) error {
if len(data) == len(testData) {
b = b.WithPartSize(16384)
}
_, err := b.FromReader(ctx, "10.jpg", bytes.NewReader(data))
return err
}},
{"FromBytes", func(b *Uploader, data []byte) error {
if len(data) == len(testData) {
b = b.WithPartSize(MaximumPartSize)
}
_, err := b.FromBytes(ctx, "10.jpg", data)
return err
}},
{"FromFS", func(b *Uploader, data []byte) error {
if len(data) == len(testData) {
b = b.WithPartSize(MaximumPartSize)
}
_, err := b.FromFS(ctx, fstest.MapFS{
"10.jpg": &fstest.MapFile{
Data: data,
},
}, "10.jpg", "")
return err
}},
{"FromURL", func(b *Uploader, data []byte) error {
if len(data) == len(testData) {
b = b.WithPartSize(MaximumPartSize)
}
b = b.WithSource(mockSource{
name: "img.jpg",
data: bytes.NewReader(data),
})
_, err := b.FromURL(ctx, "http://example.com")
return err
}},
}
options := []struct {
name string
action func(b *Uploader) *Uploader
}{
{"OneThread", func(b *Uploader) *Uploader {
return b.WithThreads(1)
}},
{"ManyThread", func(b *Uploader) *Uploader {
return b.WithThreads(runtime.GOMAXPROCS(0))
}},
}
for _, way := range ways {
t.Run(way.name, func(t *testing.T) {
for _, option := range options {
t.Run(option.name, func(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
client := newMockClient(test.err)
u := NewUploader(client)
err := way.action(option.action(u), test.data)
if test.err {
require.Error(t, err)
return
}
require.NoError(t, err)
require.Truef(
t, bytes.Equal(test.data, client.buf.Bytes()),
"expected uploaded and given equal",
)
})
}
})
}
})
}
}