move gotd fork into repo. (#111)

- update to latest telegram layer
- remove some references to fields in tg.Entities that don't exist in
the schema
- originally added here:
https://github.com/beeper/td/commit/820929062a2ba0104397bc01235ab58a9cff780e
  - referenced here
-
https://github.com/mautrix/telegramgo/commit/124f0967ed195b5a380c9bd02e170ada9710dde3
-
https://github.com/mautrix/telegramgo/commit/4205047aab2e0639217148b5d125bfaab668bd8e
This commit is contained in:
Adam Van Ymeren
2025-06-27 20:03:37 -07:00
committed by GitHub
parent 0952df0244
commit 7a04f298d2
19264 changed files with 1539697 additions and 84 deletions
+145
View File
@@ -0,0 +1,145 @@
package codec
import (
"encoding/binary"
"io"
"github.com/go-faster/errors"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
)
// AbridgedClientStart is starting bytes sent by client in Abridged mode.
//
// Note that server does not respond with it.
var AbridgedClientStart = [1]byte{0xef}
// Abridged is intermediate MTProto transport.
//
// See https://core.telegram.org/mtproto/mtproto-transports#abridged
type Abridged struct{}
var (
_ TaggedCodec = Abridged{}
)
// WriteHeader sends protocol tag.
func (i Abridged) WriteHeader(w io.Writer) error {
if _, err := w.Write(AbridgedClientStart[:]); err != nil {
return errors.Wrap(err, "write abridged header")
}
return nil
}
// ReadHeader reads protocol tag.
func (i Abridged) ReadHeader(r io.Reader) error {
var b [1]byte
if _, err := io.ReadFull(r, b[:]); err != nil {
return errors.Wrap(err, "read abridged header")
}
if b != AbridgedClientStart {
return ErrProtocolHeaderMismatch
}
return nil
}
// ObfuscatedTag returns protocol tag for obfuscation.
func (i Abridged) ObfuscatedTag() (r [4]byte) {
d := AbridgedClientStart[0]
return [4]byte{d, d, d, d}
}
// Write encode to writer message from given buffer.
func (i Abridged) Write(w io.Writer, b *bin.Buffer) error {
if err := checkOutgoingMessage(b); err != nil {
return err
}
if err := checkAlign(b, 4); err != nil {
return err
}
if err := writeAbridged(w, b); err != nil {
return errors.Wrap(err, "write abridged")
}
return nil
}
// Read fills buffer with received message.
func (i Abridged) Read(r io.Reader, b *bin.Buffer) error {
if err := readAbridged(r, b); err != nil {
return errors.Wrap(err, "read abridged")
}
return checkProtocolError(b)
}
func writeAbridged(w io.Writer, b *bin.Buffer) error {
length := b.Len()
// Re-using b.Buf if possible to reduce allocations.
b.Expand(4)
b.Buf = b.Buf[:length]
// Re-using b.Buf if possible to reduce allocations.
inner := bin.Buffer{Buf: b.Buf[length:length]}
encodeLength := b.Len() >> 2
// `0x7f == 127`, literally use one bit to distinguish length byte size.
if encodeLength < 127 {
// Payloads are wrapped in the following envelope:
//
// Length: payload length, divided by four, and encoded as a single byte,
// only if the resulting packet length is a value between 0x01..0x7e.
inner.Put([]byte{byte(encodeLength)})
} else {
// If the packet length divided by four is bigger than or equal to 127 (>= 0x7f),
// the following envelope must be used, instead:
//
var buf [5]byte
// Header: A single byte of value 0x7f
buf[0] = 0x7f
// Length: payload length, divided by four, and encoded as 3 length bytes (little endian)
binary.LittleEndian.PutUint32(buf[1:], uint32(encodeLength))
inner.Put(buf[:4])
}
if _, err := w.Write(inner.Buf); err != nil {
return err
}
if _, err := w.Write(b.Raw()); err != nil {
return err
}
return nil
}
func readAbridged(r io.Reader, b *bin.Buffer) error {
b.ResetN(bin.Word)
_, err := io.ReadFull(r, b.Buf[:1])
if err != nil {
return err
}
if b.Buf[0] >= 127 {
_, err := io.ReadFull(r, b.Buf[0:3])
if err != nil {
return err
}
}
n, err := b.Int()
if err != nil {
return err
}
b.ResetN(n << 2)
if _, err := io.ReadFull(r, b.Buf); err != nil {
return errors.Wrap(err, "read payload")
}
return nil
}
+69
View File
@@ -0,0 +1,69 @@
package codec
import (
"bytes"
"encoding/binary"
"strings"
"testing"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
"github.com/stretchr/testify/require"
)
func TestAbridged(t *testing.T) {
bigHeader := func(l int) (packet []byte) {
// header + 3 bytes of LE
var buf [4]byte
binary.LittleEndian.PutUint32(buf[:], uint32(l>>2))
packet = append([]byte{127}, buf[0:3]...)
return
}
tests := []struct {
payloadName string
testData func() (payload string, packet []byte)
}{
{"Small-4b", func() (payload string, packet []byte) {
payload = "abcd"
packet = append([]byte{byte(len(payload) >> 2)}, payload...)
return
}},
{"Medium-124b", func() (payload string, packet []byte) {
payload = strings.Repeat("a", 124)
packet = append([]byte{byte(len(payload) >> 2)}, payload...)
return
}},
{"Big-1kb", func() (payload string, packet []byte) {
payload = strings.Repeat("a", 1024)
packet = bigHeader(len(payload))
require.Equal(t, []byte{127, 0, 1, 0}, packet)
packet = append(packet, payload...)
return
}},
}
for _, test := range tests {
payload, packet := test.testData()
t.Run(test.payloadName, func(t *testing.T) {
t.Run("Write", func(t *testing.T) {
b := bytes.NewBuffer(nil)
err := writeAbridged(b, &bin.Buffer{Buf: []byte(payload)})
require.NoError(t, err)
require.Equal(t, packet, b.Bytes())
})
t.Run("Read", func(t *testing.T) {
b := &bin.Buffer{}
err := readAbridged(bytes.NewReader(packet), b)
require.NoError(t, err)
require.Equal(t, payload, string(b.Raw()))
})
})
}
}
+86
View File
@@ -0,0 +1,86 @@
package codec
import (
"bytes"
"crypto/rand"
"io"
"testing"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
"go.mau.fi/mautrix-telegram/pkg/gotd/testutil"
)
func benchWrite(codec Codec) func(payloadSize int) func(b *testing.B) {
return func(payloadSize int) func(b *testing.B) {
return func(b *testing.B) {
buf := bin.Buffer{Buf: make([]byte, payloadSize)}
if _, err := io.ReadFull(rand.Reader, buf.Buf); err != nil {
b.Fatal(err)
}
b.ReportAllocs()
b.SetBytes(int64(buf.Len() + 4))
b.ResetTimer()
for i := 0; i < b.N; i++ {
if err := codec.Write(io.Discard, &buf); err != nil {
b.Fatal(err)
}
}
}
}
}
func BenchmarkWrite(b *testing.B) {
b.Run("Abridged", func(b *testing.B) {
testutil.RunPayloads(b, benchWrite(Abridged{}))
})
b.Run("Intermediate", func(b *testing.B) {
testutil.RunPayloads(b, benchWrite(Intermediate{}))
})
b.Run("PaddedIntermediate", func(b *testing.B) {
testutil.RunPayloads(b, benchWrite(PaddedIntermediate{}))
})
}
func benchRead(codec Codec) func(payloadSize int) func(b *testing.B) {
return func(payloadSize int) func(b *testing.B) {
return func(b *testing.B) {
buf := bin.Buffer{Buf: make([]byte, payloadSize)}
if _, err := io.ReadFull(rand.Reader, buf.Buf); err != nil {
b.Fatal(err)
}
out := new(bytes.Buffer)
if err := codec.Write(out, &buf); err != nil {
b.Fatal(err)
}
raw := out.Bytes()
reader := bytes.NewReader(nil)
b.ReportAllocs()
b.SetBytes(int64(buf.Len() + 4))
b.ResetTimer()
for i := 0; i < b.N; i++ {
reader.Reset(raw)
if err := codec.Read(reader, &buf); err != nil {
b.Fatal(err)
}
buf.Reset()
}
}
}
}
func BenchmarkRead(b *testing.B) {
b.Run("Abridged", func(b *testing.B) {
testutil.RunPayloads(b, benchRead(Abridged{}))
})
b.Run("Intermediate", func(b *testing.B) {
testutil.RunPayloads(b, benchRead(Intermediate{}))
})
b.Run("PaddedIntermediate", func(b *testing.B) {
testutil.RunPayloads(b, benchRead(PaddedIntermediate{}))
})
}
+44
View File
@@ -0,0 +1,44 @@
package codec
import (
"encoding/binary"
"io"
"github.com/go-faster/errors"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
)
// Codec is MTProto transport protocol encoding abstraction.
type Codec interface {
// WriteHeader sends protocol tag if needed.
WriteHeader(w io.Writer) error
// ReadHeader reads protocol tag if needed.
ReadHeader(r io.Reader) error
// Write encode to writer message from given buffer.
Write(w io.Writer, b *bin.Buffer) error
// Read fills buffer with received message.
Read(r io.Reader, b *bin.Buffer) error
}
// TaggedCodec is codec with protocol tag.
type TaggedCodec interface {
Codec
// ObfuscatedTag returns protocol tag for obfuscation.
ObfuscatedTag() [4]byte
}
// readLen reads 32-bit integer and validates it as message length.
func readLen(r io.Reader, b *bin.Buffer) (int, error) {
b.ResetN(bin.Word)
if _, err := io.ReadFull(r, b.Buf[:bin.Word]); err != nil {
return 0, errors.Wrap(err, "read length")
}
n := int(binary.LittleEndian.Uint32(b.Buf[:bin.Word]))
if n <= 0 || n > maxMessageSize {
return 0, invalidMsgLenErr{n: n}
}
return n, nil
}
+197
View File
@@ -0,0 +1,197 @@
package codec
import (
"bytes"
"encoding/binary"
"io"
"strings"
"testing"
"testing/iotest"
"github.com/stretchr/testify/require"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
"go.mau.fi/mautrix-telegram/pkg/gotd/testutil"
)
type codecTest struct {
name string
align int
create func() Codec
}
func codecs() []codecTest {
return []codecTest{
{"Abridged", 4, func() Codec {
return Abridged{}
}},
{"Intermediate", 4, func() Codec {
return Intermediate{}
}},
{"PaddedIntermediate", 4, func() Codec {
return PaddedIntermediate{}
}},
{"Full", 0, func() Codec {
return &Full{}
}},
{"NoHeaderIntermediate", 0, func() Codec {
return NoHeader{Codec: Intermediate{}}
}},
}
}
type payload struct {
name string
testData string
mustFail bool
readTestOnly bool
}
func payloads() []payload {
var code [4]byte
binary.LittleEndian.PutUint32(code[:], CodeTransportFlood)
return []payload{
{"Empty", "", true, false},
{"Protocol error", string(code[:]), true, true},
{"Small 8b", "abcdabcd", false, false},
{"Medium 1kb", strings.Repeat("a", 1024), false, false},
}
}
func testGood(c codecTest, p payload) func(t *testing.T) {
return func(t *testing.T) {
t.Run("One message", func(t *testing.T) {
a := require.New(t)
codec := c.create()
buf := bytes.NewBuffer(nil)
payload := &bin.Buffer{Buf: []byte(p.testData)}
l := payload.Len()
// Encode
a.NoError(codec.Write(buf, payload))
a.Equal(l, payload.Len(), "Codec must not change buffer length")
// Decode
payload.Reset()
a.NoError(codec.Read(buf, payload))
a.Equal(p.testData, string(payload.Buf))
})
t.Run("Two messages", func(t *testing.T) {
a := require.New(t)
codec := c.create()
buf := bytes.NewBuffer(nil)
payload := &bin.Buffer{Buf: []byte(p.testData)}
l := payload.Len()
// Encode twice
a.NoError(codec.Write(buf, payload))
payload.ResetTo([]byte(p.testData))
a.NoError(codec.Write(buf, payload))
a.Equal(l, payload.Len(), "Codec must not change buffer length")
// Decode twice
payload.Reset()
a.NoError(codec.Read(buf, payload))
a.Equal(p.testData, string(payload.Buf))
payload.Reset()
a.NoError(codec.Read(buf, payload))
a.Equal(p.testData, string(payload.Buf))
})
}
}
func testBad(c codecTest, p payload) func(t *testing.T) {
return func(t *testing.T) {
if !p.readTestOnly {
t.Run("Write", func(t *testing.T) {
payload := &bin.Buffer{Buf: []byte(p.testData)}
require.Error(t, c.create().Write(io.Discard, payload))
})
}
t.Run("Read", func(t *testing.T) {
reader := bytes.NewBufferString(p.testData)
require.Error(t, c.create().Read(reader, &bin.Buffer{}))
})
}
}
func testHeaderTag(c codecTest) func(t *testing.T) {
e := io.ErrClosedPipe
return func(t *testing.T) {
t.Run("GoodTag", func(t *testing.T) {
a := require.New(t)
buf := bytes.NewBuffer(nil)
a.NoError(c.create().WriteHeader(buf))
a.NoError(c.create().ReadHeader(buf))
})
if tagged, ok := c.create().(TaggedCodec); ok {
t.Run("ReadError", func(t *testing.T) {
r := iotest.ErrReader(e)
require.ErrorIs(t, c.create().ReadHeader(r), e)
})
t.Run("WriteError", func(t *testing.T) {
w := testutil.ErrWriter(e)
require.ErrorIs(t, c.create().WriteHeader(w), e)
})
t.Run("BadTag", func(t *testing.T) {
tag := tagged.ObfuscatedTag()
buf := bytes.NewBuffer(tag[:])
tag[0] = 0
require.ErrorIs(t, c.create().ReadHeader(buf), ErrProtocolHeaderMismatch)
})
}
}
}
func testCodec(c codecTest) func(t *testing.T) {
e := io.ErrClosedPipe
return func(t *testing.T) {
t.Run("ReadError", func(t *testing.T) {
r := iotest.ErrReader(e)
require.ErrorIs(t, c.create().Read(r, &bin.Buffer{}), e)
})
t.Run("WriteError", func(t *testing.T) {
w := testutil.ErrWriter(e)
require.ErrorIs(t,
c.create().Write(
w,
&bin.Buffer{Buf: make([]byte, 16)},
),
e,
)
})
if c.align != 0 {
t.Run("AlignError", func(t *testing.T) {
require.Error(t,
c.create().Write(
io.Discard,
&bin.Buffer{Buf: make([]byte, c.align-1)},
),
)
})
}
}
}
func TestCodecs(t *testing.T) {
for _, c := range codecs() {
t.Run(c.name, func(t *testing.T) {
for _, p := range payloads() {
if p.mustFail {
t.Run(p.name, testBad(c, p))
} else {
t.Run(p.name, testGood(c, p))
}
}
t.Run("Header", testHeaderTag(c))
testCodec(c)(t)
})
}
}
+4
View File
@@ -0,0 +1,4 @@
// Package codec contains MTProto transport encoding implementations.
//
// See https://core.telegram.org/mtproto/mtproto-transports
package codec
+106
View File
@@ -0,0 +1,106 @@
package codec
import (
"fmt"
"github.com/go-faster/errors"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
)
const (
// CodeAuthKeyNotFound means that specified auth key ID cannot be found by the DC.
// Also, may be returned during key exchange.
CodeAuthKeyNotFound = 404
// CodeWrongDC means that current DC is wrong.
// Usually returned by server when key exchange sends wrong DC ID.
CodeWrongDC = 444
// CodeTransportFlood means that too many transport connections are
// established to the same IP in a too short lapse of time, or if any
// of the container/service message limits are reached.
CodeTransportFlood = 429
)
// ProtocolErr represents protocol level error.
type ProtocolErr struct {
Code int32
}
func (p ProtocolErr) Error() string {
switch p.Code {
case CodeAuthKeyNotFound:
return "auth key not found"
case CodeTransportFlood:
return "transport flood"
case CodeWrongDC:
return "wrong DC"
default:
return fmt.Sprintf("protocol error %d", p.Code)
}
}
// Can be bigger that 1mb.
//
// See https://github.com/gotd/td/issues/412
//
// See https://github.com/tdlib/td/blob/550ccc8d9bbbe9cff1dc618aef5764d2cbd2cd91/td/mtproto/TcpTransport.cpp#L53
const maxMessageSize = 1 << 24 // 16 MB
func checkOutgoingMessage(b *bin.Buffer) error {
length := b.Len()
if length > maxMessageSize || length == 0 {
return invalidMsgLenErr{n: length}
}
return nil
}
func checkAlign(b *bin.Buffer, n int) error {
length := b.Len()
if length%n != 0 {
return alignedPayloadExpectedErr{expected: n}
}
return nil
}
func checkProtocolError(b *bin.Buffer) error {
if b.Len() != bin.Word {
return nil
}
code, err := b.Int32()
if err != nil {
return err
}
return &ProtocolErr{Code: -code}
}
type alignedPayloadExpectedErr struct {
expected int
}
func (e alignedPayloadExpectedErr) Error() string {
return fmt.Sprintf("payload is not aligned, expected align by %d", e.expected)
}
func (e alignedPayloadExpectedErr) Is(err error) bool {
_, ok := err.(alignedPayloadExpectedErr)
return ok
}
type invalidMsgLenErr struct {
n int
}
func (e invalidMsgLenErr) Error() string {
return fmt.Sprintf("invalid message length %d", e.n)
}
func (e invalidMsgLenErr) Is(err error) bool {
_, ok := err.(invalidMsgLenErr)
return ok
}
// ErrProtocolHeaderMismatch means that received protocol header
// is mismatched with expected.
var ErrProtocolHeaderMismatch = errors.New("protocol header mismatch")
+124
View File
@@ -0,0 +1,124 @@
package codec
import (
"hash/crc32"
"io"
"sync/atomic"
"github.com/go-faster/errors"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
)
// Full is full MTProto transport.
//
// See https://core.telegram.org/mtproto/mtproto-transports#full
type Full struct {
wSeqNo int64
rSeqNo int64
}
// WriteHeader sends protocol tag.
func (i *Full) WriteHeader(w io.Writer) (err error) {
return nil
}
// ReadHeader reads protocol tag.
func (i *Full) ReadHeader(r io.Reader) (err error) {
return nil
}
// Write encode to writer message from given buffer.
func (i *Full) Write(w io.Writer, b *bin.Buffer) error {
if err := checkOutgoingMessage(b); err != nil {
return err
}
if err := writeFull(w, int(atomic.AddInt64(&i.wSeqNo, 1)-1), b); err != nil {
return errors.Wrap(err, "write full")
}
return nil
}
// Read fills buffer with received message.
func (i *Full) Read(r io.Reader, b *bin.Buffer) error {
if err := readFull(r, int(atomic.AddInt64(&i.rSeqNo, 1)-1), b); err != nil {
return errors.Wrap(err, "read full")
}
return checkProtocolError(b)
}
func writeFull(w io.Writer, seqNo int, b *bin.Buffer) error {
write := bin.Buffer{Buf: make([]byte, 0, 4+4+b.Len()+4)}
// Length: length+seqno+payload+crc length encoded as 4 length bytes
// (little endian, the length of the length field must be included, too)
write.PutInt(4 + 4 + b.Len() + 4)
// Seqno: the TCP sequence number for this TCP connection (different from the MTProto sequence number):
// the first packet sent is numbered 0, the next one 1, etc.
write.PutInt(seqNo)
// payload: MTProto payload
write.Put(b.Raw())
// crc: 4 CRC32 bytes computed using length, sequence number, and payload together.
crc := crc32.ChecksumIEEE(write.Raw())
write.PutUint32(crc)
if _, err := w.Write(write.Raw()); err != nil {
return err
}
return nil
}
var errSeqNoMismatch = errors.New("seq_no mismatch")
var errCRCMismatch = errors.New("crc mismatch")
func readFull(r io.Reader, seqNo int, b *bin.Buffer) error {
n, err := readLen(r, b)
if err != nil {
return errors.Wrap(err, "len")
}
// Put length, because it need to count CRC.
b.PutInt(n)
b.Expand(n - bin.Word)
inner := &bin.Buffer{Buf: b.Buf[bin.Word:n]}
// Reads tail of packet to the buffer.
// Length already read.
if _, err := io.ReadFull(r, inner.Buf); err != nil {
return errors.Wrap(err, "read seqno, buffer and crc")
}
serverSeqNo, err := inner.Int()
if err != nil {
return err
}
if serverSeqNo != seqNo {
return errSeqNoMismatch
}
payloadLength := n - 3*bin.Word
inner.Skip(payloadLength)
// Cut only crc part.
crc, err := inner.Uint32()
if err != nil {
return err
}
// Compute crc using all buffer without last 4 bytes from server.
clientCRC := crc32.ChecksumIEEE(b.Buf[0 : n-bin.Word])
// Compare computed and read CRCs.
if crc != clientCRC {
return errCRCMismatch
}
// n
// Length | SeqNo | payload | CRC |
// Word | Word | ....... | Word |
copy(b.Buf, b.Buf[2*bin.Word:n-bin.Word])
b.Buf = b.Buf[:payloadLength]
return nil
}
+40
View File
@@ -0,0 +1,40 @@
package codec
import (
"bytes"
"testing"
"github.com/stretchr/testify/require"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
)
func fullTestData() (packet, payload []byte) {
return []byte{
15, 0, 0, 0, // length
1, 0, 0, 0, // seqNo
97, 98, 99, // payload
78, 214, 109, 148, // crc
}, []byte("abc")
}
func TestFull(t *testing.T) {
packet, payload := fullTestData()
t.Run("write", func(t *testing.T) {
b := bytes.NewBuffer(nil)
buf := &bin.Buffer{Buf: payload}
err := writeFull(b, 1, buf)
require.NoError(t, err)
require.Equal(t, packet, b.Bytes())
})
t.Run("read", func(t *testing.T) {
b := &bin.Buffer{}
err := readFull(bytes.NewBuffer(packet), 1, b)
require.NoError(t, err)
require.Equal(t, payload, b.Buf)
})
}
+115
View File
@@ -0,0 +1,115 @@
package codec
import (
"io"
"github.com/go-faster/errors"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
)
// IntermediateClientStart is starting bytes sent by client in Intermediate mode.
//
// Note that server does not respond with it.
var IntermediateClientStart = [4]byte{0xee, 0xee, 0xee, 0xee}
// Intermediate is intermediate MTProto transport.
//
// See https://core.telegram.org/mtproto/mtproto-transports#intermediate
type Intermediate struct{}
var (
_ TaggedCodec = Intermediate{}
)
// WriteHeader sends protocol tag.
func (i Intermediate) WriteHeader(w io.Writer) (err error) {
if _, err := w.Write(IntermediateClientStart[:]); err != nil {
return errors.Wrap(err, "write intermediate header")
}
return nil
}
// ReadHeader reads protocol tag.
func (i Intermediate) ReadHeader(r io.Reader) (err error) {
var b [4]byte
if _, err := io.ReadFull(r, b[:]); err != nil {
return errors.Wrap(err, "read intermediate header")
}
if b != IntermediateClientStart {
return ErrProtocolHeaderMismatch
}
return nil
}
// ObfuscatedTag returns protocol tag for obfuscation.
func (i Intermediate) ObfuscatedTag() [4]byte {
return IntermediateClientStart
}
// Write encode to writer message from given buffer.
func (i Intermediate) Write(w io.Writer, b *bin.Buffer) error {
if err := checkOutgoingMessage(b); err != nil {
return err
}
if err := checkAlign(b, 4); err != nil {
return err
}
if err := writeIntermediate(w, b); err != nil {
return errors.Wrap(err, "write intermediate")
}
return nil
}
// Read fills buffer with received message.
func (i Intermediate) Read(r io.Reader, b *bin.Buffer) error {
if err := readIntermediate(r, b, false); err != nil {
return errors.Wrap(err, "read intermediate")
}
return checkProtocolError(b)
}
// writeIntermediate encodes b as payload to w.
func writeIntermediate(w io.Writer, b *bin.Buffer) error {
length := b.Len()
// Re-using b.Buf if possible to reduce allocations.
b.Expand(4)
b.Buf = b.Buf[:length]
inner := bin.Buffer{Buf: b.Buf[length:length]}
inner.PutInt(b.Len())
if _, err := w.Write(inner.Buf); err != nil {
return err
}
if _, err := w.Write(b.Buf); err != nil {
return err
}
return nil
}
// readIntermediate reads payload from r to b.
func readIntermediate(r io.Reader, b *bin.Buffer, padding bool) error {
n, err := readLen(r, b)
if err != nil {
return err
}
b.ResetN(n)
if _, err := io.ReadFull(r, b.Buf); err != nil {
return errors.Wrap(err, "read payload")
}
if padding {
paddingLength := n % 4
b.Buf = b.Buf[:n-paddingLength]
}
return nil
}
+45
View File
@@ -0,0 +1,45 @@
package codec
import (
"bytes"
"testing"
"github.com/go-faster/errors"
"github.com/stretchr/testify/require"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
)
func TestIntermediate(t *testing.T) {
t.Run("Ok", func(t *testing.T) {
msg := bytes.Repeat([]byte{1, 2, 3}, 100)
buf := new(bytes.Buffer)
if err := writeIntermediate(buf, &bin.Buffer{Buf: msg}); err != nil {
t.Fatal(err)
}
var out bin.Buffer
if err := readIntermediate(buf, &out, false); err != nil {
t.Fatal(err)
}
require.Equal(t, msg, out.Buf)
})
t.Run("BigMessage", func(t *testing.T) {
codec := Intermediate{}
t.Run("Read", func(t *testing.T) {
var b bin.Buffer
b.PutInt(maxMessageSize + 10)
var out bin.Buffer
if err := codec.Read(&b, &out); !errors.Is(err, invalidMsgLenErr{}) {
t.Error(err)
}
})
t.Run("Write", func(t *testing.T) {
buf := make([]byte, maxMessageSize+10)
if err := codec.Write(nil, &bin.Buffer{Buf: buf}); !errors.Is(err, invalidMsgLenErr{}) {
t.Error(err)
}
})
})
}
+18
View File
@@ -0,0 +1,18 @@
package codec
import "io"
// NoHeader wraps codec to skip WriteHeader.
type NoHeader struct {
Codec
}
// WriteHeader implements Codec.
func (NoHeader) WriteHeader(io.Writer) error {
return nil
}
// ReadHeader implements Codec.
func (NoHeader) ReadHeader(io.Reader) error {
return nil
}
+20
View File
@@ -0,0 +1,20 @@
package codec
import (
"bytes"
"testing"
"github.com/stretchr/testify/require"
)
func TestNoHeader(t *testing.T) {
a := require.New(t)
cdc := NoHeader{
Codec: Intermediate{},
}
buf := bytes.Buffer{}
a.NoError(cdc.WriteHeader(&buf))
a.Equal(0, buf.Len())
a.NoError(cdc.ReadHeader(&buf))
}
+106
View File
@@ -0,0 +1,106 @@
package codec
import (
"io"
"github.com/go-faster/errors"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
)
// PaddedIntermediateClientStart is starting bytes sent by client in Padded intermediate mode.
//
// Note that server does not respond with it.
var PaddedIntermediateClientStart = [4]byte{0xdd, 0xdd, 0xdd, 0xdd}
// PaddedIntermediate is intermediate MTProto transport.
//
// See https://core.telegram.org/mtproto/mtproto-transports#padded-intermediate
type PaddedIntermediate struct{}
var (
_ TaggedCodec = PaddedIntermediate{}
)
// WriteHeader sends protocol tag.
func (i PaddedIntermediate) WriteHeader(w io.Writer) error {
if _, err := w.Write(PaddedIntermediateClientStart[:]); err != nil {
return errors.Wrap(err, "write padded intermediate header")
}
return nil
}
// ReadHeader reads protocol tag.
func (i PaddedIntermediate) ReadHeader(r io.Reader) error {
var b [4]byte
if _, err := io.ReadFull(r, b[:]); err != nil {
return errors.Wrap(err, "read padded intermediate header")
}
if b != PaddedIntermediateClientStart {
return ErrProtocolHeaderMismatch
}
return nil
}
// ObfuscatedTag returns protocol tag for obfuscation.
func (i PaddedIntermediate) ObfuscatedTag() [4]byte {
return PaddedIntermediateClientStart
}
// Write encode to writer message from given buffer.
func (i PaddedIntermediate) Write(w io.Writer, b *bin.Buffer) error {
if err := checkOutgoingMessage(b); err != nil {
return err
}
if err := checkAlign(b, 4); err != nil {
return err
}
if err := writePaddedIntermediate(crypto.DefaultRand(), w, b); err != nil {
return errors.Wrap(err, "write padded intermediate")
}
return nil
}
// Read fills buffer with received message.
func (i PaddedIntermediate) Read(r io.Reader, b *bin.Buffer) error {
if err := readPaddedIntermediate(r, b); err != nil {
return errors.Wrap(err, "read padded intermediate")
}
return checkProtocolError(b)
}
func writePaddedIntermediate(randSource io.Reader, w io.Writer, b *bin.Buffer) error {
length := b.Len()
b.Expand(4)
defer func() {
b.Buf = b.Buf[:length]
}()
_, err := io.ReadFull(randSource, b.Buf[length:length+4])
if err != nil {
return err
}
n := int(b.Buf[length-1]) % 4
b.Buf = b.Buf[:length+n]
return writeIntermediate(w, b)
}
func readPaddedIntermediate(r io.Reader, b *bin.Buffer) error {
if err := readIntermediate(r, b, true); err != nil {
return err
}
padding := b.Len() % 4
b.Buf = b.Buf[:b.Len()-padding]
return nil
}
+96
View File
@@ -0,0 +1,96 @@
package proto
import (
"github.com/go-faster/errors"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
)
// MessageContainerTypeID is TL type id of MessageContainer.
const MessageContainerTypeID = 0x73f1f8dc
// MessageContainer contains slice of messages.
type MessageContainer struct {
Messages []Message
}
// Encode implements bin.Decoder.
func (m *MessageContainer) Encode(b *bin.Buffer) error {
b.PutID(MessageContainerTypeID)
b.PutInt(len(m.Messages))
for _, msg := range m.Messages {
if err := msg.Encode(b); err != nil {
return err
}
}
return nil
}
// Decode implements bin.Decoder.
func (m *MessageContainer) Decode(b *bin.Buffer) error {
if err := b.ConsumeID(MessageContainerTypeID); err != nil {
return errors.Wrap(err, "consume id of message container")
}
n, err := b.Int()
if err != nil {
return err
}
for i := 0; i < n; i++ {
var msg Message
if err := msg.Decode(b); err != nil {
return err
}
m.Messages = append(m.Messages, msg)
}
return nil
}
// Message is element of MessageContainer.
type Message struct {
ID int64
SeqNo int
Bytes int
Body []byte
}
// Encode implements bin.Encoder.
func (m *Message) Encode(b *bin.Buffer) error {
if m.Bytes < 0 || m.Bytes > 1024*1024 {
return errors.Errorf("message length %d is invalid", m.Bytes)
}
b.PutLong(m.ID)
b.PutInt(m.SeqNo)
b.PutInt(m.Bytes)
b.Put(m.Body)
return nil
}
// Decode implements bin.Decoder.
func (m *Message) Decode(b *bin.Buffer) error {
{
v, err := b.Long()
if err != nil {
return err
}
m.ID = v
}
{
v, err := b.Int()
if err != nil {
return err
}
m.SeqNo = v
}
{
v, err := b.Int()
if err != nil {
return err
}
m.Bytes = v
}
if m.Bytes < 0 || m.Bytes > 1024*1024 {
return errors.New("message length is too big")
}
m.Body = make([]byte, m.Bytes)
return b.ConsumeN(m.Body, m.Bytes)
}
+4
View File
@@ -0,0 +1,4 @@
// Package proto implements MTProto 2.0 primitives.
//
// See https://core.telegram.org/mtproto/description for reference.
package proto
+184
View File
@@ -0,0 +1,184 @@
package proto
import (
"bytes"
"fmt"
"io"
"sync"
"sync/atomic"
"github.com/go-faster/errors"
"github.com/klauspost/compress/gzip"
"go.uber.org/multierr"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
)
type gzipPool struct {
writers sync.Pool
readers sync.Pool
}
func newGzipPool() *gzipPool {
return &gzipPool{
writers: sync.Pool{
New: func() interface{} {
return gzip.NewWriter(nil)
},
},
readers: sync.Pool{},
}
}
func (g *gzipPool) GetWriter(w io.Writer) *gzip.Writer {
writer := g.writers.Get().(*gzip.Writer)
writer.Reset(w)
return writer
}
func (g *gzipPool) PutWriter(w *gzip.Writer) {
g.writers.Put(w)
}
func (g *gzipPool) GetReader(r io.Reader) (*gzip.Reader, error) {
reader, ok := g.readers.Get().(*gzip.Reader)
if !ok {
r, err := gzip.NewReader(r)
if err != nil {
return nil, err
}
return r, nil
}
if err := reader.Reset(r); err != nil {
g.readers.Put(reader)
return nil, err
}
return reader, nil
}
func (g *gzipPool) PutReader(w *gzip.Reader) {
g.readers.Put(w)
}
// GZIP represents a Packed Object.
//
// Used to replace any other object (or rather, a serialization thereof)
// with its archived (gzipped) representation.
type GZIP struct {
Data []byte
}
// GZIPTypeID is TL type id of GZIP.
const GZIPTypeID = 0x3072cfa1
var (
gzipRWPool = newGzipPool()
gzipBufPool = sync.Pool{New: func() interface{} {
return bytes.NewBuffer(nil)
}}
)
// Encode implements bin.Encoder.
func (g GZIP) Encode(b *bin.Buffer) (rErr error) {
b.PutID(GZIPTypeID)
// Writing compressed data to buf.
buf := gzipBufPool.Get().(*bytes.Buffer)
buf.Reset()
defer gzipBufPool.Put(buf)
w := gzipRWPool.GetWriter(buf)
defer func() {
if closeErr := w.Close(); closeErr != nil {
closeErr = errors.Wrap(closeErr, "close")
multierr.AppendInto(&rErr, closeErr)
}
gzipRWPool.PutWriter(w)
}()
if _, err := w.Write(g.Data); err != nil {
return errors.Wrap(err, "compress")
}
if err := w.Close(); err != nil {
return errors.Wrap(err, "close")
}
// Writing compressed data as bytes.
b.PutBytes(buf.Bytes())
return nil
}
type countReader struct {
reader io.Reader
read int64
}
func (c *countReader) Total() int64 {
return atomic.LoadInt64(&c.read)
}
func (c *countReader) Read(p []byte) (n int, err error) {
n, err = c.reader.Read(p)
atomic.AddInt64(&c.read, int64(n))
return n, err
}
// DecompressionBombErr means that GZIP decode detected decompression bomb
// which decompressed payload is significantly higher than initial compressed
// size and stopped decompression to prevent OOM.
type DecompressionBombErr struct {
Compressed int
Decompressed int
}
func (d *DecompressionBombErr) Error() string {
return fmt.Sprintf("payload too big (expanded %d bytes to greater than %d)",
d.Compressed, d.Decompressed,
)
}
// Decode implements bin.Decoder.
func (g *GZIP) Decode(b *bin.Buffer) (rErr error) {
if err := b.ConsumeID(GZIPTypeID); err != nil {
return err
}
buf, err := b.Bytes()
if err != nil {
return err
}
r, err := gzipRWPool.GetReader(bytes.NewReader(buf))
if err != nil {
return errors.Wrap(err, "gzip error")
}
defer func() {
if closeErr := r.Close(); closeErr != nil {
closeErr = errors.Wrap(closeErr, "close")
multierr.AppendInto(&rErr, closeErr)
}
gzipRWPool.PutReader(r)
}()
// Apply mitigation for reading too much data which can result in OOM.
const maxUncompressedSize = 1024 * 1024 * 10 // 10 mb
reader := &countReader{
reader: io.LimitReader(r, maxUncompressedSize),
}
if g.Data, err = io.ReadAll(reader); err != nil {
return errors.Wrap(err, "decompress")
}
if reader.Total() >= maxUncompressedSize {
// Read limit reached, possible decompression bomb detected.
return errors.Wrap(&DecompressionBombErr{
Compressed: maxUncompressedSize,
Decompressed: int(reader.Total()),
}, "decompress")
}
if err := r.Close(); err != nil {
return errors.Wrap(err, "checksum")
}
return nil
}
+86
View File
@@ -0,0 +1,86 @@
package proto
import (
"bytes"
"crypto/rand"
"io"
"testing"
"github.com/stretchr/testify/require"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
"go.mau.fi/mautrix-telegram/pkg/gotd/testutil"
)
func TestGZIP_Encode(t *testing.T) {
data := bytes.Repeat([]byte{1, 2, 3}, 100)
g := &GZIP{
Data: data,
}
var b bin.Buffer
require.NoError(t, b.Encode(g))
var decoded GZIP
require.NoError(t, b.Decode(&decoded))
require.Equal(t, data, decoded.Data)
}
func TestGZIP_Decode(t *testing.T) {
g := &GZIP{
Data: make([]byte, 1024*1024*15),
}
var b bin.Buffer
require.NoError(t, b.Encode(g))
var (
decoded GZIP
target *DecompressionBombErr
)
require.ErrorAs(t, b.Decode(&decoded), &target)
require.Less(t, len(decoded.Data), len(g.Data))
}
func benchmarkGZIPEncode(payloadSize int) func(b *testing.B) {
return func(b *testing.B) {
g := &GZIP{Data: make([]byte, payloadSize)}
_, err := io.ReadFull(rand.Reader, g.Data)
require.NoError(b, err)
var buf bin.Buffer
b.ReportAllocs()
b.SetBytes(int64(payloadSize))
b.ResetTimer()
for i := 0; i < b.N; i++ {
buf.Reset()
_ = g.Encode(&buf)
}
}
}
func BenchmarkGZIP_Encode(b *testing.B) {
testutil.RunPayloads(b, benchmarkGZIPEncode)
}
func benchmarkGZIPDecode(payloadSize int) func(b *testing.B) {
return func(b *testing.B) {
g := &GZIP{Data: make([]byte, payloadSize)}
_, err := io.ReadFull(rand.Reader, g.Data)
require.NoError(b, err)
var buf bin.Buffer
require.NoError(b, g.Encode(&buf))
b.ReportAllocs()
b.SetBytes(int64(payloadSize))
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = g.Decode(&bin.Buffer{Buf: buf.Buf})
}
}
}
func BenchmarkGZIP_Decode(b *testing.B) {
testutil.RunPayloads(b, benchmarkGZIPDecode)
}
+220
View File
@@ -0,0 +1,220 @@
package proto
import (
"fmt"
"sync"
"time"
)
// Message identifiers are coupled to message creation time.
//
// https://core.telegram.org/mtproto/description#message-identifier-msg-id
const (
yieldClient = 0
yieldServerResponse = 1
yieldFromServer = 3
messageIDModulo = 4
)
func newMessageID(nowNano int64, yield int) int64 {
const nano = 1e9
// Must approximately equal unixtime*2^32.
// Important: to counter replay-attacks the lower 32 bits of msg_id
// passed by the client must not be empty and must present a
// fractional part of the time point when the message was created.
intPart := nowNano / nano
fracPart := nowNano % nano
// Ensure that fracPart % 4 == 0.
fracPart &= -messageIDModulo
// Adding modulo 4 yield to ensure message type.
fracPart += int64(yield)
return (intPart << 32) | fracPart
}
// MessageID represents 64-bit message id.
type MessageID int64
func (id MessageID) String() string {
return fmt.Sprintf("%x (%s, %s)",
int64(id), id.Type(), id.Time().Format(time.RFC3339),
)
}
// MessageType is type of message determined by message id.
//
// A message is rejected over 300 seconds after it is created or
// 30 seconds before it is created (this is needed to protect from replay attacks).
//
// The identifier of a message container must be strictly greater than those of
// its nested messages.
type MessageType byte
const (
// MessageUnknown reports that message id has unknown time and probably
// should be ignored.
MessageUnknown MessageType = iota
// MessageFromClient is client message identifiers.
MessageFromClient
// MessageServerResponse is a response to a client message.
MessageServerResponse
// MessageFromServer is a message from the server.
MessageFromServer
)
func (m MessageType) String() string {
switch m {
case MessageFromClient:
return "FromClient"
case MessageServerResponse:
return "ServerResponse"
case MessageFromServer:
return "FromServer"
default:
return "Unknown"
}
}
// Time returns approximate time when MessageID were generated.
func (id MessageID) Time() time.Time {
intPart := int64(id) >> 32
fracPart := int64(int32(id))
return time.Unix(intPart, fracPart).UTC()
}
// Type returns message type.
func (id MessageID) Type() MessageType {
switch id % messageIDModulo {
case yieldClient:
return MessageFromClient
case yieldServerResponse:
return MessageServerResponse
case yieldFromServer:
return MessageFromServer
default:
return MessageUnknown
}
}
// NewMessageID returns new message id for provided time and type.
func NewMessageID(now time.Time, typ MessageType) MessageID {
return NewMessageIDNano(now.UnixNano(), typ)
}
// NewMessageIDNano returns new message id for provided current unix
// nanoseconds and type.
func NewMessageIDNano(nano int64, typ MessageType) MessageID {
var yield int
switch typ {
case MessageFromClient:
yield = yieldClient
case MessageFromServer:
yield = yieldFromServer
case MessageServerResponse:
yield = yieldServerResponse
default:
yield = yieldClient
}
return MessageID(newMessageID(nano, yield))
}
// MessageIDGen is message id generator that provides collision prevention.
//
// The main reason of such structure is that now() can return same time during
// multiple calls and that leads to duplicate message id.
type MessageIDGen struct {
mux sync.Mutex
nano int64
now func() time.Time
}
// New generates new message id for provided type, protecting from collisions
// that are caused by low system time resolution.
func (g *MessageIDGen) New(t MessageType) int64 {
g.mux.Lock()
defer g.mux.Unlock()
// Minimum resolution is required because id is only approximately
// equal to unix nano time, some part is replaced by message type.
const minResolutionNanos = 10
nano := g.now().UnixNano()
if nano > g.nano {
g.nano = nano
} else {
g.nano += minResolutionNanos
}
return int64(NewMessageIDNano(g.nano, t))
}
// NewMessageIDGen creates new message id generator.
//
// Current time will be provided by now() function.
//
// This generator compensates time resolution problem removing
// probability of id collision.
//
// Such problem can be observed for relatively high RPS, sequential calls to
// time.Now() will return same time which leads to equal ids.
func NewMessageIDGen(now func() time.Time) *MessageIDGen {
return &MessageIDGen{
now: now,
}
}
// MessageIDBuf stores last N message ids and is used in replay attack mitigation.
type MessageIDBuf struct {
mux sync.Mutex
buf []int64
}
// NewMessageIDBuf initializes new message id buffer for last N stored values.
func NewMessageIDBuf(n int) *MessageIDBuf {
return &MessageIDBuf{
buf: make([]int64, n),
}
}
// Consume returns false if message should be discarded.
func (b *MessageIDBuf) Consume(newID int64) bool {
// In addition, the identifiers (msg_id) of the last N messages received
// from the other side must be stored, and if a message comes in with an
// msg_id lower than all or equal to any of the stored values, that message
// is to be ignored. Otherwise, the new message msg_id is added to the set,
// and, if the number of stored msg_id values is greater than N, the oldest
// (i. e. the lowest) is discarded.
//
// https://core.telegram.org/mtproto/security_guidelines#checking-msg-id
b.mux.Lock()
defer b.mux.Unlock()
var (
minIDx int
minID int64
)
for i, id := range b.buf {
if id == newID {
// Equal to stored value.
return false
}
// Searching for minimum value.
if id < minID {
minIDx = i
minID = id
}
}
if newID < minID {
// Lower than all stored values.
return false
}
// Message is accepted. Replacing lowest message id with new id.
b.buf[minIDx] = newID
return true
}
+121
View File
@@ -0,0 +1,121 @@
package proto
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/gotd/neo"
"go.mau.fi/mautrix-telegram/pkg/gotd/testutil"
)
func TestMessageID(t *testing.T) {
now := time.Date(2018, 10, 10, 23, 42, 6, 13600, time.UTC)
id := MessageID(newMessageID(now.UnixNano(), 0))
if id.Type() != MessageFromClient {
t.Fatal("invalid type")
}
if id != 0x5bbe8e4e00003520 {
t.Error("mismatch")
}
delta := id.Time().Sub(now)
if delta < 0 {
delta *= -1
}
if delta > time.Second {
t.Fatal("unexpected time drift")
}
t.Run("NewMessageID", func(t *testing.T) {
if NewMessageID(now, MessageFromServer).Type() != MessageFromServer {
t.Error("Mismatch")
}
if NewMessageID(now, 100).Type() != MessageFromClient {
t.Error("Mismatch")
}
})
t.Run("String", func(t *testing.T) {
require.Equal(t, "5bbe8e4e00003520 (FromClient, 2018-10-10T23:42:06Z)", id.String())
})
}
func BenchmarkNewMessageID(b *testing.B) {
// Note that most overhead will be from time.Now() calls.
// Just ensuring that NewMessageID itself is reasonably fast.
now := testutil.Date()
for i := 0; i < b.N; i++ {
if NewMessageID(now, MessageFromServer).Type() != MessageFromServer {
b.Fatal("Mismatch")
}
}
}
func TestMessageIDGen(t *testing.T) {
date := testutil.Date()
clock := neo.NewTime(date)
gen := NewMessageIDGen(clock.Now)
met := make(map[int64]bool)
for i := 0; i < 1000; i++ {
if i%10 == 0 {
clock.Travel(time.Millisecond * 100)
}
id := gen.New(MessageFromClient)
if met[id] {
t.Fatal("met")
}
met[id] = true
}
}
func BenchmarkMsgIDGen_New(b *testing.B) {
b.ReportAllocs()
date := testutil.Date()
var dateCalls int
now := func() time.Time {
if dateCalls%100 == 0 {
date = date.Add(time.Millisecond)
}
return date
}
gen := NewMessageIDGen(now)
for i := 0; i < b.N; i++ {
_ = gen.New(MessageFromServer)
}
}
func TestNewMessageIDBuf(t *testing.T) {
t.Run("Zero", func(t *testing.T) {
buf := NewMessageIDBuf(10)
assert.False(t, buf.Consume(0))
})
t.Run("Ok", func(t *testing.T) {
buf := NewMessageIDBuf(10)
assert.True(t, buf.Consume(1))
assert.False(t, buf.Consume(1))
t.Run("Sequence", func(t *testing.T) {
for i := 2; i <= 20; i++ {
assert.True(t, buf.Consume(int64(i)))
}
assert.False(t, buf.Consume(-1))
})
})
}
func BenchmarkMessageIDBuf(b *testing.B) {
buf := NewMessageIDBuf(100)
for i := 0; i < b.N; i++ {
buf.Consume(int64(i))
}
}
+10
View File
@@ -0,0 +1,10 @@
package proto
// TypesMap returns mapping from type ids to TL type names.
func TypesMap() map[uint32]string {
return map[uint32]string{
MessageContainerTypeID: "message_container",
ResultTypeID: "rpc_result",
GZIPTypeID: "gzip",
}
}
+41
View File
@@ -0,0 +1,41 @@
package proto
import (
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
)
// ResultTypeID is TL type id of Result.
const ResultTypeID = 0xf35c6d01
// Result is rpc_result#f35c6d01.
type Result struct {
RequestMessageID int64
Result []byte
}
// Encode implements bin.Encoder.
func (r *Result) Encode(b *bin.Buffer) error {
b.PutID(ResultTypeID)
b.PutLong(r.RequestMessageID)
b.Put(r.Result)
return nil
}
// Decode implements bin.Decoder.
func (r *Result) Decode(b *bin.Buffer) error {
if err := b.ConsumeID(ResultTypeID); err != nil {
return err
}
{
v, err := b.Long()
if err != nil {
return err
}
r.RequestMessageID = v
}
r.Result = append(r.Result[:0], b.Buf...)
b.Skip(len(b.Buf))
return nil
}
+55
View File
@@ -0,0 +1,55 @@
package proto
import (
"github.com/go-faster/errors"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
)
// UnencryptedMessage is plaintext message.
type UnencryptedMessage struct {
MessageID int64
MessageData []byte
}
// Decode implements bin.Decoder.
func (u *UnencryptedMessage) Decode(b *bin.Buffer) error {
{
// Reading auth_key_id that should be always equal to zero.
id, err := b.Long()
if err != nil {
return err
}
if id != 0 {
return errors.Errorf("unexpected auth_key_id %d of plaintext message", id)
}
}
{
v, err := b.Long()
if err != nil {
return err
}
u.MessageID = v
}
// Reading data.
dataLen, err := b.Int32()
if err != nil {
return err
}
u.MessageData = append(u.MessageData[:0], make([]byte, dataLen)...)
if err := b.ConsumeN(u.MessageData, int(dataLen)); err != nil {
return errors.Wrap(err, "consume payload")
}
return nil
}
// Encode implements bin.Encoder.
func (u UnencryptedMessage) Encode(b *bin.Buffer) error {
b.PutLong(0)
b.PutLong(u.MessageID)
b.PutInt32(int32(len(u.MessageData)))
b.Put(u.MessageData)
return nil
}
@@ -0,0 +1,26 @@
package proto
import (
"testing"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
"github.com/stretchr/testify/require"
)
func TestUnencryptedMessage_Encode(t *testing.T) {
d := UnencryptedMessage{
MessageID: 3401235567,
MessageData: []byte{1, 2, 3, 100, 112},
}
b := new(bin.Buffer)
if err := d.Encode(b); err != nil {
t.Fatal(err)
}
decoded := UnencryptedMessage{}
if err := decoded.Decode(b); err != nil {
t.Fatal(err)
}
require.Equal(t, d, decoded)
require.Zero(t, b.Len(), "buffer should be consumed")
}