package faketls import ( "bytes" "crypto/hmac" "crypto/sha256" "encoding/hex" "io" "github.com/go-faster/errors" ) // peekDump returns up to n bytes from the start of buf as a hex string for diagnostics. func peekDump(buf []byte, n int) string { if len(buf) < n { n = len(buf) } return hex.EncodeToString(buf[:n]) } // readServerHello reads faketls ServerHello. func readServerHello(r io.Reader, clientRandom [32]byte, secret []byte) error { packetBuf := bytes.NewBuffer(nil) r = io.TeeReader(r, packetBuf) handshake, err := readRecord(r) if err != nil { return errors.Wrapf(err, "handshake record (peek=%s)", peekDump(packetBuf.Bytes(), 32)) } if handshake.Type != RecordTypeHandshake { return errors.Errorf("unexpected handshake record type: got 0x%02x, want 0x%02x (peek=%s)", byte(handshake.Type), byte(RecordTypeHandshake), peekDump(packetBuf.Bytes(), 32)) } changeCipher, err := readRecord(r) if err != nil { return errors.Wrap(err, "change cipher record") } if changeCipher.Type != RecordTypeChangeCipherSpec { return errors.Errorf("unexpected change cipher record type: got 0x%02x, want 0x%02x", byte(changeCipher.Type), byte(RecordTypeChangeCipherSpec)) } cert, err := readRecord(r) if err != nil { return errors.Wrap(err, "cert record") } if cert.Type != RecordTypeApplication { return errors.Errorf("unexpected application record type: got 0x%02x, want 0x%02x", byte(cert.Type), byte(RecordTypeApplication)) } // `$record_header = type 1 byte + version 2 bytes + payload_length 2 bytes = 5 bytes` // `$server_hello_header = type 1 bytes + version 2 bytes + length 3 bytes = 6 bytes` // `$offset = $record_header + $server_hello_header = 11 bytes` const serverRandomOffset = 11 packet := packetBuf.Bytes() // Copy original digest. var originalDigest [32]byte copy(originalDigest[:], packet[serverRandomOffset:serverRandomOffset+32]) // Fill original digest by zeros. var zeros [32]byte copy(packet[serverRandomOffset:serverRandomOffset+32], zeros[:]) mac := hmac.New(sha256.New, secret) if _, err := mac.Write(clientRandom[:]); err != nil { return errors.Wrap(err, "hmac write") } if _, err := mac.Write(packet); err != nil { return errors.Wrap(err, "hmac write") } if !bytes.Equal(mac.Sum(nil), originalDigest[:]) { return errors.New("hmac digest mismatch") } return nil }