move gotd fork into repo. (#111)

- update to latest telegram layer
- remove some references to fields in tg.Entities that don't exist in
the schema
- originally added here:
https://github.com/beeper/td/commit/820929062a2ba0104397bc01235ab58a9cff780e
  - referenced here
-
https://github.com/mautrix/telegramgo/commit/124f0967ed195b5a380c9bd02e170ada9710dde3
-
https://github.com/mautrix/telegramgo/commit/4205047aab2e0639217148b5d125bfaab668bd8e
This commit is contained in:
Adam Van Ymeren
2025-06-27 20:03:37 -07:00
committed by GitHub
parent 0952df0244
commit 7a04f298d2
19264 changed files with 1539697 additions and 84 deletions
+16
View File
@@ -0,0 +1,16 @@
-----BEGIN RSA PUBLIC KEY-----
MIIBCgKCAQEAyMEdY1aR+sCR3ZSJrtztKTKqigvO/vBfqACJLZtS7QMgCGXJ6XIR
yy7mx66W0/sOFa7/1mAZtEoIokDP3ShoqF4fVNb6XeqgQfaUHd8wJpDWHcR2OFwv
plUUI1PLTktZ9uW2WE23b+ixNwJjJGwBDJPQEQFBE+vfmH0JP503wr5INS1poWg/
j25sIWeYPHYeOrFp/eXaqhISP6G+q2IeTaWTXpwZj4LzXq5YOpk4bYEQ6mvRq7D1
aHWfYmlEGepfaYR8Q0YqvvhYtMte3ITnuSJs171+GDqpdKcSwHnd6FudwGO4pcCO
j4WcDuXc2CTHgH8gFTNhp/Y8/SpDOhvn9QIDAQAB
-----END RSA PUBLIC KEY-----
-----BEGIN RSA PUBLIC KEY-----
MIIBCgKCAQEA6LszBcC1LGzyr992NzE0ieY+BSaOW622Aa9Bd4ZHLl+TuFQ4lo4g
5nKaMBwK/BIb9xUfg0Q29/2mgIR6Zr9krM7HjuIcCzFvDtr+L0GQjae9H0pRB2OO
62cECs5HKhT5DZ98K33vmWiLowc621dQuwKWSQKjWf50XYFw42h21P2KXUGyp2y/
+aEyZ+uVgLLQbRA1dEjSDZ2iGRy12Mk5gpYc397aYp438fsJoHIgJ2lgMv5h7WY9
t6N/byY9Nw9p21Og3AoXSL2q/2IJ1WRUhebgAdGVMlV1fkuOQoEzR7EdpqtQD9Cs
5+bfo3Nhmcyvk5ftB0WkJ9z6bNZ7yxrP8wIDAQAB
-----END RSA PUBLIC KEY-----
+46
View File
@@ -0,0 +1,46 @@
package mtproto
import (
"context"
"github.com/go-faster/errors"
"go.uber.org/zap"
"go.mau.fi/mautrix-telegram/pkg/gotd/mt"
)
func (c *Conn) ackLoop(ctx context.Context) error {
log := c.log.Named("ack")
var buf []int64
send := func() {
defer func() { buf = buf[:0] }()
if err := c.writeServiceMessage(ctx, &mt.MsgsAck{MsgIDs: buf}); err != nil {
c.log.Error("Failed to ACK", zap.Error(err))
return
}
log.Debug("Ack", zap.Int64s("msg_ids", buf))
}
ticker := c.clock.Ticker(c.ackInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return errors.Wrap(ctx.Err(), "acl")
case <-ticker.C():
if len(buf) > 0 {
send()
}
case msgID := <-c.ackSendChan:
buf = append(buf, msgID)
if len(buf) >= c.ackBatchSize {
send()
ticker.Reset(c.ackInterval)
}
}
}
}
+221
View File
@@ -0,0 +1,221 @@
package mtproto
import (
"context"
"io"
"sync"
"time"
"github.com/go-faster/errors"
"go.uber.org/atomic"
"go.uber.org/zap"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
"go.mau.fi/mautrix-telegram/pkg/gotd/clock"
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
"go.mau.fi/mautrix-telegram/pkg/gotd/exchange"
"go.mau.fi/mautrix-telegram/pkg/gotd/mtproto/salts"
"go.mau.fi/mautrix-telegram/pkg/gotd/proto"
"go.mau.fi/mautrix-telegram/pkg/gotd/rpc"
"go.mau.fi/mautrix-telegram/pkg/gotd/tdsync"
"go.mau.fi/mautrix-telegram/pkg/gotd/tmap"
"go.mau.fi/mautrix-telegram/pkg/gotd/transport"
)
// Handler will be called on received message from Telegram.
type Handler interface {
OnMessage(b *bin.Buffer) error
OnSession(session Session) error
}
// MessageIDSource is message id generator.
type MessageIDSource interface {
New(t proto.MessageType) int64
}
// MessageBuf is message id buffer.
type MessageBuf interface {
Consume(id int64) bool
}
// Cipher handles message encryption and decryption.
type Cipher interface {
DecryptFromBuffer(k crypto.AuthKey, buf *bin.Buffer) (*crypto.EncryptedMessageData, error)
Encrypt(key crypto.AuthKey, data crypto.EncryptedMessageData, b *bin.Buffer) error
}
// Dialer is an abstraction for MTProto transport connection creator.
type Dialer func(ctx context.Context) (transport.Conn, error)
// Conn represents a MTProto client to Telegram.
type Conn struct {
dcID int
dialer Dialer
conn transport.Conn
handler Handler
rpc *rpc.Engine
rsaPublicKeys []exchange.PublicKey
types *tmap.Map
// Wrappers for external world, like current time, logs or PRNG.
// Should be immutable.
clock clock.Clock
rand io.Reader
cipher Cipher
log *zap.Logger
messageID MessageIDSource
messageIDBuf MessageBuf // replay attack protection
// use session() to access authKey, salt or sessionID.
sessionMux sync.RWMutex
authKey crypto.AuthKey
salt int64
sessionID int64
// server salts fetched by getSalts.
salts salts.Salts
// sentContentMessages is count of created content messages, used to
// compute sequence number within session.
sentContentMessages int32
reqMux sync.Mutex
// ackSendChan is queue for outgoing message id's that require waiting for
// ack from server.
ackSendChan chan int64
ackBatchSize int
ackInterval time.Duration
// callbacks for ping results.
// Key is ping id.
ping map[int64]chan struct{}
pingMux sync.Mutex
// pingTimeout sets ping_delay_disconnect delay.
pingTimeout time.Duration
// pingInterval is duration between ping_delay_disconnect request.
pingInterval time.Duration
pingCallback func()
// gotSession is a signal channel for wait for handleSessionCreated message.
gotSession *tdsync.Ready
// exchangeLock locks write calls during key exchange.
exchangeLock sync.RWMutex
// compressThreshold is a threshold in bytes to determine that message
// is large enough to be compressed using gzip.
compressThreshold int
dialTimeout time.Duration
exchangeTimeout time.Duration
saltFetchInterval time.Duration
getTimeout func(req uint32) time.Duration
// Ensure Run once.
ran atomic.Bool
}
// New creates new unstarted connection.
func New(dialer Dialer, opt Options) *Conn {
// Set default values, if user does not set.
opt.setDefaults()
conn := &Conn{
dcID: opt.DC,
dialer: dialer,
clock: opt.Clock,
rand: opt.Random,
cipher: opt.Cipher,
log: opt.Logger,
messageID: opt.MessageID,
messageIDBuf: proto.NewMessageIDBuf(100),
ackSendChan: make(chan int64),
ackInterval: opt.AckInterval,
ackBatchSize: opt.AckBatchSize,
rsaPublicKeys: opt.PublicKeys,
handler: opt.Handler,
types: opt.Types,
authKey: opt.Key,
salt: opt.Salt,
ping: map[int64]chan struct{}{},
pingTimeout: opt.PingTimeout,
pingInterval: opt.PingInterval,
pingCallback: opt.PingCallback,
gotSession: tdsync.NewReady(),
rpc: opt.engine,
compressThreshold: opt.CompressThreshold,
dialTimeout: opt.DialTimeout,
exchangeTimeout: opt.ExchangeTimeout,
saltFetchInterval: opt.SaltFetchInterval,
getTimeout: opt.RequestTimeout,
}
if conn.rpc == nil {
conn.rpc = rpc.New(conn.writeContentMessage, rpc.Options{
Logger: opt.Logger.Named("rpc"),
RetryInterval: opt.RetryInterval,
MaxRetries: opt.MaxRetries,
Clock: opt.Clock,
DropHandler: conn.dropRPC,
OnError: opt.OnError,
})
}
return conn
}
// handleClose closes rpc engine and underlying connection on context done.
func (c *Conn) handleClose(ctx context.Context) error {
<-ctx.Done()
c.log.Debug("Closing")
// Close RPC Engine.
c.rpc.ForceClose()
// Close connection.
if err := c.conn.Close(); err != nil {
c.log.Debug("Failed to cleanup connection", zap.Error(err))
}
return nil
}
// Run initializes MTProto connection to server and blocks until disconnection.
//
// When connection is ready, Handler.OnSession is called.
func (c *Conn) Run(ctx context.Context, f func(ctx context.Context) error) error {
// Starting connection.
//
// This will send initial packet to telegram and perform key exchange
// if needed.
if c.ran.Swap(true) {
return errors.New("do Run on closed connection")
}
ctx, cancel := context.WithCancel(ctx)
defer cancel()
c.log.Debug("Run: start")
defer c.log.Debug("Run: end")
if err := c.connect(ctx); err != nil {
return errors.Wrap(err, "start")
}
{
// All goroutines are bound to current call.
g := tdsync.NewLogGroup(ctx, c.log.Named("group"))
g.Go("handleClose", c.handleClose)
g.Go("pingLoop", c.pingLoop)
g.Go("ackLoop", c.ackLoop)
g.Go("saltsLoop", c.saltLoop)
g.Go("userCallback", f)
g.Go("readLoop", c.readLoop)
if err := g.Wait(); err != nil {
return errors.Wrap(err, "group")
}
}
return nil
}
+92
View File
@@ -0,0 +1,92 @@
package mtproto
import (
"context"
"crypto/md5"
"fmt"
"math/rand"
"os"
"path/filepath"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"go.uber.org/zap"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
"go.mau.fi/mautrix-telegram/pkg/gotd/mt"
"go.mau.fi/mautrix-telegram/pkg/gotd/proto"
"go.mau.fi/mautrix-telegram/pkg/gotd/rpc"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
"go.mau.fi/mautrix-telegram/pkg/gotd/tmap"
)
type testHandler func(msgID int64, seqNo int32, body bin.Encoder) (bin.Encoder, error)
type testClientOption func(o Options)
func newTestClient(h testHandler, opts ...testClientOption) *Conn {
var engine *rpc.Engine
engine = rpc.New(func(ctx context.Context, msgID int64, seqNo int32, in bin.Encoder) error {
if response, err := h(msgID, seqNo, in); err != nil {
engine.NotifyError(msgID, err)
} else {
var b bin.Buffer
if err := b.Encode(response); err != nil {
return err
}
return engine.NotifyResult(msgID, &b)
}
return nil
}, rpc.Options{})
opt := Options{
Logger: zap.NewNop(),
Random: rand.New(rand.NewSource(1)),
Key: crypto.Key{}.WithID(),
MessageID: proto.NewMessageIDGen(time.Now),
engine: engine,
}
for _, o := range opts {
o(opt)
}
return New(nil, opt)
}
// newCorpusTracer will save incoming messages to corpus folder.
//
// Usage:
//
// client.trace.OnMessage = newCorpusTracer(t)
//
// nolint: deadcode,unused // optional
func newCorpusTracer(t testing.TB) func(b *bin.Buffer) {
types := tmap.New(
mt.TypesMap(),
tg.TypesMap(),
proto.TypesMap(),
)
dir := filepath.Join("..", "_fuzz", "handle_message", "corpus")
return func(b *bin.Buffer) {
id, _ := b.PeekID()
h := md5.Sum(b.Buf)
name := types.Get(id)
if name == "" {
name = "unknown"
}
if idx := strings.Index(name, "#"); idx > 0 {
// Removing type id from name.
name = name[:idx]
}
base := fmt.Sprintf("trace_%x_%s_%x",
id, name, h,
)
assert.NoError(t, os.WriteFile(filepath.Join(dir, base), b.Buf, 0600))
}
}
+93
View File
@@ -0,0 +1,93 @@
package mtproto
import (
"context"
"github.com/go-faster/errors"
"go.uber.org/multierr"
"go.uber.org/zap"
"go.mau.fi/mautrix-telegram/pkg/gotd/exchange"
)
// connect establishes connection using configured transport, creating
// new auth key if needed.
func (c *Conn) connect(ctx context.Context) (rErr error) {
ctx, cancel := context.WithTimeout(ctx, c.dialTimeout)
defer cancel()
conn, err := c.dialer(ctx)
if err != nil {
return errors.Wrap(err, "dial failed")
}
c.conn = conn
defer func() {
if rErr != nil {
multierr.AppendInto(&rErr, conn.Close())
}
}()
session := c.session()
if session.Key.Zero() {
c.log.Info("Generating new auth key")
start := c.clock.Now()
if err := c.createAuthKey(ctx); err != nil {
return errors.Wrap(err, "create auth key")
}
c.log.Info("Auth key generated",
zap.Duration("duration", c.clock.Now().Sub(start)),
)
return nil
}
c.log.Info("Key already exists")
if session.ID == 0 {
// NB: Telegram can return 404 error if session id is zero.
//
// See https://github.com/gotd/td/issues/107.
c.log.Debug("Generating new session id")
if err := c.newSessionID(); err != nil {
return err
}
}
return nil
}
// createAuthKey generates new authorization key.
func (c *Conn) createAuthKey(ctx context.Context) error {
// Grab exclusive lock for writing.
// It prevents message sending during key regeneration if server forgot current auth key.
c.exchangeLock.Lock()
defer c.exchangeLock.Unlock()
if ce := c.log.Check(zap.DebugLevel, "Initializing new key exchange"); ce != nil {
// Useful for debugging i/o timeout errors on tcp reads or writes.
fields := []zap.Field{
zap.Duration("timeout", c.exchangeTimeout),
}
if deadline, ok := ctx.Deadline(); ok {
fields = append(fields, zap.Time("context_deadline", deadline))
}
ce.Write(fields...)
}
r, err := exchange.NewExchanger(c.conn, c.dcID).
WithClock(c.clock).
WithLogger(c.log.Named("exchange")).
WithTimeout(c.exchangeTimeout).
WithRand(c.rand).
Client(c.rsaPublicKeys).Run(ctx)
if err != nil {
return err
}
c.sessionMux.Lock()
c.authKey = r.AuthKey
c.sessionID = r.SessionID
c.salt = r.ServerSalt
c.sessionMux.Unlock()
return nil
}
+74
View File
@@ -0,0 +1,74 @@
package mtproto
import (
"context"
"crypto/rand"
"io"
"testing"
"testing/iotest"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
"go.mau.fi/mautrix-telegram/pkg/gotd/clock"
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
"go.mau.fi/mautrix-telegram/pkg/gotd/transport"
)
type closeConn struct {
closed bool
}
func (c *closeConn) Send(ctx context.Context, b *bin.Buffer) error {
return io.EOF
}
func (c *closeConn) Recv(ctx context.Context, b *bin.Buffer) error {
return io.EOF
}
func (c *closeConn) Close() error {
c.closed = true
return nil
}
func TestConn_connect(t *testing.T) {
t.Run("EnsureClose", func(t *testing.T) {
t.Run("Exchange", func(t *testing.T) {
a := require.New(t)
closeMe := &closeConn{}
c := Conn{
dialer: func(ctx context.Context) (transport.Conn, error) {
return closeMe, nil
},
clock: clock.System,
rand: rand.Reader,
log: zap.NewNop(),
}
a.Error(c.connect(context.Background()))
a.True(closeMe.closed)
})
t.Run("SessionID", func(t *testing.T) {
a := require.New(t)
closeMe := &closeConn{}
c := Conn{
dialer: func(ctx context.Context) (transport.Conn, error) {
return closeMe, nil
},
clock: clock.System,
authKey: crypto.AuthKey{
ID: [8]byte{1}, // Skip exchange.
},
rand: iotest.ErrReader(io.EOF),
log: zap.NewNop(),
}
a.Error(c.connect(context.Background()))
a.True(closeMe.closed)
})
})
}
+2
View File
@@ -0,0 +1,2 @@
// Package mtproto implements MTProto connection.
package mtproto
+51
View File
@@ -0,0 +1,51 @@
package mtproto
import (
"fmt"
"testing"
"time"
"go.uber.org/zap"
"github.com/gotd/neo"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
"go.mau.fi/mautrix-telegram/pkg/gotd/testutil"
)
func benchEncryption(b *testing.B, c *Conn, n int) {
b.Helper()
buf := &bin.Buffer{Buf: make([]byte, 0, n)}
p := testPayload{Data: make([]byte, n-4)}
b.ReportAllocs()
b.SetBytes(int64(n))
b.ResetTimer()
for i := 0; i < b.N; i++ {
buf.Reset()
if err := c.newEncryptedMessage(12345, 0, p, buf); err != nil {
b.Fatal(err)
}
}
}
func BenchmarkEncryption(b *testing.B) {
c := &Conn{
rand: Zero{},
log: zap.NewNop(),
cipher: crypto.NewClientCipher(Zero{}),
clock: neo.NewTime(time.Now()),
compressThreshold: -1,
}
for i := 0; i < 256; i++ {
c.authKey.Value[i] = byte(i)
}
for _, payload := range testutil.Payloads() {
b.Run(fmt.Sprintf("%db", payload), func(b *testing.B) {
benchEncryption(b, c, payload)
})
}
}
+21
View File
@@ -0,0 +1,21 @@
package mtproto
import (
"github.com/go-faster/errors"
"go.uber.org/zap"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
"go.mau.fi/mautrix-telegram/pkg/gotd/mt"
)
func (c *Conn) handleAck(b *bin.Buffer) error {
var ack mt.MsgsAck
if err := ack.Decode(b); err != nil {
return errors.Wrap(err, "decode")
}
c.log.Debug("Received ack", zap.Int64s("msg_ids", ack.MsgIDs))
c.rpc.NotifyAcks(ack.MsgIDs)
return nil
}
+68
View File
@@ -0,0 +1,68 @@
package mtproto
import (
"fmt"
"github.com/go-faster/errors"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
"go.mau.fi/mautrix-telegram/pkg/gotd/mt"
)
type badMessageError struct {
Code int
NewSalt int64
}
const (
codeMessageIDTooLow = 16
codeMessageIDTooHigh = 17
codeIncorrectServerSalt = 48
)
func (c badMessageError) Error() string {
description := map[int]string{
codeMessageIDTooLow: "msg_id too low",
codeMessageIDTooHigh: "msg_id too high",
codeIncorrectServerSalt: "incorrect server salt",
18: "incorrect two lower order msg_id bits",
19: "container msg_id is the same as msg_id of a previously received message",
20: "message too old",
32: "msg_seqno too low",
33: "msg_seqno too high",
34: "even msg_seqno expected, but odd received",
35: "odd msg_seqno expected, but even received",
}[c.Code]
if description == "" {
return fmt.Sprintf("bad msg error code %d", c.Code)
}
return description
}
func (c *Conn) handleBadMsg(b *bin.Buffer) error {
id, err := b.PeekID()
if err != nil {
return err
}
switch id {
case mt.BadMsgNotificationTypeID:
var bad mt.BadMsgNotification
if err := bad.Decode(b); err != nil {
return err
}
c.rpc.NotifyError(bad.BadMsgID, &badMessageError{Code: bad.ErrorCode})
return nil
case mt.BadServerSaltTypeID:
var bad mt.BadServerSalt
if err := bad.Decode(b); err != nil {
return err
}
c.rpc.NotifyError(bad.BadMsgID, &badMessageError{Code: bad.ErrorCode, NewSalt: bad.NewServerSalt})
return nil
default:
return errors.Errorf("unknown type id 0x%d", id)
}
}
+27
View File
@@ -0,0 +1,27 @@
package mtproto
import (
"github.com/go-faster/errors"
"go.mau.fi/mautrix-telegram/pkg/gotd/proto"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
)
func (c *Conn) handleContainer(msgID int64, b *bin.Buffer) error {
var container proto.MessageContainer
if err := container.Decode(b); err != nil {
return errors.Wrap(err, "container")
}
for _, msg := range container.Messages {
if err := c.processContainerMessage(msgID, msg); err != nil {
return err
}
}
return nil
}
func (c *Conn) processContainerMessage(msgID int64, msg proto.Message) error {
b := &bin.Buffer{Buf: msg.Body}
return c.handleMessage(msgID, b)
}
+25
View File
@@ -0,0 +1,25 @@
package mtproto
import (
"time"
"github.com/go-faster/errors"
"go.uber.org/zap"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
"go.mau.fi/mautrix-telegram/pkg/gotd/mt"
)
func (c *Conn) handleFutureSalts(b *bin.Buffer) error {
var res mt.FutureSalts
if err := res.Decode(b); err != nil {
return errors.Wrap(err, "error decode")
}
c.salts.Store(res.Salts)
serverTime := time.Unix(int64(res.Now), 0)
c.log.Debug("Got future salts", zap.Time("server_time", serverTime))
return nil
}
@@ -0,0 +1,52 @@
package mtproto
import (
"testing"
"time"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
"go.mau.fi/mautrix-telegram/pkg/gotd/mt"
)
func TestConn_handleFutureSalts(t *testing.T) {
now := time.Now()
ts := int(now.Unix())
testdata := []mt.FutureSalt{
{
ValidSince: ts - 1,
ValidUntil: ts + 1,
Salt: 10,
},
{
ValidSince: ts + 1,
ValidUntil: ts + 3,
Salt: 11,
},
}
t.Run("OK", func(t *testing.T) {
a := require.New(t)
conn := Conn{log: zap.NewNop()}
buf := bin.Buffer{}
a.NoError(buf.Encode(&mt.FutureSalts{
ReqMsgID: 1,
Now: 1,
Salts: testdata,
}))
a.NoError(conn.handleFutureSalts(&buf))
salt, ok := conn.salts.Get(now)
a.Equal(int64(10), salt)
a.True(ok)
})
t.Run("Invalid", func(t *testing.T) {
conn := Conn{}
buf := bin.Buffer{}
buf.PutID(mt.FutureSaltsTypeID)
require.Error(t, conn.handleFutureSalts(&buf))
})
}
+24
View File
@@ -0,0 +1,24 @@
package mtproto
import (
"github.com/go-faster/errors"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
"go.mau.fi/mautrix-telegram/pkg/gotd/proto"
)
func gzip(b *bin.Buffer) (*bin.Buffer, error) {
var content proto.GZIP
if err := content.Decode(b); err != nil {
return nil, errors.Wrap(err, "decode")
}
return &bin.Buffer{Buf: content.Data}, nil
}
func (c *Conn) handleGZIP(msgID int64, b *bin.Buffer) error {
content, err := gzip(b)
if err != nil {
return errors.Wrap(err, "unzip")
}
return c.handleMessage(msgID, content)
}
+55
View File
@@ -0,0 +1,55 @@
package mtproto
import (
"testing"
"github.com/stretchr/testify/require"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
"go.mau.fi/mautrix-telegram/pkg/gotd/proto"
)
func TestGzipDecode(t *testing.T) {
t.Run("Valid", func(t *testing.T) {
a := require.New(t)
// from grammers
// https://github.com/Lonami/grammers/blob/ed5470b3f79f5d41a5090450b866026568f3cd60/lib/grammers-mtproto/src/manual_tl.rs#L197
data := []byte{
1, 109, 92, 243, 132, 41, 150, 69, 54, 75, 49, 94, 161, 207, 114, 48, 254, 140, 1, 0,
31, 139, 8, 0, 0, 0, 0, 0, 0, 3, 149, 147, 61, 75, 195, 80, 20, 134, 79, 62, 20, 84,
170, 1, 17, 28, 68, 28, 132, 110, 183, 247, 38, 185, 249, 154, 58, 10, 130, 116, 116,
170, 54, 109, 18, 11, 173, 169, 109, 90, 112, 210, 209, 81, 112, 112, 22, 127, 131, 56,
232, 232, 224, 143, 112, 112, 238, 159, 168, 55, 31, 234, 77, 91, 82, 26, 184, 57, 36,
79, 222, 115, 222, 251, 38, 9, 170, 27, 218, 209, 62, 128, 113, 76, 234, 181, 83, 82,
55, 31, 175, 223, 101, 0, 216, 249, 120, 217, 219, 102, 181, 244, 244, 186, 203, 10, 8,
108, 109, 18, 221, 70, 132, 234, 136, 152, 20, 81, 13, 222, 132, 148, 43, 11, 184, 144,
241, 178, 138, 49, 113, 176, 171, 90, 142, 175, 106, 45, 199, 79, 46, 217, 145, 63, 53,
126, 117, 241, 92, 49, 215, 215, 48, 17, 197, 185, 185, 179, 156, 252, 113, 49, 227,
91, 60, 39, 148, 240, 190, 196, 127, 95, 134, 217, 116, 176, 238, 89, 177, 47, 181,
200, 151, 180, 156, 206, 229, 247, 35, 229, 252, 176, 156, 8, 198, 252, 126, 138, 184,
144, 241, 57, 57, 106, 139, 114, 148, 167, 115, 178, 73, 46, 199, 34, 46, 100, 124,
206, 126, 245, 162, 185, 98, 166, 227, 242, 55, 16, 81, 49, 159, 227, 18, 125, 93, 222,
207, 202, 76, 14, 126, 172, 163, 69, 126, 148, 76, 87, 178, 9, 139, 213, 66, 148, 185,
177, 48, 0, 159, 211, 52, 167, 118, 198, 27, 189, 145, 134, 6, 145, 215, 65, 205, 176,
11, 240, 201, 158, 171, 150, 36, 104, 177, 90, 211, 37, 184, 99, 63, 11, 30, 2, 124,
63, 200, 73, 253, 98, 141, 206, 199, 233, 119, 18, 63, 11, 207, 34, 76, 38, 147, 155,
120, 193, 248, 48, 185, 23, 207, 186, 117, 214, 146, 26, 247, 57, 56, 1, 184, 63, 19,
18, 189, 82, 102, 51, 47, 162, 168, 55, 112, 42, 149, 8, 117, 189, 10, 123, 247, 65,
219, 95, 247, 195, 97, 127, 112, 53, 108, 244, 61, 144, 221, 246, 101, 192, 116, 171,
65, 24, 6, 29, 47, 13, 83, 73, 203, 15, 58, 186, 13, 141, 216, 3, 0, 0,
}
var result proto.Result
err := result.Decode(&bin.Buffer{Buf: data})
a.NoError(err)
r, err := gzip(&bin.Buffer{Buf: result.Result})
a.NoError(err)
a.Equal(r.Len(), 984)
})
t.Run("Invalid", func(t *testing.T) {
_, err := gzip(new(bin.Buffer))
require.Error(t, err)
})
}
+44
View File
@@ -0,0 +1,44 @@
package mtproto
import (
"github.com/go-faster/errors"
"go.uber.org/zap"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
"go.mau.fi/mautrix-telegram/pkg/gotd/mt"
"go.mau.fi/mautrix-telegram/pkg/gotd/proto"
)
func (c *Conn) handleMessage(msgID int64, b *bin.Buffer) error {
id, err := b.PeekID()
if err != nil {
// Empty body.
return errors.Wrap(err, "peek message type")
}
c.logWithBuffer(b).Debug("Handle message", zap.Int64("msg_id", msgID))
switch id {
case mt.NewSessionCreatedTypeID:
return c.handleSessionCreated(b)
case mt.BadMsgNotificationTypeID, mt.BadServerSaltTypeID:
return c.handleBadMsg(b)
case mt.FutureSaltsTypeID:
return c.handleFutureSalts(b)
case proto.MessageContainerTypeID:
return c.handleContainer(msgID, b)
case proto.ResultTypeID:
return c.handleResult(b)
case mt.PongTypeID:
return c.handlePong(b)
case mt.MsgsAckTypeID:
return c.handleAck(b)
case proto.GZIPTypeID:
return c.handleGZIP(msgID, b)
case mt.MsgDetailedInfoTypeID,
mt.MsgNewDetailedInfoTypeID:
return nil
default:
return c.handler.OnMessage(b)
}
}
+84
View File
@@ -0,0 +1,84 @@
//go:build fuzz
// +build fuzz
package mtproto
import (
"time"
"github.com/go-faster/errors"
"go.uber.org/zap"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
"go.mau.fi/mautrix-telegram/pkg/gotd/mt"
"go.mau.fi/mautrix-telegram/pkg/gotd/proto"
"go.mau.fi/mautrix-telegram/pkg/gotd/rpc"
"go.mau.fi/mautrix-telegram/pkg/gotd/testutil"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
"go.mau.fi/mautrix-telegram/pkg/gotd/tmap"
)
type fuzzHandler struct {
types *tmap.Constructor
}
func (h fuzzHandler) OnMessage(b *bin.Buffer) error {
id, err := b.PeekID()
if err != nil {
return err
}
v := h.types.New(id)
if v == nil {
return errors.New("not found")
}
if err := v.Decode(b); err != nil {
return errors.Wrap(err, "decode")
}
// Performing decode cycle.
var newBuff bin.Buffer
newV := h.types.New(id)
if err := v.Encode(&newBuff); err != nil {
panic(err)
}
if err := newV.Decode(&newBuff); err != nil {
panic(err)
}
return nil
}
func (fuzzHandler) OnSession(session Session) error { return nil }
var (
conn *Conn
buf *bin.Buffer
)
func init() {
handler := fuzzHandler{
// Handler will try to dynamically decode any incoming message.
types: tmap.NewConstructor(
tg.TypesConstructorMap(),
mt.TypesConstructorMap(),
),
}
c := &Conn{
rand: testutil.ZeroRand{},
rpc: rpc.New(rpc.NopSend, rpc.Options{}),
log: zap.NewNop(),
messageID: proto.NewMessageIDGen(time.Now),
handler: handler,
}
conn = c
buf = &bin.Buffer{}
}
func FuzzHandleMessage(data []byte) int {
buf.ResetTo(data)
if err := conn.handleMessage(buf); err != nil {
return 0
}
return 1
}
+161
View File
@@ -0,0 +1,161 @@
package mtproto
import (
"fmt"
"os"
"path/filepath"
"testing"
"time"
"github.com/go-faster/errors"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
"github.com/gotd/neo"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
"go.mau.fi/mautrix-telegram/pkg/gotd/mt"
"go.mau.fi/mautrix-telegram/pkg/gotd/rpc"
"go.mau.fi/mautrix-telegram/pkg/gotd/tdsync"
"go.mau.fi/mautrix-telegram/pkg/gotd/testutil"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
"go.mau.fi/mautrix-telegram/pkg/gotd/tmap"
)
type testUpdateHandler struct {
types *tmap.Constructor
}
func (h testUpdateHandler) OnMessage(b *bin.Buffer) error {
id, err := b.PeekID()
if err != nil {
return err
}
v := h.types.New(id)
if v == nil {
return errors.New("not found")
}
if err := v.Decode(b); err != nil {
return errors.Wrap(err, "decode")
}
return nil
}
func (testUpdateHandler) OnSession(session Session) error { return nil }
func newTestHandler() Handler {
return &testUpdateHandler{
types: tmap.NewConstructor(
tg.TypesConstructorMap(),
mt.TypesConstructorMap(),
),
}
}
func TestConnHandleMessage(t *testing.T) {
c := &Conn{
rand: Zero{},
log: zap.NewNop(),
handler: newTestHandler(),
}
for i, input := range []string{
"\xdc\xf8\xf1stewa\x00O\x03expired c" +
"ertificate\x02\x00\x00\x00\xef",
"\x01m\\\xf300000000\x19\xcaD!0000" +
"\x1100000000000000000",
"\x01m\\\xf300000000\x19\xcaD!0000" +
"\xfe0",
"@B\xaet\x15ĵ\x1c0000,\x8f\xf8B0000" +
"00000000\x15ĵ\x1c0000\xff000" +
"00000000000000000000" +
"00000000000000000000" +
"00000000000000000000" +
"00000000000000000000" +
"00000000000000000000" +
"00000000000000000000" +
"00000000000000000000" +
"00000000000000000000" +
"00000000000000000000" +
"00000000000000000000" +
"00000000000000000000" +
"00000000000000000000" +
"000000000000",
} {
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
if err := c.handleMessage(0, &bin.Buffer{Buf: []byte(input)}); err == nil {
t.Fatal("error expected")
}
})
}
}
func TestConnHandleMessageCorpus(t *testing.T) {
if testutil.Race {
t.Skip("Skipped")
}
c := &Conn{
handler: newTestHandler(),
rpc: rpc.New(rpc.NopSend, rpc.Options{}),
clock: neo.NewTime(time.Now()),
rand: Zero{},
log: zap.NewNop(),
gotSession: tdsync.NewReady(),
}
corpusDir := filepath.Join("..", "_fuzz", "handle_message", "corpus")
corpus, err := os.ReadDir(corpusDir)
if err != nil {
t.Fatal(err)
}
b := &bin.Buffer{}
types := tmap.New(
tg.TypesMap(),
mt.TypesMap(),
)
for _, f := range corpus {
t.Run(f.Name(), func(t *testing.T) {
data, err := os.ReadFile(filepath.Join(corpusDir, f.Name()))
require.NoError(t, err)
// Default to 128 bytes per invocation.
allocThreshold := 128
// Adjusting threshold for specific types.
//
// Probably there should be better way to do this, but
// manually ensuring allocation distribution by type is
// pretty ok.
b.ResetTo(data)
if id, err := b.PeekID(); err == nil {
t.Logf("Type: 0x%x %s", id, types.Get(id))
switch id {
case tg.UpdatesTypeID,
tg.TextFixedTypeID,
tg.InputPeerChannelFromMessageTypeID,
tg.PageBlockRelatedArticlesTypeID:
allocThreshold = 512
case tg.TextBoldTypeID,
tg.TextItalicTypeID,
tg.TextMarkedTypeID,
tg.MessageTypeID,
tg.PageBlockCoverTypeID,
tg.InputMediaUploadedDocumentTypeID,
tg.SecureRequiredTypeOneOfTypeID:
allocThreshold = 256
}
}
testutil.MaxAlloc(t, allocThreshold, func() {
b.ResetTo(data)
_ = c.handleMessage(0, b)
})
})
}
}
+66
View File
@@ -0,0 +1,66 @@
package mtproto
import (
"github.com/go-faster/errors"
"go.uber.org/zap"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
"go.mau.fi/mautrix-telegram/pkg/gotd/mt"
"go.mau.fi/mautrix-telegram/pkg/gotd/proto"
"go.mau.fi/mautrix-telegram/pkg/gotd/tgerr"
)
func (c *Conn) handleResult(b *bin.Buffer) error {
// Response to an RPC query.
var res proto.Result
if err := res.Decode(b); err != nil {
return errors.Wrap(err, "decode")
}
// Now b contains result message.
b.ResetTo(res.Result)
msgID := zap.Int64("msg_id", res.RequestMessageID)
c.logWithBuffer(b).Debug("Handle result", msgID)
// Handling gzipped results.
id, err := b.PeekID()
if err != nil {
return err
}
if id == proto.GZIPTypeID {
content, err := gzip(b)
if err != nil {
return errors.Wrap(err, "decompress")
}
// Replacing buffer so callback will deal with uncompressed data.
b = content
c.logWithBuffer(b).Debug("Decompressed", msgID)
// Replacing id with inner id if error is compressed for any reason.
if id, err = b.PeekID(); err != nil {
return errors.Wrap(err, "peek id")
}
}
if id == mt.RPCErrorTypeID {
var rpcErr mt.RPCError
if err := rpcErr.Decode(b); err != nil {
return errors.Wrap(err, "error decode")
}
c.log.Debug("Got error", msgID,
zap.Int("err_code", rpcErr.ErrorCode),
zap.String("err_msg", rpcErr.ErrorMessage),
)
c.rpc.NotifyError(res.RequestMessageID, tgerr.New(rpcErr.ErrorCode, rpcErr.ErrorMessage))
return nil
}
if id == mt.PongTypeID {
return c.handlePong(b)
}
return c.rpc.NotifyResult(res.RequestMessageID, b)
}
@@ -0,0 +1,40 @@
package mtproto
import (
"github.com/go-faster/errors"
"go.uber.org/zap"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
"go.mau.fi/mautrix-telegram/pkg/gotd/mt"
"go.mau.fi/mautrix-telegram/pkg/gotd/proto"
)
func (c *Conn) handleSessionCreated(b *bin.Buffer) error {
var s mt.NewSessionCreated
if err := s.Decode(b); err != nil {
return errors.Wrap(err, "decode")
}
c.gotSession.Signal()
created := proto.MessageID(s.FirstMsgID).Time()
now := c.clock.Now()
c.log.Debug("Session created",
zap.Int64("unique_id", s.UniqueID),
zap.Int64("first_msg_id", s.FirstMsgID),
zap.Time("first_msg_time", created.Local()),
)
if (created.Before(now) && now.Sub(created) > maxPast) || created.Sub(now) > maxFuture {
c.log.Warn("Local clock needs synchronization",
zap.Time("first_msg_time", created),
zap.Time("local", now),
zap.Duration("time_difference", now.Sub(created)),
)
}
c.storeSalt(s.ServerSalt)
if err := c.handler.OnSession(c.session()); err != nil {
return errors.Wrap(err, "handler.OnSession")
}
return nil
}
@@ -0,0 +1,61 @@
package mtproto
import (
"testing"
"time"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"go.uber.org/zap/zaptest/observer"
"github.com/gotd/neo"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
"go.mau.fi/mautrix-telegram/pkg/gotd/mt"
"go.mau.fi/mautrix-telegram/pkg/gotd/proto"
"go.mau.fi/mautrix-telegram/pkg/gotd/tdsync"
)
func TestConn_handleSessionCreated(t *testing.T) {
t.Run("NeedSynchronization", func(t *testing.T) {
a := require.New(t)
logger, logs := observer.New(zapcore.WarnLevel)
now := time.Unix(100, 0)
clock := neo.NewTime(now)
gotSession := tdsync.NewReady()
conn := Conn{
clock: clock,
log: zap.New(logger),
gotSession: gotSession,
handler: newTestHandler(),
}
buf := bin.Buffer{}
msgID := proto.NewMessageID(now.Add(maxFuture+time.Second), proto.MessageFromClient)
a.NoError(buf.Encode(&mt.NewSessionCreated{
FirstMsgID: int64(msgID),
UniqueID: 10,
ServerSalt: 10,
}))
a.NoError(conn.handleSessionCreated(&buf))
select {
case <-gotSession.Ready():
default:
t.Fatal("expected gotSession signal")
}
a.Equal(int64(10), conn.salt)
msgs := logs.All()
a.Len(msgs, 1)
a.Equal("Local clock needs synchronization", msgs[0].Message)
})
t.Run("Invalid", func(t *testing.T) {
conn := Conn{}
buf := bin.Buffer{}
buf.PutID(mt.NewSessionCreatedTypeID)
require.Error(t, conn.handleSessionCreated(&buf))
})
}
+7
View File
@@ -0,0 +1,7 @@
package mtproto
import "go.mau.fi/mautrix-telegram/pkg/gotd/proto"
func (c *Conn) newMessageID() int64 {
return c.messageID.New(proto.MessageFromClient)
}
+25
View File
@@ -0,0 +1,25 @@
package mtproto
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"go.mau.fi/mautrix-telegram/pkg/gotd/proto"
)
func TestClientNewMessageID(t *testing.T) {
c := newTestClient(nil)
now := c.clock.Now()
id := proto.MessageID(c.newMessageID())
assert.Equal(t, proto.MessageFromClient, id.Type())
lag := id.Time().Sub(now)
if lag < 0 {
lag *= -1
}
if lag > time.Second {
t.Error("generated id lags in time")
}
}
+104
View File
@@ -0,0 +1,104 @@
package mtproto
import (
"context"
"sync"
"testing"
"github.com/stretchr/testify/require"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
)
func TestMsgSeq(t *testing.T) {
type record struct {
msgID int64
seqNo int32
}
var (
records []record
mux sync.Mutex
)
const (
workers = 32
requestsPerWorker = 200
)
client := newTestClient(func(msgID int64, seqNo int32, body bin.Encoder) (bin.Encoder, error) {
mux.Lock()
records = append(records, record{msgID, seqNo})
mux.Unlock()
return &tg.Config{}, nil
})
{
var wg sync.WaitGroup
wg.Add(workers)
for i := 0; i < workers; i++ {
go func(t *testing.T) {
defer wg.Done()
for i := 0; i < requestsPerWorker; i++ {
err := client.Invoke(context.Background(), &tg.Config{}, &tg.Config{})
require.NoError(t, err)
}
}(t)
}
wg.Wait()
}
var received []record
less := func(msgID int64, f func(r record)) {
for _, recv := range received {
if recv.msgID < msgID {
f(recv)
}
}
}
greater := func(msgID int64, f func(r record)) {
for _, recv := range received {
if recv.msgID > msgID {
f(recv)
}
}
}
contains := func(msgID int64) bool {
for _, rec := range received {
if rec.msgID == msgID {
return true
}
}
return false
}
for _, current := range records {
if contains(current.msgID) {
// Ignore duplicates.
continue
}
less(current.msgID, func(less record) {
// The server has already received a message with a lower msg_id
// but with either a higher or an equal and odd seqno.
if less.seqNo >= current.seqNo && less.seqNo%2 == 1 {
t.Fatal("seqNo too low")
}
})
greater(current.msgID, func(greater record) {
// Similarly, there is a message with a higher msg_id
// but with either a lower or an equal and odd seqno.
if greater.seqNo <= current.seqNo && greater.seqNo%2 == 1 {
t.Fatal("seqNo too high")
}
})
received = append(received, current)
}
}
+60
View File
@@ -0,0 +1,60 @@
package mtproto
import (
"context"
"sync"
"github.com/go-faster/errors"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
)
type testPayload struct {
Data []byte
}
func (d testPayload) Decode(b *bin.Buffer) error {
_, err := b.Bytes()
return err
}
func (d testPayload) Encode(b *bin.Buffer) error {
b.PutBytes(d.Data)
return nil
}
type noopBuf struct{}
func (n noopBuf) Consume(id int64) bool {
return true
}
type constantConn struct {
data []byte
cancel context.CancelFunc
counter int
mux sync.Mutex
}
func (c *constantConn) Send(ctx context.Context, b *bin.Buffer) error {
return nil
}
func (c *constantConn) Recv(ctx context.Context, b *bin.Buffer) error {
c.mux.Lock()
exit := c.counter == 0
if exit {
c.mux.Unlock()
c.cancel()
return errors.New("error")
}
c.counter--
c.mux.Unlock()
b.Put(c.data)
return nil
}
func (c *constantConn) Close() error {
return nil
}
+68
View File
@@ -0,0 +1,68 @@
package mtproto
import (
"github.com/go-faster/errors"
"go.uber.org/zap"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
"go.mau.fi/mautrix-telegram/pkg/gotd/proto"
)
func (c *Conn) newEncryptedMessage(id int64, seq int32, payload bin.Encoder, b *bin.Buffer) error {
s := c.session()
// TODO(tdakkota): Smarter gzip.
// 1) Generate Length() method for every encoder, to count length without encoding.
// 2) Re-use buffer instead of using yet one.
// 3) Do not send proto.GZIP if gzipped size is equal or bigger.
var (
d crypto.EncryptedMessageData
log = c.log
)
if c.compressThreshold <= 0 {
if obj, ok := payload.(interface{ TypeID() uint32 }); ok {
log = c.logWithTypeID(obj.TypeID())
}
d = crypto.EncryptedMessageData{
SessionID: s.ID,
Salt: s.Salt,
MessageID: id,
SeqNo: seq,
Message: payload,
}
} else {
payloadBuf := bufPool.Get()
defer bufPool.Put(payloadBuf)
if err := payload.Encode(payloadBuf); err != nil {
return errors.Wrap(err, "encode payload")
}
log = c.logWithType(payloadBuf)
if payloadBuf.Len() > c.compressThreshold {
d = crypto.EncryptedMessageData{
SessionID: s.ID,
Salt: s.Salt,
MessageID: id,
SeqNo: seq,
Message: proto.GZIP{Data: payloadBuf.Raw()},
}
} else {
d = crypto.EncryptedMessageData{
SessionID: s.ID,
Salt: s.Salt,
MessageID: id,
SeqNo: seq,
MessageDataLen: int32(payloadBuf.Len()),
MessageDataWithPadding: payloadBuf.Buf,
}
}
}
log.Debug("Request", zap.Int64("msg_id", id))
if err := c.cipher.Encrypt(s.Key, d, b); err != nil {
return errors.Wrap(err, "encrypt")
}
return nil
}
+167
View File
@@ -0,0 +1,167 @@
package mtproto
import (
"io"
"time"
"go.opentelemetry.io/otel/trace"
"go.uber.org/zap"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
"go.mau.fi/mautrix-telegram/pkg/gotd/clock"
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
"go.mau.fi/mautrix-telegram/pkg/gotd/exchange"
"go.mau.fi/mautrix-telegram/pkg/gotd/proto"
"go.mau.fi/mautrix-telegram/pkg/gotd/rpc"
"go.mau.fi/mautrix-telegram/pkg/gotd/tmap"
)
// Options of Conn.
type Options struct {
// DC is datacenter ID for key exchange.
// Defaults to 2.
DC int
// PublicKeys of telegram.
//
// If not provided, embedded public keys will be used.
PublicKeys []exchange.PublicKey
// Random is random source. Defaults to crypto.
Random io.Reader
// Logger is instance of zap.Logger. No logs by default.
Logger *zap.Logger
// Handler will be called on received message.
Handler Handler
// AckBatchSize is maximum ack-s to buffer.
AckBatchSize int
// AckInterval is maximum time to buffer ack.
AckInterval time.Duration
// RetryInterval is duration between retries.
RetryInterval time.Duration
// MaxRetries is max retry count until rpc request failure.
MaxRetries int
// DialTimeout is timeout of creating connection.
DialTimeout time.Duration
// ExchangeTimeout is timeout of every key exchange request.
ExchangeTimeout time.Duration
// SaltFetchInterval is duration between get_future_salts request.
SaltFetchInterval time.Duration
// PingTimeout sets ping_delay_disconnect timeout.
PingTimeout time.Duration
// PingInterval is duration between ping_delay_disconnect request.
PingInterval time.Duration
// PingCallback is called after a ping is acknowledged by the server.
PingCallback func()
// RequestTimeout is function which returns request timeout for given type ID.
RequestTimeout func(req uint32) time.Duration
// CompressThreshold is a threshold in bytes to determine that message
// is large enough to be compressed using GZIP.
// If < 0, compression will be disabled.
// If == 0, default value will be used.
CompressThreshold int
// MessageID is message id source. Share source between connection to
// reduce collision probability.
MessageID MessageIDSource
// Clock is current time source. Defaults to system time.
Clock clock.Clock
// Types map, used in verbose logging of incoming message.
Types *tmap.Map
// Key that can be used to restore previous connection.
Key crypto.AuthKey
// Salt from server that can be used to restore previous connection.
Salt int64
// Tracer for OTEL.
Tracer trace.Tracer
// Private options.
// Cipher defines message crypto.
Cipher Cipher
// engine for replacing RPC engine.
engine *rpc.Engine
// OnError is a function that is called if there is any error.
OnError func(err error)
}
type nopHandler struct{}
func (nopHandler) OnMessage(b *bin.Buffer) error { return nil }
func (nopHandler) OnSession(session Session) error { return nil }
func (opt *Options) setDefaultPublicKeys() {
// Using public keys that are included with distribution if not
// provided.
//
// This should never fail and keys should be valid for recent
// library versions.
opt.PublicKeys = vendoredKeys()
}
func (opt *Options) setDefaults() {
if opt.DC == 0 {
opt.DC = 2
}
if opt.Random == nil {
opt.Random = crypto.DefaultRand()
}
if opt.Logger == nil {
opt.Logger = zap.NewNop()
}
if opt.AckBatchSize == 0 {
opt.AckBatchSize = 20
}
if opt.AckInterval == 0 {
opt.AckInterval = 15 * time.Second
}
if opt.RetryInterval == 0 {
opt.RetryInterval = 5 * time.Second
}
if opt.MaxRetries == 0 {
opt.MaxRetries = 5
}
if opt.DialTimeout == 0 {
opt.DialTimeout = 35 * time.Second
}
if opt.ExchangeTimeout == 0 {
opt.ExchangeTimeout = exchange.DefaultTimeout
}
if opt.SaltFetchInterval == 0 {
opt.SaltFetchInterval = 1 * time.Hour
}
if opt.PingTimeout == 0 {
opt.PingTimeout = 15 * time.Second
}
if opt.PingInterval == 0 {
opt.PingInterval = 1 * time.Minute
}
if opt.RequestTimeout == nil {
opt.RequestTimeout = func(req uint32) time.Duration {
return 15 * time.Second
}
}
if opt.CompressThreshold == 0 {
opt.CompressThreshold = 1024
}
if opt.Clock == nil {
opt.Clock = clock.System
}
if opt.MessageID == nil {
opt.MessageID = proto.NewMessageIDGen(opt.Clock.Now)
}
if len(opt.PublicKeys) == 0 {
opt.setDefaultPublicKeys()
}
if opt.Handler == nil {
opt.Handler = nopHandler{}
}
if opt.Cipher == nil {
opt.Cipher = crypto.NewClientCipher(opt.Random)
}
}
+122
View File
@@ -0,0 +1,122 @@
package mtproto
import (
"context"
"github.com/go-faster/errors"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
"go.mau.fi/mautrix-telegram/pkg/gotd/mt"
)
// Ping sends ping request to server and waits until pong is received or
// context is canceled.
func (c *Conn) Ping(ctx context.Context) error {
// Generating random id.
// Probably we should check for collisions here.
pingID, err := crypto.RandInt64(c.rand)
if err != nil {
return err
}
pong := c.pong(pingID)
defer c.removePong(pingID)
if err := c.writeServiceMessage(ctx, &mt.PingRequest{PingID: pingID}); err != nil {
return errors.Wrap(err, "write")
}
select {
case <-pong:
return nil
case <-ctx.Done():
return ctx.Err()
}
}
func (c *Conn) handlePong(b *bin.Buffer) error {
var pong mt.Pong
if err := pong.Decode(b); err != nil {
return errors.Errorf("decode: %x", err)
}
c.log.Debug("Pong")
c.pingMux.Lock()
ch, ok := c.ping[pong.PingID]
if ok {
close(ch)
delete(c.ping, pong.PingID)
}
c.pingMux.Unlock()
return nil
}
func (c *Conn) pingDelayDisconnect(ctx context.Context, delay int) error {
// Generating random id.
// Probably we should check for collisions here.
pingID, err := crypto.RandInt64(c.rand)
if err != nil {
return err
}
pong := c.pong(pingID)
defer c.removePong(pingID)
if err := c.writeServiceMessage(ctx, &mt.PingDelayDisconnectRequest{
PingID: pingID,
DisconnectDelay: delay,
}); err != nil {
return errors.Wrap(err, "write")
}
select {
case <-pong:
if c.pingCallback != nil {
c.pingCallback()
}
return nil
case <-ctx.Done():
return ctx.Err()
}
}
func (c *Conn) pong(pingID int64) chan struct{} {
ch := make(chan struct{})
c.pingMux.Lock()
c.ping[pingID] = ch
c.pingMux.Unlock()
return ch
}
func (c *Conn) removePong(pingID int64) {
c.pingMux.Lock()
delete(c.ping, pingID)
c.pingMux.Unlock()
}
func (c *Conn) pingLoop(ctx context.Context) error {
// If the client sends these pings once every 60 seconds,
// for example, it may set disconnect_delay equal to 75 seconds.
delay := c.pingInterval + c.pingTimeout
ticker := c.clock.Ticker(c.pingInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return errors.Wrap(ctx.Err(), "ping loop")
case <-ticker.C():
if err := func() error {
ctx, cancel := context.WithTimeout(ctx, c.pingTimeout)
defer cancel()
return c.pingDelayDisconnect(ctx, int(delay.Seconds()))
}(); err != nil {
return errors.Wrap(err, "disconnect (pong missed)")
}
}
}
}
+213
View File
@@ -0,0 +1,213 @@
package mtproto
import (
"context"
"net"
"sync"
"sync/atomic"
"time"
"github.com/go-faster/errors"
"go.uber.org/zap"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
"go.mau.fi/mautrix-telegram/pkg/gotd/proto"
"go.mau.fi/mautrix-telegram/pkg/gotd/proto/codec"
)
// https://core.telegram.org/mtproto/description#message-identifier-msg-id
// A message is rejected over 300 seconds after it is created or 30 seconds
// before it is created (this is needed to protect from replay attacks).
const (
maxPast = time.Second * 300
maxFuture = time.Second * 30
)
// errRejected is returned on invalid message that should not be processed.
var errRejected = errors.New("message rejected")
func checkMessageID(now time.Time, rawID int64) error {
id := proto.MessageID(rawID)
// Check that message is from server.
switch id.Type() {
case proto.MessageFromServer, proto.MessageServerResponse:
// Valid.
default:
return errors.Wrapf(errRejected, "unexpected type %s", id.Type())
}
created := id.Time()
if created.Before(now) && now.Sub(created) > maxPast {
return errors.Wrap(errRejected, "created too far in past")
}
if created.Sub(now) > maxFuture {
return errors.Wrap(errRejected, "created too far in future")
}
return nil
}
func (c *Conn) decryptMessage(b *bin.Buffer) (*crypto.EncryptedMessageData, error) {
session := c.session()
msg, err := c.cipher.DecryptFromBuffer(session.Key, b)
if err != nil {
return nil, errors.Wrap(err, "decrypt")
}
// Validating message. This protects from replay attacks.
if msg.SessionID != session.ID {
return nil, errors.Wrapf(errRejected, "invalid session (got %d, expected %d)", msg.SessionID, session.ID)
}
if err := checkMessageID(c.clock.Now(), msg.MessageID); err != nil {
return nil, errors.Wrapf(err, "bad message id %d", msg.MessageID)
}
if !c.messageIDBuf.Consume(msg.MessageID) {
return nil, errors.Wrapf(errRejected, "duplicate or too low message id %d", msg.MessageID)
}
return msg, nil
}
func (c *Conn) consumeMessage(ctx context.Context, buf *bin.Buffer) error {
msg, err := c.decryptMessage(buf)
if errors.Is(err, errRejected) {
c.log.Warn("Ignoring rejected message", zap.Error(err))
return nil
}
if err != nil {
return errors.Wrap(err, "consume message")
}
if err := c.handleMessage(msg.MessageID, &bin.Buffer{Buf: msg.Data()}); err != nil {
// Probably we can return here, but this will shutdown whole
// connection which can be unexpected.
c.log.Warn("Error while handling message", zap.Error(err))
// Sending acknowledge even on error. Client should restore
// from missing updates via explicit pts check and getDiff call.
}
needAck := (msg.SeqNo & 0x01) != 0
if needAck {
select {
case <-ctx.Done():
return ctx.Err()
case c.ackSendChan <- msg.MessageID:
}
}
return nil
}
func (c *Conn) noUpdates(err error) bool {
// Checking for read timeout.
var syscall *net.OpError
if errors.As(err, &syscall) && syscall.Timeout() {
// We call SetReadDeadline so such error is expected.
c.log.Debug("No updates")
return true
}
return false
}
func (c *Conn) handleAuthKeyNotFound(ctx context.Context) error {
if c.session().ID == 0 {
// The 404 error can also be caused by zero session id.
// See https://github.com/gotd/td/issues/107
//
// We should recover from this in createAuthKey, but in general
// this code branch should be unreachable.
c.log.Warn("BUG: zero session id found")
}
c.log.Warn("Re-generating keys (server not found key that we provided)")
if err := c.createAuthKey(ctx); err != nil {
return errors.Wrap(err, "unable to create auth key")
}
c.log.Info("Re-created auth keys")
// Request will be retried by ack loop.
// Probably we can speed-up this.
return nil
}
func (c *Conn) readLoop(ctx context.Context) (err error) {
log := c.log.Named("read")
log.Debug("Read loop started")
defer func() {
l := log
if err != nil {
l = log.With(zap.NamedError("reason", err))
}
l.Debug("Read loop done")
}()
var (
// Last error encountered by consumeMessage.
lastErr atomic.Value
// To wait all spawned goroutines
handlers sync.WaitGroup
)
defer handlers.Wait()
for {
// We've tried multiple ways to reduce allocations via reusing buffer,
// but naive implementation induces high idle memory waste.
//
// Proper optimization will probably require total rework of bin.Buffer
// with sharded (by payload size?) pool that can be used after message
// size read (after readLen).
//
// Such optimization can introduce additional complexity overhead and
// is probably not worth it.
buf := &bin.Buffer{}
// Halting if consumeMessage encountered error.
// Should be something critical with crypto.
if err, ok := lastErr.Load().(error); ok && err != nil {
return errors.Wrap(err, "halting")
}
if err := c.conn.Recv(ctx, buf); err != nil {
select {
case <-ctx.Done():
return ctx.Err()
default:
if c.noUpdates(err) {
continue
}
}
var protoErr *codec.ProtocolErr
if errors.As(err, &protoErr) && protoErr.Code == codec.CodeAuthKeyNotFound {
if err := c.handleAuthKeyNotFound(ctx); err != nil {
return errors.Wrap(err, "auth key not found")
}
continue
}
select {
case <-ctx.Done():
return errors.Wrap(ctx.Err(), "read loop")
default:
return errors.Wrap(err, "read")
}
}
handlers.Add(1)
go func() {
defer handlers.Done()
// Spawning goroutine per incoming message to utilize as much
// resources as possible while keeping idle utilization low.
//
// The "worker" model was replaced by this due to idle utilization
// overhead, especially on multi-CPU systems with multiple running
// clients.
if err := c.consumeMessage(ctx, buf); err != nil {
log.Error("Failed to process message", zap.Error(err))
lastErr.Store(errors.Wrap(err, "consume"))
}
}()
}
}
+115
View File
@@ -0,0 +1,115 @@
package mtproto
import (
"context"
"crypto/rand"
"io"
"testing"
"time"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
"github.com/gotd/neo"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
"go.mau.fi/mautrix-telegram/pkg/gotd/proto"
"go.mau.fi/mautrix-telegram/pkg/gotd/tdsync"
"go.mau.fi/mautrix-telegram/pkg/gotd/testutil"
)
func TestCheckMessageID(t *testing.T) {
now := testutil.Date()
t.Run("Good", func(t *testing.T) {
for _, good := range []proto.MessageID{
proto.NewMessageID(now, proto.MessageFromServer),
proto.NewMessageID(now, proto.MessageServerResponse),
proto.NewMessageID(now.Add(time.Second*29), proto.MessageFromServer),
proto.NewMessageID(now.Add(-time.Second*299), proto.MessageFromServer),
} {
t.Run(good.String(), func(t *testing.T) {
require.NoError(t, checkMessageID(now, int64(good)))
})
}
})
t.Run("Bad", func(t *testing.T) {
for _, bad := range []proto.MessageID{
proto.NewMessageID(now, proto.MessageFromClient),
proto.NewMessageID(now.Add(time.Second*31), proto.MessageFromServer),
proto.NewMessageID(now.Add(-time.Second*301), proto.MessageFromServer),
proto.NewMessageID(time.Time{}, proto.MessageFromServer),
proto.NewMessageID(now.AddDate(-10, 0, 0), proto.MessageServerResponse),
proto.NewMessageID(time.Time{}, proto.MessageFromClient),
} {
t.Run(bad.String(), func(t *testing.T) {
require.ErrorIs(t, checkMessageID(now, int64(bad)), errRejected)
})
}
})
}
func benchRead(payloadSize int) func(b *testing.B) {
return func(b *testing.B) {
a := require.New(b)
logger := zap.NewNop()
random := rand.Reader
c := neo.NewTime(time.Now())
var key crypto.Key
_, err := io.ReadFull(random, key[:])
a.NoError(err)
authKey := key.WithID()
payload := make([]byte, payloadSize)
_, err = io.ReadFull(random, payload)
a.NoError(err)
msg := new(bin.Buffer)
serverCipher := crypto.NewServerCipher(random)
id := proto.NewMessageIDGen(c.Now).New(proto.MessageServerResponse)
a.NoError(msg.Encode(&testPayload{
Data: payload,
}))
length := msg.Len()
data := msg.Copy()
a.NoError(serverCipher.Encrypt(authKey, crypto.EncryptedMessageData{
MessageID: id,
SeqNo: 0,
MessageDataLen: int32(length),
MessageDataWithPadding: data,
}, msg))
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
conn := Conn{
conn: &constantConn{
data: msg.Raw(),
cancel: cancel,
counter: b.N,
},
handler: nopHandler{},
clock: c,
rand: random,
cipher: crypto.NewClientCipher(random),
log: logger,
messageIDBuf: noopBuf{},
authKey: authKey,
compressThreshold: -1,
}
grp := tdsync.NewCancellableGroup(ctx)
b.ResetTimer()
b.ReportAllocs()
b.SetBytes(int64(payloadSize))
grp.Go(conn.readLoop)
a.ErrorIs(grp.Wait(), context.Canceled)
}
}
func BenchmarkRead(b *testing.B) {
testutil.RunPayloads(b, benchRead)
}
+71
View File
@@ -0,0 +1,71 @@
package mtproto
import (
"context"
"github.com/go-faster/errors"
"go.uber.org/zap"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
"go.mau.fi/mautrix-telegram/pkg/gotd/mt"
"go.mau.fi/mautrix-telegram/pkg/gotd/rpc"
)
// Invoke sends input and decodes result into output.
//
// NOTE: Assuming that call contains content message (seqno increment).
func (c *Conn) Invoke(ctx context.Context, input bin.Encoder, output bin.Decoder) error {
msgID, seqNo := c.nextMsgSeq(true)
req := rpc.Request{
MsgID: msgID,
SeqNo: seqNo,
Input: input,
Output: output,
}
log := c.log.With(
zap.Int64("msg_id", req.MsgID),
)
log.Debug("Invoke start")
defer log.Debug("Invoke end")
if err := c.rpc.Do(ctx, req); err != nil {
var badMsgErr *badMessageError
if errors.As(err, &badMsgErr) && badMsgErr.Code == codeIncorrectServerSalt {
// Should retry with new salt.
c.log.Debug("Setting server salt")
// Store salt from server.
c.storeSalt(badMsgErr.NewSalt)
// Reset saved salts to fetch new.
c.salts.Reset()
c.log.Info("Retrying request after basMsgErr", zap.Int64("msg_id", req.MsgID))
return c.rpc.Do(ctx, req)
}
return err
}
return nil
}
func (c *Conn) dropRPC(req rpc.Request) error {
ctx, cancel := context.WithTimeout(context.Background(),
c.getTimeout(mt.RPCDropAnswerRequestTypeID),
)
defer cancel()
var resp mt.RPCDropAnswerBox
if err := c.Invoke(ctx, &mt.RPCDropAnswerRequest{
ReqMsgID: req.MsgID,
}, &resp); err != nil {
return err
}
switch resp.RpcDropAnswer.(type) {
case *mt.RPCAnswerDropped, *mt.RPCAnswerDroppedRunning:
return nil
case *mt.RPCAnswerUnknown:
return errors.New("answer unknown")
default:
return errors.Errorf("unexpected response type: %T", resp.RpcDropAnswer)
}
}
+52
View File
@@ -0,0 +1,52 @@
package mtproto
import (
"testing"
"github.com/stretchr/testify/assert"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
"go.mau.fi/mautrix-telegram/pkg/gotd/mt"
"go.mau.fi/mautrix-telegram/pkg/gotd/rpc"
"go.mau.fi/mautrix-telegram/pkg/gotd/testutil"
)
func TestConn_dropRPC(t *testing.T) {
dropID := int64(10)
tests := []struct {
name string
result bin.Encoder
resultErr error
wantErr bool
}{
{"Dropped", &mt.RPCAnswerDropped{MsgID: dropID}, nil, false},
{"DroppedRunning", &mt.RPCAnswerDroppedRunning{}, nil, false},
{"Unknown", &mt.RPCAnswerUnknown{}, nil, true},
{"Error", nil, testutil.TestError(), true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := assert.New(t)
client := newTestClient(func(msgID int64, seqNo int32, body bin.Encoder) (bin.Encoder, error) {
req, ok := body.(*mt.RPCDropAnswerRequest)
a.True(ok)
if ok {
a.Equal(dropID, req.ReqMsgID)
}
return tt.result, tt.resultErr
})
err := client.dropRPC(rpc.Request{
MsgID: dropID,
SeqNo: 1,
})
if tt.wantErr {
a.Error(err)
} else {
a.NoError(err)
}
})
}
}
+74
View File
@@ -0,0 +1,74 @@
package mtproto
import (
"context"
"time"
"github.com/go-faster/errors"
"go.uber.org/zap"
"go.mau.fi/mautrix-telegram/pkg/gotd/mt"
)
func (c *Conn) storeSalt(salt int64) {
c.sessionMux.Lock()
// Copy to log.
oldSalt := c.salt
c.salt = salt
c.sessionMux.Unlock()
if salt != oldSalt {
c.log.Info("Salt updated", zap.Int64("old", oldSalt), zap.Int64("new", salt))
}
}
func (c *Conn) updateSalt() {
salt, ok := c.salts.Get(c.clock.Now().Add(time.Minute * 5))
if !ok {
return
}
c.storeSalt(salt)
}
const defaultSaltsNum = 4
func (c *Conn) getSalts(ctx context.Context) error {
request := &mt.GetFutureSaltsRequest{
Num: defaultSaltsNum,
}
ctx, cancel := context.WithTimeout(ctx, c.getTimeout(request.TypeID()))
defer cancel()
if err := c.writeServiceMessage(ctx, request); err != nil {
return errors.Wrap(err, "request salts")
}
return nil
}
func (c *Conn) saltLoop(ctx context.Context) error {
select {
case <-c.gotSession.Ready():
case <-ctx.Done():
return ctx.Err()
}
// Get salts first time.
if err := c.getSalts(ctx); err != nil {
return err
}
ticker := c.clock.Ticker(c.saltFetchInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C():
if err := c.getSalts(ctx); err != nil {
return err
}
case <-ctx.Done():
return ctx.Err()
}
}
}
+90
View File
@@ -0,0 +1,90 @@
// Package salts contains MTProto server salt storage.
package salts
import (
"sort"
"sync"
"time"
"go.mau.fi/mautrix-telegram/pkg/gotd/mt"
)
// Salts is a simple struct store server salts.
type Salts struct {
// server salts fetched by getSalts.
salts []mt.FutureSalt
saltsMux sync.Mutex
}
// Get returns next valid salt.
func (s *Salts) Get(deadline time.Time) (int64, bool) {
s.saltsMux.Lock()
defer s.saltsMux.Unlock()
check:
if len(s.salts) < 1 {
return 0, false
}
date := int(deadline.Unix())
if salt := s.salts[len(s.salts)-1]; salt.ValidUntil > date {
return salt.Salt, true
}
// Filter (in place) from SliceTricks.
n := 0
// Check that the salt will be valid until deadline.
for _, salt := range s.salts {
// Filter expired salts.
if salt.ValidUntil > date {
// Keep valid salt.
s.salts[n] = salt
n++
}
}
s.salts = s.salts[:n]
goto check
}
type saltSlice []mt.FutureSalt
func (s saltSlice) Len() int {
return len(s)
}
func (s saltSlice) Less(i, j int) bool {
return s[i].ValidUntil > s[j].ValidUntil
}
func (s saltSlice) Swap(i, j int) {
s[i], s[j] = s[j], s[i]
}
// Store stores all given salts.
func (s *Salts) Store(salts []mt.FutureSalt) {
s.saltsMux.Lock()
defer s.saltsMux.Unlock()
s.salts = append(s.salts, salts...)
// Filter duplicates.
n := 0
dedup := make(map[int64]struct{}, len(s.salts)+1)
for _, salt := range s.salts {
if _, ok := dedup[salt.Salt]; !ok {
dedup[salt.Salt] = struct{}{}
s.salts[n] = salt
n++
}
}
s.salts = s.salts[:n]
// Sort slice by valid until.
sort.Sort(saltSlice(s.salts))
}
// Reset deletes all stored salts.
func (s *Salts) Reset() {
s.saltsMux.Lock()
s.salts = s.salts[:0]
s.saltsMux.Unlock()
}
+112
View File
@@ -0,0 +1,112 @@
package salts
import (
"testing"
"time"
"github.com/stretchr/testify/require"
"go.mau.fi/mautrix-telegram/pkg/gotd/mt"
"go.mau.fi/mautrix-telegram/pkg/gotd/testutil"
)
func generateSalts(n int) []mt.FutureSalt {
r := make([]mt.FutureSalt, n)
for i := range r {
since := (i + 1) * 10
r[i] = mt.FutureSalt{
ValidSince: since,
ValidUntil: since + 15,
Salt: int64(i),
}
}
return r
}
func TestSalts(t *testing.T) {
a := require.New(t)
salts := &Salts{}
var testData = []mt.FutureSalt{
{
ValidSince: 10,
ValidUntil: 25,
Salt: 1,
},
{
ValidSince: 20,
ValidUntil: 35,
Salt: 2,
},
{
ValidSince: 30,
ValidUntil: 45,
Salt: 3,
},
}
salts.Store(testData[:2])
a.Len(salts.salts, 2)
salt, ok := salts.Get(time.Unix(11, 0))
a.Equal(int64(1), salt)
a.True(ok)
_, ok = salts.Get(time.Unix(36, 0))
a.False(ok)
salts.Store(testData[:2])
a.Len(salts.salts, 2)
salts.Store(testData[:3])
a.Len(salts.salts, 3)
salt, ok = salts.Get(time.Unix(26, 0))
a.Equal(int64(2), salt)
a.True(ok)
salt, ok = salts.Get(time.Unix(36, 0))
a.Equal(int64(3), salt)
a.True(ok)
salts.Reset()
_, ok = salts.Get(time.Unix(36, 0))
a.False(ok)
}
func TestSalts_Get(t *testing.T) {
salts := &Salts{}
salts.Store(generateSalts(64))
now := time.Unix(11, 0)
testutil.ZeroAlloc(t, func() {
salts.Get(now)
})
}
func BenchmarkSalts_Get(b *testing.B) {
salts := &Salts{}
salts.Store(generateSalts(64))
t := time.Unix(11, 0)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
salts.Get(t)
}
}
func BenchmarkSalts_Store(b *testing.B) {
testData := generateSalts(64)
salts := &Salts{
salts: make([]mt.FutureSalt, 0, len(testData)),
}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
salts.Store(testData)
}
}
+37
View File
@@ -0,0 +1,37 @@
package mtproto
import "go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
// Session represents connection state.
type Session struct {
ID int64
Key crypto.AuthKey
Salt int64
}
// Session returns current connection session info.
func (c *Conn) session() Session {
c.updateSalt()
c.sessionMux.RLock()
defer c.sessionMux.RUnlock()
return Session{
Key: c.authKey,
Salt: c.salt,
ID: c.sessionID,
}
}
// newSessionID sets session id to random value.
func (c *Conn) newSessionID() error {
id, err := crypto.RandInt64(c.rand)
if err != nil {
return err
}
c.sessionMux.Lock()
defer c.sessionMux.Unlock()
c.sessionID = id
return nil
}
+58
View File
@@ -0,0 +1,58 @@
package mtproto
import (
// For embedding public keys.
_ "embed"
"sync"
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
"go.mau.fi/mautrix-telegram/pkg/gotd/exchange"
)
var (
// publicKeys is byte blob of new keys added for PQInnerData encryption (key exchange).
//
// See https://github.com/telegramdesktop/tdesktop/commit/95a7ce4622dc24717dc5b95fc99599dddfd4ff6c.
//
// See https://github.com/tdlib/td/commit/e9e24282378fcdb3a3ce020bee4253b65ac98213.
//go:embed _data/public_keys.pem
publicKeys []byte
parsedKeys struct {
Keys []exchange.PublicKey
Once sync.Once
}
)
//nolint:gochecknoinits
func init() {
makePublicKeys := func(data []byte) ([]exchange.PublicKey, error) {
rsaKeys, err := crypto.ParseRSAPublicKeys(data)
if err != nil {
return nil, err
}
keys := make([]exchange.PublicKey, 0, len(rsaKeys))
for _, key := range rsaKeys {
keys = append(keys, exchange.PublicKey{
RSA: key,
})
}
return keys, nil
}
parsedKeys.Once.Do(func() {
keys, err := makePublicKeys(publicKeys)
if err != nil {
panic(err)
}
parsedKeys.Keys = append(parsedKeys.Keys, keys...)
})
}
// vendoredKeys parses vendored file _data/public_keys.pem as list of
// PEM-encoded public RSA keys.
//
// Most recent key list can be found on https://my.telegram.org/apps.
func vendoredKeys() []exchange.PublicKey {
return parsedKeys.Keys
}
+10
View File
@@ -0,0 +1,10 @@
package mtproto
import "testing"
func TestVendoredKeys(t *testing.T) {
keys := vendoredKeys()
if len(keys) == 0 {
t.Fatal("empty keys")
}
}
+57
View File
@@ -0,0 +1,57 @@
package mtproto
import (
"context"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
)
func (c *Conn) writeContentMessage(ctx context.Context, msgID int64, seqNo int32, message bin.Encoder) error {
return c.write(ctx, msgID, seqNo, message)
}
func (c *Conn) writeServiceMessage(ctx context.Context, message bin.Encoder) error {
msgID, seqNo := c.nextMsgSeq(false)
return c.write(ctx, msgID, seqNo, message)
}
var bufPool = bin.NewPool(0)
func (c *Conn) write(ctx context.Context, msgID int64, seqNo int32, message bin.Encoder) error {
// Grab shared lock for writing.
// It prevents message sending during key regeneration if server forgot current auth key.
c.exchangeLock.RLock()
defer c.exchangeLock.RUnlock()
b := bufPool.Get()
defer bufPool.Put(b)
if err := c.newEncryptedMessage(msgID, seqNo, message, b); err != nil {
return err
}
if err := c.conn.Send(ctx, b); err != nil {
return err
}
return nil
}
func (c *Conn) nextMsgSeq(content bool) (msgID int64, seqNo int32) {
c.reqMux.Lock()
defer c.reqMux.Unlock()
msgID = c.newMessageID()
// Computing current sequence number (seqno).
// This should be serialized with new message id generation.
//
// See https://github.com/gotd/td/issues/245 for reference.
seqNo = c.sentContentMessages * 2
if content {
seqNo++
c.sentContentMessages++
}
return
}
+57
View File
@@ -0,0 +1,57 @@
package mtproto
import (
"context"
"crypto/rand"
"io"
"testing"
"time"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
"github.com/gotd/neo"
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
"go.mau.fi/mautrix-telegram/pkg/gotd/testutil"
)
func benchWrite(payloadSize int) func(b *testing.B) {
return func(b *testing.B) {
a := require.New(b)
logger := zap.NewNop()
random := rand.Reader
c := neo.NewTime(time.Now())
var key crypto.Key
_, err := io.ReadFull(random, key[:])
a.NoError(err)
authKey := key.WithID()
payload := make([]byte, payloadSize)
_, err = io.ReadFull(random, payload)
a.NoError(err)
data := &testPayload{Data: payload}
conn := Conn{
conn: &constantConn{},
clock: c,
rand: random,
cipher: crypto.NewClientCipher(random),
log: logger,
authKey: authKey,
compressThreshold: -1,
}
b.ResetTimer()
b.ReportAllocs()
b.SetBytes(int64(payloadSize))
for i := 0; i < b.N; i++ {
_ = conn.write(context.Background(), 1, 1, data)
}
}
}
func BenchmarkWrite(b *testing.B) {
testutil.RunPayloads(b, benchWrite)
}
+45
View File
@@ -0,0 +1,45 @@
package mtproto
import (
"fmt"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
)
type logType struct {
ID uint32
Name string
}
func (l logType) MarshalLogObject(e zapcore.ObjectEncoder) error {
typeIDStr := fmt.Sprintf("0x%x", l.ID)
e.AddString("type_id", typeIDStr)
if l.Name != "" {
e.AddString("type_name", l.Name)
}
return nil
}
func (c *Conn) logWithBuffer(b *bin.Buffer) *zap.Logger {
return c.logWithType(b).With(zap.Int("size_bytes", b.Len()))
}
func (c *Conn) logWithType(b *bin.Buffer) *zap.Logger {
id, err := b.PeekID()
if err != nil {
// Type info not available.
return c.log
}
return c.logWithTypeID(id)
}
func (c *Conn) logWithTypeID(id uint32) *zap.Logger {
return c.log.With(zap.Inline(logType{
ID: id,
Name: c.types.Get(id),
}))
}
+28
View File
@@ -0,0 +1,28 @@
package mtproto
import (
"testing"
"go.uber.org/zap"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
"go.mau.fi/mautrix-telegram/pkg/gotd/tmap"
)
func BenchmarkConn_logWithType(b *testing.B) {
c := Conn{
log: zap.NewNop(),
types: tmap.New(map[uint32]string{
0x3fedd339: "true",
}),
}
buf := bin.Buffer{}
buf.PutID(0x3fedd339)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
c.logWithType(&buf).Info("Hi!")
}
}
+8
View File
@@ -0,0 +1,8 @@
//go:build !fuzz
// +build !fuzz
package mtproto
import "go.mau.fi/mautrix-telegram/pkg/gotd/testutil"
type Zero = testutil.ZeroRand