move gotd fork into repo. (#111)

- update to latest telegram layer
- remove some references to fields in tg.Entities that don't exist in
the schema
- originally added here:
https://github.com/beeper/td/commit/820929062a2ba0104397bc01235ab58a9cff780e
  - referenced here
-
https://github.com/mautrix/telegramgo/commit/124f0967ed195b5a380c9bd02e170ada9710dde3
-
https://github.com/mautrix/telegramgo/commit/4205047aab2e0639217148b5d125bfaab668bd8e
This commit is contained in:
Adam Van Ymeren
2025-06-27 20:03:37 -07:00
committed by GitHub
parent 0952df0244
commit 7a04f298d2
19264 changed files with 1539697 additions and 84 deletions
+63
View File
@@ -0,0 +1,63 @@
package tgtest
import (
"context"
"sync"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
"go.mau.fi/mautrix-telegram/pkg/gotd/transport"
)
type bufferedConn struct {
conn transport.Conn
recv []bin.Buffer
recvMux sync.Mutex
}
func newBufferedConn(conn transport.Conn) *bufferedConn {
return &bufferedConn{conn: conn}
}
func (c *bufferedConn) push(b *bin.Buffer) {
c.recvMux.Lock()
c.recv = append(c.recv, bin.Buffer{Buf: b.Copy()})
c.recvMux.Unlock()
}
func (c *bufferedConn) pop() (r bin.Buffer, ok bool) {
c.recvMux.Lock()
defer c.recvMux.Unlock()
if len(c.recv) < 1 {
return
}
r, c.recv = c.recv[len(c.recv)-1], c.recv[:len(c.recv)-1]
ok = true
return
}
func (c *bufferedConn) Push(b *bin.Buffer) {
c.push(b)
}
func (c *bufferedConn) Pop() (bin.Buffer, bool) {
return c.pop()
}
func (c *bufferedConn) Send(ctx context.Context, b *bin.Buffer) error {
return c.conn.Send(ctx, b)
}
func (c *bufferedConn) Recv(ctx context.Context, b *bin.Buffer) error {
e, ok := c.Pop()
if ok {
b.ResetTo(e.Copy())
return nil
}
return c.conn.Recv(ctx, b)
}
func (c *bufferedConn) Close() error {
return c.conn.Close()
}
+64
View File
@@ -0,0 +1,64 @@
package tgtest
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
"go.mau.fi/mautrix-telegram/pkg/gotd/transport"
)
func TestBufferedConn(t *testing.T) {
a := assert.New(t)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
i := transport.Intermediate
c1, c2 := i.Pipe()
b := newBufferedConn(c1)
defer func() {
a.NoError(b.Close())
a.NoError(c2.Close())
}()
payload := []byte("abcdabcd")
go func() {
b1 := &bin.Buffer{Buf: payload}
a.NoError(c2.Send(ctx, b1))
}()
// Test Recv before Push.
recvBuf := &bin.Buffer{}
a.NoError(b.Recv(ctx, recvBuf))
a.Equal(payload, recvBuf.Buf)
pushed := []byte("12345678")
b.Push(&bin.Buffer{Buf: pushed})
go func() {
b1 := &bin.Buffer{Buf: payload}
a.NoError(c2.Send(ctx, b1))
}()
// Test Push.
recvBuf.Reset()
a.NoError(b.Recv(ctx, recvBuf))
a.Equal(pushed, recvBuf.Buf)
// Test Recv after Push.
recvBuf.Reset()
a.NoError(b.Recv(ctx, recvBuf))
a.Equal(payload, recvBuf.Buf)
// Test send.
go func() {
b1 := &bin.Buffer{Buf: payload}
a.NoError(b.Send(ctx, b1))
}()
recvBuf.Reset()
a.NoError(c2.Recv(ctx, recvBuf))
a.Equal(payload, recvBuf.Buf)
}
+96
View File
@@ -0,0 +1,96 @@
// Package cluster contains Telegram multi-DC setup utilities.
package cluster
import (
"io"
"go.uber.org/zap"
"go.mau.fi/mautrix-telegram/pkg/gotd/exchange"
"go.mau.fi/mautrix-telegram/pkg/gotd/tdsync"
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/dcs"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
"go.mau.fi/mautrix-telegram/pkg/gotd/tgtest"
"go.mau.fi/mautrix-telegram/pkg/gotd/tgtest/services/config"
)
type setup struct {
srv *tgtest.Server
dispatch *tgtest.Dispatcher
}
// Cluster is a cluster of multiple servers, representing multiple Telegram datacenters.
type Cluster struct {
// denotes to use websocket listener
web bool
setups map[int]setup
keys []exchange.PublicKey
// DCs config state.
cfg tg.Config
cdnCfg tg.CDNConfig
domains map[int]string
// Signal for readiness.
ready *tdsync.Ready
// RPC dispatcher.
common *tgtest.Dispatcher
log *zap.Logger
random io.Reader
protocol dcs.Protocol
}
// NewCluster creates new server Cluster.
func NewCluster(opts Options) *Cluster {
opts.setDefaults()
q := &Cluster{
web: opts.Web,
setups: map[int]setup{},
keys: nil,
cfg: opts.Config,
cdnCfg: opts.CDNConfig,
domains: map[int]string{},
ready: tdsync.NewReady(),
common: tgtest.NewDispatcher(),
log: opts.Logger,
random: opts.Random,
protocol: opts.Protocol,
}
config.NewService(&q.cfg, &q.cdnCfg).Register(q.common)
q.common.Fallback(q.fallback())
return q
}
// List returns DCs list.
func (c *Cluster) List() dcs.List {
return dcs.List{
Options: c.cfg.DCOptions,
Domains: c.domains,
}
}
// Resolver returns dcs.Resolver to use.
func (c *Cluster) Resolver() dcs.Resolver {
if c.web {
return dcs.Websocket(dcs.WebsocketOptions{})
}
return dcs.Plain(dcs.PlainOptions{
Protocol: c.protocol,
})
}
// Keys returns all servers public keys.
func (c *Cluster) Keys() []exchange.PublicKey {
return c.keys
}
// Ready returns signal channel to await readiness.
func (c *Cluster) Ready() <-chan struct{} {
return c.ready.Ready()
}
+70
View File
@@ -0,0 +1,70 @@
package cluster
import (
"crypto/rsa"
"go.uber.org/zap"
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
"go.mau.fi/mautrix-telegram/pkg/gotd/exchange"
"go.mau.fi/mautrix-telegram/pkg/gotd/tgtest"
"go.mau.fi/mautrix-telegram/pkg/gotd/tgtest/services"
"go.mau.fi/mautrix-telegram/pkg/gotd/transport"
)
// Common returns common dispatcher.
func (c *Cluster) Common() *tgtest.Dispatcher {
return c.common
}
func (c *Cluster) getCodec() (codec func() transport.Codec) {
if !c.web {
codec = c.protocol.Codec
}
return codec
}
// DC registers new server and returns it.
func (c *Cluster) DC(id int, name string) (*tgtest.Server, *tgtest.Dispatcher) {
if s, ok := c.setups[id]; ok {
return s.srv, s.dispatch
}
key, err := rsa.GenerateKey(c.random, crypto.RSAKeyBits)
if err != nil {
// TODO(tdakkota): Return error instead.
panic(err)
}
privateKey := exchange.PrivateKey{
RSA: key,
}
d := tgtest.NewDispatcher()
server := tgtest.NewServer(privateKey, tgtest.UnpackInvoke(d), tgtest.ServerOptions{
DC: id,
Logger: c.log.Named(name).With(zap.Int("dc_id", id)),
Codec: c.getCodec(),
})
c.setups[id] = setup{
srv: server,
dispatch: d,
}
c.keys = append(c.keys, server.Key())
// We set server fallback handler to dispatch request in order
// 1) Explicit DC handler
// 2) Explicit common handler
// 3) Common fallback
d.Fallback(c.Common())
return server, d
}
// Dispatch registers new server and returns its dispatcher.
func (c *Cluster) Dispatch(id int, name string) *tgtest.Dispatcher {
_, d := c.DC(id, name)
return d
}
func (c *Cluster) fallback() tgtest.HandlerFunc {
return services.NotImplemented
}
+17
View File
@@ -0,0 +1,17 @@
package cluster
import (
"context"
"net"
"github.com/go-faster/errors"
)
func newLocalListener(ctx context.Context) (net.Listener, error) {
cfg := net.ListenConfig{}
l, err := cfg.Listen(ctx, "tcp4", "127.0.0.1:0")
if err != nil {
return nil, errors.Wrap(err, "listen")
}
return l, nil
}
+45
View File
@@ -0,0 +1,45 @@
package cluster
import (
"io"
"go.uber.org/zap"
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/dcs"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
"go.mau.fi/mautrix-telegram/pkg/gotd/transport"
)
// Options of Cluster.
type Options struct {
// Web denotes to use websocket listener.
Web bool
// Random is random source. Used to generate RSA keys.
// Defaults to rand.Reader.
Random io.Reader
// Logger is instance of zap.Logger. No logs by default.
Logger *zap.Logger
// Codec constructor.
// Defaults to nil (underlying transport server detects protocol automatically).
Protocol dcs.Protocol
// Config is an initial cluster config.
Config tg.Config
// CDNConfig is an initial cluster CDN config.
CDNConfig tg.CDNConfig
}
func (opt *Options) setDefaults() {
// It's okay to use zero value Web.
if opt.Random == nil {
opt.Random = crypto.DefaultRand()
}
if opt.Logger == nil {
opt.Logger = zap.NewNop()
}
if opt.Protocol == nil {
opt.Protocol = transport.Intermediate
}
// It's okay to use zero value Config.
// It's okay to use zero value CDNConfig.
}
+97
View File
@@ -0,0 +1,97 @@
package cluster
import (
"context"
"fmt"
"net"
"net/http"
"net/url"
"time"
"github.com/go-faster/errors"
"go.mau.fi/mautrix-telegram/pkg/gotd/tdsync"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
"go.mau.fi/mautrix-telegram/pkg/gotd/transport"
)
// Up runs all servers in a cluster.
func (c *Cluster) Up(ctx context.Context) error {
g := tdsync.NewCancellableGroup(ctx)
listen := func(ctx context.Context, _ int) (net.Listener, error) {
return newLocalListener(ctx)
}
if c.web {
// Create local random listener
l, err := newLocalListener(ctx)
if err != nil {
return err
}
mux := http.NewServeMux()
srv := http.Server{
ReadHeaderTimeout: time.Second * 10,
Handler: mux,
BaseContext: func(net.Listener) context.Context {
return ctx
},
}
g.Go(func(ctx context.Context) error {
if err := srv.Serve(l); err != nil && !errors.Is(err, http.ErrServerClosed) {
return errors.Wrap(err, "serve")
}
return nil
})
g.Go(func(ctx context.Context) error {
<-ctx.Done()
return srv.Close()
})
baseURL := url.URL{
Scheme: "http",
Host: l.Addr().String(),
}
listen = func(ctx context.Context, dc int) (net.Listener, error) {
listener, handler := transport.WebsocketListener(l.Addr())
path := fmt.Sprintf("/dc/%d", dc)
mux.Handle(path, handler)
dcURL := baseURL
dcURL.Path = path
c.domains[dc] = dcURL.String()
return listener, nil
}
}
for dcID, s := range c.setups {
l, err := listen(ctx, dcID)
if err != nil {
return errors.Wrapf(err, "DC %d: listen port", dcID)
}
if !c.web {
// Add TCP listeners to config.
if addr, ok := l.Addr().(*net.TCPAddr); ok {
c.cfg.DCOptions = append(c.cfg.DCOptions, tg.DCOption{
Ipv6: addr.IP.To16() != nil,
Static: true,
ID: dcID,
IPAddress: addr.IP.String(),
Port: addr.Port,
})
}
}
// Copy iteration value.
srv := s.srv
g.Go(func(ctx context.Context) error {
return srv.Serve(ctx, transport.ListenCodec(nil, l))
})
}
c.ready.Signal()
return g.Wait()
}
+92
View File
@@ -0,0 +1,92 @@
package tgtest
import (
"sync"
"go.uber.org/atomic"
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
"go.mau.fi/mautrix-telegram/pkg/gotd/transport"
)
type connection struct {
transport.Conn
sent atomic.Bool
}
func (conn *connection) sentCreated() bool {
return conn.sent.Swap(true)
}
// users contains all server connections and sessions.
type users struct {
sessions map[[8]byte]crypto.AuthKey
sessionsMux sync.Mutex
conns map[int64]*connection
connsMux sync.Mutex
}
func newUsers() *users {
return &users{
conns: map[int64]*connection{},
sessions: map[[8]byte]crypto.AuthKey{},
}
}
func (c *users) createConnection(key int64, tConn transport.Conn) *connection {
c.connsMux.Lock()
defer c.connsMux.Unlock()
if v, ok := c.conns[key]; ok {
return v
}
conn := &connection{
Conn: tConn,
}
c.conns[key] = conn
return conn
}
func (c *users) getConnection(key int64) (conn *connection, ok bool) {
c.connsMux.Lock()
conn, ok = c.conns[key]
c.connsMux.Unlock()
return
}
func (c *users) deleteConnection(key int64) {
c.connsMux.Lock()
conn := c.conns[key]
if conn != nil {
_ = conn.Close()
}
delete(c.conns, key)
c.connsMux.Unlock()
}
func (c *users) addSession(key crypto.AuthKey) {
c.sessionsMux.Lock()
c.sessions[key.ID] = key
c.sessionsMux.Unlock()
}
func (c *users) getSession(k [8]byte) (s crypto.AuthKey, ok bool) {
c.connsMux.Lock()
s, ok = c.sessions[k]
c.connsMux.Unlock()
return
}
func (c *users) Close() error {
c.connsMux.Lock()
for _, conn := range c.conns {
_ = conn.Close()
}
c.connsMux.Unlock()
return nil
}
+82
View File
@@ -0,0 +1,82 @@
package tgtest
import (
"sync"
"github.com/go-faster/errors"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
)
// Dispatcher is a plain handler to map requests by ID.
type Dispatcher struct {
reqs map[uint32]Handler
mux sync.Mutex
fallback Handler
}
// NewDispatcher creates new Dispatcher.
func NewDispatcher() *Dispatcher {
return &Dispatcher{
reqs: map[uint32]Handler{},
}
}
// OnMessage implements Handler
func (d *Dispatcher) OnMessage(server *Server, req *Request) error {
id, err := req.Buf.PeekID()
if err != nil {
return errors.Wrap(err, "peek id")
}
d.mux.Lock()
h, ok := d.reqs[id]
fallback := d.fallback
d.mux.Unlock()
if ok {
return h.OnMessage(server, req)
}
if fallback != nil {
return fallback.OnMessage(server, req)
}
return errors.Errorf("unexpected type %#x", id)
}
// Handle sets handler for given TypeID.
func (d *Dispatcher) Handle(id uint32, h Handler) *Dispatcher {
d.mux.Lock()
d.reqs[id] = h
d.mux.Unlock()
return d
}
// HandleFunc sets handler for given TypeID.
func (d *Dispatcher) HandleFunc(id uint32, h func(server *Server, req *Request) error) *Dispatcher {
return d.Handle(id, HandlerFunc(h))
}
// Result sets constant result for given TypeID.
// NB: it uses rpc_result to pack given encoder.
func (d *Dispatcher) Result(id uint32, msg bin.Encoder) *Dispatcher {
return d.HandleFunc(id, func(server *Server, req *Request) error {
return server.SendResult(req, msg)
})
}
// Vector sets constant Vector result for given TypeID.
// NB: it uses rpc_result to pack generic vector with given encoders.
func (d *Dispatcher) Vector(id uint32, msgs ...bin.Encoder) *Dispatcher {
return d.HandleFunc(id, func(server *Server, req *Request) error {
return server.SendVector(req, msgs...)
})
}
// Fallback sets fallback handler.
func (d *Dispatcher) Fallback(h Handler) *Dispatcher {
d.mux.Lock()
d.fallback = h
d.mux.Unlock()
return d
}
+2
View File
@@ -0,0 +1,2 @@
// Package tgtest provides test Telegram server for basic end-to-end tests.
package tgtest
+57
View File
@@ -0,0 +1,57 @@
package tgtest
import (
"context"
"github.com/go-faster/errors"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
"go.mau.fi/mautrix-telegram/pkg/gotd/exchange"
"go.mau.fi/mautrix-telegram/pkg/gotd/proto/codec"
"go.mau.fi/mautrix-telegram/pkg/gotd/transport"
)
type exchangeConn struct {
transport.Conn
}
func (e exchangeConn) Recv(ctx context.Context, b *bin.Buffer) error {
for {
if err := e.Conn.Recv(ctx, b); err != nil {
return err
}
var authKeyID [8]byte
if err := b.PeekN(authKeyID[:], len(authKeyID)); err != nil {
return errors.Wrap(err, "peek id")
}
if authKeyID != [8]byte{} {
// TODO(tdakkota): what if client send registered auth key during key exchange?
buf := bin.Buffer{}
buf.PutInt32(-codec.CodeAuthKeyNotFound)
if err := e.Conn.Send(ctx, &buf); err != nil {
return errors.Wrap(err, "send")
}
continue
}
return nil
}
}
// exchange starts MTProto key exchange.
func (s *Server) exchange(ctx context.Context, conn transport.Conn) (crypto.AuthKey, error) {
r, err := exchange.NewExchanger(conn, s.dcID).
WithClock(s.clock).
WithLogger(s.log.Named("exchange")).
WithRand(s.cipher.Rand()).
Server(s.key).Run(ctx)
if err != nil {
return crypto.AuthKey{}, err
}
return r.Key, nil
}
+61
View File
@@ -0,0 +1,61 @@
package tgtest
import (
"context"
"testing"
"time"
"github.com/go-faster/errors"
"github.com/stretchr/testify/assert"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
"go.mau.fi/mautrix-telegram/pkg/gotd/proto/codec"
"go.mau.fi/mautrix-telegram/pkg/gotd/tdsync"
"go.mau.fi/mautrix-telegram/pkg/gotd/transport"
)
func Test_exchangeConn_Recv(t *testing.T) {
a := assert.New(t)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
i := transport.Intermediate
c1, c2 := i.Pipe()
defer func() {
a.NoError(c1.Close())
a.NoError(c2.Close())
}()
e := exchangeConn{Conn: c1}
s := "abcdabcd"
a.Len(s, 8)
grp := tdsync.NewCancellableGroup(ctx)
grp.Go(func(ctx context.Context) error {
b := bin.Buffer{Buf: []byte(s)}
if err := c2.Send(ctx, &b); err != nil {
return err
}
b.Reset()
var protocolErr *codec.ProtocolErr
if err := c2.Recv(ctx, &b); err != nil && !errors.As(err, &protocolErr) {
return err
}
b.ResetN(8)
b.Put([]byte(s))
if err := c2.Send(ctx, &b); err != nil {
return err
}
return nil
})
var b bin.Buffer
a.NoError(e.Recv(ctx, &b))
b.Skip(8)
a.Equal(s, string(b.Buf))
a.NoError(grp.Wait())
}
+144
View File
@@ -0,0 +1,144 @@
package tgtest
import (
"context"
"encoding/binary"
"github.com/go-faster/errors"
"go.uber.org/multierr"
"go.uber.org/zap"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
"go.mau.fi/mautrix-telegram/pkg/gotd/mt"
"go.mau.fi/mautrix-telegram/pkg/gotd/proto"
"go.mau.fi/mautrix-telegram/pkg/gotd/tgerr"
"go.mau.fi/mautrix-telegram/pkg/gotd/transport"
)
func (s *Server) rpcHandle(ctx context.Context, c transport.Conn, b *bin.Buffer) error {
m := &crypto.EncryptedMessage{}
if err := m.DecodeWithoutCopy(b); err != nil {
return errors.Wrap(err, "decode encrypted message")
}
key, ok := s.users.getSession(m.AuthKeyID)
if !ok {
return errors.New("invalid session")
}
msg, err := s.cipher.Decrypt(key, m)
if err != nil {
return errors.Wrap(err, "decrypt message")
}
session := Session{
ID: msg.SessionID,
AuthKey: key,
}
if conn := s.users.createConnection(msg.SessionID, c); !conn.sentCreated() {
s.log.Debug("Send handleSessionCreated event", zap.Inline(session))
salt := int64(binary.LittleEndian.Uint64(key.ID[:]))
if err := s.sendSessionCreated(ctx, session, salt); err != nil {
return err
}
}
// Buffer now contains plaintext message payload.
b.ResetTo(msg.Data())
if err := s.handle(&Request{
DC: s.dcID,
Session: session,
MsgID: msg.MessageID,
Buf: b,
RequestCtx: ctx,
}); err != nil {
return errors.Wrap(err, "handle")
}
return nil
}
func (s *Server) handle(req *Request) error {
in := req.Buf
id, err := in.PeekID()
if err != nil {
return errors.Wrap(err, "peek id")
}
s.log.Debug("Got request",
zap.Inline(req.Session),
zap.Int64("msg_id", req.MsgID),
zap.String("type", s.types.Get(id)),
)
// TODO(tdakkota): unpack all containers
switch id {
case mt.PingDelayDisconnectRequestTypeID:
pingReq := mt.PingDelayDisconnectRequest{}
if err := pingReq.Decode(in); err != nil {
return err
}
return s.SendPong(req, pingReq.PingID)
case mt.PingRequestTypeID:
pingReq := mt.PingRequest{}
if err := pingReq.Decode(in); err != nil {
return err
}
return s.SendPong(req, pingReq.PingID)
case mt.GetFutureSaltsRequestTypeID:
saltsRequest := mt.GetFutureSaltsRequest{}
if err := saltsRequest.Decode(in); err != nil {
return err
}
return s.SendEternalSalt(req)
case mt.RPCDropAnswerRequestTypeID:
drop := mt.RPCDropAnswerRequest{}
if err := drop.Decode(in); err != nil {
return err
}
return s.SendResult(req, &mt.RPCAnswerDroppedRunning{})
case proto.GZIPTypeID:
var content proto.GZIP
if err := content.Decode(in); err != nil {
return errors.Wrap(err, "gzip")
}
req.Buf = &bin.Buffer{Buf: content.Data}
case proto.MessageContainerTypeID:
var container proto.MessageContainer
if err := container.Decode(in); err != nil {
return errors.Wrap(err, "container")
}
var err error
for _, msg := range container.Messages {
err = multierr.Append(err, s.handle(&Request{
DC: req.DC,
Session: req.Session,
MsgID: msg.ID,
Buf: &bin.Buffer{Buf: msg.Body},
RequestCtx: req.RequestCtx,
}))
}
return err
}
if err := s.handler.OnMessage(s, req); err != nil {
var rpcErr *tgerr.Error
if errors.As(err, &rpcErr) {
return s.SendErr(req, rpcErr)
}
return err
}
return nil
}
+37
View File
@@ -0,0 +1,37 @@
package tgtest
import (
"context"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
)
// Request represents MTProto RPC request structure.
type Request struct {
// DC ID from server structure.
// Used to make handler less stateful.
DC int
// Session is a user session.
Session Session
// MsgID is a message ID of RPC request.
MsgID int64
// Buf contains RPC request
Buf *bin.Buffer
// RequestCtx is a request context.
RequestCtx context.Context
}
// Handler is a RPC request handler.
type Handler interface {
OnMessage(server *Server, req *Request) error
}
var _ Handler = HandlerFunc(nil)
// HandlerFunc is functional adapter for Handler.OnMessage method.
type HandlerFunc func(server *Server, req *Request) error
// OnMessage implements Handler.
func (h HandlerFunc) OnMessage(server *Server, req *Request) error {
return h(server, req)
}
+92
View File
@@ -0,0 +1,92 @@
package tgtest
import (
"context"
"github.com/go-faster/errors"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
"go.mau.fi/mautrix-telegram/pkg/gotd/exchange"
"go.mau.fi/mautrix-telegram/pkg/gotd/proto/codec"
"go.mau.fi/mautrix-telegram/pkg/gotd/transport"
)
func (s *Server) read(ctx context.Context, conn transport.Conn, b *bin.Buffer) error {
b.Reset()
ctx, cancel := context.WithTimeout(ctx, s.readTimeout)
defer cancel()
if err := conn.Recv(ctx, b); err != nil {
return err
}
return nil
}
func (s *Server) sendProtoError(ctx context.Context, conn transport.Conn, e int32) error {
var buf bin.Buffer
buf.PutInt32(-e)
ctx, cancel := context.WithTimeout(ctx, s.writeTimeout)
defer cancel()
if err := conn.Send(ctx, &buf); err != nil {
return errors.Wrap(err, "send")
}
return nil
}
func (s *Server) serveConn(ctx context.Context, conn transport.Conn) error {
s.log.Debug("User connected")
defer func() {
_ = conn.Close()
s.log.Debug("User disconnected")
}()
b := new(bin.Buffer)
for {
if err := s.read(ctx, conn, b); err != nil {
return errors.Wrap(err, "read")
}
var authKeyID [8]byte
if err := b.PeekN(authKeyID[:], len(authKeyID)); err != nil {
return errors.Wrap(err, "peek id")
}
// TODO(tdakkota): dispatch by type ID instead?
if _, ok := s.users.getSession(authKeyID); ok {
if err := s.rpcHandle(ctx, conn, b); err != nil {
return errors.Wrap(err, "handle")
}
continue
}
// If authKeyID not found and is not zero, so send protocol error.
if authKeyID != [8]byte{} {
if err := s.sendProtoError(ctx, conn, codec.CodeAuthKeyNotFound); err != nil {
return errors.Wrap(err, "send AuthKeyNotFound")
}
continue
}
s.log.Debug("Starting key exchange")
c := newBufferedConn(conn)
c.Push(b)
key, err := s.exchange(ctx, exchangeConn{Conn: c})
if err != nil {
var exchangeErr *exchange.ServerExchangeError
if errors.As(err, &exchangeErr) {
code := exchangeErr.Code
if err := s.sendProtoError(ctx, c, code); err != nil {
return errors.Wrapf(err, "send proto error %v", code)
}
return nil
}
return errors.Wrap(err, "key exchange failed")
}
s.users.addSession(key)
}
}
+101
View File
@@ -0,0 +1,101 @@
package tgtest_test
import (
"context"
"crypto/rand"
"io"
"testing"
"time"
"github.com/go-faster/errors"
"github.com/stretchr/testify/require"
"go.uber.org/zap/zaptest"
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
"go.mau.fi/mautrix-telegram/pkg/gotd/session"
"go.mau.fi/mautrix-telegram/pkg/gotd/tdsync"
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
"go.mau.fi/mautrix-telegram/pkg/gotd/tgtest/cluster"
"go.mau.fi/mautrix-telegram/pkg/gotd/tgtest/services"
"go.mau.fi/mautrix-telegram/pkg/gotd/tgtest/services/config"
)
func TestSessionHandle(t *testing.T) {
test := func(storage session.Storage, t *testing.T) error {
log := zaptest.NewLogger(t)
defer func() { _ = log.Sync() }()
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
g := tdsync.NewCancellableGroup(ctx)
c := cluster.NewCluster(cluster.Options{
Logger: log.Named("cluster"),
})
d := c.Dispatch(2, "server").Fallback(services.NotImplemented)
config.NewService(&tg.Config{}, &tg.CDNConfig{}).Register(d)
g.Go(c.Up)
g.Go(func(ctx context.Context) error {
select {
case <-c.Ready():
case <-ctx.Done():
return ctx.Err()
}
defer g.Cancel()
client := telegram.NewClient(telegram.TestAppID, telegram.TestAppHash, telegram.Options{
PublicKeys: c.Keys(),
DC: 2,
DCList: c.List(),
Resolver: c.Resolver(),
NoUpdates: true,
Logger: log.Named("client"),
SessionStorage: storage,
RetryInterval: 100 * time.Millisecond,
})
return client.Run(ctx, func(ctx context.Context) error {
return nil
})
})
if err := g.Wait(); err != nil && !errors.Is(err, context.Canceled) {
return errors.Wrap(err, "wait")
}
return nil
}
t.Run("Empty", func(t *testing.T) {
a := require.New(t)
storage := session.StorageMemory{}
a.NoError(test(&storage, t))
_, err := storage.Bytes(nil)
a.NoError(err, "Must create new session")
})
t.Run("Unknown", func(t *testing.T) {
a := require.New(t)
loader := session.Loader{Storage: &session.StorageMemory{}}
ctx := context.Background()
key := crypto.Key{}
_, err := io.ReadFull(rand.Reader, key[:])
a.NoError(err)
authKey := key.WithID()
was := &session.Data{
DC: 2,
AuthKey: authKey.Value[:],
AuthKeyID: authKey.ID[:],
}
a.NoError(loader.Save(context.Background(), was))
a.NoError(test(loader.Storage, t))
data, err := loader.Load(ctx)
a.NoError(err)
a.NotEqual(was, data, "Must regenerate session")
})
}
+72
View File
@@ -0,0 +1,72 @@
package tgtest
import (
"github.com/go-faster/errors"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
)
// UnpackInvoke is a simple Handler middleware to unpack some Invoke*-like requests.
// Including:
//
// tg.InvokeWithLayerRequest
// tg.InitConnectionRequest
// tg.InvokeWithoutUpdatesRequest
func UnpackInvoke(next Handler) Handler {
return HandlerFunc(func(srv *Server, req *Request) error {
id, err := req.Buf.PeekID()
if err != nil {
return err
}
// TODO(tdakkota): handle more Invoke* requests.
var (
obj = peekIDObject{}
r bin.Decoder
)
for {
switch id {
case tg.InvokeWithLayerRequestTypeID:
r = &tg.InvokeWithLayerRequest{
Query: &obj,
}
// TODO(tdakkota): pass Layer to session.
case tg.InitConnectionRequestTypeID:
r = &tg.InitConnectionRequest{
Query: &obj,
}
// TODO(tdakkota): pass DeviceInfo to session.
case tg.InvokeWithoutUpdatesRequestTypeID:
r = &tg.InvokeWithoutUpdatesRequest{
Query: &obj,
}
// TODO(tdakkota): pass NoUpdates flag to session.
default:
return next.OnMessage(srv, req)
}
if err := r.Decode(req.Buf); err != nil {
return err
}
id = obj.TypeID
}
})
}
type peekIDObject struct {
TypeID uint32
}
func (t *peekIDObject) Decode(b *bin.Buffer) error {
id, err := b.PeekID()
if err != nil {
return errors.Wrap(err, "peek id")
}
t.TypeID = id
return nil
}
func (t *peekIDObject) Encode(*bin.Buffer) error {
return errors.New("peekIDObject must not be encoded")
}
+191
View File
@@ -0,0 +1,191 @@
package tgtest
import (
"context"
"math"
"github.com/go-faster/errors"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
"go.mau.fi/mautrix-telegram/pkg/gotd/mt"
"go.mau.fi/mautrix-telegram/pkg/gotd/proto"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
"go.mau.fi/mautrix-telegram/pkg/gotd/tgerr"
)
const (
// MessageServerResponse is a message type of RPC calls result.
MessageServerResponse = proto.MessageServerResponse
// MessageFromServer is a message type of server-side updates.
MessageFromServer = proto.MessageFromServer
)
// Send sends given message to user session k.
// Parameter t denotes MTProto message type. It should be MessageServerResponse or MessageFromServer.
func (s *Server) Send(ctx context.Context, k Session, t proto.MessageType, message bin.Encoder) error {
conn, ok := s.users.getConnection(k.ID)
if !ok {
return errors.Errorf("send %T: invalid key: connection %s not found", message, k.AuthKey.String())
}
var b bin.Buffer
if err := message.Encode(&b); err != nil {
return errors.Wrap(err, "encode")
}
data := crypto.EncryptedMessageData{
SessionID: k.ID,
MessageDataLen: int32(b.Len()),
MessageDataWithPadding: b.Copy(),
MessageID: s.msgID.New(t),
}
err := s.cipher.Encrypt(k.AuthKey, data, &b)
if err != nil {
return errors.Wrap(err, "encrypt")
}
ctx, cancel := context.WithTimeout(ctx, s.writeTimeout)
defer cancel()
if err := conn.Send(ctx, &b); err != nil {
return errors.Wrap(err, "send")
}
return nil
}
func (s *Server) sendReq(req *Request, t proto.MessageType, encoder bin.Encoder) error {
return s.Send(req.RequestCtx, req.Session, t, encoder)
}
// SendResult sends RPC answer using msg as result.
func (s *Server) SendResult(req *Request, msg bin.Encoder) error {
var buf bin.Buffer
if err := msg.Encode(&buf); err != nil {
return errors.Wrap(err, "encode result")
}
if err := s.sendReq(req, proto.MessageServerResponse, &proto.Result{
RequestMessageID: req.MsgID,
Result: buf.Raw(),
}); err != nil {
return errors.Wrapf(err, "send result [%T]", msg)
}
return nil
}
// SendGZIP sends RPC answer and packs it into proto.GZIP.
func (s *Server) SendGZIP(req *Request, msg bin.Encoder) error {
var buf bin.Buffer
if err := msg.Encode(&buf); err != nil {
return errors.Wrap(err, "encode gzip data")
}
return s.SendResult(req, proto.GZIP{Data: buf.Buf})
}
// SendErr sends RPC answer using given error as result.
func (s *Server) SendErr(req *Request, e *tgerr.Error) error {
return s.SendResult(req, &mt.RPCError{
ErrorCode: e.Code,
ErrorMessage: e.Message,
})
}
// SendBool sends RPC answer using given bool as result.
// Usually used in methods without explicit response.
func (s *Server) SendBool(req *Request, r bool) error {
var msg tg.BoolClass = &tg.BoolTrue{}
if !r {
msg = &tg.BoolFalse{}
}
return s.SendResult(req, msg)
}
// SendVector sends RPC answer using given vector as result.
func (s *Server) SendVector(req *Request, msgs ...bin.Encoder) error {
return s.SendResult(req, &genericVector{Elems: msgs})
}
// sendSessionCreated sends mt.NewSessionCreated `new_session_created` notification.
func (s *Server) sendSessionCreated(ctx context.Context, k Session, serverSalt int64) error {
if err := s.Send(ctx, k, proto.MessageFromServer, &mt.NewSessionCreated{
FirstMsgID: s.msgID.New(proto.MessageFromClient),
ServerSalt: serverSalt,
}); err != nil {
return errors.Wrap(err, "send sessionCreated")
}
return nil
}
// SendPong sends response for mt.PingRequest request.
func (s *Server) SendPong(req *Request, pingID int64) error {
if err := s.sendReq(req, proto.MessageServerResponse, &mt.Pong{
MsgID: req.MsgID,
PingID: pingID,
}); err != nil {
return errors.Wrap(err, "send pong")
}
return nil
}
// SendEternalSalt sends response for mt.GetFutureSaltsRequest.
// It sends an `eternal` salt, which valid until maximum possible date.
func (s *Server) SendEternalSalt(req *Request) error {
return s.SendFutureSalts(req, mt.FutureSalt{
ValidSince: 1,
ValidUntil: math.MaxInt32,
Salt: 10,
})
}
// SendFutureSalts sends response for mt.GetFutureSaltsRequest.
func (s *Server) SendFutureSalts(req *Request, salts ...mt.FutureSalt) error {
if err := s.Send(req.RequestCtx, req.Session, proto.MessageServerResponse, &mt.FutureSalts{
ReqMsgID: req.MsgID,
Now: int(s.clock.Now().Unix()),
Salts: salts,
}); err != nil {
return errors.Wrap(err, "send future salts")
}
return nil
}
// SendUpdates sends given updates to user session k.
func (s *Server) SendUpdates(ctx context.Context, k Session, updates ...tg.UpdateClass) error {
if len(updates) == 0 {
return nil
}
if err := s.Send(ctx, k, proto.MessageFromServer, &tg.Updates{
Updates: updates,
Date: int(s.clock.Now().Unix()),
}); err != nil {
return errors.Wrap(err, "send updates")
}
return nil
}
// SendAck sends acknowledgment for received message.
func (s *Server) SendAck(ctx context.Context, k Session, ids ...int64) error {
if err := s.Send(ctx, k, proto.MessageFromServer, &mt.MsgsAck{MsgIDs: ids}); err != nil {
return errors.Wrap(err, "send ack")
}
return nil
}
// ForceDisconnect forcibly disconnect user from server.
// It deletes MTProto session (session_id), but not auth key.
func (s *Server) ForceDisconnect(k Session) {
s.users.deleteConnection(k.ID)
}
+135
View File
@@ -0,0 +1,135 @@
package tgtest
import (
"context"
"crypto/rsa"
"io"
"net"
"time"
"github.com/go-faster/errors"
"go.uber.org/zap"
"nhooyr.io/websocket"
"go.mau.fi/mautrix-telegram/pkg/gotd/clock"
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
"go.mau.fi/mautrix-telegram/pkg/gotd/exchange"
"go.mau.fi/mautrix-telegram/pkg/gotd/mtproto"
"go.mau.fi/mautrix-telegram/pkg/gotd/tdsync"
"go.mau.fi/mautrix-telegram/pkg/gotd/tmap"
"go.mau.fi/mautrix-telegram/pkg/gotd/transport"
)
// Server is a MTProto server structure.
type Server struct {
// DC ID of this server.
dcID int
// Key pair of this server.
key exchange.PrivateKey // immutable
// Codec constructor. May be nil.
codec func() transport.Codec // immutable,nilable
// Server-side message cipher.
cipher crypto.Cipher // immutable
// Clock to use in key exchange and message ID generation.
clock clock.Clock // immutable
// MessageID generator
msgID mtproto.MessageIDSource // immutable
readTimeout time.Duration
writeTimeout time.Duration
// RPC handler.
handler Handler // immutable
// users stores session info.
users *users
// type map for logging.
types *tmap.Map // immutable
log *zap.Logger // immutable
}
// NewPrivateKey creates new private key from RSA private key.
func NewPrivateKey(k *rsa.PrivateKey) exchange.PrivateKey {
return exchange.PrivateKey{
RSA: k,
}
}
// NewServer creates new Server.
func NewServer(key exchange.PrivateKey, handler Handler, opts ServerOptions) *Server {
opts.setDefaults()
s := &Server{
dcID: opts.DC,
key: key,
codec: opts.Codec,
cipher: crypto.NewServerCipher(opts.Random),
clock: opts.Clock,
msgID: opts.MessageID,
readTimeout: opts.ReadTimeout,
writeTimeout: opts.WriteTimeout,
handler: handler,
users: newUsers(),
types: opts.Types,
log: opts.Logger,
}
return s
}
// Key returns public key of this server.
func (s *Server) Key() exchange.PublicKey {
return s.key.Public()
}
// Serve runs server loop using given listener.
func (s *Server) Serve(ctx context.Context, l transport.Listener) error {
return s.serve(ctx, l)
}
func (s *Server) serve(ctx context.Context, l transport.Listener) error {
s.log.Info("Serving")
defer func() {
s.log.Info("Stopping")
}()
grp := tdsync.NewCancellableGroup(ctx)
grp.Go(func(context.Context) error {
for {
conn, err := l.Accept()
if err != nil {
if errors.Is(err, net.ErrClosed) {
return nil
}
return errors.Wrap(err, "accept")
}
grp.Go(func(ctx context.Context) error {
if err := s.serveConn(ctx, conn); err != nil {
// Client disconnected.
var syscallErr *net.OpError
switch {
case errors.Is(err, io.EOF):
return nil
case errors.As(err, &syscallErr) &&
(syscallErr.Op == "write" || syscallErr.Op == "read"):
return nil
}
// TODO(tdakkota): emulate errors too?
if code := websocket.CloseStatus(err); code >= 0 {
return nil
}
s.log.Info("Serving handler error", zap.Error(err))
}
return nil
})
}
})
grp.Go(func(ctx context.Context) error {
<-ctx.Done()
return l.Close()
})
return grp.Wait()
}
+75
View File
@@ -0,0 +1,75 @@
package tgtest
import (
"io"
"time"
"go.uber.org/zap"
"go.mau.fi/mautrix-telegram/pkg/gotd/clock"
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
"go.mau.fi/mautrix-telegram/pkg/gotd/mt"
"go.mau.fi/mautrix-telegram/pkg/gotd/mtproto"
"go.mau.fi/mautrix-telegram/pkg/gotd/proto"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
"go.mau.fi/mautrix-telegram/pkg/gotd/tmap"
"go.mau.fi/mautrix-telegram/pkg/gotd/transport"
)
// ServerOptions of Server.
type ServerOptions struct {
// DC ID of this server. Default to 2.
DC int
// Random is random source. Defaults to rand.Reader.
Random io.Reader
// Logger is instance of zap.Logger. No logs by default.
Logger *zap.Logger
// Codec constructor.
// Defaults to nil (underlying transport server detects protocol automatically).
Codec func() transport.Codec
// Clock to use. Defaults to clock.System.
Clock clock.Clock
// MessageID generator. Creates a new proto.MessageIDGen by default.
// Clock will be used for creation.
MessageID mtproto.MessageIDSource
// Types map, used in verbose logging of incoming message.
Types *tmap.Map
// ReadTimeout is a connection read timeout.
ReadTimeout time.Duration
// ReadTimeout is a connection write timeout.
WriteTimeout time.Duration
}
func (opt *ServerOptions) setDefaults() {
if opt.DC == 0 {
opt.DC = 2
}
if opt.Random == nil {
opt.Random = crypto.DefaultRand()
}
if opt.Logger == nil {
opt.Logger = zap.NewNop()
}
// Ignore opt.Codec, will be handled by transport.NewCustomServer.
if opt.Clock == nil {
opt.Clock = clock.System
}
if opt.MessageID == nil {
opt.MessageID = proto.NewMessageIDGen(opt.Clock.Now)
}
if opt.Types == nil {
opt.Types = tmap.New(
tg.TypesMap(),
mt.TypesMap(),
proto.TypesMap(),
)
}
if opt.ReadTimeout == 0 {
opt.ReadTimeout = 30 * time.Second
}
if opt.WriteTimeout == 0 {
opt.WriteTimeout = 30 * time.Second
}
}
+72
View File
@@ -0,0 +1,72 @@
// Package config contains config service implementation for tgtest server.
package config
import (
"context"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
"go.mau.fi/mautrix-telegram/pkg/gotd/tgtest"
"go.mau.fi/mautrix-telegram/pkg/gotd/tgtest/services"
)
// Service is a Telegram config service.
type Service struct {
cfg *tg.Config
cdnCfg *tg.CDNConfig
}
// NewService creates new Service.
func NewService(cfg *tg.Config, cdnCfg *tg.CDNConfig) *Service {
return &Service{cfg: cfg, cdnCfg: cdnCfg}
}
func (c *Service) HelpGetCDNConfig(ctx context.Context, req *tg.HelpGetCDNConfigRequest) (*tg.CDNConfig, error) {
cfg := c.cdnCfg
return cfg, nil
}
func (c *Service) HelpGetConfig(ctx context.Context, dc int, req *tg.HelpGetConfigRequest) (*tg.Config, error) {
cfg := *c.cfg
cfg.ThisDC = dc
return &cfg, nil
}
// OnMessage implements tgtest.Handler.
func (c *Service) OnMessage(server *tgtest.Server, req *tgtest.Request) error {
id, err := req.Buf.PeekID()
if err != nil {
return err
}
var (
decode bin.Decoder
result bin.Encoder
)
switch id {
case tg.HelpGetCDNConfigRequestTypeID:
cfg := c.cdnCfg
decode = &tg.HelpGetCDNConfigRequest{}
result = cfg
case tg.HelpGetConfigRequestTypeID:
cfg := *c.cfg
cfg.ThisDC = req.DC
decode = &tg.HelpGetConfigRequest{}
result = &cfg
default:
return services.ErrMethodNotImplemented
}
if err := decode.Decode(req.Buf); err != nil {
return err
}
return server.SendResult(req, result)
}
// Register registers service handlers.
func (c *Service) Register(dispatcher *tgtest.Dispatcher) {
dispatcher.HandleFunc(tg.HelpGetCDNConfigRequestTypeID, c.OnMessage)
dispatcher.HandleFunc(tg.HelpGetConfigRequestTypeID, c.OnMessage)
}
+2
View File
@@ -0,0 +1,2 @@
// Package services contains some Telegram services implemented for testing.
package services
+16
View File
@@ -0,0 +1,16 @@
package services
import (
"go.mau.fi/mautrix-telegram/pkg/gotd/tgerr"
"go.mau.fi/mautrix-telegram/pkg/gotd/tgtest"
)
var (
// ErrMethodNotImplemented denotes that method is not implemented.
ErrMethodNotImplemented error = tgerr.New(400, "INPUT_METHOD_INVALID")
// NotImplemented is a simple handler which returns ErrMethodNotImplemented.
NotImplemented tgtest.HandlerFunc = func(server *tgtest.Server, req *tgtest.Request) error {
return ErrMethodNotImplemented
}
)
+24
View File
@@ -0,0 +1,24 @@
package file
type Config struct {
// Storage to store files.
// InMemory will be used.
Storage Storage
// HashPartSize is a size of part to use in tg.FileHash.
HashPartSize int
// HashRangeSize is size of range to return in upload.getFileHashes.
HashRangeSize int
}
func (c *Config) setDefaults() {
if c.Storage == nil {
c.Storage = NewInMemory()
}
// Telegram usually uses this values.
if c.HashPartSize == 0 {
c.HashPartSize = 131072
}
if c.HashRangeSize == 0 {
c.HashRangeSize = 10
}
}
+136
View File
@@ -0,0 +1,136 @@
package file
import (
"context"
"fmt"
"github.com/go-faster/errors"
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
"go.mau.fi/mautrix-telegram/pkg/gotd/tgerr"
)
func getLocation(loc tg.InputFileLocationClass) (string, error) {
v, ok := loc.(interface {
GetLocalID() int
GetVolumeID() int64
})
if !ok {
return "", tgerr.New(400, tg.ErrFileIDInvalid)
}
return fmt.Sprintf("%d_%d", v.GetLocalID(), v.GetVolumeID()), nil
}
func (m *Service) openLocation(loc tg.InputFileLocationClass) (File, error) {
name, err := getLocation(loc)
if err != nil {
return nil, err
}
f, err := m.storage.Open(name)
if err != nil {
return nil, tgerr.New(400, tg.ErrFileIDInvalid)
}
return f, nil
}
func (m *Service) getPart(loc tg.InputFileLocationClass, offset int64, limit int) ([]byte, error) {
f, err := m.openLocation(loc)
if err != nil {
return nil, err
}
r := make([]byte, limit)
n, err := f.ReadAt(r, offset)
if err != nil {
return nil, errors.Wrap(err, "read from storage")
}
return r[:n], nil
}
func (m *Service) UploadGetFile(ctx context.Context, request *tg.UploadGetFileRequest) (tg.UploadFileClass, error) {
data, err := m.getPart(request.Location, request.Offset, request.Limit)
if err != nil {
return nil, err
}
return &tg.UploadFile{
Type: &tg.StorageFilePartial{},
Mtime: 0,
Bytes: data,
}, nil
}
func countHashes(data []byte, offset int64, partSize int) []tg.FileHash {
actions := data
batchSize := partSize
batches := make([][]byte, 0, (len(actions)+batchSize-1)/batchSize)
for batchSize < len(actions) {
actions, batches = actions[batchSize:], append(batches, actions[0:batchSize:batchSize])
}
batches = append(batches, actions)
currentRange := make([]tg.FileHash, 0, 10)
for _, batch := range batches {
currentRange = append(currentRange, tg.FileHash{
Offset: offset,
Limit: partSize,
Hash: crypto.SHA256(batch),
})
offset += int64(len(batch))
}
return currentRange
}
func divAndCeil(a, b int) int {
r := a / b
if a%b != 0 {
r++
}
return r
}
// computeBatch computes hash range number for given offset.
func computeBatch(offset int64, rangeSize, partSize int) int {
// Compute number of parts in partSize from offset.
parts := divAndCeil(int(offset+1), partSize)
// Compute number of hash ranges in rangeSize.
batches := divAndCeil(parts, rangeSize)
return batches
}
func (m *Service) UploadGetFileHashes(
ctx context.Context,
request *tg.UploadGetFileHashesRequest,
) ([]tg.FileHash, error) {
f, err := m.openLocation(request.Location)
if err != nil {
return nil, err
}
if request.Offset >= int64(f.Size()) {
return nil, nil
}
partSize := m.hashPartSize
rangeSize := m.hashRangeSize
batch := computeBatch(request.Offset, rangeSize, partSize)
low := (batch - 1) * rangeSize * partSize
high := batch * rangeSize * partSize
r := make([]byte, high-low)
n, err := f.ReadAt(r, int64(low))
if err != nil {
return nil, err
}
r = r[:n]
return countHashes(r, int64(low), partSize), nil
}
+97
View File
@@ -0,0 +1,97 @@
// Package file contains file service implementation for tgtest server.
package file
import (
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
"go.mau.fi/mautrix-telegram/pkg/gotd/tgtest"
"go.mau.fi/mautrix-telegram/pkg/gotd/tgtest/services"
)
// Service is a Telegram file service.
type Service struct {
storage Storage
// Size of part to use in tg.FileHash.
hashPartSize int
// Size of range to return in upload.getFileHashes.
hashRangeSize int
}
// NewService creates new file Service.
func NewService(cfg Config) *Service {
cfg.setDefaults()
return &Service{
storage: cfg.Storage,
hashPartSize: cfg.HashPartSize,
hashRangeSize: cfg.HashRangeSize,
}
}
// OnMessage implements tgtest.Handler.
func (m *Service) OnMessage(server *tgtest.Server, req *tgtest.Request) error {
id, err := req.Buf.PeekID()
if err != nil {
return err
}
switch id {
case tg.UploadGetFileRequestTypeID:
fileReq := tg.UploadGetFileRequest{}
if err := fileReq.Decode(req.Buf); err != nil {
return err
}
r, err := m.UploadGetFile(req.RequestCtx, &fileReq)
if err != nil {
return err
}
return server.SendResult(req, r)
case tg.UploadGetFileHashesRequestTypeID:
fileReq := tg.UploadGetFileHashesRequest{}
if err := fileReq.Decode(req.Buf); err != nil {
return err
}
r, err := m.UploadGetFileHashes(req.RequestCtx, &fileReq)
if err != nil {
return err
}
return server.SendResult(req, &tg.FileHashVector{Elems: r})
case tg.UploadSaveFilePartRequestTypeID:
fileReq := tg.UploadSaveFilePartRequest{}
if err := fileReq.Decode(req.Buf); err != nil {
return err
}
r, err := m.UploadSaveFilePart(req.RequestCtx, &fileReq)
if err != nil {
return err
}
return server.SendBool(req, r)
case tg.UploadSaveBigFilePartRequestTypeID:
fileReq := tg.UploadSaveBigFilePartRequest{}
if err := fileReq.Decode(req.Buf); err != nil {
return err
}
r, err := m.UploadSaveBigFilePart(req.RequestCtx, &fileReq)
if err != nil {
return err
}
return server.SendBool(req, r)
default:
return services.ErrMethodNotImplemented
}
}
// Register registers service handlers.
func (m *Service) Register(dispatcher *tgtest.Dispatcher) {
dispatcher.HandleFunc(tg.UploadGetFileRequestTypeID, m.OnMessage)
dispatcher.HandleFunc(tg.UploadGetFileHashesRequestTypeID, m.OnMessage)
dispatcher.HandleFunc(tg.UploadSaveFilePartRequestTypeID, m.OnMessage)
dispatcher.HandleFunc(tg.UploadSaveBigFilePartRequestTypeID, m.OnMessage)
}
+72
View File
@@ -0,0 +1,72 @@
package file
import (
"io"
"sync"
"sync/atomic"
"go.mau.fi/mautrix-telegram/pkg/gotd/syncio"
)
// File represents Telegram file.
type File interface {
io.ReaderAt
io.WriterAt
io.Closer
PartSize() int
SetPartSize(v int)
Size() int
}
// Storage is an abstraction for Telegram file storage.
type Storage interface {
Open(name string) (File, error)
}
type memFile struct {
syncio.BufWriterAt
partSize int32
_ [4]byte
}
func (m *memFile) Size() int {
return m.Len()
}
func (m *memFile) Close() error {
return nil
}
func (m *memFile) PartSize() int {
return int(atomic.LoadInt32(&m.partSize))
}
func (m *memFile) SetPartSize(v int) {
atomic.StoreInt32(&m.partSize, int32(v))
}
// InMemory is an inmemory implementation of file storage.
type InMemory struct {
files map[string]*memFile
filesMux sync.Mutex
}
// NewInMemory creates new InMemory.
func NewInMemory() *InMemory {
return &InMemory{
files: map[string]*memFile{},
}
}
// Open implement Storage.
func (i *InMemory) Open(name string) (File, error) {
i.filesMux.Lock()
defer i.filesMux.Unlock()
file, ok := i.files[name]
if !ok {
file = &memFile{}
i.files[name] = file
}
return file, nil
}
+115
View File
@@ -0,0 +1,115 @@
package file
import (
"context"
"fmt"
"github.com/go-faster/errors"
"go.uber.org/multierr"
"go.mau.fi/mautrix-telegram/pkg/gotd/constant"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
"go.mau.fi/mautrix-telegram/pkg/gotd/tgerr"
)
// https://core.telegram.org/api/files#uploading-files
const (
// Each part should have a sequence number, file_part, with a value ranging from 0 to 3,999.
uploadPartsLimit = constant.UploadMaxParts
// `part_size % 1024 = 0` (divisible by 1KB)
uploadPaddingPartSize = constant.UploadPadding
// `524288 % part_size = 0` (512KB must be evenly divisible by part_size)
uploadMaximumPartSize = constant.UploadMaxPartSize
)
type upload interface {
// GetFileID returns random file identifier created by the client.
GetFileID() int64
// GetFilePart returns numerical order of a part.
GetFilePart() int
// GetBytes returns binary data, content of a part.
GetBytes() []byte
}
func validatePartSize(got, stored int) *tgerr.Error {
switch {
case got == 0:
return tgerr.New(400, tg.ErrFilePartEmpty)
case got > uploadMaximumPartSize:
return tgerr.New(400, tg.ErrFilePartTooBig)
}
if stored == 0 {
return nil
}
switch {
case got != stored:
return tgerr.New(400, tg.ErrFilePartSizeChanged)
case uploadMaximumPartSize%got != 0,
got%uploadPaddingPartSize != 0:
return tgerr.New(400, tg.ErrFilePartSizeInvalid)
default:
return nil
}
}
func (m *Service) write(ctx context.Context, request upload) (err error) {
// TODO(tdakkota): Better way to handle user id. For now we haven't auth service to pair
// user ID and authkey
id, ok := ctx.Value("user_id").(int)
if !ok {
id = 10
}
file, err := m.storage.Open(fmt.Sprintf("%d_%d", id, request.GetFileID()))
if err != nil {
return errors.Wrap(err, "open file")
}
defer func() {
multierr.AppendInto(&err, file.Close())
}()
part := request.GetFilePart()
if part < 0 || part > uploadPartsLimit {
return tgerr.New(400, tg.ErrFilePartInvalid)
}
data := request.GetBytes()
partSize := file.PartSize()
if err := validatePartSize(len(data), partSize); err != nil {
return err
}
if partSize == 0 {
partSize = len(data)
file.SetPartSize(partSize)
}
offset := int64(partSize * part)
if _, err := file.WriteAt(data, offset); err != nil {
return errors.Errorf("write at %d-%d", offset, offset+int64(len(data)))
}
return nil
}
func (m *Service) UploadSaveFilePart(ctx context.Context, request *tg.UploadSaveFilePartRequest) (bool, error) {
if err := m.write(ctx, request); err != nil {
return false, err
}
return true, nil
}
func (m *Service) UploadSaveBigFilePart(ctx context.Context, request *tg.UploadSaveBigFilePartRequest) (bool, error) {
part := request.FileTotalParts
if part < 0 || part > uploadPartsLimit {
return false, tgerr.New(400, tg.ErrFilePartsInvalid)
}
if err := m.write(ctx, request); err != nil {
return false, err
}
return true, nil
}
+24
View File
@@ -0,0 +1,24 @@
package tgtest
import (
"encoding/hex"
"go.uber.org/zap/zapcore"
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
)
// Session represents connection session.
type Session struct {
// ID is a Session ID.
ID int64
// AuthKey is an attached key.
AuthKey crypto.AuthKey
}
// MarshalLogObject implements zap.ObjectMarshaler.
func (s Session) MarshalLogObject(encoder zapcore.ObjectEncoder) error {
encoder.AddInt64("session_id", s.ID)
encoder.AddString("key_id", hex.EncodeToString(s.AuthKey.ID[:]))
return nil
}
+88
View File
@@ -0,0 +1,88 @@
package tgtest
import (
"fmt"
"sync"
"testing"
"github.com/stretchr/testify/assert"
"go.uber.org/zap"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
)
type testTransportHandler struct {
t testing.TB
logger *zap.Logger
// For ACK testing proposes.
// We send ack only after second request
counter int
counterMx sync.Mutex
message string // immutable
}
// TestTransport is a handler for testing MTProto transport.
func TestTransport(t testing.TB, logger *zap.Logger, message string) Handler {
return &testTransportHandler{
t: t,
logger: logger,
message: message,
}
}
func (h *testTransportHandler) OnMessage(server *Server, req *Request) error {
id, err := req.Buf.PeekID()
if err != nil {
return err
}
h.logger.Info("New message", zap.String("id", fmt.Sprintf("%#x", id)))
switch id {
case tg.UsersGetUsersRequestTypeID:
getUsers := tg.UsersGetUsersRequest{}
if err := getUsers.Decode(req.Buf); err != nil {
return err
}
h.logger.Info("New client connected, invoke received")
if err := server.SendVector(req, &tg.User{
ID: 10,
AccessHash: 10,
Username: "rustcocks",
}); err != nil {
return err
}
h.logger.Info("Sending message", zap.String("message", h.message))
return server.SendUpdates(req.RequestCtx, req.Session, &tg.UpdateNewMessage{
Message: &tg.Message{
ID: 1,
PeerID: &tg.PeerUser{UserID: 1},
Message: h.message,
},
})
case tg.MessagesSendMessageRequestTypeID:
m := &tg.MessagesSendMessageRequest{}
if err := m.Decode(req.Buf); err != nil {
h.t.Fail()
return err
}
assert.Equal(h.t, "какими деньгами?", m.Message)
h.counterMx.Lock()
h.counter++
if h.counter < 2 {
h.counterMx.Unlock()
return nil
}
h.counterMx.Unlock()
return server.SendResult(req, &tg.Updates{})
}
return nil
}
+26
View File
@@ -0,0 +1,26 @@
package tgtest
import (
"fmt"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
)
// genericVector is a simple helper to encode a vector of TL objects.
type genericVector struct {
Elems []bin.Encoder
}
// Encode implements bin.Encoder.
func (vec *genericVector) Encode(b *bin.Buffer) error {
b.PutVectorHeader(len(vec.Elems))
for idx, v := range vec.Elems {
if v == nil {
return fmt.Errorf("unable to encode Vector<%T>: field Elems element with index %d is nil", v, idx)
}
if err := v.Encode(b); err != nil {
return fmt.Errorf("unable to encode Vector<%T>: field Elems element with index %d: %w", v, idx, err)
}
}
return nil
}
+56
View File
@@ -0,0 +1,56 @@
package tgtest
import (
"testing"
"github.com/stretchr/testify/require"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
"go.mau.fi/mautrix-telegram/pkg/gotd/testutil"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
)
type badEncoder struct{}
func (e badEncoder) Encode(b *bin.Buffer) error {
return testutil.TestError()
}
func TestGenericVector_Encode(t *testing.T) {
tests := []struct {
name string
data []bin.Encoder
expect bin.Decoder
wantErr bool
}{
{"Empty", nil, nil, false},
{"Nil", []bin.Encoder{nil}, nil, true},
{"BadObject", []bin.Encoder{badEncoder{}}, nil, true},
{
"Plain",
[]bin.Encoder{&tg.BotCommand{
Command: "hello",
Description: "world",
}},
&tg.BotCommandVector{},
false,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
var (
a = require.New(t)
v = genericVector{Elems: test.data}
buf bin.Buffer
)
err := v.Encode(&buf)
if test.wantErr {
a.Error(err)
} else if test.expect != nil {
a.NoError(test.expect.Decode(&buf))
}
})
}
}