move gotd fork into repo. (#111)
- update to latest telegram layer - remove some references to fields in tg.Entities that don't exist in the schema - originally added here: https://github.com/beeper/td/commit/820929062a2ba0104397bc01235ab58a9cff780e - referenced here - https://github.com/mautrix/telegramgo/commit/124f0967ed195b5a380c9bd02e170ada9710dde3 - https://github.com/mautrix/telegramgo/commit/4205047aab2e0639217148b5d125bfaab668bd8e
This commit is contained in:
@@ -0,0 +1,145 @@
|
||||
package codec
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"io"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
)
|
||||
|
||||
// AbridgedClientStart is starting bytes sent by client in Abridged mode.
|
||||
//
|
||||
// Note that server does not respond with it.
|
||||
var AbridgedClientStart = [1]byte{0xef}
|
||||
|
||||
// Abridged is intermediate MTProto transport.
|
||||
//
|
||||
// See https://core.telegram.org/mtproto/mtproto-transports#abridged
|
||||
type Abridged struct{}
|
||||
|
||||
var (
|
||||
_ TaggedCodec = Abridged{}
|
||||
)
|
||||
|
||||
// WriteHeader sends protocol tag.
|
||||
func (i Abridged) WriteHeader(w io.Writer) error {
|
||||
if _, err := w.Write(AbridgedClientStart[:]); err != nil {
|
||||
return errors.Wrap(err, "write abridged header")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReadHeader reads protocol tag.
|
||||
func (i Abridged) ReadHeader(r io.Reader) error {
|
||||
var b [1]byte
|
||||
if _, err := io.ReadFull(r, b[:]); err != nil {
|
||||
return errors.Wrap(err, "read abridged header")
|
||||
}
|
||||
|
||||
if b != AbridgedClientStart {
|
||||
return ErrProtocolHeaderMismatch
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ObfuscatedTag returns protocol tag for obfuscation.
|
||||
func (i Abridged) ObfuscatedTag() (r [4]byte) {
|
||||
d := AbridgedClientStart[0]
|
||||
return [4]byte{d, d, d, d}
|
||||
}
|
||||
|
||||
// Write encode to writer message from given buffer.
|
||||
func (i Abridged) Write(w io.Writer, b *bin.Buffer) error {
|
||||
if err := checkOutgoingMessage(b); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := checkAlign(b, 4); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := writeAbridged(w, b); err != nil {
|
||||
return errors.Wrap(err, "write abridged")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Read fills buffer with received message.
|
||||
func (i Abridged) Read(r io.Reader, b *bin.Buffer) error {
|
||||
if err := readAbridged(r, b); err != nil {
|
||||
return errors.Wrap(err, "read abridged")
|
||||
}
|
||||
|
||||
return checkProtocolError(b)
|
||||
}
|
||||
|
||||
func writeAbridged(w io.Writer, b *bin.Buffer) error {
|
||||
length := b.Len()
|
||||
// Re-using b.Buf if possible to reduce allocations.
|
||||
b.Expand(4)
|
||||
b.Buf = b.Buf[:length]
|
||||
|
||||
// Re-using b.Buf if possible to reduce allocations.
|
||||
inner := bin.Buffer{Buf: b.Buf[length:length]}
|
||||
|
||||
encodeLength := b.Len() >> 2
|
||||
// `0x7f == 127`, literally use one bit to distinguish length byte size.
|
||||
if encodeLength < 127 {
|
||||
// Payloads are wrapped in the following envelope:
|
||||
//
|
||||
// Length: payload length, divided by four, and encoded as a single byte,
|
||||
// only if the resulting packet length is a value between 0x01..0x7e.
|
||||
inner.Put([]byte{byte(encodeLength)})
|
||||
} else {
|
||||
// If the packet length divided by four is bigger than or equal to 127 (>= 0x7f),
|
||||
// the following envelope must be used, instead:
|
||||
//
|
||||
var buf [5]byte
|
||||
// Header: A single byte of value 0x7f
|
||||
buf[0] = 0x7f
|
||||
// Length: payload length, divided by four, and encoded as 3 length bytes (little endian)
|
||||
binary.LittleEndian.PutUint32(buf[1:], uint32(encodeLength))
|
||||
inner.Put(buf[:4])
|
||||
}
|
||||
|
||||
if _, err := w.Write(inner.Buf); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := w.Write(b.Raw()); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func readAbridged(r io.Reader, b *bin.Buffer) error {
|
||||
b.ResetN(bin.Word)
|
||||
|
||||
_, err := io.ReadFull(r, b.Buf[:1])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if b.Buf[0] >= 127 {
|
||||
_, err := io.ReadFull(r, b.Buf[0:3])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
n, err := b.Int()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
b.ResetN(n << 2)
|
||||
if _, err := io.ReadFull(r, b.Buf); err != nil {
|
||||
return errors.Wrap(err, "read payload")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,69 @@
|
||||
package codec
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAbridged(t *testing.T) {
|
||||
bigHeader := func(l int) (packet []byte) {
|
||||
// header + 3 bytes of LE
|
||||
var buf [4]byte
|
||||
binary.LittleEndian.PutUint32(buf[:], uint32(l>>2))
|
||||
|
||||
packet = append([]byte{127}, buf[0:3]...)
|
||||
return
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
payloadName string
|
||||
testData func() (payload string, packet []byte)
|
||||
}{
|
||||
{"Small-4b", func() (payload string, packet []byte) {
|
||||
payload = "abcd"
|
||||
packet = append([]byte{byte(len(payload) >> 2)}, payload...)
|
||||
return
|
||||
}},
|
||||
|
||||
{"Medium-124b", func() (payload string, packet []byte) {
|
||||
payload = strings.Repeat("a", 124)
|
||||
packet = append([]byte{byte(len(payload) >> 2)}, payload...)
|
||||
return
|
||||
}},
|
||||
|
||||
{"Big-1kb", func() (payload string, packet []byte) {
|
||||
payload = strings.Repeat("a", 1024)
|
||||
packet = bigHeader(len(payload))
|
||||
require.Equal(t, []byte{127, 0, 1, 0}, packet)
|
||||
packet = append(packet, payload...)
|
||||
return
|
||||
}},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
payload, packet := test.testData()
|
||||
t.Run(test.payloadName, func(t *testing.T) {
|
||||
t.Run("Write", func(t *testing.T) {
|
||||
b := bytes.NewBuffer(nil)
|
||||
err := writeAbridged(b, &bin.Buffer{Buf: []byte(payload)})
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, packet, b.Bytes())
|
||||
})
|
||||
|
||||
t.Run("Read", func(t *testing.T) {
|
||||
b := &bin.Buffer{}
|
||||
err := readAbridged(bytes.NewReader(packet), b)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, payload, string(b.Raw()))
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,86 @@
|
||||
package codec
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/testutil"
|
||||
)
|
||||
|
||||
func benchWrite(codec Codec) func(payloadSize int) func(b *testing.B) {
|
||||
return func(payloadSize int) func(b *testing.B) {
|
||||
return func(b *testing.B) {
|
||||
buf := bin.Buffer{Buf: make([]byte, payloadSize)}
|
||||
if _, err := io.ReadFull(rand.Reader, buf.Buf); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
b.ReportAllocs()
|
||||
b.SetBytes(int64(buf.Len() + 4))
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
if err := codec.Write(io.Discard, &buf); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkWrite(b *testing.B) {
|
||||
b.Run("Abridged", func(b *testing.B) {
|
||||
testutil.RunPayloads(b, benchWrite(Abridged{}))
|
||||
})
|
||||
b.Run("Intermediate", func(b *testing.B) {
|
||||
testutil.RunPayloads(b, benchWrite(Intermediate{}))
|
||||
})
|
||||
b.Run("PaddedIntermediate", func(b *testing.B) {
|
||||
testutil.RunPayloads(b, benchWrite(PaddedIntermediate{}))
|
||||
})
|
||||
}
|
||||
|
||||
func benchRead(codec Codec) func(payloadSize int) func(b *testing.B) {
|
||||
return func(payloadSize int) func(b *testing.B) {
|
||||
return func(b *testing.B) {
|
||||
buf := bin.Buffer{Buf: make([]byte, payloadSize)}
|
||||
if _, err := io.ReadFull(rand.Reader, buf.Buf); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
out := new(bytes.Buffer)
|
||||
if err := codec.Write(out, &buf); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
raw := out.Bytes()
|
||||
reader := bytes.NewReader(nil)
|
||||
|
||||
b.ReportAllocs()
|
||||
b.SetBytes(int64(buf.Len() + 4))
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
reader.Reset(raw)
|
||||
if err := codec.Read(reader, &buf); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
buf.Reset()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkRead(b *testing.B) {
|
||||
b.Run("Abridged", func(b *testing.B) {
|
||||
testutil.RunPayloads(b, benchRead(Abridged{}))
|
||||
})
|
||||
b.Run("Intermediate", func(b *testing.B) {
|
||||
testutil.RunPayloads(b, benchRead(Intermediate{}))
|
||||
})
|
||||
b.Run("PaddedIntermediate", func(b *testing.B) {
|
||||
testutil.RunPayloads(b, benchRead(PaddedIntermediate{}))
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,44 @@
|
||||
package codec
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"io"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
)
|
||||
|
||||
// Codec is MTProto transport protocol encoding abstraction.
|
||||
type Codec interface {
|
||||
// WriteHeader sends protocol tag if needed.
|
||||
WriteHeader(w io.Writer) error
|
||||
// ReadHeader reads protocol tag if needed.
|
||||
ReadHeader(r io.Reader) error
|
||||
// Write encode to writer message from given buffer.
|
||||
Write(w io.Writer, b *bin.Buffer) error
|
||||
// Read fills buffer with received message.
|
||||
Read(r io.Reader, b *bin.Buffer) error
|
||||
}
|
||||
|
||||
// TaggedCodec is codec with protocol tag.
|
||||
type TaggedCodec interface {
|
||||
Codec
|
||||
// ObfuscatedTag returns protocol tag for obfuscation.
|
||||
ObfuscatedTag() [4]byte
|
||||
}
|
||||
|
||||
// readLen reads 32-bit integer and validates it as message length.
|
||||
func readLen(r io.Reader, b *bin.Buffer) (int, error) {
|
||||
b.ResetN(bin.Word)
|
||||
if _, err := io.ReadFull(r, b.Buf[:bin.Word]); err != nil {
|
||||
return 0, errors.Wrap(err, "read length")
|
||||
}
|
||||
n := int(binary.LittleEndian.Uint32(b.Buf[:bin.Word]))
|
||||
|
||||
if n <= 0 || n > maxMessageSize {
|
||||
return 0, invalidMsgLenErr{n: n}
|
||||
}
|
||||
|
||||
return n, nil
|
||||
}
|
||||
@@ -0,0 +1,197 @@
|
||||
package codec
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
"testing/iotest"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/testutil"
|
||||
)
|
||||
|
||||
type codecTest struct {
|
||||
name string
|
||||
align int
|
||||
create func() Codec
|
||||
}
|
||||
|
||||
func codecs() []codecTest {
|
||||
return []codecTest{
|
||||
{"Abridged", 4, func() Codec {
|
||||
return Abridged{}
|
||||
}},
|
||||
{"Intermediate", 4, func() Codec {
|
||||
return Intermediate{}
|
||||
}},
|
||||
{"PaddedIntermediate", 4, func() Codec {
|
||||
return PaddedIntermediate{}
|
||||
}},
|
||||
{"Full", 0, func() Codec {
|
||||
return &Full{}
|
||||
}},
|
||||
{"NoHeaderIntermediate", 0, func() Codec {
|
||||
return NoHeader{Codec: Intermediate{}}
|
||||
}},
|
||||
}
|
||||
}
|
||||
|
||||
type payload struct {
|
||||
name string
|
||||
testData string
|
||||
mustFail bool
|
||||
readTestOnly bool
|
||||
}
|
||||
|
||||
func payloads() []payload {
|
||||
var code [4]byte
|
||||
binary.LittleEndian.PutUint32(code[:], CodeTransportFlood)
|
||||
return []payload{
|
||||
{"Empty", "", true, false},
|
||||
{"Protocol error", string(code[:]), true, true},
|
||||
{"Small 8b", "abcdabcd", false, false},
|
||||
{"Medium 1kb", strings.Repeat("a", 1024), false, false},
|
||||
}
|
||||
}
|
||||
|
||||
func testGood(c codecTest, p payload) func(t *testing.T) {
|
||||
return func(t *testing.T) {
|
||||
t.Run("One message", func(t *testing.T) {
|
||||
a := require.New(t)
|
||||
codec := c.create()
|
||||
buf := bytes.NewBuffer(nil)
|
||||
payload := &bin.Buffer{Buf: []byte(p.testData)}
|
||||
|
||||
l := payload.Len()
|
||||
// Encode
|
||||
a.NoError(codec.Write(buf, payload))
|
||||
a.Equal(l, payload.Len(), "Codec must not change buffer length")
|
||||
|
||||
// Decode
|
||||
payload.Reset()
|
||||
a.NoError(codec.Read(buf, payload))
|
||||
a.Equal(p.testData, string(payload.Buf))
|
||||
})
|
||||
|
||||
t.Run("Two messages", func(t *testing.T) {
|
||||
a := require.New(t)
|
||||
codec := c.create()
|
||||
buf := bytes.NewBuffer(nil)
|
||||
payload := &bin.Buffer{Buf: []byte(p.testData)}
|
||||
|
||||
l := payload.Len()
|
||||
// Encode twice
|
||||
a.NoError(codec.Write(buf, payload))
|
||||
payload.ResetTo([]byte(p.testData))
|
||||
a.NoError(codec.Write(buf, payload))
|
||||
a.Equal(l, payload.Len(), "Codec must not change buffer length")
|
||||
|
||||
// Decode twice
|
||||
payload.Reset()
|
||||
a.NoError(codec.Read(buf, payload))
|
||||
a.Equal(p.testData, string(payload.Buf))
|
||||
payload.Reset()
|
||||
a.NoError(codec.Read(buf, payload))
|
||||
a.Equal(p.testData, string(payload.Buf))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func testBad(c codecTest, p payload) func(t *testing.T) {
|
||||
return func(t *testing.T) {
|
||||
if !p.readTestOnly {
|
||||
t.Run("Write", func(t *testing.T) {
|
||||
payload := &bin.Buffer{Buf: []byte(p.testData)}
|
||||
require.Error(t, c.create().Write(io.Discard, payload))
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("Read", func(t *testing.T) {
|
||||
reader := bytes.NewBufferString(p.testData)
|
||||
require.Error(t, c.create().Read(reader, &bin.Buffer{}))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func testHeaderTag(c codecTest) func(t *testing.T) {
|
||||
e := io.ErrClosedPipe
|
||||
return func(t *testing.T) {
|
||||
t.Run("GoodTag", func(t *testing.T) {
|
||||
a := require.New(t)
|
||||
buf := bytes.NewBuffer(nil)
|
||||
a.NoError(c.create().WriteHeader(buf))
|
||||
a.NoError(c.create().ReadHeader(buf))
|
||||
})
|
||||
|
||||
if tagged, ok := c.create().(TaggedCodec); ok {
|
||||
t.Run("ReadError", func(t *testing.T) {
|
||||
r := iotest.ErrReader(e)
|
||||
require.ErrorIs(t, c.create().ReadHeader(r), e)
|
||||
})
|
||||
t.Run("WriteError", func(t *testing.T) {
|
||||
w := testutil.ErrWriter(e)
|
||||
require.ErrorIs(t, c.create().WriteHeader(w), e)
|
||||
})
|
||||
t.Run("BadTag", func(t *testing.T) {
|
||||
tag := tagged.ObfuscatedTag()
|
||||
buf := bytes.NewBuffer(tag[:])
|
||||
tag[0] = 0
|
||||
require.ErrorIs(t, c.create().ReadHeader(buf), ErrProtocolHeaderMismatch)
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func testCodec(c codecTest) func(t *testing.T) {
|
||||
e := io.ErrClosedPipe
|
||||
|
||||
return func(t *testing.T) {
|
||||
t.Run("ReadError", func(t *testing.T) {
|
||||
r := iotest.ErrReader(e)
|
||||
require.ErrorIs(t, c.create().Read(r, &bin.Buffer{}), e)
|
||||
})
|
||||
|
||||
t.Run("WriteError", func(t *testing.T) {
|
||||
w := testutil.ErrWriter(e)
|
||||
require.ErrorIs(t,
|
||||
c.create().Write(
|
||||
w,
|
||||
&bin.Buffer{Buf: make([]byte, 16)},
|
||||
),
|
||||
e,
|
||||
)
|
||||
})
|
||||
|
||||
if c.align != 0 {
|
||||
t.Run("AlignError", func(t *testing.T) {
|
||||
require.Error(t,
|
||||
c.create().Write(
|
||||
io.Discard,
|
||||
&bin.Buffer{Buf: make([]byte, c.align-1)},
|
||||
),
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodecs(t *testing.T) {
|
||||
for _, c := range codecs() {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
for _, p := range payloads() {
|
||||
if p.mustFail {
|
||||
t.Run(p.name, testBad(c, p))
|
||||
} else {
|
||||
t.Run(p.name, testGood(c, p))
|
||||
}
|
||||
}
|
||||
|
||||
t.Run("Header", testHeaderTag(c))
|
||||
testCodec(c)(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,4 @@
|
||||
// Package codec contains MTProto transport encoding implementations.
|
||||
//
|
||||
// See https://core.telegram.org/mtproto/mtproto-transports
|
||||
package codec
|
||||
@@ -0,0 +1,106 @@
|
||||
package codec
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
)
|
||||
|
||||
const (
|
||||
// CodeAuthKeyNotFound means that specified auth key ID cannot be found by the DC.
|
||||
// Also, may be returned during key exchange.
|
||||
CodeAuthKeyNotFound = 404
|
||||
|
||||
// CodeWrongDC means that current DC is wrong.
|
||||
// Usually returned by server when key exchange sends wrong DC ID.
|
||||
CodeWrongDC = 444
|
||||
|
||||
// CodeTransportFlood means that too many transport connections are
|
||||
// established to the same IP in a too short lapse of time, or if any
|
||||
// of the container/service message limits are reached.
|
||||
CodeTransportFlood = 429
|
||||
)
|
||||
|
||||
// ProtocolErr represents protocol level error.
|
||||
type ProtocolErr struct {
|
||||
Code int32
|
||||
}
|
||||
|
||||
func (p ProtocolErr) Error() string {
|
||||
switch p.Code {
|
||||
case CodeAuthKeyNotFound:
|
||||
return "auth key not found"
|
||||
case CodeTransportFlood:
|
||||
return "transport flood"
|
||||
case CodeWrongDC:
|
||||
return "wrong DC"
|
||||
default:
|
||||
return fmt.Sprintf("protocol error %d", p.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// Can be bigger that 1mb.
|
||||
//
|
||||
// See https://github.com/gotd/td/issues/412
|
||||
//
|
||||
// See https://github.com/tdlib/td/blob/550ccc8d9bbbe9cff1dc618aef5764d2cbd2cd91/td/mtproto/TcpTransport.cpp#L53
|
||||
const maxMessageSize = 1 << 24 // 16 MB
|
||||
|
||||
func checkOutgoingMessage(b *bin.Buffer) error {
|
||||
length := b.Len()
|
||||
if length > maxMessageSize || length == 0 {
|
||||
return invalidMsgLenErr{n: length}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func checkAlign(b *bin.Buffer, n int) error {
|
||||
length := b.Len()
|
||||
if length%n != 0 {
|
||||
return alignedPayloadExpectedErr{expected: n}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func checkProtocolError(b *bin.Buffer) error {
|
||||
if b.Len() != bin.Word {
|
||||
return nil
|
||||
}
|
||||
code, err := b.Int32()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return &ProtocolErr{Code: -code}
|
||||
}
|
||||
|
||||
type alignedPayloadExpectedErr struct {
|
||||
expected int
|
||||
}
|
||||
|
||||
func (e alignedPayloadExpectedErr) Error() string {
|
||||
return fmt.Sprintf("payload is not aligned, expected align by %d", e.expected)
|
||||
}
|
||||
|
||||
func (e alignedPayloadExpectedErr) Is(err error) bool {
|
||||
_, ok := err.(alignedPayloadExpectedErr)
|
||||
return ok
|
||||
}
|
||||
|
||||
type invalidMsgLenErr struct {
|
||||
n int
|
||||
}
|
||||
|
||||
func (e invalidMsgLenErr) Error() string {
|
||||
return fmt.Sprintf("invalid message length %d", e.n)
|
||||
}
|
||||
|
||||
func (e invalidMsgLenErr) Is(err error) bool {
|
||||
_, ok := err.(invalidMsgLenErr)
|
||||
return ok
|
||||
}
|
||||
|
||||
// ErrProtocolHeaderMismatch means that received protocol header
|
||||
// is mismatched with expected.
|
||||
var ErrProtocolHeaderMismatch = errors.New("protocol header mismatch")
|
||||
@@ -0,0 +1,124 @@
|
||||
package codec
|
||||
|
||||
import (
|
||||
"hash/crc32"
|
||||
"io"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
)
|
||||
|
||||
// Full is full MTProto transport.
|
||||
//
|
||||
// See https://core.telegram.org/mtproto/mtproto-transports#full
|
||||
type Full struct {
|
||||
wSeqNo int64
|
||||
rSeqNo int64
|
||||
}
|
||||
|
||||
// WriteHeader sends protocol tag.
|
||||
func (i *Full) WriteHeader(w io.Writer) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReadHeader reads protocol tag.
|
||||
func (i *Full) ReadHeader(r io.Reader) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Write encode to writer message from given buffer.
|
||||
func (i *Full) Write(w io.Writer, b *bin.Buffer) error {
|
||||
if err := checkOutgoingMessage(b); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := writeFull(w, int(atomic.AddInt64(&i.wSeqNo, 1)-1), b); err != nil {
|
||||
return errors.Wrap(err, "write full")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Read fills buffer with received message.
|
||||
func (i *Full) Read(r io.Reader, b *bin.Buffer) error {
|
||||
if err := readFull(r, int(atomic.AddInt64(&i.rSeqNo, 1)-1), b); err != nil {
|
||||
return errors.Wrap(err, "read full")
|
||||
}
|
||||
|
||||
return checkProtocolError(b)
|
||||
}
|
||||
|
||||
func writeFull(w io.Writer, seqNo int, b *bin.Buffer) error {
|
||||
write := bin.Buffer{Buf: make([]byte, 0, 4+4+b.Len()+4)}
|
||||
// Length: length+seqno+payload+crc length encoded as 4 length bytes
|
||||
// (little endian, the length of the length field must be included, too)
|
||||
write.PutInt(4 + 4 + b.Len() + 4)
|
||||
// Seqno: the TCP sequence number for this TCP connection (different from the MTProto sequence number):
|
||||
// the first packet sent is numbered 0, the next one 1, etc.
|
||||
write.PutInt(seqNo)
|
||||
// payload: MTProto payload
|
||||
write.Put(b.Raw())
|
||||
// crc: 4 CRC32 bytes computed using length, sequence number, and payload together.
|
||||
crc := crc32.ChecksumIEEE(write.Raw())
|
||||
write.PutUint32(crc)
|
||||
|
||||
if _, err := w.Write(write.Raw()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var errSeqNoMismatch = errors.New("seq_no mismatch")
|
||||
var errCRCMismatch = errors.New("crc mismatch")
|
||||
|
||||
func readFull(r io.Reader, seqNo int, b *bin.Buffer) error {
|
||||
n, err := readLen(r, b)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "len")
|
||||
}
|
||||
|
||||
// Put length, because it need to count CRC.
|
||||
b.PutInt(n)
|
||||
b.Expand(n - bin.Word)
|
||||
inner := &bin.Buffer{Buf: b.Buf[bin.Word:n]}
|
||||
|
||||
// Reads tail of packet to the buffer.
|
||||
// Length already read.
|
||||
if _, err := io.ReadFull(r, inner.Buf); err != nil {
|
||||
return errors.Wrap(err, "read seqno, buffer and crc")
|
||||
}
|
||||
|
||||
serverSeqNo, err := inner.Int()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if serverSeqNo != seqNo {
|
||||
return errSeqNoMismatch
|
||||
}
|
||||
|
||||
payloadLength := n - 3*bin.Word
|
||||
inner.Skip(payloadLength)
|
||||
|
||||
// Cut only crc part.
|
||||
crc, err := inner.Uint32()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Compute crc using all buffer without last 4 bytes from server.
|
||||
clientCRC := crc32.ChecksumIEEE(b.Buf[0 : n-bin.Word])
|
||||
// Compare computed and read CRCs.
|
||||
if crc != clientCRC {
|
||||
return errCRCMismatch
|
||||
}
|
||||
|
||||
// n
|
||||
// Length | SeqNo | payload | CRC |
|
||||
// Word | Word | ....... | Word |
|
||||
copy(b.Buf, b.Buf[2*bin.Word:n-bin.Word])
|
||||
b.Buf = b.Buf[:payloadLength]
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,40 @@
|
||||
package codec
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
)
|
||||
|
||||
func fullTestData() (packet, payload []byte) {
|
||||
return []byte{
|
||||
15, 0, 0, 0, // length
|
||||
1, 0, 0, 0, // seqNo
|
||||
97, 98, 99, // payload
|
||||
78, 214, 109, 148, // crc
|
||||
}, []byte("abc")
|
||||
}
|
||||
|
||||
func TestFull(t *testing.T) {
|
||||
packet, payload := fullTestData()
|
||||
t.Run("write", func(t *testing.T) {
|
||||
b := bytes.NewBuffer(nil)
|
||||
|
||||
buf := &bin.Buffer{Buf: payload}
|
||||
err := writeFull(b, 1, buf)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, packet, b.Bytes())
|
||||
})
|
||||
|
||||
t.Run("read", func(t *testing.T) {
|
||||
b := &bin.Buffer{}
|
||||
err := readFull(bytes.NewBuffer(packet), 1, b)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, payload, b.Buf)
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,115 @@
|
||||
package codec
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
)
|
||||
|
||||
// IntermediateClientStart is starting bytes sent by client in Intermediate mode.
|
||||
//
|
||||
// Note that server does not respond with it.
|
||||
var IntermediateClientStart = [4]byte{0xee, 0xee, 0xee, 0xee}
|
||||
|
||||
// Intermediate is intermediate MTProto transport.
|
||||
//
|
||||
// See https://core.telegram.org/mtproto/mtproto-transports#intermediate
|
||||
type Intermediate struct{}
|
||||
|
||||
var (
|
||||
_ TaggedCodec = Intermediate{}
|
||||
)
|
||||
|
||||
// WriteHeader sends protocol tag.
|
||||
func (i Intermediate) WriteHeader(w io.Writer) (err error) {
|
||||
if _, err := w.Write(IntermediateClientStart[:]); err != nil {
|
||||
return errors.Wrap(err, "write intermediate header")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReadHeader reads protocol tag.
|
||||
func (i Intermediate) ReadHeader(r io.Reader) (err error) {
|
||||
var b [4]byte
|
||||
if _, err := io.ReadFull(r, b[:]); err != nil {
|
||||
return errors.Wrap(err, "read intermediate header")
|
||||
}
|
||||
|
||||
if b != IntermediateClientStart {
|
||||
return ErrProtocolHeaderMismatch
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ObfuscatedTag returns protocol tag for obfuscation.
|
||||
func (i Intermediate) ObfuscatedTag() [4]byte {
|
||||
return IntermediateClientStart
|
||||
}
|
||||
|
||||
// Write encode to writer message from given buffer.
|
||||
func (i Intermediate) Write(w io.Writer, b *bin.Buffer) error {
|
||||
if err := checkOutgoingMessage(b); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := checkAlign(b, 4); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := writeIntermediate(w, b); err != nil {
|
||||
return errors.Wrap(err, "write intermediate")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Read fills buffer with received message.
|
||||
func (i Intermediate) Read(r io.Reader, b *bin.Buffer) error {
|
||||
if err := readIntermediate(r, b, false); err != nil {
|
||||
return errors.Wrap(err, "read intermediate")
|
||||
}
|
||||
|
||||
return checkProtocolError(b)
|
||||
}
|
||||
|
||||
// writeIntermediate encodes b as payload to w.
|
||||
func writeIntermediate(w io.Writer, b *bin.Buffer) error {
|
||||
length := b.Len()
|
||||
// Re-using b.Buf if possible to reduce allocations.
|
||||
b.Expand(4)
|
||||
b.Buf = b.Buf[:length]
|
||||
|
||||
inner := bin.Buffer{Buf: b.Buf[length:length]}
|
||||
inner.PutInt(b.Len())
|
||||
if _, err := w.Write(inner.Buf); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := w.Write(b.Buf); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// readIntermediate reads payload from r to b.
|
||||
func readIntermediate(r io.Reader, b *bin.Buffer, padding bool) error {
|
||||
n, err := readLen(r, b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
b.ResetN(n)
|
||||
if _, err := io.ReadFull(r, b.Buf); err != nil {
|
||||
return errors.Wrap(err, "read payload")
|
||||
}
|
||||
|
||||
if padding {
|
||||
paddingLength := n % 4
|
||||
b.Buf = b.Buf[:n-paddingLength]
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,45 @@
|
||||
package codec
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
)
|
||||
|
||||
func TestIntermediate(t *testing.T) {
|
||||
t.Run("Ok", func(t *testing.T) {
|
||||
msg := bytes.Repeat([]byte{1, 2, 3}, 100)
|
||||
buf := new(bytes.Buffer)
|
||||
if err := writeIntermediate(buf, &bin.Buffer{Buf: msg}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var out bin.Buffer
|
||||
if err := readIntermediate(buf, &out, false); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.Equal(t, msg, out.Buf)
|
||||
})
|
||||
t.Run("BigMessage", func(t *testing.T) {
|
||||
codec := Intermediate{}
|
||||
t.Run("Read", func(t *testing.T) {
|
||||
var b bin.Buffer
|
||||
b.PutInt(maxMessageSize + 10)
|
||||
|
||||
var out bin.Buffer
|
||||
if err := codec.Read(&b, &out); !errors.Is(err, invalidMsgLenErr{}) {
|
||||
t.Error(err)
|
||||
}
|
||||
})
|
||||
t.Run("Write", func(t *testing.T) {
|
||||
buf := make([]byte, maxMessageSize+10)
|
||||
|
||||
if err := codec.Write(nil, &bin.Buffer{Buf: buf}); !errors.Is(err, invalidMsgLenErr{}) {
|
||||
t.Error(err)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,18 @@
|
||||
package codec
|
||||
|
||||
import "io"
|
||||
|
||||
// NoHeader wraps codec to skip WriteHeader.
|
||||
type NoHeader struct {
|
||||
Codec
|
||||
}
|
||||
|
||||
// WriteHeader implements Codec.
|
||||
func (NoHeader) WriteHeader(io.Writer) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReadHeader implements Codec.
|
||||
func (NoHeader) ReadHeader(io.Reader) error {
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,20 @@
|
||||
package codec
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNoHeader(t *testing.T) {
|
||||
a := require.New(t)
|
||||
cdc := NoHeader{
|
||||
Codec: Intermediate{},
|
||||
}
|
||||
|
||||
buf := bytes.Buffer{}
|
||||
a.NoError(cdc.WriteHeader(&buf))
|
||||
a.Equal(0, buf.Len())
|
||||
a.NoError(cdc.ReadHeader(&buf))
|
||||
}
|
||||
@@ -0,0 +1,106 @@
|
||||
package codec
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
|
||||
)
|
||||
|
||||
// PaddedIntermediateClientStart is starting bytes sent by client in Padded intermediate mode.
|
||||
//
|
||||
// Note that server does not respond with it.
|
||||
var PaddedIntermediateClientStart = [4]byte{0xdd, 0xdd, 0xdd, 0xdd}
|
||||
|
||||
// PaddedIntermediate is intermediate MTProto transport.
|
||||
//
|
||||
// See https://core.telegram.org/mtproto/mtproto-transports#padded-intermediate
|
||||
type PaddedIntermediate struct{}
|
||||
|
||||
var (
|
||||
_ TaggedCodec = PaddedIntermediate{}
|
||||
)
|
||||
|
||||
// WriteHeader sends protocol tag.
|
||||
func (i PaddedIntermediate) WriteHeader(w io.Writer) error {
|
||||
if _, err := w.Write(PaddedIntermediateClientStart[:]); err != nil {
|
||||
return errors.Wrap(err, "write padded intermediate header")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReadHeader reads protocol tag.
|
||||
func (i PaddedIntermediate) ReadHeader(r io.Reader) error {
|
||||
var b [4]byte
|
||||
if _, err := io.ReadFull(r, b[:]); err != nil {
|
||||
return errors.Wrap(err, "read padded intermediate header")
|
||||
}
|
||||
|
||||
if b != PaddedIntermediateClientStart {
|
||||
return ErrProtocolHeaderMismatch
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ObfuscatedTag returns protocol tag for obfuscation.
|
||||
func (i PaddedIntermediate) ObfuscatedTag() [4]byte {
|
||||
return PaddedIntermediateClientStart
|
||||
}
|
||||
|
||||
// Write encode to writer message from given buffer.
|
||||
func (i PaddedIntermediate) Write(w io.Writer, b *bin.Buffer) error {
|
||||
if err := checkOutgoingMessage(b); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := checkAlign(b, 4); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := writePaddedIntermediate(crypto.DefaultRand(), w, b); err != nil {
|
||||
return errors.Wrap(err, "write padded intermediate")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Read fills buffer with received message.
|
||||
func (i PaddedIntermediate) Read(r io.Reader, b *bin.Buffer) error {
|
||||
if err := readPaddedIntermediate(r, b); err != nil {
|
||||
return errors.Wrap(err, "read padded intermediate")
|
||||
}
|
||||
|
||||
return checkProtocolError(b)
|
||||
}
|
||||
|
||||
func writePaddedIntermediate(randSource io.Reader, w io.Writer, b *bin.Buffer) error {
|
||||
length := b.Len()
|
||||
|
||||
b.Expand(4)
|
||||
defer func() {
|
||||
b.Buf = b.Buf[:length]
|
||||
}()
|
||||
|
||||
_, err := io.ReadFull(randSource, b.Buf[length:length+4])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
n := int(b.Buf[length-1]) % 4
|
||||
b.Buf = b.Buf[:length+n]
|
||||
|
||||
return writeIntermediate(w, b)
|
||||
}
|
||||
|
||||
func readPaddedIntermediate(r io.Reader, b *bin.Buffer) error {
|
||||
if err := readIntermediate(r, b, true); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
padding := b.Len() % 4
|
||||
b.Buf = b.Buf[:b.Len()-padding]
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user