diff --git a/wal/wal.go b/wal/wal.go index dd0be9cd81..b37032604e 100644 --- a/wal/wal.go +++ b/wal/wal.go @@ -228,19 +228,23 @@ func NewSize(logger log.Logger, reg prometheus.Registerer, dir string, segmentSi } // Fresh dir, no segments yet. if j == -1 { - if w.segment, err = CreateSegment(w.dir, 0); err != nil { - return nil, err - } - } else { - if w.segment, err = OpenWriteSegment(logger, w.dir, j); err != nil { - return nil, err - } - // Correctly initialize donePages. - stat, err := w.segment.Stat() + segment, err := CreateSegment(w.dir, 0) if err != nil { return nil, err } - w.donePages = int(stat.Size() / pageSize) + + if err := w.setSegment(segment); err != nil { + return nil, err + } + } else { + segment, err := OpenWriteSegment(logger, w.dir, j) + if err != nil { + return nil, err + } + + if err := w.setSegment(segment); err != nil { + return nil, err + } } go w.run() @@ -331,7 +335,9 @@ func (w *WAL) Repair(origErr error) error { if err != nil { return err } - w.segment = s + if err := w.setSegment(s); err != nil { + return err + } f, err := os.Open(tmpfn) if err != nil { @@ -382,8 +388,9 @@ func (w *WAL) nextSegment() error { return errors.Wrap(err, "create new segment file") } prev := w.segment - w.segment = next - w.donePages = 0 + if err := w.setSegment(next); err != nil { + return err + } // Don't block further writes by fsyncing the last segment. w.actorc <- func() { @@ -397,6 +404,19 @@ func (w *WAL) nextSegment() error { return nil } +func (w *WAL) setSegment(segment *Segment) error { + w.segment = segment + + // Correctly initialize donePages. + stat, err := segment.Stat() + if err != nil { + return err + } + w.donePages = int(stat.Size() / pageSize) + + return nil +} + // flushPage writes the new contents of the page to disk. If no more records will fit into // the page, the remaining bytes will be set to zero and a new page will be started. // If clear is true, this is enforced regardless of how many bytes are left in the page. diff --git a/wal/wal_test.go b/wal/wal_test.go index 898030addc..16d2775391 100644 --- a/wal/wal_test.go +++ b/wal/wal_test.go @@ -27,7 +27,6 @@ import ( ) func TestWAL_Repair(t *testing.T) { - for name, test := range map[string]struct { corrSgm int // Which segment to corrupt. corrFunc func(f *os.File) // Func that applies the corruption. @@ -115,7 +114,8 @@ func TestWAL_Repair(t *testing.T) { // We create 3 segments with 3 records each and // then corrupt a given record in a given segment. // As a result we want a repaired WAL with given intact records. - w, err := NewSize(nil, nil, dir, 3*pageSize) + segSize := 3 * pageSize + w, err := NewSize(nil, nil, dir, segSize) testutil.Ok(t, err) var records [][]byte @@ -136,7 +136,7 @@ func TestWAL_Repair(t *testing.T) { testutil.Ok(t, f.Close()) - w, err = New(nil, nil, dir) + w, err = NewSize(nil, nil, dir, segSize) testutil.Ok(t, err) sr, err := NewSegmentsReader(dir) @@ -166,6 +166,11 @@ func TestWAL_Repair(t *testing.T) { t.Fatalf("record %d diverges: want %x, got %x", i, records[i][:10], r[:10]) } } + + // Make sure the last segment is the corrupt segment. + _, last, err := w.Segments() + testutil.Ok(t, err) + testutil.Equals(t, test.corrSgm, last) }) } }