diff --git a/pkg/connector/client.go b/pkg/connector/client.go index deaff28c..07f6a738 100644 --- a/pkg/connector/client.go +++ b/pkg/connector/client.go @@ -21,11 +21,13 @@ import ( "go.mau.fi/mautrix-telegram/pkg/connector/ids" "go.mau.fi/mautrix-telegram/pkg/connector/media" "go.mau.fi/mautrix-telegram/pkg/connector/msgconv" + "go.mau.fi/mautrix-telegram/pkg/connector/store" "go.mau.fi/mautrix-telegram/pkg/connector/util" ) type TelegramClient struct { main *TelegramConnector + ScopedStore *store.ScopedStore telegramUserID int64 loginID networkid.UserLoginID userID networkid.UserID @@ -103,13 +105,28 @@ func NewTelegramClient(ctx context.Context, tc *TelegramConnector, login *bridge UpdateDispatcher: tg.NewUpdateDispatcher(), EntityHandler: client.onEntityUpdate, } - dispatcher.OnNewMessage(client.onUpdateNewMessage) - dispatcher.OnNewChannelMessage(client.onUpdateNewChannelMessage) + dispatcher.OnNewMessage(func(ctx context.Context, e tg.Entities, update *tg.UpdateNewMessage) error { + return client.onUpdateNewMessage(ctx, update) + }) + dispatcher.OnNewChannelMessage(func(ctx context.Context, e tg.Entities, update *tg.UpdateNewChannelMessage) error { + fmt.Printf("%+v\n", update) + return client.onUpdateNewMessage(ctx, update) + }) dispatcher.OnUserName(client.onUserName) - dispatcher.OnDeleteMessages(client.onDeleteMessages) - dispatcher.OnEditMessage(client.onMessageEdit) + dispatcher.OnDeleteMessages(func(ctx context.Context, e tg.Entities, update *tg.UpdateDeleteMessages) error { + return client.onDeleteMessages(ctx, update) + }) + dispatcher.OnDeleteChannelMessages(func(ctx context.Context, e tg.Entities, update *tg.UpdateDeleteChannelMessages) error { + return client.onDeleteMessages(ctx, update) + }) + dispatcher.OnEditMessage(func(ctx context.Context, e tg.Entities, update *tg.UpdateEditMessage) error { + return client.onMessageEdit(ctx, update) + }) + dispatcher.OnEditChannelMessage(func(ctx context.Context, e tg.Entities, update *tg.UpdateEditChannelMessage) error { + return client.onMessageEdit(ctx, update) + }) - store := tc.Store.GetScopedStore(telegramUserID) + client.ScopedStore = tc.Store.GetScopedStore(telegramUserID) updatesManager := updates.New(updates.Config{ OnChannelTooLong: func(channelID int64) { @@ -118,12 +135,12 @@ func NewTelegramClient(ctx context.Context, tc *TelegramConnector, login *bridge }, Handler: dispatcher, Logger: zaplog.Named("gaps"), - Storage: store, - AccessHasher: store, + Storage: client.ScopedStore, + AccessHasher: client.ScopedStore, }) client.client = telegram.NewClient(tc.Config.AppID, tc.Config.AppHash, telegram.Options{ - SessionStorage: store, + SessionStorage: client.ScopedStore, Logger: zaplog, UpdateHandler: updatesManager, }) @@ -184,7 +201,7 @@ func (t *TelegramClient) Disconnect() { } func (t *TelegramClient) GetChatInfo(ctx context.Context, portal *bridgev2.Portal) (*bridgev2.ChatInfo, error) { - fmt.Printf("%+v\n", portal) + fmt.Printf("get chat info %+v\n", portal) peerType, id, err := ids.ParsePortalID(portal.ID) if err != nil { return nil, err @@ -253,6 +270,54 @@ func (t *TelegramClient) GetChatInfo(ctx context.Context, portal *bridgev2.Porta } } + for _, user := range fullChat.Users { + memberList.Members = append(memberList.Members, bridgev2.ChatMember{ + EventSender: bridgev2.EventSender{ + IsFromMe: user.GetID() == t.telegramUserID, + SenderLogin: ids.MakeUserLoginID(user.GetID()), + Sender: ids.MakeUserID(user.GetID()), + }, + }) + } + case ids.PeerTypeChannel: + accessHash, found, err := t.ScopedStore.GetChannelAccessHash(ctx, t.telegramUserID, id) + if err != nil { + return nil, fmt.Errorf("failed to get channel access hash: %w", err) + } else if !found { + return nil, fmt.Errorf("channel access hash not found for %d", id) + } + fullChat, err := t.client.API().ChannelsGetFullChannel(ctx, &tg.InputChannel{ChannelID: id, AccessHash: accessHash}) + if err != nil { + return nil, err + } + for _, c := range fullChat.Chats { + if c.GetID() == id { + switch chat := c.(type) { + case *tg.Chat: + name = chat.Title + case *tg.Channel: + name = chat.Title + } + break + } + } + + chatFull, ok := fullChat.FullChat.(*tg.ChatFull) + if !ok { + return nil, fmt.Errorf("full chat is not %T", chatFull) + } + + if photo, ok := chatFull.GetChatPhoto(); ok { + avatar = &bridgev2.Avatar{ + ID: ids.MakeAvatarID(photo.GetID()), + Get: func(ctx context.Context) (data []byte, err error) { + data, _, err = media.NewTransferer(t.client.API()).WithPhoto(photo).Download(ctx) + return + }, + } + } + + memberList.IsFull = false for _, user := range fullChat.Users { memberList.Members = append(memberList.Members, bridgev2.ChatMember{ EventSender: bridgev2.EventSender{ diff --git a/pkg/connector/directdownload.go b/pkg/connector/directdownload.go index f0c41065..ceaac3ee 100644 --- a/pkg/connector/directdownload.go +++ b/pkg/connector/directdownload.go @@ -49,18 +49,26 @@ func (tc *TelegramConnector) Download(ctx context.Context, mediaID networkid.Med &tg.InputMessageID{ID: int(info.MessageID)}, }) case ids.PeerTypeChannel: - // TODO test this - messages, err = client.client.API().ChannelsGetMessages(ctx, &tg.ChannelsGetMessagesRequest{ - Channel: &tg.InputChannel{ChannelID: info.ChatID}, - ID: []tg.InputMessageClass{ - &tg.InputMessageID{ID: int(info.MessageID)}, - }, - }) + var accessHash int64 + var found bool + accessHash, found, err = client.ScopedStore.GetChannelAccessHash(ctx, client.telegramUserID, info.ChatID) + if err != nil { + return nil, fmt.Errorf("failed to get channel access hash: %w", err) + } else if !found { + return nil, fmt.Errorf("channel access hash not found for %d", info.ChatID) + } else { + messages, err = client.client.API().ChannelsGetMessages(ctx, &tg.ChannelsGetMessagesRequest{ + Channel: &tg.InputChannel{ChannelID: info.ChatID, AccessHash: accessHash}, + ID: []tg.InputMessageClass{ + &tg.InputMessageID{ID: int(info.MessageID)}, + }, + }) + } default: return nil, fmt.Errorf("unknown peer type %s", info.PeerType) } if err != nil { - return nil, err + return nil, fmt.Errorf("failed to get messages for %+v: %w", info, err) } var msgMedia tg.MessageMediaClass diff --git a/pkg/connector/ids/ids.go b/pkg/connector/ids/ids.go index 3a20994a..3386fb44 100644 --- a/pkg/connector/ids/ids.go +++ b/pkg/connector/ids/ids.go @@ -92,27 +92,6 @@ func ParsePortalID(portalID networkid.PortalID) (pt PeerType, id int64, err erro return } -func InputPeerForPortalID(portalID networkid.PortalID) (tg.InputPeerClass, error) { - peerType, id, err := ParsePortalID(portalID) - if err != nil { - return nil, err - } - switch peerType { - case PeerTypeUser: - return &tg.InputPeerUser{UserID: id}, nil - case PeerTypeChat: - return &tg.InputPeerChat{ChatID: id}, nil - case PeerTypeChannel: - return &tg.InputPeerChannel{ChannelID: id}, nil - default: - panic("invalid peer type") - } -} - -func InputPeerForPortalKey(portalKey networkid.PortalKey) (tg.InputPeerClass, error) { - return InputPeerForPortalID(portalKey.ID) -} - func MakeAvatarID(photoID int64) networkid.AvatarID { return networkid.AvatarID(strconv.FormatInt(photoID, 10)) } diff --git a/pkg/connector/matrix.go b/pkg/connector/matrix.go index c84f4631..742bad6d 100644 --- a/pkg/connector/matrix.go +++ b/pkg/connector/matrix.go @@ -39,12 +39,11 @@ func getMediaFilenameAndCaption(content *event.MessageEventContent) (filename, c } func (t *TelegramClient) HandleMatrixMessage(ctx context.Context, msg *bridgev2.MatrixMessage) (resp *bridgev2.MatrixMessageResponse, err error) { - sender := message.NewSender(t.client.API()) - peer, err := ids.InputPeerForPortalID(msg.Portal.ID) + peer, err := t.inputPeerForPortalID(ctx, msg.Portal.ID) if err != nil { return nil, err } - builder := sender.To(peer) + builder := message.NewSender(t.client.API()).To(peer) // TODO handle sticker @@ -173,8 +172,13 @@ func (t *TelegramClient) HandleMatrixMessageRemove(ctx context.Context, msg *bri return err } else if messageID, err := ids.ParseMessageID(dbMsg.ID); err != nil { return err + } else if peer, err := t.inputPeerForPortalID(ctx, msg.Portal.ID); err != nil { + return err } else { - _, err = message.NewSender(t.client.API()).Self().Revoke().Messages(ctx, messageID) + _, err := message.NewSender(t.client.API()). + To(peer). + Revoke(). + Messages(ctx, messageID) return err } } @@ -224,7 +228,7 @@ func (t *TelegramClient) appendEmojiID(reactionList []tg.ReactionClass, emojiID } func (t *TelegramClient) HandleMatrixReaction(ctx context.Context, msg *bridgev2.MatrixReaction) (reaction *database.Reaction, err error) { - peer, err := ids.InputPeerForPortalID(msg.Portal.ID) + peer, err := t.inputPeerForPortalID(ctx, msg.Portal.ID) if err != nil { return nil, err } @@ -255,7 +259,7 @@ func (t *TelegramClient) HandleMatrixReaction(ctx context.Context, msg *bridgev2 } func (t *TelegramClient) HandleMatrixReactionRemove(ctx context.Context, msg *bridgev2.MatrixReactionRemove) error { - peer, err := ids.InputPeerForPortalID(msg.Portal.ID) + peer, err := t.inputPeerForPortalID(ctx, msg.Portal.ID) if err != nil { return err } diff --git a/pkg/connector/store/container.go b/pkg/connector/store/container.go index 2bb6611d..b8892097 100644 --- a/pkg/connector/store/container.go +++ b/pkg/connector/store/container.go @@ -42,6 +42,6 @@ func (c *Container) Upgrade(ctx context.Context) error { return c.Database.Upgrade(ctx) } -func (c *Container) GetScopedStore(telegramUserID int64) *scopedStore { - return &scopedStore{c.Database, telegramUserID} +func (c *Container) GetScopedStore(telegramUserID int64) *ScopedStore { + return &ScopedStore{c.Database, telegramUserID} } diff --git a/pkg/connector/store/scoped_store.go b/pkg/connector/store/scoped_store.go index f13e7f96..baa67d44 100644 --- a/pkg/connector/store/scoped_store.go +++ b/pkg/connector/store/scoped_store.go @@ -11,9 +11,9 @@ import ( "go.mau.fi/util/dbutil" ) -// scopedStore is a wrapper around a database that implements +// ScopedStore is a wrapper around a database that implements // [session.Storage] scoped to a specific Telegram user ID. -type scopedStore struct { +type ScopedStore struct { db *dbutil.Database telegramUserID int64 } @@ -60,22 +60,22 @@ const ( ` ) -var _ session.Storage = (*scopedStore)(nil) +var _ session.Storage = (*ScopedStore)(nil) -func (s *scopedStore) LoadSession(ctx context.Context) (sessionData []byte, err error) { +func (s *ScopedStore) LoadSession(ctx context.Context) (sessionData []byte, err error) { row := s.db.QueryRow(ctx, loadSessionQuery, s.telegramUserID) err = row.Scan(&sessionData) return } -func (s *scopedStore) StoreSession(ctx context.Context, data []byte) error { +func (s *ScopedStore) StoreSession(ctx context.Context, data []byte) error { _, err := s.db.Exec(ctx, storeSessionQuery, s.telegramUserID, data) return err } -var _ updates.StateStorage = (*scopedStore)(nil) +var _ updates.StateStorage = (*ScopedStore)(nil) -func (s *scopedStore) ForEachChannels(ctx context.Context, userID int64, f func(ctx context.Context, channelID int64, pts int) error) error { +func (s *ScopedStore) ForEachChannels(ctx context.Context, userID int64, f func(ctx context.Context, channelID int64, pts int) error) error { s.assertUserIDMatches(userID) rows, err := s.db.Query(ctx, allChannelsQuery, userID) if err != nil { @@ -93,7 +93,7 @@ func (s *scopedStore) ForEachChannels(ctx context.Context, userID int64, f func( return nil } -func (s *scopedStore) GetChannelPts(ctx context.Context, userID int64, channelID int64) (pts int, found bool, err error) { +func (s *ScopedStore) GetChannelPts(ctx context.Context, userID int64, channelID int64) (pts int, found bool, err error) { s.assertUserIDMatches(userID) err = s.db.QueryRow(ctx, getChannelPtsQuery, userID, channelID).Scan(&pts) if errors.Is(err, sql.ErrNoRows) { @@ -102,13 +102,13 @@ func (s *scopedStore) GetChannelPts(ctx context.Context, userID int64, channelID return pts, err == nil, err } -func (s *scopedStore) SetChannelPts(ctx context.Context, userID int64, channelID int64, pts int) (err error) { +func (s *ScopedStore) SetChannelPts(ctx context.Context, userID int64, channelID int64, pts int) (err error) { s.assertUserIDMatches(userID) _, err = s.db.Exec(ctx, setChannelPtsQuery, userID, channelID, pts) return } -func (s *scopedStore) GetState(ctx context.Context, userID int64) (state updates.State, found bool, err error) { +func (s *ScopedStore) GetState(ctx context.Context, userID int64) (state updates.State, found bool, err error) { s.assertUserIDMatches(userID) err = s.db.QueryRow(ctx, getStateQuery, userID).Scan(&state.Pts, &state.Qts, &state.Date, &state.Seq) if errors.Is(err, sql.ErrNoRows) { @@ -117,45 +117,45 @@ func (s *scopedStore) GetState(ctx context.Context, userID int64) (state updates return state, err == nil, err } -func (s *scopedStore) SetState(ctx context.Context, userID int64, state updates.State) (err error) { +func (s *ScopedStore) SetState(ctx context.Context, userID int64, state updates.State) (err error) { s.assertUserIDMatches(userID) _, err = s.db.Exec(ctx, setStateQuery, userID, state.Pts, state.Qts, state.Date, state.Seq) return } -func (s *scopedStore) SetPts(ctx context.Context, userID int64, pts int) (err error) { +func (s *ScopedStore) SetPts(ctx context.Context, userID int64, pts int) (err error) { s.assertUserIDMatches(userID) _, err = s.db.Exec(ctx, setPtsQuery, userID, pts) return } -func (s *scopedStore) SetQts(ctx context.Context, userID int64, qts int) (err error) { +func (s *ScopedStore) SetQts(ctx context.Context, userID int64, qts int) (err error) { s.assertUserIDMatches(userID) _, err = s.db.Exec(ctx, setQtsQuery, userID, qts) return } -func (s *scopedStore) SetSeq(ctx context.Context, userID int64, seq int) (err error) { +func (s *ScopedStore) SetSeq(ctx context.Context, userID int64, seq int) (err error) { s.assertUserIDMatches(userID) _, err = s.db.Exec(ctx, setSeqQuery, userID, seq) return } -func (s *scopedStore) SetDate(ctx context.Context, userID int64, date int) (err error) { +func (s *ScopedStore) SetDate(ctx context.Context, userID int64, date int) (err error) { s.assertUserIDMatches(userID) _, err = s.db.Exec(ctx, setDateQuery, userID, date) return } -func (s *scopedStore) SetDateSeq(ctx context.Context, userID int64, date int, seq int) (err error) { +func (s *ScopedStore) SetDateSeq(ctx context.Context, userID int64, date int, seq int) (err error) { s.assertUserIDMatches(userID) _, err = s.db.Exec(ctx, setDateSeqQuery, userID, date, seq) return } -var _ updates.ChannelAccessHasher = (*scopedStore)(nil) +var _ updates.ChannelAccessHasher = (*ScopedStore)(nil) -func (s *scopedStore) GetChannelAccessHash(ctx context.Context, userID int64, channelID int64) (accessHash int64, found bool, err error) { +func (s *ScopedStore) GetChannelAccessHash(ctx context.Context, userID int64, channelID int64) (accessHash int64, found bool, err error) { s.assertUserIDMatches(userID) err = s.db.QueryRow(ctx, getChannelAccessHashQuery, userID, channelID).Scan(&accessHash) if errors.Is(err, sql.ErrNoRows) { @@ -164,7 +164,7 @@ func (s *scopedStore) GetChannelAccessHash(ctx context.Context, userID int64, ch return accessHash, err == nil, err } -func (s *scopedStore) SetChannelAccessHash(ctx context.Context, userID int64, channelID int64, accessHash int64) (err error) { +func (s *ScopedStore) SetChannelAccessHash(ctx context.Context, userID int64, channelID int64, accessHash int64) (err error) { s.assertUserIDMatches(userID) _, err = s.db.Exec(ctx, setChannelAccessHashQuery, userID, channelID, accessHash) return @@ -172,7 +172,7 @@ func (s *scopedStore) SetChannelAccessHash(ctx context.Context, userID int64, ch // Helper Functions -func (s *scopedStore) assertUserIDMatches(userID int64) { +func (s *ScopedStore) assertUserIDMatches(userID int64) { if s.telegramUserID != userID { panic(fmt.Sprintf("scoped store for %d function called with user ID %d", s.telegramUserID, userID)) } diff --git a/pkg/connector/telegram.go b/pkg/connector/telegram.go index b6a0ca11..a9675257 100644 --- a/pkg/connector/telegram.go +++ b/pkg/connector/telegram.go @@ -20,7 +20,15 @@ import ( "go.mau.fi/mautrix-telegram/pkg/connector/util" ) -func (t *TelegramClient) onUpdateNewMessage(ctx context.Context, e tg.Entities, update *tg.UpdateNewMessage) error { +type IGetMessage interface { + GetMessage() tg.MessageClass +} + +type IGetMessages interface { + GetMessages() []int +} + +func (t *TelegramClient) onUpdateNewMessage(ctx context.Context, update IGetMessage) error { log := zerolog.Ctx(ctx) switch msg := update.GetMessage().(type) { case *tg.Message: @@ -40,7 +48,7 @@ func (t *TelegramClient) onUpdateNewMessage(ctx context.Context, e tg.Entities, Type: bridgev2.RemoteEventMessage, LogContext: func(c zerolog.Context) zerolog.Context { return c. - Int("message_id", update.Message.GetID()). + Int("message_id", msg.GetID()). Str("sender", string(sender.Sender)). Str("sender_login", string(sender.SenderLogin)). Bool("is_from_me", sender.IsFromMe) @@ -141,11 +149,6 @@ func (t *TelegramClient) getEventSender(msg messageWithSender) (sender bridgev2. return } -func (t *TelegramClient) onUpdateNewChannelMessage(ctx context.Context, e tg.Entities, update *tg.UpdateNewChannelMessage) error { - fmt.Printf("update new channel message %+v\n", update) - return nil -} - func (t *TelegramClient) onUserName(ctx context.Context, e tg.Entities, update *tg.UpdateUserName) error { ghost, err := t.main.Bridge.GetGhostByID(ctx, ids.MakeUserID(update.UserID)) if err != nil { @@ -159,8 +162,8 @@ func (t *TelegramClient) onUserName(ctx context.Context, e tg.Entities, update * return nil } -func (t *TelegramClient) onDeleteMessages(ctx context.Context, e tg.Entities, update *tg.UpdateDeleteMessages) error { - for _, messageID := range update.Messages { +func (t *TelegramClient) onDeleteMessages(ctx context.Context, update IGetMessages) error { + for _, messageID := range update.GetMessages() { parts, err := t.main.Bridge.DB.Message.GetAllPartsByID(ctx, t.loginID, ids.MakeMessageID(messageID)) if err != nil { return err @@ -198,9 +201,9 @@ func (t *TelegramClient) onEntityUpdate(ctx context.Context, e tg.Entities) erro return nil } -func (t *TelegramClient) onMessageEdit(ctx context.Context, e tg.Entities, update *tg.UpdateEditMessage) error { +func (t *TelegramClient) onMessageEdit(ctx context.Context, update IGetMessage) error { fmt.Printf("message edit %+v\n", update) - msg, ok := update.Message.(*tg.Message) + msg, ok := update.GetMessage().(*tg.Message) if !ok { return fmt.Errorf("edit message is not *tg.Message") } @@ -261,7 +264,7 @@ func (t *TelegramClient) handleTelegramReactions(ctx context.Context, msg *tg.Me // return // TODO should calls to this be limited? - } else if peer, err := ids.InputPeerForPortalKey(ids.MakePortalKey(msg.PeerID)); err != nil { + } else if peer, err := t.inputPeerForPortalID(ctx, ids.MakePortalKey(msg.PeerID).ID); err != nil { return err } else { reactions, err := t.client.API().MessagesGetMessageReactionsList(ctx, &tg.MessagesGetMessageReactionsListRequest{ @@ -300,6 +303,29 @@ func (t *TelegramClient) handleTelegramReactions(ctx context.Context, msg *tg.Me return t.handleTelegramParsedReactionsLocked(ctx, dbMsg, reactions, customEmojiIDs, isFull, nil, nil) } +func (t *TelegramClient) inputPeerForPortalID(ctx context.Context, portalID networkid.PortalID) (tg.InputPeerClass, error) { + peerType, id, err := ids.ParsePortalID(portalID) + if err != nil { + return nil, err + } + switch peerType { + case ids.PeerTypeUser: + return &tg.InputPeerUser{UserID: id}, nil + case ids.PeerTypeChat: + return &tg.InputPeerChat{ChatID: id}, nil + case ids.PeerTypeChannel: + accessHash, found, err := t.ScopedStore.GetChannelAccessHash(ctx, t.telegramUserID, id) + if err != nil { + return nil, err + } else if !found { + return nil, fmt.Errorf("channel access hash not found for %d", id) + } + return &tg.InputPeerChannel{ChannelID: id, AccessHash: accessHash}, nil + default: + panic("invalid peer type") + } +} + func splitDMReactionCounts(res []tg.ReactionCount, theirUserID, myUserID int64) (reactions []tg.MessagePeerReaction) { for _, item := range res { if item.Count == 2 || item.ChosenOrder > 0 {