channels: handle messages Matrix <-> TG
Signed-off-by: Sumner Evans <sumner.evans@automattic.com>
This commit is contained in:
@@ -42,6 +42,6 @@ func (c *Container) Upgrade(ctx context.Context) error {
|
||||
return c.Database.Upgrade(ctx)
|
||||
}
|
||||
|
||||
func (c *Container) GetScopedStore(telegramUserID int64) *scopedStore {
|
||||
return &scopedStore{c.Database, telegramUserID}
|
||||
func (c *Container) GetScopedStore(telegramUserID int64) *ScopedStore {
|
||||
return &ScopedStore{c.Database, telegramUserID}
|
||||
}
|
||||
|
||||
@@ -11,9 +11,9 @@ import (
|
||||
"go.mau.fi/util/dbutil"
|
||||
)
|
||||
|
||||
// scopedStore is a wrapper around a database that implements
|
||||
// ScopedStore is a wrapper around a database that implements
|
||||
// [session.Storage] scoped to a specific Telegram user ID.
|
||||
type scopedStore struct {
|
||||
type ScopedStore struct {
|
||||
db *dbutil.Database
|
||||
telegramUserID int64
|
||||
}
|
||||
@@ -60,22 +60,22 @@ const (
|
||||
`
|
||||
)
|
||||
|
||||
var _ session.Storage = (*scopedStore)(nil)
|
||||
var _ session.Storage = (*ScopedStore)(nil)
|
||||
|
||||
func (s *scopedStore) LoadSession(ctx context.Context) (sessionData []byte, err error) {
|
||||
func (s *ScopedStore) LoadSession(ctx context.Context) (sessionData []byte, err error) {
|
||||
row := s.db.QueryRow(ctx, loadSessionQuery, s.telegramUserID)
|
||||
err = row.Scan(&sessionData)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *scopedStore) StoreSession(ctx context.Context, data []byte) error {
|
||||
func (s *ScopedStore) StoreSession(ctx context.Context, data []byte) error {
|
||||
_, err := s.db.Exec(ctx, storeSessionQuery, s.telegramUserID, data)
|
||||
return err
|
||||
}
|
||||
|
||||
var _ updates.StateStorage = (*scopedStore)(nil)
|
||||
var _ updates.StateStorage = (*ScopedStore)(nil)
|
||||
|
||||
func (s *scopedStore) ForEachChannels(ctx context.Context, userID int64, f func(ctx context.Context, channelID int64, pts int) error) error {
|
||||
func (s *ScopedStore) ForEachChannels(ctx context.Context, userID int64, f func(ctx context.Context, channelID int64, pts int) error) error {
|
||||
s.assertUserIDMatches(userID)
|
||||
rows, err := s.db.Query(ctx, allChannelsQuery, userID)
|
||||
if err != nil {
|
||||
@@ -93,7 +93,7 @@ func (s *scopedStore) ForEachChannels(ctx context.Context, userID int64, f func(
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *scopedStore) GetChannelPts(ctx context.Context, userID int64, channelID int64) (pts int, found bool, err error) {
|
||||
func (s *ScopedStore) GetChannelPts(ctx context.Context, userID int64, channelID int64) (pts int, found bool, err error) {
|
||||
s.assertUserIDMatches(userID)
|
||||
err = s.db.QueryRow(ctx, getChannelPtsQuery, userID, channelID).Scan(&pts)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
@@ -102,13 +102,13 @@ func (s *scopedStore) GetChannelPts(ctx context.Context, userID int64, channelID
|
||||
return pts, err == nil, err
|
||||
}
|
||||
|
||||
func (s *scopedStore) SetChannelPts(ctx context.Context, userID int64, channelID int64, pts int) (err error) {
|
||||
func (s *ScopedStore) SetChannelPts(ctx context.Context, userID int64, channelID int64, pts int) (err error) {
|
||||
s.assertUserIDMatches(userID)
|
||||
_, err = s.db.Exec(ctx, setChannelPtsQuery, userID, channelID, pts)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *scopedStore) GetState(ctx context.Context, userID int64) (state updates.State, found bool, err error) {
|
||||
func (s *ScopedStore) GetState(ctx context.Context, userID int64) (state updates.State, found bool, err error) {
|
||||
s.assertUserIDMatches(userID)
|
||||
err = s.db.QueryRow(ctx, getStateQuery, userID).Scan(&state.Pts, &state.Qts, &state.Date, &state.Seq)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
@@ -117,45 +117,45 @@ func (s *scopedStore) GetState(ctx context.Context, userID int64) (state updates
|
||||
return state, err == nil, err
|
||||
}
|
||||
|
||||
func (s *scopedStore) SetState(ctx context.Context, userID int64, state updates.State) (err error) {
|
||||
func (s *ScopedStore) SetState(ctx context.Context, userID int64, state updates.State) (err error) {
|
||||
s.assertUserIDMatches(userID)
|
||||
_, err = s.db.Exec(ctx, setStateQuery, userID, state.Pts, state.Qts, state.Date, state.Seq)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *scopedStore) SetPts(ctx context.Context, userID int64, pts int) (err error) {
|
||||
func (s *ScopedStore) SetPts(ctx context.Context, userID int64, pts int) (err error) {
|
||||
s.assertUserIDMatches(userID)
|
||||
_, err = s.db.Exec(ctx, setPtsQuery, userID, pts)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *scopedStore) SetQts(ctx context.Context, userID int64, qts int) (err error) {
|
||||
func (s *ScopedStore) SetQts(ctx context.Context, userID int64, qts int) (err error) {
|
||||
s.assertUserIDMatches(userID)
|
||||
_, err = s.db.Exec(ctx, setQtsQuery, userID, qts)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *scopedStore) SetSeq(ctx context.Context, userID int64, seq int) (err error) {
|
||||
func (s *ScopedStore) SetSeq(ctx context.Context, userID int64, seq int) (err error) {
|
||||
s.assertUserIDMatches(userID)
|
||||
_, err = s.db.Exec(ctx, setSeqQuery, userID, seq)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *scopedStore) SetDate(ctx context.Context, userID int64, date int) (err error) {
|
||||
func (s *ScopedStore) SetDate(ctx context.Context, userID int64, date int) (err error) {
|
||||
s.assertUserIDMatches(userID)
|
||||
_, err = s.db.Exec(ctx, setDateQuery, userID, date)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *scopedStore) SetDateSeq(ctx context.Context, userID int64, date int, seq int) (err error) {
|
||||
func (s *ScopedStore) SetDateSeq(ctx context.Context, userID int64, date int, seq int) (err error) {
|
||||
s.assertUserIDMatches(userID)
|
||||
_, err = s.db.Exec(ctx, setDateSeqQuery, userID, date, seq)
|
||||
return
|
||||
}
|
||||
|
||||
var _ updates.ChannelAccessHasher = (*scopedStore)(nil)
|
||||
var _ updates.ChannelAccessHasher = (*ScopedStore)(nil)
|
||||
|
||||
func (s *scopedStore) GetChannelAccessHash(ctx context.Context, userID int64, channelID int64) (accessHash int64, found bool, err error) {
|
||||
func (s *ScopedStore) GetChannelAccessHash(ctx context.Context, userID int64, channelID int64) (accessHash int64, found bool, err error) {
|
||||
s.assertUserIDMatches(userID)
|
||||
err = s.db.QueryRow(ctx, getChannelAccessHashQuery, userID, channelID).Scan(&accessHash)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
@@ -164,7 +164,7 @@ func (s *scopedStore) GetChannelAccessHash(ctx context.Context, userID int64, ch
|
||||
return accessHash, err == nil, err
|
||||
}
|
||||
|
||||
func (s *scopedStore) SetChannelAccessHash(ctx context.Context, userID int64, channelID int64, accessHash int64) (err error) {
|
||||
func (s *ScopedStore) SetChannelAccessHash(ctx context.Context, userID int64, channelID int64, accessHash int64) (err error) {
|
||||
s.assertUserIDMatches(userID)
|
||||
_, err = s.db.Exec(ctx, setChannelAccessHashQuery, userID, channelID, accessHash)
|
||||
return
|
||||
@@ -172,7 +172,7 @@ func (s *scopedStore) SetChannelAccessHash(ctx context.Context, userID int64, ch
|
||||
|
||||
// Helper Functions
|
||||
|
||||
func (s *scopedStore) assertUserIDMatches(userID int64) {
|
||||
func (s *ScopedStore) assertUserIDMatches(userID int64) {
|
||||
if s.telegramUserID != userID {
|
||||
panic(fmt.Sprintf("scoped store for %d function called with user ID %d", s.telegramUserID, userID))
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user