treewide: separate user and channel namespaces

Signed-off-by: Sumner Evans <sumner.evans@automattic.com>
This commit is contained in:
Sumner Evans
2024-09-24 14:40:41 -06:00
parent 65da56b2a6
commit c6e96682b6
15 changed files with 132 additions and 60 deletions
+7 -3
View File
@@ -181,9 +181,10 @@ SELECT
FROM reaction_old
INNER JOIN message ON reaction_old.msg_mxid=message.mxid;
INSERT INTO telegram_access_hash (user_id, entity_id, access_hash)
INSERT INTO telegram_access_hash (user_id, entity_type, entity_id, access_hash)
SELECT
user_old.tgid,
CASE WHEN id < 0 THEN 'channel' ELSE 'user' END,
CASE WHEN id < 0 THEN -id - 1000000000000 ELSE id END,
hash
FROM telethon_entities_old
@@ -202,8 +203,11 @@ FROM telethon_update_state_old
LEFT JOIN user_old ON user_old.mxid=session_id
WHERE entity_id<>0 AND user_old.tgid IS NOT NULL;
INSERT INTO telegram_username (username, entity_id)
SELECT username, CASE WHEN id < 0 THEN -id - 1000000000000 ELSE id END
INSERT INTO telegram_username (username, entity_type, entity_id)
SELECT
username,
CASE WHEN id < 0 THEN 'channel' ELSE 'user' END,
CASE WHEN id < 0 THEN -id - 1000000000000 ELSE id END
FROM telethon_entities_old
WHERE username<>''
ON CONFLICT DO NOTHING;
+5 -1
View File
@@ -289,11 +289,15 @@ func legacyProvContacts(w http.ResponseWriter, r *http.Request) {
contactsMap := map[int64]*legacyContactInfo{}
for _, contact := range contacts {
id, err := ids.ParseUserID(contact.UserID)
peerType, id, err := ids.ParseUserID(contact.UserID)
if err != nil {
log.Err(err).Msg("Failed to parse user id")
exhttp.WriteJSONResponse(w, http.StatusInternalServerError, resp.WithError("M_UNKNOWN", fmt.Sprintf("Failed to parse user id: %v", err)))
return
} else if peerType != ids.PeerTypeUser {
log.Err(err).Msg("Unexpected peer type")
exhttp.WriteJSONResponse(w, http.StatusInternalServerError, resp.WithError("M_UNKNOWN", fmt.Sprintf("Unexpected peer type: %s", peerType)))
return
}
if contact.UserInfo != nil {
contact.Ghost.UpdateInfo(ctx, contact.UserInfo)
+1 -1
View File
@@ -108,7 +108,7 @@ func main() {
"v0.16.0",
m.LegacyMigrateWithAnotherUpgrader(
legacyMigrateRenameTables, legacyMigrateCopyData, 16,
upgrades.Table, "telegram_version", 1,
upgrades.Table, "telegram_version", 2,
),
true,
)
+4 -3
View File
@@ -182,7 +182,7 @@ func (t *TelegramClient) GetChatInfo(ctx context.Context, portal *bridgev2.Porta
}
return chatInfo, nil
case ids.PeerTypeChannel:
accessHash, err := t.ScopedStore.GetAccessHash(ctx, id)
accessHash, err := t.ScopedStore.GetAccessHash(ctx, ids.PeerTypeChannel, id)
if err != nil {
return nil, fmt.Errorf("failed to get channel access hash: %w", err)
}
@@ -218,10 +218,11 @@ func (t *TelegramClient) GetChatInfo(ctx context.Context, portal *bridgev2.Porta
chatInfo.Members.IsFull = false
if !portal.Metadata.(*PortalMetadata).IsSuperGroup {
// Add the channel user
chatInfo.Members.MemberMap[ids.MakeUserID(id)] = bridgev2.ChatMember{
sender := ids.MakeUserID(id)
chatInfo.Members.MemberMap[sender] = bridgev2.ChatMember{
EventSender: bridgev2.EventSender{
SenderLogin: ids.MakeUserLoginID(id),
Sender: ids.MakeUserID(id),
Sender: sender,
},
}
}
+13 -11
View File
@@ -225,8 +225,10 @@ func NewTelegramClient(ctx context.Context, tc *TelegramConnector, login *bridge
return userInfo, nil
},
GetUserInfoByUsername: func(ctx context.Context, username string) (telegramfmt.UserInfo, error) {
if userID, err := client.ScopedStore.GetUserIDByUsername(ctx, username); err != nil {
if peerType, userID, err := client.ScopedStore.GetUserIDByUsername(ctx, username); err != nil {
return telegramfmt.UserInfo{}, err
} else if peerType != ids.PeerTypeUser {
return telegramfmt.UserInfo{}, fmt.Errorf("unexpected peer type: %s", peerType)
} else if ghost, err := tc.Bridge.GetGhostByID(ctx, ids.MakeUserID(userID)); err != nil {
return telegramfmt.UserInfo{}, err
} else {
@@ -297,11 +299,11 @@ func NewTelegramClient(ctx context.Context, tc *TelegramConnector, login *bridge
GetGhostDetails: func(ctx context.Context, ui id.UserID) (networkid.UserID, string, int64, bool) {
if userID, ok := tc.Bridge.Matrix.ParseGhostMXID(ui); !ok {
return "", "", 0, false
} else if telegramUserID, err := ids.ParseUserID(userID); err != nil {
} else if peerType, telegramUserID, err := ids.ParseUserID(userID); err != nil {
return "", "", 0, false
} else if accessHash, err := client.ScopedStore.GetAccessHash(ctx, telegramUserID); err != nil || accessHash == 0 {
} else if accessHash, err := client.ScopedStore.GetAccessHash(ctx, peerType, telegramUserID); err != nil || accessHash == 0 {
return "", "", 0, false
} else if username, err := client.ScopedStore.GetUsername(ctx, telegramUserID); err != nil {
} else if username, err := client.ScopedStore.GetUsername(ctx, peerType, telegramUserID); err != nil {
return "", "", 0, false
} else {
return userID, username, accessHash, true
@@ -407,7 +409,7 @@ func (t *TelegramClient) Disconnect() {
}
func (t *TelegramClient) getInputUser(ctx context.Context, id int64) (*tg.InputUser, error) {
accessHash, err := t.ScopedStore.GetAccessHash(ctx, id)
accessHash, err := t.ScopedStore.GetAccessHash(ctx, ids.PeerTypeUser, id)
if err != nil {
return nil, fmt.Errorf("failed to get access hash for user %d: %w", id, err)
}
@@ -428,11 +430,11 @@ func (t *TelegramClient) getSingleUser(ctx context.Context, id int64) (tg.UserCl
}
func (t *TelegramClient) GetUserInfo(ctx context.Context, ghost *bridgev2.Ghost) (*bridgev2.UserInfo, error) {
id, err := ids.ParseUserID(ghost.ID)
if err != nil {
if peerType, id, err := ids.ParseUserID(ghost.ID); err != nil {
return nil, err
}
if user, err := t.getSingleUser(ctx, id); err != nil {
} else if peerType != ids.PeerTypeUser {
return nil, fmt.Errorf("unexpected peer type: %s", peerType)
} else if user, err := t.getSingleUser(ctx, id); err != nil {
return nil, fmt.Errorf("failed to get user %d: %w", id, err)
} else if user.TypeID() != tg.UserTypeID {
return nil, err
@@ -447,11 +449,11 @@ func (t *TelegramClient) getUserInfoFromTelegramUser(ctx context.Context, u tg.U
return nil, fmt.Errorf("user is %T not *tg.User", user)
}
var identifiers []string
if err := t.ScopedStore.SetAccessHash(ctx, user.ID, user.AccessHash); err != nil {
if err := t.ScopedStore.SetAccessHash(ctx, ids.PeerTypeUser, user.ID, user.AccessHash); err != nil {
return nil, err
}
if !user.Min {
if err := t.ScopedStore.SetUsername(ctx, user.ID, user.Username); err != nil {
if err := t.ScopedStore.SetUsername(ctx, ids.PeerTypeUser, user.ID, user.Username); err != nil {
return nil, err
}
+1 -1
View File
@@ -65,7 +65,7 @@ func (tc *TelegramConnector) Download(ctx context.Context, mediaID networkid.Med
})
case ids.PeerTypeChannel:
var accessHash int64
accessHash, err = client.ScopedStore.GetAccessHash(ctx, info.ChatID)
accessHash, err = client.ScopedStore.GetAccessHash(ctx, ids.PeerTypeChannel, info.ChatID)
if err != nil {
return nil, fmt.Errorf("failed to get channel access hash: %w", err)
} else {
+13 -2
View File
@@ -13,8 +13,19 @@ func MakeUserID(userID int64) networkid.UserID {
return networkid.UserID(strconv.FormatInt(userID, 10))
}
func ParseUserID(userID networkid.UserID) (int64, error) {
return strconv.ParseInt(string(userID), 10, 64)
func MakeChannelUserID(channelID int64) networkid.UserID {
return networkid.UserID("channel-" + strconv.FormatInt(channelID, 10))
}
func ParseUserID(userID networkid.UserID) (PeerType, int64, error) {
peerType := PeerTypeUser
rawUserID := string(userID)
if strings.HasPrefix(string(userID), "channel-") {
peerType = PeerTypeChannel
rawUserID = strings.TrimPrefix(rawUserID, "channel-")
}
id, err := strconv.ParseInt(rawUserID, 10, 64)
return peerType, id, err
}
func ParseUserLoginID(userID networkid.UserLoginID) (int64, error) {
+1 -1
View File
@@ -457,7 +457,7 @@ func (t *TelegramClient) HandleMatrixReadReceipt(ctx context.Context, msg *bridg
})
case ids.PeerTypeChannel:
var accessHash int64
accessHash, readMessagesErr = t.ScopedStore.GetAccessHash(ctx, id)
accessHash, readMessagesErr = t.ScopedStore.GetAccessHash(ctx, ids.PeerTypeChannel, id)
if readMessagesErr != nil {
return
}
+5 -1
View File
@@ -18,6 +18,7 @@ package matrixfmt
import (
"context"
"fmt"
"github.com/gotd/td/tg"
"maunium.net/go/mautrix/event"
@@ -32,7 +33,10 @@ func toTelegramEntity(br telegramfmt.BodyRange) tg.MessageEntityClass {
if val.Username != "" {
return &tg.MessageEntityMention{Offset: br.Start, Length: br.Length}
} else {
userID, _ := ids.ParseUserID(val.UserID)
peerType, userID, _ := ids.ParseUserID(val.UserID)
if peerType != ids.PeerTypeUser {
panic(fmt.Errorf("unexpected peer type in mention %T", peerType))
}
return &tg.InputMessageEntityMentionName{
Offset: br.Start,
Length: br.Length,
+1 -1
View File
@@ -195,7 +195,7 @@ func (t *Transferer) WithPhoto(pc tg.PhotoClass) *ReadyTransferer {
// given user's photo as the location that will be downloaded by the
// [ReadyTransferer].
func (t *Transferer) WithUserPhoto(ctx context.Context, store *store.ScopedStore, user *tg.User, photoID int64) (*ReadyTransferer, error) {
if accessHash, err := store.GetAccessHash(ctx, user.GetID()); err != nil {
if accessHash, err := store.GetAccessHash(ctx, ids.PeerTypeUser, user.GetID()); err != nil {
return nil, fmt.Errorf("failed to get user access hash for %d: %w", user.GetID(), err)
} else {
return &ReadyTransferer{
+7 -3
View File
@@ -81,10 +81,12 @@ func (t *TelegramClient) ResolveIdentifier(ctx context.Context, identifier strin
return t.getResolveIdentifierResponseForUserID(ctx, userID)
} else if match := usernameRe.FindStringSubmatch(identifier); match != nil && !strings.Contains(identifier, "__") {
// This is a username
userID, err := t.ScopedStore.GetUserIDByUsername(ctx, match[1])
if err == nil || userID != 0 {
entityType, userID, err := t.ScopedStore.GetUserIDByUsername(ctx, match[1])
if entityType == ids.PeerTypeUser && (err == nil || userID != 0) {
// We know this username.
return t.getResolveIdentifierResponseForUserID(ctx, userID)
} else if entityType != ids.PeerTypeUser {
return nil, fmt.Errorf("unexpected peer type: %s", entityType)
} else {
// We don't know this username, try to resolve the username from
// Telegram.
@@ -205,8 +207,10 @@ func (t *TelegramClient) CreateGroup(ctx context.Context, name string, users ...
Title: name,
}
for _, networkUserID := range users {
if userID, err := ids.ParseUserID(networkUserID); err != nil {
if peerType, userID, err := ids.ParseUserID(networkUserID); err != nil {
return nil, fmt.Errorf("failed to parse user ID: %w", err)
} else if peerType != ids.PeerTypeUser {
return nil, fmt.Errorf("unexpected peer type: %s", peerType)
} else if inputUser, err := t.getInputUser(ctx, userID); err != nil {
return nil, fmt.Errorf("failed to get input user: %w", err)
} else {
+27 -23
View File
@@ -8,6 +8,8 @@ import (
"github.com/gotd/td/telegram/updates"
"go.mau.fi/util/dbutil"
"go.mau.fi/mautrix-telegram/pkg/connector/ids"
)
// ScopedStore is a wrapper around a database that implements
@@ -42,22 +44,24 @@ const (
setSeqQuery = "UPDATE telegram_user_state SET seq=$1 WHERE user_id=$2"
setDateSeqQuery = "UPDATE telegram_user_state SET date=$1, seq=$2 WHERE user_id=$3"
getAccessHashQuery = "SELECT access_hash FROM telegram_access_hash WHERE user_id=$1 AND entity_id=$2"
getAccessHashQuery = "SELECT access_hash FROM telegram_access_hash WHERE user_id=$1 AND entity_type=$2 AND entity_id=$3"
setAccessHashQuery = `
INSERT INTO telegram_access_hash (user_id, entity_id, access_hash)
VALUES ($1, $2, $3)
ON CONFLICT (user_id, entity_id) DO UPDATE SET access_hash=excluded.access_hash
INSERT INTO telegram_access_hash (user_id, entity_type, entity_id, access_hash)
VALUES ($1, $2, $3, $4)
ON CONFLICT (user_id, entity_type, entity_id) DO UPDATE SET access_hash=excluded.access_hash
`
// User Username Queries
getUsernameQuery = "SELECT username FROM telegram_username WHERE entity_id=$1"
getUsernameQuery = "SELECT username FROM telegram_username WHERE entity_type=$1 AND entity_id=$2"
setUsernameQuery = `
INSERT INTO telegram_username (username, entity_id)
VALUES ($1, $2)
ON CONFLICT (username) DO UPDATE SET entity_id=excluded.entity_id
INSERT INTO telegram_username (username, entity_type, entity_id)
VALUES ($1, $2, $3)
ON CONFLICT (username) DO UPDATE SET
entity_type=excluded.entity_type,
entity_id=excluded.entity_id
`
getByUsernameQuery = "SELECT entity_id FROM telegram_username WHERE LOWER(username)=$1"
clearUsernameQuery = `DELETE FROM telegram_username WHERE entity_id=$1`
getByUsernameQuery = "SELECT entity_type, entity_id FROM telegram_username WHERE LOWER(username)=$1"
clearUsernameQuery = `DELETE FROM telegram_username WHERE entity_type=$1 AND entity_id=$2`
// User Phone Number Queries
getEntityIDForPhoneNumber = "SELECT entity_id FROM telegram_phone_number WHERE phone_number=$1"
@@ -154,7 +158,7 @@ var _ updates.ChannelAccessHasher = (*ScopedStore)(nil)
// Deprecated: only for interface, don't use directly. Use GetAccessHash instead
func (s *ScopedStore) GetChannelAccessHash(ctx context.Context, userID, channelID int64) (accessHash int64, found bool, err error) {
s.assertUserIDMatches(userID)
accessHash, err = s.GetAccessHash(ctx, channelID)
accessHash, err = s.GetAccessHash(ctx, ids.PeerTypeChannel, channelID)
found = accessHash != 0
return
}
@@ -162,43 +166,43 @@ func (s *ScopedStore) GetChannelAccessHash(ctx context.Context, userID, channelI
// Deprecated: only for interface, don't use directly. Use SetAccessHash instead
func (s *ScopedStore) SetChannelAccessHash(ctx context.Context, userID, channelID, accessHash int64) (err error) {
s.assertUserIDMatches(userID)
return s.SetAccessHash(ctx, channelID, accessHash)
return s.SetAccessHash(ctx, ids.PeerTypeChannel, channelID, accessHash)
}
var ErrNoAccessHash = errors.New("access hash not found")
func (s *ScopedStore) GetAccessHash(ctx context.Context, entityID int64) (accessHash int64, err error) {
err = s.db.QueryRow(ctx, getAccessHashQuery, s.telegramUserID, entityID).Scan(&accessHash)
func (s *ScopedStore) GetAccessHash(ctx context.Context, entityType ids.PeerType, entityID int64) (accessHash int64, err error) {
err = s.db.QueryRow(ctx, getAccessHashQuery, s.telegramUserID, entityType, entityID).Scan(&accessHash)
if errors.Is(err, sql.ErrNoRows) {
err = ErrNoAccessHash
}
return
}
func (s *ScopedStore) SetAccessHash(ctx context.Context, entityID, accessHash int64) (err error) {
_, err = s.db.Exec(ctx, setAccessHashQuery, s.telegramUserID, entityID, accessHash)
func (s *ScopedStore) SetAccessHash(ctx context.Context, entityType ids.PeerType, entityID, accessHash int64) (err error) {
_, err = s.db.Exec(ctx, setAccessHashQuery, s.telegramUserID, entityType, entityID, accessHash)
return
}
func (s *ScopedStore) GetUsername(ctx context.Context, userID int64) (username string, err error) {
err = s.db.QueryRow(ctx, getUsernameQuery, userID).Scan(&username)
func (s *ScopedStore) GetUsername(ctx context.Context, entityType ids.PeerType, userID int64) (username string, err error) {
err = s.db.QueryRow(ctx, getUsernameQuery, entityType, userID).Scan(&username)
if errors.Is(err, sql.ErrNoRows) {
err = nil
}
return
}
func (s *ScopedStore) SetUsername(ctx context.Context, userID int64, username string) (err error) {
func (s *ScopedStore) SetUsername(ctx context.Context, entityType ids.PeerType, userID int64, username string) (err error) {
if username == "" {
_, err = s.db.Exec(ctx, clearUsernameQuery, userID)
_, err = s.db.Exec(ctx, clearUsernameQuery, entityType, userID)
} else {
_, err = s.db.Exec(ctx, setUsernameQuery, username, userID)
_, err = s.db.Exec(ctx, setUsernameQuery, username, entityType, userID)
}
return
}
func (s *ScopedStore) GetUserIDByUsername(ctx context.Context, username string) (userID int64, err error) {
err = s.db.QueryRow(ctx, getByUsernameQuery, username).Scan(&userID)
func (s *ScopedStore) GetUserIDByUsername(ctx context.Context, username string) (entityType ids.PeerType, userID int64, err error) {
err = s.db.QueryRow(ctx, getByUsernameQuery, username).Scan(&entityType, &userID)
if errors.Is(err, sql.ErrNoRows) {
err = nil
}
+6 -4
View File
@@ -1,4 +1,4 @@
-- v0 -> v1: Latest revision
-- v0 -> v2: Latest revision
CREATE TABLE telegram_user_state (
user_id BIGINT NOT NULL PRIMARY KEY,
@@ -20,15 +20,17 @@ CREATE INDEX telegram_channel_state_user_id_idx ON telegram_channel_state (user_
CREATE TABLE telegram_access_hash (
user_id BIGINT NOT NULL,
entity_type TEXT NOT NULL,
entity_id BIGINT NOT NULL,
access_hash BIGINT NOT NULL,
PRIMARY KEY (user_id, entity_id)
PRIMARY KEY (user_id, entity_type, entity_id)
);
CREATE TABLE telegram_username (
username TEXT NOT NULL,
entity_id BIGINT NOT NULL,
username TEXT NOT NULL,
entity_type TEXT NOT NULL,
entity_id BIGINT NOT NULL,
PRIMARY KEY (username)
);
@@ -0,0 +1,36 @@
-- v2: Separate users and channels into separate namespaces
ALTER TABLE telegram_access_hash RENAME TO telegram_access_hash_old;
ALTER TABLE telegram_username RENAME TO telegram_username_old;
CREATE TABLE telegram_access_hash (
user_id BIGINT NOT NULL,
entity_type TEXT NOT NULL,
entity_id BIGINT NOT NULL,
access_hash BIGINT NOT NULL,
PRIMARY KEY (user_id, entity_type, entity_id)
);
CREATE TABLE telegram_username (
username TEXT NOT NULL,
entity_type TEXT NOT NULL,
entity_id BIGINT NOT NULL,
PRIMARY KEY (username)
);
INSERT INTO telegram_access_hash (user_id, entity_type, entity_id, access_hash)
SELECT user_id, 'user', entity_id, access_hash
FROM telegram_access_hash_old;
INSERT INTO telegram_access_hash (user_id, entity_type, entity_id, access_hash)
SELECT user_id, 'channel', entity_id, access_hash
FROM telegram_access_hash_old;
INSERT INTO telegram_username (username, entity_type, entity_id)
SELECT username, 'user', entity_id
FROM telegram_username_old;
DROP TABLE telegram_access_hash_old;
DROP table telegram_username_old;
+5 -5
View File
@@ -195,7 +195,7 @@ func (t *TelegramClient) getEventSender(msg interface {
}
case *tg.PeerChannel:
return bridgev2.EventSender{
Sender: ids.MakeUserID(from.ChannelID),
Sender: ids.MakeChannelUserID(from.ChannelID),
}
default:
fromID, _ := msg.GetFromID()
@@ -325,7 +325,7 @@ func (t *TelegramClient) updateGhost(ctx context.Context, userID int64, user *tg
}
func (t *TelegramClient) updateChannel(ctx context.Context, channel *tg.Channel) error {
if err := t.ScopedStore.SetAccessHash(ctx, channel.ID, channel.AccessHash); err != nil {
if err := t.ScopedStore.SetAccessHash(ctx, ids.PeerTypeChannel, channel.ID, channel.AccessHash); err != nil {
return err
}
@@ -334,7 +334,7 @@ func (t *TelegramClient) updateChannel(ctx context.Context, channel *tg.Channel)
}
// Update the channel ghost if this is a broadcast channel.
ghost, err := t.main.Bridge.GetGhostByID(ctx, ids.MakeUserID(channel.ID))
ghost, err := t.main.Bridge.GetGhostByID(ctx, ids.MakeChannelUserID(channel.ID))
if err != nil {
return err
}
@@ -496,7 +496,7 @@ func (t *TelegramClient) inputPeerForPortalID(ctx context.Context, portalID netw
}
switch peerType {
case ids.PeerTypeUser:
if accessHash, err := t.ScopedStore.GetAccessHash(ctx, id); err != nil {
if accessHash, err := t.ScopedStore.GetAccessHash(ctx, ids.PeerTypeUser, id); err != nil {
return nil, fmt.Errorf("failed to get user access hash for %d: %w", id, err)
} else {
return &tg.InputPeerUser{UserID: id, AccessHash: accessHash}, nil
@@ -504,7 +504,7 @@ func (t *TelegramClient) inputPeerForPortalID(ctx context.Context, portalID netw
case ids.PeerTypeChat:
return &tg.InputPeerChat{ChatID: id}, nil
case ids.PeerTypeChannel:
if accessHash, err := t.ScopedStore.GetAccessHash(ctx, id); err != nil {
if accessHash, err := t.ScopedStore.GetAccessHash(ctx, ids.PeerTypeChannel, id); err != nil {
return nil, err
} else {
return &tg.InputPeerChannel{ChannelID: id, AccessHash: accessHash}, nil