Files
mautrix-telegram/pkg/gotd/telegram/query/internal/itergen/collect.go
T
2025-06-27 20:03:37 -07:00

252 lines
6.0 KiB
Go

package main
import (
"flag"
"go/token"
"go/types"
"sort"
"strings"
"github.com/go-faster/errors"
"golang.org/x/tools/go/packages"
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/query/internal/genutil"
)
type method struct {
name string
f *types.Func
sig *types.Signature
reqType types.Type
resultType types.Type
fromRequest []RequestArgument
params []Param
}
type collector struct {
ignoreFields map[string]struct{}
canFillFromRequest map[string]struct{}
requiredByIter []string
required map[string]string
pkg *packages.Package
ifaces *genutil.Interfaces
iface *types.Interface
resultTypeName string
elemName string
prefix string
pkgName string
requestFields []Param
}
type collectorConfig struct {
ResultName string
ElemName string
Prefix string
PkgName string
}
func (c *collectorConfig) fromFlags(set *flag.FlagSet) {
set.StringVar(&c.ResultName, "result", "MessagesMessagesClass", "result type name")
set.StringVar(&c.ElemName, "elem", "Elem", "element type name")
set.StringVar(&c.Prefix, "prefix", "Messages", "prefix of methods to trim")
set.StringVar(&c.PkgName, "package", "messages", "name of package name to generate")
}
func newCollector(pkg *packages.Package, cfg collectorConfig) *collector {
intGetter := types.NewSignatureType(nil, nil, nil, nil,
types.NewTuple(types.NewVar(0, nil, "", types.Typ[types.Int])), false) // func() int
methods := []*types.Func{
types.NewFunc(token.NoPos, nil, "GetLimit", intGetter),
}
match := types.NewInterfaceType(methods, nil).Complete()
canFillFromRequest := map[string]struct{}{
"AddOffset": {},
"OffsetID": {},
"OffsetDate": {},
"OffsetPeer": {},
"OffsetRate": {},
"Offset": {},
}
ignoreFields := map[string]struct{}{
// Already handled by match interface.
"Limit": {},
// Not real field.
"Flags": {},
// Telegram ignores MaxID and MinID sometimes.
"MaxID": {}, "MinID": {},
// ExcludePinned used by iterator.
"ExcludePinned": {},
// Hash can be used internally, so do not expose it.
"Hash": {},
}
requiredByIter := []string{
"OffsetID",
"OffsetDate",
"Offset",
}
required := map[string]string{
"Peer": "InputPeerClass",
"Channel": "InputChannelClass",
"UserID": "InputUserClass",
}
return &collector{
ignoreFields: ignoreFields,
canFillFromRequest: canFillFromRequest,
requiredByIter: requiredByIter,
required: required,
ifaces: genutil.NewInterfaces(pkg),
pkg: pkg,
iface: match,
resultTypeName: cfg.ResultName,
elemName: cfg.ElemName,
prefix: cfg.Prefix,
pkgName: cfg.PkgName,
}
}
func (c *collector) methods() ([]method, error) { // nolint:gocognit
var result []method
for _, def := range genutil.Funcs(c.pkg, func(f genutil.Func) bool {
return f.Args().Len() == 2 && f.Results().Len() == 2
}) {
args := def.Args()
results := def.Results()
ptr, ok := args.At(1).Type().(*types.Pointer)
if !ok || !types.Implements(ptr, c.iface) {
continue
}
reqType := ptr.Elem()
resultType, ok := results.At(0).Type().(*types.Named)
if !ok {
continue
}
if resultType.Obj().Name() != c.resultTypeName {
continue
}
name := strings.TrimPrefix(def.Decl.Name(), c.prefix)
m := method{
name: name,
f: def.Decl,
sig: def.Sig,
reqType: reqType,
resultType: resultType,
}
reqTypeStruct, ok := reqType.Underlying().(*types.Struct)
if !ok {
return nil, errors.Errorf("unexpected type %T", reqType.Underlying())
}
for i := 0; i < reqTypeStruct.NumFields(); i++ {
field := reqTypeStruct.Field(i)
if _, ok := c.ignoreFields[field.Name()]; ok {
continue
}
param := varToParam(field)
if _, ok := c.canFillFromRequest[field.Name()]; ok {
requiredByIter := false
for _, field := range c.requiredByIter {
if field == param.OriginalName {
requiredByIter = true
break
}
}
m.fromRequest = append(m.fromRequest, RequestArgument{
Arg: param,
Chain: field.Name() == "OffsetID" || field.Name() == "OffsetDate",
RequiredByIter: requiredByIter,
})
skip := false
for _, field := range c.requestFields {
if field.OriginalName == param.OriginalName {
skip = true
break
}
}
if !skip {
c.requestFields = append(c.requestFields, param)
}
continue
}
m.params = append(m.params, param)
}
result = append(result, m)
}
sort.SliceStable(result, func(i, j int) bool {
return result[i].name < result[j].name
})
return result, nil
}
func (c *collector) Config() (Config, error) {
methods, err := c.collect()
if err != nil {
return Config{}, errors.Wrap(err, "collect")
}
return Config{
Methods: methods,
Package: c.pkgName,
ResultName: c.resultTypeName,
RequestFields: sortParams(c.requestFields),
}, nil
}
func (c *collector) collect() ([]Method, error) {
methods, err := c.methods()
if err != nil {
return nil, errors.Wrap(err, "collect types")
}
result := make([]Method, 0, len(methods))
for _, method := range methods {
mapping := method.fromRequest
sort.SliceStable(mapping, func(i, j int) bool {
return mapping[i].Arg.Name < mapping[j].Arg.Name
})
m := Method{
Name: method.name,
OriginalName: method.f.Name(),
RequestName: genutil.PrintType(method.reqType),
ResultName: genutil.PrintType(method.resultType),
AdditionalMapping: mapping,
AdditionalParams: sortParams(method.params),
IteratorName: "Iterator",
ElemName: c.elemName,
}
for _, field := range method.params {
if _, ok := c.required[field.OriginalName]; ok {
m.RequiredParams = append(m.RequiredParams, field)
}
}
m.RequiredParams = sortParams(m.RequiredParams)
cases, err := c.collectSpecial(m)
if err != nil {
return nil, errors.Wrap(err, "collect special")
}
m.SpecialCase = cases
result = append(result, m)
}
return result, nil
}