move gotd fork into repo. (#111)
- update to latest telegram layer - remove some references to fields in tg.Entities that don't exist in the schema - originally added here: https://github.com/beeper/td/commit/820929062a2ba0104397bc01235ab58a9cff780e - referenced here - https://github.com/mautrix/telegramgo/commit/124f0967ed195b5a380c9bd02e170ada9710dde3 - https://github.com/mautrix/telegramgo/commit/4205047aab2e0639217148b5d125bfaab668bd8e
This commit is contained in:
@@ -0,0 +1,16 @@
|
||||
-----BEGIN RSA PUBLIC KEY-----
|
||||
MIIBCgKCAQEAyMEdY1aR+sCR3ZSJrtztKTKqigvO/vBfqACJLZtS7QMgCGXJ6XIR
|
||||
yy7mx66W0/sOFa7/1mAZtEoIokDP3ShoqF4fVNb6XeqgQfaUHd8wJpDWHcR2OFwv
|
||||
plUUI1PLTktZ9uW2WE23b+ixNwJjJGwBDJPQEQFBE+vfmH0JP503wr5INS1poWg/
|
||||
j25sIWeYPHYeOrFp/eXaqhISP6G+q2IeTaWTXpwZj4LzXq5YOpk4bYEQ6mvRq7D1
|
||||
aHWfYmlEGepfaYR8Q0YqvvhYtMte3ITnuSJs171+GDqpdKcSwHnd6FudwGO4pcCO
|
||||
j4WcDuXc2CTHgH8gFTNhp/Y8/SpDOhvn9QIDAQAB
|
||||
-----END RSA PUBLIC KEY-----
|
||||
-----BEGIN RSA PUBLIC KEY-----
|
||||
MIIBCgKCAQEA6LszBcC1LGzyr992NzE0ieY+BSaOW622Aa9Bd4ZHLl+TuFQ4lo4g
|
||||
5nKaMBwK/BIb9xUfg0Q29/2mgIR6Zr9krM7HjuIcCzFvDtr+L0GQjae9H0pRB2OO
|
||||
62cECs5HKhT5DZ98K33vmWiLowc621dQuwKWSQKjWf50XYFw42h21P2KXUGyp2y/
|
||||
+aEyZ+uVgLLQbRA1dEjSDZ2iGRy12Mk5gpYc397aYp438fsJoHIgJ2lgMv5h7WY9
|
||||
t6N/byY9Nw9p21Og3AoXSL2q/2IJ1WRUhebgAdGVMlV1fkuOQoEzR7EdpqtQD9Cs
|
||||
5+bfo3Nhmcyvk5ftB0WkJ9z6bNZ7yxrP8wIDAQAB
|
||||
-----END RSA PUBLIC KEY-----
|
||||
@@ -0,0 +1,46 @@
|
||||
package mtproto
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/mt"
|
||||
)
|
||||
|
||||
func (c *Conn) ackLoop(ctx context.Context) error {
|
||||
log := c.log.Named("ack")
|
||||
|
||||
var buf []int64
|
||||
send := func() {
|
||||
defer func() { buf = buf[:0] }()
|
||||
|
||||
if err := c.writeServiceMessage(ctx, &mt.MsgsAck{MsgIDs: buf}); err != nil {
|
||||
c.log.Error("Failed to ACK", zap.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
log.Debug("Ack", zap.Int64s("msg_ids", buf))
|
||||
}
|
||||
|
||||
ticker := c.clock.Ticker(c.ackInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return errors.Wrap(ctx.Err(), "acl")
|
||||
case <-ticker.C():
|
||||
if len(buf) > 0 {
|
||||
send()
|
||||
}
|
||||
case msgID := <-c.ackSendChan:
|
||||
buf = append(buf, msgID)
|
||||
if len(buf) >= c.ackBatchSize {
|
||||
send()
|
||||
ticker.Reset(c.ackInterval)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,221 @@
|
||||
package mtproto
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
"go.uber.org/atomic"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/clock"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/exchange"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/mtproto/salts"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/proto"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/rpc"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tdsync"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tmap"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/transport"
|
||||
)
|
||||
|
||||
// Handler will be called on received message from Telegram.
|
||||
type Handler interface {
|
||||
OnMessage(b *bin.Buffer) error
|
||||
OnSession(session Session) error
|
||||
}
|
||||
|
||||
// MessageIDSource is message id generator.
|
||||
type MessageIDSource interface {
|
||||
New(t proto.MessageType) int64
|
||||
}
|
||||
|
||||
// MessageBuf is message id buffer.
|
||||
type MessageBuf interface {
|
||||
Consume(id int64) bool
|
||||
}
|
||||
|
||||
// Cipher handles message encryption and decryption.
|
||||
type Cipher interface {
|
||||
DecryptFromBuffer(k crypto.AuthKey, buf *bin.Buffer) (*crypto.EncryptedMessageData, error)
|
||||
Encrypt(key crypto.AuthKey, data crypto.EncryptedMessageData, b *bin.Buffer) error
|
||||
}
|
||||
|
||||
// Dialer is an abstraction for MTProto transport connection creator.
|
||||
type Dialer func(ctx context.Context) (transport.Conn, error)
|
||||
|
||||
// Conn represents a MTProto client to Telegram.
|
||||
type Conn struct {
|
||||
dcID int
|
||||
|
||||
dialer Dialer
|
||||
conn transport.Conn
|
||||
handler Handler
|
||||
rpc *rpc.Engine
|
||||
rsaPublicKeys []exchange.PublicKey
|
||||
types *tmap.Map
|
||||
|
||||
// Wrappers for external world, like current time, logs or PRNG.
|
||||
// Should be immutable.
|
||||
clock clock.Clock
|
||||
rand io.Reader
|
||||
cipher Cipher
|
||||
log *zap.Logger
|
||||
messageID MessageIDSource
|
||||
messageIDBuf MessageBuf // replay attack protection
|
||||
|
||||
// use session() to access authKey, salt or sessionID.
|
||||
sessionMux sync.RWMutex
|
||||
authKey crypto.AuthKey
|
||||
salt int64
|
||||
sessionID int64
|
||||
|
||||
// server salts fetched by getSalts.
|
||||
salts salts.Salts
|
||||
|
||||
// sentContentMessages is count of created content messages, used to
|
||||
// compute sequence number within session.
|
||||
sentContentMessages int32
|
||||
reqMux sync.Mutex
|
||||
|
||||
// ackSendChan is queue for outgoing message id's that require waiting for
|
||||
// ack from server.
|
||||
ackSendChan chan int64
|
||||
ackBatchSize int
|
||||
ackInterval time.Duration
|
||||
|
||||
// callbacks for ping results.
|
||||
// Key is ping id.
|
||||
ping map[int64]chan struct{}
|
||||
pingMux sync.Mutex
|
||||
// pingTimeout sets ping_delay_disconnect delay.
|
||||
pingTimeout time.Duration
|
||||
// pingInterval is duration between ping_delay_disconnect request.
|
||||
pingInterval time.Duration
|
||||
pingCallback func()
|
||||
|
||||
// gotSession is a signal channel for wait for handleSessionCreated message.
|
||||
gotSession *tdsync.Ready
|
||||
|
||||
// exchangeLock locks write calls during key exchange.
|
||||
exchangeLock sync.RWMutex
|
||||
|
||||
// compressThreshold is a threshold in bytes to determine that message
|
||||
// is large enough to be compressed using gzip.
|
||||
compressThreshold int
|
||||
dialTimeout time.Duration
|
||||
exchangeTimeout time.Duration
|
||||
saltFetchInterval time.Duration
|
||||
getTimeout func(req uint32) time.Duration
|
||||
// Ensure Run once.
|
||||
ran atomic.Bool
|
||||
}
|
||||
|
||||
// New creates new unstarted connection.
|
||||
func New(dialer Dialer, opt Options) *Conn {
|
||||
// Set default values, if user does not set.
|
||||
opt.setDefaults()
|
||||
|
||||
conn := &Conn{
|
||||
dcID: opt.DC,
|
||||
|
||||
dialer: dialer,
|
||||
clock: opt.Clock,
|
||||
rand: opt.Random,
|
||||
cipher: opt.Cipher,
|
||||
log: opt.Logger,
|
||||
messageID: opt.MessageID,
|
||||
messageIDBuf: proto.NewMessageIDBuf(100),
|
||||
|
||||
ackSendChan: make(chan int64),
|
||||
ackInterval: opt.AckInterval,
|
||||
ackBatchSize: opt.AckBatchSize,
|
||||
|
||||
rsaPublicKeys: opt.PublicKeys,
|
||||
handler: opt.Handler,
|
||||
types: opt.Types,
|
||||
|
||||
authKey: opt.Key,
|
||||
salt: opt.Salt,
|
||||
|
||||
ping: map[int64]chan struct{}{},
|
||||
pingTimeout: opt.PingTimeout,
|
||||
pingInterval: opt.PingInterval,
|
||||
pingCallback: opt.PingCallback,
|
||||
|
||||
gotSession: tdsync.NewReady(),
|
||||
|
||||
rpc: opt.engine,
|
||||
compressThreshold: opt.CompressThreshold,
|
||||
dialTimeout: opt.DialTimeout,
|
||||
exchangeTimeout: opt.ExchangeTimeout,
|
||||
saltFetchInterval: opt.SaltFetchInterval,
|
||||
getTimeout: opt.RequestTimeout,
|
||||
}
|
||||
if conn.rpc == nil {
|
||||
conn.rpc = rpc.New(conn.writeContentMessage, rpc.Options{
|
||||
Logger: opt.Logger.Named("rpc"),
|
||||
RetryInterval: opt.RetryInterval,
|
||||
MaxRetries: opt.MaxRetries,
|
||||
Clock: opt.Clock,
|
||||
DropHandler: conn.dropRPC,
|
||||
OnError: opt.OnError,
|
||||
})
|
||||
}
|
||||
|
||||
return conn
|
||||
}
|
||||
|
||||
// handleClose closes rpc engine and underlying connection on context done.
|
||||
func (c *Conn) handleClose(ctx context.Context) error {
|
||||
<-ctx.Done()
|
||||
c.log.Debug("Closing")
|
||||
|
||||
// Close RPC Engine.
|
||||
c.rpc.ForceClose()
|
||||
// Close connection.
|
||||
if err := c.conn.Close(); err != nil {
|
||||
c.log.Debug("Failed to cleanup connection", zap.Error(err))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Run initializes MTProto connection to server and blocks until disconnection.
|
||||
//
|
||||
// When connection is ready, Handler.OnSession is called.
|
||||
func (c *Conn) Run(ctx context.Context, f func(ctx context.Context) error) error {
|
||||
// Starting connection.
|
||||
//
|
||||
// This will send initial packet to telegram and perform key exchange
|
||||
// if needed.
|
||||
if c.ran.Swap(true) {
|
||||
return errors.New("do Run on closed connection")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
c.log.Debug("Run: start")
|
||||
defer c.log.Debug("Run: end")
|
||||
if err := c.connect(ctx); err != nil {
|
||||
return errors.Wrap(err, "start")
|
||||
}
|
||||
{
|
||||
// All goroutines are bound to current call.
|
||||
g := tdsync.NewLogGroup(ctx, c.log.Named("group"))
|
||||
g.Go("handleClose", c.handleClose)
|
||||
g.Go("pingLoop", c.pingLoop)
|
||||
g.Go("ackLoop", c.ackLoop)
|
||||
g.Go("saltsLoop", c.saltLoop)
|
||||
g.Go("userCallback", f)
|
||||
g.Go("readLoop", c.readLoop)
|
||||
|
||||
if err := g.Wait(); err != nil {
|
||||
return errors.Wrap(err, "group")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,92 @@
|
||||
package mtproto
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/md5"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/mt"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/proto"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/rpc"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tmap"
|
||||
)
|
||||
|
||||
type testHandler func(msgID int64, seqNo int32, body bin.Encoder) (bin.Encoder, error)
|
||||
|
||||
type testClientOption func(o Options)
|
||||
|
||||
func newTestClient(h testHandler, opts ...testClientOption) *Conn {
|
||||
var engine *rpc.Engine
|
||||
|
||||
engine = rpc.New(func(ctx context.Context, msgID int64, seqNo int32, in bin.Encoder) error {
|
||||
if response, err := h(msgID, seqNo, in); err != nil {
|
||||
engine.NotifyError(msgID, err)
|
||||
} else {
|
||||
var b bin.Buffer
|
||||
if err := b.Encode(response); err != nil {
|
||||
return err
|
||||
}
|
||||
return engine.NotifyResult(msgID, &b)
|
||||
}
|
||||
return nil
|
||||
}, rpc.Options{})
|
||||
|
||||
opt := Options{
|
||||
Logger: zap.NewNop(),
|
||||
Random: rand.New(rand.NewSource(1)),
|
||||
Key: crypto.Key{}.WithID(),
|
||||
MessageID: proto.NewMessageIDGen(time.Now),
|
||||
|
||||
engine: engine,
|
||||
}
|
||||
for _, o := range opts {
|
||||
o(opt)
|
||||
}
|
||||
|
||||
return New(nil, opt)
|
||||
}
|
||||
|
||||
// newCorpusTracer will save incoming messages to corpus folder.
|
||||
//
|
||||
// Usage:
|
||||
//
|
||||
// client.trace.OnMessage = newCorpusTracer(t)
|
||||
//
|
||||
// nolint: deadcode,unused // optional
|
||||
func newCorpusTracer(t testing.TB) func(b *bin.Buffer) {
|
||||
types := tmap.New(
|
||||
mt.TypesMap(),
|
||||
tg.TypesMap(),
|
||||
proto.TypesMap(),
|
||||
)
|
||||
dir := filepath.Join("..", "_fuzz", "handle_message", "corpus")
|
||||
|
||||
return func(b *bin.Buffer) {
|
||||
id, _ := b.PeekID()
|
||||
h := md5.Sum(b.Buf)
|
||||
name := types.Get(id)
|
||||
if name == "" {
|
||||
name = "unknown"
|
||||
}
|
||||
if idx := strings.Index(name, "#"); idx > 0 {
|
||||
// Removing type id from name.
|
||||
name = name[:idx]
|
||||
}
|
||||
base := fmt.Sprintf("trace_%x_%s_%x",
|
||||
id, name, h,
|
||||
)
|
||||
assert.NoError(t, os.WriteFile(filepath.Join(dir, base), b.Buf, 0600))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,93 @@
|
||||
package mtproto
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
"go.uber.org/multierr"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/exchange"
|
||||
)
|
||||
|
||||
// connect establishes connection using configured transport, creating
|
||||
// new auth key if needed.
|
||||
func (c *Conn) connect(ctx context.Context) (rErr error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, c.dialTimeout)
|
||||
defer cancel()
|
||||
|
||||
conn, err := c.dialer(ctx)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "dial failed")
|
||||
}
|
||||
c.conn = conn
|
||||
defer func() {
|
||||
if rErr != nil {
|
||||
multierr.AppendInto(&rErr, conn.Close())
|
||||
}
|
||||
}()
|
||||
|
||||
session := c.session()
|
||||
if session.Key.Zero() {
|
||||
c.log.Info("Generating new auth key")
|
||||
start := c.clock.Now()
|
||||
if err := c.createAuthKey(ctx); err != nil {
|
||||
return errors.Wrap(err, "create auth key")
|
||||
}
|
||||
|
||||
c.log.Info("Auth key generated",
|
||||
zap.Duration("duration", c.clock.Now().Sub(start)),
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
c.log.Info("Key already exists")
|
||||
if session.ID == 0 {
|
||||
// NB: Telegram can return 404 error if session id is zero.
|
||||
//
|
||||
// See https://github.com/gotd/td/issues/107.
|
||||
c.log.Debug("Generating new session id")
|
||||
if err := c.newSessionID(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// createAuthKey generates new authorization key.
|
||||
func (c *Conn) createAuthKey(ctx context.Context) error {
|
||||
// Grab exclusive lock for writing.
|
||||
// It prevents message sending during key regeneration if server forgot current auth key.
|
||||
c.exchangeLock.Lock()
|
||||
defer c.exchangeLock.Unlock()
|
||||
|
||||
if ce := c.log.Check(zap.DebugLevel, "Initializing new key exchange"); ce != nil {
|
||||
// Useful for debugging i/o timeout errors on tcp reads or writes.
|
||||
fields := []zap.Field{
|
||||
zap.Duration("timeout", c.exchangeTimeout),
|
||||
}
|
||||
if deadline, ok := ctx.Deadline(); ok {
|
||||
fields = append(fields, zap.Time("context_deadline", deadline))
|
||||
}
|
||||
ce.Write(fields...)
|
||||
}
|
||||
|
||||
r, err := exchange.NewExchanger(c.conn, c.dcID).
|
||||
WithClock(c.clock).
|
||||
WithLogger(c.log.Named("exchange")).
|
||||
WithTimeout(c.exchangeTimeout).
|
||||
WithRand(c.rand).
|
||||
Client(c.rsaPublicKeys).Run(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.sessionMux.Lock()
|
||||
c.authKey = r.AuthKey
|
||||
c.sessionID = r.SessionID
|
||||
c.salt = r.ServerSalt
|
||||
c.sessionMux.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,74 @@
|
||||
package mtproto
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"io"
|
||||
"testing"
|
||||
"testing/iotest"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/clock"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/transport"
|
||||
)
|
||||
|
||||
type closeConn struct {
|
||||
closed bool
|
||||
}
|
||||
|
||||
func (c *closeConn) Send(ctx context.Context, b *bin.Buffer) error {
|
||||
return io.EOF
|
||||
}
|
||||
|
||||
func (c *closeConn) Recv(ctx context.Context, b *bin.Buffer) error {
|
||||
return io.EOF
|
||||
}
|
||||
|
||||
func (c *closeConn) Close() error {
|
||||
c.closed = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestConn_connect(t *testing.T) {
|
||||
t.Run("EnsureClose", func(t *testing.T) {
|
||||
t.Run("Exchange", func(t *testing.T) {
|
||||
a := require.New(t)
|
||||
|
||||
closeMe := &closeConn{}
|
||||
c := Conn{
|
||||
dialer: func(ctx context.Context) (transport.Conn, error) {
|
||||
return closeMe, nil
|
||||
},
|
||||
clock: clock.System,
|
||||
rand: rand.Reader,
|
||||
log: zap.NewNop(),
|
||||
}
|
||||
|
||||
a.Error(c.connect(context.Background()))
|
||||
a.True(closeMe.closed)
|
||||
})
|
||||
t.Run("SessionID", func(t *testing.T) {
|
||||
a := require.New(t)
|
||||
|
||||
closeMe := &closeConn{}
|
||||
c := Conn{
|
||||
dialer: func(ctx context.Context) (transport.Conn, error) {
|
||||
return closeMe, nil
|
||||
},
|
||||
clock: clock.System,
|
||||
authKey: crypto.AuthKey{
|
||||
ID: [8]byte{1}, // Skip exchange.
|
||||
},
|
||||
rand: iotest.ErrReader(io.EOF),
|
||||
log: zap.NewNop(),
|
||||
}
|
||||
|
||||
a.Error(c.connect(context.Background()))
|
||||
a.True(closeMe.closed)
|
||||
})
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,2 @@
|
||||
// Package mtproto implements MTProto connection.
|
||||
package mtproto
|
||||
@@ -0,0 +1,51 @@
|
||||
package mtproto
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/gotd/neo"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/testutil"
|
||||
)
|
||||
|
||||
func benchEncryption(b *testing.B, c *Conn, n int) {
|
||||
b.Helper()
|
||||
|
||||
buf := &bin.Buffer{Buf: make([]byte, 0, n)}
|
||||
p := testPayload{Data: make([]byte, n-4)}
|
||||
b.ReportAllocs()
|
||||
b.SetBytes(int64(n))
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
buf.Reset()
|
||||
if err := c.newEncryptedMessage(12345, 0, p, buf); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkEncryption(b *testing.B) {
|
||||
c := &Conn{
|
||||
rand: Zero{},
|
||||
log: zap.NewNop(),
|
||||
cipher: crypto.NewClientCipher(Zero{}),
|
||||
clock: neo.NewTime(time.Now()),
|
||||
compressThreshold: -1,
|
||||
}
|
||||
for i := 0; i < 256; i++ {
|
||||
c.authKey.Value[i] = byte(i)
|
||||
}
|
||||
|
||||
for _, payload := range testutil.Payloads() {
|
||||
b.Run(fmt.Sprintf("%db", payload), func(b *testing.B) {
|
||||
benchEncryption(b, c, payload)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,21 @@
|
||||
package mtproto
|
||||
|
||||
import (
|
||||
"github.com/go-faster/errors"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/mt"
|
||||
)
|
||||
|
||||
func (c *Conn) handleAck(b *bin.Buffer) error {
|
||||
var ack mt.MsgsAck
|
||||
if err := ack.Decode(b); err != nil {
|
||||
return errors.Wrap(err, "decode")
|
||||
}
|
||||
|
||||
c.log.Debug("Received ack", zap.Int64s("msg_ids", ack.MsgIDs))
|
||||
c.rpc.NotifyAcks(ack.MsgIDs)
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,68 @@
|
||||
package mtproto
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/mt"
|
||||
)
|
||||
|
||||
type badMessageError struct {
|
||||
Code int
|
||||
NewSalt int64
|
||||
}
|
||||
|
||||
const (
|
||||
codeMessageIDTooLow = 16
|
||||
codeMessageIDTooHigh = 17
|
||||
codeIncorrectServerSalt = 48
|
||||
)
|
||||
|
||||
func (c badMessageError) Error() string {
|
||||
description := map[int]string{
|
||||
codeMessageIDTooLow: "msg_id too low",
|
||||
codeMessageIDTooHigh: "msg_id too high",
|
||||
codeIncorrectServerSalt: "incorrect server salt",
|
||||
|
||||
18: "incorrect two lower order msg_id bits",
|
||||
19: "container msg_id is the same as msg_id of a previously received message",
|
||||
20: "message too old",
|
||||
32: "msg_seqno too low",
|
||||
33: "msg_seqno too high",
|
||||
34: "even msg_seqno expected, but odd received",
|
||||
35: "odd msg_seqno expected, but even received",
|
||||
}[c.Code]
|
||||
if description == "" {
|
||||
return fmt.Sprintf("bad msg error code %d", c.Code)
|
||||
}
|
||||
return description
|
||||
}
|
||||
|
||||
func (c *Conn) handleBadMsg(b *bin.Buffer) error {
|
||||
id, err := b.PeekID()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
switch id {
|
||||
case mt.BadMsgNotificationTypeID:
|
||||
var bad mt.BadMsgNotification
|
||||
if err := bad.Decode(b); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.rpc.NotifyError(bad.BadMsgID, &badMessageError{Code: bad.ErrorCode})
|
||||
return nil
|
||||
case mt.BadServerSaltTypeID:
|
||||
var bad mt.BadServerSalt
|
||||
if err := bad.Decode(b); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.rpc.NotifyError(bad.BadMsgID, &badMessageError{Code: bad.ErrorCode, NewSalt: bad.NewServerSalt})
|
||||
return nil
|
||||
default:
|
||||
return errors.Errorf("unknown type id 0x%d", id)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,27 @@
|
||||
package mtproto
|
||||
|
||||
import (
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/proto"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
)
|
||||
|
||||
func (c *Conn) handleContainer(msgID int64, b *bin.Buffer) error {
|
||||
var container proto.MessageContainer
|
||||
if err := container.Decode(b); err != nil {
|
||||
return errors.Wrap(err, "container")
|
||||
}
|
||||
for _, msg := range container.Messages {
|
||||
if err := c.processContainerMessage(msgID, msg); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) processContainerMessage(msgID int64, msg proto.Message) error {
|
||||
b := &bin.Buffer{Buf: msg.Body}
|
||||
return c.handleMessage(msgID, b)
|
||||
}
|
||||
@@ -0,0 +1,25 @@
|
||||
package mtproto
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/mt"
|
||||
)
|
||||
|
||||
func (c *Conn) handleFutureSalts(b *bin.Buffer) error {
|
||||
var res mt.FutureSalts
|
||||
|
||||
if err := res.Decode(b); err != nil {
|
||||
return errors.Wrap(err, "error decode")
|
||||
}
|
||||
|
||||
c.salts.Store(res.Salts)
|
||||
|
||||
serverTime := time.Unix(int64(res.Now), 0)
|
||||
c.log.Debug("Got future salts", zap.Time("server_time", serverTime))
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,52 @@
|
||||
package mtproto
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/mt"
|
||||
)
|
||||
|
||||
func TestConn_handleFutureSalts(t *testing.T) {
|
||||
now := time.Now()
|
||||
ts := int(now.Unix())
|
||||
testdata := []mt.FutureSalt{
|
||||
{
|
||||
ValidSince: ts - 1,
|
||||
ValidUntil: ts + 1,
|
||||
Salt: 10,
|
||||
},
|
||||
{
|
||||
ValidSince: ts + 1,
|
||||
ValidUntil: ts + 3,
|
||||
Salt: 11,
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("OK", func(t *testing.T) {
|
||||
a := require.New(t)
|
||||
conn := Conn{log: zap.NewNop()}
|
||||
buf := bin.Buffer{}
|
||||
|
||||
a.NoError(buf.Encode(&mt.FutureSalts{
|
||||
ReqMsgID: 1,
|
||||
Now: 1,
|
||||
Salts: testdata,
|
||||
}))
|
||||
a.NoError(conn.handleFutureSalts(&buf))
|
||||
|
||||
salt, ok := conn.salts.Get(now)
|
||||
a.Equal(int64(10), salt)
|
||||
a.True(ok)
|
||||
})
|
||||
t.Run("Invalid", func(t *testing.T) {
|
||||
conn := Conn{}
|
||||
buf := bin.Buffer{}
|
||||
buf.PutID(mt.FutureSaltsTypeID)
|
||||
require.Error(t, conn.handleFutureSalts(&buf))
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,24 @@
|
||||
package mtproto
|
||||
|
||||
import (
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/proto"
|
||||
)
|
||||
|
||||
func gzip(b *bin.Buffer) (*bin.Buffer, error) {
|
||||
var content proto.GZIP
|
||||
if err := content.Decode(b); err != nil {
|
||||
return nil, errors.Wrap(err, "decode")
|
||||
}
|
||||
return &bin.Buffer{Buf: content.Data}, nil
|
||||
}
|
||||
|
||||
func (c *Conn) handleGZIP(msgID int64, b *bin.Buffer) error {
|
||||
content, err := gzip(b)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "unzip")
|
||||
}
|
||||
return c.handleMessage(msgID, content)
|
||||
}
|
||||
@@ -0,0 +1,55 @@
|
||||
package mtproto
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/proto"
|
||||
)
|
||||
|
||||
func TestGzipDecode(t *testing.T) {
|
||||
t.Run("Valid", func(t *testing.T) {
|
||||
a := require.New(t)
|
||||
|
||||
// from grammers
|
||||
// https://github.com/Lonami/grammers/blob/ed5470b3f79f5d41a5090450b866026568f3cd60/lib/grammers-mtproto/src/manual_tl.rs#L197
|
||||
data := []byte{
|
||||
1, 109, 92, 243, 132, 41, 150, 69, 54, 75, 49, 94, 161, 207, 114, 48, 254, 140, 1, 0,
|
||||
31, 139, 8, 0, 0, 0, 0, 0, 0, 3, 149, 147, 61, 75, 195, 80, 20, 134, 79, 62, 20, 84,
|
||||
170, 1, 17, 28, 68, 28, 132, 110, 183, 247, 38, 185, 249, 154, 58, 10, 130, 116, 116,
|
||||
170, 54, 109, 18, 11, 173, 169, 109, 90, 112, 210, 209, 81, 112, 112, 22, 127, 131, 56,
|
||||
232, 232, 224, 143, 112, 112, 238, 159, 168, 55, 31, 234, 77, 91, 82, 26, 184, 57, 36,
|
||||
79, 222, 115, 222, 251, 38, 9, 170, 27, 218, 209, 62, 128, 113, 76, 234, 181, 83, 82,
|
||||
55, 31, 175, 223, 101, 0, 216, 249, 120, 217, 219, 102, 181, 244, 244, 186, 203, 10, 8,
|
||||
108, 109, 18, 221, 70, 132, 234, 136, 152, 20, 81, 13, 222, 132, 148, 43, 11, 184, 144,
|
||||
241, 178, 138, 49, 113, 176, 171, 90, 142, 175, 106, 45, 199, 79, 46, 217, 145, 63, 53,
|
||||
126, 117, 241, 92, 49, 215, 215, 48, 17, 197, 185, 185, 179, 156, 252, 113, 49, 227,
|
||||
91, 60, 39, 148, 240, 190, 196, 127, 95, 134, 217, 116, 176, 238, 89, 177, 47, 181,
|
||||
200, 151, 180, 156, 206, 229, 247, 35, 229, 252, 176, 156, 8, 198, 252, 126, 138, 184,
|
||||
144, 241, 57, 57, 106, 139, 114, 148, 167, 115, 178, 73, 46, 199, 34, 46, 100, 124,
|
||||
206, 126, 245, 162, 185, 98, 166, 227, 242, 55, 16, 81, 49, 159, 227, 18, 125, 93, 222,
|
||||
207, 202, 76, 14, 126, 172, 163, 69, 126, 148, 76, 87, 178, 9, 139, 213, 66, 148, 185,
|
||||
177, 48, 0, 159, 211, 52, 167, 118, 198, 27, 189, 145, 134, 6, 145, 215, 65, 205, 176,
|
||||
11, 240, 201, 158, 171, 150, 36, 104, 177, 90, 211, 37, 184, 99, 63, 11, 30, 2, 124,
|
||||
63, 200, 73, 253, 98, 141, 206, 199, 233, 119, 18, 63, 11, 207, 34, 76, 38, 147, 155,
|
||||
120, 193, 248, 48, 185, 23, 207, 186, 117, 214, 146, 26, 247, 57, 56, 1, 184, 63, 19,
|
||||
18, 189, 82, 102, 51, 47, 162, 168, 55, 112, 42, 149, 8, 117, 189, 10, 123, 247, 65,
|
||||
219, 95, 247, 195, 97, 127, 112, 53, 108, 244, 61, 144, 221, 246, 101, 192, 116, 171,
|
||||
65, 24, 6, 29, 47, 13, 83, 73, 203, 15, 58, 186, 13, 141, 216, 3, 0, 0,
|
||||
}
|
||||
|
||||
var result proto.Result
|
||||
err := result.Decode(&bin.Buffer{Buf: data})
|
||||
a.NoError(err)
|
||||
|
||||
r, err := gzip(&bin.Buffer{Buf: result.Result})
|
||||
a.NoError(err)
|
||||
a.Equal(r.Len(), 984)
|
||||
})
|
||||
t.Run("Invalid", func(t *testing.T) {
|
||||
_, err := gzip(new(bin.Buffer))
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,44 @@
|
||||
package mtproto
|
||||
|
||||
import (
|
||||
"github.com/go-faster/errors"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/mt"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/proto"
|
||||
)
|
||||
|
||||
func (c *Conn) handleMessage(msgID int64, b *bin.Buffer) error {
|
||||
id, err := b.PeekID()
|
||||
if err != nil {
|
||||
// Empty body.
|
||||
return errors.Wrap(err, "peek message type")
|
||||
}
|
||||
|
||||
c.logWithBuffer(b).Debug("Handle message", zap.Int64("msg_id", msgID))
|
||||
|
||||
switch id {
|
||||
case mt.NewSessionCreatedTypeID:
|
||||
return c.handleSessionCreated(b)
|
||||
case mt.BadMsgNotificationTypeID, mt.BadServerSaltTypeID:
|
||||
return c.handleBadMsg(b)
|
||||
case mt.FutureSaltsTypeID:
|
||||
return c.handleFutureSalts(b)
|
||||
case proto.MessageContainerTypeID:
|
||||
return c.handleContainer(msgID, b)
|
||||
case proto.ResultTypeID:
|
||||
return c.handleResult(b)
|
||||
case mt.PongTypeID:
|
||||
return c.handlePong(b)
|
||||
case mt.MsgsAckTypeID:
|
||||
return c.handleAck(b)
|
||||
case proto.GZIPTypeID:
|
||||
return c.handleGZIP(msgID, b)
|
||||
case mt.MsgDetailedInfoTypeID,
|
||||
mt.MsgNewDetailedInfoTypeID:
|
||||
return nil
|
||||
default:
|
||||
return c.handler.OnMessage(b)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,84 @@
|
||||
//go:build fuzz
|
||||
// +build fuzz
|
||||
|
||||
package mtproto
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/mt"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/proto"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/rpc"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/testutil"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tmap"
|
||||
)
|
||||
|
||||
type fuzzHandler struct {
|
||||
types *tmap.Constructor
|
||||
}
|
||||
|
||||
func (h fuzzHandler) OnMessage(b *bin.Buffer) error {
|
||||
id, err := b.PeekID()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
v := h.types.New(id)
|
||||
if v == nil {
|
||||
return errors.New("not found")
|
||||
}
|
||||
if err := v.Decode(b); err != nil {
|
||||
return errors.Wrap(err, "decode")
|
||||
}
|
||||
|
||||
// Performing decode cycle.
|
||||
var newBuff bin.Buffer
|
||||
newV := h.types.New(id)
|
||||
if err := v.Encode(&newBuff); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if err := newV.Decode(&newBuff); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (fuzzHandler) OnSession(session Session) error { return nil }
|
||||
|
||||
var (
|
||||
conn *Conn
|
||||
buf *bin.Buffer
|
||||
)
|
||||
|
||||
func init() {
|
||||
handler := fuzzHandler{
|
||||
// Handler will try to dynamically decode any incoming message.
|
||||
types: tmap.NewConstructor(
|
||||
tg.TypesConstructorMap(),
|
||||
mt.TypesConstructorMap(),
|
||||
),
|
||||
}
|
||||
c := &Conn{
|
||||
rand: testutil.ZeroRand{},
|
||||
rpc: rpc.New(rpc.NopSend, rpc.Options{}),
|
||||
log: zap.NewNop(),
|
||||
messageID: proto.NewMessageIDGen(time.Now),
|
||||
handler: handler,
|
||||
}
|
||||
|
||||
conn = c
|
||||
buf = &bin.Buffer{}
|
||||
}
|
||||
|
||||
func FuzzHandleMessage(data []byte) int {
|
||||
buf.ResetTo(data)
|
||||
if err := conn.handleMessage(buf); err != nil {
|
||||
return 0
|
||||
}
|
||||
return 1
|
||||
}
|
||||
@@ -0,0 +1,161 @@
|
||||
package mtproto
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/gotd/neo"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/mt"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/rpc"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tdsync"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/testutil"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tmap"
|
||||
)
|
||||
|
||||
type testUpdateHandler struct {
|
||||
types *tmap.Constructor
|
||||
}
|
||||
|
||||
func (h testUpdateHandler) OnMessage(b *bin.Buffer) error {
|
||||
id, err := b.PeekID()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
v := h.types.New(id)
|
||||
if v == nil {
|
||||
return errors.New("not found")
|
||||
}
|
||||
if err := v.Decode(b); err != nil {
|
||||
return errors.Wrap(err, "decode")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (testUpdateHandler) OnSession(session Session) error { return nil }
|
||||
|
||||
func newTestHandler() Handler {
|
||||
return &testUpdateHandler{
|
||||
types: tmap.NewConstructor(
|
||||
tg.TypesConstructorMap(),
|
||||
mt.TypesConstructorMap(),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnHandleMessage(t *testing.T) {
|
||||
c := &Conn{
|
||||
rand: Zero{},
|
||||
log: zap.NewNop(),
|
||||
handler: newTestHandler(),
|
||||
}
|
||||
|
||||
for i, input := range []string{
|
||||
"\xdc\xf8\xf1stewa\x00O\x03expired c" +
|
||||
"ertificate\x02\x00\x00\x00\xef",
|
||||
|
||||
"\x01m\\\xf300000000\x19\xcaD!0000" +
|
||||
"\x1100000000000000000",
|
||||
|
||||
"\x01m\\\xf300000000\x19\xcaD!0000" +
|
||||
"\xfe0",
|
||||
|
||||
"@B\xaet\x15ĵ\x1c0000,\x8f\xf8B0000" +
|
||||
"00000000\x15ĵ\x1c0000\xff000" +
|
||||
"00000000000000000000" +
|
||||
"00000000000000000000" +
|
||||
"00000000000000000000" +
|
||||
"00000000000000000000" +
|
||||
"00000000000000000000" +
|
||||
"00000000000000000000" +
|
||||
"00000000000000000000" +
|
||||
"00000000000000000000" +
|
||||
"00000000000000000000" +
|
||||
"00000000000000000000" +
|
||||
"00000000000000000000" +
|
||||
"00000000000000000000" +
|
||||
"000000000000",
|
||||
} {
|
||||
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
|
||||
if err := c.handleMessage(0, &bin.Buffer{Buf: []byte(input)}); err == nil {
|
||||
t.Fatal("error expected")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnHandleMessageCorpus(t *testing.T) {
|
||||
if testutil.Race {
|
||||
t.Skip("Skipped")
|
||||
}
|
||||
|
||||
c := &Conn{
|
||||
handler: newTestHandler(),
|
||||
rpc: rpc.New(rpc.NopSend, rpc.Options{}),
|
||||
clock: neo.NewTime(time.Now()),
|
||||
rand: Zero{},
|
||||
log: zap.NewNop(),
|
||||
gotSession: tdsync.NewReady(),
|
||||
}
|
||||
|
||||
corpusDir := filepath.Join("..", "_fuzz", "handle_message", "corpus")
|
||||
corpus, err := os.ReadDir(corpusDir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
b := &bin.Buffer{}
|
||||
types := tmap.New(
|
||||
tg.TypesMap(),
|
||||
mt.TypesMap(),
|
||||
)
|
||||
|
||||
for _, f := range corpus {
|
||||
t.Run(f.Name(), func(t *testing.T) {
|
||||
data, err := os.ReadFile(filepath.Join(corpusDir, f.Name()))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Default to 128 bytes per invocation.
|
||||
allocThreshold := 128
|
||||
|
||||
// Adjusting threshold for specific types.
|
||||
//
|
||||
// Probably there should be better way to do this, but
|
||||
// manually ensuring allocation distribution by type is
|
||||
// pretty ok.
|
||||
b.ResetTo(data)
|
||||
if id, err := b.PeekID(); err == nil {
|
||||
t.Logf("Type: 0x%x %s", id, types.Get(id))
|
||||
switch id {
|
||||
case tg.UpdatesTypeID,
|
||||
tg.TextFixedTypeID,
|
||||
tg.InputPeerChannelFromMessageTypeID,
|
||||
tg.PageBlockRelatedArticlesTypeID:
|
||||
allocThreshold = 512
|
||||
case tg.TextBoldTypeID,
|
||||
tg.TextItalicTypeID,
|
||||
tg.TextMarkedTypeID,
|
||||
tg.MessageTypeID,
|
||||
tg.PageBlockCoverTypeID,
|
||||
tg.InputMediaUploadedDocumentTypeID,
|
||||
tg.SecureRequiredTypeOneOfTypeID:
|
||||
allocThreshold = 256
|
||||
}
|
||||
}
|
||||
|
||||
testutil.MaxAlloc(t, allocThreshold, func() {
|
||||
b.ResetTo(data)
|
||||
_ = c.handleMessage(0, b)
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,66 @@
|
||||
package mtproto
|
||||
|
||||
import (
|
||||
"github.com/go-faster/errors"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/mt"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/proto"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tgerr"
|
||||
)
|
||||
|
||||
func (c *Conn) handleResult(b *bin.Buffer) error {
|
||||
// Response to an RPC query.
|
||||
var res proto.Result
|
||||
if err := res.Decode(b); err != nil {
|
||||
return errors.Wrap(err, "decode")
|
||||
}
|
||||
|
||||
// Now b contains result message.
|
||||
b.ResetTo(res.Result)
|
||||
|
||||
msgID := zap.Int64("msg_id", res.RequestMessageID)
|
||||
c.logWithBuffer(b).Debug("Handle result", msgID)
|
||||
|
||||
// Handling gzipped results.
|
||||
id, err := b.PeekID()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if id == proto.GZIPTypeID {
|
||||
content, err := gzip(b)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "decompress")
|
||||
}
|
||||
|
||||
// Replacing buffer so callback will deal with uncompressed data.
|
||||
b = content
|
||||
c.logWithBuffer(b).Debug("Decompressed", msgID)
|
||||
|
||||
// Replacing id with inner id if error is compressed for any reason.
|
||||
if id, err = b.PeekID(); err != nil {
|
||||
return errors.Wrap(err, "peek id")
|
||||
}
|
||||
}
|
||||
|
||||
if id == mt.RPCErrorTypeID {
|
||||
var rpcErr mt.RPCError
|
||||
if err := rpcErr.Decode(b); err != nil {
|
||||
return errors.Wrap(err, "error decode")
|
||||
}
|
||||
|
||||
c.log.Debug("Got error", msgID,
|
||||
zap.Int("err_code", rpcErr.ErrorCode),
|
||||
zap.String("err_msg", rpcErr.ErrorMessage),
|
||||
)
|
||||
c.rpc.NotifyError(res.RequestMessageID, tgerr.New(rpcErr.ErrorCode, rpcErr.ErrorMessage))
|
||||
|
||||
return nil
|
||||
}
|
||||
if id == mt.PongTypeID {
|
||||
return c.handlePong(b)
|
||||
}
|
||||
|
||||
return c.rpc.NotifyResult(res.RequestMessageID, b)
|
||||
}
|
||||
@@ -0,0 +1,40 @@
|
||||
package mtproto
|
||||
|
||||
import (
|
||||
"github.com/go-faster/errors"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/mt"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/proto"
|
||||
)
|
||||
|
||||
func (c *Conn) handleSessionCreated(b *bin.Buffer) error {
|
||||
var s mt.NewSessionCreated
|
||||
if err := s.Decode(b); err != nil {
|
||||
return errors.Wrap(err, "decode")
|
||||
}
|
||||
c.gotSession.Signal()
|
||||
|
||||
created := proto.MessageID(s.FirstMsgID).Time()
|
||||
now := c.clock.Now()
|
||||
c.log.Debug("Session created",
|
||||
zap.Int64("unique_id", s.UniqueID),
|
||||
zap.Int64("first_msg_id", s.FirstMsgID),
|
||||
zap.Time("first_msg_time", created.Local()),
|
||||
)
|
||||
|
||||
if (created.Before(now) && now.Sub(created) > maxPast) || created.Sub(now) > maxFuture {
|
||||
c.log.Warn("Local clock needs synchronization",
|
||||
zap.Time("first_msg_time", created),
|
||||
zap.Time("local", now),
|
||||
zap.Duration("time_difference", now.Sub(created)),
|
||||
)
|
||||
}
|
||||
|
||||
c.storeSalt(s.ServerSalt)
|
||||
if err := c.handler.OnSession(c.session()); err != nil {
|
||||
return errors.Wrap(err, "handler.OnSession")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,61 @@
|
||||
package mtproto
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
"go.uber.org/zap/zaptest/observer"
|
||||
|
||||
"github.com/gotd/neo"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/mt"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/proto"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tdsync"
|
||||
)
|
||||
|
||||
func TestConn_handleSessionCreated(t *testing.T) {
|
||||
t.Run("NeedSynchronization", func(t *testing.T) {
|
||||
a := require.New(t)
|
||||
logger, logs := observer.New(zapcore.WarnLevel)
|
||||
|
||||
now := time.Unix(100, 0)
|
||||
clock := neo.NewTime(now)
|
||||
gotSession := tdsync.NewReady()
|
||||
conn := Conn{
|
||||
clock: clock,
|
||||
log: zap.New(logger),
|
||||
gotSession: gotSession,
|
||||
handler: newTestHandler(),
|
||||
}
|
||||
|
||||
buf := bin.Buffer{}
|
||||
msgID := proto.NewMessageID(now.Add(maxFuture+time.Second), proto.MessageFromClient)
|
||||
a.NoError(buf.Encode(&mt.NewSessionCreated{
|
||||
FirstMsgID: int64(msgID),
|
||||
UniqueID: 10,
|
||||
ServerSalt: 10,
|
||||
}))
|
||||
a.NoError(conn.handleSessionCreated(&buf))
|
||||
|
||||
select {
|
||||
case <-gotSession.Ready():
|
||||
default:
|
||||
t.Fatal("expected gotSession signal")
|
||||
}
|
||||
a.Equal(int64(10), conn.salt)
|
||||
|
||||
msgs := logs.All()
|
||||
a.Len(msgs, 1)
|
||||
a.Equal("Local clock needs synchronization", msgs[0].Message)
|
||||
})
|
||||
t.Run("Invalid", func(t *testing.T) {
|
||||
conn := Conn{}
|
||||
buf := bin.Buffer{}
|
||||
buf.PutID(mt.NewSessionCreatedTypeID)
|
||||
require.Error(t, conn.handleSessionCreated(&buf))
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
package mtproto
|
||||
|
||||
import "go.mau.fi/mautrix-telegram/pkg/gotd/proto"
|
||||
|
||||
func (c *Conn) newMessageID() int64 {
|
||||
return c.messageID.New(proto.MessageFromClient)
|
||||
}
|
||||
@@ -0,0 +1,25 @@
|
||||
package mtproto
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/proto"
|
||||
)
|
||||
|
||||
func TestClientNewMessageID(t *testing.T) {
|
||||
c := newTestClient(nil)
|
||||
now := c.clock.Now()
|
||||
id := proto.MessageID(c.newMessageID())
|
||||
assert.Equal(t, proto.MessageFromClient, id.Type())
|
||||
|
||||
lag := id.Time().Sub(now)
|
||||
if lag < 0 {
|
||||
lag *= -1
|
||||
}
|
||||
if lag > time.Second {
|
||||
t.Error("generated id lags in time")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,104 @@
|
||||
package mtproto
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
func TestMsgSeq(t *testing.T) {
|
||||
type record struct {
|
||||
msgID int64
|
||||
seqNo int32
|
||||
}
|
||||
|
||||
var (
|
||||
records []record
|
||||
mux sync.Mutex
|
||||
)
|
||||
|
||||
const (
|
||||
workers = 32
|
||||
requestsPerWorker = 200
|
||||
)
|
||||
client := newTestClient(func(msgID int64, seqNo int32, body bin.Encoder) (bin.Encoder, error) {
|
||||
mux.Lock()
|
||||
records = append(records, record{msgID, seqNo})
|
||||
mux.Unlock()
|
||||
return &tg.Config{}, nil
|
||||
})
|
||||
|
||||
{
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(workers)
|
||||
for i := 0; i < workers; i++ {
|
||||
go func(t *testing.T) {
|
||||
defer wg.Done()
|
||||
|
||||
for i := 0; i < requestsPerWorker; i++ {
|
||||
err := client.Invoke(context.Background(), &tg.Config{}, &tg.Config{})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
}(t)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
var received []record
|
||||
less := func(msgID int64, f func(r record)) {
|
||||
for _, recv := range received {
|
||||
if recv.msgID < msgID {
|
||||
f(recv)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
greater := func(msgID int64, f func(r record)) {
|
||||
for _, recv := range received {
|
||||
if recv.msgID > msgID {
|
||||
f(recv)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
contains := func(msgID int64) bool {
|
||||
for _, rec := range received {
|
||||
if rec.msgID == msgID {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
for _, current := range records {
|
||||
if contains(current.msgID) {
|
||||
// Ignore duplicates.
|
||||
continue
|
||||
}
|
||||
|
||||
less(current.msgID, func(less record) {
|
||||
// The server has already received a message with a lower msg_id
|
||||
// but with either a higher or an equal and odd seqno.
|
||||
if less.seqNo >= current.seqNo && less.seqNo%2 == 1 {
|
||||
t.Fatal("seqNo too low")
|
||||
}
|
||||
})
|
||||
|
||||
greater(current.msgID, func(greater record) {
|
||||
// Similarly, there is a message with a higher msg_id
|
||||
// but with either a lower or an equal and odd seqno.
|
||||
if greater.seqNo <= current.seqNo && greater.seqNo%2 == 1 {
|
||||
t.Fatal("seqNo too high")
|
||||
}
|
||||
})
|
||||
|
||||
received = append(received, current)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,60 @@
|
||||
package mtproto
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
)
|
||||
|
||||
type testPayload struct {
|
||||
Data []byte
|
||||
}
|
||||
|
||||
func (d testPayload) Decode(b *bin.Buffer) error {
|
||||
_, err := b.Bytes()
|
||||
return err
|
||||
}
|
||||
|
||||
func (d testPayload) Encode(b *bin.Buffer) error {
|
||||
b.PutBytes(d.Data)
|
||||
return nil
|
||||
}
|
||||
|
||||
type noopBuf struct{}
|
||||
|
||||
func (n noopBuf) Consume(id int64) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
type constantConn struct {
|
||||
data []byte
|
||||
cancel context.CancelFunc
|
||||
counter int
|
||||
mux sync.Mutex
|
||||
}
|
||||
|
||||
func (c *constantConn) Send(ctx context.Context, b *bin.Buffer) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *constantConn) Recv(ctx context.Context, b *bin.Buffer) error {
|
||||
c.mux.Lock()
|
||||
exit := c.counter == 0
|
||||
if exit {
|
||||
c.mux.Unlock()
|
||||
c.cancel()
|
||||
return errors.New("error")
|
||||
}
|
||||
c.counter--
|
||||
c.mux.Unlock()
|
||||
|
||||
b.Put(c.data)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *constantConn) Close() error {
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,68 @@
|
||||
package mtproto
|
||||
|
||||
import (
|
||||
"github.com/go-faster/errors"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/proto"
|
||||
)
|
||||
|
||||
func (c *Conn) newEncryptedMessage(id int64, seq int32, payload bin.Encoder, b *bin.Buffer) error {
|
||||
s := c.session()
|
||||
|
||||
// TODO(tdakkota): Smarter gzip.
|
||||
// 1) Generate Length() method for every encoder, to count length without encoding.
|
||||
// 2) Re-use buffer instead of using yet one.
|
||||
// 3) Do not send proto.GZIP if gzipped size is equal or bigger.
|
||||
var (
|
||||
d crypto.EncryptedMessageData
|
||||
log = c.log
|
||||
)
|
||||
if c.compressThreshold <= 0 {
|
||||
if obj, ok := payload.(interface{ TypeID() uint32 }); ok {
|
||||
log = c.logWithTypeID(obj.TypeID())
|
||||
}
|
||||
d = crypto.EncryptedMessageData{
|
||||
SessionID: s.ID,
|
||||
Salt: s.Salt,
|
||||
MessageID: id,
|
||||
SeqNo: seq,
|
||||
Message: payload,
|
||||
}
|
||||
} else {
|
||||
payloadBuf := bufPool.Get()
|
||||
defer bufPool.Put(payloadBuf)
|
||||
if err := payload.Encode(payloadBuf); err != nil {
|
||||
return errors.Wrap(err, "encode payload")
|
||||
}
|
||||
|
||||
log = c.logWithType(payloadBuf)
|
||||
if payloadBuf.Len() > c.compressThreshold {
|
||||
d = crypto.EncryptedMessageData{
|
||||
SessionID: s.ID,
|
||||
Salt: s.Salt,
|
||||
MessageID: id,
|
||||
SeqNo: seq,
|
||||
Message: proto.GZIP{Data: payloadBuf.Raw()},
|
||||
}
|
||||
} else {
|
||||
d = crypto.EncryptedMessageData{
|
||||
SessionID: s.ID,
|
||||
Salt: s.Salt,
|
||||
MessageID: id,
|
||||
SeqNo: seq,
|
||||
MessageDataLen: int32(payloadBuf.Len()),
|
||||
MessageDataWithPadding: payloadBuf.Buf,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug("Request", zap.Int64("msg_id", id))
|
||||
if err := c.cipher.Encrypt(s.Key, d, b); err != nil {
|
||||
return errors.Wrap(err, "encrypt")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,167 @@
|
||||
package mtproto
|
||||
|
||||
import (
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/clock"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/exchange"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/proto"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/rpc"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tmap"
|
||||
)
|
||||
|
||||
// Options of Conn.
|
||||
type Options struct {
|
||||
// DC is datacenter ID for key exchange.
|
||||
// Defaults to 2.
|
||||
DC int
|
||||
|
||||
// PublicKeys of telegram.
|
||||
//
|
||||
// If not provided, embedded public keys will be used.
|
||||
PublicKeys []exchange.PublicKey
|
||||
|
||||
// Random is random source. Defaults to crypto.
|
||||
Random io.Reader
|
||||
// Logger is instance of zap.Logger. No logs by default.
|
||||
Logger *zap.Logger
|
||||
// Handler will be called on received message.
|
||||
Handler Handler
|
||||
|
||||
// AckBatchSize is maximum ack-s to buffer.
|
||||
AckBatchSize int
|
||||
// AckInterval is maximum time to buffer ack.
|
||||
AckInterval time.Duration
|
||||
|
||||
// RetryInterval is duration between retries.
|
||||
RetryInterval time.Duration
|
||||
// MaxRetries is max retry count until rpc request failure.
|
||||
MaxRetries int
|
||||
|
||||
// DialTimeout is timeout of creating connection.
|
||||
DialTimeout time.Duration
|
||||
// ExchangeTimeout is timeout of every key exchange request.
|
||||
ExchangeTimeout time.Duration
|
||||
// SaltFetchInterval is duration between get_future_salts request.
|
||||
SaltFetchInterval time.Duration
|
||||
// PingTimeout sets ping_delay_disconnect timeout.
|
||||
PingTimeout time.Duration
|
||||
// PingInterval is duration between ping_delay_disconnect request.
|
||||
PingInterval time.Duration
|
||||
// PingCallback is called after a ping is acknowledged by the server.
|
||||
PingCallback func()
|
||||
// RequestTimeout is function which returns request timeout for given type ID.
|
||||
RequestTimeout func(req uint32) time.Duration
|
||||
|
||||
// CompressThreshold is a threshold in bytes to determine that message
|
||||
// is large enough to be compressed using GZIP.
|
||||
// If < 0, compression will be disabled.
|
||||
// If == 0, default value will be used.
|
||||
CompressThreshold int
|
||||
// MessageID is message id source. Share source between connection to
|
||||
// reduce collision probability.
|
||||
MessageID MessageIDSource
|
||||
// Clock is current time source. Defaults to system time.
|
||||
Clock clock.Clock
|
||||
// Types map, used in verbose logging of incoming message.
|
||||
Types *tmap.Map
|
||||
// Key that can be used to restore previous connection.
|
||||
Key crypto.AuthKey
|
||||
// Salt from server that can be used to restore previous connection.
|
||||
Salt int64
|
||||
|
||||
// Tracer for OTEL.
|
||||
Tracer trace.Tracer
|
||||
|
||||
// Private options.
|
||||
|
||||
// Cipher defines message crypto.
|
||||
Cipher Cipher
|
||||
// engine for replacing RPC engine.
|
||||
engine *rpc.Engine
|
||||
|
||||
// OnError is a function that is called if there is any error.
|
||||
OnError func(err error)
|
||||
}
|
||||
|
||||
type nopHandler struct{}
|
||||
|
||||
func (nopHandler) OnMessage(b *bin.Buffer) error { return nil }
|
||||
func (nopHandler) OnSession(session Session) error { return nil }
|
||||
|
||||
func (opt *Options) setDefaultPublicKeys() {
|
||||
// Using public keys that are included with distribution if not
|
||||
// provided.
|
||||
//
|
||||
// This should never fail and keys should be valid for recent
|
||||
// library versions.
|
||||
opt.PublicKeys = vendoredKeys()
|
||||
}
|
||||
|
||||
func (opt *Options) setDefaults() {
|
||||
if opt.DC == 0 {
|
||||
opt.DC = 2
|
||||
}
|
||||
if opt.Random == nil {
|
||||
opt.Random = crypto.DefaultRand()
|
||||
}
|
||||
if opt.Logger == nil {
|
||||
opt.Logger = zap.NewNop()
|
||||
}
|
||||
if opt.AckBatchSize == 0 {
|
||||
opt.AckBatchSize = 20
|
||||
}
|
||||
if opt.AckInterval == 0 {
|
||||
opt.AckInterval = 15 * time.Second
|
||||
}
|
||||
if opt.RetryInterval == 0 {
|
||||
opt.RetryInterval = 5 * time.Second
|
||||
}
|
||||
if opt.MaxRetries == 0 {
|
||||
opt.MaxRetries = 5
|
||||
}
|
||||
if opt.DialTimeout == 0 {
|
||||
opt.DialTimeout = 35 * time.Second
|
||||
}
|
||||
if opt.ExchangeTimeout == 0 {
|
||||
opt.ExchangeTimeout = exchange.DefaultTimeout
|
||||
}
|
||||
if opt.SaltFetchInterval == 0 {
|
||||
opt.SaltFetchInterval = 1 * time.Hour
|
||||
}
|
||||
if opt.PingTimeout == 0 {
|
||||
opt.PingTimeout = 15 * time.Second
|
||||
}
|
||||
if opt.PingInterval == 0 {
|
||||
opt.PingInterval = 1 * time.Minute
|
||||
}
|
||||
if opt.RequestTimeout == nil {
|
||||
opt.RequestTimeout = func(req uint32) time.Duration {
|
||||
return 15 * time.Second
|
||||
}
|
||||
}
|
||||
if opt.CompressThreshold == 0 {
|
||||
opt.CompressThreshold = 1024
|
||||
}
|
||||
if opt.Clock == nil {
|
||||
opt.Clock = clock.System
|
||||
}
|
||||
if opt.MessageID == nil {
|
||||
opt.MessageID = proto.NewMessageIDGen(opt.Clock.Now)
|
||||
}
|
||||
if len(opt.PublicKeys) == 0 {
|
||||
opt.setDefaultPublicKeys()
|
||||
}
|
||||
if opt.Handler == nil {
|
||||
opt.Handler = nopHandler{}
|
||||
}
|
||||
if opt.Cipher == nil {
|
||||
opt.Cipher = crypto.NewClientCipher(opt.Random)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,122 @@
|
||||
package mtproto
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/mt"
|
||||
)
|
||||
|
||||
// Ping sends ping request to server and waits until pong is received or
|
||||
// context is canceled.
|
||||
func (c *Conn) Ping(ctx context.Context) error {
|
||||
// Generating random id.
|
||||
// Probably we should check for collisions here.
|
||||
pingID, err := crypto.RandInt64(c.rand)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pong := c.pong(pingID)
|
||||
defer c.removePong(pingID)
|
||||
|
||||
if err := c.writeServiceMessage(ctx, &mt.PingRequest{PingID: pingID}); err != nil {
|
||||
return errors.Wrap(err, "write")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-pong:
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) handlePong(b *bin.Buffer) error {
|
||||
var pong mt.Pong
|
||||
if err := pong.Decode(b); err != nil {
|
||||
return errors.Errorf("decode: %x", err)
|
||||
}
|
||||
c.log.Debug("Pong")
|
||||
|
||||
c.pingMux.Lock()
|
||||
ch, ok := c.ping[pong.PingID]
|
||||
if ok {
|
||||
close(ch)
|
||||
delete(c.ping, pong.PingID)
|
||||
}
|
||||
c.pingMux.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) pingDelayDisconnect(ctx context.Context, delay int) error {
|
||||
// Generating random id.
|
||||
// Probably we should check for collisions here.
|
||||
pingID, err := crypto.RandInt64(c.rand)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pong := c.pong(pingID)
|
||||
defer c.removePong(pingID)
|
||||
|
||||
if err := c.writeServiceMessage(ctx, &mt.PingDelayDisconnectRequest{
|
||||
PingID: pingID,
|
||||
DisconnectDelay: delay,
|
||||
}); err != nil {
|
||||
return errors.Wrap(err, "write")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-pong:
|
||||
if c.pingCallback != nil {
|
||||
c.pingCallback()
|
||||
}
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) pong(pingID int64) chan struct{} {
|
||||
ch := make(chan struct{})
|
||||
c.pingMux.Lock()
|
||||
c.ping[pingID] = ch
|
||||
c.pingMux.Unlock()
|
||||
return ch
|
||||
}
|
||||
|
||||
func (c *Conn) removePong(pingID int64) {
|
||||
c.pingMux.Lock()
|
||||
delete(c.ping, pingID)
|
||||
c.pingMux.Unlock()
|
||||
}
|
||||
|
||||
func (c *Conn) pingLoop(ctx context.Context) error {
|
||||
// If the client sends these pings once every 60 seconds,
|
||||
// for example, it may set disconnect_delay equal to 75 seconds.
|
||||
delay := c.pingInterval + c.pingTimeout
|
||||
|
||||
ticker := c.clock.Ticker(c.pingInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return errors.Wrap(ctx.Err(), "ping loop")
|
||||
case <-ticker.C():
|
||||
if err := func() error {
|
||||
ctx, cancel := context.WithTimeout(ctx, c.pingTimeout)
|
||||
defer cancel()
|
||||
|
||||
return c.pingDelayDisconnect(ctx, int(delay.Seconds()))
|
||||
}(); err != nil {
|
||||
return errors.Wrap(err, "disconnect (pong missed)")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,213 @@
|
||||
package mtproto
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/proto"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/proto/codec"
|
||||
)
|
||||
|
||||
// https://core.telegram.org/mtproto/description#message-identifier-msg-id
|
||||
// A message is rejected over 300 seconds after it is created or 30 seconds
|
||||
// before it is created (this is needed to protect from replay attacks).
|
||||
const (
|
||||
maxPast = time.Second * 300
|
||||
maxFuture = time.Second * 30
|
||||
)
|
||||
|
||||
// errRejected is returned on invalid message that should not be processed.
|
||||
var errRejected = errors.New("message rejected")
|
||||
|
||||
func checkMessageID(now time.Time, rawID int64) error {
|
||||
id := proto.MessageID(rawID)
|
||||
|
||||
// Check that message is from server.
|
||||
switch id.Type() {
|
||||
case proto.MessageFromServer, proto.MessageServerResponse:
|
||||
// Valid.
|
||||
default:
|
||||
return errors.Wrapf(errRejected, "unexpected type %s", id.Type())
|
||||
}
|
||||
|
||||
created := id.Time()
|
||||
if created.Before(now) && now.Sub(created) > maxPast {
|
||||
return errors.Wrap(errRejected, "created too far in past")
|
||||
}
|
||||
if created.Sub(now) > maxFuture {
|
||||
return errors.Wrap(errRejected, "created too far in future")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) decryptMessage(b *bin.Buffer) (*crypto.EncryptedMessageData, error) {
|
||||
session := c.session()
|
||||
msg, err := c.cipher.DecryptFromBuffer(session.Key, b)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "decrypt")
|
||||
}
|
||||
|
||||
// Validating message. This protects from replay attacks.
|
||||
if msg.SessionID != session.ID {
|
||||
return nil, errors.Wrapf(errRejected, "invalid session (got %d, expected %d)", msg.SessionID, session.ID)
|
||||
}
|
||||
if err := checkMessageID(c.clock.Now(), msg.MessageID); err != nil {
|
||||
return nil, errors.Wrapf(err, "bad message id %d", msg.MessageID)
|
||||
}
|
||||
if !c.messageIDBuf.Consume(msg.MessageID) {
|
||||
return nil, errors.Wrapf(errRejected, "duplicate or too low message id %d", msg.MessageID)
|
||||
}
|
||||
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func (c *Conn) consumeMessage(ctx context.Context, buf *bin.Buffer) error {
|
||||
msg, err := c.decryptMessage(buf)
|
||||
if errors.Is(err, errRejected) {
|
||||
c.log.Warn("Ignoring rejected message", zap.Error(err))
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "consume message")
|
||||
}
|
||||
|
||||
if err := c.handleMessage(msg.MessageID, &bin.Buffer{Buf: msg.Data()}); err != nil {
|
||||
// Probably we can return here, but this will shutdown whole
|
||||
// connection which can be unexpected.
|
||||
c.log.Warn("Error while handling message", zap.Error(err))
|
||||
// Sending acknowledge even on error. Client should restore
|
||||
// from missing updates via explicit pts check and getDiff call.
|
||||
}
|
||||
|
||||
needAck := (msg.SeqNo & 0x01) != 0
|
||||
if needAck {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case c.ackSendChan <- msg.MessageID:
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) noUpdates(err error) bool {
|
||||
// Checking for read timeout.
|
||||
var syscall *net.OpError
|
||||
if errors.As(err, &syscall) && syscall.Timeout() {
|
||||
// We call SetReadDeadline so such error is expected.
|
||||
c.log.Debug("No updates")
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *Conn) handleAuthKeyNotFound(ctx context.Context) error {
|
||||
if c.session().ID == 0 {
|
||||
// The 404 error can also be caused by zero session id.
|
||||
// See https://github.com/gotd/td/issues/107
|
||||
//
|
||||
// We should recover from this in createAuthKey, but in general
|
||||
// this code branch should be unreachable.
|
||||
c.log.Warn("BUG: zero session id found")
|
||||
}
|
||||
c.log.Warn("Re-generating keys (server not found key that we provided)")
|
||||
if err := c.createAuthKey(ctx); err != nil {
|
||||
return errors.Wrap(err, "unable to create auth key")
|
||||
}
|
||||
c.log.Info("Re-created auth keys")
|
||||
// Request will be retried by ack loop.
|
||||
// Probably we can speed-up this.
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) readLoop(ctx context.Context) (err error) {
|
||||
log := c.log.Named("read")
|
||||
log.Debug("Read loop started")
|
||||
defer func() {
|
||||
l := log
|
||||
if err != nil {
|
||||
l = log.With(zap.NamedError("reason", err))
|
||||
}
|
||||
l.Debug("Read loop done")
|
||||
}()
|
||||
|
||||
var (
|
||||
// Last error encountered by consumeMessage.
|
||||
lastErr atomic.Value
|
||||
// To wait all spawned goroutines
|
||||
handlers sync.WaitGroup
|
||||
)
|
||||
defer handlers.Wait()
|
||||
|
||||
for {
|
||||
// We've tried multiple ways to reduce allocations via reusing buffer,
|
||||
// but naive implementation induces high idle memory waste.
|
||||
//
|
||||
// Proper optimization will probably require total rework of bin.Buffer
|
||||
// with sharded (by payload size?) pool that can be used after message
|
||||
// size read (after readLen).
|
||||
//
|
||||
// Such optimization can introduce additional complexity overhead and
|
||||
// is probably not worth it.
|
||||
buf := &bin.Buffer{}
|
||||
|
||||
// Halting if consumeMessage encountered error.
|
||||
// Should be something critical with crypto.
|
||||
if err, ok := lastErr.Load().(error); ok && err != nil {
|
||||
return errors.Wrap(err, "halting")
|
||||
}
|
||||
|
||||
if err := c.conn.Recv(ctx, buf); err != nil {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
if c.noUpdates(err) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
var protoErr *codec.ProtocolErr
|
||||
if errors.As(err, &protoErr) && protoErr.Code == codec.CodeAuthKeyNotFound {
|
||||
if err := c.handleAuthKeyNotFound(ctx); err != nil {
|
||||
return errors.Wrap(err, "auth key not found")
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return errors.Wrap(ctx.Err(), "read loop")
|
||||
default:
|
||||
return errors.Wrap(err, "read")
|
||||
}
|
||||
}
|
||||
|
||||
handlers.Add(1)
|
||||
go func() {
|
||||
defer handlers.Done()
|
||||
|
||||
// Spawning goroutine per incoming message to utilize as much
|
||||
// resources as possible while keeping idle utilization low.
|
||||
//
|
||||
// The "worker" model was replaced by this due to idle utilization
|
||||
// overhead, especially on multi-CPU systems with multiple running
|
||||
// clients.
|
||||
if err := c.consumeMessage(ctx, buf); err != nil {
|
||||
log.Error("Failed to process message", zap.Error(err))
|
||||
lastErr.Store(errors.Wrap(err, "consume"))
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,115 @@
|
||||
package mtproto
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"io"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/gotd/neo"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/proto"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tdsync"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/testutil"
|
||||
)
|
||||
|
||||
func TestCheckMessageID(t *testing.T) {
|
||||
now := testutil.Date()
|
||||
t.Run("Good", func(t *testing.T) {
|
||||
for _, good := range []proto.MessageID{
|
||||
proto.NewMessageID(now, proto.MessageFromServer),
|
||||
proto.NewMessageID(now, proto.MessageServerResponse),
|
||||
proto.NewMessageID(now.Add(time.Second*29), proto.MessageFromServer),
|
||||
proto.NewMessageID(now.Add(-time.Second*299), proto.MessageFromServer),
|
||||
} {
|
||||
t.Run(good.String(), func(t *testing.T) {
|
||||
require.NoError(t, checkMessageID(now, int64(good)))
|
||||
})
|
||||
}
|
||||
})
|
||||
t.Run("Bad", func(t *testing.T) {
|
||||
for _, bad := range []proto.MessageID{
|
||||
proto.NewMessageID(now, proto.MessageFromClient),
|
||||
proto.NewMessageID(now.Add(time.Second*31), proto.MessageFromServer),
|
||||
proto.NewMessageID(now.Add(-time.Second*301), proto.MessageFromServer),
|
||||
proto.NewMessageID(time.Time{}, proto.MessageFromServer),
|
||||
proto.NewMessageID(now.AddDate(-10, 0, 0), proto.MessageServerResponse),
|
||||
proto.NewMessageID(time.Time{}, proto.MessageFromClient),
|
||||
} {
|
||||
t.Run(bad.String(), func(t *testing.T) {
|
||||
require.ErrorIs(t, checkMessageID(now, int64(bad)), errRejected)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func benchRead(payloadSize int) func(b *testing.B) {
|
||||
return func(b *testing.B) {
|
||||
a := require.New(b)
|
||||
logger := zap.NewNop()
|
||||
random := rand.Reader
|
||||
c := neo.NewTime(time.Now())
|
||||
|
||||
var key crypto.Key
|
||||
_, err := io.ReadFull(random, key[:])
|
||||
a.NoError(err)
|
||||
authKey := key.WithID()
|
||||
|
||||
payload := make([]byte, payloadSize)
|
||||
_, err = io.ReadFull(random, payload)
|
||||
a.NoError(err)
|
||||
|
||||
msg := new(bin.Buffer)
|
||||
serverCipher := crypto.NewServerCipher(random)
|
||||
id := proto.NewMessageIDGen(c.Now).New(proto.MessageServerResponse)
|
||||
a.NoError(msg.Encode(&testPayload{
|
||||
Data: payload,
|
||||
}))
|
||||
|
||||
length := msg.Len()
|
||||
data := msg.Copy()
|
||||
a.NoError(serverCipher.Encrypt(authKey, crypto.EncryptedMessageData{
|
||||
MessageID: id,
|
||||
SeqNo: 0,
|
||||
MessageDataLen: int32(length),
|
||||
MessageDataWithPadding: data,
|
||||
}, msg))
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
conn := Conn{
|
||||
conn: &constantConn{
|
||||
data: msg.Raw(),
|
||||
cancel: cancel,
|
||||
counter: b.N,
|
||||
},
|
||||
handler: nopHandler{},
|
||||
clock: c,
|
||||
rand: random,
|
||||
cipher: crypto.NewClientCipher(random),
|
||||
log: logger,
|
||||
messageIDBuf: noopBuf{},
|
||||
authKey: authKey,
|
||||
compressThreshold: -1,
|
||||
}
|
||||
grp := tdsync.NewCancellableGroup(ctx)
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
b.SetBytes(int64(payloadSize))
|
||||
|
||||
grp.Go(conn.readLoop)
|
||||
a.ErrorIs(grp.Wait(), context.Canceled)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkRead(b *testing.B) {
|
||||
testutil.RunPayloads(b, benchRead)
|
||||
}
|
||||
@@ -0,0 +1,71 @@
|
||||
package mtproto
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/mt"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/rpc"
|
||||
)
|
||||
|
||||
// Invoke sends input and decodes result into output.
|
||||
//
|
||||
// NOTE: Assuming that call contains content message (seqno increment).
|
||||
func (c *Conn) Invoke(ctx context.Context, input bin.Encoder, output bin.Decoder) error {
|
||||
msgID, seqNo := c.nextMsgSeq(true)
|
||||
req := rpc.Request{
|
||||
MsgID: msgID,
|
||||
SeqNo: seqNo,
|
||||
Input: input,
|
||||
Output: output,
|
||||
}
|
||||
|
||||
log := c.log.With(
|
||||
zap.Int64("msg_id", req.MsgID),
|
||||
)
|
||||
log.Debug("Invoke start")
|
||||
defer log.Debug("Invoke end")
|
||||
|
||||
if err := c.rpc.Do(ctx, req); err != nil {
|
||||
var badMsgErr *badMessageError
|
||||
if errors.As(err, &badMsgErr) && badMsgErr.Code == codeIncorrectServerSalt {
|
||||
// Should retry with new salt.
|
||||
c.log.Debug("Setting server salt")
|
||||
// Store salt from server.
|
||||
c.storeSalt(badMsgErr.NewSalt)
|
||||
// Reset saved salts to fetch new.
|
||||
c.salts.Reset()
|
||||
c.log.Info("Retrying request after basMsgErr", zap.Int64("msg_id", req.MsgID))
|
||||
return c.rpc.Do(ctx, req)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) dropRPC(req rpc.Request) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(),
|
||||
c.getTimeout(mt.RPCDropAnswerRequestTypeID),
|
||||
)
|
||||
defer cancel()
|
||||
|
||||
var resp mt.RPCDropAnswerBox
|
||||
if err := c.Invoke(ctx, &mt.RPCDropAnswerRequest{
|
||||
ReqMsgID: req.MsgID,
|
||||
}, &resp); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch resp.RpcDropAnswer.(type) {
|
||||
case *mt.RPCAnswerDropped, *mt.RPCAnswerDroppedRunning:
|
||||
return nil
|
||||
case *mt.RPCAnswerUnknown:
|
||||
return errors.New("answer unknown")
|
||||
default:
|
||||
return errors.Errorf("unexpected response type: %T", resp.RpcDropAnswer)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,52 @@
|
||||
package mtproto
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/mt"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/rpc"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/testutil"
|
||||
)
|
||||
|
||||
func TestConn_dropRPC(t *testing.T) {
|
||||
dropID := int64(10)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
result bin.Encoder
|
||||
resultErr error
|
||||
wantErr bool
|
||||
}{
|
||||
{"Dropped", &mt.RPCAnswerDropped{MsgID: dropID}, nil, false},
|
||||
{"DroppedRunning", &mt.RPCAnswerDroppedRunning{}, nil, false},
|
||||
{"Unknown", &mt.RPCAnswerUnknown{}, nil, true},
|
||||
{"Error", nil, testutil.TestError(), true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
|
||||
client := newTestClient(func(msgID int64, seqNo int32, body bin.Encoder) (bin.Encoder, error) {
|
||||
req, ok := body.(*mt.RPCDropAnswerRequest)
|
||||
a.True(ok)
|
||||
if ok {
|
||||
a.Equal(dropID, req.ReqMsgID)
|
||||
}
|
||||
return tt.result, tt.resultErr
|
||||
})
|
||||
|
||||
err := client.dropRPC(rpc.Request{
|
||||
MsgID: dropID,
|
||||
SeqNo: 1,
|
||||
})
|
||||
if tt.wantErr {
|
||||
a.Error(err)
|
||||
} else {
|
||||
a.NoError(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,74 @@
|
||||
package mtproto
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/mt"
|
||||
)
|
||||
|
||||
func (c *Conn) storeSalt(salt int64) {
|
||||
c.sessionMux.Lock()
|
||||
// Copy to log.
|
||||
oldSalt := c.salt
|
||||
c.salt = salt
|
||||
c.sessionMux.Unlock()
|
||||
|
||||
if salt != oldSalt {
|
||||
c.log.Info("Salt updated", zap.Int64("old", oldSalt), zap.Int64("new", salt))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) updateSalt() {
|
||||
salt, ok := c.salts.Get(c.clock.Now().Add(time.Minute * 5))
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
c.storeSalt(salt)
|
||||
}
|
||||
|
||||
const defaultSaltsNum = 4
|
||||
|
||||
func (c *Conn) getSalts(ctx context.Context) error {
|
||||
request := &mt.GetFutureSaltsRequest{
|
||||
Num: defaultSaltsNum,
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(ctx, c.getTimeout(request.TypeID()))
|
||||
defer cancel()
|
||||
|
||||
if err := c.writeServiceMessage(ctx, request); err != nil {
|
||||
return errors.Wrap(err, "request salts")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) saltLoop(ctx context.Context) error {
|
||||
select {
|
||||
case <-c.gotSession.Ready():
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
// Get salts first time.
|
||||
if err := c.getSalts(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ticker := c.clock.Ticker(c.saltFetchInterval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C():
|
||||
if err := c.getSalts(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,90 @@
|
||||
// Package salts contains MTProto server salt storage.
|
||||
package salts
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/mt"
|
||||
)
|
||||
|
||||
// Salts is a simple struct store server salts.
|
||||
type Salts struct {
|
||||
// server salts fetched by getSalts.
|
||||
salts []mt.FutureSalt
|
||||
saltsMux sync.Mutex
|
||||
}
|
||||
|
||||
// Get returns next valid salt.
|
||||
func (s *Salts) Get(deadline time.Time) (int64, bool) {
|
||||
s.saltsMux.Lock()
|
||||
defer s.saltsMux.Unlock()
|
||||
|
||||
check:
|
||||
if len(s.salts) < 1 {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
date := int(deadline.Unix())
|
||||
if salt := s.salts[len(s.salts)-1]; salt.ValidUntil > date {
|
||||
return salt.Salt, true
|
||||
}
|
||||
|
||||
// Filter (in place) from SliceTricks.
|
||||
n := 0
|
||||
// Check that the salt will be valid until deadline.
|
||||
for _, salt := range s.salts {
|
||||
// Filter expired salts.
|
||||
if salt.ValidUntil > date {
|
||||
// Keep valid salt.
|
||||
s.salts[n] = salt
|
||||
n++
|
||||
}
|
||||
}
|
||||
s.salts = s.salts[:n]
|
||||
goto check
|
||||
}
|
||||
|
||||
type saltSlice []mt.FutureSalt
|
||||
|
||||
func (s saltSlice) Len() int {
|
||||
return len(s)
|
||||
}
|
||||
|
||||
func (s saltSlice) Less(i, j int) bool {
|
||||
return s[i].ValidUntil > s[j].ValidUntil
|
||||
}
|
||||
|
||||
func (s saltSlice) Swap(i, j int) {
|
||||
s[i], s[j] = s[j], s[i]
|
||||
}
|
||||
|
||||
// Store stores all given salts.
|
||||
func (s *Salts) Store(salts []mt.FutureSalt) {
|
||||
s.saltsMux.Lock()
|
||||
defer s.saltsMux.Unlock()
|
||||
|
||||
s.salts = append(s.salts, salts...)
|
||||
// Filter duplicates.
|
||||
n := 0
|
||||
dedup := make(map[int64]struct{}, len(s.salts)+1)
|
||||
for _, salt := range s.salts {
|
||||
if _, ok := dedup[salt.Salt]; !ok {
|
||||
dedup[salt.Salt] = struct{}{}
|
||||
s.salts[n] = salt
|
||||
n++
|
||||
}
|
||||
}
|
||||
s.salts = s.salts[:n]
|
||||
|
||||
// Sort slice by valid until.
|
||||
sort.Sort(saltSlice(s.salts))
|
||||
}
|
||||
|
||||
// Reset deletes all stored salts.
|
||||
func (s *Salts) Reset() {
|
||||
s.saltsMux.Lock()
|
||||
s.salts = s.salts[:0]
|
||||
s.saltsMux.Unlock()
|
||||
}
|
||||
@@ -0,0 +1,112 @@
|
||||
package salts
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/mt"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/testutil"
|
||||
)
|
||||
|
||||
func generateSalts(n int) []mt.FutureSalt {
|
||||
r := make([]mt.FutureSalt, n)
|
||||
for i := range r {
|
||||
since := (i + 1) * 10
|
||||
|
||||
r[i] = mt.FutureSalt{
|
||||
ValidSince: since,
|
||||
ValidUntil: since + 15,
|
||||
Salt: int64(i),
|
||||
}
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
func TestSalts(t *testing.T) {
|
||||
a := require.New(t)
|
||||
salts := &Salts{}
|
||||
var testData = []mt.FutureSalt{
|
||||
{
|
||||
ValidSince: 10,
|
||||
ValidUntil: 25,
|
||||
Salt: 1,
|
||||
},
|
||||
{
|
||||
ValidSince: 20,
|
||||
ValidUntil: 35,
|
||||
Salt: 2,
|
||||
},
|
||||
{
|
||||
ValidSince: 30,
|
||||
ValidUntil: 45,
|
||||
Salt: 3,
|
||||
},
|
||||
}
|
||||
|
||||
salts.Store(testData[:2])
|
||||
a.Len(salts.salts, 2)
|
||||
|
||||
salt, ok := salts.Get(time.Unix(11, 0))
|
||||
a.Equal(int64(1), salt)
|
||||
a.True(ok)
|
||||
|
||||
_, ok = salts.Get(time.Unix(36, 0))
|
||||
a.False(ok)
|
||||
|
||||
salts.Store(testData[:2])
|
||||
a.Len(salts.salts, 2)
|
||||
|
||||
salts.Store(testData[:3])
|
||||
a.Len(salts.salts, 3)
|
||||
|
||||
salt, ok = salts.Get(time.Unix(26, 0))
|
||||
a.Equal(int64(2), salt)
|
||||
a.True(ok)
|
||||
|
||||
salt, ok = salts.Get(time.Unix(36, 0))
|
||||
a.Equal(int64(3), salt)
|
||||
a.True(ok)
|
||||
|
||||
salts.Reset()
|
||||
_, ok = salts.Get(time.Unix(36, 0))
|
||||
a.False(ok)
|
||||
}
|
||||
|
||||
func TestSalts_Get(t *testing.T) {
|
||||
salts := &Salts{}
|
||||
salts.Store(generateSalts(64))
|
||||
|
||||
now := time.Unix(11, 0)
|
||||
testutil.ZeroAlloc(t, func() {
|
||||
salts.Get(now)
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkSalts_Get(b *testing.B) {
|
||||
salts := &Salts{}
|
||||
salts.Store(generateSalts(64))
|
||||
t := time.Unix(11, 0)
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
salts.Get(t)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSalts_Store(b *testing.B) {
|
||||
testData := generateSalts(64)
|
||||
salts := &Salts{
|
||||
salts: make([]mt.FutureSalt, 0, len(testData)),
|
||||
}
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
salts.Store(testData)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,37 @@
|
||||
package mtproto
|
||||
|
||||
import "go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
|
||||
|
||||
// Session represents connection state.
|
||||
type Session struct {
|
||||
ID int64
|
||||
Key crypto.AuthKey
|
||||
Salt int64
|
||||
}
|
||||
|
||||
// Session returns current connection session info.
|
||||
func (c *Conn) session() Session {
|
||||
c.updateSalt()
|
||||
|
||||
c.sessionMux.RLock()
|
||||
defer c.sessionMux.RUnlock()
|
||||
return Session{
|
||||
Key: c.authKey,
|
||||
Salt: c.salt,
|
||||
ID: c.sessionID,
|
||||
}
|
||||
}
|
||||
|
||||
// newSessionID sets session id to random value.
|
||||
func (c *Conn) newSessionID() error {
|
||||
id, err := crypto.RandInt64(c.rand)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.sessionMux.Lock()
|
||||
defer c.sessionMux.Unlock()
|
||||
c.sessionID = id
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,58 @@
|
||||
package mtproto
|
||||
|
||||
import (
|
||||
// For embedding public keys.
|
||||
_ "embed"
|
||||
"sync"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/exchange"
|
||||
)
|
||||
|
||||
var (
|
||||
// publicKeys is byte blob of new keys added for PQInnerData encryption (key exchange).
|
||||
//
|
||||
// See https://github.com/telegramdesktop/tdesktop/commit/95a7ce4622dc24717dc5b95fc99599dddfd4ff6c.
|
||||
//
|
||||
// See https://github.com/tdlib/td/commit/e9e24282378fcdb3a3ce020bee4253b65ac98213.
|
||||
//go:embed _data/public_keys.pem
|
||||
publicKeys []byte
|
||||
|
||||
parsedKeys struct {
|
||||
Keys []exchange.PublicKey
|
||||
Once sync.Once
|
||||
}
|
||||
)
|
||||
|
||||
//nolint:gochecknoinits
|
||||
func init() {
|
||||
makePublicKeys := func(data []byte) ([]exchange.PublicKey, error) {
|
||||
rsaKeys, err := crypto.ParseRSAPublicKeys(data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
keys := make([]exchange.PublicKey, 0, len(rsaKeys))
|
||||
for _, key := range rsaKeys {
|
||||
keys = append(keys, exchange.PublicKey{
|
||||
RSA: key,
|
||||
})
|
||||
}
|
||||
return keys, nil
|
||||
}
|
||||
parsedKeys.Once.Do(func() {
|
||||
keys, err := makePublicKeys(publicKeys)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
parsedKeys.Keys = append(parsedKeys.Keys, keys...)
|
||||
})
|
||||
}
|
||||
|
||||
// vendoredKeys parses vendored file _data/public_keys.pem as list of
|
||||
// PEM-encoded public RSA keys.
|
||||
//
|
||||
// Most recent key list can be found on https://my.telegram.org/apps.
|
||||
func vendoredKeys() []exchange.PublicKey {
|
||||
return parsedKeys.Keys
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
package mtproto
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestVendoredKeys(t *testing.T) {
|
||||
keys := vendoredKeys()
|
||||
if len(keys) == 0 {
|
||||
t.Fatal("empty keys")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,57 @@
|
||||
package mtproto
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
)
|
||||
|
||||
func (c *Conn) writeContentMessage(ctx context.Context, msgID int64, seqNo int32, message bin.Encoder) error {
|
||||
return c.write(ctx, msgID, seqNo, message)
|
||||
}
|
||||
|
||||
func (c *Conn) writeServiceMessage(ctx context.Context, message bin.Encoder) error {
|
||||
msgID, seqNo := c.nextMsgSeq(false)
|
||||
return c.write(ctx, msgID, seqNo, message)
|
||||
}
|
||||
|
||||
var bufPool = bin.NewPool(0)
|
||||
|
||||
func (c *Conn) write(ctx context.Context, msgID int64, seqNo int32, message bin.Encoder) error {
|
||||
// Grab shared lock for writing.
|
||||
// It prevents message sending during key regeneration if server forgot current auth key.
|
||||
c.exchangeLock.RLock()
|
||||
defer c.exchangeLock.RUnlock()
|
||||
|
||||
b := bufPool.Get()
|
||||
defer bufPool.Put(b)
|
||||
|
||||
if err := c.newEncryptedMessage(msgID, seqNo, message, b); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := c.conn.Send(ctx, b); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) nextMsgSeq(content bool) (msgID int64, seqNo int32) {
|
||||
c.reqMux.Lock()
|
||||
defer c.reqMux.Unlock()
|
||||
|
||||
msgID = c.newMessageID()
|
||||
|
||||
// Computing current sequence number (seqno).
|
||||
// This should be serialized with new message id generation.
|
||||
//
|
||||
// See https://github.com/gotd/td/issues/245 for reference.
|
||||
seqNo = c.sentContentMessages * 2
|
||||
if content {
|
||||
seqNo++
|
||||
c.sentContentMessages++
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
@@ -0,0 +1,57 @@
|
||||
package mtproto
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"io"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/gotd/neo"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/testutil"
|
||||
)
|
||||
|
||||
func benchWrite(payloadSize int) func(b *testing.B) {
|
||||
return func(b *testing.B) {
|
||||
a := require.New(b)
|
||||
logger := zap.NewNop()
|
||||
random := rand.Reader
|
||||
c := neo.NewTime(time.Now())
|
||||
|
||||
var key crypto.Key
|
||||
_, err := io.ReadFull(random, key[:])
|
||||
a.NoError(err)
|
||||
authKey := key.WithID()
|
||||
|
||||
payload := make([]byte, payloadSize)
|
||||
_, err = io.ReadFull(random, payload)
|
||||
a.NoError(err)
|
||||
data := &testPayload{Data: payload}
|
||||
|
||||
conn := Conn{
|
||||
conn: &constantConn{},
|
||||
clock: c,
|
||||
rand: random,
|
||||
cipher: crypto.NewClientCipher(random),
|
||||
log: logger,
|
||||
authKey: authKey,
|
||||
compressThreshold: -1,
|
||||
}
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
b.SetBytes(int64(payloadSize))
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = conn.write(context.Background(), 1, 1, data)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkWrite(b *testing.B) {
|
||||
testutil.RunPayloads(b, benchWrite)
|
||||
}
|
||||
@@ -0,0 +1,45 @@
|
||||
package mtproto
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
)
|
||||
|
||||
type logType struct {
|
||||
ID uint32
|
||||
Name string
|
||||
}
|
||||
|
||||
func (l logType) MarshalLogObject(e zapcore.ObjectEncoder) error {
|
||||
typeIDStr := fmt.Sprintf("0x%x", l.ID)
|
||||
e.AddString("type_id", typeIDStr)
|
||||
if l.Name != "" {
|
||||
e.AddString("type_name", l.Name)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) logWithBuffer(b *bin.Buffer) *zap.Logger {
|
||||
return c.logWithType(b).With(zap.Int("size_bytes", b.Len()))
|
||||
}
|
||||
|
||||
func (c *Conn) logWithType(b *bin.Buffer) *zap.Logger {
|
||||
id, err := b.PeekID()
|
||||
if err != nil {
|
||||
// Type info not available.
|
||||
return c.log
|
||||
}
|
||||
|
||||
return c.logWithTypeID(id)
|
||||
}
|
||||
|
||||
func (c *Conn) logWithTypeID(id uint32) *zap.Logger {
|
||||
return c.log.With(zap.Inline(logType{
|
||||
ID: id,
|
||||
Name: c.types.Get(id),
|
||||
}))
|
||||
}
|
||||
@@ -0,0 +1,28 @@
|
||||
package mtproto
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tmap"
|
||||
)
|
||||
|
||||
func BenchmarkConn_logWithType(b *testing.B) {
|
||||
c := Conn{
|
||||
log: zap.NewNop(),
|
||||
types: tmap.New(map[uint32]string{
|
||||
0x3fedd339: "true",
|
||||
}),
|
||||
}
|
||||
buf := bin.Buffer{}
|
||||
buf.PutID(0x3fedd339)
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
c.logWithType(&buf).Info("Hi!")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,8 @@
|
||||
//go:build !fuzz
|
||||
// +build !fuzz
|
||||
|
||||
package mtproto
|
||||
|
||||
import "go.mau.fi/mautrix-telegram/pkg/gotd/testutil"
|
||||
|
||||
type Zero = testutil.ZeroRand
|
||||
Reference in New Issue
Block a user