move gotd fork into repo. (#111)
- update to latest telegram layer - remove some references to fields in tg.Entities that don't exist in the schema - originally added here: https://github.com/beeper/td/commit/820929062a2ba0104397bc01235ab58a9cff780e - referenced here - https://github.com/mautrix/telegramgo/commit/124f0967ed195b5a380c9bd02e170ada9710dde3 - https://github.com/mautrix/telegramgo/commit/4205047aab2e0639217148b5d125bfaab668bd8e
This commit is contained in:
@@ -0,0 +1,129 @@
|
||||
package uploader
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/syncio"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tdsync"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tgerr"
|
||||
)
|
||||
|
||||
type part struct {
|
||||
id int
|
||||
buf *bin.Buffer
|
||||
upload *Upload
|
||||
}
|
||||
|
||||
func (u *Uploader) uploadBigFilePart(ctx context.Context, p part) (int, error) {
|
||||
defer u.pool.Put(p.buf)
|
||||
|
||||
// Upload loop.
|
||||
for {
|
||||
r, err := u.rpc.UploadSaveBigFilePart(ctx, &tg.UploadSaveBigFilePartRequest{
|
||||
FileID: p.upload.id,
|
||||
FilePart: p.id,
|
||||
FileTotalParts: p.upload.totalParts,
|
||||
Bytes: p.buf.Buf,
|
||||
})
|
||||
|
||||
if flood, err := tgerr.FloodWait(ctx, err); err != nil {
|
||||
if flood {
|
||||
continue
|
||||
}
|
||||
return 0, errors.Wrapf(err, "send upload part %d RPC", p.id)
|
||||
}
|
||||
|
||||
// If Telegram returned false, it seems save is not successful, so we retry to send.
|
||||
if r {
|
||||
return p.buf.Len(), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (u *Uploader) bigLoop(ctx context.Context, threads int, upload *Upload) error { // nolint:gocognit
|
||||
g := tdsync.NewCancellableGroup(ctx)
|
||||
toSend := make(chan part, threads)
|
||||
|
||||
// Run read loop
|
||||
r := syncio.NewReader(upload.from)
|
||||
g.Go(func(ctx context.Context) error {
|
||||
last := false
|
||||
totalStreamSize := 0
|
||||
|
||||
for {
|
||||
buf := u.pool.GetSize(u.partSize)
|
||||
|
||||
n, err := io.ReadFull(r, buf.Buf)
|
||||
if n > 0 {
|
||||
totalStreamSize += n
|
||||
}
|
||||
switch {
|
||||
case errors.Is(err, io.ErrUnexpectedEOF):
|
||||
last = true
|
||||
if upload.totalParts == -1 {
|
||||
totalParts := (totalStreamSize + u.partSize - 1) / u.partSize
|
||||
upload.totalParts = int(totalParts)
|
||||
}
|
||||
case errors.Is(err, io.EOF):
|
||||
u.pool.Put(buf)
|
||||
|
||||
close(toSend)
|
||||
return nil
|
||||
case err != nil:
|
||||
u.pool.Put(buf)
|
||||
|
||||
return errors.Wrap(err, "read source")
|
||||
}
|
||||
|
||||
buf.Buf = buf.Buf[:n]
|
||||
nextPart := part{
|
||||
id: int(upload.sentParts.Load()),
|
||||
buf: buf,
|
||||
upload: upload,
|
||||
}
|
||||
select {
|
||||
case toSend <- nextPart:
|
||||
upload.sentParts.Inc()
|
||||
if last {
|
||||
close(toSend)
|
||||
return nil
|
||||
}
|
||||
case <-ctx.Done():
|
||||
u.pool.Put(buf)
|
||||
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
for i := 0; i < threads; i++ {
|
||||
g.Go(func(ctx context.Context) error {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case part, ok := <-toSend:
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
n, err := u.uploadBigFilePart(ctx, part)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "upload part")
|
||||
}
|
||||
|
||||
if err := u.callback(ctx, upload.confirm(part.id, n)); err != nil {
|
||||
return errors.Wrap(err, "progress callback")
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
return g.Wait()
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
package uploader
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
// Client represents Telegram RPC client.
|
||||
type Client interface {
|
||||
UploadSaveFilePart(ctx context.Context, request *tg.UploadSaveFilePartRequest) (bool, error)
|
||||
UploadSaveBigFilePart(ctx context.Context, request *tg.UploadSaveBigFilePartRequest) (bool, error)
|
||||
}
|
||||
@@ -0,0 +1,109 @@
|
||||
package uploader
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"io/fs"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
"go.uber.org/multierr"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/uploader/source"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
// File is file abstraction.
|
||||
type File interface {
|
||||
Stat() (os.FileInfo, error)
|
||||
io.Reader
|
||||
}
|
||||
|
||||
// FromFile uploads given File.
|
||||
// NB: FromFile does not close given file.
|
||||
func (u *Uploader) FromFile(ctx context.Context, f File, name string) (tg.InputFileClass, error) {
|
||||
info, err := f.Stat()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "stat")
|
||||
}
|
||||
|
||||
if name == "" {
|
||||
name = info.Name()
|
||||
}
|
||||
|
||||
return u.Upload(ctx, NewUpload(name, f, info.Size()))
|
||||
}
|
||||
|
||||
// FromPath uploads file from given path.
|
||||
func (u *Uploader) FromPath(ctx context.Context, path, name string) (tg.InputFileClass, error) {
|
||||
return u.FromFS(ctx, osFS{}, path, name)
|
||||
}
|
||||
|
||||
type osFS struct{}
|
||||
|
||||
func (o osFS) Open(name string) (fs.File, error) {
|
||||
return os.Open(filepath.Clean(name))
|
||||
}
|
||||
|
||||
// FromFS uploads file from fs using given path.
|
||||
func (u *Uploader) FromFS(ctx context.Context, filesystem fs.FS, path, name string) (_ tg.InputFileClass, err error) {
|
||||
f, err := filesystem.Open(path)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "open")
|
||||
}
|
||||
defer func() {
|
||||
multierr.AppendInto(&err, f.Close())
|
||||
}()
|
||||
|
||||
return u.FromFile(ctx, f, name)
|
||||
}
|
||||
|
||||
// FromReader uploads file from given io.Reader.
|
||||
// NB: totally stream should not exceed the limit for
|
||||
// small files (10 MB as docs says, may be a bit bigger).
|
||||
// Support For Big Files
|
||||
// https://core.telegram.org/api/files#streamed-uploads
|
||||
func (u *Uploader) FromReader(ctx context.Context, name string, f io.Reader) (tg.InputFileClass, error) {
|
||||
return u.Upload(ctx, NewUpload(name, f, -1))
|
||||
}
|
||||
|
||||
// FromBytes uploads file from given byte slice.
|
||||
func (u *Uploader) FromBytes(ctx context.Context, name string, b []byte) (tg.InputFileClass, error) {
|
||||
return u.Upload(ctx, NewUpload(name, bytes.NewReader(b), int64(len(b))))
|
||||
}
|
||||
|
||||
// FromURL uses given source to upload to Telegram.
|
||||
func (u *Uploader) FromURL(ctx context.Context, rawURL string) (_ tg.InputFileClass, rerr error) {
|
||||
return u.FromSource(ctx, u.src, rawURL)
|
||||
}
|
||||
|
||||
// FromSource uses given source and URL to fetch data and upload it to Telegram.
|
||||
func (u *Uploader) FromSource(ctx context.Context, src source.Source, rawURL string) (_ tg.InputFileClass, rerr error) {
|
||||
parsed, err := url.Parse(rawURL)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "parse url %q", rawURL)
|
||||
}
|
||||
|
||||
f, err := src.Open(ctx, parsed)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "open %q", rawURL)
|
||||
}
|
||||
defer func() {
|
||||
multierr.AppendInto(&rerr, f.Close())
|
||||
}()
|
||||
|
||||
name := f.Name()
|
||||
if name == "" {
|
||||
return nil, errors.Errorf("invalid name %q got from %q", name, rawURL)
|
||||
}
|
||||
|
||||
size := f.Size()
|
||||
if size < 0 {
|
||||
size = -1
|
||||
}
|
||||
|
||||
return u.Upload(ctx, NewUpload(f.Name(), f, size))
|
||||
}
|
||||
@@ -0,0 +1,82 @@
|
||||
package uploader
|
||||
|
||||
import (
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/constant"
|
||||
)
|
||||
|
||||
// https://core.telegram.org/api/files#uploading-files
|
||||
const (
|
||||
// Use upload.saveBigFilePart in case the full size of the file is more than 10 MB
|
||||
// and upload.saveFilePart for smaller files.
|
||||
bigFileLimit = constant.UploadMaxSmallSize
|
||||
|
||||
// Each part should have a sequence number, file_part, with a value ranging from 0 to 3,999.
|
||||
partsLimit = constant.UploadMaxParts
|
||||
|
||||
defaultPartSize = 128 * 1024 // 128 KB
|
||||
// The file’s binary content is then split into parts. All parts must have the same size (part_size)
|
||||
// and the following conditions must be met:
|
||||
|
||||
// `part_size % 1024 = 0` (divisible by 1KB)
|
||||
paddingPartSize = constant.UploadPadding
|
||||
|
||||
// MaximumPartSize is maximum size of single part.
|
||||
MaximumPartSize = constant.UploadMaxPartSize
|
||||
)
|
||||
|
||||
func checkPartSize(partSize int) error {
|
||||
switch {
|
||||
case partSize == 0:
|
||||
return errors.New("is equal to zero")
|
||||
case partSize%paddingPartSize != 0:
|
||||
return errors.Errorf("%d is not divisible by %d", partSize, paddingPartSize)
|
||||
case MaximumPartSize%partSize != 0:
|
||||
return errors.Errorf("%d is not divisible by %d", MaximumPartSize, partSize)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func computeParts(partSize, total int) int {
|
||||
if total <= 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
parts := total / partSize
|
||||
if total%partSize != 0 {
|
||||
parts++
|
||||
}
|
||||
return parts
|
||||
}
|
||||
|
||||
func (u *Uploader) initUpload(upload *Upload) error {
|
||||
big := upload.totalBytes > bigFileLimit
|
||||
totalParts := computeParts(u.partSize, int(upload.totalBytes))
|
||||
if !big && totalParts > partsLimit {
|
||||
return errors.Errorf(
|
||||
"part size is too small: total size = %d, part size = %d, %d / %d > %d",
|
||||
upload.totalBytes, u.partSize, upload.totalBytes, u.partSize, partsLimit,
|
||||
)
|
||||
}
|
||||
|
||||
if upload.id == 0 {
|
||||
id, err := u.id()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "id generation")
|
||||
}
|
||||
|
||||
upload.id = id
|
||||
upload.partSize = u.partSize
|
||||
} else if upload.partSize != u.partSize {
|
||||
return errors.Errorf(
|
||||
"previous upload has part size %d, but uploader size is %d",
|
||||
upload.partSize, u.partSize,
|
||||
)
|
||||
}
|
||||
|
||||
upload.big = big
|
||||
upload.totalParts = totalParts
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,48 @@
|
||||
package uploader
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestUploader_checkPartSize(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
partSize int
|
||||
err bool
|
||||
}{
|
||||
{"Zero", 0, true},
|
||||
{"Not divisible by 1024", 1023, true},
|
||||
{"Max not divisible by part", MaximumPartSize + 1024, true},
|
||||
{"Default", defaultPartSize, false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := checkPartSize(tt.partSize)
|
||||
if tt.err {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_computeParts(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
partSize int
|
||||
total int
|
||||
want int
|
||||
}{
|
||||
{"Exact part", 1024, 1024, 1},
|
||||
{"Bit more than part", 1024, 1024 + 1, 2},
|
||||
{"Stream", 1024, -1, 0},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
require.Equal(t, tt.want, computeParts(tt.partSize, tt.total))
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,25 @@
|
||||
package uploader
|
||||
|
||||
import "context"
|
||||
|
||||
// ProgressState represents upload state change.
|
||||
type ProgressState struct {
|
||||
// ID of upload.
|
||||
ID int64
|
||||
// Name of uploading file.
|
||||
Name string
|
||||
// Part is an ID of uploaded part.
|
||||
Part int
|
||||
// PartSize is a size of uploaded part.
|
||||
PartSize int
|
||||
// Uploaded is a total sum of uploaded bytes.
|
||||
Uploaded int64
|
||||
// Total is a total size of uploading file.
|
||||
// May be equal to -1, in case when Upload created without size (stream upload).
|
||||
Total int64
|
||||
}
|
||||
|
||||
// Progress is interface of upload process tracker.
|
||||
type Progress interface {
|
||||
Chunk(ctx context.Context, state ProgressState) error
|
||||
}
|
||||
@@ -0,0 +1,66 @@
|
||||
package uploader
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tgerr"
|
||||
)
|
||||
|
||||
func (u *Uploader) smallLoop(ctx context.Context, h io.Writer, upload *Upload) error {
|
||||
buf := u.pool.GetSize(u.partSize)
|
||||
defer u.pool.Put(buf)
|
||||
|
||||
last := false
|
||||
|
||||
r := io.TeeReader(upload.from, h)
|
||||
for {
|
||||
n, err := io.ReadFull(r, buf.Buf)
|
||||
switch {
|
||||
case errors.Is(err, io.ErrUnexpectedEOF):
|
||||
last = true
|
||||
case errors.Is(err, io.EOF):
|
||||
return nil
|
||||
case err != nil:
|
||||
return errors.Wrap(err, "read source")
|
||||
}
|
||||
read := buf.Buf[:n]
|
||||
|
||||
// Upload loop.
|
||||
for {
|
||||
r, err := u.rpc.UploadSaveFilePart(ctx, &tg.UploadSaveFilePartRequest{
|
||||
FileID: upload.id,
|
||||
FilePart: int(upload.sentParts.Load()) % partsLimit,
|
||||
Bytes: read,
|
||||
})
|
||||
|
||||
if flood, err := tgerr.FloodWait(ctx, err); err != nil {
|
||||
if flood {
|
||||
continue
|
||||
}
|
||||
return errors.Wrap(err, "send upload RPC")
|
||||
}
|
||||
|
||||
// If Telegram returned false, it seems save is not successful, so we retry to send.
|
||||
if !r {
|
||||
continue
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
upload.sentParts.Inc()
|
||||
if err := u.callback(ctx, upload.confirmSmall(n)); err != nil {
|
||||
return errors.Wrap(err, "progress callback")
|
||||
}
|
||||
|
||||
if last {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,2 @@
|
||||
// Package source contains remote source interface and implementations for uploader.
|
||||
package source
|
||||
@@ -0,0 +1,82 @@
|
||||
package source
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
"go.uber.org/multierr"
|
||||
)
|
||||
|
||||
// HTTPSource is HTTP source.
|
||||
type HTTPSource struct {
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
// NewHTTPSource creates new HTTPSource.
|
||||
func NewHTTPSource() *HTTPSource {
|
||||
return &HTTPSource{client: http.DefaultClient}
|
||||
}
|
||||
|
||||
// WithClient sets HTTP client to use.
|
||||
func (s *HTTPSource) WithClient(client *http.Client) *HTTPSource {
|
||||
s.client = client
|
||||
return s
|
||||
}
|
||||
|
||||
type httpFile struct {
|
||||
body io.ReadCloser
|
||||
name string
|
||||
size int64
|
||||
}
|
||||
|
||||
func (h httpFile) Read(p []byte) (n int, err error) {
|
||||
return h.body.Read(p)
|
||||
}
|
||||
|
||||
func (h httpFile) Close() error {
|
||||
return h.body.Close()
|
||||
}
|
||||
|
||||
func (h httpFile) Name() string {
|
||||
return h.name
|
||||
}
|
||||
|
||||
func (h httpFile) Size() int64 {
|
||||
return h.size
|
||||
}
|
||||
|
||||
// Open implements Source.
|
||||
func (s *HTTPSource) Open(ctx context.Context, u *url.URL) (_ RemoteFile, rerr error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "create request")
|
||||
}
|
||||
|
||||
resp, err := s.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "get")
|
||||
}
|
||||
defer func() {
|
||||
if rerr != nil {
|
||||
multierr.AppendInto(&rerr, resp.Body.Close())
|
||||
}
|
||||
}()
|
||||
if resp.StatusCode >= 400 {
|
||||
return nil, errors.Errorf("bad code %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
lastURL := u
|
||||
if resp.Request.URL != nil {
|
||||
lastURL = resp.Request.URL
|
||||
}
|
||||
|
||||
return httpFile{
|
||||
body: resp.Body,
|
||||
name: path.Base(lastURL.Path),
|
||||
size: resp.ContentLength,
|
||||
}, nil
|
||||
}
|
||||
@@ -0,0 +1,68 @@
|
||||
package source
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestHTTPSource(t *testing.T) {
|
||||
t.Run("OK", func(t *testing.T) {
|
||||
a := require.New(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||
defer cancel()
|
||||
|
||||
data := bytes.Repeat([]byte{1}, 10)
|
||||
var h http.HandlerFunc = func(w http.ResponseWriter, req *http.Request) {
|
||||
_, err := w.Write(data)
|
||||
a.NoError(err)
|
||||
}
|
||||
|
||||
s := httptest.NewServer(h)
|
||||
defer s.Close()
|
||||
|
||||
src := new(HTTPSource).WithClient(s.Client())
|
||||
f, err := src.Open(ctx, &url.URL{
|
||||
Scheme: "http",
|
||||
Host: s.Listener.Addr().String(),
|
||||
Path: "img.jpg",
|
||||
})
|
||||
a.NoError(err)
|
||||
a.Len(data, int(f.Size()))
|
||||
a.Equal("img.jpg", f.Name())
|
||||
|
||||
r, err := io.ReadAll(f)
|
||||
a.NoError(err)
|
||||
a.Equal(data, r)
|
||||
|
||||
a.NoError(f.Close())
|
||||
})
|
||||
|
||||
t.Run("NotFound", func(t *testing.T) {
|
||||
a := require.New(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||
defer cancel()
|
||||
|
||||
var h http.HandlerFunc = func(w http.ResponseWriter, req *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
|
||||
s := httptest.NewServer(h)
|
||||
defer s.Close()
|
||||
|
||||
src := new(HTTPSource).WithClient(s.Client())
|
||||
_, err := src.Open(ctx, &url.URL{
|
||||
Scheme: "http",
|
||||
Host: s.Listener.Addr().String(),
|
||||
Path: "img.jpg",
|
||||
})
|
||||
a.Error(err)
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,21 @@
|
||||
package source
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
// RemoteFile is abstraction for remote file.
|
||||
type RemoteFile interface {
|
||||
io.ReadCloser
|
||||
// Name returns filename. Should not be empty.
|
||||
Name() string
|
||||
// Size returns size of file. If size is unknown, -1 should be returned.
|
||||
Size() int64
|
||||
}
|
||||
|
||||
// Source is abstraction for remote upload source.
|
||||
type Source interface {
|
||||
Open(ctx context.Context, u *url.URL) (RemoteFile, error)
|
||||
}
|
||||
BIN
Binary file not shown.
@@ -0,0 +1,77 @@
|
||||
// Package uploader contains uploading files helpers.
|
||||
package uploader
|
||||
|
||||
import (
|
||||
"io"
|
||||
"sync"
|
||||
|
||||
"go.uber.org/atomic"
|
||||
)
|
||||
|
||||
// NewUpload creates new Upload struct using given
|
||||
// name and reader.
|
||||
func NewUpload(name string, from io.Reader, total int64) *Upload {
|
||||
return &Upload{
|
||||
name: name,
|
||||
totalBytes: total,
|
||||
from: from,
|
||||
partSize: -1,
|
||||
}
|
||||
}
|
||||
|
||||
// Upload represents Telegram file upload.
|
||||
type Upload struct {
|
||||
// Fields which will be set by Uploader.
|
||||
// File ID for Telegram.
|
||||
id int64
|
||||
// Sent parts (in partSize).
|
||||
sentParts atomic.Int64
|
||||
|
||||
// Confirmed uploaded parts.
|
||||
confirmedParts int
|
||||
// Confirmed uploaded bytes.
|
||||
confirmedBytes int64
|
||||
confirmedMux sync.Mutex
|
||||
|
||||
// Total parts.
|
||||
totalParts int
|
||||
// Part size of uploader.
|
||||
partSize int
|
||||
// Flag to determine class of size of file.
|
||||
big bool
|
||||
|
||||
// Total size (in bytes) of upload.
|
||||
totalBytes int64 // immutable
|
||||
// Name of file.
|
||||
name string // immutable
|
||||
// Reader of data.
|
||||
from io.Reader // immutable
|
||||
}
|
||||
|
||||
func (u *Upload) confirmSmall(bytes int) ProgressState {
|
||||
u.confirmedMux.Lock()
|
||||
defer u.confirmedMux.Unlock()
|
||||
|
||||
u.confirmedParts++
|
||||
return u.confirmLocked(u.confirmedParts, bytes)
|
||||
}
|
||||
|
||||
func (u *Upload) confirm(part, bytes int) ProgressState {
|
||||
u.confirmedMux.Lock()
|
||||
defer u.confirmedMux.Unlock()
|
||||
|
||||
return u.confirmLocked(part, bytes)
|
||||
}
|
||||
|
||||
func (u *Upload) confirmLocked(part, bytes int) ProgressState {
|
||||
u.confirmedBytes += int64(bytes)
|
||||
|
||||
return ProgressState{
|
||||
ID: u.id,
|
||||
Name: u.name,
|
||||
Part: part,
|
||||
PartSize: u.partSize,
|
||||
Uploaded: u.confirmedBytes,
|
||||
Total: u.totalBytes,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,129 @@
|
||||
package uploader
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/md5" // #nosec G501
|
||||
"encoding/hex"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/uploader/source"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
// Uploader is Telegram file uploader.
|
||||
type Uploader struct {
|
||||
rpc Client
|
||||
id func() (int64, error)
|
||||
partSize int
|
||||
pool *bin.Pool
|
||||
threads int
|
||||
progress Progress
|
||||
src source.Source
|
||||
}
|
||||
|
||||
// NewUploader creates new Uploader.
|
||||
func NewUploader(rpc Client) *Uploader {
|
||||
return (&Uploader{
|
||||
rpc: rpc,
|
||||
id: func() (int64, error) {
|
||||
return crypto.RandInt64(crypto.DefaultRand())
|
||||
},
|
||||
src: source.NewHTTPSource(),
|
||||
threads: 1,
|
||||
}).WithPartSize(defaultPartSize)
|
||||
}
|
||||
|
||||
// WithProgress sets progress callback.
|
||||
func (u *Uploader) WithProgress(progress Progress) *Uploader {
|
||||
u.progress = progress
|
||||
return u
|
||||
}
|
||||
|
||||
// WithSource sets URL resolver to use.
|
||||
func (u *Uploader) WithSource(src source.Source) *Uploader {
|
||||
u.src = src
|
||||
return u
|
||||
}
|
||||
|
||||
// WithThreads sets uploading goroutines limit per upload.
|
||||
func (u *Uploader) WithThreads(threads int) *Uploader {
|
||||
if threads > 0 {
|
||||
u.threads = threads
|
||||
}
|
||||
return u
|
||||
}
|
||||
|
||||
// WithIDGenerator sets id generator.
|
||||
func (u *Uploader) WithIDGenerator(cb func() (int64, error)) *Uploader {
|
||||
u.id = cb
|
||||
return u
|
||||
}
|
||||
|
||||
// WithPartSize sets part size.
|
||||
// Should be divisible by 1024.
|
||||
// 524288 should be divisible by partSize.
|
||||
//
|
||||
// See https://core.telegram.org/api/files#uploading-files.
|
||||
func (u *Uploader) WithPartSize(partSize int) *Uploader {
|
||||
u.partSize = partSize
|
||||
u.pool = bin.NewPool(partSize)
|
||||
return u
|
||||
}
|
||||
|
||||
// Upload uploads data from Upload object.
|
||||
func (u *Uploader) Upload(ctx context.Context, upload *Upload) (tg.InputFileClass, error) {
|
||||
if err := checkPartSize(u.partSize); err != nil {
|
||||
return nil, errors.Wrap(err, "invalid part size")
|
||||
}
|
||||
|
||||
if err := u.initUpload(upload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if upload.totalBytes == -1 {
|
||||
upload.big = true
|
||||
upload.totalParts = -1
|
||||
}
|
||||
|
||||
if !upload.big {
|
||||
return u.uploadSmall(ctx, upload)
|
||||
}
|
||||
|
||||
return u.uploadBig(ctx, upload)
|
||||
}
|
||||
|
||||
func (u *Uploader) uploadSmall(ctx context.Context, upload *Upload) (tg.InputFileClass, error) {
|
||||
h := md5.New() // #nosec G401
|
||||
if err := u.smallLoop(ctx, h, upload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &tg.InputFile{
|
||||
ID: upload.id,
|
||||
Parts: int(upload.sentParts.Load()),
|
||||
Name: upload.name,
|
||||
MD5Checksum: hex.EncodeToString(h.Sum(nil)),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (u *Uploader) uploadBig(ctx context.Context, upload *Upload) (tg.InputFileClass, error) {
|
||||
if err := u.bigLoop(ctx, u.threads, upload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &tg.InputFileBig{
|
||||
ID: upload.id,
|
||||
Parts: int(upload.sentParts.Load()),
|
||||
Name: upload.name,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (u *Uploader) callback(ctx context.Context, state ProgressState) error {
|
||||
if u.progress != nil {
|
||||
return u.progress.Chunk(ctx, state)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,224 @@
|
||||
package uploader
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"io"
|
||||
"net/url"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"sync"
|
||||
"testing"
|
||||
"testing/fstest"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/atomic"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/syncio"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/uploader/source"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/testutil"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
type mockClient struct {
|
||||
err bool
|
||||
|
||||
// Upload state.
|
||||
buf *syncio.BufWriterAt
|
||||
parts []atomic.Int64
|
||||
|
||||
partSize int
|
||||
partSizeMux sync.Mutex
|
||||
}
|
||||
|
||||
func newMockClient(err bool) *mockClient {
|
||||
return &mockClient{
|
||||
err: err,
|
||||
buf: &syncio.BufWriterAt{},
|
||||
parts: make([]atomic.Int64, partsLimit+1),
|
||||
}
|
||||
}
|
||||
|
||||
var testErr = testutil.TestError()
|
||||
|
||||
func (m *mockClient) write(part int, data []byte) error {
|
||||
m.partSizeMux.Lock()
|
||||
if m.partSize == 0 {
|
||||
m.partSize = len(data)
|
||||
} else if m.partSize != len(data) {
|
||||
m.partSizeMux.Unlock()
|
||||
|
||||
return errors.Errorf(
|
||||
"invalid part size, expected %d, got %d",
|
||||
m.partSize, len(data),
|
||||
)
|
||||
}
|
||||
partSize := m.partSize
|
||||
m.partSizeMux.Unlock()
|
||||
|
||||
// Every part have ID which is offset in partSize from start of file.
|
||||
// But maximal ID is 3999, so part ID for big files can overflow.
|
||||
// We use parts array to count received parts by ID to compute the offset.
|
||||
rangeOffset := int(m.parts[part].Inc() - 1)
|
||||
|
||||
// If rangeOffset is zero, so offset will be zero, part ID received first time.
|
||||
// Otherwise, we count next range offset.
|
||||
offset := rangeOffset * partsLimit * partSize
|
||||
_, err := m.buf.WriteAt(data, int64(part*partSize+offset))
|
||||
return err
|
||||
}
|
||||
|
||||
func (m *mockClient) UploadSaveFilePart(ctx context.Context, request *tg.UploadSaveFilePartRequest) (bool, error) {
|
||||
if m.err {
|
||||
return false, testErr
|
||||
}
|
||||
|
||||
if err := m.write(request.FilePart, request.Bytes); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (m *mockClient) UploadSaveBigFilePart(ctx context.Context, request *tg.UploadSaveBigFilePartRequest) (bool, error) {
|
||||
if m.err {
|
||||
return false, testErr
|
||||
}
|
||||
|
||||
if err := m.write(request.FilePart, request.Bytes); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
type mockSource struct {
|
||||
name string
|
||||
data *bytes.Reader
|
||||
}
|
||||
|
||||
func (m mockSource) Open(ctx context.Context, u *url.URL) (source.RemoteFile, error) {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m mockSource) Read(p []byte) (n int, err error) {
|
||||
return m.data.Read(p)
|
||||
}
|
||||
|
||||
func (m mockSource) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m mockSource) Name() string {
|
||||
return m.name
|
||||
}
|
||||
|
||||
func (m mockSource) Size() int64 {
|
||||
return m.data.Size()
|
||||
}
|
||||
|
||||
func TestUploader(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
testData := make([]byte, 15*1024*1024)
|
||||
if _, err := io.ReadFull(rand.Reader, testData); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
data []byte
|
||||
err bool
|
||||
}{
|
||||
{"5b", []byte{1, 2, 3, 4, 5}, false},
|
||||
{strconv.Itoa(defaultPartSize) + "b", bytes.Repeat([]byte{1}, defaultPartSize), false},
|
||||
{strconv.Itoa(len(testData)) + "b", testData, false},
|
||||
{"Error", []byte{1, 2, 3, 4, 5}, true},
|
||||
}
|
||||
|
||||
ways := []struct {
|
||||
name string
|
||||
action func(b *Uploader, data []byte) error
|
||||
}{
|
||||
{"FromReader", func(b *Uploader, data []byte) error {
|
||||
if len(data) == len(testData) {
|
||||
b = b.WithPartSize(16384)
|
||||
}
|
||||
_, err := b.FromReader(ctx, "10.jpg", bytes.NewReader(data))
|
||||
return err
|
||||
}},
|
||||
{"FromBytes", func(b *Uploader, data []byte) error {
|
||||
if len(data) == len(testData) {
|
||||
b = b.WithPartSize(MaximumPartSize)
|
||||
}
|
||||
|
||||
_, err := b.FromBytes(ctx, "10.jpg", data)
|
||||
return err
|
||||
}},
|
||||
{"FromFS", func(b *Uploader, data []byte) error {
|
||||
if len(data) == len(testData) {
|
||||
b = b.WithPartSize(MaximumPartSize)
|
||||
}
|
||||
|
||||
_, err := b.FromFS(ctx, fstest.MapFS{
|
||||
"10.jpg": &fstest.MapFile{
|
||||
Data: data,
|
||||
},
|
||||
}, "10.jpg", "")
|
||||
return err
|
||||
}},
|
||||
{"FromURL", func(b *Uploader, data []byte) error {
|
||||
if len(data) == len(testData) {
|
||||
b = b.WithPartSize(MaximumPartSize)
|
||||
}
|
||||
b = b.WithSource(mockSource{
|
||||
name: "img.jpg",
|
||||
data: bytes.NewReader(data),
|
||||
})
|
||||
|
||||
_, err := b.FromURL(ctx, "http://example.com")
|
||||
return err
|
||||
}},
|
||||
}
|
||||
|
||||
options := []struct {
|
||||
name string
|
||||
action func(b *Uploader) *Uploader
|
||||
}{
|
||||
{"OneThread", func(b *Uploader) *Uploader {
|
||||
return b.WithThreads(1)
|
||||
}},
|
||||
{"ManyThread", func(b *Uploader) *Uploader {
|
||||
return b.WithThreads(runtime.GOMAXPROCS(0))
|
||||
}},
|
||||
}
|
||||
|
||||
for _, way := range ways {
|
||||
t.Run(way.name, func(t *testing.T) {
|
||||
for _, option := range options {
|
||||
t.Run(option.name, func(t *testing.T) {
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
client := newMockClient(test.err)
|
||||
u := NewUploader(client)
|
||||
|
||||
err := way.action(option.action(u), test.data)
|
||||
if test.err {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Truef(
|
||||
t, bytes.Equal(test.data, client.buf.Bytes()),
|
||||
"expected uploaded and given equal",
|
||||
)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user