store: refactor access hash and session tables

* Move sessions to user_login metadata, as that data rarely changes after login.
* Merge user and channel access hashes. Those IDs don't conflict.
* Split usernames into a new table to allow better `ON CONFLICT` updates
  (when a username moves to another entity, we want the old row to be replaced).
  Usernames also don't need to be scoped to a login.
This commit is contained in:
Tulir Asokan
2024-08-22 16:21:24 +03:00
parent e611c87342
commit b25c09fc53
14 changed files with 129 additions and 195 deletions
+2 -6
View File
@@ -22,11 +22,9 @@ func (t *TelegramClient) getDMChatInfo(ctx context.Context, userID int64) (*brid
Members: &bridgev2.ChatMemberList{IsFull: true},
CanBackfill: true,
}
accessHash, found, err := t.ScopedStore.GetUserAccessHash(ctx, userID)
accessHash, err := t.ScopedStore.GetAccessHash(ctx, userID)
if err != nil {
return nil, fmt.Errorf("failed to get access hash for user %d: %w", userID, err)
} else if !found {
return nil, fmt.Errorf("access hash not found for user %d", userID)
}
users, err := t.client.API().UsersGetUsers(ctx, []tg.InputUserClass{&tg.InputUser{
UserID: userID,
@@ -198,11 +196,9 @@ func (t *TelegramClient) GetChatInfo(ctx context.Context, portal *bridgev2.Porta
}
return chatInfo, nil
case ids.PeerTypeChannel:
accessHash, found, err := t.ScopedStore.GetChannelAccessHash(ctx, t.telegramUserID, id)
accessHash, err := t.ScopedStore.GetAccessHash(ctx, id)
if err != nil {
return nil, fmt.Errorf("failed to get channel access hash: %w", err)
} else if !found {
return nil, fmt.Errorf("channel access hash not found for %d", id)
}
inputChannel := &tg.InputChannel{ChannelID: id, AccessHash: accessHash}
fullChat, err := APICallWithUpdates(ctx, t, func() (*tg.MessagesChatFull, error) {
+23 -21
View File
@@ -180,14 +180,14 @@ func NewTelegramClient(ctx context.Context, tc *TelegramConnector, login *bridge
})
client.client = telegram.NewClient(tc.Config.AppID, tc.Config.AppHash, telegram.Options{
SessionStorage: client.ScopedStore,
Logger: zaplog,
UpdateHandler: client.updatesManager,
OnDead: client.onDead,
OnSession: client.onSession,
OnAuthError: client.onAuthError,
PingTimeout: time.Duration(tc.Config.Ping.TimeoutSeconds) * time.Second,
PingInterval: time.Duration(tc.Config.Ping.IntervalSeconds) * time.Second,
CustomSessionStorage: &login.Metadata.(*UserLoginMetadata).Session,
Logger: zaplog,
UpdateHandler: client.updatesManager,
OnDead: client.onDead,
OnSession: client.onSession,
OnAuthError: client.onAuthError,
PingTimeout: time.Duration(tc.Config.Ping.TimeoutSeconds) * time.Second,
PingInterval: time.Duration(tc.Config.Ping.IntervalSeconds) * time.Second,
})
client.telegramFmtParams = &telegramfmt.FormatParams{
@@ -288,8 +288,13 @@ func NewTelegramClient(ctx context.Context, tc *TelegramConnector, login *bridge
if err != nil {
return "", "", 0, false
}
username, accessHash, found, err := tc.Store.GetScopedStore(telegramUserID).GetUserMetadata(ctx, telegramUserID)
if err != nil || !found {
ss := tc.Store.GetScopedStore(telegramUserID)
accessHash, err := ss.GetAccessHash(ctx, telegramUserID)
if err != nil || accessHash == 0 {
return "", "", 0, false
}
username, err := ss.GetUsername(ctx, telegramUserID)
if err != nil {
return "", "", 0, false
}
return userID, username, accessHash, true
@@ -390,11 +395,9 @@ func (t *TelegramClient) GetUserInfo(ctx context.Context, ghost *bridgev2.Ghost)
if err != nil {
return nil, err
}
accessHash, found, err := t.ScopedStore.GetUserAccessHash(ctx, id)
accessHash, err := t.ScopedStore.GetAccessHash(ctx, id)
if err != nil {
return nil, fmt.Errorf("failed to get access hash for user %d: %w", id, err)
} else if !found {
return nil, fmt.Errorf("access hash not found for user %d", id)
}
users, err := t.client.API().UsersGetUsers(ctx, []tg.InputUserClass{&tg.InputUser{
UserID: id,
@@ -419,17 +422,16 @@ func (t *TelegramClient) getUserInfoFromTelegramUser(ctx context.Context, u tg.U
return nil, fmt.Errorf("user is %T not *tg.User", user)
}
var identifiers []string
if user.Min {
if err := t.ScopedStore.SetUserAccessHash(ctx, user.ID, user.AccessHash); err != nil {
return nil, err
}
} else {
if err := t.ScopedStore.SetUserMetadata(ctx, user.ID, user.Username, user.AccessHash); err != nil {
if err := t.ScopedStore.SetAccessHash(ctx, user.ID, user.AccessHash); err != nil {
return nil, err
}
if !user.Min {
if err := t.ScopedStore.SetUsername(ctx, user.ID, user.Username); err != nil {
return nil, err
}
if username, ok := user.GetUsername(); ok {
identifiers = append(identifiers, fmt.Sprintf("telegram:%s", username))
if user.Username != "" {
identifiers = append(identifiers, fmt.Sprintf("telegram:%s", user.Username))
}
for _, username := range user.Usernames {
identifiers = append(identifiers, fmt.Sprintf("telegram:%s", username.Username))
+37 -2
View File
@@ -1,10 +1,13 @@
package connector
import (
"context"
_ "embed"
"fmt"
"slices"
"github.com/gotd/td/crypto"
"github.com/gotd/td/session"
up "go.mau.fi/util/configupgrade"
"maunium.net/go/mautrix/bridgev2"
"maunium.net/go/mautrix/bridgev2/database"
@@ -118,6 +121,38 @@ type MessageMetadata struct {
ContentURI id.ContentURIString `json:"content_uri,omitempty"`
}
type UserLoginMetadata struct {
Phone string `json:"phone"`
type UserLoginSession struct {
AuthKey []byte `json:"auth_key,omitempty"`
Datacenter int `json:"dc_id,omitempty"`
ServerAddress string `json:"server_address,omitempty"`
ServerPort int `json:"port,omitempty"`
Salt int64 `json:"salt,omitempty"`
}
type UserLoginMetadata struct {
Phone string `json:"phone"`
Session UserLoginSession `json:"session"`
}
func (s *UserLoginSession) Load(_ context.Context) (*session.Data, error) {
if len(s.AuthKey) != 256 {
return nil, nil
}
keyID := crypto.Key(s.AuthKey).ID()
return &session.Data{
DC: s.Datacenter,
Addr: s.ServerAddress,
AuthKey: s.AuthKey,
AuthKeyID: keyID[:],
Salt: s.Salt,
}, nil
}
func (s *UserLoginSession) Save(ctx context.Context, data *session.Data) error {
s.Datacenter = data.DC
s.ServerAddress = data.Addr
s.AuthKey = data.AuthKey
s.Salt = data.Salt
// TODO save UserLogin to database?
return nil
}
+1 -4
View File
@@ -65,12 +65,9 @@ func (tc *TelegramConnector) Download(ctx context.Context, mediaID networkid.Med
})
case ids.PeerTypeChannel:
var accessHash int64
var found bool
accessHash, found, err = client.ScopedStore.GetChannelAccessHash(ctx, client.telegramUserID, info.ChatID)
accessHash, err = client.ScopedStore.GetAccessHash(ctx, info.ChatID)
if err != nil {
return nil, fmt.Errorf("failed to get channel access hash: %w", err)
} else if !found {
return nil, fmt.Errorf("channel access hash not found for %d", info.ChatID)
} else {
messages, err = APICallWithUpdates(ctx, client, func() (tg.ModifiedMessagesMessages, error) {
m, err := client.client.API().ChannelsGetMessages(ctx, &tg.ChannelsGetMessagesRequest{
+5 -16
View File
@@ -21,7 +21,6 @@ import (
"errors"
"fmt"
"github.com/gotd/td/session"
"github.com/gotd/td/telegram"
"github.com/gotd/td/telegram/auth"
"github.com/gotd/td/tg"
@@ -64,7 +63,7 @@ const (
type PhoneLogin struct {
user *bridgev2.User
main *TelegramConnector
storage *session.StorageMemory
authData UserLoginSession
client *telegram.Client
clientCancel context.CancelFunc
@@ -99,10 +98,9 @@ func (p *PhoneLogin) Start(ctx context.Context) (*bridgev2.LoginStep, error) {
func (p *PhoneLogin) SubmitUserInput(ctx context.Context, input map[string]string) (*bridgev2.LoginStep, error) {
if phone, ok := input[phoneNumberStep]; ok {
p.phone = phone
p.storage = &session.StorageMemory{}
p.client = telegram.NewClient(p.main.Config.AppID, p.main.Config.AppHash, telegram.Options{
SessionStorage: p.storage,
Logger: zap.New(zerozap.New(zerolog.Ctx(ctx).With().Str("component", "telegram_login_client").Logger())),
CustomSessionStorage: &p.authData,
Logger: zap.New(zerozap.New(zerolog.Ctx(ctx).With().Str("component", "telegram_login_client").Logger())),
})
var err error
p.clientCancel, err = connectTelegramClient(context.Background(), p.client)
@@ -180,23 +178,14 @@ func (p *PhoneLogin) SubmitUserInput(ctx context.Context, input map[string]strin
func (p *PhoneLogin) handleAuthSuccess(ctx context.Context, authorization *tg.AuthAuthorization) (*bridgev2.LoginStep, error) {
// Now that we have the Telegram user ID, store it in the database and
// close the login client.
sessionStore := p.main.Store.GetScopedStore(authorization.User.GetID())
var sessionData []byte
sessionData, err := p.storage.Bytes(sessionData)
if err != nil {
return nil, err
}
err = sessionStore.StoreSession(ctx, sessionData)
if err != nil {
return nil, err
}
p.clientCancel()
userLoginID := ids.MakeUserLoginID(authorization.User.GetID())
ul, err := p.user.NewLogin(ctx, &database.UserLogin{
ID: userLoginID,
Metadata: UserLoginMetadata{
Phone: p.phone,
Phone: p.phone,
Session: p.authData,
},
}, nil)
if err != nil {
+1 -5
View File
@@ -457,13 +457,9 @@ func (t *TelegramClient) HandleMatrixReadReceipt(ctx context.Context, msg *bridg
})
case ids.PeerTypeChannel:
var accessHash int64
var found bool
accessHash, found, readMessagesErr = t.ScopedStore.GetChannelAccessHash(ctx, t.telegramUserID, id)
accessHash, readMessagesErr = t.ScopedStore.GetAccessHash(ctx, id)
if readMessagesErr != nil {
return
} else if !found {
readMessagesErr = fmt.Errorf("channel access hash not found for %d", id)
return
}
_, readMessagesErr = t.client.API().ChannelsReadHistory(ctx, &tg.ChannelsReadHistoryRequest{
Channel: &tg.InputChannel{ChannelID: id, AccessHash: accessHash},
+1 -3
View File
@@ -195,10 +195,8 @@ func (t *Transferer) WithPhoto(pc tg.PhotoClass) *ReadyTransferer {
// given user's photo as the location that will be downloaded by the
// [ReadyTransferer].
func (t *Transferer) WithUserPhoto(ctx context.Context, store *store.ScopedStore, user *tg.User, photoID int64) (*ReadyTransferer, error) {
if accessHash, found, err := store.GetUserAccessHash(ctx, user.GetID()); err != nil {
if accessHash, err := store.GetAccessHash(ctx, user.GetID()); err != nil {
return nil, fmt.Errorf("failed to get user access hash for %d: %w", user.GetID(), err)
} else if !found {
return nil, fmt.Errorf("user access hash not found for %d", user.GetID())
} else {
return &ReadyTransferer{
inner: t,
+37 -85
View File
@@ -6,7 +6,6 @@ import (
"errors"
"fmt"
"github.com/gotd/td/session"
"github.com/gotd/td/telegram/updates"
"go.mau.fi/util/dbutil"
)
@@ -19,14 +18,6 @@ type ScopedStore struct {
}
const (
// Session Storage Queries
loadSessionQuery = `SELECT session_data FROM telegram_session WHERE user_id=$1`
storeSessionQuery = `
INSERT INTO telegram_session (user_id, session_data)
VALUES ($1, $2)
ON CONFLICT (user_id) DO UPDATE SET session_data=excluded.session_data
`
// State Storage Queries
allChannelsQuery = "SELECT channel_id, pts FROM telegram_channel_state WHERE user_id=$1"
getChannelPtsQuery = "SELECT pts FROM telegram_channel_state WHERE user_id=$1 AND channel_id=$2"
@@ -51,54 +42,23 @@ const (
setSeqQuery = "UPDATE telegram_user_state SET seq=$1 WHERE user_id=$2"
setDateSeqQuery = "UPDATE telegram_user_state SET date=$1, seq=$2 WHERE user_id=$3"
// Channel Access Hasher Queries
getChannelAccessHashQuery = "SELECT access_hash FROM telegram_channel_access_hashes WHERE user_id=$1 AND channel_id=$2"
setChannelAccessHashQuery = `
INSERT INTO telegram_channel_access_hashes (user_id, channel_id, access_hash)
getAccessHashQuery = "SELECT access_hash FROM telegram_access_hash WHERE user_id=$1 AND entity_id=$2"
setAccessHashQuery = `
INSERT INTO telegram_access_hash (user_id, entity_id, access_hash)
VALUES ($1, $2, $3)
ON CONFLICT (user_id, channel_id) DO UPDATE SET access_hash=excluded.access_hash
`
// User Access Hash Queries
getUserAccessHashQuery = "SELECT access_hash FROM telegram_user_metadata WHERE receiver_id=$1 AND user_id=$2"
setUserAccessHashQuery = `
INSERT INTO telegram_user_metadata (receiver_id, user_id, access_hash)
VALUES ($1, $2, $3)
ON CONFLICT (receiver_id, user_id) DO UPDATE SET access_hash=excluded.access_hash
ON CONFLICT (user_id, entity_id) DO UPDATE SET access_hash=excluded.access_hash
`
// User Username Queries
getUserUsernameQuery = "SELECT username FROM telegram_user_metadata WHERE receiver_id=$1 AND user_id=$2"
setUserUsernameQuery = `
INSERT INTO telegram_user_metadata (receiver_id, user_id, username)
VALUES ($1, $2, $3)
ON CONFLICT (receiver_id, user_id) DO UPDATE SET username=excluded.username
`
// User Metadata Queries
getUserMetadataQuery = "SELECT username, access_hash FROM telegram_user_metadata WHERE receiver_id=$1 AND user_id=$2"
setUserMetadataQuery = `
INSERT INTO telegram_user_metadata (receiver_id, user_id, username, access_hash)
VALUES ($1, $2, $3, $4)
ON CONFLICT (receiver_id, user_id) DO UPDATE SET
username=excluded.username,
access_hash=excluded.access_hash
getUsernameQuery = "SELECT username FROM telegram_username WHERE entity_id=$1"
setUsernameQuery = `
INSERT INTO telegram_username (username, entity_id)
VALUES ($1, $2)
ON CONFLICT (username) DO UPDATE SET entity_id=excluded.entity_id
`
clearUsernameQuery = `DELETE FROM telegram_username WHERE entity_id=$1`
)
var _ session.Storage = (*ScopedStore)(nil)
func (s *ScopedStore) LoadSession(ctx context.Context) (sessionData []byte, err error) {
row := s.db.QueryRow(ctx, loadSessionQuery, s.telegramUserID)
err = row.Scan(&sessionData)
return
}
func (s *ScopedStore) StoreSession(ctx context.Context, data []byte) error {
_, err := s.db.Exec(ctx, storeSessionQuery, s.telegramUserID, data)
return err
}
var _ updates.StateStorage = (*ScopedStore)(nil)
func (s *ScopedStore) ForEachChannels(ctx context.Context, userID int64, f func(ctx context.Context, channelID int64, pts int) error) error {
@@ -181,57 +141,49 @@ func (s *ScopedStore) SetDateSeq(ctx context.Context, userID int64, date int, se
var _ updates.ChannelAccessHasher = (*ScopedStore)(nil)
func (s *ScopedStore) GetChannelAccessHash(ctx context.Context, userID int64, channelID int64) (accessHash int64, found bool, err error) {
// Deprecated: only for interface, don't use directly. Use GetAccessHash instead
func (s *ScopedStore) GetChannelAccessHash(ctx context.Context, userID, channelID int64) (accessHash int64, found bool, err error) {
s.assertUserIDMatches(userID)
err = s.db.QueryRow(ctx, getChannelAccessHashQuery, userID, channelID).Scan(&accessHash)
if errors.Is(err, sql.ErrNoRows) {
return 0, false, nil
}
return accessHash, err == nil, err
accessHash, err = s.GetAccessHash(ctx, channelID)
found = accessHash != 0
return
}
func (s *ScopedStore) SetChannelAccessHash(ctx context.Context, userID int64, channelID int64, accessHash int64) (err error) {
// Deprecated: only for interface, don't use directly. Use SetAccessHash instead
func (s *ScopedStore) SetChannelAccessHash(ctx context.Context, userID, channelID, accessHash int64) (err error) {
s.assertUserIDMatches(userID)
_, err = s.db.Exec(ctx, setChannelAccessHashQuery, userID, channelID, accessHash)
return s.SetAccessHash(ctx, channelID, accessHash)
}
var ErrNoAccessHash = errors.New("access hash not found")
func (s *ScopedStore) GetAccessHash(ctx context.Context, userID int64) (accessHash int64, err error) {
err = s.db.QueryRow(ctx, getAccessHashQuery, s.telegramUserID, userID).Scan(&accessHash)
if errors.Is(err, sql.ErrNoRows) {
err = ErrNoAccessHash
}
return
}
func (s *ScopedStore) GetUserAccessHash(ctx context.Context, userID int64) (accessHash int64, found bool, err error) {
err = s.db.QueryRow(ctx, getUserAccessHashQuery, s.telegramUserID, userID).Scan(&accessHash)
if errors.Is(err, sql.ErrNoRows) {
return 0, false, nil
}
return accessHash, err == nil, err
}
func (s *ScopedStore) SetUserAccessHash(ctx context.Context, userID, accessHash int64) (err error) {
_, err = s.db.Exec(ctx, setUserAccessHashQuery, s.telegramUserID, userID, accessHash)
func (s *ScopedStore) SetAccessHash(ctx context.Context, userID, accessHash int64) (err error) {
_, err = s.db.Exec(ctx, setAccessHashQuery, s.telegramUserID, userID, accessHash)
return
}
func (s *ScopedStore) GetUserUsername(ctx context.Context, userID int64) (username string, found bool, err error) {
err = s.db.QueryRow(ctx, getUserUsernameQuery, s.telegramUserID, userID).Scan(&username)
func (s *ScopedStore) GetUsername(ctx context.Context, userID int64) (username string, err error) {
err = s.db.QueryRow(ctx, getUsernameQuery, userID).Scan(&username)
if errors.Is(err, sql.ErrNoRows) {
return "", false, nil
err = nil
}
return username, err == nil, err
}
func (s *ScopedStore) SetUserUsername(ctx context.Context, userID int64, username string) (err error) {
_, err = s.db.Exec(ctx, setUserUsernameQuery, s.telegramUserID, userID, username)
return
}
func (s *ScopedStore) GetUserMetadata(ctx context.Context, userID int64) (username string, accessHash int64, found bool, err error) {
err = s.db.QueryRow(ctx, getUserMetadataQuery, s.telegramUserID, userID).Scan(&username, &accessHash)
if errors.Is(err, sql.ErrNoRows) {
return "", 0, false, nil
func (s *ScopedStore) SetUsername(ctx context.Context, userID int64, username string) (err error) {
if username == "" {
_, err = s.db.Exec(ctx, clearUsernameQuery, userID)
} else {
_, err = s.db.Exec(ctx, setUsernameQuery, username, userID)
}
return username, accessHash, err == nil, err
}
func (s *ScopedStore) SetUserMetadata(ctx context.Context, userID int64, username string, accessHash int64) (err error) {
_, err = s.db.Exec(ctx, setUserMetadataQuery, s.telegramUserID, userID, username, accessHash)
return
}
+15 -24
View File
@@ -1,12 +1,7 @@
-- v0 -> v3: Latest revision
CREATE TABLE telegram_session (
user_id BIGINT PRIMARY KEY,
session_data BYTEA NOT NULL
);
-- v0 -> v1: Latest revision
CREATE TABLE telegram_user_state (
user_id BIGINT PRIMARY KEY,
user_id BIGINT NOT NULL PRIMARY KEY,
pts BIGINT NOT NULL,
qts BIGINT NOT NULL,
date BIGINT NOT NULL,
@@ -14,39 +9,35 @@ CREATE TABLE telegram_user_state (
);
CREATE TABLE telegram_channel_state (
user_id BIGINT,
channel_id BIGINT,
user_id BIGINT NOT NULL,
channel_id BIGINT NOT NULL,
pts BIGINT NOT NULL,
PRIMARY KEY (user_id, channel_id)
);
CREATE INDEX idx_telegram_channel_state_user_id ON telegram_channel_state (user_id);
CREATE INDEX telegram_channel_state_user_id_idx ON telegram_channel_state (user_id);
CREATE TABLE telegram_channel_access_hashes (
user_id BIGINT,
channel_id BIGINT,
CREATE TABLE telegram_access_hash (
user_id BIGINT NOT NULL,
entity_id BIGINT NOT NULL,
access_hash BIGINT NOT NULL,
PRIMARY KEY (user_id, channel_id)
PRIMARY KEY (user_id, entity_id)
);
CREATE TABLE telegram_user_metadata (
receiver_id BIGINT,
user_id BIGINT,
CREATE TABLE telegram_username (
username TEXT NOT NULL,
entity_id BIGINT NOT NULL,
access_hash BIGINT NOT NULL,
username TEXT,
PRIMARY KEY (receiver_id, user_id)
PRIMARY KEY (username)
);
CREATE INDEX telegram_username_entity_idx ON telegram_username (entity_id);
CREATE TABLE telegram_file (
id TEXT PRIMARY KEY,
mxc TEXT NOT NULL,
mime_type TEXT,
size BIGINT
);
-- TODO this will be unnecessary once the queries switch to reading telegram_user_metadata
CREATE INDEX idx_ghost_username ON ghost ((metadata->>'username'));
@@ -1,3 +0,0 @@
-- v2: Add index for ghost username metadata field
CREATE INDEX idx_ghost_username ON ghost ((metadata->>'username'));
@@ -1,15 +0,0 @@
-- v3: Move the user access hash to a table so it can be per-user
CREATE TABLE telegram_user_metadata (
receiver_id INTEGER,
user_id INTEGER,
access_hash INTEGER NOT NULL,
username TEXT,
PRIMARY KEY (receiver_id, user_id)
);
INSERT INTO telegram_user_metadata (receiver_id, user_id, access_hash, username)
SELECT ul.id, g.id, g.metadata->>'access_hash', g.metadata->>'username'
FROM user_login ul, ghost g;
+4 -8
View File
@@ -370,23 +370,19 @@ func (t *TelegramClient) inputPeerForPortalID(ctx context.Context, portalID netw
}
switch peerType {
case ids.PeerTypeUser:
if accessHash, found, err := t.ScopedStore.GetUserAccessHash(ctx, id); err != nil {
if accessHash, err := t.ScopedStore.GetAccessHash(ctx, id); err != nil {
return nil, fmt.Errorf("failed to get user access hash for %d: %w", id, err)
} else if !found {
return nil, fmt.Errorf("user access hash not found for %d", id)
} else {
return &tg.InputPeerUser{UserID: id, AccessHash: accessHash}, nil
}
case ids.PeerTypeChat:
return &tg.InputPeerChat{ChatID: id}, nil
case ids.PeerTypeChannel:
accessHash, found, err := t.ScopedStore.GetChannelAccessHash(ctx, t.telegramUserID, id)
if err != nil {
if accessHash, err := t.ScopedStore.GetAccessHash(ctx, id); err != nil {
return nil, err
} else if !found {
return nil, fmt.Errorf("channel access hash not found for %d", id)
} else {
return &tg.InputPeerChannel{ChannelID: id, AccessHash: accessHash}, nil
}
return &tg.InputPeerChannel{ChannelID: id, AccessHash: accessHash}, nil
default:
panic("invalid peer type")
}