move gotd fork into repo. (#111)
- update to latest telegram layer - remove some references to fields in tg.Entities that don't exist in the schema - originally added here: https://github.com/beeper/td/commit/820929062a2ba0104397bc01235ab58a9cff780e - referenced here - https://github.com/mautrix/telegramgo/commit/124f0967ed195b5a380c9bd02e170ada9710dde3 - https://github.com/mautrix/telegramgo/commit/4205047aab2e0639217148b5d125bfaab668bd8e
This commit is contained in:
@@ -0,0 +1,119 @@
|
||||
{{ define "header" }}{{- /*gotype: go.mau.fi/mautrix-telegram/pkg/gotd/telegram/query/internal/cachedgen.Config*/ -}}
|
||||
// Code generated by itergen, DO NOT EDIT.
|
||||
|
||||
package {{ $.Package }}
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
// No-op definition for keeping imports.
|
||||
var _ = context.Background()
|
||||
|
||||
{{ range $query := $.Queries }}
|
||||
|
||||
type inner{{ $query.Name }} struct {
|
||||
// Last received hash.
|
||||
hash int64
|
||||
// Last received result.
|
||||
value *tg.{{ $query.ResultName }}
|
||||
}
|
||||
|
||||
type {{ $query.Name }} struct {
|
||||
{{- if $query.RequestParams }}
|
||||
// Query to send.
|
||||
req *tg.{{ $query.RequestName }}{{ end }}
|
||||
// Result state.
|
||||
last atomic.Value
|
||||
|
||||
// Reference to RPC client to make requests.
|
||||
raw *tg.Client
|
||||
}
|
||||
|
||||
// New{{ $query.Name }} creates new {{ $query.Name }}.
|
||||
func New{{ $query.Name }}(raw *tg.Client, {{- if $query.RequestParams }}initial *tg.{{ $query.RequestName }}{{- end }}) *{{ $query.Name }} {
|
||||
q := &{{ $query.Name }}{
|
||||
{{- if $query.RequestParams }}
|
||||
req: initial,{{ end }}
|
||||
raw: raw,
|
||||
}
|
||||
|
||||
return q
|
||||
}
|
||||
|
||||
func (s *{{ $query.Name }}) store(v inner{{ $query.Name }}) {
|
||||
s.last.Store(v)
|
||||
}
|
||||
|
||||
func (s *{{ $query.Name }}) load() (inner{{ $query.Name }}, bool) {
|
||||
v, ok := s.last.Load().(inner{{ $query.Name }})
|
||||
return v, ok
|
||||
}
|
||||
|
||||
// Value returns last received result.
|
||||
// NB: May be nil. Returned {{ $query.ResultName }} must not be mutated.
|
||||
func (s *{{ $query.Name }}) Value() *tg.{{ $query.ResultName }} {
|
||||
inner, _ := s.load()
|
||||
return inner.value
|
||||
}
|
||||
|
||||
// Hash returns last received hash.
|
||||
func (s *{{ $query.Name }}) Hash() int64 {
|
||||
inner, _ := s.load()
|
||||
return inner.hash
|
||||
}
|
||||
|
||||
// Get updates data if needed and returns it.
|
||||
func (s *{{ $query.Name }}) Get(ctx context.Context) (*tg.{{ $query.ResultName }}, error) {
|
||||
if _, err := s.Fetch(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s.Value(), nil
|
||||
}
|
||||
|
||||
// Fetch updates data if needed and returns true if data was modified.
|
||||
func (s *{{ $query.Name }}) Fetch(ctx context.Context) (bool, error) {
|
||||
lastHash := s.Hash()
|
||||
|
||||
{{ if $query.RequestParams -}}
|
||||
req := s.req
|
||||
req.Hash = lastHash
|
||||
{{- else -}}
|
||||
req := lastHash
|
||||
{{- end }}
|
||||
result, err := s.raw.{{ $query.MethodName }}(ctx, req)
|
||||
if err != nil {
|
||||
return false, errors.Wrap(err, "execute {{ $query.MethodName }}")
|
||||
}
|
||||
|
||||
switch variant := result.(type) {
|
||||
case *tg.{{ $query.ResultName }}:
|
||||
{{ if $query.ManualHash -}}
|
||||
hash := s.computeHash(variant)
|
||||
{{- else -}}
|
||||
hash := variant.Hash
|
||||
{{- end }}
|
||||
|
||||
s.store(inner{{ $query.Name }}{
|
||||
hash: hash,
|
||||
value: variant,
|
||||
})
|
||||
return true, nil
|
||||
case *tg.{{ $query.NotModifiedName }}:
|
||||
if lastHash == 0 {
|
||||
return false, errors.Errorf("got unexpected %T result", result)
|
||||
}
|
||||
return false, nil
|
||||
default:
|
||||
return false, errors.Errorf("unexpected type %T", result)
|
||||
}
|
||||
}
|
||||
{{ end }}
|
||||
|
||||
{{ end }}
|
||||
@@ -0,0 +1,143 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"go/types"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/tools/go/packages"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/query/internal/genutil"
|
||||
)
|
||||
|
||||
func isHashField(field *types.Var) bool {
|
||||
basic, ok := field.Type().(*types.Basic)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
return basic.Kind() == types.Int64 && field.Name() == "Hash"
|
||||
}
|
||||
|
||||
func hasHashField(st *types.Struct) bool {
|
||||
for i := 0; i < st.NumFields(); i++ {
|
||||
if isHashField(st.Field(i)) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
type request struct {
|
||||
name string
|
||||
params []Param
|
||||
}
|
||||
|
||||
func isCachedQuery(args *types.Tuple) (request, bool) {
|
||||
arg := args.At(1)
|
||||
switch req := arg.Type().(type) {
|
||||
case *types.Pointer:
|
||||
named, ok := req.Elem().(*types.Named)
|
||||
if !ok {
|
||||
return request{}, false
|
||||
}
|
||||
|
||||
st, ok := named.Underlying().(*types.Struct)
|
||||
if !ok {
|
||||
return request{}, false
|
||||
}
|
||||
|
||||
var r []Param
|
||||
for i := 0; i < st.NumFields(); i++ {
|
||||
field := st.Field(i)
|
||||
if strings.Contains(field.Name(), "Offset") {
|
||||
return request{}, false
|
||||
}
|
||||
|
||||
if isHashField(field) || field.Name() == "Flags" {
|
||||
continue
|
||||
}
|
||||
|
||||
r = append(r, varToParam(field))
|
||||
}
|
||||
|
||||
return request{
|
||||
name: named.Obj().Name(),
|
||||
params: sortParams(r),
|
||||
}, hasHashField(st)
|
||||
case *types.Basic:
|
||||
if req.Kind() != types.Int64 || arg.Name() != "hash" {
|
||||
return request{}, false
|
||||
}
|
||||
return request{}, true
|
||||
default:
|
||||
return request{}, false
|
||||
}
|
||||
}
|
||||
|
||||
func collect(pkg *packages.Package) []CachedQuery {
|
||||
var r []CachedQuery
|
||||
|
||||
for _, def := range genutil.Funcs(pkg, func(f genutil.Func) bool {
|
||||
return f.Args().Len() == 2 && f.Results().Len() == 2
|
||||
}) {
|
||||
args := def.Args()
|
||||
req, ok := isCachedQuery(args)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
resultNamed, ok := def.Results().At(0).Type().(*types.Named)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
result, ok := resultNamed.Underlying().(*types.Interface)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
impls := genutil.Implementations(pkg, result)
|
||||
if len(impls) != 2 {
|
||||
continue
|
||||
}
|
||||
var (
|
||||
notModified *types.Named
|
||||
pure *types.Named
|
||||
)
|
||||
for _, impl := range impls {
|
||||
if notModified == nil && strings.Contains(impl.Obj().Name(), "NotModified") {
|
||||
notModified = impl
|
||||
continue
|
||||
}
|
||||
|
||||
if pure == nil {
|
||||
pure = impl
|
||||
}
|
||||
}
|
||||
if pure == nil || notModified == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
pureStruct, ok := pure.Underlying().(*types.Struct)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
r = append(r, CachedQuery{
|
||||
Name: def.Decl.Name(),
|
||||
MethodName: def.Decl.Name(),
|
||||
RequestName: req.name,
|
||||
ManualHash: !hasHashField(pureStruct),
|
||||
RequestParams: req.params,
|
||||
ResultName: pure.Obj().Name(),
|
||||
NotModifiedName: notModified.Obj().Name(),
|
||||
})
|
||||
}
|
||||
sort.SliceStable(r, func(i, j int) bool {
|
||||
return r[i].Name < r[j].Name
|
||||
})
|
||||
|
||||
return r
|
||||
}
|
||||
@@ -0,0 +1,65 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"embed"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/signal"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
"go.uber.org/multierr"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/query/internal/genutil"
|
||||
)
|
||||
|
||||
//go:embed _template/*.tmpl
|
||||
var templates embed.FS
|
||||
|
||||
func generate(ctx context.Context, out io.Writer, pkgName string) error {
|
||||
pkg, err := genutil.Load(ctx, "go.mau.fi/mautrix-telegram/pkg/gotd/tg")
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "load")
|
||||
}
|
||||
|
||||
return genutil.WriteTemplate(templates, out, "header", Config{
|
||||
Queries: collect(pkg),
|
||||
Package: pkgName,
|
||||
})
|
||||
}
|
||||
|
||||
func run(ctx context.Context) (err error) {
|
||||
var out io.Writer = os.Stdout
|
||||
|
||||
set := flag.NewFlagSet("gen", flag.ExitOnError)
|
||||
output := set.String("out", "", "output file")
|
||||
pkgName := set.String("package", "cached", "name of package name to generate")
|
||||
if err := set.Parse(os.Args[1:]); err != nil {
|
||||
return errors.Wrap(err, "parse flags")
|
||||
}
|
||||
|
||||
if *output != "" {
|
||||
f, err := os.Create(*output)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "can't create file %q", *output)
|
||||
}
|
||||
defer func() {
|
||||
multierr.AppendInto(&err, f.Close())
|
||||
}()
|
||||
out = f
|
||||
}
|
||||
|
||||
return generate(ctx, out, *pkgName)
|
||||
}
|
||||
|
||||
func main() {
|
||||
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt)
|
||||
defer cancel()
|
||||
|
||||
if err := run(ctx); err != nil {
|
||||
fmt.Println(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,14 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGenerate(t *testing.T) {
|
||||
var out bytes.Buffer
|
||||
if err := generate(context.Background(), &out, "testgen"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,65 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"go/types"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/query/internal/genutil"
|
||||
)
|
||||
|
||||
// Param represents request parameter.
|
||||
type Param struct {
|
||||
// Name to use in function declaration.
|
||||
Name string
|
||||
// OriginalName in struct definition.
|
||||
OriginalName string
|
||||
// Go type.
|
||||
Type string
|
||||
}
|
||||
|
||||
func varToParam(field *types.Var) Param {
|
||||
fieldName := field.Name()
|
||||
fieldName = strings.ToLower(fieldName[:1]) + fieldName[1:]
|
||||
return Param{
|
||||
Name: fieldName,
|
||||
OriginalName: field.Name(),
|
||||
Type: genutil.PrintType(field.Type()),
|
||||
}
|
||||
}
|
||||
|
||||
func sortParams(p []Param) []Param {
|
||||
sort.SliceStable(p, func(i, j int) bool {
|
||||
return p[i].Name < p[j].Name
|
||||
})
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
// CachedQuery is a RPC cacheable query helper.
|
||||
type CachedQuery struct {
|
||||
// Name of struct to generate.
|
||||
Name string
|
||||
// MethodName is name of method of tg.Client.
|
||||
MethodName string
|
||||
// RequestName is name of request struct.
|
||||
RequestName string
|
||||
// ManualHash determines whether hash must be computed using
|
||||
// hand-written function computeHash or not.
|
||||
// Need to resolve case when Telegram does not return hash with result.
|
||||
ManualHash bool
|
||||
// RequestParams contains additional params to send.
|
||||
RequestParams []Param
|
||||
// ResultName is name of result type.
|
||||
ResultName string
|
||||
// NotModifiedName is name of NotModified result type.
|
||||
NotModifiedName string
|
||||
}
|
||||
|
||||
// Config is codegeneration config to use.
|
||||
type Config struct {
|
||||
// Query helpers to generate.
|
||||
Queries []CachedQuery
|
||||
// ResultName package name
|
||||
Package string
|
||||
}
|
||||
Reference in New Issue
Block a user