move gotd fork into repo. (#111)

- update to latest telegram layer
- remove some references to fields in tg.Entities that don't exist in
the schema
- originally added here:
https://github.com/beeper/td/commit/820929062a2ba0104397bc01235ab58a9cff780e
  - referenced here
-
https://github.com/mautrix/telegramgo/commit/124f0967ed195b5a380c9bd02e170ada9710dde3
-
https://github.com/mautrix/telegramgo/commit/4205047aab2e0639217148b5d125bfaab668bd8e
This commit is contained in:
Adam Van Ymeren
2025-06-27 20:03:37 -07:00
committed by GitHub
parent 0952df0244
commit 7a04f298d2
19264 changed files with 1539697 additions and 84 deletions
@@ -0,0 +1,119 @@
{{ define "header" }}{{- /*gotype: go.mau.fi/mautrix-telegram/pkg/gotd/telegram/query/internal/cachedgen.Config*/ -}}
// Code generated by itergen, DO NOT EDIT.
package {{ $.Package }}
import (
"context"
"sync/atomic"
"github.com/go-faster/errors"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
)
// No-op definition for keeping imports.
var _ = context.Background()
{{ range $query := $.Queries }}
type inner{{ $query.Name }} struct {
// Last received hash.
hash int64
// Last received result.
value *tg.{{ $query.ResultName }}
}
type {{ $query.Name }} struct {
{{- if $query.RequestParams }}
// Query to send.
req *tg.{{ $query.RequestName }}{{ end }}
// Result state.
last atomic.Value
// Reference to RPC client to make requests.
raw *tg.Client
}
// New{{ $query.Name }} creates new {{ $query.Name }}.
func New{{ $query.Name }}(raw *tg.Client, {{- if $query.RequestParams }}initial *tg.{{ $query.RequestName }}{{- end }}) *{{ $query.Name }} {
q := &{{ $query.Name }}{
{{- if $query.RequestParams }}
req: initial,{{ end }}
raw: raw,
}
return q
}
func (s *{{ $query.Name }}) store(v inner{{ $query.Name }}) {
s.last.Store(v)
}
func (s *{{ $query.Name }}) load() (inner{{ $query.Name }}, bool) {
v, ok := s.last.Load().(inner{{ $query.Name }})
return v, ok
}
// Value returns last received result.
// NB: May be nil. Returned {{ $query.ResultName }} must not be mutated.
func (s *{{ $query.Name }}) Value() *tg.{{ $query.ResultName }} {
inner, _ := s.load()
return inner.value
}
// Hash returns last received hash.
func (s *{{ $query.Name }}) Hash() int64 {
inner, _ := s.load()
return inner.hash
}
// Get updates data if needed and returns it.
func (s *{{ $query.Name }}) Get(ctx context.Context) (*tg.{{ $query.ResultName }}, error) {
if _, err := s.Fetch(ctx); err != nil {
return nil, err
}
return s.Value(), nil
}
// Fetch updates data if needed and returns true if data was modified.
func (s *{{ $query.Name }}) Fetch(ctx context.Context) (bool, error) {
lastHash := s.Hash()
{{ if $query.RequestParams -}}
req := s.req
req.Hash = lastHash
{{- else -}}
req := lastHash
{{- end }}
result, err := s.raw.{{ $query.MethodName }}(ctx, req)
if err != nil {
return false, errors.Wrap(err, "execute {{ $query.MethodName }}")
}
switch variant := result.(type) {
case *tg.{{ $query.ResultName }}:
{{ if $query.ManualHash -}}
hash := s.computeHash(variant)
{{- else -}}
hash := variant.Hash
{{- end }}
s.store(inner{{ $query.Name }}{
hash: hash,
value: variant,
})
return true, nil
case *tg.{{ $query.NotModifiedName }}:
if lastHash == 0 {
return false, errors.Errorf("got unexpected %T result", result)
}
return false, nil
default:
return false, errors.Errorf("unexpected type %T", result)
}
}
{{ end }}
{{ end }}
@@ -0,0 +1,143 @@
package main
import (
"go/types"
"sort"
"strings"
"golang.org/x/tools/go/packages"
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/query/internal/genutil"
)
func isHashField(field *types.Var) bool {
basic, ok := field.Type().(*types.Basic)
if !ok {
return false
}
return basic.Kind() == types.Int64 && field.Name() == "Hash"
}
func hasHashField(st *types.Struct) bool {
for i := 0; i < st.NumFields(); i++ {
if isHashField(st.Field(i)) {
return true
}
}
return false
}
type request struct {
name string
params []Param
}
func isCachedQuery(args *types.Tuple) (request, bool) {
arg := args.At(1)
switch req := arg.Type().(type) {
case *types.Pointer:
named, ok := req.Elem().(*types.Named)
if !ok {
return request{}, false
}
st, ok := named.Underlying().(*types.Struct)
if !ok {
return request{}, false
}
var r []Param
for i := 0; i < st.NumFields(); i++ {
field := st.Field(i)
if strings.Contains(field.Name(), "Offset") {
return request{}, false
}
if isHashField(field) || field.Name() == "Flags" {
continue
}
r = append(r, varToParam(field))
}
return request{
name: named.Obj().Name(),
params: sortParams(r),
}, hasHashField(st)
case *types.Basic:
if req.Kind() != types.Int64 || arg.Name() != "hash" {
return request{}, false
}
return request{}, true
default:
return request{}, false
}
}
func collect(pkg *packages.Package) []CachedQuery {
var r []CachedQuery
for _, def := range genutil.Funcs(pkg, func(f genutil.Func) bool {
return f.Args().Len() == 2 && f.Results().Len() == 2
}) {
args := def.Args()
req, ok := isCachedQuery(args)
if !ok {
continue
}
resultNamed, ok := def.Results().At(0).Type().(*types.Named)
if !ok {
continue
}
result, ok := resultNamed.Underlying().(*types.Interface)
if !ok {
continue
}
impls := genutil.Implementations(pkg, result)
if len(impls) != 2 {
continue
}
var (
notModified *types.Named
pure *types.Named
)
for _, impl := range impls {
if notModified == nil && strings.Contains(impl.Obj().Name(), "NotModified") {
notModified = impl
continue
}
if pure == nil {
pure = impl
}
}
if pure == nil || notModified == nil {
continue
}
pureStruct, ok := pure.Underlying().(*types.Struct)
if !ok {
continue
}
r = append(r, CachedQuery{
Name: def.Decl.Name(),
MethodName: def.Decl.Name(),
RequestName: req.name,
ManualHash: !hasHashField(pureStruct),
RequestParams: req.params,
ResultName: pure.Obj().Name(),
NotModifiedName: notModified.Obj().Name(),
})
}
sort.SliceStable(r, func(i, j int) bool {
return r[i].Name < r[j].Name
})
return r
}
@@ -0,0 +1,65 @@
package main
import (
"context"
"embed"
"flag"
"fmt"
"io"
"os"
"os/signal"
"github.com/go-faster/errors"
"go.uber.org/multierr"
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/query/internal/genutil"
)
//go:embed _template/*.tmpl
var templates embed.FS
func generate(ctx context.Context, out io.Writer, pkgName string) error {
pkg, err := genutil.Load(ctx, "go.mau.fi/mautrix-telegram/pkg/gotd/tg")
if err != nil {
return errors.Wrap(err, "load")
}
return genutil.WriteTemplate(templates, out, "header", Config{
Queries: collect(pkg),
Package: pkgName,
})
}
func run(ctx context.Context) (err error) {
var out io.Writer = os.Stdout
set := flag.NewFlagSet("gen", flag.ExitOnError)
output := set.String("out", "", "output file")
pkgName := set.String("package", "cached", "name of package name to generate")
if err := set.Parse(os.Args[1:]); err != nil {
return errors.Wrap(err, "parse flags")
}
if *output != "" {
f, err := os.Create(*output)
if err != nil {
return errors.Wrapf(err, "can't create file %q", *output)
}
defer func() {
multierr.AppendInto(&err, f.Close())
}()
out = f
}
return generate(ctx, out, *pkgName)
}
func main() {
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt)
defer cancel()
if err := run(ctx); err != nil {
fmt.Println(err)
return
}
}
@@ -0,0 +1,14 @@
package main
import (
"bytes"
"context"
"testing"
)
func TestGenerate(t *testing.T) {
var out bytes.Buffer
if err := generate(context.Background(), &out, "testgen"); err != nil {
t.Fatal(err)
}
}
@@ -0,0 +1,65 @@
package main
import (
"go/types"
"sort"
"strings"
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/query/internal/genutil"
)
// Param represents request parameter.
type Param struct {
// Name to use in function declaration.
Name string
// OriginalName in struct definition.
OriginalName string
// Go type.
Type string
}
func varToParam(field *types.Var) Param {
fieldName := field.Name()
fieldName = strings.ToLower(fieldName[:1]) + fieldName[1:]
return Param{
Name: fieldName,
OriginalName: field.Name(),
Type: genutil.PrintType(field.Type()),
}
}
func sortParams(p []Param) []Param {
sort.SliceStable(p, func(i, j int) bool {
return p[i].Name < p[j].Name
})
return p
}
// CachedQuery is a RPC cacheable query helper.
type CachedQuery struct {
// Name of struct to generate.
Name string
// MethodName is name of method of tg.Client.
MethodName string
// RequestName is name of request struct.
RequestName string
// ManualHash determines whether hash must be computed using
// hand-written function computeHash or not.
// Need to resolve case when Telegram does not return hash with result.
ManualHash bool
// RequestParams contains additional params to send.
RequestParams []Param
// ResultName is name of result type.
ResultName string
// NotModifiedName is name of NotModified result type.
NotModifiedName string
}
// Config is codegeneration config to use.
type Config struct {
// Query helpers to generate.
Queries []CachedQuery
// ResultName package name
Package string
}
@@ -0,0 +1,2 @@
// Package genutil is a utility package for query helpers codegeneration.
package genutil
@@ -0,0 +1,55 @@
package genutil
import (
"go/types"
"golang.org/x/tools/go/packages"
)
// Func is a function representation.
type Func struct {
Sig *types.Signature
Decl *types.Func
}
// Results returns function results.
func (f Func) Results() *types.Tuple {
return f.Sig.Results()
}
// Args returns function arguments.
func (f Func) Args() *types.Tuple {
return f.Sig.Params()
}
// Funcs collects all function from package using given filter.
// Parameter keep may be nil.
func Funcs(pkg *packages.Package, keep func(f Func) bool) []Func {
var r []Func
for _, def := range pkg.TypesInfo.Defs {
if def == nil {
continue
}
f, ok := def.(*types.Func)
if !ok {
continue
}
sig, ok := f.Type().(*types.Signature)
if !ok {
continue
}
repr := Func{
Sig: sig,
Decl: f,
}
if keep(repr) {
r = append(r, repr)
}
}
return r
}
@@ -0,0 +1,46 @@
package genutil
import (
"context"
"go/ast"
"go/parser"
"go/token"
"os"
"github.com/go-faster/errors"
"golang.org/x/tools/go/packages"
)
func loadPackages(ctx context.Context, dir, pattern string, environ []string) ([]*packages.Package, error) {
return packages.Load(&packages.Config{
Context: ctx,
Dir: dir,
Mode: packages.NeedTypes |
packages.NeedTypesInfo |
packages.NeedTypesSizes |
packages.NeedSyntax |
packages.NeedDeps,
Env: environ,
Fset: token.NewFileSet(),
ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) {
const mode = parser.AllErrors | parser.ParseComments
return parser.ParseFile(fset, filename, src, mode)
},
}, pattern)
}
// Load loads package using given pattern.
func Load(ctx context.Context, pattern string) (*packages.Package, error) {
pkgs, err := loadPackages(ctx, "", pattern, os.Environ())
if err != nil {
return nil, errors.Wrap(err, "load packages")
}
for _, pkg := range pkgs {
if pkg.ID == pattern {
return pkg, nil
}
}
return nil, errors.Errorf("package %s not found", pattern)
}
@@ -0,0 +1,10 @@
package genutil
import "go/types"
// PrintType prints typename into string without package name.
func PrintType(typ types.Type) string {
return types.TypeString(typ, func(i *types.Package) string {
return i.Name()
})
}
@@ -0,0 +1,75 @@
package genutil
import (
"go/types"
"github.com/go-faster/errors"
"golang.org/x/tools/go/packages"
)
// Implementations finds iface implementations.
func Implementations(pkg *packages.Package, iface *types.Interface) []*types.Named {
var r []*types.Named
for _, def := range pkg.TypesInfo.Defs {
if def == nil || !def.Exported() {
continue
}
named, ok := def.Type().(*types.Named)
if !ok {
continue
}
if !types.Implements(types.NewPointer(named), iface) {
continue
}
r = append(r, named)
}
return r
}
// Interfaces is a simple utility struct to find interfaces and implementations.
type Interfaces struct {
pkg *packages.Package
implsCache map[string][]*types.Named
}
// NewInterfaces creates new Interfaces structure.
func NewInterfaces(pkg *packages.Package) *Interfaces {
return &Interfaces{pkg: pkg, implsCache: map[string][]*types.Named{}}
}
// Interface finds interface by name.
func (c *Interfaces) Interface(name string) (*types.Interface, error) {
obj := c.pkg.Types.Scope().Lookup(name)
if obj == nil {
return nil, errors.Errorf("%q not found", name)
}
v, ok := obj.Type().Underlying().(*types.Interface)
if !ok {
return nil, errors.Errorf("%q has unexpected kind type %T", name, obj.Type().Underlying())
}
return v, nil
}
// Implementations finds interface implementations by interface name.
func (c *Interfaces) Implementations(name string) ([]*types.Named, error) {
impls, ok := c.implsCache[name]
if ok {
return impls, nil
}
iface, err := c.Interface(name)
if err != nil {
return nil, errors.Wrapf(err, "find %q", name)
}
impls = Implementations(c.pkg, iface)
c.implsCache[name] = impls
return impls, nil
}
@@ -0,0 +1,35 @@
package genutil
import (
"bytes"
"go/format"
"io"
"io/fs"
"os"
"text/template"
"github.com/go-faster/errors"
"go.mau.fi/mautrix-telegram/pkg/gotd/gen"
)
// WriteTemplate loads template from FS and executes it to given output writer.
func WriteTemplate(source fs.FS, out io.Writer, name string, data interface{}) error {
tmpl := template.New("templates").Funcs(gen.Funcs())
tmpl = template.Must(tmpl.ParseFS(source, "_template/*.tmpl"))
var buf bytes.Buffer
if err := tmpl.ExecuteTemplate(&buf, name, data); err != nil {
return errors.Wrap(err, "template")
}
formatted, err := format.Source(buf.Bytes())
if err != nil {
if _, cpyErr := io.Copy(os.Stdout, &buf); cpyErr != nil {
return errors.Wrapf(cpyErr, "dump generated (original error: %v)", err)
}
return errors.Wrap(err, "format")
}
_, err = out.Write(formatted)
return err
}
@@ -0,0 +1,183 @@
{{ define "header" }}{{- /*gotype: go.mau.fi/mautrix-telegram/pkg/gotd/telegram/query/internal/itergen.Config*/ -}}
// Code generated by itergen, DO NOT EDIT.
package {{ $.Package }}
import (
"context"
"github.com/go-faster/errors"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
)
// No-op definition for keeping imports.
var _ = context.Background()
// Request is a parameter for Query.
type Request struct {
{{- range $arg := $.RequestFields }}
{{ $arg.OriginalName }} {{ $arg.Type }}
{{- end }}
Limit int
}
// Query is an abstraction for {{ $.Package }} request.
// NB: iterator mutates returned data (sorts, at least).
type Query interface {
Query(ctx context.Context, req Request) (tg.{{ $.ResultName }}, error)
}
// QueryFunc is a function adapter for Query.
type QueryFunc func(ctx context.Context, req Request) (tg.{{ $.ResultName }}, error)
// Query implements Query interface.
func (q QueryFunc) Query(ctx context.Context, req Request) (tg.{{ $.ResultName }}, error) {
return q(ctx, req)
}
// QueryBuilder is a helper to create message queries.
type QueryBuilder struct {
raw *tg.Client
}
// NewQueryBuilder creates new QueryBuilder.
func NewQueryBuilder(raw *tg.Client) *QueryBuilder {
return &QueryBuilder{raw: raw}
}
{{ range $method := $.Methods }}
{{ template "query" $method }}
{{- end }}
{{ end }}
{{ define "query" }}{{- /*gotype: go.mau.fi/mautrix-telegram/pkg/gotd/telegram/query/internal/gen.Method*/ -}}
// {{ $.Name }}QueryBuilder is query builder of {{ $.OriginalName }}.
type {{ $.Name }}QueryBuilder struct {
raw *tg.Client
req {{ $.RequestName }}
batchSize int
{{- range $mapping := $.AdditionalMapping }}
{{ $mapping.Arg.Name }} {{ $mapping.Arg.Type }}
{{- end }}
}
// {{ $.Name }} creates query builder of {{ $.OriginalName }}.
func (q *QueryBuilder) {{ $.Name }}({{ range $arg := $.RequiredParams }}param{{ $arg.OriginalName }} {{ $arg.Type }},{{ end }}) *{{ $.Name }}QueryBuilder {
b := &{{ $.Name }}QueryBuilder{
raw: q.raw,
batchSize: 1,
req:{{ $.RequestName }}{
{{- range $f := $.AdditionalParams }}{{ if eq ($f.Type) ("tg.MessagesFilterClass") }}
{{ $f.OriginalName }}: &tg.InputMessagesFilterEmpty{},
{{- end }}{{ if eq ($f.Type) ("tg.InputPeerClass") }}
{{ $f.OriginalName }}: &tg.InputPeerEmpty{},
{{- end }}{{ if eq ($f.Type) ("tg.ChannelParticipantsFilterClass") }}
{{ $f.OriginalName }}: &tg.ChannelParticipantsRecent{},
{{- end }}{{ end }}
},
}
{{ range $arg := $.RequiredParams }}
b.req.{{ $arg.OriginalName }} = param{{ $arg.OriginalName }}
{{- end }}
return b
}
// BatchSize sets buffer of message loaded from one request.
// Be carefully, when set this limit, because Telegram does not return error if limit is too big,
// so results can be incorrect.
func (b *{{ $.Name }}QueryBuilder) BatchSize(batchSize int) *{{ $.Name }}QueryBuilder {
b.batchSize = batchSize
return b
}
{{- range $mapping := $.AdditionalMapping }}{{ if $mapping.Chain }}
// {{ $mapping.Arg.OriginalName }} sets {{ $mapping.Arg.Name }} from which iterate start.
func (b *{{ $.Name }}QueryBuilder) {{ $mapping.Arg.OriginalName }}({{ $mapping.Arg.Name }} int) *{{ $.Name }}QueryBuilder {
b.{{ $mapping.Arg.Name }} = {{ $mapping.Arg.Name }}
return b
}
{{- end }}{{- end }}
{{ range $f := $.AdditionalParams }}
// {{ $f.OriginalName }} sets {{ $f.OriginalName }} field of {{ $.Name }} query.
func (b *{{ $.Name }}QueryBuilder) {{ $f.OriginalName }}(param{{ $f.OriginalName }} {{ $f.Type }}) *{{ $.Name }}QueryBuilder {
b.req.{{ $f.OriginalName }} = param{{ $f.OriginalName }}
return b
}
{{ end }}
{{ range $f := $.SpecialCase }}
// {{ $f.ConstructorName }} sets {{ $f.Field.OriginalName }} field of {{ $.Name }} query.
func (b *{{ $.Name }}QueryBuilder) {{ $f.ConstructorName }}({{ range $arg := $f.Args }}param{{ $arg.OriginalName }} {{ $arg.Type }},{{ end }}) *{{ $.Name }}QueryBuilder {
b.req.{{ $f.Field.OriginalName }} = &{{ $f.ConstructorType }}{
{{- range $arg := $f.Args }}
{{ $arg.OriginalName }}: param{{ $arg.OriginalName }},
{{- end }}
}
return b
}
{{ end }}
// Query implements Query interface.
func (b *{{ $.Name }}QueryBuilder) Query(ctx context.Context, req Request) ({{ $.ResultName }}, error) {
r := &{{ $.RequestName }}{
Limit: req.Limit,
}
{{ range $f := $.AdditionalParams }}
r.{{ $f.OriginalName }} = b.req.{{ $f.OriginalName }}
{{- end }}
{{- range $f := $.AdditionalMapping }}
r.{{ $f.Arg.OriginalName }} = req.{{ $f.Arg.OriginalName }}
{{- end }}
return b.raw.{{ $.OriginalName }}(ctx, r)
}
// Iter returns iterator using built query.
func (b *{{ $.Name }}QueryBuilder) Iter() *{{ $.IteratorName }} {
iter := New{{ $.IteratorName }}(b, b.batchSize)
{{- range $mapping := $.AdditionalMapping }}{{ if $mapping.RequiredByIter }}
iter = iter.{{ $mapping.Arg.OriginalName }}(b.{{ $mapping.Arg.Name }})
{{- end }}{{- end }}
return iter
}
// ForEach calls given callback on each iterator element.
func (b *{{ $.Name }}QueryBuilder) ForEach(ctx context.Context, cb func(context.Context, {{ $.ElemName }}) error) error {
iter := b.Iter()
for iter.Next(ctx) {
if err := cb(ctx, iter.Value()); err != nil {
return err
}
}
return iter.Err()
}
// Count fetches remote state to get number of elements.
func (b *{{ $.Name }}QueryBuilder) Count(ctx context.Context) (int, error) {
iter := b.Iter()
c, err := iter.Total(ctx)
if err != nil {
return 0, errors.Wrap(err, "get total")
}
return c, nil
}
// Collect creates iterator and collects all elements to slice.
func (b *{{ $.Name }}QueryBuilder) Collect(ctx context.Context) ([]{{ $.ElemName }}, error) {
iter := b.Iter()
c, err := iter.Total(ctx)
if err != nil {
return nil, errors.Wrap(err, "get total")
}
r := make([]{{ $.ElemName }}, 0, c)
for iter.Next(ctx) {
r = append(r, iter.Value())
}
return r, iter.Err()
}
{{ end }}
@@ -0,0 +1,251 @@
package main
import (
"flag"
"go/token"
"go/types"
"sort"
"strings"
"github.com/go-faster/errors"
"golang.org/x/tools/go/packages"
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/query/internal/genutil"
)
type method struct {
name string
f *types.Func
sig *types.Signature
reqType types.Type
resultType types.Type
fromRequest []RequestArgument
params []Param
}
type collector struct {
ignoreFields map[string]struct{}
canFillFromRequest map[string]struct{}
requiredByIter []string
required map[string]string
pkg *packages.Package
ifaces *genutil.Interfaces
iface *types.Interface
resultTypeName string
elemName string
prefix string
pkgName string
requestFields []Param
}
type collectorConfig struct {
ResultName string
ElemName string
Prefix string
PkgName string
}
func (c *collectorConfig) fromFlags(set *flag.FlagSet) {
set.StringVar(&c.ResultName, "result", "MessagesMessagesClass", "result type name")
set.StringVar(&c.ElemName, "elem", "Elem", "element type name")
set.StringVar(&c.Prefix, "prefix", "Messages", "prefix of methods to trim")
set.StringVar(&c.PkgName, "package", "messages", "name of package name to generate")
}
func newCollector(pkg *packages.Package, cfg collectorConfig) *collector {
intGetter := types.NewSignatureType(nil, nil, nil, nil,
types.NewTuple(types.NewVar(0, nil, "", types.Typ[types.Int])), false) // func() int
methods := []*types.Func{
types.NewFunc(token.NoPos, nil, "GetLimit", intGetter),
}
match := types.NewInterfaceType(methods, nil).Complete()
canFillFromRequest := map[string]struct{}{
"AddOffset": {},
"OffsetID": {},
"OffsetDate": {},
"OffsetPeer": {},
"OffsetRate": {},
"Offset": {},
}
ignoreFields := map[string]struct{}{
// Already handled by match interface.
"Limit": {},
// Not real field.
"Flags": {},
// Telegram ignores MaxID and MinID sometimes.
"MaxID": {}, "MinID": {},
// ExcludePinned used by iterator.
"ExcludePinned": {},
// Hash can be used internally, so do not expose it.
"Hash": {},
}
requiredByIter := []string{
"OffsetID",
"OffsetDate",
"Offset",
}
required := map[string]string{
"Peer": "InputPeerClass",
"Channel": "InputChannelClass",
"UserID": "InputUserClass",
}
return &collector{
ignoreFields: ignoreFields,
canFillFromRequest: canFillFromRequest,
requiredByIter: requiredByIter,
required: required,
ifaces: genutil.NewInterfaces(pkg),
pkg: pkg,
iface: match,
resultTypeName: cfg.ResultName,
elemName: cfg.ElemName,
prefix: cfg.Prefix,
pkgName: cfg.PkgName,
}
}
func (c *collector) methods() ([]method, error) { // nolint:gocognit
var result []method
for _, def := range genutil.Funcs(c.pkg, func(f genutil.Func) bool {
return f.Args().Len() == 2 && f.Results().Len() == 2
}) {
args := def.Args()
results := def.Results()
ptr, ok := args.At(1).Type().(*types.Pointer)
if !ok || !types.Implements(ptr, c.iface) {
continue
}
reqType := ptr.Elem()
resultType, ok := results.At(0).Type().(*types.Named)
if !ok {
continue
}
if resultType.Obj().Name() != c.resultTypeName {
continue
}
name := strings.TrimPrefix(def.Decl.Name(), c.prefix)
m := method{
name: name,
f: def.Decl,
sig: def.Sig,
reqType: reqType,
resultType: resultType,
}
reqTypeStruct, ok := reqType.Underlying().(*types.Struct)
if !ok {
return nil, errors.Errorf("unexpected type %T", reqType.Underlying())
}
for i := 0; i < reqTypeStruct.NumFields(); i++ {
field := reqTypeStruct.Field(i)
if _, ok := c.ignoreFields[field.Name()]; ok {
continue
}
param := varToParam(field)
if _, ok := c.canFillFromRequest[field.Name()]; ok {
requiredByIter := false
for _, field := range c.requiredByIter {
if field == param.OriginalName {
requiredByIter = true
break
}
}
m.fromRequest = append(m.fromRequest, RequestArgument{
Arg: param,
Chain: field.Name() == "OffsetID" || field.Name() == "OffsetDate",
RequiredByIter: requiredByIter,
})
skip := false
for _, field := range c.requestFields {
if field.OriginalName == param.OriginalName {
skip = true
break
}
}
if !skip {
c.requestFields = append(c.requestFields, param)
}
continue
}
m.params = append(m.params, param)
}
result = append(result, m)
}
sort.SliceStable(result, func(i, j int) bool {
return result[i].name < result[j].name
})
return result, nil
}
func (c *collector) Config() (Config, error) {
methods, err := c.collect()
if err != nil {
return Config{}, errors.Wrap(err, "collect")
}
return Config{
Methods: methods,
Package: c.pkgName,
ResultName: c.resultTypeName,
RequestFields: sortParams(c.requestFields),
}, nil
}
func (c *collector) collect() ([]Method, error) {
methods, err := c.methods()
if err != nil {
return nil, errors.Wrap(err, "collect types")
}
result := make([]Method, 0, len(methods))
for _, method := range methods {
mapping := method.fromRequest
sort.SliceStable(mapping, func(i, j int) bool {
return mapping[i].Arg.Name < mapping[j].Arg.Name
})
m := Method{
Name: method.name,
OriginalName: method.f.Name(),
RequestName: genutil.PrintType(method.reqType),
ResultName: genutil.PrintType(method.resultType),
AdditionalMapping: mapping,
AdditionalParams: sortParams(method.params),
IteratorName: "Iterator",
ElemName: c.elemName,
}
for _, field := range method.params {
if _, ok := c.required[field.OriginalName]; ok {
m.RequiredParams = append(m.RequiredParams, field)
}
}
m.RequiredParams = sortParams(m.RequiredParams)
cases, err := c.collectSpecial(m)
if err != nil {
return nil, errors.Wrap(err, "collect special")
}
m.SpecialCase = cases
result = append(result, m)
}
return result, nil
}
@@ -0,0 +1,69 @@
package main
import (
"context"
"embed"
"flag"
"fmt"
"io"
"os"
"os/signal"
"github.com/go-faster/errors"
"go.uber.org/multierr"
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/query/internal/genutil"
)
//go:embed _template/*.tmpl
var templates embed.FS
func generate(ctx context.Context, out io.Writer, cfg collectorConfig) error {
pkg, err := genutil.Load(ctx, "go.mau.fi/mautrix-telegram/pkg/gotd/tg")
if err != nil {
return errors.Wrap(err, "load")
}
c := newCollector(pkg, cfg)
config, err := c.Config()
if err != nil {
return errors.Wrap(err, "collect")
}
return genutil.WriteTemplate(templates, out, "header", config)
}
func run(ctx context.Context) (err error) {
var out io.Writer = os.Stdout
set := flag.NewFlagSet("gen", flag.ExitOnError)
output := set.String("out", "", "output file")
cfg := collectorConfig{}
cfg.fromFlags(set)
if err := set.Parse(os.Args[1:]); err != nil {
return errors.Wrap(err, "parse flags")
}
if *output != "" {
f, err := os.Create(*output)
if err != nil {
return errors.Wrapf(err, "can't create file %q", *output)
}
defer func() {
multierr.AppendInto(&err, f.Close())
}()
out = f
}
return generate(ctx, out, cfg)
}
func main() {
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt)
defer cancel()
if err := run(ctx); err != nil {
fmt.Println(err)
return
}
}
@@ -0,0 +1,19 @@
package main
import (
"bytes"
"context"
"testing"
)
func TestGenerate(t *testing.T) {
var out bytes.Buffer
if err := generate(context.Background(), &out, collectorConfig{
ResultName: "MessagesMessagesClass",
ElemName: "Elem",
Prefix: "Messages",
PkgName: "messages",
}); err != nil {
t.Fatal(err)
}
}
@@ -0,0 +1,100 @@
package main
import (
"go/types"
"sort"
"strings"
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/query/internal/genutil"
)
// Param represents request parameter.
type Param struct {
// Name to use in function declaration.
Name string
// OriginalName in struct definition.
OriginalName string
// Go type.
Type string
}
func varToParam(field *types.Var) Param {
fieldName := field.Name()
fieldName = strings.ToLower(fieldName[:1]) + fieldName[1:]
return Param{
Name: fieldName,
OriginalName: field.Name(),
Type: genutil.PrintType(field.Type()),
}
}
func sortParams(p []Param) []Param {
sort.SliceStable(p, func(i, j int) bool {
return p[i].Name < p[j].Name
})
return p
}
// SpecialCaseChain represents special request parameter setter.
type SpecialCaseChain struct {
// ConstructorName to use in function body.
ConstructorName string
// ConstructorType to use in function body.
ConstructorType string
// Field of request struct.
Field Param
// Args is a slice of arguments. May be empty.
Args []Param
}
// RequestArgument represents request parameter passed by iterator.
type RequestArgument struct {
// Arg describes argument.
Arg Param
// Chain is flag to generate builder chain setter.
Chain bool
// RequiredByIter is flag to generate pass to iterator constructor.
RequiredByIter bool
}
// Method is a RPC method.
type Method struct {
// Name to use in function declaration.
Name string
// OriginalName is name of method of tg.Client.
OriginalName string
// RequestName is name of request struct.
RequestName string
// ResultName is name of result type.
ResultName string
// RequiredParams is a required params for query builder.
RequiredParams []Param
// AdditionalMapping is names of field from iterator.
// Some type doesn't have AddOffset for example, so we customize mapping here.
AdditionalMapping []RequestArgument
// SpecialCase is a slice of special case chains.
// Like tg.MessagesFilterClass constructor field setters.
SpecialCase []SpecialCaseChain
// Other parameters of request to pass it in constructor.
AdditionalParams []Param
// IteratorName is name of iterator to build.
IteratorName string
// ElemName is name of iterator elem.
ElemName string
}
// Config is codegeneration config to use.
type Config struct {
// Methods to generate helpers and query builders.
Methods []Method
// ResultName package name
Package string
// ResultName is name of result type.
ResultName string
// RequestFields is a slice of request struct fields.
RequestFields []Param
}
@@ -0,0 +1,90 @@
package main
import (
"go/types"
"sort"
"strings"
"github.com/go-faster/errors"
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/query/internal/genutil"
)
func (c *collector) unpackClass(
field Param,
typeName, trimPrefix string,
) ([]SpecialCaseChain, error) {
var r []SpecialCaseChain
if field.Type == "tg."+typeName {
impls, err := c.ifaces.Implementations(typeName)
if err != nil {
return nil, errors.Wrapf(err, "find %q constructors", typeName)
}
for _, impl := range impls {
s, ok := impl.Underlying().(*types.Struct)
if !ok {
continue
}
cse := SpecialCaseChain{
ConstructorName: strings.TrimPrefix(impl.Obj().Name(), trimPrefix),
ConstructorType: genutil.PrintType(impl),
Field: field,
}
if strings.Contains(cse.ConstructorName, "Empty") {
continue
}
for i := 0; i < s.NumFields(); i++ {
field := s.Field(i)
if field.Name() == "Flags" {
continue
}
cse.Args = append(cse.Args, varToParam(field))
}
cse.Args = sortParams(cse.Args)
r = append(r, cse)
}
}
return r, nil
}
func (c *collector) unpackClasses(
field Param,
classes ...[2]string,
) ([]SpecialCaseChain, error) {
var r []SpecialCaseChain
for _, class := range classes {
cases, err := c.unpackClass(field, class[0], class[1])
if err != nil {
return nil, errors.Wrapf(err, "unpack %q", class[0])
}
r = append(r, cases...)
}
return r, nil
}
func (c *collector) collectSpecial(m Method) ([]SpecialCaseChain, error) {
var r []SpecialCaseChain
for _, field := range m.AdditionalParams {
cases, err := c.unpackClasses(field, [][2]string{
{"MessagesFilterClass", "InputMessagesFilter"},
{"ChannelParticipantsFilterClass", "ChannelParticipants"},
}...)
if err != nil {
return nil, err
}
r = append(r, cases...)
}
sort.SliceStable(r, func(i, j int) bool {
return r[i].ConstructorName < r[j].ConstructorName
})
return r, nil
}