diff --git a/pkg/gotd/exchange/client.go b/pkg/gotd/exchange/client.go index 59a1e6ad..a5463057 100644 --- a/pkg/gotd/exchange/client.go +++ b/pkg/gotd/exchange/client.go @@ -2,6 +2,7 @@ package exchange import ( "io" + "time" "go.uber.org/zap" @@ -23,4 +24,6 @@ type ClientExchangeResult struct { AuthKey crypto.AuthKey SessionID int64 ServerSalt int64 + + ServerTimeOffset time.Duration } diff --git a/pkg/gotd/exchange/client_flow.go b/pkg/gotd/exchange/client_flow.go index ba8a273a..7c5d3b4f 100644 --- a/pkg/gotd/exchange/client_flow.go +++ b/pkg/gotd/exchange/client_flow.go @@ -4,6 +4,7 @@ import ( "context" "crypto/rand" "math/big" + "time" "github.com/go-faster/errors" "go.uber.org/zap" @@ -118,6 +119,7 @@ Loop: EncryptedData: encryptedData, } c.log.Debug("Sending ReqDHParamsRequest") + reqStart := c.clock.Now() if err := c.writeUnencrypted(ctx, b, reqDHParams); err != nil { return ClientExchangeResult{}, errors.Wrap(err, "write ReqDHParamsRequest") } @@ -126,6 +128,7 @@ Loop: if err := c.conn.Recv(ctx, b); err != nil { return ClientExchangeResult{}, errors.Wrap(err, "read ServerDHParams message") } + roundtripDuration := c.clock.Now().Sub(reqStart) c.log.Debug("Received server ServerDHParams") var plaintextMsg proto.UnencryptedMessage @@ -262,6 +265,8 @@ Loop: AuthKey: crypto.AuthKey{Value: key, ID: authKeyID}, SessionID: sessionID, ServerSalt: serverSalt, + + ServerTimeOffset: time.Unix(int64(innerData.ServerTime), 0).Sub(reqStart.Add(roundtripDuration / 2)), }, nil case *mt.DhGenRetry: // dh_gen_retry#46dc1fb9 return ClientExchangeResult{}, errors.Errorf("retry required: %x", v.NewNonceHash2) diff --git a/pkg/gotd/mtproto/conn.go b/pkg/gotd/mtproto/conn.go index 70a6b46b..2229bfd9 100644 --- a/pkg/gotd/mtproto/conn.go +++ b/pkg/gotd/mtproto/conn.go @@ -31,6 +31,7 @@ type Handler interface { // MessageIDSource is message id generator. type MessageIDSource interface { New(t proto.MessageType) int64 + Reset() } // MessageBuf is message id buffer. @@ -73,6 +74,8 @@ type Conn struct { salt int64 sessionID int64 + serverTimeOffset time.Duration + // server salts fetched by getSalts. salts salts.Salts @@ -127,7 +130,6 @@ func New(dialer Dialer, opt Options) *Conn { rand: opt.Random, cipher: opt.Cipher, log: opt.Logger, - messageID: opt.MessageID, messageIDBuf: proto.NewMessageIDBuf(100), ackSendChan: make(chan int64), @@ -155,6 +157,7 @@ func New(dialer Dialer, opt Options) *Conn { saltFetchInterval: opt.SaltFetchInterval, getTimeout: opt.RequestTimeout, } + conn.messageID = proto.NewMessageIDGen(conn.TimeWithOffset) if conn.rpc == nil { conn.rpc = rpc.New(conn.writeContentMessage, rpc.Options{ Logger: opt.Logger.Named("rpc"), @@ -218,3 +221,40 @@ func (c *Conn) Run(ctx context.Context, f func(ctx context.Context) error) error } return nil } + +func (c *Conn) setServerTimeOffset(offset time.Duration) { + if offset == 0 { + offset = 1 + } + c.sessionMux.Lock() + c.serverTimeOffset = offset + c.sessionMux.Unlock() + if offset > 10*time.Second || offset < -10*time.Second { + c.log.Warn("Updated server time offset (high)", zap.Duration("offset", offset)) + } else { + c.log.Info("Updated server time offset", zap.Duration("offset", offset)) + } +} + +func (c *Conn) hasServerTimeOffset() bool { + c.sessionMux.RLock() + has := c.serverTimeOffset != 0 + c.sessionMux.RUnlock() + return has +} + +func (c *Conn) TimeWithOffset() (t time.Time) { + c.sessionMux.RLock() + t = c.clock.Now().Add(c.serverTimeOffset) + c.sessionMux.RUnlock() + return +} + +func (c *Conn) altTimeWithOffset() (t time.Time) { + c.sessionMux.RLock() + if c.serverTimeOffset != 0 { + t = c.clock.Now().Add(c.serverTimeOffset) + } + c.sessionMux.RUnlock() + return +} diff --git a/pkg/gotd/mtproto/conn_test.go b/pkg/gotd/mtproto/conn_test.go index c66440c1..4e5b106c 100644 --- a/pkg/gotd/mtproto/conn_test.go +++ b/pkg/gotd/mtproto/conn_test.go @@ -9,7 +9,6 @@ import ( "path/filepath" "strings" "testing" - "time" "github.com/stretchr/testify/assert" "go.uber.org/zap" @@ -44,10 +43,9 @@ func newTestClient(h testHandler, opts ...testClientOption) *Conn { }, rpc.Options{}) opt := Options{ - Logger: zap.NewNop(), - Random: rand.New(rand.NewSource(1)), - Key: crypto.Key{}.WithID(), - MessageID: proto.NewMessageIDGen(time.Now), + Logger: zap.NewNop(), + Random: rand.New(rand.NewSource(1)), + Key: crypto.Key{}.WithID(), engine: engine, } diff --git a/pkg/gotd/mtproto/connect.go b/pkg/gotd/mtproto/connect.go index 32c5286c..433b8f1c 100644 --- a/pkg/gotd/mtproto/connect.go +++ b/pkg/gotd/mtproto/connect.go @@ -84,10 +84,17 @@ func (c *Conn) createAuthKey(ctx context.Context) error { } c.sessionMux.Lock() + c.serverTimeOffset = r.ServerTimeOffset + if c.serverTimeOffset == 0 { + // Creating an auth key always calculates the offset and it should never be 0 in practice, + // but default to 1 just in case + c.serverTimeOffset = 1 + } c.authKey = r.AuthKey c.sessionID = r.SessionID c.salt = r.ServerSalt c.sessionMux.Unlock() + c.log.Info("Created auth key", zap.Duration("server_time_offset", r.ServerTimeOffset)) return nil } diff --git a/pkg/gotd/mtproto/handle_bad_msg.go b/pkg/gotd/mtproto/handle_bad_msg.go index 68409a60..9839de76 100644 --- a/pkg/gotd/mtproto/handle_bad_msg.go +++ b/pkg/gotd/mtproto/handle_bad_msg.go @@ -7,11 +7,13 @@ import ( "go.mau.fi/mautrix-telegram/pkg/gotd/bin" "go.mau.fi/mautrix-telegram/pkg/gotd/mt" + "go.mau.fi/mautrix-telegram/pkg/gotd/proto" ) type badMessageError struct { - Code int - NewSalt int64 + Code int + NewSalt int64 + TimeResynced bool } const ( @@ -40,7 +42,8 @@ func (c badMessageError) Error() string { return description } -func (c *Conn) handleBadMsg(b *bin.Buffer) error { +func (c *Conn) handleBadMsg(msgID int64, b *bin.Buffer) error { + now := c.clock.Now() id, err := b.PeekID() if err != nil { return err @@ -51,8 +54,16 @@ func (c *Conn) handleBadMsg(b *bin.Buffer) error { if err := bad.Decode(b); err != nil { return err } + var resynced bool + if !c.hasServerTimeOffset() && (bad.ErrorCode == codeMessageIDTooLow || bad.ErrorCode == codeMessageIDTooHigh) { + created := proto.MessageID(msgID).Time() + c.setServerTimeOffset(created.Sub(now)) + c.messageID.Reset() + c.updateSalt() + resynced = true + } - c.rpc.NotifyError(bad.BadMsgID, &badMessageError{Code: bad.ErrorCode}) + c.rpc.NotifyError(bad.BadMsgID, &badMessageError{Code: bad.ErrorCode, TimeResynced: resynced}) return nil case mt.BadServerSaltTypeID: var bad mt.BadServerSalt diff --git a/pkg/gotd/mtproto/handle_message.go b/pkg/gotd/mtproto/handle_message.go index 625bf3e1..310ce3a5 100644 --- a/pkg/gotd/mtproto/handle_message.go +++ b/pkg/gotd/mtproto/handle_message.go @@ -19,7 +19,7 @@ func (c *Conn) handleMessage(msgID int64, b *bin.Buffer) error { case mt.NewSessionCreatedTypeID: return c.handleSessionCreated(b) case mt.BadMsgNotificationTypeID, mt.BadServerSaltTypeID: - return c.handleBadMsg(b) + return c.handleBadMsg(msgID, b) case mt.FutureSaltsTypeID: return c.handleFutureSalts(b) case proto.MessageContainerTypeID: diff --git a/pkg/gotd/mtproto/handle_session_created.go b/pkg/gotd/mtproto/handle_session_created.go index 277f3b5c..629f093b 100644 --- a/pkg/gotd/mtproto/handle_session_created.go +++ b/pkg/gotd/mtproto/handle_session_created.go @@ -24,14 +24,7 @@ func (c *Conn) handleSessionCreated(b *bin.Buffer) error { zap.Time("first_msg_time", created.Local()), ) - if (created.Before(now) && now.Sub(created) > maxPast) || created.Sub(now) > maxFuture { - c.log.Warn("Local clock needs synchronization", - zap.Time("first_msg_time", created), - zap.Time("local", now), - zap.Duration("time_difference", now.Sub(created)), - ) - } - + c.setServerTimeOffset(created.Sub(now)) c.storeSalt(s.ServerSalt) if err := c.handler.OnSession(c.session()); err != nil { return errors.Wrap(err, "handler.OnSession") diff --git a/pkg/gotd/mtproto/options.go b/pkg/gotd/mtproto/options.go index 824ac471..56ac0098 100644 --- a/pkg/gotd/mtproto/options.go +++ b/pkg/gotd/mtproto/options.go @@ -11,7 +11,6 @@ import ( "go.mau.fi/mautrix-telegram/pkg/gotd/clock" "go.mau.fi/mautrix-telegram/pkg/gotd/crypto" "go.mau.fi/mautrix-telegram/pkg/gotd/exchange" - "go.mau.fi/mautrix-telegram/pkg/gotd/proto" "go.mau.fi/mautrix-telegram/pkg/gotd/rpc" "go.mau.fi/mautrix-telegram/pkg/gotd/tmap" ) @@ -64,9 +63,6 @@ type Options struct { // If < 0, compression will be disabled. // If == 0, default value will be used. CompressThreshold int - // MessageID is message id source. Share source between connection to - // reduce collision probability. - MessageID MessageIDSource // Clock is current time source. Defaults to system time. Clock clock.Clock // Types map, used in verbose logging of incoming message. @@ -152,9 +148,6 @@ func (opt *Options) setDefaults() { if opt.Clock == nil { opt.Clock = clock.System } - if opt.MessageID == nil { - opt.MessageID = proto.NewMessageIDGen(opt.Clock.Now) - } if len(opt.PublicKeys) == 0 { opt.setDefaultPublicKeys() } diff --git a/pkg/gotd/mtproto/read.go b/pkg/gotd/mtproto/read.go index 63b49079..7241aebd 100644 --- a/pkg/gotd/mtproto/read.go +++ b/pkg/gotd/mtproto/read.go @@ -38,12 +38,14 @@ func checkMessageID(now time.Time, rawID int64) error { return errors.Wrapf(errRejected, "unexpected type %s", id.Type()) } - created := id.Time() - if created.Before(now) && now.Sub(created) > maxPast { - return errors.Wrap(errRejected, "created too far in past") - } - if created.Sub(now) > maxFuture { - return errors.Wrap(errRejected, "created too far in future") + if !now.IsZero() { + created := id.Time() + if created.Before(now) && now.Sub(created) > maxPast { + return errors.Wrap(errRejected, "created too far in past") + } + if created.Sub(now) > maxFuture { + return errors.Wrap(errRejected, "created too far in future") + } } return nil @@ -60,7 +62,8 @@ func (c *Conn) decryptMessage(b *bin.Buffer) (*crypto.EncryptedMessageData, erro if msg.SessionID != session.ID { return nil, errors.Wrapf(errRejected, "invalid session (got %d, expected %d)", msg.SessionID, session.ID) } - if err := checkMessageID(c.clock.Now(), msg.MessageID); err != nil { + + if err := checkMessageID(c.altTimeWithOffset(), msg.MessageID); err != nil { return nil, errors.Wrapf(err, "bad message id %d", msg.MessageID) } if !c.messageIDBuf.Consume(msg.MessageID) { diff --git a/pkg/gotd/mtproto/rpc.go b/pkg/gotd/mtproto/rpc.go index c6d089f7..56fa5419 100644 --- a/pkg/gotd/mtproto/rpc.go +++ b/pkg/gotd/mtproto/rpc.go @@ -8,6 +8,7 @@ import ( "go.mau.fi/mautrix-telegram/pkg/gotd/bin" "go.mau.fi/mautrix-telegram/pkg/gotd/mt" + "go.mau.fi/mautrix-telegram/pkg/gotd/proto" "go.mau.fi/mautrix-telegram/pkg/gotd/rpc" ) @@ -23,20 +24,29 @@ func (c *Conn) Invoke(ctx context.Context, input bin.Encoder, output bin.Decoder Output: output, } - if err := c.rpc.Do(ctx, req); err != nil { + for retries := 0; ; retries++ { var badMsgErr *badMessageError - if errors.As(err, &badMsgErr) && badMsgErr.Code == codeIncorrectServerSalt { + err := c.rpc.Do(ctx, req) + if err == nil || retries >= 2 || !errors.As(err, &badMsgErr) { + return err + } else if badMsgErr.Code == codeIncorrectServerSalt { // Store salt from server. c.storeSalt(badMsgErr.NewSalt) // Reset saved salts to fetch new. c.salts.Reset() c.log.Info("Retrying request after updating salt from badMsgErr", zap.Int64("msg_id", req.MsgID)) - return c.rpc.Do(ctx, req) + } else if badMsgErr.TimeResynced { + req.MsgID, req.SeqNo = c.nextMsgSeq(true) + c.log.Info("Retrying request after adjusting time offset from badMsgErr", + zap.Int64("old_msg_id", msgID), + zap.Int64("new_msg_id", req.MsgID), + zap.Stringer("old_msg_id_str", proto.MessageID(msgID)), + zap.Stringer("new_msg_id_str", proto.MessageID(req.MsgID)), + ) + } else { + return err } - return err } - - return nil } func (c *Conn) dropRPC(req rpc.Request) error { diff --git a/pkg/gotd/mtproto/salt.go b/pkg/gotd/mtproto/salt.go index 41a828bf..68637366 100644 --- a/pkg/gotd/mtproto/salt.go +++ b/pkg/gotd/mtproto/salt.go @@ -23,7 +23,7 @@ func (c *Conn) storeSalt(salt int64) { } func (c *Conn) updateSalt() { - salt, ok := c.salts.Get(c.clock.Now().Add(time.Minute * 5)) + salt, ok := c.salts.Get(c.TimeWithOffset().Add(time.Minute * 5)) if !ok { return } diff --git a/pkg/gotd/proto/message_id.go b/pkg/gotd/proto/message_id.go index b101e7df..16d4fd88 100644 --- a/pkg/gotd/proto/message_id.go +++ b/pkg/gotd/proto/message_id.go @@ -152,6 +152,12 @@ func (g *MessageIDGen) New(t MessageType) int64 { return int64(NewMessageIDNano(g.nano, t)) } +func (g *MessageIDGen) Reset() { + g.mux.Lock() + g.nano = 0 + g.mux.Unlock() +} + // NewMessageIDGen creates new message id generator. // // Current time will be provided by now() function. diff --git a/pkg/gotd/telegram/client.go b/pkg/gotd/telegram/client.go index bbd2ce3c..2b2f7cb2 100644 --- a/pkg/gotd/telegram/client.go +++ b/pkg/gotd/telegram/client.go @@ -206,7 +206,6 @@ func NewClient(appID int, appHash string, opt Options) *Client { RetryInterval: opt.RetryInterval, MaxRetries: opt.MaxRetries, CompressThreshold: opt.CompressThreshold, - MessageID: opt.MessageID, ExchangeTimeout: opt.ExchangeTimeout, DialTimeout: opt.DialTimeout, Clock: opt.Clock, diff --git a/pkg/gotd/telegram/options.go b/pkg/gotd/telegram/options.go index bca05ece..e8180767 100644 --- a/pkg/gotd/telegram/options.go +++ b/pkg/gotd/telegram/options.go @@ -12,8 +12,6 @@ import ( "go.mau.fi/mautrix-telegram/pkg/gotd/clock" "go.mau.fi/mautrix-telegram/pkg/gotd/crypto" "go.mau.fi/mautrix-telegram/pkg/gotd/exchange" - "go.mau.fi/mautrix-telegram/pkg/gotd/mtproto" - "go.mau.fi/mautrix-telegram/pkg/gotd/proto" "go.mau.fi/mautrix-telegram/pkg/gotd/telegram/dcs" "go.mau.fi/mautrix-telegram/pkg/gotd/tg" ) @@ -99,8 +97,7 @@ type Options struct { // Will be sent with session creation request. Device DeviceConfig - MessageID mtproto.MessageIDSource - Clock clock.Clock + Clock clock.Clock PingInterval time.Duration PingTimeout time.Duration @@ -153,9 +150,6 @@ func (opt *Options) setDefaults() { if opt.MigrationTimeout == 0 { opt.MigrationTimeout = time.Second * 15 } - if opt.MessageID == nil { - opt.MessageID = proto.NewMessageIDGen(opt.Clock.Now) - } if opt.UpdateHandler == nil { // No updates handler passed, so no sense to subscribe for updates. // User should explicitly ignore updates using custom UpdateHandler.