move gotd fork into repo. (#111)

- update to latest telegram layer
- remove some references to fields in tg.Entities that don't exist in
the schema
- originally added here:
https://github.com/beeper/td/commit/820929062a2ba0104397bc01235ab58a9cff780e
  - referenced here
-
https://github.com/mautrix/telegramgo/commit/124f0967ed195b5a380c9bd02e170ada9710dde3
-
https://github.com/mautrix/telegramgo/commit/4205047aab2e0639217148b5d125bfaab668bd8e
This commit is contained in:
Adam Van Ymeren
2025-06-27 20:03:37 -07:00
committed by GitHub
parent 0952df0244
commit 7a04f298d2
19264 changed files with 1539697 additions and 84 deletions
+145
View File
@@ -0,0 +1,145 @@
package codec
import (
"encoding/binary"
"io"
"github.com/go-faster/errors"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
)
// AbridgedClientStart is starting bytes sent by client in Abridged mode.
//
// Note that server does not respond with it.
var AbridgedClientStart = [1]byte{0xef}
// Abridged is intermediate MTProto transport.
//
// See https://core.telegram.org/mtproto/mtproto-transports#abridged
type Abridged struct{}
var (
_ TaggedCodec = Abridged{}
)
// WriteHeader sends protocol tag.
func (i Abridged) WriteHeader(w io.Writer) error {
if _, err := w.Write(AbridgedClientStart[:]); err != nil {
return errors.Wrap(err, "write abridged header")
}
return nil
}
// ReadHeader reads protocol tag.
func (i Abridged) ReadHeader(r io.Reader) error {
var b [1]byte
if _, err := io.ReadFull(r, b[:]); err != nil {
return errors.Wrap(err, "read abridged header")
}
if b != AbridgedClientStart {
return ErrProtocolHeaderMismatch
}
return nil
}
// ObfuscatedTag returns protocol tag for obfuscation.
func (i Abridged) ObfuscatedTag() (r [4]byte) {
d := AbridgedClientStart[0]
return [4]byte{d, d, d, d}
}
// Write encode to writer message from given buffer.
func (i Abridged) Write(w io.Writer, b *bin.Buffer) error {
if err := checkOutgoingMessage(b); err != nil {
return err
}
if err := checkAlign(b, 4); err != nil {
return err
}
if err := writeAbridged(w, b); err != nil {
return errors.Wrap(err, "write abridged")
}
return nil
}
// Read fills buffer with received message.
func (i Abridged) Read(r io.Reader, b *bin.Buffer) error {
if err := readAbridged(r, b); err != nil {
return errors.Wrap(err, "read abridged")
}
return checkProtocolError(b)
}
func writeAbridged(w io.Writer, b *bin.Buffer) error {
length := b.Len()
// Re-using b.Buf if possible to reduce allocations.
b.Expand(4)
b.Buf = b.Buf[:length]
// Re-using b.Buf if possible to reduce allocations.
inner := bin.Buffer{Buf: b.Buf[length:length]}
encodeLength := b.Len() >> 2
// `0x7f == 127`, literally use one bit to distinguish length byte size.
if encodeLength < 127 {
// Payloads are wrapped in the following envelope:
//
// Length: payload length, divided by four, and encoded as a single byte,
// only if the resulting packet length is a value between 0x01..0x7e.
inner.Put([]byte{byte(encodeLength)})
} else {
// If the packet length divided by four is bigger than or equal to 127 (>= 0x7f),
// the following envelope must be used, instead:
//
var buf [5]byte
// Header: A single byte of value 0x7f
buf[0] = 0x7f
// Length: payload length, divided by four, and encoded as 3 length bytes (little endian)
binary.LittleEndian.PutUint32(buf[1:], uint32(encodeLength))
inner.Put(buf[:4])
}
if _, err := w.Write(inner.Buf); err != nil {
return err
}
if _, err := w.Write(b.Raw()); err != nil {
return err
}
return nil
}
func readAbridged(r io.Reader, b *bin.Buffer) error {
b.ResetN(bin.Word)
_, err := io.ReadFull(r, b.Buf[:1])
if err != nil {
return err
}
if b.Buf[0] >= 127 {
_, err := io.ReadFull(r, b.Buf[0:3])
if err != nil {
return err
}
}
n, err := b.Int()
if err != nil {
return err
}
b.ResetN(n << 2)
if _, err := io.ReadFull(r, b.Buf); err != nil {
return errors.Wrap(err, "read payload")
}
return nil
}
+69
View File
@@ -0,0 +1,69 @@
package codec
import (
"bytes"
"encoding/binary"
"strings"
"testing"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
"github.com/stretchr/testify/require"
)
func TestAbridged(t *testing.T) {
bigHeader := func(l int) (packet []byte) {
// header + 3 bytes of LE
var buf [4]byte
binary.LittleEndian.PutUint32(buf[:], uint32(l>>2))
packet = append([]byte{127}, buf[0:3]...)
return
}
tests := []struct {
payloadName string
testData func() (payload string, packet []byte)
}{
{"Small-4b", func() (payload string, packet []byte) {
payload = "abcd"
packet = append([]byte{byte(len(payload) >> 2)}, payload...)
return
}},
{"Medium-124b", func() (payload string, packet []byte) {
payload = strings.Repeat("a", 124)
packet = append([]byte{byte(len(payload) >> 2)}, payload...)
return
}},
{"Big-1kb", func() (payload string, packet []byte) {
payload = strings.Repeat("a", 1024)
packet = bigHeader(len(payload))
require.Equal(t, []byte{127, 0, 1, 0}, packet)
packet = append(packet, payload...)
return
}},
}
for _, test := range tests {
payload, packet := test.testData()
t.Run(test.payloadName, func(t *testing.T) {
t.Run("Write", func(t *testing.T) {
b := bytes.NewBuffer(nil)
err := writeAbridged(b, &bin.Buffer{Buf: []byte(payload)})
require.NoError(t, err)
require.Equal(t, packet, b.Bytes())
})
t.Run("Read", func(t *testing.T) {
b := &bin.Buffer{}
err := readAbridged(bytes.NewReader(packet), b)
require.NoError(t, err)
require.Equal(t, payload, string(b.Raw()))
})
})
}
}
+86
View File
@@ -0,0 +1,86 @@
package codec
import (
"bytes"
"crypto/rand"
"io"
"testing"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
"go.mau.fi/mautrix-telegram/pkg/gotd/testutil"
)
func benchWrite(codec Codec) func(payloadSize int) func(b *testing.B) {
return func(payloadSize int) func(b *testing.B) {
return func(b *testing.B) {
buf := bin.Buffer{Buf: make([]byte, payloadSize)}
if _, err := io.ReadFull(rand.Reader, buf.Buf); err != nil {
b.Fatal(err)
}
b.ReportAllocs()
b.SetBytes(int64(buf.Len() + 4))
b.ResetTimer()
for i := 0; i < b.N; i++ {
if err := codec.Write(io.Discard, &buf); err != nil {
b.Fatal(err)
}
}
}
}
}
func BenchmarkWrite(b *testing.B) {
b.Run("Abridged", func(b *testing.B) {
testutil.RunPayloads(b, benchWrite(Abridged{}))
})
b.Run("Intermediate", func(b *testing.B) {
testutil.RunPayloads(b, benchWrite(Intermediate{}))
})
b.Run("PaddedIntermediate", func(b *testing.B) {
testutil.RunPayloads(b, benchWrite(PaddedIntermediate{}))
})
}
func benchRead(codec Codec) func(payloadSize int) func(b *testing.B) {
return func(payloadSize int) func(b *testing.B) {
return func(b *testing.B) {
buf := bin.Buffer{Buf: make([]byte, payloadSize)}
if _, err := io.ReadFull(rand.Reader, buf.Buf); err != nil {
b.Fatal(err)
}
out := new(bytes.Buffer)
if err := codec.Write(out, &buf); err != nil {
b.Fatal(err)
}
raw := out.Bytes()
reader := bytes.NewReader(nil)
b.ReportAllocs()
b.SetBytes(int64(buf.Len() + 4))
b.ResetTimer()
for i := 0; i < b.N; i++ {
reader.Reset(raw)
if err := codec.Read(reader, &buf); err != nil {
b.Fatal(err)
}
buf.Reset()
}
}
}
}
func BenchmarkRead(b *testing.B) {
b.Run("Abridged", func(b *testing.B) {
testutil.RunPayloads(b, benchRead(Abridged{}))
})
b.Run("Intermediate", func(b *testing.B) {
testutil.RunPayloads(b, benchRead(Intermediate{}))
})
b.Run("PaddedIntermediate", func(b *testing.B) {
testutil.RunPayloads(b, benchRead(PaddedIntermediate{}))
})
}
+44
View File
@@ -0,0 +1,44 @@
package codec
import (
"encoding/binary"
"io"
"github.com/go-faster/errors"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
)
// Codec is MTProto transport protocol encoding abstraction.
type Codec interface {
// WriteHeader sends protocol tag if needed.
WriteHeader(w io.Writer) error
// ReadHeader reads protocol tag if needed.
ReadHeader(r io.Reader) error
// Write encode to writer message from given buffer.
Write(w io.Writer, b *bin.Buffer) error
// Read fills buffer with received message.
Read(r io.Reader, b *bin.Buffer) error
}
// TaggedCodec is codec with protocol tag.
type TaggedCodec interface {
Codec
// ObfuscatedTag returns protocol tag for obfuscation.
ObfuscatedTag() [4]byte
}
// readLen reads 32-bit integer and validates it as message length.
func readLen(r io.Reader, b *bin.Buffer) (int, error) {
b.ResetN(bin.Word)
if _, err := io.ReadFull(r, b.Buf[:bin.Word]); err != nil {
return 0, errors.Wrap(err, "read length")
}
n := int(binary.LittleEndian.Uint32(b.Buf[:bin.Word]))
if n <= 0 || n > maxMessageSize {
return 0, invalidMsgLenErr{n: n}
}
return n, nil
}
+197
View File
@@ -0,0 +1,197 @@
package codec
import (
"bytes"
"encoding/binary"
"io"
"strings"
"testing"
"testing/iotest"
"github.com/stretchr/testify/require"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
"go.mau.fi/mautrix-telegram/pkg/gotd/testutil"
)
type codecTest struct {
name string
align int
create func() Codec
}
func codecs() []codecTest {
return []codecTest{
{"Abridged", 4, func() Codec {
return Abridged{}
}},
{"Intermediate", 4, func() Codec {
return Intermediate{}
}},
{"PaddedIntermediate", 4, func() Codec {
return PaddedIntermediate{}
}},
{"Full", 0, func() Codec {
return &Full{}
}},
{"NoHeaderIntermediate", 0, func() Codec {
return NoHeader{Codec: Intermediate{}}
}},
}
}
type payload struct {
name string
testData string
mustFail bool
readTestOnly bool
}
func payloads() []payload {
var code [4]byte
binary.LittleEndian.PutUint32(code[:], CodeTransportFlood)
return []payload{
{"Empty", "", true, false},
{"Protocol error", string(code[:]), true, true},
{"Small 8b", "abcdabcd", false, false},
{"Medium 1kb", strings.Repeat("a", 1024), false, false},
}
}
func testGood(c codecTest, p payload) func(t *testing.T) {
return func(t *testing.T) {
t.Run("One message", func(t *testing.T) {
a := require.New(t)
codec := c.create()
buf := bytes.NewBuffer(nil)
payload := &bin.Buffer{Buf: []byte(p.testData)}
l := payload.Len()
// Encode
a.NoError(codec.Write(buf, payload))
a.Equal(l, payload.Len(), "Codec must not change buffer length")
// Decode
payload.Reset()
a.NoError(codec.Read(buf, payload))
a.Equal(p.testData, string(payload.Buf))
})
t.Run("Two messages", func(t *testing.T) {
a := require.New(t)
codec := c.create()
buf := bytes.NewBuffer(nil)
payload := &bin.Buffer{Buf: []byte(p.testData)}
l := payload.Len()
// Encode twice
a.NoError(codec.Write(buf, payload))
payload.ResetTo([]byte(p.testData))
a.NoError(codec.Write(buf, payload))
a.Equal(l, payload.Len(), "Codec must not change buffer length")
// Decode twice
payload.Reset()
a.NoError(codec.Read(buf, payload))
a.Equal(p.testData, string(payload.Buf))
payload.Reset()
a.NoError(codec.Read(buf, payload))
a.Equal(p.testData, string(payload.Buf))
})
}
}
func testBad(c codecTest, p payload) func(t *testing.T) {
return func(t *testing.T) {
if !p.readTestOnly {
t.Run("Write", func(t *testing.T) {
payload := &bin.Buffer{Buf: []byte(p.testData)}
require.Error(t, c.create().Write(io.Discard, payload))
})
}
t.Run("Read", func(t *testing.T) {
reader := bytes.NewBufferString(p.testData)
require.Error(t, c.create().Read(reader, &bin.Buffer{}))
})
}
}
func testHeaderTag(c codecTest) func(t *testing.T) {
e := io.ErrClosedPipe
return func(t *testing.T) {
t.Run("GoodTag", func(t *testing.T) {
a := require.New(t)
buf := bytes.NewBuffer(nil)
a.NoError(c.create().WriteHeader(buf))
a.NoError(c.create().ReadHeader(buf))
})
if tagged, ok := c.create().(TaggedCodec); ok {
t.Run("ReadError", func(t *testing.T) {
r := iotest.ErrReader(e)
require.ErrorIs(t, c.create().ReadHeader(r), e)
})
t.Run("WriteError", func(t *testing.T) {
w := testutil.ErrWriter(e)
require.ErrorIs(t, c.create().WriteHeader(w), e)
})
t.Run("BadTag", func(t *testing.T) {
tag := tagged.ObfuscatedTag()
buf := bytes.NewBuffer(tag[:])
tag[0] = 0
require.ErrorIs(t, c.create().ReadHeader(buf), ErrProtocolHeaderMismatch)
})
}
}
}
func testCodec(c codecTest) func(t *testing.T) {
e := io.ErrClosedPipe
return func(t *testing.T) {
t.Run("ReadError", func(t *testing.T) {
r := iotest.ErrReader(e)
require.ErrorIs(t, c.create().Read(r, &bin.Buffer{}), e)
})
t.Run("WriteError", func(t *testing.T) {
w := testutil.ErrWriter(e)
require.ErrorIs(t,
c.create().Write(
w,
&bin.Buffer{Buf: make([]byte, 16)},
),
e,
)
})
if c.align != 0 {
t.Run("AlignError", func(t *testing.T) {
require.Error(t,
c.create().Write(
io.Discard,
&bin.Buffer{Buf: make([]byte, c.align-1)},
),
)
})
}
}
}
func TestCodecs(t *testing.T) {
for _, c := range codecs() {
t.Run(c.name, func(t *testing.T) {
for _, p := range payloads() {
if p.mustFail {
t.Run(p.name, testBad(c, p))
} else {
t.Run(p.name, testGood(c, p))
}
}
t.Run("Header", testHeaderTag(c))
testCodec(c)(t)
})
}
}
+4
View File
@@ -0,0 +1,4 @@
// Package codec contains MTProto transport encoding implementations.
//
// See https://core.telegram.org/mtproto/mtproto-transports
package codec
+106
View File
@@ -0,0 +1,106 @@
package codec
import (
"fmt"
"github.com/go-faster/errors"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
)
const (
// CodeAuthKeyNotFound means that specified auth key ID cannot be found by the DC.
// Also, may be returned during key exchange.
CodeAuthKeyNotFound = 404
// CodeWrongDC means that current DC is wrong.
// Usually returned by server when key exchange sends wrong DC ID.
CodeWrongDC = 444
// CodeTransportFlood means that too many transport connections are
// established to the same IP in a too short lapse of time, or if any
// of the container/service message limits are reached.
CodeTransportFlood = 429
)
// ProtocolErr represents protocol level error.
type ProtocolErr struct {
Code int32
}
func (p ProtocolErr) Error() string {
switch p.Code {
case CodeAuthKeyNotFound:
return "auth key not found"
case CodeTransportFlood:
return "transport flood"
case CodeWrongDC:
return "wrong DC"
default:
return fmt.Sprintf("protocol error %d", p.Code)
}
}
// Can be bigger that 1mb.
//
// See https://github.com/gotd/td/issues/412
//
// See https://github.com/tdlib/td/blob/550ccc8d9bbbe9cff1dc618aef5764d2cbd2cd91/td/mtproto/TcpTransport.cpp#L53
const maxMessageSize = 1 << 24 // 16 MB
func checkOutgoingMessage(b *bin.Buffer) error {
length := b.Len()
if length > maxMessageSize || length == 0 {
return invalidMsgLenErr{n: length}
}
return nil
}
func checkAlign(b *bin.Buffer, n int) error {
length := b.Len()
if length%n != 0 {
return alignedPayloadExpectedErr{expected: n}
}
return nil
}
func checkProtocolError(b *bin.Buffer) error {
if b.Len() != bin.Word {
return nil
}
code, err := b.Int32()
if err != nil {
return err
}
return &ProtocolErr{Code: -code}
}
type alignedPayloadExpectedErr struct {
expected int
}
func (e alignedPayloadExpectedErr) Error() string {
return fmt.Sprintf("payload is not aligned, expected align by %d", e.expected)
}
func (e alignedPayloadExpectedErr) Is(err error) bool {
_, ok := err.(alignedPayloadExpectedErr)
return ok
}
type invalidMsgLenErr struct {
n int
}
func (e invalidMsgLenErr) Error() string {
return fmt.Sprintf("invalid message length %d", e.n)
}
func (e invalidMsgLenErr) Is(err error) bool {
_, ok := err.(invalidMsgLenErr)
return ok
}
// ErrProtocolHeaderMismatch means that received protocol header
// is mismatched with expected.
var ErrProtocolHeaderMismatch = errors.New("protocol header mismatch")
+124
View File
@@ -0,0 +1,124 @@
package codec
import (
"hash/crc32"
"io"
"sync/atomic"
"github.com/go-faster/errors"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
)
// Full is full MTProto transport.
//
// See https://core.telegram.org/mtproto/mtproto-transports#full
type Full struct {
wSeqNo int64
rSeqNo int64
}
// WriteHeader sends protocol tag.
func (i *Full) WriteHeader(w io.Writer) (err error) {
return nil
}
// ReadHeader reads protocol tag.
func (i *Full) ReadHeader(r io.Reader) (err error) {
return nil
}
// Write encode to writer message from given buffer.
func (i *Full) Write(w io.Writer, b *bin.Buffer) error {
if err := checkOutgoingMessage(b); err != nil {
return err
}
if err := writeFull(w, int(atomic.AddInt64(&i.wSeqNo, 1)-1), b); err != nil {
return errors.Wrap(err, "write full")
}
return nil
}
// Read fills buffer with received message.
func (i *Full) Read(r io.Reader, b *bin.Buffer) error {
if err := readFull(r, int(atomic.AddInt64(&i.rSeqNo, 1)-1), b); err != nil {
return errors.Wrap(err, "read full")
}
return checkProtocolError(b)
}
func writeFull(w io.Writer, seqNo int, b *bin.Buffer) error {
write := bin.Buffer{Buf: make([]byte, 0, 4+4+b.Len()+4)}
// Length: length+seqno+payload+crc length encoded as 4 length bytes
// (little endian, the length of the length field must be included, too)
write.PutInt(4 + 4 + b.Len() + 4)
// Seqno: the TCP sequence number for this TCP connection (different from the MTProto sequence number):
// the first packet sent is numbered 0, the next one 1, etc.
write.PutInt(seqNo)
// payload: MTProto payload
write.Put(b.Raw())
// crc: 4 CRC32 bytes computed using length, sequence number, and payload together.
crc := crc32.ChecksumIEEE(write.Raw())
write.PutUint32(crc)
if _, err := w.Write(write.Raw()); err != nil {
return err
}
return nil
}
var errSeqNoMismatch = errors.New("seq_no mismatch")
var errCRCMismatch = errors.New("crc mismatch")
func readFull(r io.Reader, seqNo int, b *bin.Buffer) error {
n, err := readLen(r, b)
if err != nil {
return errors.Wrap(err, "len")
}
// Put length, because it need to count CRC.
b.PutInt(n)
b.Expand(n - bin.Word)
inner := &bin.Buffer{Buf: b.Buf[bin.Word:n]}
// Reads tail of packet to the buffer.
// Length already read.
if _, err := io.ReadFull(r, inner.Buf); err != nil {
return errors.Wrap(err, "read seqno, buffer and crc")
}
serverSeqNo, err := inner.Int()
if err != nil {
return err
}
if serverSeqNo != seqNo {
return errSeqNoMismatch
}
payloadLength := n - 3*bin.Word
inner.Skip(payloadLength)
// Cut only crc part.
crc, err := inner.Uint32()
if err != nil {
return err
}
// Compute crc using all buffer without last 4 bytes from server.
clientCRC := crc32.ChecksumIEEE(b.Buf[0 : n-bin.Word])
// Compare computed and read CRCs.
if crc != clientCRC {
return errCRCMismatch
}
// n
// Length | SeqNo | payload | CRC |
// Word | Word | ....... | Word |
copy(b.Buf, b.Buf[2*bin.Word:n-bin.Word])
b.Buf = b.Buf[:payloadLength]
return nil
}
+40
View File
@@ -0,0 +1,40 @@
package codec
import (
"bytes"
"testing"
"github.com/stretchr/testify/require"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
)
func fullTestData() (packet, payload []byte) {
return []byte{
15, 0, 0, 0, // length
1, 0, 0, 0, // seqNo
97, 98, 99, // payload
78, 214, 109, 148, // crc
}, []byte("abc")
}
func TestFull(t *testing.T) {
packet, payload := fullTestData()
t.Run("write", func(t *testing.T) {
b := bytes.NewBuffer(nil)
buf := &bin.Buffer{Buf: payload}
err := writeFull(b, 1, buf)
require.NoError(t, err)
require.Equal(t, packet, b.Bytes())
})
t.Run("read", func(t *testing.T) {
b := &bin.Buffer{}
err := readFull(bytes.NewBuffer(packet), 1, b)
require.NoError(t, err)
require.Equal(t, payload, b.Buf)
})
}
+115
View File
@@ -0,0 +1,115 @@
package codec
import (
"io"
"github.com/go-faster/errors"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
)
// IntermediateClientStart is starting bytes sent by client in Intermediate mode.
//
// Note that server does not respond with it.
var IntermediateClientStart = [4]byte{0xee, 0xee, 0xee, 0xee}
// Intermediate is intermediate MTProto transport.
//
// See https://core.telegram.org/mtproto/mtproto-transports#intermediate
type Intermediate struct{}
var (
_ TaggedCodec = Intermediate{}
)
// WriteHeader sends protocol tag.
func (i Intermediate) WriteHeader(w io.Writer) (err error) {
if _, err := w.Write(IntermediateClientStart[:]); err != nil {
return errors.Wrap(err, "write intermediate header")
}
return nil
}
// ReadHeader reads protocol tag.
func (i Intermediate) ReadHeader(r io.Reader) (err error) {
var b [4]byte
if _, err := io.ReadFull(r, b[:]); err != nil {
return errors.Wrap(err, "read intermediate header")
}
if b != IntermediateClientStart {
return ErrProtocolHeaderMismatch
}
return nil
}
// ObfuscatedTag returns protocol tag for obfuscation.
func (i Intermediate) ObfuscatedTag() [4]byte {
return IntermediateClientStart
}
// Write encode to writer message from given buffer.
func (i Intermediate) Write(w io.Writer, b *bin.Buffer) error {
if err := checkOutgoingMessage(b); err != nil {
return err
}
if err := checkAlign(b, 4); err != nil {
return err
}
if err := writeIntermediate(w, b); err != nil {
return errors.Wrap(err, "write intermediate")
}
return nil
}
// Read fills buffer with received message.
func (i Intermediate) Read(r io.Reader, b *bin.Buffer) error {
if err := readIntermediate(r, b, false); err != nil {
return errors.Wrap(err, "read intermediate")
}
return checkProtocolError(b)
}
// writeIntermediate encodes b as payload to w.
func writeIntermediate(w io.Writer, b *bin.Buffer) error {
length := b.Len()
// Re-using b.Buf if possible to reduce allocations.
b.Expand(4)
b.Buf = b.Buf[:length]
inner := bin.Buffer{Buf: b.Buf[length:length]}
inner.PutInt(b.Len())
if _, err := w.Write(inner.Buf); err != nil {
return err
}
if _, err := w.Write(b.Buf); err != nil {
return err
}
return nil
}
// readIntermediate reads payload from r to b.
func readIntermediate(r io.Reader, b *bin.Buffer, padding bool) error {
n, err := readLen(r, b)
if err != nil {
return err
}
b.ResetN(n)
if _, err := io.ReadFull(r, b.Buf); err != nil {
return errors.Wrap(err, "read payload")
}
if padding {
paddingLength := n % 4
b.Buf = b.Buf[:n-paddingLength]
}
return nil
}
+45
View File
@@ -0,0 +1,45 @@
package codec
import (
"bytes"
"testing"
"github.com/go-faster/errors"
"github.com/stretchr/testify/require"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
)
func TestIntermediate(t *testing.T) {
t.Run("Ok", func(t *testing.T) {
msg := bytes.Repeat([]byte{1, 2, 3}, 100)
buf := new(bytes.Buffer)
if err := writeIntermediate(buf, &bin.Buffer{Buf: msg}); err != nil {
t.Fatal(err)
}
var out bin.Buffer
if err := readIntermediate(buf, &out, false); err != nil {
t.Fatal(err)
}
require.Equal(t, msg, out.Buf)
})
t.Run("BigMessage", func(t *testing.T) {
codec := Intermediate{}
t.Run("Read", func(t *testing.T) {
var b bin.Buffer
b.PutInt(maxMessageSize + 10)
var out bin.Buffer
if err := codec.Read(&b, &out); !errors.Is(err, invalidMsgLenErr{}) {
t.Error(err)
}
})
t.Run("Write", func(t *testing.T) {
buf := make([]byte, maxMessageSize+10)
if err := codec.Write(nil, &bin.Buffer{Buf: buf}); !errors.Is(err, invalidMsgLenErr{}) {
t.Error(err)
}
})
})
}
+18
View File
@@ -0,0 +1,18 @@
package codec
import "io"
// NoHeader wraps codec to skip WriteHeader.
type NoHeader struct {
Codec
}
// WriteHeader implements Codec.
func (NoHeader) WriteHeader(io.Writer) error {
return nil
}
// ReadHeader implements Codec.
func (NoHeader) ReadHeader(io.Reader) error {
return nil
}
+20
View File
@@ -0,0 +1,20 @@
package codec
import (
"bytes"
"testing"
"github.com/stretchr/testify/require"
)
func TestNoHeader(t *testing.T) {
a := require.New(t)
cdc := NoHeader{
Codec: Intermediate{},
}
buf := bytes.Buffer{}
a.NoError(cdc.WriteHeader(&buf))
a.Equal(0, buf.Len())
a.NoError(cdc.ReadHeader(&buf))
}
+106
View File
@@ -0,0 +1,106 @@
package codec
import (
"io"
"github.com/go-faster/errors"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
)
// PaddedIntermediateClientStart is starting bytes sent by client in Padded intermediate mode.
//
// Note that server does not respond with it.
var PaddedIntermediateClientStart = [4]byte{0xdd, 0xdd, 0xdd, 0xdd}
// PaddedIntermediate is intermediate MTProto transport.
//
// See https://core.telegram.org/mtproto/mtproto-transports#padded-intermediate
type PaddedIntermediate struct{}
var (
_ TaggedCodec = PaddedIntermediate{}
)
// WriteHeader sends protocol tag.
func (i PaddedIntermediate) WriteHeader(w io.Writer) error {
if _, err := w.Write(PaddedIntermediateClientStart[:]); err != nil {
return errors.Wrap(err, "write padded intermediate header")
}
return nil
}
// ReadHeader reads protocol tag.
func (i PaddedIntermediate) ReadHeader(r io.Reader) error {
var b [4]byte
if _, err := io.ReadFull(r, b[:]); err != nil {
return errors.Wrap(err, "read padded intermediate header")
}
if b != PaddedIntermediateClientStart {
return ErrProtocolHeaderMismatch
}
return nil
}
// ObfuscatedTag returns protocol tag for obfuscation.
func (i PaddedIntermediate) ObfuscatedTag() [4]byte {
return PaddedIntermediateClientStart
}
// Write encode to writer message from given buffer.
func (i PaddedIntermediate) Write(w io.Writer, b *bin.Buffer) error {
if err := checkOutgoingMessage(b); err != nil {
return err
}
if err := checkAlign(b, 4); err != nil {
return err
}
if err := writePaddedIntermediate(crypto.DefaultRand(), w, b); err != nil {
return errors.Wrap(err, "write padded intermediate")
}
return nil
}
// Read fills buffer with received message.
func (i PaddedIntermediate) Read(r io.Reader, b *bin.Buffer) error {
if err := readPaddedIntermediate(r, b); err != nil {
return errors.Wrap(err, "read padded intermediate")
}
return checkProtocolError(b)
}
func writePaddedIntermediate(randSource io.Reader, w io.Writer, b *bin.Buffer) error {
length := b.Len()
b.Expand(4)
defer func() {
b.Buf = b.Buf[:length]
}()
_, err := io.ReadFull(randSource, b.Buf[length:length+4])
if err != nil {
return err
}
n := int(b.Buf[length-1]) % 4
b.Buf = b.Buf[:length+n]
return writeIntermediate(w, b)
}
func readPaddedIntermediate(r io.Reader, b *bin.Buffer) error {
if err := readIntermediate(r, b, true); err != nil {
return err
}
padding := b.Len() % 4
b.Buf = b.Buf[:b.Len()-padding]
return nil
}