// Copyright 2019 The Prometheus Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package wlog import ( "encoding/binary" "errors" "fmt" "hash/crc32" "io" "log/slog" "github.com/golang/snappy" "github.com/klauspost/compress/zstd" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/common/promslog" ) // LiveReaderMetrics holds all metrics exposed by the LiveReader. type LiveReaderMetrics struct { readerCorruptionErrors *prometheus.CounterVec } // NewLiveReaderMetrics instantiates, registers and returns metrics to be injected // at LiveReader instantiation. func NewLiveReaderMetrics(reg prometheus.Registerer) *LiveReaderMetrics { m := &LiveReaderMetrics{ readerCorruptionErrors: prometheus.NewCounterVec(prometheus.CounterOpts{ Name: "prometheus_tsdb_wal_reader_corruption_errors_total", Help: "Errors encountered when reading the WAL.", }, []string{"error"}), } if reg != nil { reg.MustRegister(m.readerCorruptionErrors) } return m } // NewLiveReader returns a new live reader. func NewLiveReader(logger *slog.Logger, metrics *LiveReaderMetrics, r io.Reader) *LiveReader { if logger == nil { logger = promslog.NewNopLogger() } // Calling zstd.NewReader with a nil io.Reader and no options cannot return an error. zstdReader, _ := zstd.NewReader(nil) lr := &LiveReader{ logger: logger, rdr: r, zstdReader: zstdReader, metrics: metrics, // Until we understand how they come about, make readers permissive // to records spanning pages. permissive: true, } return lr } // LiveReader reads WAL records from an io.Reader. It allows reading of WALs // that are still in the process of being written, and returns records as soon // as they can be read. type LiveReader struct { logger *slog.Logger rdr io.Reader err error rec []byte compressBuf []byte zstdReader *zstd.Decoder hdr [recordHeaderSize]byte buf [pageSize]byte readIndex int // Index in buf to start at for next read. writeIndex int // Index in buf to start at for next write. total int64 // Total bytes processed during reading in calls to Next(). index int // Used to track partial records, should be 0 at the start of every new record. // For testing, we can treat EOF as a non-error. eofNonErr bool // We sometime see records span page boundaries. Should never happen, but it // does. Until we track down why, set permissive to true to tolerate it. // NB the non-ive Reader implementation allows for this. permissive bool metrics *LiveReaderMetrics } // Err returns any errors encountered reading the WAL. io.EOFs are not terminal // and Next can be tried again. Non-EOFs are terminal, and the reader should // not be used again. It is up to the user to decide when to stop trying should // io.EOF be returned. func (r *LiveReader) Err() error { if r.eofNonErr && errors.Is(r.err, io.EOF) { return nil } return r.err } // Offset returns the number of bytes consumed from this segment. func (r *LiveReader) Offset() int64 { return r.total } func (r *LiveReader) fillBuffer() (int, error) { n, err := r.rdr.Read(r.buf[r.writeIndex:len(r.buf)]) r.writeIndex += n return n, err } // Next returns true if Record() will contain a full record. // If Next returns false, you should always checked the contents of Error(). // Return false guarantees there are no more records if the segment is closed // and not corrupt, otherwise if Err() == io.EOF you should try again when more // data has been written. func (r *LiveReader) Next() bool { for { // If buildRecord returns a non-EOF error, its game up - the segment is // corrupt. If buildRecord returns an EOF, we try and read more in // fillBuffer later on. If that fails to read anything (n=0 && err=EOF), // we return EOF and the user can try again later. If we have a full // page, buildRecord is guaranteed to return a record or a non-EOF; it // has checks the records fit in pages. switch ok, err := r.buildRecord(); { case ok: return true case err != nil && !errors.Is(err, io.EOF): r.err = err return false } // If we've filled the page and not found a record, this // means records have started to span pages. Shouldn't happen // but does and until we found out why, we need to deal with this. if r.permissive && r.writeIndex == pageSize && r.readIndex > 0 { copy(r.buf[:], r.buf[r.readIndex:]) r.writeIndex -= r.readIndex r.readIndex = 0 continue } if r.readIndex == pageSize { r.writeIndex = 0 r.readIndex = 0 } if r.writeIndex != pageSize { n, err := r.fillBuffer() if n == 0 || (err != nil && !errors.Is(err, io.EOF)) { r.err = err return false } } } } // Record returns the current record. // The returned byte slice is only valid until the next call to Next. func (r *LiveReader) Record() []byte { return r.rec } // Rebuild a full record from potentially partial records. Returns false // if there was an error or if we weren't able to read a record for any reason. // Returns true if we read a full record. Any record data is appended to // LiveReader.rec. func (r *LiveReader) buildRecord() (bool, error) { for { // Check that we have data in the internal buffer to read. if r.writeIndex <= r.readIndex { return false, nil } // Attempt to read a record, partial or otherwise. temp, n, err := r.readRecord() if err != nil { return false, err } r.readIndex += n r.total += int64(n) if temp == nil { return false, nil } rt := recTypeFromHeader(r.hdr[0]) if rt == recFirst || rt == recFull { r.rec = r.rec[:0] r.compressBuf = r.compressBuf[:0] } isSnappyCompressed := r.hdr[0]&snappyMask == snappyMask isZstdCompressed := r.hdr[0]&zstdMask == zstdMask if isSnappyCompressed || isZstdCompressed { r.compressBuf = append(r.compressBuf, temp...) } else { r.rec = append(r.rec, temp...) } if err := validateRecord(rt, r.index); err != nil { r.index = 0 return false, err } if rt == recLast || rt == recFull { r.index = 0 if isSnappyCompressed && len(r.compressBuf) > 0 { // The snappy library uses `len` to calculate if we need a new buffer. // In order to allocate as few buffers as possible make the length // equal to the capacity. r.rec = r.rec[:cap(r.rec)] r.rec, err = snappy.Decode(r.rec, r.compressBuf) if err != nil { return false, err } } else if isZstdCompressed && len(r.compressBuf) > 0 { r.rec, err = r.zstdReader.DecodeAll(r.compressBuf, r.rec[:0]) if err != nil { return false, err } } return true, nil } // Only increment i for non-zero records since we use it // to determine valid content record sequences. r.index++ } } // Returns an error if the recType and i indicate an invalid record sequence. // As an example, if i is > 0 because we've read some amount of a partial record // (recFirst, recMiddle, etc. but not recLast) and then we get another recFirst or recFull // instead of a recLast or recMiddle we would have an invalid record. func validateRecord(typ recType, i int) error { switch typ { case recFull: if i != 0 { return errors.New("unexpected full record") } return nil case recFirst: if i != 0 { return errors.New("unexpected first record, dropping buffer") } return nil case recMiddle: if i == 0 { return errors.New("unexpected middle record, dropping buffer") } return nil case recLast: if i == 0 { return errors.New("unexpected last record, dropping buffer") } return nil default: return fmt.Errorf("unexpected record type %d", typ) } } // Read a sub-record (see recType) from the buffer. It could potentially // be a full record (recFull) if the record fits within the bounds of a single page. // Returns a byte slice of the record data read, the number of bytes read, and an error // if there's a non-zero byte in a page term record or the record checksum fails. // This is a non-method function to make it clear it does not mutate the reader. func (r *LiveReader) readRecord() ([]byte, int, error) { // Special case: for recPageTerm, check that are all zeros to end of page, // consume them but don't return them. if r.buf[r.readIndex] == byte(recPageTerm) { // End of page won't necessarily be end of buffer, as we may have // got misaligned by records spanning page boundaries. // r.total % pageSize is the offset into the current page // that r.readIndex points to in buf. Therefore // pageSize - (r.total % pageSize) is the amount left to read of // the current page. remaining := int(pageSize - (r.total % pageSize)) if r.readIndex+remaining > r.writeIndex { return nil, 0, io.EOF } for i := r.readIndex; i < r.readIndex+remaining; i++ { if r.buf[i] != 0 { return nil, 0, errors.New("unexpected non-zero byte in page term bytes") } } return nil, remaining, nil } // Not a recPageTerm; read the record and check the checksum. if r.writeIndex-r.readIndex < recordHeaderSize { return nil, 0, io.EOF } copy(r.hdr[:], r.buf[r.readIndex:r.readIndex+recordHeaderSize]) length := int(binary.BigEndian.Uint16(r.hdr[1:])) crc := binary.BigEndian.Uint32(r.hdr[3:]) if r.readIndex+recordHeaderSize+length > pageSize { if !r.permissive { return nil, 0, fmt.Errorf("record would overflow current page: %d > %d", r.readIndex+recordHeaderSize+length, pageSize) } r.metrics.readerCorruptionErrors.WithLabelValues("record_span_page").Inc() r.logger.Warn("Record spans page boundaries", "start", r.readIndex, "end", recordHeaderSize+length, "pageSize", pageSize) } if recordHeaderSize+length > pageSize { return nil, 0, fmt.Errorf("record length greater than a single page: %d > %d", recordHeaderSize+length, pageSize) } if r.readIndex+recordHeaderSize+length > r.writeIndex { return nil, 0, io.EOF } rec := r.buf[r.readIndex+recordHeaderSize : r.readIndex+recordHeaderSize+length] if c := crc32.Checksum(rec, castagnoliTable); c != crc { return nil, 0, fmt.Errorf("unexpected checksum %x, expected %x", c, crc) } return rec, length + recordHeaderSize, nil }