From 323fe1603e1d53e26a4a4d7e0ec9306776852bff Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Thu, 6 Jun 2024 16:59:53 -0600 Subject: [PATCH] store: save updates state in database Signed-off-by: Sumner Evans --- pkg/connector/client.go | 34 ++++-- pkg/connector/login.go | 2 +- pkg/store/container.go | 4 +- pkg/store/scoped_store.go | 179 +++++++++++++++++++++++++++++++ pkg/store/session_store.go | 39 ------- pkg/store/upgrades/00-latest.sql | 27 ++++- 6 files changed, 232 insertions(+), 53 deletions(-) create mode 100644 pkg/store/scoped_store.go delete mode 100644 pkg/store/session_store.go diff --git a/pkg/connector/client.go b/pkg/connector/client.go index c3317190..ceecbb02 100644 --- a/pkg/connector/client.go +++ b/pkg/connector/client.go @@ -11,7 +11,6 @@ import ( "github.com/gotd/td/telegram" "github.com/gotd/td/telegram/message" "github.com/gotd/td/telegram/updates" - "github.com/gotd/td/telegram/updates/hook" "github.com/gotd/td/tg" "github.com/rs/zerolog" "go.mau.fi/zerozap" @@ -36,12 +35,12 @@ func NewTelegramClient(ctx context.Context, tc *TelegramConnector, login *bridge return nil, err } - logger := zerolog.Ctx(ctx).With(). + log := zerolog.Ctx(ctx).With(). Str("component", "telegram_client"). Int64("login_id", loginID). Logger() - zaplog := zap.New(zerozap.New(logger)) + zaplog := zap.New(zerozap.New(log)) client := TelegramClient{ main: tc, @@ -52,20 +51,32 @@ func NewTelegramClient(ctx context.Context, tc *TelegramConnector, login *bridge dispatcher := tg.NewUpdateDispatcher() dispatcher.OnNewMessage(client.onUpdateNewMessage) + store := tc.store.GetScopedStore(loginID) + updatesManager := updates.New(updates.Config{ - Handler: dispatcher, - Logger: zaplog.Named("gaps"), + OnChannelTooLong: func(channelID int64) { + log.Error().Int64("channel_id", channelID).Msg("OnChannelTooLong") + panic("unimplemented channel too long") + }, + Handler: dispatcher, + Logger: zaplog.Named("gaps"), + Storage: store, + AccessHasher: store, }) client.client = telegram.NewClient(tc.Config.AppID, tc.Config.AppHash, telegram.Options{ - SessionStorage: tc.store.GetSessionStore(loginID), + SessionStorage: store, Logger: zaplog, UpdateHandler: updatesManager, - Middlewares: []telegram.Middleware{ - hook.UpdateHook(updatesManager.Handle), - }, }) client.clientCancel, err = connectTelegramClient(ctx, client.client) + go func() { + err = updatesManager.Run(ctx, client.client.API(), loginID, updates.AuthOptions{}) + if err != nil { + log.Err(err).Msg("updates manager error") + client.clientCancel() + } + }() return &client, err } @@ -138,7 +149,10 @@ func (t *TelegramClient) onUpdateNewMessage(ctx context.Context, e tg.Entities, Type: bridgev2.RemoteEventMessage, LogContext: func(c zerolog.Context) zerolog.Context { return c. - Int("message_id", update.Message.GetID()) + Int("message_id", update.Message.GetID()). + Str("sender", string(sender.Sender)). + Str("sender_login", string(sender.SenderLogin)). + Bool("is_from_me", sender.IsFromMe) }, ID: makeMessageID(msg.ID), Sender: sender, diff --git a/pkg/connector/login.go b/pkg/connector/login.go index a6d83bf6..694feaaa 100644 --- a/pkg/connector/login.go +++ b/pkg/connector/login.go @@ -175,7 +175,7 @@ func (p *PhoneLogin) SubmitUserInput(ctx context.Context, input map[string]strin func (p *PhoneLogin) handleAuthSuccess(ctx context.Context, authorization *tg.AuthAuthorization) (*bridgev2.LoginStep, error) { // Now that we have the Telegram user ID, store it in the database and // close the login client. - sessionStore := p.main.store.GetSessionStore(authorization.User.GetID()) + sessionStore := p.main.store.GetScopedStore(authorization.User.GetID()) var sessionData []byte sessionData, err := p.storage.Bytes(sessionData) if err != nil { diff --git a/pkg/store/container.go b/pkg/store/container.go index ab4e7657..98a99278 100644 --- a/pkg/store/container.go +++ b/pkg/store/container.go @@ -36,6 +36,6 @@ func (c *Container) Upgrade(ctx context.Context) error { return c.db.Upgrade(ctx) } -func (c *Container) GetSessionStore(telegramUserID int64) *SessionStore { - return &SessionStore{c.db, telegramUserID} +func (c *Container) GetScopedStore(telegramUserID int64) *scopedStore { + return &scopedStore{c.db, telegramUserID} } diff --git a/pkg/store/scoped_store.go b/pkg/store/scoped_store.go new file mode 100644 index 00000000..97928dc4 --- /dev/null +++ b/pkg/store/scoped_store.go @@ -0,0 +1,179 @@ +package store + +import ( + "context" + "database/sql" + "errors" + "fmt" + + "github.com/gotd/td/session" + "github.com/gotd/td/telegram/updates" + "go.mau.fi/util/dbutil" +) + +// scopedStore is a wrapper around a database that implements +// [session.Storage] scoped to a specific Telegram user ID. +type scopedStore struct { + db *dbutil.Database + telegramUserID int64 +} + +const ( + // Session Storage Queries + loadSessionQuery = `SELECT session_data FROM telegram_session WHERE user_id=$1` + storeSessionQuery = ` + INSERT INTO telegram_session (user_id, session_data) + VALUES ($1, $2) + ON CONFLICT (user_id) DO UPDATE SET session_data=excluded.session_data + ` + + // State Storage Queries + allChannelsQuery = "SELECT channel_id, pts FROM telegram_channel_state WHERE user_id=$1" + getChannelPtsQuery = "SELECT pts FROM telegram_channel_state WHERE user_id=$1 AND channel_id=$2" + setChannelPtsQuery = ` + INSERT INTO telegram_channel_state (user_id, channel_id, pts) + VALUES ($1, $2, $3) + ON CONFLICT (user_id, channel_id) DO UPDATE SET pts=excluded.pts + ` + getStateQuery = "SELECT pts, qts, date, seq from telegram_user_state WHERE user_id=$1" + setStateQuery = ` + INSERT INTO telegram_user_state (user_id, pts, qts, date, seq) + VALUES ($1, $2, $3, $4, $5) + ON CONFLICT (user_id) DO UPDATE SET + pts=excluded.pts, + qts=excluded.qts, + date=excluded.date, + seq=excluded.seq + ` + setPtsQuery = "UPDATE telegram_user_state SET pts=$1 WHERE user_id=$2" + setQtsQuery = "UPDATE telegram_user_state SET qts=$1 WHERE user_id=$2" + setDateQuery = "UPDATE telegram_user_state SET date=$1 WHERE user_id=$2" + setSeqQuery = "UPDATE telegram_user_state SET seq=$1 WHERE user_id=$2" + setDateSeqQuery = "UPDATE telegram_user_state SET date=$1, seq=$2 WHERE user_id=$3" + + // Channel Access Hasher Queries + getChannelAccessHashQuery = "SELECT access_hash FROM telegram_channel_access_hashes WHERE user_id=$1 AND channel_id=$2" + setChannelAccessHashQuery = ` + INSERT INTO telegram_channel_access_hashes (user_id, channel_id, access_hash) + VALUES ($1, $2, $3) + ON CONFLICT (user_id, channel_id) DO UPDATE SET access_hash=excluded.access_hash + ` +) + +var _ session.Storage = (*scopedStore)(nil) + +func (s *scopedStore) LoadSession(ctx context.Context) (sessionData []byte, err error) { + row := s.db.QueryRow(ctx, loadSessionQuery, s.telegramUserID) + err = row.Scan(&sessionData) + return +} + +func (s *scopedStore) StoreSession(ctx context.Context, data []byte) error { + _, err := s.db.Exec(ctx, storeSessionQuery, s.telegramUserID, data) + return err +} + +var _ updates.StateStorage = (*scopedStore)(nil) + +func (s *scopedStore) ForEachChannels(ctx context.Context, userID int64, f func(ctx context.Context, channelID int64, pts int) error) error { + s.assertUserIDMatches(userID) + rows, err := s.db.Query(ctx, allChannelsQuery, userID) + if err != nil { + return err + } + var channelID int64 + var pts int + for rows.Next() { + if err = rows.Scan(&channelID, &pts); err != nil { + return err + } else if err = f(ctx, channelID, pts); err != nil { + return err + } + } + return nil +} + +func (s *scopedStore) GetChannelPts(ctx context.Context, userID int64, channelID int64) (pts int, found bool, err error) { + s.assertUserIDMatches(userID) + err = s.db.QueryRow(ctx, getChannelPtsQuery, userID, channelID).Scan(&pts) + if errors.Is(err, sql.ErrNoRows) { + return 0, false, nil + } + return pts, err == nil, err +} + +func (s *scopedStore) SetChannelPts(ctx context.Context, userID int64, channelID int64, pts int) (err error) { + s.assertUserIDMatches(userID) + _, err = s.db.Exec(ctx, setChannelPtsQuery, userID, channelID, pts) + return +} + +func (s *scopedStore) GetState(ctx context.Context, userID int64) (state updates.State, found bool, err error) { + s.assertUserIDMatches(userID) + err = s.db.QueryRow(ctx, getStateQuery, userID).Scan(&state.Pts, &state.Qts, &state.Date, &state.Seq) + if errors.Is(err, sql.ErrNoRows) { + return state, false, nil + } + return state, err == nil, err +} + +func (s *scopedStore) SetState(ctx context.Context, userID int64, state updates.State) (err error) { + s.assertUserIDMatches(userID) + _, err = s.db.Exec(ctx, setStateQuery, userID, state.Pts, state.Qts, state.Date, state.Seq) + return +} + +func (s *scopedStore) SetPts(ctx context.Context, userID int64, pts int) (err error) { + s.assertUserIDMatches(userID) + _, err = s.db.Exec(ctx, setPtsQuery, userID, pts) + return +} + +func (s *scopedStore) SetQts(ctx context.Context, userID int64, qts int) (err error) { + s.assertUserIDMatches(userID) + _, err = s.db.Exec(ctx, setQtsQuery, userID, qts) + return +} + +func (s *scopedStore) SetSeq(ctx context.Context, userID int64, seq int) (err error) { + s.assertUserIDMatches(userID) + _, err = s.db.Exec(ctx, setSeqQuery, userID, seq) + return +} + +func (s *scopedStore) SetDate(ctx context.Context, userID int64, date int) (err error) { + s.assertUserIDMatches(userID) + _, err = s.db.Exec(ctx, setDateQuery, userID, date) + return +} + +func (s *scopedStore) SetDateSeq(ctx context.Context, userID int64, date int, seq int) (err error) { + s.assertUserIDMatches(userID) + _, err = s.db.Exec(ctx, setDateSeqQuery, userID, date, seq) + return +} + +var _ updates.ChannelAccessHasher = (*scopedStore)(nil) + +func (s *scopedStore) GetChannelAccessHash(ctx context.Context, userID int64, channelID int64) (accessHash int64, found bool, err error) { + s.assertUserIDMatches(userID) + err = s.db.QueryRow(ctx, getChannelAccessHashQuery, userID).Scan(&accessHash) + if errors.Is(err, sql.ErrNoRows) { + return 0, false, nil + } + return accessHash, err == nil, err +} + +func (s *scopedStore) SetChannelAccessHash(ctx context.Context, userID int64, channelID int64, accessHash int64) (err error) { + s.assertUserIDMatches(userID) + _, err = s.db.Exec(ctx, setChannelAccessHashQuery, userID, channelID, accessHash) + return +} + +// Helper Functions + +func (s *scopedStore) assertUserIDMatches(userID int64) { + if s.telegramUserID != userID { + panic(fmt.Sprintf("scoped store for %d function called with user ID %d", s.telegramUserID, userID)) + } +} diff --git a/pkg/store/session_store.go b/pkg/store/session_store.go deleted file mode 100644 index efa85aa6..00000000 --- a/pkg/store/session_store.go +++ /dev/null @@ -1,39 +0,0 @@ -package store - -import ( - "context" - - "github.com/gotd/td/session" - "go.mau.fi/util/dbutil" -) - -// SessionStore is a wrapper around a database that implements -// [session.Storage] scoped to a specific Telegram user ID. -type SessionStore struct { - db *dbutil.Database - telegramUserID int64 -} - -var _ session.Storage = (*SessionStore)(nil) - -const ( - loadSessionQuery = `SELECT session_data FROM telegram_session WHERE user_id=$1` - storeSessionQuery = ` - INSERT INTO telegram_session (user_id, session_data) - VALUES ($1, $2) - ON CONFLICT (user_id) DO UPDATE SET session_data=excluded.session_data - ` -) - -// LoadSession loads session data from the database. -func (s *SessionStore) LoadSession(ctx context.Context) (sessionData []byte, err error) { - row := s.db.QueryRow(ctx, loadSessionQuery, s.telegramUserID) - err = row.Scan(&sessionData) - return -} - -// StoreSession stores session data for a login into the database. -func (s *SessionStore) StoreSession(ctx context.Context, data []byte) error { - _, err := s.db.Exec(ctx, storeSessionQuery, s.telegramUserID, data) - return err -} diff --git a/pkg/store/upgrades/00-latest.sql b/pkg/store/upgrades/00-latest.sql index 72ffbb99..327130d3 100644 --- a/pkg/store/upgrades/00-latest.sql +++ b/pkg/store/upgrades/00-latest.sql @@ -1,7 +1,32 @@ -- v0 -> v1: Latest revision --- TODO do I need to have bridge ID here? CREATE TABLE telegram_session ( user_id INTEGER PRIMARY KEY, session_data BYTEA NOT NULL ); + +CREATE TABLE telegram_user_state ( + user_id INTEGER PRIMARY KEY, + pts INTEGER NOT NULL, + qts INTEGER NOT NULL, + date INTEGER NOT NULL, + seq INTEGER NOT NULL +); + +CREATE TABLE telegram_channel_state ( + user_id INTEGER, + channel_id INTEGER, + pts INTEGER NOT NULL, + + PRIMARY KEY (user_id, channel_id) +); + +CREATE INDEX idx_telegram_channel_state_user_id ON telegram_channel_state (user_id); + +CREATE TABLE telegram_channel_access_hashes ( + user_id INTEGER, + channel_id INTEGER, + access_hash INTEGER NOT NULL, + + PRIMARY KEY (user_id, channel_id) +);