move gotd fork into repo. (#111)

- update to latest telegram layer
- remove some references to fields in tg.Entities that don't exist in
the schema
- originally added here:
https://github.com/beeper/td/commit/820929062a2ba0104397bc01235ab58a9cff780e
  - referenced here
-
https://github.com/mautrix/telegramgo/commit/124f0967ed195b5a380c9bd02e170ada9710dde3
-
https://github.com/mautrix/telegramgo/commit/4205047aab2e0639217148b5d125bfaab668bd8e
This commit is contained in:
Adam Van Ymeren
2025-06-27 20:03:37 -07:00
committed by GitHub
parent 0952df0244
commit 7a04f298d2
19264 changed files with 1539697 additions and 84 deletions
+26
View File
@@ -0,0 +1,26 @@
package exchange
import (
"io"
"go.uber.org/zap"
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
)
// ClientExchange is a client-side key exchange flow.
type ClientExchange struct {
unencryptedWriter
rand io.Reader
log *zap.Logger
keys []PublicKey
dc int
}
// ClientExchangeResult contains client part of key exchange result.
type ClientExchangeResult struct {
AuthKey crypto.AuthKey
SessionID int64
ServerSalt int64
}
+278
View File
@@ -0,0 +1,278 @@
package exchange
import (
"context"
"crypto/rand"
"math/big"
"github.com/go-faster/errors"
"go.uber.org/zap"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
"go.mau.fi/mautrix-telegram/pkg/gotd/mt"
"go.mau.fi/mautrix-telegram/pkg/gotd/proto"
)
// Run runs client-side flow.
func (c ClientExchange) Run(ctx context.Context) (ClientExchangeResult, error) {
// 1. DH exchange initiation.
nonce, err := crypto.RandInt128(c.rand)
if err != nil {
return ClientExchangeResult{}, errors.Wrap(err, "client nonce generation")
}
b := new(bin.Buffer)
c.log.Debug("Sending ReqPqMultiRequest")
if err := c.writeUnencrypted(ctx, b, &mt.ReqPqMultiRequest{Nonce: nonce}); err != nil {
return ClientExchangeResult{}, errors.Wrap(err, "write ReqPqMultiRequest")
}
// 2. Server sends response of the form
// resPQ#05162463 nonce:int128 server_nonce:int128 pq:string server_public_key_fingerprints:Vector long = ResPQ;
var res mt.ResPQ
if err := c.readUnencrypted(ctx, b, &res); err != nil {
return ClientExchangeResult{}, errors.Wrap(err, "read ResPQ response")
}
c.log.Debug("Received server ResPQ")
if res.Nonce != nonce {
return ClientExchangeResult{}, errors.New("ResPQ nonce mismatch")
}
serverNonce := res.ServerNonce
// Selecting first public key that match fingerprint.
var selectedPubKey PublicKey
Loop:
for _, key := range c.keys {
f := key.Fingerprint()
for _, fingerprint := range res.ServerPublicKeyFingerprints {
if fingerprint == f {
selectedPubKey = key
break Loop
}
}
}
if selectedPubKey.Zero() {
return ClientExchangeResult{}, ErrKeyFingerprintNotFound
}
// The pq is a representation of a natural number (in binary big endian format).
// SetBytes is also big endian.
pq := big.NewInt(0).SetBytes(res.Pq)
// Normally pq is less than or equal to 2^63-1.
pqMax := big.NewInt(0).Exp(big.NewInt(2), big.NewInt(63), nil)
if pq.Cmp(pqMax) > 0 {
return ClientExchangeResult{}, errors.New("server provided bad pq")
}
start := c.clock.Now()
// 3. Client decomposes pq into prime factors such that p < q.
// Performing proof of work.
p, q, err := crypto.DecomposePQ(pq, c.rand)
if err != nil {
return ClientExchangeResult{}, errors.Wrap(err, "decompose pq")
}
c.log.Debug("PQ decomposing complete", zap.Duration("took", c.clock.Now().Sub(start)))
// Make a copy of p and q values to reduce allocations.
pBytes := p.Bytes()
qBytes := q.Bytes()
// 4. Client sends query to server.
// req_DH_params#d712e4be nonce:int128 server_nonce:int128 p:string q:string
// public_key_fingerprint:long encrypted_data:string = Server_DH_Params
newNonce, err := crypto.RandInt256(c.rand)
if err != nil {
return ClientExchangeResult{}, errors.Wrap(err, "generate new nonce")
}
var encryptedData []byte
pqInnerData := &mt.PQInnerDataDC{
Pq: res.Pq,
Nonce: nonce,
NewNonce: newNonce,
ServerNonce: serverNonce,
P: pBytes,
Q: qBytes,
DC: c.dc,
}
b.Reset()
if err := pqInnerData.Encode(b); err != nil {
return ClientExchangeResult{}, err
}
// `encrypted_data := RSA_PAD(data, server_public_key);`
data, err := crypto.RSAPad(b.Buf, selectedPubKey.RSA, c.rand)
if err != nil {
return ClientExchangeResult{}, errors.Wrap(err, "encrypted_data generation")
}
encryptedData = data
reqDHParams := &mt.ReqDHParamsRequest{
Nonce: nonce,
ServerNonce: serverNonce,
P: pBytes,
Q: qBytes,
PublicKeyFingerprint: selectedPubKey.Fingerprint(),
EncryptedData: encryptedData,
}
c.log.Debug("Sending ReqDHParamsRequest")
if err := c.writeUnencrypted(ctx, b, reqDHParams); err != nil {
return ClientExchangeResult{}, errors.Wrap(err, "write ReqDHParamsRequest")
}
// 5. Server responds with Server_DH_Params.
if err := c.conn.Recv(ctx, b); err != nil {
return ClientExchangeResult{}, errors.Wrap(err, "read ServerDHParams message")
}
c.log.Debug("Received server ServerDHParams")
var plaintextMsg proto.UnencryptedMessage
if err := plaintextMsg.Decode(b); err != nil {
return ClientExchangeResult{}, errors.Wrap(err, "decode ServerDHParams message")
}
b.ResetTo(plaintextMsg.MessageData)
dhParams, err := mt.DecodeServerDHParams(b)
if err != nil {
return ClientExchangeResult{}, err
}
switch p := dhParams.(type) {
case *mt.ServerDHParamsOk:
// Success.
if p.Nonce != nonce {
return ClientExchangeResult{}, errors.New("ServerDHParamsOk nonce mismatch")
}
if p.ServerNonce != serverNonce {
return ClientExchangeResult{}, errors.New("ServerDHParamsOk server nonce mismatch")
}
key, iv := crypto.TempAESKeys(newNonce.BigInt(), serverNonce.BigInt())
// Decrypting inner data.
data, err := crypto.DecryptExchangeAnswer(p.EncryptedAnswer, key, iv)
if err != nil {
return ClientExchangeResult{}, errors.Wrap(err, "exchange answer decrypt")
}
b.ResetTo(data)
innerData := mt.ServerDHInnerData{}
if err := innerData.Decode(b); err != nil {
return ClientExchangeResult{}, err
}
if innerData.Nonce != nonce {
return ClientExchangeResult{}, errors.New("ServerDHInnerData nonce mismatch")
}
if innerData.ServerNonce != serverNonce {
return ClientExchangeResult{}, errors.New("ServerDHInnerData server nonce mismatch")
}
dhPrime := big.NewInt(0).SetBytes(innerData.DhPrime)
g := big.NewInt(int64(innerData.G))
if err := crypto.CheckDH(innerData.G, dhPrime); err != nil {
return ClientExchangeResult{}, errors.Wrap(err, "check DH params")
}
gA := big.NewInt(0).SetBytes(innerData.GA)
// 6. Random number b is computed:
randMax := big.NewInt(0).SetBit(big.NewInt(0), crypto.RSAKeyBits, 1)
bParam, err := rand.Int(c.rand, randMax)
if err != nil {
return ClientExchangeResult{}, errors.Wrap(err, "number b generation")
}
// g_b = g^b mod dh_prime
gB := big.NewInt(0).Exp(g, bParam, dhPrime)
// Checking key exchange parameters.
if err := crypto.CheckDHParams(dhPrime, g, gA, gB); err != nil {
return ClientExchangeResult{}, errors.Wrap(err, "key exchange failed: invalid params")
}
clientInnerData := mt.ClientDHInnerData{
ServerNonce: innerData.ServerNonce,
Nonce: innerData.Nonce,
GB: gB.Bytes(),
// first attempt
RetryID: 0,
}
b.Reset()
if err := clientInnerData.Encode(b); err != nil {
return ClientExchangeResult{}, err
}
clientEncrypted, err := crypto.EncryptExchangeAnswer(c.rand, b.Buf, key, iv)
if err != nil {
return ClientExchangeResult{}, errors.Wrap(err, "exchange answer encrypt")
}
setParamsReq := &mt.SetClientDHParamsRequest{
Nonce: nonce,
ServerNonce: reqDHParams.ServerNonce,
EncryptedData: clientEncrypted,
}
c.log.Debug("Sending SetClientDHParamsRequest")
if err := c.writeUnencrypted(ctx, b, setParamsReq); err != nil {
return ClientExchangeResult{}, errors.Wrap(err, "write SetClientDHParamsRequest")
}
// 7. Computing auth_key using formula (g_a)^b mod dh_prime
authKey := big.NewInt(0).Exp(gA, bParam, dhPrime)
b.Reset()
if err := c.conn.Recv(ctx, b); err != nil {
return ClientExchangeResult{}, errors.Wrap(err, "read DhGen message")
}
c.log.Debug("Received server DhGen")
if err := plaintextMsg.Decode(b); err != nil {
return ClientExchangeResult{}, errors.Wrap(err, "decode DhGen message")
}
b.ResetTo(plaintextMsg.MessageData)
dhSetRes, err := mt.DecodeSetClientDHParamsAnswer(b)
if err != nil {
return ClientExchangeResult{}, errors.Wrap(err, "decode DhGen answer")
}
switch v := dhSetRes.(type) {
case *mt.DhGenOk: // dh_gen_ok#3bcbf734
if v.Nonce != nonce {
return ClientExchangeResult{}, errors.New("DhGenOk nonce mismatch")
}
if v.ServerNonce != serverNonce {
return ClientExchangeResult{}, errors.New("DhGenOk server nonce mismatch")
}
var key crypto.Key
authKey.FillBytes(key[:])
authKeyID := key.ID()
// Checking received hash.
nonceHash1 := crypto.NonceHash1(newNonce, key)
serverSalt := crypto.ServerSalt(newNonce, v.ServerNonce)
if nonceHash1 != v.NewNonceHash1 {
return ClientExchangeResult{}, errors.New("key exchange verification failed: hash mismatch")
}
// Generating new session id and salt.
sessionID, err := crypto.NewSessionID(c.rand)
if err != nil {
return ClientExchangeResult{}, err
}
return ClientExchangeResult{
AuthKey: crypto.AuthKey{Value: key, ID: authKeyID},
SessionID: sessionID,
ServerSalt: serverSalt,
}, nil
case *mt.DhGenRetry: // dh_gen_retry#46dc1fb9
return ClientExchangeResult{}, errors.Errorf("retry required: %x", v.NewNonceHash2)
case *mt.DhGenFail: // dh_gen_fail#a69dae02
return ClientExchangeResult{}, errors.Errorf("dh_hen_fail: %x", v.NewNonceHash3)
default:
return ClientExchangeResult{}, errors.Errorf("unexpected SetClientDHParamsRequest result %T", v)
}
case *mt.ServerDHParamsFail:
return ClientExchangeResult{}, errors.New("server respond with server_DH_params_fail")
default:
return ClientExchangeResult{}, errors.Errorf("unexpected ReqDHParamsRequest result %T", p)
}
}
+52
View File
@@ -0,0 +1,52 @@
package exchange
import (
"context"
"crypto/rsa"
"math/rand"
"net"
"testing"
"time"
"github.com/stretchr/testify/require"
"go.uber.org/zap/zaptest"
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
"go.mau.fi/mautrix-telegram/pkg/gotd/tdsync"
"go.mau.fi/mautrix-telegram/pkg/gotd/transport"
)
func TestExchangeTimeout(t *testing.T) {
a := require.New(t)
reader := rand.New(rand.NewSource(1))
key, err := rsa.GenerateKey(reader, crypto.RSAKeyBits)
a.NoError(err)
log := zaptest.NewLogger(t)
i := transport.Intermediate
client, _ := i.Pipe()
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
g := tdsync.NewCancellableGroup(ctx)
g.Go(func(ctx context.Context) error {
_, err := NewExchanger(client, 2).
WithLogger(log.Named("client")).
WithRand(reader).
WithTimeout(1 * time.Second).
Client([]PublicKey{
{
RSA: &key.PublicKey,
},
}).
Run(ctx)
return err
})
err = g.Wait()
var e net.Error
a.ErrorAs(err, &e)
a.True(e.Timeout())
}
+9
View File
@@ -0,0 +1,9 @@
package exchange
import (
"github.com/go-faster/errors"
)
// ErrKeyFingerprintNotFound is returned when client can't find keys by fingerprints
// provided by server during key exchange.
var ErrKeyFingerprintNotFound = errors.New("key fingerprint not found")
+105
View File
@@ -0,0 +1,105 @@
// Package exchange contains Telegram key exchange algorithm flows.
// See https://core.telegram.org/mtproto/auth_key.
package exchange
import (
"io"
"time"
"go.uber.org/zap"
"go.mau.fi/mautrix-telegram/pkg/gotd/clock"
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
"go.mau.fi/mautrix-telegram/pkg/gotd/proto"
"go.mau.fi/mautrix-telegram/pkg/gotd/transport"
)
// DefaultTimeout is default WithTimeout parameter value.
const DefaultTimeout = 1 * time.Minute
// Exchanger is builder for key exchangers.
type Exchanger struct {
conn transport.Conn
clock clock.Clock
rand io.Reader
log *zap.Logger
timeout time.Duration
dc int
}
// WithClock sets exchange flow clock.
func (e Exchanger) WithClock(c clock.Clock) Exchanger {
e.clock = c
return e
}
// WithRand sets exchange flow random source.
func (e Exchanger) WithRand(reader io.Reader) Exchanger {
e.rand = reader
return e
}
// WithLogger sets exchange flow logger.
func (e Exchanger) WithLogger(log *zap.Logger) Exchanger {
e.log = log
return e
}
// WithTimeout sets write/read deadline of every exchange request.
func (e Exchanger) WithTimeout(timeout time.Duration) Exchanger {
e.timeout = timeout
return e
}
// NewExchanger creates new Exchanger.
func NewExchanger(conn transport.Conn, dc int) Exchanger {
return Exchanger{
conn: conn,
clock: clock.System,
rand: crypto.DefaultRand(),
log: zap.NewNop(),
timeout: DefaultTimeout,
dc: dc,
}
}
func (e Exchanger) unencryptedWriter(input, output proto.MessageType) unencryptedWriter {
return unencryptedWriter{
clock: e.clock,
conn: e.conn,
timeout: e.timeout,
input: input,
output: output,
}
}
// Client creates new ClientExchange using parameters from Exchanger.
func (e Exchanger) Client(keys []PublicKey) ClientExchange {
return ClientExchange{
unencryptedWriter: e.unencryptedWriter(
proto.MessageServerResponse,
proto.MessageFromClient,
),
rand: e.rand,
log: e.log,
keys: keys,
dc: e.dc,
}
}
// Server creates new ServerExchange using parameters from Exchanger.
func (e Exchanger) Server(key PrivateKey) ServerExchange {
return ServerExchange{
unencryptedWriter: e.unencryptedWriter(
proto.MessageFromClient,
proto.MessageServerResponse,
),
rand: e.rand,
log: e.log,
rng: TestServerRNG{rand: e.rand},
key: key,
dc: e.dc,
}
}
+114
View File
@@ -0,0 +1,114 @@
package exchange
import (
"context"
"crypto/rsa"
"fmt"
"math/rand"
"testing"
"time"
"github.com/stretchr/testify/require"
"go.uber.org/zap/zaptest"
"golang.org/x/sync/errgroup"
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
"go.mau.fi/mautrix-telegram/pkg/gotd/tdsync"
"go.mau.fi/mautrix-telegram/pkg/gotd/testutil"
"go.mau.fi/mautrix-telegram/pkg/gotd/transport"
)
func testExchange(rsaPad bool) func(t *testing.T) {
return func(t *testing.T) {
a := require.New(t)
log := zaptest.NewLogger(t)
dc := 2
reader := rand.New(rand.NewSource(1))
key, err := rsa.GenerateKey(reader, crypto.RSAKeyBits)
a.NoError(err)
privateKey := PrivateKey{
RSA: key,
}
i := transport.Intermediate
client, server := i.Pipe()
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
g := tdsync.NewCancellableGroup(ctx)
g.Go(func(ctx context.Context) error {
_, err := NewExchanger(client, dc).
WithLogger(log.Named("client")).
WithRand(reader).
Client([]PublicKey{privateKey.Public()}).
Run(ctx)
return err
})
g.Go(func(ctx context.Context) error {
_, err := NewExchanger(server, dc).
WithLogger(log.Named("server")).
WithRand(reader).
Server(privateKey).
Run(ctx)
return err
})
a.NoError(g.Wait())
}
}
func TestExchange(t *testing.T) {
t.Run("PQInnerData", testExchange(false))
t.Run("PQInnerDataDC", testExchange(true))
}
func TestExchangeCorpus(t *testing.T) {
privateKey := PrivateKey{
RSA: testutil.RSAPrivateKey(),
}
for i, seed := range []string{
"\xef\x00\x04",
} {
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
dc := 2
reader := testutil.Rand([]byte(seed))
log := zaptest.NewLogger(t)
i := transport.Intermediate
client, server := i.Pipe()
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
g, gctx := errgroup.WithContext(ctx)
g.Go(func() error {
_, err := NewExchanger(client, dc).
WithLogger(log.Named("client")).
WithRand(reader).
Client([]PublicKey{privateKey.Public()}).
Run(gctx)
if err != nil {
cancel()
}
return err
})
g.Go(func() error {
_, err := NewExchanger(server, dc).
WithLogger(log.Named("server")).
WithRand(reader).
Server(privateKey).
Run(gctx)
if err != nil {
cancel()
}
return err
})
require.NoError(t, g.Wait())
})
}
}
+67
View File
@@ -0,0 +1,67 @@
//go:build go1.18
package exchange
import (
"context"
"testing"
"time"
"go.uber.org/zap"
"golang.org/x/sync/errgroup"
"go.mau.fi/mautrix-telegram/pkg/gotd/testutil"
"go.mau.fi/mautrix-telegram/pkg/gotd/transport"
)
func FuzzValid(f *testing.F) {
f.Add([]byte{1, 2, 3})
f.Fuzz(func(t *testing.T, data []byte) {
const dc = 2
reader := testutil.Rand(data)
privateKey := PrivateKey{
RSA: testutil.RSAPrivateKey(),
}
config := zap.NewProductionConfig()
config.OutputPaths = []string{"stdout"}
log, err := config.Build()
if err != nil {
t.Fatal(err)
}
i := transport.Intermediate
client, server := i.Pipe()
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
g, gctx := errgroup.WithContext(ctx)
g.Go(func() error {
_, err := NewExchanger(client, dc).
WithLogger(log.Named("client")).
WithRand(reader).
Client([]PublicKey{privateKey.Public()}).
Run(gctx)
if err != nil {
cancel()
}
return err
})
g.Go(func() error {
_, err := NewExchanger(server, dc).
WithLogger(log.Named("server")).
WithRand(reader).
Server(privateKey).
Run(gctx)
if err != nil {
cancel()
}
return err
})
if err := g.Wait(); err != nil {
t.Fatal(err)
}
})
}
+81
View File
@@ -0,0 +1,81 @@
package exchange
import (
"crypto/rand"
"encoding/hex"
"io"
"math/big"
"github.com/go-faster/errors"
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
)
// ServerRNG is server-side random number generator.
type ServerRNG interface {
PQ() (pq *big.Int, err error)
GA(g int, dhPrime *big.Int) (a, ga *big.Int, err error)
DhPrime() (p *big.Int, err error)
}
var _ ServerRNG = TestServerRNG{}
// TestServerRNG implements testing-only ServerRNG.
type TestServerRNG struct {
rand io.Reader
}
func (s TestServerRNG) bigFromHex(hexString string) (p *big.Int, err error) {
data, err := hex.DecodeString(hexString)
if err != nil {
return nil, errors.Wrap(err, "decode hex string")
}
return big.NewInt(0).SetBytes(data), nil
}
// PQ always returns testing pq value.
//
// nolint:unparam
func (s TestServerRNG) PQ() (pq *big.Int, err error) {
return big.NewInt(0x17ED48941A08F981), nil
}
// GA returns testing a and g_a params.
func (s TestServerRNG) GA(g int, dhPrime *big.Int) (a, ga *big.Int, err error) {
if err := crypto.CheckGP(g, dhPrime); err != nil {
return nil, nil, err
}
gBig := big.NewInt(int64(g))
one := big.NewInt(1)
dhPrimeMinusOne := big.NewInt(0).Sub(dhPrime, one)
safetyRangeMin := big.NewInt(0).Exp(big.NewInt(2), big.NewInt(crypto.RSAKeyBits-64), nil)
safetyRangeMax := big.NewInt(0).Sub(dhPrime, safetyRangeMin)
randMax := big.NewInt(0).SetBit(big.NewInt(0), crypto.RSAKeyBits, 1)
for {
a, err = rand.Int(s.rand, randMax)
if err != nil {
return
}
ga = big.NewInt(0).Exp(gBig, a, dhPrime)
if crypto.InRange(ga, one, dhPrimeMinusOne) && crypto.InRange(ga, safetyRangeMin, safetyRangeMax) {
return
}
}
}
// DhPrime always returns testing dh_prime.
func (s TestServerRNG) DhPrime() (p *big.Int, err error) {
return s.bigFromHex("C71CAEB9C6B1C9048E6C522F70F13F73980D40238E3E21C14934D037563D930F" +
"48198A0AA7C14058229493D22530F4DBFA336F6E0AC925139543AED44CCE7C37" +
"20FD51F69458705AC68CD4FE6B6B13ABDC9746512969328454F18FAF8C595F64" +
"2477FE96BB2A941D5BCD1D4AC8CC49880708FA9B378E3C4F3A9060BEE67CF9A4" +
"A4A695811051907E162753B56B0F6B410DBA74D8A84B2A14B3144E0EF1284754" +
"FD17ED950D5965B4B9DD46582DB1178D169C6BC465B0D6FF9CA3928FEF5B9AE4" +
"E418FC15E83EBEA0F87FA9FF5EED70050DED2849F47BF959D956850CE929851F" +
"0D8115F635B105EE2E4E15D04B2454BF6F4FADF034B10403119CD8E3B92FCC5B")
}
+46
View File
@@ -0,0 +1,46 @@
package exchange
import (
"crypto/rsa"
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
)
// PublicKey is a public Telegram server key.
type PublicKey struct {
// RSA public key.
RSA *rsa.PublicKey
}
// Zero denotes that current PublicKey is zero value.
func (k PublicKey) Zero() bool {
return k.RSA == nil
}
// Fingerprint computes key fingerprint.
func (k PublicKey) Fingerprint() int64 {
return crypto.RSAFingerprint(k.RSA)
}
// PrivateKey is a private Telegram server key.
type PrivateKey struct {
// RSA private key.
RSA *rsa.PrivateKey
}
// Zero denotes that current PublicKey is zero value.
func (k PrivateKey) Zero() bool {
return k.RSA == nil
}
// Fingerprint computes key fingerprint.
func (k PrivateKey) Fingerprint() int64 {
return crypto.RSAFingerprint(&k.RSA.PublicKey)
}
// Public returns PublicKey of this PrivateKey pair.
func (k PrivateKey) Public() PublicKey {
return PublicKey{
RSA: &k.RSA.PublicKey,
}
}
+93
View File
@@ -0,0 +1,93 @@
package exchange
import (
"context"
"time"
"github.com/go-faster/errors"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
"go.mau.fi/mautrix-telegram/pkg/gotd/clock"
"go.mau.fi/mautrix-telegram/pkg/gotd/proto"
"go.mau.fi/mautrix-telegram/pkg/gotd/proto/codec"
"go.mau.fi/mautrix-telegram/pkg/gotd/transport"
)
type unencryptedWriter struct {
clock clock.Clock
conn transport.Conn
timeout time.Duration
input proto.MessageType
output proto.MessageType
}
func (w unencryptedWriter) writeUnencrypted(ctx context.Context, b *bin.Buffer, data bin.Encoder) error {
b.Reset()
if err := data.Encode(b); err != nil {
return err
}
msg := proto.UnencryptedMessage{
MessageID: int64(proto.NewMessageID(w.clock.Now(), w.output)),
MessageData: b.Copy(),
}
b.Reset()
if err := msg.Encode(b); err != nil {
return err
}
ctx, cancel := context.WithTimeout(ctx, w.timeout)
defer cancel()
return w.conn.Send(ctx, b)
}
func (w unencryptedWriter) tryRead(ctx context.Context, b *bin.Buffer) error {
ctx, cancel := context.WithTimeout(ctx, w.timeout)
defer cancel()
if err := w.conn.Recv(ctx, b); err != nil {
return err
}
return nil
}
func (w unencryptedWriter) isClient() bool {
return w.output == proto.MessageFromClient
}
func (w unencryptedWriter) readUnencrypted(ctx context.Context, b *bin.Buffer, data bin.Decoder) error {
b.Reset()
for {
if err := w.tryRead(ctx, b); err != nil {
var protocolErr *codec.ProtocolErr
if w.isClient() &&
errors.As(err, &protocolErr) &&
protocolErr.Code == codec.CodeAuthKeyNotFound {
continue
}
return err
}
break
}
var msg proto.UnencryptedMessage
if err := msg.Decode(b); err != nil {
return err
}
if err := w.checkMsgID(msg.MessageID); err != nil {
return err
}
b.ResetTo(msg.MessageData)
return data.Decode(b)
}
func (w unencryptedWriter) checkMsgID(id int64) error {
if proto.MessageID(id).Type() != w.input {
return errors.New("bad msg type")
}
return nil
}
+26
View File
@@ -0,0 +1,26 @@
package exchange
import (
"io"
"go.uber.org/zap"
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
)
// ServerExchange is a server-side key exchange flow.
type ServerExchange struct {
unencryptedWriter
rand io.Reader
log *zap.Logger
rng ServerRNG
key PrivateKey
dc int
}
// ServerExchangeResult contains server part of key exchange result.
type ServerExchangeResult struct {
Key crypto.AuthKey
ServerSalt int64
}
+273
View File
@@ -0,0 +1,273 @@
package exchange
import (
"context"
"math/big"
"github.com/go-faster/errors"
"go.uber.org/zap"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
"go.mau.fi/mautrix-telegram/pkg/gotd/mt"
"go.mau.fi/mautrix-telegram/pkg/gotd/proto/codec"
)
// ServerExchangeError is returned when exchange fails due to
// some security or validation checks.
type ServerExchangeError struct {
Code int32
Err error
}
// Error implements error.
func (s *ServerExchangeError) Error() string {
return s.Err.Error()
}
// Unwrap implements error wrapper interface.
func (s *ServerExchangeError) Unwrap() error {
return s.Err
}
func serverError(code int32, err error) error {
return &ServerExchangeError{
Code: code,
Err: err,
}
}
// req_pq#60469778 or req_pq_multi#be7e8ef1
type reqPQ struct {
Type uint32
Nonce bin.Int128
}
func (r *reqPQ) Decode(b *bin.Buffer) error {
var (
legacy mt.ReqPqRequest
multi mt.ReqPqMultiRequest
)
id, err := b.PeekID()
if err != nil {
return err
}
r.Type = id
switch id {
case legacy.TypeID():
if err := legacy.Decode(b); err != nil {
return err
}
r.Nonce = legacy.Nonce
return nil
case multi.TypeID():
if err := multi.Decode(b); err != nil {
return err
}
r.Nonce = multi.Nonce
return nil
default:
return bin.NewUnexpectedID(id)
}
}
type reqOrDH struct {
Type uint32
DH mt.ReqDHParamsRequest
Req reqPQ
}
func (r *reqOrDH) Decode(b *bin.Buffer) error {
id, err := b.PeekID()
if err != nil {
return err
}
r.Type = id
switch id {
case r.DH.TypeID():
return r.DH.Decode(b)
default:
return r.Req.Decode(b)
}
}
// Run runs server-side flow.
// If b parameter is not nil, it will be used as first read message.
// Otherwise, it will be read from connection.
func (s ServerExchange) Run(ctx context.Context) (ServerExchangeResult, error) {
wrapKeyNotFound := func(err error) error {
return serverError(codec.CodeAuthKeyNotFound, err)
}
// 1. Client sends query to server
// req_pq#60469778 nonce:int128 = ResPQ; // legacy
// req_pq_multi#be7e8ef1 nonce:int128 = ResPQ;
var req reqPQ
b := new(bin.Buffer)
if err := s.readUnencrypted(ctx, b, &req); err != nil {
return ServerExchangeResult{}, err
}
s.log.Debug("Received client ReqPqMultiRequest")
serverNonce, err := crypto.RandInt128(s.rand)
if err != nil {
return ServerExchangeResult{}, errors.Wrap(err, "generate server nonce")
}
// 2. Server sends response of the form
//
// resPQ#05162463 nonce:int128 server_nonce:int128 pq:string server_public_key_fingerprints:Vector long = ResPQ;
pq, err := s.rng.PQ()
if err != nil {
return ServerExchangeResult{}, errors.Wrap(err, "generate pq")
}
SendResPQ:
s.log.Debug("Sending ResPQ", zap.String("pq", pq.String()))
if err := s.writeUnencrypted(ctx, b, &mt.ResPQ{
Pq: pq.Bytes(),
Nonce: req.Nonce,
ServerNonce: serverNonce,
ServerPublicKeyFingerprints: []int64{
s.key.Fingerprint(),
},
}); err != nil {
return ServerExchangeResult{}, err
}
// 4. Client sends query to server
//
// req_DH_params#d712e4be nonce:int128 server_nonce:int128 p:string
// q:string public_key_fingerprint:long encrypted_data:string = Server_DH_Params
var dhParams reqOrDH
if err := s.readUnencrypted(ctx, b, &dhParams); err != nil {
return ServerExchangeResult{}, err
}
switch dhParams.Type {
case mt.ReqPqRequestTypeID, mt.ReqPqMultiRequestTypeID:
// Client can send fake req_pq on start. Ignore it.
//
// Next one should be not fake.
s.log.Debug("Received ReqPQ again")
req = dhParams.Req
goto SendResPQ
default:
s.log.Debug("Received client ReqDHParamsRequest")
}
var innerData mt.PQInnerData
{
r, err := crypto.DecodeRSAPad(dhParams.DH.EncryptedData, s.key.RSA)
if err != nil {
return ServerExchangeResult{}, wrapKeyNotFound(err)
}
b.ResetTo(r)
d, err := mt.DecodePQInnerData(b)
if err != nil {
return ServerExchangeResult{}, err
}
if innerDataDC, ok := d.(*mt.PQInnerDataDC); ok && innerDataDC.DC != s.dc {
err := errors.Errorf(
"wrong DC ID, want %d, got %d",
s.dc, innerDataDC.DC,
)
return ServerExchangeResult{}, serverError(codec.CodeWrongDC, err)
}
innerData = mt.PQInnerData{
Pq: d.GetPq(),
P: d.GetP(),
Q: d.GetQ(),
Nonce: d.GetNonce(),
ServerNonce: d.GetServerNonce(),
NewNonce: d.GetNewNonce(),
}
}
dhPrime, err := s.rng.DhPrime()
if err != nil {
return ServerExchangeResult{}, errors.Wrap(err, "generate dh_prime")
}
g := 3
a, ga, err := s.rng.GA(g, dhPrime)
if err != nil {
return ServerExchangeResult{}, errors.Wrap(err, "generate g_a")
}
data := mt.ServerDHInnerData{
Nonce: req.Nonce,
ServerNonce: serverNonce,
G: g,
GA: ga.Bytes(),
DhPrime: dhPrime.Bytes(),
ServerTime: int(s.clock.Now().Unix()),
}
b.Reset()
if err := data.Encode(b); err != nil {
return ServerExchangeResult{}, err
}
key, iv := crypto.TempAESKeys(innerData.NewNonce.BigInt(), serverNonce.BigInt())
answer, err := crypto.EncryptExchangeAnswer(s.rand, b.Raw(), key, iv)
if err != nil {
return ServerExchangeResult{}, err
}
s.log.Debug("Sending ServerDHParamsOk", zap.Int("g", g))
// 5. Server responds with Server_DH_Params.
if err := s.writeUnencrypted(ctx, b, &mt.ServerDHParamsOk{
Nonce: req.Nonce,
ServerNonce: serverNonce,
EncryptedAnswer: answer,
}); err != nil {
return ServerExchangeResult{}, err
}
var clientDhParams mt.SetClientDHParamsRequest
if err := s.readUnencrypted(ctx, b, &clientDhParams); err != nil {
return ServerExchangeResult{}, err
}
s.log.Debug("Received client SetClientDHParamsRequest")
decrypted, err := crypto.DecryptExchangeAnswer(clientDhParams.EncryptedData, key, iv)
if err != nil {
err = errors.Wrap(err, "decrypt exchange answer")
return ServerExchangeResult{}, wrapKeyNotFound(err)
}
b.ResetTo(decrypted)
var clientInnerData mt.ClientDHInnerData
if err := clientInnerData.Decode(b); err != nil {
return ServerExchangeResult{}, wrapKeyNotFound(err)
}
gB := big.NewInt(0).SetBytes(clientInnerData.GB)
var authKey crypto.Key
if !crypto.FillBytes(big.NewInt(0).Exp(gB, a, dhPrime), authKey[:]) {
err := errors.New("auth_key is too big")
return ServerExchangeResult{}, wrapKeyNotFound(err)
}
// DH key exchange complete
// 8. Server responds in one of three ways:
// dh_gen_ok#3bcbf734 nonce:int128 server_nonce:int128
// new_nonce_hash1:int128 = Set_client_DH_params_answer;
s.log.Debug("Sending DhGenOk")
if err := s.writeUnencrypted(ctx, b, &mt.DhGenOk{
Nonce: req.Nonce,
ServerNonce: serverNonce,
NewNonceHash1: crypto.NonceHash1(innerData.NewNonce, authKey),
}); err != nil {
return ServerExchangeResult{}, err
}
serverSalt := crypto.ServerSalt(innerData.NewNonce, serverNonce)
return ServerExchangeResult{
Key: authKey.WithID(),
ServerSalt: serverSalt,
}, nil
}