move gotd fork into repo. (#111)
- update to latest telegram layer - remove some references to fields in tg.Entities that don't exist in the schema - originally added here: https://github.com/beeper/td/commit/820929062a2ba0104397bc01235ab58a9cff780e - referenced here - https://github.com/mautrix/telegramgo/commit/124f0967ed195b5a380c9bd02e170ada9710dde3 - https://github.com/mautrix/telegramgo/commit/4205047aab2e0639217148b5d125bfaab668bd8e
This commit is contained in:
@@ -0,0 +1,19 @@
|
||||
package pool
|
||||
|
||||
import "go.uber.org/zap"
|
||||
|
||||
// DCOptions is a Telegram data center connections pool options.
|
||||
type DCOptions struct {
|
||||
// Logger is instance of zap.Logger. No logs by default.
|
||||
Logger *zap.Logger
|
||||
// MTProto options for connections.
|
||||
// Opened connection limit to the DC.
|
||||
MaxOpenConnections int64
|
||||
}
|
||||
|
||||
func (d *DCOptions) setDefaults() {
|
||||
if d.Logger == nil {
|
||||
d.Logger = zap.NewNop()
|
||||
}
|
||||
// It's okay to use zero value for MaxOpenConnections.
|
||||
}
|
||||
@@ -0,0 +1,278 @@
|
||||
// Package pool contains Telegram connections pool.
|
||||
package pool
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
"go.uber.org/atomic"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tdsync"
|
||||
)
|
||||
|
||||
// DC represents connection pool to one data center.
|
||||
type DC struct {
|
||||
id int
|
||||
|
||||
// Connection constructor.
|
||||
newConn func() Conn
|
||||
|
||||
// Wrappers for external world, like logs or PRNG.
|
||||
log *zap.Logger // immutable
|
||||
|
||||
// DC context. Will be canceled by Run on exit.
|
||||
ctx context.Context // immutable
|
||||
cancel context.CancelFunc // immutable
|
||||
|
||||
// Connections supervisor.
|
||||
grp *tdsync.Supervisor
|
||||
// Free connections.
|
||||
free []*poolConn
|
||||
// Total connections.
|
||||
total int64
|
||||
// Connection id monotonic counter.
|
||||
nextConn atomic.Int64
|
||||
freeReq *reqMap
|
||||
// DC mutex.
|
||||
mu sync.Mutex
|
||||
|
||||
// Limit of connections.
|
||||
max int64 // immutable
|
||||
|
||||
// Signal connection for cases when all connections are dead, but all requests waiting for
|
||||
// free connection in 3rd acquire case.
|
||||
stuck *tdsync.ResetReady
|
||||
|
||||
closed atomic.Bool
|
||||
}
|
||||
|
||||
// NewDC creates new uninitialized DC.
|
||||
func NewDC(ctx context.Context, id int, newConn func() Conn, opts DCOptions) *DC {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
opts.setDefaults()
|
||||
return &DC{
|
||||
id: id,
|
||||
newConn: newConn,
|
||||
log: opts.Logger,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
grp: tdsync.NewSupervisor(ctx),
|
||||
freeReq: newReqMap(),
|
||||
max: opts.MaxOpenConnections,
|
||||
stuck: tdsync.NewResetReady(),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *DC) createConnection(id int64) *poolConn {
|
||||
conn := &poolConn{
|
||||
Conn: c.newConn(),
|
||||
id: id,
|
||||
dc: c,
|
||||
deleted: atomic.NewBool(false),
|
||||
dead: tdsync.NewReady(),
|
||||
}
|
||||
|
||||
c.grp.Go(func(groupCtx context.Context) (err error) {
|
||||
defer c.dead(conn, err)
|
||||
return conn.Run(groupCtx)
|
||||
})
|
||||
|
||||
return conn
|
||||
}
|
||||
|
||||
func (c *DC) dead(r *poolConn, deadErr error) {
|
||||
if r.deleted.Swap(true) {
|
||||
return // Already deleted.
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.total--
|
||||
remaining := c.total
|
||||
if remaining < 0 {
|
||||
panic("unreachable: remaining can't be less than zero")
|
||||
}
|
||||
|
||||
idx := -1
|
||||
for i, conn := range c.free {
|
||||
// Search connection by pointer.
|
||||
if conn.id == r.id {
|
||||
idx = i
|
||||
}
|
||||
}
|
||||
|
||||
if idx >= 0 {
|
||||
// Delete by index from slice tricks.
|
||||
copy(c.free[idx:], c.free[idx+1:])
|
||||
// Delete reference to prevent resource leaking.
|
||||
c.free[len(c.free)-1] = nil
|
||||
c.free = c.free[:len(c.free)-1]
|
||||
}
|
||||
|
||||
r.dead.Signal()
|
||||
c.stuck.Reset()
|
||||
|
||||
c.log.Debug("Connection died",
|
||||
zap.Int64("remaining", remaining),
|
||||
zap.Int64("conn_id", r.id),
|
||||
zap.Error(deadErr),
|
||||
)
|
||||
}
|
||||
|
||||
func (c *DC) pop() (r *poolConn, ok bool) {
|
||||
l := len(c.free)
|
||||
if l > 0 {
|
||||
r, c.free = c.free[l-1], c.free[:l-1]
|
||||
|
||||
return r, true
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (c *DC) release(r *poolConn) {
|
||||
if r == nil {
|
||||
return
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.freeReq.transfer(r) {
|
||||
c.log.Debug("Transfer connection to requester", zap.Int64("conn_id", r.id))
|
||||
return
|
||||
}
|
||||
c.log.Debug("Connection released", zap.Int64("conn_id", r.id))
|
||||
c.free = append(c.free, r)
|
||||
}
|
||||
|
||||
var errDCIsClosed = errors.New("DC is closed")
|
||||
|
||||
func (c *DC) acquire(ctx context.Context) (r *poolConn, err error) { // nolint:gocyclo
|
||||
retry:
|
||||
c.mu.Lock()
|
||||
// 1st case: have free connections.
|
||||
if r, ok := c.pop(); ok {
|
||||
c.mu.Unlock()
|
||||
select {
|
||||
case <-r.Dead():
|
||||
c.dead(r, nil)
|
||||
goto retry
|
||||
default:
|
||||
}
|
||||
c.log.Debug("Re-using free connection", zap.Int64("conn_id", r.id))
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// 2nd case: no free connections, but can create one.
|
||||
// c.max < 1 means unlimited
|
||||
if c.max < 1 || c.total < c.max {
|
||||
c.total++
|
||||
c.mu.Unlock()
|
||||
|
||||
id := c.nextConn.Inc()
|
||||
c.log.Debug("Creating new connection",
|
||||
zap.Int64("conn_id", id),
|
||||
)
|
||||
conn := c.createConnection(id)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-c.ctx.Done():
|
||||
return nil, errors.Wrap(c.ctx.Err(), "DC closed")
|
||||
case <-conn.Ready():
|
||||
return conn, nil
|
||||
case <-conn.Dead():
|
||||
c.dead(conn, nil)
|
||||
goto retry
|
||||
}
|
||||
}
|
||||
|
||||
// 3rd case: no free connections, can't create yet one, wait for free.
|
||||
key, ch := c.freeReq.request()
|
||||
c.mu.Unlock()
|
||||
c.log.Debug("Waiting for free connect", zap.Int64("request_id", int64(key)))
|
||||
|
||||
select {
|
||||
case conn := <-ch:
|
||||
c.log.Debug("Got connection for request",
|
||||
zap.Int64("conn_id", conn.id),
|
||||
zap.Int64("request_id", int64(key)),
|
||||
)
|
||||
return conn, nil
|
||||
case <-c.stuck.Ready():
|
||||
c.log.Debug("Some connection dead, try to create new connection, cancel waiting")
|
||||
|
||||
c.freeReq.delete(key)
|
||||
select {
|
||||
default:
|
||||
case conn, ok := <-ch:
|
||||
if ok && conn != nil {
|
||||
return conn, nil
|
||||
}
|
||||
}
|
||||
|
||||
goto retry
|
||||
case <-ctx.Done():
|
||||
err = ctx.Err()
|
||||
case <-c.ctx.Done():
|
||||
err = errors.Wrap(c.ctx.Err(), "DC closed")
|
||||
}
|
||||
|
||||
// Executed only if at least one of context is Done.
|
||||
c.freeReq.delete(key)
|
||||
select {
|
||||
default:
|
||||
case conn, ok := <-ch:
|
||||
if ok && conn != nil {
|
||||
c.release(conn)
|
||||
}
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Invoke sends MTProto request using one of pool connection.
|
||||
func (c *DC) Invoke(ctx context.Context, input bin.Encoder, output bin.Decoder) error {
|
||||
if c.closed.Load() {
|
||||
return errDCIsClosed
|
||||
}
|
||||
|
||||
for {
|
||||
conn, err := c.acquire(ctx)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrConnDead) {
|
||||
continue
|
||||
}
|
||||
return errors.Wrap(err, "acquire connection")
|
||||
}
|
||||
|
||||
c.log.Debug("DC Invoke")
|
||||
err = conn.Invoke(ctx, input, output)
|
||||
c.release(conn)
|
||||
if err != nil {
|
||||
c.log.Debug("DC Invoke failed", zap.Error(err))
|
||||
return errors.Wrap(err, "invoke pool")
|
||||
}
|
||||
|
||||
c.log.Debug("DC Invoke complete")
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Close waits while all ongoing requests will be done or until given context is done.
|
||||
// Then, closes the DC.
|
||||
func (c *DC) Close() error {
|
||||
if c.closed.Swap(true) {
|
||||
return errors.New("DC already closed")
|
||||
}
|
||||
c.log.Debug("Closing DC")
|
||||
defer c.log.Debug("DC closed")
|
||||
|
||||
c.cancel()
|
||||
return c.grp.Wait()
|
||||
}
|
||||
@@ -0,0 +1,34 @@
|
||||
package pool
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
"go.uber.org/atomic"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tdsync"
|
||||
)
|
||||
|
||||
// ErrConnDead means that connection is closed and can't be used anymore.
|
||||
var ErrConnDead = errors.New("connection dead")
|
||||
|
||||
// Conn represents Telegram MTProto connection.
|
||||
type Conn interface {
|
||||
Run(ctx context.Context) error
|
||||
Invoke(ctx context.Context, input bin.Encoder, output bin.Decoder) error
|
||||
Ping(ctx context.Context) error
|
||||
Ready() <-chan struct{}
|
||||
}
|
||||
|
||||
type poolConn struct {
|
||||
Conn
|
||||
id int64 // immutable
|
||||
dc *DC // immutable
|
||||
deleted *atomic.Bool
|
||||
dead *tdsync.Ready
|
||||
}
|
||||
|
||||
func (p *poolConn) Dead() <-chan struct{} {
|
||||
return p.dead.Ready()
|
||||
}
|
||||
@@ -0,0 +1,152 @@
|
||||
package pool
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tdsync"
|
||||
)
|
||||
|
||||
type mockConn struct {
|
||||
ready *tdsync.Ready
|
||||
readyOnRun bool
|
||||
}
|
||||
|
||||
func (mockConn) Ping(ctx context.Context) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func newMockConn(readyOnRun bool) mockConn {
|
||||
return mockConn{
|
||||
ready: tdsync.NewReady(),
|
||||
readyOnRun: readyOnRun,
|
||||
}
|
||||
}
|
||||
|
||||
func (m mockConn) Run(ctx context.Context) error {
|
||||
if m.readyOnRun {
|
||||
m.ready.Signal()
|
||||
}
|
||||
|
||||
<-ctx.Done()
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
func (m mockConn) Invoke(ctx context.Context, input bin.Encoder, output bin.Decoder) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m mockConn) Ready() <-chan struct{} {
|
||||
return m.ready.Ready()
|
||||
}
|
||||
|
||||
func TestDC_acquire(t *testing.T) {
|
||||
t.Run("AcquireRelease", func(t *testing.T) {
|
||||
a := require.New(t)
|
||||
ctx := context.Background()
|
||||
|
||||
created := 0
|
||||
p := NewDC(ctx, 2, func() Conn {
|
||||
created++
|
||||
return newMockConn(true)
|
||||
}, DCOptions{
|
||||
MaxOpenConnections: 1,
|
||||
})
|
||||
defer func() {
|
||||
a.NoError(p.Close())
|
||||
}()
|
||||
|
||||
c, err := p.acquire(ctx)
|
||||
a.NoError(err)
|
||||
a.NotNil(c)
|
||||
a.Equal(1, created, "Pool must create new connection")
|
||||
|
||||
p.release(c)
|
||||
|
||||
_, err = p.acquire(ctx)
|
||||
a.NoError(err)
|
||||
a.Equal(1, created, "Pool must re-use connection")
|
||||
|
||||
p.release(c)
|
||||
})
|
||||
t.Run("CancelWhileWait", func(t *testing.T) {
|
||||
a := require.New(t)
|
||||
ctx := context.Background()
|
||||
|
||||
created := 0
|
||||
p := NewDC(ctx, 2, func() Conn {
|
||||
created++
|
||||
return newMockConn(true)
|
||||
}, DCOptions{
|
||||
MaxOpenConnections: 1,
|
||||
})
|
||||
defer func() {
|
||||
a.NoError(p.Close())
|
||||
}()
|
||||
|
||||
c, err := p.acquire(ctx)
|
||||
a.NoError(err)
|
||||
a.NotNil(c)
|
||||
a.Equal(1, created, "Pool must create new connection")
|
||||
|
||||
canceledCtx, cancel := context.WithCancel(ctx)
|
||||
cancel()
|
||||
c2, err := p.acquire(canceledCtx)
|
||||
a.ErrorIs(err, context.Canceled)
|
||||
a.Nil(c2)
|
||||
a.Empty(p.freeReq.m)
|
||||
})
|
||||
t.Run("Dead", func(t *testing.T) {
|
||||
a := require.New(t)
|
||||
ctx := context.Background()
|
||||
|
||||
created := 0
|
||||
p := NewDC(ctx, 2, func() Conn {
|
||||
created++
|
||||
return newMockConn(true)
|
||||
}, DCOptions{
|
||||
MaxOpenConnections: 1,
|
||||
})
|
||||
defer func() {
|
||||
a.NoError(p.Close())
|
||||
}()
|
||||
|
||||
c, err := p.acquire(ctx)
|
||||
a.NoError(err)
|
||||
a.NotNil(c)
|
||||
a.Equal(1, created, "Pool must create new connection")
|
||||
|
||||
p.release(c)
|
||||
c.dead.Signal()
|
||||
|
||||
_, err = p.acquire(ctx)
|
||||
a.NoError(err)
|
||||
a.Equal(2, created, "Pool must not re-use dead connection")
|
||||
})
|
||||
t.Run("CancelWhileCreate", func(t *testing.T) {
|
||||
a := require.New(t)
|
||||
ctx := context.Background()
|
||||
|
||||
created := 0
|
||||
p := NewDC(ctx, 2, func() Conn {
|
||||
created++
|
||||
return newMockConn(false)
|
||||
}, DCOptions{
|
||||
MaxOpenConnections: 1,
|
||||
})
|
||||
defer func() {
|
||||
a.NoError(p.Close())
|
||||
}()
|
||||
|
||||
canceledCtx, cancel := context.WithCancel(ctx)
|
||||
cancel()
|
||||
c, err := p.acquire(canceledCtx)
|
||||
a.ErrorIs(err, context.Canceled)
|
||||
a.Nil(c)
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,63 @@
|
||||
package pool
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"go.uber.org/atomic"
|
||||
)
|
||||
|
||||
type reqKey int64
|
||||
|
||||
type reqMap struct {
|
||||
m map[reqKey]chan *poolConn
|
||||
mux sync.Mutex
|
||||
_ [4]byte
|
||||
|
||||
nextRequest atomic.Int64
|
||||
}
|
||||
|
||||
func newReqMap() *reqMap {
|
||||
return &reqMap{
|
||||
m: map[reqKey]chan *poolConn{},
|
||||
}
|
||||
}
|
||||
|
||||
func (r *reqMap) request() (key reqKey, ch chan *poolConn) {
|
||||
key = reqKey(r.nextRequest.Inc())
|
||||
ch = make(chan *poolConn, 1)
|
||||
|
||||
r.mux.Lock()
|
||||
r.m[key] = ch
|
||||
r.mux.Unlock()
|
||||
return key, ch
|
||||
}
|
||||
|
||||
func (r *reqMap) transfer(c *poolConn) bool {
|
||||
r.mux.Lock()
|
||||
if len(r.m) < 1 { // no requests
|
||||
r.mux.Unlock()
|
||||
return false
|
||||
}
|
||||
|
||||
var ch chan *poolConn
|
||||
var k reqKey
|
||||
for k, ch = range r.m { // Get one from map.
|
||||
break
|
||||
}
|
||||
delete(r.m, k) // Remove from pending requests.
|
||||
r.mux.Unlock()
|
||||
|
||||
if ch == nil {
|
||||
panic("unreachable: channel can't be nil due to map not empty")
|
||||
}
|
||||
|
||||
ch <- c
|
||||
close(ch)
|
||||
return true
|
||||
}
|
||||
|
||||
func (r *reqMap) delete(key reqKey) {
|
||||
r.mux.Lock()
|
||||
delete(r.m, key)
|
||||
r.mux.Unlock()
|
||||
}
|
||||
@@ -0,0 +1,39 @@
|
||||
package pool
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tdsync"
|
||||
)
|
||||
|
||||
func TestReqMap(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
g := tdsync.NewCancellableGroup(ctx)
|
||||
req := newReqMap()
|
||||
|
||||
key, ch := req.request()
|
||||
g.Go(func(ctx context.Context) error {
|
||||
defer req.delete(key)
|
||||
|
||||
select {
|
||||
case <-ch:
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
})
|
||||
|
||||
g.Go(func(ctx context.Context) error {
|
||||
require.True(t, req.transfer(&poolConn{}))
|
||||
require.False(t, req.transfer(&poolConn{}))
|
||||
return nil
|
||||
})
|
||||
|
||||
require.NoError(t, g.Wait())
|
||||
}
|
||||
@@ -0,0 +1,64 @@
|
||||
package pool
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/mtproto"
|
||||
)
|
||||
|
||||
// Session represents DC session.
|
||||
type Session struct {
|
||||
DC int
|
||||
AuthKey crypto.AuthKey
|
||||
Salt int64
|
||||
}
|
||||
|
||||
// SyncSession is synchronization helper for Session.
|
||||
type SyncSession struct {
|
||||
data Session
|
||||
mux sync.RWMutex
|
||||
}
|
||||
|
||||
// NewSyncSession creates new SyncSession.
|
||||
func NewSyncSession(data Session) *SyncSession {
|
||||
return &SyncSession{
|
||||
data: data,
|
||||
}
|
||||
}
|
||||
|
||||
// Store saves given Session.
|
||||
func (s *SyncSession) Store(data Session) {
|
||||
s.mux.Lock()
|
||||
s.data = data
|
||||
s.mux.Unlock()
|
||||
}
|
||||
|
||||
// Migrate changes current DC and its addr, zeroes AuthKey and Salt.
|
||||
func (s *SyncSession) Migrate(dc int) {
|
||||
s.mux.Lock()
|
||||
s.data.DC = dc
|
||||
s.data.AuthKey = crypto.AuthKey{}
|
||||
s.data.Salt = 0
|
||||
s.mux.Unlock()
|
||||
}
|
||||
|
||||
// Options fills Key and Salt field of given Options using stored session and returns it.
|
||||
func (s *SyncSession) Options(opts mtproto.Options) (mtproto.Options, Session) {
|
||||
s.mux.RLock()
|
||||
data := s.data
|
||||
s.mux.RUnlock()
|
||||
|
||||
opts.Key = data.AuthKey
|
||||
opts.Salt = data.Salt
|
||||
return opts, data
|
||||
}
|
||||
|
||||
// Load gets session and returns it.
|
||||
func (s *SyncSession) Load() (data Session) {
|
||||
s.mux.RLock()
|
||||
data = s.data
|
||||
s.mux.RUnlock()
|
||||
|
||||
return
|
||||
}
|
||||
Reference in New Issue
Block a user