move gotd fork into repo. (#111)
- update to latest telegram layer - remove some references to fields in tg.Entities that don't exist in the schema - originally added here: https://github.com/beeper/td/commit/820929062a2ba0104397bc01235ab58a9cff780e - referenced here - https://github.com/mautrix/telegramgo/commit/124f0967ed195b5a380c9bd02e170ada9710dde3 - https://github.com/mautrix/telegramgo/commit/4205047aab2e0639217148b5d125bfaab668bd8e
This commit is contained in:
@@ -0,0 +1,63 @@
|
||||
package tgtest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/transport"
|
||||
)
|
||||
|
||||
type bufferedConn struct {
|
||||
conn transport.Conn
|
||||
|
||||
recv []bin.Buffer
|
||||
recvMux sync.Mutex
|
||||
}
|
||||
|
||||
func newBufferedConn(conn transport.Conn) *bufferedConn {
|
||||
return &bufferedConn{conn: conn}
|
||||
}
|
||||
|
||||
func (c *bufferedConn) push(b *bin.Buffer) {
|
||||
c.recvMux.Lock()
|
||||
c.recv = append(c.recv, bin.Buffer{Buf: b.Copy()})
|
||||
c.recvMux.Unlock()
|
||||
}
|
||||
|
||||
func (c *bufferedConn) pop() (r bin.Buffer, ok bool) {
|
||||
c.recvMux.Lock()
|
||||
defer c.recvMux.Unlock()
|
||||
if len(c.recv) < 1 {
|
||||
return
|
||||
}
|
||||
r, c.recv = c.recv[len(c.recv)-1], c.recv[:len(c.recv)-1]
|
||||
ok = true
|
||||
return
|
||||
}
|
||||
|
||||
func (c *bufferedConn) Push(b *bin.Buffer) {
|
||||
c.push(b)
|
||||
}
|
||||
|
||||
func (c *bufferedConn) Pop() (bin.Buffer, bool) {
|
||||
return c.pop()
|
||||
}
|
||||
|
||||
func (c *bufferedConn) Send(ctx context.Context, b *bin.Buffer) error {
|
||||
return c.conn.Send(ctx, b)
|
||||
}
|
||||
|
||||
func (c *bufferedConn) Recv(ctx context.Context, b *bin.Buffer) error {
|
||||
e, ok := c.Pop()
|
||||
if ok {
|
||||
b.ResetTo(e.Copy())
|
||||
return nil
|
||||
}
|
||||
|
||||
return c.conn.Recv(ctx, b)
|
||||
}
|
||||
|
||||
func (c *bufferedConn) Close() error {
|
||||
return c.conn.Close()
|
||||
}
|
||||
@@ -0,0 +1,64 @@
|
||||
package tgtest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/transport"
|
||||
)
|
||||
|
||||
func TestBufferedConn(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
i := transport.Intermediate
|
||||
c1, c2 := i.Pipe()
|
||||
b := newBufferedConn(c1)
|
||||
defer func() {
|
||||
a.NoError(b.Close())
|
||||
a.NoError(c2.Close())
|
||||
}()
|
||||
|
||||
payload := []byte("abcdabcd")
|
||||
go func() {
|
||||
b1 := &bin.Buffer{Buf: payload}
|
||||
a.NoError(c2.Send(ctx, b1))
|
||||
}()
|
||||
|
||||
// Test Recv before Push.
|
||||
recvBuf := &bin.Buffer{}
|
||||
a.NoError(b.Recv(ctx, recvBuf))
|
||||
a.Equal(payload, recvBuf.Buf)
|
||||
|
||||
pushed := []byte("12345678")
|
||||
b.Push(&bin.Buffer{Buf: pushed})
|
||||
go func() {
|
||||
b1 := &bin.Buffer{Buf: payload}
|
||||
a.NoError(c2.Send(ctx, b1))
|
||||
}()
|
||||
|
||||
// Test Push.
|
||||
recvBuf.Reset()
|
||||
a.NoError(b.Recv(ctx, recvBuf))
|
||||
a.Equal(pushed, recvBuf.Buf)
|
||||
|
||||
// Test Recv after Push.
|
||||
recvBuf.Reset()
|
||||
a.NoError(b.Recv(ctx, recvBuf))
|
||||
a.Equal(payload, recvBuf.Buf)
|
||||
|
||||
// Test send.
|
||||
go func() {
|
||||
b1 := &bin.Buffer{Buf: payload}
|
||||
a.NoError(b.Send(ctx, b1))
|
||||
}()
|
||||
|
||||
recvBuf.Reset()
|
||||
a.NoError(c2.Recv(ctx, recvBuf))
|
||||
a.Equal(payload, recvBuf.Buf)
|
||||
}
|
||||
@@ -0,0 +1,96 @@
|
||||
// Package cluster contains Telegram multi-DC setup utilities.
|
||||
package cluster
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/exchange"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tdsync"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/dcs"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tgtest"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tgtest/services/config"
|
||||
)
|
||||
|
||||
type setup struct {
|
||||
srv *tgtest.Server
|
||||
dispatch *tgtest.Dispatcher
|
||||
}
|
||||
|
||||
// Cluster is a cluster of multiple servers, representing multiple Telegram datacenters.
|
||||
type Cluster struct {
|
||||
// denotes to use websocket listener
|
||||
web bool
|
||||
|
||||
setups map[int]setup
|
||||
keys []exchange.PublicKey
|
||||
|
||||
// DCs config state.
|
||||
cfg tg.Config
|
||||
cdnCfg tg.CDNConfig
|
||||
domains map[int]string
|
||||
|
||||
// Signal for readiness.
|
||||
ready *tdsync.Ready
|
||||
|
||||
// RPC dispatcher.
|
||||
common *tgtest.Dispatcher
|
||||
|
||||
log *zap.Logger
|
||||
random io.Reader
|
||||
protocol dcs.Protocol
|
||||
}
|
||||
|
||||
// NewCluster creates new server Cluster.
|
||||
func NewCluster(opts Options) *Cluster {
|
||||
opts.setDefaults()
|
||||
|
||||
q := &Cluster{
|
||||
web: opts.Web,
|
||||
setups: map[int]setup{},
|
||||
keys: nil,
|
||||
cfg: opts.Config,
|
||||
cdnCfg: opts.CDNConfig,
|
||||
domains: map[int]string{},
|
||||
ready: tdsync.NewReady(),
|
||||
common: tgtest.NewDispatcher(),
|
||||
log: opts.Logger,
|
||||
random: opts.Random,
|
||||
protocol: opts.Protocol,
|
||||
}
|
||||
config.NewService(&q.cfg, &q.cdnCfg).Register(q.common)
|
||||
q.common.Fallback(q.fallback())
|
||||
|
||||
return q
|
||||
}
|
||||
|
||||
// List returns DCs list.
|
||||
func (c *Cluster) List() dcs.List {
|
||||
return dcs.List{
|
||||
Options: c.cfg.DCOptions,
|
||||
Domains: c.domains,
|
||||
}
|
||||
}
|
||||
|
||||
// Resolver returns dcs.Resolver to use.
|
||||
func (c *Cluster) Resolver() dcs.Resolver {
|
||||
if c.web {
|
||||
return dcs.Websocket(dcs.WebsocketOptions{})
|
||||
}
|
||||
|
||||
return dcs.Plain(dcs.PlainOptions{
|
||||
Protocol: c.protocol,
|
||||
})
|
||||
}
|
||||
|
||||
// Keys returns all servers public keys.
|
||||
func (c *Cluster) Keys() []exchange.PublicKey {
|
||||
return c.keys
|
||||
}
|
||||
|
||||
// Ready returns signal channel to await readiness.
|
||||
func (c *Cluster) Ready() <-chan struct{} {
|
||||
return c.ready.Ready()
|
||||
}
|
||||
@@ -0,0 +1,70 @@
|
||||
package cluster
|
||||
|
||||
import (
|
||||
"crypto/rsa"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/exchange"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tgtest"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tgtest/services"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/transport"
|
||||
)
|
||||
|
||||
// Common returns common dispatcher.
|
||||
func (c *Cluster) Common() *tgtest.Dispatcher {
|
||||
return c.common
|
||||
}
|
||||
|
||||
func (c *Cluster) getCodec() (codec func() transport.Codec) {
|
||||
if !c.web {
|
||||
codec = c.protocol.Codec
|
||||
}
|
||||
return codec
|
||||
}
|
||||
|
||||
// DC registers new server and returns it.
|
||||
func (c *Cluster) DC(id int, name string) (*tgtest.Server, *tgtest.Dispatcher) {
|
||||
if s, ok := c.setups[id]; ok {
|
||||
return s.srv, s.dispatch
|
||||
}
|
||||
|
||||
key, err := rsa.GenerateKey(c.random, crypto.RSAKeyBits)
|
||||
if err != nil {
|
||||
// TODO(tdakkota): Return error instead.
|
||||
panic(err)
|
||||
}
|
||||
privateKey := exchange.PrivateKey{
|
||||
RSA: key,
|
||||
}
|
||||
|
||||
d := tgtest.NewDispatcher()
|
||||
server := tgtest.NewServer(privateKey, tgtest.UnpackInvoke(d), tgtest.ServerOptions{
|
||||
DC: id,
|
||||
Logger: c.log.Named(name).With(zap.Int("dc_id", id)),
|
||||
Codec: c.getCodec(),
|
||||
})
|
||||
c.setups[id] = setup{
|
||||
srv: server,
|
||||
dispatch: d,
|
||||
}
|
||||
c.keys = append(c.keys, server.Key())
|
||||
|
||||
// We set server fallback handler to dispatch request in order
|
||||
// 1) Explicit DC handler
|
||||
// 2) Explicit common handler
|
||||
// 3) Common fallback
|
||||
d.Fallback(c.Common())
|
||||
return server, d
|
||||
}
|
||||
|
||||
// Dispatch registers new server and returns its dispatcher.
|
||||
func (c *Cluster) Dispatch(id int, name string) *tgtest.Dispatcher {
|
||||
_, d := c.DC(id, name)
|
||||
return d
|
||||
}
|
||||
|
||||
func (c *Cluster) fallback() tgtest.HandlerFunc {
|
||||
return services.NotImplemented
|
||||
}
|
||||
@@ -0,0 +1,17 @@
|
||||
package cluster
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
)
|
||||
|
||||
func newLocalListener(ctx context.Context) (net.Listener, error) {
|
||||
cfg := net.ListenConfig{}
|
||||
l, err := cfg.Listen(ctx, "tcp4", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "listen")
|
||||
}
|
||||
return l, nil
|
||||
}
|
||||
@@ -0,0 +1,45 @@
|
||||
package cluster
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/dcs"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/transport"
|
||||
)
|
||||
|
||||
// Options of Cluster.
|
||||
type Options struct {
|
||||
// Web denotes to use websocket listener.
|
||||
Web bool
|
||||
// Random is random source. Used to generate RSA keys.
|
||||
// Defaults to rand.Reader.
|
||||
Random io.Reader
|
||||
// Logger is instance of zap.Logger. No logs by default.
|
||||
Logger *zap.Logger
|
||||
// Codec constructor.
|
||||
// Defaults to nil (underlying transport server detects protocol automatically).
|
||||
Protocol dcs.Protocol
|
||||
// Config is an initial cluster config.
|
||||
Config tg.Config
|
||||
// CDNConfig is an initial cluster CDN config.
|
||||
CDNConfig tg.CDNConfig
|
||||
}
|
||||
|
||||
func (opt *Options) setDefaults() {
|
||||
// It's okay to use zero value Web.
|
||||
if opt.Random == nil {
|
||||
opt.Random = crypto.DefaultRand()
|
||||
}
|
||||
if opt.Logger == nil {
|
||||
opt.Logger = zap.NewNop()
|
||||
}
|
||||
if opt.Protocol == nil {
|
||||
opt.Protocol = transport.Intermediate
|
||||
}
|
||||
// It's okay to use zero value Config.
|
||||
// It's okay to use zero value CDNConfig.
|
||||
}
|
||||
@@ -0,0 +1,97 @@
|
||||
package cluster
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tdsync"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/transport"
|
||||
)
|
||||
|
||||
// Up runs all servers in a cluster.
|
||||
func (c *Cluster) Up(ctx context.Context) error {
|
||||
g := tdsync.NewCancellableGroup(ctx)
|
||||
|
||||
listen := func(ctx context.Context, _ int) (net.Listener, error) {
|
||||
return newLocalListener(ctx)
|
||||
}
|
||||
if c.web {
|
||||
// Create local random listener
|
||||
l, err := newLocalListener(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
srv := http.Server{
|
||||
ReadHeaderTimeout: time.Second * 10,
|
||||
Handler: mux,
|
||||
BaseContext: func(net.Listener) context.Context {
|
||||
return ctx
|
||||
},
|
||||
}
|
||||
g.Go(func(ctx context.Context) error {
|
||||
if err := srv.Serve(l); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
return errors.Wrap(err, "serve")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
g.Go(func(ctx context.Context) error {
|
||||
<-ctx.Done()
|
||||
|
||||
return srv.Close()
|
||||
})
|
||||
|
||||
baseURL := url.URL{
|
||||
Scheme: "http",
|
||||
Host: l.Addr().String(),
|
||||
}
|
||||
listen = func(ctx context.Context, dc int) (net.Listener, error) {
|
||||
listener, handler := transport.WebsocketListener(l.Addr())
|
||||
|
||||
path := fmt.Sprintf("/dc/%d", dc)
|
||||
mux.Handle(path, handler)
|
||||
|
||||
dcURL := baseURL
|
||||
dcURL.Path = path
|
||||
c.domains[dc] = dcURL.String()
|
||||
return listener, nil
|
||||
}
|
||||
}
|
||||
|
||||
for dcID, s := range c.setups {
|
||||
l, err := listen(ctx, dcID)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "DC %d: listen port", dcID)
|
||||
}
|
||||
|
||||
if !c.web {
|
||||
// Add TCP listeners to config.
|
||||
if addr, ok := l.Addr().(*net.TCPAddr); ok {
|
||||
c.cfg.DCOptions = append(c.cfg.DCOptions, tg.DCOption{
|
||||
Ipv6: addr.IP.To16() != nil,
|
||||
Static: true,
|
||||
ID: dcID,
|
||||
IPAddress: addr.IP.String(),
|
||||
Port: addr.Port,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Copy iteration value.
|
||||
srv := s.srv
|
||||
g.Go(func(ctx context.Context) error {
|
||||
return srv.Serve(ctx, transport.ListenCodec(nil, l))
|
||||
})
|
||||
}
|
||||
c.ready.Signal()
|
||||
|
||||
return g.Wait()
|
||||
}
|
||||
@@ -0,0 +1,92 @@
|
||||
package tgtest
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"go.uber.org/atomic"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/transport"
|
||||
)
|
||||
|
||||
type connection struct {
|
||||
transport.Conn
|
||||
sent atomic.Bool
|
||||
}
|
||||
|
||||
func (conn *connection) sentCreated() bool {
|
||||
return conn.sent.Swap(true)
|
||||
}
|
||||
|
||||
// users contains all server connections and sessions.
|
||||
type users struct {
|
||||
sessions map[[8]byte]crypto.AuthKey
|
||||
sessionsMux sync.Mutex
|
||||
|
||||
conns map[int64]*connection
|
||||
connsMux sync.Mutex
|
||||
}
|
||||
|
||||
func newUsers() *users {
|
||||
return &users{
|
||||
conns: map[int64]*connection{},
|
||||
sessions: map[[8]byte]crypto.AuthKey{},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *users) createConnection(key int64, tConn transport.Conn) *connection {
|
||||
c.connsMux.Lock()
|
||||
defer c.connsMux.Unlock()
|
||||
|
||||
if v, ok := c.conns[key]; ok {
|
||||
return v
|
||||
}
|
||||
|
||||
conn := &connection{
|
||||
Conn: tConn,
|
||||
}
|
||||
c.conns[key] = conn
|
||||
return conn
|
||||
}
|
||||
|
||||
func (c *users) getConnection(key int64) (conn *connection, ok bool) {
|
||||
c.connsMux.Lock()
|
||||
conn, ok = c.conns[key]
|
||||
c.connsMux.Unlock()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (c *users) deleteConnection(key int64) {
|
||||
c.connsMux.Lock()
|
||||
conn := c.conns[key]
|
||||
if conn != nil {
|
||||
_ = conn.Close()
|
||||
}
|
||||
delete(c.conns, key)
|
||||
c.connsMux.Unlock()
|
||||
}
|
||||
|
||||
func (c *users) addSession(key crypto.AuthKey) {
|
||||
c.sessionsMux.Lock()
|
||||
c.sessions[key.ID] = key
|
||||
c.sessionsMux.Unlock()
|
||||
}
|
||||
|
||||
func (c *users) getSession(k [8]byte) (s crypto.AuthKey, ok bool) {
|
||||
c.connsMux.Lock()
|
||||
s, ok = c.sessions[k]
|
||||
c.connsMux.Unlock()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (c *users) Close() error {
|
||||
c.connsMux.Lock()
|
||||
for _, conn := range c.conns {
|
||||
_ = conn.Close()
|
||||
}
|
||||
c.connsMux.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,82 @@
|
||||
package tgtest
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
)
|
||||
|
||||
// Dispatcher is a plain handler to map requests by ID.
|
||||
type Dispatcher struct {
|
||||
reqs map[uint32]Handler
|
||||
mux sync.Mutex
|
||||
fallback Handler
|
||||
}
|
||||
|
||||
// NewDispatcher creates new Dispatcher.
|
||||
func NewDispatcher() *Dispatcher {
|
||||
return &Dispatcher{
|
||||
reqs: map[uint32]Handler{},
|
||||
}
|
||||
}
|
||||
|
||||
// OnMessage implements Handler
|
||||
func (d *Dispatcher) OnMessage(server *Server, req *Request) error {
|
||||
id, err := req.Buf.PeekID()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "peek id")
|
||||
}
|
||||
|
||||
d.mux.Lock()
|
||||
h, ok := d.reqs[id]
|
||||
fallback := d.fallback
|
||||
d.mux.Unlock()
|
||||
if ok {
|
||||
return h.OnMessage(server, req)
|
||||
}
|
||||
|
||||
if fallback != nil {
|
||||
return fallback.OnMessage(server, req)
|
||||
}
|
||||
|
||||
return errors.Errorf("unexpected type %#x", id)
|
||||
}
|
||||
|
||||
// Handle sets handler for given TypeID.
|
||||
func (d *Dispatcher) Handle(id uint32, h Handler) *Dispatcher {
|
||||
d.mux.Lock()
|
||||
d.reqs[id] = h
|
||||
d.mux.Unlock()
|
||||
return d
|
||||
}
|
||||
|
||||
// HandleFunc sets handler for given TypeID.
|
||||
func (d *Dispatcher) HandleFunc(id uint32, h func(server *Server, req *Request) error) *Dispatcher {
|
||||
return d.Handle(id, HandlerFunc(h))
|
||||
}
|
||||
|
||||
// Result sets constant result for given TypeID.
|
||||
// NB: it uses rpc_result to pack given encoder.
|
||||
func (d *Dispatcher) Result(id uint32, msg bin.Encoder) *Dispatcher {
|
||||
return d.HandleFunc(id, func(server *Server, req *Request) error {
|
||||
return server.SendResult(req, msg)
|
||||
})
|
||||
}
|
||||
|
||||
// Vector sets constant Vector result for given TypeID.
|
||||
// NB: it uses rpc_result to pack generic vector with given encoders.
|
||||
func (d *Dispatcher) Vector(id uint32, msgs ...bin.Encoder) *Dispatcher {
|
||||
return d.HandleFunc(id, func(server *Server, req *Request) error {
|
||||
return server.SendVector(req, msgs...)
|
||||
})
|
||||
}
|
||||
|
||||
// Fallback sets fallback handler.
|
||||
func (d *Dispatcher) Fallback(h Handler) *Dispatcher {
|
||||
d.mux.Lock()
|
||||
d.fallback = h
|
||||
d.mux.Unlock()
|
||||
return d
|
||||
}
|
||||
@@ -0,0 +1,2 @@
|
||||
// Package tgtest provides test Telegram server for basic end-to-end tests.
|
||||
package tgtest
|
||||
@@ -0,0 +1,57 @@
|
||||
package tgtest
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/exchange"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/proto/codec"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/transport"
|
||||
)
|
||||
|
||||
type exchangeConn struct {
|
||||
transport.Conn
|
||||
}
|
||||
|
||||
func (e exchangeConn) Recv(ctx context.Context, b *bin.Buffer) error {
|
||||
for {
|
||||
if err := e.Conn.Recv(ctx, b); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var authKeyID [8]byte
|
||||
if err := b.PeekN(authKeyID[:], len(authKeyID)); err != nil {
|
||||
return errors.Wrap(err, "peek id")
|
||||
}
|
||||
if authKeyID != [8]byte{} {
|
||||
// TODO(tdakkota): what if client send registered auth key during key exchange?
|
||||
buf := bin.Buffer{}
|
||||
buf.PutInt32(-codec.CodeAuthKeyNotFound)
|
||||
|
||||
if err := e.Conn.Send(ctx, &buf); err != nil {
|
||||
return errors.Wrap(err, "send")
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// exchange starts MTProto key exchange.
|
||||
func (s *Server) exchange(ctx context.Context, conn transport.Conn) (crypto.AuthKey, error) {
|
||||
r, err := exchange.NewExchanger(conn, s.dcID).
|
||||
WithClock(s.clock).
|
||||
WithLogger(s.log.Named("exchange")).
|
||||
WithRand(s.cipher.Rand()).
|
||||
Server(s.key).Run(ctx)
|
||||
if err != nil {
|
||||
return crypto.AuthKey{}, err
|
||||
}
|
||||
|
||||
return r.Key, nil
|
||||
}
|
||||
@@ -0,0 +1,61 @@
|
||||
package tgtest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/proto/codec"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tdsync"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/transport"
|
||||
)
|
||||
|
||||
func Test_exchangeConn_Recv(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
i := transport.Intermediate
|
||||
c1, c2 := i.Pipe()
|
||||
defer func() {
|
||||
a.NoError(c1.Close())
|
||||
a.NoError(c2.Close())
|
||||
}()
|
||||
e := exchangeConn{Conn: c1}
|
||||
|
||||
s := "abcdabcd"
|
||||
a.Len(s, 8)
|
||||
|
||||
grp := tdsync.NewCancellableGroup(ctx)
|
||||
grp.Go(func(ctx context.Context) error {
|
||||
b := bin.Buffer{Buf: []byte(s)}
|
||||
if err := c2.Send(ctx, &b); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
b.Reset()
|
||||
var protocolErr *codec.ProtocolErr
|
||||
if err := c2.Recv(ctx, &b); err != nil && !errors.As(err, &protocolErr) {
|
||||
return err
|
||||
}
|
||||
|
||||
b.ResetN(8)
|
||||
b.Put([]byte(s))
|
||||
if err := c2.Send(ctx, &b); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
var b bin.Buffer
|
||||
a.NoError(e.Recv(ctx, &b))
|
||||
b.Skip(8)
|
||||
a.Equal(s, string(b.Buf))
|
||||
|
||||
a.NoError(grp.Wait())
|
||||
}
|
||||
@@ -0,0 +1,144 @@
|
||||
package tgtest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
"go.uber.org/multierr"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/mt"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/proto"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tgerr"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/transport"
|
||||
)
|
||||
|
||||
func (s *Server) rpcHandle(ctx context.Context, c transport.Conn, b *bin.Buffer) error {
|
||||
m := &crypto.EncryptedMessage{}
|
||||
if err := m.DecodeWithoutCopy(b); err != nil {
|
||||
return errors.Wrap(err, "decode encrypted message")
|
||||
}
|
||||
|
||||
key, ok := s.users.getSession(m.AuthKeyID)
|
||||
if !ok {
|
||||
return errors.New("invalid session")
|
||||
}
|
||||
|
||||
msg, err := s.cipher.Decrypt(key, m)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "decrypt message")
|
||||
}
|
||||
|
||||
session := Session{
|
||||
ID: msg.SessionID,
|
||||
AuthKey: key,
|
||||
}
|
||||
if conn := s.users.createConnection(msg.SessionID, c); !conn.sentCreated() {
|
||||
s.log.Debug("Send handleSessionCreated event", zap.Inline(session))
|
||||
salt := int64(binary.LittleEndian.Uint64(key.ID[:]))
|
||||
if err := s.sendSessionCreated(ctx, session, salt); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Buffer now contains plaintext message payload.
|
||||
b.ResetTo(msg.Data())
|
||||
|
||||
if err := s.handle(&Request{
|
||||
DC: s.dcID,
|
||||
Session: session,
|
||||
MsgID: msg.MessageID,
|
||||
Buf: b,
|
||||
RequestCtx: ctx,
|
||||
}); err != nil {
|
||||
return errors.Wrap(err, "handle")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) handle(req *Request) error {
|
||||
in := req.Buf
|
||||
id, err := in.PeekID()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "peek id")
|
||||
}
|
||||
|
||||
s.log.Debug("Got request",
|
||||
zap.Inline(req.Session),
|
||||
zap.Int64("msg_id", req.MsgID),
|
||||
zap.String("type", s.types.Get(id)),
|
||||
)
|
||||
|
||||
// TODO(tdakkota): unpack all containers
|
||||
switch id {
|
||||
case mt.PingDelayDisconnectRequestTypeID:
|
||||
pingReq := mt.PingDelayDisconnectRequest{}
|
||||
if err := pingReq.Decode(in); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.SendPong(req, pingReq.PingID)
|
||||
case mt.PingRequestTypeID:
|
||||
pingReq := mt.PingRequest{}
|
||||
if err := pingReq.Decode(in); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.SendPong(req, pingReq.PingID)
|
||||
|
||||
case mt.GetFutureSaltsRequestTypeID:
|
||||
saltsRequest := mt.GetFutureSaltsRequest{}
|
||||
if err := saltsRequest.Decode(in); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.SendEternalSalt(req)
|
||||
|
||||
case mt.RPCDropAnswerRequestTypeID:
|
||||
drop := mt.RPCDropAnswerRequest{}
|
||||
if err := drop.Decode(in); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.SendResult(req, &mt.RPCAnswerDroppedRunning{})
|
||||
|
||||
case proto.GZIPTypeID:
|
||||
var content proto.GZIP
|
||||
if err := content.Decode(in); err != nil {
|
||||
return errors.Wrap(err, "gzip")
|
||||
}
|
||||
req.Buf = &bin.Buffer{Buf: content.Data}
|
||||
|
||||
case proto.MessageContainerTypeID:
|
||||
var container proto.MessageContainer
|
||||
if err := container.Decode(in); err != nil {
|
||||
return errors.Wrap(err, "container")
|
||||
}
|
||||
|
||||
var err error
|
||||
for _, msg := range container.Messages {
|
||||
err = multierr.Append(err, s.handle(&Request{
|
||||
DC: req.DC,
|
||||
Session: req.Session,
|
||||
MsgID: msg.ID,
|
||||
Buf: &bin.Buffer{Buf: msg.Body},
|
||||
RequestCtx: req.RequestCtx,
|
||||
}))
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.handler.OnMessage(s, req); err != nil {
|
||||
var rpcErr *tgerr.Error
|
||||
if errors.As(err, &rpcErr) {
|
||||
return s.SendErr(req, rpcErr)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,37 @@
|
||||
package tgtest
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
)
|
||||
|
||||
// Request represents MTProto RPC request structure.
|
||||
type Request struct {
|
||||
// DC ID from server structure.
|
||||
// Used to make handler less stateful.
|
||||
DC int
|
||||
// Session is a user session.
|
||||
Session Session
|
||||
// MsgID is a message ID of RPC request.
|
||||
MsgID int64
|
||||
// Buf contains RPC request
|
||||
Buf *bin.Buffer
|
||||
// RequestCtx is a request context.
|
||||
RequestCtx context.Context
|
||||
}
|
||||
|
||||
// Handler is a RPC request handler.
|
||||
type Handler interface {
|
||||
OnMessage(server *Server, req *Request) error
|
||||
}
|
||||
|
||||
var _ Handler = HandlerFunc(nil)
|
||||
|
||||
// HandlerFunc is functional adapter for Handler.OnMessage method.
|
||||
type HandlerFunc func(server *Server, req *Request) error
|
||||
|
||||
// OnMessage implements Handler.
|
||||
func (h HandlerFunc) OnMessage(server *Server, req *Request) error {
|
||||
return h(server, req)
|
||||
}
|
||||
@@ -0,0 +1,92 @@
|
||||
package tgtest
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/exchange"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/proto/codec"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/transport"
|
||||
)
|
||||
|
||||
func (s *Server) read(ctx context.Context, conn transport.Conn, b *bin.Buffer) error {
|
||||
b.Reset()
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, s.readTimeout)
|
||||
defer cancel()
|
||||
if err := conn.Recv(ctx, b); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) sendProtoError(ctx context.Context, conn transport.Conn, e int32) error {
|
||||
var buf bin.Buffer
|
||||
buf.PutInt32(-e)
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, s.writeTimeout)
|
||||
defer cancel()
|
||||
|
||||
if err := conn.Send(ctx, &buf); err != nil {
|
||||
return errors.Wrap(err, "send")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) serveConn(ctx context.Context, conn transport.Conn) error {
|
||||
s.log.Debug("User connected")
|
||||
defer func() {
|
||||
_ = conn.Close()
|
||||
s.log.Debug("User disconnected")
|
||||
}()
|
||||
|
||||
b := new(bin.Buffer)
|
||||
for {
|
||||
if err := s.read(ctx, conn, b); err != nil {
|
||||
return errors.Wrap(err, "read")
|
||||
}
|
||||
|
||||
var authKeyID [8]byte
|
||||
if err := b.PeekN(authKeyID[:], len(authKeyID)); err != nil {
|
||||
return errors.Wrap(err, "peek id")
|
||||
}
|
||||
|
||||
// TODO(tdakkota): dispatch by type ID instead?
|
||||
if _, ok := s.users.getSession(authKeyID); ok {
|
||||
if err := s.rpcHandle(ctx, conn, b); err != nil {
|
||||
return errors.Wrap(err, "handle")
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// If authKeyID not found and is not zero, so send protocol error.
|
||||
if authKeyID != [8]byte{} {
|
||||
if err := s.sendProtoError(ctx, conn, codec.CodeAuthKeyNotFound); err != nil {
|
||||
return errors.Wrap(err, "send AuthKeyNotFound")
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
s.log.Debug("Starting key exchange")
|
||||
c := newBufferedConn(conn)
|
||||
c.Push(b)
|
||||
|
||||
key, err := s.exchange(ctx, exchangeConn{Conn: c})
|
||||
if err != nil {
|
||||
var exchangeErr *exchange.ServerExchangeError
|
||||
if errors.As(err, &exchangeErr) {
|
||||
code := exchangeErr.Code
|
||||
if err := s.sendProtoError(ctx, c, code); err != nil {
|
||||
return errors.Wrapf(err, "send proto error %v", code)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return errors.Wrap(err, "key exchange failed")
|
||||
}
|
||||
|
||||
s.users.addSession(key)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,101 @@
|
||||
package tgtest_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"io"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap/zaptest"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/session"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tdsync"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tgtest/cluster"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tgtest/services"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tgtest/services/config"
|
||||
)
|
||||
|
||||
func TestSessionHandle(t *testing.T) {
|
||||
test := func(storage session.Storage, t *testing.T) error {
|
||||
log := zaptest.NewLogger(t)
|
||||
defer func() { _ = log.Sync() }()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||
defer cancel()
|
||||
|
||||
g := tdsync.NewCancellableGroup(ctx)
|
||||
c := cluster.NewCluster(cluster.Options{
|
||||
Logger: log.Named("cluster"),
|
||||
})
|
||||
d := c.Dispatch(2, "server").Fallback(services.NotImplemented)
|
||||
config.NewService(&tg.Config{}, &tg.CDNConfig{}).Register(d)
|
||||
|
||||
g.Go(c.Up)
|
||||
|
||||
g.Go(func(ctx context.Context) error {
|
||||
select {
|
||||
case <-c.Ready():
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
defer g.Cancel()
|
||||
|
||||
client := telegram.NewClient(telegram.TestAppID, telegram.TestAppHash, telegram.Options{
|
||||
PublicKeys: c.Keys(),
|
||||
DC: 2,
|
||||
DCList: c.List(),
|
||||
Resolver: c.Resolver(),
|
||||
NoUpdates: true,
|
||||
Logger: log.Named("client"),
|
||||
SessionStorage: storage,
|
||||
RetryInterval: 100 * time.Millisecond,
|
||||
})
|
||||
|
||||
return client.Run(ctx, func(ctx context.Context) error {
|
||||
return nil
|
||||
})
|
||||
})
|
||||
|
||||
if err := g.Wait(); err != nil && !errors.Is(err, context.Canceled) {
|
||||
return errors.Wrap(err, "wait")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
t.Run("Empty", func(t *testing.T) {
|
||||
a := require.New(t)
|
||||
storage := session.StorageMemory{}
|
||||
a.NoError(test(&storage, t))
|
||||
|
||||
_, err := storage.Bytes(nil)
|
||||
a.NoError(err, "Must create new session")
|
||||
})
|
||||
t.Run("Unknown", func(t *testing.T) {
|
||||
a := require.New(t)
|
||||
loader := session.Loader{Storage: &session.StorageMemory{}}
|
||||
ctx := context.Background()
|
||||
|
||||
key := crypto.Key{}
|
||||
_, err := io.ReadFull(rand.Reader, key[:])
|
||||
a.NoError(err)
|
||||
authKey := key.WithID()
|
||||
|
||||
was := &session.Data{
|
||||
DC: 2,
|
||||
AuthKey: authKey.Value[:],
|
||||
AuthKeyID: authKey.ID[:],
|
||||
}
|
||||
a.NoError(loader.Save(context.Background(), was))
|
||||
a.NoError(test(loader.Storage, t))
|
||||
|
||||
data, err := loader.Load(ctx)
|
||||
a.NoError(err)
|
||||
a.NotEqual(was, data, "Must regenerate session")
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,72 @@
|
||||
package tgtest
|
||||
|
||||
import (
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
// UnpackInvoke is a simple Handler middleware to unpack some Invoke*-like requests.
|
||||
// Including:
|
||||
//
|
||||
// tg.InvokeWithLayerRequest
|
||||
// tg.InitConnectionRequest
|
||||
// tg.InvokeWithoutUpdatesRequest
|
||||
func UnpackInvoke(next Handler) Handler {
|
||||
return HandlerFunc(func(srv *Server, req *Request) error {
|
||||
id, err := req.Buf.PeekID()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO(tdakkota): handle more Invoke* requests.
|
||||
var (
|
||||
obj = peekIDObject{}
|
||||
r bin.Decoder
|
||||
)
|
||||
for {
|
||||
switch id {
|
||||
case tg.InvokeWithLayerRequestTypeID:
|
||||
r = &tg.InvokeWithLayerRequest{
|
||||
Query: &obj,
|
||||
}
|
||||
// TODO(tdakkota): pass Layer to session.
|
||||
case tg.InitConnectionRequestTypeID:
|
||||
r = &tg.InitConnectionRequest{
|
||||
Query: &obj,
|
||||
}
|
||||
// TODO(tdakkota): pass DeviceInfo to session.
|
||||
case tg.InvokeWithoutUpdatesRequestTypeID:
|
||||
r = &tg.InvokeWithoutUpdatesRequest{
|
||||
Query: &obj,
|
||||
}
|
||||
// TODO(tdakkota): pass NoUpdates flag to session.
|
||||
default:
|
||||
return next.OnMessage(srv, req)
|
||||
}
|
||||
|
||||
if err := r.Decode(req.Buf); err != nil {
|
||||
return err
|
||||
}
|
||||
id = obj.TypeID
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
type peekIDObject struct {
|
||||
TypeID uint32
|
||||
}
|
||||
|
||||
func (t *peekIDObject) Decode(b *bin.Buffer) error {
|
||||
id, err := b.PeekID()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "peek id")
|
||||
}
|
||||
t.TypeID = id
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *peekIDObject) Encode(*bin.Buffer) error {
|
||||
return errors.New("peekIDObject must not be encoded")
|
||||
}
|
||||
@@ -0,0 +1,191 @@
|
||||
package tgtest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"math"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/mt"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/proto"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tgerr"
|
||||
)
|
||||
|
||||
const (
|
||||
// MessageServerResponse is a message type of RPC calls result.
|
||||
MessageServerResponse = proto.MessageServerResponse
|
||||
// MessageFromServer is a message type of server-side updates.
|
||||
MessageFromServer = proto.MessageFromServer
|
||||
)
|
||||
|
||||
// Send sends given message to user session k.
|
||||
// Parameter t denotes MTProto message type. It should be MessageServerResponse or MessageFromServer.
|
||||
func (s *Server) Send(ctx context.Context, k Session, t proto.MessageType, message bin.Encoder) error {
|
||||
conn, ok := s.users.getConnection(k.ID)
|
||||
if !ok {
|
||||
return errors.Errorf("send %T: invalid key: connection %s not found", message, k.AuthKey.String())
|
||||
}
|
||||
|
||||
var b bin.Buffer
|
||||
if err := message.Encode(&b); err != nil {
|
||||
return errors.Wrap(err, "encode")
|
||||
}
|
||||
|
||||
data := crypto.EncryptedMessageData{
|
||||
SessionID: k.ID,
|
||||
MessageDataLen: int32(b.Len()),
|
||||
MessageDataWithPadding: b.Copy(),
|
||||
MessageID: s.msgID.New(t),
|
||||
}
|
||||
|
||||
err := s.cipher.Encrypt(k.AuthKey, data, &b)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "encrypt")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, s.writeTimeout)
|
||||
defer cancel()
|
||||
|
||||
if err := conn.Send(ctx, &b); err != nil {
|
||||
return errors.Wrap(err, "send")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) sendReq(req *Request, t proto.MessageType, encoder bin.Encoder) error {
|
||||
return s.Send(req.RequestCtx, req.Session, t, encoder)
|
||||
}
|
||||
|
||||
// SendResult sends RPC answer using msg as result.
|
||||
func (s *Server) SendResult(req *Request, msg bin.Encoder) error {
|
||||
var buf bin.Buffer
|
||||
|
||||
if err := msg.Encode(&buf); err != nil {
|
||||
return errors.Wrap(err, "encode result")
|
||||
}
|
||||
|
||||
if err := s.sendReq(req, proto.MessageServerResponse, &proto.Result{
|
||||
RequestMessageID: req.MsgID,
|
||||
Result: buf.Raw(),
|
||||
}); err != nil {
|
||||
return errors.Wrapf(err, "send result [%T]", msg)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SendGZIP sends RPC answer and packs it into proto.GZIP.
|
||||
func (s *Server) SendGZIP(req *Request, msg bin.Encoder) error {
|
||||
var buf bin.Buffer
|
||||
|
||||
if err := msg.Encode(&buf); err != nil {
|
||||
return errors.Wrap(err, "encode gzip data")
|
||||
}
|
||||
|
||||
return s.SendResult(req, proto.GZIP{Data: buf.Buf})
|
||||
}
|
||||
|
||||
// SendErr sends RPC answer using given error as result.
|
||||
func (s *Server) SendErr(req *Request, e *tgerr.Error) error {
|
||||
return s.SendResult(req, &mt.RPCError{
|
||||
ErrorCode: e.Code,
|
||||
ErrorMessage: e.Message,
|
||||
})
|
||||
}
|
||||
|
||||
// SendBool sends RPC answer using given bool as result.
|
||||
// Usually used in methods without explicit response.
|
||||
func (s *Server) SendBool(req *Request, r bool) error {
|
||||
var msg tg.BoolClass = &tg.BoolTrue{}
|
||||
if !r {
|
||||
msg = &tg.BoolFalse{}
|
||||
}
|
||||
return s.SendResult(req, msg)
|
||||
}
|
||||
|
||||
// SendVector sends RPC answer using given vector as result.
|
||||
func (s *Server) SendVector(req *Request, msgs ...bin.Encoder) error {
|
||||
return s.SendResult(req, &genericVector{Elems: msgs})
|
||||
}
|
||||
|
||||
// sendSessionCreated sends mt.NewSessionCreated `new_session_created` notification.
|
||||
func (s *Server) sendSessionCreated(ctx context.Context, k Session, serverSalt int64) error {
|
||||
if err := s.Send(ctx, k, proto.MessageFromServer, &mt.NewSessionCreated{
|
||||
FirstMsgID: s.msgID.New(proto.MessageFromClient),
|
||||
ServerSalt: serverSalt,
|
||||
}); err != nil {
|
||||
return errors.Wrap(err, "send sessionCreated")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SendPong sends response for mt.PingRequest request.
|
||||
func (s *Server) SendPong(req *Request, pingID int64) error {
|
||||
if err := s.sendReq(req, proto.MessageServerResponse, &mt.Pong{
|
||||
MsgID: req.MsgID,
|
||||
PingID: pingID,
|
||||
}); err != nil {
|
||||
return errors.Wrap(err, "send pong")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SendEternalSalt sends response for mt.GetFutureSaltsRequest.
|
||||
// It sends an `eternal` salt, which valid until maximum possible date.
|
||||
func (s *Server) SendEternalSalt(req *Request) error {
|
||||
return s.SendFutureSalts(req, mt.FutureSalt{
|
||||
ValidSince: 1,
|
||||
ValidUntil: math.MaxInt32,
|
||||
Salt: 10,
|
||||
})
|
||||
}
|
||||
|
||||
// SendFutureSalts sends response for mt.GetFutureSaltsRequest.
|
||||
func (s *Server) SendFutureSalts(req *Request, salts ...mt.FutureSalt) error {
|
||||
if err := s.Send(req.RequestCtx, req.Session, proto.MessageServerResponse, &mt.FutureSalts{
|
||||
ReqMsgID: req.MsgID,
|
||||
Now: int(s.clock.Now().Unix()),
|
||||
Salts: salts,
|
||||
}); err != nil {
|
||||
return errors.Wrap(err, "send future salts")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SendUpdates sends given updates to user session k.
|
||||
func (s *Server) SendUpdates(ctx context.Context, k Session, updates ...tg.UpdateClass) error {
|
||||
if len(updates) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := s.Send(ctx, k, proto.MessageFromServer, &tg.Updates{
|
||||
Updates: updates,
|
||||
Date: int(s.clock.Now().Unix()),
|
||||
}); err != nil {
|
||||
return errors.Wrap(err, "send updates")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SendAck sends acknowledgment for received message.
|
||||
func (s *Server) SendAck(ctx context.Context, k Session, ids ...int64) error {
|
||||
if err := s.Send(ctx, k, proto.MessageFromServer, &mt.MsgsAck{MsgIDs: ids}); err != nil {
|
||||
return errors.Wrap(err, "send ack")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ForceDisconnect forcibly disconnect user from server.
|
||||
// It deletes MTProto session (session_id), but not auth key.
|
||||
func (s *Server) ForceDisconnect(k Session) {
|
||||
s.users.deleteConnection(k.ID)
|
||||
}
|
||||
@@ -0,0 +1,135 @@
|
||||
package tgtest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rsa"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
"go.uber.org/zap"
|
||||
"nhooyr.io/websocket"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/clock"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/exchange"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/mtproto"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tdsync"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tmap"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/transport"
|
||||
)
|
||||
|
||||
// Server is a MTProto server structure.
|
||||
type Server struct {
|
||||
// DC ID of this server.
|
||||
dcID int
|
||||
// Key pair of this server.
|
||||
key exchange.PrivateKey // immutable
|
||||
|
||||
// Codec constructor. May be nil.
|
||||
codec func() transport.Codec // immutable,nilable
|
||||
// Server-side message cipher.
|
||||
cipher crypto.Cipher // immutable
|
||||
// Clock to use in key exchange and message ID generation.
|
||||
clock clock.Clock // immutable
|
||||
// MessageID generator
|
||||
msgID mtproto.MessageIDSource // immutable
|
||||
|
||||
readTimeout time.Duration
|
||||
writeTimeout time.Duration
|
||||
|
||||
// RPC handler.
|
||||
handler Handler // immutable
|
||||
|
||||
// users stores session info.
|
||||
users *users
|
||||
|
||||
// type map for logging.
|
||||
types *tmap.Map // immutable
|
||||
log *zap.Logger // immutable
|
||||
}
|
||||
|
||||
// NewPrivateKey creates new private key from RSA private key.
|
||||
func NewPrivateKey(k *rsa.PrivateKey) exchange.PrivateKey {
|
||||
return exchange.PrivateKey{
|
||||
RSA: k,
|
||||
}
|
||||
}
|
||||
|
||||
// NewServer creates new Server.
|
||||
func NewServer(key exchange.PrivateKey, handler Handler, opts ServerOptions) *Server {
|
||||
opts.setDefaults()
|
||||
|
||||
s := &Server{
|
||||
dcID: opts.DC,
|
||||
key: key,
|
||||
codec: opts.Codec,
|
||||
cipher: crypto.NewServerCipher(opts.Random),
|
||||
clock: opts.Clock,
|
||||
msgID: opts.MessageID,
|
||||
readTimeout: opts.ReadTimeout,
|
||||
writeTimeout: opts.WriteTimeout,
|
||||
handler: handler,
|
||||
users: newUsers(),
|
||||
types: opts.Types,
|
||||
log: opts.Logger,
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// Key returns public key of this server.
|
||||
func (s *Server) Key() exchange.PublicKey {
|
||||
return s.key.Public()
|
||||
}
|
||||
|
||||
// Serve runs server loop using given listener.
|
||||
func (s *Server) Serve(ctx context.Context, l transport.Listener) error {
|
||||
return s.serve(ctx, l)
|
||||
}
|
||||
|
||||
func (s *Server) serve(ctx context.Context, l transport.Listener) error {
|
||||
s.log.Info("Serving")
|
||||
defer func() {
|
||||
s.log.Info("Stopping")
|
||||
}()
|
||||
|
||||
grp := tdsync.NewCancellableGroup(ctx)
|
||||
grp.Go(func(context.Context) error {
|
||||
for {
|
||||
conn, err := l.Accept()
|
||||
if err != nil {
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
return nil
|
||||
}
|
||||
return errors.Wrap(err, "accept")
|
||||
}
|
||||
|
||||
grp.Go(func(ctx context.Context) error {
|
||||
if err := s.serveConn(ctx, conn); err != nil {
|
||||
// Client disconnected.
|
||||
var syscallErr *net.OpError
|
||||
switch {
|
||||
case errors.Is(err, io.EOF):
|
||||
return nil
|
||||
case errors.As(err, &syscallErr) &&
|
||||
(syscallErr.Op == "write" || syscallErr.Op == "read"):
|
||||
return nil
|
||||
}
|
||||
// TODO(tdakkota): emulate errors too?
|
||||
if code := websocket.CloseStatus(err); code >= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
s.log.Info("Serving handler error", zap.Error(err))
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
})
|
||||
grp.Go(func(ctx context.Context) error {
|
||||
<-ctx.Done()
|
||||
return l.Close()
|
||||
})
|
||||
return grp.Wait()
|
||||
}
|
||||
@@ -0,0 +1,75 @@
|
||||
package tgtest
|
||||
|
||||
import (
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/clock"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/mt"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/mtproto"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/proto"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tmap"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/transport"
|
||||
)
|
||||
|
||||
// ServerOptions of Server.
|
||||
type ServerOptions struct {
|
||||
// DC ID of this server. Default to 2.
|
||||
DC int
|
||||
// Random is random source. Defaults to rand.Reader.
|
||||
Random io.Reader
|
||||
// Logger is instance of zap.Logger. No logs by default.
|
||||
Logger *zap.Logger
|
||||
// Codec constructor.
|
||||
// Defaults to nil (underlying transport server detects protocol automatically).
|
||||
Codec func() transport.Codec
|
||||
// Clock to use. Defaults to clock.System.
|
||||
Clock clock.Clock
|
||||
// MessageID generator. Creates a new proto.MessageIDGen by default.
|
||||
// Clock will be used for creation.
|
||||
MessageID mtproto.MessageIDSource
|
||||
// Types map, used in verbose logging of incoming message.
|
||||
Types *tmap.Map
|
||||
// ReadTimeout is a connection read timeout.
|
||||
ReadTimeout time.Duration
|
||||
// ReadTimeout is a connection write timeout.
|
||||
WriteTimeout time.Duration
|
||||
}
|
||||
|
||||
func (opt *ServerOptions) setDefaults() {
|
||||
if opt.DC == 0 {
|
||||
opt.DC = 2
|
||||
}
|
||||
if opt.Random == nil {
|
||||
opt.Random = crypto.DefaultRand()
|
||||
}
|
||||
if opt.Logger == nil {
|
||||
opt.Logger = zap.NewNop()
|
||||
}
|
||||
|
||||
// Ignore opt.Codec, will be handled by transport.NewCustomServer.
|
||||
|
||||
if opt.Clock == nil {
|
||||
opt.Clock = clock.System
|
||||
}
|
||||
if opt.MessageID == nil {
|
||||
opt.MessageID = proto.NewMessageIDGen(opt.Clock.Now)
|
||||
}
|
||||
if opt.Types == nil {
|
||||
opt.Types = tmap.New(
|
||||
tg.TypesMap(),
|
||||
mt.TypesMap(),
|
||||
proto.TypesMap(),
|
||||
)
|
||||
}
|
||||
if opt.ReadTimeout == 0 {
|
||||
opt.ReadTimeout = 30 * time.Second
|
||||
}
|
||||
if opt.WriteTimeout == 0 {
|
||||
opt.WriteTimeout = 30 * time.Second
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,72 @@
|
||||
// Package config contains config service implementation for tgtest server.
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tgtest"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tgtest/services"
|
||||
)
|
||||
|
||||
// Service is a Telegram config service.
|
||||
type Service struct {
|
||||
cfg *tg.Config
|
||||
cdnCfg *tg.CDNConfig
|
||||
}
|
||||
|
||||
// NewService creates new Service.
|
||||
func NewService(cfg *tg.Config, cdnCfg *tg.CDNConfig) *Service {
|
||||
return &Service{cfg: cfg, cdnCfg: cdnCfg}
|
||||
}
|
||||
|
||||
func (c *Service) HelpGetCDNConfig(ctx context.Context, req *tg.HelpGetCDNConfigRequest) (*tg.CDNConfig, error) {
|
||||
cfg := c.cdnCfg
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func (c *Service) HelpGetConfig(ctx context.Context, dc int, req *tg.HelpGetConfigRequest) (*tg.Config, error) {
|
||||
cfg := *c.cfg
|
||||
cfg.ThisDC = dc
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
// OnMessage implements tgtest.Handler.
|
||||
func (c *Service) OnMessage(server *tgtest.Server, req *tgtest.Request) error {
|
||||
id, err := req.Buf.PeekID()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var (
|
||||
decode bin.Decoder
|
||||
result bin.Encoder
|
||||
)
|
||||
switch id {
|
||||
case tg.HelpGetCDNConfigRequestTypeID:
|
||||
cfg := c.cdnCfg
|
||||
|
||||
decode = &tg.HelpGetCDNConfigRequest{}
|
||||
result = cfg
|
||||
case tg.HelpGetConfigRequestTypeID:
|
||||
cfg := *c.cfg
|
||||
cfg.ThisDC = req.DC
|
||||
|
||||
decode = &tg.HelpGetConfigRequest{}
|
||||
result = &cfg
|
||||
default:
|
||||
return services.ErrMethodNotImplemented
|
||||
}
|
||||
|
||||
if err := decode.Decode(req.Buf); err != nil {
|
||||
return err
|
||||
}
|
||||
return server.SendResult(req, result)
|
||||
}
|
||||
|
||||
// Register registers service handlers.
|
||||
func (c *Service) Register(dispatcher *tgtest.Dispatcher) {
|
||||
dispatcher.HandleFunc(tg.HelpGetCDNConfigRequestTypeID, c.OnMessage)
|
||||
dispatcher.HandleFunc(tg.HelpGetConfigRequestTypeID, c.OnMessage)
|
||||
}
|
||||
@@ -0,0 +1,2 @@
|
||||
// Package services contains some Telegram services implemented for testing.
|
||||
package services
|
||||
@@ -0,0 +1,16 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tgerr"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tgtest"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrMethodNotImplemented denotes that method is not implemented.
|
||||
ErrMethodNotImplemented error = tgerr.New(400, "INPUT_METHOD_INVALID")
|
||||
|
||||
// NotImplemented is a simple handler which returns ErrMethodNotImplemented.
|
||||
NotImplemented tgtest.HandlerFunc = func(server *tgtest.Server, req *tgtest.Request) error {
|
||||
return ErrMethodNotImplemented
|
||||
}
|
||||
)
|
||||
@@ -0,0 +1,24 @@
|
||||
package file
|
||||
|
||||
type Config struct {
|
||||
// Storage to store files.
|
||||
// InMemory will be used.
|
||||
Storage Storage
|
||||
// HashPartSize is a size of part to use in tg.FileHash.
|
||||
HashPartSize int
|
||||
// HashRangeSize is size of range to return in upload.getFileHashes.
|
||||
HashRangeSize int
|
||||
}
|
||||
|
||||
func (c *Config) setDefaults() {
|
||||
if c.Storage == nil {
|
||||
c.Storage = NewInMemory()
|
||||
}
|
||||
// Telegram usually uses this values.
|
||||
if c.HashPartSize == 0 {
|
||||
c.HashPartSize = 131072
|
||||
}
|
||||
if c.HashRangeSize == 0 {
|
||||
c.HashRangeSize = 10
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,136 @@
|
||||
package file
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tgerr"
|
||||
)
|
||||
|
||||
func getLocation(loc tg.InputFileLocationClass) (string, error) {
|
||||
v, ok := loc.(interface {
|
||||
GetLocalID() int
|
||||
GetVolumeID() int64
|
||||
})
|
||||
if !ok {
|
||||
return "", tgerr.New(400, tg.ErrFileIDInvalid)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%d_%d", v.GetLocalID(), v.GetVolumeID()), nil
|
||||
}
|
||||
|
||||
func (m *Service) openLocation(loc tg.InputFileLocationClass) (File, error) {
|
||||
name, err := getLocation(loc)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f, err := m.storage.Open(name)
|
||||
if err != nil {
|
||||
return nil, tgerr.New(400, tg.ErrFileIDInvalid)
|
||||
}
|
||||
|
||||
return f, nil
|
||||
}
|
||||
|
||||
func (m *Service) getPart(loc tg.InputFileLocationClass, offset int64, limit int) ([]byte, error) {
|
||||
f, err := m.openLocation(loc)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r := make([]byte, limit)
|
||||
n, err := f.ReadAt(r, offset)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "read from storage")
|
||||
}
|
||||
|
||||
return r[:n], nil
|
||||
}
|
||||
|
||||
func (m *Service) UploadGetFile(ctx context.Context, request *tg.UploadGetFileRequest) (tg.UploadFileClass, error) {
|
||||
data, err := m.getPart(request.Location, request.Offset, request.Limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &tg.UploadFile{
|
||||
Type: &tg.StorageFilePartial{},
|
||||
Mtime: 0,
|
||||
Bytes: data,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func countHashes(data []byte, offset int64, partSize int) []tg.FileHash {
|
||||
actions := data
|
||||
batchSize := partSize
|
||||
batches := make([][]byte, 0, (len(actions)+batchSize-1)/batchSize)
|
||||
|
||||
for batchSize < len(actions) {
|
||||
actions, batches = actions[batchSize:], append(batches, actions[0:batchSize:batchSize])
|
||||
}
|
||||
batches = append(batches, actions)
|
||||
|
||||
currentRange := make([]tg.FileHash, 0, 10)
|
||||
for _, batch := range batches {
|
||||
currentRange = append(currentRange, tg.FileHash{
|
||||
Offset: offset,
|
||||
Limit: partSize,
|
||||
Hash: crypto.SHA256(batch),
|
||||
})
|
||||
offset += int64(len(batch))
|
||||
}
|
||||
return currentRange
|
||||
}
|
||||
|
||||
func divAndCeil(a, b int) int {
|
||||
r := a / b
|
||||
if a%b != 0 {
|
||||
r++
|
||||
}
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
// computeBatch computes hash range number for given offset.
|
||||
func computeBatch(offset int64, rangeSize, partSize int) int {
|
||||
// Compute number of parts in partSize from offset.
|
||||
parts := divAndCeil(int(offset+1), partSize)
|
||||
// Compute number of hash ranges in rangeSize.
|
||||
batches := divAndCeil(parts, rangeSize)
|
||||
|
||||
return batches
|
||||
}
|
||||
|
||||
func (m *Service) UploadGetFileHashes(
|
||||
ctx context.Context,
|
||||
request *tg.UploadGetFileHashesRequest,
|
||||
) ([]tg.FileHash, error) {
|
||||
f, err := m.openLocation(request.Location)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if request.Offset >= int64(f.Size()) {
|
||||
return nil, nil
|
||||
}
|
||||
partSize := m.hashPartSize
|
||||
rangeSize := m.hashRangeSize
|
||||
batch := computeBatch(request.Offset, rangeSize, partSize)
|
||||
|
||||
low := (batch - 1) * rangeSize * partSize
|
||||
high := batch * rangeSize * partSize
|
||||
|
||||
r := make([]byte, high-low)
|
||||
n, err := f.ReadAt(r, int64(low))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r = r[:n]
|
||||
|
||||
return countHashes(r, int64(low), partSize), nil
|
||||
}
|
||||
@@ -0,0 +1,97 @@
|
||||
// Package file contains file service implementation for tgtest server.
|
||||
package file
|
||||
|
||||
import (
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tgtest"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tgtest/services"
|
||||
)
|
||||
|
||||
// Service is a Telegram file service.
|
||||
type Service struct {
|
||||
storage Storage
|
||||
// Size of part to use in tg.FileHash.
|
||||
hashPartSize int
|
||||
// Size of range to return in upload.getFileHashes.
|
||||
hashRangeSize int
|
||||
}
|
||||
|
||||
// NewService creates new file Service.
|
||||
func NewService(cfg Config) *Service {
|
||||
cfg.setDefaults()
|
||||
|
||||
return &Service{
|
||||
storage: cfg.Storage,
|
||||
hashPartSize: cfg.HashPartSize,
|
||||
hashRangeSize: cfg.HashRangeSize,
|
||||
}
|
||||
}
|
||||
|
||||
// OnMessage implements tgtest.Handler.
|
||||
func (m *Service) OnMessage(server *tgtest.Server, req *tgtest.Request) error {
|
||||
id, err := req.Buf.PeekID()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch id {
|
||||
case tg.UploadGetFileRequestTypeID:
|
||||
fileReq := tg.UploadGetFileRequest{}
|
||||
if err := fileReq.Decode(req.Buf); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r, err := m.UploadGetFile(req.RequestCtx, &fileReq)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return server.SendResult(req, r)
|
||||
case tg.UploadGetFileHashesRequestTypeID:
|
||||
fileReq := tg.UploadGetFileHashesRequest{}
|
||||
if err := fileReq.Decode(req.Buf); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r, err := m.UploadGetFileHashes(req.RequestCtx, &fileReq)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return server.SendResult(req, &tg.FileHashVector{Elems: r})
|
||||
case tg.UploadSaveFilePartRequestTypeID:
|
||||
fileReq := tg.UploadSaveFilePartRequest{}
|
||||
if err := fileReq.Decode(req.Buf); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r, err := m.UploadSaveFilePart(req.RequestCtx, &fileReq)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return server.SendBool(req, r)
|
||||
case tg.UploadSaveBigFilePartRequestTypeID:
|
||||
fileReq := tg.UploadSaveBigFilePartRequest{}
|
||||
if err := fileReq.Decode(req.Buf); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r, err := m.UploadSaveBigFilePart(req.RequestCtx, &fileReq)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return server.SendBool(req, r)
|
||||
default:
|
||||
return services.ErrMethodNotImplemented
|
||||
}
|
||||
}
|
||||
|
||||
// Register registers service handlers.
|
||||
func (m *Service) Register(dispatcher *tgtest.Dispatcher) {
|
||||
dispatcher.HandleFunc(tg.UploadGetFileRequestTypeID, m.OnMessage)
|
||||
dispatcher.HandleFunc(tg.UploadGetFileHashesRequestTypeID, m.OnMessage)
|
||||
dispatcher.HandleFunc(tg.UploadSaveFilePartRequestTypeID, m.OnMessage)
|
||||
dispatcher.HandleFunc(tg.UploadSaveBigFilePartRequestTypeID, m.OnMessage)
|
||||
}
|
||||
@@ -0,0 +1,72 @@
|
||||
package file
|
||||
|
||||
import (
|
||||
"io"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/syncio"
|
||||
)
|
||||
|
||||
// File represents Telegram file.
|
||||
type File interface {
|
||||
io.ReaderAt
|
||||
io.WriterAt
|
||||
io.Closer
|
||||
PartSize() int
|
||||
SetPartSize(v int)
|
||||
Size() int
|
||||
}
|
||||
|
||||
// Storage is an abstraction for Telegram file storage.
|
||||
type Storage interface {
|
||||
Open(name string) (File, error)
|
||||
}
|
||||
|
||||
type memFile struct {
|
||||
syncio.BufWriterAt
|
||||
partSize int32
|
||||
_ [4]byte
|
||||
}
|
||||
|
||||
func (m *memFile) Size() int {
|
||||
return m.Len()
|
||||
}
|
||||
|
||||
func (m *memFile) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *memFile) PartSize() int {
|
||||
return int(atomic.LoadInt32(&m.partSize))
|
||||
}
|
||||
|
||||
func (m *memFile) SetPartSize(v int) {
|
||||
atomic.StoreInt32(&m.partSize, int32(v))
|
||||
}
|
||||
|
||||
// InMemory is an inmemory implementation of file storage.
|
||||
type InMemory struct {
|
||||
files map[string]*memFile
|
||||
filesMux sync.Mutex
|
||||
}
|
||||
|
||||
// NewInMemory creates new InMemory.
|
||||
func NewInMemory() *InMemory {
|
||||
return &InMemory{
|
||||
files: map[string]*memFile{},
|
||||
}
|
||||
}
|
||||
|
||||
// Open implement Storage.
|
||||
func (i *InMemory) Open(name string) (File, error) {
|
||||
i.filesMux.Lock()
|
||||
defer i.filesMux.Unlock()
|
||||
file, ok := i.files[name]
|
||||
if !ok {
|
||||
file = &memFile{}
|
||||
i.files[name] = file
|
||||
}
|
||||
|
||||
return file, nil
|
||||
}
|
||||
@@ -0,0 +1,115 @@
|
||||
package file
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
"go.uber.org/multierr"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/constant"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tgerr"
|
||||
)
|
||||
|
||||
// https://core.telegram.org/api/files#uploading-files
|
||||
const (
|
||||
// Each part should have a sequence number, file_part, with a value ranging from 0 to 3,999.
|
||||
uploadPartsLimit = constant.UploadMaxParts
|
||||
|
||||
// `part_size % 1024 = 0` (divisible by 1KB)
|
||||
uploadPaddingPartSize = constant.UploadPadding
|
||||
// `524288 % part_size = 0` (512KB must be evenly divisible by part_size)
|
||||
uploadMaximumPartSize = constant.UploadMaxPartSize
|
||||
)
|
||||
|
||||
type upload interface {
|
||||
// GetFileID returns random file identifier created by the client.
|
||||
GetFileID() int64
|
||||
// GetFilePart returns numerical order of a part.
|
||||
GetFilePart() int
|
||||
// GetBytes returns binary data, content of a part.
|
||||
GetBytes() []byte
|
||||
}
|
||||
|
||||
func validatePartSize(got, stored int) *tgerr.Error {
|
||||
switch {
|
||||
case got == 0:
|
||||
return tgerr.New(400, tg.ErrFilePartEmpty)
|
||||
case got > uploadMaximumPartSize:
|
||||
return tgerr.New(400, tg.ErrFilePartTooBig)
|
||||
}
|
||||
|
||||
if stored == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch {
|
||||
case got != stored:
|
||||
return tgerr.New(400, tg.ErrFilePartSizeChanged)
|
||||
case uploadMaximumPartSize%got != 0,
|
||||
got%uploadPaddingPartSize != 0:
|
||||
return tgerr.New(400, tg.ErrFilePartSizeInvalid)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Service) write(ctx context.Context, request upload) (err error) {
|
||||
// TODO(tdakkota): Better way to handle user id. For now we haven't auth service to pair
|
||||
// user ID and authkey
|
||||
id, ok := ctx.Value("user_id").(int)
|
||||
if !ok {
|
||||
id = 10
|
||||
}
|
||||
|
||||
file, err := m.storage.Open(fmt.Sprintf("%d_%d", id, request.GetFileID()))
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "open file")
|
||||
}
|
||||
defer func() {
|
||||
multierr.AppendInto(&err, file.Close())
|
||||
}()
|
||||
|
||||
part := request.GetFilePart()
|
||||
if part < 0 || part > uploadPartsLimit {
|
||||
return tgerr.New(400, tg.ErrFilePartInvalid)
|
||||
}
|
||||
data := request.GetBytes()
|
||||
partSize := file.PartSize()
|
||||
if err := validatePartSize(len(data), partSize); err != nil {
|
||||
return err
|
||||
}
|
||||
if partSize == 0 {
|
||||
partSize = len(data)
|
||||
file.SetPartSize(partSize)
|
||||
}
|
||||
|
||||
offset := int64(partSize * part)
|
||||
if _, err := file.WriteAt(data, offset); err != nil {
|
||||
return errors.Errorf("write at %d-%d", offset, offset+int64(len(data)))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Service) UploadSaveFilePart(ctx context.Context, request *tg.UploadSaveFilePartRequest) (bool, error) {
|
||||
if err := m.write(ctx, request); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (m *Service) UploadSaveBigFilePart(ctx context.Context, request *tg.UploadSaveBigFilePartRequest) (bool, error) {
|
||||
part := request.FileTotalParts
|
||||
if part < 0 || part > uploadPartsLimit {
|
||||
return false, tgerr.New(400, tg.ErrFilePartsInvalid)
|
||||
}
|
||||
|
||||
if err := m.write(ctx, request); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
@@ -0,0 +1,24 @@
|
||||
package tgtest
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
|
||||
"go.uber.org/zap/zapcore"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
|
||||
)
|
||||
|
||||
// Session represents connection session.
|
||||
type Session struct {
|
||||
// ID is a Session ID.
|
||||
ID int64
|
||||
// AuthKey is an attached key.
|
||||
AuthKey crypto.AuthKey
|
||||
}
|
||||
|
||||
// MarshalLogObject implements zap.ObjectMarshaler.
|
||||
func (s Session) MarshalLogObject(encoder zapcore.ObjectEncoder) error {
|
||||
encoder.AddInt64("session_id", s.ID)
|
||||
encoder.AddString("key_id", hex.EncodeToString(s.AuthKey.ID[:]))
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,88 @@
|
||||
package tgtest
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
type testTransportHandler struct {
|
||||
t testing.TB
|
||||
logger *zap.Logger
|
||||
// For ACK testing proposes.
|
||||
// We send ack only after second request
|
||||
counter int
|
||||
counterMx sync.Mutex
|
||||
|
||||
message string // immutable
|
||||
}
|
||||
|
||||
// TestTransport is a handler for testing MTProto transport.
|
||||
func TestTransport(t testing.TB, logger *zap.Logger, message string) Handler {
|
||||
return &testTransportHandler{
|
||||
t: t,
|
||||
logger: logger,
|
||||
message: message,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *testTransportHandler) OnMessage(server *Server, req *Request) error {
|
||||
id, err := req.Buf.PeekID()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
h.logger.Info("New message", zap.String("id", fmt.Sprintf("%#x", id)))
|
||||
|
||||
switch id {
|
||||
case tg.UsersGetUsersRequestTypeID:
|
||||
getUsers := tg.UsersGetUsersRequest{}
|
||||
|
||||
if err := getUsers.Decode(req.Buf); err != nil {
|
||||
return err
|
||||
}
|
||||
h.logger.Info("New client connected, invoke received")
|
||||
|
||||
if err := server.SendVector(req, &tg.User{
|
||||
ID: 10,
|
||||
AccessHash: 10,
|
||||
Username: "rustcocks",
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
h.logger.Info("Sending message", zap.String("message", h.message))
|
||||
return server.SendUpdates(req.RequestCtx, req.Session, &tg.UpdateNewMessage{
|
||||
Message: &tg.Message{
|
||||
ID: 1,
|
||||
PeerID: &tg.PeerUser{UserID: 1},
|
||||
Message: h.message,
|
||||
},
|
||||
})
|
||||
case tg.MessagesSendMessageRequestTypeID:
|
||||
m := &tg.MessagesSendMessageRequest{}
|
||||
if err := m.Decode(req.Buf); err != nil {
|
||||
h.t.Fail()
|
||||
return err
|
||||
}
|
||||
|
||||
assert.Equal(h.t, "какими деньгами?", m.Message)
|
||||
|
||||
h.counterMx.Lock()
|
||||
h.counter++
|
||||
if h.counter < 2 {
|
||||
h.counterMx.Unlock()
|
||||
return nil
|
||||
}
|
||||
h.counterMx.Unlock()
|
||||
|
||||
return server.SendResult(req, &tg.Updates{})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
package tgtest
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
)
|
||||
|
||||
// genericVector is a simple helper to encode a vector of TL objects.
|
||||
type genericVector struct {
|
||||
Elems []bin.Encoder
|
||||
}
|
||||
|
||||
// Encode implements bin.Encoder.
|
||||
func (vec *genericVector) Encode(b *bin.Buffer) error {
|
||||
b.PutVectorHeader(len(vec.Elems))
|
||||
for idx, v := range vec.Elems {
|
||||
if v == nil {
|
||||
return fmt.Errorf("unable to encode Vector<%T>: field Elems element with index %d is nil", v, idx)
|
||||
}
|
||||
if err := v.Encode(b); err != nil {
|
||||
return fmt.Errorf("unable to encode Vector<%T>: field Elems element with index %d: %w", v, idx, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,56 @@
|
||||
package tgtest
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/testutil"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
type badEncoder struct{}
|
||||
|
||||
func (e badEncoder) Encode(b *bin.Buffer) error {
|
||||
return testutil.TestError()
|
||||
}
|
||||
|
||||
func TestGenericVector_Encode(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
data []bin.Encoder
|
||||
expect bin.Decoder
|
||||
wantErr bool
|
||||
}{
|
||||
{"Empty", nil, nil, false},
|
||||
{"Nil", []bin.Encoder{nil}, nil, true},
|
||||
{"BadObject", []bin.Encoder{badEncoder{}}, nil, true},
|
||||
{
|
||||
"Plain",
|
||||
[]bin.Encoder{&tg.BotCommand{
|
||||
Command: "hello",
|
||||
Description: "world",
|
||||
}},
|
||||
&tg.BotCommandVector{},
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
var (
|
||||
a = require.New(t)
|
||||
v = genericVector{Elems: test.data}
|
||||
buf bin.Buffer
|
||||
)
|
||||
|
||||
err := v.Encode(&buf)
|
||||
if test.wantErr {
|
||||
a.Error(err)
|
||||
} else if test.expect != nil {
|
||||
a.NoError(test.expect.Decode(&buf))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user