7a04f298d2
- update to latest telegram layer - remove some references to fields in tg.Entities that don't exist in the schema - originally added here: https://github.com/beeper/td/commit/820929062a2ba0104397bc01235ab58a9cff780e - referenced here - https://github.com/mautrix/telegramgo/commit/124f0967ed195b5a380c9bd02e170ada9710dde3 - https://github.com/mautrix/telegramgo/commit/4205047aab2e0639217148b5d125bfaab668bd8e
137 lines
3.0 KiB
Go
137 lines
3.0 KiB
Go
package file
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
|
|
"github.com/go-faster/errors"
|
|
|
|
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
|
|
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
|
"go.mau.fi/mautrix-telegram/pkg/gotd/tgerr"
|
|
)
|
|
|
|
func getLocation(loc tg.InputFileLocationClass) (string, error) {
|
|
v, ok := loc.(interface {
|
|
GetLocalID() int
|
|
GetVolumeID() int64
|
|
})
|
|
if !ok {
|
|
return "", tgerr.New(400, tg.ErrFileIDInvalid)
|
|
}
|
|
|
|
return fmt.Sprintf("%d_%d", v.GetLocalID(), v.GetVolumeID()), nil
|
|
}
|
|
|
|
func (m *Service) openLocation(loc tg.InputFileLocationClass) (File, error) {
|
|
name, err := getLocation(loc)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
f, err := m.storage.Open(name)
|
|
if err != nil {
|
|
return nil, tgerr.New(400, tg.ErrFileIDInvalid)
|
|
}
|
|
|
|
return f, nil
|
|
}
|
|
|
|
func (m *Service) getPart(loc tg.InputFileLocationClass, offset int64, limit int) ([]byte, error) {
|
|
f, err := m.openLocation(loc)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
r := make([]byte, limit)
|
|
n, err := f.ReadAt(r, offset)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "read from storage")
|
|
}
|
|
|
|
return r[:n], nil
|
|
}
|
|
|
|
func (m *Service) UploadGetFile(ctx context.Context, request *tg.UploadGetFileRequest) (tg.UploadFileClass, error) {
|
|
data, err := m.getPart(request.Location, request.Offset, request.Limit)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &tg.UploadFile{
|
|
Type: &tg.StorageFilePartial{},
|
|
Mtime: 0,
|
|
Bytes: data,
|
|
}, nil
|
|
}
|
|
|
|
func countHashes(data []byte, offset int64, partSize int) []tg.FileHash {
|
|
actions := data
|
|
batchSize := partSize
|
|
batches := make([][]byte, 0, (len(actions)+batchSize-1)/batchSize)
|
|
|
|
for batchSize < len(actions) {
|
|
actions, batches = actions[batchSize:], append(batches, actions[0:batchSize:batchSize])
|
|
}
|
|
batches = append(batches, actions)
|
|
|
|
currentRange := make([]tg.FileHash, 0, 10)
|
|
for _, batch := range batches {
|
|
currentRange = append(currentRange, tg.FileHash{
|
|
Offset: offset,
|
|
Limit: partSize,
|
|
Hash: crypto.SHA256(batch),
|
|
})
|
|
offset += int64(len(batch))
|
|
}
|
|
return currentRange
|
|
}
|
|
|
|
func divAndCeil(a, b int) int {
|
|
r := a / b
|
|
if a%b != 0 {
|
|
r++
|
|
}
|
|
|
|
return r
|
|
}
|
|
|
|
// computeBatch computes hash range number for given offset.
|
|
func computeBatch(offset int64, rangeSize, partSize int) int {
|
|
// Compute number of parts in partSize from offset.
|
|
parts := divAndCeil(int(offset+1), partSize)
|
|
// Compute number of hash ranges in rangeSize.
|
|
batches := divAndCeil(parts, rangeSize)
|
|
|
|
return batches
|
|
}
|
|
|
|
func (m *Service) UploadGetFileHashes(
|
|
ctx context.Context,
|
|
request *tg.UploadGetFileHashesRequest,
|
|
) ([]tg.FileHash, error) {
|
|
f, err := m.openLocation(request.Location)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if request.Offset >= int64(f.Size()) {
|
|
return nil, nil
|
|
}
|
|
partSize := m.hashPartSize
|
|
rangeSize := m.hashRangeSize
|
|
batch := computeBatch(request.Offset, rangeSize, partSize)
|
|
|
|
low := (batch - 1) * rangeSize * partSize
|
|
high := batch * rangeSize * partSize
|
|
|
|
r := make([]byte, high-low)
|
|
n, err := f.ReadAt(r, int64(low))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
r = r[:n]
|
|
|
|
return countHashes(r, int64(low), partSize), nil
|
|
}
|