move gotd fork into repo. (#111)
- update to latest telegram layer - remove some references to fields in tg.Entities that don't exist in the schema - originally added here: https://github.com/beeper/td/commit/820929062a2ba0104397bc01235ab58a9cff780e - referenced here - https://github.com/mautrix/telegramgo/commit/124f0967ed195b5a380c9bd02e170ada9710dde3 - https://github.com/mautrix/telegramgo/commit/4205047aab2e0639217148b5d125bfaab668bd8e
This commit is contained in:
@@ -0,0 +1,36 @@
|
||||
package cached
|
||||
|
||||
import (
|
||||
"sort"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/query/hasher"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
func (s *ContactsGetContacts) computeHash(v *tg.ContactsContacts) int64 {
|
||||
cts := v.Contacts
|
||||
|
||||
sort.SliceStable(cts, func(i, j int) bool {
|
||||
return cts[i].UserID < cts[j].UserID
|
||||
})
|
||||
h := hasher.Hasher{}
|
||||
for _, contact := range cts {
|
||||
h.Update(uint32(contact.UserID))
|
||||
}
|
||||
|
||||
return h.Sum()
|
||||
}
|
||||
|
||||
func (s *MessagesGetQuickReplies) computeHash(v *tg.MessagesQuickReplies) int64 {
|
||||
r := v.QuickReplies
|
||||
|
||||
sort.SliceStable(r, func(i, j int) bool {
|
||||
return r[i].ShortcutID < r[j].ShortcutID
|
||||
})
|
||||
h := hasher.Hasher{}
|
||||
for _, contact := range r {
|
||||
h.Update(uint32(contact.ShortcutID))
|
||||
}
|
||||
|
||||
return h.Sum()
|
||||
}
|
||||
@@ -0,0 +1,2 @@
|
||||
// Package cached contains cached query helpers.
|
||||
package cached
|
||||
@@ -0,0 +1,3 @@
|
||||
package cached
|
||||
|
||||
//go:generate go run go.mau.fi/mautrix-telegram/pkg/gotd/telegram/query/internal/cachedgen -package=cached -out=queries.gen.go
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,167 @@
|
||||
// Package participants contains channel participants iteration helper.
|
||||
package participants
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/message/peer"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
// Elem is a channel participants iterator element.
|
||||
type Elem struct {
|
||||
Participant tg.ChannelParticipantClass
|
||||
Entities peer.Entities
|
||||
}
|
||||
|
||||
// Iterator is a channel participants stream iterator.
|
||||
type Iterator struct {
|
||||
// Current state.
|
||||
lastErr error
|
||||
// Buffer state.
|
||||
buf []Elem
|
||||
bufCur int
|
||||
// Request state.
|
||||
limit int
|
||||
lastBatch bool
|
||||
// Offset parameters state.
|
||||
offset int
|
||||
// Remote state.
|
||||
count int
|
||||
totalGot bool
|
||||
|
||||
// Query builder.
|
||||
query Query
|
||||
}
|
||||
|
||||
// NewIterator creates new iterator.
|
||||
func NewIterator(query Query, limit int) *Iterator {
|
||||
return &Iterator{
|
||||
buf: make([]Elem, 0, limit),
|
||||
bufCur: -1,
|
||||
limit: limit,
|
||||
query: query,
|
||||
}
|
||||
}
|
||||
|
||||
// Offset sets Offset request parameter.
|
||||
func (m *Iterator) Offset(offset int) *Iterator {
|
||||
m.offset = offset
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *Iterator) apply(r tg.ChannelsChannelParticipantsClass) error {
|
||||
if m.lastBatch {
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
participants tg.ChannelParticipantClassArray
|
||||
entities peer.Entities
|
||||
)
|
||||
switch prts := r.(type) {
|
||||
case *tg.ChannelsChannelParticipants: // channels.channelParticipants#f56ee2a8
|
||||
participants = prts.Participants
|
||||
entities = peer.NewEntities(prts.MapUsers().UserToMap(), map[int64]*tg.Chat{}, map[int64]*tg.Channel{})
|
||||
|
||||
m.count = prts.Count
|
||||
m.lastBatch = len(participants) < 1
|
||||
default: // channels.channelParticipantsNotModified#f0173fe9
|
||||
return errors.Errorf("unexpected type %T", r)
|
||||
}
|
||||
m.totalGot = true
|
||||
m.offset += len(participants)
|
||||
|
||||
m.bufCur = -1
|
||||
m.buf = m.buf[:0]
|
||||
for i := range participants {
|
||||
m.buf = append(m.buf, Elem{Participant: participants[i], Entities: entities})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Iterator) requestNext(ctx context.Context) error {
|
||||
r, err := m.query.Query(ctx, Request{
|
||||
Offset: m.offset,
|
||||
Limit: m.limit,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return m.apply(r)
|
||||
}
|
||||
|
||||
func (m *Iterator) bufNext() bool {
|
||||
if len(m.buf)-1 <= m.bufCur {
|
||||
return false
|
||||
}
|
||||
|
||||
m.bufCur++
|
||||
return true
|
||||
}
|
||||
|
||||
// Total returns last fetched count of elements.
|
||||
// If count was not fetched before, it requests server using FetchTotal.
|
||||
func (m *Iterator) Total(ctx context.Context) (int, error) {
|
||||
if m.totalGot {
|
||||
return m.count, nil
|
||||
}
|
||||
|
||||
return m.FetchTotal(ctx)
|
||||
}
|
||||
|
||||
// FetchTotal fetches and returns count of elements.
|
||||
func (m *Iterator) FetchTotal(ctx context.Context) (int, error) {
|
||||
r, err := m.query.Query(ctx, Request{
|
||||
Limit: 1,
|
||||
})
|
||||
if err != nil {
|
||||
return 0, errors.Wrap(err, "fetch total")
|
||||
}
|
||||
|
||||
switch prts := r.(type) {
|
||||
case *tg.ChannelsChannelParticipants: // channels.channelParticipants#f56ee2a8
|
||||
m.count = prts.Count
|
||||
default: // channels.channelParticipantsNotModified#f0173fe9
|
||||
return 0, errors.Errorf("unexpected type %T", r)
|
||||
}
|
||||
|
||||
m.totalGot = true
|
||||
return m.count, nil
|
||||
}
|
||||
|
||||
// Next prepares the next message for reading with the Value method.
|
||||
//
|
||||
// Returns true on success, or false if there is no next message or an error happened while preparing it.
|
||||
// Err should be consulted to distinguish between the two cases.
|
||||
func (m *Iterator) Next(ctx context.Context) bool {
|
||||
if m.lastErr != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if !m.bufNext() {
|
||||
// If buffer is empty, we should fetch next batch.
|
||||
if err := m.requestNext(ctx); err != nil {
|
||||
m.lastErr = err
|
||||
return false
|
||||
}
|
||||
// Try again with new buffer.
|
||||
return m.bufNext()
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// Value returns current message.
|
||||
func (m *Iterator) Value() Elem {
|
||||
return m.buf[m.bufCur]
|
||||
}
|
||||
|
||||
// Err returns the error, if any, that was encountered during iteration.
|
||||
func (m *Iterator) Err() error {
|
||||
return m.lastErr
|
||||
}
|
||||
@@ -0,0 +1,92 @@
|
||||
package participants
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tgmock"
|
||||
)
|
||||
|
||||
func generateParticipants(count int) []tg.ChannelParticipantClass {
|
||||
r := make([]tg.ChannelParticipantClass, 0, count)
|
||||
|
||||
for i := 0; i < count; i++ {
|
||||
r = append(r, &tg.ChannelParticipant{
|
||||
UserID: int64(i),
|
||||
Date: i,
|
||||
})
|
||||
}
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
func result(r []tg.ChannelParticipantClass, count int) tg.ChannelsChannelParticipantsClass {
|
||||
return &tg.ChannelsChannelParticipants{
|
||||
Participants: r,
|
||||
Count: count,
|
||||
}
|
||||
}
|
||||
|
||||
func TestIterator(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
mock := tgmock.NewRequire(t)
|
||||
limit := 10
|
||||
totalRecords := 3 * limit
|
||||
expected := generateParticipants(totalRecords)
|
||||
raw := tg.NewClient(mock)
|
||||
ch := &tg.InputChannel{
|
||||
ChannelID: 10,
|
||||
AccessHash: 10,
|
||||
}
|
||||
|
||||
mock.ExpectCall(&tg.ChannelsGetParticipantsRequest{
|
||||
Channel: ch,
|
||||
Filter: &tg.ChannelParticipantsRecent{},
|
||||
Offset: 0,
|
||||
Limit: limit,
|
||||
}).ThenResult(result(expected[0:limit], totalRecords))
|
||||
mock.ExpectCall(&tg.ChannelsGetParticipantsRequest{
|
||||
Channel: ch,
|
||||
Filter: &tg.ChannelParticipantsRecent{},
|
||||
Offset: limit,
|
||||
Limit: limit,
|
||||
}).ThenResult(result(expected[limit:2*limit], totalRecords))
|
||||
mock.ExpectCall(&tg.ChannelsGetParticipantsRequest{
|
||||
Channel: ch,
|
||||
Filter: &tg.ChannelParticipantsRecent{},
|
||||
Offset: 2 * limit,
|
||||
Limit: limit,
|
||||
}).ThenResult(result(expected[2*limit:3*limit], totalRecords))
|
||||
mock.ExpectCall(&tg.ChannelsGetParticipantsRequest{
|
||||
Channel: ch,
|
||||
Filter: &tg.ChannelParticipantsRecent{},
|
||||
Offset: 3 * limit,
|
||||
Limit: limit,
|
||||
}).ThenResult(result(expected[3*limit:], totalRecords))
|
||||
|
||||
iter := NewQueryBuilder(raw).GetParticipants(ch).BatchSize(10).Iter()
|
||||
i := 0
|
||||
for iter.Next(ctx) {
|
||||
require.Equal(t, expected[i], iter.Value().Participant)
|
||||
i++
|
||||
}
|
||||
require.NoError(t, iter.Err())
|
||||
require.Equal(t, totalRecords, i)
|
||||
|
||||
total, err := iter.Total(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, totalRecords, total)
|
||||
|
||||
mock.ExpectCall(&tg.ChannelsGetParticipantsRequest{
|
||||
Channel: ch,
|
||||
Filter: &tg.ChannelParticipantsRecent{},
|
||||
Offset: 0,
|
||||
Limit: 1,
|
||||
}).ThenResult(result(expected[:0], totalRecords))
|
||||
total, err = iter.FetchTotal(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, totalRecords, total)
|
||||
}
|
||||
@@ -0,0 +1,201 @@
|
||||
// Code generated by itergen, DO NOT EDIT.
|
||||
|
||||
package participants
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
// No-op definition for keeping imports.
|
||||
var _ = context.Background()
|
||||
|
||||
// Request is a parameter for Query.
|
||||
type Request struct {
|
||||
Offset int
|
||||
Limit int
|
||||
}
|
||||
|
||||
// Query is an abstraction for participants request.
|
||||
// NB: iterator mutates returned data (sorts, at least).
|
||||
type Query interface {
|
||||
Query(ctx context.Context, req Request) (tg.ChannelsChannelParticipantsClass, error)
|
||||
}
|
||||
|
||||
// QueryFunc is a function adapter for Query.
|
||||
type QueryFunc func(ctx context.Context, req Request) (tg.ChannelsChannelParticipantsClass, error)
|
||||
|
||||
// Query implements Query interface.
|
||||
func (q QueryFunc) Query(ctx context.Context, req Request) (tg.ChannelsChannelParticipantsClass, error) {
|
||||
return q(ctx, req)
|
||||
}
|
||||
|
||||
// QueryBuilder is a helper to create message queries.
|
||||
type QueryBuilder struct {
|
||||
raw *tg.Client
|
||||
}
|
||||
|
||||
// NewQueryBuilder creates new QueryBuilder.
|
||||
func NewQueryBuilder(raw *tg.Client) *QueryBuilder {
|
||||
return &QueryBuilder{raw: raw}
|
||||
}
|
||||
|
||||
// GetParticipantsQueryBuilder is query builder of ChannelsGetParticipants.
|
||||
type GetParticipantsQueryBuilder struct {
|
||||
raw *tg.Client
|
||||
req tg.ChannelsGetParticipantsRequest
|
||||
batchSize int
|
||||
offset int
|
||||
}
|
||||
|
||||
// GetParticipants creates query builder of ChannelsGetParticipants.
|
||||
func (q *QueryBuilder) GetParticipants(paramChannel tg.InputChannelClass) *GetParticipantsQueryBuilder {
|
||||
b := &GetParticipantsQueryBuilder{
|
||||
raw: q.raw,
|
||||
batchSize: 1,
|
||||
req: tg.ChannelsGetParticipantsRequest{
|
||||
Filter: &tg.ChannelParticipantsRecent{},
|
||||
},
|
||||
}
|
||||
|
||||
b.req.Channel = paramChannel
|
||||
return b
|
||||
}
|
||||
|
||||
// BatchSize sets buffer of message loaded from one request.
|
||||
// Be carefully, when set this limit, because Telegram does not return error if limit is too big,
|
||||
// so results can be incorrect.
|
||||
func (b *GetParticipantsQueryBuilder) BatchSize(batchSize int) *GetParticipantsQueryBuilder {
|
||||
b.batchSize = batchSize
|
||||
return b
|
||||
}
|
||||
|
||||
// Channel sets Channel field of GetParticipants query.
|
||||
func (b *GetParticipantsQueryBuilder) Channel(paramChannel tg.InputChannelClass) *GetParticipantsQueryBuilder {
|
||||
b.req.Channel = paramChannel
|
||||
return b
|
||||
}
|
||||
|
||||
// Filter sets Filter field of GetParticipants query.
|
||||
func (b *GetParticipantsQueryBuilder) Filter(paramFilter tg.ChannelParticipantsFilterClass) *GetParticipantsQueryBuilder {
|
||||
b.req.Filter = paramFilter
|
||||
return b
|
||||
}
|
||||
|
||||
// Admins sets Filter field of GetParticipants query.
|
||||
func (b *GetParticipantsQueryBuilder) Admins() *GetParticipantsQueryBuilder {
|
||||
b.req.Filter = &tg.ChannelParticipantsAdmins{}
|
||||
return b
|
||||
}
|
||||
|
||||
// Banned sets Filter field of GetParticipants query.
|
||||
func (b *GetParticipantsQueryBuilder) Banned(paramQ string) *GetParticipantsQueryBuilder {
|
||||
b.req.Filter = &tg.ChannelParticipantsBanned{
|
||||
Q: paramQ,
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// Bots sets Filter field of GetParticipants query.
|
||||
func (b *GetParticipantsQueryBuilder) Bots() *GetParticipantsQueryBuilder {
|
||||
b.req.Filter = &tg.ChannelParticipantsBots{}
|
||||
return b
|
||||
}
|
||||
|
||||
// Contacts sets Filter field of GetParticipants query.
|
||||
func (b *GetParticipantsQueryBuilder) Contacts(paramQ string) *GetParticipantsQueryBuilder {
|
||||
b.req.Filter = &tg.ChannelParticipantsContacts{
|
||||
Q: paramQ,
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// Kicked sets Filter field of GetParticipants query.
|
||||
func (b *GetParticipantsQueryBuilder) Kicked(paramQ string) *GetParticipantsQueryBuilder {
|
||||
b.req.Filter = &tg.ChannelParticipantsKicked{
|
||||
Q: paramQ,
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// Mentions sets Filter field of GetParticipants query.
|
||||
func (b *GetParticipantsQueryBuilder) Mentions(paramQ string, paramTopMsgID int) *GetParticipantsQueryBuilder {
|
||||
b.req.Filter = &tg.ChannelParticipantsMentions{
|
||||
Q: paramQ,
|
||||
TopMsgID: paramTopMsgID,
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// Recent sets Filter field of GetParticipants query.
|
||||
func (b *GetParticipantsQueryBuilder) Recent() *GetParticipantsQueryBuilder {
|
||||
b.req.Filter = &tg.ChannelParticipantsRecent{}
|
||||
return b
|
||||
}
|
||||
|
||||
// Search sets Filter field of GetParticipants query.
|
||||
func (b *GetParticipantsQueryBuilder) Search(paramQ string) *GetParticipantsQueryBuilder {
|
||||
b.req.Filter = &tg.ChannelParticipantsSearch{
|
||||
Q: paramQ,
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// Query implements Query interface.
|
||||
func (b *GetParticipantsQueryBuilder) Query(ctx context.Context, req Request) (tg.ChannelsChannelParticipantsClass, error) {
|
||||
r := &tg.ChannelsGetParticipantsRequest{
|
||||
Limit: req.Limit,
|
||||
}
|
||||
|
||||
r.Channel = b.req.Channel
|
||||
r.Filter = b.req.Filter
|
||||
r.Offset = req.Offset
|
||||
return b.raw.ChannelsGetParticipants(ctx, r)
|
||||
}
|
||||
|
||||
// Iter returns iterator using built query.
|
||||
func (b *GetParticipantsQueryBuilder) Iter() *Iterator {
|
||||
iter := NewIterator(b, b.batchSize)
|
||||
iter = iter.Offset(b.offset)
|
||||
return iter
|
||||
}
|
||||
|
||||
// ForEach calls given callback on each iterator element.
|
||||
func (b *GetParticipantsQueryBuilder) ForEach(ctx context.Context, cb func(context.Context, Elem) error) error {
|
||||
iter := b.Iter()
|
||||
for iter.Next(ctx) {
|
||||
if err := cb(ctx, iter.Value()); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return iter.Err()
|
||||
}
|
||||
|
||||
// Count fetches remote state to get number of elements.
|
||||
func (b *GetParticipantsQueryBuilder) Count(ctx context.Context) (int, error) {
|
||||
iter := b.Iter()
|
||||
c, err := iter.Total(ctx)
|
||||
if err != nil {
|
||||
return 0, errors.Wrap(err, "get total")
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// Collect creates iterator and collects all elements to slice.
|
||||
func (b *GetParticipantsQueryBuilder) Collect(ctx context.Context) ([]Elem, error) {
|
||||
iter := b.Iter()
|
||||
c, err := iter.Total(ctx)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "get total")
|
||||
}
|
||||
|
||||
r := make([]Elem, 0, c)
|
||||
for iter.Next(ctx) {
|
||||
r = append(r, iter.Value())
|
||||
}
|
||||
|
||||
return r, iter.Err()
|
||||
}
|
||||
@@ -0,0 +1,3 @@
|
||||
package participants
|
||||
|
||||
//go:generate go run go.mau.fi/mautrix-telegram/pkg/gotd/telegram/query/internal/itergen -result=ChannelsChannelParticipantsClass -package=participants -prefix=Channels -out=queries.gen.go
|
||||
@@ -0,0 +1,62 @@
|
||||
package participants
|
||||
|
||||
import (
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/query/photos"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
// UserPhotos returns new user photo query builder for participant.
|
||||
func (e Elem) UserPhotos(raw *tg.Client) (*photos.GetUserPhotosQueryBuilder, bool) {
|
||||
user, ok := e.User()
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
return photos.NewQueryBuilder(raw).GetUserPhotos(user.AsInput()), true
|
||||
}
|
||||
|
||||
// User tries to get participant user object.
|
||||
func (e Elem) User() (*tg.User, bool) {
|
||||
switch part := e.Participant.(type) {
|
||||
case interface{ GetUserID() int64 }:
|
||||
return e.Entities.User(part.GetUserID())
|
||||
case interface{ GetPeer() tg.PeerClass }:
|
||||
user, ok := part.GetPeer().(*tg.PeerUser)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return e.Entities.User(user.GetUserID())
|
||||
default:
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
|
||||
// Creator returns participant user object and meta info if participant is a creator of channel.
|
||||
func (e Elem) Creator() (*tg.User, *tg.ChannelParticipantCreator, bool) {
|
||||
part, ok := e.Participant.(*tg.ChannelParticipantCreator)
|
||||
if !ok {
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
user, ok := e.User()
|
||||
if !ok {
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
return user, part, true
|
||||
}
|
||||
|
||||
// Admin returns participant user object and meta info if participant is admin of channel.
|
||||
func (e Elem) Admin() (*tg.User, *tg.ChannelParticipantAdmin, bool) {
|
||||
part, ok := e.Participant.(*tg.ChannelParticipantAdmin)
|
||||
if !ok {
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
user, ok := e.User()
|
||||
if !ok {
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
return user, part, true
|
||||
}
|
||||
@@ -0,0 +1,58 @@
|
||||
package participants
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/message/peer"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
func TestElem(t *testing.T) {
|
||||
entities := peer.NewEntities(
|
||||
map[int64]*tg.User{10: {}},
|
||||
map[int64]*tg.Chat{},
|
||||
map[int64]*tg.Channel{},
|
||||
)
|
||||
|
||||
type results struct {
|
||||
admin, creator, photos bool
|
||||
}
|
||||
tests := []struct {
|
||||
Name string
|
||||
Part tg.ChannelParticipantClass
|
||||
results
|
||||
}{
|
||||
{"UnknownPlain", &tg.ChannelParticipant{UserID: 45}, results{}},
|
||||
{"UnknownBanned", &tg.ChannelParticipantBanned{Peer: &tg.PeerUser{UserID: 45}},
|
||||
results{}},
|
||||
{"UnknownAdmin", &tg.ChannelParticipantAdmin{UserID: 45}, results{}},
|
||||
{"UnknownCreator", &tg.ChannelParticipantCreator{UserID: 45}, results{}},
|
||||
{"Plain", &tg.ChannelParticipant{UserID: 10}, results{photos: true}},
|
||||
{"Banned", &tg.ChannelParticipantBanned{Peer: &tg.PeerUser{UserID: 10}},
|
||||
results{photos: true}},
|
||||
{"Admin", &tg.ChannelParticipantAdmin{UserID: 10}, results{
|
||||
admin: true,
|
||||
photos: true,
|
||||
}},
|
||||
{"Creator", &tg.ChannelParticipantCreator{UserID: 10}, results{
|
||||
creator: true,
|
||||
photos: true,
|
||||
}},
|
||||
}
|
||||
for _, test := range tests {
|
||||
t.Run(test.Name, func(t *testing.T) {
|
||||
a := require.New(t)
|
||||
var ok bool
|
||||
|
||||
elem := Elem{Participant: test.Part, Entities: entities}
|
||||
_, ok = elem.UserPhotos(nil)
|
||||
a.Equal(test.photos, ok)
|
||||
_, _, ok = elem.Admin()
|
||||
a.Equal(test.admin, ok)
|
||||
_, _, ok = elem.Creator()
|
||||
a.Equal(test.creator, ok)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,174 @@
|
||||
// Package blocked contains blocked contacts iteration helper.
|
||||
package blocked
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/message/peer"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
// Elem is a contact iterator element.
|
||||
type Elem struct {
|
||||
Contact tg.PeerBlocked
|
||||
Entities peer.Entities
|
||||
}
|
||||
|
||||
// Iterator is a blocked contacts stream iterator.
|
||||
type Iterator struct {
|
||||
// Current state.
|
||||
lastErr error
|
||||
// Buffer state.
|
||||
buf []Elem
|
||||
bufCur int
|
||||
// Request state.
|
||||
limit int
|
||||
lastBatch bool
|
||||
// Offset parameters state.
|
||||
offset int
|
||||
// Remote state.
|
||||
count int
|
||||
totalGot bool
|
||||
|
||||
// Query builder.
|
||||
query Query
|
||||
}
|
||||
|
||||
// NewIterator creates new iterator.
|
||||
func NewIterator(query Query, limit int) *Iterator {
|
||||
return &Iterator{
|
||||
buf: make([]Elem, 0, limit),
|
||||
bufCur: -1,
|
||||
limit: limit,
|
||||
query: query,
|
||||
}
|
||||
}
|
||||
|
||||
// Offset sets Offset request parameter.
|
||||
func (m *Iterator) Offset(offset int) *Iterator {
|
||||
m.offset = offset
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *Iterator) apply(r tg.ContactsBlockedClass) error {
|
||||
if m.lastBatch {
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
blocked []tg.PeerBlocked
|
||||
entities peer.Entities
|
||||
)
|
||||
switch ctcs := r.(type) {
|
||||
case *tg.ContactsBlocked: // contacts.blocked#ade1591
|
||||
blocked = ctcs.Blocked
|
||||
entities = peer.EntitiesFromResult(ctcs)
|
||||
|
||||
m.count = len(ctcs.Blocked)
|
||||
m.lastBatch = true
|
||||
case *tg.ContactsBlockedSlice: // contacts.blockedSlice#e1664194
|
||||
blocked = ctcs.Blocked
|
||||
entities = peer.EntitiesFromResult(ctcs)
|
||||
|
||||
m.count = ctcs.Count
|
||||
m.lastBatch = len(ctcs.Blocked) < m.limit
|
||||
default:
|
||||
return errors.Errorf("unexpected type %T", r)
|
||||
}
|
||||
m.totalGot = true
|
||||
m.offset += len(blocked)
|
||||
|
||||
m.bufCur = -1
|
||||
m.buf = m.buf[:0]
|
||||
for i := range blocked {
|
||||
m.buf = append(m.buf, Elem{Contact: blocked[i], Entities: entities})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Iterator) requestNext(ctx context.Context) error {
|
||||
r, err := m.query.Query(ctx, Request{
|
||||
Offset: m.offset,
|
||||
Limit: m.limit,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return m.apply(r)
|
||||
}
|
||||
|
||||
func (m *Iterator) bufNext() bool {
|
||||
if len(m.buf)-1 <= m.bufCur {
|
||||
return false
|
||||
}
|
||||
|
||||
m.bufCur++
|
||||
return true
|
||||
}
|
||||
|
||||
// Total returns last fetched count of elements.
|
||||
// If count was not fetched before, it requests server using FetchTotal.
|
||||
func (m *Iterator) Total(ctx context.Context) (int, error) {
|
||||
if m.totalGot {
|
||||
return m.count, nil
|
||||
}
|
||||
|
||||
return m.FetchTotal(ctx)
|
||||
}
|
||||
|
||||
// FetchTotal fetches and returns count of elements.
|
||||
func (m *Iterator) FetchTotal(ctx context.Context) (int, error) {
|
||||
r, err := m.query.Query(ctx, Request{
|
||||
Limit: 1,
|
||||
})
|
||||
if err != nil {
|
||||
return 0, errors.Wrap(err, "fetch total")
|
||||
}
|
||||
|
||||
switch ctcs := r.(type) {
|
||||
case *tg.ContactsBlocked: // contacts.blocked#ade1591
|
||||
m.count = len(ctcs.Blocked)
|
||||
case *tg.ContactsBlockedSlice: // contacts.blockedSlice#e1664194
|
||||
m.count = ctcs.Count
|
||||
default:
|
||||
return 0, errors.Errorf("unexpected type %T", r)
|
||||
}
|
||||
|
||||
m.totalGot = true
|
||||
return m.count, nil
|
||||
}
|
||||
|
||||
// Next prepares the next message for reading with the Value method.
|
||||
// It returns true on success, or false if there is no next message or an error happened while preparing it.
|
||||
// Err should be consulted to distinguish between the two cases.
|
||||
func (m *Iterator) Next(ctx context.Context) bool {
|
||||
if m.lastErr != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if !m.bufNext() {
|
||||
// If buffer is empty, we should fetch next batch.
|
||||
if err := m.requestNext(ctx); err != nil {
|
||||
m.lastErr = err
|
||||
return false
|
||||
}
|
||||
// Try again with new buffer.
|
||||
return m.bufNext()
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// Value returns current message.
|
||||
func (m *Iterator) Value() Elem {
|
||||
return m.buf[m.bufCur]
|
||||
}
|
||||
|
||||
// Err returns the error, if any, that was encountered during iteration.
|
||||
func (m *Iterator) Err() error {
|
||||
return m.lastErr
|
||||
}
|
||||
@@ -0,0 +1,80 @@
|
||||
package blocked
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tgmock"
|
||||
)
|
||||
|
||||
func generateBlocked(count int) []tg.PeerBlocked {
|
||||
r := make([]tg.PeerBlocked, 0, count)
|
||||
|
||||
for i := 0; i < count; i++ {
|
||||
r = append(r, tg.PeerBlocked{
|
||||
PeerID: &tg.PeerUser{
|
||||
UserID: int64(i + 1),
|
||||
},
|
||||
Date: i,
|
||||
})
|
||||
}
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
func result(r []tg.PeerBlocked, count int) tg.ContactsBlockedClass {
|
||||
return &tg.ContactsBlockedSlice{
|
||||
Blocked: r,
|
||||
Count: count,
|
||||
}
|
||||
}
|
||||
|
||||
func TestIterator(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
mock := tgmock.NewRequire(t)
|
||||
limit := 10
|
||||
totalRecords := 3 * limit
|
||||
expected := generateBlocked(totalRecords)
|
||||
raw := tg.NewClient(mock)
|
||||
|
||||
mock.ExpectCall(&tg.ContactsGetBlockedRequest{
|
||||
Offset: 0,
|
||||
Limit: limit,
|
||||
}).ThenResult(result(expected[0:limit], totalRecords))
|
||||
mock.ExpectCall(&tg.ContactsGetBlockedRequest{
|
||||
Offset: limit,
|
||||
Limit: limit,
|
||||
}).ThenResult(result(expected[limit:2*limit], totalRecords))
|
||||
mock.ExpectCall(&tg.ContactsGetBlockedRequest{
|
||||
Offset: 2 * limit,
|
||||
Limit: limit,
|
||||
}).ThenResult(result(expected[2*limit:3*limit], totalRecords))
|
||||
mock.ExpectCall(&tg.ContactsGetBlockedRequest{
|
||||
Offset: 3 * limit,
|
||||
Limit: limit,
|
||||
}).ThenResult(result(expected[3*limit:], totalRecords))
|
||||
|
||||
iter := NewQueryBuilder(raw).GetBlocked().BatchSize(10).Iter()
|
||||
i := 0
|
||||
for iter.Next(ctx) {
|
||||
require.Equal(t, expected[i], iter.Value().Contact)
|
||||
i++
|
||||
}
|
||||
require.NoError(t, iter.Err())
|
||||
require.Equal(t, totalRecords, i)
|
||||
|
||||
total, err := iter.Total(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, totalRecords, total)
|
||||
|
||||
mock.ExpectCall(&tg.ContactsGetBlockedRequest{
|
||||
Offset: 0,
|
||||
Limit: 1,
|
||||
}).ThenResult(result(expected[:0], totalRecords))
|
||||
total, err = iter.FetchTotal(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, totalRecords, total)
|
||||
}
|
||||
@@ -0,0 +1,132 @@
|
||||
// Code generated by itergen, DO NOT EDIT.
|
||||
|
||||
package blocked
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
// No-op definition for keeping imports.
|
||||
var _ = context.Background()
|
||||
|
||||
// Request is a parameter for Query.
|
||||
type Request struct {
|
||||
Offset int
|
||||
Limit int
|
||||
}
|
||||
|
||||
// Query is an abstraction for blocked request.
|
||||
// NB: iterator mutates returned data (sorts, at least).
|
||||
type Query interface {
|
||||
Query(ctx context.Context, req Request) (tg.ContactsBlockedClass, error)
|
||||
}
|
||||
|
||||
// QueryFunc is a function adapter for Query.
|
||||
type QueryFunc func(ctx context.Context, req Request) (tg.ContactsBlockedClass, error)
|
||||
|
||||
// Query implements Query interface.
|
||||
func (q QueryFunc) Query(ctx context.Context, req Request) (tg.ContactsBlockedClass, error) {
|
||||
return q(ctx, req)
|
||||
}
|
||||
|
||||
// QueryBuilder is a helper to create message queries.
|
||||
type QueryBuilder struct {
|
||||
raw *tg.Client
|
||||
}
|
||||
|
||||
// NewQueryBuilder creates new QueryBuilder.
|
||||
func NewQueryBuilder(raw *tg.Client) *QueryBuilder {
|
||||
return &QueryBuilder{raw: raw}
|
||||
}
|
||||
|
||||
// GetBlockedQueryBuilder is query builder of ContactsGetBlocked.
|
||||
type GetBlockedQueryBuilder struct {
|
||||
raw *tg.Client
|
||||
req tg.ContactsGetBlockedRequest
|
||||
batchSize int
|
||||
offset int
|
||||
}
|
||||
|
||||
// GetBlocked creates query builder of ContactsGetBlocked.
|
||||
func (q *QueryBuilder) GetBlocked() *GetBlockedQueryBuilder {
|
||||
b := &GetBlockedQueryBuilder{
|
||||
raw: q.raw,
|
||||
batchSize: 1,
|
||||
req: tg.ContactsGetBlockedRequest{},
|
||||
}
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
// BatchSize sets buffer of message loaded from one request.
|
||||
// Be carefully, when set this limit, because Telegram does not return error if limit is too big,
|
||||
// so results can be incorrect.
|
||||
func (b *GetBlockedQueryBuilder) BatchSize(batchSize int) *GetBlockedQueryBuilder {
|
||||
b.batchSize = batchSize
|
||||
return b
|
||||
}
|
||||
|
||||
// MyStoriesFrom sets MyStoriesFrom field of GetBlocked query.
|
||||
func (b *GetBlockedQueryBuilder) MyStoriesFrom(paramMyStoriesFrom bool) *GetBlockedQueryBuilder {
|
||||
b.req.MyStoriesFrom = paramMyStoriesFrom
|
||||
return b
|
||||
}
|
||||
|
||||
// Query implements Query interface.
|
||||
func (b *GetBlockedQueryBuilder) Query(ctx context.Context, req Request) (tg.ContactsBlockedClass, error) {
|
||||
r := &tg.ContactsGetBlockedRequest{
|
||||
Limit: req.Limit,
|
||||
}
|
||||
|
||||
r.MyStoriesFrom = b.req.MyStoriesFrom
|
||||
r.Offset = req.Offset
|
||||
return b.raw.ContactsGetBlocked(ctx, r)
|
||||
}
|
||||
|
||||
// Iter returns iterator using built query.
|
||||
func (b *GetBlockedQueryBuilder) Iter() *Iterator {
|
||||
iter := NewIterator(b, b.batchSize)
|
||||
iter = iter.Offset(b.offset)
|
||||
return iter
|
||||
}
|
||||
|
||||
// ForEach calls given callback on each iterator element.
|
||||
func (b *GetBlockedQueryBuilder) ForEach(ctx context.Context, cb func(context.Context, Elem) error) error {
|
||||
iter := b.Iter()
|
||||
for iter.Next(ctx) {
|
||||
if err := cb(ctx, iter.Value()); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return iter.Err()
|
||||
}
|
||||
|
||||
// Count fetches remote state to get number of elements.
|
||||
func (b *GetBlockedQueryBuilder) Count(ctx context.Context) (int, error) {
|
||||
iter := b.Iter()
|
||||
c, err := iter.Total(ctx)
|
||||
if err != nil {
|
||||
return 0, errors.Wrap(err, "get total")
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// Collect creates iterator and collects all elements to slice.
|
||||
func (b *GetBlockedQueryBuilder) Collect(ctx context.Context) ([]Elem, error) {
|
||||
iter := b.Iter()
|
||||
c, err := iter.Total(ctx)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "get total")
|
||||
}
|
||||
|
||||
r := make([]Elem, 0, c)
|
||||
for iter.Next(ctx) {
|
||||
r = append(r, iter.Value())
|
||||
}
|
||||
|
||||
return r, iter.Err()
|
||||
}
|
||||
@@ -0,0 +1,3 @@
|
||||
package blocked
|
||||
|
||||
//go:generate go run go.mau.fi/mautrix-telegram/pkg/gotd/telegram/query/internal/itergen -result=ContactsBlockedClass -package=blocked -prefix=Contacts -out=queries.gen.go
|
||||
@@ -0,0 +1,254 @@
|
||||
// Package dialogs contains dialog iteration helper.
|
||||
package dialogs
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/message/peer"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
// Elem is a dialog iterator element.
|
||||
type Elem struct {
|
||||
Dialog tg.DialogClass
|
||||
Peer tg.InputPeerClass
|
||||
Last tg.NotEmptyMessage
|
||||
Entities peer.Entities
|
||||
}
|
||||
|
||||
// Iterator is a dialog stream iterator.
|
||||
type Iterator struct {
|
||||
// Current state.
|
||||
lastErr error
|
||||
// Buffer state.
|
||||
buf []Elem
|
||||
bufCur int
|
||||
// Request state.
|
||||
limit int
|
||||
lastBatch bool
|
||||
// Offset parameters state.
|
||||
offsetID int
|
||||
offsetDate int
|
||||
offsetPeer tg.InputPeerClass
|
||||
// Remote state.
|
||||
count int
|
||||
totalGot bool
|
||||
|
||||
// Query builder.
|
||||
query Query
|
||||
}
|
||||
|
||||
// NewIterator creates new iterator.
|
||||
func NewIterator(query Query, limit int) *Iterator {
|
||||
return &Iterator{
|
||||
buf: make([]Elem, 0, limit),
|
||||
bufCur: -1,
|
||||
limit: limit,
|
||||
query: query,
|
||||
offsetPeer: &tg.InputPeerEmpty{},
|
||||
}
|
||||
}
|
||||
|
||||
// OffsetID sets OffsetID request parameter.
|
||||
func (m *Iterator) OffsetID(offsetID int) *Iterator {
|
||||
m.offsetID = offsetID
|
||||
return m
|
||||
}
|
||||
|
||||
// OffsetDate sets OffsetDate request parameter.
|
||||
func (m *Iterator) OffsetDate(offsetDate int) *Iterator {
|
||||
m.offsetDate = offsetDate
|
||||
return m
|
||||
}
|
||||
|
||||
// OffsetPeer sets OffsetPeer request parameter.
|
||||
func (m *Iterator) OffsetPeer(offsetPeer tg.InputPeerClass) *Iterator {
|
||||
m.offsetPeer = offsetPeer
|
||||
return m
|
||||
}
|
||||
|
||||
// messageMap is a helper to store messages for multiple peers.
|
||||
type messageMap map[DialogKey]tg.NotEmptyMessage
|
||||
|
||||
func (m messageMap) collect(messages tg.MessageClassArray) error {
|
||||
for _, msg := range messages {
|
||||
nonEmpty, ok := msg.AsNotEmpty()
|
||||
if !ok {
|
||||
// TODO(tdakkota): Maybe should I return error here?
|
||||
continue
|
||||
}
|
||||
|
||||
var key DialogKey
|
||||
if err := key.FromPeer(nonEmpty.GetPeerID()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m[key] = nonEmpty
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Iterator) apply(r tg.MessagesDialogsClass) error {
|
||||
if m.lastBatch {
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
messages tg.MessageClassArray
|
||||
dialogs tg.DialogClassArray
|
||||
entities peer.Entities
|
||||
)
|
||||
|
||||
switch dlgs := r.(type) {
|
||||
case *tg.MessagesDialogs: // messages.dialogs#15ba6c40
|
||||
dialogs = dlgs.Dialogs
|
||||
messages = dlgs.Messages
|
||||
entities = peer.EntitiesFromResult(dlgs)
|
||||
|
||||
m.count = len(messages)
|
||||
m.lastBatch = true
|
||||
case *tg.MessagesDialogsSlice: // messages.dialogsSlice#71e094f3
|
||||
dialogs = dlgs.Dialogs
|
||||
messages = dlgs.Messages
|
||||
entities = peer.EntitiesFromResult(dlgs)
|
||||
|
||||
m.count = dlgs.Count
|
||||
m.lastBatch = len(dlgs.Dialogs) == 0
|
||||
default: // messages.dialogsNotModified#f0e3e596
|
||||
return errors.Errorf("unexpected type %T", r)
|
||||
}
|
||||
m.totalGot = true
|
||||
|
||||
msgMap := make(messageMap, len(messages))
|
||||
if err := msgMap.collect(messages); err != nil {
|
||||
return errors.Wrap(err, "collect last messages")
|
||||
}
|
||||
|
||||
m.bufCur = -1
|
||||
m.buf = m.buf[:0]
|
||||
|
||||
var last tg.NotEmptyMessage
|
||||
for _, dlg := range dialogs {
|
||||
var key DialogKey
|
||||
if err := key.FromPeer(dlg.GetPeer()); err == nil {
|
||||
last = msgMap[key]
|
||||
}
|
||||
|
||||
p, err := entities.ExtractPeer(dlg.GetPeer())
|
||||
if err != nil {
|
||||
p = &tg.InputPeerEmpty{}
|
||||
}
|
||||
|
||||
m.buf = append(m.buf, Elem{
|
||||
Dialog: dlg,
|
||||
Peer: p,
|
||||
Last: last,
|
||||
Entities: entities,
|
||||
})
|
||||
}
|
||||
|
||||
if !m.lastBatch && len(m.buf) > 0 {
|
||||
if last != nil {
|
||||
m.offsetID = last.GetID()
|
||||
m.offsetDate = last.GetDate()
|
||||
}
|
||||
|
||||
p, err := entities.ExtractPeer(dialogs[len(m.buf)-1].GetPeer())
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "get offset peer")
|
||||
}
|
||||
m.offsetPeer = p
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Iterator) requestNext(ctx context.Context) error {
|
||||
r, err := m.query.Query(ctx, Request{
|
||||
OffsetID: m.offsetID,
|
||||
OffsetDate: m.offsetDate,
|
||||
OffsetPeer: m.offsetPeer,
|
||||
Limit: m.limit,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return m.apply(r)
|
||||
}
|
||||
|
||||
func (m *Iterator) bufNext() bool {
|
||||
if len(m.buf)-1 <= m.bufCur {
|
||||
return false
|
||||
}
|
||||
|
||||
m.bufCur++
|
||||
return true
|
||||
}
|
||||
|
||||
// Total returns last fetched count of elements.
|
||||
// If count was not fetched before, it requests server using FetchTotal.
|
||||
func (m *Iterator) Total(ctx context.Context) (int, error) {
|
||||
if m.totalGot {
|
||||
return m.count, nil
|
||||
}
|
||||
|
||||
return m.FetchTotal(ctx)
|
||||
}
|
||||
|
||||
// FetchTotal fetches and returns count of elements.
|
||||
func (m *Iterator) FetchTotal(ctx context.Context) (int, error) {
|
||||
r, err := m.query.Query(ctx, Request{
|
||||
Limit: 1,
|
||||
OffsetPeer: &tg.InputPeerEmpty{},
|
||||
})
|
||||
if err != nil {
|
||||
return 0, errors.Wrap(err, "fetch total")
|
||||
}
|
||||
|
||||
switch dlgs := r.(type) {
|
||||
case *tg.MessagesDialogs: // messages.dialogs#15ba6c40
|
||||
m.count = len(dlgs.Dialogs)
|
||||
case *tg.MessagesDialogsSlice: // messages.dialogsSlice#71e094f3
|
||||
m.count = dlgs.Count
|
||||
default: // messages.dialogsNotModified#f0e3e596
|
||||
return 0, errors.Errorf("unexpected type %T", r)
|
||||
}
|
||||
|
||||
m.totalGot = true
|
||||
return m.count, nil
|
||||
}
|
||||
|
||||
// Next prepares the next message for reading with the Value method.
|
||||
// It returns true on success, or false if there is no next message or an error happened while preparing it.
|
||||
// Err should be consulted to distinguish between the two cases.
|
||||
func (m *Iterator) Next(ctx context.Context) bool {
|
||||
if m.lastErr != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if !m.bufNext() {
|
||||
// If buffer is empty, we should fetch next batch.
|
||||
if err := m.requestNext(ctx); err != nil {
|
||||
m.lastErr = err
|
||||
return false
|
||||
}
|
||||
// Try again with new buffer.
|
||||
return m.bufNext()
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// Value returns current message.
|
||||
func (m *Iterator) Value() Elem {
|
||||
return m.buf[m.bufCur]
|
||||
}
|
||||
|
||||
// Err returns the error, if any, that was encountered during iteration.
|
||||
func (m *Iterator) Err() error {
|
||||
return m.lastErr
|
||||
}
|
||||
@@ -0,0 +1,87 @@
|
||||
package dialogs
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tgmock"
|
||||
)
|
||||
|
||||
func generateDialogs(count int) []tg.DialogClass {
|
||||
r := make([]tg.DialogClass, 0, count)
|
||||
|
||||
for i := 0; i < count; i++ {
|
||||
r = append(r, &tg.Dialog{
|
||||
Peer: &tg.PeerChannel{ChannelID: int64(i)},
|
||||
})
|
||||
}
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
func result(r []tg.DialogClass, count int) tg.MessagesDialogsClass {
|
||||
msgs := make([]tg.MessageClass, 0, len(r))
|
||||
for i, dlg := range r {
|
||||
msgs = append(msgs, &tg.Message{
|
||||
ID: i,
|
||||
PeerID: dlg.GetPeer(),
|
||||
})
|
||||
}
|
||||
|
||||
chats := make([]tg.ChatClass, 0, len(r))
|
||||
for i, dlg := range r {
|
||||
id := dlg.GetPeer().(*tg.PeerChannel).ChannelID
|
||||
chats = append(chats, &tg.Channel{
|
||||
ID: id,
|
||||
AccessHash: 10,
|
||||
Photo: &tg.ChatPhoto{
|
||||
PhotoID: int64(i),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
return &tg.MessagesDialogsSlice{
|
||||
Dialogs: r,
|
||||
Messages: msgs,
|
||||
Chats: chats,
|
||||
Count: count,
|
||||
}
|
||||
}
|
||||
|
||||
func TestIterator(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
mock := tgmock.NewRequire(t)
|
||||
limit := 10
|
||||
totalRows := 3 * limit
|
||||
expected := generateDialogs(totalRows)
|
||||
raw := tg.NewClient(mock)
|
||||
|
||||
mock.Expect().ThenResult(result(expected[0:limit], totalRows))
|
||||
mock.Expect().ThenResult(result(expected[limit:2*limit], totalRows))
|
||||
mock.Expect().ThenResult(result(expected[2*limit:3*limit], totalRows))
|
||||
mock.Expect().ThenResult(result(expected[3*limit:], totalRows))
|
||||
|
||||
iter := NewQueryBuilder(raw).GetDialogs().BatchSize(10).Iter()
|
||||
i := 0
|
||||
for iter.Next(ctx) {
|
||||
require.Equal(t, expected[i].GetPeer(), iter.Value().Dialog.GetPeer())
|
||||
i++
|
||||
}
|
||||
require.NoError(t, iter.Err())
|
||||
require.Equal(t, totalRows, i)
|
||||
|
||||
total, err := iter.Total(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, totalRows, total)
|
||||
|
||||
mock.ExpectCall(&tg.MessagesGetDialogsRequest{
|
||||
OffsetPeer: &tg.InputPeerEmpty{},
|
||||
Limit: 1,
|
||||
}).ThenResult(result(expected[:0], totalRows))
|
||||
total, err = iter.FetchTotal(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, totalRows, total)
|
||||
}
|
||||
@@ -0,0 +1,66 @@
|
||||
package dialogs
|
||||
|
||||
import (
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
// PeerKind represents peer kind.
|
||||
type PeerKind int
|
||||
|
||||
const (
|
||||
// User is a private chat with user.
|
||||
User PeerKind = iota
|
||||
// Chat is a legacy chat.
|
||||
Chat
|
||||
// Channel is a supergroup/channel.
|
||||
Channel
|
||||
)
|
||||
|
||||
// DialogKey is a generic peer key.
|
||||
type DialogKey struct {
|
||||
Kind PeerKind
|
||||
ID int64
|
||||
AccessHash int64
|
||||
}
|
||||
|
||||
// FromInputPeer fills key using given peer.
|
||||
func (d *DialogKey) FromInputPeer(peer tg.InputPeerClass) error {
|
||||
switch v := peer.(type) {
|
||||
case *tg.InputPeerUser:
|
||||
d.Kind = User
|
||||
d.ID = v.UserID
|
||||
d.AccessHash = v.AccessHash
|
||||
case *tg.InputPeerChat:
|
||||
d.Kind = Chat
|
||||
d.ID = v.ChatID
|
||||
case *tg.InputPeerChannel:
|
||||
d.Kind = Channel
|
||||
d.ID = v.ChannelID
|
||||
d.AccessHash = v.AccessHash
|
||||
default:
|
||||
return errors.Errorf("unexpected type %T", peer)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// FromPeer fills key using given peer.
|
||||
func (d *DialogKey) FromPeer(peer tg.PeerClass) error {
|
||||
switch v := peer.(type) {
|
||||
case *tg.PeerUser:
|
||||
d.Kind = User
|
||||
d.ID = v.UserID
|
||||
case *tg.PeerChat:
|
||||
d.Kind = Chat
|
||||
d.ID = v.ChatID
|
||||
case *tg.PeerChannel:
|
||||
d.Kind = Channel
|
||||
d.ID = v.ChannelID
|
||||
default:
|
||||
return errors.Errorf("unexpected type %T", peer)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,151 @@
|
||||
// Code generated by itergen, DO NOT EDIT.
|
||||
|
||||
package dialogs
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
// No-op definition for keeping imports.
|
||||
var _ = context.Background()
|
||||
|
||||
// Request is a parameter for Query.
|
||||
type Request struct {
|
||||
OffsetDate int
|
||||
OffsetID int
|
||||
OffsetPeer tg.InputPeerClass
|
||||
Limit int
|
||||
}
|
||||
|
||||
// Query is an abstraction for dialogs request.
|
||||
// NB: iterator mutates returned data (sorts, at least).
|
||||
type Query interface {
|
||||
Query(ctx context.Context, req Request) (tg.MessagesDialogsClass, error)
|
||||
}
|
||||
|
||||
// QueryFunc is a function adapter for Query.
|
||||
type QueryFunc func(ctx context.Context, req Request) (tg.MessagesDialogsClass, error)
|
||||
|
||||
// Query implements Query interface.
|
||||
func (q QueryFunc) Query(ctx context.Context, req Request) (tg.MessagesDialogsClass, error) {
|
||||
return q(ctx, req)
|
||||
}
|
||||
|
||||
// QueryBuilder is a helper to create message queries.
|
||||
type QueryBuilder struct {
|
||||
raw *tg.Client
|
||||
}
|
||||
|
||||
// NewQueryBuilder creates new QueryBuilder.
|
||||
func NewQueryBuilder(raw *tg.Client) *QueryBuilder {
|
||||
return &QueryBuilder{raw: raw}
|
||||
}
|
||||
|
||||
// GetDialogsQueryBuilder is query builder of MessagesGetDialogs.
|
||||
type GetDialogsQueryBuilder struct {
|
||||
raw *tg.Client
|
||||
req tg.MessagesGetDialogsRequest
|
||||
batchSize int
|
||||
offsetDate int
|
||||
offsetID int
|
||||
offsetPeer tg.InputPeerClass
|
||||
}
|
||||
|
||||
// GetDialogs creates query builder of MessagesGetDialogs.
|
||||
func (q *QueryBuilder) GetDialogs() *GetDialogsQueryBuilder {
|
||||
b := &GetDialogsQueryBuilder{
|
||||
raw: q.raw,
|
||||
batchSize: 1,
|
||||
req: tg.MessagesGetDialogsRequest{},
|
||||
}
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
// BatchSize sets buffer of message loaded from one request.
|
||||
// Be carefully, when set this limit, because Telegram does not return error if limit is too big,
|
||||
// so results can be incorrect.
|
||||
func (b *GetDialogsQueryBuilder) BatchSize(batchSize int) *GetDialogsQueryBuilder {
|
||||
b.batchSize = batchSize
|
||||
return b
|
||||
}
|
||||
|
||||
// OffsetDate sets offsetDate from which iterate start.
|
||||
func (b *GetDialogsQueryBuilder) OffsetDate(offsetDate int) *GetDialogsQueryBuilder {
|
||||
b.offsetDate = offsetDate
|
||||
return b
|
||||
}
|
||||
|
||||
// OffsetID sets offsetID from which iterate start.
|
||||
func (b *GetDialogsQueryBuilder) OffsetID(offsetID int) *GetDialogsQueryBuilder {
|
||||
b.offsetID = offsetID
|
||||
return b
|
||||
}
|
||||
|
||||
// FolderID sets FolderID field of GetDialogs query.
|
||||
func (b *GetDialogsQueryBuilder) FolderID(paramFolderID int) *GetDialogsQueryBuilder {
|
||||
b.req.FolderID = paramFolderID
|
||||
return b
|
||||
}
|
||||
|
||||
// Query implements Query interface.
|
||||
func (b *GetDialogsQueryBuilder) Query(ctx context.Context, req Request) (tg.MessagesDialogsClass, error) {
|
||||
r := &tg.MessagesGetDialogsRequest{
|
||||
Limit: req.Limit,
|
||||
}
|
||||
|
||||
r.FolderID = b.req.FolderID
|
||||
r.OffsetDate = req.OffsetDate
|
||||
r.OffsetID = req.OffsetID
|
||||
r.OffsetPeer = req.OffsetPeer
|
||||
return b.raw.MessagesGetDialogs(ctx, r)
|
||||
}
|
||||
|
||||
// Iter returns iterator using built query.
|
||||
func (b *GetDialogsQueryBuilder) Iter() *Iterator {
|
||||
iter := NewIterator(b, b.batchSize)
|
||||
iter = iter.OffsetDate(b.offsetDate)
|
||||
iter = iter.OffsetID(b.offsetID)
|
||||
return iter
|
||||
}
|
||||
|
||||
// ForEach calls given callback on each iterator element.
|
||||
func (b *GetDialogsQueryBuilder) ForEach(ctx context.Context, cb func(context.Context, Elem) error) error {
|
||||
iter := b.Iter()
|
||||
for iter.Next(ctx) {
|
||||
if err := cb(ctx, iter.Value()); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return iter.Err()
|
||||
}
|
||||
|
||||
// Count fetches remote state to get number of elements.
|
||||
func (b *GetDialogsQueryBuilder) Count(ctx context.Context) (int, error) {
|
||||
iter := b.Iter()
|
||||
c, err := iter.Total(ctx)
|
||||
if err != nil {
|
||||
return 0, errors.Wrap(err, "get total")
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// Collect creates iterator and collects all elements to slice.
|
||||
func (b *GetDialogsQueryBuilder) Collect(ctx context.Context) ([]Elem, error) {
|
||||
iter := b.Iter()
|
||||
c, err := iter.Total(ctx)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "get total")
|
||||
}
|
||||
|
||||
r := make([]Elem, 0, c)
|
||||
for iter.Next(ctx) {
|
||||
r = append(r, iter.Value())
|
||||
}
|
||||
|
||||
return r, iter.Err()
|
||||
}
|
||||
@@ -0,0 +1,3 @@
|
||||
package dialogs
|
||||
|
||||
//go:generate go run go.mau.fi/mautrix-telegram/pkg/gotd/telegram/query/internal/itergen -result=MessagesDialogsClass -package=dialogs -out=queries.gen.go
|
||||
@@ -0,0 +1,60 @@
|
||||
package dialogs
|
||||
|
||||
import (
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/message/peer"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/query/channels/participants"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/query/messages"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/query/photos"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
// Messages returns new messages history query builder for current dialog.
|
||||
func (e Elem) Messages(raw *tg.Client) *messages.GetHistoryQueryBuilder {
|
||||
return messages.NewQueryBuilder(raw).GetHistory(e.Peer)
|
||||
}
|
||||
|
||||
// Search returns new search query builder for current dialog.
|
||||
func (e Elem) Search(raw *tg.Client) *messages.SearchQueryBuilder {
|
||||
return messages.NewQueryBuilder(raw).Search(e.Peer)
|
||||
}
|
||||
|
||||
// Replies returns new replies query builder for current dialog.
|
||||
func (e Elem) Replies(raw *tg.Client) *messages.GetRepliesQueryBuilder {
|
||||
return messages.NewQueryBuilder(raw).GetReplies(e.Peer)
|
||||
}
|
||||
|
||||
// UnreadMentions returns new unread mentions query builder for current dialog.
|
||||
func (e Elem) UnreadMentions(raw *tg.Client) *messages.GetUnreadMentionsQueryBuilder {
|
||||
return messages.NewQueryBuilder(raw).GetUnreadMentions(e.Peer)
|
||||
}
|
||||
|
||||
// RecentLocations returns new live location history query builder for current dialog.
|
||||
func (e Elem) RecentLocations(raw *tg.Client) *messages.GetRecentLocationsQueryBuilder {
|
||||
return messages.NewQueryBuilder(raw).GetRecentLocations(e.Peer)
|
||||
}
|
||||
|
||||
// UserPhotos returns new user photo query builder for current dialog.
|
||||
// If peer is not user, returns false.
|
||||
func (e Elem) UserPhotos(raw *tg.Client) (*photos.GetUserPhotosQueryBuilder, bool) {
|
||||
user, ok := peer.ToInputUser(e.Peer)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
return photos.NewQueryBuilder(raw).GetUserPhotos(user), true
|
||||
}
|
||||
|
||||
// Participants returns new channel participants query builder for current dialog.
|
||||
// If peer is not channel, returns false.
|
||||
func (e Elem) Participants(raw *tg.Client) (*participants.GetParticipantsQueryBuilder, bool) {
|
||||
channel, ok := peer.ToInputChannel(e.Peer)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
return participants.NewQueryBuilder(raw).GetParticipants(channel), true
|
||||
}
|
||||
|
||||
// Deleted denotes that dialog is deleted.
|
||||
func (e Elem) Deleted() bool {
|
||||
_, ok := e.Peer.(*tg.InputPeerEmpty)
|
||||
return ok
|
||||
}
|
||||
@@ -0,0 +1,54 @@
|
||||
package dialogs
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tgerr"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tgmock"
|
||||
)
|
||||
|
||||
func TestElem(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
mock := tgmock.NewRequire(t)
|
||||
raw := tg.NewClient(mock)
|
||||
|
||||
ch := Elem{
|
||||
Peer: &tg.InputPeerChannel{},
|
||||
}
|
||||
testErr := tgerr.New(1337, "TEST_ERROR")
|
||||
|
||||
var err error
|
||||
mock.Expect().ThenRPCErr(testErr)
|
||||
_, err = ch.Messages(raw).Count(ctx)
|
||||
require.Error(t, err)
|
||||
mock.Expect().ThenRPCErr(testErr)
|
||||
_, err = ch.Search(raw).Count(ctx)
|
||||
require.Error(t, err)
|
||||
mock.Expect().ThenRPCErr(testErr)
|
||||
_, err = ch.Replies(raw).Count(ctx)
|
||||
require.Error(t, err)
|
||||
mock.Expect().ThenRPCErr(testErr)
|
||||
_, err = ch.UnreadMentions(raw).Count(ctx)
|
||||
require.Error(t, err)
|
||||
mock.Expect().ThenRPCErr(testErr)
|
||||
_, err = ch.RecentLocations(raw).Count(ctx)
|
||||
require.Error(t, err)
|
||||
|
||||
_, ok := ch.Participants(raw)
|
||||
require.True(t, ok)
|
||||
_, ok = ch.UserPhotos(raw)
|
||||
require.False(t, ok)
|
||||
|
||||
ch = Elem{
|
||||
Peer: &tg.InputPeerUser{},
|
||||
}
|
||||
|
||||
_, ok = ch.Participants(raw)
|
||||
require.False(t, ok)
|
||||
_, ok = ch.UserPhotos(raw)
|
||||
require.True(t, ok)
|
||||
}
|
||||
@@ -0,0 +1,30 @@
|
||||
// Package hasher contains Telegram pagination hash implementation.
|
||||
package hasher
|
||||
|
||||
// Hasher implements Telegram pagination hash counting.
|
||||
//
|
||||
// See https://core.telegram.org/api/offsets#hash-generation.
|
||||
type Hasher struct {
|
||||
state uint32
|
||||
}
|
||||
|
||||
// Reset resets the Hasher to its initial state.
|
||||
func (h *Hasher) Reset() {
|
||||
h.state = 0
|
||||
}
|
||||
|
||||
// Update performs state change using given value.
|
||||
func (h *Hasher) Update(value uint32) {
|
||||
h.state = (h.state * 20261) + value
|
||||
}
|
||||
|
||||
// Update64 performs state change using given 64-bit value.
|
||||
func (h *Hasher) Update64(value uint64) {
|
||||
h.Update(uint32(value >> 32))
|
||||
h.Update(uint32(value & 0xFFFFFFFF))
|
||||
}
|
||||
|
||||
// Sum returns final sum.
|
||||
func (h *Hasher) Sum() int64 {
|
||||
return int64(h.state & 0x7FFFFFFF)
|
||||
}
|
||||
@@ -0,0 +1,18 @@
|
||||
package hasher
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestHasher(t *testing.T) {
|
||||
hasher := Hasher{}
|
||||
data := []int{7, 5, 16, 8}
|
||||
|
||||
for i := range data {
|
||||
hasher.Update(uint32(data[i]))
|
||||
}
|
||||
|
||||
require.Equal(t, int64(611477280), hasher.Sum())
|
||||
}
|
||||
@@ -0,0 +1,119 @@
|
||||
{{ define "header" }}{{- /*gotype: go.mau.fi/mautrix-telegram/pkg/gotd/telegram/query/internal/cachedgen.Config*/ -}}
|
||||
// Code generated by itergen, DO NOT EDIT.
|
||||
|
||||
package {{ $.Package }}
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
// No-op definition for keeping imports.
|
||||
var _ = context.Background()
|
||||
|
||||
{{ range $query := $.Queries }}
|
||||
|
||||
type inner{{ $query.Name }} struct {
|
||||
// Last received hash.
|
||||
hash int64
|
||||
// Last received result.
|
||||
value *tg.{{ $query.ResultName }}
|
||||
}
|
||||
|
||||
type {{ $query.Name }} struct {
|
||||
{{- if $query.RequestParams }}
|
||||
// Query to send.
|
||||
req *tg.{{ $query.RequestName }}{{ end }}
|
||||
// Result state.
|
||||
last atomic.Value
|
||||
|
||||
// Reference to RPC client to make requests.
|
||||
raw *tg.Client
|
||||
}
|
||||
|
||||
// New{{ $query.Name }} creates new {{ $query.Name }}.
|
||||
func New{{ $query.Name }}(raw *tg.Client, {{- if $query.RequestParams }}initial *tg.{{ $query.RequestName }}{{- end }}) *{{ $query.Name }} {
|
||||
q := &{{ $query.Name }}{
|
||||
{{- if $query.RequestParams }}
|
||||
req: initial,{{ end }}
|
||||
raw: raw,
|
||||
}
|
||||
|
||||
return q
|
||||
}
|
||||
|
||||
func (s *{{ $query.Name }}) store(v inner{{ $query.Name }}) {
|
||||
s.last.Store(v)
|
||||
}
|
||||
|
||||
func (s *{{ $query.Name }}) load() (inner{{ $query.Name }}, bool) {
|
||||
v, ok := s.last.Load().(inner{{ $query.Name }})
|
||||
return v, ok
|
||||
}
|
||||
|
||||
// Value returns last received result.
|
||||
// NB: May be nil. Returned {{ $query.ResultName }} must not be mutated.
|
||||
func (s *{{ $query.Name }}) Value() *tg.{{ $query.ResultName }} {
|
||||
inner, _ := s.load()
|
||||
return inner.value
|
||||
}
|
||||
|
||||
// Hash returns last received hash.
|
||||
func (s *{{ $query.Name }}) Hash() int64 {
|
||||
inner, _ := s.load()
|
||||
return inner.hash
|
||||
}
|
||||
|
||||
// Get updates data if needed and returns it.
|
||||
func (s *{{ $query.Name }}) Get(ctx context.Context) (*tg.{{ $query.ResultName }}, error) {
|
||||
if _, err := s.Fetch(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s.Value(), nil
|
||||
}
|
||||
|
||||
// Fetch updates data if needed and returns true if data was modified.
|
||||
func (s *{{ $query.Name }}) Fetch(ctx context.Context) (bool, error) {
|
||||
lastHash := s.Hash()
|
||||
|
||||
{{ if $query.RequestParams -}}
|
||||
req := s.req
|
||||
req.Hash = lastHash
|
||||
{{- else -}}
|
||||
req := lastHash
|
||||
{{- end }}
|
||||
result, err := s.raw.{{ $query.MethodName }}(ctx, req)
|
||||
if err != nil {
|
||||
return false, errors.Wrap(err, "execute {{ $query.MethodName }}")
|
||||
}
|
||||
|
||||
switch variant := result.(type) {
|
||||
case *tg.{{ $query.ResultName }}:
|
||||
{{ if $query.ManualHash -}}
|
||||
hash := s.computeHash(variant)
|
||||
{{- else -}}
|
||||
hash := variant.Hash
|
||||
{{- end }}
|
||||
|
||||
s.store(inner{{ $query.Name }}{
|
||||
hash: hash,
|
||||
value: variant,
|
||||
})
|
||||
return true, nil
|
||||
case *tg.{{ $query.NotModifiedName }}:
|
||||
if lastHash == 0 {
|
||||
return false, errors.Errorf("got unexpected %T result", result)
|
||||
}
|
||||
return false, nil
|
||||
default:
|
||||
return false, errors.Errorf("unexpected type %T", result)
|
||||
}
|
||||
}
|
||||
{{ end }}
|
||||
|
||||
{{ end }}
|
||||
@@ -0,0 +1,143 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"go/types"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/tools/go/packages"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/query/internal/genutil"
|
||||
)
|
||||
|
||||
func isHashField(field *types.Var) bool {
|
||||
basic, ok := field.Type().(*types.Basic)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
return basic.Kind() == types.Int64 && field.Name() == "Hash"
|
||||
}
|
||||
|
||||
func hasHashField(st *types.Struct) bool {
|
||||
for i := 0; i < st.NumFields(); i++ {
|
||||
if isHashField(st.Field(i)) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
type request struct {
|
||||
name string
|
||||
params []Param
|
||||
}
|
||||
|
||||
func isCachedQuery(args *types.Tuple) (request, bool) {
|
||||
arg := args.At(1)
|
||||
switch req := arg.Type().(type) {
|
||||
case *types.Pointer:
|
||||
named, ok := req.Elem().(*types.Named)
|
||||
if !ok {
|
||||
return request{}, false
|
||||
}
|
||||
|
||||
st, ok := named.Underlying().(*types.Struct)
|
||||
if !ok {
|
||||
return request{}, false
|
||||
}
|
||||
|
||||
var r []Param
|
||||
for i := 0; i < st.NumFields(); i++ {
|
||||
field := st.Field(i)
|
||||
if strings.Contains(field.Name(), "Offset") {
|
||||
return request{}, false
|
||||
}
|
||||
|
||||
if isHashField(field) || field.Name() == "Flags" {
|
||||
continue
|
||||
}
|
||||
|
||||
r = append(r, varToParam(field))
|
||||
}
|
||||
|
||||
return request{
|
||||
name: named.Obj().Name(),
|
||||
params: sortParams(r),
|
||||
}, hasHashField(st)
|
||||
case *types.Basic:
|
||||
if req.Kind() != types.Int64 || arg.Name() != "hash" {
|
||||
return request{}, false
|
||||
}
|
||||
return request{}, true
|
||||
default:
|
||||
return request{}, false
|
||||
}
|
||||
}
|
||||
|
||||
func collect(pkg *packages.Package) []CachedQuery {
|
||||
var r []CachedQuery
|
||||
|
||||
for _, def := range genutil.Funcs(pkg, func(f genutil.Func) bool {
|
||||
return f.Args().Len() == 2 && f.Results().Len() == 2
|
||||
}) {
|
||||
args := def.Args()
|
||||
req, ok := isCachedQuery(args)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
resultNamed, ok := def.Results().At(0).Type().(*types.Named)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
result, ok := resultNamed.Underlying().(*types.Interface)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
impls := genutil.Implementations(pkg, result)
|
||||
if len(impls) != 2 {
|
||||
continue
|
||||
}
|
||||
var (
|
||||
notModified *types.Named
|
||||
pure *types.Named
|
||||
)
|
||||
for _, impl := range impls {
|
||||
if notModified == nil && strings.Contains(impl.Obj().Name(), "NotModified") {
|
||||
notModified = impl
|
||||
continue
|
||||
}
|
||||
|
||||
if pure == nil {
|
||||
pure = impl
|
||||
}
|
||||
}
|
||||
if pure == nil || notModified == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
pureStruct, ok := pure.Underlying().(*types.Struct)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
r = append(r, CachedQuery{
|
||||
Name: def.Decl.Name(),
|
||||
MethodName: def.Decl.Name(),
|
||||
RequestName: req.name,
|
||||
ManualHash: !hasHashField(pureStruct),
|
||||
RequestParams: req.params,
|
||||
ResultName: pure.Obj().Name(),
|
||||
NotModifiedName: notModified.Obj().Name(),
|
||||
})
|
||||
}
|
||||
sort.SliceStable(r, func(i, j int) bool {
|
||||
return r[i].Name < r[j].Name
|
||||
})
|
||||
|
||||
return r
|
||||
}
|
||||
@@ -0,0 +1,65 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"embed"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/signal"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
"go.uber.org/multierr"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/query/internal/genutil"
|
||||
)
|
||||
|
||||
//go:embed _template/*.tmpl
|
||||
var templates embed.FS
|
||||
|
||||
func generate(ctx context.Context, out io.Writer, pkgName string) error {
|
||||
pkg, err := genutil.Load(ctx, "go.mau.fi/mautrix-telegram/pkg/gotd/tg")
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "load")
|
||||
}
|
||||
|
||||
return genutil.WriteTemplate(templates, out, "header", Config{
|
||||
Queries: collect(pkg),
|
||||
Package: pkgName,
|
||||
})
|
||||
}
|
||||
|
||||
func run(ctx context.Context) (err error) {
|
||||
var out io.Writer = os.Stdout
|
||||
|
||||
set := flag.NewFlagSet("gen", flag.ExitOnError)
|
||||
output := set.String("out", "", "output file")
|
||||
pkgName := set.String("package", "cached", "name of package name to generate")
|
||||
if err := set.Parse(os.Args[1:]); err != nil {
|
||||
return errors.Wrap(err, "parse flags")
|
||||
}
|
||||
|
||||
if *output != "" {
|
||||
f, err := os.Create(*output)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "can't create file %q", *output)
|
||||
}
|
||||
defer func() {
|
||||
multierr.AppendInto(&err, f.Close())
|
||||
}()
|
||||
out = f
|
||||
}
|
||||
|
||||
return generate(ctx, out, *pkgName)
|
||||
}
|
||||
|
||||
func main() {
|
||||
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt)
|
||||
defer cancel()
|
||||
|
||||
if err := run(ctx); err != nil {
|
||||
fmt.Println(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,14 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGenerate(t *testing.T) {
|
||||
var out bytes.Buffer
|
||||
if err := generate(context.Background(), &out, "testgen"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,65 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"go/types"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/query/internal/genutil"
|
||||
)
|
||||
|
||||
// Param represents request parameter.
|
||||
type Param struct {
|
||||
// Name to use in function declaration.
|
||||
Name string
|
||||
// OriginalName in struct definition.
|
||||
OriginalName string
|
||||
// Go type.
|
||||
Type string
|
||||
}
|
||||
|
||||
func varToParam(field *types.Var) Param {
|
||||
fieldName := field.Name()
|
||||
fieldName = strings.ToLower(fieldName[:1]) + fieldName[1:]
|
||||
return Param{
|
||||
Name: fieldName,
|
||||
OriginalName: field.Name(),
|
||||
Type: genutil.PrintType(field.Type()),
|
||||
}
|
||||
}
|
||||
|
||||
func sortParams(p []Param) []Param {
|
||||
sort.SliceStable(p, func(i, j int) bool {
|
||||
return p[i].Name < p[j].Name
|
||||
})
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
// CachedQuery is a RPC cacheable query helper.
|
||||
type CachedQuery struct {
|
||||
// Name of struct to generate.
|
||||
Name string
|
||||
// MethodName is name of method of tg.Client.
|
||||
MethodName string
|
||||
// RequestName is name of request struct.
|
||||
RequestName string
|
||||
// ManualHash determines whether hash must be computed using
|
||||
// hand-written function computeHash or not.
|
||||
// Need to resolve case when Telegram does not return hash with result.
|
||||
ManualHash bool
|
||||
// RequestParams contains additional params to send.
|
||||
RequestParams []Param
|
||||
// ResultName is name of result type.
|
||||
ResultName string
|
||||
// NotModifiedName is name of NotModified result type.
|
||||
NotModifiedName string
|
||||
}
|
||||
|
||||
// Config is codegeneration config to use.
|
||||
type Config struct {
|
||||
// Query helpers to generate.
|
||||
Queries []CachedQuery
|
||||
// ResultName package name
|
||||
Package string
|
||||
}
|
||||
@@ -0,0 +1,2 @@
|
||||
// Package genutil is a utility package for query helpers codegeneration.
|
||||
package genutil
|
||||
@@ -0,0 +1,55 @@
|
||||
package genutil
|
||||
|
||||
import (
|
||||
"go/types"
|
||||
|
||||
"golang.org/x/tools/go/packages"
|
||||
)
|
||||
|
||||
// Func is a function representation.
|
||||
type Func struct {
|
||||
Sig *types.Signature
|
||||
Decl *types.Func
|
||||
}
|
||||
|
||||
// Results returns function results.
|
||||
func (f Func) Results() *types.Tuple {
|
||||
return f.Sig.Results()
|
||||
}
|
||||
|
||||
// Args returns function arguments.
|
||||
func (f Func) Args() *types.Tuple {
|
||||
return f.Sig.Params()
|
||||
}
|
||||
|
||||
// Funcs collects all function from package using given filter.
|
||||
// Parameter keep may be nil.
|
||||
func Funcs(pkg *packages.Package, keep func(f Func) bool) []Func {
|
||||
var r []Func
|
||||
|
||||
for _, def := range pkg.TypesInfo.Defs {
|
||||
if def == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
f, ok := def.(*types.Func)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
sig, ok := f.Type().(*types.Signature)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
repr := Func{
|
||||
Sig: sig,
|
||||
Decl: f,
|
||||
}
|
||||
|
||||
if keep(repr) {
|
||||
r = append(r, repr)
|
||||
}
|
||||
}
|
||||
|
||||
return r
|
||||
}
|
||||
@@ -0,0 +1,46 @@
|
||||
package genutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"go/ast"
|
||||
"go/parser"
|
||||
"go/token"
|
||||
"os"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
"golang.org/x/tools/go/packages"
|
||||
)
|
||||
|
||||
func loadPackages(ctx context.Context, dir, pattern string, environ []string) ([]*packages.Package, error) {
|
||||
return packages.Load(&packages.Config{
|
||||
Context: ctx,
|
||||
Dir: dir,
|
||||
Mode: packages.NeedTypes |
|
||||
packages.NeedTypesInfo |
|
||||
packages.NeedTypesSizes |
|
||||
packages.NeedSyntax |
|
||||
packages.NeedDeps,
|
||||
Env: environ,
|
||||
Fset: token.NewFileSet(),
|
||||
ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) {
|
||||
const mode = parser.AllErrors | parser.ParseComments
|
||||
return parser.ParseFile(fset, filename, src, mode)
|
||||
},
|
||||
}, pattern)
|
||||
}
|
||||
|
||||
// Load loads package using given pattern.
|
||||
func Load(ctx context.Context, pattern string) (*packages.Package, error) {
|
||||
pkgs, err := loadPackages(ctx, "", pattern, os.Environ())
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "load packages")
|
||||
}
|
||||
|
||||
for _, pkg := range pkgs {
|
||||
if pkg.ID == pattern {
|
||||
return pkg, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, errors.Errorf("package %s not found", pattern)
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
package genutil
|
||||
|
||||
import "go/types"
|
||||
|
||||
// PrintType prints typename into string without package name.
|
||||
func PrintType(typ types.Type) string {
|
||||
return types.TypeString(typ, func(i *types.Package) string {
|
||||
return i.Name()
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,75 @@
|
||||
package genutil
|
||||
|
||||
import (
|
||||
"go/types"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
"golang.org/x/tools/go/packages"
|
||||
)
|
||||
|
||||
// Implementations finds iface implementations.
|
||||
func Implementations(pkg *packages.Package, iface *types.Interface) []*types.Named {
|
||||
var r []*types.Named
|
||||
|
||||
for _, def := range pkg.TypesInfo.Defs {
|
||||
if def == nil || !def.Exported() {
|
||||
continue
|
||||
}
|
||||
|
||||
named, ok := def.Type().(*types.Named)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if !types.Implements(types.NewPointer(named), iface) {
|
||||
continue
|
||||
}
|
||||
|
||||
r = append(r, named)
|
||||
}
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
// Interfaces is a simple utility struct to find interfaces and implementations.
|
||||
type Interfaces struct {
|
||||
pkg *packages.Package
|
||||
implsCache map[string][]*types.Named
|
||||
}
|
||||
|
||||
// NewInterfaces creates new Interfaces structure.
|
||||
func NewInterfaces(pkg *packages.Package) *Interfaces {
|
||||
return &Interfaces{pkg: pkg, implsCache: map[string][]*types.Named{}}
|
||||
}
|
||||
|
||||
// Interface finds interface by name.
|
||||
func (c *Interfaces) Interface(name string) (*types.Interface, error) {
|
||||
obj := c.pkg.Types.Scope().Lookup(name)
|
||||
if obj == nil {
|
||||
return nil, errors.Errorf("%q not found", name)
|
||||
}
|
||||
|
||||
v, ok := obj.Type().Underlying().(*types.Interface)
|
||||
if !ok {
|
||||
return nil, errors.Errorf("%q has unexpected kind type %T", name, obj.Type().Underlying())
|
||||
}
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// Implementations finds interface implementations by interface name.
|
||||
func (c *Interfaces) Implementations(name string) ([]*types.Named, error) {
|
||||
impls, ok := c.implsCache[name]
|
||||
if ok {
|
||||
return impls, nil
|
||||
}
|
||||
|
||||
iface, err := c.Interface(name)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "find %q", name)
|
||||
}
|
||||
|
||||
impls = Implementations(c.pkg, iface)
|
||||
c.implsCache[name] = impls
|
||||
return impls, nil
|
||||
}
|
||||
@@ -0,0 +1,35 @@
|
||||
package genutil
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"go/format"
|
||||
"io"
|
||||
"io/fs"
|
||||
"os"
|
||||
"text/template"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/gen"
|
||||
)
|
||||
|
||||
// WriteTemplate loads template from FS and executes it to given output writer.
|
||||
func WriteTemplate(source fs.FS, out io.Writer, name string, data interface{}) error {
|
||||
tmpl := template.New("templates").Funcs(gen.Funcs())
|
||||
tmpl = template.Must(tmpl.ParseFS(source, "_template/*.tmpl"))
|
||||
var buf bytes.Buffer
|
||||
if err := tmpl.ExecuteTemplate(&buf, name, data); err != nil {
|
||||
return errors.Wrap(err, "template")
|
||||
}
|
||||
|
||||
formatted, err := format.Source(buf.Bytes())
|
||||
if err != nil {
|
||||
if _, cpyErr := io.Copy(os.Stdout, &buf); cpyErr != nil {
|
||||
return errors.Wrapf(cpyErr, "dump generated (original error: %v)", err)
|
||||
}
|
||||
return errors.Wrap(err, "format")
|
||||
}
|
||||
|
||||
_, err = out.Write(formatted)
|
||||
return err
|
||||
}
|
||||
@@ -0,0 +1,183 @@
|
||||
{{ define "header" }}{{- /*gotype: go.mau.fi/mautrix-telegram/pkg/gotd/telegram/query/internal/itergen.Config*/ -}}
|
||||
// Code generated by itergen, DO NOT EDIT.
|
||||
|
||||
package {{ $.Package }}
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
// No-op definition for keeping imports.
|
||||
var _ = context.Background()
|
||||
|
||||
// Request is a parameter for Query.
|
||||
type Request struct {
|
||||
{{- range $arg := $.RequestFields }}
|
||||
{{ $arg.OriginalName }} {{ $arg.Type }}
|
||||
{{- end }}
|
||||
Limit int
|
||||
}
|
||||
|
||||
// Query is an abstraction for {{ $.Package }} request.
|
||||
// NB: iterator mutates returned data (sorts, at least).
|
||||
type Query interface {
|
||||
Query(ctx context.Context, req Request) (tg.{{ $.ResultName }}, error)
|
||||
}
|
||||
|
||||
// QueryFunc is a function adapter for Query.
|
||||
type QueryFunc func(ctx context.Context, req Request) (tg.{{ $.ResultName }}, error)
|
||||
|
||||
// Query implements Query interface.
|
||||
func (q QueryFunc) Query(ctx context.Context, req Request) (tg.{{ $.ResultName }}, error) {
|
||||
return q(ctx, req)
|
||||
}
|
||||
|
||||
|
||||
// QueryBuilder is a helper to create message queries.
|
||||
type QueryBuilder struct {
|
||||
raw *tg.Client
|
||||
}
|
||||
|
||||
// NewQueryBuilder creates new QueryBuilder.
|
||||
func NewQueryBuilder(raw *tg.Client) *QueryBuilder {
|
||||
return &QueryBuilder{raw: raw}
|
||||
}
|
||||
|
||||
{{ range $method := $.Methods }}
|
||||
{{ template "query" $method }}
|
||||
{{- end }}
|
||||
{{ end }}
|
||||
|
||||
{{ define "query" }}{{- /*gotype: go.mau.fi/mautrix-telegram/pkg/gotd/telegram/query/internal/gen.Method*/ -}}
|
||||
// {{ $.Name }}QueryBuilder is query builder of {{ $.OriginalName }}.
|
||||
type {{ $.Name }}QueryBuilder struct {
|
||||
raw *tg.Client
|
||||
req {{ $.RequestName }}
|
||||
batchSize int
|
||||
{{- range $mapping := $.AdditionalMapping }}
|
||||
{{ $mapping.Arg.Name }} {{ $mapping.Arg.Type }}
|
||||
{{- end }}
|
||||
}
|
||||
|
||||
// {{ $.Name }} creates query builder of {{ $.OriginalName }}.
|
||||
func (q *QueryBuilder) {{ $.Name }}({{ range $arg := $.RequiredParams }}param{{ $arg.OriginalName }} {{ $arg.Type }},{{ end }}) *{{ $.Name }}QueryBuilder {
|
||||
b := &{{ $.Name }}QueryBuilder{
|
||||
raw: q.raw,
|
||||
batchSize: 1,
|
||||
req:{{ $.RequestName }}{
|
||||
{{- range $f := $.AdditionalParams }}{{ if eq ($f.Type) ("tg.MessagesFilterClass") }}
|
||||
{{ $f.OriginalName }}: &tg.InputMessagesFilterEmpty{},
|
||||
{{- end }}{{ if eq ($f.Type) ("tg.InputPeerClass") }}
|
||||
{{ $f.OriginalName }}: &tg.InputPeerEmpty{},
|
||||
{{- end }}{{ if eq ($f.Type) ("tg.ChannelParticipantsFilterClass") }}
|
||||
{{ $f.OriginalName }}: &tg.ChannelParticipantsRecent{},
|
||||
{{- end }}{{ end }}
|
||||
},
|
||||
}
|
||||
{{ range $arg := $.RequiredParams }}
|
||||
b.req.{{ $arg.OriginalName }} = param{{ $arg.OriginalName }}
|
||||
{{- end }}
|
||||
return b
|
||||
}
|
||||
|
||||
// BatchSize sets buffer of message loaded from one request.
|
||||
// Be carefully, when set this limit, because Telegram does not return error if limit is too big,
|
||||
// so results can be incorrect.
|
||||
func (b *{{ $.Name }}QueryBuilder) BatchSize(batchSize int) *{{ $.Name }}QueryBuilder {
|
||||
b.batchSize = batchSize
|
||||
return b
|
||||
}
|
||||
|
||||
{{- range $mapping := $.AdditionalMapping }}{{ if $mapping.Chain }}
|
||||
// {{ $mapping.Arg.OriginalName }} sets {{ $mapping.Arg.Name }} from which iterate start.
|
||||
func (b *{{ $.Name }}QueryBuilder) {{ $mapping.Arg.OriginalName }}({{ $mapping.Arg.Name }} int) *{{ $.Name }}QueryBuilder {
|
||||
b.{{ $mapping.Arg.Name }} = {{ $mapping.Arg.Name }}
|
||||
return b
|
||||
}
|
||||
{{- end }}{{- end }}
|
||||
|
||||
{{ range $f := $.AdditionalParams }}
|
||||
// {{ $f.OriginalName }} sets {{ $f.OriginalName }} field of {{ $.Name }} query.
|
||||
func (b *{{ $.Name }}QueryBuilder) {{ $f.OriginalName }}(param{{ $f.OriginalName }} {{ $f.Type }}) *{{ $.Name }}QueryBuilder {
|
||||
b.req.{{ $f.OriginalName }} = param{{ $f.OriginalName }}
|
||||
return b
|
||||
}
|
||||
{{ end }}
|
||||
|
||||
{{ range $f := $.SpecialCase }}
|
||||
// {{ $f.ConstructorName }} sets {{ $f.Field.OriginalName }} field of {{ $.Name }} query.
|
||||
func (b *{{ $.Name }}QueryBuilder) {{ $f.ConstructorName }}({{ range $arg := $f.Args }}param{{ $arg.OriginalName }} {{ $arg.Type }},{{ end }}) *{{ $.Name }}QueryBuilder {
|
||||
b.req.{{ $f.Field.OriginalName }} = &{{ $f.ConstructorType }}{
|
||||
{{- range $arg := $f.Args }}
|
||||
{{ $arg.OriginalName }}: param{{ $arg.OriginalName }},
|
||||
{{- end }}
|
||||
}
|
||||
return b
|
||||
}
|
||||
{{ end }}
|
||||
|
||||
|
||||
// Query implements Query interface.
|
||||
func (b *{{ $.Name }}QueryBuilder) Query(ctx context.Context, req Request) ({{ $.ResultName }}, error) {
|
||||
r := &{{ $.RequestName }}{
|
||||
Limit: req.Limit,
|
||||
}
|
||||
{{ range $f := $.AdditionalParams }}
|
||||
r.{{ $f.OriginalName }} = b.req.{{ $f.OriginalName }}
|
||||
{{- end }}
|
||||
{{- range $f := $.AdditionalMapping }}
|
||||
r.{{ $f.Arg.OriginalName }} = req.{{ $f.Arg.OriginalName }}
|
||||
{{- end }}
|
||||
return b.raw.{{ $.OriginalName }}(ctx, r)
|
||||
}
|
||||
|
||||
// Iter returns iterator using built query.
|
||||
func (b *{{ $.Name }}QueryBuilder) Iter() *{{ $.IteratorName }} {
|
||||
iter := New{{ $.IteratorName }}(b, b.batchSize)
|
||||
{{- range $mapping := $.AdditionalMapping }}{{ if $mapping.RequiredByIter }}
|
||||
iter = iter.{{ $mapping.Arg.OriginalName }}(b.{{ $mapping.Arg.Name }})
|
||||
{{- end }}{{- end }}
|
||||
return iter
|
||||
}
|
||||
|
||||
// ForEach calls given callback on each iterator element.
|
||||
func (b *{{ $.Name }}QueryBuilder) ForEach(ctx context.Context, cb func(context.Context, {{ $.ElemName }}) error) error {
|
||||
iter := b.Iter()
|
||||
for iter.Next(ctx) {
|
||||
if err := cb(ctx, iter.Value()); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return iter.Err()
|
||||
}
|
||||
|
||||
// Count fetches remote state to get number of elements.
|
||||
func (b *{{ $.Name }}QueryBuilder) Count(ctx context.Context) (int, error) {
|
||||
iter := b.Iter()
|
||||
c, err := iter.Total(ctx)
|
||||
if err != nil {
|
||||
return 0, errors.Wrap(err, "get total")
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// Collect creates iterator and collects all elements to slice.
|
||||
func (b *{{ $.Name }}QueryBuilder) Collect(ctx context.Context) ([]{{ $.ElemName }}, error) {
|
||||
iter := b.Iter()
|
||||
c, err := iter.Total(ctx)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "get total")
|
||||
}
|
||||
|
||||
r := make([]{{ $.ElemName }}, 0, c)
|
||||
for iter.Next(ctx) {
|
||||
r = append(r, iter.Value())
|
||||
}
|
||||
|
||||
return r, iter.Err()
|
||||
}
|
||||
{{ end }}
|
||||
@@ -0,0 +1,251 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"go/token"
|
||||
"go/types"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
"golang.org/x/tools/go/packages"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/query/internal/genutil"
|
||||
)
|
||||
|
||||
type method struct {
|
||||
name string
|
||||
f *types.Func
|
||||
sig *types.Signature
|
||||
reqType types.Type
|
||||
resultType types.Type
|
||||
|
||||
fromRequest []RequestArgument
|
||||
params []Param
|
||||
}
|
||||
|
||||
type collector struct {
|
||||
ignoreFields map[string]struct{}
|
||||
canFillFromRequest map[string]struct{}
|
||||
requiredByIter []string
|
||||
required map[string]string
|
||||
|
||||
pkg *packages.Package
|
||||
ifaces *genutil.Interfaces
|
||||
|
||||
iface *types.Interface
|
||||
resultTypeName string
|
||||
elemName string
|
||||
prefix string
|
||||
pkgName string
|
||||
requestFields []Param
|
||||
}
|
||||
|
||||
type collectorConfig struct {
|
||||
ResultName string
|
||||
ElemName string
|
||||
Prefix string
|
||||
PkgName string
|
||||
}
|
||||
|
||||
func (c *collectorConfig) fromFlags(set *flag.FlagSet) {
|
||||
set.StringVar(&c.ResultName, "result", "MessagesMessagesClass", "result type name")
|
||||
set.StringVar(&c.ElemName, "elem", "Elem", "element type name")
|
||||
set.StringVar(&c.Prefix, "prefix", "Messages", "prefix of methods to trim")
|
||||
set.StringVar(&c.PkgName, "package", "messages", "name of package name to generate")
|
||||
}
|
||||
|
||||
func newCollector(pkg *packages.Package, cfg collectorConfig) *collector {
|
||||
intGetter := types.NewSignatureType(nil, nil, nil, nil,
|
||||
types.NewTuple(types.NewVar(0, nil, "", types.Typ[types.Int])), false) // func() int
|
||||
methods := []*types.Func{
|
||||
types.NewFunc(token.NoPos, nil, "GetLimit", intGetter),
|
||||
}
|
||||
match := types.NewInterfaceType(methods, nil).Complete()
|
||||
|
||||
canFillFromRequest := map[string]struct{}{
|
||||
"AddOffset": {},
|
||||
"OffsetID": {},
|
||||
"OffsetDate": {},
|
||||
"OffsetPeer": {},
|
||||
"OffsetRate": {},
|
||||
"Offset": {},
|
||||
}
|
||||
ignoreFields := map[string]struct{}{
|
||||
// Already handled by match interface.
|
||||
"Limit": {},
|
||||
// Not real field.
|
||||
"Flags": {},
|
||||
// Telegram ignores MaxID and MinID sometimes.
|
||||
"MaxID": {}, "MinID": {},
|
||||
// ExcludePinned used by iterator.
|
||||
"ExcludePinned": {},
|
||||
// Hash can be used internally, so do not expose it.
|
||||
"Hash": {},
|
||||
}
|
||||
requiredByIter := []string{
|
||||
"OffsetID",
|
||||
"OffsetDate",
|
||||
"Offset",
|
||||
}
|
||||
required := map[string]string{
|
||||
"Peer": "InputPeerClass",
|
||||
"Channel": "InputChannelClass",
|
||||
"UserID": "InputUserClass",
|
||||
}
|
||||
|
||||
return &collector{
|
||||
ignoreFields: ignoreFields,
|
||||
canFillFromRequest: canFillFromRequest,
|
||||
requiredByIter: requiredByIter,
|
||||
required: required,
|
||||
ifaces: genutil.NewInterfaces(pkg),
|
||||
pkg: pkg,
|
||||
iface: match,
|
||||
resultTypeName: cfg.ResultName,
|
||||
elemName: cfg.ElemName,
|
||||
prefix: cfg.Prefix,
|
||||
pkgName: cfg.PkgName,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *collector) methods() ([]method, error) { // nolint:gocognit
|
||||
var result []method
|
||||
|
||||
for _, def := range genutil.Funcs(c.pkg, func(f genutil.Func) bool {
|
||||
return f.Args().Len() == 2 && f.Results().Len() == 2
|
||||
}) {
|
||||
args := def.Args()
|
||||
results := def.Results()
|
||||
|
||||
ptr, ok := args.At(1).Type().(*types.Pointer)
|
||||
if !ok || !types.Implements(ptr, c.iface) {
|
||||
continue
|
||||
}
|
||||
reqType := ptr.Elem()
|
||||
|
||||
resultType, ok := results.At(0).Type().(*types.Named)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if resultType.Obj().Name() != c.resultTypeName {
|
||||
continue
|
||||
}
|
||||
name := strings.TrimPrefix(def.Decl.Name(), c.prefix)
|
||||
|
||||
m := method{
|
||||
name: name,
|
||||
f: def.Decl,
|
||||
sig: def.Sig,
|
||||
reqType: reqType,
|
||||
resultType: resultType,
|
||||
}
|
||||
|
||||
reqTypeStruct, ok := reqType.Underlying().(*types.Struct)
|
||||
if !ok {
|
||||
return nil, errors.Errorf("unexpected type %T", reqType.Underlying())
|
||||
}
|
||||
|
||||
for i := 0; i < reqTypeStruct.NumFields(); i++ {
|
||||
field := reqTypeStruct.Field(i)
|
||||
|
||||
if _, ok := c.ignoreFields[field.Name()]; ok {
|
||||
continue
|
||||
}
|
||||
|
||||
param := varToParam(field)
|
||||
if _, ok := c.canFillFromRequest[field.Name()]; ok {
|
||||
requiredByIter := false
|
||||
for _, field := range c.requiredByIter {
|
||||
if field == param.OriginalName {
|
||||
requiredByIter = true
|
||||
break
|
||||
}
|
||||
}
|
||||
m.fromRequest = append(m.fromRequest, RequestArgument{
|
||||
Arg: param,
|
||||
Chain: field.Name() == "OffsetID" || field.Name() == "OffsetDate",
|
||||
RequiredByIter: requiredByIter,
|
||||
})
|
||||
|
||||
skip := false
|
||||
for _, field := range c.requestFields {
|
||||
if field.OriginalName == param.OriginalName {
|
||||
skip = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !skip {
|
||||
c.requestFields = append(c.requestFields, param)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
m.params = append(m.params, param)
|
||||
}
|
||||
|
||||
result = append(result, m)
|
||||
}
|
||||
|
||||
sort.SliceStable(result, func(i, j int) bool {
|
||||
return result[i].name < result[j].name
|
||||
})
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (c *collector) Config() (Config, error) {
|
||||
methods, err := c.collect()
|
||||
if err != nil {
|
||||
return Config{}, errors.Wrap(err, "collect")
|
||||
}
|
||||
|
||||
return Config{
|
||||
Methods: methods,
|
||||
Package: c.pkgName,
|
||||
ResultName: c.resultTypeName,
|
||||
RequestFields: sortParams(c.requestFields),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *collector) collect() ([]Method, error) {
|
||||
methods, err := c.methods()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "collect types")
|
||||
}
|
||||
|
||||
result := make([]Method, 0, len(methods))
|
||||
for _, method := range methods {
|
||||
mapping := method.fromRequest
|
||||
sort.SliceStable(mapping, func(i, j int) bool {
|
||||
return mapping[i].Arg.Name < mapping[j].Arg.Name
|
||||
})
|
||||
|
||||
m := Method{
|
||||
Name: method.name,
|
||||
OriginalName: method.f.Name(),
|
||||
RequestName: genutil.PrintType(method.reqType),
|
||||
ResultName: genutil.PrintType(method.resultType),
|
||||
AdditionalMapping: mapping,
|
||||
AdditionalParams: sortParams(method.params),
|
||||
IteratorName: "Iterator",
|
||||
ElemName: c.elemName,
|
||||
}
|
||||
|
||||
for _, field := range method.params {
|
||||
if _, ok := c.required[field.OriginalName]; ok {
|
||||
m.RequiredParams = append(m.RequiredParams, field)
|
||||
}
|
||||
}
|
||||
m.RequiredParams = sortParams(m.RequiredParams)
|
||||
|
||||
cases, err := c.collectSpecial(m)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "collect special")
|
||||
}
|
||||
|
||||
m.SpecialCase = cases
|
||||
result = append(result, m)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
@@ -0,0 +1,69 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"embed"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/signal"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
"go.uber.org/multierr"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/query/internal/genutil"
|
||||
)
|
||||
|
||||
//go:embed _template/*.tmpl
|
||||
var templates embed.FS
|
||||
|
||||
func generate(ctx context.Context, out io.Writer, cfg collectorConfig) error {
|
||||
pkg, err := genutil.Load(ctx, "go.mau.fi/mautrix-telegram/pkg/gotd/tg")
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "load")
|
||||
}
|
||||
|
||||
c := newCollector(pkg, cfg)
|
||||
config, err := c.Config()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "collect")
|
||||
}
|
||||
|
||||
return genutil.WriteTemplate(templates, out, "header", config)
|
||||
}
|
||||
|
||||
func run(ctx context.Context) (err error) {
|
||||
var out io.Writer = os.Stdout
|
||||
|
||||
set := flag.NewFlagSet("gen", flag.ExitOnError)
|
||||
output := set.String("out", "", "output file")
|
||||
cfg := collectorConfig{}
|
||||
cfg.fromFlags(set)
|
||||
if err := set.Parse(os.Args[1:]); err != nil {
|
||||
return errors.Wrap(err, "parse flags")
|
||||
}
|
||||
|
||||
if *output != "" {
|
||||
f, err := os.Create(*output)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "can't create file %q", *output)
|
||||
}
|
||||
defer func() {
|
||||
multierr.AppendInto(&err, f.Close())
|
||||
}()
|
||||
out = f
|
||||
}
|
||||
|
||||
return generate(ctx, out, cfg)
|
||||
}
|
||||
|
||||
func main() {
|
||||
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt)
|
||||
defer cancel()
|
||||
|
||||
if err := run(ctx); err != nil {
|
||||
fmt.Println(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,19 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGenerate(t *testing.T) {
|
||||
var out bytes.Buffer
|
||||
if err := generate(context.Background(), &out, collectorConfig{
|
||||
ResultName: "MessagesMessagesClass",
|
||||
ElemName: "Elem",
|
||||
Prefix: "Messages",
|
||||
PkgName: "messages",
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,100 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"go/types"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/query/internal/genutil"
|
||||
)
|
||||
|
||||
// Param represents request parameter.
|
||||
type Param struct {
|
||||
// Name to use in function declaration.
|
||||
Name string
|
||||
// OriginalName in struct definition.
|
||||
OriginalName string
|
||||
// Go type.
|
||||
Type string
|
||||
}
|
||||
|
||||
func varToParam(field *types.Var) Param {
|
||||
fieldName := field.Name()
|
||||
fieldName = strings.ToLower(fieldName[:1]) + fieldName[1:]
|
||||
return Param{
|
||||
Name: fieldName,
|
||||
OriginalName: field.Name(),
|
||||
Type: genutil.PrintType(field.Type()),
|
||||
}
|
||||
}
|
||||
|
||||
func sortParams(p []Param) []Param {
|
||||
sort.SliceStable(p, func(i, j int) bool {
|
||||
return p[i].Name < p[j].Name
|
||||
})
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
// SpecialCaseChain represents special request parameter setter.
|
||||
type SpecialCaseChain struct {
|
||||
// ConstructorName to use in function body.
|
||||
ConstructorName string
|
||||
// ConstructorType to use in function body.
|
||||
ConstructorType string
|
||||
// Field of request struct.
|
||||
Field Param
|
||||
// Args is a slice of arguments. May be empty.
|
||||
Args []Param
|
||||
}
|
||||
|
||||
// RequestArgument represents request parameter passed by iterator.
|
||||
type RequestArgument struct {
|
||||
// Arg describes argument.
|
||||
Arg Param
|
||||
// Chain is flag to generate builder chain setter.
|
||||
Chain bool
|
||||
// RequiredByIter is flag to generate pass to iterator constructor.
|
||||
RequiredByIter bool
|
||||
}
|
||||
|
||||
// Method is a RPC method.
|
||||
type Method struct {
|
||||
// Name to use in function declaration.
|
||||
Name string
|
||||
// OriginalName is name of method of tg.Client.
|
||||
OriginalName string
|
||||
// RequestName is name of request struct.
|
||||
RequestName string
|
||||
// ResultName is name of result type.
|
||||
ResultName string
|
||||
// RequiredParams is a required params for query builder.
|
||||
RequiredParams []Param
|
||||
// AdditionalMapping is names of field from iterator.
|
||||
// Some type doesn't have AddOffset for example, so we customize mapping here.
|
||||
AdditionalMapping []RequestArgument
|
||||
|
||||
// SpecialCase is a slice of special case chains.
|
||||
// Like tg.MessagesFilterClass constructor field setters.
|
||||
SpecialCase []SpecialCaseChain
|
||||
|
||||
// Other parameters of request to pass it in constructor.
|
||||
AdditionalParams []Param
|
||||
|
||||
// IteratorName is name of iterator to build.
|
||||
IteratorName string
|
||||
// ElemName is name of iterator elem.
|
||||
ElemName string
|
||||
}
|
||||
|
||||
// Config is codegeneration config to use.
|
||||
type Config struct {
|
||||
// Methods to generate helpers and query builders.
|
||||
Methods []Method
|
||||
// ResultName package name
|
||||
Package string
|
||||
// ResultName is name of result type.
|
||||
ResultName string
|
||||
// RequestFields is a slice of request struct fields.
|
||||
RequestFields []Param
|
||||
}
|
||||
@@ -0,0 +1,90 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"go/types"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/query/internal/genutil"
|
||||
)
|
||||
|
||||
func (c *collector) unpackClass(
|
||||
field Param,
|
||||
typeName, trimPrefix string,
|
||||
) ([]SpecialCaseChain, error) {
|
||||
var r []SpecialCaseChain
|
||||
if field.Type == "tg."+typeName {
|
||||
impls, err := c.ifaces.Implementations(typeName)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "find %q constructors", typeName)
|
||||
}
|
||||
for _, impl := range impls {
|
||||
s, ok := impl.Underlying().(*types.Struct)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
cse := SpecialCaseChain{
|
||||
ConstructorName: strings.TrimPrefix(impl.Obj().Name(), trimPrefix),
|
||||
ConstructorType: genutil.PrintType(impl),
|
||||
Field: field,
|
||||
}
|
||||
|
||||
if strings.Contains(cse.ConstructorName, "Empty") {
|
||||
continue
|
||||
}
|
||||
|
||||
for i := 0; i < s.NumFields(); i++ {
|
||||
field := s.Field(i)
|
||||
if field.Name() == "Flags" {
|
||||
continue
|
||||
}
|
||||
|
||||
cse.Args = append(cse.Args, varToParam(field))
|
||||
}
|
||||
|
||||
cse.Args = sortParams(cse.Args)
|
||||
r = append(r, cse)
|
||||
}
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func (c *collector) unpackClasses(
|
||||
field Param,
|
||||
classes ...[2]string,
|
||||
) ([]SpecialCaseChain, error) {
|
||||
var r []SpecialCaseChain
|
||||
for _, class := range classes {
|
||||
cases, err := c.unpackClass(field, class[0], class[1])
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "unpack %q", class[0])
|
||||
}
|
||||
r = append(r, cases...)
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func (c *collector) collectSpecial(m Method) ([]SpecialCaseChain, error) {
|
||||
var r []SpecialCaseChain
|
||||
for _, field := range m.AdditionalParams {
|
||||
cases, err := c.unpackClasses(field, [][2]string{
|
||||
{"MessagesFilterClass", "InputMessagesFilter"},
|
||||
{"ChannelParticipantsFilterClass", "ChannelParticipants"},
|
||||
}...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r = append(r, cases...)
|
||||
}
|
||||
|
||||
sort.SliceStable(r, func(i, j int) bool {
|
||||
return r[i].ConstructorName < r[j].ConstructorName
|
||||
})
|
||||
return r, nil
|
||||
}
|
||||
@@ -0,0 +1,269 @@
|
||||
// Package messages contains message iteration helper.
|
||||
package messages
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/message/peer"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
// Elem is a message iterator element.
|
||||
type Elem struct {
|
||||
Msg tg.NotEmptyMessage
|
||||
Peer tg.InputPeerClass
|
||||
Entities peer.Entities
|
||||
}
|
||||
|
||||
// Iterator is a message stream iterator.
|
||||
type Iterator struct {
|
||||
// Current state.
|
||||
lastErr error
|
||||
// Buffer state.
|
||||
buf []Elem
|
||||
bufCur int
|
||||
// Request state.
|
||||
addOffset int
|
||||
limit int
|
||||
lastBatch bool
|
||||
// Offset parameters state.
|
||||
offsetID int
|
||||
offsetDate int
|
||||
offsetPeer tg.InputPeerClass
|
||||
offsetRate int
|
||||
// Remote state.
|
||||
count int
|
||||
totalGot bool
|
||||
|
||||
// Query builder.
|
||||
query Query
|
||||
}
|
||||
|
||||
// NewIterator creates new iterator.
|
||||
func NewIterator(query Query, limit int) *Iterator {
|
||||
return &Iterator{
|
||||
buf: make([]Elem, 0, limit),
|
||||
bufCur: -1,
|
||||
limit: limit,
|
||||
query: query,
|
||||
offsetPeer: &tg.InputPeerEmpty{},
|
||||
}
|
||||
}
|
||||
|
||||
// OffsetID sets OffsetID request parameter.
|
||||
func (m *Iterator) OffsetID(offsetID int) *Iterator {
|
||||
m.offsetID = offsetID
|
||||
return m
|
||||
}
|
||||
|
||||
// OffsetDate sets OffsetDate request parameter.
|
||||
func (m *Iterator) OffsetDate(offsetDate int) *Iterator {
|
||||
m.offsetDate = offsetDate
|
||||
return m
|
||||
}
|
||||
|
||||
// OffsetRate sets OffsetRate request parameter.
|
||||
func (m *Iterator) OffsetRate(offsetRate int) *Iterator {
|
||||
m.offsetRate = offsetRate
|
||||
return m
|
||||
}
|
||||
|
||||
// OffsetPeer sets OffsetPeer request parameter.
|
||||
func (m *Iterator) OffsetPeer(offsetPeer tg.InputPeerClass) *Iterator {
|
||||
m.offsetPeer = offsetPeer
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *Iterator) apply(r tg.MessagesMessagesClass) error {
|
||||
if m.lastBatch {
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
messages tg.MessageClassArray
|
||||
entities peer.Entities
|
||||
)
|
||||
switch msgs := r.(type) {
|
||||
case *tg.MessagesMessages: // messages.messages#8c718e87
|
||||
messages = msgs.Messages
|
||||
entities = peer.EntitiesFromResult(msgs)
|
||||
|
||||
m.count = len(messages)
|
||||
m.lastBatch = true
|
||||
case *tg.MessagesMessagesSlice: // messages.messagesSlice#3a54685e
|
||||
messages = msgs.Messages
|
||||
entities = peer.EntitiesFromResult(msgs)
|
||||
|
||||
m.offsetRate = msgs.NextRate
|
||||
m.count = msgs.Count
|
||||
m.lastBatch = len(msgs.Messages) < m.limit
|
||||
case *tg.MessagesChannelMessages: // messages.channelMessages#64479808
|
||||
messages = msgs.Messages
|
||||
entities = peer.EntitiesFromResult(msgs)
|
||||
|
||||
m.count = msgs.Count
|
||||
m.lastBatch = len(msgs.Messages) < m.limit
|
||||
default: // messages.messagesNotModified#74535f21
|
||||
return errors.Errorf("unexpected type %T", r)
|
||||
}
|
||||
m.totalGot = true
|
||||
|
||||
// Sort messages to guarantee order and find the last message.
|
||||
messages = messages.SortStable(func(a, b tg.MessageClass) bool {
|
||||
return a.GetID() > b.GetID()
|
||||
})
|
||||
|
||||
// Get the last message (with smallest ID).
|
||||
msg, ok := messages.Last()
|
||||
if !ok {
|
||||
// If Last() returned false, result is empty, so we this is a last batch.
|
||||
m.lastBatch = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// Update offsetID and offsetDate, if can to prevent duplication in case
|
||||
// when there a lot new messages in a chat/channel between previous and current request.
|
||||
//
|
||||
// Illustration of problem:
|
||||
//
|
||||
// Remote state:
|
||||
// [10, 9, 8, 7, 6, 5, 4, 3, 2, 1]
|
||||
// ^ offset = 0
|
||||
//
|
||||
// First request(offset = 0, limit = 5):
|
||||
// [10, 9, 8, 7, 6]
|
||||
// offset = 5
|
||||
//
|
||||
// Remote state:
|
||||
// [15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1]
|
||||
// ^ offset = 5
|
||||
//
|
||||
// Second request(offset = 5, limit = 5):
|
||||
// [10, 9, 8, 7, 6]
|
||||
// offset = 10
|
||||
//
|
||||
m.offsetID = msg.GetID()
|
||||
if nonEmpty, ok := msg.AsNotEmpty(); ok {
|
||||
m.offsetDate = nonEmpty.GetDate()
|
||||
|
||||
p, err := entities.ExtractPeer(nonEmpty.GetPeerID())
|
||||
if err == nil {
|
||||
m.offsetPeer = p
|
||||
}
|
||||
}
|
||||
|
||||
m.bufCur = -1
|
||||
m.buf = m.buf[:0]
|
||||
for _, msg := range messages {
|
||||
nonEmpty, ok := msg.AsNotEmpty()
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
msgPeer, err := entities.ExtractPeer(nonEmpty.GetPeerID())
|
||||
if err != nil {
|
||||
msgPeer = &tg.InputPeerEmpty{}
|
||||
}
|
||||
|
||||
m.buf = append(m.buf, Elem{
|
||||
Msg: nonEmpty,
|
||||
Peer: msgPeer,
|
||||
Entities: entities,
|
||||
})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Iterator) requestNext(ctx context.Context) error {
|
||||
r, err := m.query.Query(ctx, Request{
|
||||
OffsetID: m.offsetID,
|
||||
AddOffset: m.addOffset,
|
||||
OffsetDate: m.offsetDate,
|
||||
OffsetRate: m.offsetRate,
|
||||
OffsetPeer: m.offsetPeer,
|
||||
Limit: m.limit,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return m.apply(r)
|
||||
}
|
||||
|
||||
func (m *Iterator) bufNext() bool {
|
||||
if len(m.buf)-1 <= m.bufCur {
|
||||
return false
|
||||
}
|
||||
|
||||
m.bufCur++
|
||||
return true
|
||||
}
|
||||
|
||||
// Total returns last fetched count of elements.
|
||||
// If count was not fetched before, it requests server using FetchTotal.
|
||||
func (m *Iterator) Total(ctx context.Context) (int, error) {
|
||||
if m.totalGot {
|
||||
return m.count, nil
|
||||
}
|
||||
|
||||
return m.FetchTotal(ctx)
|
||||
}
|
||||
|
||||
// FetchTotal fetches and returns count of elements.
|
||||
func (m *Iterator) FetchTotal(ctx context.Context) (int, error) {
|
||||
r, err := m.query.Query(ctx, Request{
|
||||
Limit: 1,
|
||||
OffsetPeer: &tg.InputPeerEmpty{},
|
||||
})
|
||||
if err != nil {
|
||||
return 0, errors.Wrap(err, "fetch total")
|
||||
}
|
||||
|
||||
switch msgs := r.(type) {
|
||||
case *tg.MessagesMessages: // messages.messages#8c718e87
|
||||
m.count = len(msgs.Messages)
|
||||
case *tg.MessagesMessagesSlice: // messages.messagesSlice#3a54685e
|
||||
m.count = msgs.Count
|
||||
case *tg.MessagesChannelMessages: // messages.channelMessages#64479808
|
||||
m.count = msgs.Count
|
||||
default: // messages.messagesNotModified#74535f21
|
||||
return 0, errors.Errorf("unexpected type %T", r)
|
||||
}
|
||||
|
||||
m.totalGot = true
|
||||
return m.count, nil
|
||||
}
|
||||
|
||||
// Next prepares the next message for reading with the Value method.
|
||||
// It returns true on success, or false if there is no next message or an error happened while preparing it.
|
||||
// Err should be consulted to distinguish between the two cases.
|
||||
func (m *Iterator) Next(ctx context.Context) bool {
|
||||
if m.lastErr != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if !m.bufNext() {
|
||||
// If buffer is empty, we should fetch next batch.
|
||||
if err := m.requestNext(ctx); err != nil {
|
||||
m.lastErr = err
|
||||
return false
|
||||
}
|
||||
// Try again with new buffer.
|
||||
return m.bufNext()
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// Value returns current message.
|
||||
func (m *Iterator) Value() Elem {
|
||||
return m.buf[m.bufCur]
|
||||
}
|
||||
|
||||
// Err returns the error, if any, that was encountered during iteration.
|
||||
func (m *Iterator) Err() error {
|
||||
return m.lastErr
|
||||
}
|
||||
@@ -0,0 +1,107 @@
|
||||
package messages
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tgmock"
|
||||
)
|
||||
|
||||
func generateMessages(count int) []tg.MessageClass {
|
||||
r := make([]tg.MessageClass, 0, count)
|
||||
|
||||
for i := 0; i < count; i++ {
|
||||
r = append(r, &tg.Message{
|
||||
ID: i,
|
||||
PeerID: &tg.PeerUser{UserID: 10},
|
||||
Message: strconv.Itoa(i),
|
||||
})
|
||||
}
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
func messagesClass(r []tg.MessageClass, count int) tg.MessagesMessagesClass {
|
||||
return &tg.MessagesChannelMessages{
|
||||
Messages: r,
|
||||
Count: count,
|
||||
}
|
||||
}
|
||||
|
||||
func TestIterator(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
mock := tgmock.NewRequire(t)
|
||||
limit := 10
|
||||
totalMessages := 3 * limit
|
||||
expected := generateMessages(totalMessages)
|
||||
raw := tg.NewClient(mock)
|
||||
|
||||
mock.ExpectCall(&tg.MessagesSearchRequest{
|
||||
Q: "query",
|
||||
Peer: &tg.InputPeerSelf{},
|
||||
OffsetID: 0,
|
||||
FromID: &tg.InputPeerEmpty{},
|
||||
Filter: &tg.InputMessagesFilterEmpty{},
|
||||
SavedPeerID: &tg.InputPeerEmpty{},
|
||||
Limit: limit,
|
||||
}).ThenResult(messagesClass(expected[2*limit:3*limit], totalMessages))
|
||||
mock.ExpectCall(&tg.MessagesSearchRequest{
|
||||
Q: "query",
|
||||
Peer: &tg.InputPeerSelf{},
|
||||
OffsetID: 20,
|
||||
FromID: &tg.InputPeerEmpty{},
|
||||
Filter: &tg.InputMessagesFilterEmpty{},
|
||||
SavedPeerID: &tg.InputPeerEmpty{},
|
||||
Limit: limit,
|
||||
}).ThenResult(messagesClass(expected[limit:2*limit], totalMessages))
|
||||
mock.ExpectCall(&tg.MessagesSearchRequest{
|
||||
Q: "query",
|
||||
Peer: &tg.InputPeerSelf{},
|
||||
OffsetID: 10,
|
||||
FromID: &tg.InputPeerEmpty{},
|
||||
Filter: &tg.InputMessagesFilterEmpty{},
|
||||
SavedPeerID: &tg.InputPeerEmpty{},
|
||||
Limit: limit,
|
||||
}).ThenResult(messagesClass(expected[:limit], totalMessages))
|
||||
mock.ExpectCall(&tg.MessagesSearchRequest{
|
||||
Q: "query",
|
||||
Peer: &tg.InputPeerSelf{},
|
||||
OffsetID: 0,
|
||||
FromID: &tg.InputPeerEmpty{},
|
||||
Filter: &tg.InputMessagesFilterEmpty{},
|
||||
SavedPeerID: &tg.InputPeerEmpty{},
|
||||
Limit: limit,
|
||||
}).ThenResult(messagesClass(expected[:0], totalMessages))
|
||||
|
||||
iter := NewQueryBuilder(raw).Search(&tg.InputPeerSelf{}).
|
||||
Filter(&tg.InputMessagesFilterEmpty{}).
|
||||
Q("query").BatchSize(10).Iter()
|
||||
i := 0
|
||||
for iter.Next(ctx) {
|
||||
require.Equal(t, expected[len(expected)-i-1], iter.Value().Msg)
|
||||
i++
|
||||
}
|
||||
require.NoError(t, iter.Err())
|
||||
require.Equal(t, totalMessages, i)
|
||||
|
||||
total, err := iter.Total(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, totalMessages, total)
|
||||
|
||||
mock.ExpectCall(&tg.MessagesSearchRequest{
|
||||
Q: "query",
|
||||
Peer: &tg.InputPeerSelf{},
|
||||
OffsetID: 0,
|
||||
FromID: &tg.InputPeerEmpty{},
|
||||
Filter: &tg.InputMessagesFilterEmpty{},
|
||||
SavedPeerID: &tg.InputPeerEmpty{},
|
||||
Limit: 1,
|
||||
}).ThenResult(messagesClass(expected[:0], totalMessages))
|
||||
total, err = iter.FetchTotal(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, totalMessages, total)
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,3 @@
|
||||
package messages
|
||||
|
||||
//go:generate go run go.mau.fi/mautrix-telegram/pkg/gotd/telegram/query/internal/itergen -out=queries.gen.go
|
||||
@@ -0,0 +1,177 @@
|
||||
package messages
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
// Document returns document object if message has a document attachment (video, voice, audio,
|
||||
// basically every type except photo).
|
||||
func (e Elem) Document() (*tg.Document, bool) {
|
||||
msg, ok := e.Msg.(*tg.Message)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
media, ok := msg.Media.(*tg.MessageMediaDocument)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return media.Document.AsNotEmpty()
|
||||
}
|
||||
|
||||
// Photo returns photo object if message has a photo attachment.
|
||||
func (e Elem) Photo() (*tg.Photo, bool) {
|
||||
msg, ok := e.Msg.(*tg.Message)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
media, ok := msg.Media.(*tg.MessageMediaPhoto)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return media.Photo.AsNotEmpty()
|
||||
}
|
||||
|
||||
// File represents file attachment.
|
||||
type File struct {
|
||||
Name string
|
||||
MIMEType string
|
||||
Location tg.InputFileLocationClass
|
||||
}
|
||||
|
||||
const dateLayout = "2006-01-02_15-04-05"
|
||||
|
||||
func getDocFilename(doc *tg.Document) string {
|
||||
var filename, ext string
|
||||
for _, attr := range doc.Attributes {
|
||||
switch v := attr.(type) {
|
||||
case *tg.DocumentAttributeImageSize:
|
||||
switch doc.MimeType {
|
||||
case "image/png":
|
||||
ext = ".png"
|
||||
case "image/webp":
|
||||
ext = ".webp"
|
||||
case "image/tiff":
|
||||
ext = ".tif"
|
||||
default:
|
||||
ext = ".jpg"
|
||||
}
|
||||
case *tg.DocumentAttributeAnimated:
|
||||
ext = ".gif"
|
||||
case *tg.DocumentAttributeSticker:
|
||||
ext = ".webp"
|
||||
case *tg.DocumentAttributeVideo:
|
||||
switch doc.MimeType {
|
||||
case "video/mpeg":
|
||||
ext = ".mpeg"
|
||||
case "video/webm":
|
||||
ext = ".webm"
|
||||
case "video/ogg":
|
||||
ext = ".ogg"
|
||||
default:
|
||||
ext = ".mp4"
|
||||
}
|
||||
case *tg.DocumentAttributeAudio:
|
||||
switch doc.MimeType {
|
||||
case "audio/webm":
|
||||
ext = ".webm"
|
||||
case "audio/aac":
|
||||
ext = ".aac"
|
||||
case "audio/ogg":
|
||||
ext = ".ogg"
|
||||
default:
|
||||
ext = ".mp3"
|
||||
}
|
||||
case *tg.DocumentAttributeFilename:
|
||||
filename = v.FileName
|
||||
}
|
||||
}
|
||||
|
||||
if filename == "" {
|
||||
filename = fmt.Sprintf(
|
||||
"doc%d_%s%s", doc.GetID(),
|
||||
time.Unix(int64(doc.Date), 0).Format(dateLayout),
|
||||
ext,
|
||||
)
|
||||
}
|
||||
|
||||
return filename
|
||||
}
|
||||
|
||||
type sizedPhoto interface {
|
||||
GetW() int
|
||||
GetH() int
|
||||
GetType() string
|
||||
}
|
||||
|
||||
var (
|
||||
_ sizedPhoto = (*tg.PhotoSize)(nil)
|
||||
_ sizedPhoto = (*tg.PhotoCachedSize)(nil)
|
||||
_ sizedPhoto = (*tg.PhotoSizeProgressive)(nil)
|
||||
)
|
||||
|
||||
// File returns file location if message has a file attachment.
|
||||
func (e Elem) File() (File, bool) {
|
||||
msg, ok := e.Msg.(*tg.Message)
|
||||
if !ok {
|
||||
return File{}, false
|
||||
}
|
||||
|
||||
switch media := msg.Media.(type) {
|
||||
case *tg.MessageMediaPhoto:
|
||||
photo, ok := media.Photo.AsNotEmpty()
|
||||
if !ok {
|
||||
return File{}, false
|
||||
}
|
||||
|
||||
filename := fmt.Sprintf(
|
||||
"photo%d_%s.jpg", photo.GetID(),
|
||||
time.Unix(int64(photo.Date), 0).Format(dateLayout),
|
||||
)
|
||||
|
||||
var (
|
||||
thumbSize string
|
||||
maxW, maxH int
|
||||
)
|
||||
for _, g := range photo.Sizes {
|
||||
// TODO(tdakkota): add helpers to choose photo size.
|
||||
if sz, ok := g.(sizedPhoto); ok && maxW < sz.GetW() && maxH < sz.GetH() {
|
||||
thumbSize = sz.GetType()
|
||||
}
|
||||
}
|
||||
|
||||
if thumbSize == "" {
|
||||
return File{}, false
|
||||
}
|
||||
|
||||
return File{
|
||||
Name: filename,
|
||||
MIMEType: "image/jpeg",
|
||||
Location: &tg.InputPhotoFileLocation{
|
||||
ID: photo.ID,
|
||||
AccessHash: photo.AccessHash,
|
||||
FileReference: photo.FileReference,
|
||||
ThumbSize: thumbSize,
|
||||
},
|
||||
}, true
|
||||
case *tg.MessageMediaDocument:
|
||||
doc, ok := media.Document.AsNotEmpty()
|
||||
if !ok {
|
||||
return File{}, false
|
||||
}
|
||||
|
||||
return File{
|
||||
Name: getDocFilename(doc),
|
||||
MIMEType: doc.MimeType,
|
||||
Location: doc.AsInputDocumentFileLocation(),
|
||||
}, true
|
||||
default:
|
||||
return File{}, false
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,133 @@
|
||||
package messages
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
func Test_getDocFilename(t *testing.T) {
|
||||
date := time.Now()
|
||||
f := date.Format(dateLayout)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
args *tg.Document
|
||||
want string
|
||||
}{
|
||||
{
|
||||
"Doc",
|
||||
&tg.Document{
|
||||
Date: int(date.Unix()),
|
||||
Attributes: []tg.DocumentAttributeClass{
|
||||
&tg.DocumentAttributeFilename{FileName: "10.jpg"},
|
||||
},
|
||||
},
|
||||
"10.jpg",
|
||||
},
|
||||
{
|
||||
"Gif",
|
||||
&tg.Document{
|
||||
Date: int(date.Unix()),
|
||||
Attributes: []tg.DocumentAttributeClass{
|
||||
&tg.DocumentAttributeAnimated{},
|
||||
},
|
||||
},
|
||||
"doc0_" + f + ".gif",
|
||||
},
|
||||
{
|
||||
"Video",
|
||||
&tg.Document{
|
||||
Date: int(date.Unix()),
|
||||
Attributes: []tg.DocumentAttributeClass{
|
||||
&tg.DocumentAttributeVideo{},
|
||||
},
|
||||
},
|
||||
"doc0_" + f + ".mp4",
|
||||
},
|
||||
{
|
||||
"Photo",
|
||||
&tg.Document{
|
||||
Date: int(date.Unix()),
|
||||
Attributes: []tg.DocumentAttributeClass{
|
||||
&tg.DocumentAttributeImageSize{},
|
||||
},
|
||||
},
|
||||
"doc0_" + f + ".jpg",
|
||||
},
|
||||
{
|
||||
"Audio",
|
||||
&tg.Document{
|
||||
Date: int(date.Unix()),
|
||||
Attributes: []tg.DocumentAttributeClass{
|
||||
&tg.DocumentAttributeAudio{},
|
||||
},
|
||||
},
|
||||
"doc0_" + f + ".mp3",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
require.Equal(t, tt.want, getDocFilename(tt.args))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestElem_File(t *testing.T) {
|
||||
type results struct {
|
||||
file, doc, photo bool
|
||||
}
|
||||
tests := []struct {
|
||||
Name string
|
||||
Msg tg.NotEmptyMessage
|
||||
results
|
||||
}{
|
||||
{"EmptyMessage", &tg.Message{}, results{}},
|
||||
{"ServiceMessage", &tg.MessageService{}, results{}},
|
||||
{"EmptyPhoto", &tg.Message{
|
||||
Media: &tg.MessageMediaPhoto{
|
||||
Photo: &tg.PhotoEmpty{},
|
||||
},
|
||||
}, results{}},
|
||||
{"EmptyDoc", &tg.Message{
|
||||
Media: &tg.MessageMediaDocument{
|
||||
Document: &tg.DocumentEmpty{},
|
||||
},
|
||||
}, results{}},
|
||||
{"Photo", &tg.Message{
|
||||
Media: &tg.MessageMediaPhoto{
|
||||
Photo: &tg.Photo{
|
||||
Sizes: []tg.PhotoSizeClass{
|
||||
&tg.PhotoSize{
|
||||
Type: "cock",
|
||||
W: 10,
|
||||
H: 10,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, results{file: true, photo: true}},
|
||||
{"Document", &tg.Message{
|
||||
Media: &tg.MessageMediaDocument{
|
||||
Document: &tg.Document{},
|
||||
},
|
||||
}, results{file: true, doc: true}},
|
||||
}
|
||||
for _, test := range tests {
|
||||
t.Run(test.Name, func(t *testing.T) {
|
||||
a := require.New(t)
|
||||
var ok bool
|
||||
|
||||
elem := Elem{Msg: test.Msg}
|
||||
_, ok = elem.File()
|
||||
a.Equal(test.file, ok)
|
||||
_, ok = elem.Document()
|
||||
a.Equal(test.doc, ok)
|
||||
_, ok = elem.Photo()
|
||||
a.Equal(test.photo, ok)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,166 @@
|
||||
// Package featured contains featured stickers iteration helper.
|
||||
package featured
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
// Elem is a sticker iterator element.
|
||||
type Elem struct {
|
||||
Sticker tg.StickerSetCoveredClass
|
||||
// IDs of new featured stickersets
|
||||
Unread []int64
|
||||
}
|
||||
|
||||
// Iterator is a featured stickers stream iterator.
|
||||
type Iterator struct {
|
||||
// Current state.
|
||||
lastErr error
|
||||
// Buffer state.
|
||||
buf []Elem
|
||||
bufCur int
|
||||
// Request state.
|
||||
limit int
|
||||
lastBatch bool
|
||||
// Offset parameters state.
|
||||
offset int
|
||||
// Remote state.
|
||||
count int
|
||||
totalGot bool
|
||||
|
||||
// Query builder.
|
||||
query Query
|
||||
}
|
||||
|
||||
// NewIterator creates new iterator.
|
||||
func NewIterator(query Query, limit int) *Iterator {
|
||||
return &Iterator{
|
||||
buf: make([]Elem, 0, limit),
|
||||
bufCur: -1,
|
||||
limit: limit,
|
||||
query: query,
|
||||
}
|
||||
}
|
||||
|
||||
// Offset sets Offset request parameter.
|
||||
func (m *Iterator) Offset(offset int) *Iterator {
|
||||
m.offset = offset
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *Iterator) apply(r tg.MessagesFeaturedStickersClass) error {
|
||||
if m.lastBatch {
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
stickers []tg.StickerSetCoveredClass
|
||||
unread []int64
|
||||
)
|
||||
switch stks := r.(type) {
|
||||
case *tg.MessagesFeaturedStickers: // messages.featuredStickers#b6abc341
|
||||
stickers = stks.Sets
|
||||
unread = stks.Unread
|
||||
|
||||
m.count = stks.Count
|
||||
m.lastBatch = len(stickers) < m.limit
|
||||
default:
|
||||
return errors.Errorf("unexpected type %T", r)
|
||||
}
|
||||
m.totalGot = true
|
||||
m.offset += len(stickers)
|
||||
|
||||
m.bufCur = -1
|
||||
m.buf = m.buf[:0]
|
||||
for i := range stickers {
|
||||
m.buf = append(m.buf, Elem{Sticker: stickers[i], Unread: unread})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Iterator) requestNext(ctx context.Context) error {
|
||||
r, err := m.query.Query(ctx, Request{
|
||||
Offset: m.offset,
|
||||
Limit: m.limit,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return m.apply(r)
|
||||
}
|
||||
|
||||
func (m *Iterator) bufNext() bool {
|
||||
if len(m.buf)-1 <= m.bufCur {
|
||||
return false
|
||||
}
|
||||
|
||||
m.bufCur++
|
||||
return true
|
||||
}
|
||||
|
||||
// Total returns last fetched count of elements.
|
||||
// If count was not fetched before, it requests server using FetchTotal.
|
||||
func (m *Iterator) Total(ctx context.Context) (int, error) {
|
||||
if m.totalGot {
|
||||
return m.count, nil
|
||||
}
|
||||
|
||||
return m.FetchTotal(ctx)
|
||||
}
|
||||
|
||||
// FetchTotal fetches and returns count of elements.
|
||||
func (m *Iterator) FetchTotal(ctx context.Context) (int, error) {
|
||||
r, err := m.query.Query(ctx, Request{
|
||||
Limit: 1,
|
||||
})
|
||||
if err != nil {
|
||||
return 0, errors.Wrap(err, "fetch total")
|
||||
}
|
||||
|
||||
switch stks := r.(type) {
|
||||
case *tg.MessagesFeaturedStickers: // messages.featuredStickers#b6abc341
|
||||
m.count = stks.Count
|
||||
default:
|
||||
return 0, errors.Errorf("unexpected type %T", r)
|
||||
}
|
||||
|
||||
m.totalGot = true
|
||||
return m.count, nil
|
||||
}
|
||||
|
||||
// Next prepares the next message for reading with the Value method.
|
||||
// It returns true on success, or false if there is no next message or an error happened while preparing it.
|
||||
// Err should be consulted to distinguish between the two cases.
|
||||
func (m *Iterator) Next(ctx context.Context) bool {
|
||||
if m.lastErr != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if !m.bufNext() {
|
||||
// If buffer is empty, we should fetch next batch.
|
||||
if err := m.requestNext(ctx); err != nil {
|
||||
m.lastErr = err
|
||||
return false
|
||||
}
|
||||
// Try again with new buffer.
|
||||
return m.bufNext()
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// Value returns current message.
|
||||
func (m *Iterator) Value() Elem {
|
||||
return m.buf[m.bufCur]
|
||||
}
|
||||
|
||||
// Err returns the error, if any, that was encountered during iteration.
|
||||
func (m *Iterator) Err() error {
|
||||
return m.lastErr
|
||||
}
|
||||
@@ -0,0 +1,85 @@
|
||||
package featured
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tgmock"
|
||||
)
|
||||
|
||||
func generateStickers(count int) []tg.StickerSetCoveredClass {
|
||||
r := make([]tg.StickerSetCoveredClass, 0, count)
|
||||
|
||||
for i := 0; i < count; i++ {
|
||||
r = append(r, &tg.StickerSetCovered{
|
||||
Set: tg.StickerSet{
|
||||
ID: int64(i + 1),
|
||||
AccessHash: int64(i + 1),
|
||||
},
|
||||
Cover: &tg.Document{
|
||||
ID: int64(i + 1),
|
||||
AccessHash: int64(i + 1),
|
||||
FileReference: []uint8{uint8(i)},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
func result(r []tg.StickerSetCoveredClass, count int) tg.MessagesFeaturedStickersClass {
|
||||
return &tg.MessagesFeaturedStickers{
|
||||
Sets: r,
|
||||
Count: count,
|
||||
}
|
||||
}
|
||||
|
||||
func TestIterator(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
mock := tgmock.NewRequire(t)
|
||||
limit := 10
|
||||
totalRecords := 3 * limit
|
||||
expected := generateStickers(totalRecords)
|
||||
raw := tg.NewClient(mock)
|
||||
|
||||
mock.ExpectCall(&tg.MessagesGetOldFeaturedStickersRequest{
|
||||
Offset: 0,
|
||||
Limit: limit,
|
||||
}).ThenResult(result(expected[0:limit], totalRecords))
|
||||
mock.ExpectCall(&tg.MessagesGetOldFeaturedStickersRequest{
|
||||
Offset: limit,
|
||||
Limit: limit,
|
||||
}).ThenResult(result(expected[limit:2*limit], totalRecords))
|
||||
mock.ExpectCall(&tg.MessagesGetOldFeaturedStickersRequest{
|
||||
Offset: 2 * limit,
|
||||
Limit: limit,
|
||||
}).ThenResult(result(expected[2*limit:3*limit], totalRecords))
|
||||
mock.ExpectCall(&tg.MessagesGetOldFeaturedStickersRequest{
|
||||
Offset: 3 * limit,
|
||||
Limit: limit,
|
||||
}).ThenResult(result(expected[3*limit:], totalRecords))
|
||||
|
||||
iter := NewQueryBuilder(raw).GetOldFeaturedStickers().BatchSize(10).Iter()
|
||||
i := 0
|
||||
for iter.Next(ctx) {
|
||||
require.Equal(t, expected[i], iter.Value().Sticker)
|
||||
i++
|
||||
}
|
||||
require.NoError(t, iter.Err())
|
||||
require.Equal(t, totalRecords, i)
|
||||
|
||||
total, err := iter.Total(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, totalRecords, total)
|
||||
|
||||
mock.ExpectCall(&tg.MessagesGetOldFeaturedStickersRequest{
|
||||
Offset: 0,
|
||||
Limit: 1,
|
||||
}).ThenResult(result(expected[:0], totalRecords))
|
||||
total, err = iter.FetchTotal(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, totalRecords, total)
|
||||
}
|
||||
@@ -0,0 +1,125 @@
|
||||
// Code generated by itergen, DO NOT EDIT.
|
||||
|
||||
package featured
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
// No-op definition for keeping imports.
|
||||
var _ = context.Background()
|
||||
|
||||
// Request is a parameter for Query.
|
||||
type Request struct {
|
||||
Offset int
|
||||
Limit int
|
||||
}
|
||||
|
||||
// Query is an abstraction for featured request.
|
||||
// NB: iterator mutates returned data (sorts, at least).
|
||||
type Query interface {
|
||||
Query(ctx context.Context, req Request) (tg.MessagesFeaturedStickersClass, error)
|
||||
}
|
||||
|
||||
// QueryFunc is a function adapter for Query.
|
||||
type QueryFunc func(ctx context.Context, req Request) (tg.MessagesFeaturedStickersClass, error)
|
||||
|
||||
// Query implements Query interface.
|
||||
func (q QueryFunc) Query(ctx context.Context, req Request) (tg.MessagesFeaturedStickersClass, error) {
|
||||
return q(ctx, req)
|
||||
}
|
||||
|
||||
// QueryBuilder is a helper to create message queries.
|
||||
type QueryBuilder struct {
|
||||
raw *tg.Client
|
||||
}
|
||||
|
||||
// NewQueryBuilder creates new QueryBuilder.
|
||||
func NewQueryBuilder(raw *tg.Client) *QueryBuilder {
|
||||
return &QueryBuilder{raw: raw}
|
||||
}
|
||||
|
||||
// GetOldFeaturedStickersQueryBuilder is query builder of MessagesGetOldFeaturedStickers.
|
||||
type GetOldFeaturedStickersQueryBuilder struct {
|
||||
raw *tg.Client
|
||||
req tg.MessagesGetOldFeaturedStickersRequest
|
||||
batchSize int
|
||||
offset int
|
||||
}
|
||||
|
||||
// GetOldFeaturedStickers creates query builder of MessagesGetOldFeaturedStickers.
|
||||
func (q *QueryBuilder) GetOldFeaturedStickers() *GetOldFeaturedStickersQueryBuilder {
|
||||
b := &GetOldFeaturedStickersQueryBuilder{
|
||||
raw: q.raw,
|
||||
batchSize: 1,
|
||||
req: tg.MessagesGetOldFeaturedStickersRequest{},
|
||||
}
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
// BatchSize sets buffer of message loaded from one request.
|
||||
// Be carefully, when set this limit, because Telegram does not return error if limit is too big,
|
||||
// so results can be incorrect.
|
||||
func (b *GetOldFeaturedStickersQueryBuilder) BatchSize(batchSize int) *GetOldFeaturedStickersQueryBuilder {
|
||||
b.batchSize = batchSize
|
||||
return b
|
||||
}
|
||||
|
||||
// Query implements Query interface.
|
||||
func (b *GetOldFeaturedStickersQueryBuilder) Query(ctx context.Context, req Request) (tg.MessagesFeaturedStickersClass, error) {
|
||||
r := &tg.MessagesGetOldFeaturedStickersRequest{
|
||||
Limit: req.Limit,
|
||||
}
|
||||
|
||||
r.Offset = req.Offset
|
||||
return b.raw.MessagesGetOldFeaturedStickers(ctx, r)
|
||||
}
|
||||
|
||||
// Iter returns iterator using built query.
|
||||
func (b *GetOldFeaturedStickersQueryBuilder) Iter() *Iterator {
|
||||
iter := NewIterator(b, b.batchSize)
|
||||
iter = iter.Offset(b.offset)
|
||||
return iter
|
||||
}
|
||||
|
||||
// ForEach calls given callback on each iterator element.
|
||||
func (b *GetOldFeaturedStickersQueryBuilder) ForEach(ctx context.Context, cb func(context.Context, Elem) error) error {
|
||||
iter := b.Iter()
|
||||
for iter.Next(ctx) {
|
||||
if err := cb(ctx, iter.Value()); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return iter.Err()
|
||||
}
|
||||
|
||||
// Count fetches remote state to get number of elements.
|
||||
func (b *GetOldFeaturedStickersQueryBuilder) Count(ctx context.Context) (int, error) {
|
||||
iter := b.Iter()
|
||||
c, err := iter.Total(ctx)
|
||||
if err != nil {
|
||||
return 0, errors.Wrap(err, "get total")
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// Collect creates iterator and collects all elements to slice.
|
||||
func (b *GetOldFeaturedStickersQueryBuilder) Collect(ctx context.Context) ([]Elem, error) {
|
||||
iter := b.Iter()
|
||||
c, err := iter.Total(ctx)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "get total")
|
||||
}
|
||||
|
||||
r := make([]Elem, 0, c)
|
||||
for iter.Next(ctx) {
|
||||
r = append(r, iter.Value())
|
||||
}
|
||||
|
||||
return r, iter.Err()
|
||||
}
|
||||
@@ -0,0 +1,3 @@
|
||||
package featured
|
||||
|
||||
//go:generate go run go.mau.fi/mautrix-telegram/pkg/gotd/telegram/query/internal/itergen -result=MessagesFeaturedStickersClass -package=featured -out=queries.gen.go
|
||||
@@ -0,0 +1,174 @@
|
||||
// Package photos contains photos iteration helper.
|
||||
package photos
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/message/peer"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
// Elem is a photo iterator element.
|
||||
type Elem struct {
|
||||
Photo tg.PhotoClass
|
||||
Entities peer.Entities
|
||||
}
|
||||
|
||||
// Iterator is a photo stream iterator.
|
||||
type Iterator struct {
|
||||
// Current state.
|
||||
lastErr error
|
||||
// Buffer state.
|
||||
buf []Elem
|
||||
bufCur int
|
||||
// Request state.
|
||||
limit int
|
||||
lastBatch bool
|
||||
// Offset parameters state.
|
||||
offset int
|
||||
// Remote state.
|
||||
count int
|
||||
totalGot bool
|
||||
|
||||
// Query builder.
|
||||
query Query
|
||||
}
|
||||
|
||||
// NewIterator creates new iterator.
|
||||
func NewIterator(query Query, limit int) *Iterator {
|
||||
return &Iterator{
|
||||
buf: make([]Elem, 0, limit),
|
||||
bufCur: -1,
|
||||
limit: limit,
|
||||
query: query,
|
||||
}
|
||||
}
|
||||
|
||||
// Offset sets Offset request parameter.
|
||||
func (m *Iterator) Offset(offset int) *Iterator {
|
||||
m.offset = offset
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *Iterator) apply(r tg.PhotosPhotosClass) error {
|
||||
if m.lastBatch {
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
photos []tg.PhotoClass
|
||||
entities peer.Entities
|
||||
)
|
||||
switch phts := r.(type) {
|
||||
case *tg.PhotosPhotos: // photos.photos#8dca6aa5
|
||||
photos = phts.Photos
|
||||
entities = peer.NewEntities(phts.MapUsers().UserToMap(), map[int64]*tg.Chat{}, map[int64]*tg.Channel{})
|
||||
|
||||
m.count = len(phts.Photos)
|
||||
m.lastBatch = true
|
||||
case *tg.PhotosPhotosSlice: // photos.photosSlice#15051f54
|
||||
photos = phts.Photos
|
||||
entities = peer.NewEntities(phts.MapUsers().UserToMap(), map[int64]*tg.Chat{}, map[int64]*tg.Channel{})
|
||||
|
||||
m.count = phts.Count
|
||||
m.lastBatch = len(phts.Photos) < m.limit
|
||||
default:
|
||||
return errors.Errorf("unexpected type %T", r)
|
||||
}
|
||||
m.totalGot = true
|
||||
m.offset += len(photos)
|
||||
|
||||
m.bufCur = -1
|
||||
m.buf = m.buf[:0]
|
||||
for i := range photos {
|
||||
m.buf = append(m.buf, Elem{Photo: photos[i], Entities: entities})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Iterator) requestNext(ctx context.Context) error {
|
||||
r, err := m.query.Query(ctx, Request{
|
||||
Offset: m.offset,
|
||||
Limit: m.limit,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return m.apply(r)
|
||||
}
|
||||
|
||||
func (m *Iterator) bufNext() bool {
|
||||
if len(m.buf)-1 <= m.bufCur {
|
||||
return false
|
||||
}
|
||||
|
||||
m.bufCur++
|
||||
return true
|
||||
}
|
||||
|
||||
// Total returns last fetched count of elements.
|
||||
// If count was not fetched before, it requests server using FetchTotal.
|
||||
func (m *Iterator) Total(ctx context.Context) (int, error) {
|
||||
if m.totalGot {
|
||||
return m.count, nil
|
||||
}
|
||||
|
||||
return m.FetchTotal(ctx)
|
||||
}
|
||||
|
||||
// FetchTotal fetches and returns count of elements.
|
||||
func (m *Iterator) FetchTotal(ctx context.Context) (int, error) {
|
||||
r, err := m.query.Query(ctx, Request{
|
||||
Limit: 1,
|
||||
})
|
||||
if err != nil {
|
||||
return 0, errors.Wrap(err, "fetch total")
|
||||
}
|
||||
|
||||
switch phts := r.(type) {
|
||||
case *tg.PhotosPhotos: // photos.photos#8dca6aa5
|
||||
m.count = len(phts.Photos)
|
||||
case *tg.PhotosPhotosSlice: // photos.photosSlice#15051f54
|
||||
m.count = phts.Count
|
||||
default:
|
||||
return 0, errors.Errorf("unexpected type %T", r)
|
||||
}
|
||||
|
||||
m.totalGot = true
|
||||
return m.count, nil
|
||||
}
|
||||
|
||||
// Next prepares the next message for reading with the Value method.
|
||||
// It returns true on success, or false if there is no next message or an error happened while preparing it.
|
||||
// Err should be consulted to distinguish between the two cases.
|
||||
func (m *Iterator) Next(ctx context.Context) bool {
|
||||
if m.lastErr != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if !m.bufNext() {
|
||||
// If buffer is empty, we should fetch next batch.
|
||||
if err := m.requestNext(ctx); err != nil {
|
||||
m.lastErr = err
|
||||
return false
|
||||
}
|
||||
// Try again with new buffer.
|
||||
return m.bufNext()
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// Value returns current message.
|
||||
func (m *Iterator) Value() Elem {
|
||||
return m.buf[m.bufCur]
|
||||
}
|
||||
|
||||
// Err returns the error, if any, that was encountered during iteration.
|
||||
func (m *Iterator) Err() error {
|
||||
return m.lastErr
|
||||
}
|
||||
@@ -0,0 +1,84 @@
|
||||
package photos
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tgmock"
|
||||
)
|
||||
|
||||
func generatePhotos(count int) []tg.PhotoClass {
|
||||
r := make([]tg.PhotoClass, 0, count)
|
||||
|
||||
for i := 0; i < count; i++ {
|
||||
r = append(r, &tg.Photo{
|
||||
ID: int64(i + 1),
|
||||
AccessHash: int64(i + 1),
|
||||
FileReference: []uint8{uint8(i)},
|
||||
})
|
||||
}
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
func result(r []tg.PhotoClass, count int) tg.PhotosPhotosClass {
|
||||
return &tg.PhotosPhotosSlice{
|
||||
Photos: r,
|
||||
Count: count,
|
||||
}
|
||||
}
|
||||
|
||||
func TestIterator(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
mock := tgmock.NewRequire(t)
|
||||
limit := 10
|
||||
totalRecords := 3 * limit
|
||||
expected := generatePhotos(totalRecords)
|
||||
raw := tg.NewClient(mock)
|
||||
|
||||
mock.ExpectCall(&tg.PhotosGetUserPhotosRequest{
|
||||
UserID: &tg.InputUserSelf{},
|
||||
Offset: 0,
|
||||
Limit: limit,
|
||||
}).ThenResult(result(expected[0:limit], totalRecords))
|
||||
mock.ExpectCall(&tg.PhotosGetUserPhotosRequest{
|
||||
UserID: &tg.InputUserSelf{},
|
||||
Offset: limit,
|
||||
Limit: limit,
|
||||
}).ThenResult(result(expected[limit:2*limit], totalRecords))
|
||||
mock.ExpectCall(&tg.PhotosGetUserPhotosRequest{
|
||||
UserID: &tg.InputUserSelf{},
|
||||
Offset: 2 * limit,
|
||||
Limit: limit,
|
||||
}).ThenResult(result(expected[2*limit:3*limit], totalRecords))
|
||||
mock.ExpectCall(&tg.PhotosGetUserPhotosRequest{
|
||||
UserID: &tg.InputUserSelf{},
|
||||
Offset: 3 * limit,
|
||||
Limit: limit,
|
||||
}).ThenResult(result(expected[3*limit:], totalRecords))
|
||||
|
||||
iter := NewQueryBuilder(raw).GetUserPhotos(&tg.InputUserSelf{}).BatchSize(10).Iter()
|
||||
i := 0
|
||||
for iter.Next(ctx) {
|
||||
require.Equal(t, expected[i], iter.Value().Photo)
|
||||
i++
|
||||
}
|
||||
require.NoError(t, iter.Err())
|
||||
require.Equal(t, totalRecords, i)
|
||||
|
||||
total, err := iter.Total(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, totalRecords, total)
|
||||
|
||||
mock.ExpectCall(&tg.PhotosGetUserPhotosRequest{
|
||||
UserID: &tg.InputUserSelf{},
|
||||
Offset: 0,
|
||||
Limit: 1,
|
||||
}).ThenResult(result(expected[:0], totalRecords))
|
||||
total, err = iter.FetchTotal(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, totalRecords, total)
|
||||
}
|
||||
@@ -0,0 +1,133 @@
|
||||
// Code generated by itergen, DO NOT EDIT.
|
||||
|
||||
package photos
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
// No-op definition for keeping imports.
|
||||
var _ = context.Background()
|
||||
|
||||
// Request is a parameter for Query.
|
||||
type Request struct {
|
||||
Offset int
|
||||
Limit int
|
||||
}
|
||||
|
||||
// Query is an abstraction for photos request.
|
||||
// NB: iterator mutates returned data (sorts, at least).
|
||||
type Query interface {
|
||||
Query(ctx context.Context, req Request) (tg.PhotosPhotosClass, error)
|
||||
}
|
||||
|
||||
// QueryFunc is a function adapter for Query.
|
||||
type QueryFunc func(ctx context.Context, req Request) (tg.PhotosPhotosClass, error)
|
||||
|
||||
// Query implements Query interface.
|
||||
func (q QueryFunc) Query(ctx context.Context, req Request) (tg.PhotosPhotosClass, error) {
|
||||
return q(ctx, req)
|
||||
}
|
||||
|
||||
// QueryBuilder is a helper to create message queries.
|
||||
type QueryBuilder struct {
|
||||
raw *tg.Client
|
||||
}
|
||||
|
||||
// NewQueryBuilder creates new QueryBuilder.
|
||||
func NewQueryBuilder(raw *tg.Client) *QueryBuilder {
|
||||
return &QueryBuilder{raw: raw}
|
||||
}
|
||||
|
||||
// GetUserPhotosQueryBuilder is query builder of PhotosGetUserPhotos.
|
||||
type GetUserPhotosQueryBuilder struct {
|
||||
raw *tg.Client
|
||||
req tg.PhotosGetUserPhotosRequest
|
||||
batchSize int
|
||||
offset int
|
||||
}
|
||||
|
||||
// GetUserPhotos creates query builder of PhotosGetUserPhotos.
|
||||
func (q *QueryBuilder) GetUserPhotos(paramUserID tg.InputUserClass) *GetUserPhotosQueryBuilder {
|
||||
b := &GetUserPhotosQueryBuilder{
|
||||
raw: q.raw,
|
||||
batchSize: 1,
|
||||
req: tg.PhotosGetUserPhotosRequest{},
|
||||
}
|
||||
|
||||
b.req.UserID = paramUserID
|
||||
return b
|
||||
}
|
||||
|
||||
// BatchSize sets buffer of message loaded from one request.
|
||||
// Be carefully, when set this limit, because Telegram does not return error if limit is too big,
|
||||
// so results can be incorrect.
|
||||
func (b *GetUserPhotosQueryBuilder) BatchSize(batchSize int) *GetUserPhotosQueryBuilder {
|
||||
b.batchSize = batchSize
|
||||
return b
|
||||
}
|
||||
|
||||
// UserID sets UserID field of GetUserPhotos query.
|
||||
func (b *GetUserPhotosQueryBuilder) UserID(paramUserID tg.InputUserClass) *GetUserPhotosQueryBuilder {
|
||||
b.req.UserID = paramUserID
|
||||
return b
|
||||
}
|
||||
|
||||
// Query implements Query interface.
|
||||
func (b *GetUserPhotosQueryBuilder) Query(ctx context.Context, req Request) (tg.PhotosPhotosClass, error) {
|
||||
r := &tg.PhotosGetUserPhotosRequest{
|
||||
Limit: req.Limit,
|
||||
}
|
||||
|
||||
r.UserID = b.req.UserID
|
||||
r.Offset = req.Offset
|
||||
return b.raw.PhotosGetUserPhotos(ctx, r)
|
||||
}
|
||||
|
||||
// Iter returns iterator using built query.
|
||||
func (b *GetUserPhotosQueryBuilder) Iter() *Iterator {
|
||||
iter := NewIterator(b, b.batchSize)
|
||||
iter = iter.Offset(b.offset)
|
||||
return iter
|
||||
}
|
||||
|
||||
// ForEach calls given callback on each iterator element.
|
||||
func (b *GetUserPhotosQueryBuilder) ForEach(ctx context.Context, cb func(context.Context, Elem) error) error {
|
||||
iter := b.Iter()
|
||||
for iter.Next(ctx) {
|
||||
if err := cb(ctx, iter.Value()); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return iter.Err()
|
||||
}
|
||||
|
||||
// Count fetches remote state to get number of elements.
|
||||
func (b *GetUserPhotosQueryBuilder) Count(ctx context.Context) (int, error) {
|
||||
iter := b.Iter()
|
||||
c, err := iter.Total(ctx)
|
||||
if err != nil {
|
||||
return 0, errors.Wrap(err, "get total")
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// Collect creates iterator and collects all elements to slice.
|
||||
func (b *GetUserPhotosQueryBuilder) Collect(ctx context.Context) ([]Elem, error) {
|
||||
iter := b.Iter()
|
||||
c, err := iter.Total(ctx)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "get total")
|
||||
}
|
||||
|
||||
r := make([]Elem, 0, c)
|
||||
for iter.Next(ctx) {
|
||||
r = append(r, iter.Value())
|
||||
}
|
||||
|
||||
return r, iter.Err()
|
||||
}
|
||||
@@ -0,0 +1,3 @@
|
||||
package photos
|
||||
|
||||
//go:generate go run go.mau.fi/mautrix-telegram/pkg/gotd/telegram/query/internal/itergen -result=PhotosPhotosClass -package=photos -prefix=Photos -out=queries.gen.go
|
||||
@@ -0,0 +1,125 @@
|
||||
// Package query contains generic pagination helpers.
|
||||
package query
|
||||
|
||||
import (
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/query/channels/participants"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/query/contacts/blocked"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/query/dialogs"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/query/messages"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/query/messages/stickers/featured"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/query/photos"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
// Query is common struct to create query builders.
|
||||
type Query struct {
|
||||
raw *tg.Client
|
||||
}
|
||||
|
||||
// NewQuery creates Query.
|
||||
func NewQuery(raw *tg.Client) *Query {
|
||||
return &Query{raw: raw}
|
||||
}
|
||||
|
||||
// Participants creates participants.QueryBuilder
|
||||
func (q *Query) Participants() *participants.QueryBuilder {
|
||||
return participants.NewQueryBuilder(q.raw)
|
||||
}
|
||||
|
||||
// Blocked creates blocked.QueryBuilder
|
||||
func (q *Query) Blocked() *blocked.QueryBuilder {
|
||||
return blocked.NewQueryBuilder(q.raw)
|
||||
}
|
||||
|
||||
// Photos creates photos.QueryBuilder
|
||||
func (q *Query) Photos() *photos.QueryBuilder {
|
||||
return photos.NewQueryBuilder(q.raw)
|
||||
}
|
||||
|
||||
// Dialogs creates dialogs.QueryBuilder
|
||||
func (q *Query) Dialogs() *dialogs.QueryBuilder {
|
||||
return dialogs.NewQueryBuilder(q.raw)
|
||||
}
|
||||
|
||||
// Messages creates messages.QueryBuilder.
|
||||
func (q *Query) Messages() *messages.QueryBuilder {
|
||||
return messages.NewQueryBuilder(q.raw)
|
||||
}
|
||||
|
||||
// Featured creates featured.QueryBuilder
|
||||
func (q *Query) Featured() *featured.QueryBuilder {
|
||||
return featured.NewQueryBuilder(q.raw)
|
||||
}
|
||||
|
||||
// GetParticipants creates participants.GetParticipantsQueryBuilder.
|
||||
func (q *Query) GetParticipants(channel tg.InputChannelClass) *participants.GetParticipantsQueryBuilder {
|
||||
return participants.NewQueryBuilder(q.raw).GetParticipants(channel)
|
||||
}
|
||||
|
||||
// GetParticipants creates participants.GetParticipantsQueryBuilder.
|
||||
// Shorthand for
|
||||
//
|
||||
// query.NewQuery(raw).GetParticipants(channel)
|
||||
func GetParticipants(raw *tg.Client, channel tg.InputChannelClass) *participants.GetParticipantsQueryBuilder {
|
||||
return NewQuery(raw).GetParticipants(channel)
|
||||
}
|
||||
|
||||
// GetBlocked creates blocked.GetBlockedQueryBuilder.
|
||||
func (q *Query) GetBlocked() *blocked.GetBlockedQueryBuilder {
|
||||
return blocked.NewQueryBuilder(q.raw).GetBlocked()
|
||||
}
|
||||
|
||||
// GetBlocked creates blocked.GetBlockedQueryBuilder.
|
||||
// Shorthand for
|
||||
//
|
||||
// query.NewQuery(raw).GetBlocked()
|
||||
func GetBlocked(raw *tg.Client) *blocked.GetBlockedQueryBuilder {
|
||||
return NewQuery(raw).GetBlocked()
|
||||
}
|
||||
|
||||
// GetUserPhotos creates photos.GetUserPhotosQueryBuilder.
|
||||
func (q *Query) GetUserPhotos(user tg.InputUserClass) *photos.GetUserPhotosQueryBuilder {
|
||||
return photos.NewQueryBuilder(q.raw).GetUserPhotos(user)
|
||||
}
|
||||
|
||||
// GetUserPhotos creates photos.GetUserPhotosQueryBuilder.
|
||||
// Shorthand for
|
||||
//
|
||||
// query.NewQuery(raw).GetUserPhotos(user)
|
||||
func GetUserPhotos(raw *tg.Client, user tg.InputUserClass) *photos.GetUserPhotosQueryBuilder {
|
||||
return NewQuery(raw).GetUserPhotos(user)
|
||||
}
|
||||
|
||||
// GetDialogs creates dialogs.GetDialogsQueryBuilder.
|
||||
func (q *Query) GetDialogs() *dialogs.GetDialogsQueryBuilder {
|
||||
return dialogs.NewQueryBuilder(q.raw).GetDialogs()
|
||||
}
|
||||
|
||||
// Messages creates messages.QueryBuilder.
|
||||
// Shorthand for
|
||||
//
|
||||
// query.NewQuery(raw).Messages()
|
||||
func Messages(raw *tg.Client) *messages.QueryBuilder {
|
||||
return NewQuery(raw).Messages()
|
||||
}
|
||||
|
||||
// GetDialogs creates dialogs.GetDialogsQueryBuilder.
|
||||
// Shorthand for
|
||||
//
|
||||
// query.NewQuery(raw).GetDialogs()
|
||||
func GetDialogs(raw *tg.Client) *dialogs.GetDialogsQueryBuilder {
|
||||
return NewQuery(raw).GetDialogs()
|
||||
}
|
||||
|
||||
// GetOldFeaturedStickers creates featured.QueryBuilder.
|
||||
func (q *Query) GetOldFeaturedStickers() *featured.GetOldFeaturedStickersQueryBuilder {
|
||||
return featured.NewQueryBuilder(q.raw).GetOldFeaturedStickers()
|
||||
}
|
||||
|
||||
// GetOldFeaturedStickers creates featured.QueryBuilder.
|
||||
// Shorthand for
|
||||
//
|
||||
// query.NewQuery(raw).GetOldFeaturedStickers()
|
||||
func GetOldFeaturedStickers(raw *tg.Client) *featured.GetOldFeaturedStickersQueryBuilder {
|
||||
return NewQuery(raw).GetOldFeaturedStickers()
|
||||
}
|
||||
@@ -0,0 +1,153 @@
|
||||
package query_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/downloader"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/query"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/query/channels/participants"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/query/dialogs"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/query/messages"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
func ExampleQuery_iterAllMessages() {
|
||||
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt)
|
||||
defer cancel()
|
||||
|
||||
client, err := telegram.ClientFromEnvironment(telegram.Options{})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// This example iterates over all messages of all dialogs of user and prints them.
|
||||
if err := client.Run(ctx, func(ctx context.Context) error {
|
||||
raw := tg.NewClient(client)
|
||||
cb := func(ctx context.Context, dlg dialogs.Elem) error {
|
||||
// Skip deleted dialogs.
|
||||
if dlg.Deleted() {
|
||||
return nil
|
||||
}
|
||||
|
||||
return dlg.Messages(raw).ForEach(ctx, func(ctx context.Context, elem messages.Elem) error {
|
||||
msg, ok := elem.Msg.(*tg.Message)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
fmt.Println(msg.Message)
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
return query.GetDialogs(raw).ForEach(ctx, cb)
|
||||
}); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func ExampleQuery_downloadSaved() {
|
||||
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt)
|
||||
defer cancel()
|
||||
|
||||
client, err := telegram.ClientFromEnvironment(telegram.Options{})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// This example downloads all attachments (photo, video, docs, etc.)
|
||||
// from SavedMessages dialog.
|
||||
if err := client.Run(ctx, func(ctx context.Context) error {
|
||||
raw := tg.NewClient(client)
|
||||
d := downloader.NewDownloader()
|
||||
return query.Messages(raw).GetHistory(&tg.InputPeerSelf{}).ForEach(ctx,
|
||||
func(ctx context.Context, elem messages.Elem) error {
|
||||
f, ok := elem.File()
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err := d.Download(raw, f.Location).ToPath(ctx, f.Name)
|
||||
return err
|
||||
})
|
||||
}); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func ExampleQuery_getAdmins() {
|
||||
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt)
|
||||
defer cancel()
|
||||
|
||||
client, err := telegram.ClientFromEnvironment(telegram.Options{})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// This example iterates over all channels and prints admins.
|
||||
if err := client.Run(ctx, func(ctx context.Context) error {
|
||||
raw := tg.NewClient(client)
|
||||
cb := func(ctx context.Context, dlg dialogs.Elem) error {
|
||||
// Skip deleted dialogs.
|
||||
if dlg.Deleted() {
|
||||
return nil
|
||||
}
|
||||
|
||||
q, ok := dlg.Participants(raw)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
return q.ForEach(ctx, func(ctx context.Context, elem participants.Elem) error {
|
||||
user, admin, ok := elem.Admin()
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
fmt.Println(user.Username, "admin")
|
||||
if admin.AdminRights.ChangeInfo {
|
||||
fmt.Println("\t+ ChangeInfo")
|
||||
}
|
||||
if admin.AdminRights.PostMessages {
|
||||
fmt.Println("\t+ PostMessages")
|
||||
}
|
||||
if admin.AdminRights.EditMessages {
|
||||
fmt.Println("\t+ EditMessages")
|
||||
}
|
||||
if admin.AdminRights.DeleteMessages {
|
||||
fmt.Println("\t+ DeleteMessages")
|
||||
}
|
||||
if admin.AdminRights.BanUsers {
|
||||
fmt.Println("\t+ BanUsers")
|
||||
}
|
||||
if admin.AdminRights.InviteUsers {
|
||||
fmt.Println("\t+ InviteUsers")
|
||||
}
|
||||
if admin.AdminRights.PinMessages {
|
||||
fmt.Println("\t+ PinMessages")
|
||||
}
|
||||
if admin.AdminRights.AddAdmins {
|
||||
fmt.Println("\t+ AddAdmins")
|
||||
}
|
||||
if admin.AdminRights.Anonymous {
|
||||
fmt.Println("\t+ Anonymous")
|
||||
}
|
||||
if admin.AdminRights.ManageCall {
|
||||
fmt.Println("\t+ ManageCall")
|
||||
}
|
||||
if admin.AdminRights.Other {
|
||||
fmt.Println("\t+ Other")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
return query.GetDialogs(raw).ForEach(ctx, cb)
|
||||
}); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user