move gotd fork into repo. (#111)
- update to latest telegram layer - remove some references to fields in tg.Entities that don't exist in the schema - originally added here: https://github.com/beeper/td/commit/820929062a2ba0104397bc01235ab58a9cff780e - referenced here - https://github.com/mautrix/telegramgo/commit/124f0967ed195b5a380c9bd02e170ada9710dde3 - https://github.com/mautrix/telegramgo/commit/4205047aab2e0639217148b5d125bfaab668bd8e
This commit is contained in:
@@ -0,0 +1,23 @@
|
||||
package telegram
|
||||
|
||||
import (
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/auth"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/auth/qrlogin"
|
||||
)
|
||||
|
||||
// Auth returns auth client.
|
||||
func (c *Client) Auth() *auth.Client {
|
||||
return auth.NewClient(
|
||||
c.tg, c.rand, c.appID, c.appHash,
|
||||
)
|
||||
}
|
||||
|
||||
// QR returns QR login helper.
|
||||
func (c *Client) QR() qrlogin.QR {
|
||||
return qrlogin.NewQR(
|
||||
c.tg,
|
||||
c.appID,
|
||||
c.appHash,
|
||||
qrlogin.Options{Migrate: c.MigrateTo},
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,23 @@
|
||||
// Package auth provides authentication on top of tg.Client.
|
||||
package auth
|
||||
|
||||
import (
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tgerr"
|
||||
)
|
||||
|
||||
// IsKeyUnregistered reports whether err is AUTH_KEY_UNREGISTERED error.
|
||||
//
|
||||
// Deprecated: use IsUnauthorized.
|
||||
func IsKeyUnregistered(err error) bool {
|
||||
return tgerr.Is(err, "AUTH_KEY_UNREGISTERED")
|
||||
}
|
||||
|
||||
// IsUnauthorized reports whether err is any 401 UNAUTHORIZED or is a 406
|
||||
// NOT_ACCEPTABLE with AUTH_KEY_DUPLICATED.
|
||||
//
|
||||
// https://core.telegram.org/api/errors#401-unauthorized
|
||||
// https://core.telegram.org/api/errors#406-not-acceptable
|
||||
func IsUnauthorized(err error) bool {
|
||||
return tgerr.IsCode(err, 401) ||
|
||||
(tgerr.IsCode(err, 406) && tgerr.Is(err, "AUTH_KEY_DUPLICATED"))
|
||||
}
|
||||
@@ -0,0 +1,44 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/testutil"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tgmock"
|
||||
)
|
||||
|
||||
const (
|
||||
testAppID = 1
|
||||
testAppHash = "hash"
|
||||
)
|
||||
|
||||
func testClient(invoker tg.Invoker) *Client {
|
||||
return &Client{
|
||||
api: tg.NewClient(invoker),
|
||||
rand: rand.Reader,
|
||||
appID: testAppID,
|
||||
appHash: testAppHash,
|
||||
}
|
||||
}
|
||||
|
||||
func mockClient(t *testing.T) (*tgmock.Mock, *Client) {
|
||||
mock := tgmock.New(t)
|
||||
return mock, NewClient(tg.NewClient(mock), testutil.ZeroRand{}, testAppID, testAppHash)
|
||||
}
|
||||
|
||||
func mockTest(cb func(
|
||||
a *require.Assertions,
|
||||
mock *tgmock.Mock,
|
||||
client *Client,
|
||||
)) func(t *testing.T) {
|
||||
return func(t *testing.T) {
|
||||
a := require.New(t)
|
||||
m, client := mockClient(t)
|
||||
|
||||
cb(a, m, client)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
// Bot performs bot authentication request.
|
||||
func (c *Client) Bot(ctx context.Context, token string) (*tg.AuthAuthorization, error) {
|
||||
auth, err := c.api.AuthImportBotAuthorization(ctx, &tg.AuthImportBotAuthorizationRequest{
|
||||
APIID: c.appID,
|
||||
APIHash: c.appHash,
|
||||
BotAuthToken: token,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result, err := checkResult(auth)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "check")
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
@@ -0,0 +1,46 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tgmock"
|
||||
)
|
||||
|
||||
func TestClient_AuthBot(t *testing.T) {
|
||||
const token = "12345:token"
|
||||
|
||||
t.Run("AuthAuthorization", func(t *testing.T) {
|
||||
mock := tgmock.New(t)
|
||||
|
||||
testUser := &tg.User{}
|
||||
testUser.SetBot(true)
|
||||
|
||||
mock.ExpectCall(&tg.AuthImportBotAuthorizationRequest{
|
||||
BotAuthToken: token,
|
||||
APIID: testAppID,
|
||||
APIHash: testAppHash,
|
||||
}).ThenResult(&tg.AuthAuthorization{User: testUser})
|
||||
|
||||
result, err := testClient(mock).Bot(context.Background(), token)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, testUser, result.User)
|
||||
})
|
||||
|
||||
t.Run("AuthAuthorizationSignUpRequired", func(t *testing.T) {
|
||||
mock := tgmock.New(t)
|
||||
|
||||
mock.ExpectCall(&tg.AuthImportBotAuthorizationRequest{
|
||||
BotAuthToken: token,
|
||||
APIID: testAppID,
|
||||
APIHash: testAppHash,
|
||||
}).ThenResult(&tg.AuthAuthorizationSignUpRequired{})
|
||||
|
||||
result, err := testClient(mock).Bot(context.Background(), token)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, result)
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,30 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
// Client implements Telegram authentication.
|
||||
type Client struct {
|
||||
api *tg.Client
|
||||
rand io.Reader
|
||||
appID int
|
||||
appHash string
|
||||
}
|
||||
|
||||
// NewClient initializes and returns Telegram authentication client.
|
||||
func NewClient(
|
||||
api *tg.Client,
|
||||
rand io.Reader,
|
||||
appID int,
|
||||
appHash string,
|
||||
) *Client {
|
||||
return &Client{
|
||||
api: api,
|
||||
rand: rand,
|
||||
appID: appID,
|
||||
appHash: appHash,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,307 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
// NewFlow initializes new authentication flow.
|
||||
func NewFlow(auth UserAuthenticator, opt SendCodeOptions) Flow {
|
||||
return Flow{
|
||||
Auth: auth,
|
||||
Options: opt,
|
||||
}
|
||||
}
|
||||
|
||||
// Flow simplifies boilerplate for authentication flow.
|
||||
type Flow struct {
|
||||
Auth UserAuthenticator
|
||||
Options SendCodeOptions
|
||||
}
|
||||
|
||||
func (f Flow) handleSignUp(ctx context.Context, client FlowClient, phone, hash string, s *SignUpRequired) error {
|
||||
if err := f.Auth.AcceptTermsOfService(ctx, s.TermsOfService); err != nil {
|
||||
return errors.Wrap(err, "confirm TOS")
|
||||
}
|
||||
info, err := f.Auth.SignUp(ctx)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "sign up info not provided")
|
||||
}
|
||||
if _, err := client.SignUp(ctx, SignUp{
|
||||
PhoneNumber: phone,
|
||||
PhoneCodeHash: hash,
|
||||
FirstName: info.FirstName,
|
||||
LastName: info.LastName,
|
||||
}); err != nil {
|
||||
return errors.Wrap(err, "sign up")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Run starts authentication flow on client.
|
||||
func (f Flow) Run(ctx context.Context, client FlowClient) error {
|
||||
if f.Auth == nil {
|
||||
return errors.New("no UserAuthenticator provided")
|
||||
}
|
||||
|
||||
phone, err := f.Auth.Phone(ctx)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "get phone")
|
||||
}
|
||||
|
||||
sentCode, err := client.SendCode(ctx, phone, f.Options)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "send code")
|
||||
}
|
||||
switch s := sentCode.(type) {
|
||||
case *tg.AuthSentCode:
|
||||
hash := s.PhoneCodeHash
|
||||
code, err := f.Auth.Code(ctx, s)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "get code")
|
||||
}
|
||||
|
||||
_, signInErr := client.SignIn(ctx, phone, code, hash)
|
||||
if errors.Is(signInErr, ErrPasswordAuthNeeded) {
|
||||
password, err := f.Auth.Password(ctx)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "get password")
|
||||
}
|
||||
if _, err := client.Password(ctx, password); err != nil {
|
||||
return errors.Wrap(err, "sign in with password")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
var signUpRequired *SignUpRequired
|
||||
if errors.As(signInErr, &signUpRequired) {
|
||||
return f.handleSignUp(ctx, client, phone, hash, signUpRequired)
|
||||
}
|
||||
|
||||
if signInErr != nil {
|
||||
return errors.Wrap(signInErr, "sign in")
|
||||
}
|
||||
|
||||
return nil
|
||||
case *tg.AuthSentCodeSuccess:
|
||||
switch a := s.Authorization.(type) {
|
||||
case *tg.AuthAuthorization:
|
||||
// Looks that we are already authorized.
|
||||
return nil
|
||||
case *tg.AuthAuthorizationSignUpRequired:
|
||||
if err := f.handleSignUp(ctx, client, phone, "", &SignUpRequired{
|
||||
TermsOfService: a.TermsOfService,
|
||||
}); err != nil {
|
||||
// TODO: not sure that blank hash will work here
|
||||
return errors.Wrap(err, "sign up after auth sent code success")
|
||||
}
|
||||
return nil
|
||||
default:
|
||||
return errors.Errorf("unexpected authorization type: %T", a)
|
||||
}
|
||||
default:
|
||||
return errors.Errorf("unexpected sent code type: %T", sentCode)
|
||||
}
|
||||
}
|
||||
|
||||
// FlowClient abstracts telegram client for Flow.
|
||||
type FlowClient interface {
|
||||
SignIn(ctx context.Context, phone, code, codeHash string) (*tg.AuthAuthorization, error)
|
||||
SendCode(ctx context.Context, phone string, options SendCodeOptions) (tg.AuthSentCodeClass, error)
|
||||
Password(ctx context.Context, password string) (*tg.AuthAuthorization, error)
|
||||
SignUp(ctx context.Context, s SignUp) (*tg.AuthAuthorization, error)
|
||||
}
|
||||
|
||||
// CodeAuthenticator asks user for received authentication code.
|
||||
type CodeAuthenticator interface {
|
||||
Code(ctx context.Context, sentCode *tg.AuthSentCode) (string, error)
|
||||
}
|
||||
|
||||
// CodeAuthenticatorFunc is functional wrapper for CodeAuthenticator.
|
||||
type CodeAuthenticatorFunc func(ctx context.Context, sentCode *tg.AuthSentCode) (string, error)
|
||||
|
||||
// Code implements CodeAuthenticator interface.
|
||||
func (c CodeAuthenticatorFunc) Code(ctx context.Context, sentCode *tg.AuthSentCode) (string, error) {
|
||||
return c(ctx, sentCode)
|
||||
}
|
||||
|
||||
// UserInfo represents user info required for sign up.
|
||||
type UserInfo struct {
|
||||
FirstName string
|
||||
LastName string
|
||||
}
|
||||
|
||||
// UserAuthenticator asks user for phone, password and received authentication code.
|
||||
type UserAuthenticator interface {
|
||||
Phone(ctx context.Context) (string, error)
|
||||
Password(ctx context.Context) (string, error)
|
||||
AcceptTermsOfService(ctx context.Context, tos tg.HelpTermsOfService) error
|
||||
SignUp(ctx context.Context) (UserInfo, error)
|
||||
CodeAuthenticator
|
||||
}
|
||||
|
||||
type noSignUp struct{}
|
||||
|
||||
func (c noSignUp) SignUp(ctx context.Context) (UserInfo, error) {
|
||||
return UserInfo{}, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (c noSignUp) AcceptTermsOfService(ctx context.Context, tos tg.HelpTermsOfService) error {
|
||||
return &SignUpRequired{TermsOfService: tos}
|
||||
}
|
||||
|
||||
type constantAuth struct {
|
||||
phone, password string
|
||||
CodeAuthenticator
|
||||
noSignUp
|
||||
}
|
||||
|
||||
func (c constantAuth) Phone(ctx context.Context) (string, error) {
|
||||
return c.phone, nil
|
||||
}
|
||||
|
||||
func (c constantAuth) Password(ctx context.Context) (string, error) {
|
||||
return c.password, nil
|
||||
}
|
||||
|
||||
// Constant creates UserAuthenticator with constant phone and password.
|
||||
func Constant(phone, password string, code CodeAuthenticator) UserAuthenticator {
|
||||
return constantAuth{
|
||||
phone: phone,
|
||||
password: password,
|
||||
CodeAuthenticator: code,
|
||||
}
|
||||
}
|
||||
|
||||
type envAuth struct {
|
||||
prefix string
|
||||
CodeAuthenticator
|
||||
noSignUp
|
||||
}
|
||||
|
||||
func (e envAuth) lookup(k string) (string, error) {
|
||||
env := e.prefix + k
|
||||
v, ok := os.LookupEnv(env)
|
||||
if !ok {
|
||||
return "", errors.Errorf("environment variable %q not set", env)
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func (e envAuth) Phone(ctx context.Context) (string, error) {
|
||||
return e.lookup("PHONE")
|
||||
}
|
||||
|
||||
func (e envAuth) Password(ctx context.Context) (string, error) {
|
||||
p, err := e.lookup("PASSWORD")
|
||||
if err != nil {
|
||||
return "", ErrPasswordNotProvided
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// Env creates UserAuthenticator which gets phone and password from environment variables.
|
||||
func Env(prefix string, code CodeAuthenticator) UserAuthenticator {
|
||||
return envAuth{
|
||||
prefix: prefix,
|
||||
CodeAuthenticator: code,
|
||||
noSignUp: noSignUp{},
|
||||
}
|
||||
}
|
||||
|
||||
// ErrPasswordNotProvided means that password requested by Telegram,
|
||||
// but not provided by user.
|
||||
var ErrPasswordNotProvided = errors.New("password requested but not provided")
|
||||
|
||||
type codeOnlyAuth struct {
|
||||
phone string
|
||||
CodeAuthenticator
|
||||
noSignUp
|
||||
}
|
||||
|
||||
func (c codeOnlyAuth) Phone(ctx context.Context) (string, error) {
|
||||
return c.phone, nil
|
||||
}
|
||||
|
||||
func (c codeOnlyAuth) Password(ctx context.Context) (string, error) {
|
||||
return "", ErrPasswordNotProvided
|
||||
}
|
||||
|
||||
// CodeOnly creates UserAuthenticator with constant phone and no password.
|
||||
func CodeOnly(phone string, code CodeAuthenticator) UserAuthenticator {
|
||||
return codeOnlyAuth{
|
||||
phone: phone,
|
||||
CodeAuthenticator: code,
|
||||
}
|
||||
}
|
||||
|
||||
type testAuth struct {
|
||||
dc int
|
||||
phone string
|
||||
}
|
||||
|
||||
func (t testAuth) Phone(ctx context.Context) (string, error) { return t.phone, nil }
|
||||
func (t testAuth) Password(ctx context.Context) (string, error) { return "", ErrPasswordNotProvided }
|
||||
func (t testAuth) Code(ctx context.Context, sentCode *tg.AuthSentCode) (string, error) {
|
||||
type notFlashing interface {
|
||||
GetLength() int
|
||||
}
|
||||
|
||||
length := 5
|
||||
if sentCode != nil {
|
||||
typ, ok := sentCode.Type.(notFlashing)
|
||||
if !ok {
|
||||
return "", errors.Errorf("unexpected type: %T", sentCode.Type)
|
||||
}
|
||||
length = typ.GetLength()
|
||||
}
|
||||
|
||||
return strings.Repeat(strconv.Itoa(t.dc), length), nil
|
||||
}
|
||||
|
||||
func (t testAuth) AcceptTermsOfService(ctx context.Context, tos tg.HelpTermsOfService) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t testAuth) SignUp(ctx context.Context) (UserInfo, error) {
|
||||
return UserInfo{
|
||||
FirstName: "Test",
|
||||
LastName: "User",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Test returns UserAuthenticator that authenticates via testing credentials.
|
||||
//
|
||||
// Can be used only with testing server. Will perform sign up if test user is
|
||||
// not registered.
|
||||
func Test(randReader io.Reader, dc int) UserAuthenticator {
|
||||
// 99966XYYYY, X = dc_id, Y = random numbers, code = X repeat 6.
|
||||
// The n value is from 0000 to 9999.
|
||||
n, err := crypto.RandInt64n(randReader, 1000)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
phone := fmt.Sprintf("99966%d%04d", dc, n)
|
||||
|
||||
return TestUser(phone, dc)
|
||||
}
|
||||
|
||||
// TestUser returns UserAuthenticator that authenticates via testing credentials.
|
||||
// Uses given phone to sign in/sign up.
|
||||
//
|
||||
// Can be used only with testing server. Will perform sign up if test user is
|
||||
// not registered.
|
||||
func TestUser(phone string, dc int) UserAuthenticator {
|
||||
return testAuth{
|
||||
dc: dc,
|
||||
phone: phone,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,114 @@
|
||||
package auth_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/auth"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/testutil"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
func askCode(code string, err error) auth.CodeAuthenticatorFunc {
|
||||
return func(ctx context.Context, sentCode *tg.AuthSentCode) (string, error) {
|
||||
return code, err
|
||||
}
|
||||
}
|
||||
|
||||
func TestConstantAuth(t *testing.T) {
|
||||
a := require.New(t)
|
||||
authConst := auth.Constant("phone", "password", askCode("123", nil))
|
||||
ctx := context.Background()
|
||||
|
||||
result, err := authConst.Code(ctx, nil)
|
||||
a.NoError(err)
|
||||
a.Equal("123", result)
|
||||
|
||||
result, err = authConst.Phone(ctx)
|
||||
a.NoError(err)
|
||||
a.Equal("phone", result)
|
||||
|
||||
result, err = authConst.Password(ctx)
|
||||
a.NoError(err)
|
||||
a.Equal("password", result)
|
||||
}
|
||||
|
||||
func TestCodeOnlyAuth(t *testing.T) {
|
||||
a := require.New(t)
|
||||
authCodeOnly := auth.CodeOnly("phone", askCode("123", nil))
|
||||
ctx := context.Background()
|
||||
|
||||
result, err := authCodeOnly.Code(ctx, nil)
|
||||
a.NoError(err)
|
||||
a.Equal("123", result)
|
||||
|
||||
result, err = authCodeOnly.Phone(ctx)
|
||||
a.NoError(err)
|
||||
a.Equal("phone", result)
|
||||
|
||||
_, err = authCodeOnly.Password(ctx)
|
||||
a.ErrorIs(err, auth.ErrPasswordNotProvided)
|
||||
}
|
||||
|
||||
func TestEnvAuth(t *testing.T) {
|
||||
a := require.New(t)
|
||||
ctx := context.Background()
|
||||
|
||||
prefix := "TEST_ENV_AUTH_"
|
||||
authEnv := auth.Env(prefix, askCode("123", nil))
|
||||
|
||||
result, err := authEnv.Code(ctx, nil)
|
||||
a.NoError(err)
|
||||
a.Equal("123", result)
|
||||
|
||||
_, err = authEnv.Phone(ctx)
|
||||
a.Error(err)
|
||||
|
||||
_, err = authEnv.Password(ctx)
|
||||
a.ErrorIs(err, auth.ErrPasswordNotProvided)
|
||||
|
||||
// Set envs.
|
||||
testutil.SetEnv(t, prefix+"PHONE", "phone")
|
||||
testutil.SetEnv(t, prefix+"PASSWORD", "password")
|
||||
|
||||
result, err = authEnv.Phone(ctx)
|
||||
a.NoError(err)
|
||||
a.Equal("phone", result)
|
||||
|
||||
result, err = authEnv.Password(ctx)
|
||||
a.NoError(err)
|
||||
a.Equal("password", result)
|
||||
}
|
||||
|
||||
func TestTestAuth(t *testing.T) {
|
||||
a := require.New(t)
|
||||
ctx := context.Background()
|
||||
testAuth := auth.Test(testutil.ZeroRand{}, 2)
|
||||
|
||||
_, err := testAuth.Code(ctx, &tg.AuthSentCode{
|
||||
Type: &tg.AuthSentCodeTypeFlashCall{},
|
||||
})
|
||||
a.Error(err)
|
||||
|
||||
result, err := testAuth.Code(ctx, nil)
|
||||
a.NoError(err)
|
||||
a.Equal("22222", result)
|
||||
|
||||
result, err = testAuth.Code(ctx, &tg.AuthSentCode{
|
||||
Type: &tg.AuthSentCodeTypeApp{
|
||||
Length: 1,
|
||||
},
|
||||
})
|
||||
a.NoError(err)
|
||||
a.Equal("2", result)
|
||||
|
||||
result, err = testAuth.Phone(ctx)
|
||||
a.NoError(err)
|
||||
a.True(strings.HasPrefix(result, "999662"))
|
||||
|
||||
_, err = testAuth.Password(ctx)
|
||||
a.ErrorIs(err, auth.ErrPasswordNotProvided)
|
||||
}
|
||||
@@ -0,0 +1,179 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto/srp"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
// PasswordHash computes password hash to log in.
|
||||
//
|
||||
// See https://core.telegram.org/api/srp#checking-the-password-with-srp.
|
||||
func PasswordHash(
|
||||
password []byte,
|
||||
srpID int64,
|
||||
srpB, secureRandom []byte,
|
||||
alg tg.PasswordKdfAlgoClass,
|
||||
) (*tg.InputCheckPasswordSRP, error) {
|
||||
s := srp.NewSRP(crypto.DefaultRand())
|
||||
|
||||
algo, ok := alg.(*tg.PasswordKdfAlgoSHA256SHA256PBKDF2HMACSHA512iter100000SHA256ModPow)
|
||||
if !ok {
|
||||
return nil, errors.Errorf("unsupported algo: %T", alg)
|
||||
}
|
||||
|
||||
a, err := s.Hash(password, srpB, secureRandom, srp.Input(*algo))
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "create SRP answer")
|
||||
}
|
||||
|
||||
return &tg.InputCheckPasswordSRP{
|
||||
SRPID: srpID,
|
||||
A: a.A,
|
||||
M1: a.M1,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// NewPasswordHash computes new password hash to update password.
|
||||
//
|
||||
// Notice that NewPasswordHash mutates given alg.
|
||||
//
|
||||
// See https://core.telegram.org/api/srp#setting-a-new-2fa-password.
|
||||
func NewPasswordHash(
|
||||
password []byte,
|
||||
algo *tg.PasswordKdfAlgoSHA256SHA256PBKDF2HMACSHA512iter100000SHA256ModPow,
|
||||
) (hash []byte, _ error) {
|
||||
s := srp.NewSRP(crypto.DefaultRand())
|
||||
|
||||
hash, newSalt, err := s.NewHash(password, srp.Input(*algo))
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "create SRP answer")
|
||||
}
|
||||
algo.Salt1 = newSalt
|
||||
|
||||
return hash, nil
|
||||
}
|
||||
|
||||
var (
|
||||
emptyPassword tg.InputCheckPasswordSRPClass = &tg.InputCheckPasswordEmpty{}
|
||||
)
|
||||
|
||||
// UpdatePasswordOptions is options structure for UpdatePassword.
|
||||
type UpdatePasswordOptions struct {
|
||||
// Hint is new password hint.
|
||||
Hint string
|
||||
// Password is password callback.
|
||||
//
|
||||
// If password was requested and Password is nil, ErrPasswordNotProvided error will be returned.
|
||||
Password func(ctx context.Context) (string, error)
|
||||
}
|
||||
|
||||
// UpdatePassword sets new cloud password for this account.
|
||||
//
|
||||
// See https://core.telegram.org/api/srp#setting-a-new-2fa-password.
|
||||
func (c *Client) UpdatePassword(
|
||||
ctx context.Context,
|
||||
newPassword string,
|
||||
opts UpdatePasswordOptions,
|
||||
) error {
|
||||
p, err := c.api.AccountGetPassword(ctx)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "get SRP parameters")
|
||||
}
|
||||
|
||||
algo, ok := p.NewAlgo.(*tg.PasswordKdfAlgoSHA256SHA256PBKDF2HMACSHA512iter100000SHA256ModPow)
|
||||
if !ok {
|
||||
return errors.Errorf("unsupported algo: %T", p.NewAlgo)
|
||||
}
|
||||
|
||||
newHash, err := NewPasswordHash([]byte(newPassword), algo)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "compute new password hash")
|
||||
}
|
||||
|
||||
var old = emptyPassword
|
||||
if p.HasPassword {
|
||||
if opts.Password == nil {
|
||||
return ErrPasswordNotProvided
|
||||
}
|
||||
|
||||
oldPassword, err := opts.Password(ctx)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "get password")
|
||||
}
|
||||
|
||||
hash, err := PasswordHash([]byte(oldPassword), p.SRPID, p.SRPB, p.SecureRandom, p.CurrentAlgo)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "compute old password hash")
|
||||
}
|
||||
old = hash
|
||||
}
|
||||
|
||||
if _, err := c.api.AccountUpdatePasswordSettings(ctx, &tg.AccountUpdatePasswordSettingsRequest{
|
||||
Password: old,
|
||||
NewSettings: tg.AccountPasswordInputSettings{
|
||||
NewAlgo: algo,
|
||||
NewPasswordHash: newHash,
|
||||
Hint: opts.Hint,
|
||||
},
|
||||
}); err != nil {
|
||||
return errors.Wrap(err, "update password")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ResetFailedWaitError reports that you recently requested a password reset that was cancel and need to wait until the
|
||||
// specified date before requesting another reset.
|
||||
type ResetFailedWaitError struct {
|
||||
Result tg.AccountResetPasswordFailedWait
|
||||
}
|
||||
|
||||
// Until returns time required to wait.
|
||||
func (r ResetFailedWaitError) Until() time.Duration {
|
||||
retryDate := time.Unix(int64(r.Result.RetryDate), 0)
|
||||
return time.Until(retryDate)
|
||||
}
|
||||
|
||||
// Error implements error.
|
||||
func (r *ResetFailedWaitError) Error() string {
|
||||
return fmt.Sprintf("wait to reset password (%s)", r.Until())
|
||||
}
|
||||
|
||||
// ResetPassword resets cloud password and returns time to wait until reset be performed.
|
||||
// If time is zero, password was successfully reset.
|
||||
//
|
||||
// May return ResetFailedWaitError.
|
||||
//
|
||||
// See https://core.telegram.org/api/srp#password-reset.
|
||||
func (c *Client) ResetPassword(ctx context.Context) (time.Time, error) {
|
||||
r, err := c.api.AccountResetPassword(ctx)
|
||||
if err != nil {
|
||||
return time.Time{}, errors.Wrap(err, "reset password")
|
||||
}
|
||||
switch v := r.(type) {
|
||||
case *tg.AccountResetPasswordFailedWait:
|
||||
return time.Time{}, &ResetFailedWaitError{Result: *v}
|
||||
case *tg.AccountResetPasswordRequestedWait:
|
||||
return time.Unix(int64(v.UntilDate), 0), nil
|
||||
case *tg.AccountResetPasswordOk:
|
||||
return time.Time{}, nil
|
||||
default:
|
||||
return time.Time{}, errors.Errorf("unexpected type %T", v)
|
||||
}
|
||||
}
|
||||
|
||||
// CancelPasswordReset cancels password reset.
|
||||
//
|
||||
// See https://core.telegram.org/api/srp#password-reset.
|
||||
func (c *Client) CancelPasswordReset(ctx context.Context) error {
|
||||
if _, err := c.api.AccountDeclinePasswordReset(ctx); err != nil {
|
||||
return errors.Wrap(err, "cancel password reset")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,62 @@
|
||||
package auth_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/auth"
|
||||
)
|
||||
|
||||
func ExampleClient_UpdatePassword() {
|
||||
ctx := context.Background()
|
||||
client := telegram.NewClient(telegram.TestAppID, telegram.TestAppHash, telegram.Options{})
|
||||
if err := client.Run(ctx, func(ctx context.Context) error {
|
||||
// Updating password.
|
||||
if err := client.Auth().UpdatePassword(ctx, "new_password", auth.UpdatePasswordOptions{
|
||||
// Hint sets new password hint.
|
||||
Hint: "new password hint",
|
||||
// Password will be called if old password is requested by Telegram.
|
||||
//
|
||||
// If password was requested and Password is nil, auth.ErrPasswordNotProvided error will be returned.
|
||||
Password: func(ctx context.Context) (string, error) {
|
||||
return "old_password", nil
|
||||
},
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func ExampleClient_ResetPassword() {
|
||||
ctx := context.Background()
|
||||
client := telegram.NewClient(telegram.TestAppID, telegram.TestAppHash, telegram.Options{})
|
||||
if err := client.Run(ctx, func(ctx context.Context) error {
|
||||
wait, err := client.Auth().ResetPassword(ctx)
|
||||
var waitErr *auth.ResetFailedWaitError
|
||||
switch {
|
||||
case errors.As(err, &waitErr):
|
||||
// Telegram requested wait until making new reset request.
|
||||
fmt.Printf("Wait until %s to reset password.\n", wait.String())
|
||||
case err != nil:
|
||||
return err
|
||||
}
|
||||
|
||||
// If returned time is zero, password was successfully reset.
|
||||
if wait.IsZero() {
|
||||
fmt.Println("Password was reset.")
|
||||
return nil
|
||||
}
|
||||
|
||||
fmt.Printf("Password will be reset on %s.\n", wait.String())
|
||||
return nil
|
||||
}); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,169 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/testutil"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tgmock"
|
||||
)
|
||||
|
||||
func TestPasswordHash(t *testing.T) {
|
||||
a := require.New(t)
|
||||
_, err := PasswordHash(nil, 0, nil, nil, nil)
|
||||
a.Error(err, "unsupported algo")
|
||||
}
|
||||
|
||||
var testAlgo = &tg.PasswordKdfAlgoSHA256SHA256PBKDF2HMACSHA512iter100000SHA256ModPow{
|
||||
Salt1: []uint8{
|
||||
230, 200, 149, 125, 223, 152, 141, 72,
|
||||
},
|
||||
Salt2: []uint8{
|
||||
159, 99, 68, 130, 43, 9, 108, 255, 135, 239, 164, 38, 245, 120, 87, 182,
|
||||
},
|
||||
G: 3,
|
||||
P: []uint8{
|
||||
199, 28, 174, 185, 198, 177, 201, 4, 142, 108, 82, 47, 112, 241, 63, 115,
|
||||
152, 13, 64, 35, 142, 62, 33, 193, 73, 52, 208, 55, 86, 61, 147, 15,
|
||||
72, 25, 138, 10, 167, 193, 64, 88, 34, 148, 147, 210, 37, 48, 244, 219,
|
||||
250, 51, 111, 110, 10, 201, 37, 19, 149, 67, 174, 212, 76, 206, 124, 55,
|
||||
32, 253, 81, 246, 148, 88, 112, 90, 198, 140, 212, 254, 107, 107, 19, 171,
|
||||
220, 151, 70, 81, 41, 105, 50, 132, 84, 241, 143, 175, 140, 89, 95, 100,
|
||||
36, 119, 254, 150, 187, 42, 148, 29, 91, 205, 29, 74, 200, 204, 73, 136,
|
||||
7, 8, 250, 155, 55, 142, 60, 79, 58, 144, 96, 190, 230, 124, 249, 164,
|
||||
164, 166, 149, 129, 16, 81, 144, 126, 22, 39, 83, 181, 107, 15, 107, 65,
|
||||
13, 186, 116, 216, 168, 75, 42, 20, 179, 20, 78, 14, 241, 40, 71, 84,
|
||||
253, 23, 237, 149, 13, 89, 101, 180, 185, 221, 70, 88, 45, 177, 23, 141,
|
||||
22, 156, 107, 196, 101, 176, 214, 255, 156, 163, 146, 143, 239, 91, 154, 228,
|
||||
228, 24, 252, 21, 232, 62, 190, 160, 248, 127, 169, 255, 94, 237, 112, 5,
|
||||
13, 237, 40, 73, 244, 123, 249, 89, 217, 86, 133, 12, 233, 41, 133, 31,
|
||||
13, 129, 21, 246, 53, 177, 5, 238, 46, 78, 21, 208, 75, 36, 84, 191,
|
||||
111, 79, 173, 240, 52, 177, 4, 3, 17, 156, 216, 227, 185, 47, 204, 91,
|
||||
},
|
||||
}
|
||||
|
||||
func TestClient_UpdatePassword(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
expectCall := func(a *require.Assertions, m *tgmock.Mock, hasPassword bool) *tgmock.RequestBuilder {
|
||||
p := &tg.AccountPassword{
|
||||
HasPassword: hasPassword,
|
||||
NewAlgo: testAlgo,
|
||||
NewSecureAlgo: &tg.SecurePasswordKdfAlgoUnknown{},
|
||||
}
|
||||
if hasPassword {
|
||||
p.CurrentAlgo = testAlgo
|
||||
}
|
||||
p.SetFlags()
|
||||
return m.ExpectCall(&tg.AccountGetPasswordRequest{}).
|
||||
ThenResult(p).ExpectFunc(func(b bin.Encoder) {
|
||||
a.IsType(&tg.AccountUpdatePasswordSettingsRequest{}, b)
|
||||
r := b.(*tg.AccountUpdatePasswordSettingsRequest)
|
||||
|
||||
if !hasPassword {
|
||||
a.Equal(emptyPassword, r.Password)
|
||||
} else {
|
||||
a.NotEqual(emptyPassword, r.Password)
|
||||
}
|
||||
a.NotEmpty(r.NewSettings.NewPasswordHash)
|
||||
a.Equal("hint", r.NewSettings.Hint)
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("PasswordNotRequired", mockTest(func(
|
||||
a *require.Assertions,
|
||||
m *tgmock.Mock,
|
||||
client *Client,
|
||||
) {
|
||||
m.ExpectCall(&tg.AccountGetPasswordRequest{}).ThenErr(testutil.TestError())
|
||||
a.Error(client.UpdatePassword(ctx, "", UpdatePasswordOptions{}))
|
||||
|
||||
expectCall(a, m, false).ThenTrue()
|
||||
a.NoError(client.UpdatePassword(ctx, "", UpdatePasswordOptions{
|
||||
Hint: "hint",
|
||||
}))
|
||||
}))
|
||||
|
||||
t.Run("PasswordRequired", mockTest(func(
|
||||
a *require.Assertions,
|
||||
m *tgmock.Mock,
|
||||
client *Client,
|
||||
) {
|
||||
m.ExpectCall(&tg.AccountGetPasswordRequest{}).
|
||||
ThenResult(&tg.AccountPassword{
|
||||
HasPassword: true,
|
||||
NewAlgo: testAlgo,
|
||||
CurrentAlgo: testAlgo,
|
||||
NewSecureAlgo: &tg.SecurePasswordKdfAlgoUnknown{},
|
||||
})
|
||||
a.ErrorIs(client.UpdatePassword(ctx, "", UpdatePasswordOptions{}), ErrPasswordNotProvided)
|
||||
|
||||
m.ExpectCall(&tg.AccountGetPasswordRequest{}).
|
||||
ThenResult(&tg.AccountPassword{
|
||||
HasPassword: true,
|
||||
NewAlgo: testAlgo,
|
||||
CurrentAlgo: testAlgo,
|
||||
NewSecureAlgo: &tg.SecurePasswordKdfAlgoUnknown{},
|
||||
})
|
||||
a.ErrorIs(client.UpdatePassword(ctx, "", UpdatePasswordOptions{
|
||||
Hint: "hint",
|
||||
Password: func(ctx context.Context) (string, error) {
|
||||
return "", testutil.TestError()
|
||||
},
|
||||
}), testutil.TestError())
|
||||
|
||||
expectCall(a, m, true).ThenTrue()
|
||||
a.NoError(client.UpdatePassword(ctx, "", UpdatePasswordOptions{
|
||||
Hint: "hint",
|
||||
Password: func(ctx context.Context) (string, error) {
|
||||
return "password", nil
|
||||
},
|
||||
}))
|
||||
}))
|
||||
}
|
||||
|
||||
func TestClient_ResetPassword(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
wait := time.Date(2022, 1, 1, 0, 0, 0, 0, time.UTC).Unix()
|
||||
mockTest(func(a *require.Assertions, mock *tgmock.Mock, client *Client) {
|
||||
mock.ExpectCall(&tg.AccountResetPasswordRequest{}).ThenErr(testutil.TestError())
|
||||
_, err := client.ResetPassword(ctx)
|
||||
a.Error(err)
|
||||
|
||||
mock.ExpectCall(&tg.AccountResetPasswordRequest{}).ThenResult(&tg.AccountResetPasswordFailedWait{
|
||||
RetryDate: int(wait),
|
||||
})
|
||||
var waitErr *ResetFailedWaitError
|
||||
_, err = client.ResetPassword(ctx)
|
||||
a.ErrorAs(err, &waitErr)
|
||||
a.Equal(int(wait), waitErr.Result.RetryDate)
|
||||
a.NotEmpty(waitErr.Error())
|
||||
|
||||
mock.ExpectCall(&tg.AccountResetPasswordRequest{}).ThenResult(&tg.AccountResetPasswordOk{})
|
||||
r, err := client.ResetPassword(ctx)
|
||||
a.NoError(err)
|
||||
a.True(r.IsZero())
|
||||
|
||||
mock.ExpectCall(&tg.AccountResetPasswordRequest{}).ThenResult(&tg.AccountResetPasswordRequestedWait{
|
||||
UntilDate: int(wait),
|
||||
})
|
||||
r, err = client.ResetPassword(ctx)
|
||||
a.NoError(err)
|
||||
a.False(r.IsZero())
|
||||
})(t)
|
||||
}
|
||||
|
||||
func TestClient_CancelPasswordReset(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
mockTest(func(a *require.Assertions, mock *tgmock.Mock, client *Client) {
|
||||
mock.ExpectCall(&tg.AccountDeclinePasswordResetRequest{}).ThenErr(testutil.TestError())
|
||||
a.Error(client.CancelPasswordReset(ctx))
|
||||
|
||||
mock.ExpectCall(&tg.AccountDeclinePasswordResetRequest{}).ThenTrue()
|
||||
a.NoError(client.CancelPasswordReset(ctx))
|
||||
})(t)
|
||||
}
|
||||
@@ -0,0 +1,20 @@
|
||||
package qrlogin
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
// AcceptQR accepts given token.
|
||||
//
|
||||
// See https://core.telegram.org/api/qr-login#accepting-a-login-token.
|
||||
func AcceptQR(ctx context.Context, raw *tg.Client, t Token) (*tg.Authorization, error) {
|
||||
auth, err := raw.AuthAcceptLoginToken(ctx, t.token)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "accept")
|
||||
}
|
||||
return auth, nil
|
||||
}
|
||||
@@ -0,0 +1,23 @@
|
||||
package qrlogin
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
// MigrationNeededError reports that Telegram requested DC migration to continue login.
|
||||
type MigrationNeededError struct {
|
||||
MigrateTo *tg.AuthLoginTokenMigrateTo
|
||||
|
||||
// Tried indicates that the migration was attempted.
|
||||
//
|
||||
// Deprecated: do not use. QR login uses migrate function passed via
|
||||
// options.
|
||||
Tried bool
|
||||
}
|
||||
|
||||
// Error implements error.
|
||||
func (m *MigrationNeededError) Error() string {
|
||||
return fmt.Sprintf("migration to %d needed", m.MigrateTo.DCID)
|
||||
}
|
||||
@@ -0,0 +1,20 @@
|
||||
package qrlogin
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/clock"
|
||||
)
|
||||
|
||||
// Options of QR.
|
||||
type Options struct {
|
||||
Migrate func(ctx context.Context, dcID int) error
|
||||
Clock clock.Clock
|
||||
}
|
||||
|
||||
func (o *Options) setDefaults() {
|
||||
// It's okay to use zero value Migrate.
|
||||
if o.Clock == nil {
|
||||
o.Clock = clock.System
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,175 @@
|
||||
// Package qrlogin provides QR login flow implementation.
|
||||
//
|
||||
// See https://core.telegram.org/api/qr-login.
|
||||
package qrlogin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/clock"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
// QR implements Telegram QR login flow.
|
||||
type QR struct {
|
||||
api *tg.Client
|
||||
appID int
|
||||
appHash string
|
||||
migrate func(ctx context.Context, dcID int) error
|
||||
clock clock.Clock
|
||||
}
|
||||
|
||||
// NewQR creates new QR
|
||||
func NewQR(api *tg.Client, appID int, appHash string, opts Options) QR {
|
||||
opts.setDefaults()
|
||||
return QR{
|
||||
api: api,
|
||||
appID: appID,
|
||||
appHash: appHash,
|
||||
clock: opts.Clock,
|
||||
migrate: opts.Migrate,
|
||||
}
|
||||
}
|
||||
|
||||
// Export exports new login token.
|
||||
//
|
||||
// See https://core.telegram.org/api/qr-login#exporting-a-login-token.
|
||||
func (q QR) Export(ctx context.Context, exceptIDs ...int64) (Token, error) {
|
||||
result, err := q.api.AuthExportLoginToken(ctx, &tg.AuthExportLoginTokenRequest{
|
||||
APIID: q.appID,
|
||||
APIHash: q.appHash,
|
||||
ExceptIDs: exceptIDs,
|
||||
})
|
||||
if err != nil {
|
||||
return Token{}, errors.Wrap(err, "export")
|
||||
}
|
||||
|
||||
t, ok := result.(*tg.AuthLoginToken)
|
||||
if !ok {
|
||||
return Token{}, errors.Errorf("unexpected type %T", result)
|
||||
}
|
||||
return NewToken(t.Token, t.Expires), nil
|
||||
}
|
||||
|
||||
// Accept accepts given token.
|
||||
//
|
||||
// See https://core.telegram.org/api/qr-login#accepting-a-login-token.
|
||||
func (q QR) Accept(ctx context.Context, t Token) (*tg.Authorization, error) {
|
||||
return AcceptQR(ctx, q.api, t)
|
||||
}
|
||||
|
||||
// Import imports accepted token.
|
||||
//
|
||||
// See https://core.telegram.org/api/qr-login#confirming-importing-the-login-token.
|
||||
func (q QR) Import(ctx context.Context) (*tg.AuthAuthorization, error) {
|
||||
result, err := q.api.AuthExportLoginToken(ctx, &tg.AuthExportLoginTokenRequest{
|
||||
APIID: q.appID,
|
||||
APIHash: q.appHash,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "import")
|
||||
}
|
||||
|
||||
switch t := result.(type) {
|
||||
case *tg.AuthLoginTokenMigrateTo:
|
||||
if q.migrate == nil {
|
||||
return nil, &MigrationNeededError{
|
||||
MigrateTo: t,
|
||||
}
|
||||
}
|
||||
if err := q.migrate(ctx, t.DCID); err != nil {
|
||||
return nil, errors.Wrap(err, "migrate")
|
||||
}
|
||||
|
||||
res, err := q.api.AuthImportLoginToken(ctx, t.Token)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "import")
|
||||
}
|
||||
|
||||
success, ok := res.(*tg.AuthLoginTokenSuccess)
|
||||
if !ok {
|
||||
return nil, errors.Errorf("unexpected type %T", res)
|
||||
}
|
||||
|
||||
auth, ok := success.Authorization.(*tg.AuthAuthorization)
|
||||
if !ok {
|
||||
return nil, errors.Errorf("unexpected type %T", success.Authorization)
|
||||
}
|
||||
return auth, nil
|
||||
case *tg.AuthLoginTokenSuccess:
|
||||
auth, ok := t.Authorization.(*tg.AuthAuthorization)
|
||||
if !ok {
|
||||
return nil, errors.Errorf("unexpected type %T", t.Authorization)
|
||||
}
|
||||
return auth, nil
|
||||
default:
|
||||
return nil, errors.Errorf("unexpected type %T", result)
|
||||
}
|
||||
}
|
||||
|
||||
// LoggedIn is signal channel to notify about tg.UpdateLoginToken.
|
||||
type LoggedIn <-chan struct{}
|
||||
|
||||
// OnLoginToken sets handler for given dispatcher and returns signal channel.
|
||||
func OnLoginToken(d interface {
|
||||
OnLoginToken(tg.LoginTokenHandler)
|
||||
},
|
||||
) LoggedIn {
|
||||
loggedIn := make(chan struct{})
|
||||
d.OnLoginToken(func(ctx context.Context, e tg.Entities, update *tg.UpdateLoginToken) error {
|
||||
select {
|
||||
case loggedIn <- struct{}{}:
|
||||
return nil
|
||||
default:
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return loggedIn
|
||||
}
|
||||
|
||||
// Auth generates new QR login token, shows it and awaits acceptation.
|
||||
//
|
||||
// NB: Show callback may be called more than once if QR expires.
|
||||
func (q QR) Auth(
|
||||
ctx context.Context,
|
||||
loggedIn LoggedIn,
|
||||
show func(ctx context.Context, token Token) error,
|
||||
exceptIDs ...int64,
|
||||
) (*tg.AuthAuthorization, error) {
|
||||
until := func(token Token) time.Duration {
|
||||
return token.Expires().Sub(q.clock.Now()).Truncate(time.Second)
|
||||
}
|
||||
|
||||
token, err := q.Export(ctx, exceptIDs...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
timer := q.clock.Timer(until(token))
|
||||
defer clock.StopTimer(timer)
|
||||
|
||||
for {
|
||||
if err := show(ctx, token); err != nil {
|
||||
return nil, errors.Wrap(err, "show")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-timer.C():
|
||||
t, err := q.Export(ctx, exceptIDs...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
token = t
|
||||
timer.Reset(until(token))
|
||||
|
||||
continue
|
||||
case <-loggedIn:
|
||||
}
|
||||
|
||||
return q.Import(ctx)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,187 @@
|
||||
package qrlogin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/gotd/neo"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/constant"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/testutil"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tgmock"
|
||||
)
|
||||
|
||||
func testQR(t *testing.T, migrate func(ctx context.Context, dcID int) error) (*tgmock.Mock, QR) {
|
||||
mock := tgmock.New(t)
|
||||
return mock, NewQR(tg.NewClient(mock), constant.TestAppID, constant.TestAppHash, Options{
|
||||
Migrate: migrate,
|
||||
})
|
||||
}
|
||||
|
||||
var testToken = NewToken([]byte("token"), 0)
|
||||
|
||||
func TestQR_Export(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
a := require.New(t)
|
||||
mock, qr := testQR(t, nil)
|
||||
|
||||
mock.ExpectCall(&tg.AuthExportLoginTokenRequest{
|
||||
APIID: constant.TestAppID,
|
||||
APIHash: constant.TestAppHash,
|
||||
ExceptIDs: []int64{0},
|
||||
}).ThenResult(&tg.AuthLoginToken{
|
||||
Expires: 0,
|
||||
Token: testToken.token,
|
||||
})
|
||||
result, err := qr.Export(ctx, 0)
|
||||
a.NoError(err)
|
||||
a.Equal(Token{
|
||||
token: testToken.token,
|
||||
expires: time.Unix(0, 0),
|
||||
}, result)
|
||||
|
||||
mock.ExpectCall(&tg.AuthExportLoginTokenRequest{
|
||||
APIID: constant.TestAppID,
|
||||
APIHash: constant.TestAppHash,
|
||||
}).ThenResult(&tg.AuthLoginTokenMigrateTo{
|
||||
Token: testToken.token,
|
||||
})
|
||||
_, err = qr.Export(ctx)
|
||||
a.Error(err)
|
||||
|
||||
mock.ExpectCall(&tg.AuthExportLoginTokenRequest{
|
||||
APIID: constant.TestAppID,
|
||||
APIHash: constant.TestAppHash,
|
||||
}).ThenErr(testutil.TestError())
|
||||
_, err = qr.Export(ctx)
|
||||
a.ErrorIs(err, testutil.TestError())
|
||||
}
|
||||
|
||||
func TestQR_Accept(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
a := require.New(t)
|
||||
mock, qr := testQR(t, nil)
|
||||
|
||||
auth := &tg.Authorization{
|
||||
APIID: 1,
|
||||
}
|
||||
mock.ExpectCall(&tg.AuthAcceptLoginTokenRequest{
|
||||
Token: testToken.token,
|
||||
}).ThenResult(auth)
|
||||
result, err := qr.Accept(ctx, testToken)
|
||||
a.NoError(err)
|
||||
a.Equal(auth, result)
|
||||
|
||||
mock.ExpectCall(&tg.AuthAcceptLoginTokenRequest{
|
||||
Token: testToken.token,
|
||||
}).ThenErr(testutil.TestError())
|
||||
_, err = qr.Accept(ctx, testToken)
|
||||
a.ErrorIs(err, testutil.TestError())
|
||||
}
|
||||
|
||||
func TestQR_Import(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
a := require.New(t)
|
||||
mock, qr := testQR(t, nil)
|
||||
|
||||
auth := &tg.AuthAuthorization{
|
||||
User: &tg.User{ID: 10},
|
||||
}
|
||||
mock.ExpectCall(&tg.AuthExportLoginTokenRequest{
|
||||
APIID: constant.TestAppID,
|
||||
APIHash: constant.TestAppHash,
|
||||
}).ThenResult(&tg.AuthLoginTokenSuccess{
|
||||
Authorization: auth,
|
||||
})
|
||||
result, err := qr.Import(ctx)
|
||||
a.NoError(err)
|
||||
a.Equal(auth, result)
|
||||
|
||||
mock.ExpectCall(&tg.AuthExportLoginTokenRequest{
|
||||
APIID: constant.TestAppID,
|
||||
APIHash: constant.TestAppHash,
|
||||
}).ThenResult(&tg.AuthLoginTokenMigrateTo{
|
||||
DCID: 1,
|
||||
})
|
||||
_, err = qr.Import(ctx)
|
||||
var mig *MigrationNeededError
|
||||
a.ErrorAs(err, &mig)
|
||||
a.Equal(1, mig.MigrateTo.DCID)
|
||||
|
||||
mock.ExpectCall(&tg.AuthExportLoginTokenRequest{
|
||||
APIID: constant.TestAppID,
|
||||
APIHash: constant.TestAppHash,
|
||||
}).ThenResult(&tg.AuthLoginToken{
|
||||
Token: testToken.token,
|
||||
})
|
||||
_, err = qr.Import(ctx)
|
||||
a.Error(err)
|
||||
|
||||
mock.ExpectCall(&tg.AuthExportLoginTokenRequest{
|
||||
APIID: constant.TestAppID,
|
||||
APIHash: constant.TestAppHash,
|
||||
}).ThenErr(testutil.TestError())
|
||||
_, err = qr.Import(ctx)
|
||||
a.ErrorIs(err, testutil.TestError())
|
||||
}
|
||||
|
||||
func TestQR_Auth(t *testing.T) {
|
||||
a := require.New(t)
|
||||
mock := tgmock.New(t)
|
||||
clock := neo.NewTime(time.Now())
|
||||
|
||||
auth := &tg.AuthAuthorization{
|
||||
User: &tg.User{ID: 10},
|
||||
}
|
||||
mock.ExpectCall(&tg.AuthExportLoginTokenRequest{
|
||||
APIID: constant.TestAppID,
|
||||
APIHash: constant.TestAppHash,
|
||||
}).ThenResult(&tg.AuthLoginToken{
|
||||
Expires: int(clock.Now().Add(time.Minute).Unix()),
|
||||
Token: testToken.token,
|
||||
}).ExpectCall(&tg.AuthExportLoginTokenRequest{
|
||||
APIID: constant.TestAppID,
|
||||
APIHash: constant.TestAppHash,
|
||||
}).ThenResult(&tg.AuthLoginToken{
|
||||
Expires: int(clock.Now().Add(2 * time.Minute).Unix()),
|
||||
Token: testToken.token,
|
||||
}).ExpectCall(&tg.AuthExportLoginTokenRequest{
|
||||
APIID: constant.TestAppID,
|
||||
APIHash: constant.TestAppHash,
|
||||
}).ThenResult(&tg.AuthLoginTokenSuccess{
|
||||
Authorization: auth,
|
||||
})
|
||||
|
||||
qr := NewQR(tg.NewClient(mock), constant.TestAppID, constant.TestAppHash, Options{
|
||||
Clock: clock,
|
||||
})
|
||||
|
||||
show := make(chan struct{})
|
||||
done := make(chan error)
|
||||
loggedIn := make(chan struct{})
|
||||
go func() {
|
||||
_, err := qr.Auth(context.Background(), loggedIn, func(ctx context.Context, token Token) error {
|
||||
show <- struct{}{}
|
||||
return nil
|
||||
})
|
||||
done <- err
|
||||
}()
|
||||
|
||||
// Show QR first time.
|
||||
<-show
|
||||
|
||||
// Skip 1 minute, token expires.
|
||||
clock.Travel(time.Minute + 1)
|
||||
|
||||
// Show QR second time.
|
||||
<-show
|
||||
|
||||
// Emulate update, auth done.
|
||||
loggedIn <- struct{}{}
|
||||
|
||||
a.NoError(<-done)
|
||||
}
|
||||
@@ -0,0 +1,76 @@
|
||||
package qrlogin
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"image"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
"rsc.io/qr"
|
||||
)
|
||||
|
||||
// Token represents Telegram QR Login token.
|
||||
type Token struct {
|
||||
token []byte
|
||||
expires time.Time
|
||||
}
|
||||
|
||||
// ParseTokenURL creates Token from given URL.
|
||||
func ParseTokenURL(u string) (Token, error) {
|
||||
parsed, err := url.Parse(u)
|
||||
if err != nil {
|
||||
return Token{}, err
|
||||
}
|
||||
switch {
|
||||
case parsed.Scheme != "tg":
|
||||
return Token{}, errors.Errorf("unexpected scheme %q", parsed.Scheme)
|
||||
case parsed.Host != "login":
|
||||
return Token{}, errors.Errorf("wrong path %q", parsed.Host)
|
||||
}
|
||||
|
||||
q := parsed.Query()
|
||||
if q.Get("token") == "" {
|
||||
return Token{}, errors.New("token is empty")
|
||||
}
|
||||
token, err := base64.URLEncoding.DecodeString(q.Get("token"))
|
||||
if err != nil {
|
||||
return Token{}, err
|
||||
}
|
||||
|
||||
return NewToken(token, 0), nil
|
||||
}
|
||||
|
||||
// NewToken creates new Token.
|
||||
func NewToken(token []byte, expires int) Token {
|
||||
return Token{
|
||||
token: token,
|
||||
expires: time.Unix(int64(expires), 0),
|
||||
}
|
||||
}
|
||||
|
||||
// Expires returns token expiration time.
|
||||
func (t Token) Expires() time.Time {
|
||||
return t.expires
|
||||
}
|
||||
|
||||
// String implements fmt.Stringer.
|
||||
func (t Token) String() string {
|
||||
return base64.URLEncoding.EncodeToString(t.token)
|
||||
}
|
||||
|
||||
// URL returns login URL.
|
||||
//
|
||||
// See https://core.telegram.org/api/qr-login#exporting-a-login-token.
|
||||
func (t Token) URL() string {
|
||||
return "tg://login?token=" + base64.URLEncoding.EncodeToString(t.token)
|
||||
}
|
||||
|
||||
// Image returns QR image.
|
||||
func (t Token) Image(level qr.Level) (image.Image, error) {
|
||||
code, err := qr.Encode(t.URL(), level)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "encode")
|
||||
}
|
||||
return code.Image(), nil
|
||||
}
|
||||
@@ -0,0 +1,52 @@
|
||||
package qrlogin
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestParseTokenURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
u string
|
||||
want Token
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
"Valid",
|
||||
"tg://login?token=AQL0cY5hVg_D1OqESdYnJVg5845qbd8FiOLpUUeyvcb28g==",
|
||||
Token{
|
||||
token: []uint8{
|
||||
0x1, 0x2, 0xf4, 0x71, 0x8e, 0x61,
|
||||
0x56, 0xf, 0xc3, 0xd4, 0xea, 0x84,
|
||||
0x49, 0xd6, 0x27, 0x25, 0x58, 0x39,
|
||||
0xf3, 0x8e, 0x6a, 0x6d, 0xdf, 0x5,
|
||||
0x88, 0xe2, 0xe9, 0x51, 0x47, 0xb2,
|
||||
0xbd, 0xc6, 0xf6, 0xf2,
|
||||
},
|
||||
expires: time.Unix(0, 0),
|
||||
},
|
||||
false,
|
||||
},
|
||||
{"InvalidSchema", "vk://login", Token{}, true},
|
||||
{"InvalidPath", "tg://aboba", Token{}, true},
|
||||
{"NoToken", "tg://login", Token{}, true},
|
||||
{"InvalidBase64", "tg://login?token=A", Token{}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a := require.New(t)
|
||||
|
||||
got, err := ParseTokenURL(tt.u)
|
||||
if tt.wantErr {
|
||||
a.Error(err)
|
||||
} else {
|
||||
a.Equal(tt.want, got)
|
||||
a.NoError(err)
|
||||
a.Equal(tt.want.URL(), tt.u)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
// self returns current user.
|
||||
//
|
||||
// You can use tg.User.Bot to check whether current user is bot.
|
||||
func (c *Client) self(ctx context.Context) (*tg.User, error) {
|
||||
users, err := c.api.UsersGetUsers(ctx, []tg.InputUserClass{&tg.InputUserSelf{}})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user, ok := tg.UserClassArray(users).FirstAsNotEmpty()
|
||||
if !ok {
|
||||
return nil, errors.Errorf("users response count: %v", users)
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
@@ -0,0 +1,42 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/testutil"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tgmock"
|
||||
)
|
||||
|
||||
func TestClient_self(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
mockTest(func(a *require.Assertions, mock *tgmock.Mock, client *Client) {
|
||||
mock.ExpectCall(&tg.UsersGetUsersRequest{
|
||||
ID: []tg.InputUserClass{&tg.InputUserSelf{}},
|
||||
}).ThenErr(testutil.TestError())
|
||||
_, err := client.self(ctx)
|
||||
a.Error(err)
|
||||
|
||||
mock.ExpectCall(&tg.UsersGetUsersRequest{
|
||||
ID: []tg.InputUserClass{&tg.InputUserSelf{}},
|
||||
}).ThenResult(&tg.UserClassVector{Elems: []tg.UserClass{&tg.UserEmpty{
|
||||
ID: 10,
|
||||
}}})
|
||||
_, err = client.self(ctx)
|
||||
a.Error(err)
|
||||
|
||||
mock.ExpectCall(&tg.UsersGetUsersRequest{
|
||||
ID: []tg.InputUserClass{&tg.InputUserSelf{}},
|
||||
}).ThenResult(&tg.UserClassVector{Elems: []tg.UserClass{&tg.User{
|
||||
Self: true,
|
||||
ID: 10,
|
||||
AccessHash: 10,
|
||||
}}})
|
||||
r, err := client.self(ctx)
|
||||
a.NoError(err)
|
||||
a.Equal(int64(10), r.ID)
|
||||
})(t)
|
||||
}
|
||||
@@ -0,0 +1,37 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
// SignUpRequired means that log in failed because corresponding account
|
||||
// does not exist, so sign up is required.
|
||||
type SignUpRequired struct {
|
||||
TermsOfService tg.HelpTermsOfService
|
||||
}
|
||||
|
||||
// Is returns true if err is SignUpRequired.
|
||||
func (s *SignUpRequired) Is(err error) bool {
|
||||
_, ok := err.(*SignUpRequired)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (s *SignUpRequired) Error() string {
|
||||
return "account with provided number does not exist (sign up required)"
|
||||
}
|
||||
|
||||
// checkResult checks that `a` is *tg.AuthAuthorization and returns authorization result or error.
|
||||
func checkResult(a tg.AuthAuthorizationClass) (*tg.AuthAuthorization, error) {
|
||||
switch a := a.(type) {
|
||||
case *tg.AuthAuthorization:
|
||||
return a, nil // ok
|
||||
case *tg.AuthAuthorizationSignUpRequired:
|
||||
return nil, &SignUpRequired{
|
||||
TermsOfService: a.TermsOfService,
|
||||
}
|
||||
default:
|
||||
return nil, errors.Errorf("got unexpected response %T", a)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,60 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
// Status represents authorization status.
|
||||
type Status struct {
|
||||
// Authorized is true if client is authorized.
|
||||
Authorized bool
|
||||
// User is current User object.
|
||||
User *tg.User
|
||||
}
|
||||
|
||||
// Status gets authorization status of client.
|
||||
func (c *Client) Status(ctx context.Context) (*Status, error) {
|
||||
u, err := c.self(ctx)
|
||||
if IsUnauthorized(err) {
|
||||
return &Status{}, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Status{
|
||||
Authorized: true,
|
||||
User: u,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// IfNecessary runs given auth flow if current session is not authorized.
|
||||
func (c *Client) IfNecessary(ctx context.Context, flow Flow) error {
|
||||
auth, err := c.Status(ctx)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "get auth status")
|
||||
}
|
||||
if auth.Authorized {
|
||||
return nil
|
||||
}
|
||||
if err := flow.Run(ctx, c); err != nil {
|
||||
return errors.Wrap(err, "auth flow")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Test creates and runs auth flow using Test authenticator
|
||||
// if current session is not authorized.
|
||||
func (c *Client) Test(ctx context.Context, dc int) error {
|
||||
return c.IfNecessary(ctx, NewFlow(Test(c.rand, dc), SendCodeOptions{}))
|
||||
}
|
||||
|
||||
// TestUser creates and runs auth flow using TestUser authenticator
|
||||
// if current session is not authorized.
|
||||
func (c *Client) TestUser(ctx context.Context, phone string, dc int) error {
|
||||
return c.IfNecessary(ctx, NewFlow(TestUser(phone, dc), SendCodeOptions{}))
|
||||
}
|
||||
@@ -0,0 +1,107 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tgerr"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tgmock"
|
||||
)
|
||||
|
||||
func TestClient_Status(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("Authorized", func(t *testing.T) {
|
||||
mock := tgmock.NewRequire(t)
|
||||
user := &tg.User{
|
||||
Username: "user",
|
||||
}
|
||||
mock.Expect().ThenResult(&tg.UserClassVector{Elems: []tg.UserClass{user}})
|
||||
|
||||
status, err := testClient(mock).Status(ctx)
|
||||
require.NoError(t, err)
|
||||
require.True(t, status.Authorized)
|
||||
require.Equal(t, user, status.User)
|
||||
})
|
||||
|
||||
t.Run("Unauthorized", func(t *testing.T) {
|
||||
mock := tgmock.NewRequire(t)
|
||||
mock.Expect().ThenUnregistered()
|
||||
|
||||
status, err := testClient(mock).Status(ctx)
|
||||
require.NoError(t, err)
|
||||
require.False(t, status.Authorized)
|
||||
})
|
||||
|
||||
t.Run("Error", func(t *testing.T) {
|
||||
mock := tgmock.NewRequire(t)
|
||||
mock.Expect().ThenRPCErr(&tgerr.Error{
|
||||
Code: 500,
|
||||
Message: "BRUH",
|
||||
Type: "BRUH",
|
||||
})
|
||||
|
||||
_, err := testClient(mock).Status(ctx)
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestClient_IfNecessary(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("Authorized", func(t *testing.T) {
|
||||
mock := tgmock.NewRequire(t)
|
||||
testUser := &tg.User{
|
||||
Username: "user",
|
||||
}
|
||||
mock.Expect().ThenResult(&tg.UserClassVector{Elems: []tg.UserClass{testUser}})
|
||||
|
||||
// Pass empty AuthFlow because it should not be called anyway.
|
||||
require.NoError(t, testClient(mock).IfNecessary(ctx, Flow{}))
|
||||
})
|
||||
|
||||
t.Run("Error", func(t *testing.T) {
|
||||
mock := tgmock.NewRequire(t)
|
||||
mock.Expect().ThenRPCErr(&tgerr.Error{
|
||||
Code: 500,
|
||||
Message: "BRUH",
|
||||
Type: "BRUH",
|
||||
})
|
||||
|
||||
// Pass empty AuthFlow because it should not be called anyway.
|
||||
require.Error(t, testClient(mock).IfNecessary(ctx, Flow{}))
|
||||
})
|
||||
}
|
||||
|
||||
func TestClient_Test(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("Authorized", func(t *testing.T) {
|
||||
mock := tgmock.NewRequire(t)
|
||||
testUser := &tg.User{
|
||||
Username: "user",
|
||||
}
|
||||
mock.Expect().ThenResult(&tg.UserClassVector{Elems: []tg.UserClass{testUser}})
|
||||
|
||||
// Pass empty AuthFlow because it should not be called anyway.
|
||||
require.NoError(t, testClient(mock).Test(ctx, 2))
|
||||
})
|
||||
}
|
||||
|
||||
func TestClient_TestUser(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("Authorized", func(t *testing.T) {
|
||||
mock := tgmock.NewRequire(t)
|
||||
testUser := &tg.User{
|
||||
Username: "user",
|
||||
}
|
||||
mock.Expect().ThenResult(&tg.UserClassVector{Elems: []tg.UserClass{testUser}})
|
||||
|
||||
// Pass empty AuthFlow because it should not be called anyway.
|
||||
require.NoError(t, testClient(mock).TestUser(ctx, "phone", 2))
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,154 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tgerr"
|
||||
)
|
||||
|
||||
// ErrPasswordInvalid means that password provided to Password is invalid.
|
||||
//
|
||||
// Note that telegram does not trim whitespace characters by default, check
|
||||
// that provided password is expected and clean whitespaces if needed.
|
||||
// You can use strings.TrimSpace(password) for this.
|
||||
var ErrPasswordInvalid = errors.New("invalid password")
|
||||
|
||||
// Password performs login via secure remote password (aka 2FA).
|
||||
//
|
||||
// Method can be called after SignIn to provide password if requested.
|
||||
func (c *Client) Password(ctx context.Context, password string) (*tg.AuthAuthorization, error) {
|
||||
p, err := c.api.AccountGetPassword(ctx)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "get SRP parameters")
|
||||
}
|
||||
|
||||
a, err := PasswordHash([]byte(password), p.SRPID, p.SRPB, p.SecureRandom, p.CurrentAlgo)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "compute password hash")
|
||||
}
|
||||
|
||||
auth, err := c.api.AuthCheckPassword(ctx, &tg.InputCheckPasswordSRP{
|
||||
SRPID: p.SRPID,
|
||||
A: a.A,
|
||||
M1: a.M1,
|
||||
})
|
||||
if tg.IsPasswordHashInvalid(err) {
|
||||
return nil, ErrPasswordInvalid
|
||||
}
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "check password")
|
||||
}
|
||||
result, err := checkResult(auth)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "check")
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// SendCodeOptions defines how to send auth code to user.
|
||||
type SendCodeOptions struct {
|
||||
// AllowFlashCall allows phone verification via phone calls.
|
||||
AllowFlashCall bool
|
||||
// Pass true if the phone number is used on the current device.
|
||||
// Ignored if AllowFlashCall is not set.
|
||||
CurrentNumber bool
|
||||
// If a token that will be included in eventually sent SMSs is required:
|
||||
// required in newer versions of android, to use the android SMS receiver APIs.
|
||||
AllowAppHash bool
|
||||
}
|
||||
|
||||
// SendCode requests code for provided phone number, returning code hash
|
||||
// and error if any. Use AuthFlow to reduce boilerplate.
|
||||
//
|
||||
// This method should be called first in user authentication flow.
|
||||
func (c *Client) SendCode(ctx context.Context, phone string, options SendCodeOptions) (tg.AuthSentCodeClass, error) {
|
||||
var settings tg.CodeSettings
|
||||
if options.AllowAppHash {
|
||||
settings.SetAllowAppHash(true)
|
||||
}
|
||||
if options.AllowFlashCall {
|
||||
settings.SetAllowFlashcall(true)
|
||||
}
|
||||
if options.CurrentNumber {
|
||||
settings.SetCurrentNumber(true)
|
||||
}
|
||||
|
||||
sentCode, err := c.api.AuthSendCode(ctx, &tg.AuthSendCodeRequest{
|
||||
PhoneNumber: phone,
|
||||
APIID: c.appID,
|
||||
APIHash: c.appHash,
|
||||
Settings: settings,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "send code")
|
||||
}
|
||||
return sentCode, nil
|
||||
}
|
||||
|
||||
// ErrPasswordAuthNeeded means that 2FA auth is required.
|
||||
//
|
||||
// Call Client.Password to provide 2FA password.
|
||||
var ErrPasswordAuthNeeded = errors.New("2FA required")
|
||||
|
||||
// SignIn performs sign in with provided user phone, code and code hash.
|
||||
//
|
||||
// If ErrPasswordAuthNeeded is returned, call Password to provide 2FA
|
||||
// password.
|
||||
//
|
||||
// To obtain codeHash, use SendCode.
|
||||
func (c *Client) SignIn(ctx context.Context, phone, code, codeHash string) (*tg.AuthAuthorization, error) {
|
||||
auth, err := c.api.AuthSignIn(ctx, &tg.AuthSignInRequest{
|
||||
PhoneNumber: phone,
|
||||
PhoneCodeHash: codeHash,
|
||||
PhoneCode: code,
|
||||
})
|
||||
if tgerr.Is(err, "SESSION_PASSWORD_NEEDED") {
|
||||
return nil, ErrPasswordAuthNeeded
|
||||
}
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "sign in")
|
||||
}
|
||||
result, err := checkResult(auth)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "check")
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// AcceptTOS accepts version of Terms Of Service.
|
||||
func (c *Client) AcceptTOS(ctx context.Context, id tg.DataJSON) error {
|
||||
_, err := c.api.HelpAcceptTermsOfService(ctx, id)
|
||||
return err
|
||||
}
|
||||
|
||||
// SignUp wraps parameters for SignUp.
|
||||
type SignUp struct {
|
||||
PhoneNumber string
|
||||
PhoneCodeHash string
|
||||
FirstName string
|
||||
LastName string
|
||||
}
|
||||
|
||||
// SignUp registers a validated phone number in the system.
|
||||
//
|
||||
// To obtain codeHash, use SendCode.
|
||||
// Use AuthFlow helper to handle authentication flow.
|
||||
func (c *Client) SignUp(ctx context.Context, s SignUp) (*tg.AuthAuthorization, error) {
|
||||
auth, err := c.api.AuthSignUp(ctx, &tg.AuthSignUpRequest{
|
||||
LastName: s.LastName,
|
||||
PhoneCodeHash: s.PhoneCodeHash,
|
||||
PhoneNumber: s.PhoneNumber,
|
||||
FirstName: s.FirstName,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "request")
|
||||
}
|
||||
result, err := checkResult(auth)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "check")
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
@@ -0,0 +1,254 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/hex"
|
||||
"math/rand"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tgerr"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tgmock"
|
||||
)
|
||||
|
||||
func getHex(t testing.TB, in string) []byte {
|
||||
res, err := hex.DecodeString(in)
|
||||
if err != nil {
|
||||
t.Fatal("failed to get hex", err)
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func TestClient_AuthSignIn(t *testing.T) {
|
||||
const (
|
||||
phone = "123123"
|
||||
code = "1010"
|
||||
password = "secret"
|
||||
codeHash = "hash"
|
||||
)
|
||||
ctx := context.Background()
|
||||
testUser := &tg.User{ID: 1}
|
||||
invoker := tgmock.Invoker(func(body bin.Encoder) (bin.Encoder, error) {
|
||||
switch req := body.(type) {
|
||||
case *tg.UsersGetUsersRequest:
|
||||
return nil, &tgerr.Error{
|
||||
Code: 401,
|
||||
Message: "AUTH_KEY_UNREGISTERED",
|
||||
Type: "AUTH_KEY_UNREGISTERED",
|
||||
}
|
||||
case *tg.AuthSendCodeRequest:
|
||||
settings := tg.CodeSettings{}
|
||||
settings.SetCurrentNumber(true)
|
||||
assert.Equal(t, &tg.AuthSendCodeRequest{
|
||||
PhoneNumber: phone,
|
||||
APIHash: testAppHash,
|
||||
APIID: testAppID,
|
||||
Settings: settings,
|
||||
}, req)
|
||||
return &tg.AuthSentCode{
|
||||
Type: &tg.AuthSentCodeTypeApp{},
|
||||
PhoneCodeHash: codeHash,
|
||||
}, nil
|
||||
case *tg.AuthSignInRequest:
|
||||
assert.Equal(t, &tg.AuthSignInRequest{
|
||||
PhoneNumber: phone,
|
||||
PhoneCodeHash: codeHash,
|
||||
PhoneCode: code,
|
||||
}, req)
|
||||
return nil, tgerr.New(401, "SESSION_PASSWORD_NEEDED")
|
||||
case *tg.AccountGetPasswordRequest:
|
||||
algo := &tg.PasswordKdfAlgoSHA256SHA256PBKDF2HMACSHA512iter100000SHA256ModPow{
|
||||
Salt1: getHex(t, "4D11FB6BEC38F9D2546BB0F61E4F1C99A1BC0DB8F0D5F35B1291B37B213123D7ED48F3C6794D495B"),
|
||||
Salt2: getHex(t, "A1B181AAFE88188680AE32860D60BB01"),
|
||||
G: 3,
|
||||
P: getHex(t, "C71CAEB9C6B1C9048E6C522F70F13F73980D40238E3E21C14934D037563D930F"+
|
||||
"48198A0AA7C14058229493D22530F4DBFA336F6E0AC925139543AED44CCE7C37"+
|
||||
"20FD51F69458705AC68CD4FE6B6B13ABDC9746512969328454F18FAF8C595F64"+
|
||||
"2477FE96BB2A941D5BCD1D4AC8CC49880708FA9B378E3C4F3A9060BEE67CF9A4"+
|
||||
"A4A695811051907E162753B56B0F6B410DBA74D8A84B2A14B3144E0EF1284754"+
|
||||
"FD17ED950D5965B4B9DD46582DB1178D169C6BC465B0D6FF9CA3928FEF5B9AE4"+
|
||||
"E418FC15E83EBEA0F87FA9FF5EED70050DED2849F47BF959D956850CE929851F"+
|
||||
"0D8115F635B105EE2E4E15D04B2454BF6F4FADF034B10403119CD8E3B92FCC5B"),
|
||||
}
|
||||
pwd := &tg.AccountPassword{
|
||||
NewAlgo: algo,
|
||||
NewSecureAlgo: &tg.SecurePasswordKdfAlgoPBKDF2HMACSHA512iter100000{},
|
||||
}
|
||||
pwd.SetCurrentAlgo(algo)
|
||||
return pwd, nil
|
||||
case *tg.AuthCheckPasswordRequest:
|
||||
// TODO(ernado): Check actual secure remote password here.
|
||||
switch pwd := req.Password.(type) {
|
||||
case *tg.InputCheckPasswordSRP:
|
||||
assert.NotEmpty(t, pwd.A)
|
||||
assert.NotEmpty(t, pwd.M1)
|
||||
assert.NotEqual(t, pwd.SRPID, 0)
|
||||
default:
|
||||
t.Errorf("unexpectd pwd type %T", pwd)
|
||||
}
|
||||
return &tg.AuthAuthorization{
|
||||
User: testUser,
|
||||
}, nil
|
||||
}
|
||||
return nil, errors.New("unexpected")
|
||||
})
|
||||
|
||||
t.Run("Manual", func(t *testing.T) {
|
||||
// 1. Request code from server to device.
|
||||
client := testClient(invoker)
|
||||
sentCode, err := client.SendCode(ctx, phone, SendCodeOptions{CurrentNumber: true})
|
||||
require.NoError(t, err)
|
||||
h := sentCode.(*tg.AuthSentCode).PhoneCodeHash
|
||||
require.Equal(t, codeHash, h)
|
||||
|
||||
// 2. Send code from device to server.
|
||||
// Server is responding with 2FA password prompt.
|
||||
_, signInErr := client.SignIn(ctx, phone, code, h)
|
||||
require.ErrorIs(t, signInErr, ErrPasswordAuthNeeded)
|
||||
|
||||
// 3. Provide 2FA password.
|
||||
result, err := client.Password(ctx, password)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, testUser, result.User)
|
||||
})
|
||||
|
||||
flow := NewFlow(
|
||||
Constant(phone, password, CodeAuthenticatorFunc(
|
||||
func(ctx context.Context, _ *tg.AuthSentCode) (string, error) {
|
||||
return code, nil
|
||||
},
|
||||
)),
|
||||
SendCodeOptions{CurrentNumber: true},
|
||||
)
|
||||
t.Run("AuthFlow", func(t *testing.T) {
|
||||
require.NoError(t, flow.Run(ctx, testClient(invoker)))
|
||||
})
|
||||
t.Run("IfNecessary", func(t *testing.T) {
|
||||
require.NoError(t, testClient(invoker).IfNecessary(ctx, flow))
|
||||
})
|
||||
}
|
||||
|
||||
func TestClientTestAuth(t *testing.T) {
|
||||
const (
|
||||
codeHash = "hash"
|
||||
dcID = 2
|
||||
)
|
||||
ctx := context.Background()
|
||||
invoker := tgmock.Invoker(func(body bin.Encoder) (bin.Encoder, error) {
|
||||
switch req := body.(type) {
|
||||
case *tg.AuthSendCodeRequest:
|
||||
assert.Equal(t, &tg.AuthSendCodeRequest{
|
||||
PhoneNumber: req.PhoneNumber,
|
||||
APIHash: testAppHash,
|
||||
APIID: testAppID,
|
||||
Settings: tg.CodeSettings{},
|
||||
}, req)
|
||||
return &tg.AuthSentCode{
|
||||
Type: &tg.AuthSentCodeTypeApp{
|
||||
Length: 6,
|
||||
},
|
||||
PhoneCodeHash: codeHash,
|
||||
}, nil
|
||||
case *tg.AuthSignInRequest:
|
||||
if !strings.HasPrefix(req.PhoneNumber, "99966") {
|
||||
t.Fatalf("unexpected phone number %s", req.PhoneNumber)
|
||||
}
|
||||
dcPart := req.PhoneNumber[5:6]
|
||||
assert.Equal(t, strconv.Itoa(dcID), dcPart, "dc part of phone number")
|
||||
assert.Equal(t, &tg.AuthSignInRequest{
|
||||
PhoneNumber: req.PhoneNumber,
|
||||
PhoneCodeHash: codeHash,
|
||||
PhoneCode: strings.Repeat(dcPart, 6),
|
||||
}, req)
|
||||
return &tg.AuthAuthorization{
|
||||
User: &tg.User{ID: 1},
|
||||
}, nil
|
||||
}
|
||||
return nil, errors.New("unexpected")
|
||||
})
|
||||
require.NoError(t, NewFlow(
|
||||
Test(rand.New(rand.NewSource(1)), dcID),
|
||||
SendCodeOptions{},
|
||||
).Run(ctx, testClient(invoker)))
|
||||
}
|
||||
|
||||
func TestClientTestSignUp(t *testing.T) {
|
||||
const (
|
||||
dcID = 2
|
||||
codeHash = "hash"
|
||||
tosID = "foo"
|
||||
)
|
||||
ctx := context.Background()
|
||||
invoker := tgmock.Invoker(func(body bin.Encoder) (bin.Encoder, error) {
|
||||
switch req := body.(type) {
|
||||
case *tg.AuthSendCodeRequest:
|
||||
assert.Equal(t, &tg.AuthSendCodeRequest{
|
||||
PhoneNumber: req.PhoneNumber,
|
||||
APIHash: testAppHash,
|
||||
APIID: testAppID,
|
||||
Settings: tg.CodeSettings{},
|
||||
}, req)
|
||||
return &tg.AuthSentCode{
|
||||
Type: &tg.AuthSentCodeTypeApp{
|
||||
Length: 6,
|
||||
},
|
||||
PhoneCodeHash: codeHash,
|
||||
}, nil
|
||||
case *tg.AuthSignUpRequest:
|
||||
assert.Equal(t, &tg.AuthSignUpRequest{
|
||||
PhoneNumber: req.PhoneNumber,
|
||||
PhoneCodeHash: codeHash,
|
||||
FirstName: "Test",
|
||||
LastName: "User",
|
||||
}, req)
|
||||
return &tg.AuthAuthorization{
|
||||
User: &tg.User{ID: 1},
|
||||
}, nil
|
||||
case *tg.HelpAcceptTermsOfServiceRequest:
|
||||
return &tg.BoolTrue{}, nil
|
||||
case *tg.AuthSignInRequest:
|
||||
if !strings.HasPrefix(req.PhoneNumber, "99966") {
|
||||
t.Fatalf("unexpected phone number %s", req.PhoneNumber)
|
||||
}
|
||||
dcPart := req.PhoneNumber[5:6]
|
||||
assert.Equal(t, strconv.Itoa(dcID), dcPart, "dc part of phone number")
|
||||
assert.Equal(t, &tg.AuthSignInRequest{
|
||||
PhoneNumber: req.PhoneNumber,
|
||||
PhoneCodeHash: codeHash,
|
||||
PhoneCode: strings.Repeat(dcPart, 6),
|
||||
}, req)
|
||||
|
||||
res := &tg.AuthAuthorizationSignUpRequired{}
|
||||
res.SetTermsOfService(tg.HelpTermsOfService{ID: tg.DataJSON{Data: tosID}})
|
||||
|
||||
return res, nil
|
||||
}
|
||||
return nil, errors.New("unexpected")
|
||||
})
|
||||
require.NoError(t, NewFlow(
|
||||
Test(rand.New(rand.NewSource(1)), dcID),
|
||||
SendCodeOptions{},
|
||||
).Run(ctx, testClient(invoker)))
|
||||
}
|
||||
|
||||
func TestClient_AcceptTOS(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
mockTest(func(a *require.Assertions, mock *tgmock.Mock, client *Client) {
|
||||
mock.Expect().ThenUnregistered()
|
||||
a.Error(client.AcceptTOS(ctx, tg.DataJSON{
|
||||
Data: `{"data":"data"}`,
|
||||
}))
|
||||
|
||||
mock.Expect().ThenTrue()
|
||||
a.NoError(client.AcceptTOS(ctx, tg.DataJSON{
|
||||
Data: `{"data":"data"}`,
|
||||
}))
|
||||
})(t)
|
||||
}
|
||||
@@ -0,0 +1,166 @@
|
||||
package telegram_test
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/auth"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/auth/qrlogin"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/dcs"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
func ExampleClient_Auth_codeOnly() {
|
||||
check := func(err error) {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
appIDString = os.Getenv("APP_ID")
|
||||
appHash = os.Getenv("APP_HASH")
|
||||
phone = os.Getenv("PHONE")
|
||||
)
|
||||
if appIDString == "" || appHash == "" || phone == "" {
|
||||
log.Fatal("PHONE, APP_ID or APP_HASH is not set")
|
||||
}
|
||||
|
||||
appID, err := strconv.Atoi(appIDString)
|
||||
check(err)
|
||||
|
||||
ctx := context.Background()
|
||||
client := telegram.NewClient(appID, appHash, telegram.Options{})
|
||||
codeAsk := func(ctx context.Context, sentCode *tg.AuthSentCode) (string, error) {
|
||||
fmt.Print("code:")
|
||||
code, err := bufio.NewReader(os.Stdin).ReadString('\n')
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
code = strings.ReplaceAll(code, "\n", "")
|
||||
return code, nil
|
||||
}
|
||||
|
||||
check(client.Run(ctx, func(ctx context.Context) error {
|
||||
return auth.NewFlow(
|
||||
auth.CodeOnly(phone, auth.CodeAuthenticatorFunc(codeAsk)),
|
||||
auth.SendCodeOptions{},
|
||||
).Run(ctx, client.Auth())
|
||||
}))
|
||||
}
|
||||
|
||||
func ExampleClient_Auth_password() {
|
||||
check := func(err error) {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
appIDString = os.Getenv("APP_ID")
|
||||
appHash = os.Getenv("APP_HASH")
|
||||
phone = os.Getenv("PHONE")
|
||||
pass = os.Getenv("PASSWORD")
|
||||
)
|
||||
if appIDString == "" || appHash == "" || phone == "" || pass == "" {
|
||||
log.Fatal("PHONE, PASSWORD, APP_ID or APP_HASH is not set")
|
||||
}
|
||||
|
||||
appID, err := strconv.Atoi(appIDString)
|
||||
check(err)
|
||||
|
||||
ctx := context.Background()
|
||||
client := telegram.NewClient(appID, appHash, telegram.Options{})
|
||||
codeAsk := func(ctx context.Context, sentCode *tg.AuthSentCode) (string, error) {
|
||||
fmt.Print("code:")
|
||||
code, err := bufio.NewReader(os.Stdin).ReadString('\n')
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
code = strings.ReplaceAll(code, "\n", "")
|
||||
return code, nil
|
||||
}
|
||||
|
||||
check(client.Run(ctx, func(ctx context.Context) error {
|
||||
return auth.NewFlow(
|
||||
auth.Constant(phone, pass, auth.CodeAuthenticatorFunc(codeAsk)),
|
||||
auth.SendCodeOptions{},
|
||||
).Run(ctx, client.Auth())
|
||||
}))
|
||||
}
|
||||
|
||||
func ExampleClient_Auth_test() {
|
||||
// Example of using test server.
|
||||
const dcID = 2
|
||||
|
||||
ctx := context.Background()
|
||||
client := telegram.NewClient(telegram.TestAppID, telegram.TestAppHash, telegram.Options{
|
||||
DC: dcID,
|
||||
DCList: dcs.Test(),
|
||||
})
|
||||
if err := client.Run(ctx, func(ctx context.Context) error {
|
||||
return client.Auth().Test(ctx, dcID)
|
||||
}); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func ExampleClient_Auth_bot() {
|
||||
ctx := context.Background()
|
||||
client := telegram.NewClient(telegram.TestAppID, telegram.TestAppHash, telegram.Options{})
|
||||
if err := client.Run(ctx, func(ctx context.Context) error {
|
||||
// Checking auth status.
|
||||
status, err := client.Auth().Status(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Can be already authenticated if we have valid session in
|
||||
// session storage.
|
||||
if !status.Authorized {
|
||||
// Otherwise, perform bot authentication.
|
||||
if _, err := client.Auth().Bot(ctx, os.Getenv("BOT_TOKEN")); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// All good, manually authenticated.
|
||||
return nil
|
||||
}); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func ExampleQR_Auth() {
|
||||
ctx := context.Background()
|
||||
|
||||
d := tg.NewUpdateDispatcher()
|
||||
loggedIn := qrlogin.OnLoginToken(d)
|
||||
client := telegram.NewClient(telegram.TestAppID, telegram.TestAppHash, telegram.Options{
|
||||
UpdateHandler: d,
|
||||
})
|
||||
if err := client.Run(ctx, func(ctx context.Context) error {
|
||||
qr := client.QR()
|
||||
authorization, err := qr.Auth(ctx, loggedIn, func(ctx context.Context, token qrlogin.Token) error {
|
||||
fmt.Printf("Open %s using your phone\n", token.URL())
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
u, ok := authorization.User.AsNotEmpty()
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T", authorization.User)
|
||||
}
|
||||
fmt.Println("ID:", u.ID, "Username:", u.Username, "Bot:", u.Bot)
|
||||
return nil
|
||||
}); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,53 @@
|
||||
package telegram
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
)
|
||||
|
||||
// RunUntilCanceled is client callback which
|
||||
// locks until client context is canceled.
|
||||
func RunUntilCanceled(ctx context.Context, client *Client) error {
|
||||
<-ctx.Done()
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
// BotFromEnvironment creates bot client using ClientFromEnvironment
|
||||
// connects to server and authenticates it.
|
||||
//
|
||||
// Variables:
|
||||
// BOT_TOKEN — token from BotFather.
|
||||
func BotFromEnvironment(
|
||||
ctx context.Context,
|
||||
opts Options,
|
||||
setup func(ctx context.Context, client *Client) error,
|
||||
cb func(ctx context.Context, client *Client) error,
|
||||
) error {
|
||||
client, err := ClientFromEnvironment(opts)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "create client")
|
||||
}
|
||||
|
||||
if setup != nil {
|
||||
if err := setup(ctx, client); err != nil {
|
||||
return errors.Wrap(err, "setup")
|
||||
}
|
||||
}
|
||||
|
||||
return client.Run(ctx, func(ctx context.Context) error {
|
||||
status, err := client.Auth().Status(ctx)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "auth status")
|
||||
}
|
||||
|
||||
if !status.Authorized {
|
||||
if _, err := client.Auth().Bot(ctx, os.Getenv("BOT_TOKEN")); err != nil {
|
||||
return errors.Wrap(err, "login")
|
||||
}
|
||||
}
|
||||
|
||||
return cb(ctx, client)
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,174 @@
|
||||
package telegram
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"github.com/go-faster/errors"
|
||||
"go.uber.org/zap"
|
||||
"golang.org/x/net/proxy"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/clock"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/session"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/auth"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/dcs"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tgerr"
|
||||
)
|
||||
|
||||
func sessionDir() (string, error) {
|
||||
dir, ok := os.LookupEnv("SESSION_DIR")
|
||||
if ok {
|
||||
return filepath.Abs(dir)
|
||||
}
|
||||
|
||||
dir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
dir = "."
|
||||
}
|
||||
|
||||
return filepath.Abs(filepath.Join(dir, ".td"))
|
||||
}
|
||||
|
||||
// OptionsFromEnvironment fills unfilled field in opts parameter
|
||||
// using environment variables.
|
||||
//
|
||||
// Variables:
|
||||
//
|
||||
// SESSION_FILE: path to session file
|
||||
// SESSION_DIR: path to session directory, if SESSION_FILE is not set
|
||||
// ALL_PROXY, NO_PROXY: see https://pkg.go.dev/golang.org/x/net/proxy#FromEnvironment
|
||||
func OptionsFromEnvironment(opts Options) (Options, error) {
|
||||
// Setting up session storage if not provided.
|
||||
if opts.SessionStorage == nil {
|
||||
sessionFile, ok := os.LookupEnv("SESSION_FILE")
|
||||
if !ok {
|
||||
dir, err := sessionDir()
|
||||
if err != nil {
|
||||
return Options{}, errors.Wrap(err, "SESSION_DIR not set or invalid")
|
||||
}
|
||||
sessionFile = filepath.Join(dir, "session.json")
|
||||
}
|
||||
|
||||
dir, _ := filepath.Split(sessionFile)
|
||||
if err := os.MkdirAll(dir, 0700); err != nil {
|
||||
return Options{}, errors.Wrap(err, "session dir creation")
|
||||
}
|
||||
|
||||
opts.SessionStorage = &session.FileStorage{
|
||||
Path: sessionFile,
|
||||
}
|
||||
}
|
||||
|
||||
if opts.Resolver == nil {
|
||||
opts.Resolver = dcs.Plain(dcs.PlainOptions{
|
||||
Dial: proxy.Dial,
|
||||
})
|
||||
}
|
||||
|
||||
return opts, nil
|
||||
}
|
||||
|
||||
// ClientFromEnvironment creates client using OptionsFromEnvironment
|
||||
// but does not connect to server.
|
||||
//
|
||||
// Variables:
|
||||
//
|
||||
// APP_ID: app_id of Telegram app.
|
||||
// APP_HASH: app_hash of Telegram app.
|
||||
func ClientFromEnvironment(opts Options) (*Client, error) {
|
||||
appID, err := strconv.Atoi(os.Getenv("APP_ID"))
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "APP_ID not set or invalid")
|
||||
}
|
||||
|
||||
appHash := os.Getenv("APP_HASH")
|
||||
if appHash == "" {
|
||||
return nil, errors.New("no APP_HASH provided")
|
||||
}
|
||||
|
||||
opts, err = OptionsFromEnvironment(opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewClient(appID, appHash, opts), nil
|
||||
}
|
||||
|
||||
func retry(ctx context.Context, logger *zap.Logger, cb func(ctx context.Context) error) error {
|
||||
b := backoff.WithContext(backoff.NewExponentialBackOff(), ctx)
|
||||
|
||||
// List of known retryable RPC error types.
|
||||
retryableErrors := []string{
|
||||
"NEED_MEMBER_INVALID",
|
||||
"AUTH_KEY_UNREGISTERED",
|
||||
"API_ID_PUBLISHED_FLOOD",
|
||||
}
|
||||
|
||||
return backoff.Retry(func() error {
|
||||
if err := cb(ctx); err != nil {
|
||||
logger.Warn("TestClient run failed", zap.Error(err))
|
||||
|
||||
if tgerr.Is(err, retryableErrors...) {
|
||||
return err
|
||||
}
|
||||
if timeout, ok := AsFloodWait(err); ok {
|
||||
timer := clock.System.Timer(timeout + 1*time.Second)
|
||||
defer clock.StopTimer(timer)
|
||||
|
||||
select {
|
||||
case <-timer.C():
|
||||
return err
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
// Possibly server closed connection.
|
||||
return err
|
||||
}
|
||||
|
||||
return backoff.Permanent(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}, b)
|
||||
}
|
||||
|
||||
// TestClient creates and authenticates user telegram.Client
|
||||
// using Telegram test server.
|
||||
func TestClient(ctx context.Context, opts Options, cb func(ctx context.Context, client *Client) error) error {
|
||||
if opts.DC == 0 {
|
||||
opts.DC = 2
|
||||
}
|
||||
if opts.DCList.Zero() {
|
||||
opts.DCList = dcs.Test()
|
||||
}
|
||||
|
||||
logger := zap.NewNop()
|
||||
if opts.Logger != nil {
|
||||
logger = opts.Logger.Named("test")
|
||||
}
|
||||
|
||||
// Sometimes testing server can return "AUTH_KEY_UNREGISTERED" error.
|
||||
// It is expected and client implementation is unlikely to cause
|
||||
// such errors, so just doing retries using backoff.
|
||||
return retry(ctx, logger, func(retryCtx context.Context) error {
|
||||
client := NewClient(TestAppID, TestAppHash, opts)
|
||||
return client.Run(retryCtx, func(runCtx context.Context) error {
|
||||
if err := client.Auth().IfNecessary(runCtx, auth.NewFlow(
|
||||
auth.Test(crypto.DefaultRand(), opts.DC),
|
||||
auth.SendCodeOptions{},
|
||||
)); err != nil {
|
||||
return errors.Wrap(err, "auth flow")
|
||||
}
|
||||
|
||||
return cb(runCtx, client)
|
||||
})
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,31 @@
|
||||
package telegram
|
||||
|
||||
import (
|
||||
"crypto/rsa"
|
||||
"encoding/pem"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
func parseCDNKeys(keys ...tg.CDNPublicKey) ([]*rsa.PublicKey, error) {
|
||||
r := make([]*rsa.PublicKey, 0, len(keys))
|
||||
|
||||
for _, key := range keys {
|
||||
block, _ := pem.Decode([]byte(key.PublicKey))
|
||||
if block == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
key, err := crypto.ParseRSA(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "parse RSA from PEM")
|
||||
}
|
||||
|
||||
r = append(r, key)
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
@@ -0,0 +1,43 @@
|
||||
package telegram
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
func Test_parseCDNKeys(t *testing.T) {
|
||||
keys := []string{
|
||||
`-----BEGIN RSA PUBLIC KEY-----
|
||||
MIIBCgKCAQEA+Lf3PvgE1yxbJUCMaEAk
|
||||
V0QySTVpnaDjiednB5RbtNWjCeqSVakYHbqqGMIIv5WCGdFdrqOfMNcNSstPtSU6
|
||||
R9UmRw6tquOIykpSuUOje9H+4XVIKqujyL2ISdK+4ZOMl4hCMkqauw4bP1Sbr03v
|
||||
ZRQbU6qEA04V4j879BAyBVhr3WG9+Zi+t5XfGSTgSExPYEl8rZNHYNV5RB+BuroV
|
||||
H2HLTOpT/mJVfikYpgjfWF5ldezV4Wo9LSH0cZGSFIaeJl8d0A8Eiy5B9gtBO8mL
|
||||
+XfQRKOOmr7a4BM4Ro2de5rr2i2od7hYXd3DO9FRSl4y1zA8Am48Rfd95WHF3N/O
|
||||
mQIDAQAB
|
||||
-----END RSA PUBLIC KEY-----`,
|
||||
`-----BEGIN RSA PUBLIC KEY-----
|
||||
MIIBCgKCAQEAyu5PXyfp+VFLc2hKJsq/cvQ+wq9V2s1iGMMwcrkXrKAqX0S5QEcY
|
||||
W9b6pV5LulbsvNcxp/YniiSL4FsAja28B9fH//Y+AolWASomCB0NSVHwS1Pqfe3m
|
||||
GdLTwDmqU17tSWk/48+Kfn4B+WT85ZIKt8bOnABwnM1AtykX0zKwzm9yKcTX0MeY
|
||||
rwzgiOQax6J1cfgtLdxl8HVKT6wCOS1e43zpXMU+UoWqRqIan+J6q+ubi1yF4PWl
|
||||
DyDgJSw8uxlhNNMP4tAnshIRZ1ZZ25O/g58jw1qz5XMztZwLNA2pUxaFtyy1LdHC
|
||||
FRX7DdwIA/FdOzfWyXYLlCFaSX8K/6CnSQIDAQAB
|
||||
-----END RSA PUBLIC KEY-----`,
|
||||
}
|
||||
|
||||
cdnKeys := make([]tg.CDNPublicKey, 0, len(keys))
|
||||
for i, key := range keys {
|
||||
cdnKeys = append(cdnKeys, tg.CDNPublicKey{
|
||||
DCID: i + 1,
|
||||
PublicKey: key,
|
||||
})
|
||||
}
|
||||
|
||||
publicKeys, err := parseCDNKeys(cdnKeys...)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, publicKeys, 2)
|
||||
}
|
||||
@@ -0,0 +1,43 @@
|
||||
package telegram
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/constant"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
// Available MTProto default server addresses.
|
||||
//
|
||||
// See https://my.telegram.org/apps.
|
||||
const (
|
||||
AddrProduction = "149.154.167.50:443"
|
||||
AddrTest = "149.154.167.40:443"
|
||||
)
|
||||
|
||||
// Test-only credentials. Can be used with AddrTest and TestAuth to
|
||||
// test authentication.
|
||||
//
|
||||
// Reference:
|
||||
// - https://github.com/telegramdesktop/tdesktop/blob/5f665b8ecb48802cd13cfb48ec834b946459274a/docs/api_credentials.md
|
||||
const (
|
||||
TestAppID = constant.TestAppID
|
||||
TestAppHash = constant.TestAppHash
|
||||
)
|
||||
|
||||
// Config returns current config.
|
||||
func (c *Client) Config() tg.Config {
|
||||
return c.cfg.Load()
|
||||
}
|
||||
|
||||
func (c *Client) fetchConfig(ctx context.Context) {
|
||||
cfg, err := c.tg.HelpGetConfig(ctx)
|
||||
if err != nil {
|
||||
c.log.Warn("Got error on config update", zap.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
c.cfg.Store(*cfg)
|
||||
}
|
||||
@@ -0,0 +1,27 @@
|
||||
package telegram
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
func TestClient_fetchConfig(t *testing.T) {
|
||||
a := require.New(t)
|
||||
cfg := &tg.Config{
|
||||
ThisDC: 10,
|
||||
}
|
||||
client := newTestClient(func(id int64, body bin.Encoder) (bin.Encoder, error) {
|
||||
a.IsType(&tg.HelpGetConfigRequest{}, body)
|
||||
return cfg, nil
|
||||
})
|
||||
|
||||
a.NoError(client.processUpdates(&tg.Updates{
|
||||
Updates: []tg.UpdateClass{&tg.UpdateConfig{}},
|
||||
}))
|
||||
|
||||
a.Equal(*cfg, client.Config())
|
||||
}
|
||||
@@ -0,0 +1,262 @@
|
||||
package telegram
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"go.uber.org/atomic"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/clock"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/mtproto"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/oteltg"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/pool"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/session"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tdsync"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/dcs"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/internal/manager"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/internal/version"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
// UpdateHandler will be called on received updates from Telegram.
|
||||
type UpdateHandler interface {
|
||||
Handle(ctx context.Context, u tg.UpdatesClass) error
|
||||
}
|
||||
|
||||
// UpdateHandlerFunc type is an adapter to allow the use of
|
||||
// ordinary function as update handler.
|
||||
//
|
||||
// UpdateHandlerFunc(f) is an UpdateHandler that calls f.
|
||||
type UpdateHandlerFunc func(ctx context.Context, u tg.UpdatesClass) error
|
||||
|
||||
// Handle calls f(ctx, u)
|
||||
func (f UpdateHandlerFunc) Handle(ctx context.Context, u tg.UpdatesClass) error {
|
||||
return f(ctx, u)
|
||||
}
|
||||
|
||||
type clientStorage interface {
|
||||
Load(ctx context.Context) (*session.Data, error)
|
||||
Save(ctx context.Context, data *session.Data) error
|
||||
}
|
||||
|
||||
type clientConn interface {
|
||||
Run(ctx context.Context) error
|
||||
Invoke(ctx context.Context, input bin.Encoder, output bin.Decoder) error
|
||||
Ping(ctx context.Context) error
|
||||
}
|
||||
|
||||
// Client represents a MTProto client to Telegram.
|
||||
type Client struct {
|
||||
// Put migration in the header of the structure to ensure 64-bit alignment,
|
||||
// otherwise it will cause the atomic operation of connsCounter to panic.
|
||||
// DO NOT change the order of members arbitrarily.
|
||||
// Ref: https://pkg.go.dev/sync/atomic#pkg-note-BUG
|
||||
|
||||
// Connection factory fields.
|
||||
connsCounter atomic.Int64
|
||||
create connConstructor // immutable
|
||||
resolver dcs.Resolver // immutable
|
||||
onDead func() // immutable
|
||||
onAuthError func(error) // immutable
|
||||
onConnected func() // immutable
|
||||
newConnBackoff func() backoff.BackOff // immutable
|
||||
defaultMode manager.ConnMode // immutable
|
||||
|
||||
// Migration state.
|
||||
migrationTimeout time.Duration // immutable
|
||||
migration chan struct{}
|
||||
|
||||
// tg provides RPC calls via Client. Uses invoker below.
|
||||
tg *tg.Client // immutable
|
||||
// invoker implements tg.Invoker on top of Client and mw.
|
||||
invoker tg.Invoker // immutable
|
||||
// mw is list of middlewares used in invoker, can be blank.
|
||||
mw []Middleware // immutable
|
||||
|
||||
// Telegram device information.
|
||||
device DeviceConfig // immutable
|
||||
|
||||
// MTProto options.
|
||||
opts mtproto.Options // immutable
|
||||
|
||||
// DCList state.
|
||||
// Domain list (for websocket)
|
||||
domains map[int]string // immutable
|
||||
// Denotes to use Test DCs.
|
||||
testDC bool // immutable
|
||||
|
||||
// Connection state. Guarded by connMux.
|
||||
session *pool.SyncSession
|
||||
cfg *manager.AtomicConfig
|
||||
conn clientConn
|
||||
connBackoff atomic.Pointer[backoff.BackOff]
|
||||
connMux sync.Mutex
|
||||
|
||||
// Restart signal channel.
|
||||
restart chan struct{} // immutable
|
||||
|
||||
// Connections to non-primary DC.
|
||||
subConns map[int]CloseInvoker
|
||||
subConnsMux sync.Mutex
|
||||
sessions map[int]*pool.SyncSession
|
||||
sessionsMux sync.Mutex
|
||||
|
||||
// Wrappers for external world, like logs or PRNG.
|
||||
rand io.Reader // immutable
|
||||
log *zap.Logger // immutable
|
||||
clock clock.Clock // immutable
|
||||
|
||||
// Client context. Will be canceled by Run on exit.
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
// Client config.
|
||||
appID int // immutable
|
||||
appHash string // immutable
|
||||
// Session storage.
|
||||
storage clientStorage // immutable, nillable
|
||||
|
||||
// Ready signal channel, sends signal when client connection is ready.
|
||||
// Resets on reconnect.
|
||||
ready *tdsync.ResetReady // immutable
|
||||
|
||||
// Telegram updates handler.
|
||||
updateHandler UpdateHandler // immutable
|
||||
// Denotes that no update mode is enabled.
|
||||
noUpdatesMode bool // immutable
|
||||
|
||||
// Tracing.
|
||||
tracer trace.Tracer
|
||||
|
||||
// onTransfer is called in transfer.
|
||||
onTransfer AuthTransferHandler
|
||||
|
||||
// onSelfError is called on error calling Self().
|
||||
onSelfError func(ctx context.Context, err error) error
|
||||
}
|
||||
|
||||
// NewClient creates new unstarted client.
|
||||
func NewClient(appID int, appHash string, opt Options) *Client {
|
||||
opt.setDefaults()
|
||||
|
||||
mode := manager.ConnModeUpdates
|
||||
if opt.NoUpdates {
|
||||
mode = manager.ConnModeData
|
||||
}
|
||||
client := &Client{
|
||||
rand: opt.Random,
|
||||
log: opt.Logger,
|
||||
appID: appID,
|
||||
appHash: appHash,
|
||||
updateHandler: opt.UpdateHandler,
|
||||
session: pool.NewSyncSession(pool.Session{
|
||||
DC: opt.DC,
|
||||
}),
|
||||
domains: opt.DCList.Domains,
|
||||
testDC: opt.DCList.Test,
|
||||
cfg: manager.NewAtomicConfig(tg.Config{
|
||||
DCOptions: opt.DCList.Options,
|
||||
}),
|
||||
create: defaultConstructor(),
|
||||
resolver: opt.Resolver,
|
||||
defaultMode: mode,
|
||||
newConnBackoff: opt.ReconnectionBackoff,
|
||||
onDead: opt.OnDead,
|
||||
onAuthError: opt.OnAuthError,
|
||||
onConnected: opt.OnConnected,
|
||||
clock: opt.Clock,
|
||||
device: opt.Device,
|
||||
migrationTimeout: opt.MigrationTimeout,
|
||||
noUpdatesMode: opt.NoUpdates,
|
||||
mw: opt.Middlewares,
|
||||
onTransfer: opt.OnTransfer,
|
||||
onSelfError: opt.OnSelfError,
|
||||
}
|
||||
if opt.TracerProvider != nil {
|
||||
client.tracer = opt.TracerProvider.Tracer(oteltg.Name)
|
||||
}
|
||||
client.init()
|
||||
|
||||
// Including version into client logger to help with debugging.
|
||||
if v := version.GetVersion(); v != "" {
|
||||
client.log = client.log.With(zap.String("v", v))
|
||||
}
|
||||
|
||||
if opt.SessionStorage != nil {
|
||||
client.storage = &session.Loader{
|
||||
Storage: opt.SessionStorage,
|
||||
}
|
||||
}
|
||||
if opt.CustomSessionStorage != nil {
|
||||
client.storage = opt.CustomSessionStorage
|
||||
}
|
||||
|
||||
client.opts = mtproto.Options{
|
||||
PublicKeys: opt.PublicKeys,
|
||||
Random: opt.Random,
|
||||
Logger: opt.Logger,
|
||||
AckBatchSize: opt.AckBatchSize,
|
||||
AckInterval: opt.AckInterval,
|
||||
RetryInterval: opt.RetryInterval,
|
||||
MaxRetries: opt.MaxRetries,
|
||||
CompressThreshold: opt.CompressThreshold,
|
||||
MessageID: opt.MessageID,
|
||||
ExchangeTimeout: opt.ExchangeTimeout,
|
||||
DialTimeout: opt.DialTimeout,
|
||||
Clock: opt.Clock,
|
||||
PingInterval: opt.PingInterval,
|
||||
PingCallback: opt.PingCallback,
|
||||
PingTimeout: opt.PingTimeout,
|
||||
|
||||
Handler: SessionNotifierHandler{opt.OnSession},
|
||||
|
||||
Types: getTypesMapping(),
|
||||
|
||||
Tracer: client.tracer,
|
||||
}
|
||||
client.conn = client.createPrimaryConn(nil)
|
||||
|
||||
return client
|
||||
}
|
||||
|
||||
// SessionNotifierHandler is a handler which notifies when a session is
|
||||
// received.
|
||||
type SessionNotifierHandler struct {
|
||||
onSession func()
|
||||
}
|
||||
|
||||
// OnSession implements Handler.
|
||||
func (n SessionNotifierHandler) OnSession(s mtproto.Session) error {
|
||||
if n.onSession != nil {
|
||||
n.onSession()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// OnMessage implements Handler
|
||||
func (n SessionNotifierHandler) OnMessage(b *bin.Buffer) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// init sets fields which needs explicit initialization, like maps or channels.
|
||||
func (c *Client) init() {
|
||||
if c.domains == nil {
|
||||
c.domains = map[int]string{}
|
||||
}
|
||||
if c.cfg == nil {
|
||||
c.cfg = manager.NewAtomicConfig(tg.Config{})
|
||||
}
|
||||
c.ready = tdsync.NewResetReady()
|
||||
c.restart = make(chan struct{})
|
||||
c.migration = make(chan struct{}, 1)
|
||||
c.sessions = map[int]*pool.SyncSession{}
|
||||
c.subConns = map[int]CloseInvoker{}
|
||||
c.invoker = chainMiddlewares(InvokeFunc(c.invokeDirect), c.mw...)
|
||||
c.tg = tg.NewClient(c.invoker)
|
||||
}
|
||||
@@ -0,0 +1,294 @@
|
||||
package telegram_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zaptest"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/session"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tdsync"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/dcs"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/downloader"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/uploader"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tgerr"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tgtest"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tgtest/cluster"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tgtest/services/file"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/transport"
|
||||
)
|
||||
|
||||
type clusterSetup struct {
|
||||
TB testing.TB
|
||||
Cluster *cluster.Cluster
|
||||
Logger *zap.Logger
|
||||
}
|
||||
|
||||
type clientSetup struct {
|
||||
TB testing.TB
|
||||
Options telegram.Options
|
||||
Complete func()
|
||||
}
|
||||
|
||||
var user = &tg.User{
|
||||
ID: 10,
|
||||
AccessHash: 10,
|
||||
Username: "username",
|
||||
}
|
||||
|
||||
func testCluster(
|
||||
p dcs.Protocol,
|
||||
ws bool,
|
||||
setup func(q clusterSetup),
|
||||
run func(ctx context.Context, c clientSetup) error,
|
||||
) func(t *testing.T) {
|
||||
return func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
log := zaptest.NewLogger(t)
|
||||
defer func() { _ = log.Sync() }()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||||
defer cancel()
|
||||
g := tdsync.NewCancellableGroup(ctx)
|
||||
|
||||
c := cluster.NewCluster(cluster.Options{
|
||||
Web: ws,
|
||||
Logger: log.Named("cluster"),
|
||||
Protocol: p,
|
||||
})
|
||||
setup(clusterSetup{
|
||||
TB: t,
|
||||
Cluster: c,
|
||||
Logger: log,
|
||||
})
|
||||
g.Go(c.Up)
|
||||
|
||||
g.Go(func(ctx context.Context) error {
|
||||
select {
|
||||
// Wait for cluster readiness.
|
||||
case <-c.Ready():
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
return run(ctx, clientSetup{
|
||||
TB: t,
|
||||
Options: telegram.Options{
|
||||
UpdateHandler: telegram.UpdateHandlerFunc(func(ctx context.Context, u tg.UpdatesClass) error {
|
||||
// No-op update handler.
|
||||
return nil
|
||||
}),
|
||||
PublicKeys: c.Keys(),
|
||||
Resolver: c.Resolver(),
|
||||
Logger: log.Named("client"),
|
||||
SessionStorage: &session.StorageMemory{},
|
||||
DCList: c.List(),
|
||||
},
|
||||
Complete: cancel,
|
||||
})
|
||||
})
|
||||
|
||||
log.Debug("Waiting")
|
||||
if err := g.Wait(); err != nil && !errors.Is(err, context.Canceled) {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func testAllTransports(t *testing.T, test func(p dcs.Protocol) func(t *testing.T)) {
|
||||
t.Run("Abridged", test(transport.Abridged))
|
||||
t.Run("Intermediate", test(transport.Intermediate))
|
||||
t.Run("PaddedIntermediate", test(transport.PaddedIntermediate))
|
||||
t.Run("Full", test(transport.Full))
|
||||
}
|
||||
|
||||
func testTransport(p dcs.Protocol) func(t *testing.T) {
|
||||
testMessage := "ну че там с деньгами?"
|
||||
|
||||
return testCluster(p, false, func(s clusterSetup) {
|
||||
h := tgtest.TestTransport(s.TB, s.Logger.Named("handler"), testMessage)
|
||||
d := s.Cluster.Dispatch(2, "server")
|
||||
d.Handle(tg.MessagesSendMessageRequestTypeID, h)
|
||||
d.Handle(tg.UsersGetUsersRequestTypeID, h)
|
||||
}, func(ctx context.Context, c clientSetup) error {
|
||||
opts := c.Options
|
||||
opts.AckBatchSize = 1
|
||||
opts.AckInterval = time.Millisecond * 50
|
||||
opts.RetryInterval = time.Millisecond * 50
|
||||
logger := opts.Logger
|
||||
|
||||
dispatcher := tg.NewUpdateDispatcher()
|
||||
opts.UpdateHandler = dispatcher
|
||||
client := telegram.NewClient(1, "hash", opts)
|
||||
|
||||
waitForMessage := make(chan struct{})
|
||||
dispatcher.OnNewMessage(func(ctx context.Context, entities tg.Entities, update *tg.UpdateNewMessage) error {
|
||||
message := update.Message.(*tg.Message).Message
|
||||
logger.Info("Got message", zap.String("text", message))
|
||||
assert.Equal(c.TB, testMessage, message)
|
||||
if err := client.SendMessage(ctx, &tg.MessagesSendMessageRequest{
|
||||
Peer: &tg.InputPeerUser{},
|
||||
Message: "какими деньгами?",
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
logger.Info("Closing waitForMessage")
|
||||
close(waitForMessage)
|
||||
return nil
|
||||
})
|
||||
|
||||
return client.Run(ctx, func(ctx context.Context) error {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
c.TB.Error("Failed to wait for message")
|
||||
return ctx.Err()
|
||||
case <-waitForMessage:
|
||||
logger.Info("Returning")
|
||||
c.Complete()
|
||||
return nil
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestClientE2E(t *testing.T) {
|
||||
testAllTransports(t, testTransport)
|
||||
}
|
||||
|
||||
func testMigrate(p dcs.Protocol) func(t *testing.T) {
|
||||
test := func(ws bool) func(t *testing.T) {
|
||||
wait := make(chan struct{}, 1)
|
||||
return testCluster(p, ws, func(s clusterSetup) {
|
||||
c := s.Cluster
|
||||
c.Common().Vector(tg.UsersGetUsersRequestTypeID, user)
|
||||
c.Dispatch(1, "server").HandleFunc(tg.MessagesSendMessageRequestTypeID,
|
||||
func(server *tgtest.Server, req *tgtest.Request) error {
|
||||
m := &tg.MessagesSendMessageRequest{}
|
||||
if err := m.Decode(req.Buf); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
select {
|
||||
case wait <- struct{}{}:
|
||||
case <-req.RequestCtx.Done():
|
||||
return req.RequestCtx.Err()
|
||||
}
|
||||
return server.SendGZIP(req, &tg.Updates{})
|
||||
},
|
||||
)
|
||||
c.Dispatch(2, "migrate").HandleFunc(tg.MessagesSendMessageRequestTypeID,
|
||||
func(server *tgtest.Server, req *tgtest.Request) error {
|
||||
m := &tg.MessagesSendMessageRequest{}
|
||||
if err := m.Decode(req.Buf); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return server.SendErr(req, tgerr.New(303, "NETWORK_MIGRATE_1"))
|
||||
},
|
||||
)
|
||||
}, func(ctx context.Context, c clientSetup) error {
|
||||
opts := c.Options
|
||||
opts.MigrationTimeout = time.Minute
|
||||
client := telegram.NewClient(1, "hash", opts)
|
||||
return client.Run(ctx, func(ctx context.Context) error {
|
||||
if err := client.SendMessage(ctx, &tg.MessagesSendMessageRequest{
|
||||
Peer: &tg.InputPeerUser{},
|
||||
Message: "abc",
|
||||
}); err != nil {
|
||||
return errors.Wrap(err, "send")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-wait:
|
||||
c.Complete()
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
return nil
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
return func(t *testing.T) {
|
||||
t.Run("TCP", test(false))
|
||||
t.Run("Websocket", test(true))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMigrate(t *testing.T) {
|
||||
t.Run("Intermediate", testMigrate(transport.Intermediate))
|
||||
}
|
||||
|
||||
func testFiles(p dcs.Protocol) func(t *testing.T) {
|
||||
test := func(ws bool) func(t *testing.T) {
|
||||
return testCluster(p, ws, func(s clusterSetup) {
|
||||
c := s.Cluster
|
||||
c.Common().Vector(tg.UsersGetUsersRequestTypeID, user)
|
||||
f := file.NewService(file.Config{
|
||||
HashPartSize: 1024,
|
||||
})
|
||||
f.Register(c.Dispatch(2, "DC"))
|
||||
}, func(ctx context.Context, c clientSetup) error {
|
||||
client := telegram.NewClient(1, "hash", c.Options)
|
||||
defer c.Complete()
|
||||
return client.Run(ctx, func(ctx context.Context) error {
|
||||
raw := tg.NewClient(client)
|
||||
upd := uploader.NewUploader(raw)
|
||||
dwn := downloader.NewDownloader()
|
||||
|
||||
payloads := [][]byte{
|
||||
[]byte("data"),
|
||||
bytes.Repeat([]byte{10}, 1337),
|
||||
bytes.Repeat([]byte{42}, 16384),
|
||||
}
|
||||
|
||||
for _, payload := range payloads {
|
||||
f, err := upd.FromBytes(ctx, "10.jpg", payload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
vf, ok := f.(interface {
|
||||
GetID() (value int64)
|
||||
})
|
||||
if !ok {
|
||||
return errors.Errorf("%T", f)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
_, err = dwn.Download(raw, &tg.InputFileLocation{
|
||||
VolumeID: vf.GetID(),
|
||||
LocalID: 10,
|
||||
}).WithVerify(true).Stream(ctx, &b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !bytes.Equal(payload, b.Bytes()) {
|
||||
c.TB.Error("must be equal")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
return func(t *testing.T) {
|
||||
t.Run("TCP", test(false))
|
||||
t.Run("Websocket", test(true))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFiles(t *testing.T) {
|
||||
t.Run("Intermediate", testFiles(transport.Intermediate))
|
||||
}
|
||||
@@ -0,0 +1,120 @@
|
||||
package telegram_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
"go.uber.org/zap/zaptest"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/session"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tdsync"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/dcs"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/internal/e2etest"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/testutil"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/transport"
|
||||
)
|
||||
|
||||
func tryConnect(ctx context.Context, opts telegram.Options) error {
|
||||
client := telegram.NewClient(telegram.TestAppID, telegram.TestAppHash, opts)
|
||||
return client.Run(ctx, func(ctx context.Context) error {
|
||||
_, err := client.API().HelpGetNearestDC(ctx)
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
func testTransportExternal(resolver dcs.Resolver, storage session.Storage) func(t *testing.T) {
|
||||
return func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||
defer cancel()
|
||||
|
||||
log := zaptest.NewLogger(t)
|
||||
defer func() { _ = log.Sync() }()
|
||||
|
||||
require.NoError(t, tryConnect(ctx, telegram.Options{
|
||||
Logger: log.Named("client"),
|
||||
SessionStorage: storage,
|
||||
Resolver: resolver,
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
func TestExternalE2EConnect(t *testing.T) {
|
||||
testutil.SkipExternal(t)
|
||||
// To re-use session.
|
||||
storage := &session.StorageMemory{}
|
||||
|
||||
tcp := func(p dcs.Protocol) func(t *testing.T) {
|
||||
return testTransportExternal(dcs.Plain(dcs.PlainOptions{Protocol: p}), storage)
|
||||
}
|
||||
|
||||
t.Run("Abridged", tcp(transport.Abridged))
|
||||
t.Run("Intermediate", tcp(transport.Intermediate))
|
||||
t.Run("PaddedIntermediate", tcp(transport.PaddedIntermediate))
|
||||
t.Run("Full", tcp(transport.Full))
|
||||
|
||||
wsOpts := dcs.WebsocketOptions{}
|
||||
t.Run("Websocket", testTransportExternal(dcs.Websocket(wsOpts), storage))
|
||||
}
|
||||
|
||||
const dialog = `— Да?
|
||||
— Алё!
|
||||
— Да да?
|
||||
— Ну как там с деньгами?
|
||||
— А?
|
||||
— Как с деньгами-то там?
|
||||
— Чё с деньгами?
|
||||
— Чё?
|
||||
— Куда ты звонишь?
|
||||
— Тебе звоню.
|
||||
— Кому?
|
||||
— Ну тебе.`
|
||||
|
||||
func TestExternalE2EUsersDialog(t *testing.T) {
|
||||
testutil.SkipExternal(t)
|
||||
if v, _ := strconv.ParseBool(os.Getenv("GOTD_E2E_DIALOGS_BROKEN")); v {
|
||||
// TODO(ernado): enable when fixed
|
||||
t.Skip("Dialogs are broken.")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||||
defer cancel()
|
||||
log := zaptest.NewLogger(t).WithOptions(zap.IncreaseLevel(zapcore.InfoLevel))
|
||||
|
||||
cfg := e2etest.TestOptions{
|
||||
Logger: log,
|
||||
}
|
||||
suite := e2etest.NewSuite(t, cfg)
|
||||
|
||||
auth := make(chan *tg.User, 1)
|
||||
g := tdsync.NewLogGroup(ctx, log.Named("group"))
|
||||
|
||||
g.Go("echobot", func(ctx context.Context) error {
|
||||
if err := e2etest.NewEchoBot(suite, auth).Run(ctx); err != nil {
|
||||
return errors.Wrap(err, "echo bot")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
user, ok := <-auth
|
||||
if ok {
|
||||
g.Go("terentyev", func(ctx context.Context) error {
|
||||
defer g.Cancel()
|
||||
if err := e2etest.NewUser(suite, strings.Split(dialog, "\n"), user.Username).Run(ctx); err != nil {
|
||||
return errors.Wrap(err, "user")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
require.NoError(t, g.Wait())
|
||||
}
|
||||
@@ -0,0 +1,138 @@
|
||||
package telegram_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap/zaptest"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/session"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tdsync"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/dcs"
|
||||
)
|
||||
|
||||
type mtg struct {
|
||||
path string
|
||||
addr string
|
||||
}
|
||||
|
||||
type signalWriter struct {
|
||||
io.Writer
|
||||
wait *tdsync.Ready
|
||||
}
|
||||
|
||||
func (s signalWriter) Write(p []byte) (n int, err error) {
|
||||
s.wait.Signal()
|
||||
return s.Writer.Write(p)
|
||||
}
|
||||
|
||||
func (m mtg) run(ctx context.Context, secret string, out, err io.Writer, wait *tdsync.Ready) error {
|
||||
cmd := exec.CommandContext(ctx, m.path, "simple-run", "-d", m.addr, secret)
|
||||
cmd.Stdout = signalWriter{Writer: out, wait: wait}
|
||||
cmd.Stderr = signalWriter{Writer: err, wait: wait}
|
||||
cmd.Env = append([]string{"MTG_DEBUG=true", "MTG_TEST_DC=true"}, os.Environ()...)
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
func (m mtg) generateSecret(ctx context.Context, _ string) ([]byte, error) {
|
||||
args := []string{"generate-secret", "google.com"}
|
||||
|
||||
o, err := exec.CommandContext(ctx, m.path, args...).Output()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "execute command")
|
||||
}
|
||||
output := strings.TrimSpace(string(o))
|
||||
|
||||
r, err := base64.RawURLEncoding.DecodeString(output)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "decode secret %q", output)
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func testMTProxy(secretType string, m mtg, storage session.Storage) func(t *testing.T) {
|
||||
return func(t *testing.T) {
|
||||
a := require.New(t)
|
||||
logger := zaptest.NewLogger(t)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
secret, err := m.generateSecret(ctx, secretType)
|
||||
a.NoError(err)
|
||||
|
||||
// Store mtg logs to buffer and print it only if test failed.
|
||||
w := &bytes.Buffer{}
|
||||
t.Cleanup(func() {
|
||||
if t.Failed() {
|
||||
_, _ = io.Copy(os.Stdout, w)
|
||||
}
|
||||
})
|
||||
|
||||
g := tdsync.NewCancellableGroup(ctx)
|
||||
ready := tdsync.NewReady()
|
||||
g.Go(func(ctx context.Context) error {
|
||||
err := m.run(ctx, hex.EncodeToString(secret), w, w, ready)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
default:
|
||||
return err
|
||||
}
|
||||
})
|
||||
g.Go(func(ctx context.Context) error {
|
||||
defer g.Cancel()
|
||||
select {
|
||||
case <-ready.Ready():
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
resolver, err := dcs.MTProxy(m.addr, secret, dcs.MTProxyOptions{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return tryConnect(ctx, telegram.Options{
|
||||
Resolver: resolver,
|
||||
Logger: logger,
|
||||
SessionStorage: storage,
|
||||
DCList: dcs.Prod(),
|
||||
})
|
||||
})
|
||||
|
||||
a.NoError(g.Wait())
|
||||
}
|
||||
}
|
||||
|
||||
func TestExternalE2EMTProxy(t *testing.T) {
|
||||
addr, ok := os.LookupEnv("GOTD_MTPROXY_ADDR")
|
||||
if !ok {
|
||||
t.Skip("Skipped. Set GOTD_MTPROXY_ADDR to enable external e2e mtproxy test.")
|
||||
}
|
||||
|
||||
mtgPath, err := exec.LookPath("mtg")
|
||||
if err != nil {
|
||||
t.Fatal("mtg binary not found", err)
|
||||
}
|
||||
|
||||
// To re-use session.
|
||||
storage := &session.StorageMemory{}
|
||||
m := mtg{path: mtgPath, addr: addr}
|
||||
// TODO(tdakkota): test all proxy types (mtg v2 supports only faketls)
|
||||
for _, secretType := range []string{"tls"} {
|
||||
t.Run(strings.Title(secretType), testMTProxy(secretType, m, storage))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,155 @@
|
||||
package telegram
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/md5"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/atomic"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/mt"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/proto"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/rpc"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tdsync"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/dcs"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/testutil"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tgmock"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tmap"
|
||||
)
|
||||
|
||||
type testHandler func(id int64, body bin.Encoder) (bin.Encoder, error)
|
||||
|
||||
type testConn struct {
|
||||
id atomic.Int64
|
||||
engine *rpc.Engine
|
||||
ready *tdsync.Ready
|
||||
}
|
||||
|
||||
func (t *testConn) Ping(ctx context.Context) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (t *testConn) Ready() <-chan struct{} {
|
||||
return t.ready.Ready()
|
||||
}
|
||||
|
||||
func (t *testConn) Invoke(ctx context.Context, input bin.Encoder, output bin.Decoder) error {
|
||||
id := t.id.Inc() - 1
|
||||
return t.engine.Do(ctx, rpc.Request{
|
||||
Input: input,
|
||||
Output: output,
|
||||
MsgID: id,
|
||||
})
|
||||
}
|
||||
|
||||
func (testConn) Run(ctx context.Context) error { return nil }
|
||||
|
||||
func newTestClient(h testHandler) *Client {
|
||||
var engine *rpc.Engine
|
||||
|
||||
engine = rpc.New(func(ctx context.Context, msgID int64, seqNo int32, in bin.Encoder) error {
|
||||
if response, err := h(msgID, in); err != nil {
|
||||
engine.NotifyError(msgID, err)
|
||||
} else {
|
||||
var b bin.Buffer
|
||||
if err := b.Encode(response); err != nil {
|
||||
return err
|
||||
}
|
||||
return engine.NotifyResult(msgID, &b)
|
||||
}
|
||||
return nil
|
||||
}, rpc.Options{})
|
||||
|
||||
ready := tdsync.NewReady()
|
||||
ready.Signal()
|
||||
client := &Client{
|
||||
log: zap.NewNop(),
|
||||
rand: rand.New(rand.NewSource(1)),
|
||||
appID: TestAppID,
|
||||
appHash: TestAppHash,
|
||||
conn: &testConn{engine: engine, ready: ready},
|
||||
ctx: context.Background(),
|
||||
cancel: func() {},
|
||||
updateHandler: UpdateHandlerFunc(func(ctx context.Context, u tg.UpdatesClass) error { return nil }),
|
||||
onTransfer: noopOnTransfer,
|
||||
}
|
||||
client.init()
|
||||
|
||||
return client
|
||||
}
|
||||
|
||||
func mockClient(cb func(mock *tgmock.Mock, client *Client)) func(t *testing.T) {
|
||||
return func(t *testing.T) {
|
||||
t.Helper()
|
||||
mock := tgmock.NewRequire(t)
|
||||
client := newTestClient(testHandler(mock.Handler()))
|
||||
cb(mock, client)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureErrorIfCantConnect(t *testing.T) {
|
||||
testErr := testutil.TestError()
|
||||
dialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return nil, testErr
|
||||
}
|
||||
opts := Options{
|
||||
Resolver: dcs.Plain(dcs.PlainOptions{Dial: dialer}),
|
||||
ReconnectionBackoff: func() backoff.BackOff {
|
||||
return backoff.WithMaxRetries(backoff.NewConstantBackOff(time.Nanosecond), 2)
|
||||
},
|
||||
}
|
||||
|
||||
err := NewClient(1, "hash", opts).Run(context.Background(),
|
||||
func(ctx context.Context) error {
|
||||
return nil
|
||||
})
|
||||
require.ErrorIs(t, err, testErr)
|
||||
}
|
||||
|
||||
// newCorpusTracer will save incoming messages to corpus folder.
|
||||
//
|
||||
// Usage:
|
||||
//
|
||||
// client.trace.OnMessage = newCorpusTracer(t)
|
||||
//
|
||||
// nolint: deadcode,unused // optional
|
||||
func newCorpusTracer(t testing.TB) func(b *bin.Buffer) {
|
||||
types := tmap.New(
|
||||
mt.TypesMap(),
|
||||
tg.TypesMap(),
|
||||
proto.TypesMap(),
|
||||
)
|
||||
dir := filepath.Join("..", "_fuzz", "handle_message", "corpus")
|
||||
|
||||
return func(b *bin.Buffer) {
|
||||
id, _ := b.PeekID()
|
||||
h := md5.Sum(b.Buf)
|
||||
name := types.Get(id)
|
||||
if name == "" {
|
||||
name = "unknown"
|
||||
}
|
||||
if idx := strings.Index(name, "#"); idx > 0 {
|
||||
// Removing type id from name.
|
||||
name = name[:idx]
|
||||
}
|
||||
base := fmt.Sprintf("trace_%x_%s_%x",
|
||||
id, name, h,
|
||||
)
|
||||
assert.NoError(t, os.WriteFile(filepath.Join(dir, base), b.Buf, 0600))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,15 @@
|
||||
package telegram
|
||||
|
||||
import (
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/session"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tgerr"
|
||||
)
|
||||
|
||||
// SessionStorage is alias of mtproto.SessionStorage.
|
||||
type SessionStorage = session.Storage
|
||||
|
||||
// FileSessionStorage is alias of mtproto.FileSessionStorage.
|
||||
type FileSessionStorage = session.FileStorage
|
||||
|
||||
// Error represents RPC error returned to request.
|
||||
type Error = tgerr.Error
|
||||
@@ -0,0 +1,103 @@
|
||||
package telegram
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/mtproto"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/pool"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/dcs"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/internal/manager"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/transport"
|
||||
)
|
||||
|
||||
type clientHandler struct {
|
||||
client *Client
|
||||
}
|
||||
|
||||
func (c clientHandler) OnSession(cfg tg.Config, s mtproto.Session) (err error) {
|
||||
err = c.client.onSession(cfg, s)
|
||||
if c.client.opts.Handler != nil {
|
||||
c.client.opts.Handler.OnSession(s)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c clientHandler) OnMessage(b *bin.Buffer) error {
|
||||
return c.client.handleUpdates(b)
|
||||
}
|
||||
|
||||
func (c *Client) asHandler() manager.Handler {
|
||||
return clientHandler{
|
||||
client: c,
|
||||
}
|
||||
}
|
||||
|
||||
type connConstructor func(
|
||||
create mtproto.Dialer,
|
||||
mode manager.ConnMode,
|
||||
appID int,
|
||||
opts mtproto.Options,
|
||||
connOpts manager.ConnOptions,
|
||||
) pool.Conn
|
||||
|
||||
func defaultConstructor() connConstructor {
|
||||
return func(
|
||||
create mtproto.Dialer,
|
||||
mode manager.ConnMode,
|
||||
appID int,
|
||||
opts mtproto.Options,
|
||||
connOpts manager.ConnOptions,
|
||||
) pool.Conn {
|
||||
return manager.CreateConn(create, mode, appID, opts, connOpts)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) dcList() dcs.List {
|
||||
cfg := c.cfg.Load()
|
||||
return dcs.List{
|
||||
Options: cfg.DCOptions,
|
||||
Domains: c.domains,
|
||||
Test: c.testDC,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) primaryDC(dc int) mtproto.Dialer {
|
||||
return func(ctx context.Context) (transport.Conn, error) {
|
||||
return c.resolver.Primary(ctx, dc, c.dcList())
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) createPrimaryConn(setup manager.SetupCallback) pool.Conn {
|
||||
return c.createConn(0, c.defaultMode, setup, c.onDead, c.onAuthError)
|
||||
}
|
||||
|
||||
func (c *Client) createConn(
|
||||
id int64,
|
||||
mode manager.ConnMode,
|
||||
setup manager.SetupCallback,
|
||||
onDead func(),
|
||||
onAuthError func(error),
|
||||
) pool.Conn {
|
||||
opts, s := c.session.Options(c.opts)
|
||||
opts.Logger = c.log.Named("conn").With(
|
||||
zap.Int64("conn_id", id),
|
||||
zap.Int("dc_id", s.DC),
|
||||
)
|
||||
|
||||
return c.create(
|
||||
c.primaryDC(s.DC), mode, c.appID,
|
||||
opts, manager.ConnOptions{
|
||||
DC: s.DC,
|
||||
Test: c.testDC,
|
||||
Device: c.device,
|
||||
Handler: c.asHandler(),
|
||||
Setup: setup,
|
||||
OnDead: onDead,
|
||||
OnAuthError: onAuthError,
|
||||
},
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,193 @@
|
||||
package telegram
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"github.com/go-faster/errors"
|
||||
"go.uber.org/multierr"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/exchange"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tdsync"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/auth"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tgerr"
|
||||
)
|
||||
|
||||
func (c *Client) runUntilRestart(ctx context.Context) error {
|
||||
g := tdsync.NewCancellableGroup(ctx)
|
||||
g.Go(c.conn.Run)
|
||||
|
||||
// If we don't need updates, so there is no reason to subscribe for it.
|
||||
if !c.noUpdatesMode {
|
||||
g.Go(func(ctx context.Context) error {
|
||||
// Call method which requires authorization, to subscribe for updates.
|
||||
// See https://core.telegram.org/api/updates#subscribing-to-updates.
|
||||
self, err := c.Self(ctx)
|
||||
if err != nil {
|
||||
// Ignore unauthorized errors.
|
||||
if !auth.IsUnauthorized(err) {
|
||||
c.log.Warn("Got error on self", zap.Error(err))
|
||||
} else if c.onAuthError != nil {
|
||||
c.onAuthError(err)
|
||||
}
|
||||
if h := c.onSelfError; h != nil {
|
||||
// Help with https://github.com/gotd/td/issues/1458.
|
||||
if err := h(ctx, err); err != nil {
|
||||
return errors.Wrap(err, "onSelfError")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
c.log.Info("Got self", zap.String("username", self.Username))
|
||||
if c.onConnected != nil {
|
||||
c.onConnected()
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
g.Go(func(ctx context.Context) error {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-c.restart:
|
||||
c.log.Debug("Restart triggered")
|
||||
// Should call cancel() to cancel group.
|
||||
g.Cancel()
|
||||
|
||||
return nil
|
||||
}
|
||||
})
|
||||
|
||||
return g.Wait()
|
||||
}
|
||||
|
||||
func (c *Client) isPermanentError(err error) bool {
|
||||
// See https://github.com/gotd/td/issues/1458.
|
||||
if errors.Is(err, exchange.ErrKeyFingerprintNotFound) {
|
||||
return true
|
||||
}
|
||||
if tgerr.Is(err, "AUTH_KEY_UNREGISTERED", "SESSION_EXPIRED", "AUTH_KEY_DUPLICATED") {
|
||||
return true
|
||||
}
|
||||
if auth.IsUnauthorized(err) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *Client) reconnectUntilClosed(ctx context.Context) error {
|
||||
// Note that we currently have no timeout on connection, so this is
|
||||
// potentially eternal.
|
||||
b := tdsync.SyncBackoff(backoff.WithContext(c.newConnBackoff(), ctx))
|
||||
c.connBackoff.Store(&b)
|
||||
|
||||
return backoff.RetryNotify(func() error {
|
||||
if err := c.runUntilRestart(ctx); err != nil {
|
||||
if c.isPermanentError(err) {
|
||||
return backoff.Permanent(err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}, b, func(err error, timeout time.Duration) {
|
||||
c.log.Info("Restarting connection", zap.Error(err), zap.Duration("backoff", timeout))
|
||||
|
||||
c.connMux.Lock()
|
||||
c.conn = c.createPrimaryConn(nil)
|
||||
c.connMux.Unlock()
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Client) onReady() {
|
||||
c.log.Debug("Ready")
|
||||
c.ready.Signal()
|
||||
|
||||
if b := c.connBackoff.Load(); b != nil {
|
||||
// Reconnect faster next time.
|
||||
(*b).Reset()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) resetReady() {
|
||||
c.ready.Reset()
|
||||
}
|
||||
|
||||
// Run starts client session and blocks until connection close.
|
||||
// The f callback is called on successful session initialization and Run
|
||||
// will return on f() result.
|
||||
//
|
||||
// Context of callback will be canceled if fatal error is detected.
|
||||
// The ctx is used for background operations like updates handling or pools.
|
||||
//
|
||||
// See `examples/bg-run` and `contrib/gb` package for classic approach without
|
||||
// explicit callback, with Connect and defer close().
|
||||
func (c *Client) Run(ctx context.Context, f func(ctx context.Context) error) (err error) {
|
||||
if c.ctx != nil {
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
return errors.Wrap(c.ctx.Err(), "client already closed")
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// Setting up client context for background operations like updates
|
||||
// handling or pool creation.
|
||||
c.ctx, c.cancel = context.WithCancel(ctx)
|
||||
|
||||
c.log.Info("Starting")
|
||||
defer c.log.Info("Closed")
|
||||
// Cancel client on exit.
|
||||
defer c.cancel()
|
||||
defer func() {
|
||||
c.subConnsMux.Lock()
|
||||
defer c.subConnsMux.Unlock()
|
||||
|
||||
for _, conn := range c.subConns {
|
||||
if closeErr := conn.Close(); !errors.Is(closeErr, context.Canceled) {
|
||||
multierr.AppendInto(&err, closeErr)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
c.resetReady()
|
||||
if err := c.restoreConnection(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
g := tdsync.NewCancellableGroup(ctx)
|
||||
g.Go(c.reconnectUntilClosed)
|
||||
g.Go(func(ctx context.Context) error {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
c.cancel()
|
||||
return ctx.Err()
|
||||
case <-c.ctx.Done():
|
||||
return c.ctx.Err()
|
||||
}
|
||||
})
|
||||
g.Go(func(ctx context.Context) error {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-c.ready.Ready():
|
||||
if err := f(ctx); err != nil {
|
||||
return errors.Wrap(err, "callback")
|
||||
}
|
||||
// Should call cancel() to cancel ctx.
|
||||
// This will terminate c.conn.Run().
|
||||
c.log.Debug("Callback returned, stopping")
|
||||
g.Cancel()
|
||||
return nil
|
||||
}
|
||||
})
|
||||
if err := g.Wait(); !errors.Is(err, context.Canceled) {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,46 @@
|
||||
package telegram
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/exchange"
|
||||
)
|
||||
|
||||
type fingerprintNotFoundConn struct{}
|
||||
|
||||
func (m fingerprintNotFoundConn) Run(context.Context) error {
|
||||
return exchange.ErrKeyFingerprintNotFound
|
||||
}
|
||||
|
||||
func (m fingerprintNotFoundConn) Invoke(context.Context, bin.Encoder, bin.Decoder) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m fingerprintNotFoundConn) Ping(context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m fingerprintNotFoundConn) Ready() <-chan struct{} {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestClient_reconnectUntilClosed(t *testing.T) {
|
||||
client := Client{
|
||||
newConnBackoff: func() backoff.BackOff {
|
||||
return backoff.NewConstantBackOff(time.Nanosecond)
|
||||
},
|
||||
log: zap.NewNop(),
|
||||
}
|
||||
client.init()
|
||||
client.conn = fingerprintNotFoundConn{}
|
||||
|
||||
ctx := context.Background()
|
||||
require.ErrorIs(t, client.reconnectUntilClosed(ctx), exchange.ErrKeyFingerprintNotFound)
|
||||
}
|
||||
@@ -0,0 +1,181 @@
|
||||
package dcs
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rsa"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"math/big"
|
||||
"net"
|
||||
"sort"
|
||||
"sync"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
var dnsKey struct {
|
||||
rsa.PublicKey
|
||||
once sync.Once
|
||||
eBig *big.Int
|
||||
}
|
||||
|
||||
//nolint:gochecknoinits
|
||||
func init() {
|
||||
dnsKey.once.Do(func() {
|
||||
k, err := crypto.ParseRSAPublicKeys([]byte(`-----BEGIN RSA PUBLIC KEY-----
|
||||
MIIBCgKCAQEAyr+18Rex2ohtVy8sroGPBwXD3DOoKCSpjDqYoXgCqB7ioln4eDCF
|
||||
fOBUlfXUEvM/fnKCpF46VkAftlb4VuPDeQSS/ZxZYEGqHaywlroVnXHIjgqoxiAd
|
||||
192xRGreuXIaUKmkwlM9JID9WS2jUsTpzQ91L8MEPLJ/4zrBwZua8W5fECwCCh2c
|
||||
9G5IzzBm+otMS/YKwmR1olzRCyEkyAEjXWqBI9Ftv5eG8m0VkBzOG655WIYdyV0H
|
||||
fDK/NWcvGqa0w/nriMD6mDjKOryamw0OP9QuYgMN0C9xMW9y8SmP4h92OAWodTYg
|
||||
Y1hZCxdv6cs5UnW9+PWvS+WIbkh+GaWYxwIDAQAB
|
||||
-----END RSA PUBLIC KEY-----`))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
dnsKey.PublicKey = *k[0]
|
||||
dnsKey.eBig = big.NewInt(int64(dnsKey.E))
|
||||
})
|
||||
}
|
||||
|
||||
// parseDNSList decodes raw encrypted simple config.
|
||||
//
|
||||
// Notice that parseDNSList does not decode base64, user should do it manually.
|
||||
func parseDNSList(input [256]byte) (tg.HelpConfigSimple, error) {
|
||||
// See https://github.com/tdlib/td/blob/master/td/telegram/ConfigManager.cpp#L148.
|
||||
x := new(big.Int).SetBytes(input[:])
|
||||
y := new(big.Int).Exp(x, dnsKey.eBig, dnsKey.N)
|
||||
|
||||
dataRSA := make([]byte, 256)
|
||||
if !crypto.FillBytes(y, dataRSA) {
|
||||
return tg.HelpConfigSimple{}, errors.New("dataRSA has invalid size")
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(dataRSA[:32])
|
||||
if err != nil {
|
||||
return tg.HelpConfigSimple{}, err
|
||||
}
|
||||
d := cipher.NewCBCDecrypter(block, dataRSA[16:32])
|
||||
dataCBC := dataRSA[32:]
|
||||
d.CryptBlocks(dataCBC, dataCBC)
|
||||
|
||||
decrypted := dataCBC[:len(dataCBC)-16]
|
||||
decryptedHash := sha256.Sum256(decrypted)
|
||||
hash := dataCBC[len(dataCBC)-16:]
|
||||
|
||||
if !bytes.Equal(decryptedHash[:16], hash) {
|
||||
return tg.HelpConfigSimple{}, errors.New("hash mismatch")
|
||||
}
|
||||
|
||||
var cfg tg.HelpConfigSimple
|
||||
if err := cfg.Decode(&bin.Buffer{Buf: decrypted[4:]}); err != nil {
|
||||
return tg.HelpConfigSimple{}, err
|
||||
}
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
type sortByLen []string
|
||||
|
||||
func (s sortByLen) Len() int {
|
||||
return len(s)
|
||||
}
|
||||
|
||||
func (s sortByLen) Less(i, j int) bool {
|
||||
return len(s[i]) > len(s[j])
|
||||
}
|
||||
|
||||
func (s sortByLen) Swap(i, j int) {
|
||||
s[i], s[j] = s[j], s[i]
|
||||
}
|
||||
|
||||
// DNSConfig is DC connection config obtained from DNS.
|
||||
type DNSConfig struct {
|
||||
// Date field of HelpConfigSimple.
|
||||
Date int
|
||||
// Expires field of HelpConfigSimple.
|
||||
Expires int
|
||||
// Rules field of HelpConfigSimple.
|
||||
Rules []tg.AccessPointRule
|
||||
}
|
||||
|
||||
// Options returns DC options from this config.
|
||||
func (d DNSConfig) Options() (r []tg.DCOption) {
|
||||
convertIP := func(ip int) string {
|
||||
return net.IPv4(
|
||||
byte(ip),
|
||||
byte(ip>>8),
|
||||
byte(ip>>16),
|
||||
byte(ip>>24),
|
||||
).String()
|
||||
}
|
||||
for _, rule := range d.Rules {
|
||||
for _, ip := range rule.IPs {
|
||||
switch ip := ip.(type) {
|
||||
case *tg.IPPort:
|
||||
r = append(r, tg.DCOption{
|
||||
ID: rule.DCID,
|
||||
IPAddress: convertIP(ip.Ipv4),
|
||||
Port: ip.Port,
|
||||
})
|
||||
case *tg.IPPortSecret:
|
||||
r = append(r, tg.DCOption{
|
||||
TCPObfuscatedOnly: true,
|
||||
ID: rule.DCID,
|
||||
IPAddress: convertIP(ip.Ipv4),
|
||||
Port: ip.Port,
|
||||
Secret: ip.Secret,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// ParseDNSConfig parses tg.HelpConfigSimple from TXT response.
|
||||
func ParseDNSConfig(txt []string) (DNSConfig, error) {
|
||||
encoding := base64.StdEncoding
|
||||
const (
|
||||
decodedLen = 256
|
||||
encodedLen = 344
|
||||
)
|
||||
sort.Sort(sortByLen(txt))
|
||||
|
||||
var totalLength int
|
||||
for i := range txt {
|
||||
totalLength += len(txt[i])
|
||||
}
|
||||
if totalLength != encodedLen {
|
||||
return DNSConfig{}, errors.Errorf("invalid input length %d", totalLength)
|
||||
}
|
||||
|
||||
var (
|
||||
encoded [encodedLen]byte
|
||||
decoded [decodedLen]byte
|
||||
)
|
||||
n := 0
|
||||
for i := range txt {
|
||||
n += copy(encoded[n:], txt[i])
|
||||
}
|
||||
|
||||
if _, err := encoding.Decode(decoded[:], encoded[:]); err != nil {
|
||||
return DNSConfig{}, errors.Wrap(err, "decode")
|
||||
}
|
||||
|
||||
cfg, err := parseDNSList(decoded)
|
||||
if err != nil {
|
||||
return DNSConfig{}, errors.Wrap(err, "decrypt config")
|
||||
}
|
||||
|
||||
return DNSConfig{
|
||||
Date: cfg.Date,
|
||||
Expires: cfg.Expires,
|
||||
Rules: cfg.Rules,
|
||||
}, nil
|
||||
}
|
||||
@@ -0,0 +1,92 @@
|
||||
package dcs
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func testTXTResponse() []string {
|
||||
return []string{
|
||||
"LcmEoukF2bVjKwz3E+J9BsDdL+rv9lGqLQWIGXrWACT2ESk5xuOpA6Cz6klKRbhbwSiHOd2zC5PiR57j/OJHPpj4i+tw==",
|
||||
"umjjLFLpOKtPeW9zHLq2ypbMzg/zkqvPhvhr0bxrLZlgPQ04l2GpO/4qZgAx3tk3BDHbY6/gmG1e8eaFBq3YSqR5SZ5hQ1Cm5f4/" +
|
||||
"o67GYcPJClaf1TiHq3wVfsQ5OLnyJRw9A2ZfUfzIXxoSklPJrVdF/4hM1ZdUE0eWDAbmYf7JCeao8ecVVwKndd4CZHZS9wyf1T7DIUh95VpQ" +
|
||||
"sn2klLPA6gA/2YNXOh9gITvjZrKuXLwwh9hBHhPvxv",
|
||||
}
|
||||
}
|
||||
|
||||
func Test_ParseDNSConfig(t *testing.T) {
|
||||
t.Run("Good", func(t *testing.T) {
|
||||
a := require.New(t)
|
||||
|
||||
cfg, err := ParseDNSConfig(testTXTResponse())
|
||||
a.NoError(err)
|
||||
a.Equal(1565541126, cfg.Expires)
|
||||
a.Equal(1562949126, cfg.Date)
|
||||
a.Len(cfg.Rules, 1)
|
||||
|
||||
rule := cfg.Rules[0]
|
||||
a.Equal(2, rule.DCID)
|
||||
})
|
||||
|
||||
t.Run("Bad", func(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input []string
|
||||
}{
|
||||
{"Empty", nil},
|
||||
{"InvalidHash", func() []string {
|
||||
r := testTXTResponse()
|
||||
first := r[0]
|
||||
r[0] = string(first[0]+1) + first[1:]
|
||||
return r
|
||||
}()},
|
||||
{"InvalidBase64", func() []string {
|
||||
r := testTXTResponse()
|
||||
first := r[0]
|
||||
r[0] = string('#') + first[1:]
|
||||
return r
|
||||
}()},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := ParseDNSConfig(tt.input)
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkDNSConfig(b *testing.B) {
|
||||
message := testTXTResponse()
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
var (
|
||||
err error
|
||||
cfgSink DNSConfig
|
||||
)
|
||||
for i := 0; i < b.N; i++ {
|
||||
cfgSink, err = ParseDNSConfig(message)
|
||||
if cfgSink.Date == 0 || err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSConfig_Options(t *testing.T) {
|
||||
a := require.New(t)
|
||||
|
||||
cfg, err := ParseDNSConfig(testTXTResponse())
|
||||
a.NoError(err)
|
||||
|
||||
options := cfg.Options()
|
||||
a.Len(options, 1)
|
||||
option := options[0]
|
||||
a.Equal(2, option.ID)
|
||||
a.Equal(14544, option.Port)
|
||||
a.Equal("98.210.59.139", option.IPAddress)
|
||||
a.True(option.TCPObfuscatedOnly)
|
||||
}
|
||||
@@ -0,0 +1,2 @@
|
||||
// Package dcs contains Telegram DCs list and some helpers.
|
||||
package dcs
|
||||
@@ -0,0 +1,52 @@
|
||||
package dcs_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/proxy"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/dcs"
|
||||
)
|
||||
|
||||
func ExampleDialFunc() {
|
||||
// Dial using proxy from environment.
|
||||
|
||||
// Creating connection.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
defer cancel()
|
||||
client := telegram.NewClient(1, "appHash", telegram.Options{
|
||||
Resolver: dcs.Plain(dcs.PlainOptions{Dial: proxy.Dial}),
|
||||
})
|
||||
|
||||
_ = client.Run(ctx, func(ctx context.Context) error {
|
||||
fmt.Println("Started")
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func ExampleDialFunc_dialer() {
|
||||
// Dial using SOCKS5 proxy.
|
||||
|
||||
sock5, _ := proxy.SOCKS5("tcp", "IP:PORT", &proxy.Auth{
|
||||
User: "YOURUSERNAME",
|
||||
Password: "YOURPASSWORD",
|
||||
}, proxy.Direct)
|
||||
dc := sock5.(proxy.ContextDialer)
|
||||
|
||||
// Creating connection.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
defer cancel()
|
||||
client := telegram.NewClient(1, "appHash", telegram.Options{
|
||||
Resolver: dcs.Plain(dcs.PlainOptions{
|
||||
Dial: dc.DialContext,
|
||||
}),
|
||||
})
|
||||
|
||||
_ = client.Run(ctx, func(ctx context.Context) error {
|
||||
fmt.Println("Started")
|
||||
return nil
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,60 @@
|
||||
package dcs
|
||||
|
||||
import (
|
||||
"sort"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
// FindDCs searches DCs candidates from given config.
|
||||
func FindDCs(opts []tg.DCOption, dcID int, preferIPv6 bool) []tg.DCOption {
|
||||
// Preallocate slice.
|
||||
candidates := make([]tg.DCOption, 0, 32)
|
||||
|
||||
for _, candidateDC := range opts {
|
||||
if candidateDC.ID != dcID {
|
||||
continue
|
||||
}
|
||||
candidates = append(candidates, candidateDC)
|
||||
}
|
||||
|
||||
if len(candidates) < 1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
sort.Slice(candidates, func(i, j int) bool {
|
||||
l, r := candidates[i], candidates[j]
|
||||
|
||||
// If we prefer IPv6 and left is IPv6 and right is not, so then
|
||||
// left is smaller (would be before right).
|
||||
if preferIPv6 {
|
||||
if l.Ipv6 && !r.Ipv6 {
|
||||
return true
|
||||
}
|
||||
if !l.Ipv6 && r.Ipv6 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Also we prefer static addresses.
|
||||
return l.Static && !r.Static
|
||||
})
|
||||
|
||||
return candidates
|
||||
}
|
||||
|
||||
// FindPrimaryDCs searches new primary DC from given config.
|
||||
// Unlike FindDC, it filters CDNs and MediaOnly servers, returns error
|
||||
// if not found.
|
||||
func FindPrimaryDCs(opts []tg.DCOption, dcID int, preferIPv6 bool) []tg.DCOption {
|
||||
candidates := FindDCs(opts, dcID, preferIPv6)
|
||||
// Filter (in place) from SliceTricks.
|
||||
n := 0
|
||||
for _, opt := range candidates {
|
||||
if !opt.MediaOnly && !opt.CDN {
|
||||
candidates[n] = opt
|
||||
n++
|
||||
}
|
||||
}
|
||||
return candidates[:n]
|
||||
}
|
||||
@@ -0,0 +1,76 @@
|
||||
package dcs
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
func TestFindDCs(t *testing.T) {
|
||||
dcOptions := []tg.DCOption{
|
||||
{ID: 1, Ipv6: false},
|
||||
{ID: 1, Ipv6: true},
|
||||
{ID: 1, Ipv6: false, Static: true},
|
||||
|
||||
{ID: 2, Ipv6: true, Static: true},
|
||||
{ID: 2, Ipv6: true},
|
||||
{ID: 2, Ipv6: false},
|
||||
}
|
||||
for i := range dcOptions {
|
||||
dcOptions[i].IPAddress = fmt.Sprintf("DC: %d, Index: %d", dcOptions[i].ID, i)
|
||||
}
|
||||
|
||||
a := require.New(t)
|
||||
dc := FindDCs(dcOptions, -2, false)
|
||||
a.Empty(dc)
|
||||
dc = FindDCs(dcOptions, -2, true)
|
||||
a.Empty(dc)
|
||||
|
||||
// Prefer IPv6.
|
||||
dc = FindDCs(dcOptions, 1, true)
|
||||
a.True(dc[0].Ipv6)
|
||||
|
||||
// Prefer static.
|
||||
dc = FindDCs(dcOptions, 1, false)
|
||||
a.True(dc[0].Static)
|
||||
|
||||
// Prefer static and IPv6.
|
||||
dc = FindDCs(dcOptions, 2, true)
|
||||
a.True(dc[0].Static)
|
||||
a.True(dc[0].Ipv6)
|
||||
}
|
||||
|
||||
func TestFindPrimaryDCs(t *testing.T) {
|
||||
dcOptions := []tg.DCOption{
|
||||
{ID: 1, Ipv6: false},
|
||||
{ID: 1, Ipv6: true},
|
||||
{ID: 1, Ipv6: false, Static: true},
|
||||
|
||||
{ID: 2, Ipv6: true, Static: true, MediaOnly: true},
|
||||
{ID: 2, Ipv6: true, CDN: true},
|
||||
{ID: 2, Ipv6: false, CDN: true},
|
||||
}
|
||||
for i := range dcOptions {
|
||||
dcOptions[i].IPAddress = fmt.Sprintf("DC: %d, Index: %d", dcOptions[i].ID, i)
|
||||
}
|
||||
a := require.New(t)
|
||||
dc := FindPrimaryDCs(dcOptions, -2, false)
|
||||
a.Empty(dc)
|
||||
dc = FindPrimaryDCs(dcOptions, -2, true)
|
||||
a.Empty(dc)
|
||||
|
||||
// Prefer IPv6.
|
||||
dc = FindPrimaryDCs(dcOptions, 1, true)
|
||||
a.True(dc[0].Ipv6)
|
||||
|
||||
// Prefer static.
|
||||
dc = FindPrimaryDCs(dcOptions, 1, false)
|
||||
a.True(dc[0].Static)
|
||||
|
||||
// Filter CDN/MediaOnly/TCPo.
|
||||
dc = FindPrimaryDCs(dcOptions, 2, false)
|
||||
a.Empty(dc)
|
||||
}
|
||||
@@ -0,0 +1,15 @@
|
||||
package dcs
|
||||
|
||||
import "go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
|
||||
// List is a list of Telegram DC addresses and domains.
|
||||
type List struct {
|
||||
Options []tg.DCOption
|
||||
Domains map[int]string
|
||||
Test bool
|
||||
}
|
||||
|
||||
// Zero returns true if this List is zero value.
|
||||
func (d List) Zero() bool {
|
||||
return d.Options == nil && d.Domains == nil && !d.Test
|
||||
}
|
||||
@@ -0,0 +1,140 @@
|
||||
package dcs
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
"go.uber.org/multierr"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/mtproxy"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/mtproxy/obfuscator"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/proto/codec"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/transport"
|
||||
)
|
||||
|
||||
var _ Resolver = mtProxy{}
|
||||
|
||||
type mtProxy struct {
|
||||
dial DialFunc
|
||||
protocol protocol
|
||||
addr, network string
|
||||
|
||||
secret mtproxy.Secret
|
||||
tag [4]byte
|
||||
rand io.Reader
|
||||
}
|
||||
|
||||
func (m mtProxy) Primary(ctx context.Context, dc int, _ List) (transport.Conn, error) {
|
||||
return m.resolve(ctx, dc)
|
||||
}
|
||||
|
||||
func (m mtProxy) MediaOnly(ctx context.Context, dc int, _ List) (transport.Conn, error) {
|
||||
if dc > 0 {
|
||||
dc *= -1
|
||||
}
|
||||
return m.resolve(ctx, dc)
|
||||
}
|
||||
|
||||
func (m mtProxy) CDN(ctx context.Context, dc int, _ List) (transport.Conn, error) {
|
||||
return m.resolve(ctx, dc)
|
||||
}
|
||||
|
||||
func (m mtProxy) resolve(ctx context.Context, dc int) (transport.Conn, error) {
|
||||
c, err := m.dial(ctx, m.network, m.addr)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "connect to the MTProxy %q", m.addr)
|
||||
}
|
||||
|
||||
conn, err := m.handshakeConn(c, dc)
|
||||
if err != nil {
|
||||
err = errors.Wrap(err, "handshake")
|
||||
return nil, multierr.Combine(err, c.Close())
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// handshakeConn inits given net.Conn as MTProto connection.
|
||||
func (m mtProxy) handshakeConn(c net.Conn, dc int) (transport.Conn, error) {
|
||||
var obsConn *obfuscator.Conn
|
||||
switch m.secret.Type {
|
||||
case mtproxy.Simple, mtproxy.Secured:
|
||||
obsConn = obfuscator.Obfuscated2(m.rand, c)
|
||||
case mtproxy.TLS:
|
||||
obsConn = obfuscator.FakeTLS(m.rand, c)
|
||||
default:
|
||||
return nil, errors.Errorf("unknown MTProxy secret type: %d", m.secret.Type)
|
||||
}
|
||||
|
||||
secret := m.secret
|
||||
if err := obsConn.Handshake(m.tag, dc, secret); err != nil {
|
||||
return nil, errors.Wrap(err, "MTProxy handshake")
|
||||
}
|
||||
|
||||
transportConn, err := m.protocol.Handshake(obsConn)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "transport handshake")
|
||||
}
|
||||
|
||||
return transportConn, nil
|
||||
}
|
||||
|
||||
// MTProxyOptions is MTProxy resolver creation options.
|
||||
type MTProxyOptions struct {
|
||||
// Dial specifies the dial function for creating unencrypted TCP connections.
|
||||
// If Dial is nil, then the resolver dials using package net.
|
||||
Dial DialFunc
|
||||
// Network to use. Defaults to "tcp"
|
||||
Network string
|
||||
// Random source for MTProxy obfuscator.
|
||||
Rand io.Reader
|
||||
}
|
||||
|
||||
func (m *MTProxyOptions) setDefaults() {
|
||||
if m.Dial == nil {
|
||||
var d net.Dialer
|
||||
m.Dial = d.DialContext
|
||||
}
|
||||
if m.Network == "" {
|
||||
m.Network = "tcp"
|
||||
}
|
||||
if m.Rand == nil {
|
||||
m.Rand = crypto.DefaultRand()
|
||||
}
|
||||
}
|
||||
|
||||
// MTProxy creates MTProxy obfuscated DC resolver.
|
||||
//
|
||||
// See https://core.telegram.org/mtproto/mtproto-transports#transport-obfuscation.
|
||||
func MTProxy(addr string, secret []byte, opts MTProxyOptions) (Resolver, error) {
|
||||
s, err := mtproxy.ParseSecret(secret)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var cdc codec.Codec = codec.PaddedIntermediate{}
|
||||
tag := codec.PaddedIntermediateClientStart
|
||||
|
||||
// FIXME(tdakkota): some proxies forces to use Padded (Secure) Intermediate
|
||||
// even if secret denotes to use another transport type.
|
||||
if s.Type != mtproxy.TLS {
|
||||
if c, ok := s.ExpectedCodec(); ok {
|
||||
cdc = c
|
||||
tag = [4]byte{s.Tag, s.Tag, s.Tag, s.Tag}
|
||||
}
|
||||
}
|
||||
|
||||
opts.setDefaults()
|
||||
return mtProxy{
|
||||
dial: opts.Dial,
|
||||
addr: addr,
|
||||
network: opts.Network,
|
||||
protocol: transport.NewProtocol(func() transport.Codec { return codec.NoHeader{Codec: cdc} }),
|
||||
secret: s,
|
||||
tag: tag,
|
||||
rand: opts.Rand,
|
||||
}, nil
|
||||
}
|
||||
@@ -0,0 +1,226 @@
|
||||
package dcs
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
"go.uber.org/multierr"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/mtproxy"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/mtproxy/obfuscator"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/proto/codec"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/transport"
|
||||
)
|
||||
|
||||
var _ Resolver = plain{}
|
||||
|
||||
type plain struct {
|
||||
dial DialFunc
|
||||
protocol Protocol
|
||||
rand io.Reader
|
||||
network string
|
||||
noObfuscated bool
|
||||
preferIPv6 bool
|
||||
}
|
||||
|
||||
func (p plain) Primary(ctx context.Context, dc int, list List) (transport.Conn, error) {
|
||||
candidates := FindPrimaryDCs(list.Options, dc, p.preferIPv6)
|
||||
if p.noObfuscated {
|
||||
n := 0
|
||||
for _, x := range candidates {
|
||||
if !x.TCPObfuscatedOnly {
|
||||
candidates[n] = x
|
||||
n++
|
||||
}
|
||||
}
|
||||
candidates = candidates[:n]
|
||||
}
|
||||
return p.connect(ctx, dc, list.Test, candidates)
|
||||
}
|
||||
|
||||
func (p plain) MediaOnly(ctx context.Context, dc int, list List) (transport.Conn, error) {
|
||||
candidates := FindDCs(list.Options, dc, p.preferIPv6)
|
||||
// Filter (in place) from SliceTricks.
|
||||
n := 0
|
||||
for _, x := range candidates {
|
||||
if x.MediaOnly {
|
||||
candidates[n] = x
|
||||
n++
|
||||
}
|
||||
}
|
||||
return p.connect(ctx, dc, list.Test, candidates[:n])
|
||||
}
|
||||
|
||||
func (p plain) CDN(ctx context.Context, dc int, list List) (transport.Conn, error) {
|
||||
return nil, errors.Errorf("can't resolve %d: CDN is unsupported", dc)
|
||||
}
|
||||
|
||||
func (p plain) dialTransport(ctx context.Context, test bool, dc tg.DCOption) (_ transport.Conn, rerr error) {
|
||||
addr := net.JoinHostPort(dc.IPAddress, strconv.Itoa(dc.Port))
|
||||
|
||||
conn, err := p.dial(ctx, p.network, addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
if rerr != nil {
|
||||
multierr.AppendInto(&rerr, conn.Close())
|
||||
}
|
||||
}()
|
||||
|
||||
proto := p.protocol
|
||||
if dc.TCPObfuscatedOnly {
|
||||
dcID := dc.ID
|
||||
if test {
|
||||
if dcID < 0 {
|
||||
dcID -= 10000
|
||||
} else {
|
||||
dcID += 10000
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
cdc codec.Codec = codec.Intermediate{}
|
||||
tag = codec.IntermediateClientStart
|
||||
obfs = obfuscator.Obfuscated2
|
||||
)
|
||||
|
||||
secret, err := mtproxy.ParseSecret(dc.Secret)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "check DC secret")
|
||||
}
|
||||
|
||||
if secret.Type == mtproxy.TLS {
|
||||
obfs = obfuscator.FakeTLS
|
||||
} else if c, ok := secret.ExpectedCodec(); ok {
|
||||
tag = [4]byte{secret.Tag, secret.Tag, secret.Tag, secret.Tag}
|
||||
cdc = c
|
||||
}
|
||||
|
||||
obfsConn := obfs(p.rand, conn)
|
||||
if err := obfsConn.Handshake(tag, dcID, secret); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conn = obfsConn
|
||||
|
||||
proto = transport.NewProtocol(func() transport.Codec {
|
||||
return codec.NoHeader{Codec: cdc}
|
||||
})
|
||||
}
|
||||
|
||||
transportConn, err := proto.Handshake(conn)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "transport handshake")
|
||||
}
|
||||
|
||||
return transportConn, nil
|
||||
}
|
||||
|
||||
func (p plain) connect(ctx context.Context, dc int, test bool, dcOptions []tg.DCOption) (transport.Conn, error) {
|
||||
switch len(dcOptions) {
|
||||
case 0:
|
||||
return nil, errors.Errorf("no addresses for DC %d", dc)
|
||||
case 1:
|
||||
return p.dialTransport(ctx, test, dcOptions[0])
|
||||
}
|
||||
|
||||
type dialResult struct {
|
||||
conn transport.Conn
|
||||
err error
|
||||
}
|
||||
|
||||
// We use unbuffered channel to ensure that only one connection will be returned
|
||||
// and all other will be closed.
|
||||
results := make(chan dialResult)
|
||||
tryDial := func(ctx context.Context, option tg.DCOption) {
|
||||
conn, err := p.dialTransport(ctx, test, option)
|
||||
select {
|
||||
case results <- dialResult{
|
||||
conn: conn,
|
||||
err: err,
|
||||
}:
|
||||
case <-ctx.Done():
|
||||
if conn != nil {
|
||||
_ = conn.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
dialCtx, dialCancel := context.WithCancel(ctx)
|
||||
defer dialCancel()
|
||||
|
||||
for _, dcOption := range dcOptions {
|
||||
go tryDial(dialCtx, dcOption)
|
||||
}
|
||||
|
||||
remain := len(dcOptions)
|
||||
var rErr error
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case result := <-results:
|
||||
remain--
|
||||
if result.err != nil {
|
||||
rErr = multierr.Append(rErr, result.err)
|
||||
if remain == 0 {
|
||||
return nil, rErr
|
||||
}
|
||||
continue
|
||||
}
|
||||
return result.conn, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// PlainOptions is plain resolver creation options.
|
||||
type PlainOptions struct {
|
||||
// Protocol is the transport protocol to use. Defaults to intermediate.
|
||||
Protocol Protocol
|
||||
// Dial specifies the dial function for creating unencrypted TCP connections.
|
||||
// If Dial is nil, then the resolver dials using package net.
|
||||
Dial DialFunc
|
||||
// Random source for TCPObfuscated DCs.
|
||||
Rand io.Reader
|
||||
// Network to use. Defaults to "tcp".
|
||||
Network string
|
||||
// NoObfuscated denotes to filter out TCP Obfuscated Only DCs.
|
||||
NoObfuscated bool
|
||||
// PreferIPv6 gives IPv6 DCs higher precedence.
|
||||
// Default is to prefer IPv4 DCs over IPv6.
|
||||
PreferIPv6 bool
|
||||
}
|
||||
|
||||
func (m *PlainOptions) setDefaults() {
|
||||
if m.Protocol == nil {
|
||||
m.Protocol = transport.Intermediate
|
||||
}
|
||||
if m.Dial == nil {
|
||||
var d net.Dialer
|
||||
m.Dial = d.DialContext
|
||||
}
|
||||
if m.Rand == nil {
|
||||
m.Rand = crypto.DefaultRand()
|
||||
}
|
||||
if m.Network == "" {
|
||||
m.Network = "tcp"
|
||||
}
|
||||
}
|
||||
|
||||
// Plain creates plain DC resolver.
|
||||
func Plain(opts PlainOptions) Resolver {
|
||||
opts.setDefaults()
|
||||
return plain{
|
||||
dial: opts.Dial,
|
||||
protocol: opts.Protocol,
|
||||
rand: opts.Rand,
|
||||
network: opts.Network,
|
||||
noObfuscated: opts.NoObfuscated,
|
||||
preferIPv6: opts.PreferIPv6,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,134 @@
|
||||
package dcs
|
||||
|
||||
import (
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
// Prod returns production DC list.
|
||||
func Prod() List {
|
||||
// https://github.com/telegramdesktop/tdesktop/blob/dev/Telegram/SourceFiles/mtproto/mtproto_dc_options.cpp
|
||||
// Also available with client.API().HelpGetConfig(ctx) [tg.DCOption].
|
||||
// TODO(ernado): automate update from HelpGetConfig.
|
||||
return List{
|
||||
Options: []tg.DCOption{
|
||||
{
|
||||
ID: 1,
|
||||
IPAddress: "149.154.175.52",
|
||||
Port: 443,
|
||||
},
|
||||
{
|
||||
Static: true,
|
||||
ID: 1,
|
||||
IPAddress: "149.154.175.53",
|
||||
Port: 443,
|
||||
},
|
||||
{
|
||||
Ipv6: true,
|
||||
ID: 1,
|
||||
IPAddress: "2001:0b28:f23d:f001:0000:0000:0000:000a",
|
||||
Port: 443,
|
||||
},
|
||||
{
|
||||
ID: 2,
|
||||
IPAddress: "149.154.167.41",
|
||||
Port: 443,
|
||||
},
|
||||
{
|
||||
Static: true,
|
||||
ID: 2,
|
||||
IPAddress: "149.154.167.41",
|
||||
Port: 443,
|
||||
},
|
||||
{
|
||||
MediaOnly: true,
|
||||
ID: 2,
|
||||
IPAddress: "149.154.167.222",
|
||||
Port: 443,
|
||||
},
|
||||
{
|
||||
Ipv6: true,
|
||||
ID: 2,
|
||||
IPAddress: "2001:067c:04e8:f002:0000:0000:0000:000a",
|
||||
Port: 443,
|
||||
},
|
||||
{
|
||||
Ipv6: true,
|
||||
MediaOnly: true,
|
||||
ID: 2,
|
||||
IPAddress: "2001:067c:04e8:f002:0000:0000:0000:000b",
|
||||
Port: 443,
|
||||
},
|
||||
{
|
||||
ID: 3,
|
||||
IPAddress: "149.154.175.100",
|
||||
Port: 443,
|
||||
},
|
||||
{
|
||||
Static: true,
|
||||
ID: 3,
|
||||
IPAddress: "149.154.175.100",
|
||||
Port: 443,
|
||||
},
|
||||
{
|
||||
Ipv6: true,
|
||||
ID: 3,
|
||||
IPAddress: "2001:0b28:f23d:f003:0000:0000:0000:000a",
|
||||
Port: 443,
|
||||
},
|
||||
{
|
||||
ID: 4,
|
||||
IPAddress: "149.154.167.91",
|
||||
Port: 443,
|
||||
},
|
||||
{
|
||||
Static: true,
|
||||
ID: 4,
|
||||
IPAddress: "149.154.167.91",
|
||||
Port: 443,
|
||||
},
|
||||
{
|
||||
Ipv6: true,
|
||||
ID: 4,
|
||||
IPAddress: "2001:067c:04e8:f004:0000:0000:0000:000a",
|
||||
Port: 443,
|
||||
},
|
||||
{
|
||||
MediaOnly: true,
|
||||
ID: 4,
|
||||
IPAddress: "149.154.166.120",
|
||||
Port: 443,
|
||||
},
|
||||
{
|
||||
Ipv6: true,
|
||||
MediaOnly: true,
|
||||
ID: 4,
|
||||
IPAddress: "2001:067c:04e8:f004:0000:0000:0000:000b",
|
||||
Port: 443,
|
||||
},
|
||||
{
|
||||
Ipv6: true,
|
||||
ID: 5,
|
||||
IPAddress: "2001:0b28:f23f:f005:0000:0000:0000:000a",
|
||||
Port: 443,
|
||||
},
|
||||
{
|
||||
ID: 5,
|
||||
IPAddress: "91.108.56.191",
|
||||
Port: 443,
|
||||
},
|
||||
{
|
||||
Static: true,
|
||||
ID: 5,
|
||||
IPAddress: "91.108.56.191",
|
||||
Port: 443,
|
||||
},
|
||||
},
|
||||
Domains: map[int]string{
|
||||
1: "wss://pluto.web.telegram.org/apiws",
|
||||
2: "wss://venus.web.telegram.org/apiws",
|
||||
3: "wss://aurora.web.telegram.org/apiws",
|
||||
4: "wss://vesta.web.telegram.org/apiws",
|
||||
5: "wss://flora.web.telegram.org/apiws",
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,17 @@
|
||||
package dcs
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestProd(t *testing.T) {
|
||||
require.NotEmpty(t, Prod())
|
||||
|
||||
// Check copying.
|
||||
a := Prod().Options
|
||||
a[0].IPAddress = "10"
|
||||
b := Prod().Options
|
||||
require.NotEqual(t, "10", b[0].IPAddress)
|
||||
}
|
||||
@@ -0,0 +1,11 @@
|
||||
package dcs
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/transport"
|
||||
)
|
||||
|
||||
type protocol interface {
|
||||
Handshake(conn net.Conn) (transport.Conn, error)
|
||||
}
|
||||
@@ -0,0 +1,28 @@
|
||||
package dcs
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/transport"
|
||||
)
|
||||
|
||||
var _ Resolver = DefaultResolver()
|
||||
|
||||
// Resolver resolves DC and creates transport MTProto connection.
|
||||
type Resolver interface {
|
||||
Primary(ctx context.Context, dc int, list List) (transport.Conn, error)
|
||||
MediaOnly(ctx context.Context, dc int, list List) (transport.Conn, error)
|
||||
CDN(ctx context.Context, dc int, list List) (transport.Conn, error)
|
||||
}
|
||||
|
||||
// DialFunc connects to the address on the named network.
|
||||
type DialFunc func(ctx context.Context, network, addr string) (net.Conn, error)
|
||||
|
||||
// Protocol is MTProto transport protocol.
|
||||
//
|
||||
// See https://core.telegram.org/mtproto/mtproto-transports
|
||||
type Protocol interface {
|
||||
Codec() transport.Codec
|
||||
Handshake(conn net.Conn) (transport.Conn, error)
|
||||
}
|
||||
@@ -0,0 +1,9 @@
|
||||
//go:build js && wasm
|
||||
// +build js,wasm
|
||||
|
||||
package dcs
|
||||
|
||||
// DefaultResolver returns default DC resolver for current platform.
|
||||
func DefaultResolver() Resolver {
|
||||
return Websocket(WebsocketOptions{})
|
||||
}
|
||||
@@ -0,0 +1,9 @@
|
||||
//go:build !js || !wasm
|
||||
// +build !js !wasm
|
||||
|
||||
package dcs
|
||||
|
||||
// DefaultResolver returns default DC resolver for current platform.
|
||||
func DefaultResolver() Resolver {
|
||||
return Plain(PlainOptions{})
|
||||
}
|
||||
@@ -0,0 +1,59 @@
|
||||
package dcs
|
||||
|
||||
import "go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
|
||||
// Staging returns staging DC list.
|
||||
//
|
||||
// Deprecated: Use Test().
|
||||
func Staging() List {
|
||||
return Test()
|
||||
}
|
||||
|
||||
// Test returns test DC list.
|
||||
func Test() List {
|
||||
return List{
|
||||
Options: []tg.DCOption{
|
||||
{
|
||||
ID: 1,
|
||||
IPAddress: "149.154.175.10",
|
||||
Port: 443,
|
||||
},
|
||||
{
|
||||
ID: 1,
|
||||
Ipv6: true,
|
||||
IPAddress: "2001:0b28:f23d:f001:0000:0000:0000:000e",
|
||||
Port: 443,
|
||||
},
|
||||
{
|
||||
ID: 2,
|
||||
IPAddress: "149.154.167.40",
|
||||
Port: 443,
|
||||
},
|
||||
{
|
||||
ID: 2,
|
||||
Ipv6: true,
|
||||
IPAddress: "2001:067c:04e8:f002:0000:0000:0000:000e",
|
||||
Port: 443,
|
||||
},
|
||||
{
|
||||
ID: 3,
|
||||
IPAddress: "149.154.175.117",
|
||||
Port: 443,
|
||||
},
|
||||
{
|
||||
ID: 3,
|
||||
Ipv6: true,
|
||||
IPAddress: "2001:0b28:f23d:f003:0000:0000:0000:000e",
|
||||
Port: 443,
|
||||
},
|
||||
},
|
||||
Domains: map[int]string{
|
||||
1: "wss://pluto.web.telegram.org/apiws_test",
|
||||
2: "wss://venus.web.telegram.org/apiws_test",
|
||||
3: "wss://aurora.web.telegram.org/apiws_test",
|
||||
4: "wss://vesta.web.telegram.org/apiws_test",
|
||||
5: "wss://flora.web.telegram.org/apiws_test",
|
||||
},
|
||||
Test: true,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,17 @@
|
||||
package dcs
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestTestDCs(t *testing.T) {
|
||||
require.NotEmpty(t, Prod())
|
||||
|
||||
// Check copying.
|
||||
a := Test().Options
|
||||
a[0].IPAddress = "10"
|
||||
b := Test().Options
|
||||
require.NotEqual(t, "10", b[0].IPAddress)
|
||||
}
|
||||
@@ -0,0 +1,103 @@
|
||||
package dcs
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
|
||||
"github.com/coder/websocket"
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/mtproxy"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/mtproxy/obfuscator"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/proto/codec"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/transport"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/wsutil"
|
||||
)
|
||||
|
||||
var _ Resolver = ws{}
|
||||
|
||||
type ws struct {
|
||||
dialOptions *websocket.DialOptions
|
||||
protocol protocol
|
||||
|
||||
tag [4]byte
|
||||
rand io.Reader
|
||||
}
|
||||
|
||||
func (w ws) connect(ctx context.Context, dc int, domains map[int]string) (transport.Conn, error) {
|
||||
addr, ok := domains[dc]
|
||||
if !ok {
|
||||
return nil, errors.Errorf("domain for %d not found", dc)
|
||||
}
|
||||
|
||||
conn, resp, err := websocket.Dial(ctx, addr, w.dialOptions)
|
||||
if resp != nil && resp.Body != nil {
|
||||
_ = resp.Body.Close()
|
||||
}
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "dial ws")
|
||||
}
|
||||
obsConn := obfuscator.Obfuscated2(w.rand, wsutil.NetConn(conn))
|
||||
|
||||
if err := obsConn.Handshake(w.tag, dc, mtproxy.Secret{
|
||||
Secret: nil,
|
||||
Type: mtproxy.Simple,
|
||||
}); err != nil {
|
||||
return nil, errors.Wrap(err, "handshake")
|
||||
}
|
||||
|
||||
transportConn, err := w.protocol.Handshake(obsConn)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "transport handshake")
|
||||
}
|
||||
|
||||
return transportConn, nil
|
||||
}
|
||||
|
||||
func (w ws) Primary(ctx context.Context, dc int, list List) (transport.Conn, error) {
|
||||
return w.connect(ctx, dc, list.Domains)
|
||||
}
|
||||
|
||||
func (w ws) MediaOnly(ctx context.Context, dc int, list List) (transport.Conn, error) {
|
||||
return nil, errors.Errorf("can't resolve %d: MediaOnly is unsupported", dc)
|
||||
}
|
||||
|
||||
func (w ws) CDN(ctx context.Context, dc int, list List) (transport.Conn, error) {
|
||||
return nil, errors.Errorf("can't resolve %d: CDN is unsupported", dc)
|
||||
}
|
||||
|
||||
// WebsocketOptions is Websocket resolver creation options.
|
||||
type WebsocketOptions struct {
|
||||
// Dialer specifies the websocket dialer.
|
||||
// If Dialer is nil, then the resolver dials using websocket.DefaultDialer.
|
||||
DialOptions *websocket.DialOptions
|
||||
// Random source for MTProxy obfuscator.
|
||||
Rand io.Reader
|
||||
}
|
||||
|
||||
func (m *WebsocketOptions) setDefaults() {
|
||||
if m.DialOptions == nil {
|
||||
m.DialOptions = &websocket.DialOptions{Subprotocols: []string{
|
||||
"binary",
|
||||
}}
|
||||
}
|
||||
if m.Rand == nil {
|
||||
m.Rand = crypto.DefaultRand()
|
||||
}
|
||||
}
|
||||
|
||||
// Websocket creates Websocket DC resolver.
|
||||
//
|
||||
// See https://core.telegram.org/mtproto/transports#websocket.
|
||||
func Websocket(opts WebsocketOptions) Resolver {
|
||||
cdc := codec.Intermediate{}
|
||||
opts.setDefaults()
|
||||
|
||||
return ws{
|
||||
dialOptions: opts.DialOptions,
|
||||
protocol: transport.NewProtocol(func() transport.Codec { return codec.NoHeader{Codec: cdc} }),
|
||||
tag: cdc.ObfuscatedTag(),
|
||||
rand: opts.Rand,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,6 @@
|
||||
package telegram
|
||||
|
||||
import "go.mau.fi/mautrix-telegram/pkg/gotd/telegram/internal/manager"
|
||||
|
||||
// DeviceConfig is config which send when Telegram connection session created.
|
||||
type DeviceConfig = manager.DeviceConfig
|
||||
@@ -0,0 +1,2 @@
|
||||
// Package telegram implements Telegram client.
|
||||
package telegram
|
||||
@@ -0,0 +1,87 @@
|
||||
package downloader
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
"go.uber.org/multierr"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
// Builder is a download builder.
|
||||
type Builder struct {
|
||||
downloader *Downloader
|
||||
|
||||
schema schema
|
||||
hashes []tg.FileHash
|
||||
verify bool
|
||||
threads int
|
||||
}
|
||||
|
||||
func newBuilder(downloader *Downloader, schema schema) *Builder {
|
||||
return &Builder{
|
||||
schema: schema,
|
||||
threads: 1,
|
||||
downloader: downloader,
|
||||
}
|
||||
}
|
||||
|
||||
// WithThreads sets downloading goroutines limit.
|
||||
func (b *Builder) WithThreads(threads int) *Builder {
|
||||
if threads > 0 {
|
||||
b.threads = threads
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// WithVerify sets verify parameter.
|
||||
// If verify is true, file hashes will be checked
|
||||
// Verify is true by default for CDN downloads.
|
||||
func (b *Builder) WithVerify(verify bool) *Builder {
|
||||
b.verify = verify
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *Builder) reader() *reader {
|
||||
if b.verify {
|
||||
return verifiedReader(b.schema, newVerifier(b.schema, b.hashes...))
|
||||
}
|
||||
|
||||
return plainReader(b.schema, b.downloader.partSize)
|
||||
}
|
||||
|
||||
// Stream downloads file to given io.Writer.
|
||||
// NB: in this mode download can't be parallel.
|
||||
func (b *Builder) Stream(ctx context.Context, output io.Writer) (tg.StorageFileTypeClass, error) {
|
||||
return b.downloader.stream(ctx, b.reader(), output)
|
||||
}
|
||||
|
||||
// StreamToReader streams a file to the returned [io.Reader].
|
||||
// NB: in this mode download can't be parallel.
|
||||
func (b *Builder) StreamToReader(ctx context.Context) (tg.StorageFileTypeClass, io.Reader, error) {
|
||||
var tgDC int
|
||||
ctx = context.WithValue(ctx, "tg_dc", &tgDC)
|
||||
return b.downloader.streamToReader(ctx, b.reader())
|
||||
}
|
||||
|
||||
// Parallel downloads file to given io.WriterAt.
|
||||
func (b *Builder) Parallel(ctx context.Context, output io.WriterAt) (tg.StorageFileTypeClass, error) {
|
||||
return b.downloader.parallel(ctx, b.reader(), b.threads, output)
|
||||
}
|
||||
|
||||
// ToPath downloads file to given path.
|
||||
func (b *Builder) ToPath(ctx context.Context, path string) (_ tg.StorageFileTypeClass, err error) {
|
||||
f, err := os.Create(filepath.Clean(path))
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "create output file")
|
||||
}
|
||||
defer func() {
|
||||
multierr.AppendInto(&err, f.Close())
|
||||
}()
|
||||
|
||||
return b.Parallel(ctx, f)
|
||||
}
|
||||
@@ -0,0 +1,98 @@
|
||||
package downloader
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"encoding/binary"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
// ExpiredTokenError error is returned when Downloader get expired file token for CDN.
|
||||
// See https://core.telegram.org/constructor/upload.fileCdnRedirect.
|
||||
type ExpiredTokenError struct {
|
||||
*tg.UploadCDNFileReuploadNeeded
|
||||
}
|
||||
|
||||
// Error implements error interface.
|
||||
func (r *ExpiredTokenError) Error() string {
|
||||
return "redirect to master DC for requesting new file token"
|
||||
}
|
||||
|
||||
// cdn is a CDN DC download schema.
|
||||
// See https://core.telegram.org/cdn#getting-files-from-a-cdn.
|
||||
type cdn struct {
|
||||
cdn CDN
|
||||
client Client
|
||||
pool *bin.Pool
|
||||
redirect *tg.UploadFileCDNRedirect
|
||||
}
|
||||
|
||||
var _ schema = cdn{}
|
||||
|
||||
// decrypt decrypts file chunk from Telegram CDN.
|
||||
// See https://core.telegram.org/cdn#decrypting-files.
|
||||
func (c cdn) decrypt(src []byte, offset int64) ([]byte, error) {
|
||||
block, err := aes.NewCipher(c.redirect.EncryptionKey)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "create cipher")
|
||||
}
|
||||
|
||||
if block.BlockSize() != len(c.redirect.EncryptionIv) {
|
||||
return nil, errors.Errorf(
|
||||
"invalid IV or key length, block size %d != IV %d",
|
||||
block.BlockSize(), len(c.redirect.EncryptionIv),
|
||||
)
|
||||
}
|
||||
|
||||
// Copy IV to buffer from Pool.
|
||||
iv := c.pool.GetSize(len(c.redirect.EncryptionIv))
|
||||
defer c.pool.Put(iv)
|
||||
copy(iv.Buf, c.redirect.EncryptionIv)
|
||||
|
||||
// For IV, it should use the value of encryption_iv, modified in the following manner:
|
||||
// for each offset replace the last 4 bytes of the encryption_iv with offset / 16 in big-endian.
|
||||
binary.BigEndian.PutUint32(iv.Buf[iv.Len()-4:], uint32(offset/16))
|
||||
|
||||
dst := make([]byte, len(src))
|
||||
cipher.NewCTR(block, iv.Buf).XORKeyStream(dst, src)
|
||||
return dst, nil
|
||||
}
|
||||
|
||||
func (c cdn) Chunk(ctx context.Context, offset int64, limit int) (chunk, error) {
|
||||
r, err := c.cdn.UploadGetCDNFile(ctx, &tg.UploadGetCDNFileRequest{
|
||||
Offset: offset,
|
||||
Limit: limit,
|
||||
FileToken: c.redirect.FileToken,
|
||||
})
|
||||
if err != nil {
|
||||
return chunk{}, err
|
||||
}
|
||||
|
||||
switch result := r.(type) {
|
||||
case *tg.UploadCDNFile:
|
||||
data, err := c.decrypt(result.Bytes, offset)
|
||||
if err != nil {
|
||||
return chunk{}, err
|
||||
}
|
||||
|
||||
return chunk{
|
||||
data: data,
|
||||
}, nil
|
||||
case *tg.UploadCDNFileReuploadNeeded:
|
||||
return chunk{}, &ExpiredTokenError{UploadCDNFileReuploadNeeded: result}
|
||||
default:
|
||||
return chunk{}, errors.Errorf("unexpected type %T", r)
|
||||
}
|
||||
}
|
||||
|
||||
func (c cdn) Hashes(ctx context.Context, offset int64) ([]tg.FileHash, error) {
|
||||
return c.client.UploadGetCDNFileHashes(ctx, &tg.UploadGetCDNFileHashesRequest{
|
||||
FileToken: c.redirect.FileToken,
|
||||
Offset: offset,
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,36 @@
|
||||
package downloader
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
func Test_cdn_decrypt(t *testing.T) {
|
||||
testdata := make([]byte, 32)
|
||||
tests := []struct {
|
||||
name string
|
||||
key, iv []byte
|
||||
err bool
|
||||
}{
|
||||
{"Bad key", []byte{10}, nil, true},
|
||||
{"Bad IV", make([]byte, 32), nil, true},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
c := &cdn{
|
||||
redirect: &tg.UploadFileCDNRedirect{
|
||||
EncryptionKey: test.key,
|
||||
EncryptionIv: test.iv,
|
||||
},
|
||||
}
|
||||
_, err := c.decrypt(testdata, 0)
|
||||
if test.err {
|
||||
require.Error(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,34 @@
|
||||
package downloader
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
// CDN represents Telegram RPC client to CDN server.
|
||||
type CDN interface {
|
||||
UploadGetCDNFile(ctx context.Context, request *tg.UploadGetCDNFileRequest) (tg.UploadCDNFileClass, error)
|
||||
}
|
||||
|
||||
// Client represents Telegram RPC client.
|
||||
type Client interface {
|
||||
UploadGetFile(ctx context.Context, request *tg.UploadGetFileRequest) (tg.UploadFileClass, error)
|
||||
UploadGetFileHashes(ctx context.Context, request *tg.UploadGetFileHashesRequest) ([]tg.FileHash, error)
|
||||
|
||||
UploadReuploadCDNFile(ctx context.Context, request *tg.UploadReuploadCDNFileRequest) ([]tg.FileHash, error)
|
||||
UploadGetCDNFileHashes(ctx context.Context, request *tg.UploadGetCDNFileHashesRequest) ([]tg.FileHash, error)
|
||||
|
||||
UploadGetWebFile(ctx context.Context, request *tg.UploadGetWebFileRequest) (*tg.UploadWebFile, error)
|
||||
}
|
||||
|
||||
type chunk struct {
|
||||
data []byte
|
||||
tag tg.StorageFileTypeClass
|
||||
}
|
||||
|
||||
// schema is simple interface for different download schemas.
|
||||
type schema interface {
|
||||
Chunk(ctx context.Context, offset int64, limit int) (chunk, error)
|
||||
Hashes(ctx context.Context, offset int64) ([]tg.FileHash, error)
|
||||
}
|
||||
@@ -0,0 +1,48 @@
|
||||
// Package downloader contains downloading files helpers.
|
||||
package downloader
|
||||
|
||||
import (
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
// Downloader is Telegram file downloader.
|
||||
type Downloader struct {
|
||||
partSize int
|
||||
pool *bin.Pool
|
||||
}
|
||||
|
||||
const defaultPartSize = 512 * 1024 // 512 kb
|
||||
|
||||
// NewDownloader creates new Downloader.
|
||||
func NewDownloader() *Downloader {
|
||||
return new(Downloader).WithPartSize(defaultPartSize)
|
||||
}
|
||||
|
||||
// WithPartSize sets chunk size.
|
||||
// Must be divisible by 4KB.
|
||||
//
|
||||
// See https://core.telegram.org/api/files#downloading-files.
|
||||
func (d *Downloader) WithPartSize(partSize int) *Downloader {
|
||||
d.partSize = partSize
|
||||
d.pool = bin.NewPool(partSize)
|
||||
return d
|
||||
}
|
||||
|
||||
// Download creates Builder for plain downloads.
|
||||
func (d *Downloader) Download(rpc Client, location tg.InputFileLocationClass) *Builder {
|
||||
return newBuilder(d, master{
|
||||
client: rpc,
|
||||
precise: true,
|
||||
allowCDN: false,
|
||||
location: location,
|
||||
})
|
||||
}
|
||||
|
||||
// Web creates Builder for web files downloads.
|
||||
func (d *Downloader) Web(rpc Client, location tg.InputWebFileLocationClass) *Builder {
|
||||
return newBuilder(d, web{
|
||||
client: rpc,
|
||||
location: location,
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,300 @@
|
||||
package downloader
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/syncio"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/testutil"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
type mock struct {
|
||||
data []byte
|
||||
hashes mockHashes
|
||||
migrate bool
|
||||
err bool
|
||||
hashesErr bool
|
||||
redirect *tg.UploadFileCDNRedirect
|
||||
}
|
||||
|
||||
var testErr = testutil.TestError()
|
||||
|
||||
func (m mock) getPart(offset int64, limit int) []byte {
|
||||
length := len(m.data)
|
||||
if offset >= int64(length) {
|
||||
return []byte{}
|
||||
}
|
||||
|
||||
size := length - int(offset)
|
||||
if size > limit {
|
||||
size = limit
|
||||
}
|
||||
|
||||
r := make([]byte, size)
|
||||
copy(r, m.data[offset:])
|
||||
return r
|
||||
}
|
||||
|
||||
func (m mock) UploadGetFile(ctx context.Context, request *tg.UploadGetFileRequest) (tg.UploadFileClass, error) {
|
||||
if m.err {
|
||||
return nil, testErr
|
||||
}
|
||||
|
||||
if m.migrate {
|
||||
return m.redirect, nil
|
||||
}
|
||||
|
||||
return &tg.UploadFile{
|
||||
Bytes: m.getPart(request.Offset, request.Limit),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m mock) UploadGetFileHashes(ctx context.Context, request *tg.UploadGetFileHashesRequest) ([]tg.FileHash, error) {
|
||||
if m.hashesErr {
|
||||
return nil, testErr
|
||||
}
|
||||
|
||||
return m.hashes.Hashes(ctx, request.Offset)
|
||||
}
|
||||
|
||||
func (m mock) UploadReuploadCDNFile(ctx context.Context, request *tg.UploadReuploadCDNFileRequest) ([]tg.FileHash, error) {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (m mock) UploadGetCDNFile(ctx context.Context, request *tg.UploadGetCDNFileRequest) (tg.UploadCDNFileClass, error) {
|
||||
if m.err {
|
||||
return nil, testErr
|
||||
}
|
||||
|
||||
if m.migrate {
|
||||
return &tg.UploadCDNFileReuploadNeeded{
|
||||
RequestToken: []byte{1, 2, 3},
|
||||
}, nil
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(m.redirect.EncryptionKey)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "CDN mock cipher creation")
|
||||
}
|
||||
|
||||
iv := make([]byte, len(m.redirect.EncryptionIv))
|
||||
copy(iv, m.redirect.EncryptionIv)
|
||||
binary.BigEndian.PutUint32(iv[len(iv)-4:], uint32(request.Offset/16))
|
||||
|
||||
part := m.getPart(request.Offset, request.Limit)
|
||||
r := make([]byte, len(part))
|
||||
cipher.NewCTR(block, iv).XORKeyStream(r, part)
|
||||
return &tg.UploadCDNFile{
|
||||
Bytes: r,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m mock) UploadGetCDNFileHashes(ctx context.Context, request *tg.UploadGetCDNFileHashesRequest) ([]tg.FileHash, error) {
|
||||
if m.hashesErr {
|
||||
return nil, testErr
|
||||
}
|
||||
|
||||
return m.hashes.Hashes(ctx, request.Offset)
|
||||
}
|
||||
|
||||
func (m mock) UploadGetWebFile(ctx context.Context, request *tg.UploadGetWebFileRequest) (*tg.UploadWebFile, error) {
|
||||
if m.err {
|
||||
return nil, testErr
|
||||
}
|
||||
|
||||
return &tg.UploadWebFile{
|
||||
Bytes: m.getPart(int64(request.Offset), request.Limit),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func countHashes(data []byte, partSize int) (r [][]tg.FileHash) {
|
||||
actions := data
|
||||
batchSize := partSize
|
||||
batches := make([][]byte, 0, (len(actions)+batchSize-1)/batchSize)
|
||||
|
||||
for batchSize < len(actions) {
|
||||
actions, batches = actions[batchSize:], append(batches, actions[0:batchSize:batchSize])
|
||||
}
|
||||
batches = append(batches, actions)
|
||||
|
||||
currentRange := make([]tg.FileHash, 0, 10)
|
||||
offset := 0
|
||||
for _, batch := range batches {
|
||||
if len(currentRange) >= 10 {
|
||||
r = append(r, currentRange)
|
||||
currentRange = make([]tg.FileHash, 0, 10)
|
||||
}
|
||||
currentRange = append(currentRange, tg.FileHash{
|
||||
Offset: int64(offset),
|
||||
Limit: partSize,
|
||||
Hash: crypto.SHA256(batch),
|
||||
})
|
||||
offset += len(batch)
|
||||
|
||||
if len(batch) < partSize {
|
||||
break
|
||||
}
|
||||
}
|
||||
r = append(r, currentRange)
|
||||
return
|
||||
}
|
||||
|
||||
func Test_countHashes(t *testing.T) {
|
||||
a := require.New(t)
|
||||
data := bytes.Repeat([]byte{1, 2, 3, 4, 5}, 10)
|
||||
hashes := countHashes(data, 4)
|
||||
|
||||
a.NotEmpty(hashes)
|
||||
for _, hashRange := range hashes {
|
||||
for _, hash := range hashRange {
|
||||
from := hash.Offset
|
||||
to := int(hash.Offset) + hash.Limit
|
||||
if to > len(data) {
|
||||
to = len(data)
|
||||
}
|
||||
a.Equal(crypto.SHA256(data[from:to]), hash.Hash)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDownloader(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
key := make([]byte, 32)
|
||||
iv := make([]byte, aes.BlockSize)
|
||||
if _, err := io.ReadFull(rand.Reader, key); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := io.ReadFull(rand.Reader, iv); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
redirect := &tg.UploadFileCDNRedirect{
|
||||
DCID: 1,
|
||||
FileToken: []byte{10},
|
||||
EncryptionKey: key,
|
||||
EncryptionIv: iv,
|
||||
}
|
||||
|
||||
testData := make([]byte, defaultPartSize*2)
|
||||
if _, err := io.ReadFull(rand.Reader, testData); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
data []byte
|
||||
migrate bool
|
||||
err bool
|
||||
hashesErr bool
|
||||
}{
|
||||
{"5b", []byte{1, 2, 3, 4, 5}, false, false, false},
|
||||
{strconv.Itoa(len(testData)) + "b", testData, false, false, false},
|
||||
{"Error", []byte{}, false, true, false},
|
||||
{"HashesError", []byte{}, false, true, true},
|
||||
{"Migrate", []byte{}, true, false, false},
|
||||
}
|
||||
schemas := []struct {
|
||||
name string
|
||||
creator func(c Client, cdn CDN) *Builder
|
||||
}{
|
||||
{"Master", func(c Client, cdn CDN) *Builder {
|
||||
return NewDownloader().Download(c, nil)
|
||||
}},
|
||||
{"Web", func(c Client, cdn CDN) *Builder {
|
||||
return NewDownloader().Web(c, nil)
|
||||
}},
|
||||
}
|
||||
ways := []struct {
|
||||
name string
|
||||
action func(b *Builder) ([]byte, error)
|
||||
}{
|
||||
{"Stream", func(b *Builder) ([]byte, error) {
|
||||
output := new(bytes.Buffer)
|
||||
_, err := b.Stream(ctx, output)
|
||||
return output.Bytes(), err
|
||||
}},
|
||||
{"Parallel", func(b *Builder) ([]byte, error) {
|
||||
output := new(syncio.BufWriterAt)
|
||||
_, err := b.WithThreads(runtime.GOMAXPROCS(0)).Parallel(ctx, output)
|
||||
return output.Bytes(), err
|
||||
}},
|
||||
{"Parallel-OneThread", func(b *Builder) ([]byte, error) {
|
||||
output := new(syncio.BufWriterAt)
|
||||
_, err := b.WithThreads(1).Parallel(ctx, output)
|
||||
return output.Bytes(), err
|
||||
}},
|
||||
}
|
||||
options := []struct {
|
||||
name string
|
||||
action func(b *Builder) *Builder
|
||||
}{
|
||||
{"NoVerify", func(b *Builder) *Builder {
|
||||
return b.WithVerify(false)
|
||||
}},
|
||||
{"Verify", func(b *Builder) *Builder {
|
||||
return b.WithVerify(true)
|
||||
}},
|
||||
}
|
||||
|
||||
for _, schema := range schemas {
|
||||
t.Run(schema.name, func(t *testing.T) {
|
||||
for _, test := range tests {
|
||||
// Telegram can't redirect web file downloads.
|
||||
if schema.name == "Web" && test.migrate {
|
||||
continue
|
||||
}
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
for _, option := range options {
|
||||
// Telegram can't return hashes for web files.
|
||||
if schema.name == "Web" && option.name == "Verify" {
|
||||
continue
|
||||
}
|
||||
|
||||
t.Run(option.name, func(t *testing.T) {
|
||||
for _, way := range ways {
|
||||
t.Run(way.name, func(t *testing.T) {
|
||||
a := require.New(t)
|
||||
client := &mock{
|
||||
data: test.data,
|
||||
hashes: mockHashes{
|
||||
ranges: countHashes(test.data, 128*1024),
|
||||
},
|
||||
migrate: test.migrate,
|
||||
err: test.err,
|
||||
redirect: redirect,
|
||||
}
|
||||
|
||||
b := schema.creator(client, client)
|
||||
b = option.action(b)
|
||||
data, err := way.action(b)
|
||||
switch {
|
||||
case test.migrate:
|
||||
a.Error(err)
|
||||
case test.err:
|
||||
a.Error(err)
|
||||
default:
|
||||
a.NoError(err)
|
||||
a.True(bytes.Equal(test.data, data))
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,64 @@
|
||||
package downloader
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
// RedirectError error is returned when Downloader get CDN redirect.
|
||||
// See https://core.telegram.org/constructor/upload.fileCdnRedirect.
|
||||
type RedirectError struct {
|
||||
Redirect *tg.UploadFileCDNRedirect
|
||||
}
|
||||
|
||||
// Error implements error interface.
|
||||
func (r *RedirectError) Error() string {
|
||||
return "redirect to CDN DC " + strconv.Itoa(r.Redirect.DCID)
|
||||
}
|
||||
|
||||
// master is a master DC download schema.
|
||||
// See https://core.telegram.org/api/files#downloading-files.
|
||||
type master struct {
|
||||
client Client
|
||||
|
||||
precise bool
|
||||
allowCDN bool
|
||||
location tg.InputFileLocationClass
|
||||
}
|
||||
|
||||
var _ schema = master{}
|
||||
|
||||
func (c master) Chunk(ctx context.Context, offset int64, limit int) (chunk, error) {
|
||||
req := &tg.UploadGetFileRequest{
|
||||
Offset: offset,
|
||||
Limit: limit,
|
||||
Location: c.location,
|
||||
}
|
||||
req.SetCDNSupported(c.allowCDN)
|
||||
req.SetPrecise(c.precise)
|
||||
|
||||
r, err := c.client.UploadGetFile(ctx, req)
|
||||
if err != nil {
|
||||
return chunk{}, err
|
||||
}
|
||||
|
||||
switch result := r.(type) {
|
||||
case *tg.UploadFile:
|
||||
return chunk{data: result.Bytes, tag: result.Type}, nil
|
||||
case *tg.UploadFileCDNRedirect:
|
||||
return chunk{}, &RedirectError{Redirect: result}
|
||||
default:
|
||||
return chunk{}, errors.Errorf("unexpected type %T", r)
|
||||
}
|
||||
}
|
||||
|
||||
func (c master) Hashes(ctx context.Context, offset int64) ([]tg.FileHash, error) {
|
||||
return c.client.UploadGetFileHashes(ctx, &tg.UploadGetFileHashesRequest{
|
||||
Location: c.location,
|
||||
Offset: offset,
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,83 @@
|
||||
package downloader
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"sync"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/syncio"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tdsync"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
// nolint:gocognit
|
||||
func (d *Downloader) parallel(
|
||||
ctx context.Context, r *reader,
|
||||
threads int, w io.WriterAt,
|
||||
) (tg.StorageFileTypeClass, error) {
|
||||
var typ tg.StorageFileTypeClass
|
||||
typOnce := &sync.Once{}
|
||||
|
||||
ready := tdsync.NewReady()
|
||||
g := tdsync.NewCancellableGroup(ctx)
|
||||
toWrite := make(chan block, threads)
|
||||
|
||||
stop := func(t tg.StorageFileTypeClass) {
|
||||
typOnce.Do(func() {
|
||||
typ = t
|
||||
})
|
||||
ready.Signal()
|
||||
}
|
||||
|
||||
// Download loop
|
||||
g.Go(func(ctx context.Context) error {
|
||||
downloads := tdsync.NewCancellableGroup(ctx)
|
||||
defer close(toWrite)
|
||||
|
||||
for i := 0; i < threads; i++ {
|
||||
downloads.Go(func(ctx context.Context) error {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-ready.Ready():
|
||||
return nil
|
||||
default:
|
||||
}
|
||||
|
||||
b, err := r.Next(ctx)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "get file")
|
||||
}
|
||||
|
||||
// If returned chunk is zero, that means we read all file.
|
||||
n := len(b.data)
|
||||
if n < 1 {
|
||||
stop(b.tag)
|
||||
return nil
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case toWrite <- b:
|
||||
}
|
||||
|
||||
if b.last() {
|
||||
stop(b.tag)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
return downloads.Wait()
|
||||
})
|
||||
|
||||
// Write loop
|
||||
g.Go(writeAtLoop(syncio.NewWriterAt(w), toWrite))
|
||||
|
||||
return typ, g.Wait()
|
||||
}
|
||||
@@ -0,0 +1,107 @@
|
||||
package downloader
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tgerr"
|
||||
)
|
||||
|
||||
type block struct {
|
||||
chunk
|
||||
offset int64
|
||||
partSize int
|
||||
}
|
||||
|
||||
// last compares partSize and chunk length to determine last part.
|
||||
func (b block) last() bool {
|
||||
// If returned chunk is smaller than requested part, it seems
|
||||
// it is last part.
|
||||
return len(b.data) < b.partSize
|
||||
}
|
||||
|
||||
type reader struct {
|
||||
sch schema // immutable
|
||||
verifier *verifier // immutable
|
||||
partSize int // immutable
|
||||
|
||||
offset int64
|
||||
offsetMux sync.Mutex
|
||||
}
|
||||
|
||||
func verifiedReader(sch schema, verifier *verifier) *reader {
|
||||
return &reader{
|
||||
sch: sch,
|
||||
verifier: verifier,
|
||||
}
|
||||
}
|
||||
|
||||
func plainReader(sch schema, partSize int) *reader {
|
||||
return &reader{
|
||||
sch: sch,
|
||||
partSize: partSize,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *reader) Next(ctx context.Context) (block, error) {
|
||||
if r.verifier != nil {
|
||||
return r.nextHashed(ctx)
|
||||
}
|
||||
|
||||
return r.nextPlain(ctx)
|
||||
}
|
||||
|
||||
func (r *reader) nextHashed(ctx context.Context) (block, error) {
|
||||
// Fetch next hashes.
|
||||
hash, ok, err := r.verifier.next(ctx)
|
||||
if err != nil {
|
||||
return block{}, err
|
||||
}
|
||||
if !ok {
|
||||
return block{}, nil
|
||||
}
|
||||
|
||||
// Get next chunk.
|
||||
b, err := r.next(ctx, hash.Offset, hash.Limit)
|
||||
if err != nil {
|
||||
return block{}, err
|
||||
}
|
||||
|
||||
// Verify chunk.
|
||||
if !r.verifier.verify(hash, b.data) {
|
||||
return block{}, ErrHashMismatch
|
||||
}
|
||||
|
||||
return b, nil
|
||||
}
|
||||
|
||||
func (r *reader) nextPlain(ctx context.Context) (block, error) {
|
||||
r.offsetMux.Lock()
|
||||
offset := r.offset
|
||||
r.offset += int64(r.partSize)
|
||||
r.offsetMux.Unlock()
|
||||
|
||||
return r.next(ctx, offset, r.partSize)
|
||||
}
|
||||
|
||||
func (r *reader) next(ctx context.Context, offset int64, limit int) (block, error) {
|
||||
for {
|
||||
ch, err := r.sch.Chunk(ctx, offset, limit)
|
||||
|
||||
if flood, err := tgerr.FloodWait(ctx, err); err != nil {
|
||||
if flood || tgerr.Is(err, tg.ErrTimeout) {
|
||||
continue
|
||||
}
|
||||
return block{}, errors.Wrap(err, "get next chunk")
|
||||
}
|
||||
|
||||
return block{
|
||||
chunk: ch,
|
||||
offset: offset,
|
||||
partSize: r.partSize,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,48 @@
|
||||
package downloader
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
)
|
||||
|
||||
func writeAtLoop(w io.WriterAt, toWrite <-chan block) func(context.Context) error {
|
||||
return func(ctx context.Context) error {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case part, ok := <-toWrite:
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err := w.WriteAt(part.data, part.offset)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "write output")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func writeLoop(w io.Writer, toWrite <-chan block) func(context.Context) error {
|
||||
return func(ctx context.Context) error {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case part, ok := <-toWrite:
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err := w.Write(part.data)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "write output")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,54 @@
|
||||
package downloader
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tdsync"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
func (d *Downloader) stream(ctx context.Context, r *reader, w io.Writer) (tg.StorageFileTypeClass, error) {
|
||||
var typ tg.StorageFileTypeClass
|
||||
|
||||
g := tdsync.NewCancellableGroup(ctx)
|
||||
toWrite := make(chan block, 1)
|
||||
|
||||
stop := func(t tg.StorageFileTypeClass) {
|
||||
typ = t
|
||||
close(toWrite)
|
||||
}
|
||||
// Download loop
|
||||
g.Go(func(ctx context.Context) error {
|
||||
for {
|
||||
b, err := r.Next(ctx)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "get file")
|
||||
}
|
||||
|
||||
n := len(b.data)
|
||||
if n < 1 {
|
||||
stop(b.tag)
|
||||
return nil
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case toWrite <- b:
|
||||
}
|
||||
|
||||
if b.last() {
|
||||
stop(b.tag)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// Write loop
|
||||
g.Go(writeLoop(w, toWrite))
|
||||
|
||||
return typ, g.Wait()
|
||||
}
|
||||
@@ -0,0 +1,49 @@
|
||||
package downloader
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
type streamReader struct {
|
||||
ctx context.Context
|
||||
reader *reader
|
||||
curBlock block
|
||||
last bool
|
||||
}
|
||||
|
||||
var _ io.Reader = (*streamReader)(nil)
|
||||
|
||||
func (s *streamReader) Read(p []byte) (n int, err error) {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
return 0, s.ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
if len(s.curBlock.data) == 0 {
|
||||
if s.last {
|
||||
return 0, io.EOF
|
||||
} else {
|
||||
s.curBlock, err = s.reader.Next(s.ctx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
s.last = s.curBlock.last()
|
||||
}
|
||||
}
|
||||
|
||||
n = copy(p, s.curBlock.data)
|
||||
s.curBlock.data = s.curBlock.data[n:]
|
||||
return
|
||||
}
|
||||
|
||||
func (d *Downloader) streamToReader(ctx context.Context, r *reader) (tg.StorageFileTypeClass, io.Reader, error) {
|
||||
first, err := r.Next(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return first.tag, &streamReader{ctx, r, first, first.last()}, nil
|
||||
}
|
||||
@@ -0,0 +1,104 @@
|
||||
package downloader
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"sort"
|
||||
"sync"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tgerr"
|
||||
)
|
||||
|
||||
// ErrHashMismatch means that download hash verification was failed.
|
||||
var ErrHashMismatch = errors.New("file hash mismatch")
|
||||
|
||||
type verifier struct {
|
||||
client schema
|
||||
|
||||
hashes []tg.FileHash
|
||||
offset int64
|
||||
mux sync.Mutex
|
||||
}
|
||||
|
||||
func newVerifier(client schema, hashes ...tg.FileHash) *verifier {
|
||||
r := make([]tg.FileHash, len(hashes))
|
||||
|
||||
copy(r, hashes)
|
||||
sort.SliceStable(r, func(i, j int) bool {
|
||||
return r[i].Offset < r[j].Offset
|
||||
})
|
||||
|
||||
return &verifier{client: client, hashes: r}
|
||||
}
|
||||
|
||||
func (v *verifier) pop() (tg.FileHash, bool) {
|
||||
if len(v.hashes) < 1 {
|
||||
return tg.FileHash{}, false
|
||||
}
|
||||
|
||||
// Pop and move.
|
||||
hash := v.hashes[0]
|
||||
copy(v.hashes, v.hashes[1:])
|
||||
v.hashes[len(v.hashes)-1] = tg.FileHash{}
|
||||
v.hashes = v.hashes[:len(v.hashes)-1]
|
||||
|
||||
return hash, true
|
||||
}
|
||||
|
||||
func (v *verifier) update(hashes ...tg.FileHash) (tg.FileHash, bool) {
|
||||
// If result is empty and queue is empty, so we can't return next hash.
|
||||
if len(hashes) < 1 {
|
||||
return tg.FileHash{}, false
|
||||
}
|
||||
|
||||
// Sort hashes by offset.
|
||||
// Usually Telegram server returns sorted parts, but...
|
||||
// you never known what can they do.
|
||||
sort.SliceStable(hashes, func(i, j int) bool {
|
||||
return hashes[i].Offset < hashes[j].Offset
|
||||
})
|
||||
|
||||
last := hashes[len(hashes)-1]
|
||||
// Check if we have reached the end.
|
||||
// If current state offset is equal the last offset + limit (right border)
|
||||
// then we got all hashes.
|
||||
if last.Offset == v.offset-int64(last.Limit) {
|
||||
return tg.FileHash{}, false
|
||||
}
|
||||
|
||||
// Otherwise, we update current offset and add hashes to the end of queue.
|
||||
v.offset = last.Offset + int64(last.Limit)
|
||||
v.hashes = append(v.hashes, hashes...)
|
||||
return v.pop()
|
||||
}
|
||||
|
||||
func (v *verifier) next(ctx context.Context) (tg.FileHash, bool, error) {
|
||||
v.mux.Lock()
|
||||
defer v.mux.Unlock()
|
||||
|
||||
hash, ok := v.pop()
|
||||
if ok {
|
||||
return hash, ok, nil
|
||||
}
|
||||
|
||||
for {
|
||||
hashes, err := v.client.Hashes(ctx, v.offset)
|
||||
if flood, err := tgerr.FloodWait(ctx, err); err != nil {
|
||||
if flood || tgerr.Is(err, tg.ErrTimeout) {
|
||||
continue
|
||||
}
|
||||
return tg.FileHash{}, false, errors.Wrap(err, "get hashes")
|
||||
}
|
||||
|
||||
hash, ok = v.update(hashes...)
|
||||
return hash, ok, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (v *verifier) verify(hash tg.FileHash, data []byte) bool {
|
||||
return bytes.Equal(crypto.SHA256(data), hash.Hash)
|
||||
}
|
||||
@@ -0,0 +1,104 @@
|
||||
package downloader
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
var hashRanges = [][]tg.FileHash{
|
||||
{
|
||||
tg.FileHash{Offset: 0, Limit: 131072},
|
||||
tg.FileHash{Offset: 131072, Limit: 131072},
|
||||
tg.FileHash{Offset: 262144, Limit: 131072},
|
||||
tg.FileHash{Offset: 393216, Limit: 131072},
|
||||
tg.FileHash{Offset: 524288, Limit: 131072},
|
||||
tg.FileHash{Offset: 655360, Limit: 131072},
|
||||
tg.FileHash{Offset: 786432, Limit: 131072},
|
||||
tg.FileHash{Offset: 917504, Limit: 131072},
|
||||
}, {
|
||||
tg.FileHash{Offset: 1048576, Limit: 131072},
|
||||
tg.FileHash{Offset: 1179648, Limit: 131072},
|
||||
tg.FileHash{Offset: 1310720, Limit: 131072},
|
||||
tg.FileHash{Offset: 1441792, Limit: 131072},
|
||||
tg.FileHash{Offset: 1572864, Limit: 131072},
|
||||
tg.FileHash{Offset: 1703936, Limit: 131072},
|
||||
tg.FileHash{Offset: 1835008, Limit: 131072},
|
||||
tg.FileHash{Offset: 1966080, Limit: 131072},
|
||||
}, {
|
||||
tg.FileHash{Offset: 2097152, Limit: 131072},
|
||||
tg.FileHash{Offset: 2228224, Limit: 131072},
|
||||
tg.FileHash{Offset: 2359296, Limit: 131072},
|
||||
tg.FileHash{Offset: 2490368, Limit: 131072},
|
||||
tg.FileHash{Offset: 2621440, Limit: 131072},
|
||||
tg.FileHash{Offset: 2752512, Limit: 131072},
|
||||
tg.FileHash{Offset: 2883584, Limit: 131072},
|
||||
tg.FileHash{Offset: 3014656, Limit: 131072},
|
||||
},
|
||||
}
|
||||
|
||||
type mockHashes struct {
|
||||
ranges [][]tg.FileHash
|
||||
}
|
||||
|
||||
func (m mockHashes) Chunk(ctx context.Context, offset int64, limit int) (chunk, error) {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (m mockHashes) Hashes(ctx context.Context, offset int64) ([]tg.FileHash, error) {
|
||||
for _, r := range m.ranges {
|
||||
last := r[len(r)-1]
|
||||
if last.Offset+int64(last.Limit) <= offset {
|
||||
continue
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
return m.ranges[len(m.ranges)-1], nil
|
||||
}
|
||||
|
||||
func TestVerifier(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
ranges [][]tg.FileHash
|
||||
// Hashes returned from CDN redirect, for example.
|
||||
predefined []tg.FileHash
|
||||
expected [][]tg.FileHash
|
||||
}{
|
||||
{"NoPredefined", hashRanges, nil, hashRanges},
|
||||
{"Predefined", hashRanges[1:], hashRanges[0], hashRanges},
|
||||
{"OnlyPredefined", hashRanges[:1], hashRanges[0], hashRanges[:1]},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
a := require.New(t)
|
||||
m := mockHashes{ranges: test.ranges}
|
||||
v := newVerifier(m, test.predefined...)
|
||||
|
||||
hashes := make([]tg.FileHash, 0, len(test.predefined))
|
||||
for {
|
||||
hash, ok, err := v.next(ctx)
|
||||
a.NoError(err)
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
|
||||
hashes = append(hashes, hash)
|
||||
}
|
||||
|
||||
i := 0
|
||||
for _, hashRange := range test.expected {
|
||||
for _, expected := range hashRange {
|
||||
a.Equal(expected, hashes[i])
|
||||
i++
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,38 @@
|
||||
package downloader
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
var errHashesNotSupported = errors.New("this schema does not support hashes fetch")
|
||||
|
||||
// web is a web file download schema.
|
||||
// See https://core.telegram.org/api/files#downloading-webfiles.
|
||||
type web struct {
|
||||
client Client
|
||||
|
||||
location tg.InputWebFileLocationClass
|
||||
}
|
||||
|
||||
var _ schema = web{}
|
||||
|
||||
func (w web) Chunk(ctx context.Context, offset int64, limit int) (chunk, error) {
|
||||
file, err := w.client.UploadGetWebFile(ctx, &tg.UploadGetWebFileRequest{
|
||||
Location: w.location,
|
||||
Offset: int(offset),
|
||||
Limit: limit,
|
||||
})
|
||||
if err != nil {
|
||||
return chunk{}, err
|
||||
}
|
||||
|
||||
return chunk{data: file.Bytes, tag: file.FileType}, nil
|
||||
}
|
||||
|
||||
func (w web) Hashes(ctx context.Context, offset int64) ([]tg.FileHash, error) {
|
||||
return nil, errHashesNotSupported
|
||||
}
|
||||
@@ -0,0 +1,15 @@
|
||||
package telegram
|
||||
|
||||
import (
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tgerr"
|
||||
)
|
||||
|
||||
// ErrFloodWait is error type of "FLOOD_WAIT" error.
|
||||
const ErrFloodWait = tgerr.ErrFloodWait
|
||||
|
||||
// AsFloodWait returns wait duration and true boolean if err is
|
||||
// the "FLOOD_WAIT" error.
|
||||
//
|
||||
// Client should wait for that duration before issuing new requests with
|
||||
// same method.
|
||||
var AsFloodWait = tgerr.AsFloodWait
|
||||
@@ -0,0 +1,22 @@
|
||||
package telegram_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tgerr"
|
||||
)
|
||||
|
||||
func TestAsFloodWait(t *testing.T) {
|
||||
err := func() error {
|
||||
return errors.Wrap(tgerr.New(400, "FLOOD_WAIT_10"), "perform operation")
|
||||
}()
|
||||
|
||||
d, ok := telegram.AsFloodWait(err)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, time.Second*10, d)
|
||||
}
|
||||
@@ -0,0 +1,47 @@
|
||||
package telegram
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
func (c *Client) updateInterceptor(updates ...tg.UpdateClass) {
|
||||
for _, update := range updates {
|
||||
switch update.(type) {
|
||||
case *tg.UpdateConfig, *tg.UpdateDCOptions:
|
||||
c.fetchConfig(c.ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) processUpdates(updates tg.UpdatesClass) error {
|
||||
switch u := updates.(type) {
|
||||
case *tg.Updates:
|
||||
c.updateInterceptor(u.Updates...)
|
||||
return c.updateHandler.Handle(c.ctx, u)
|
||||
case *tg.UpdatesCombined:
|
||||
c.updateInterceptor(u.Updates...)
|
||||
return c.updateHandler.Handle(c.ctx, u)
|
||||
case *tg.UpdateShort:
|
||||
c.updateInterceptor(u.Update)
|
||||
return c.updateHandler.Handle(c.ctx, u)
|
||||
case *tg.UpdateShortMessage, *tg.UpdateShortChatMessage, *tg.UpdateShortSentMessage, *tg.UpdatesTooLong:
|
||||
return c.updateHandler.Handle(c.ctx, u)
|
||||
default:
|
||||
c.log.Warn("Ignoring update", zap.String("update_type", fmt.Sprintf("%T", u)))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) handleUpdates(b *bin.Buffer) error {
|
||||
updates, err := tg.DecodeUpdates(b)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "decode updates")
|
||||
}
|
||||
return c.processUpdates(updates)
|
||||
}
|
||||
@@ -0,0 +1,40 @@
|
||||
package telegram
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
type mockHandler struct {
|
||||
LastUpdate tg.UpdatesClass
|
||||
}
|
||||
|
||||
func (m *mockHandler) Handle(ctx context.Context, u tg.UpdatesClass) error {
|
||||
m.LastUpdate = u
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestClient_processUpdates(t *testing.T) {
|
||||
msg := &tg.Message{
|
||||
ID: 1,
|
||||
}
|
||||
upd := &tg.Updates{
|
||||
Updates: []tg.UpdateClass{&tg.UpdateNewMessage{
|
||||
Message: msg,
|
||||
}},
|
||||
}
|
||||
|
||||
t.Run("Handle", func(t *testing.T) {
|
||||
mock := &mockHandler{}
|
||||
c := new(Client)
|
||||
c.updateHandler = mock
|
||||
|
||||
err := c.processUpdates(upd)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, upd, mock.LastUpdate)
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,191 @@
|
||||
// Package deeplink contains deeplink parsing helpers.
|
||||
package deeplink
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
)
|
||||
|
||||
// Type is an enum type of Telegram deeplinks types.
|
||||
type Type string
|
||||
|
||||
const (
|
||||
// Resolve is deeplink like
|
||||
//
|
||||
// tg:resolve?domain={domain}
|
||||
// tg://resolve?domain={domain}
|
||||
// https://t.me/{domain}
|
||||
// https://telegram.me/{domain}
|
||||
//
|
||||
Resolve Type = "resolve"
|
||||
|
||||
// Join is deeplink like
|
||||
//
|
||||
// tg:join?invite={hash}
|
||||
// tg://join?invite={hash}
|
||||
// https://t.me/joinchat/{hash}
|
||||
// https://telegram.me/joinchat/{hash}
|
||||
// t.me/+{hash}
|
||||
//
|
||||
Join Type = "join"
|
||||
)
|
||||
|
||||
// DeepLink represents Telegram deeplink.
|
||||
type DeepLink struct {
|
||||
Type Type
|
||||
Args url.Values
|
||||
}
|
||||
|
||||
func ensureParam(query url.Values, key string) error {
|
||||
if query.Get(key) == "" {
|
||||
return errors.Errorf("should have %q query parameter", key)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d DeepLink) validate() error {
|
||||
switch d.Type {
|
||||
case Resolve:
|
||||
return ensureParam(d.Args, "domain")
|
||||
case Join:
|
||||
return ensureParam(d.Args, "invite")
|
||||
default:
|
||||
return errors.Errorf("unsupported deeplink %q", d.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func parseTg(u *url.URL) (DeepLink, error) {
|
||||
query := u.Query()
|
||||
switch Type(u.Hostname()) {
|
||||
case Resolve:
|
||||
return DeepLink{
|
||||
Type: Resolve,
|
||||
Args: query,
|
||||
}, nil
|
||||
case Join:
|
||||
return DeepLink{
|
||||
Type: Join,
|
||||
Args: query,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return DeepLink{}, errors.Errorf("unsupported deeplink %q", u.String())
|
||||
}
|
||||
|
||||
func parseHTTPS(u *url.URL) (DeepLink, error) {
|
||||
cleanInviteHash := func(root string) string {
|
||||
hash := strings.Trim(root, "+ ")
|
||||
if u.RawPath == "" {
|
||||
hash = url.PathEscape(hash)
|
||||
}
|
||||
return hash
|
||||
}
|
||||
|
||||
query := url.Values{}
|
||||
p := strings.TrimPrefix(u.Path, "/")
|
||||
p = strings.TrimSuffix(p, "/")
|
||||
split := strings.Split(p, "/")
|
||||
var (
|
||||
root = split[0]
|
||||
base string
|
||||
)
|
||||
if len(split) > 1 {
|
||||
base = split[1]
|
||||
}
|
||||
|
||||
switch root {
|
||||
case "joinchat":
|
||||
query.Set("invite", cleanInviteHash(base))
|
||||
return DeepLink{
|
||||
Type: Join,
|
||||
Args: query,
|
||||
}, nil
|
||||
case "":
|
||||
return DeepLink{}, errors.Errorf("unsupported deeplink %q", u.String())
|
||||
}
|
||||
|
||||
switch root[0] {
|
||||
case ' ', '+':
|
||||
query.Set("invite", cleanInviteHash(root))
|
||||
return DeepLink{
|
||||
Type: Join,
|
||||
Args: query,
|
||||
}, nil
|
||||
default:
|
||||
if err := ValidateDomain(root); err != nil {
|
||||
return DeepLink{}, err
|
||||
}
|
||||
query.Set("domain", root)
|
||||
return DeepLink{
|
||||
Type: Resolve,
|
||||
Args: query,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func hasTelegramPrefix(link string) bool {
|
||||
return strings.HasPrefix(link, "t.me") ||
|
||||
strings.HasPrefix(link, "telegram.me") ||
|
||||
strings.HasPrefix(link, "telegram.dog")
|
||||
}
|
||||
|
||||
// IsDeeplinkLike returns true if string may be a valid deeplink.
|
||||
func IsDeeplinkLike(link string) bool {
|
||||
return strings.HasPrefix(link, "tg:") ||
|
||||
hasTelegramPrefix(link) ||
|
||||
strings.HasPrefix(link, "https://")
|
||||
}
|
||||
|
||||
// Parse parses and returns deeplink.
|
||||
func Parse(link string) (DeepLink, error) {
|
||||
switch {
|
||||
// Normalize case like t.me/gotd.
|
||||
case hasTelegramPrefix(link):
|
||||
link = strings.TrimSuffix("https://"+link, "/")
|
||||
// Normalize case like tg:resolve?domain=gotd.
|
||||
case !strings.HasPrefix(link, "tg://") && strings.HasPrefix(link, "tg:"):
|
||||
link = "tg://" + strings.TrimPrefix(link, "tg:")
|
||||
}
|
||||
|
||||
u, err := url.Parse(link)
|
||||
if err != nil {
|
||||
return DeepLink{}, errors.Wrapf(err, "invalid URL %q", link)
|
||||
}
|
||||
|
||||
var d DeepLink
|
||||
switch {
|
||||
case u.Scheme == "https":
|
||||
switch strings.TrimPrefix(u.Hostname(), "www.") {
|
||||
case "t.me", "telegram.me", "telegram.dog":
|
||||
d, err = parseHTTPS(u)
|
||||
default:
|
||||
return DeepLink{}, errors.Errorf("invalid domain %q", link)
|
||||
}
|
||||
case u.Scheme == "tg":
|
||||
d, err = parseTg(u)
|
||||
default:
|
||||
return DeepLink{}, errors.Errorf("invalid deeplink %q", link)
|
||||
}
|
||||
if err != nil {
|
||||
return DeepLink{}, err
|
||||
}
|
||||
if err := d.validate(); err != nil {
|
||||
return DeepLink{}, err
|
||||
}
|
||||
|
||||
return d, nil
|
||||
}
|
||||
|
||||
// Expect parses deeplink and check type its type.
|
||||
func Expect(link string, typ Type) (DeepLink, error) {
|
||||
l, err := Parse(link)
|
||||
if err != nil {
|
||||
return l, err
|
||||
}
|
||||
if l.Type != typ {
|
||||
return l, errors.Errorf("unexpected deeplink type %q", l.Type)
|
||||
}
|
||||
return l, nil
|
||||
}
|
||||
@@ -0,0 +1,28 @@
|
||||
//go:build go1.18
|
||||
|
||||
package deeplink
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func addSuites(f *testing.F, suites map[string][]testCase) {
|
||||
for _, suite := range suites {
|
||||
for _, test := range suite {
|
||||
f.Add(test.input)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func FuzzParse(f *testing.F) {
|
||||
for _, typeSuite := range typeSuites {
|
||||
addSuites(f, typeSuite)
|
||||
}
|
||||
|
||||
f.Fuzz(func(t *testing.T, link string) {
|
||||
_, err := Parse(link)
|
||||
if err != nil {
|
||||
t.Skip(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,158 @@
|
||||
package deeplink
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type testCase struct {
|
||||
link DeepLink
|
||||
input string
|
||||
wantErr bool
|
||||
}
|
||||
|
||||
func join(arg string) DeepLink {
|
||||
return DeepLink{
|
||||
Type: Join,
|
||||
Args: map[string][]string{
|
||||
"invite": {arg},
|
||||
},
|
||||
}
|
||||
}
|
||||
func resolve(arg string) DeepLink {
|
||||
return DeepLink{
|
||||
Type: Resolve,
|
||||
Args: map[string][]string{
|
||||
"domain": {arg},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func joinSuite() map[string][]testCase {
|
||||
expect := join("AAAAAAAAAAAAAAAAAA")
|
||||
return map[string][]testCase{
|
||||
"Test": {
|
||||
{expect, `t.me/joinchat/AAAAAAAAAAAAAAAAAA`, false},
|
||||
{expect, `t.me/joinchat/AAAAAAAAAAAAAAAAAA/`, false},
|
||||
{expect, `t.me/+AAAAAAAAAAAAAAAAAA`, false},
|
||||
{expect, `t.me/+AAAAAAAAAAAAAAAAAA/`, false},
|
||||
{expect, `t.me/ +AAAAAAAAAAAAAAAAAA/`, false},
|
||||
{expect, `https://t.me/joinchat/AAAAAAAAAAAAAAAAAA`, false},
|
||||
{expect, `https://t.me/joinchat/AAAAAAAAAAAAAAAAAA/`, false},
|
||||
{expect, `tg:join?invite=AAAAAAAAAAAAAAAAAA`, false},
|
||||
{expect, `tg://join?invite=AAAAAAAAAAAAAAAAAA`, false},
|
||||
|
||||
{DeepLink{}, `https://t.co/joinchat/AAAAAAAAAAAAAAAAAA`, true},
|
||||
{DeepLink{}, `rt://join?invite=AAAAAAAAAAAAAAAAAA`, true},
|
||||
},
|
||||
"TDLib": {
|
||||
// t.me/+<hash>
|
||||
// Positive
|
||||
{join("aba%20aba"), "t.me/+aba%20aba", false},
|
||||
{join("aba0aba"), "t.me/+aba%30aba", false},
|
||||
{join("123456a"), "t.me/+123456a", false},
|
||||
{join("12345678901"), "t.me/%2012345678901", false},
|
||||
// Negative
|
||||
{DeepLink{}, "t.me/+?invite=abcdef", true},
|
||||
{DeepLink{}, "t.me/+", true},
|
||||
{DeepLink{}, "t.me/+/abcdef", true},
|
||||
{DeepLink{}, "t.me/ ?/abcdef", true},
|
||||
{DeepLink{}, "t.me/+?abcdef", true},
|
||||
{DeepLink{}, "t.me/+#abcdef", true},
|
||||
{DeepLink{}, "t.me/ /123456/123123/12/31/a/s//21w/?asdas#test", true},
|
||||
|
||||
// t.me/joinchat/<hash>
|
||||
// Positive
|
||||
{join("abacaba"), "t.me/joinchat/abacaba", false},
|
||||
{join("aba%20aba"), "t.me/joinchat/aba%20aba", false},
|
||||
{join("aba0aba"), "t.me/joinchat/aba%30aba", false},
|
||||
{join("123456a"), "t.me/joinchat/123456a", false},
|
||||
{join("12345678901"), "t.me/joinchat/12345678901", false},
|
||||
{join("123456"), "t.me/joinchat/123456", false},
|
||||
{join("123456"), "t.me/joinchat/123456/123123/12/31/a/s//21w/?asdas#test", false},
|
||||
// Negative
|
||||
{DeepLink{}, "t.me/joinchat?invite=abcdef", true},
|
||||
{DeepLink{}, "t.me/joinchat", true},
|
||||
{DeepLink{}, "t.me/joinchat/", true},
|
||||
{DeepLink{}, "t.me/joinchat//abcdef", true},
|
||||
{DeepLink{}, "t.me/joinchat?/abcdef", true},
|
||||
{DeepLink{}, "t.me/joinchat/?abcdef", true},
|
||||
{DeepLink{}, "t.me/joinchat/#abcdef", true},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func resolveSuite() map[string][]testCase {
|
||||
expect := resolve("gotd_ru")
|
||||
return map[string][]testCase{
|
||||
"Test": {
|
||||
{expect, `t.me/gotd_ru`, false},
|
||||
{expect, `t.me/gotd_ru/`, false},
|
||||
{expect, `https://t.me/gotd_ru`, false},
|
||||
{expect, `https://t.me/gotd_ru/`, false},
|
||||
{expect, `tg:resolve?domain=gotd_ru`, false},
|
||||
{expect, `tg:resolve?&&&&&&&domain=gotd_ru`, false},
|
||||
{expect, `tg://resolve?domain=gotd_ru`, false},
|
||||
|
||||
{DeepLink{}, `https://t.co/gotd_ru`, true},
|
||||
{DeepLink{}, `rt://join?invite=AAAAAAAAAAAAAAAAAA`, true},
|
||||
},
|
||||
"TDLib": {
|
||||
// t.me/<domain>
|
||||
// Positive
|
||||
{resolve("a"), "t.me/a", false},
|
||||
{resolve("abcdefghijklmnopqrstuvwxyz123456"), "t.me/abcdefghijklmnopqrstuvwxyz123456", false},
|
||||
{resolve("Aasdf"), "t.me/Aasdf", false},
|
||||
{resolve("asdf0"), "t.me/asdf0", false},
|
||||
{resolve("username"), "t.me/username/0/a//s/as?gam=asd", false},
|
||||
{resolve("username"), "t.me/username/aasdas?test=1", false},
|
||||
{resolve("username"), "t.me/username/0", false},
|
||||
{resolve("telecram"), "https://telegram.dog/tele%63ram", false},
|
||||
// Negative
|
||||
{DeepLink{}, "t.me/abcdefghijklmnopqrstuvwxyz1234567", true},
|
||||
{DeepLink{}, "t.me/abcdefghijklmnop-qrstuvwxyz", true},
|
||||
{DeepLink{}, "t.me/abcdefghijklmnop~qrstuvwxyz", true},
|
||||
{DeepLink{}, "t.me/_asdf", true},
|
||||
{DeepLink{}, "t.me/0asdf", true},
|
||||
{DeepLink{}, "t.me/9asdf", true},
|
||||
{DeepLink{}, "t.me/asdf_", true},
|
||||
{DeepLink{}, "t.me/asd__fg", true},
|
||||
{DeepLink{}, "t.me//username", true},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
var typeSuites = map[string]map[string][]testCase{
|
||||
"Join": joinSuite(),
|
||||
"Resolve": resolveSuite(),
|
||||
}
|
||||
|
||||
func TestParseDeeplink(t *testing.T) {
|
||||
runSuite := func(suite []testCase) func(t *testing.T) {
|
||||
return func(t *testing.T) {
|
||||
for i, test := range suite {
|
||||
t.Run(fmt.Sprintf("Test%d (%s)", i, test.input), func(t *testing.T) {
|
||||
a := require.New(t)
|
||||
d, err := Parse(test.input)
|
||||
|
||||
if test.wantErr {
|
||||
a.Error(err, test.input)
|
||||
} else {
|
||||
a.NoError(err, test.input)
|
||||
a.Equal(test.link, d, test.input)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for typeName, typeSuite := range typeSuites {
|
||||
t.Run(typeName, func(t *testing.T) {
|
||||
for suiteName, suite := range typeSuite {
|
||||
t.Run(suiteName, runSuite(suite))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,40 @@
|
||||
package deeplink
|
||||
|
||||
import (
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/ascii"
|
||||
)
|
||||
|
||||
// ValidateDomain validate given domain (user) name
|
||||
func ValidateDomain(domain string) error {
|
||||
return checkDomainSymbols(domain)
|
||||
}
|
||||
|
||||
// checkDomainSymbols check that domain contains only a-z, A-Z, 0-9 and '_'
|
||||
// symbols.
|
||||
func checkDomainSymbols(domain string) error {
|
||||
switch {
|
||||
case domain == "":
|
||||
return errors.New("is empty")
|
||||
case len(domain) > 32:
|
||||
return errors.New("is too big")
|
||||
case !ascii.IsLatinLower(rune(domain[0])):
|
||||
return errors.New("must start with lower letter")
|
||||
case domain[len(domain)-1] == '_':
|
||||
return errors.New("must not end with '_'")
|
||||
}
|
||||
|
||||
for i, r := range domain {
|
||||
switch {
|
||||
case !ascii.IsLatinLetter(r) && !ascii.IsDigit(r) && r != '_':
|
||||
case i > 0 && domain[i] == '_' && domain[i] == domain[i-1]:
|
||||
default:
|
||||
continue
|
||||
}
|
||||
|
||||
return errors.Errorf("unexpected %c at %d", r, i)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,38 @@
|
||||
package deeplink
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestValidateDomain(t *testing.T) {
|
||||
tests := []struct {
|
||||
domain string
|
||||
wantErr bool
|
||||
}{
|
||||
{"a", false},
|
||||
{"abcdefghijklmnopqrstuvwxyz123456", false},
|
||||
{"Aasdf", false},
|
||||
{"asdf0", false},
|
||||
{"", true},
|
||||
{"asdf_", true},
|
||||
{"asd__fg", true},
|
||||
{"_asdf", true},
|
||||
{"0asdf", true},
|
||||
{"9asdf", true},
|
||||
{"abcdefghijklmnopqrstuvwxyz1234567", true},
|
||||
{"abcdefghijklmnop-qrstuvwxyz", true},
|
||||
{"abcdefghijklmnop~qrstuvwxyz", true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.domain, func(t *testing.T) {
|
||||
err := ValidateDomain(tt.domain)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,58 @@
|
||||
package e2etest
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/auth"
|
||||
)
|
||||
|
||||
func (s *Suite) createFlow(ctx context.Context) (auth.Flow, error) {
|
||||
var ua auth.UserAuthenticator
|
||||
for {
|
||||
ua = auth.Test(s.rand, s.dc)
|
||||
phone, err := ua.Phone(ctx)
|
||||
if err != nil {
|
||||
return auth.Flow{}, err
|
||||
}
|
||||
|
||||
s.usedMux.Lock()
|
||||
if _, ok := s.used[phone]; !ok {
|
||||
s.used[phone] = struct{}{}
|
||||
s.usedMux.Unlock()
|
||||
break
|
||||
}
|
||||
s.usedMux.Unlock()
|
||||
}
|
||||
|
||||
return auth.NewFlow(ua, auth.SendCodeOptions{}), nil
|
||||
}
|
||||
|
||||
// Authenticate authenticates client on test server.
|
||||
func (s *Suite) Authenticate(ctx context.Context, client auth.FlowClient) error {
|
||||
for {
|
||||
flow, err := s.createFlow(ctx)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "create flow")
|
||||
}
|
||||
|
||||
if err := flow.Run(ctx, client); err != nil {
|
||||
if errors.Is(err, auth.ErrPasswordNotProvided) {
|
||||
continue
|
||||
}
|
||||
|
||||
return errors.Wrap(err, "run flow")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// RetryAuthenticate authenticates client on test server.
|
||||
func (s *Suite) RetryAuthenticate(ctx context.Context, client auth.FlowClient) error {
|
||||
bck := backoff.WithContext(backoff.NewExponentialBackOff(), ctx)
|
||||
return backoff.Retry(func() error {
|
||||
return s.Authenticate(ctx, client)
|
||||
}, bck)
|
||||
}
|
||||
@@ -0,0 +1,61 @@
|
||||
package e2etest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap/zaptest"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/auth"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
type mockFlow struct {
|
||||
flag bool
|
||||
}
|
||||
|
||||
var _ auth.FlowClient = &mockFlow{}
|
||||
|
||||
func (m *mockFlow) SignIn(context.Context, string, string, string) (*tg.AuthAuthorization, error) {
|
||||
// Ensure retry.
|
||||
if !m.flag {
|
||||
m.flag = true
|
||||
return nil, auth.ErrPasswordAuthNeeded
|
||||
}
|
||||
|
||||
return m.Password(context.Background(), "")
|
||||
}
|
||||
|
||||
func (m *mockFlow) SendCode(context.Context, string, auth.SendCodeOptions) (tg.AuthSentCodeClass, error) {
|
||||
return &tg.AuthSentCode{
|
||||
PhoneCodeHash: "hash",
|
||||
Type: &tg.AuthSentCodeTypeApp{},
|
||||
Timeout: 10,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *mockFlow) Password(context.Context, string) (*tg.AuthAuthorization, error) {
|
||||
return &tg.AuthAuthorization{
|
||||
User: &tg.User{
|
||||
ID: 10,
|
||||
Username: "aboba",
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *mockFlow) SignUp(context.Context, auth.SignUp) (*tg.AuthAuthorization, error) {
|
||||
return nil, errors.New("must not be called")
|
||||
}
|
||||
|
||||
func TestSuite_Authenticate(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
logger := zaptest.NewLogger(t)
|
||||
s := NewSuite(t, TestOptions{
|
||||
Logger: logger,
|
||||
})
|
||||
|
||||
flow := &mockFlow{}
|
||||
require.NoError(t, s.Authenticate(ctx, flow))
|
||||
}
|
||||
@@ -0,0 +1,3 @@
|
||||
// Package e2etest contains some helpers to make external E2E tests
|
||||
// using Telegram test server.
|
||||
package e2etest
|
||||
@@ -0,0 +1,190 @@
|
||||
package e2etest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/message"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tgerr"
|
||||
)
|
||||
|
||||
// EchoBot is a simple echo message bot.
|
||||
type EchoBot struct {
|
||||
suite *Suite
|
||||
|
||||
logger *zap.Logger
|
||||
auth chan<- *tg.User
|
||||
}
|
||||
|
||||
// NewEchoBot creates new echo bot.
|
||||
func NewEchoBot(suite *Suite, auth chan<- *tg.User) EchoBot {
|
||||
return EchoBot{
|
||||
suite: suite,
|
||||
logger: suite.logger.Named("echobot"),
|
||||
auth: auth,
|
||||
}
|
||||
}
|
||||
|
||||
type users struct {
|
||||
users map[int64]*tg.User
|
||||
lock sync.RWMutex
|
||||
}
|
||||
|
||||
func newUsers() *users {
|
||||
return &users{
|
||||
users: map[int64]*tg.User{},
|
||||
}
|
||||
}
|
||||
|
||||
func (m *users) empty() (r bool) {
|
||||
m.lock.RLock()
|
||||
r = len(m.users) < 1
|
||||
m.lock.RUnlock()
|
||||
return
|
||||
}
|
||||
|
||||
func (m *users) add(list ...tg.UserClass) {
|
||||
m.lock.Lock()
|
||||
defer m.lock.Unlock()
|
||||
|
||||
tg.UserClassArray(list).FillNotEmptyMap(m.users)
|
||||
}
|
||||
|
||||
func (m *users) get(id int64) (r *tg.User) {
|
||||
m.lock.RLock()
|
||||
r = m.users[id]
|
||||
m.lock.RUnlock()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (b EchoBot) login(ctx context.Context, client *telegram.Client) (*tg.User, error) {
|
||||
if err := b.suite.RetryAuthenticate(ctx, client.Auth()); err != nil {
|
||||
return nil, errors.Wrap(err, "authenticate")
|
||||
}
|
||||
|
||||
var me *tg.User
|
||||
if err := retry(ctx, func() (err error) {
|
||||
me, err = client.Self(ctx)
|
||||
return err
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
expectedUsername := "echobot" + strconv.FormatInt(me.ID, 10)
|
||||
raw := tg.NewClient(retryInvoker{prev: client})
|
||||
_, err := raw.AccountUpdateUsername(ctx, expectedUsername)
|
||||
if err != nil {
|
||||
if !tgerr.Is(err, tg.ErrUsernameNotModified) {
|
||||
return nil, errors.Wrap(err, "update username")
|
||||
}
|
||||
}
|
||||
me, err = retryResult(ctx, func() (*tg.User, error) {
|
||||
return client.Self(ctx)
|
||||
})
|
||||
if me.Username != expectedUsername {
|
||||
return nil, errors.Errorf("expected username %q, got %q", expectedUsername, me.Username)
|
||||
}
|
||||
|
||||
return me, nil
|
||||
}
|
||||
|
||||
func (b EchoBot) handler(client *telegram.Client) tg.NewMessageHandler {
|
||||
dialogsUsers := newUsers()
|
||||
|
||||
raw := tg.NewClient(client)
|
||||
sender := message.NewSender(raw)
|
||||
return func(ctx context.Context, entities tg.Entities, update *tg.UpdateNewMessage) error {
|
||||
if filterMessage(update) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if m, ok := update.Message.(interface{ GetMessage() string }); ok {
|
||||
b.logger.Named("dispatcher").
|
||||
Info("Got new message update", zap.String("message", m.GetMessage()))
|
||||
}
|
||||
|
||||
if dialogsUsers.empty() {
|
||||
dialogs, err := retryResult(ctx, func() (tg.MessagesDialogsClass, error) {
|
||||
dialogs, err := raw.MessagesGetDialogs(ctx, &tg.MessagesGetDialogsRequest{
|
||||
Limit: 100,
|
||||
OffsetPeer: &tg.InputPeerEmpty{},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "get dialogs")
|
||||
}
|
||||
return dialogs, nil
|
||||
})
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "get dialogs")
|
||||
}
|
||||
if dlg, ok := dialogs.AsModified(); ok {
|
||||
dialogsUsers.add(dlg.GetUsers()...)
|
||||
}
|
||||
}
|
||||
|
||||
switch m := update.Message.(type) {
|
||||
case *tg.Message:
|
||||
switch peer := m.PeerID.(type) {
|
||||
case *tg.PeerUser:
|
||||
user := entities.Users[peer.UserID]
|
||||
if user == nil {
|
||||
user = dialogsUsers.get(peer.UserID)
|
||||
}
|
||||
|
||||
b.logger.Info("Got message",
|
||||
zap.String("text", m.Message),
|
||||
zap.Int64("user_id", user.ID),
|
||||
zap.String("user_first_name", user.FirstName),
|
||||
zap.String("username", user.Username),
|
||||
)
|
||||
|
||||
if err := retry(ctx, func() error {
|
||||
_, err := sender.To(user.AsInputPeer()).Text(ctx, m.Message)
|
||||
return err
|
||||
}); err != nil {
|
||||
return errors.Wrap(err, "send message")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Run setups and starts echo bot.
|
||||
func (b EchoBot) Run(ctx context.Context) error {
|
||||
dispatcher := tg.NewUpdateDispatcher()
|
||||
client := b.suite.Client(b.logger, dispatcher)
|
||||
dispatcher.OnNewMessage(b.handler(client))
|
||||
|
||||
return client.Run(ctx, func(ctx context.Context) error {
|
||||
defer close(b.auth)
|
||||
|
||||
me, err := b.login(ctx, client)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "login")
|
||||
}
|
||||
|
||||
b.logger.Info("Logged in",
|
||||
zap.String("user", me.Username),
|
||||
zap.Int64("id", me.ID),
|
||||
)
|
||||
|
||||
select {
|
||||
case b.auth <- me:
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
<-ctx.Done()
|
||||
return nil
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,25 @@
|
||||
package e2etest
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
func filterMessage(update *tg.UpdateNewMessage) bool {
|
||||
if v, ok := update.Message.(interface{ GetOut() bool }); ok && v.GetOut() {
|
||||
return true
|
||||
}
|
||||
|
||||
if v, ok := update.Message.(interface{ GetPeerID() tg.PeerClass }); ok && v.GetPeerID() == nil {
|
||||
return true
|
||||
}
|
||||
if _, ok := update.Message.(*tg.MessageService); ok {
|
||||
return true
|
||||
}
|
||||
if v, ok := update.Message.(interface{ GetMessage() string }); ok && strings.HasPrefix(v.GetMessage(), "Login code:") {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
@@ -0,0 +1,37 @@
|
||||
package e2etest
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/constant"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
|
||||
)
|
||||
|
||||
// TestOptions contains some common test server settings.
|
||||
type TestOptions struct {
|
||||
AppID int
|
||||
AppHash string
|
||||
DC int
|
||||
Random io.Reader
|
||||
Logger *zap.Logger
|
||||
}
|
||||
|
||||
func (opt *TestOptions) setDefaults() {
|
||||
if opt.AppID == 0 {
|
||||
opt.AppID = constant.TestAppID
|
||||
}
|
||||
if opt.AppHash == "" {
|
||||
opt.AppHash = constant.TestAppHash
|
||||
}
|
||||
if opt.DC == 0 {
|
||||
opt.DC = 2
|
||||
}
|
||||
if opt.Random == nil {
|
||||
opt.Random = crypto.DefaultRand()
|
||||
}
|
||||
if opt.Logger == nil {
|
||||
opt.Logger = zap.NewNop()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,61 @@
|
||||
package e2etest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tgerr"
|
||||
)
|
||||
|
||||
type retryInvoker struct {
|
||||
prev tg.Invoker
|
||||
}
|
||||
|
||||
func retryResult[T any](ctx context.Context, cb func() (T, error)) (T, error) {
|
||||
var zero T
|
||||
return backoff.RetryWithData[T](func() (T, error) {
|
||||
res, err := cb()
|
||||
if err != nil {
|
||||
if tgerr.IsCode(err, -500) {
|
||||
return zero, err
|
||||
}
|
||||
if tgerr.Is(err, "CONNECTION_NOT_INITED") {
|
||||
return zero, err
|
||||
}
|
||||
if ok, err := tgerr.FloodWait(ctx, err); ok {
|
||||
return zero, err
|
||||
}
|
||||
return zero, backoff.Permanent(err)
|
||||
}
|
||||
return res, nil
|
||||
}, backoff.WithContext(backoff.NewConstantBackOff(time.Millisecond*500), ctx))
|
||||
}
|
||||
|
||||
func retry(ctx context.Context, cb func() error) error {
|
||||
return backoff.Retry(func() error {
|
||||
if err := cb(); err != nil {
|
||||
if tgerr.IsCode(err, -500) {
|
||||
return err
|
||||
}
|
||||
if tgerr.Is(err, "CONNECTION_NOT_INITED") {
|
||||
return err
|
||||
}
|
||||
if ok, err := tgerr.FloodWait(ctx, err); ok {
|
||||
return err
|
||||
}
|
||||
return backoff.Permanent(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}, backoff.WithContext(backoff.NewExponentialBackOff(), ctx))
|
||||
}
|
||||
|
||||
func (w retryInvoker) Invoke(ctx context.Context, input bin.Encoder, output bin.Decoder) error {
|
||||
return retry(ctx, func() error {
|
||||
return w.prev.Invoke(ctx, input, output)
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,50 @@
|
||||
package e2etest
|
||||
|
||||
import (
|
||||
"io"
|
||||
"sync"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/dcs"
|
||||
)
|
||||
|
||||
// Suite is struct which contains external E2E test parameters.
|
||||
type Suite struct {
|
||||
TB require.TestingT
|
||||
appID int
|
||||
appHash string
|
||||
dc int
|
||||
logger *zap.Logger
|
||||
|
||||
rand io.Reader
|
||||
// already used phone numbers
|
||||
used map[string]struct{}
|
||||
usedMux sync.Mutex
|
||||
}
|
||||
|
||||
// NewSuite creates new Suite.
|
||||
func NewSuite(tb require.TestingT, config TestOptions) *Suite {
|
||||
config.setDefaults()
|
||||
return &Suite{
|
||||
TB: tb,
|
||||
appID: config.AppID,
|
||||
appHash: config.AppHash,
|
||||
dc: config.DC,
|
||||
logger: config.Logger,
|
||||
rand: config.Random,
|
||||
used: map[string]struct{}{},
|
||||
}
|
||||
}
|
||||
|
||||
// Client creates new *telegram.Client using this suite.
|
||||
func (s *Suite) Client(logger *zap.Logger, handler telegram.UpdateHandler) *telegram.Client {
|
||||
return telegram.NewClient(s.appID, s.appHash, telegram.Options{
|
||||
DC: s.dc,
|
||||
DCList: dcs.Test(),
|
||||
Logger: logger,
|
||||
UpdateHandler: handler,
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,98 @@
|
||||
package e2etest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/message"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tgerr"
|
||||
)
|
||||
|
||||
// User is a simple user bot.
|
||||
type User struct {
|
||||
suite *Suite
|
||||
text []string
|
||||
username string
|
||||
|
||||
logger *zap.Logger
|
||||
message chan string
|
||||
}
|
||||
|
||||
// NewUser creates new User bot.
|
||||
func NewUser(suite *Suite, text []string, username string) User {
|
||||
return User{
|
||||
suite: suite,
|
||||
text: text,
|
||||
username: username,
|
||||
logger: suite.logger.Named("terentyev"),
|
||||
message: make(chan string, 1),
|
||||
}
|
||||
}
|
||||
|
||||
func (u User) messageHandler(ctx context.Context, entities tg.Entities, update *tg.UpdateNewMessage) error {
|
||||
if filterMessage(update) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if m, ok := update.Message.(interface{ GetMessage() string }); ok {
|
||||
u.logger.Named("dispatcher").
|
||||
Info("Got new message update", zap.String("message", m.GetMessage()))
|
||||
}
|
||||
|
||||
msg, ok := update.Message.(*tg.Message)
|
||||
if !ok {
|
||||
return errors.Errorf("unexpected type %T", update.Message)
|
||||
}
|
||||
|
||||
select {
|
||||
case u.message <- msg.Message:
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// Run setups and starts user bot.
|
||||
func (u User) Run(ctx context.Context) error {
|
||||
dispatcher := tg.NewUpdateDispatcher()
|
||||
dispatcher.OnNewMessage(u.messageHandler)
|
||||
client := u.suite.Client(u.logger, dispatcher)
|
||||
sender := message.NewSender(tg.NewClient(retryInvoker{prev: client}))
|
||||
|
||||
return client.Run(ctx, func(ctx context.Context) error {
|
||||
if err := u.suite.RetryAuthenticate(ctx, client.Auth()); err != nil {
|
||||
return errors.Wrap(err, "authenticate")
|
||||
}
|
||||
|
||||
peer, err := sender.Resolve(u.username).AsInputPeer(ctx)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "resolve bot username %q", u.username)
|
||||
}
|
||||
|
||||
for _, line := range u.text {
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
_, err = sender.To(peer).Text(ctx, line)
|
||||
if flood, err := tgerr.FloodWait(ctx, err); err != nil {
|
||||
if flood {
|
||||
continue
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
select {
|
||||
case gotMessage := <-u.message:
|
||||
require.Equal(u.suite.TB, line, gotMessage)
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,257 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"github.com/go-faster/errors"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/clock"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/mtproto"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/pool"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tdsync"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tgerr"
|
||||
)
|
||||
|
||||
type protoConn interface {
|
||||
Invoke(ctx context.Context, input bin.Encoder, output bin.Decoder) error
|
||||
Run(ctx context.Context, f func(ctx context.Context) error) error
|
||||
Ping(ctx context.Context) error
|
||||
}
|
||||
|
||||
//go:generate go run -modfile=../../../_tools/go.mod golang.org/x/tools/cmd/stringer -type=ConnMode
|
||||
|
||||
// ConnMode represents connection mode.
|
||||
type ConnMode byte
|
||||
|
||||
const (
|
||||
// ConnModeUpdates is update connection mode.
|
||||
ConnModeUpdates ConnMode = iota
|
||||
// ConnModeData is data connection mode.
|
||||
ConnModeData
|
||||
// ConnModeCDN is CDN connection mode.
|
||||
ConnModeCDN
|
||||
)
|
||||
|
||||
// Conn is a Telegram client connection.
|
||||
type Conn struct {
|
||||
// Connection parameters.
|
||||
mode ConnMode // immutable
|
||||
// MTProto connection.
|
||||
proto protoConn // immutable
|
||||
|
||||
// InitConnection parameters.
|
||||
appID int // immutable
|
||||
device DeviceConfig // immutable
|
||||
|
||||
// setup is callback which called after initConnection, but before ready signaling.
|
||||
// This is necessary to transfer auth from previous connection to another DC.
|
||||
setup SetupCallback // nilable
|
||||
|
||||
// onDead is called on connection death.
|
||||
onDead func()
|
||||
|
||||
// Wrappers for external world, like logs or PRNG.
|
||||
// Should be immutable.
|
||||
clock clock.Clock // immutable
|
||||
log *zap.Logger // immutable
|
||||
|
||||
// Handler passed by client.
|
||||
handler Handler // immutable
|
||||
|
||||
// State fields.
|
||||
cfg tg.Config
|
||||
ongoing int
|
||||
latest time.Time
|
||||
mux sync.Mutex
|
||||
|
||||
sessionInit *tdsync.Ready // immutable
|
||||
gotConfig *tdsync.Ready // immutable
|
||||
dead *tdsync.Ready // immutable
|
||||
|
||||
connBackoff func(ctx context.Context) backoff.BackOff // immutable
|
||||
}
|
||||
|
||||
// OnSession implements mtproto.Handler.
|
||||
func (c *Conn) OnSession(session mtproto.Session) error {
|
||||
c.log.Info("SessionInit")
|
||||
c.sessionInit.Signal()
|
||||
|
||||
// Waiting for config, because OnSession can occur before we set config.
|
||||
select {
|
||||
case <-c.gotConfig.Ready():
|
||||
case <-c.dead.Ready():
|
||||
return nil
|
||||
}
|
||||
|
||||
c.mux.Lock()
|
||||
cfg := c.cfg
|
||||
c.mux.Unlock()
|
||||
|
||||
return c.handler.OnSession(cfg, session)
|
||||
}
|
||||
|
||||
func (c *Conn) trackInvoke() func() {
|
||||
start := c.clock.Now()
|
||||
|
||||
c.mux.Lock()
|
||||
defer c.mux.Unlock()
|
||||
|
||||
c.ongoing++
|
||||
c.latest = start
|
||||
|
||||
return func() {
|
||||
c.mux.Lock()
|
||||
defer c.mux.Unlock()
|
||||
|
||||
c.ongoing--
|
||||
end := c.clock.Now()
|
||||
c.latest = end
|
||||
|
||||
c.log.Debug("Invoke",
|
||||
zap.Duration("duration", end.Sub(start)),
|
||||
zap.Int("ongoing", c.ongoing),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Run initialize connection.
|
||||
func (c *Conn) Run(ctx context.Context) (err error) {
|
||||
defer c.dead.Signal()
|
||||
defer func() {
|
||||
if err != nil && ctx.Err() == nil {
|
||||
c.log.Debug("Connection dead", zap.Error(err))
|
||||
if c.onDead != nil {
|
||||
c.onDead()
|
||||
}
|
||||
}
|
||||
}()
|
||||
return c.proto.Run(ctx, func(ctx context.Context) error {
|
||||
// Signal death on init error. Otherwise connection shutdown
|
||||
// deadlocks in OnSession that occurs before init fails.
|
||||
err := c.init(ctx)
|
||||
if err != nil {
|
||||
c.dead.Signal()
|
||||
}
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Conn) waitSession(ctx context.Context) error {
|
||||
select {
|
||||
case <-c.sessionInit.Ready():
|
||||
return nil
|
||||
case <-c.dead.Ready():
|
||||
return pool.ErrConnDead
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// Ready returns channel to determine connection readiness.
|
||||
// Useful for pooling.
|
||||
func (c *Conn) Ready() <-chan struct{} {
|
||||
return c.sessionInit.Ready()
|
||||
}
|
||||
|
||||
// Invoke implements Invoker.
|
||||
func (c *Conn) Invoke(ctx context.Context, input bin.Encoder, output bin.Decoder) error {
|
||||
// Tracking ongoing invokes.
|
||||
defer c.trackInvoke()()
|
||||
if err := c.waitSession(ctx); err != nil {
|
||||
return errors.Wrap(err, "waitSession")
|
||||
}
|
||||
|
||||
return c.proto.Invoke(ctx, c.wrapRequest(noopDecoder{input}), output)
|
||||
}
|
||||
|
||||
// OnMessage implements mtproto.Handler.
|
||||
func (c *Conn) OnMessage(b *bin.Buffer) error {
|
||||
return c.handler.OnMessage(b)
|
||||
}
|
||||
|
||||
type noopDecoder struct {
|
||||
bin.Encoder
|
||||
}
|
||||
|
||||
func (n noopDecoder) Decode(b *bin.Buffer) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (c *Conn) wrapRequest(req bin.Object) bin.Object {
|
||||
if c.mode != ConnModeUpdates {
|
||||
return &tg.InvokeWithoutUpdatesRequest{
|
||||
Query: req,
|
||||
}
|
||||
}
|
||||
|
||||
return req
|
||||
}
|
||||
|
||||
func (c *Conn) init(ctx context.Context) error {
|
||||
c.log.Debug("Initializing")
|
||||
|
||||
q := c.wrapRequest(&tg.InitConnectionRequest{
|
||||
APIID: c.appID,
|
||||
DeviceModel: c.device.DeviceModel,
|
||||
SystemVersion: c.device.SystemVersion,
|
||||
AppVersion: c.device.AppVersion,
|
||||
SystemLangCode: c.device.SystemLangCode,
|
||||
LangPack: c.device.LangPack,
|
||||
LangCode: c.device.LangCode,
|
||||
Proxy: c.device.Proxy,
|
||||
Params: c.device.Params,
|
||||
Query: c.wrapRequest(&tg.HelpGetConfigRequest{}),
|
||||
})
|
||||
req := c.wrapRequest(&tg.InvokeWithLayerRequest{
|
||||
Layer: tg.Layer,
|
||||
Query: q,
|
||||
})
|
||||
|
||||
var cfg tg.Config
|
||||
if err := backoff.RetryNotify(func() error {
|
||||
if err := c.proto.Invoke(ctx, req, &cfg); err != nil {
|
||||
if tgerr.Is(err, tgerr.ErrFloodWait) {
|
||||
// Server sometimes returns FLOOD_WAIT(0) if you create
|
||||
// multiple connections in short period of time.
|
||||
//
|
||||
// See https://github.com/gotd/td/issues/388.
|
||||
return errors.Wrap(err, "flood wait")
|
||||
}
|
||||
// Not retrying other errors.
|
||||
return backoff.Permanent(errors.Wrap(err, "invoke"))
|
||||
}
|
||||
|
||||
return nil
|
||||
}, c.connBackoff(ctx), func(err error, duration time.Duration) {
|
||||
c.log.Debug("Retrying connection initialization",
|
||||
zap.Error(err), zap.Duration("duration", duration),
|
||||
)
|
||||
}); err != nil {
|
||||
return errors.Wrap(err, "initConnection")
|
||||
}
|
||||
|
||||
if c.setup != nil {
|
||||
if err := c.setup(ctx, c); err != nil {
|
||||
return errors.Wrap(err, "setup")
|
||||
}
|
||||
}
|
||||
|
||||
c.mux.Lock()
|
||||
c.latest = c.clock.Now()
|
||||
c.cfg = cfg
|
||||
c.mux.Unlock()
|
||||
|
||||
c.gotConfig.Signal()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ping calls ping for underlying protocol connection.
|
||||
func (c *Conn) Ping(ctx context.Context) error {
|
||||
return c.proto.Ping(ctx)
|
||||
}
|
||||
@@ -0,0 +1,25 @@
|
||||
// Code generated by "stringer -type=ConnMode"; DO NOT EDIT.
|
||||
|
||||
package manager
|
||||
|
||||
import "strconv"
|
||||
|
||||
func _() {
|
||||
// An "invalid array index" compiler error signifies that the constant values have changed.
|
||||
// Re-run the stringer command to generate them again.
|
||||
var x [1]struct{}
|
||||
_ = x[ConnModeUpdates-0]
|
||||
_ = x[ConnModeData-1]
|
||||
_ = x[ConnModeCDN-2]
|
||||
}
|
||||
|
||||
const _ConnMode_name = "ConnModeUpdatesConnModeDataConnModeCDN"
|
||||
|
||||
var _ConnMode_index = [...]uint8{0, 15, 27, 38}
|
||||
|
||||
func (i ConnMode) String() string {
|
||||
if i >= ConnMode(len(_ConnMode_index)-1) {
|
||||
return "ConnMode(" + strconv.FormatInt(int64(i), 10) + ")"
|
||||
}
|
||||
return _ConnMode_name[_ConnMode_index[i]:_ConnMode_index[i+1]]
|
||||
}
|
||||
@@ -0,0 +1,96 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/clock"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/mtproto"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tdsync"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/auth"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
// SetupCallback is an optional setup connection callback.
|
||||
type SetupCallback = func(ctx context.Context, invoker tg.Invoker) error
|
||||
|
||||
// ConnOptions is a Telegram client connection options.
|
||||
type ConnOptions struct {
|
||||
DC int
|
||||
Test bool
|
||||
Device DeviceConfig
|
||||
Handler Handler
|
||||
Setup SetupCallback
|
||||
OnDead func()
|
||||
OnAuthError func(error)
|
||||
Backoff func(ctx context.Context) backoff.BackOff
|
||||
}
|
||||
|
||||
func defaultBackoff(c clock.Clock) func(ctx context.Context) backoff.BackOff {
|
||||
return func(ctx context.Context) backoff.BackOff {
|
||||
b := backoff.NewExponentialBackOff()
|
||||
b.Clock = c
|
||||
b.MaxElapsedTime = time.Second * 30
|
||||
b.MaxInterval = time.Second * 5
|
||||
return backoff.WithContext(b, ctx)
|
||||
}
|
||||
}
|
||||
|
||||
// setDefaults sets default values.
|
||||
func (c *ConnOptions) setDefaults(connClock clock.Clock) {
|
||||
if c.DC == 0 {
|
||||
c.DC = 2
|
||||
}
|
||||
// It's okay to use zero value Test.
|
||||
c.Device.SetDefaults()
|
||||
if c.Handler == nil {
|
||||
c.Handler = NoopHandler{}
|
||||
}
|
||||
if c.Backoff == nil {
|
||||
c.Backoff = defaultBackoff(connClock)
|
||||
}
|
||||
}
|
||||
|
||||
// CreateConn creates new connection.
|
||||
func CreateConn(
|
||||
create mtproto.Dialer,
|
||||
mode ConnMode,
|
||||
appID int,
|
||||
opts mtproto.Options,
|
||||
connOpts ConnOptions,
|
||||
) *Conn {
|
||||
connOpts.setDefaults(opts.Clock)
|
||||
conn := &Conn{
|
||||
mode: mode,
|
||||
appID: appID,
|
||||
device: connOpts.Device,
|
||||
clock: opts.Clock,
|
||||
handler: connOpts.Handler,
|
||||
sessionInit: tdsync.NewReady(),
|
||||
gotConfig: tdsync.NewReady(),
|
||||
dead: tdsync.NewReady(),
|
||||
setup: connOpts.Setup,
|
||||
onDead: connOpts.OnDead,
|
||||
connBackoff: connOpts.Backoff,
|
||||
}
|
||||
|
||||
conn.log = opts.Logger
|
||||
opts.DC = connOpts.DC
|
||||
if connOpts.Test {
|
||||
// New key exchange algorithm requires DC ID and uses mapping like MTProxy.
|
||||
// +10000 for test DC, *-1 for media-only.
|
||||
opts.DC += 10000
|
||||
}
|
||||
opts.Handler = conn
|
||||
opts.Logger = conn.log.Named("mtproto")
|
||||
opts.OnError = func(err error) {
|
||||
if auth.IsUnauthorized(err) && connOpts.OnAuthError != nil {
|
||||
connOpts.OnAuthError(err)
|
||||
}
|
||||
}
|
||||
conn.proto = mtproto.New(create, opts)
|
||||
|
||||
return conn
|
||||
}
|
||||
@@ -0,0 +1,61 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/internal/version"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
// DeviceConfig is config which send when Telegram connection session created.
|
||||
type DeviceConfig struct {
|
||||
// Device model.
|
||||
DeviceModel string
|
||||
// Operating system version.
|
||||
SystemVersion string
|
||||
// Application version.
|
||||
AppVersion string
|
||||
// Code for the language used on the device's OS, ISO 639-1 standard.
|
||||
SystemLangCode string
|
||||
// Language pack to use.
|
||||
LangPack string
|
||||
// Code for the language used on the client, ISO 639-1 standard.
|
||||
LangCode string
|
||||
// Info about an MTProto proxy.
|
||||
Proxy tg.InputClientProxy
|
||||
// Additional initConnection parameters. For now, only the tz_offset field is supported,
|
||||
// for specifying timezone offset in seconds.
|
||||
Params tg.JSONValueClass
|
||||
}
|
||||
|
||||
// SetDefaults sets default values.
|
||||
func (c *DeviceConfig) SetDefaults() {
|
||||
const notAvailable = "n/a"
|
||||
|
||||
// Strings must be non-empty, so set notAvailable if default value is empty.
|
||||
set := func(to *string, value string) {
|
||||
if value != "" {
|
||||
*to = value
|
||||
} else {
|
||||
*to = notAvailable
|
||||
}
|
||||
}
|
||||
|
||||
if c.DeviceModel == "" {
|
||||
set(&c.DeviceModel, runtime.Version())
|
||||
}
|
||||
if c.SystemVersion == "" {
|
||||
set(&c.SystemVersion, runtime.GOOS)
|
||||
}
|
||||
if c.AppVersion == "" {
|
||||
set(&c.AppVersion, version.GetVersion())
|
||||
}
|
||||
if c.SystemLangCode == "" {
|
||||
c.SystemLangCode = "en"
|
||||
}
|
||||
if c.LangCode == "" {
|
||||
c.LangCode = "en"
|
||||
}
|
||||
// It's okay to use zero value Proxy.
|
||||
// It's okay to use zero value Params.
|
||||
}
|
||||
@@ -0,0 +1,2 @@
|
||||
// Package manager contains connection management utilities.
|
||||
package manager
|
||||
@@ -0,0 +1,26 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/mtproto"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
// Handler abstracts updates and session handler.
|
||||
type Handler interface {
|
||||
OnSession(cfg tg.Config, s mtproto.Session) error
|
||||
OnMessage(b *bin.Buffer) error
|
||||
}
|
||||
|
||||
// NoopHandler is a noop handler.
|
||||
type NoopHandler struct{}
|
||||
|
||||
// OnSession implements Handler.
|
||||
func (n NoopHandler) OnSession(cfg tg.Config, s mtproto.Session) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// OnMessage implements Handler
|
||||
func (n NoopHandler) OnMessage(b *bin.Buffer) error {
|
||||
return nil
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user