move gotd fork into repo. (#111)
- update to latest telegram layer - remove some references to fields in tg.Entities that don't exist in the schema - originally added here: https://github.com/beeper/td/commit/820929062a2ba0104397bc01235ab58a9cff780e - referenced here - https://github.com/mautrix/telegramgo/commit/124f0967ed195b5a380c9bd02e170ada9710dde3 - https://github.com/mautrix/telegramgo/commit/4205047aab2e0639217148b5d125bfaab668bd8e
This commit is contained in:
@@ -0,0 +1,81 @@
|
||||
package e2e
|
||||
|
||||
import "go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
|
||||
// Entities contains update entities.
|
||||
type Entities struct {
|
||||
Users map[int64]*tg.User
|
||||
Chats map[int64]*tg.Chat
|
||||
Channels map[int64]*tg.Channel
|
||||
ChannelsForbidden map[int64]*tg.ChannelForbidden
|
||||
}
|
||||
|
||||
// NewEntities creates new Entities.
|
||||
func NewEntities() *Entities {
|
||||
return &Entities{
|
||||
Users: map[int64]*tg.User{},
|
||||
Chats: map[int64]*tg.Chat{},
|
||||
Channels: map[int64]*tg.Channel{},
|
||||
ChannelsForbidden: map[int64]*tg.ChannelForbidden{},
|
||||
}
|
||||
}
|
||||
|
||||
// Merge merges entities.
|
||||
func (e *Entities) Merge(from *Entities) {
|
||||
if from == nil {
|
||||
return
|
||||
}
|
||||
|
||||
for userID, user := range from.Users {
|
||||
e.Users[userID] = user
|
||||
}
|
||||
|
||||
for chanID, chat := range from.Chats {
|
||||
e.Chats[chanID] = chat
|
||||
}
|
||||
|
||||
for channelID, channel := range from.Channels {
|
||||
e.Channels[channelID] = channel
|
||||
}
|
||||
|
||||
for channelID, channel := range from.ChannelsForbidden {
|
||||
e.ChannelsForbidden[channelID] = channel
|
||||
}
|
||||
}
|
||||
|
||||
// FromUpdates method.
|
||||
func (e *Entities) FromUpdates(u interface {
|
||||
tg.UpdatesClass
|
||||
MapUsers() tg.UserClassArray
|
||||
MapChats() tg.ChatClassArray
|
||||
}) *Entities {
|
||||
u.MapChats().FillChatMap(e.Chats)
|
||||
u.MapChats().FillChannelMap(e.Channels)
|
||||
u.MapChats().FillChannelForbiddenMap(e.ChannelsForbidden)
|
||||
u.MapUsers().FillUserMap(e.Users)
|
||||
return e
|
||||
}
|
||||
|
||||
// AsUsers returns users as tg.UserClass slice.
|
||||
func (e *Entities) AsUsers() []tg.UserClass {
|
||||
var users []tg.UserClass
|
||||
for _, u := range e.Users {
|
||||
users = append(users, u)
|
||||
}
|
||||
return users
|
||||
}
|
||||
|
||||
// AsChats returns chats as tg.ChatClass slice.
|
||||
func (e *Entities) AsChats() []tg.ChatClass {
|
||||
var chats []tg.ChatClass
|
||||
for _, c := range e.Chats {
|
||||
chats = append(chats, c)
|
||||
}
|
||||
for _, c := range e.Channels {
|
||||
chats = append(chats, c)
|
||||
}
|
||||
for _, c := range e.ChannelsForbidden {
|
||||
chats = append(chats, c)
|
||||
}
|
||||
return chats
|
||||
}
|
||||
@@ -0,0 +1,60 @@
|
||||
package e2e
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
// Handler handles updates.
|
||||
type handler struct {
|
||||
messages *messageDatabase
|
||||
ents *Entities
|
||||
mux sync.Mutex
|
||||
}
|
||||
|
||||
func newHandler() *handler {
|
||||
return &handler{
|
||||
messages: &messageDatabase{
|
||||
channels: make(map[int64][]tg.MessageClass),
|
||||
},
|
||||
ents: NewEntities(),
|
||||
}
|
||||
}
|
||||
|
||||
func (h *handler) Handle(ctx context.Context, u tg.UpdatesClass) error {
|
||||
switch u := u.(type) {
|
||||
case *tg.Updates:
|
||||
return h.handleUpdates(NewEntities().FromUpdates(u), u.Updates)
|
||||
case *tg.UpdatesCombined:
|
||||
return h.handleUpdates(NewEntities().FromUpdates(u), u.Updates)
|
||||
default:
|
||||
panic(u)
|
||||
}
|
||||
}
|
||||
|
||||
// HandleUpdates handler.
|
||||
func (h *handler) handleUpdates(ents *Entities, upds []tg.UpdateClass) error {
|
||||
h.mux.Lock()
|
||||
defer h.mux.Unlock()
|
||||
|
||||
h.ents.Merge(ents)
|
||||
for _, u := range upds {
|
||||
switch u := u.(type) {
|
||||
case *tg.UpdateNewMessage:
|
||||
h.messages.common = append(h.messages.common, u.Message)
|
||||
case *tg.UpdateNewEncryptedMessage:
|
||||
h.messages.secret = append(h.messages.secret, u.Message)
|
||||
case *tg.UpdateNewChannelMessage:
|
||||
channelID := u.Message.(*tg.Message).PeerID.(*tg.PeerChannel).ChannelID
|
||||
msgs := h.messages.channels[channelID]
|
||||
msgs = append(msgs, u.Message)
|
||||
h.messages.channels[channelID] = msgs
|
||||
default:
|
||||
panic("unexpected update type")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,211 @@
|
||||
//go:build linux
|
||||
|
||||
package e2e
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap/zaptest"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/updates"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
func TestE2E(t *testing.T) {
|
||||
testManager(t, func(s *server, storage updates.StateStorage) chan *tg.Updates {
|
||||
t.Helper()
|
||||
|
||||
c := make(chan *tg.Updates, 10)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
var (
|
||||
biba = s.peers.createUser("biba")
|
||||
boba = s.peers.createUser("boba")
|
||||
chat = s.peers.createChat("chat")
|
||||
)
|
||||
|
||||
var channels []*tg.PeerChannel
|
||||
require.NoError(t, storage.ForEachChannels(ctx, 123, func(ctx context.Context, channelID int64, pts int) error {
|
||||
channels = append(channels, &tg.PeerChannel{
|
||||
ChannelID: channelID,
|
||||
})
|
||||
return nil
|
||||
}))
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
// Biba.
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for i := 0; i < 3; i++ {
|
||||
c <- s.CreateEvent(func(ev *EventBuilder) {
|
||||
ev.SendMessage(biba, chat, fmt.Sprintf("biba-%d", i))
|
||||
|
||||
for mi, c := range channels {
|
||||
ev.SendMessage(biba, c, fmt.Sprintf("biba-channel-%d-%d", i, mi))
|
||||
}
|
||||
})
|
||||
}
|
||||
}()
|
||||
|
||||
// Boba.
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for i := 0; i < 3; i++ {
|
||||
c <- s.CreateEvent(func(ev *EventBuilder) {
|
||||
ev.SendMessage(boba, chat, fmt.Sprintf("boba-%d", i))
|
||||
|
||||
for _, c := range channels {
|
||||
ev.SendMessage(boba, c, fmt.Sprintf("boba-channel-%d", i))
|
||||
}
|
||||
})
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(c)
|
||||
}()
|
||||
return c
|
||||
})
|
||||
}
|
||||
|
||||
func testManager(t *testing.T, f func(s *server, storage updates.StateStorage) chan *tg.Updates) {
|
||||
t.Helper()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
var (
|
||||
log = zaptest.NewLogger(t)
|
||||
s = newServer()
|
||||
h = newHandler()
|
||||
storage = newMemStorage()
|
||||
hasher = newMemAccessHasher()
|
||||
)
|
||||
|
||||
const uid = 123
|
||||
|
||||
require.NoError(t, storage.SetState(ctx, uid, updates.State{
|
||||
Pts: 0,
|
||||
Qts: 0,
|
||||
Date: 0,
|
||||
Seq: 0,
|
||||
}))
|
||||
|
||||
for i := 0; i < 2; i++ {
|
||||
c := s.peers.createChannel(fmt.Sprintf("channel-%d", i))
|
||||
require.NoError(t, storage.SetChannelPts(ctx, uid, c.ChannelID, 0))
|
||||
require.NoError(t, hasher.SetChannelAccessHash(ctx, uid, c.ChannelID, c.ChannelID*2))
|
||||
}
|
||||
|
||||
e := updates.New(updates.Config{
|
||||
Handler: h,
|
||||
Logger: log.Named("gaps"),
|
||||
Storage: storage,
|
||||
AccessHasher: hasher,
|
||||
})
|
||||
|
||||
uchan := loss(f(s, storage))
|
||||
g, ctx := errgroup.WithContext(ctx)
|
||||
ready := make(chan struct{})
|
||||
opts := updates.AuthOptions{
|
||||
OnStart: func(ctx context.Context) {
|
||||
t.Log("OnStart")
|
||||
close(ready)
|
||||
},
|
||||
}
|
||||
g.Go(func() error {
|
||||
t.Log("Starting manager")
|
||||
defer t.Log("Manager stopped")
|
||||
return e.Run(ctx, s, uid, opts)
|
||||
})
|
||||
g.Go(func() error {
|
||||
t.Log("Starting updates generator")
|
||||
defer t.Log("Updates generator stopped")
|
||||
|
||||
defer cancel()
|
||||
|
||||
select {
|
||||
case <-ready:
|
||||
t.Log("Ready")
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
var g errgroup.Group
|
||||
for i := 0; i < 2; i++ {
|
||||
g.Go(func() error {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case u, ok := <-uchan:
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := e.Handle(ctx, u); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
t.Log("Waiting")
|
||||
if err := g.Wait(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
t.Log("Sending pts changed")
|
||||
|
||||
ups := []tg.UpdateClass{&tg.UpdatePtsChanged{}}
|
||||
if err := storage.ForEachChannels(ctx, uid, func(ctx context.Context, channelID int64, pts int) error {
|
||||
ups = append(ups, &tg.UpdateChannelTooLong{ChannelID: channelID})
|
||||
return nil
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
t.Log("Handle")
|
||||
|
||||
return e.Handle(ctx, &tg.Updates{
|
||||
Updates: ups,
|
||||
})
|
||||
})
|
||||
|
||||
t.Log("Waiting for shutdown")
|
||||
require.ErrorIs(t, g.Wait(), context.Canceled)
|
||||
|
||||
t.Log("Checking")
|
||||
require.Equal(t, s.messages, h.messages)
|
||||
require.Equal(t, s.peers.channels, h.ents.Channels)
|
||||
require.Equal(t, s.peers.chats, h.ents.Chats)
|
||||
require.Equal(t, s.peers.users, h.ents.Users)
|
||||
}
|
||||
|
||||
func loss(in chan *tg.Updates) chan *tg.Updates {
|
||||
out := make(chan *tg.Updates)
|
||||
|
||||
go func() {
|
||||
defer close(out)
|
||||
|
||||
for u := range in {
|
||||
if rand.Intn(5) == 1 {
|
||||
continue
|
||||
}
|
||||
|
||||
out <- u
|
||||
}
|
||||
}()
|
||||
|
||||
return out
|
||||
}
|
||||
@@ -0,0 +1,48 @@
|
||||
package e2e
|
||||
|
||||
import "go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
|
||||
type messageDatabase struct {
|
||||
common []tg.MessageClass
|
||||
secret []tg.EncryptedMessageClass
|
||||
channels map[int64][]tg.MessageClass
|
||||
}
|
||||
|
||||
type peerDatabase struct {
|
||||
users map[int64]*tg.User
|
||||
chats map[int64]*tg.Chat
|
||||
channels map[int64]*tg.Channel
|
||||
|
||||
id int64
|
||||
}
|
||||
|
||||
func (p *peerDatabase) createUser(username string) *tg.PeerUser {
|
||||
p.users[p.id] = &tg.User{
|
||||
ID: p.id,
|
||||
Username: username,
|
||||
}
|
||||
|
||||
defer func() { p.id++ }()
|
||||
return &tg.PeerUser{UserID: p.id}
|
||||
}
|
||||
|
||||
func (p *peerDatabase) createChat(title string) *tg.PeerChat {
|
||||
p.chats[p.id] = &tg.Chat{
|
||||
ID: p.id,
|
||||
Title: title,
|
||||
}
|
||||
|
||||
defer func() { p.id++ }()
|
||||
return &tg.PeerChat{ChatID: p.id}
|
||||
}
|
||||
|
||||
func (p *peerDatabase) createChannel(username string) *tg.PeerChannel {
|
||||
p.channels[p.id] = &tg.Channel{
|
||||
ID: p.id,
|
||||
Username: username,
|
||||
}
|
||||
p.channels[p.id].SetAccessHash(p.id * 2)
|
||||
|
||||
defer func() { p.id++ }()
|
||||
return &tg.PeerChannel{ChannelID: p.id}
|
||||
}
|
||||
@@ -0,0 +1,190 @@
|
||||
// Package e2e contains end-to-end updates processing test.
|
||||
package e2e
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
// Server for testing gaps.
|
||||
type server struct {
|
||||
date int
|
||||
peers *peerDatabase
|
||||
messages *messageDatabase
|
||||
|
||||
mux sync.Mutex
|
||||
}
|
||||
|
||||
// NewServer creates new test server.
|
||||
func newServer() *server {
|
||||
return &server{
|
||||
date: 1,
|
||||
peers: &peerDatabase{
|
||||
users: make(map[int64]*tg.User),
|
||||
chats: make(map[int64]*tg.Chat),
|
||||
channels: make(map[int64]*tg.Channel),
|
||||
},
|
||||
messages: &messageDatabase{
|
||||
channels: make(map[int64][]tg.MessageClass),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// UpdatesGetState returns current remote state.
|
||||
func (s *server) UpdatesGetState(ctx context.Context) (*tg.UpdatesState, error) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
return &tg.UpdatesState{
|
||||
Pts: len(s.messages.common),
|
||||
Qts: len(s.messages.secret),
|
||||
Date: s.date,
|
||||
Seq: 0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// UpdatesGetDifference returns difference between local and remote states.
|
||||
func (s *server) UpdatesGetDifference(ctx context.Context, request *tg.UpdatesGetDifferenceRequest) (tg.UpdatesDifferenceClass, error) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
ents := NewEntities()
|
||||
var common []tg.MessageClass
|
||||
for i := request.Pts + 1; i <= len(s.messages.common); i++ {
|
||||
common = append(common, s.messages.common[i-1])
|
||||
s.fillMessageEnts(s.messages.common[i-1], ents)
|
||||
}
|
||||
|
||||
var secret []tg.EncryptedMessageClass
|
||||
for i := request.Qts + 1; i <= len(s.messages.secret); i++ {
|
||||
secret = append(secret, s.messages.secret[i-1])
|
||||
}
|
||||
|
||||
var others []tg.UpdateClass
|
||||
for _, msgs := range s.messages.channels {
|
||||
for i, msg := range msgs {
|
||||
if msg.(*tg.Message).Date > request.Date {
|
||||
others = append(others, &tg.UpdateNewChannelMessage{
|
||||
Message: msg,
|
||||
Pts: i + 1,
|
||||
PtsCount: 1,
|
||||
})
|
||||
s.fillMessageEnts(msg, ents)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(common) == 0 && len(secret) == 0 && len(others) == 0 {
|
||||
return &tg.UpdatesDifferenceEmpty{
|
||||
Date: s.date,
|
||||
Seq: 0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &tg.UpdatesDifference{
|
||||
NewMessages: common,
|
||||
NewEncryptedMessages: secret,
|
||||
OtherUpdates: others,
|
||||
Users: ents.AsUsers(),
|
||||
Chats: ents.AsChats(),
|
||||
State: tg.UpdatesState{
|
||||
Pts: len(s.messages.common),
|
||||
Qts: len(s.messages.secret),
|
||||
Date: s.date,
|
||||
Seq: 0,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// UpdatesGetChannelDifference returns difference between local and remote channel states.
|
||||
func (s *server) UpdatesGetChannelDifference(
|
||||
ctx context.Context, request *tg.UpdatesGetChannelDifferenceRequest,
|
||||
) (tg.UpdatesChannelDifferenceClass, error) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
channel, ok := request.Channel.(*tg.InputChannel)
|
||||
if !ok {
|
||||
return nil, errors.Errorf("bad InputChannelClass type: %T", request.Channel)
|
||||
}
|
||||
|
||||
if peer, ok := s.peers.channels[channel.ChannelID]; true {
|
||||
if !ok {
|
||||
return nil, errors.Errorf("channel %d not found", channel.ChannelID)
|
||||
}
|
||||
|
||||
if peer.AccessHash != channel.AccessHash {
|
||||
return nil, errors.New("invalid access hash")
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
channelMsgs = s.messages.channels[channel.ChannelID]
|
||||
ents = NewEntities()
|
||||
prepared []tg.MessageClass
|
||||
)
|
||||
|
||||
for i := request.Pts + 1; i <= len(channelMsgs); i++ {
|
||||
prepared = append(prepared, channelMsgs[i-1])
|
||||
s.fillMessageEnts(channelMsgs[i-1], ents)
|
||||
}
|
||||
|
||||
if len(prepared) == 0 {
|
||||
return &tg.UpdatesChannelDifferenceEmpty{
|
||||
Pts: len(channelMsgs),
|
||||
Final: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &tg.UpdatesChannelDifference{
|
||||
NewMessages: prepared,
|
||||
Users: ents.AsUsers(),
|
||||
Chats: ents.AsChats(),
|
||||
Pts: len(channelMsgs),
|
||||
Final: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *server) fillMessageEnts(msg tg.MessageClass, ents *Entities) {
|
||||
switch peer := msg.(*tg.Message).PeerID.(type) {
|
||||
case *tg.PeerUser:
|
||||
user, ok := s.peers.users[peer.UserID]
|
||||
if !ok {
|
||||
panic("bad user")
|
||||
}
|
||||
|
||||
ents.Users[user.ID] = user
|
||||
case *tg.PeerChat:
|
||||
chat, ok := s.peers.chats[peer.ChatID]
|
||||
if !ok {
|
||||
panic("bad chat")
|
||||
}
|
||||
|
||||
ents.Chats[chat.ID] = chat
|
||||
case *tg.PeerChannel:
|
||||
channel, ok := s.peers.channels[peer.ChannelID]
|
||||
if !ok {
|
||||
panic("bad channel")
|
||||
}
|
||||
|
||||
ents.Channels[channel.ID] = channel
|
||||
default:
|
||||
panic("unexpected peer type")
|
||||
}
|
||||
|
||||
peerUser, ok := msg.(*tg.Message).FromID.(*tg.PeerUser)
|
||||
if !ok {
|
||||
panic("bad fromID")
|
||||
}
|
||||
|
||||
user, ok := s.peers.users[peerUser.UserID]
|
||||
if !ok {
|
||||
panic("bad user")
|
||||
}
|
||||
|
||||
ents.Users[user.ID] = user
|
||||
}
|
||||
@@ -0,0 +1,96 @@
|
||||
package e2e
|
||||
|
||||
import (
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
|
||||
// EventBuilder struct.
|
||||
type EventBuilder struct {
|
||||
updates []tg.UpdateClass
|
||||
ents *Entities
|
||||
s *server
|
||||
date int
|
||||
}
|
||||
|
||||
// SendMessage send a new message.
|
||||
func (e *EventBuilder) SendMessage(from *tg.PeerUser, peer tg.PeerClass, text string) {
|
||||
msg := &tg.Message{
|
||||
Message: text,
|
||||
PeerID: peer,
|
||||
FromID: from,
|
||||
Date: e.date,
|
||||
}
|
||||
|
||||
fromUser, ok := e.s.peers.users[from.UserID]
|
||||
if !ok {
|
||||
panic("bad fromID")
|
||||
}
|
||||
e.ents.Users[from.UserID] = fromUser
|
||||
|
||||
switch peer := peer.(type) {
|
||||
case *tg.PeerUser:
|
||||
user, ok := e.s.peers.users[peer.UserID]
|
||||
if !ok {
|
||||
panic("peer not found")
|
||||
}
|
||||
|
||||
e.ents.Users[user.ID] = user
|
||||
e.s.messages.common = append(e.s.messages.common, msg)
|
||||
e.updates = append(e.updates, &tg.UpdateNewMessage{
|
||||
Message: msg,
|
||||
Pts: len(e.s.messages.common),
|
||||
PtsCount: 1,
|
||||
})
|
||||
case *tg.PeerChat:
|
||||
chat, ok := e.s.peers.chats[peer.ChatID]
|
||||
if !ok {
|
||||
panic("peer not found")
|
||||
}
|
||||
|
||||
e.ents.Chats[chat.ID] = chat
|
||||
e.s.messages.common = append(e.s.messages.common, msg)
|
||||
e.updates = append(e.updates, &tg.UpdateNewMessage{
|
||||
Message: msg,
|
||||
Pts: len(e.s.messages.common),
|
||||
PtsCount: 1,
|
||||
})
|
||||
case *tg.PeerChannel:
|
||||
channel, ok := e.s.peers.channels[peer.ChannelID]
|
||||
if !ok {
|
||||
panic("peer not found")
|
||||
}
|
||||
|
||||
e.ents.Channels[channel.ID] = channel
|
||||
msgs := append(e.s.messages.channels[peer.ChannelID], msg)
|
||||
e.s.messages.channels[peer.ChannelID] = msgs
|
||||
e.updates = append(e.updates, &tg.UpdateNewChannelMessage{
|
||||
Message: msg,
|
||||
Pts: len(msgs),
|
||||
PtsCount: 1,
|
||||
})
|
||||
default:
|
||||
panic("unexpected peer type")
|
||||
}
|
||||
}
|
||||
|
||||
// CreateEvent creates new event.
|
||||
func (s *server) CreateEvent(f func(ev *EventBuilder)) *tg.Updates {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
s.date++
|
||||
ev := &EventBuilder{
|
||||
ents: NewEntities(),
|
||||
s: s,
|
||||
date: s.date,
|
||||
}
|
||||
f(ev)
|
||||
|
||||
return &tg.Updates{
|
||||
Updates: ev.updates,
|
||||
Users: ev.ents.AsUsers(),
|
||||
Chats: ev.ents.AsChats(),
|
||||
Date: s.date,
|
||||
Seq: 0,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,226 @@
|
||||
package e2e
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/updates"
|
||||
)
|
||||
|
||||
var _ updates.StateStorage = (*memStorage)(nil)
|
||||
|
||||
type memStorage struct {
|
||||
states map[int64]updates.State
|
||||
channels map[int64]map[int64]int
|
||||
mux sync.Mutex
|
||||
}
|
||||
|
||||
func newMemStorage() *memStorage {
|
||||
return &memStorage{
|
||||
states: map[int64]updates.State{},
|
||||
channels: map[int64]map[int64]int{},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *memStorage) GetState(ctx context.Context, userID int64) (state updates.State, found bool, err error) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
state, found = s.states[userID]
|
||||
return
|
||||
}
|
||||
|
||||
func (s *memStorage) SetState(ctx context.Context, userID int64, state updates.State) error {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
s.states[userID] = state
|
||||
s.channels[userID] = map[int64]int{}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memStorage) SetPts(ctx context.Context, userID int64, pts int) error {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
state, ok := s.states[userID]
|
||||
if !ok {
|
||||
return errors.New("state not found")
|
||||
}
|
||||
|
||||
state.Pts = pts
|
||||
s.states[userID] = state
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memStorage) SetQts(ctx context.Context, userID int64, qts int) error {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
state, ok := s.states[userID]
|
||||
if !ok {
|
||||
return errors.New("state not found")
|
||||
}
|
||||
|
||||
state.Qts = qts
|
||||
s.states[userID] = state
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memStorage) SetDate(ctx context.Context, userID int64, date int) error {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
state, ok := s.states[userID]
|
||||
if !ok {
|
||||
return errors.New("state not found")
|
||||
}
|
||||
|
||||
state.Date = date
|
||||
s.states[userID] = state
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memStorage) SetSeq(ctx context.Context, userID int64, seq int) error {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
state, ok := s.states[userID]
|
||||
if !ok {
|
||||
return errors.New("state not found")
|
||||
}
|
||||
|
||||
state.Seq = seq
|
||||
s.states[userID] = state
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memStorage) SetDateSeq(ctx context.Context, userID int64, date, seq int) error {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
state, ok := s.states[userID]
|
||||
if !ok {
|
||||
return errors.New("state not found")
|
||||
}
|
||||
|
||||
state.Date = date
|
||||
state.Seq = seq
|
||||
s.states[userID] = state
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memStorage) SetChannelPts(ctx context.Context, userID, channelID int64, pts int) error {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
channels, ok := s.channels[userID]
|
||||
if !ok {
|
||||
return errors.New("user state does not exist")
|
||||
}
|
||||
|
||||
channels[channelID] = pts
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memStorage) GetChannelPts(ctx context.Context, userID, channelID int64) (pts int, found bool, err error) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
channels, ok := s.channels[userID]
|
||||
if !ok {
|
||||
return 0, false, nil
|
||||
}
|
||||
|
||||
pts, found = channels[channelID]
|
||||
return
|
||||
}
|
||||
|
||||
func (s *memStorage) ForEachChannels(ctx context.Context, userID int64, f func(ctx context.Context, channelID int64, pts int) error) error {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
cmap, ok := s.channels[userID]
|
||||
if !ok {
|
||||
return errors.New("channels map does not exist")
|
||||
}
|
||||
|
||||
for id, pts := range cmap {
|
||||
if err := f(ctx, id, pts); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var _ updates.AccessHasher = (*memAccessHasher)(nil)
|
||||
|
||||
type memAccessHasher struct {
|
||||
channelHashes map[int64]map[int64]int64
|
||||
userHashes map[int64]map[int64]int64
|
||||
mux sync.Mutex
|
||||
}
|
||||
|
||||
func newMemAccessHasher() *memAccessHasher {
|
||||
return &memAccessHasher{
|
||||
channelHashes: map[int64]map[int64]int64{},
|
||||
userHashes: map[int64]map[int64]int64{},
|
||||
}
|
||||
}
|
||||
|
||||
func (m *memAccessHasher) GetChannelAccessHash(ctx context.Context, forUserID, channelID int64) (accessHash int64, found bool, err error) {
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
|
||||
accessHashes, ok := m.channelHashes[forUserID]
|
||||
if !ok {
|
||||
return 0, false, nil
|
||||
}
|
||||
|
||||
accessHash, found = accessHashes[channelID]
|
||||
return
|
||||
}
|
||||
|
||||
func (m *memAccessHasher) SetChannelAccessHash(ctx context.Context, forUserID, channelID, accessHash int64) error {
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
|
||||
accessHashes, ok := m.channelHashes[forUserID]
|
||||
if !ok {
|
||||
accessHashes = map[int64]int64{}
|
||||
m.channelHashes[forUserID] = accessHashes
|
||||
}
|
||||
|
||||
accessHashes[channelID] = accessHash
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *memAccessHasher) GetUserAccessHash(ctx context.Context, forUserID, userID int64) (accessHash int64, found bool, err error) {
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
|
||||
accessHashes, ok := m.userHashes[forUserID]
|
||||
if !ok {
|
||||
return 0, false, nil
|
||||
}
|
||||
|
||||
accessHash, found = accessHashes[userID]
|
||||
return
|
||||
}
|
||||
|
||||
func (m *memAccessHasher) SetUserAccessHash(ctx context.Context, forUserID, userID, accessHash int64) error {
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
|
||||
accessHashes, ok := m.userHashes[forUserID]
|
||||
if !ok {
|
||||
accessHashes = map[int64]int64{}
|
||||
m.channelHashes[forUserID] = accessHashes
|
||||
}
|
||||
|
||||
accessHashes[userID] = accessHash
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user