7a04f298d2
- update to latest telegram layer - remove some references to fields in tg.Entities that don't exist in the schema - originally added here: https://github.com/beeper/td/commit/820929062a2ba0104397bc01235ab58a9cff780e - referenced here - https://github.com/mautrix/telegramgo/commit/124f0967ed195b5a380c9bd02e170ada9710dde3 - https://github.com/mautrix/telegramgo/commit/4205047aab2e0639217148b5d125bfaab668bd8e
252 lines
6.0 KiB
Go
252 lines
6.0 KiB
Go
package main
|
|
|
|
import (
|
|
"flag"
|
|
"go/token"
|
|
"go/types"
|
|
"sort"
|
|
"strings"
|
|
|
|
"github.com/go-faster/errors"
|
|
"golang.org/x/tools/go/packages"
|
|
|
|
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/query/internal/genutil"
|
|
)
|
|
|
|
type method struct {
|
|
name string
|
|
f *types.Func
|
|
sig *types.Signature
|
|
reqType types.Type
|
|
resultType types.Type
|
|
|
|
fromRequest []RequestArgument
|
|
params []Param
|
|
}
|
|
|
|
type collector struct {
|
|
ignoreFields map[string]struct{}
|
|
canFillFromRequest map[string]struct{}
|
|
requiredByIter []string
|
|
required map[string]string
|
|
|
|
pkg *packages.Package
|
|
ifaces *genutil.Interfaces
|
|
|
|
iface *types.Interface
|
|
resultTypeName string
|
|
elemName string
|
|
prefix string
|
|
pkgName string
|
|
requestFields []Param
|
|
}
|
|
|
|
type collectorConfig struct {
|
|
ResultName string
|
|
ElemName string
|
|
Prefix string
|
|
PkgName string
|
|
}
|
|
|
|
func (c *collectorConfig) fromFlags(set *flag.FlagSet) {
|
|
set.StringVar(&c.ResultName, "result", "MessagesMessagesClass", "result type name")
|
|
set.StringVar(&c.ElemName, "elem", "Elem", "element type name")
|
|
set.StringVar(&c.Prefix, "prefix", "Messages", "prefix of methods to trim")
|
|
set.StringVar(&c.PkgName, "package", "messages", "name of package name to generate")
|
|
}
|
|
|
|
func newCollector(pkg *packages.Package, cfg collectorConfig) *collector {
|
|
intGetter := types.NewSignatureType(nil, nil, nil, nil,
|
|
types.NewTuple(types.NewVar(0, nil, "", types.Typ[types.Int])), false) // func() int
|
|
methods := []*types.Func{
|
|
types.NewFunc(token.NoPos, nil, "GetLimit", intGetter),
|
|
}
|
|
match := types.NewInterfaceType(methods, nil).Complete()
|
|
|
|
canFillFromRequest := map[string]struct{}{
|
|
"AddOffset": {},
|
|
"OffsetID": {},
|
|
"OffsetDate": {},
|
|
"OffsetPeer": {},
|
|
"OffsetRate": {},
|
|
"Offset": {},
|
|
}
|
|
ignoreFields := map[string]struct{}{
|
|
// Already handled by match interface.
|
|
"Limit": {},
|
|
// Not real field.
|
|
"Flags": {},
|
|
// Telegram ignores MaxID and MinID sometimes.
|
|
"MaxID": {}, "MinID": {},
|
|
// ExcludePinned used by iterator.
|
|
"ExcludePinned": {},
|
|
// Hash can be used internally, so do not expose it.
|
|
"Hash": {},
|
|
}
|
|
requiredByIter := []string{
|
|
"OffsetID",
|
|
"OffsetDate",
|
|
"Offset",
|
|
}
|
|
required := map[string]string{
|
|
"Peer": "InputPeerClass",
|
|
"Channel": "InputChannelClass",
|
|
"UserID": "InputUserClass",
|
|
}
|
|
|
|
return &collector{
|
|
ignoreFields: ignoreFields,
|
|
canFillFromRequest: canFillFromRequest,
|
|
requiredByIter: requiredByIter,
|
|
required: required,
|
|
ifaces: genutil.NewInterfaces(pkg),
|
|
pkg: pkg,
|
|
iface: match,
|
|
resultTypeName: cfg.ResultName,
|
|
elemName: cfg.ElemName,
|
|
prefix: cfg.Prefix,
|
|
pkgName: cfg.PkgName,
|
|
}
|
|
}
|
|
|
|
func (c *collector) methods() ([]method, error) { // nolint:gocognit
|
|
var result []method
|
|
|
|
for _, def := range genutil.Funcs(c.pkg, func(f genutil.Func) bool {
|
|
return f.Args().Len() == 2 && f.Results().Len() == 2
|
|
}) {
|
|
args := def.Args()
|
|
results := def.Results()
|
|
|
|
ptr, ok := args.At(1).Type().(*types.Pointer)
|
|
if !ok || !types.Implements(ptr, c.iface) {
|
|
continue
|
|
}
|
|
reqType := ptr.Elem()
|
|
|
|
resultType, ok := results.At(0).Type().(*types.Named)
|
|
if !ok {
|
|
continue
|
|
}
|
|
|
|
if resultType.Obj().Name() != c.resultTypeName {
|
|
continue
|
|
}
|
|
name := strings.TrimPrefix(def.Decl.Name(), c.prefix)
|
|
|
|
m := method{
|
|
name: name,
|
|
f: def.Decl,
|
|
sig: def.Sig,
|
|
reqType: reqType,
|
|
resultType: resultType,
|
|
}
|
|
|
|
reqTypeStruct, ok := reqType.Underlying().(*types.Struct)
|
|
if !ok {
|
|
return nil, errors.Errorf("unexpected type %T", reqType.Underlying())
|
|
}
|
|
|
|
for i := 0; i < reqTypeStruct.NumFields(); i++ {
|
|
field := reqTypeStruct.Field(i)
|
|
|
|
if _, ok := c.ignoreFields[field.Name()]; ok {
|
|
continue
|
|
}
|
|
|
|
param := varToParam(field)
|
|
if _, ok := c.canFillFromRequest[field.Name()]; ok {
|
|
requiredByIter := false
|
|
for _, field := range c.requiredByIter {
|
|
if field == param.OriginalName {
|
|
requiredByIter = true
|
|
break
|
|
}
|
|
}
|
|
m.fromRequest = append(m.fromRequest, RequestArgument{
|
|
Arg: param,
|
|
Chain: field.Name() == "OffsetID" || field.Name() == "OffsetDate",
|
|
RequiredByIter: requiredByIter,
|
|
})
|
|
|
|
skip := false
|
|
for _, field := range c.requestFields {
|
|
if field.OriginalName == param.OriginalName {
|
|
skip = true
|
|
break
|
|
}
|
|
}
|
|
if !skip {
|
|
c.requestFields = append(c.requestFields, param)
|
|
}
|
|
continue
|
|
}
|
|
|
|
m.params = append(m.params, param)
|
|
}
|
|
|
|
result = append(result, m)
|
|
}
|
|
|
|
sort.SliceStable(result, func(i, j int) bool {
|
|
return result[i].name < result[j].name
|
|
})
|
|
return result, nil
|
|
}
|
|
|
|
func (c *collector) Config() (Config, error) {
|
|
methods, err := c.collect()
|
|
if err != nil {
|
|
return Config{}, errors.Wrap(err, "collect")
|
|
}
|
|
|
|
return Config{
|
|
Methods: methods,
|
|
Package: c.pkgName,
|
|
ResultName: c.resultTypeName,
|
|
RequestFields: sortParams(c.requestFields),
|
|
}, nil
|
|
}
|
|
|
|
func (c *collector) collect() ([]Method, error) {
|
|
methods, err := c.methods()
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "collect types")
|
|
}
|
|
|
|
result := make([]Method, 0, len(methods))
|
|
for _, method := range methods {
|
|
mapping := method.fromRequest
|
|
sort.SliceStable(mapping, func(i, j int) bool {
|
|
return mapping[i].Arg.Name < mapping[j].Arg.Name
|
|
})
|
|
|
|
m := Method{
|
|
Name: method.name,
|
|
OriginalName: method.f.Name(),
|
|
RequestName: genutil.PrintType(method.reqType),
|
|
ResultName: genutil.PrintType(method.resultType),
|
|
AdditionalMapping: mapping,
|
|
AdditionalParams: sortParams(method.params),
|
|
IteratorName: "Iterator",
|
|
ElemName: c.elemName,
|
|
}
|
|
|
|
for _, field := range method.params {
|
|
if _, ok := c.required[field.OriginalName]; ok {
|
|
m.RequiredParams = append(m.RequiredParams, field)
|
|
}
|
|
}
|
|
m.RequiredParams = sortParams(m.RequiredParams)
|
|
|
|
cases, err := c.collectSpecial(m)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "collect special")
|
|
}
|
|
|
|
m.SpecialCase = cases
|
|
result = append(result, m)
|
|
}
|
|
return result, nil
|
|
}
|