diff --git a/pkg/gotd/mtproxy/obfuscated2/obfuscated2.go b/pkg/gotd/mtproxy/obfuscated2/obfuscated2.go index fe642d1f..7aecae38 100644 --- a/pkg/gotd/mtproxy/obfuscated2/obfuscated2.go +++ b/pkg/gotd/mtproxy/obfuscated2/obfuscated2.go @@ -53,7 +53,12 @@ func (o *Obfuscated2) Read(b []byte) (int, error) { return n, err } if n > 0 { - o.decrypt.XORKeyStream(b, b) + // IMPORTANT: only XOR the n bytes that were actually read. + // XOR-ing the full b advances the CTR keystream past where the + // server is and permanently desyncs the stream — every later + // MTProto message decrypts to garbage and the engine fails + // with "msg_key is invalid". + o.decrypt.XORKeyStream(b[:n], b[:n]) } return n, err } diff --git a/pkg/gotd/mtproxy/obfuscated2/short_read_test.go b/pkg/gotd/mtproxy/obfuscated2/short_read_test.go new file mode 100644 index 00000000..bb2b32c8 --- /dev/null +++ b/pkg/gotd/mtproxy/obfuscated2/short_read_test.go @@ -0,0 +1,82 @@ +package obfuscated2 + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "io" + "testing" + + "github.com/stretchr/testify/require" +) + +// chunkConn delivers data from buf in chunks of at most chunkSize bytes. +type chunkConn struct { + buf *bytes.Buffer + chunkSize int +} + +func (c *chunkConn) Read(p []byte) (int, error) { + if c.buf.Len() == 0 { + return 0, io.EOF + } + want := len(p) + if want > c.chunkSize { + want = c.chunkSize + } + return c.buf.Read(p[:want]) +} + +func (c *chunkConn) Write(p []byte) (int, error) { return len(p), nil } + +// TestShortReadKeepsKeystreamAligned ensures that when the underlying +// transport returns fewer bytes than the caller asked for, the CTR +// keystream is only advanced by the bytes actually delivered. +// +// The previous implementation called XORKeyStream(b, b) instead of +// XORKeyStream(b[:n], b[:n]); after a single short read the client and +// server keystreams diverged and every subsequent MTProto message +// failed integrity (msg_key invalid). +func TestShortReadKeepsKeystreamAligned(t *testing.T) { + a := require.New(t) + + key := bytes.Repeat([]byte{0x11}, 32) + iv := bytes.Repeat([]byte{0x22}, 16) + + enc, err := aes.NewCipher(key) + a.NoError(err) + dec, err := aes.NewCipher(key) + a.NoError(err) + + encStream := cipher.NewCTR(enc, iv) + decStream := cipher.NewCTR(dec, iv) + + plaintext := bytes.Repeat([]byte("Hello, MTProxy! "), 50) + ciphertext := make([]byte, len(plaintext)) + encStream.XORKeyStream(ciphertext, plaintext) + + wire := &chunkConn{buf: bytes.NewBuffer(append([]byte(nil), ciphertext...)), chunkSize: 7} + o := &Obfuscated2{ + conn: wire, + keys: keys{decrypt: decStream}, + } + + got := make([]byte, len(plaintext)) + off := 0 + for off < len(plaintext) { + end := off + 128 + if end > len(got) { + end = len(got) + } + n, err := o.Read(got[off:end]) + if err != nil && err != io.EOF { + t.Fatalf("read at off %d: %v", off, err) + } + if n == 0 { + t.Fatalf("zero-length read at off %d", off) + } + off += n + } + + a.Equal(plaintext, got, "short reads must not desync the keystream") +}