move gotd fork into repo. (#111)
- update to latest telegram layer - remove some references to fields in tg.Entities that don't exist in the schema - originally added here: https://github.com/beeper/td/commit/820929062a2ba0104397bc01235ab58a9cff780e - referenced here - https://github.com/mautrix/telegramgo/commit/124f0967ed195b5a380c9bd02e170ada9710dde3 - https://github.com/mautrix/telegramgo/commit/4205047aab2e0639217148b5d125bfaab668bd8e
This commit is contained in:
@@ -0,0 +1,19 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
)
|
||||
|
||||
// Codec is MTProto transport protocol encoding abstraction.
|
||||
type Codec interface {
|
||||
// WriteHeader sends protocol tag if needed.
|
||||
WriteHeader(w io.Writer) error
|
||||
// ReadHeader reads protocol tag if needed.
|
||||
ReadHeader(r io.Reader) error
|
||||
// Write encode to writer message from given buffer.
|
||||
Write(w io.Writer, b *bin.Buffer) error
|
||||
// Read fills buffer with received message.
|
||||
Read(r io.Reader, b *bin.Buffer) error
|
||||
}
|
||||
@@ -0,0 +1,79 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
)
|
||||
|
||||
// Conn is transport connection.
|
||||
type Conn interface {
|
||||
Send(ctx context.Context, b *bin.Buffer) error
|
||||
Recv(ctx context.Context, b *bin.Buffer) error
|
||||
Close() error
|
||||
}
|
||||
|
||||
var _ Conn = (*connection)(nil)
|
||||
|
||||
// connection is MTProto connection.
|
||||
type connection struct {
|
||||
conn net.Conn
|
||||
codec Codec
|
||||
|
||||
readMux sync.Mutex
|
||||
writeMux sync.Mutex
|
||||
}
|
||||
|
||||
// Send sends message from buffer using MTProto connection.
|
||||
func (c *connection) Send(ctx context.Context, b *bin.Buffer) error {
|
||||
// Serializing access to deadlines.
|
||||
c.writeMux.Lock()
|
||||
defer c.writeMux.Unlock()
|
||||
|
||||
if err := c.conn.SetWriteDeadline(time.Time{}); err != nil {
|
||||
return errors.Wrap(err, "reset write deadline")
|
||||
}
|
||||
if deadline, ok := ctx.Deadline(); ok {
|
||||
if err := c.conn.SetWriteDeadline(deadline); err != nil {
|
||||
return errors.Wrap(err, "set write deadline")
|
||||
}
|
||||
}
|
||||
|
||||
if err := c.codec.Write(c.conn, b); err != nil {
|
||||
return errors.Wrap(err, "write")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Recv reads message to buffer using MTProto connection.
|
||||
func (c *connection) Recv(ctx context.Context, b *bin.Buffer) error {
|
||||
// Serializing access to deadlines.
|
||||
c.readMux.Lock()
|
||||
defer c.readMux.Unlock()
|
||||
|
||||
if err := c.conn.SetReadDeadline(time.Time{}); err != nil {
|
||||
return errors.Wrap(err, "reset read deadline")
|
||||
}
|
||||
if deadline, ok := ctx.Deadline(); ok {
|
||||
if err := c.conn.SetReadDeadline(deadline); err != nil {
|
||||
return errors.Wrap(err, "set read deadline")
|
||||
}
|
||||
}
|
||||
|
||||
if err := c.codec.Read(c.conn, b); err != nil {
|
||||
return errors.Wrap(err, "read")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes MTProto connection.
|
||||
func (c *connection) Close() error {
|
||||
return c.conn.Close()
|
||||
}
|
||||
@@ -0,0 +1,51 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/proto/codec"
|
||||
)
|
||||
|
||||
func TestConnection(t *testing.T) {
|
||||
leftConn, rightConn := net.Pipe()
|
||||
intermediate := codec.Intermediate{}
|
||||
|
||||
left := connection{
|
||||
conn: leftConn,
|
||||
codec: intermediate,
|
||||
}
|
||||
right := connection{
|
||||
conn: rightConn,
|
||||
codec: intermediate,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||
defer cancel()
|
||||
|
||||
buf := bytes.Repeat([]byte{1, 2, 3, 4}, 50)
|
||||
done := make(chan []byte)
|
||||
go func() {
|
||||
defer close(done)
|
||||
|
||||
var b bin.Buffer
|
||||
if err := right.Recv(ctx, &b); err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
done <- b.Buf
|
||||
}()
|
||||
|
||||
if err := left.Send(ctx, &bin.Buffer{Buf: buf}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
require.Equal(t, buf, <-done)
|
||||
}
|
||||
@@ -0,0 +1,35 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/proto/codec"
|
||||
)
|
||||
|
||||
func detectCodec(c io.Reader) (Codec, io.Reader, error) {
|
||||
var buf [4]byte
|
||||
if _, err := io.ReadFull(c, buf[:1]); err != nil {
|
||||
return nil, nil, errors.Wrap(err, "read first byte")
|
||||
}
|
||||
|
||||
if buf[0] == codec.AbridgedClientStart[0] {
|
||||
return Abridged.Codec(), c, nil
|
||||
}
|
||||
|
||||
if _, err := io.ReadFull(c, buf[1:4]); err != nil {
|
||||
return nil, nil, errors.Wrap(err, "read header")
|
||||
}
|
||||
switch buf {
|
||||
case codec.IntermediateClientStart:
|
||||
return Intermediate.Codec(), c, nil
|
||||
case codec.PaddedIntermediateClientStart:
|
||||
return PaddedIntermediate.Codec(), c, nil
|
||||
default:
|
||||
buffered := bytes.NewReader(buf[:])
|
||||
r := io.MultiReader(buffered, c)
|
||||
return Full.Codec(), r, nil
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,70 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/proto/codec"
|
||||
)
|
||||
|
||||
func Test_detectCodec(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
data []byte
|
||||
resultType Codec
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
"Abridged",
|
||||
codec.AbridgedClientStart[:],
|
||||
codec.Abridged{},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"Intermediate",
|
||||
codec.IntermediateClientStart[:],
|
||||
codec.Intermediate{}, false,
|
||||
},
|
||||
{
|
||||
"PaddedIntermediate",
|
||||
codec.PaddedIntermediateClientStart[:],
|
||||
codec.PaddedIntermediate{},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"Full",
|
||||
[]byte{'g', 'o', 't', 'd'},
|
||||
&codec.Full{},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"EOF-first",
|
||||
nil,
|
||||
nil,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"EOF-second",
|
||||
[]byte{'a'},
|
||||
nil,
|
||||
true,
|
||||
},
|
||||
}
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
a := require.New(t)
|
||||
|
||||
r, _, err := detectCodec(bytes.NewReader(test.data))
|
||||
if test.wantErr {
|
||||
a.Nil(r)
|
||||
a.Error(err)
|
||||
} else {
|
||||
a.NotNil(r)
|
||||
a.NoError(err)
|
||||
a.IsType(test.resultType, r)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,3 @@
|
||||
// Package transport contains different MTProto transport
|
||||
// implementations.
|
||||
package transport
|
||||
@@ -0,0 +1,92 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
"go.uber.org/multierr"
|
||||
)
|
||||
|
||||
// Listener is a simple net.Listener wrapper for listening
|
||||
// MTProto transport connections.
|
||||
type Listener struct {
|
||||
codec func() Codec
|
||||
listener net.Listener
|
||||
}
|
||||
|
||||
// Listen creates new Listener using given net.Listener.
|
||||
// Transport codec will be detected automatically.
|
||||
func Listen(listener net.Listener) Listener {
|
||||
return ListenCodec(nil, listener)
|
||||
}
|
||||
|
||||
// ListenCodec creates new Listener using given net.Listener.
|
||||
// Listener will always use given Codec constructor.
|
||||
func ListenCodec(codec func() Codec, listener net.Listener) Listener {
|
||||
return Listener{
|
||||
codec: codec,
|
||||
listener: &onceCloseListener{Listener: listener},
|
||||
}
|
||||
}
|
||||
|
||||
type wrappedConn struct {
|
||||
reader io.Reader
|
||||
net.Conn
|
||||
}
|
||||
|
||||
func (w wrappedConn) Read(b []byte) (int, error) {
|
||||
return w.reader.Read(b)
|
||||
}
|
||||
|
||||
// Accept waits for and returns the next connection to the listener.
|
||||
func (l Listener) Accept() (_ Conn, rErr error) {
|
||||
conn, err := l.listener.Accept()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
if rErr != nil {
|
||||
multierr.AppendInto(&rErr, conn.Close())
|
||||
}
|
||||
}()
|
||||
|
||||
// If codec provided explicitly, use it.
|
||||
if l.codec != nil {
|
||||
codec := l.codec()
|
||||
|
||||
if err := codec.ReadHeader(conn); err != nil {
|
||||
return nil, errors.Wrap(err, "read header")
|
||||
}
|
||||
|
||||
return &connection{
|
||||
conn: conn,
|
||||
codec: codec,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Otherwise try to detect codec.
|
||||
transportCodec, reader, err := detectCodec(conn)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "detect codec")
|
||||
}
|
||||
|
||||
return &connection{
|
||||
conn: wrappedConn{
|
||||
reader: reader,
|
||||
Conn: conn,
|
||||
},
|
||||
codec: transportCodec,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Close closes the listener.
|
||||
// Any blocked Accept operations will be unblocked and return errors.
|
||||
func (l Listener) Close() error {
|
||||
return l.listener.Close()
|
||||
}
|
||||
|
||||
// Addr returns the listener's network address.
|
||||
func (l Listener) Addr() net.Addr {
|
||||
return l.listener.Addr()
|
||||
}
|
||||
@@ -0,0 +1,135 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/proto/codec"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/testutil"
|
||||
)
|
||||
|
||||
type mockListener struct {
|
||||
connData []byte
|
||||
acceptErr error
|
||||
addr net.Addr
|
||||
}
|
||||
|
||||
type mockConn struct {
|
||||
reader bytes.Reader
|
||||
closed bool
|
||||
}
|
||||
|
||||
func (m *mockConn) Read(b []byte) (int, error) {
|
||||
return m.reader.Read(b)
|
||||
}
|
||||
|
||||
func (m *mockConn) Write(b []byte) (int, error) {
|
||||
return 0, io.ErrClosedPipe
|
||||
}
|
||||
|
||||
func (m *mockConn) Close() error {
|
||||
m.closed = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockConn) LocalAddr() net.Addr {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockConn) RemoteAddr() net.Addr {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockConn) SetDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockConn) SetReadDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockConn) SetWriteDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m mockListener) Accept() (net.Conn, error) {
|
||||
return &mockConn{
|
||||
reader: *bytes.NewReader(m.connData),
|
||||
closed: false,
|
||||
}, m.acceptErr
|
||||
}
|
||||
|
||||
func (m mockListener) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m mockListener) Addr() net.Addr {
|
||||
return m.addr
|
||||
}
|
||||
|
||||
func TestListener_Accept(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
data []byte
|
||||
codec func() Codec
|
||||
wantErr bool
|
||||
}{
|
||||
{"DetectCodec", codec.PaddedIntermediateClientStart[:], nil, false},
|
||||
{"PassCodec", codec.AbridgedClientStart[:], Abridged.Codec, false},
|
||||
{"InvalidCodec", codec.PaddedIntermediateClientStart[:], Abridged.Codec, true},
|
||||
{"FirstByteError", nil, nil, true},
|
||||
{"HeaderError", make([]byte, 3), nil, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a := require.New(t)
|
||||
m := mockListener{
|
||||
connData: tt.data,
|
||||
}
|
||||
|
||||
l := ListenCodec(tt.codec, m)
|
||||
defer func() {
|
||||
a.NoError(l.Close())
|
||||
}()
|
||||
|
||||
conn, err := l.Accept()
|
||||
if tt.wantErr {
|
||||
a.Error(err)
|
||||
if c, ok := conn.(*connection); ok {
|
||||
a.True(c.conn.(*mockConn).closed)
|
||||
}
|
||||
} else {
|
||||
a.NoError(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("AcceptError", func(t *testing.T) {
|
||||
e := testutil.TestError()
|
||||
m := Listener{
|
||||
listener: mockListener{
|
||||
acceptErr: e,
|
||||
},
|
||||
}
|
||||
|
||||
_, err := m.Accept()
|
||||
require.ErrorIs(t, err, e)
|
||||
})
|
||||
}
|
||||
|
||||
func TestListener_Addr(t *testing.T) {
|
||||
addr := &net.TCPAddr{
|
||||
IP: net.IPv4(127, 0, 0, 1),
|
||||
Port: 443,
|
||||
Zone: "",
|
||||
}
|
||||
|
||||
l := Listener{listener: mockListener{addr: addr}}
|
||||
require.Equal(t, addr, l.Addr())
|
||||
}
|
||||
@@ -0,0 +1,85 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
"go.uber.org/multierr"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/mtproxy/obfuscated2"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/proto/codec"
|
||||
)
|
||||
|
||||
type obfListener struct {
|
||||
listener net.Listener
|
||||
}
|
||||
|
||||
type obfConn struct {
|
||||
reader io.Reader
|
||||
writer io.Writer
|
||||
net.Conn
|
||||
}
|
||||
|
||||
func (c *obfConn) Read(p []byte) (int, error) {
|
||||
return c.reader.Read(p)
|
||||
}
|
||||
|
||||
func (c *obfConn) Write(p []byte) (int, error) {
|
||||
return c.writer.Write(p)
|
||||
}
|
||||
|
||||
// ObfuscatedListener creates new obfuscated2 listener using given net.Listener.
|
||||
//
|
||||
// Useful for creating Telegram servers:
|
||||
//
|
||||
// transport.Listen(transport.ObfuscatedListener(ln))
|
||||
func ObfuscatedListener(listener net.Listener) net.Listener {
|
||||
return obfListener{listener: listener}
|
||||
}
|
||||
|
||||
// Accept waits for and returns the next connection to the listener.
|
||||
func (l obfListener) Accept() (_ net.Conn, err error) {
|
||||
conn, err := l.listener.Accept()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
multierr.AppendInto(&err, conn.Close())
|
||||
}
|
||||
}()
|
||||
|
||||
rw, md, err := obfuscated2.Accept(conn, nil)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "accept")
|
||||
}
|
||||
|
||||
var tag *bytes.Reader
|
||||
if md.Protocol[0] == codec.AbridgedClientStart[0] {
|
||||
// Abridged sends only byte for tag.
|
||||
tag = bytes.NewReader(md.Protocol[:1])
|
||||
} else {
|
||||
tag = bytes.NewReader(md.Protocol[:])
|
||||
}
|
||||
|
||||
accepted := &obfConn{
|
||||
reader: io.MultiReader(tag, rw),
|
||||
writer: rw,
|
||||
Conn: conn,
|
||||
}
|
||||
|
||||
return accepted, nil
|
||||
}
|
||||
|
||||
// Close closes the listener.
|
||||
// Any blocked Accept operations will be unblocked and return errors.
|
||||
func (l obfListener) Close() error {
|
||||
return l.listener.Close()
|
||||
}
|
||||
|
||||
// Addr returns the listener's network address.
|
||||
func (l obfListener) Addr() net.Addr {
|
||||
return l.listener.Addr()
|
||||
}
|
||||
@@ -0,0 +1,21 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// onceCloseListener wraps a net.Listener, protecting it from
|
||||
// multiple Close calls.
|
||||
type onceCloseListener struct {
|
||||
net.Listener
|
||||
once sync.Once
|
||||
err error
|
||||
}
|
||||
|
||||
func (o *onceCloseListener) Close() error {
|
||||
o.once.Do(func() {
|
||||
o.err = o.Listener.Close()
|
||||
})
|
||||
return o.err
|
||||
}
|
||||
@@ -0,0 +1,47 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/testutil"
|
||||
)
|
||||
|
||||
type closeMockListener struct {
|
||||
closed int
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *closeMockListener) Accept() (net.Conn, error) {
|
||||
panic("unexpected call")
|
||||
}
|
||||
|
||||
func (m *closeMockListener) Addr() net.Addr {
|
||||
panic("unexpected call")
|
||||
}
|
||||
|
||||
func (m *closeMockListener) Close() error {
|
||||
m.closed++
|
||||
return m.err
|
||||
}
|
||||
|
||||
func Test_onceCloseListener_Close(t *testing.T) {
|
||||
t.Run("OK", func(t *testing.T) {
|
||||
m := &closeMockListener{}
|
||||
once := onceCloseListener{Listener: m}
|
||||
require.NoError(t, once.Close())
|
||||
require.NoError(t, once.Close())
|
||||
require.Equal(t, 1, m.closed)
|
||||
})
|
||||
|
||||
t.Run("With Error", func(t *testing.T) {
|
||||
testErr := testutil.TestError()
|
||||
m := &closeMockListener{err: testErr}
|
||||
once := onceCloseListener{Listener: m}
|
||||
require.Equal(t, testErr, once.Close())
|
||||
require.Equal(t, testErr, once.Close())
|
||||
require.Equal(t, 1, m.closed)
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,86 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/proto/codec"
|
||||
)
|
||||
|
||||
// Protocol is MTProto transport protocol.
|
||||
//
|
||||
// See https://core.telegram.org/mtproto/mtproto-transports
|
||||
type Protocol struct {
|
||||
codec func() Codec
|
||||
}
|
||||
|
||||
// NewProtocol creates new transport protocol using user Codec constructor.
|
||||
//
|
||||
// See https://core.telegram.org/mtproto/mtproto-transports
|
||||
func NewProtocol(getCodec func() Codec) Protocol {
|
||||
return Protocol{
|
||||
codec: getCodec,
|
||||
}
|
||||
}
|
||||
|
||||
// Telegram transport protocols.
|
||||
//
|
||||
// See https://core.telegram.org/mtproto/mtproto-transports
|
||||
var (
|
||||
// Abridged is abridged transport protocol.
|
||||
//
|
||||
// See https://core.telegram.org/mtproto/mtproto-transports#abridged
|
||||
Abridged = NewProtocol(func() Codec { return codec.Abridged{} })
|
||||
|
||||
// Intermediate is intermediate transport protocol.
|
||||
//
|
||||
// See https://core.telegram.org/mtproto/mtproto-transports#intermediate
|
||||
Intermediate = NewProtocol(func() Codec { return codec.Intermediate{} })
|
||||
|
||||
// PaddedIntermediate is padded intermediate transport protocol.
|
||||
//
|
||||
// See https://core.telegram.org/mtproto/mtproto-transports#padded-intermediate
|
||||
PaddedIntermediate = NewProtocol(func() Codec { return codec.PaddedIntermediate{} })
|
||||
|
||||
// Full is full transport protocol.
|
||||
//
|
||||
// See https://core.telegram.org/mtproto/mtproto-transports#full
|
||||
Full = NewProtocol(func() Codec { return &codec.Full{} })
|
||||
)
|
||||
|
||||
// Codec creates new codec using protocol settings.
|
||||
func (p Protocol) Codec() Codec {
|
||||
return p.codec()
|
||||
}
|
||||
|
||||
// CodecNoHeader is Codec without header.
|
||||
func (p Protocol) CodecNoHeader() Codec {
|
||||
return codec.NoHeader{Codec: p.codec()}
|
||||
}
|
||||
|
||||
// Handshake inits given net.Conn as MTProto connection.
|
||||
func (p Protocol) Handshake(conn net.Conn) (Conn, error) {
|
||||
connCodec := p.codec()
|
||||
if err := connCodec.WriteHeader(conn); err != nil {
|
||||
return nil, errors.Wrap(err, "write header")
|
||||
}
|
||||
|
||||
return &connection{
|
||||
conn: conn,
|
||||
codec: connCodec,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Pipe creates an in-memory MTProto connection.
|
||||
func (p Protocol) Pipe() (a, b Conn) {
|
||||
p1, p2 := net.Pipe()
|
||||
|
||||
return &connection{
|
||||
conn: p1,
|
||||
codec: p.codec(),
|
||||
}, &connection{
|
||||
conn: p2,
|
||||
codec: p.codec(),
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,39 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
)
|
||||
|
||||
func TestProtocol_Pipe(t *testing.T) {
|
||||
a := require.New(t)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
payload := []byte("abcdabcd")
|
||||
test := func(c1, c2 Conn) {
|
||||
go func() {
|
||||
b1 := &bin.Buffer{Buf: payload}
|
||||
a.NoError(c1.Send(ctx, b1))
|
||||
}()
|
||||
|
||||
b2 := &bin.Buffer{}
|
||||
a.NoError(c2.Recv(ctx, b2))
|
||||
a.Equal(payload, b2.Buf)
|
||||
}
|
||||
|
||||
c1, c2 := Intermediate.Pipe()
|
||||
defer func() {
|
||||
a.NoError(c1.Close())
|
||||
a.NoError(c2.Close())
|
||||
}()
|
||||
|
||||
test(c1, c2)
|
||||
test(c2, c1)
|
||||
}
|
||||
@@ -0,0 +1,131 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
"github.com/coder/websocket"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/mtproxy/obfuscated2"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/proto/codec"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tdsync"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/wsutil"
|
||||
)
|
||||
|
||||
type wsListener struct {
|
||||
addr net.Addr
|
||||
ch chan *wsServerConn
|
||||
closed *tdsync.Ready
|
||||
}
|
||||
|
||||
// WebsocketListener creates new MTProto Websocket listener.
|
||||
func WebsocketListener(addr net.Addr) (net.Listener, http.Handler) {
|
||||
l := wsListener{
|
||||
addr: addr,
|
||||
ch: make(chan *wsServerConn, 1),
|
||||
closed: tdsync.NewReady(),
|
||||
}
|
||||
return l, l
|
||||
}
|
||||
|
||||
func (l wsListener) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
wsConn, err := websocket.Accept(w, r, &websocket.AcceptOptions{
|
||||
Subprotocols: []string{"binary"},
|
||||
})
|
||||
if err != nil {
|
||||
w.WriteHeader(400)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = wsConn.Close(websocket.StatusNormalClosure, "Close")
|
||||
}()
|
||||
|
||||
conn := wsutil.NetConn(wsConn)
|
||||
rw, md, err := obfuscated2.Accept(conn, nil)
|
||||
if err != nil {
|
||||
w.WriteHeader(400)
|
||||
return
|
||||
}
|
||||
|
||||
var tag *bytes.Reader
|
||||
if md.Protocol[0] == codec.AbridgedClientStart[0] {
|
||||
// Abridged sends only byte for tag.
|
||||
tag = bytes.NewReader(md.Protocol[:1])
|
||||
} else {
|
||||
tag = bytes.NewReader(md.Protocol[:])
|
||||
}
|
||||
|
||||
accepted := &wsServerConn{
|
||||
closed: *tdsync.NewReady(),
|
||||
// Add codec tag in the begin of stream to emulate TCP fully.
|
||||
// MTProto sends codec tag in plain TCP connections, but not in obfuscated2 (Websocket/MTProxy).
|
||||
reader: io.MultiReader(tag, rw),
|
||||
writer: rw,
|
||||
Conn: conn,
|
||||
}
|
||||
|
||||
reqCtx := r.Context().Done()
|
||||
closed := l.closed.Ready()
|
||||
|
||||
// Pass connection to the Accept().
|
||||
select {
|
||||
case <-reqCtx:
|
||||
return
|
||||
case <-closed:
|
||||
return
|
||||
case l.ch <- accepted:
|
||||
}
|
||||
|
||||
// Await close or shutdown.
|
||||
select {
|
||||
case <-reqCtx:
|
||||
return
|
||||
case <-closed:
|
||||
return
|
||||
case <-accepted.closed.Ready():
|
||||
}
|
||||
}
|
||||
|
||||
func (l wsListener) Accept() (net.Conn, error) {
|
||||
r := l.closed.Ready()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-r:
|
||||
return nil, net.ErrClosed
|
||||
case conn := <-l.ch:
|
||||
return conn, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l wsListener) Close() error {
|
||||
l.closed.Signal()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l wsListener) Addr() net.Addr {
|
||||
return l.addr
|
||||
}
|
||||
|
||||
type wsServerConn struct {
|
||||
closed tdsync.Ready
|
||||
reader io.Reader
|
||||
writer io.Writer
|
||||
net.Conn
|
||||
}
|
||||
|
||||
func (c *wsServerConn) Read(p []byte) (int, error) {
|
||||
return c.reader.Read(p)
|
||||
}
|
||||
|
||||
func (c *wsServerConn) Write(p []byte) (int, error) {
|
||||
return c.writer.Write(p)
|
||||
}
|
||||
|
||||
func (c *wsServerConn) Close() error {
|
||||
c.closed.Signal()
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,76 @@
|
||||
package transport_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/dcs"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/transport"
|
||||
)
|
||||
|
||||
func TestWebsocketListener(t *testing.T) {
|
||||
a := require.New(t)
|
||||
ctx := context.Background()
|
||||
|
||||
var handler http.Handler
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
handler.ServeHTTP(w, r)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
listener, h := transport.WebsocketListener(srv.Listener.Addr())
|
||||
handler = h
|
||||
list := dcs.List{
|
||||
Domains: map[int]string{
|
||||
2: srv.URL,
|
||||
},
|
||||
}
|
||||
|
||||
server := transport.Listen(listener)
|
||||
defer server.Close()
|
||||
done := make(chan struct{})
|
||||
|
||||
grp, ctx := errgroup.WithContext(ctx)
|
||||
grp.Go(func() error {
|
||||
defer close(done)
|
||||
|
||||
conn, err := server.Accept()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "accept")
|
||||
}
|
||||
|
||||
var b bin.Buffer
|
||||
if err := conn.Recv(ctx, &b); err != nil {
|
||||
return errors.Wrap(err, "recv")
|
||||
}
|
||||
|
||||
if err := conn.Send(ctx, &b); err != nil {
|
||||
return errors.Wrap(err, "send")
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
rs := dcs.Websocket(dcs.WebsocketOptions{})
|
||||
conn, err := rs.Primary(ctx, 2, list)
|
||||
a.NoError(err)
|
||||
|
||||
data, err := io.ReadAll(io.LimitReader(rand.Reader, 1024))
|
||||
a.NoError(err)
|
||||
a.NoError(conn.Send(ctx, &bin.Buffer{Buf: data}))
|
||||
|
||||
var b bin.Buffer
|
||||
a.NoError(conn.Recv(ctx, &b))
|
||||
a.Equal(data, b.Buf)
|
||||
|
||||
a.NoError(grp.Wait())
|
||||
}
|
||||
Reference in New Issue
Block a user