diff --git a/storage/remote/chunked.go b/storage/remote/chunked.go index 53573d4f7e..39982f6e30 100644 --- a/storage/remote/chunked.go +++ b/storage/remote/chunked.go @@ -15,6 +15,8 @@ package remote import ( "bufio" "encoding/binary" + "hash" + "hash/crc32" "io" "net/http" @@ -26,21 +28,37 @@ import ( // 50MB is the default. This is equivalent to ~100k full XOR chunks and average labelset. const DefaultChunkedReadLimit = 5e+7 +// The table gets initialized with sync.Once but may still cause a race +// with any other use of the crc32 package anywhere. Thus we initialize it +// before. +var castagnoliTable *crc32.Table + +func init() { + castagnoliTable = crc32.MakeTable(crc32.Castagnoli) +} + // ChunkedWriter is an io.Writer wrapper that allows streaming by adding uvarint delimiter before each write in a form // of length of the corresponded byte array. type ChunkedWriter struct { writer io.Writer flusher http.Flusher + + crc32 hash.Hash32 } // NewChunkedWriter constructs a ChunkedWriter. func NewChunkedWriter(w io.Writer, f http.Flusher) *ChunkedWriter { - return &ChunkedWriter{writer: w, flusher: f} + return &ChunkedWriter{writer: w, flusher: f, crc32: crc32.New(castagnoliTable)} } -// Write writes given bytes to the stream. It adds uvarint delimiter before each message. -// Returned bytes number represents sent bytes for a given buffer. The number does not include delimiter bytes. -// It does the flushing for you. +// Write writes given bytes to the stream and flushes it. +// Each frame includes: +// +// 1. uvarint for the size of the data frame. +// 2. uvarint for the Castagnoli polynomial CRC-32 checksum of the data frame. +// 3. n bytes where n is given in the first uvarint. +// +// Write returns number of sent bytes for a given buffer. The number does not include delimiter and checksum bytes. func (w *ChunkedWriter) Write(b []byte) (int, error) { if len(b) == 0 { return 0, nil @@ -48,7 +66,16 @@ func (w *ChunkedWriter) Write(b []byte) (int, error) { var buf [binary.MaxVarintLen64]byte v := binary.PutUvarint(buf[:], uint64(len(b))) + if _, err := w.writer.Write(buf[:v]); err != nil { + return 0, err + } + w.crc32.Reset() + if _, err := w.crc32.Write(b); err != nil { + return 0, err + } + + v = binary.PutUvarint(buf[:], uint64(w.crc32.Sum32())) if _, err := w.writer.Write(buf[:v]); err != nil { return 0, err } @@ -62,22 +89,25 @@ func (w *ChunkedWriter) Write(b []byte) (int, error) { return n, nil } -// ChunkedReader is a buffered reader that expects uvarint delimiter before each message. +// ChunkedReader is a buffered reader that expects uvarint delimiter and checksum before each message. // It will allocate as much as the biggest frame defined by delimiter (on top of bufio.Reader allocations). type ChunkedReader struct { b *bufio.Reader data []byte sizeLimit uint64 + + crc32 hash.Hash32 } // NewChunkedReader constructs a ChunkedReader. func NewChunkedReader(r io.Reader, sizeLimit uint64) *ChunkedReader { - return &ChunkedReader{b: bufio.NewReader(r), sizeLimit: sizeLimit} + return &ChunkedReader{b: bufio.NewReader(r), sizeLimit: sizeLimit, crc32: crc32.New(castagnoliTable)} } // Next returns the next length-delimited record from the input, or io.EOF if // there are no more records available. Returns io.ErrUnexpectedEOF if a short // record is found, with a length of n but fewer than n bytes of data. +// Next also verifies the CRC32 checksum. // // NOTE: The slice returned is valid only until a subsequent call to Next. It's a caller's responsibility to copy the // returned slice if needed. @@ -97,9 +127,19 @@ func (r *ChunkedReader) Next() ([]byte, error) { r.data = r.data[:size] } - if _, err := io.ReadFull(r.b, r.data); err != nil { + crc32, err := binary.ReadUvarint(r.b) + if err != nil { return nil, err } + + r.crc32.Reset() + if _, err := io.ReadFull(io.TeeReader(r.b, r.crc32), r.data); err != nil { + return nil, err + } + + if uint64(r.crc32.Sum32()) != crc32 { + return nil, errors.New("chunkedReader: corrupted frame; checksum mismatch") + } return r.data, nil } diff --git a/storage/remote/chunked_test.go b/storage/remote/chunked_test.go index da6e271b4b..bee352c323 100644 --- a/storage/remote/chunked_test.go +++ b/storage/remote/chunked_test.go @@ -88,3 +88,19 @@ func TestChunkedReader_Overflow(t *testing.T) { testutil.NotOk(t, err, "expect exceed limit error") testutil.Equals(t, "chunkedReader: message size exceeded the limit 11 bytes; got: 12 bytes", err.Error()) } + +func TestChunkedReader_CorruptedFrame(t *testing.T) { + b := &bytes.Buffer{} + w := NewChunkedWriter(b, &mockedFlusher{}) + + n, err := w.Write([]byte("test1")) + testutil.Ok(t, err) + testutil.Equals(t, 5, n) + + bs := b.Bytes() + bs[9] = 1 // Malform the frame by changing one byte. + + _, err = NewChunkedReader(bytes.NewReader(bs), 20).Next() + testutil.NotOk(t, err, "expected malformed frame") + testutil.Equals(t, "chunkedReader: corrupted frame; checksum mismatch", err.Error()) +}