package manager import ( "context" "fmt" "sync" "time" "github.com/cenkalti/backoff/v4" "github.com/go-faster/errors" "go.uber.org/zap" "go.mau.fi/mautrix-telegram/pkg/gotd/bin" "go.mau.fi/mautrix-telegram/pkg/gotd/clock" "go.mau.fi/mautrix-telegram/pkg/gotd/mtproto" "go.mau.fi/mautrix-telegram/pkg/gotd/pool" "go.mau.fi/mautrix-telegram/pkg/gotd/tdsync" "go.mau.fi/mautrix-telegram/pkg/gotd/tg" "go.mau.fi/mautrix-telegram/pkg/gotd/tgerr" ) type protoConn interface { Invoke(ctx context.Context, input bin.Encoder, output bin.Decoder) error Run(ctx context.Context, f func(ctx context.Context) error) error Ping(ctx context.Context) error } //go:generate go tool golang.org/x/tools/cmd/stringer -type=ConnMode // ConnMode represents connection mode. type ConnMode byte const ( // ConnModeUpdates is update connection mode. ConnModeUpdates ConnMode = iota // ConnModeData is data connection mode. ConnModeData // ConnModeCDN is CDN connection mode. ConnModeCDN ) // Conn is a Telegram client connection. type Conn struct { // Connection parameters. mode ConnMode // immutable // MTProto connection. proto protoConn // immutable // InitConnection parameters. appID int // immutable device DeviceConfig // immutable // setup is callback which called after initConnection, but before ready signaling. // This is necessary to transfer auth from previous connection to another DC. setup SetupCallback // nilable // onDead is called on connection death. onDead func() // Wrappers for external world, like logs or PRNG. // Should be immutable. clock clock.Clock // immutable log *zap.Logger // immutable // Handler passed by client. handler Handler // immutable // State fields. cfg tg.Config ongoing int latest time.Time mux sync.Mutex sessionInit *tdsync.Ready // immutable gotConfig *tdsync.Ready // immutable dead *tdsync.Ready // immutable connBackoff func(ctx context.Context) backoff.BackOff // immutable } // OnSession implements mtproto.Handler. func (c *Conn) OnSession(session mtproto.Session) error { c.log.Info("SessionInit") c.sessionInit.Signal() // Waiting for config, because OnSession can occur before we set config. select { case <-c.gotConfig.Ready(): case <-c.dead.Ready(): return nil } c.mux.Lock() cfg := c.cfg c.mux.Unlock() return c.handler.OnSession(cfg, session) } func (c *Conn) trackInvoke() func(bin.Encoder, bin.Decoder, *error) { start := c.clock.Now() c.mux.Lock() defer c.mux.Unlock() c.ongoing++ c.latest = start return func(input bin.Encoder, output bin.Decoder, retErr *error) { c.mux.Lock() defer c.mux.Unlock() c.ongoing-- end := c.clock.Now() c.latest = end var reqField zap.Field respField := zap.Skip() if *retErr != nil { respField = zap.Error(*retErr) } else if box, isFile := output.(*tg.UploadFileBox); isFile { if uploadFile, ok := box.File.(*tg.UploadFile); ok { respField = zap.Dict("response_payload", zap.Int("size", len(uploadFile.Bytes)), zap.Int("mtime", uploadFile.Mtime), zap.Stringer("type", uploadFile.Type)) } } else if cdnFile, isCDNFile := output.(*tg.UploadCDNFileBox); isCDNFile { if uploadFile, ok := cdnFile.CdnFile.(*tg.UploadCDNFile); ok { respField = zap.Dict("response_payload", zap.Int("size", len(uploadFile.Bytes))) } } else { respField = zap.Any("response_payload", output) } if f, isFile := input.(*tg.UploadSaveFilePartRequest); isFile { reqField = zap.String("request_payload", fmt.Sprintf("%d bytes for part %d of %d", len(f.Bytes), f.FilePart, f.FileID)) } else { reqField = zap.Any("request_payload", input) } c.log.Debug("Request completed", zap.Duration("duration", end.Sub(start)), zap.Int("ongoing", c.ongoing), reqField, respField, ) } } // Run initialize connection. func (c *Conn) Run(ctx context.Context) (err error) { defer c.dead.Signal() defer func() { if err != nil && ctx.Err() == nil { c.log.Debug("Connection dead", zap.Error(err)) if c.onDead != nil { c.onDead() } } }() return c.proto.Run(ctx, func(ctx context.Context) error { // Signal death on init error. Otherwise connection shutdown // deadlocks in OnSession that occurs before init fails. err := c.init(ctx) if err != nil { c.dead.Signal() } return err }) } func (c *Conn) waitSession(ctx context.Context) error { select { case <-c.sessionInit.Ready(): return nil case <-c.dead.Ready(): return pool.ErrConnDead case <-ctx.Done(): return ctx.Err() } } // Ready returns channel to determine connection readiness. // Useful for pooling. func (c *Conn) Ready() <-chan struct{} { return c.sessionInit.Ready() } // Invoke implements Invoker. func (c *Conn) Invoke(ctx context.Context, input bin.Encoder, output bin.Decoder) (retErr error) { // Tracking ongoing invokes. defer c.trackInvoke()(input, output, &retErr) if err := c.waitSession(ctx); err != nil { return errors.Wrap(err, "waitSession") } q := c.wrapRequest(noopDecoder{input}) req := c.wrapRequest(&tg.InvokeWithLayerRequest{ Layer: tg.Layer, Query: q, }) return c.proto.Invoke(ctx, req, output) } // OnMessage implements mtproto.Handler. func (c *Conn) OnMessage(b *bin.Buffer) error { return c.handler.OnMessage(b) } type noopDecoder struct { bin.Encoder } func (n noopDecoder) Decode(b *bin.Buffer) error { return errors.New("not implemented") } func (c *Conn) wrapRequest(req bin.Object) bin.Object { if c.mode != ConnModeUpdates { return &tg.InvokeWithoutUpdatesRequest{ Query: req, } } return req } func (c *Conn) init(ctx context.Context) error { c.log.Debug("Initializing") q := c.wrapRequest(&tg.InitConnectionRequest{ APIID: c.appID, DeviceModel: c.device.DeviceModel, SystemVersion: c.device.SystemVersion, AppVersion: c.device.AppVersion, SystemLangCode: c.device.SystemLangCode, LangPack: c.device.LangPack, LangCode: c.device.LangCode, Proxy: c.device.Proxy, Params: c.device.Params, Query: c.wrapRequest(&tg.HelpGetConfigRequest{}), }) req := c.wrapRequest(&tg.InvokeWithLayerRequest{ Layer: tg.Layer, Query: q, }) var cfg tg.Config if err := backoff.RetryNotify(func() error { if err := c.proto.Invoke(ctx, req, &cfg); err != nil { if tgerr.Is(err, tgerr.ErrFloodWait) { // Server sometimes returns FLOOD_WAIT(0) if you create // multiple connections in short period of time. // // See https://github.com/gotd/td/issues/388. return errors.Wrap(err, "flood wait") } // Not retrying other errors. return backoff.Permanent(errors.Wrap(err, "invoke")) } return nil }, c.connBackoff(ctx), func(err error, duration time.Duration) { c.log.Debug("Retrying connection initialization", zap.Error(err), zap.Duration("duration", duration), ) }); err != nil { return errors.Wrap(err, "initConnection") } if c.setup != nil { if err := c.setup(ctx, c); err != nil { return errors.Wrap(err, "setup") } } c.mux.Lock() c.latest = c.clock.Now() c.cfg = cfg c.mux.Unlock() c.gotConfig.Signal() return nil } // Ping calls ping for underlying protocol connection. func (c *Conn) Ping(ctx context.Context) error { return c.proto.Ping(ctx) }