move gotd fork into repo. (#111)

- update to latest telegram layer
- remove some references to fields in tg.Entities that don't exist in
the schema
- originally added here:
https://github.com/beeper/td/commit/820929062a2ba0104397bc01235ab58a9cff780e
  - referenced here
-
https://github.com/mautrix/telegramgo/commit/124f0967ed195b5a380c9bd02e170ada9710dde3
-
https://github.com/mautrix/telegramgo/commit/4205047aab2e0639217148b5d125bfaab668bd8e
This commit is contained in:
Adam Van Ymeren
2025-06-27 20:03:37 -07:00
committed by GitHub
parent 0952df0244
commit 7a04f298d2
19264 changed files with 1539697 additions and 84 deletions
+23
View File
@@ -0,0 +1,23 @@
package telegram
import (
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/auth"
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/auth/qrlogin"
)
// Auth returns auth client.
func (c *Client) Auth() *auth.Client {
return auth.NewClient(
c.tg, c.rand, c.appID, c.appHash,
)
}
// QR returns QR login helper.
func (c *Client) QR() qrlogin.QR {
return qrlogin.NewQR(
c.tg,
c.appID,
c.appHash,
qrlogin.Options{Migrate: c.MigrateTo},
)
}
+23
View File
@@ -0,0 +1,23 @@
// Package auth provides authentication on top of tg.Client.
package auth
import (
"go.mau.fi/mautrix-telegram/pkg/gotd/tgerr"
)
// IsKeyUnregistered reports whether err is AUTH_KEY_UNREGISTERED error.
//
// Deprecated: use IsUnauthorized.
func IsKeyUnregistered(err error) bool {
return tgerr.Is(err, "AUTH_KEY_UNREGISTERED")
}
// IsUnauthorized reports whether err is any 401 UNAUTHORIZED or is a 406
// NOT_ACCEPTABLE with AUTH_KEY_DUPLICATED.
//
// https://core.telegram.org/api/errors#401-unauthorized
// https://core.telegram.org/api/errors#406-not-acceptable
func IsUnauthorized(err error) bool {
return tgerr.IsCode(err, 401) ||
(tgerr.IsCode(err, 406) && tgerr.Is(err, "AUTH_KEY_DUPLICATED"))
}
+44
View File
@@ -0,0 +1,44 @@
package auth
import (
"crypto/rand"
"testing"
"github.com/stretchr/testify/require"
"go.mau.fi/mautrix-telegram/pkg/gotd/testutil"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
"go.mau.fi/mautrix-telegram/pkg/gotd/tgmock"
)
const (
testAppID = 1
testAppHash = "hash"
)
func testClient(invoker tg.Invoker) *Client {
return &Client{
api: tg.NewClient(invoker),
rand: rand.Reader,
appID: testAppID,
appHash: testAppHash,
}
}
func mockClient(t *testing.T) (*tgmock.Mock, *Client) {
mock := tgmock.New(t)
return mock, NewClient(tg.NewClient(mock), testutil.ZeroRand{}, testAppID, testAppHash)
}
func mockTest(cb func(
a *require.Assertions,
mock *tgmock.Mock,
client *Client,
)) func(t *testing.T) {
return func(t *testing.T) {
a := require.New(t)
m, client := mockClient(t)
cb(a, m, client)
}
}
+26
View File
@@ -0,0 +1,26 @@
package auth
import (
"context"
"github.com/go-faster/errors"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
)
// Bot performs bot authentication request.
func (c *Client) Bot(ctx context.Context, token string) (*tg.AuthAuthorization, error) {
auth, err := c.api.AuthImportBotAuthorization(ctx, &tg.AuthImportBotAuthorizationRequest{
APIID: c.appID,
APIHash: c.appHash,
BotAuthToken: token,
})
if err != nil {
return nil, err
}
result, err := checkResult(auth)
if err != nil {
return nil, errors.Wrap(err, "check")
}
return result, nil
}
+46
View File
@@ -0,0 +1,46 @@
package auth
import (
"context"
"testing"
"github.com/stretchr/testify/require"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
"go.mau.fi/mautrix-telegram/pkg/gotd/tgmock"
)
func TestClient_AuthBot(t *testing.T) {
const token = "12345:token"
t.Run("AuthAuthorization", func(t *testing.T) {
mock := tgmock.New(t)
testUser := &tg.User{}
testUser.SetBot(true)
mock.ExpectCall(&tg.AuthImportBotAuthorizationRequest{
BotAuthToken: token,
APIID: testAppID,
APIHash: testAppHash,
}).ThenResult(&tg.AuthAuthorization{User: testUser})
result, err := testClient(mock).Bot(context.Background(), token)
require.NoError(t, err)
require.Equal(t, testUser, result.User)
})
t.Run("AuthAuthorizationSignUpRequired", func(t *testing.T) {
mock := tgmock.New(t)
mock.ExpectCall(&tg.AuthImportBotAuthorizationRequest{
BotAuthToken: token,
APIID: testAppID,
APIHash: testAppHash,
}).ThenResult(&tg.AuthAuthorizationSignUpRequired{})
result, err := testClient(mock).Bot(context.Background(), token)
require.Error(t, err)
require.Nil(t, result)
})
}
+30
View File
@@ -0,0 +1,30 @@
package auth
import (
"io"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
)
// Client implements Telegram authentication.
type Client struct {
api *tg.Client
rand io.Reader
appID int
appHash string
}
// NewClient initializes and returns Telegram authentication client.
func NewClient(
api *tg.Client,
rand io.Reader,
appID int,
appHash string,
) *Client {
return &Client{
api: api,
rand: rand,
appID: appID,
appHash: appHash,
}
}
+307
View File
@@ -0,0 +1,307 @@
package auth
import (
"context"
"fmt"
"io"
"os"
"strconv"
"strings"
"github.com/go-faster/errors"
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
)
// NewFlow initializes new authentication flow.
func NewFlow(auth UserAuthenticator, opt SendCodeOptions) Flow {
return Flow{
Auth: auth,
Options: opt,
}
}
// Flow simplifies boilerplate for authentication flow.
type Flow struct {
Auth UserAuthenticator
Options SendCodeOptions
}
func (f Flow) handleSignUp(ctx context.Context, client FlowClient, phone, hash string, s *SignUpRequired) error {
if err := f.Auth.AcceptTermsOfService(ctx, s.TermsOfService); err != nil {
return errors.Wrap(err, "confirm TOS")
}
info, err := f.Auth.SignUp(ctx)
if err != nil {
return errors.Wrap(err, "sign up info not provided")
}
if _, err := client.SignUp(ctx, SignUp{
PhoneNumber: phone,
PhoneCodeHash: hash,
FirstName: info.FirstName,
LastName: info.LastName,
}); err != nil {
return errors.Wrap(err, "sign up")
}
return nil
}
// Run starts authentication flow on client.
func (f Flow) Run(ctx context.Context, client FlowClient) error {
if f.Auth == nil {
return errors.New("no UserAuthenticator provided")
}
phone, err := f.Auth.Phone(ctx)
if err != nil {
return errors.Wrap(err, "get phone")
}
sentCode, err := client.SendCode(ctx, phone, f.Options)
if err != nil {
return errors.Wrap(err, "send code")
}
switch s := sentCode.(type) {
case *tg.AuthSentCode:
hash := s.PhoneCodeHash
code, err := f.Auth.Code(ctx, s)
if err != nil {
return errors.Wrap(err, "get code")
}
_, signInErr := client.SignIn(ctx, phone, code, hash)
if errors.Is(signInErr, ErrPasswordAuthNeeded) {
password, err := f.Auth.Password(ctx)
if err != nil {
return errors.Wrap(err, "get password")
}
if _, err := client.Password(ctx, password); err != nil {
return errors.Wrap(err, "sign in with password")
}
return nil
}
var signUpRequired *SignUpRequired
if errors.As(signInErr, &signUpRequired) {
return f.handleSignUp(ctx, client, phone, hash, signUpRequired)
}
if signInErr != nil {
return errors.Wrap(signInErr, "sign in")
}
return nil
case *tg.AuthSentCodeSuccess:
switch a := s.Authorization.(type) {
case *tg.AuthAuthorization:
// Looks that we are already authorized.
return nil
case *tg.AuthAuthorizationSignUpRequired:
if err := f.handleSignUp(ctx, client, phone, "", &SignUpRequired{
TermsOfService: a.TermsOfService,
}); err != nil {
// TODO: not sure that blank hash will work here
return errors.Wrap(err, "sign up after auth sent code success")
}
return nil
default:
return errors.Errorf("unexpected authorization type: %T", a)
}
default:
return errors.Errorf("unexpected sent code type: %T", sentCode)
}
}
// FlowClient abstracts telegram client for Flow.
type FlowClient interface {
SignIn(ctx context.Context, phone, code, codeHash string) (*tg.AuthAuthorization, error)
SendCode(ctx context.Context, phone string, options SendCodeOptions) (tg.AuthSentCodeClass, error)
Password(ctx context.Context, password string) (*tg.AuthAuthorization, error)
SignUp(ctx context.Context, s SignUp) (*tg.AuthAuthorization, error)
}
// CodeAuthenticator asks user for received authentication code.
type CodeAuthenticator interface {
Code(ctx context.Context, sentCode *tg.AuthSentCode) (string, error)
}
// CodeAuthenticatorFunc is functional wrapper for CodeAuthenticator.
type CodeAuthenticatorFunc func(ctx context.Context, sentCode *tg.AuthSentCode) (string, error)
// Code implements CodeAuthenticator interface.
func (c CodeAuthenticatorFunc) Code(ctx context.Context, sentCode *tg.AuthSentCode) (string, error) {
return c(ctx, sentCode)
}
// UserInfo represents user info required for sign up.
type UserInfo struct {
FirstName string
LastName string
}
// UserAuthenticator asks user for phone, password and received authentication code.
type UserAuthenticator interface {
Phone(ctx context.Context) (string, error)
Password(ctx context.Context) (string, error)
AcceptTermsOfService(ctx context.Context, tos tg.HelpTermsOfService) error
SignUp(ctx context.Context) (UserInfo, error)
CodeAuthenticator
}
type noSignUp struct{}
func (c noSignUp) SignUp(ctx context.Context) (UserInfo, error) {
return UserInfo{}, errors.New("not implemented")
}
func (c noSignUp) AcceptTermsOfService(ctx context.Context, tos tg.HelpTermsOfService) error {
return &SignUpRequired{TermsOfService: tos}
}
type constantAuth struct {
phone, password string
CodeAuthenticator
noSignUp
}
func (c constantAuth) Phone(ctx context.Context) (string, error) {
return c.phone, nil
}
func (c constantAuth) Password(ctx context.Context) (string, error) {
return c.password, nil
}
// Constant creates UserAuthenticator with constant phone and password.
func Constant(phone, password string, code CodeAuthenticator) UserAuthenticator {
return constantAuth{
phone: phone,
password: password,
CodeAuthenticator: code,
}
}
type envAuth struct {
prefix string
CodeAuthenticator
noSignUp
}
func (e envAuth) lookup(k string) (string, error) {
env := e.prefix + k
v, ok := os.LookupEnv(env)
if !ok {
return "", errors.Errorf("environment variable %q not set", env)
}
return v, nil
}
func (e envAuth) Phone(ctx context.Context) (string, error) {
return e.lookup("PHONE")
}
func (e envAuth) Password(ctx context.Context) (string, error) {
p, err := e.lookup("PASSWORD")
if err != nil {
return "", ErrPasswordNotProvided
}
return p, nil
}
// Env creates UserAuthenticator which gets phone and password from environment variables.
func Env(prefix string, code CodeAuthenticator) UserAuthenticator {
return envAuth{
prefix: prefix,
CodeAuthenticator: code,
noSignUp: noSignUp{},
}
}
// ErrPasswordNotProvided means that password requested by Telegram,
// but not provided by user.
var ErrPasswordNotProvided = errors.New("password requested but not provided")
type codeOnlyAuth struct {
phone string
CodeAuthenticator
noSignUp
}
func (c codeOnlyAuth) Phone(ctx context.Context) (string, error) {
return c.phone, nil
}
func (c codeOnlyAuth) Password(ctx context.Context) (string, error) {
return "", ErrPasswordNotProvided
}
// CodeOnly creates UserAuthenticator with constant phone and no password.
func CodeOnly(phone string, code CodeAuthenticator) UserAuthenticator {
return codeOnlyAuth{
phone: phone,
CodeAuthenticator: code,
}
}
type testAuth struct {
dc int
phone string
}
func (t testAuth) Phone(ctx context.Context) (string, error) { return t.phone, nil }
func (t testAuth) Password(ctx context.Context) (string, error) { return "", ErrPasswordNotProvided }
func (t testAuth) Code(ctx context.Context, sentCode *tg.AuthSentCode) (string, error) {
type notFlashing interface {
GetLength() int
}
length := 5
if sentCode != nil {
typ, ok := sentCode.Type.(notFlashing)
if !ok {
return "", errors.Errorf("unexpected type: %T", sentCode.Type)
}
length = typ.GetLength()
}
return strings.Repeat(strconv.Itoa(t.dc), length), nil
}
func (t testAuth) AcceptTermsOfService(ctx context.Context, tos tg.HelpTermsOfService) error {
return nil
}
func (t testAuth) SignUp(ctx context.Context) (UserInfo, error) {
return UserInfo{
FirstName: "Test",
LastName: "User",
}, nil
}
// Test returns UserAuthenticator that authenticates via testing credentials.
//
// Can be used only with testing server. Will perform sign up if test user is
// not registered.
func Test(randReader io.Reader, dc int) UserAuthenticator {
// 99966XYYYY, X = dc_id, Y = random numbers, code = X repeat 6.
// The n value is from 0000 to 9999.
n, err := crypto.RandInt64n(randReader, 1000)
if err != nil {
panic(err)
}
phone := fmt.Sprintf("99966%d%04d", dc, n)
return TestUser(phone, dc)
}
// TestUser returns UserAuthenticator that authenticates via testing credentials.
// Uses given phone to sign in/sign up.
//
// Can be used only with testing server. Will perform sign up if test user is
// not registered.
func TestUser(phone string, dc int) UserAuthenticator {
return testAuth{
dc: dc,
phone: phone,
}
}
+114
View File
@@ -0,0 +1,114 @@
package auth_test
import (
"context"
"strings"
"testing"
"github.com/stretchr/testify/require"
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/auth"
"go.mau.fi/mautrix-telegram/pkg/gotd/testutil"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
)
func askCode(code string, err error) auth.CodeAuthenticatorFunc {
return func(ctx context.Context, sentCode *tg.AuthSentCode) (string, error) {
return code, err
}
}
func TestConstantAuth(t *testing.T) {
a := require.New(t)
authConst := auth.Constant("phone", "password", askCode("123", nil))
ctx := context.Background()
result, err := authConst.Code(ctx, nil)
a.NoError(err)
a.Equal("123", result)
result, err = authConst.Phone(ctx)
a.NoError(err)
a.Equal("phone", result)
result, err = authConst.Password(ctx)
a.NoError(err)
a.Equal("password", result)
}
func TestCodeOnlyAuth(t *testing.T) {
a := require.New(t)
authCodeOnly := auth.CodeOnly("phone", askCode("123", nil))
ctx := context.Background()
result, err := authCodeOnly.Code(ctx, nil)
a.NoError(err)
a.Equal("123", result)
result, err = authCodeOnly.Phone(ctx)
a.NoError(err)
a.Equal("phone", result)
_, err = authCodeOnly.Password(ctx)
a.ErrorIs(err, auth.ErrPasswordNotProvided)
}
func TestEnvAuth(t *testing.T) {
a := require.New(t)
ctx := context.Background()
prefix := "TEST_ENV_AUTH_"
authEnv := auth.Env(prefix, askCode("123", nil))
result, err := authEnv.Code(ctx, nil)
a.NoError(err)
a.Equal("123", result)
_, err = authEnv.Phone(ctx)
a.Error(err)
_, err = authEnv.Password(ctx)
a.ErrorIs(err, auth.ErrPasswordNotProvided)
// Set envs.
testutil.SetEnv(t, prefix+"PHONE", "phone")
testutil.SetEnv(t, prefix+"PASSWORD", "password")
result, err = authEnv.Phone(ctx)
a.NoError(err)
a.Equal("phone", result)
result, err = authEnv.Password(ctx)
a.NoError(err)
a.Equal("password", result)
}
func TestTestAuth(t *testing.T) {
a := require.New(t)
ctx := context.Background()
testAuth := auth.Test(testutil.ZeroRand{}, 2)
_, err := testAuth.Code(ctx, &tg.AuthSentCode{
Type: &tg.AuthSentCodeTypeFlashCall{},
})
a.Error(err)
result, err := testAuth.Code(ctx, nil)
a.NoError(err)
a.Equal("22222", result)
result, err = testAuth.Code(ctx, &tg.AuthSentCode{
Type: &tg.AuthSentCodeTypeApp{
Length: 1,
},
})
a.NoError(err)
a.Equal("2", result)
result, err = testAuth.Phone(ctx)
a.NoError(err)
a.True(strings.HasPrefix(result, "999662"))
_, err = testAuth.Password(ctx)
a.ErrorIs(err, auth.ErrPasswordNotProvided)
}
+179
View File
@@ -0,0 +1,179 @@
package auth
import (
"context"
"fmt"
"time"
"github.com/go-faster/errors"
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto/srp"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
)
// PasswordHash computes password hash to log in.
//
// See https://core.telegram.org/api/srp#checking-the-password-with-srp.
func PasswordHash(
password []byte,
srpID int64,
srpB, secureRandom []byte,
alg tg.PasswordKdfAlgoClass,
) (*tg.InputCheckPasswordSRP, error) {
s := srp.NewSRP(crypto.DefaultRand())
algo, ok := alg.(*tg.PasswordKdfAlgoSHA256SHA256PBKDF2HMACSHA512iter100000SHA256ModPow)
if !ok {
return nil, errors.Errorf("unsupported algo: %T", alg)
}
a, err := s.Hash(password, srpB, secureRandom, srp.Input(*algo))
if err != nil {
return nil, errors.Wrap(err, "create SRP answer")
}
return &tg.InputCheckPasswordSRP{
SRPID: srpID,
A: a.A,
M1: a.M1,
}, nil
}
// NewPasswordHash computes new password hash to update password.
//
// Notice that NewPasswordHash mutates given alg.
//
// See https://core.telegram.org/api/srp#setting-a-new-2fa-password.
func NewPasswordHash(
password []byte,
algo *tg.PasswordKdfAlgoSHA256SHA256PBKDF2HMACSHA512iter100000SHA256ModPow,
) (hash []byte, _ error) {
s := srp.NewSRP(crypto.DefaultRand())
hash, newSalt, err := s.NewHash(password, srp.Input(*algo))
if err != nil {
return nil, errors.Wrap(err, "create SRP answer")
}
algo.Salt1 = newSalt
return hash, nil
}
var (
emptyPassword tg.InputCheckPasswordSRPClass = &tg.InputCheckPasswordEmpty{}
)
// UpdatePasswordOptions is options structure for UpdatePassword.
type UpdatePasswordOptions struct {
// Hint is new password hint.
Hint string
// Password is password callback.
//
// If password was requested and Password is nil, ErrPasswordNotProvided error will be returned.
Password func(ctx context.Context) (string, error)
}
// UpdatePassword sets new cloud password for this account.
//
// See https://core.telegram.org/api/srp#setting-a-new-2fa-password.
func (c *Client) UpdatePassword(
ctx context.Context,
newPassword string,
opts UpdatePasswordOptions,
) error {
p, err := c.api.AccountGetPassword(ctx)
if err != nil {
return errors.Wrap(err, "get SRP parameters")
}
algo, ok := p.NewAlgo.(*tg.PasswordKdfAlgoSHA256SHA256PBKDF2HMACSHA512iter100000SHA256ModPow)
if !ok {
return errors.Errorf("unsupported algo: %T", p.NewAlgo)
}
newHash, err := NewPasswordHash([]byte(newPassword), algo)
if err != nil {
return errors.Wrap(err, "compute new password hash")
}
var old = emptyPassword
if p.HasPassword {
if opts.Password == nil {
return ErrPasswordNotProvided
}
oldPassword, err := opts.Password(ctx)
if err != nil {
return errors.Wrap(err, "get password")
}
hash, err := PasswordHash([]byte(oldPassword), p.SRPID, p.SRPB, p.SecureRandom, p.CurrentAlgo)
if err != nil {
return errors.Wrap(err, "compute old password hash")
}
old = hash
}
if _, err := c.api.AccountUpdatePasswordSettings(ctx, &tg.AccountUpdatePasswordSettingsRequest{
Password: old,
NewSettings: tg.AccountPasswordInputSettings{
NewAlgo: algo,
NewPasswordHash: newHash,
Hint: opts.Hint,
},
}); err != nil {
return errors.Wrap(err, "update password")
}
return nil
}
// ResetFailedWaitError reports that you recently requested a password reset that was cancel and need to wait until the
// specified date before requesting another reset.
type ResetFailedWaitError struct {
Result tg.AccountResetPasswordFailedWait
}
// Until returns time required to wait.
func (r ResetFailedWaitError) Until() time.Duration {
retryDate := time.Unix(int64(r.Result.RetryDate), 0)
return time.Until(retryDate)
}
// Error implements error.
func (r *ResetFailedWaitError) Error() string {
return fmt.Sprintf("wait to reset password (%s)", r.Until())
}
// ResetPassword resets cloud password and returns time to wait until reset be performed.
// If time is zero, password was successfully reset.
//
// May return ResetFailedWaitError.
//
// See https://core.telegram.org/api/srp#password-reset.
func (c *Client) ResetPassword(ctx context.Context) (time.Time, error) {
r, err := c.api.AccountResetPassword(ctx)
if err != nil {
return time.Time{}, errors.Wrap(err, "reset password")
}
switch v := r.(type) {
case *tg.AccountResetPasswordFailedWait:
return time.Time{}, &ResetFailedWaitError{Result: *v}
case *tg.AccountResetPasswordRequestedWait:
return time.Unix(int64(v.UntilDate), 0), nil
case *tg.AccountResetPasswordOk:
return time.Time{}, nil
default:
return time.Time{}, errors.Errorf("unexpected type %T", v)
}
}
// CancelPasswordReset cancels password reset.
//
// See https://core.telegram.org/api/srp#password-reset.
func (c *Client) CancelPasswordReset(ctx context.Context) error {
if _, err := c.api.AccountDeclinePasswordReset(ctx); err != nil {
return errors.Wrap(err, "cancel password reset")
}
return nil
}
@@ -0,0 +1,62 @@
package auth_test
import (
"context"
"fmt"
"github.com/go-faster/errors"
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram"
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/auth"
)
func ExampleClient_UpdatePassword() {
ctx := context.Background()
client := telegram.NewClient(telegram.TestAppID, telegram.TestAppHash, telegram.Options{})
if err := client.Run(ctx, func(ctx context.Context) error {
// Updating password.
if err := client.Auth().UpdatePassword(ctx, "new_password", auth.UpdatePasswordOptions{
// Hint sets new password hint.
Hint: "new password hint",
// Password will be called if old password is requested by Telegram.
//
// If password was requested and Password is nil, auth.ErrPasswordNotProvided error will be returned.
Password: func(ctx context.Context) (string, error) {
return "old_password", nil
},
}); err != nil {
return err
}
return nil
}); err != nil {
panic(err)
}
}
func ExampleClient_ResetPassword() {
ctx := context.Background()
client := telegram.NewClient(telegram.TestAppID, telegram.TestAppHash, telegram.Options{})
if err := client.Run(ctx, func(ctx context.Context) error {
wait, err := client.Auth().ResetPassword(ctx)
var waitErr *auth.ResetFailedWaitError
switch {
case errors.As(err, &waitErr):
// Telegram requested wait until making new reset request.
fmt.Printf("Wait until %s to reset password.\n", wait.String())
case err != nil:
return err
}
// If returned time is zero, password was successfully reset.
if wait.IsZero() {
fmt.Println("Password was reset.")
return nil
}
fmt.Printf("Password will be reset on %s.\n", wait.String())
return nil
}); err != nil {
panic(err)
}
}
+169
View File
@@ -0,0 +1,169 @@
package auth
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/require"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
"go.mau.fi/mautrix-telegram/pkg/gotd/testutil"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
"go.mau.fi/mautrix-telegram/pkg/gotd/tgmock"
)
func TestPasswordHash(t *testing.T) {
a := require.New(t)
_, err := PasswordHash(nil, 0, nil, nil, nil)
a.Error(err, "unsupported algo")
}
var testAlgo = &tg.PasswordKdfAlgoSHA256SHA256PBKDF2HMACSHA512iter100000SHA256ModPow{
Salt1: []uint8{
230, 200, 149, 125, 223, 152, 141, 72,
},
Salt2: []uint8{
159, 99, 68, 130, 43, 9, 108, 255, 135, 239, 164, 38, 245, 120, 87, 182,
},
G: 3,
P: []uint8{
199, 28, 174, 185, 198, 177, 201, 4, 142, 108, 82, 47, 112, 241, 63, 115,
152, 13, 64, 35, 142, 62, 33, 193, 73, 52, 208, 55, 86, 61, 147, 15,
72, 25, 138, 10, 167, 193, 64, 88, 34, 148, 147, 210, 37, 48, 244, 219,
250, 51, 111, 110, 10, 201, 37, 19, 149, 67, 174, 212, 76, 206, 124, 55,
32, 253, 81, 246, 148, 88, 112, 90, 198, 140, 212, 254, 107, 107, 19, 171,
220, 151, 70, 81, 41, 105, 50, 132, 84, 241, 143, 175, 140, 89, 95, 100,
36, 119, 254, 150, 187, 42, 148, 29, 91, 205, 29, 74, 200, 204, 73, 136,
7, 8, 250, 155, 55, 142, 60, 79, 58, 144, 96, 190, 230, 124, 249, 164,
164, 166, 149, 129, 16, 81, 144, 126, 22, 39, 83, 181, 107, 15, 107, 65,
13, 186, 116, 216, 168, 75, 42, 20, 179, 20, 78, 14, 241, 40, 71, 84,
253, 23, 237, 149, 13, 89, 101, 180, 185, 221, 70, 88, 45, 177, 23, 141,
22, 156, 107, 196, 101, 176, 214, 255, 156, 163, 146, 143, 239, 91, 154, 228,
228, 24, 252, 21, 232, 62, 190, 160, 248, 127, 169, 255, 94, 237, 112, 5,
13, 237, 40, 73, 244, 123, 249, 89, 217, 86, 133, 12, 233, 41, 133, 31,
13, 129, 21, 246, 53, 177, 5, 238, 46, 78, 21, 208, 75, 36, 84, 191,
111, 79, 173, 240, 52, 177, 4, 3, 17, 156, 216, 227, 185, 47, 204, 91,
},
}
func TestClient_UpdatePassword(t *testing.T) {
ctx := context.Background()
expectCall := func(a *require.Assertions, m *tgmock.Mock, hasPassword bool) *tgmock.RequestBuilder {
p := &tg.AccountPassword{
HasPassword: hasPassword,
NewAlgo: testAlgo,
NewSecureAlgo: &tg.SecurePasswordKdfAlgoUnknown{},
}
if hasPassword {
p.CurrentAlgo = testAlgo
}
p.SetFlags()
return m.ExpectCall(&tg.AccountGetPasswordRequest{}).
ThenResult(p).ExpectFunc(func(b bin.Encoder) {
a.IsType(&tg.AccountUpdatePasswordSettingsRequest{}, b)
r := b.(*tg.AccountUpdatePasswordSettingsRequest)
if !hasPassword {
a.Equal(emptyPassword, r.Password)
} else {
a.NotEqual(emptyPassword, r.Password)
}
a.NotEmpty(r.NewSettings.NewPasswordHash)
a.Equal("hint", r.NewSettings.Hint)
})
}
t.Run("PasswordNotRequired", mockTest(func(
a *require.Assertions,
m *tgmock.Mock,
client *Client,
) {
m.ExpectCall(&tg.AccountGetPasswordRequest{}).ThenErr(testutil.TestError())
a.Error(client.UpdatePassword(ctx, "", UpdatePasswordOptions{}))
expectCall(a, m, false).ThenTrue()
a.NoError(client.UpdatePassword(ctx, "", UpdatePasswordOptions{
Hint: "hint",
}))
}))
t.Run("PasswordRequired", mockTest(func(
a *require.Assertions,
m *tgmock.Mock,
client *Client,
) {
m.ExpectCall(&tg.AccountGetPasswordRequest{}).
ThenResult(&tg.AccountPassword{
HasPassword: true,
NewAlgo: testAlgo,
CurrentAlgo: testAlgo,
NewSecureAlgo: &tg.SecurePasswordKdfAlgoUnknown{},
})
a.ErrorIs(client.UpdatePassword(ctx, "", UpdatePasswordOptions{}), ErrPasswordNotProvided)
m.ExpectCall(&tg.AccountGetPasswordRequest{}).
ThenResult(&tg.AccountPassword{
HasPassword: true,
NewAlgo: testAlgo,
CurrentAlgo: testAlgo,
NewSecureAlgo: &tg.SecurePasswordKdfAlgoUnknown{},
})
a.ErrorIs(client.UpdatePassword(ctx, "", UpdatePasswordOptions{
Hint: "hint",
Password: func(ctx context.Context) (string, error) {
return "", testutil.TestError()
},
}), testutil.TestError())
expectCall(a, m, true).ThenTrue()
a.NoError(client.UpdatePassword(ctx, "", UpdatePasswordOptions{
Hint: "hint",
Password: func(ctx context.Context) (string, error) {
return "password", nil
},
}))
}))
}
func TestClient_ResetPassword(t *testing.T) {
ctx := context.Background()
wait := time.Date(2022, 1, 1, 0, 0, 0, 0, time.UTC).Unix()
mockTest(func(a *require.Assertions, mock *tgmock.Mock, client *Client) {
mock.ExpectCall(&tg.AccountResetPasswordRequest{}).ThenErr(testutil.TestError())
_, err := client.ResetPassword(ctx)
a.Error(err)
mock.ExpectCall(&tg.AccountResetPasswordRequest{}).ThenResult(&tg.AccountResetPasswordFailedWait{
RetryDate: int(wait),
})
var waitErr *ResetFailedWaitError
_, err = client.ResetPassword(ctx)
a.ErrorAs(err, &waitErr)
a.Equal(int(wait), waitErr.Result.RetryDate)
a.NotEmpty(waitErr.Error())
mock.ExpectCall(&tg.AccountResetPasswordRequest{}).ThenResult(&tg.AccountResetPasswordOk{})
r, err := client.ResetPassword(ctx)
a.NoError(err)
a.True(r.IsZero())
mock.ExpectCall(&tg.AccountResetPasswordRequest{}).ThenResult(&tg.AccountResetPasswordRequestedWait{
UntilDate: int(wait),
})
r, err = client.ResetPassword(ctx)
a.NoError(err)
a.False(r.IsZero())
})(t)
}
func TestClient_CancelPasswordReset(t *testing.T) {
ctx := context.Background()
mockTest(func(a *require.Assertions, mock *tgmock.Mock, client *Client) {
mock.ExpectCall(&tg.AccountDeclinePasswordResetRequest{}).ThenErr(testutil.TestError())
a.Error(client.CancelPasswordReset(ctx))
mock.ExpectCall(&tg.AccountDeclinePasswordResetRequest{}).ThenTrue()
a.NoError(client.CancelPasswordReset(ctx))
})(t)
}
+20
View File
@@ -0,0 +1,20 @@
package qrlogin
import (
"context"
"github.com/go-faster/errors"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
)
// AcceptQR accepts given token.
//
// See https://core.telegram.org/api/qr-login#accepting-a-login-token.
func AcceptQR(ctx context.Context, raw *tg.Client, t Token) (*tg.Authorization, error) {
auth, err := raw.AuthAcceptLoginToken(ctx, t.token)
if err != nil {
return nil, errors.Wrap(err, "accept")
}
return auth, nil
}
+23
View File
@@ -0,0 +1,23 @@
package qrlogin
import (
"fmt"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
)
// MigrationNeededError reports that Telegram requested DC migration to continue login.
type MigrationNeededError struct {
MigrateTo *tg.AuthLoginTokenMigrateTo
// Tried indicates that the migration was attempted.
//
// Deprecated: do not use. QR login uses migrate function passed via
// options.
Tried bool
}
// Error implements error.
func (m *MigrationNeededError) Error() string {
return fmt.Sprintf("migration to %d needed", m.MigrateTo.DCID)
}
+20
View File
@@ -0,0 +1,20 @@
package qrlogin
import (
"context"
"go.mau.fi/mautrix-telegram/pkg/gotd/clock"
)
// Options of QR.
type Options struct {
Migrate func(ctx context.Context, dcID int) error
Clock clock.Clock
}
func (o *Options) setDefaults() {
// It's okay to use zero value Migrate.
if o.Clock == nil {
o.Clock = clock.System
}
}
+175
View File
@@ -0,0 +1,175 @@
// Package qrlogin provides QR login flow implementation.
//
// See https://core.telegram.org/api/qr-login.
package qrlogin
import (
"context"
"time"
"github.com/go-faster/errors"
"go.mau.fi/mautrix-telegram/pkg/gotd/clock"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
)
// QR implements Telegram QR login flow.
type QR struct {
api *tg.Client
appID int
appHash string
migrate func(ctx context.Context, dcID int) error
clock clock.Clock
}
// NewQR creates new QR
func NewQR(api *tg.Client, appID int, appHash string, opts Options) QR {
opts.setDefaults()
return QR{
api: api,
appID: appID,
appHash: appHash,
clock: opts.Clock,
migrate: opts.Migrate,
}
}
// Export exports new login token.
//
// See https://core.telegram.org/api/qr-login#exporting-a-login-token.
func (q QR) Export(ctx context.Context, exceptIDs ...int64) (Token, error) {
result, err := q.api.AuthExportLoginToken(ctx, &tg.AuthExportLoginTokenRequest{
APIID: q.appID,
APIHash: q.appHash,
ExceptIDs: exceptIDs,
})
if err != nil {
return Token{}, errors.Wrap(err, "export")
}
t, ok := result.(*tg.AuthLoginToken)
if !ok {
return Token{}, errors.Errorf("unexpected type %T", result)
}
return NewToken(t.Token, t.Expires), nil
}
// Accept accepts given token.
//
// See https://core.telegram.org/api/qr-login#accepting-a-login-token.
func (q QR) Accept(ctx context.Context, t Token) (*tg.Authorization, error) {
return AcceptQR(ctx, q.api, t)
}
// Import imports accepted token.
//
// See https://core.telegram.org/api/qr-login#confirming-importing-the-login-token.
func (q QR) Import(ctx context.Context) (*tg.AuthAuthorization, error) {
result, err := q.api.AuthExportLoginToken(ctx, &tg.AuthExportLoginTokenRequest{
APIID: q.appID,
APIHash: q.appHash,
})
if err != nil {
return nil, errors.Wrap(err, "import")
}
switch t := result.(type) {
case *tg.AuthLoginTokenMigrateTo:
if q.migrate == nil {
return nil, &MigrationNeededError{
MigrateTo: t,
}
}
if err := q.migrate(ctx, t.DCID); err != nil {
return nil, errors.Wrap(err, "migrate")
}
res, err := q.api.AuthImportLoginToken(ctx, t.Token)
if err != nil {
return nil, errors.Wrap(err, "import")
}
success, ok := res.(*tg.AuthLoginTokenSuccess)
if !ok {
return nil, errors.Errorf("unexpected type %T", res)
}
auth, ok := success.Authorization.(*tg.AuthAuthorization)
if !ok {
return nil, errors.Errorf("unexpected type %T", success.Authorization)
}
return auth, nil
case *tg.AuthLoginTokenSuccess:
auth, ok := t.Authorization.(*tg.AuthAuthorization)
if !ok {
return nil, errors.Errorf("unexpected type %T", t.Authorization)
}
return auth, nil
default:
return nil, errors.Errorf("unexpected type %T", result)
}
}
// LoggedIn is signal channel to notify about tg.UpdateLoginToken.
type LoggedIn <-chan struct{}
// OnLoginToken sets handler for given dispatcher and returns signal channel.
func OnLoginToken(d interface {
OnLoginToken(tg.LoginTokenHandler)
},
) LoggedIn {
loggedIn := make(chan struct{})
d.OnLoginToken(func(ctx context.Context, e tg.Entities, update *tg.UpdateLoginToken) error {
select {
case loggedIn <- struct{}{}:
return nil
default:
}
return nil
})
return loggedIn
}
// Auth generates new QR login token, shows it and awaits acceptation.
//
// NB: Show callback may be called more than once if QR expires.
func (q QR) Auth(
ctx context.Context,
loggedIn LoggedIn,
show func(ctx context.Context, token Token) error,
exceptIDs ...int64,
) (*tg.AuthAuthorization, error) {
until := func(token Token) time.Duration {
return token.Expires().Sub(q.clock.Now()).Truncate(time.Second)
}
token, err := q.Export(ctx, exceptIDs...)
if err != nil {
return nil, err
}
timer := q.clock.Timer(until(token))
defer clock.StopTimer(timer)
for {
if err := show(ctx, token); err != nil {
return nil, errors.Wrap(err, "show")
}
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-timer.C():
t, err := q.Export(ctx, exceptIDs...)
if err != nil {
return nil, err
}
token = t
timer.Reset(until(token))
continue
case <-loggedIn:
}
return q.Import(ctx)
}
}
@@ -0,0 +1,187 @@
package qrlogin
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/gotd/neo"
"go.mau.fi/mautrix-telegram/pkg/gotd/constant"
"go.mau.fi/mautrix-telegram/pkg/gotd/testutil"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
"go.mau.fi/mautrix-telegram/pkg/gotd/tgmock"
)
func testQR(t *testing.T, migrate func(ctx context.Context, dcID int) error) (*tgmock.Mock, QR) {
mock := tgmock.New(t)
return mock, NewQR(tg.NewClient(mock), constant.TestAppID, constant.TestAppHash, Options{
Migrate: migrate,
})
}
var testToken = NewToken([]byte("token"), 0)
func TestQR_Export(t *testing.T) {
ctx := context.Background()
a := require.New(t)
mock, qr := testQR(t, nil)
mock.ExpectCall(&tg.AuthExportLoginTokenRequest{
APIID: constant.TestAppID,
APIHash: constant.TestAppHash,
ExceptIDs: []int64{0},
}).ThenResult(&tg.AuthLoginToken{
Expires: 0,
Token: testToken.token,
})
result, err := qr.Export(ctx, 0)
a.NoError(err)
a.Equal(Token{
token: testToken.token,
expires: time.Unix(0, 0),
}, result)
mock.ExpectCall(&tg.AuthExportLoginTokenRequest{
APIID: constant.TestAppID,
APIHash: constant.TestAppHash,
}).ThenResult(&tg.AuthLoginTokenMigrateTo{
Token: testToken.token,
})
_, err = qr.Export(ctx)
a.Error(err)
mock.ExpectCall(&tg.AuthExportLoginTokenRequest{
APIID: constant.TestAppID,
APIHash: constant.TestAppHash,
}).ThenErr(testutil.TestError())
_, err = qr.Export(ctx)
a.ErrorIs(err, testutil.TestError())
}
func TestQR_Accept(t *testing.T) {
ctx := context.Background()
a := require.New(t)
mock, qr := testQR(t, nil)
auth := &tg.Authorization{
APIID: 1,
}
mock.ExpectCall(&tg.AuthAcceptLoginTokenRequest{
Token: testToken.token,
}).ThenResult(auth)
result, err := qr.Accept(ctx, testToken)
a.NoError(err)
a.Equal(auth, result)
mock.ExpectCall(&tg.AuthAcceptLoginTokenRequest{
Token: testToken.token,
}).ThenErr(testutil.TestError())
_, err = qr.Accept(ctx, testToken)
a.ErrorIs(err, testutil.TestError())
}
func TestQR_Import(t *testing.T) {
ctx := context.Background()
a := require.New(t)
mock, qr := testQR(t, nil)
auth := &tg.AuthAuthorization{
User: &tg.User{ID: 10},
}
mock.ExpectCall(&tg.AuthExportLoginTokenRequest{
APIID: constant.TestAppID,
APIHash: constant.TestAppHash,
}).ThenResult(&tg.AuthLoginTokenSuccess{
Authorization: auth,
})
result, err := qr.Import(ctx)
a.NoError(err)
a.Equal(auth, result)
mock.ExpectCall(&tg.AuthExportLoginTokenRequest{
APIID: constant.TestAppID,
APIHash: constant.TestAppHash,
}).ThenResult(&tg.AuthLoginTokenMigrateTo{
DCID: 1,
})
_, err = qr.Import(ctx)
var mig *MigrationNeededError
a.ErrorAs(err, &mig)
a.Equal(1, mig.MigrateTo.DCID)
mock.ExpectCall(&tg.AuthExportLoginTokenRequest{
APIID: constant.TestAppID,
APIHash: constant.TestAppHash,
}).ThenResult(&tg.AuthLoginToken{
Token: testToken.token,
})
_, err = qr.Import(ctx)
a.Error(err)
mock.ExpectCall(&tg.AuthExportLoginTokenRequest{
APIID: constant.TestAppID,
APIHash: constant.TestAppHash,
}).ThenErr(testutil.TestError())
_, err = qr.Import(ctx)
a.ErrorIs(err, testutil.TestError())
}
func TestQR_Auth(t *testing.T) {
a := require.New(t)
mock := tgmock.New(t)
clock := neo.NewTime(time.Now())
auth := &tg.AuthAuthorization{
User: &tg.User{ID: 10},
}
mock.ExpectCall(&tg.AuthExportLoginTokenRequest{
APIID: constant.TestAppID,
APIHash: constant.TestAppHash,
}).ThenResult(&tg.AuthLoginToken{
Expires: int(clock.Now().Add(time.Minute).Unix()),
Token: testToken.token,
}).ExpectCall(&tg.AuthExportLoginTokenRequest{
APIID: constant.TestAppID,
APIHash: constant.TestAppHash,
}).ThenResult(&tg.AuthLoginToken{
Expires: int(clock.Now().Add(2 * time.Minute).Unix()),
Token: testToken.token,
}).ExpectCall(&tg.AuthExportLoginTokenRequest{
APIID: constant.TestAppID,
APIHash: constant.TestAppHash,
}).ThenResult(&tg.AuthLoginTokenSuccess{
Authorization: auth,
})
qr := NewQR(tg.NewClient(mock), constant.TestAppID, constant.TestAppHash, Options{
Clock: clock,
})
show := make(chan struct{})
done := make(chan error)
loggedIn := make(chan struct{})
go func() {
_, err := qr.Auth(context.Background(), loggedIn, func(ctx context.Context, token Token) error {
show <- struct{}{}
return nil
})
done <- err
}()
// Show QR first time.
<-show
// Skip 1 minute, token expires.
clock.Travel(time.Minute + 1)
// Show QR second time.
<-show
// Emulate update, auth done.
loggedIn <- struct{}{}
a.NoError(<-done)
}
+76
View File
@@ -0,0 +1,76 @@
package qrlogin
import (
"encoding/base64"
"image"
"net/url"
"time"
"github.com/go-faster/errors"
"rsc.io/qr"
)
// Token represents Telegram QR Login token.
type Token struct {
token []byte
expires time.Time
}
// ParseTokenURL creates Token from given URL.
func ParseTokenURL(u string) (Token, error) {
parsed, err := url.Parse(u)
if err != nil {
return Token{}, err
}
switch {
case parsed.Scheme != "tg":
return Token{}, errors.Errorf("unexpected scheme %q", parsed.Scheme)
case parsed.Host != "login":
return Token{}, errors.Errorf("wrong path %q", parsed.Host)
}
q := parsed.Query()
if q.Get("token") == "" {
return Token{}, errors.New("token is empty")
}
token, err := base64.URLEncoding.DecodeString(q.Get("token"))
if err != nil {
return Token{}, err
}
return NewToken(token, 0), nil
}
// NewToken creates new Token.
func NewToken(token []byte, expires int) Token {
return Token{
token: token,
expires: time.Unix(int64(expires), 0),
}
}
// Expires returns token expiration time.
func (t Token) Expires() time.Time {
return t.expires
}
// String implements fmt.Stringer.
func (t Token) String() string {
return base64.URLEncoding.EncodeToString(t.token)
}
// URL returns login URL.
//
// See https://core.telegram.org/api/qr-login#exporting-a-login-token.
func (t Token) URL() string {
return "tg://login?token=" + base64.URLEncoding.EncodeToString(t.token)
}
// Image returns QR image.
func (t Token) Image(level qr.Level) (image.Image, error) {
code, err := qr.Encode(t.URL(), level)
if err != nil {
return nil, errors.Wrap(err, "encode")
}
return code.Image(), nil
}
@@ -0,0 +1,52 @@
package qrlogin
import (
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestParseTokenURL(t *testing.T) {
tests := []struct {
name string
u string
want Token
wantErr bool
}{
{
"Valid",
"tg://login?token=AQL0cY5hVg_D1OqESdYnJVg5845qbd8FiOLpUUeyvcb28g==",
Token{
token: []uint8{
0x1, 0x2, 0xf4, 0x71, 0x8e, 0x61,
0x56, 0xf, 0xc3, 0xd4, 0xea, 0x84,
0x49, 0xd6, 0x27, 0x25, 0x58, 0x39,
0xf3, 0x8e, 0x6a, 0x6d, 0xdf, 0x5,
0x88, 0xe2, 0xe9, 0x51, 0x47, 0xb2,
0xbd, 0xc6, 0xf6, 0xf2,
},
expires: time.Unix(0, 0),
},
false,
},
{"InvalidSchema", "vk://login", Token{}, true},
{"InvalidPath", "tg://aboba", Token{}, true},
{"NoToken", "tg://login", Token{}, true},
{"InvalidBase64", "tg://login?token=A", Token{}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := require.New(t)
got, err := ParseTokenURL(tt.u)
if tt.wantErr {
a.Error(err)
} else {
a.Equal(tt.want, got)
a.NoError(err)
a.Equal(tt.want.URL(), tt.u)
}
})
}
}
+26
View File
@@ -0,0 +1,26 @@
package auth
import (
"context"
"github.com/go-faster/errors"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
)
// self returns current user.
//
// You can use tg.User.Bot to check whether current user is bot.
func (c *Client) self(ctx context.Context) (*tg.User, error) {
users, err := c.api.UsersGetUsers(ctx, []tg.InputUserClass{&tg.InputUserSelf{}})
if err != nil {
return nil, err
}
user, ok := tg.UserClassArray(users).FirstAsNotEmpty()
if !ok {
return nil, errors.Errorf("users response count: %v", users)
}
return user, nil
}
+42
View File
@@ -0,0 +1,42 @@
package auth
import (
"context"
"testing"
"github.com/stretchr/testify/require"
"go.mau.fi/mautrix-telegram/pkg/gotd/testutil"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
"go.mau.fi/mautrix-telegram/pkg/gotd/tgmock"
)
func TestClient_self(t *testing.T) {
ctx := context.Background()
mockTest(func(a *require.Assertions, mock *tgmock.Mock, client *Client) {
mock.ExpectCall(&tg.UsersGetUsersRequest{
ID: []tg.InputUserClass{&tg.InputUserSelf{}},
}).ThenErr(testutil.TestError())
_, err := client.self(ctx)
a.Error(err)
mock.ExpectCall(&tg.UsersGetUsersRequest{
ID: []tg.InputUserClass{&tg.InputUserSelf{}},
}).ThenResult(&tg.UserClassVector{Elems: []tg.UserClass{&tg.UserEmpty{
ID: 10,
}}})
_, err = client.self(ctx)
a.Error(err)
mock.ExpectCall(&tg.UsersGetUsersRequest{
ID: []tg.InputUserClass{&tg.InputUserSelf{}},
}).ThenResult(&tg.UserClassVector{Elems: []tg.UserClass{&tg.User{
Self: true,
ID: 10,
AccessHash: 10,
}}})
r, err := client.self(ctx)
a.NoError(err)
a.Equal(int64(10), r.ID)
})(t)
}
+37
View File
@@ -0,0 +1,37 @@
package auth
import (
"github.com/go-faster/errors"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
)
// SignUpRequired means that log in failed because corresponding account
// does not exist, so sign up is required.
type SignUpRequired struct {
TermsOfService tg.HelpTermsOfService
}
// Is returns true if err is SignUpRequired.
func (s *SignUpRequired) Is(err error) bool {
_, ok := err.(*SignUpRequired)
return ok
}
func (s *SignUpRequired) Error() string {
return "account with provided number does not exist (sign up required)"
}
// checkResult checks that `a` is *tg.AuthAuthorization and returns authorization result or error.
func checkResult(a tg.AuthAuthorizationClass) (*tg.AuthAuthorization, error) {
switch a := a.(type) {
case *tg.AuthAuthorization:
return a, nil // ok
case *tg.AuthAuthorizationSignUpRequired:
return nil, &SignUpRequired{
TermsOfService: a.TermsOfService,
}
default:
return nil, errors.Errorf("got unexpected response %T", a)
}
}
+60
View File
@@ -0,0 +1,60 @@
package auth
import (
"context"
"github.com/go-faster/errors"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
)
// Status represents authorization status.
type Status struct {
// Authorized is true if client is authorized.
Authorized bool
// User is current User object.
User *tg.User
}
// Status gets authorization status of client.
func (c *Client) Status(ctx context.Context) (*Status, error) {
u, err := c.self(ctx)
if IsUnauthorized(err) {
return &Status{}, nil
}
if err != nil {
return nil, err
}
return &Status{
Authorized: true,
User: u,
}, nil
}
// IfNecessary runs given auth flow if current session is not authorized.
func (c *Client) IfNecessary(ctx context.Context, flow Flow) error {
auth, err := c.Status(ctx)
if err != nil {
return errors.Wrap(err, "get auth status")
}
if auth.Authorized {
return nil
}
if err := flow.Run(ctx, c); err != nil {
return errors.Wrap(err, "auth flow")
}
return nil
}
// Test creates and runs auth flow using Test authenticator
// if current session is not authorized.
func (c *Client) Test(ctx context.Context, dc int) error {
return c.IfNecessary(ctx, NewFlow(Test(c.rand, dc), SendCodeOptions{}))
}
// TestUser creates and runs auth flow using TestUser authenticator
// if current session is not authorized.
func (c *Client) TestUser(ctx context.Context, phone string, dc int) error {
return c.IfNecessary(ctx, NewFlow(TestUser(phone, dc), SendCodeOptions{}))
}
+107
View File
@@ -0,0 +1,107 @@
package auth
import (
"context"
"testing"
"github.com/stretchr/testify/require"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
"go.mau.fi/mautrix-telegram/pkg/gotd/tgerr"
"go.mau.fi/mautrix-telegram/pkg/gotd/tgmock"
)
func TestClient_Status(t *testing.T) {
ctx := context.Background()
t.Run("Authorized", func(t *testing.T) {
mock := tgmock.NewRequire(t)
user := &tg.User{
Username: "user",
}
mock.Expect().ThenResult(&tg.UserClassVector{Elems: []tg.UserClass{user}})
status, err := testClient(mock).Status(ctx)
require.NoError(t, err)
require.True(t, status.Authorized)
require.Equal(t, user, status.User)
})
t.Run("Unauthorized", func(t *testing.T) {
mock := tgmock.NewRequire(t)
mock.Expect().ThenUnregistered()
status, err := testClient(mock).Status(ctx)
require.NoError(t, err)
require.False(t, status.Authorized)
})
t.Run("Error", func(t *testing.T) {
mock := tgmock.NewRequire(t)
mock.Expect().ThenRPCErr(&tgerr.Error{
Code: 500,
Message: "BRUH",
Type: "BRUH",
})
_, err := testClient(mock).Status(ctx)
require.Error(t, err)
})
}
func TestClient_IfNecessary(t *testing.T) {
ctx := context.Background()
t.Run("Authorized", func(t *testing.T) {
mock := tgmock.NewRequire(t)
testUser := &tg.User{
Username: "user",
}
mock.Expect().ThenResult(&tg.UserClassVector{Elems: []tg.UserClass{testUser}})
// Pass empty AuthFlow because it should not be called anyway.
require.NoError(t, testClient(mock).IfNecessary(ctx, Flow{}))
})
t.Run("Error", func(t *testing.T) {
mock := tgmock.NewRequire(t)
mock.Expect().ThenRPCErr(&tgerr.Error{
Code: 500,
Message: "BRUH",
Type: "BRUH",
})
// Pass empty AuthFlow because it should not be called anyway.
require.Error(t, testClient(mock).IfNecessary(ctx, Flow{}))
})
}
func TestClient_Test(t *testing.T) {
ctx := context.Background()
t.Run("Authorized", func(t *testing.T) {
mock := tgmock.NewRequire(t)
testUser := &tg.User{
Username: "user",
}
mock.Expect().ThenResult(&tg.UserClassVector{Elems: []tg.UserClass{testUser}})
// Pass empty AuthFlow because it should not be called anyway.
require.NoError(t, testClient(mock).Test(ctx, 2))
})
}
func TestClient_TestUser(t *testing.T) {
ctx := context.Background()
t.Run("Authorized", func(t *testing.T) {
mock := tgmock.NewRequire(t)
testUser := &tg.User{
Username: "user",
}
mock.Expect().ThenResult(&tg.UserClassVector{Elems: []tg.UserClass{testUser}})
// Pass empty AuthFlow because it should not be called anyway.
require.NoError(t, testClient(mock).TestUser(ctx, "phone", 2))
})
}
+154
View File
@@ -0,0 +1,154 @@
package auth
import (
"context"
"github.com/go-faster/errors"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
"go.mau.fi/mautrix-telegram/pkg/gotd/tgerr"
)
// ErrPasswordInvalid means that password provided to Password is invalid.
//
// Note that telegram does not trim whitespace characters by default, check
// that provided password is expected and clean whitespaces if needed.
// You can use strings.TrimSpace(password) for this.
var ErrPasswordInvalid = errors.New("invalid password")
// Password performs login via secure remote password (aka 2FA).
//
// Method can be called after SignIn to provide password if requested.
func (c *Client) Password(ctx context.Context, password string) (*tg.AuthAuthorization, error) {
p, err := c.api.AccountGetPassword(ctx)
if err != nil {
return nil, errors.Wrap(err, "get SRP parameters")
}
a, err := PasswordHash([]byte(password), p.SRPID, p.SRPB, p.SecureRandom, p.CurrentAlgo)
if err != nil {
return nil, errors.Wrap(err, "compute password hash")
}
auth, err := c.api.AuthCheckPassword(ctx, &tg.InputCheckPasswordSRP{
SRPID: p.SRPID,
A: a.A,
M1: a.M1,
})
if tg.IsPasswordHashInvalid(err) {
return nil, ErrPasswordInvalid
}
if err != nil {
return nil, errors.Wrap(err, "check password")
}
result, err := checkResult(auth)
if err != nil {
return nil, errors.Wrap(err, "check")
}
return result, nil
}
// SendCodeOptions defines how to send auth code to user.
type SendCodeOptions struct {
// AllowFlashCall allows phone verification via phone calls.
AllowFlashCall bool
// Pass true if the phone number is used on the current device.
// Ignored if AllowFlashCall is not set.
CurrentNumber bool
// If a token that will be included in eventually sent SMSs is required:
// required in newer versions of android, to use the android SMS receiver APIs.
AllowAppHash bool
}
// SendCode requests code for provided phone number, returning code hash
// and error if any. Use AuthFlow to reduce boilerplate.
//
// This method should be called first in user authentication flow.
func (c *Client) SendCode(ctx context.Context, phone string, options SendCodeOptions) (tg.AuthSentCodeClass, error) {
var settings tg.CodeSettings
if options.AllowAppHash {
settings.SetAllowAppHash(true)
}
if options.AllowFlashCall {
settings.SetAllowFlashcall(true)
}
if options.CurrentNumber {
settings.SetCurrentNumber(true)
}
sentCode, err := c.api.AuthSendCode(ctx, &tg.AuthSendCodeRequest{
PhoneNumber: phone,
APIID: c.appID,
APIHash: c.appHash,
Settings: settings,
})
if err != nil {
return nil, errors.Wrap(err, "send code")
}
return sentCode, nil
}
// ErrPasswordAuthNeeded means that 2FA auth is required.
//
// Call Client.Password to provide 2FA password.
var ErrPasswordAuthNeeded = errors.New("2FA required")
// SignIn performs sign in with provided user phone, code and code hash.
//
// If ErrPasswordAuthNeeded is returned, call Password to provide 2FA
// password.
//
// To obtain codeHash, use SendCode.
func (c *Client) SignIn(ctx context.Context, phone, code, codeHash string) (*tg.AuthAuthorization, error) {
auth, err := c.api.AuthSignIn(ctx, &tg.AuthSignInRequest{
PhoneNumber: phone,
PhoneCodeHash: codeHash,
PhoneCode: code,
})
if tgerr.Is(err, "SESSION_PASSWORD_NEEDED") {
return nil, ErrPasswordAuthNeeded
}
if err != nil {
return nil, errors.Wrap(err, "sign in")
}
result, err := checkResult(auth)
if err != nil {
return nil, errors.Wrap(err, "check")
}
return result, nil
}
// AcceptTOS accepts version of Terms Of Service.
func (c *Client) AcceptTOS(ctx context.Context, id tg.DataJSON) error {
_, err := c.api.HelpAcceptTermsOfService(ctx, id)
return err
}
// SignUp wraps parameters for SignUp.
type SignUp struct {
PhoneNumber string
PhoneCodeHash string
FirstName string
LastName string
}
// SignUp registers a validated phone number in the system.
//
// To obtain codeHash, use SendCode.
// Use AuthFlow helper to handle authentication flow.
func (c *Client) SignUp(ctx context.Context, s SignUp) (*tg.AuthAuthorization, error) {
auth, err := c.api.AuthSignUp(ctx, &tg.AuthSignUpRequest{
LastName: s.LastName,
PhoneCodeHash: s.PhoneCodeHash,
PhoneNumber: s.PhoneNumber,
FirstName: s.FirstName,
})
if err != nil {
return nil, errors.Wrap(err, "request")
}
result, err := checkResult(auth)
if err != nil {
return nil, errors.Wrap(err, "check")
}
return result, nil
}
+254
View File
@@ -0,0 +1,254 @@
package auth
import (
"context"
"encoding/hex"
"math/rand"
"strconv"
"strings"
"testing"
"github.com/go-faster/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
"go.mau.fi/mautrix-telegram/pkg/gotd/tgerr"
"go.mau.fi/mautrix-telegram/pkg/gotd/tgmock"
)
func getHex(t testing.TB, in string) []byte {
res, err := hex.DecodeString(in)
if err != nil {
t.Fatal("failed to get hex", err)
}
return res
}
func TestClient_AuthSignIn(t *testing.T) {
const (
phone = "123123"
code = "1010"
password = "secret"
codeHash = "hash"
)
ctx := context.Background()
testUser := &tg.User{ID: 1}
invoker := tgmock.Invoker(func(body bin.Encoder) (bin.Encoder, error) {
switch req := body.(type) {
case *tg.UsersGetUsersRequest:
return nil, &tgerr.Error{
Code: 401,
Message: "AUTH_KEY_UNREGISTERED",
Type: "AUTH_KEY_UNREGISTERED",
}
case *tg.AuthSendCodeRequest:
settings := tg.CodeSettings{}
settings.SetCurrentNumber(true)
assert.Equal(t, &tg.AuthSendCodeRequest{
PhoneNumber: phone,
APIHash: testAppHash,
APIID: testAppID,
Settings: settings,
}, req)
return &tg.AuthSentCode{
Type: &tg.AuthSentCodeTypeApp{},
PhoneCodeHash: codeHash,
}, nil
case *tg.AuthSignInRequest:
assert.Equal(t, &tg.AuthSignInRequest{
PhoneNumber: phone,
PhoneCodeHash: codeHash,
PhoneCode: code,
}, req)
return nil, tgerr.New(401, "SESSION_PASSWORD_NEEDED")
case *tg.AccountGetPasswordRequest:
algo := &tg.PasswordKdfAlgoSHA256SHA256PBKDF2HMACSHA512iter100000SHA256ModPow{
Salt1: getHex(t, "4D11FB6BEC38F9D2546BB0F61E4F1C99A1BC0DB8F0D5F35B1291B37B213123D7ED48F3C6794D495B"),
Salt2: getHex(t, "A1B181AAFE88188680AE32860D60BB01"),
G: 3,
P: getHex(t, "C71CAEB9C6B1C9048E6C522F70F13F73980D40238E3E21C14934D037563D930F"+
"48198A0AA7C14058229493D22530F4DBFA336F6E0AC925139543AED44CCE7C37"+
"20FD51F69458705AC68CD4FE6B6B13ABDC9746512969328454F18FAF8C595F64"+
"2477FE96BB2A941D5BCD1D4AC8CC49880708FA9B378E3C4F3A9060BEE67CF9A4"+
"A4A695811051907E162753B56B0F6B410DBA74D8A84B2A14B3144E0EF1284754"+
"FD17ED950D5965B4B9DD46582DB1178D169C6BC465B0D6FF9CA3928FEF5B9AE4"+
"E418FC15E83EBEA0F87FA9FF5EED70050DED2849F47BF959D956850CE929851F"+
"0D8115F635B105EE2E4E15D04B2454BF6F4FADF034B10403119CD8E3B92FCC5B"),
}
pwd := &tg.AccountPassword{
NewAlgo: algo,
NewSecureAlgo: &tg.SecurePasswordKdfAlgoPBKDF2HMACSHA512iter100000{},
}
pwd.SetCurrentAlgo(algo)
return pwd, nil
case *tg.AuthCheckPasswordRequest:
// TODO(ernado): Check actual secure remote password here.
switch pwd := req.Password.(type) {
case *tg.InputCheckPasswordSRP:
assert.NotEmpty(t, pwd.A)
assert.NotEmpty(t, pwd.M1)
assert.NotEqual(t, pwd.SRPID, 0)
default:
t.Errorf("unexpectd pwd type %T", pwd)
}
return &tg.AuthAuthorization{
User: testUser,
}, nil
}
return nil, errors.New("unexpected")
})
t.Run("Manual", func(t *testing.T) {
// 1. Request code from server to device.
client := testClient(invoker)
sentCode, err := client.SendCode(ctx, phone, SendCodeOptions{CurrentNumber: true})
require.NoError(t, err)
h := sentCode.(*tg.AuthSentCode).PhoneCodeHash
require.Equal(t, codeHash, h)
// 2. Send code from device to server.
// Server is responding with 2FA password prompt.
_, signInErr := client.SignIn(ctx, phone, code, h)
require.ErrorIs(t, signInErr, ErrPasswordAuthNeeded)
// 3. Provide 2FA password.
result, err := client.Password(ctx, password)
require.NoError(t, err)
require.Equal(t, testUser, result.User)
})
flow := NewFlow(
Constant(phone, password, CodeAuthenticatorFunc(
func(ctx context.Context, _ *tg.AuthSentCode) (string, error) {
return code, nil
},
)),
SendCodeOptions{CurrentNumber: true},
)
t.Run("AuthFlow", func(t *testing.T) {
require.NoError(t, flow.Run(ctx, testClient(invoker)))
})
t.Run("IfNecessary", func(t *testing.T) {
require.NoError(t, testClient(invoker).IfNecessary(ctx, flow))
})
}
func TestClientTestAuth(t *testing.T) {
const (
codeHash = "hash"
dcID = 2
)
ctx := context.Background()
invoker := tgmock.Invoker(func(body bin.Encoder) (bin.Encoder, error) {
switch req := body.(type) {
case *tg.AuthSendCodeRequest:
assert.Equal(t, &tg.AuthSendCodeRequest{
PhoneNumber: req.PhoneNumber,
APIHash: testAppHash,
APIID: testAppID,
Settings: tg.CodeSettings{},
}, req)
return &tg.AuthSentCode{
Type: &tg.AuthSentCodeTypeApp{
Length: 6,
},
PhoneCodeHash: codeHash,
}, nil
case *tg.AuthSignInRequest:
if !strings.HasPrefix(req.PhoneNumber, "99966") {
t.Fatalf("unexpected phone number %s", req.PhoneNumber)
}
dcPart := req.PhoneNumber[5:6]
assert.Equal(t, strconv.Itoa(dcID), dcPart, "dc part of phone number")
assert.Equal(t, &tg.AuthSignInRequest{
PhoneNumber: req.PhoneNumber,
PhoneCodeHash: codeHash,
PhoneCode: strings.Repeat(dcPart, 6),
}, req)
return &tg.AuthAuthorization{
User: &tg.User{ID: 1},
}, nil
}
return nil, errors.New("unexpected")
})
require.NoError(t, NewFlow(
Test(rand.New(rand.NewSource(1)), dcID),
SendCodeOptions{},
).Run(ctx, testClient(invoker)))
}
func TestClientTestSignUp(t *testing.T) {
const (
dcID = 2
codeHash = "hash"
tosID = "foo"
)
ctx := context.Background()
invoker := tgmock.Invoker(func(body bin.Encoder) (bin.Encoder, error) {
switch req := body.(type) {
case *tg.AuthSendCodeRequest:
assert.Equal(t, &tg.AuthSendCodeRequest{
PhoneNumber: req.PhoneNumber,
APIHash: testAppHash,
APIID: testAppID,
Settings: tg.CodeSettings{},
}, req)
return &tg.AuthSentCode{
Type: &tg.AuthSentCodeTypeApp{
Length: 6,
},
PhoneCodeHash: codeHash,
}, nil
case *tg.AuthSignUpRequest:
assert.Equal(t, &tg.AuthSignUpRequest{
PhoneNumber: req.PhoneNumber,
PhoneCodeHash: codeHash,
FirstName: "Test",
LastName: "User",
}, req)
return &tg.AuthAuthorization{
User: &tg.User{ID: 1},
}, nil
case *tg.HelpAcceptTermsOfServiceRequest:
return &tg.BoolTrue{}, nil
case *tg.AuthSignInRequest:
if !strings.HasPrefix(req.PhoneNumber, "99966") {
t.Fatalf("unexpected phone number %s", req.PhoneNumber)
}
dcPart := req.PhoneNumber[5:6]
assert.Equal(t, strconv.Itoa(dcID), dcPart, "dc part of phone number")
assert.Equal(t, &tg.AuthSignInRequest{
PhoneNumber: req.PhoneNumber,
PhoneCodeHash: codeHash,
PhoneCode: strings.Repeat(dcPart, 6),
}, req)
res := &tg.AuthAuthorizationSignUpRequired{}
res.SetTermsOfService(tg.HelpTermsOfService{ID: tg.DataJSON{Data: tosID}})
return res, nil
}
return nil, errors.New("unexpected")
})
require.NoError(t, NewFlow(
Test(rand.New(rand.NewSource(1)), dcID),
SendCodeOptions{},
).Run(ctx, testClient(invoker)))
}
func TestClient_AcceptTOS(t *testing.T) {
ctx := context.Background()
mockTest(func(a *require.Assertions, mock *tgmock.Mock, client *Client) {
mock.Expect().ThenUnregistered()
a.Error(client.AcceptTOS(ctx, tg.DataJSON{
Data: `{"data":"data"}`,
}))
mock.Expect().ThenTrue()
a.NoError(client.AcceptTOS(ctx, tg.DataJSON{
Data: `{"data":"data"}`,
}))
})(t)
}
+166
View File
@@ -0,0 +1,166 @@
package telegram_test
import (
"bufio"
"context"
"fmt"
"log"
"os"
"strconv"
"strings"
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram"
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/auth"
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/auth/qrlogin"
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/dcs"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
)
func ExampleClient_Auth_codeOnly() {
check := func(err error) {
if err != nil {
panic(err)
}
}
var (
appIDString = os.Getenv("APP_ID")
appHash = os.Getenv("APP_HASH")
phone = os.Getenv("PHONE")
)
if appIDString == "" || appHash == "" || phone == "" {
log.Fatal("PHONE, APP_ID or APP_HASH is not set")
}
appID, err := strconv.Atoi(appIDString)
check(err)
ctx := context.Background()
client := telegram.NewClient(appID, appHash, telegram.Options{})
codeAsk := func(ctx context.Context, sentCode *tg.AuthSentCode) (string, error) {
fmt.Print("code:")
code, err := bufio.NewReader(os.Stdin).ReadString('\n')
if err != nil {
return "", err
}
code = strings.ReplaceAll(code, "\n", "")
return code, nil
}
check(client.Run(ctx, func(ctx context.Context) error {
return auth.NewFlow(
auth.CodeOnly(phone, auth.CodeAuthenticatorFunc(codeAsk)),
auth.SendCodeOptions{},
).Run(ctx, client.Auth())
}))
}
func ExampleClient_Auth_password() {
check := func(err error) {
if err != nil {
panic(err)
}
}
var (
appIDString = os.Getenv("APP_ID")
appHash = os.Getenv("APP_HASH")
phone = os.Getenv("PHONE")
pass = os.Getenv("PASSWORD")
)
if appIDString == "" || appHash == "" || phone == "" || pass == "" {
log.Fatal("PHONE, PASSWORD, APP_ID or APP_HASH is not set")
}
appID, err := strconv.Atoi(appIDString)
check(err)
ctx := context.Background()
client := telegram.NewClient(appID, appHash, telegram.Options{})
codeAsk := func(ctx context.Context, sentCode *tg.AuthSentCode) (string, error) {
fmt.Print("code:")
code, err := bufio.NewReader(os.Stdin).ReadString('\n')
if err != nil {
return "", err
}
code = strings.ReplaceAll(code, "\n", "")
return code, nil
}
check(client.Run(ctx, func(ctx context.Context) error {
return auth.NewFlow(
auth.Constant(phone, pass, auth.CodeAuthenticatorFunc(codeAsk)),
auth.SendCodeOptions{},
).Run(ctx, client.Auth())
}))
}
func ExampleClient_Auth_test() {
// Example of using test server.
const dcID = 2
ctx := context.Background()
client := telegram.NewClient(telegram.TestAppID, telegram.TestAppHash, telegram.Options{
DC: dcID,
DCList: dcs.Test(),
})
if err := client.Run(ctx, func(ctx context.Context) error {
return client.Auth().Test(ctx, dcID)
}); err != nil {
panic(err)
}
}
func ExampleClient_Auth_bot() {
ctx := context.Background()
client := telegram.NewClient(telegram.TestAppID, telegram.TestAppHash, telegram.Options{})
if err := client.Run(ctx, func(ctx context.Context) error {
// Checking auth status.
status, err := client.Auth().Status(ctx)
if err != nil {
return err
}
// Can be already authenticated if we have valid session in
// session storage.
if !status.Authorized {
// Otherwise, perform bot authentication.
if _, err := client.Auth().Bot(ctx, os.Getenv("BOT_TOKEN")); err != nil {
return err
}
}
// All good, manually authenticated.
return nil
}); err != nil {
panic(err)
}
}
func ExampleQR_Auth() {
ctx := context.Background()
d := tg.NewUpdateDispatcher()
loggedIn := qrlogin.OnLoginToken(d)
client := telegram.NewClient(telegram.TestAppID, telegram.TestAppHash, telegram.Options{
UpdateHandler: d,
})
if err := client.Run(ctx, func(ctx context.Context) error {
qr := client.QR()
authorization, err := qr.Auth(ctx, loggedIn, func(ctx context.Context, token qrlogin.Token) error {
fmt.Printf("Open %s using your phone\n", token.URL())
return nil
})
if err != nil {
return err
}
u, ok := authorization.User.AsNotEmpty()
if !ok {
return fmt.Errorf("unexpected type %T", authorization.User)
}
fmt.Println("ID:", u.ID, "Username:", u.Username, "Bot:", u.Bot)
return nil
}); err != nil {
panic(err)
}
}
+53
View File
@@ -0,0 +1,53 @@
package telegram
import (
"context"
"os"
"github.com/go-faster/errors"
)
// RunUntilCanceled is client callback which
// locks until client context is canceled.
func RunUntilCanceled(ctx context.Context, client *Client) error {
<-ctx.Done()
return ctx.Err()
}
// BotFromEnvironment creates bot client using ClientFromEnvironment
// connects to server and authenticates it.
//
// Variables:
// BOT_TOKEN — token from BotFather.
func BotFromEnvironment(
ctx context.Context,
opts Options,
setup func(ctx context.Context, client *Client) error,
cb func(ctx context.Context, client *Client) error,
) error {
client, err := ClientFromEnvironment(opts)
if err != nil {
return errors.Wrap(err, "create client")
}
if setup != nil {
if err := setup(ctx, client); err != nil {
return errors.Wrap(err, "setup")
}
}
return client.Run(ctx, func(ctx context.Context) error {
status, err := client.Auth().Status(ctx)
if err != nil {
return errors.Wrap(err, "auth status")
}
if !status.Authorized {
if _, err := client.Auth().Bot(ctx, os.Getenv("BOT_TOKEN")); err != nil {
return errors.Wrap(err, "login")
}
}
return cb(ctx, client)
})
}
+174
View File
@@ -0,0 +1,174 @@
package telegram
import (
"context"
"io"
"os"
"path/filepath"
"strconv"
"time"
"github.com/cenkalti/backoff/v4"
"github.com/go-faster/errors"
"go.uber.org/zap"
"golang.org/x/net/proxy"
"go.mau.fi/mautrix-telegram/pkg/gotd/clock"
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
"go.mau.fi/mautrix-telegram/pkg/gotd/session"
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/auth"
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/dcs"
"go.mau.fi/mautrix-telegram/pkg/gotd/tgerr"
)
func sessionDir() (string, error) {
dir, ok := os.LookupEnv("SESSION_DIR")
if ok {
return filepath.Abs(dir)
}
dir, err := os.UserHomeDir()
if err != nil {
dir = "."
}
return filepath.Abs(filepath.Join(dir, ".td"))
}
// OptionsFromEnvironment fills unfilled field in opts parameter
// using environment variables.
//
// Variables:
//
// SESSION_FILE: path to session file
// SESSION_DIR: path to session directory, if SESSION_FILE is not set
// ALL_PROXY, NO_PROXY: see https://pkg.go.dev/golang.org/x/net/proxy#FromEnvironment
func OptionsFromEnvironment(opts Options) (Options, error) {
// Setting up session storage if not provided.
if opts.SessionStorage == nil {
sessionFile, ok := os.LookupEnv("SESSION_FILE")
if !ok {
dir, err := sessionDir()
if err != nil {
return Options{}, errors.Wrap(err, "SESSION_DIR not set or invalid")
}
sessionFile = filepath.Join(dir, "session.json")
}
dir, _ := filepath.Split(sessionFile)
if err := os.MkdirAll(dir, 0700); err != nil {
return Options{}, errors.Wrap(err, "session dir creation")
}
opts.SessionStorage = &session.FileStorage{
Path: sessionFile,
}
}
if opts.Resolver == nil {
opts.Resolver = dcs.Plain(dcs.PlainOptions{
Dial: proxy.Dial,
})
}
return opts, nil
}
// ClientFromEnvironment creates client using OptionsFromEnvironment
// but does not connect to server.
//
// Variables:
//
// APP_ID: app_id of Telegram app.
// APP_HASH: app_hash of Telegram app.
func ClientFromEnvironment(opts Options) (*Client, error) {
appID, err := strconv.Atoi(os.Getenv("APP_ID"))
if err != nil {
return nil, errors.Wrap(err, "APP_ID not set or invalid")
}
appHash := os.Getenv("APP_HASH")
if appHash == "" {
return nil, errors.New("no APP_HASH provided")
}
opts, err = OptionsFromEnvironment(opts)
if err != nil {
return nil, err
}
return NewClient(appID, appHash, opts), nil
}
func retry(ctx context.Context, logger *zap.Logger, cb func(ctx context.Context) error) error {
b := backoff.WithContext(backoff.NewExponentialBackOff(), ctx)
// List of known retryable RPC error types.
retryableErrors := []string{
"NEED_MEMBER_INVALID",
"AUTH_KEY_UNREGISTERED",
"API_ID_PUBLISHED_FLOOD",
}
return backoff.Retry(func() error {
if err := cb(ctx); err != nil {
logger.Warn("TestClient run failed", zap.Error(err))
if tgerr.Is(err, retryableErrors...) {
return err
}
if timeout, ok := AsFloodWait(err); ok {
timer := clock.System.Timer(timeout + 1*time.Second)
defer clock.StopTimer(timer)
select {
case <-timer.C():
return err
case <-ctx.Done():
return ctx.Err()
}
}
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) {
// Possibly server closed connection.
return err
}
return backoff.Permanent(err)
}
return nil
}, b)
}
// TestClient creates and authenticates user telegram.Client
// using Telegram test server.
func TestClient(ctx context.Context, opts Options, cb func(ctx context.Context, client *Client) error) error {
if opts.DC == 0 {
opts.DC = 2
}
if opts.DCList.Zero() {
opts.DCList = dcs.Test()
}
logger := zap.NewNop()
if opts.Logger != nil {
logger = opts.Logger.Named("test")
}
// Sometimes testing server can return "AUTH_KEY_UNREGISTERED" error.
// It is expected and client implementation is unlikely to cause
// such errors, so just doing retries using backoff.
return retry(ctx, logger, func(retryCtx context.Context) error {
client := NewClient(TestAppID, TestAppHash, opts)
return client.Run(retryCtx, func(runCtx context.Context) error {
if err := client.Auth().IfNecessary(runCtx, auth.NewFlow(
auth.Test(crypto.DefaultRand(), opts.DC),
auth.SendCodeOptions{},
)); err != nil {
return errors.Wrap(err, "auth flow")
}
return cb(runCtx, client)
})
})
}
+31
View File
@@ -0,0 +1,31 @@
package telegram
import (
"crypto/rsa"
"encoding/pem"
"github.com/go-faster/errors"
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
)
func parseCDNKeys(keys ...tg.CDNPublicKey) ([]*rsa.PublicKey, error) {
r := make([]*rsa.PublicKey, 0, len(keys))
for _, key := range keys {
block, _ := pem.Decode([]byte(key.PublicKey))
if block == nil {
continue
}
key, err := crypto.ParseRSA(block.Bytes)
if err != nil {
return nil, errors.Wrap(err, "parse RSA from PEM")
}
r = append(r, key)
}
return r, nil
}
+43
View File
@@ -0,0 +1,43 @@
package telegram
import (
"testing"
"github.com/stretchr/testify/require"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
)
func Test_parseCDNKeys(t *testing.T) {
keys := []string{
`-----BEGIN RSA PUBLIC KEY-----
MIIBCgKCAQEA+Lf3PvgE1yxbJUCMaEAk
V0QySTVpnaDjiednB5RbtNWjCeqSVakYHbqqGMIIv5WCGdFdrqOfMNcNSstPtSU6
R9UmRw6tquOIykpSuUOje9H+4XVIKqujyL2ISdK+4ZOMl4hCMkqauw4bP1Sbr03v
ZRQbU6qEA04V4j879BAyBVhr3WG9+Zi+t5XfGSTgSExPYEl8rZNHYNV5RB+BuroV
H2HLTOpT/mJVfikYpgjfWF5ldezV4Wo9LSH0cZGSFIaeJl8d0A8Eiy5B9gtBO8mL
+XfQRKOOmr7a4BM4Ro2de5rr2i2od7hYXd3DO9FRSl4y1zA8Am48Rfd95WHF3N/O
mQIDAQAB
-----END RSA PUBLIC KEY-----`,
`-----BEGIN RSA PUBLIC KEY-----
MIIBCgKCAQEAyu5PXyfp+VFLc2hKJsq/cvQ+wq9V2s1iGMMwcrkXrKAqX0S5QEcY
W9b6pV5LulbsvNcxp/YniiSL4FsAja28B9fH//Y+AolWASomCB0NSVHwS1Pqfe3m
GdLTwDmqU17tSWk/48+Kfn4B+WT85ZIKt8bOnABwnM1AtykX0zKwzm9yKcTX0MeY
rwzgiOQax6J1cfgtLdxl8HVKT6wCOS1e43zpXMU+UoWqRqIan+J6q+ubi1yF4PWl
DyDgJSw8uxlhNNMP4tAnshIRZ1ZZ25O/g58jw1qz5XMztZwLNA2pUxaFtyy1LdHC
FRX7DdwIA/FdOzfWyXYLlCFaSX8K/6CnSQIDAQAB
-----END RSA PUBLIC KEY-----`,
}
cdnKeys := make([]tg.CDNPublicKey, 0, len(keys))
for i, key := range keys {
cdnKeys = append(cdnKeys, tg.CDNPublicKey{
DCID: i + 1,
PublicKey: key,
})
}
publicKeys, err := parseCDNKeys(cdnKeys...)
require.NoError(t, err)
require.Len(t, publicKeys, 2)
}
+43
View File
@@ -0,0 +1,43 @@
package telegram
import (
"context"
"go.uber.org/zap"
"go.mau.fi/mautrix-telegram/pkg/gotd/constant"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
)
// Available MTProto default server addresses.
//
// See https://my.telegram.org/apps.
const (
AddrProduction = "149.154.167.50:443"
AddrTest = "149.154.167.40:443"
)
// Test-only credentials. Can be used with AddrTest and TestAuth to
// test authentication.
//
// Reference:
// - https://github.com/telegramdesktop/tdesktop/blob/5f665b8ecb48802cd13cfb48ec834b946459274a/docs/api_credentials.md
const (
TestAppID = constant.TestAppID
TestAppHash = constant.TestAppHash
)
// Config returns current config.
func (c *Client) Config() tg.Config {
return c.cfg.Load()
}
func (c *Client) fetchConfig(ctx context.Context) {
cfg, err := c.tg.HelpGetConfig(ctx)
if err != nil {
c.log.Warn("Got error on config update", zap.Error(err))
return
}
c.cfg.Store(*cfg)
}
+27
View File
@@ -0,0 +1,27 @@
package telegram
import (
"testing"
"github.com/stretchr/testify/require"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
)
func TestClient_fetchConfig(t *testing.T) {
a := require.New(t)
cfg := &tg.Config{
ThisDC: 10,
}
client := newTestClient(func(id int64, body bin.Encoder) (bin.Encoder, error) {
a.IsType(&tg.HelpGetConfigRequest{}, body)
return cfg, nil
})
a.NoError(client.processUpdates(&tg.Updates{
Updates: []tg.UpdateClass{&tg.UpdateConfig{}},
}))
a.Equal(*cfg, client.Config())
}
+262
View File
@@ -0,0 +1,262 @@
package telegram
import (
"context"
"io"
"sync"
"time"
"github.com/cenkalti/backoff/v4"
"go.opentelemetry.io/otel/trace"
"go.uber.org/atomic"
"go.uber.org/zap"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
"go.mau.fi/mautrix-telegram/pkg/gotd/clock"
"go.mau.fi/mautrix-telegram/pkg/gotd/mtproto"
"go.mau.fi/mautrix-telegram/pkg/gotd/oteltg"
"go.mau.fi/mautrix-telegram/pkg/gotd/pool"
"go.mau.fi/mautrix-telegram/pkg/gotd/session"
"go.mau.fi/mautrix-telegram/pkg/gotd/tdsync"
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/dcs"
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/internal/manager"
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/internal/version"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
)
// UpdateHandler will be called on received updates from Telegram.
type UpdateHandler interface {
Handle(ctx context.Context, u tg.UpdatesClass) error
}
// UpdateHandlerFunc type is an adapter to allow the use of
// ordinary function as update handler.
//
// UpdateHandlerFunc(f) is an UpdateHandler that calls f.
type UpdateHandlerFunc func(ctx context.Context, u tg.UpdatesClass) error
// Handle calls f(ctx, u)
func (f UpdateHandlerFunc) Handle(ctx context.Context, u tg.UpdatesClass) error {
return f(ctx, u)
}
type clientStorage interface {
Load(ctx context.Context) (*session.Data, error)
Save(ctx context.Context, data *session.Data) error
}
type clientConn interface {
Run(ctx context.Context) error
Invoke(ctx context.Context, input bin.Encoder, output bin.Decoder) error
Ping(ctx context.Context) error
}
// Client represents a MTProto client to Telegram.
type Client struct {
// Put migration in the header of the structure to ensure 64-bit alignment,
// otherwise it will cause the atomic operation of connsCounter to panic.
// DO NOT change the order of members arbitrarily.
// Ref: https://pkg.go.dev/sync/atomic#pkg-note-BUG
// Connection factory fields.
connsCounter atomic.Int64
create connConstructor // immutable
resolver dcs.Resolver // immutable
onDead func() // immutable
onAuthError func(error) // immutable
onConnected func() // immutable
newConnBackoff func() backoff.BackOff // immutable
defaultMode manager.ConnMode // immutable
// Migration state.
migrationTimeout time.Duration // immutable
migration chan struct{}
// tg provides RPC calls via Client. Uses invoker below.
tg *tg.Client // immutable
// invoker implements tg.Invoker on top of Client and mw.
invoker tg.Invoker // immutable
// mw is list of middlewares used in invoker, can be blank.
mw []Middleware // immutable
// Telegram device information.
device DeviceConfig // immutable
// MTProto options.
opts mtproto.Options // immutable
// DCList state.
// Domain list (for websocket)
domains map[int]string // immutable
// Denotes to use Test DCs.
testDC bool // immutable
// Connection state. Guarded by connMux.
session *pool.SyncSession
cfg *manager.AtomicConfig
conn clientConn
connBackoff atomic.Pointer[backoff.BackOff]
connMux sync.Mutex
// Restart signal channel.
restart chan struct{} // immutable
// Connections to non-primary DC.
subConns map[int]CloseInvoker
subConnsMux sync.Mutex
sessions map[int]*pool.SyncSession
sessionsMux sync.Mutex
// Wrappers for external world, like logs or PRNG.
rand io.Reader // immutable
log *zap.Logger // immutable
clock clock.Clock // immutable
// Client context. Will be canceled by Run on exit.
ctx context.Context
cancel context.CancelFunc
// Client config.
appID int // immutable
appHash string // immutable
// Session storage.
storage clientStorage // immutable, nillable
// Ready signal channel, sends signal when client connection is ready.
// Resets on reconnect.
ready *tdsync.ResetReady // immutable
// Telegram updates handler.
updateHandler UpdateHandler // immutable
// Denotes that no update mode is enabled.
noUpdatesMode bool // immutable
// Tracing.
tracer trace.Tracer
// onTransfer is called in transfer.
onTransfer AuthTransferHandler
// onSelfError is called on error calling Self().
onSelfError func(ctx context.Context, err error) error
}
// NewClient creates new unstarted client.
func NewClient(appID int, appHash string, opt Options) *Client {
opt.setDefaults()
mode := manager.ConnModeUpdates
if opt.NoUpdates {
mode = manager.ConnModeData
}
client := &Client{
rand: opt.Random,
log: opt.Logger,
appID: appID,
appHash: appHash,
updateHandler: opt.UpdateHandler,
session: pool.NewSyncSession(pool.Session{
DC: opt.DC,
}),
domains: opt.DCList.Domains,
testDC: opt.DCList.Test,
cfg: manager.NewAtomicConfig(tg.Config{
DCOptions: opt.DCList.Options,
}),
create: defaultConstructor(),
resolver: opt.Resolver,
defaultMode: mode,
newConnBackoff: opt.ReconnectionBackoff,
onDead: opt.OnDead,
onAuthError: opt.OnAuthError,
onConnected: opt.OnConnected,
clock: opt.Clock,
device: opt.Device,
migrationTimeout: opt.MigrationTimeout,
noUpdatesMode: opt.NoUpdates,
mw: opt.Middlewares,
onTransfer: opt.OnTransfer,
onSelfError: opt.OnSelfError,
}
if opt.TracerProvider != nil {
client.tracer = opt.TracerProvider.Tracer(oteltg.Name)
}
client.init()
// Including version into client logger to help with debugging.
if v := version.GetVersion(); v != "" {
client.log = client.log.With(zap.String("v", v))
}
if opt.SessionStorage != nil {
client.storage = &session.Loader{
Storage: opt.SessionStorage,
}
}
if opt.CustomSessionStorage != nil {
client.storage = opt.CustomSessionStorage
}
client.opts = mtproto.Options{
PublicKeys: opt.PublicKeys,
Random: opt.Random,
Logger: opt.Logger,
AckBatchSize: opt.AckBatchSize,
AckInterval: opt.AckInterval,
RetryInterval: opt.RetryInterval,
MaxRetries: opt.MaxRetries,
CompressThreshold: opt.CompressThreshold,
MessageID: opt.MessageID,
ExchangeTimeout: opt.ExchangeTimeout,
DialTimeout: opt.DialTimeout,
Clock: opt.Clock,
PingInterval: opt.PingInterval,
PingCallback: opt.PingCallback,
PingTimeout: opt.PingTimeout,
Handler: SessionNotifierHandler{opt.OnSession},
Types: getTypesMapping(),
Tracer: client.tracer,
}
client.conn = client.createPrimaryConn(nil)
return client
}
// SessionNotifierHandler is a handler which notifies when a session is
// received.
type SessionNotifierHandler struct {
onSession func()
}
// OnSession implements Handler.
func (n SessionNotifierHandler) OnSession(s mtproto.Session) error {
if n.onSession != nil {
n.onSession()
}
return nil
}
// OnMessage implements Handler
func (n SessionNotifierHandler) OnMessage(b *bin.Buffer) error {
return nil
}
// init sets fields which needs explicit initialization, like maps or channels.
func (c *Client) init() {
if c.domains == nil {
c.domains = map[int]string{}
}
if c.cfg == nil {
c.cfg = manager.NewAtomicConfig(tg.Config{})
}
c.ready = tdsync.NewResetReady()
c.restart = make(chan struct{})
c.migration = make(chan struct{}, 1)
c.sessions = map[int]*pool.SyncSession{}
c.subConns = map[int]CloseInvoker{}
c.invoker = chainMiddlewares(InvokeFunc(c.invokeDirect), c.mw...)
c.tg = tg.NewClient(c.invoker)
}
+294
View File
@@ -0,0 +1,294 @@
package telegram_test
import (
"bytes"
"context"
"testing"
"time"
"github.com/go-faster/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
"go.uber.org/zap/zaptest"
"go.mau.fi/mautrix-telegram/pkg/gotd/session"
"go.mau.fi/mautrix-telegram/pkg/gotd/tdsync"
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram"
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/dcs"
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/downloader"
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/uploader"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
"go.mau.fi/mautrix-telegram/pkg/gotd/tgerr"
"go.mau.fi/mautrix-telegram/pkg/gotd/tgtest"
"go.mau.fi/mautrix-telegram/pkg/gotd/tgtest/cluster"
"go.mau.fi/mautrix-telegram/pkg/gotd/tgtest/services/file"
"go.mau.fi/mautrix-telegram/pkg/gotd/transport"
)
type clusterSetup struct {
TB testing.TB
Cluster *cluster.Cluster
Logger *zap.Logger
}
type clientSetup struct {
TB testing.TB
Options telegram.Options
Complete func()
}
var user = &tg.User{
ID: 10,
AccessHash: 10,
Username: "username",
}
func testCluster(
p dcs.Protocol,
ws bool,
setup func(q clusterSetup),
run func(ctx context.Context, c clientSetup) error,
) func(t *testing.T) {
return func(t *testing.T) {
t.Parallel()
log := zaptest.NewLogger(t)
defer func() { _ = log.Sync() }()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
g := tdsync.NewCancellableGroup(ctx)
c := cluster.NewCluster(cluster.Options{
Web: ws,
Logger: log.Named("cluster"),
Protocol: p,
})
setup(clusterSetup{
TB: t,
Cluster: c,
Logger: log,
})
g.Go(c.Up)
g.Go(func(ctx context.Context) error {
select {
// Wait for cluster readiness.
case <-c.Ready():
case <-ctx.Done():
return ctx.Err()
}
return run(ctx, clientSetup{
TB: t,
Options: telegram.Options{
UpdateHandler: telegram.UpdateHandlerFunc(func(ctx context.Context, u tg.UpdatesClass) error {
// No-op update handler.
return nil
}),
PublicKeys: c.Keys(),
Resolver: c.Resolver(),
Logger: log.Named("client"),
SessionStorage: &session.StorageMemory{},
DCList: c.List(),
},
Complete: cancel,
})
})
log.Debug("Waiting")
if err := g.Wait(); err != nil && !errors.Is(err, context.Canceled) {
require.NoError(t, err)
}
}
}
func testAllTransports(t *testing.T, test func(p dcs.Protocol) func(t *testing.T)) {
t.Run("Abridged", test(transport.Abridged))
t.Run("Intermediate", test(transport.Intermediate))
t.Run("PaddedIntermediate", test(transport.PaddedIntermediate))
t.Run("Full", test(transport.Full))
}
func testTransport(p dcs.Protocol) func(t *testing.T) {
testMessage := "ну че там с деньгами?"
return testCluster(p, false, func(s clusterSetup) {
h := tgtest.TestTransport(s.TB, s.Logger.Named("handler"), testMessage)
d := s.Cluster.Dispatch(2, "server")
d.Handle(tg.MessagesSendMessageRequestTypeID, h)
d.Handle(tg.UsersGetUsersRequestTypeID, h)
}, func(ctx context.Context, c clientSetup) error {
opts := c.Options
opts.AckBatchSize = 1
opts.AckInterval = time.Millisecond * 50
opts.RetryInterval = time.Millisecond * 50
logger := opts.Logger
dispatcher := tg.NewUpdateDispatcher()
opts.UpdateHandler = dispatcher
client := telegram.NewClient(1, "hash", opts)
waitForMessage := make(chan struct{})
dispatcher.OnNewMessage(func(ctx context.Context, entities tg.Entities, update *tg.UpdateNewMessage) error {
message := update.Message.(*tg.Message).Message
logger.Info("Got message", zap.String("text", message))
assert.Equal(c.TB, testMessage, message)
if err := client.SendMessage(ctx, &tg.MessagesSendMessageRequest{
Peer: &tg.InputPeerUser{},
Message: "какими деньгами?",
}); err != nil {
return err
}
logger.Info("Closing waitForMessage")
close(waitForMessage)
return nil
})
return client.Run(ctx, func(ctx context.Context) error {
select {
case <-ctx.Done():
c.TB.Error("Failed to wait for message")
return ctx.Err()
case <-waitForMessage:
logger.Info("Returning")
c.Complete()
return nil
}
})
})
}
func TestClientE2E(t *testing.T) {
testAllTransports(t, testTransport)
}
func testMigrate(p dcs.Protocol) func(t *testing.T) {
test := func(ws bool) func(t *testing.T) {
wait := make(chan struct{}, 1)
return testCluster(p, ws, func(s clusterSetup) {
c := s.Cluster
c.Common().Vector(tg.UsersGetUsersRequestTypeID, user)
c.Dispatch(1, "server").HandleFunc(tg.MessagesSendMessageRequestTypeID,
func(server *tgtest.Server, req *tgtest.Request) error {
m := &tg.MessagesSendMessageRequest{}
if err := m.Decode(req.Buf); err != nil {
return err
}
select {
case wait <- struct{}{}:
case <-req.RequestCtx.Done():
return req.RequestCtx.Err()
}
return server.SendGZIP(req, &tg.Updates{})
},
)
c.Dispatch(2, "migrate").HandleFunc(tg.MessagesSendMessageRequestTypeID,
func(server *tgtest.Server, req *tgtest.Request) error {
m := &tg.MessagesSendMessageRequest{}
if err := m.Decode(req.Buf); err != nil {
return err
}
return server.SendErr(req, tgerr.New(303, "NETWORK_MIGRATE_1"))
},
)
}, func(ctx context.Context, c clientSetup) error {
opts := c.Options
opts.MigrationTimeout = time.Minute
client := telegram.NewClient(1, "hash", opts)
return client.Run(ctx, func(ctx context.Context) error {
if err := client.SendMessage(ctx, &tg.MessagesSendMessageRequest{
Peer: &tg.InputPeerUser{},
Message: "abc",
}); err != nil {
return errors.Wrap(err, "send")
}
select {
case <-wait:
c.Complete()
case <-ctx.Done():
return ctx.Err()
}
return nil
})
})
}
return func(t *testing.T) {
t.Run("TCP", test(false))
t.Run("Websocket", test(true))
}
}
func TestMigrate(t *testing.T) {
t.Run("Intermediate", testMigrate(transport.Intermediate))
}
func testFiles(p dcs.Protocol) func(t *testing.T) {
test := func(ws bool) func(t *testing.T) {
return testCluster(p, ws, func(s clusterSetup) {
c := s.Cluster
c.Common().Vector(tg.UsersGetUsersRequestTypeID, user)
f := file.NewService(file.Config{
HashPartSize: 1024,
})
f.Register(c.Dispatch(2, "DC"))
}, func(ctx context.Context, c clientSetup) error {
client := telegram.NewClient(1, "hash", c.Options)
defer c.Complete()
return client.Run(ctx, func(ctx context.Context) error {
raw := tg.NewClient(client)
upd := uploader.NewUploader(raw)
dwn := downloader.NewDownloader()
payloads := [][]byte{
[]byte("data"),
bytes.Repeat([]byte{10}, 1337),
bytes.Repeat([]byte{42}, 16384),
}
for _, payload := range payloads {
f, err := upd.FromBytes(ctx, "10.jpg", payload)
if err != nil {
return err
}
vf, ok := f.(interface {
GetID() (value int64)
})
if !ok {
return errors.Errorf("%T", f)
}
var b bytes.Buffer
_, err = dwn.Download(raw, &tg.InputFileLocation{
VolumeID: vf.GetID(),
LocalID: 10,
}).WithVerify(true).Stream(ctx, &b)
if err != nil {
return err
}
if !bytes.Equal(payload, b.Bytes()) {
c.TB.Error("must be equal")
}
}
return nil
})
})
}
return func(t *testing.T) {
t.Run("TCP", test(false))
t.Run("Websocket", test(true))
}
}
func TestFiles(t *testing.T) {
t.Run("Intermediate", testFiles(transport.Intermediate))
}
+120
View File
@@ -0,0 +1,120 @@
package telegram_test
import (
"context"
"os"
"strconv"
"strings"
"testing"
"time"
"github.com/go-faster/errors"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"go.uber.org/zap/zaptest"
"go.mau.fi/mautrix-telegram/pkg/gotd/session"
"go.mau.fi/mautrix-telegram/pkg/gotd/tdsync"
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram"
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/dcs"
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/internal/e2etest"
"go.mau.fi/mautrix-telegram/pkg/gotd/testutil"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
"go.mau.fi/mautrix-telegram/pkg/gotd/transport"
)
func tryConnect(ctx context.Context, opts telegram.Options) error {
client := telegram.NewClient(telegram.TestAppID, telegram.TestAppHash, opts)
return client.Run(ctx, func(ctx context.Context) error {
_, err := client.API().HelpGetNearestDC(ctx)
return err
})
}
func testTransportExternal(resolver dcs.Resolver, storage session.Storage) func(t *testing.T) {
return func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
log := zaptest.NewLogger(t)
defer func() { _ = log.Sync() }()
require.NoError(t, tryConnect(ctx, telegram.Options{
Logger: log.Named("client"),
SessionStorage: storage,
Resolver: resolver,
}))
}
}
func TestExternalE2EConnect(t *testing.T) {
testutil.SkipExternal(t)
// To re-use session.
storage := &session.StorageMemory{}
tcp := func(p dcs.Protocol) func(t *testing.T) {
return testTransportExternal(dcs.Plain(dcs.PlainOptions{Protocol: p}), storage)
}
t.Run("Abridged", tcp(transport.Abridged))
t.Run("Intermediate", tcp(transport.Intermediate))
t.Run("PaddedIntermediate", tcp(transport.PaddedIntermediate))
t.Run("Full", tcp(transport.Full))
wsOpts := dcs.WebsocketOptions{}
t.Run("Websocket", testTransportExternal(dcs.Websocket(wsOpts), storage))
}
const dialog = `— Да?
— Алё!
— Да да?
— Ну как там с деньгами?
А?
— Как с деньгами-то там?
— Чё с деньгами?
— Чё?
— Куда ты звонишь?
— Тебе звоню.
— Кому?
— Ну тебе.`
func TestExternalE2EUsersDialog(t *testing.T) {
testutil.SkipExternal(t)
if v, _ := strconv.ParseBool(os.Getenv("GOTD_E2E_DIALOGS_BROKEN")); v {
// TODO(ernado): enable when fixed
t.Skip("Dialogs are broken.")
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
log := zaptest.NewLogger(t).WithOptions(zap.IncreaseLevel(zapcore.InfoLevel))
cfg := e2etest.TestOptions{
Logger: log,
}
suite := e2etest.NewSuite(t, cfg)
auth := make(chan *tg.User, 1)
g := tdsync.NewLogGroup(ctx, log.Named("group"))
g.Go("echobot", func(ctx context.Context) error {
if err := e2etest.NewEchoBot(suite, auth).Run(ctx); err != nil {
return errors.Wrap(err, "echo bot")
}
return nil
})
user, ok := <-auth
if ok {
g.Go("terentyev", func(ctx context.Context) error {
defer g.Cancel()
if err := e2etest.NewUser(suite, strings.Split(dialog, "\n"), user.Username).Run(ctx); err != nil {
return errors.Wrap(err, "user")
}
return nil
})
}
require.NoError(t, g.Wait())
}
+138
View File
@@ -0,0 +1,138 @@
package telegram_test
import (
"bytes"
"context"
"encoding/base64"
"encoding/hex"
"io"
"os"
"os/exec"
"strings"
"testing"
"time"
"github.com/go-faster/errors"
"github.com/stretchr/testify/require"
"go.uber.org/zap/zaptest"
"go.mau.fi/mautrix-telegram/pkg/gotd/session"
"go.mau.fi/mautrix-telegram/pkg/gotd/tdsync"
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram"
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/dcs"
)
type mtg struct {
path string
addr string
}
type signalWriter struct {
io.Writer
wait *tdsync.Ready
}
func (s signalWriter) Write(p []byte) (n int, err error) {
s.wait.Signal()
return s.Writer.Write(p)
}
func (m mtg) run(ctx context.Context, secret string, out, err io.Writer, wait *tdsync.Ready) error {
cmd := exec.CommandContext(ctx, m.path, "simple-run", "-d", m.addr, secret)
cmd.Stdout = signalWriter{Writer: out, wait: wait}
cmd.Stderr = signalWriter{Writer: err, wait: wait}
cmd.Env = append([]string{"MTG_DEBUG=true", "MTG_TEST_DC=true"}, os.Environ()...)
return cmd.Run()
}
func (m mtg) generateSecret(ctx context.Context, _ string) ([]byte, error) {
args := []string{"generate-secret", "google.com"}
o, err := exec.CommandContext(ctx, m.path, args...).Output()
if err != nil {
return nil, errors.Wrap(err, "execute command")
}
output := strings.TrimSpace(string(o))
r, err := base64.RawURLEncoding.DecodeString(output)
if err != nil {
return nil, errors.Wrapf(err, "decode secret %q", output)
}
return r, nil
}
func testMTProxy(secretType string, m mtg, storage session.Storage) func(t *testing.T) {
return func(t *testing.T) {
a := require.New(t)
logger := zaptest.NewLogger(t)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
secret, err := m.generateSecret(ctx, secretType)
a.NoError(err)
// Store mtg logs to buffer and print it only if test failed.
w := &bytes.Buffer{}
t.Cleanup(func() {
if t.Failed() {
_, _ = io.Copy(os.Stdout, w)
}
})
g := tdsync.NewCancellableGroup(ctx)
ready := tdsync.NewReady()
g.Go(func(ctx context.Context) error {
err := m.run(ctx, hex.EncodeToString(secret), w, w, ready)
select {
case <-ctx.Done():
return nil
default:
return err
}
})
g.Go(func(ctx context.Context) error {
defer g.Cancel()
select {
case <-ready.Ready():
case <-ctx.Done():
return ctx.Err()
}
resolver, err := dcs.MTProxy(m.addr, secret, dcs.MTProxyOptions{})
if err != nil {
return err
}
return tryConnect(ctx, telegram.Options{
Resolver: resolver,
Logger: logger,
SessionStorage: storage,
DCList: dcs.Prod(),
})
})
a.NoError(g.Wait())
}
}
func TestExternalE2EMTProxy(t *testing.T) {
addr, ok := os.LookupEnv("GOTD_MTPROXY_ADDR")
if !ok {
t.Skip("Skipped. Set GOTD_MTPROXY_ADDR to enable external e2e mtproxy test.")
}
mtgPath, err := exec.LookPath("mtg")
if err != nil {
t.Fatal("mtg binary not found", err)
}
// To re-use session.
storage := &session.StorageMemory{}
m := mtg{path: mtgPath, addr: addr}
// TODO(tdakkota): test all proxy types (mtg v2 supports only faketls)
for _, secretType := range []string{"tls"} {
t.Run(strings.Title(secretType), testMTProxy(secretType, m, storage))
}
}
+155
View File
@@ -0,0 +1,155 @@
package telegram
import (
"context"
"crypto/md5"
"fmt"
"math/rand"
"net"
"os"
"path/filepath"
"strings"
"testing"
"time"
"github.com/go-faster/errors"
"github.com/cenkalti/backoff/v4"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/atomic"
"go.uber.org/zap"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
"go.mau.fi/mautrix-telegram/pkg/gotd/mt"
"go.mau.fi/mautrix-telegram/pkg/gotd/proto"
"go.mau.fi/mautrix-telegram/pkg/gotd/rpc"
"go.mau.fi/mautrix-telegram/pkg/gotd/tdsync"
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/dcs"
"go.mau.fi/mautrix-telegram/pkg/gotd/testutil"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
"go.mau.fi/mautrix-telegram/pkg/gotd/tgmock"
"go.mau.fi/mautrix-telegram/pkg/gotd/tmap"
)
type testHandler func(id int64, body bin.Encoder) (bin.Encoder, error)
type testConn struct {
id atomic.Int64
engine *rpc.Engine
ready *tdsync.Ready
}
func (t *testConn) Ping(ctx context.Context) error {
return errors.New("not implemented")
}
func (t *testConn) Ready() <-chan struct{} {
return t.ready.Ready()
}
func (t *testConn) Invoke(ctx context.Context, input bin.Encoder, output bin.Decoder) error {
id := t.id.Inc() - 1
return t.engine.Do(ctx, rpc.Request{
Input: input,
Output: output,
MsgID: id,
})
}
func (testConn) Run(ctx context.Context) error { return nil }
func newTestClient(h testHandler) *Client {
var engine *rpc.Engine
engine = rpc.New(func(ctx context.Context, msgID int64, seqNo int32, in bin.Encoder) error {
if response, err := h(msgID, in); err != nil {
engine.NotifyError(msgID, err)
} else {
var b bin.Buffer
if err := b.Encode(response); err != nil {
return err
}
return engine.NotifyResult(msgID, &b)
}
return nil
}, rpc.Options{})
ready := tdsync.NewReady()
ready.Signal()
client := &Client{
log: zap.NewNop(),
rand: rand.New(rand.NewSource(1)),
appID: TestAppID,
appHash: TestAppHash,
conn: &testConn{engine: engine, ready: ready},
ctx: context.Background(),
cancel: func() {},
updateHandler: UpdateHandlerFunc(func(ctx context.Context, u tg.UpdatesClass) error { return nil }),
onTransfer: noopOnTransfer,
}
client.init()
return client
}
func mockClient(cb func(mock *tgmock.Mock, client *Client)) func(t *testing.T) {
return func(t *testing.T) {
t.Helper()
mock := tgmock.NewRequire(t)
client := newTestClient(testHandler(mock.Handler()))
cb(mock, client)
}
}
func TestEnsureErrorIfCantConnect(t *testing.T) {
testErr := testutil.TestError()
dialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
return nil, testErr
}
opts := Options{
Resolver: dcs.Plain(dcs.PlainOptions{Dial: dialer}),
ReconnectionBackoff: func() backoff.BackOff {
return backoff.WithMaxRetries(backoff.NewConstantBackOff(time.Nanosecond), 2)
},
}
err := NewClient(1, "hash", opts).Run(context.Background(),
func(ctx context.Context) error {
return nil
})
require.ErrorIs(t, err, testErr)
}
// newCorpusTracer will save incoming messages to corpus folder.
//
// Usage:
//
// client.trace.OnMessage = newCorpusTracer(t)
//
// nolint: deadcode,unused // optional
func newCorpusTracer(t testing.TB) func(b *bin.Buffer) {
types := tmap.New(
mt.TypesMap(),
tg.TypesMap(),
proto.TypesMap(),
)
dir := filepath.Join("..", "_fuzz", "handle_message", "corpus")
return func(b *bin.Buffer) {
id, _ := b.PeekID()
h := md5.Sum(b.Buf)
name := types.Get(id)
if name == "" {
name = "unknown"
}
if idx := strings.Index(name, "#"); idx > 0 {
// Removing type id from name.
name = name[:idx]
}
base := fmt.Sprintf("trace_%x_%s_%x",
id, name, h,
)
assert.NoError(t, os.WriteFile(filepath.Join(dir, base), b.Buf, 0600))
}
}
+15
View File
@@ -0,0 +1,15 @@
package telegram
import (
"go.mau.fi/mautrix-telegram/pkg/gotd/session"
"go.mau.fi/mautrix-telegram/pkg/gotd/tgerr"
)
// SessionStorage is alias of mtproto.SessionStorage.
type SessionStorage = session.Storage
// FileSessionStorage is alias of mtproto.FileSessionStorage.
type FileSessionStorage = session.FileStorage
// Error represents RPC error returned to request.
type Error = tgerr.Error
+103
View File
@@ -0,0 +1,103 @@
package telegram
import (
"context"
"go.uber.org/zap"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
"go.mau.fi/mautrix-telegram/pkg/gotd/mtproto"
"go.mau.fi/mautrix-telegram/pkg/gotd/pool"
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/dcs"
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/internal/manager"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
"go.mau.fi/mautrix-telegram/pkg/gotd/transport"
)
type clientHandler struct {
client *Client
}
func (c clientHandler) OnSession(cfg tg.Config, s mtproto.Session) (err error) {
err = c.client.onSession(cfg, s)
if c.client.opts.Handler != nil {
c.client.opts.Handler.OnSession(s)
}
return
}
func (c clientHandler) OnMessage(b *bin.Buffer) error {
return c.client.handleUpdates(b)
}
func (c *Client) asHandler() manager.Handler {
return clientHandler{
client: c,
}
}
type connConstructor func(
create mtproto.Dialer,
mode manager.ConnMode,
appID int,
opts mtproto.Options,
connOpts manager.ConnOptions,
) pool.Conn
func defaultConstructor() connConstructor {
return func(
create mtproto.Dialer,
mode manager.ConnMode,
appID int,
opts mtproto.Options,
connOpts manager.ConnOptions,
) pool.Conn {
return manager.CreateConn(create, mode, appID, opts, connOpts)
}
}
func (c *Client) dcList() dcs.List {
cfg := c.cfg.Load()
return dcs.List{
Options: cfg.DCOptions,
Domains: c.domains,
Test: c.testDC,
}
}
func (c *Client) primaryDC(dc int) mtproto.Dialer {
return func(ctx context.Context) (transport.Conn, error) {
return c.resolver.Primary(ctx, dc, c.dcList())
}
}
func (c *Client) createPrimaryConn(setup manager.SetupCallback) pool.Conn {
return c.createConn(0, c.defaultMode, setup, c.onDead, c.onAuthError)
}
func (c *Client) createConn(
id int64,
mode manager.ConnMode,
setup manager.SetupCallback,
onDead func(),
onAuthError func(error),
) pool.Conn {
opts, s := c.session.Options(c.opts)
opts.Logger = c.log.Named("conn").With(
zap.Int64("conn_id", id),
zap.Int("dc_id", s.DC),
)
return c.create(
c.primaryDC(s.DC), mode, c.appID,
opts, manager.ConnOptions{
DC: s.DC,
Test: c.testDC,
Device: c.device,
Handler: c.asHandler(),
Setup: setup,
OnDead: onDead,
OnAuthError: onAuthError,
},
)
}
+193
View File
@@ -0,0 +1,193 @@
package telegram
import (
"context"
"time"
"github.com/cenkalti/backoff/v4"
"github.com/go-faster/errors"
"go.uber.org/multierr"
"go.uber.org/zap"
"go.mau.fi/mautrix-telegram/pkg/gotd/exchange"
"go.mau.fi/mautrix-telegram/pkg/gotd/tdsync"
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/auth"
"go.mau.fi/mautrix-telegram/pkg/gotd/tgerr"
)
func (c *Client) runUntilRestart(ctx context.Context) error {
g := tdsync.NewCancellableGroup(ctx)
g.Go(c.conn.Run)
// If we don't need updates, so there is no reason to subscribe for it.
if !c.noUpdatesMode {
g.Go(func(ctx context.Context) error {
// Call method which requires authorization, to subscribe for updates.
// See https://core.telegram.org/api/updates#subscribing-to-updates.
self, err := c.Self(ctx)
if err != nil {
// Ignore unauthorized errors.
if !auth.IsUnauthorized(err) {
c.log.Warn("Got error on self", zap.Error(err))
} else if c.onAuthError != nil {
c.onAuthError(err)
}
if h := c.onSelfError; h != nil {
// Help with https://github.com/gotd/td/issues/1458.
if err := h(ctx, err); err != nil {
return errors.Wrap(err, "onSelfError")
}
}
return nil
}
c.log.Info("Got self", zap.String("username", self.Username))
if c.onConnected != nil {
c.onConnected()
}
return nil
})
}
g.Go(func(ctx context.Context) error {
select {
case <-ctx.Done():
return ctx.Err()
case <-c.restart:
c.log.Debug("Restart triggered")
// Should call cancel() to cancel group.
g.Cancel()
return nil
}
})
return g.Wait()
}
func (c *Client) isPermanentError(err error) bool {
// See https://github.com/gotd/td/issues/1458.
if errors.Is(err, exchange.ErrKeyFingerprintNotFound) {
return true
}
if tgerr.Is(err, "AUTH_KEY_UNREGISTERED", "SESSION_EXPIRED", "AUTH_KEY_DUPLICATED") {
return true
}
if auth.IsUnauthorized(err) {
return true
}
return false
}
func (c *Client) reconnectUntilClosed(ctx context.Context) error {
// Note that we currently have no timeout on connection, so this is
// potentially eternal.
b := tdsync.SyncBackoff(backoff.WithContext(c.newConnBackoff(), ctx))
c.connBackoff.Store(&b)
return backoff.RetryNotify(func() error {
if err := c.runUntilRestart(ctx); err != nil {
if c.isPermanentError(err) {
return backoff.Permanent(err)
}
return err
}
return nil
}, b, func(err error, timeout time.Duration) {
c.log.Info("Restarting connection", zap.Error(err), zap.Duration("backoff", timeout))
c.connMux.Lock()
c.conn = c.createPrimaryConn(nil)
c.connMux.Unlock()
})
}
func (c *Client) onReady() {
c.log.Debug("Ready")
c.ready.Signal()
if b := c.connBackoff.Load(); b != nil {
// Reconnect faster next time.
(*b).Reset()
}
}
func (c *Client) resetReady() {
c.ready.Reset()
}
// Run starts client session and blocks until connection close.
// The f callback is called on successful session initialization and Run
// will return on f() result.
//
// Context of callback will be canceled if fatal error is detected.
// The ctx is used for background operations like updates handling or pools.
//
// See `examples/bg-run` and `contrib/gb` package for classic approach without
// explicit callback, with Connect and defer close().
func (c *Client) Run(ctx context.Context, f func(ctx context.Context) error) (err error) {
if c.ctx != nil {
select {
case <-c.ctx.Done():
return errors.Wrap(c.ctx.Err(), "client already closed")
default:
}
}
// Setting up client context for background operations like updates
// handling or pool creation.
c.ctx, c.cancel = context.WithCancel(ctx)
c.log.Info("Starting")
defer c.log.Info("Closed")
// Cancel client on exit.
defer c.cancel()
defer func() {
c.subConnsMux.Lock()
defer c.subConnsMux.Unlock()
for _, conn := range c.subConns {
if closeErr := conn.Close(); !errors.Is(closeErr, context.Canceled) {
multierr.AppendInto(&err, closeErr)
}
}
}()
c.resetReady()
if err := c.restoreConnection(ctx); err != nil {
return err
}
g := tdsync.NewCancellableGroup(ctx)
g.Go(c.reconnectUntilClosed)
g.Go(func(ctx context.Context) error {
select {
case <-ctx.Done():
c.cancel()
return ctx.Err()
case <-c.ctx.Done():
return c.ctx.Err()
}
})
g.Go(func(ctx context.Context) error {
select {
case <-ctx.Done():
return ctx.Err()
case <-c.ready.Ready():
if err := f(ctx); err != nil {
return errors.Wrap(err, "callback")
}
// Should call cancel() to cancel ctx.
// This will terminate c.conn.Run().
c.log.Debug("Callback returned, stopping")
g.Cancel()
return nil
}
})
if err := g.Wait(); !errors.Is(err, context.Canceled) {
return err
}
return nil
}
+46
View File
@@ -0,0 +1,46 @@
package telegram
import (
"context"
"testing"
"time"
"github.com/cenkalti/backoff/v4"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
"go.mau.fi/mautrix-telegram/pkg/gotd/exchange"
)
type fingerprintNotFoundConn struct{}
func (m fingerprintNotFoundConn) Run(context.Context) error {
return exchange.ErrKeyFingerprintNotFound
}
func (m fingerprintNotFoundConn) Invoke(context.Context, bin.Encoder, bin.Decoder) error {
return nil
}
func (m fingerprintNotFoundConn) Ping(context.Context) error {
return nil
}
func (m fingerprintNotFoundConn) Ready() <-chan struct{} {
return nil
}
func TestClient_reconnectUntilClosed(t *testing.T) {
client := Client{
newConnBackoff: func() backoff.BackOff {
return backoff.NewConstantBackOff(time.Nanosecond)
},
log: zap.NewNop(),
}
client.init()
client.conn = fingerprintNotFoundConn{}
ctx := context.Background()
require.ErrorIs(t, client.reconnectUntilClosed(ctx), exchange.ErrKeyFingerprintNotFound)
}
+181
View File
@@ -0,0 +1,181 @@
package dcs
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/rsa"
"crypto/sha256"
"encoding/base64"
"math/big"
"net"
"sort"
"sync"
"github.com/go-faster/errors"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
)
var dnsKey struct {
rsa.PublicKey
once sync.Once
eBig *big.Int
}
//nolint:gochecknoinits
func init() {
dnsKey.once.Do(func() {
k, err := crypto.ParseRSAPublicKeys([]byte(`-----BEGIN RSA PUBLIC KEY-----
MIIBCgKCAQEAyr+18Rex2ohtVy8sroGPBwXD3DOoKCSpjDqYoXgCqB7ioln4eDCF
fOBUlfXUEvM/fnKCpF46VkAftlb4VuPDeQSS/ZxZYEGqHaywlroVnXHIjgqoxiAd
192xRGreuXIaUKmkwlM9JID9WS2jUsTpzQ91L8MEPLJ/4zrBwZua8W5fECwCCh2c
9G5IzzBm+otMS/YKwmR1olzRCyEkyAEjXWqBI9Ftv5eG8m0VkBzOG655WIYdyV0H
fDK/NWcvGqa0w/nriMD6mDjKOryamw0OP9QuYgMN0C9xMW9y8SmP4h92OAWodTYg
Y1hZCxdv6cs5UnW9+PWvS+WIbkh+GaWYxwIDAQAB
-----END RSA PUBLIC KEY-----`))
if err != nil {
panic(err)
}
dnsKey.PublicKey = *k[0]
dnsKey.eBig = big.NewInt(int64(dnsKey.E))
})
}
// parseDNSList decodes raw encrypted simple config.
//
// Notice that parseDNSList does not decode base64, user should do it manually.
func parseDNSList(input [256]byte) (tg.HelpConfigSimple, error) {
// See https://github.com/tdlib/td/blob/master/td/telegram/ConfigManager.cpp#L148.
x := new(big.Int).SetBytes(input[:])
y := new(big.Int).Exp(x, dnsKey.eBig, dnsKey.N)
dataRSA := make([]byte, 256)
if !crypto.FillBytes(y, dataRSA) {
return tg.HelpConfigSimple{}, errors.New("dataRSA has invalid size")
}
block, err := aes.NewCipher(dataRSA[:32])
if err != nil {
return tg.HelpConfigSimple{}, err
}
d := cipher.NewCBCDecrypter(block, dataRSA[16:32])
dataCBC := dataRSA[32:]
d.CryptBlocks(dataCBC, dataCBC)
decrypted := dataCBC[:len(dataCBC)-16]
decryptedHash := sha256.Sum256(decrypted)
hash := dataCBC[len(dataCBC)-16:]
if !bytes.Equal(decryptedHash[:16], hash) {
return tg.HelpConfigSimple{}, errors.New("hash mismatch")
}
var cfg tg.HelpConfigSimple
if err := cfg.Decode(&bin.Buffer{Buf: decrypted[4:]}); err != nil {
return tg.HelpConfigSimple{}, err
}
return cfg, nil
}
type sortByLen []string
func (s sortByLen) Len() int {
return len(s)
}
func (s sortByLen) Less(i, j int) bool {
return len(s[i]) > len(s[j])
}
func (s sortByLen) Swap(i, j int) {
s[i], s[j] = s[j], s[i]
}
// DNSConfig is DC connection config obtained from DNS.
type DNSConfig struct {
// Date field of HelpConfigSimple.
Date int
// Expires field of HelpConfigSimple.
Expires int
// Rules field of HelpConfigSimple.
Rules []tg.AccessPointRule
}
// Options returns DC options from this config.
func (d DNSConfig) Options() (r []tg.DCOption) {
convertIP := func(ip int) string {
return net.IPv4(
byte(ip),
byte(ip>>8),
byte(ip>>16),
byte(ip>>24),
).String()
}
for _, rule := range d.Rules {
for _, ip := range rule.IPs {
switch ip := ip.(type) {
case *tg.IPPort:
r = append(r, tg.DCOption{
ID: rule.DCID,
IPAddress: convertIP(ip.Ipv4),
Port: ip.Port,
})
case *tg.IPPortSecret:
r = append(r, tg.DCOption{
TCPObfuscatedOnly: true,
ID: rule.DCID,
IPAddress: convertIP(ip.Ipv4),
Port: ip.Port,
Secret: ip.Secret,
})
}
}
}
return r
}
// ParseDNSConfig parses tg.HelpConfigSimple from TXT response.
func ParseDNSConfig(txt []string) (DNSConfig, error) {
encoding := base64.StdEncoding
const (
decodedLen = 256
encodedLen = 344
)
sort.Sort(sortByLen(txt))
var totalLength int
for i := range txt {
totalLength += len(txt[i])
}
if totalLength != encodedLen {
return DNSConfig{}, errors.Errorf("invalid input length %d", totalLength)
}
var (
encoded [encodedLen]byte
decoded [decodedLen]byte
)
n := 0
for i := range txt {
n += copy(encoded[n:], txt[i])
}
if _, err := encoding.Decode(decoded[:], encoded[:]); err != nil {
return DNSConfig{}, errors.Wrap(err, "decode")
}
cfg, err := parseDNSList(decoded)
if err != nil {
return DNSConfig{}, errors.Wrap(err, "decrypt config")
}
return DNSConfig{
Date: cfg.Date,
Expires: cfg.Expires,
Rules: cfg.Rules,
}, nil
}
+92
View File
@@ -0,0 +1,92 @@
package dcs
import (
"testing"
"github.com/stretchr/testify/require"
)
func testTXTResponse() []string {
return []string{
"LcmEoukF2bVjKwz3E+J9BsDdL+rv9lGqLQWIGXrWACT2ESk5xuOpA6Cz6klKRbhbwSiHOd2zC5PiR57j/OJHPpj4i+tw==",
"umjjLFLpOKtPeW9zHLq2ypbMzg/zkqvPhvhr0bxrLZlgPQ04l2GpO/4qZgAx3tk3BDHbY6/gmG1e8eaFBq3YSqR5SZ5hQ1Cm5f4/" +
"o67GYcPJClaf1TiHq3wVfsQ5OLnyJRw9A2ZfUfzIXxoSklPJrVdF/4hM1ZdUE0eWDAbmYf7JCeao8ecVVwKndd4CZHZS9wyf1T7DIUh95VpQ" +
"sn2klLPA6gA/2YNXOh9gITvjZrKuXLwwh9hBHhPvxv",
}
}
func Test_ParseDNSConfig(t *testing.T) {
t.Run("Good", func(t *testing.T) {
a := require.New(t)
cfg, err := ParseDNSConfig(testTXTResponse())
a.NoError(err)
a.Equal(1565541126, cfg.Expires)
a.Equal(1562949126, cfg.Date)
a.Len(cfg.Rules, 1)
rule := cfg.Rules[0]
a.Equal(2, rule.DCID)
})
t.Run("Bad", func(t *testing.T) {
tests := []struct {
name string
input []string
}{
{"Empty", nil},
{"InvalidHash", func() []string {
r := testTXTResponse()
first := r[0]
r[0] = string(first[0]+1) + first[1:]
return r
}()},
{"InvalidBase64", func() []string {
r := testTXTResponse()
first := r[0]
r[0] = string('#') + first[1:]
return r
}()},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := ParseDNSConfig(tt.input)
require.Error(t, err)
})
}
})
}
func BenchmarkDNSConfig(b *testing.B) {
message := testTXTResponse()
b.ResetTimer()
b.ReportAllocs()
var (
err error
cfgSink DNSConfig
)
for i := 0; i < b.N; i++ {
cfgSink, err = ParseDNSConfig(message)
if cfgSink.Date == 0 || err != nil {
b.Fatal(err)
}
}
}
func TestDNSConfig_Options(t *testing.T) {
a := require.New(t)
cfg, err := ParseDNSConfig(testTXTResponse())
a.NoError(err)
options := cfg.Options()
a.Len(options, 1)
option := options[0]
a.Equal(2, option.ID)
a.Equal(14544, option.Port)
a.Equal("98.210.59.139", option.IPAddress)
a.True(option.TCPObfuscatedOnly)
}
+2
View File
@@ -0,0 +1,2 @@
// Package dcs contains Telegram DCs list and some helpers.
package dcs
+52
View File
@@ -0,0 +1,52 @@
package dcs_test
import (
"context"
"fmt"
"time"
"golang.org/x/net/proxy"
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram"
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/dcs"
)
func ExampleDialFunc() {
// Dial using proxy from environment.
// Creating connection.
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
client := telegram.NewClient(1, "appHash", telegram.Options{
Resolver: dcs.Plain(dcs.PlainOptions{Dial: proxy.Dial}),
})
_ = client.Run(ctx, func(ctx context.Context) error {
fmt.Println("Started")
return nil
})
}
func ExampleDialFunc_dialer() {
// Dial using SOCKS5 proxy.
sock5, _ := proxy.SOCKS5("tcp", "IP:PORT", &proxy.Auth{
User: "YOURUSERNAME",
Password: "YOURPASSWORD",
}, proxy.Direct)
dc := sock5.(proxy.ContextDialer)
// Creating connection.
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
client := telegram.NewClient(1, "appHash", telegram.Options{
Resolver: dcs.Plain(dcs.PlainOptions{
Dial: dc.DialContext,
}),
})
_ = client.Run(ctx, func(ctx context.Context) error {
fmt.Println("Started")
return nil
})
}
+60
View File
@@ -0,0 +1,60 @@
package dcs
import (
"sort"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
)
// FindDCs searches DCs candidates from given config.
func FindDCs(opts []tg.DCOption, dcID int, preferIPv6 bool) []tg.DCOption {
// Preallocate slice.
candidates := make([]tg.DCOption, 0, 32)
for _, candidateDC := range opts {
if candidateDC.ID != dcID {
continue
}
candidates = append(candidates, candidateDC)
}
if len(candidates) < 1 {
return nil
}
sort.Slice(candidates, func(i, j int) bool {
l, r := candidates[i], candidates[j]
// If we prefer IPv6 and left is IPv6 and right is not, so then
// left is smaller (would be before right).
if preferIPv6 {
if l.Ipv6 && !r.Ipv6 {
return true
}
if !l.Ipv6 && r.Ipv6 {
return false
}
}
// Also we prefer static addresses.
return l.Static && !r.Static
})
return candidates
}
// FindPrimaryDCs searches new primary DC from given config.
// Unlike FindDC, it filters CDNs and MediaOnly servers, returns error
// if not found.
func FindPrimaryDCs(opts []tg.DCOption, dcID int, preferIPv6 bool) []tg.DCOption {
candidates := FindDCs(opts, dcID, preferIPv6)
// Filter (in place) from SliceTricks.
n := 0
for _, opt := range candidates {
if !opt.MediaOnly && !opt.CDN {
candidates[n] = opt
n++
}
}
return candidates[:n]
}
+76
View File
@@ -0,0 +1,76 @@
package dcs
import (
"fmt"
"testing"
"github.com/stretchr/testify/require"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
)
func TestFindDCs(t *testing.T) {
dcOptions := []tg.DCOption{
{ID: 1, Ipv6: false},
{ID: 1, Ipv6: true},
{ID: 1, Ipv6: false, Static: true},
{ID: 2, Ipv6: true, Static: true},
{ID: 2, Ipv6: true},
{ID: 2, Ipv6: false},
}
for i := range dcOptions {
dcOptions[i].IPAddress = fmt.Sprintf("DC: %d, Index: %d", dcOptions[i].ID, i)
}
a := require.New(t)
dc := FindDCs(dcOptions, -2, false)
a.Empty(dc)
dc = FindDCs(dcOptions, -2, true)
a.Empty(dc)
// Prefer IPv6.
dc = FindDCs(dcOptions, 1, true)
a.True(dc[0].Ipv6)
// Prefer static.
dc = FindDCs(dcOptions, 1, false)
a.True(dc[0].Static)
// Prefer static and IPv6.
dc = FindDCs(dcOptions, 2, true)
a.True(dc[0].Static)
a.True(dc[0].Ipv6)
}
func TestFindPrimaryDCs(t *testing.T) {
dcOptions := []tg.DCOption{
{ID: 1, Ipv6: false},
{ID: 1, Ipv6: true},
{ID: 1, Ipv6: false, Static: true},
{ID: 2, Ipv6: true, Static: true, MediaOnly: true},
{ID: 2, Ipv6: true, CDN: true},
{ID: 2, Ipv6: false, CDN: true},
}
for i := range dcOptions {
dcOptions[i].IPAddress = fmt.Sprintf("DC: %d, Index: %d", dcOptions[i].ID, i)
}
a := require.New(t)
dc := FindPrimaryDCs(dcOptions, -2, false)
a.Empty(dc)
dc = FindPrimaryDCs(dcOptions, -2, true)
a.Empty(dc)
// Prefer IPv6.
dc = FindPrimaryDCs(dcOptions, 1, true)
a.True(dc[0].Ipv6)
// Prefer static.
dc = FindPrimaryDCs(dcOptions, 1, false)
a.True(dc[0].Static)
// Filter CDN/MediaOnly/TCPo.
dc = FindPrimaryDCs(dcOptions, 2, false)
a.Empty(dc)
}
+15
View File
@@ -0,0 +1,15 @@
package dcs
import "go.mau.fi/mautrix-telegram/pkg/gotd/tg"
// List is a list of Telegram DC addresses and domains.
type List struct {
Options []tg.DCOption
Domains map[int]string
Test bool
}
// Zero returns true if this List is zero value.
func (d List) Zero() bool {
return d.Options == nil && d.Domains == nil && !d.Test
}
+140
View File
@@ -0,0 +1,140 @@
package dcs
import (
"context"
"io"
"net"
"github.com/go-faster/errors"
"go.uber.org/multierr"
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
"go.mau.fi/mautrix-telegram/pkg/gotd/mtproxy"
"go.mau.fi/mautrix-telegram/pkg/gotd/mtproxy/obfuscator"
"go.mau.fi/mautrix-telegram/pkg/gotd/proto/codec"
"go.mau.fi/mautrix-telegram/pkg/gotd/transport"
)
var _ Resolver = mtProxy{}
type mtProxy struct {
dial DialFunc
protocol protocol
addr, network string
secret mtproxy.Secret
tag [4]byte
rand io.Reader
}
func (m mtProxy) Primary(ctx context.Context, dc int, _ List) (transport.Conn, error) {
return m.resolve(ctx, dc)
}
func (m mtProxy) MediaOnly(ctx context.Context, dc int, _ List) (transport.Conn, error) {
if dc > 0 {
dc *= -1
}
return m.resolve(ctx, dc)
}
func (m mtProxy) CDN(ctx context.Context, dc int, _ List) (transport.Conn, error) {
return m.resolve(ctx, dc)
}
func (m mtProxy) resolve(ctx context.Context, dc int) (transport.Conn, error) {
c, err := m.dial(ctx, m.network, m.addr)
if err != nil {
return nil, errors.Wrapf(err, "connect to the MTProxy %q", m.addr)
}
conn, err := m.handshakeConn(c, dc)
if err != nil {
err = errors.Wrap(err, "handshake")
return nil, multierr.Combine(err, c.Close())
}
return conn, nil
}
// handshakeConn inits given net.Conn as MTProto connection.
func (m mtProxy) handshakeConn(c net.Conn, dc int) (transport.Conn, error) {
var obsConn *obfuscator.Conn
switch m.secret.Type {
case mtproxy.Simple, mtproxy.Secured:
obsConn = obfuscator.Obfuscated2(m.rand, c)
case mtproxy.TLS:
obsConn = obfuscator.FakeTLS(m.rand, c)
default:
return nil, errors.Errorf("unknown MTProxy secret type: %d", m.secret.Type)
}
secret := m.secret
if err := obsConn.Handshake(m.tag, dc, secret); err != nil {
return nil, errors.Wrap(err, "MTProxy handshake")
}
transportConn, err := m.protocol.Handshake(obsConn)
if err != nil {
return nil, errors.Wrap(err, "transport handshake")
}
return transportConn, nil
}
// MTProxyOptions is MTProxy resolver creation options.
type MTProxyOptions struct {
// Dial specifies the dial function for creating unencrypted TCP connections.
// If Dial is nil, then the resolver dials using package net.
Dial DialFunc
// Network to use. Defaults to "tcp"
Network string
// Random source for MTProxy obfuscator.
Rand io.Reader
}
func (m *MTProxyOptions) setDefaults() {
if m.Dial == nil {
var d net.Dialer
m.Dial = d.DialContext
}
if m.Network == "" {
m.Network = "tcp"
}
if m.Rand == nil {
m.Rand = crypto.DefaultRand()
}
}
// MTProxy creates MTProxy obfuscated DC resolver.
//
// See https://core.telegram.org/mtproto/mtproto-transports#transport-obfuscation.
func MTProxy(addr string, secret []byte, opts MTProxyOptions) (Resolver, error) {
s, err := mtproxy.ParseSecret(secret)
if err != nil {
return nil, err
}
var cdc codec.Codec = codec.PaddedIntermediate{}
tag := codec.PaddedIntermediateClientStart
// FIXME(tdakkota): some proxies forces to use Padded (Secure) Intermediate
// even if secret denotes to use another transport type.
if s.Type != mtproxy.TLS {
if c, ok := s.ExpectedCodec(); ok {
cdc = c
tag = [4]byte{s.Tag, s.Tag, s.Tag, s.Tag}
}
}
opts.setDefaults()
return mtProxy{
dial: opts.Dial,
addr: addr,
network: opts.Network,
protocol: transport.NewProtocol(func() transport.Codec { return codec.NoHeader{Codec: cdc} }),
secret: s,
tag: tag,
rand: opts.Rand,
}, nil
}
+226
View File
@@ -0,0 +1,226 @@
package dcs
import (
"context"
"io"
"net"
"strconv"
"github.com/go-faster/errors"
"go.uber.org/multierr"
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
"go.mau.fi/mautrix-telegram/pkg/gotd/mtproxy"
"go.mau.fi/mautrix-telegram/pkg/gotd/mtproxy/obfuscator"
"go.mau.fi/mautrix-telegram/pkg/gotd/proto/codec"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
"go.mau.fi/mautrix-telegram/pkg/gotd/transport"
)
var _ Resolver = plain{}
type plain struct {
dial DialFunc
protocol Protocol
rand io.Reader
network string
noObfuscated bool
preferIPv6 bool
}
func (p plain) Primary(ctx context.Context, dc int, list List) (transport.Conn, error) {
candidates := FindPrimaryDCs(list.Options, dc, p.preferIPv6)
if p.noObfuscated {
n := 0
for _, x := range candidates {
if !x.TCPObfuscatedOnly {
candidates[n] = x
n++
}
}
candidates = candidates[:n]
}
return p.connect(ctx, dc, list.Test, candidates)
}
func (p plain) MediaOnly(ctx context.Context, dc int, list List) (transport.Conn, error) {
candidates := FindDCs(list.Options, dc, p.preferIPv6)
// Filter (in place) from SliceTricks.
n := 0
for _, x := range candidates {
if x.MediaOnly {
candidates[n] = x
n++
}
}
return p.connect(ctx, dc, list.Test, candidates[:n])
}
func (p plain) CDN(ctx context.Context, dc int, list List) (transport.Conn, error) {
return nil, errors.Errorf("can't resolve %d: CDN is unsupported", dc)
}
func (p plain) dialTransport(ctx context.Context, test bool, dc tg.DCOption) (_ transport.Conn, rerr error) {
addr := net.JoinHostPort(dc.IPAddress, strconv.Itoa(dc.Port))
conn, err := p.dial(ctx, p.network, addr)
if err != nil {
return nil, err
}
defer func() {
if rerr != nil {
multierr.AppendInto(&rerr, conn.Close())
}
}()
proto := p.protocol
if dc.TCPObfuscatedOnly {
dcID := dc.ID
if test {
if dcID < 0 {
dcID -= 10000
} else {
dcID += 10000
}
}
var (
cdc codec.Codec = codec.Intermediate{}
tag = codec.IntermediateClientStart
obfs = obfuscator.Obfuscated2
)
secret, err := mtproxy.ParseSecret(dc.Secret)
if err != nil {
return nil, errors.Wrap(err, "check DC secret")
}
if secret.Type == mtproxy.TLS {
obfs = obfuscator.FakeTLS
} else if c, ok := secret.ExpectedCodec(); ok {
tag = [4]byte{secret.Tag, secret.Tag, secret.Tag, secret.Tag}
cdc = c
}
obfsConn := obfs(p.rand, conn)
if err := obfsConn.Handshake(tag, dcID, secret); err != nil {
return nil, err
}
conn = obfsConn
proto = transport.NewProtocol(func() transport.Codec {
return codec.NoHeader{Codec: cdc}
})
}
transportConn, err := proto.Handshake(conn)
if err != nil {
return nil, errors.Wrap(err, "transport handshake")
}
return transportConn, nil
}
func (p plain) connect(ctx context.Context, dc int, test bool, dcOptions []tg.DCOption) (transport.Conn, error) {
switch len(dcOptions) {
case 0:
return nil, errors.Errorf("no addresses for DC %d", dc)
case 1:
return p.dialTransport(ctx, test, dcOptions[0])
}
type dialResult struct {
conn transport.Conn
err error
}
// We use unbuffered channel to ensure that only one connection will be returned
// and all other will be closed.
results := make(chan dialResult)
tryDial := func(ctx context.Context, option tg.DCOption) {
conn, err := p.dialTransport(ctx, test, option)
select {
case results <- dialResult{
conn: conn,
err: err,
}:
case <-ctx.Done():
if conn != nil {
_ = conn.Close()
}
}
}
dialCtx, dialCancel := context.WithCancel(ctx)
defer dialCancel()
for _, dcOption := range dcOptions {
go tryDial(dialCtx, dcOption)
}
remain := len(dcOptions)
var rErr error
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
case result := <-results:
remain--
if result.err != nil {
rErr = multierr.Append(rErr, result.err)
if remain == 0 {
return nil, rErr
}
continue
}
return result.conn, nil
}
}
}
// PlainOptions is plain resolver creation options.
type PlainOptions struct {
// Protocol is the transport protocol to use. Defaults to intermediate.
Protocol Protocol
// Dial specifies the dial function for creating unencrypted TCP connections.
// If Dial is nil, then the resolver dials using package net.
Dial DialFunc
// Random source for TCPObfuscated DCs.
Rand io.Reader
// Network to use. Defaults to "tcp".
Network string
// NoObfuscated denotes to filter out TCP Obfuscated Only DCs.
NoObfuscated bool
// PreferIPv6 gives IPv6 DCs higher precedence.
// Default is to prefer IPv4 DCs over IPv6.
PreferIPv6 bool
}
func (m *PlainOptions) setDefaults() {
if m.Protocol == nil {
m.Protocol = transport.Intermediate
}
if m.Dial == nil {
var d net.Dialer
m.Dial = d.DialContext
}
if m.Rand == nil {
m.Rand = crypto.DefaultRand()
}
if m.Network == "" {
m.Network = "tcp"
}
}
// Plain creates plain DC resolver.
func Plain(opts PlainOptions) Resolver {
opts.setDefaults()
return plain{
dial: opts.Dial,
protocol: opts.Protocol,
rand: opts.Rand,
network: opts.Network,
noObfuscated: opts.NoObfuscated,
preferIPv6: opts.PreferIPv6,
}
}
+134
View File
@@ -0,0 +1,134 @@
package dcs
import (
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
)
// Prod returns production DC list.
func Prod() List {
// https://github.com/telegramdesktop/tdesktop/blob/dev/Telegram/SourceFiles/mtproto/mtproto_dc_options.cpp
// Also available with client.API().HelpGetConfig(ctx) [tg.DCOption].
// TODO(ernado): automate update from HelpGetConfig.
return List{
Options: []tg.DCOption{
{
ID: 1,
IPAddress: "149.154.175.52",
Port: 443,
},
{
Static: true,
ID: 1,
IPAddress: "149.154.175.53",
Port: 443,
},
{
Ipv6: true,
ID: 1,
IPAddress: "2001:0b28:f23d:f001:0000:0000:0000:000a",
Port: 443,
},
{
ID: 2,
IPAddress: "149.154.167.41",
Port: 443,
},
{
Static: true,
ID: 2,
IPAddress: "149.154.167.41",
Port: 443,
},
{
MediaOnly: true,
ID: 2,
IPAddress: "149.154.167.222",
Port: 443,
},
{
Ipv6: true,
ID: 2,
IPAddress: "2001:067c:04e8:f002:0000:0000:0000:000a",
Port: 443,
},
{
Ipv6: true,
MediaOnly: true,
ID: 2,
IPAddress: "2001:067c:04e8:f002:0000:0000:0000:000b",
Port: 443,
},
{
ID: 3,
IPAddress: "149.154.175.100",
Port: 443,
},
{
Static: true,
ID: 3,
IPAddress: "149.154.175.100",
Port: 443,
},
{
Ipv6: true,
ID: 3,
IPAddress: "2001:0b28:f23d:f003:0000:0000:0000:000a",
Port: 443,
},
{
ID: 4,
IPAddress: "149.154.167.91",
Port: 443,
},
{
Static: true,
ID: 4,
IPAddress: "149.154.167.91",
Port: 443,
},
{
Ipv6: true,
ID: 4,
IPAddress: "2001:067c:04e8:f004:0000:0000:0000:000a",
Port: 443,
},
{
MediaOnly: true,
ID: 4,
IPAddress: "149.154.166.120",
Port: 443,
},
{
Ipv6: true,
MediaOnly: true,
ID: 4,
IPAddress: "2001:067c:04e8:f004:0000:0000:0000:000b",
Port: 443,
},
{
Ipv6: true,
ID: 5,
IPAddress: "2001:0b28:f23f:f005:0000:0000:0000:000a",
Port: 443,
},
{
ID: 5,
IPAddress: "91.108.56.191",
Port: 443,
},
{
Static: true,
ID: 5,
IPAddress: "91.108.56.191",
Port: 443,
},
},
Domains: map[int]string{
1: "wss://pluto.web.telegram.org/apiws",
2: "wss://venus.web.telegram.org/apiws",
3: "wss://aurora.web.telegram.org/apiws",
4: "wss://vesta.web.telegram.org/apiws",
5: "wss://flora.web.telegram.org/apiws",
},
}
}
+17
View File
@@ -0,0 +1,17 @@
package dcs
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestProd(t *testing.T) {
require.NotEmpty(t, Prod())
// Check copying.
a := Prod().Options
a[0].IPAddress = "10"
b := Prod().Options
require.NotEqual(t, "10", b[0].IPAddress)
}
+11
View File
@@ -0,0 +1,11 @@
package dcs
import (
"net"
"go.mau.fi/mautrix-telegram/pkg/gotd/transport"
)
type protocol interface {
Handshake(conn net.Conn) (transport.Conn, error)
}
+28
View File
@@ -0,0 +1,28 @@
package dcs
import (
"context"
"net"
"go.mau.fi/mautrix-telegram/pkg/gotd/transport"
)
var _ Resolver = DefaultResolver()
// Resolver resolves DC and creates transport MTProto connection.
type Resolver interface {
Primary(ctx context.Context, dc int, list List) (transport.Conn, error)
MediaOnly(ctx context.Context, dc int, list List) (transport.Conn, error)
CDN(ctx context.Context, dc int, list List) (transport.Conn, error)
}
// DialFunc connects to the address on the named network.
type DialFunc func(ctx context.Context, network, addr string) (net.Conn, error)
// Protocol is MTProto transport protocol.
//
// See https://core.telegram.org/mtproto/mtproto-transports
type Protocol interface {
Codec() transport.Codec
Handshake(conn net.Conn) (transport.Conn, error)
}
+9
View File
@@ -0,0 +1,9 @@
//go:build js && wasm
// +build js,wasm
package dcs
// DefaultResolver returns default DC resolver for current platform.
func DefaultResolver() Resolver {
return Websocket(WebsocketOptions{})
}
+9
View File
@@ -0,0 +1,9 @@
//go:build !js || !wasm
// +build !js !wasm
package dcs
// DefaultResolver returns default DC resolver for current platform.
func DefaultResolver() Resolver {
return Plain(PlainOptions{})
}
+59
View File
@@ -0,0 +1,59 @@
package dcs
import "go.mau.fi/mautrix-telegram/pkg/gotd/tg"
// Staging returns staging DC list.
//
// Deprecated: Use Test().
func Staging() List {
return Test()
}
// Test returns test DC list.
func Test() List {
return List{
Options: []tg.DCOption{
{
ID: 1,
IPAddress: "149.154.175.10",
Port: 443,
},
{
ID: 1,
Ipv6: true,
IPAddress: "2001:0b28:f23d:f001:0000:0000:0000:000e",
Port: 443,
},
{
ID: 2,
IPAddress: "149.154.167.40",
Port: 443,
},
{
ID: 2,
Ipv6: true,
IPAddress: "2001:067c:04e8:f002:0000:0000:0000:000e",
Port: 443,
},
{
ID: 3,
IPAddress: "149.154.175.117",
Port: 443,
},
{
ID: 3,
Ipv6: true,
IPAddress: "2001:0b28:f23d:f003:0000:0000:0000:000e",
Port: 443,
},
},
Domains: map[int]string{
1: "wss://pluto.web.telegram.org/apiws_test",
2: "wss://venus.web.telegram.org/apiws_test",
3: "wss://aurora.web.telegram.org/apiws_test",
4: "wss://vesta.web.telegram.org/apiws_test",
5: "wss://flora.web.telegram.org/apiws_test",
},
Test: true,
}
}
+17
View File
@@ -0,0 +1,17 @@
package dcs
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestTestDCs(t *testing.T) {
require.NotEmpty(t, Prod())
// Check copying.
a := Test().Options
a[0].IPAddress = "10"
b := Test().Options
require.NotEqual(t, "10", b[0].IPAddress)
}
+103
View File
@@ -0,0 +1,103 @@
package dcs
import (
"context"
"io"
"github.com/coder/websocket"
"github.com/go-faster/errors"
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
"go.mau.fi/mautrix-telegram/pkg/gotd/mtproxy"
"go.mau.fi/mautrix-telegram/pkg/gotd/mtproxy/obfuscator"
"go.mau.fi/mautrix-telegram/pkg/gotd/proto/codec"
"go.mau.fi/mautrix-telegram/pkg/gotd/transport"
"go.mau.fi/mautrix-telegram/pkg/gotd/wsutil"
)
var _ Resolver = ws{}
type ws struct {
dialOptions *websocket.DialOptions
protocol protocol
tag [4]byte
rand io.Reader
}
func (w ws) connect(ctx context.Context, dc int, domains map[int]string) (transport.Conn, error) {
addr, ok := domains[dc]
if !ok {
return nil, errors.Errorf("domain for %d not found", dc)
}
conn, resp, err := websocket.Dial(ctx, addr, w.dialOptions)
if resp != nil && resp.Body != nil {
_ = resp.Body.Close()
}
if err != nil {
return nil, errors.Wrap(err, "dial ws")
}
obsConn := obfuscator.Obfuscated2(w.rand, wsutil.NetConn(conn))
if err := obsConn.Handshake(w.tag, dc, mtproxy.Secret{
Secret: nil,
Type: mtproxy.Simple,
}); err != nil {
return nil, errors.Wrap(err, "handshake")
}
transportConn, err := w.protocol.Handshake(obsConn)
if err != nil {
return nil, errors.Wrap(err, "transport handshake")
}
return transportConn, nil
}
func (w ws) Primary(ctx context.Context, dc int, list List) (transport.Conn, error) {
return w.connect(ctx, dc, list.Domains)
}
func (w ws) MediaOnly(ctx context.Context, dc int, list List) (transport.Conn, error) {
return nil, errors.Errorf("can't resolve %d: MediaOnly is unsupported", dc)
}
func (w ws) CDN(ctx context.Context, dc int, list List) (transport.Conn, error) {
return nil, errors.Errorf("can't resolve %d: CDN is unsupported", dc)
}
// WebsocketOptions is Websocket resolver creation options.
type WebsocketOptions struct {
// Dialer specifies the websocket dialer.
// If Dialer is nil, then the resolver dials using websocket.DefaultDialer.
DialOptions *websocket.DialOptions
// Random source for MTProxy obfuscator.
Rand io.Reader
}
func (m *WebsocketOptions) setDefaults() {
if m.DialOptions == nil {
m.DialOptions = &websocket.DialOptions{Subprotocols: []string{
"binary",
}}
}
if m.Rand == nil {
m.Rand = crypto.DefaultRand()
}
}
// Websocket creates Websocket DC resolver.
//
// See https://core.telegram.org/mtproto/transports#websocket.
func Websocket(opts WebsocketOptions) Resolver {
cdc := codec.Intermediate{}
opts.setDefaults()
return ws{
dialOptions: opts.DialOptions,
protocol: transport.NewProtocol(func() transport.Codec { return codec.NoHeader{Codec: cdc} }),
tag: cdc.ObfuscatedTag(),
rand: opts.Rand,
}
}
+6
View File
@@ -0,0 +1,6 @@
package telegram
import "go.mau.fi/mautrix-telegram/pkg/gotd/telegram/internal/manager"
// DeviceConfig is config which send when Telegram connection session created.
type DeviceConfig = manager.DeviceConfig
+2
View File
@@ -0,0 +1,2 @@
// Package telegram implements Telegram client.
package telegram
+87
View File
@@ -0,0 +1,87 @@
package downloader
import (
"context"
"io"
"os"
"path/filepath"
"github.com/go-faster/errors"
"go.uber.org/multierr"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
)
// Builder is a download builder.
type Builder struct {
downloader *Downloader
schema schema
hashes []tg.FileHash
verify bool
threads int
}
func newBuilder(downloader *Downloader, schema schema) *Builder {
return &Builder{
schema: schema,
threads: 1,
downloader: downloader,
}
}
// WithThreads sets downloading goroutines limit.
func (b *Builder) WithThreads(threads int) *Builder {
if threads > 0 {
b.threads = threads
}
return b
}
// WithVerify sets verify parameter.
// If verify is true, file hashes will be checked
// Verify is true by default for CDN downloads.
func (b *Builder) WithVerify(verify bool) *Builder {
b.verify = verify
return b
}
func (b *Builder) reader() *reader {
if b.verify {
return verifiedReader(b.schema, newVerifier(b.schema, b.hashes...))
}
return plainReader(b.schema, b.downloader.partSize)
}
// Stream downloads file to given io.Writer.
// NB: in this mode download can't be parallel.
func (b *Builder) Stream(ctx context.Context, output io.Writer) (tg.StorageFileTypeClass, error) {
return b.downloader.stream(ctx, b.reader(), output)
}
// StreamToReader streams a file to the returned [io.Reader].
// NB: in this mode download can't be parallel.
func (b *Builder) StreamToReader(ctx context.Context) (tg.StorageFileTypeClass, io.Reader, error) {
var tgDC int
ctx = context.WithValue(ctx, "tg_dc", &tgDC)
return b.downloader.streamToReader(ctx, b.reader())
}
// Parallel downloads file to given io.WriterAt.
func (b *Builder) Parallel(ctx context.Context, output io.WriterAt) (tg.StorageFileTypeClass, error) {
return b.downloader.parallel(ctx, b.reader(), b.threads, output)
}
// ToPath downloads file to given path.
func (b *Builder) ToPath(ctx context.Context, path string) (_ tg.StorageFileTypeClass, err error) {
f, err := os.Create(filepath.Clean(path))
if err != nil {
return nil, errors.Wrap(err, "create output file")
}
defer func() {
multierr.AppendInto(&err, f.Close())
}()
return b.Parallel(ctx, f)
}
+98
View File
@@ -0,0 +1,98 @@
package downloader
import (
"context"
"crypto/aes"
"crypto/cipher"
"encoding/binary"
"github.com/go-faster/errors"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
)
// ExpiredTokenError error is returned when Downloader get expired file token for CDN.
// See https://core.telegram.org/constructor/upload.fileCdnRedirect.
type ExpiredTokenError struct {
*tg.UploadCDNFileReuploadNeeded
}
// Error implements error interface.
func (r *ExpiredTokenError) Error() string {
return "redirect to master DC for requesting new file token"
}
// cdn is a CDN DC download schema.
// See https://core.telegram.org/cdn#getting-files-from-a-cdn.
type cdn struct {
cdn CDN
client Client
pool *bin.Pool
redirect *tg.UploadFileCDNRedirect
}
var _ schema = cdn{}
// decrypt decrypts file chunk from Telegram CDN.
// See https://core.telegram.org/cdn#decrypting-files.
func (c cdn) decrypt(src []byte, offset int64) ([]byte, error) {
block, err := aes.NewCipher(c.redirect.EncryptionKey)
if err != nil {
return nil, errors.Wrap(err, "create cipher")
}
if block.BlockSize() != len(c.redirect.EncryptionIv) {
return nil, errors.Errorf(
"invalid IV or key length, block size %d != IV %d",
block.BlockSize(), len(c.redirect.EncryptionIv),
)
}
// Copy IV to buffer from Pool.
iv := c.pool.GetSize(len(c.redirect.EncryptionIv))
defer c.pool.Put(iv)
copy(iv.Buf, c.redirect.EncryptionIv)
// For IV, it should use the value of encryption_iv, modified in the following manner:
// for each offset replace the last 4 bytes of the encryption_iv with offset / 16 in big-endian.
binary.BigEndian.PutUint32(iv.Buf[iv.Len()-4:], uint32(offset/16))
dst := make([]byte, len(src))
cipher.NewCTR(block, iv.Buf).XORKeyStream(dst, src)
return dst, nil
}
func (c cdn) Chunk(ctx context.Context, offset int64, limit int) (chunk, error) {
r, err := c.cdn.UploadGetCDNFile(ctx, &tg.UploadGetCDNFileRequest{
Offset: offset,
Limit: limit,
FileToken: c.redirect.FileToken,
})
if err != nil {
return chunk{}, err
}
switch result := r.(type) {
case *tg.UploadCDNFile:
data, err := c.decrypt(result.Bytes, offset)
if err != nil {
return chunk{}, err
}
return chunk{
data: data,
}, nil
case *tg.UploadCDNFileReuploadNeeded:
return chunk{}, &ExpiredTokenError{UploadCDNFileReuploadNeeded: result}
default:
return chunk{}, errors.Errorf("unexpected type %T", r)
}
}
func (c cdn) Hashes(ctx context.Context, offset int64) ([]tg.FileHash, error) {
return c.client.UploadGetCDNFileHashes(ctx, &tg.UploadGetCDNFileHashesRequest{
FileToken: c.redirect.FileToken,
Offset: offset,
})
}
+36
View File
@@ -0,0 +1,36 @@
package downloader
import (
"testing"
"github.com/stretchr/testify/require"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
)
func Test_cdn_decrypt(t *testing.T) {
testdata := make([]byte, 32)
tests := []struct {
name string
key, iv []byte
err bool
}{
{"Bad key", []byte{10}, nil, true},
{"Bad IV", make([]byte, 32), nil, true},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
c := &cdn{
redirect: &tg.UploadFileCDNRedirect{
EncryptionKey: test.key,
EncryptionIv: test.iv,
},
}
_, err := c.decrypt(testdata, 0)
if test.err {
require.Error(t, err)
}
})
}
}
+34
View File
@@ -0,0 +1,34 @@
package downloader
import (
"context"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
)
// CDN represents Telegram RPC client to CDN server.
type CDN interface {
UploadGetCDNFile(ctx context.Context, request *tg.UploadGetCDNFileRequest) (tg.UploadCDNFileClass, error)
}
// Client represents Telegram RPC client.
type Client interface {
UploadGetFile(ctx context.Context, request *tg.UploadGetFileRequest) (tg.UploadFileClass, error)
UploadGetFileHashes(ctx context.Context, request *tg.UploadGetFileHashesRequest) ([]tg.FileHash, error)
UploadReuploadCDNFile(ctx context.Context, request *tg.UploadReuploadCDNFileRequest) ([]tg.FileHash, error)
UploadGetCDNFileHashes(ctx context.Context, request *tg.UploadGetCDNFileHashesRequest) ([]tg.FileHash, error)
UploadGetWebFile(ctx context.Context, request *tg.UploadGetWebFileRequest) (*tg.UploadWebFile, error)
}
type chunk struct {
data []byte
tag tg.StorageFileTypeClass
}
// schema is simple interface for different download schemas.
type schema interface {
Chunk(ctx context.Context, offset int64, limit int) (chunk, error)
Hashes(ctx context.Context, offset int64) ([]tg.FileHash, error)
}
@@ -0,0 +1,48 @@
// Package downloader contains downloading files helpers.
package downloader
import (
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
)
// Downloader is Telegram file downloader.
type Downloader struct {
partSize int
pool *bin.Pool
}
const defaultPartSize = 512 * 1024 // 512 kb
// NewDownloader creates new Downloader.
func NewDownloader() *Downloader {
return new(Downloader).WithPartSize(defaultPartSize)
}
// WithPartSize sets chunk size.
// Must be divisible by 4KB.
//
// See https://core.telegram.org/api/files#downloading-files.
func (d *Downloader) WithPartSize(partSize int) *Downloader {
d.partSize = partSize
d.pool = bin.NewPool(partSize)
return d
}
// Download creates Builder for plain downloads.
func (d *Downloader) Download(rpc Client, location tg.InputFileLocationClass) *Builder {
return newBuilder(d, master{
client: rpc,
precise: true,
allowCDN: false,
location: location,
})
}
// Web creates Builder for web files downloads.
func (d *Downloader) Web(rpc Client, location tg.InputWebFileLocationClass) *Builder {
return newBuilder(d, web{
client: rpc,
location: location,
})
}
@@ -0,0 +1,300 @@
package downloader
import (
"bytes"
"context"
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/binary"
"io"
"runtime"
"strconv"
"testing"
"github.com/go-faster/errors"
"github.com/stretchr/testify/require"
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
"go.mau.fi/mautrix-telegram/pkg/gotd/syncio"
"go.mau.fi/mautrix-telegram/pkg/gotd/testutil"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
)
type mock struct {
data []byte
hashes mockHashes
migrate bool
err bool
hashesErr bool
redirect *tg.UploadFileCDNRedirect
}
var testErr = testutil.TestError()
func (m mock) getPart(offset int64, limit int) []byte {
length := len(m.data)
if offset >= int64(length) {
return []byte{}
}
size := length - int(offset)
if size > limit {
size = limit
}
r := make([]byte, size)
copy(r, m.data[offset:])
return r
}
func (m mock) UploadGetFile(ctx context.Context, request *tg.UploadGetFileRequest) (tg.UploadFileClass, error) {
if m.err {
return nil, testErr
}
if m.migrate {
return m.redirect, nil
}
return &tg.UploadFile{
Bytes: m.getPart(request.Offset, request.Limit),
}, nil
}
func (m mock) UploadGetFileHashes(ctx context.Context, request *tg.UploadGetFileHashesRequest) ([]tg.FileHash, error) {
if m.hashesErr {
return nil, testErr
}
return m.hashes.Hashes(ctx, request.Offset)
}
func (m mock) UploadReuploadCDNFile(ctx context.Context, request *tg.UploadReuploadCDNFileRequest) ([]tg.FileHash, error) {
panic("implement me")
}
func (m mock) UploadGetCDNFile(ctx context.Context, request *tg.UploadGetCDNFileRequest) (tg.UploadCDNFileClass, error) {
if m.err {
return nil, testErr
}
if m.migrate {
return &tg.UploadCDNFileReuploadNeeded{
RequestToken: []byte{1, 2, 3},
}, nil
}
block, err := aes.NewCipher(m.redirect.EncryptionKey)
if err != nil {
return nil, errors.Wrap(err, "CDN mock cipher creation")
}
iv := make([]byte, len(m.redirect.EncryptionIv))
copy(iv, m.redirect.EncryptionIv)
binary.BigEndian.PutUint32(iv[len(iv)-4:], uint32(request.Offset/16))
part := m.getPart(request.Offset, request.Limit)
r := make([]byte, len(part))
cipher.NewCTR(block, iv).XORKeyStream(r, part)
return &tg.UploadCDNFile{
Bytes: r,
}, nil
}
func (m mock) UploadGetCDNFileHashes(ctx context.Context, request *tg.UploadGetCDNFileHashesRequest) ([]tg.FileHash, error) {
if m.hashesErr {
return nil, testErr
}
return m.hashes.Hashes(ctx, request.Offset)
}
func (m mock) UploadGetWebFile(ctx context.Context, request *tg.UploadGetWebFileRequest) (*tg.UploadWebFile, error) {
if m.err {
return nil, testErr
}
return &tg.UploadWebFile{
Bytes: m.getPart(int64(request.Offset), request.Limit),
}, nil
}
func countHashes(data []byte, partSize int) (r [][]tg.FileHash) {
actions := data
batchSize := partSize
batches := make([][]byte, 0, (len(actions)+batchSize-1)/batchSize)
for batchSize < len(actions) {
actions, batches = actions[batchSize:], append(batches, actions[0:batchSize:batchSize])
}
batches = append(batches, actions)
currentRange := make([]tg.FileHash, 0, 10)
offset := 0
for _, batch := range batches {
if len(currentRange) >= 10 {
r = append(r, currentRange)
currentRange = make([]tg.FileHash, 0, 10)
}
currentRange = append(currentRange, tg.FileHash{
Offset: int64(offset),
Limit: partSize,
Hash: crypto.SHA256(batch),
})
offset += len(batch)
if len(batch) < partSize {
break
}
}
r = append(r, currentRange)
return
}
func Test_countHashes(t *testing.T) {
a := require.New(t)
data := bytes.Repeat([]byte{1, 2, 3, 4, 5}, 10)
hashes := countHashes(data, 4)
a.NotEmpty(hashes)
for _, hashRange := range hashes {
for _, hash := range hashRange {
from := hash.Offset
to := int(hash.Offset) + hash.Limit
if to > len(data) {
to = len(data)
}
a.Equal(crypto.SHA256(data[from:to]), hash.Hash)
}
}
}
func TestDownloader(t *testing.T) {
ctx := context.Background()
key := make([]byte, 32)
iv := make([]byte, aes.BlockSize)
if _, err := io.ReadFull(rand.Reader, key); err != nil {
t.Fatal(err)
}
if _, err := io.ReadFull(rand.Reader, iv); err != nil {
t.Fatal(err)
}
redirect := &tg.UploadFileCDNRedirect{
DCID: 1,
FileToken: []byte{10},
EncryptionKey: key,
EncryptionIv: iv,
}
testData := make([]byte, defaultPartSize*2)
if _, err := io.ReadFull(rand.Reader, testData); err != nil {
t.Fatal(err)
}
tests := []struct {
name string
data []byte
migrate bool
err bool
hashesErr bool
}{
{"5b", []byte{1, 2, 3, 4, 5}, false, false, false},
{strconv.Itoa(len(testData)) + "b", testData, false, false, false},
{"Error", []byte{}, false, true, false},
{"HashesError", []byte{}, false, true, true},
{"Migrate", []byte{}, true, false, false},
}
schemas := []struct {
name string
creator func(c Client, cdn CDN) *Builder
}{
{"Master", func(c Client, cdn CDN) *Builder {
return NewDownloader().Download(c, nil)
}},
{"Web", func(c Client, cdn CDN) *Builder {
return NewDownloader().Web(c, nil)
}},
}
ways := []struct {
name string
action func(b *Builder) ([]byte, error)
}{
{"Stream", func(b *Builder) ([]byte, error) {
output := new(bytes.Buffer)
_, err := b.Stream(ctx, output)
return output.Bytes(), err
}},
{"Parallel", func(b *Builder) ([]byte, error) {
output := new(syncio.BufWriterAt)
_, err := b.WithThreads(runtime.GOMAXPROCS(0)).Parallel(ctx, output)
return output.Bytes(), err
}},
{"Parallel-OneThread", func(b *Builder) ([]byte, error) {
output := new(syncio.BufWriterAt)
_, err := b.WithThreads(1).Parallel(ctx, output)
return output.Bytes(), err
}},
}
options := []struct {
name string
action func(b *Builder) *Builder
}{
{"NoVerify", func(b *Builder) *Builder {
return b.WithVerify(false)
}},
{"Verify", func(b *Builder) *Builder {
return b.WithVerify(true)
}},
}
for _, schema := range schemas {
t.Run(schema.name, func(t *testing.T) {
for _, test := range tests {
// Telegram can't redirect web file downloads.
if schema.name == "Web" && test.migrate {
continue
}
t.Run(test.name, func(t *testing.T) {
for _, option := range options {
// Telegram can't return hashes for web files.
if schema.name == "Web" && option.name == "Verify" {
continue
}
t.Run(option.name, func(t *testing.T) {
for _, way := range ways {
t.Run(way.name, func(t *testing.T) {
a := require.New(t)
client := &mock{
data: test.data,
hashes: mockHashes{
ranges: countHashes(test.data, 128*1024),
},
migrate: test.migrate,
err: test.err,
redirect: redirect,
}
b := schema.creator(client, client)
b = option.action(b)
data, err := way.action(b)
switch {
case test.migrate:
a.Error(err)
case test.err:
a.Error(err)
default:
a.NoError(err)
a.True(bytes.Equal(test.data, data))
}
})
}
})
}
})
}
})
}
}
+64
View File
@@ -0,0 +1,64 @@
package downloader
import (
"context"
"strconv"
"github.com/go-faster/errors"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
)
// RedirectError error is returned when Downloader get CDN redirect.
// See https://core.telegram.org/constructor/upload.fileCdnRedirect.
type RedirectError struct {
Redirect *tg.UploadFileCDNRedirect
}
// Error implements error interface.
func (r *RedirectError) Error() string {
return "redirect to CDN DC " + strconv.Itoa(r.Redirect.DCID)
}
// master is a master DC download schema.
// See https://core.telegram.org/api/files#downloading-files.
type master struct {
client Client
precise bool
allowCDN bool
location tg.InputFileLocationClass
}
var _ schema = master{}
func (c master) Chunk(ctx context.Context, offset int64, limit int) (chunk, error) {
req := &tg.UploadGetFileRequest{
Offset: offset,
Limit: limit,
Location: c.location,
}
req.SetCDNSupported(c.allowCDN)
req.SetPrecise(c.precise)
r, err := c.client.UploadGetFile(ctx, req)
if err != nil {
return chunk{}, err
}
switch result := r.(type) {
case *tg.UploadFile:
return chunk{data: result.Bytes, tag: result.Type}, nil
case *tg.UploadFileCDNRedirect:
return chunk{}, &RedirectError{Redirect: result}
default:
return chunk{}, errors.Errorf("unexpected type %T", r)
}
}
func (c master) Hashes(ctx context.Context, offset int64) ([]tg.FileHash, error) {
return c.client.UploadGetFileHashes(ctx, &tg.UploadGetFileHashesRequest{
Location: c.location,
Offset: offset,
})
}
+83
View File
@@ -0,0 +1,83 @@
package downloader
import (
"context"
"io"
"sync"
"github.com/go-faster/errors"
"go.mau.fi/mautrix-telegram/pkg/gotd/syncio"
"go.mau.fi/mautrix-telegram/pkg/gotd/tdsync"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
)
// nolint:gocognit
func (d *Downloader) parallel(
ctx context.Context, r *reader,
threads int, w io.WriterAt,
) (tg.StorageFileTypeClass, error) {
var typ tg.StorageFileTypeClass
typOnce := &sync.Once{}
ready := tdsync.NewReady()
g := tdsync.NewCancellableGroup(ctx)
toWrite := make(chan block, threads)
stop := func(t tg.StorageFileTypeClass) {
typOnce.Do(func() {
typ = t
})
ready.Signal()
}
// Download loop
g.Go(func(ctx context.Context) error {
downloads := tdsync.NewCancellableGroup(ctx)
defer close(toWrite)
for i := 0; i < threads; i++ {
downloads.Go(func(ctx context.Context) error {
for {
select {
case <-ctx.Done():
return ctx.Err()
case <-ready.Ready():
return nil
default:
}
b, err := r.Next(ctx)
if err != nil {
return errors.Wrap(err, "get file")
}
// If returned chunk is zero, that means we read all file.
n := len(b.data)
if n < 1 {
stop(b.tag)
return nil
}
select {
case <-ctx.Done():
return ctx.Err()
case toWrite <- b:
}
if b.last() {
stop(b.tag)
return nil
}
}
})
}
return downloads.Wait()
})
// Write loop
g.Go(writeAtLoop(syncio.NewWriterAt(w), toWrite))
return typ, g.Wait()
}
+107
View File
@@ -0,0 +1,107 @@
package downloader
import (
"context"
"sync"
"github.com/go-faster/errors"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
"go.mau.fi/mautrix-telegram/pkg/gotd/tgerr"
)
type block struct {
chunk
offset int64
partSize int
}
// last compares partSize and chunk length to determine last part.
func (b block) last() bool {
// If returned chunk is smaller than requested part, it seems
// it is last part.
return len(b.data) < b.partSize
}
type reader struct {
sch schema // immutable
verifier *verifier // immutable
partSize int // immutable
offset int64
offsetMux sync.Mutex
}
func verifiedReader(sch schema, verifier *verifier) *reader {
return &reader{
sch: sch,
verifier: verifier,
}
}
func plainReader(sch schema, partSize int) *reader {
return &reader{
sch: sch,
partSize: partSize,
}
}
func (r *reader) Next(ctx context.Context) (block, error) {
if r.verifier != nil {
return r.nextHashed(ctx)
}
return r.nextPlain(ctx)
}
func (r *reader) nextHashed(ctx context.Context) (block, error) {
// Fetch next hashes.
hash, ok, err := r.verifier.next(ctx)
if err != nil {
return block{}, err
}
if !ok {
return block{}, nil
}
// Get next chunk.
b, err := r.next(ctx, hash.Offset, hash.Limit)
if err != nil {
return block{}, err
}
// Verify chunk.
if !r.verifier.verify(hash, b.data) {
return block{}, ErrHashMismatch
}
return b, nil
}
func (r *reader) nextPlain(ctx context.Context) (block, error) {
r.offsetMux.Lock()
offset := r.offset
r.offset += int64(r.partSize)
r.offsetMux.Unlock()
return r.next(ctx, offset, r.partSize)
}
func (r *reader) next(ctx context.Context, offset int64, limit int) (block, error) {
for {
ch, err := r.sch.Chunk(ctx, offset, limit)
if flood, err := tgerr.FloodWait(ctx, err); err != nil {
if flood || tgerr.Is(err, tg.ErrTimeout) {
continue
}
return block{}, errors.Wrap(err, "get next chunk")
}
return block{
chunk: ch,
offset: offset,
partSize: r.partSize,
}, nil
}
}
+48
View File
@@ -0,0 +1,48 @@
package downloader
import (
"context"
"io"
"github.com/go-faster/errors"
)
func writeAtLoop(w io.WriterAt, toWrite <-chan block) func(context.Context) error {
return func(ctx context.Context) error {
for {
select {
case <-ctx.Done():
return ctx.Err()
case part, ok := <-toWrite:
if !ok {
return nil
}
_, err := w.WriteAt(part.data, part.offset)
if err != nil {
return errors.Wrap(err, "write output")
}
}
}
}
}
func writeLoop(w io.Writer, toWrite <-chan block) func(context.Context) error {
return func(ctx context.Context) error {
for {
select {
case <-ctx.Done():
return ctx.Err()
case part, ok := <-toWrite:
if !ok {
return nil
}
_, err := w.Write(part.data)
if err != nil {
return errors.Wrap(err, "write output")
}
}
}
}
}
+54
View File
@@ -0,0 +1,54 @@
package downloader
import (
"context"
"io"
"github.com/go-faster/errors"
"go.mau.fi/mautrix-telegram/pkg/gotd/tdsync"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
)
func (d *Downloader) stream(ctx context.Context, r *reader, w io.Writer) (tg.StorageFileTypeClass, error) {
var typ tg.StorageFileTypeClass
g := tdsync.NewCancellableGroup(ctx)
toWrite := make(chan block, 1)
stop := func(t tg.StorageFileTypeClass) {
typ = t
close(toWrite)
}
// Download loop
g.Go(func(ctx context.Context) error {
for {
b, err := r.Next(ctx)
if err != nil {
return errors.Wrap(err, "get file")
}
n := len(b.data)
if n < 1 {
stop(b.tag)
return nil
}
select {
case <-ctx.Done():
return ctx.Err()
case toWrite <- b:
}
if b.last() {
stop(b.tag)
return nil
}
}
})
// Write loop
g.Go(writeLoop(w, toWrite))
return typ, g.Wait()
}
@@ -0,0 +1,49 @@
package downloader
import (
"context"
"io"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
)
type streamReader struct {
ctx context.Context
reader *reader
curBlock block
last bool
}
var _ io.Reader = (*streamReader)(nil)
func (s *streamReader) Read(p []byte) (n int, err error) {
select {
case <-s.ctx.Done():
return 0, s.ctx.Err()
default:
}
if len(s.curBlock.data) == 0 {
if s.last {
return 0, io.EOF
} else {
s.curBlock, err = s.reader.Next(s.ctx)
if err != nil {
return 0, err
}
s.last = s.curBlock.last()
}
}
n = copy(p, s.curBlock.data)
s.curBlock.data = s.curBlock.data[n:]
return
}
func (d *Downloader) streamToReader(ctx context.Context, r *reader) (tg.StorageFileTypeClass, io.Reader, error) {
first, err := r.Next(ctx)
if err != nil {
return nil, nil, err
}
return first.tag, &streamReader{ctx, r, first, first.last()}, nil
}
+104
View File
@@ -0,0 +1,104 @@
package downloader
import (
"bytes"
"context"
"sort"
"sync"
"github.com/go-faster/errors"
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
"go.mau.fi/mautrix-telegram/pkg/gotd/tgerr"
)
// ErrHashMismatch means that download hash verification was failed.
var ErrHashMismatch = errors.New("file hash mismatch")
type verifier struct {
client schema
hashes []tg.FileHash
offset int64
mux sync.Mutex
}
func newVerifier(client schema, hashes ...tg.FileHash) *verifier {
r := make([]tg.FileHash, len(hashes))
copy(r, hashes)
sort.SliceStable(r, func(i, j int) bool {
return r[i].Offset < r[j].Offset
})
return &verifier{client: client, hashes: r}
}
func (v *verifier) pop() (tg.FileHash, bool) {
if len(v.hashes) < 1 {
return tg.FileHash{}, false
}
// Pop and move.
hash := v.hashes[0]
copy(v.hashes, v.hashes[1:])
v.hashes[len(v.hashes)-1] = tg.FileHash{}
v.hashes = v.hashes[:len(v.hashes)-1]
return hash, true
}
func (v *verifier) update(hashes ...tg.FileHash) (tg.FileHash, bool) {
// If result is empty and queue is empty, so we can't return next hash.
if len(hashes) < 1 {
return tg.FileHash{}, false
}
// Sort hashes by offset.
// Usually Telegram server returns sorted parts, but...
// you never known what can they do.
sort.SliceStable(hashes, func(i, j int) bool {
return hashes[i].Offset < hashes[j].Offset
})
last := hashes[len(hashes)-1]
// Check if we have reached the end.
// If current state offset is equal the last offset + limit (right border)
// then we got all hashes.
if last.Offset == v.offset-int64(last.Limit) {
return tg.FileHash{}, false
}
// Otherwise, we update current offset and add hashes to the end of queue.
v.offset = last.Offset + int64(last.Limit)
v.hashes = append(v.hashes, hashes...)
return v.pop()
}
func (v *verifier) next(ctx context.Context) (tg.FileHash, bool, error) {
v.mux.Lock()
defer v.mux.Unlock()
hash, ok := v.pop()
if ok {
return hash, ok, nil
}
for {
hashes, err := v.client.Hashes(ctx, v.offset)
if flood, err := tgerr.FloodWait(ctx, err); err != nil {
if flood || tgerr.Is(err, tg.ErrTimeout) {
continue
}
return tg.FileHash{}, false, errors.Wrap(err, "get hashes")
}
hash, ok = v.update(hashes...)
return hash, ok, nil
}
}
func (v *verifier) verify(hash tg.FileHash, data []byte) bool {
return bytes.Equal(crypto.SHA256(data), hash.Hash)
}
@@ -0,0 +1,104 @@
package downloader
import (
"context"
"testing"
"github.com/stretchr/testify/require"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
)
var hashRanges = [][]tg.FileHash{
{
tg.FileHash{Offset: 0, Limit: 131072},
tg.FileHash{Offset: 131072, Limit: 131072},
tg.FileHash{Offset: 262144, Limit: 131072},
tg.FileHash{Offset: 393216, Limit: 131072},
tg.FileHash{Offset: 524288, Limit: 131072},
tg.FileHash{Offset: 655360, Limit: 131072},
tg.FileHash{Offset: 786432, Limit: 131072},
tg.FileHash{Offset: 917504, Limit: 131072},
}, {
tg.FileHash{Offset: 1048576, Limit: 131072},
tg.FileHash{Offset: 1179648, Limit: 131072},
tg.FileHash{Offset: 1310720, Limit: 131072},
tg.FileHash{Offset: 1441792, Limit: 131072},
tg.FileHash{Offset: 1572864, Limit: 131072},
tg.FileHash{Offset: 1703936, Limit: 131072},
tg.FileHash{Offset: 1835008, Limit: 131072},
tg.FileHash{Offset: 1966080, Limit: 131072},
}, {
tg.FileHash{Offset: 2097152, Limit: 131072},
tg.FileHash{Offset: 2228224, Limit: 131072},
tg.FileHash{Offset: 2359296, Limit: 131072},
tg.FileHash{Offset: 2490368, Limit: 131072},
tg.FileHash{Offset: 2621440, Limit: 131072},
tg.FileHash{Offset: 2752512, Limit: 131072},
tg.FileHash{Offset: 2883584, Limit: 131072},
tg.FileHash{Offset: 3014656, Limit: 131072},
},
}
type mockHashes struct {
ranges [][]tg.FileHash
}
func (m mockHashes) Chunk(ctx context.Context, offset int64, limit int) (chunk, error) {
panic("implement me")
}
func (m mockHashes) Hashes(ctx context.Context, offset int64) ([]tg.FileHash, error) {
for _, r := range m.ranges {
last := r[len(r)-1]
if last.Offset+int64(last.Limit) <= offset {
continue
}
return r, nil
}
return m.ranges[len(m.ranges)-1], nil
}
func TestVerifier(t *testing.T) {
ctx := context.Background()
tests := []struct {
name string
ranges [][]tg.FileHash
// Hashes returned from CDN redirect, for example.
predefined []tg.FileHash
expected [][]tg.FileHash
}{
{"NoPredefined", hashRanges, nil, hashRanges},
{"Predefined", hashRanges[1:], hashRanges[0], hashRanges},
{"OnlyPredefined", hashRanges[:1], hashRanges[0], hashRanges[:1]},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
a := require.New(t)
m := mockHashes{ranges: test.ranges}
v := newVerifier(m, test.predefined...)
hashes := make([]tg.FileHash, 0, len(test.predefined))
for {
hash, ok, err := v.next(ctx)
a.NoError(err)
if !ok {
break
}
hashes = append(hashes, hash)
}
i := 0
for _, hashRange := range test.expected {
for _, expected := range hashRange {
a.Equal(expected, hashes[i])
i++
}
}
})
}
}
+38
View File
@@ -0,0 +1,38 @@
package downloader
import (
"context"
"github.com/go-faster/errors"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
)
var errHashesNotSupported = errors.New("this schema does not support hashes fetch")
// web is a web file download schema.
// See https://core.telegram.org/api/files#downloading-webfiles.
type web struct {
client Client
location tg.InputWebFileLocationClass
}
var _ schema = web{}
func (w web) Chunk(ctx context.Context, offset int64, limit int) (chunk, error) {
file, err := w.client.UploadGetWebFile(ctx, &tg.UploadGetWebFileRequest{
Location: w.location,
Offset: int(offset),
Limit: limit,
})
if err != nil {
return chunk{}, err
}
return chunk{data: file.Bytes, tag: file.FileType}, nil
}
func (w web) Hashes(ctx context.Context, offset int64) ([]tg.FileHash, error) {
return nil, errHashesNotSupported
}
+15
View File
@@ -0,0 +1,15 @@
package telegram
import (
"go.mau.fi/mautrix-telegram/pkg/gotd/tgerr"
)
// ErrFloodWait is error type of "FLOOD_WAIT" error.
const ErrFloodWait = tgerr.ErrFloodWait
// AsFloodWait returns wait duration and true boolean if err is
// the "FLOOD_WAIT" error.
//
// Client should wait for that duration before issuing new requests with
// same method.
var AsFloodWait = tgerr.AsFloodWait
+22
View File
@@ -0,0 +1,22 @@
package telegram_test
import (
"testing"
"time"
"github.com/go-faster/errors"
"github.com/stretchr/testify/assert"
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram"
"go.mau.fi/mautrix-telegram/pkg/gotd/tgerr"
)
func TestAsFloodWait(t *testing.T) {
err := func() error {
return errors.Wrap(tgerr.New(400, "FLOOD_WAIT_10"), "perform operation")
}()
d, ok := telegram.AsFloodWait(err)
assert.True(t, ok)
assert.Equal(t, time.Second*10, d)
}
+47
View File
@@ -0,0 +1,47 @@
package telegram
import (
"fmt"
"github.com/go-faster/errors"
"go.uber.org/zap"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
)
func (c *Client) updateInterceptor(updates ...tg.UpdateClass) {
for _, update := range updates {
switch update.(type) {
case *tg.UpdateConfig, *tg.UpdateDCOptions:
c.fetchConfig(c.ctx)
}
}
}
func (c *Client) processUpdates(updates tg.UpdatesClass) error {
switch u := updates.(type) {
case *tg.Updates:
c.updateInterceptor(u.Updates...)
return c.updateHandler.Handle(c.ctx, u)
case *tg.UpdatesCombined:
c.updateInterceptor(u.Updates...)
return c.updateHandler.Handle(c.ctx, u)
case *tg.UpdateShort:
c.updateInterceptor(u.Update)
return c.updateHandler.Handle(c.ctx, u)
case *tg.UpdateShortMessage, *tg.UpdateShortChatMessage, *tg.UpdateShortSentMessage, *tg.UpdatesTooLong:
return c.updateHandler.Handle(c.ctx, u)
default:
c.log.Warn("Ignoring update", zap.String("update_type", fmt.Sprintf("%T", u)))
}
return nil
}
func (c *Client) handleUpdates(b *bin.Buffer) error {
updates, err := tg.DecodeUpdates(b)
if err != nil {
return errors.Wrap(err, "decode updates")
}
return c.processUpdates(updates)
}
+40
View File
@@ -0,0 +1,40 @@
package telegram
import (
"context"
"testing"
"github.com/stretchr/testify/require"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
)
type mockHandler struct {
LastUpdate tg.UpdatesClass
}
func (m *mockHandler) Handle(ctx context.Context, u tg.UpdatesClass) error {
m.LastUpdate = u
return nil
}
func TestClient_processUpdates(t *testing.T) {
msg := &tg.Message{
ID: 1,
}
upd := &tg.Updates{
Updates: []tg.UpdateClass{&tg.UpdateNewMessage{
Message: msg,
}},
}
t.Run("Handle", func(t *testing.T) {
mock := &mockHandler{}
c := new(Client)
c.updateHandler = mock
err := c.processUpdates(upd)
require.NoError(t, err)
require.Equal(t, upd, mock.LastUpdate)
})
}
@@ -0,0 +1,191 @@
// Package deeplink contains deeplink parsing helpers.
package deeplink
import (
"net/url"
"strings"
"github.com/go-faster/errors"
)
// Type is an enum type of Telegram deeplinks types.
type Type string
const (
// Resolve is deeplink like
//
// tg:resolve?domain={domain}
// tg://resolve?domain={domain}
// https://t.me/{domain}
// https://telegram.me/{domain}
//
Resolve Type = "resolve"
// Join is deeplink like
//
// tg:join?invite={hash}
// tg://join?invite={hash}
// https://t.me/joinchat/{hash}
// https://telegram.me/joinchat/{hash}
// t.me/+{hash}
//
Join Type = "join"
)
// DeepLink represents Telegram deeplink.
type DeepLink struct {
Type Type
Args url.Values
}
func ensureParam(query url.Values, key string) error {
if query.Get(key) == "" {
return errors.Errorf("should have %q query parameter", key)
}
return nil
}
func (d DeepLink) validate() error {
switch d.Type {
case Resolve:
return ensureParam(d.Args, "domain")
case Join:
return ensureParam(d.Args, "invite")
default:
return errors.Errorf("unsupported deeplink %q", d.Type)
}
}
func parseTg(u *url.URL) (DeepLink, error) {
query := u.Query()
switch Type(u.Hostname()) {
case Resolve:
return DeepLink{
Type: Resolve,
Args: query,
}, nil
case Join:
return DeepLink{
Type: Join,
Args: query,
}, nil
}
return DeepLink{}, errors.Errorf("unsupported deeplink %q", u.String())
}
func parseHTTPS(u *url.URL) (DeepLink, error) {
cleanInviteHash := func(root string) string {
hash := strings.Trim(root, "+ ")
if u.RawPath == "" {
hash = url.PathEscape(hash)
}
return hash
}
query := url.Values{}
p := strings.TrimPrefix(u.Path, "/")
p = strings.TrimSuffix(p, "/")
split := strings.Split(p, "/")
var (
root = split[0]
base string
)
if len(split) > 1 {
base = split[1]
}
switch root {
case "joinchat":
query.Set("invite", cleanInviteHash(base))
return DeepLink{
Type: Join,
Args: query,
}, nil
case "":
return DeepLink{}, errors.Errorf("unsupported deeplink %q", u.String())
}
switch root[0] {
case ' ', '+':
query.Set("invite", cleanInviteHash(root))
return DeepLink{
Type: Join,
Args: query,
}, nil
default:
if err := ValidateDomain(root); err != nil {
return DeepLink{}, err
}
query.Set("domain", root)
return DeepLink{
Type: Resolve,
Args: query,
}, nil
}
}
func hasTelegramPrefix(link string) bool {
return strings.HasPrefix(link, "t.me") ||
strings.HasPrefix(link, "telegram.me") ||
strings.HasPrefix(link, "telegram.dog")
}
// IsDeeplinkLike returns true if string may be a valid deeplink.
func IsDeeplinkLike(link string) bool {
return strings.HasPrefix(link, "tg:") ||
hasTelegramPrefix(link) ||
strings.HasPrefix(link, "https://")
}
// Parse parses and returns deeplink.
func Parse(link string) (DeepLink, error) {
switch {
// Normalize case like t.me/gotd.
case hasTelegramPrefix(link):
link = strings.TrimSuffix("https://"+link, "/")
// Normalize case like tg:resolve?domain=gotd.
case !strings.HasPrefix(link, "tg://") && strings.HasPrefix(link, "tg:"):
link = "tg://" + strings.TrimPrefix(link, "tg:")
}
u, err := url.Parse(link)
if err != nil {
return DeepLink{}, errors.Wrapf(err, "invalid URL %q", link)
}
var d DeepLink
switch {
case u.Scheme == "https":
switch strings.TrimPrefix(u.Hostname(), "www.") {
case "t.me", "telegram.me", "telegram.dog":
d, err = parseHTTPS(u)
default:
return DeepLink{}, errors.Errorf("invalid domain %q", link)
}
case u.Scheme == "tg":
d, err = parseTg(u)
default:
return DeepLink{}, errors.Errorf("invalid deeplink %q", link)
}
if err != nil {
return DeepLink{}, err
}
if err := d.validate(); err != nil {
return DeepLink{}, err
}
return d, nil
}
// Expect parses deeplink and check type its type.
func Expect(link string, typ Type) (DeepLink, error) {
l, err := Parse(link)
if err != nil {
return l, err
}
if l.Type != typ {
return l, errors.Errorf("unexpected deeplink type %q", l.Type)
}
return l, nil
}
@@ -0,0 +1,28 @@
//go:build go1.18
package deeplink
import (
"testing"
)
func addSuites(f *testing.F, suites map[string][]testCase) {
for _, suite := range suites {
for _, test := range suite {
f.Add(test.input)
}
}
}
func FuzzParse(f *testing.F) {
for _, typeSuite := range typeSuites {
addSuites(f, typeSuite)
}
f.Fuzz(func(t *testing.T, link string) {
_, err := Parse(link)
if err != nil {
t.Skip(err)
}
})
}
@@ -0,0 +1,158 @@
package deeplink
import (
"fmt"
"testing"
"github.com/stretchr/testify/require"
)
type testCase struct {
link DeepLink
input string
wantErr bool
}
func join(arg string) DeepLink {
return DeepLink{
Type: Join,
Args: map[string][]string{
"invite": {arg},
},
}
}
func resolve(arg string) DeepLink {
return DeepLink{
Type: Resolve,
Args: map[string][]string{
"domain": {arg},
},
}
}
func joinSuite() map[string][]testCase {
expect := join("AAAAAAAAAAAAAAAAAA")
return map[string][]testCase{
"Test": {
{expect, `t.me/joinchat/AAAAAAAAAAAAAAAAAA`, false},
{expect, `t.me/joinchat/AAAAAAAAAAAAAAAAAA/`, false},
{expect, `t.me/+AAAAAAAAAAAAAAAAAA`, false},
{expect, `t.me/+AAAAAAAAAAAAAAAAAA/`, false},
{expect, `t.me/ +AAAAAAAAAAAAAAAAAA/`, false},
{expect, `https://t.me/joinchat/AAAAAAAAAAAAAAAAAA`, false},
{expect, `https://t.me/joinchat/AAAAAAAAAAAAAAAAAA/`, false},
{expect, `tg:join?invite=AAAAAAAAAAAAAAAAAA`, false},
{expect, `tg://join?invite=AAAAAAAAAAAAAAAAAA`, false},
{DeepLink{}, `https://t.co/joinchat/AAAAAAAAAAAAAAAAAA`, true},
{DeepLink{}, `rt://join?invite=AAAAAAAAAAAAAAAAAA`, true},
},
"TDLib": {
// t.me/+<hash>
// Positive
{join("aba%20aba"), "t.me/+aba%20aba", false},
{join("aba0aba"), "t.me/+aba%30aba", false},
{join("123456a"), "t.me/+123456a", false},
{join("12345678901"), "t.me/%2012345678901", false},
// Negative
{DeepLink{}, "t.me/+?invite=abcdef", true},
{DeepLink{}, "t.me/+", true},
{DeepLink{}, "t.me/+/abcdef", true},
{DeepLink{}, "t.me/ ?/abcdef", true},
{DeepLink{}, "t.me/+?abcdef", true},
{DeepLink{}, "t.me/+#abcdef", true},
{DeepLink{}, "t.me/ /123456/123123/12/31/a/s//21w/?asdas#test", true},
// t.me/joinchat/<hash>
// Positive
{join("abacaba"), "t.me/joinchat/abacaba", false},
{join("aba%20aba"), "t.me/joinchat/aba%20aba", false},
{join("aba0aba"), "t.me/joinchat/aba%30aba", false},
{join("123456a"), "t.me/joinchat/123456a", false},
{join("12345678901"), "t.me/joinchat/12345678901", false},
{join("123456"), "t.me/joinchat/123456", false},
{join("123456"), "t.me/joinchat/123456/123123/12/31/a/s//21w/?asdas#test", false},
// Negative
{DeepLink{}, "t.me/joinchat?invite=abcdef", true},
{DeepLink{}, "t.me/joinchat", true},
{DeepLink{}, "t.me/joinchat/", true},
{DeepLink{}, "t.me/joinchat//abcdef", true},
{DeepLink{}, "t.me/joinchat?/abcdef", true},
{DeepLink{}, "t.me/joinchat/?abcdef", true},
{DeepLink{}, "t.me/joinchat/#abcdef", true},
},
}
}
func resolveSuite() map[string][]testCase {
expect := resolve("gotd_ru")
return map[string][]testCase{
"Test": {
{expect, `t.me/gotd_ru`, false},
{expect, `t.me/gotd_ru/`, false},
{expect, `https://t.me/gotd_ru`, false},
{expect, `https://t.me/gotd_ru/`, false},
{expect, `tg:resolve?domain=gotd_ru`, false},
{expect, `tg:resolve?&&&&&&&domain=gotd_ru`, false},
{expect, `tg://resolve?domain=gotd_ru`, false},
{DeepLink{}, `https://t.co/gotd_ru`, true},
{DeepLink{}, `rt://join?invite=AAAAAAAAAAAAAAAAAA`, true},
},
"TDLib": {
// t.me/<domain>
// Positive
{resolve("a"), "t.me/a", false},
{resolve("abcdefghijklmnopqrstuvwxyz123456"), "t.me/abcdefghijklmnopqrstuvwxyz123456", false},
{resolve("Aasdf"), "t.me/Aasdf", false},
{resolve("asdf0"), "t.me/asdf0", false},
{resolve("username"), "t.me/username/0/a//s/as?gam=asd", false},
{resolve("username"), "t.me/username/aasdas?test=1", false},
{resolve("username"), "t.me/username/0", false},
{resolve("telecram"), "https://telegram.dog/tele%63ram", false},
// Negative
{DeepLink{}, "t.me/abcdefghijklmnopqrstuvwxyz1234567", true},
{DeepLink{}, "t.me/abcdefghijklmnop-qrstuvwxyz", true},
{DeepLink{}, "t.me/abcdefghijklmnop~qrstuvwxyz", true},
{DeepLink{}, "t.me/_asdf", true},
{DeepLink{}, "t.me/0asdf", true},
{DeepLink{}, "t.me/9asdf", true},
{DeepLink{}, "t.me/asdf_", true},
{DeepLink{}, "t.me/asd__fg", true},
{DeepLink{}, "t.me//username", true},
},
}
}
var typeSuites = map[string]map[string][]testCase{
"Join": joinSuite(),
"Resolve": resolveSuite(),
}
func TestParseDeeplink(t *testing.T) {
runSuite := func(suite []testCase) func(t *testing.T) {
return func(t *testing.T) {
for i, test := range suite {
t.Run(fmt.Sprintf("Test%d (%s)", i, test.input), func(t *testing.T) {
a := require.New(t)
d, err := Parse(test.input)
if test.wantErr {
a.Error(err, test.input)
} else {
a.NoError(err, test.input)
a.Equal(test.link, d, test.input)
}
})
}
}
}
for typeName, typeSuite := range typeSuites {
t.Run(typeName, func(t *testing.T) {
for suiteName, suite := range typeSuite {
t.Run(suiteName, runSuite(suite))
}
})
}
}
@@ -0,0 +1,40 @@
package deeplink
import (
"github.com/go-faster/errors"
"go.mau.fi/mautrix-telegram/pkg/gotd/ascii"
)
// ValidateDomain validate given domain (user) name
func ValidateDomain(domain string) error {
return checkDomainSymbols(domain)
}
// checkDomainSymbols check that domain contains only a-z, A-Z, 0-9 and '_'
// symbols.
func checkDomainSymbols(domain string) error {
switch {
case domain == "":
return errors.New("is empty")
case len(domain) > 32:
return errors.New("is too big")
case !ascii.IsLatinLower(rune(domain[0])):
return errors.New("must start with lower letter")
case domain[len(domain)-1] == '_':
return errors.New("must not end with '_'")
}
for i, r := range domain {
switch {
case !ascii.IsLatinLetter(r) && !ascii.IsDigit(r) && r != '_':
case i > 0 && domain[i] == '_' && domain[i] == domain[i-1]:
default:
continue
}
return errors.Errorf("unexpected %c at %d", r, i)
}
return nil
}
@@ -0,0 +1,38 @@
package deeplink
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestValidateDomain(t *testing.T) {
tests := []struct {
domain string
wantErr bool
}{
{"a", false},
{"abcdefghijklmnopqrstuvwxyz123456", false},
{"Aasdf", false},
{"asdf0", false},
{"", true},
{"asdf_", true},
{"asd__fg", true},
{"_asdf", true},
{"0asdf", true},
{"9asdf", true},
{"abcdefghijklmnopqrstuvwxyz1234567", true},
{"abcdefghijklmnop-qrstuvwxyz", true},
{"abcdefghijklmnop~qrstuvwxyz", true},
}
for _, tt := range tests {
t.Run(tt.domain, func(t *testing.T) {
err := ValidateDomain(tt.domain)
if tt.wantErr {
require.Error(t, err)
} else {
require.NoError(t, err)
}
})
}
}
@@ -0,0 +1,58 @@
package e2etest
import (
"context"
"github.com/cenkalti/backoff/v4"
"github.com/go-faster/errors"
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/auth"
)
func (s *Suite) createFlow(ctx context.Context) (auth.Flow, error) {
var ua auth.UserAuthenticator
for {
ua = auth.Test(s.rand, s.dc)
phone, err := ua.Phone(ctx)
if err != nil {
return auth.Flow{}, err
}
s.usedMux.Lock()
if _, ok := s.used[phone]; !ok {
s.used[phone] = struct{}{}
s.usedMux.Unlock()
break
}
s.usedMux.Unlock()
}
return auth.NewFlow(ua, auth.SendCodeOptions{}), nil
}
// Authenticate authenticates client on test server.
func (s *Suite) Authenticate(ctx context.Context, client auth.FlowClient) error {
for {
flow, err := s.createFlow(ctx)
if err != nil {
return errors.Wrap(err, "create flow")
}
if err := flow.Run(ctx, client); err != nil {
if errors.Is(err, auth.ErrPasswordNotProvided) {
continue
}
return errors.Wrap(err, "run flow")
}
return nil
}
}
// RetryAuthenticate authenticates client on test server.
func (s *Suite) RetryAuthenticate(ctx context.Context, client auth.FlowClient) error {
bck := backoff.WithContext(backoff.NewExponentialBackOff(), ctx)
return backoff.Retry(func() error {
return s.Authenticate(ctx, client)
}, bck)
}
@@ -0,0 +1,61 @@
package e2etest
import (
"context"
"testing"
"github.com/go-faster/errors"
"github.com/stretchr/testify/require"
"go.uber.org/zap/zaptest"
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/auth"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
)
type mockFlow struct {
flag bool
}
var _ auth.FlowClient = &mockFlow{}
func (m *mockFlow) SignIn(context.Context, string, string, string) (*tg.AuthAuthorization, error) {
// Ensure retry.
if !m.flag {
m.flag = true
return nil, auth.ErrPasswordAuthNeeded
}
return m.Password(context.Background(), "")
}
func (m *mockFlow) SendCode(context.Context, string, auth.SendCodeOptions) (tg.AuthSentCodeClass, error) {
return &tg.AuthSentCode{
PhoneCodeHash: "hash",
Type: &tg.AuthSentCodeTypeApp{},
Timeout: 10,
}, nil
}
func (m *mockFlow) Password(context.Context, string) (*tg.AuthAuthorization, error) {
return &tg.AuthAuthorization{
User: &tg.User{
ID: 10,
Username: "aboba",
},
}, nil
}
func (m *mockFlow) SignUp(context.Context, auth.SignUp) (*tg.AuthAuthorization, error) {
return nil, errors.New("must not be called")
}
func TestSuite_Authenticate(t *testing.T) {
ctx := context.Background()
logger := zaptest.NewLogger(t)
s := NewSuite(t, TestOptions{
Logger: logger,
})
flow := &mockFlow{}
require.NoError(t, s.Authenticate(ctx, flow))
}
@@ -0,0 +1,3 @@
// Package e2etest contains some helpers to make external E2E tests
// using Telegram test server.
package e2etest
@@ -0,0 +1,190 @@
package e2etest
import (
"context"
"strconv"
"sync"
"github.com/go-faster/errors"
"go.uber.org/zap"
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram"
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/message"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
"go.mau.fi/mautrix-telegram/pkg/gotd/tgerr"
)
// EchoBot is a simple echo message bot.
type EchoBot struct {
suite *Suite
logger *zap.Logger
auth chan<- *tg.User
}
// NewEchoBot creates new echo bot.
func NewEchoBot(suite *Suite, auth chan<- *tg.User) EchoBot {
return EchoBot{
suite: suite,
logger: suite.logger.Named("echobot"),
auth: auth,
}
}
type users struct {
users map[int64]*tg.User
lock sync.RWMutex
}
func newUsers() *users {
return &users{
users: map[int64]*tg.User{},
}
}
func (m *users) empty() (r bool) {
m.lock.RLock()
r = len(m.users) < 1
m.lock.RUnlock()
return
}
func (m *users) add(list ...tg.UserClass) {
m.lock.Lock()
defer m.lock.Unlock()
tg.UserClassArray(list).FillNotEmptyMap(m.users)
}
func (m *users) get(id int64) (r *tg.User) {
m.lock.RLock()
r = m.users[id]
m.lock.RUnlock()
return
}
func (b EchoBot) login(ctx context.Context, client *telegram.Client) (*tg.User, error) {
if err := b.suite.RetryAuthenticate(ctx, client.Auth()); err != nil {
return nil, errors.Wrap(err, "authenticate")
}
var me *tg.User
if err := retry(ctx, func() (err error) {
me, err = client.Self(ctx)
return err
}); err != nil {
return nil, err
}
expectedUsername := "echobot" + strconv.FormatInt(me.ID, 10)
raw := tg.NewClient(retryInvoker{prev: client})
_, err := raw.AccountUpdateUsername(ctx, expectedUsername)
if err != nil {
if !tgerr.Is(err, tg.ErrUsernameNotModified) {
return nil, errors.Wrap(err, "update username")
}
}
me, err = retryResult(ctx, func() (*tg.User, error) {
return client.Self(ctx)
})
if me.Username != expectedUsername {
return nil, errors.Errorf("expected username %q, got %q", expectedUsername, me.Username)
}
return me, nil
}
func (b EchoBot) handler(client *telegram.Client) tg.NewMessageHandler {
dialogsUsers := newUsers()
raw := tg.NewClient(client)
sender := message.NewSender(raw)
return func(ctx context.Context, entities tg.Entities, update *tg.UpdateNewMessage) error {
if filterMessage(update) {
return nil
}
if m, ok := update.Message.(interface{ GetMessage() string }); ok {
b.logger.Named("dispatcher").
Info("Got new message update", zap.String("message", m.GetMessage()))
}
if dialogsUsers.empty() {
dialogs, err := retryResult(ctx, func() (tg.MessagesDialogsClass, error) {
dialogs, err := raw.MessagesGetDialogs(ctx, &tg.MessagesGetDialogsRequest{
Limit: 100,
OffsetPeer: &tg.InputPeerEmpty{},
})
if err != nil {
return nil, errors.Wrap(err, "get dialogs")
}
return dialogs, nil
})
if err != nil {
return errors.Wrap(err, "get dialogs")
}
if dlg, ok := dialogs.AsModified(); ok {
dialogsUsers.add(dlg.GetUsers()...)
}
}
switch m := update.Message.(type) {
case *tg.Message:
switch peer := m.PeerID.(type) {
case *tg.PeerUser:
user := entities.Users[peer.UserID]
if user == nil {
user = dialogsUsers.get(peer.UserID)
}
b.logger.Info("Got message",
zap.String("text", m.Message),
zap.Int64("user_id", user.ID),
zap.String("user_first_name", user.FirstName),
zap.String("username", user.Username),
)
if err := retry(ctx, func() error {
_, err := sender.To(user.AsInputPeer()).Text(ctx, m.Message)
return err
}); err != nil {
return errors.Wrap(err, "send message")
}
return nil
}
}
return nil
}
}
// Run setups and starts echo bot.
func (b EchoBot) Run(ctx context.Context) error {
dispatcher := tg.NewUpdateDispatcher()
client := b.suite.Client(b.logger, dispatcher)
dispatcher.OnNewMessage(b.handler(client))
return client.Run(ctx, func(ctx context.Context) error {
defer close(b.auth)
me, err := b.login(ctx, client)
if err != nil {
return errors.Wrap(err, "login")
}
b.logger.Info("Logged in",
zap.String("user", me.Username),
zap.Int64("id", me.ID),
)
select {
case b.auth <- me:
case <-ctx.Done():
return ctx.Err()
}
<-ctx.Done()
return nil
})
}
@@ -0,0 +1,25 @@
package e2etest
import (
"strings"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
)
func filterMessage(update *tg.UpdateNewMessage) bool {
if v, ok := update.Message.(interface{ GetOut() bool }); ok && v.GetOut() {
return true
}
if v, ok := update.Message.(interface{ GetPeerID() tg.PeerClass }); ok && v.GetPeerID() == nil {
return true
}
if _, ok := update.Message.(*tg.MessageService); ok {
return true
}
if v, ok := update.Message.(interface{ GetMessage() string }); ok && strings.HasPrefix(v.GetMessage(), "Login code:") {
return true
}
return false
}
@@ -0,0 +1,37 @@
package e2etest
import (
"io"
"go.uber.org/zap"
"go.mau.fi/mautrix-telegram/pkg/gotd/constant"
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
)
// TestOptions contains some common test server settings.
type TestOptions struct {
AppID int
AppHash string
DC int
Random io.Reader
Logger *zap.Logger
}
func (opt *TestOptions) setDefaults() {
if opt.AppID == 0 {
opt.AppID = constant.TestAppID
}
if opt.AppHash == "" {
opt.AppHash = constant.TestAppHash
}
if opt.DC == 0 {
opt.DC = 2
}
if opt.Random == nil {
opt.Random = crypto.DefaultRand()
}
if opt.Logger == nil {
opt.Logger = zap.NewNop()
}
}
@@ -0,0 +1,61 @@
package e2etest
import (
"context"
"time"
"github.com/cenkalti/backoff/v4"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
"go.mau.fi/mautrix-telegram/pkg/gotd/tgerr"
)
type retryInvoker struct {
prev tg.Invoker
}
func retryResult[T any](ctx context.Context, cb func() (T, error)) (T, error) {
var zero T
return backoff.RetryWithData[T](func() (T, error) {
res, err := cb()
if err != nil {
if tgerr.IsCode(err, -500) {
return zero, err
}
if tgerr.Is(err, "CONNECTION_NOT_INITED") {
return zero, err
}
if ok, err := tgerr.FloodWait(ctx, err); ok {
return zero, err
}
return zero, backoff.Permanent(err)
}
return res, nil
}, backoff.WithContext(backoff.NewConstantBackOff(time.Millisecond*500), ctx))
}
func retry(ctx context.Context, cb func() error) error {
return backoff.Retry(func() error {
if err := cb(); err != nil {
if tgerr.IsCode(err, -500) {
return err
}
if tgerr.Is(err, "CONNECTION_NOT_INITED") {
return err
}
if ok, err := tgerr.FloodWait(ctx, err); ok {
return err
}
return backoff.Permanent(err)
}
return nil
}, backoff.WithContext(backoff.NewExponentialBackOff(), ctx))
}
func (w retryInvoker) Invoke(ctx context.Context, input bin.Encoder, output bin.Decoder) error {
return retry(ctx, func() error {
return w.prev.Invoke(ctx, input, output)
})
}
@@ -0,0 +1,50 @@
package e2etest
import (
"io"
"sync"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram"
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/dcs"
)
// Suite is struct which contains external E2E test parameters.
type Suite struct {
TB require.TestingT
appID int
appHash string
dc int
logger *zap.Logger
rand io.Reader
// already used phone numbers
used map[string]struct{}
usedMux sync.Mutex
}
// NewSuite creates new Suite.
func NewSuite(tb require.TestingT, config TestOptions) *Suite {
config.setDefaults()
return &Suite{
TB: tb,
appID: config.AppID,
appHash: config.AppHash,
dc: config.DC,
logger: config.Logger,
rand: config.Random,
used: map[string]struct{}{},
}
}
// Client creates new *telegram.Client using this suite.
func (s *Suite) Client(logger *zap.Logger, handler telegram.UpdateHandler) *telegram.Client {
return telegram.NewClient(s.appID, s.appHash, telegram.Options{
DC: s.dc,
DCList: dcs.Test(),
Logger: logger,
UpdateHandler: handler,
})
}
@@ -0,0 +1,98 @@
package e2etest
import (
"context"
"time"
"github.com/go-faster/errors"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/message"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
"go.mau.fi/mautrix-telegram/pkg/gotd/tgerr"
)
// User is a simple user bot.
type User struct {
suite *Suite
text []string
username string
logger *zap.Logger
message chan string
}
// NewUser creates new User bot.
func NewUser(suite *Suite, text []string, username string) User {
return User{
suite: suite,
text: text,
username: username,
logger: suite.logger.Named("terentyev"),
message: make(chan string, 1),
}
}
func (u User) messageHandler(ctx context.Context, entities tg.Entities, update *tg.UpdateNewMessage) error {
if filterMessage(update) {
return nil
}
if m, ok := update.Message.(interface{ GetMessage() string }); ok {
u.logger.Named("dispatcher").
Info("Got new message update", zap.String("message", m.GetMessage()))
}
msg, ok := update.Message.(*tg.Message)
if !ok {
return errors.Errorf("unexpected type %T", update.Message)
}
select {
case u.message <- msg.Message:
return nil
case <-ctx.Done():
return ctx.Err()
}
}
// Run setups and starts user bot.
func (u User) Run(ctx context.Context) error {
dispatcher := tg.NewUpdateDispatcher()
dispatcher.OnNewMessage(u.messageHandler)
client := u.suite.Client(u.logger, dispatcher)
sender := message.NewSender(tg.NewClient(retryInvoker{prev: client}))
return client.Run(ctx, func(ctx context.Context) error {
if err := u.suite.RetryAuthenticate(ctx, client.Auth()); err != nil {
return errors.Wrap(err, "authenticate")
}
peer, err := sender.Resolve(u.username).AsInputPeer(ctx)
if err != nil {
return errors.Wrapf(err, "resolve bot username %q", u.username)
}
for _, line := range u.text {
time.Sleep(2 * time.Second)
_, err = sender.To(peer).Text(ctx, line)
if flood, err := tgerr.FloodWait(ctx, err); err != nil {
if flood {
continue
}
return err
}
select {
case gotMessage := <-u.message:
require.Equal(u.suite.TB, line, gotMessage)
case <-ctx.Done():
return ctx.Err()
}
}
return nil
})
}
+257
View File
@@ -0,0 +1,257 @@
package manager
import (
"context"
"sync"
"time"
"github.com/cenkalti/backoff/v4"
"github.com/go-faster/errors"
"go.uber.org/zap"
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
"go.mau.fi/mautrix-telegram/pkg/gotd/clock"
"go.mau.fi/mautrix-telegram/pkg/gotd/mtproto"
"go.mau.fi/mautrix-telegram/pkg/gotd/pool"
"go.mau.fi/mautrix-telegram/pkg/gotd/tdsync"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
"go.mau.fi/mautrix-telegram/pkg/gotd/tgerr"
)
type protoConn interface {
Invoke(ctx context.Context, input bin.Encoder, output bin.Decoder) error
Run(ctx context.Context, f func(ctx context.Context) error) error
Ping(ctx context.Context) error
}
//go:generate go run -modfile=../../../_tools/go.mod golang.org/x/tools/cmd/stringer -type=ConnMode
// ConnMode represents connection mode.
type ConnMode byte
const (
// ConnModeUpdates is update connection mode.
ConnModeUpdates ConnMode = iota
// ConnModeData is data connection mode.
ConnModeData
// ConnModeCDN is CDN connection mode.
ConnModeCDN
)
// Conn is a Telegram client connection.
type Conn struct {
// Connection parameters.
mode ConnMode // immutable
// MTProto connection.
proto protoConn // immutable
// InitConnection parameters.
appID int // immutable
device DeviceConfig // immutable
// setup is callback which called after initConnection, but before ready signaling.
// This is necessary to transfer auth from previous connection to another DC.
setup SetupCallback // nilable
// onDead is called on connection death.
onDead func()
// Wrappers for external world, like logs or PRNG.
// Should be immutable.
clock clock.Clock // immutable
log *zap.Logger // immutable
// Handler passed by client.
handler Handler // immutable
// State fields.
cfg tg.Config
ongoing int
latest time.Time
mux sync.Mutex
sessionInit *tdsync.Ready // immutable
gotConfig *tdsync.Ready // immutable
dead *tdsync.Ready // immutable
connBackoff func(ctx context.Context) backoff.BackOff // immutable
}
// OnSession implements mtproto.Handler.
func (c *Conn) OnSession(session mtproto.Session) error {
c.log.Info("SessionInit")
c.sessionInit.Signal()
// Waiting for config, because OnSession can occur before we set config.
select {
case <-c.gotConfig.Ready():
case <-c.dead.Ready():
return nil
}
c.mux.Lock()
cfg := c.cfg
c.mux.Unlock()
return c.handler.OnSession(cfg, session)
}
func (c *Conn) trackInvoke() func() {
start := c.clock.Now()
c.mux.Lock()
defer c.mux.Unlock()
c.ongoing++
c.latest = start
return func() {
c.mux.Lock()
defer c.mux.Unlock()
c.ongoing--
end := c.clock.Now()
c.latest = end
c.log.Debug("Invoke",
zap.Duration("duration", end.Sub(start)),
zap.Int("ongoing", c.ongoing),
)
}
}
// Run initialize connection.
func (c *Conn) Run(ctx context.Context) (err error) {
defer c.dead.Signal()
defer func() {
if err != nil && ctx.Err() == nil {
c.log.Debug("Connection dead", zap.Error(err))
if c.onDead != nil {
c.onDead()
}
}
}()
return c.proto.Run(ctx, func(ctx context.Context) error {
// Signal death on init error. Otherwise connection shutdown
// deadlocks in OnSession that occurs before init fails.
err := c.init(ctx)
if err != nil {
c.dead.Signal()
}
return err
})
}
func (c *Conn) waitSession(ctx context.Context) error {
select {
case <-c.sessionInit.Ready():
return nil
case <-c.dead.Ready():
return pool.ErrConnDead
case <-ctx.Done():
return ctx.Err()
}
}
// Ready returns channel to determine connection readiness.
// Useful for pooling.
func (c *Conn) Ready() <-chan struct{} {
return c.sessionInit.Ready()
}
// Invoke implements Invoker.
func (c *Conn) Invoke(ctx context.Context, input bin.Encoder, output bin.Decoder) error {
// Tracking ongoing invokes.
defer c.trackInvoke()()
if err := c.waitSession(ctx); err != nil {
return errors.Wrap(err, "waitSession")
}
return c.proto.Invoke(ctx, c.wrapRequest(noopDecoder{input}), output)
}
// OnMessage implements mtproto.Handler.
func (c *Conn) OnMessage(b *bin.Buffer) error {
return c.handler.OnMessage(b)
}
type noopDecoder struct {
bin.Encoder
}
func (n noopDecoder) Decode(b *bin.Buffer) error {
return errors.New("not implemented")
}
func (c *Conn) wrapRequest(req bin.Object) bin.Object {
if c.mode != ConnModeUpdates {
return &tg.InvokeWithoutUpdatesRequest{
Query: req,
}
}
return req
}
func (c *Conn) init(ctx context.Context) error {
c.log.Debug("Initializing")
q := c.wrapRequest(&tg.InitConnectionRequest{
APIID: c.appID,
DeviceModel: c.device.DeviceModel,
SystemVersion: c.device.SystemVersion,
AppVersion: c.device.AppVersion,
SystemLangCode: c.device.SystemLangCode,
LangPack: c.device.LangPack,
LangCode: c.device.LangCode,
Proxy: c.device.Proxy,
Params: c.device.Params,
Query: c.wrapRequest(&tg.HelpGetConfigRequest{}),
})
req := c.wrapRequest(&tg.InvokeWithLayerRequest{
Layer: tg.Layer,
Query: q,
})
var cfg tg.Config
if err := backoff.RetryNotify(func() error {
if err := c.proto.Invoke(ctx, req, &cfg); err != nil {
if tgerr.Is(err, tgerr.ErrFloodWait) {
// Server sometimes returns FLOOD_WAIT(0) if you create
// multiple connections in short period of time.
//
// See https://github.com/gotd/td/issues/388.
return errors.Wrap(err, "flood wait")
}
// Not retrying other errors.
return backoff.Permanent(errors.Wrap(err, "invoke"))
}
return nil
}, c.connBackoff(ctx), func(err error, duration time.Duration) {
c.log.Debug("Retrying connection initialization",
zap.Error(err), zap.Duration("duration", duration),
)
}); err != nil {
return errors.Wrap(err, "initConnection")
}
if c.setup != nil {
if err := c.setup(ctx, c); err != nil {
return errors.Wrap(err, "setup")
}
}
c.mux.Lock()
c.latest = c.clock.Now()
c.cfg = cfg
c.mux.Unlock()
c.gotConfig.Signal()
return nil
}
// Ping calls ping for underlying protocol connection.
func (c *Conn) Ping(ctx context.Context) error {
return c.proto.Ping(ctx)
}
@@ -0,0 +1,25 @@
// Code generated by "stringer -type=ConnMode"; DO NOT EDIT.
package manager
import "strconv"
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[ConnModeUpdates-0]
_ = x[ConnModeData-1]
_ = x[ConnModeCDN-2]
}
const _ConnMode_name = "ConnModeUpdatesConnModeDataConnModeCDN"
var _ConnMode_index = [...]uint8{0, 15, 27, 38}
func (i ConnMode) String() string {
if i >= ConnMode(len(_ConnMode_index)-1) {
return "ConnMode(" + strconv.FormatInt(int64(i), 10) + ")"
}
return _ConnMode_name[_ConnMode_index[i]:_ConnMode_index[i+1]]
}
@@ -0,0 +1,96 @@
package manager
import (
"context"
"time"
"github.com/cenkalti/backoff/v4"
"go.mau.fi/mautrix-telegram/pkg/gotd/clock"
"go.mau.fi/mautrix-telegram/pkg/gotd/mtproto"
"go.mau.fi/mautrix-telegram/pkg/gotd/tdsync"
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/auth"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
)
// SetupCallback is an optional setup connection callback.
type SetupCallback = func(ctx context.Context, invoker tg.Invoker) error
// ConnOptions is a Telegram client connection options.
type ConnOptions struct {
DC int
Test bool
Device DeviceConfig
Handler Handler
Setup SetupCallback
OnDead func()
OnAuthError func(error)
Backoff func(ctx context.Context) backoff.BackOff
}
func defaultBackoff(c clock.Clock) func(ctx context.Context) backoff.BackOff {
return func(ctx context.Context) backoff.BackOff {
b := backoff.NewExponentialBackOff()
b.Clock = c
b.MaxElapsedTime = time.Second * 30
b.MaxInterval = time.Second * 5
return backoff.WithContext(b, ctx)
}
}
// setDefaults sets default values.
func (c *ConnOptions) setDefaults(connClock clock.Clock) {
if c.DC == 0 {
c.DC = 2
}
// It's okay to use zero value Test.
c.Device.SetDefaults()
if c.Handler == nil {
c.Handler = NoopHandler{}
}
if c.Backoff == nil {
c.Backoff = defaultBackoff(connClock)
}
}
// CreateConn creates new connection.
func CreateConn(
create mtproto.Dialer,
mode ConnMode,
appID int,
opts mtproto.Options,
connOpts ConnOptions,
) *Conn {
connOpts.setDefaults(opts.Clock)
conn := &Conn{
mode: mode,
appID: appID,
device: connOpts.Device,
clock: opts.Clock,
handler: connOpts.Handler,
sessionInit: tdsync.NewReady(),
gotConfig: tdsync.NewReady(),
dead: tdsync.NewReady(),
setup: connOpts.Setup,
onDead: connOpts.OnDead,
connBackoff: connOpts.Backoff,
}
conn.log = opts.Logger
opts.DC = connOpts.DC
if connOpts.Test {
// New key exchange algorithm requires DC ID and uses mapping like MTProxy.
// +10000 for test DC, *-1 for media-only.
opts.DC += 10000
}
opts.Handler = conn
opts.Logger = conn.log.Named("mtproto")
opts.OnError = func(err error) {
if auth.IsUnauthorized(err) && connOpts.OnAuthError != nil {
connOpts.OnAuthError(err)
}
}
conn.proto = mtproto.New(create, opts)
return conn
}
@@ -0,0 +1,61 @@
package manager
import (
"runtime"
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/internal/version"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
)
// DeviceConfig is config which send when Telegram connection session created.
type DeviceConfig struct {
// Device model.
DeviceModel string
// Operating system version.
SystemVersion string
// Application version.
AppVersion string
// Code for the language used on the device's OS, ISO 639-1 standard.
SystemLangCode string
// Language pack to use.
LangPack string
// Code for the language used on the client, ISO 639-1 standard.
LangCode string
// Info about an MTProto proxy.
Proxy tg.InputClientProxy
// Additional initConnection parameters. For now, only the tz_offset field is supported,
// for specifying timezone offset in seconds.
Params tg.JSONValueClass
}
// SetDefaults sets default values.
func (c *DeviceConfig) SetDefaults() {
const notAvailable = "n/a"
// Strings must be non-empty, so set notAvailable if default value is empty.
set := func(to *string, value string) {
if value != "" {
*to = value
} else {
*to = notAvailable
}
}
if c.DeviceModel == "" {
set(&c.DeviceModel, runtime.Version())
}
if c.SystemVersion == "" {
set(&c.SystemVersion, runtime.GOOS)
}
if c.AppVersion == "" {
set(&c.AppVersion, version.GetVersion())
}
if c.SystemLangCode == "" {
c.SystemLangCode = "en"
}
if c.LangCode == "" {
c.LangCode = "en"
}
// It's okay to use zero value Proxy.
// It's okay to use zero value Params.
}
@@ -0,0 +1,2 @@
// Package manager contains connection management utilities.
package manager
@@ -0,0 +1,26 @@
package manager
import (
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
"go.mau.fi/mautrix-telegram/pkg/gotd/mtproto"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
)
// Handler abstracts updates and session handler.
type Handler interface {
OnSession(cfg tg.Config, s mtproto.Session) error
OnMessage(b *bin.Buffer) error
}
// NoopHandler is a noop handler.
type NoopHandler struct{}
// OnSession implements Handler.
func (n NoopHandler) OnSession(cfg tg.Config, s mtproto.Session) error {
return nil
}
// OnMessage implements Handler
func (n NoopHandler) OnMessage(b *bin.Buffer) error {
return nil
}

Some files were not shown because too many files have changed in this diff Show More