store: save updates state in database
Signed-off-by: Sumner Evans <sumner.evans@automattic.com>
This commit is contained in:
+24
-10
@@ -11,7 +11,6 @@ import (
|
||||
"github.com/gotd/td/telegram"
|
||||
"github.com/gotd/td/telegram/message"
|
||||
"github.com/gotd/td/telegram/updates"
|
||||
"github.com/gotd/td/telegram/updates/hook"
|
||||
"github.com/gotd/td/tg"
|
||||
"github.com/rs/zerolog"
|
||||
"go.mau.fi/zerozap"
|
||||
@@ -36,12 +35,12 @@ func NewTelegramClient(ctx context.Context, tc *TelegramConnector, login *bridge
|
||||
return nil, err
|
||||
}
|
||||
|
||||
logger := zerolog.Ctx(ctx).With().
|
||||
log := zerolog.Ctx(ctx).With().
|
||||
Str("component", "telegram_client").
|
||||
Int64("login_id", loginID).
|
||||
Logger()
|
||||
|
||||
zaplog := zap.New(zerozap.New(logger))
|
||||
zaplog := zap.New(zerozap.New(log))
|
||||
|
||||
client := TelegramClient{
|
||||
main: tc,
|
||||
@@ -52,20 +51,32 @@ func NewTelegramClient(ctx context.Context, tc *TelegramConnector, login *bridge
|
||||
dispatcher := tg.NewUpdateDispatcher()
|
||||
dispatcher.OnNewMessage(client.onUpdateNewMessage)
|
||||
|
||||
store := tc.store.GetScopedStore(loginID)
|
||||
|
||||
updatesManager := updates.New(updates.Config{
|
||||
Handler: dispatcher,
|
||||
Logger: zaplog.Named("gaps"),
|
||||
OnChannelTooLong: func(channelID int64) {
|
||||
log.Error().Int64("channel_id", channelID).Msg("OnChannelTooLong")
|
||||
panic("unimplemented channel too long")
|
||||
},
|
||||
Handler: dispatcher,
|
||||
Logger: zaplog.Named("gaps"),
|
||||
Storage: store,
|
||||
AccessHasher: store,
|
||||
})
|
||||
|
||||
client.client = telegram.NewClient(tc.Config.AppID, tc.Config.AppHash, telegram.Options{
|
||||
SessionStorage: tc.store.GetSessionStore(loginID),
|
||||
SessionStorage: store,
|
||||
Logger: zaplog,
|
||||
UpdateHandler: updatesManager,
|
||||
Middlewares: []telegram.Middleware{
|
||||
hook.UpdateHook(updatesManager.Handle),
|
||||
},
|
||||
})
|
||||
client.clientCancel, err = connectTelegramClient(ctx, client.client)
|
||||
go func() {
|
||||
err = updatesManager.Run(ctx, client.client.API(), loginID, updates.AuthOptions{})
|
||||
if err != nil {
|
||||
log.Err(err).Msg("updates manager error")
|
||||
client.clientCancel()
|
||||
}
|
||||
}()
|
||||
return &client, err
|
||||
}
|
||||
|
||||
@@ -138,7 +149,10 @@ func (t *TelegramClient) onUpdateNewMessage(ctx context.Context, e tg.Entities,
|
||||
Type: bridgev2.RemoteEventMessage,
|
||||
LogContext: func(c zerolog.Context) zerolog.Context {
|
||||
return c.
|
||||
Int("message_id", update.Message.GetID())
|
||||
Int("message_id", update.Message.GetID()).
|
||||
Str("sender", string(sender.Sender)).
|
||||
Str("sender_login", string(sender.SenderLogin)).
|
||||
Bool("is_from_me", sender.IsFromMe)
|
||||
},
|
||||
ID: makeMessageID(msg.ID),
|
||||
Sender: sender,
|
||||
|
||||
@@ -175,7 +175,7 @@ func (p *PhoneLogin) SubmitUserInput(ctx context.Context, input map[string]strin
|
||||
func (p *PhoneLogin) handleAuthSuccess(ctx context.Context, authorization *tg.AuthAuthorization) (*bridgev2.LoginStep, error) {
|
||||
// Now that we have the Telegram user ID, store it in the database and
|
||||
// close the login client.
|
||||
sessionStore := p.main.store.GetSessionStore(authorization.User.GetID())
|
||||
sessionStore := p.main.store.GetScopedStore(authorization.User.GetID())
|
||||
var sessionData []byte
|
||||
sessionData, err := p.storage.Bytes(sessionData)
|
||||
if err != nil {
|
||||
|
||||
@@ -36,6 +36,6 @@ func (c *Container) Upgrade(ctx context.Context) error {
|
||||
return c.db.Upgrade(ctx)
|
||||
}
|
||||
|
||||
func (c *Container) GetSessionStore(telegramUserID int64) *SessionStore {
|
||||
return &SessionStore{c.db, telegramUserID}
|
||||
func (c *Container) GetScopedStore(telegramUserID int64) *scopedStore {
|
||||
return &scopedStore{c.db, telegramUserID}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,179 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/gotd/td/session"
|
||||
"github.com/gotd/td/telegram/updates"
|
||||
"go.mau.fi/util/dbutil"
|
||||
)
|
||||
|
||||
// scopedStore is a wrapper around a database that implements
|
||||
// [session.Storage] scoped to a specific Telegram user ID.
|
||||
type scopedStore struct {
|
||||
db *dbutil.Database
|
||||
telegramUserID int64
|
||||
}
|
||||
|
||||
const (
|
||||
// Session Storage Queries
|
||||
loadSessionQuery = `SELECT session_data FROM telegram_session WHERE user_id=$1`
|
||||
storeSessionQuery = `
|
||||
INSERT INTO telegram_session (user_id, session_data)
|
||||
VALUES ($1, $2)
|
||||
ON CONFLICT (user_id) DO UPDATE SET session_data=excluded.session_data
|
||||
`
|
||||
|
||||
// State Storage Queries
|
||||
allChannelsQuery = "SELECT channel_id, pts FROM telegram_channel_state WHERE user_id=$1"
|
||||
getChannelPtsQuery = "SELECT pts FROM telegram_channel_state WHERE user_id=$1 AND channel_id=$2"
|
||||
setChannelPtsQuery = `
|
||||
INSERT INTO telegram_channel_state (user_id, channel_id, pts)
|
||||
VALUES ($1, $2, $3)
|
||||
ON CONFLICT (user_id, channel_id) DO UPDATE SET pts=excluded.pts
|
||||
`
|
||||
getStateQuery = "SELECT pts, qts, date, seq from telegram_user_state WHERE user_id=$1"
|
||||
setStateQuery = `
|
||||
INSERT INTO telegram_user_state (user_id, pts, qts, date, seq)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
ON CONFLICT (user_id) DO UPDATE SET
|
||||
pts=excluded.pts,
|
||||
qts=excluded.qts,
|
||||
date=excluded.date,
|
||||
seq=excluded.seq
|
||||
`
|
||||
setPtsQuery = "UPDATE telegram_user_state SET pts=$1 WHERE user_id=$2"
|
||||
setQtsQuery = "UPDATE telegram_user_state SET qts=$1 WHERE user_id=$2"
|
||||
setDateQuery = "UPDATE telegram_user_state SET date=$1 WHERE user_id=$2"
|
||||
setSeqQuery = "UPDATE telegram_user_state SET seq=$1 WHERE user_id=$2"
|
||||
setDateSeqQuery = "UPDATE telegram_user_state SET date=$1, seq=$2 WHERE user_id=$3"
|
||||
|
||||
// Channel Access Hasher Queries
|
||||
getChannelAccessHashQuery = "SELECT access_hash FROM telegram_channel_access_hashes WHERE user_id=$1 AND channel_id=$2"
|
||||
setChannelAccessHashQuery = `
|
||||
INSERT INTO telegram_channel_access_hashes (user_id, channel_id, access_hash)
|
||||
VALUES ($1, $2, $3)
|
||||
ON CONFLICT (user_id, channel_id) DO UPDATE SET access_hash=excluded.access_hash
|
||||
`
|
||||
)
|
||||
|
||||
var _ session.Storage = (*scopedStore)(nil)
|
||||
|
||||
func (s *scopedStore) LoadSession(ctx context.Context) (sessionData []byte, err error) {
|
||||
row := s.db.QueryRow(ctx, loadSessionQuery, s.telegramUserID)
|
||||
err = row.Scan(&sessionData)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *scopedStore) StoreSession(ctx context.Context, data []byte) error {
|
||||
_, err := s.db.Exec(ctx, storeSessionQuery, s.telegramUserID, data)
|
||||
return err
|
||||
}
|
||||
|
||||
var _ updates.StateStorage = (*scopedStore)(nil)
|
||||
|
||||
func (s *scopedStore) ForEachChannels(ctx context.Context, userID int64, f func(ctx context.Context, channelID int64, pts int) error) error {
|
||||
s.assertUserIDMatches(userID)
|
||||
rows, err := s.db.Query(ctx, allChannelsQuery, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var channelID int64
|
||||
var pts int
|
||||
for rows.Next() {
|
||||
if err = rows.Scan(&channelID, &pts); err != nil {
|
||||
return err
|
||||
} else if err = f(ctx, channelID, pts); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *scopedStore) GetChannelPts(ctx context.Context, userID int64, channelID int64) (pts int, found bool, err error) {
|
||||
s.assertUserIDMatches(userID)
|
||||
err = s.db.QueryRow(ctx, getChannelPtsQuery, userID, channelID).Scan(&pts)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return 0, false, nil
|
||||
}
|
||||
return pts, err == nil, err
|
||||
}
|
||||
|
||||
func (s *scopedStore) SetChannelPts(ctx context.Context, userID int64, channelID int64, pts int) (err error) {
|
||||
s.assertUserIDMatches(userID)
|
||||
_, err = s.db.Exec(ctx, setChannelPtsQuery, userID, channelID, pts)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *scopedStore) GetState(ctx context.Context, userID int64) (state updates.State, found bool, err error) {
|
||||
s.assertUserIDMatches(userID)
|
||||
err = s.db.QueryRow(ctx, getStateQuery, userID).Scan(&state.Pts, &state.Qts, &state.Date, &state.Seq)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return state, false, nil
|
||||
}
|
||||
return state, err == nil, err
|
||||
}
|
||||
|
||||
func (s *scopedStore) SetState(ctx context.Context, userID int64, state updates.State) (err error) {
|
||||
s.assertUserIDMatches(userID)
|
||||
_, err = s.db.Exec(ctx, setStateQuery, userID, state.Pts, state.Qts, state.Date, state.Seq)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *scopedStore) SetPts(ctx context.Context, userID int64, pts int) (err error) {
|
||||
s.assertUserIDMatches(userID)
|
||||
_, err = s.db.Exec(ctx, setPtsQuery, userID, pts)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *scopedStore) SetQts(ctx context.Context, userID int64, qts int) (err error) {
|
||||
s.assertUserIDMatches(userID)
|
||||
_, err = s.db.Exec(ctx, setQtsQuery, userID, qts)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *scopedStore) SetSeq(ctx context.Context, userID int64, seq int) (err error) {
|
||||
s.assertUserIDMatches(userID)
|
||||
_, err = s.db.Exec(ctx, setSeqQuery, userID, seq)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *scopedStore) SetDate(ctx context.Context, userID int64, date int) (err error) {
|
||||
s.assertUserIDMatches(userID)
|
||||
_, err = s.db.Exec(ctx, setDateQuery, userID, date)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *scopedStore) SetDateSeq(ctx context.Context, userID int64, date int, seq int) (err error) {
|
||||
s.assertUserIDMatches(userID)
|
||||
_, err = s.db.Exec(ctx, setDateSeqQuery, userID, date, seq)
|
||||
return
|
||||
}
|
||||
|
||||
var _ updates.ChannelAccessHasher = (*scopedStore)(nil)
|
||||
|
||||
func (s *scopedStore) GetChannelAccessHash(ctx context.Context, userID int64, channelID int64) (accessHash int64, found bool, err error) {
|
||||
s.assertUserIDMatches(userID)
|
||||
err = s.db.QueryRow(ctx, getChannelAccessHashQuery, userID).Scan(&accessHash)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return 0, false, nil
|
||||
}
|
||||
return accessHash, err == nil, err
|
||||
}
|
||||
|
||||
func (s *scopedStore) SetChannelAccessHash(ctx context.Context, userID int64, channelID int64, accessHash int64) (err error) {
|
||||
s.assertUserIDMatches(userID)
|
||||
_, err = s.db.Exec(ctx, setChannelAccessHashQuery, userID, channelID, accessHash)
|
||||
return
|
||||
}
|
||||
|
||||
// Helper Functions
|
||||
|
||||
func (s *scopedStore) assertUserIDMatches(userID int64) {
|
||||
if s.telegramUserID != userID {
|
||||
panic(fmt.Sprintf("scoped store for %d function called with user ID %d", s.telegramUserID, userID))
|
||||
}
|
||||
}
|
||||
@@ -1,39 +0,0 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/gotd/td/session"
|
||||
"go.mau.fi/util/dbutil"
|
||||
)
|
||||
|
||||
// SessionStore is a wrapper around a database that implements
|
||||
// [session.Storage] scoped to a specific Telegram user ID.
|
||||
type SessionStore struct {
|
||||
db *dbutil.Database
|
||||
telegramUserID int64
|
||||
}
|
||||
|
||||
var _ session.Storage = (*SessionStore)(nil)
|
||||
|
||||
const (
|
||||
loadSessionQuery = `SELECT session_data FROM telegram_session WHERE user_id=$1`
|
||||
storeSessionQuery = `
|
||||
INSERT INTO telegram_session (user_id, session_data)
|
||||
VALUES ($1, $2)
|
||||
ON CONFLICT (user_id) DO UPDATE SET session_data=excluded.session_data
|
||||
`
|
||||
)
|
||||
|
||||
// LoadSession loads session data from the database.
|
||||
func (s *SessionStore) LoadSession(ctx context.Context) (sessionData []byte, err error) {
|
||||
row := s.db.QueryRow(ctx, loadSessionQuery, s.telegramUserID)
|
||||
err = row.Scan(&sessionData)
|
||||
return
|
||||
}
|
||||
|
||||
// StoreSession stores session data for a login into the database.
|
||||
func (s *SessionStore) StoreSession(ctx context.Context, data []byte) error {
|
||||
_, err := s.db.Exec(ctx, storeSessionQuery, s.telegramUserID, data)
|
||||
return err
|
||||
}
|
||||
@@ -1,7 +1,32 @@
|
||||
-- v0 -> v1: Latest revision
|
||||
|
||||
-- TODO do I need to have bridge ID here?
|
||||
CREATE TABLE telegram_session (
|
||||
user_id INTEGER PRIMARY KEY,
|
||||
session_data BYTEA NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE telegram_user_state (
|
||||
user_id INTEGER PRIMARY KEY,
|
||||
pts INTEGER NOT NULL,
|
||||
qts INTEGER NOT NULL,
|
||||
date INTEGER NOT NULL,
|
||||
seq INTEGER NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE telegram_channel_state (
|
||||
user_id INTEGER,
|
||||
channel_id INTEGER,
|
||||
pts INTEGER NOT NULL,
|
||||
|
||||
PRIMARY KEY (user_id, channel_id)
|
||||
);
|
||||
|
||||
CREATE INDEX idx_telegram_channel_state_user_id ON telegram_channel_state (user_id);
|
||||
|
||||
CREATE TABLE telegram_channel_access_hashes (
|
||||
user_id INTEGER,
|
||||
channel_id INTEGER,
|
||||
access_hash INTEGER NOT NULL,
|
||||
|
||||
PRIMARY KEY (user_id, channel_id)
|
||||
);
|
||||
|
||||
Reference in New Issue
Block a user