From 4c98decfb729a783c765213e8098bee13fa93a0a Mon Sep 17 00:00:00 2001 From: Michael Suchacz <203725896+ibetitsmike@users.noreply.github.com> Date: Thu, 11 Sep 2025 14:05:14 +0200 Subject: [PATCH] chore: add backed reader, writer and pipe implementation (#19147) Relates to: https://github.com/coder/coder/issues/18101 This PR introduces a new `backedpipe` package that provides reliable bidirectional byte streams over unreliable network connections. The implementation includes: - `BackedPipe`: Orchestrates a reader and writer to provide transparent reconnection and data replay - `BackedReader`: Handles reading with automatic reconnection, blocking reads when disconnected - `BackedWriter`: Maintains a ring buffer of recent writes for replay during reconnection - `RingBuffer`: Efficient circular buffer implementation for storing data The package enables resilient connections by tracking sequence numbers and replaying missed data after reconnection. It handles connection failures gracefully, automatically reconnecting and resuming data transfer from the appropriate point. --- .../immortalstreams/backedpipe/backed_pipe.go | 350 ++++++ .../backedpipe/backed_pipe_test.go | 989 +++++++++++++++++ .../backedpipe/backed_reader.go | 166 +++ .../backedpipe/backed_reader_test.go | 603 +++++++++++ .../backedpipe/backed_writer.go | 243 +++++ .../backedpipe/backed_writer_test.go | 996 ++++++++++++++++++ .../immortalstreams/backedpipe/ring_buffer.go | 129 +++ .../backedpipe/ring_buffer_internal_test.go | 261 +++++ 8 files changed, 3737 insertions(+) create mode 100644 agent/immortalstreams/backedpipe/backed_pipe.go create mode 100644 agent/immortalstreams/backedpipe/backed_pipe_test.go create mode 100644 agent/immortalstreams/backedpipe/backed_reader.go create mode 100644 agent/immortalstreams/backedpipe/backed_reader_test.go create mode 100644 agent/immortalstreams/backedpipe/backed_writer.go create mode 100644 agent/immortalstreams/backedpipe/backed_writer_test.go create mode 100644 agent/immortalstreams/backedpipe/ring_buffer.go create mode 100644 agent/immortalstreams/backedpipe/ring_buffer_internal_test.go diff --git a/agent/immortalstreams/backedpipe/backed_pipe.go b/agent/immortalstreams/backedpipe/backed_pipe.go new file mode 100644 index 0000000000..4b7a9f0300 --- /dev/null +++ b/agent/immortalstreams/backedpipe/backed_pipe.go @@ -0,0 +1,350 @@ +package backedpipe + +import ( + "context" + "io" + "sync" + + "golang.org/x/sync/errgroup" + "golang.org/x/sync/singleflight" + "golang.org/x/xerrors" +) + +var ( + ErrPipeClosed = xerrors.New("pipe is closed") + ErrPipeAlreadyConnected = xerrors.New("pipe is already connected") + ErrReconnectionInProgress = xerrors.New("reconnection already in progress") + ErrReconnectFailed = xerrors.New("reconnect failed") + ErrInvalidSequenceNumber = xerrors.New("remote sequence number exceeds local sequence") + ErrReconnectWriterFailed = xerrors.New("reconnect writer failed") +) + +// connectionState represents the current state of the BackedPipe connection. +type connectionState int + +const ( + // connected indicates the pipe is connected and operational. + connected connectionState = iota + // disconnected indicates the pipe is not connected but not closed. + disconnected + // reconnecting indicates a reconnection attempt is in progress. + reconnecting + // closed indicates the pipe is permanently closed. + closed +) + +// ErrorEvent represents an error from a reader or writer with connection generation info. +type ErrorEvent struct { + Err error + Component string // "reader" or "writer" + Generation uint64 // connection generation when error occurred +} + +const ( + // Default buffer capacity used by the writer - 64MB + DefaultBufferSize = 64 * 1024 * 1024 +) + +// Reconnector is an interface for establishing connections when the BackedPipe needs to reconnect. +// Implementations should: +// 1. Establish a new connection to the remote side +// 2. Exchange sequence numbers with the remote side +// 3. Return the new connection and the remote's reader sequence number +// +// The readerSeqNum parameter is the local reader's current sequence number +// (total bytes successfully read from the remote). This must be sent to the +// remote so it can replay its data to us starting from this number. +// +// The returned remoteReaderSeqNum should be the remote side's reader sequence +// number (how many bytes of our outbound data it has successfully read). This +// informs our writer where to resume (i.e., which bytes to replay to the remote). +type Reconnector interface { + Reconnect(ctx context.Context, readerSeqNum uint64) (conn io.ReadWriteCloser, remoteReaderSeqNum uint64, err error) +} + +// BackedPipe provides a reliable bidirectional byte stream over unreliable network connections. +// It orchestrates a BackedReader and BackedWriter to provide transparent reconnection +// and data replay capabilities. +type BackedPipe struct { + ctx context.Context + cancel context.CancelFunc + mu sync.RWMutex + reader *BackedReader + writer *BackedWriter + reconnector Reconnector + conn io.ReadWriteCloser + + // State machine + state connectionState + connGen uint64 // Increments on each successful reconnection + + // Unified error handling with generation filtering + errChan chan ErrorEvent + + // singleflight group to dedupe concurrent ForceReconnect calls + sf singleflight.Group + + // Track first error per generation to avoid duplicate reconnections + lastErrorGen uint64 +} + +// NewBackedPipe creates a new BackedPipe with default options and the specified reconnector. +// The pipe starts disconnected and must be connected using Connect(). +func NewBackedPipe(ctx context.Context, reconnector Reconnector) *BackedPipe { + pipeCtx, cancel := context.WithCancel(ctx) + + errChan := make(chan ErrorEvent, 1) + + bp := &BackedPipe{ + ctx: pipeCtx, + cancel: cancel, + reconnector: reconnector, + state: disconnected, + connGen: 0, // Start with generation 0 + errChan: errChan, + } + + // Create reader and writer with typed error channel for generation-aware error reporting + bp.reader = NewBackedReader(errChan) + bp.writer = NewBackedWriter(DefaultBufferSize, errChan) + + // Start error handler goroutine + go bp.handleErrors() + + return bp +} + +// Connect establishes the initial connection using the reconnect function. +func (bp *BackedPipe) Connect() error { + bp.mu.Lock() + defer bp.mu.Unlock() + + if bp.state == closed { + return ErrPipeClosed + } + + if bp.state == connected { + return ErrPipeAlreadyConnected + } + + // Use internal context for the actual reconnect operation to ensure + // Close() reliably cancels any in-flight attempt. + return bp.reconnectLocked() +} + +// Read implements io.Reader by delegating to the BackedReader. +func (bp *BackedPipe) Read(p []byte) (int, error) { + return bp.reader.Read(p) +} + +// Write implements io.Writer by delegating to the BackedWriter. +func (bp *BackedPipe) Write(p []byte) (int, error) { + bp.mu.RLock() + writer := bp.writer + state := bp.state + bp.mu.RUnlock() + + if state == closed { + return 0, io.EOF + } + + return writer.Write(p) +} + +// Close closes the pipe and all underlying connections. +func (bp *BackedPipe) Close() error { + bp.mu.Lock() + defer bp.mu.Unlock() + + if bp.state == closed { + return nil + } + + bp.state = closed + bp.cancel() // Cancel main context + + // Close all components in parallel to avoid deadlocks + // + // IMPORTANT: The connection must be closed first to unblock any + // readers or writers that might be holding the mutex on Read/Write + var g errgroup.Group + + if bp.conn != nil { + conn := bp.conn + g.Go(func() error { + return conn.Close() + }) + bp.conn = nil + } + + if bp.reader != nil { + reader := bp.reader + g.Go(func() error { + return reader.Close() + }) + } + + if bp.writer != nil { + writer := bp.writer + g.Go(func() error { + return writer.Close() + }) + } + + // Wait for all close operations to complete and return any error + return g.Wait() +} + +// Connected returns whether the pipe is currently connected. +func (bp *BackedPipe) Connected() bool { + bp.mu.RLock() + defer bp.mu.RUnlock() + return bp.state == connected && bp.reader.Connected() && bp.writer.Connected() +} + +// reconnectLocked handles the reconnection logic. Must be called with write lock held. +func (bp *BackedPipe) reconnectLocked() error { + if bp.state == reconnecting { + return ErrReconnectionInProgress + } + + bp.state = reconnecting + defer func() { + // Only reset to disconnected if we're still in reconnecting state + // (successful reconnection will set state to connected) + if bp.state == reconnecting { + bp.state = disconnected + } + }() + + // Close existing connection if any + if bp.conn != nil { + _ = bp.conn.Close() + bp.conn = nil + } + + // Increment the generation and update both reader and writer. + // We do it now to track even the connections that fail during + // Reconnect. + bp.connGen++ + bp.reader.SetGeneration(bp.connGen) + bp.writer.SetGeneration(bp.connGen) + + // Reconnect reader and writer + seqNum := make(chan uint64, 1) + newR := make(chan io.Reader, 1) + + go bp.reader.Reconnect(seqNum, newR) + + // Get the precise reader sequence number from the reader while it holds its lock + readerSeqNum, ok := <-seqNum + if !ok { + // Reader was closed during reconnection + return ErrReconnectFailed + } + + // Perform reconnect using the exact sequence number we just received + conn, remoteReaderSeqNum, err := bp.reconnector.Reconnect(bp.ctx, readerSeqNum) + if err != nil { + // Unblock reader reconnect + newR <- nil + return ErrReconnectFailed + } + + // Provide the new connection to the reader (reader still holds its lock) + newR <- conn + + // Replay our outbound data from the remote's reader sequence number + writerReconnectErr := bp.writer.Reconnect(remoteReaderSeqNum, conn) + if writerReconnectErr != nil { + return ErrReconnectWriterFailed + } + + // Success - update state + bp.conn = conn + bp.state = connected + + return nil +} + +// handleErrors listens for connection errors from reader/writer and triggers reconnection. +// It filters errors from old connections and ensures only the first error per generation +// triggers reconnection. +func (bp *BackedPipe) handleErrors() { + for { + select { + case <-bp.ctx.Done(): + return + case errorEvt := <-bp.errChan: + bp.handleConnectionError(errorEvt) + } + } +} + +// handleConnectionError handles errors from either reader or writer components. +// It filters errors from old connections and ensures only one reconnection per generation. +func (bp *BackedPipe) handleConnectionError(errorEvt ErrorEvent) { + bp.mu.Lock() + defer bp.mu.Unlock() + + // Skip if already closed + if bp.state == closed { + return + } + + // Filter errors from old connections (lower generation) + if errorEvt.Generation < bp.connGen { + return + } + + // Skip if not connected (already disconnected or reconnecting) + if bp.state != connected { + return + } + + // Skip if we've already seen an error for this generation + if bp.lastErrorGen >= errorEvt.Generation { + return + } + + // This is the first error for this generation + bp.lastErrorGen = errorEvt.Generation + + // Mark as disconnected + bp.state = disconnected + + // Try to reconnect using internal context + reconnectErr := bp.reconnectLocked() + + if reconnectErr != nil { + // Reconnection failed - log or handle as needed + // For now, we'll just continue and wait for manual reconnection + _ = errorEvt.Err // Use the original error from the component + _ = errorEvt.Component // Component info available for potential logging by higher layers + } +} + +// ForceReconnect forces a reconnection attempt immediately. +// This can be used to force a reconnection if a new connection is established. +// It prevents duplicate reconnections when called concurrently. +func (bp *BackedPipe) ForceReconnect() error { + // Deduplicate concurrent ForceReconnect calls so only one reconnection + // attempt runs at a time from this API. Use the pipe's internal context + // to ensure Close() cancels any in-flight attempt. + _, err, _ := bp.sf.Do("force-reconnect", func() (interface{}, error) { + bp.mu.Lock() + defer bp.mu.Unlock() + + if bp.state == closed { + return nil, io.EOF + } + + // Don't force reconnect if already reconnecting + if bp.state == reconnecting { + return nil, ErrReconnectionInProgress + } + + return nil, bp.reconnectLocked() + }) + return err +} diff --git a/agent/immortalstreams/backedpipe/backed_pipe_test.go b/agent/immortalstreams/backedpipe/backed_pipe_test.go new file mode 100644 index 0000000000..57d5a4724d --- /dev/null +++ b/agent/immortalstreams/backedpipe/backed_pipe_test.go @@ -0,0 +1,989 @@ +package backedpipe_test + +import ( + "bytes" + "context" + "io" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/agent/immortalstreams/backedpipe" + "github.com/coder/coder/v2/testutil" +) + +// mockConnection implements io.ReadWriteCloser for testing +type mockConnection struct { + mu sync.Mutex + readBuffer bytes.Buffer + writeBuffer bytes.Buffer + closed bool + readError error + writeError error + closeError error + readFunc func([]byte) (int, error) + writeFunc func([]byte) (int, error) + seqNum uint64 +} + +func newMockConnection() *mockConnection { + return &mockConnection{} +} + +func (mc *mockConnection) Read(p []byte) (int, error) { + mc.mu.Lock() + defer mc.mu.Unlock() + + if mc.readFunc != nil { + return mc.readFunc(p) + } + + if mc.readError != nil { + return 0, mc.readError + } + + return mc.readBuffer.Read(p) +} + +func (mc *mockConnection) Write(p []byte) (int, error) { + mc.mu.Lock() + defer mc.mu.Unlock() + + if mc.writeFunc != nil { + return mc.writeFunc(p) + } + + if mc.writeError != nil { + return 0, mc.writeError + } + + return mc.writeBuffer.Write(p) +} + +func (mc *mockConnection) Close() error { + mc.mu.Lock() + defer mc.mu.Unlock() + mc.closed = true + return mc.closeError +} + +func (mc *mockConnection) WriteString(s string) { + mc.mu.Lock() + defer mc.mu.Unlock() + _, _ = mc.readBuffer.WriteString(s) +} + +func (mc *mockConnection) ReadString() string { + mc.mu.Lock() + defer mc.mu.Unlock() + return mc.writeBuffer.String() +} + +func (mc *mockConnection) SetReadError(err error) { + mc.mu.Lock() + defer mc.mu.Unlock() + mc.readError = err +} + +func (mc *mockConnection) SetWriteError(err error) { + mc.mu.Lock() + defer mc.mu.Unlock() + mc.writeError = err +} + +func (mc *mockConnection) Reset() { + mc.mu.Lock() + defer mc.mu.Unlock() + mc.readBuffer.Reset() + mc.writeBuffer.Reset() + mc.readError = nil + mc.writeError = nil + mc.closed = false +} + +// mockReconnector implements the Reconnector interface for testing +type mockReconnector struct { + mu sync.Mutex + connections []*mockConnection + connectionIndex int + callCount int + signalChan chan struct{} +} + +// Reconnect implements the Reconnector interface +func (m *mockReconnector) Reconnect(ctx context.Context, readerSeqNum uint64) (io.ReadWriteCloser, uint64, error) { + m.mu.Lock() + defer m.mu.Unlock() + + m.callCount++ + + if m.connectionIndex >= len(m.connections) { + return nil, 0, xerrors.New("no more connections available") + } + + conn := m.connections[m.connectionIndex] + m.connectionIndex++ + + // Signal when reconnection happens + if m.connectionIndex > 1 { + select { + case m.signalChan <- struct{}{}: + default: + } + } + + // Determine remoteReaderSeqNum (how many bytes of our outbound data the remote has read) + var remoteReaderSeqNum uint64 + switch { + case m.callCount == 1: + remoteReaderSeqNum = 0 + case conn.seqNum != 0: + remoteReaderSeqNum = conn.seqNum + default: + // Default to 0 if unspecified + remoteReaderSeqNum = 0 + } + + return conn, remoteReaderSeqNum, nil +} + +// GetCallCount returns the current call count in a thread-safe manner +func (m *mockReconnector) GetCallCount() int { + m.mu.Lock() + defer m.mu.Unlock() + return m.callCount +} + +// mockReconnectFunc creates a unified reconnector with all behaviors enabled +func mockReconnectFunc(connections ...*mockConnection) (*mockReconnector, chan struct{}) { + signalChan := make(chan struct{}, 1) + + reconnector := &mockReconnector{ + connections: connections, + signalChan: signalChan, + } + + return reconnector, signalChan +} + +// blockingReconnector is a reconnector that blocks on a channel for deterministic testing +type blockingReconnector struct { + conn1 *mockConnection + conn2 *mockConnection + callCount int + blockChan <-chan struct{} + blockedChan chan struct{} + mu sync.Mutex + signalOnce sync.Once // Ensure we only signal once for the first actual reconnect +} + +func (b *blockingReconnector) Reconnect(ctx context.Context, readerSeqNum uint64) (io.ReadWriteCloser, uint64, error) { + b.mu.Lock() + b.callCount++ + currentCall := b.callCount + b.mu.Unlock() + + if currentCall == 1 { + // Initial connect + return b.conn1, 0, nil + } + + // Signal that we're about to block, but only once for the first reconnect attempt + // This ensures we properly test singleflight deduplication + b.signalOnce.Do(func() { + select { + case b.blockedChan <- struct{}{}: + default: + // If channel is full, don't block + } + }) + + // For subsequent calls, block until channel is closed + select { + case <-b.blockChan: + // Channel closed, proceed with reconnection + case <-ctx.Done(): + return nil, 0, ctx.Err() + } + + return b.conn2, 0, nil +} + +// GetCallCount returns the current call count in a thread-safe manner +func (b *blockingReconnector) GetCallCount() int { + b.mu.Lock() + defer b.mu.Unlock() + return b.callCount +} + +func mockBlockingReconnectFunc(conn1, conn2 *mockConnection, blockChan <-chan struct{}) (*blockingReconnector, chan struct{}) { + blockedChan := make(chan struct{}, 1) + reconnector := &blockingReconnector{ + conn1: conn1, + conn2: conn2, + blockChan: blockChan, + blockedChan: blockedChan, + } + + return reconnector, blockedChan +} + +// eofTestReconnector is a custom reconnector for the EOF test case +type eofTestReconnector struct { + mu sync.Mutex + conn1 io.ReadWriteCloser + conn2 io.ReadWriteCloser + callCount int +} + +func (e *eofTestReconnector) Reconnect(ctx context.Context, readerSeqNum uint64) (io.ReadWriteCloser, uint64, error) { + e.mu.Lock() + defer e.mu.Unlock() + + e.callCount++ + + if e.callCount == 1 { + return e.conn1, 0, nil + } + if e.callCount == 2 { + // Second call is the reconnection after EOF + // Return 5 to indicate remote has read all 5 bytes of "hello" + return e.conn2, 5, nil + } + + return nil, 0, xerrors.New("no more connections") +} + +// GetCallCount returns the current call count in a thread-safe manner +func (e *eofTestReconnector) GetCallCount() int { + e.mu.Lock() + defer e.mu.Unlock() + return e.callCount +} + +func TestBackedPipe_NewBackedPipe(t *testing.T) { + t.Parallel() + + ctx := context.Background() + reconnectFn, _ := mockReconnectFunc(newMockConnection()) + + bp := backedpipe.NewBackedPipe(ctx, reconnectFn) + defer bp.Close() + require.NotNil(t, bp) + require.False(t, bp.Connected()) +} + +func TestBackedPipe_Connect(t *testing.T) { + t.Parallel() + + ctx := context.Background() + conn := newMockConnection() + reconnector, _ := mockReconnectFunc(conn) + + bp := backedpipe.NewBackedPipe(ctx, reconnector) + defer bp.Close() + + err := bp.Connect() + require.NoError(t, err) + require.True(t, bp.Connected()) + require.Equal(t, 1, reconnector.GetCallCount()) +} + +func TestBackedPipe_ConnectAlreadyConnected(t *testing.T) { + t.Parallel() + + ctx := context.Background() + conn := newMockConnection() + reconnectFn, _ := mockReconnectFunc(conn) + + bp := backedpipe.NewBackedPipe(ctx, reconnectFn) + defer bp.Close() + + err := bp.Connect() + require.NoError(t, err) + + // Second connect should fail + err = bp.Connect() + require.Error(t, err) + require.ErrorIs(t, err, backedpipe.ErrPipeAlreadyConnected) +} + +func TestBackedPipe_ConnectAfterClose(t *testing.T) { + t.Parallel() + + ctx := context.Background() + conn := newMockConnection() + reconnectFn, _ := mockReconnectFunc(conn) + + bp := backedpipe.NewBackedPipe(ctx, reconnectFn) + + err := bp.Close() + require.NoError(t, err) + + err = bp.Connect() + require.Error(t, err) + require.ErrorIs(t, err, backedpipe.ErrPipeClosed) +} + +func TestBackedPipe_BasicReadWrite(t *testing.T) { + t.Parallel() + + ctx := context.Background() + conn := newMockConnection() + reconnectFn, _ := mockReconnectFunc(conn) + + bp := backedpipe.NewBackedPipe(ctx, reconnectFn) + defer bp.Close() + + err := bp.Connect() + require.NoError(t, err) + + // Write data + n, err := bp.Write([]byte("hello")) + require.NoError(t, err) + require.Equal(t, 5, n) + + // Simulate data coming back + conn.WriteString("world") + + // Read data + buf := make([]byte, 10) + n, err = bp.Read(buf) + require.NoError(t, err) + require.Equal(t, 5, n) + require.Equal(t, "world", string(buf[:n])) +} + +func TestBackedPipe_WriteBeforeConnect(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + + conn := newMockConnection() + reconnectFn, _ := mockReconnectFunc(conn) + + bp := backedpipe.NewBackedPipe(ctx, reconnectFn) + defer bp.Close() + + // Write before connecting should block + writeComplete := make(chan error, 1) + go func() { + _, err := bp.Write([]byte("hello")) + writeComplete <- err + }() + + // Verify write is blocked + select { + case <-writeComplete: + t.Fatal("Write should have blocked when disconnected") + case <-time.After(100 * time.Millisecond): + // Expected - write is blocked + } + + // Connect should unblock the write + err := bp.Connect() + require.NoError(t, err) + + // Write should now complete + err = testutil.RequireReceive(ctx, t, writeComplete) + require.NoError(t, err) + + // Check that data was replayed to connection + require.Equal(t, "hello", conn.ReadString()) +} + +func TestBackedPipe_ReadBlocksWhenDisconnected(t *testing.T) { + t.Parallel() + + ctx := context.Background() + testCtx := testutil.Context(t, testutil.WaitShort) + reconnectFn, _ := mockReconnectFunc(newMockConnection()) + + bp := backedpipe.NewBackedPipe(ctx, reconnectFn) + defer bp.Close() + + // Start a read that should block + readDone := make(chan struct{}) + readStarted := make(chan struct{}, 1) + var readErr error + + go func() { + defer close(readDone) + readStarted <- struct{}{} // Signal that we're about to start the read + buf := make([]byte, 10) + _, readErr = bp.Read(buf) + }() + + // Wait for the goroutine to start + testutil.TryReceive(testCtx, t, readStarted) + + // Ensure the read is actually blocked by verifying it hasn't completed + require.Eventually(t, func() bool { + select { + case <-readDone: + t.Fatal("Read should be blocked when disconnected") + return false + default: + // Good, still blocked + return true + } + }, testutil.WaitShort, testutil.IntervalMedium) + + // Close should unblock the read + bp.Close() + + testutil.TryReceive(testCtx, t, readDone) + require.Equal(t, io.EOF, readErr) +} + +func TestBackedPipe_Reconnection(t *testing.T) { + t.Parallel() + + ctx := context.Background() + testCtx := testutil.Context(t, testutil.WaitShort) + conn1 := newMockConnection() + conn2 := newMockConnection() + conn2.seqNum = 17 // Remote has received 17 bytes, so replay from sequence 17 + reconnectFn, signalChan := mockReconnectFunc(conn1, conn2) + + bp := backedpipe.NewBackedPipe(ctx, reconnectFn) + defer bp.Close() + + // Initial connect + err := bp.Connect() + require.NoError(t, err) + + // Write some data before failure + bp.Write([]byte("before disconnect***")) + + // Simulate connection failure + conn1.SetReadError(xerrors.New("connection lost")) + conn1.SetWriteError(xerrors.New("connection lost")) + + // Trigger a write to cause the pipe to notice the failure + _, _ = bp.Write([]byte("trigger failure ")) + + testutil.RequireReceive(testCtx, t, signalChan) + + // Wait for reconnection to complete + require.Eventually(t, func() bool { + return bp.Connected() + }, testutil.WaitShort, testutil.IntervalFast, "pipe should reconnect") + + replayedData := conn2.ReadString() + require.Equal(t, "***trigger failure ", replayedData, "Should replay exactly the data written after sequence 17") + + // Verify that new writes work with the reconnected pipe + _, err = bp.Write([]byte("new data after reconnect")) + require.NoError(t, err) + + // Read all data from the connection (replayed + new data) + allData := conn2.ReadString() + require.Equal(t, "***trigger failure new data after reconnect", allData, "Should have replayed data plus new data") +} + +func TestBackedPipe_Close(t *testing.T) { + t.Parallel() + + ctx := context.Background() + conn := newMockConnection() + reconnectFn, _ := mockReconnectFunc(conn) + + bp := backedpipe.NewBackedPipe(ctx, reconnectFn) + + err := bp.Connect() + require.NoError(t, err) + + err = bp.Close() + require.NoError(t, err) + require.True(t, conn.closed) + + // Operations after close should fail + _, err = bp.Read(make([]byte, 10)) + require.Equal(t, io.EOF, err) + + _, err = bp.Write([]byte("test")) + require.Equal(t, io.EOF, err) +} + +func TestBackedPipe_CloseIdempotent(t *testing.T) { + t.Parallel() + + ctx := context.Background() + conn := newMockConnection() + reconnectFn, _ := mockReconnectFunc(conn) + + bp := backedpipe.NewBackedPipe(ctx, reconnectFn) + + err := bp.Close() + require.NoError(t, err) + + // Second close should be no-op + err = bp.Close() + require.NoError(t, err) +} + +func TestBackedPipe_ReconnectFunctionFailure(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + failingReconnector := &mockReconnector{ + connections: nil, // No connections available + } + + bp := backedpipe.NewBackedPipe(ctx, failingReconnector) + defer bp.Close() + + err := bp.Connect() + require.Error(t, err) + require.ErrorIs(t, err, backedpipe.ErrReconnectFailed) + require.False(t, bp.Connected()) +} + +func TestBackedPipe_ForceReconnect(t *testing.T) { + t.Parallel() + + ctx := context.Background() + conn1 := newMockConnection() + conn2 := newMockConnection() + // Set conn2 sequence number to 9 to indicate remote has read all 9 bytes of "test data" + conn2.seqNum = 9 + reconnector, _ := mockReconnectFunc(conn1, conn2) + + bp := backedpipe.NewBackedPipe(ctx, reconnector) + defer bp.Close() + + // Initial connect + err := bp.Connect() + require.NoError(t, err) + require.True(t, bp.Connected()) + require.Equal(t, 1, reconnector.GetCallCount()) + + // Write some data to the first connection + _, err = bp.Write([]byte("test data")) + require.NoError(t, err) + require.Equal(t, "test data", conn1.ReadString()) + + // Force a reconnection + err = bp.ForceReconnect() + require.NoError(t, err) + require.True(t, bp.Connected()) + require.Equal(t, 2, reconnector.GetCallCount()) + + // Since the mock returns the proper sequence number, no data should be replayed + // The new connection should be empty + require.Equal(t, "", conn2.ReadString()) + + // Verify that data can still be written and read after forced reconnection + _, err = bp.Write([]byte("new data")) + require.NoError(t, err) + require.Equal(t, "new data", conn2.ReadString()) + + // Verify that reads work with the new connection + conn2.WriteString("response data") + buf := make([]byte, 20) + n, err := bp.Read(buf) + require.NoError(t, err) + require.Equal(t, 13, n) + require.Equal(t, "response data", string(buf[:n])) +} + +func TestBackedPipe_ForceReconnectWhenClosed(t *testing.T) { + t.Parallel() + + ctx := context.Background() + conn := newMockConnection() + reconnectFn, _ := mockReconnectFunc(conn) + + bp := backedpipe.NewBackedPipe(ctx, reconnectFn) + + // Close the pipe first + err := bp.Close() + require.NoError(t, err) + + // Try to force reconnect when closed + err = bp.ForceReconnect() + require.Error(t, err) + require.Equal(t, io.EOF, err) +} + +func TestBackedPipe_StateTransitionsAndGenerationTracking(t *testing.T) { + t.Parallel() + + ctx := context.Background() + conn1 := newMockConnection() + conn2 := newMockConnection() + conn3 := newMockConnection() + reconnector, signalChan := mockReconnectFunc(conn1, conn2, conn3) + + bp := backedpipe.NewBackedPipe(ctx, reconnector) + defer bp.Close() + + // Initial state should be disconnected + require.False(t, bp.Connected()) + + // Connect should transition to connected + err := bp.Connect() + require.NoError(t, err) + require.True(t, bp.Connected()) + require.Equal(t, 1, reconnector.GetCallCount()) + + // Write some data + _, err = bp.Write([]byte("test data gen 1")) + require.NoError(t, err) + + // Simulate connection failure by setting errors on connection + conn1.SetReadError(xerrors.New("connection lost")) + conn1.SetWriteError(xerrors.New("connection lost")) + + // Trigger a write to cause the pipe to notice the failure + _, _ = bp.Write([]byte("trigger failure")) + + // Wait for reconnection signal + testutil.RequireReceive(testutil.Context(t, testutil.WaitShort), t, signalChan) + + // Wait for reconnection to complete + require.Eventually(t, func() bool { + return bp.Connected() + }, testutil.WaitShort, testutil.IntervalFast, "should reconnect") + require.Equal(t, 2, reconnector.GetCallCount()) + + // Force another reconnection + err = bp.ForceReconnect() + require.NoError(t, err) + require.True(t, bp.Connected()) + require.Equal(t, 3, reconnector.GetCallCount()) + + // Close should transition to closed state + err = bp.Close() + require.NoError(t, err) + require.False(t, bp.Connected()) + + // Operations on closed pipe should fail + err = bp.Connect() + require.Equal(t, backedpipe.ErrPipeClosed, err) + + err = bp.ForceReconnect() + require.Equal(t, io.EOF, err) +} + +func TestBackedPipe_GenerationFiltering(t *testing.T) { + t.Parallel() + + ctx := context.Background() + conn1 := newMockConnection() + conn2 := newMockConnection() + reconnector, _ := mockReconnectFunc(conn1, conn2) + + bp := backedpipe.NewBackedPipe(ctx, reconnector) + defer bp.Close() + + // Connect + err := bp.Connect() + require.NoError(t, err) + require.True(t, bp.Connected()) + + // Simulate multiple rapid errors from the same connection generation + // Only the first one should trigger reconnection + conn1.SetReadError(xerrors.New("error 1")) + conn1.SetWriteError(xerrors.New("error 2")) + + // Trigger multiple errors quickly + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + _, _ = bp.Write([]byte("trigger error 1")) + }() + go func() { + defer wg.Done() + _, _ = bp.Write([]byte("trigger error 2")) + }() + + // Wait for both writes to complete + wg.Wait() + + // Wait for reconnection to complete + require.Eventually(t, func() bool { + return bp.Connected() + }, testutil.WaitShort, testutil.IntervalFast, "should reconnect once") + + // Should have only reconnected once despite multiple errors + require.Equal(t, 2, reconnector.GetCallCount()) // Initial connect + 1 reconnect +} + +func TestBackedPipe_DuplicateReconnectionPrevention(t *testing.T) { + t.Parallel() + + ctx := context.Background() + testCtx := testutil.Context(t, testutil.WaitShort) + + // Create a blocking reconnector for deterministic testing + conn1 := newMockConnection() + conn2 := newMockConnection() + blockChan := make(chan struct{}) + reconnector, blockedChan := mockBlockingReconnectFunc(conn1, conn2, blockChan) + + bp := backedpipe.NewBackedPipe(ctx, reconnector) + defer bp.Close() + + // Initial connect + err := bp.Connect() + require.NoError(t, err) + require.Equal(t, 1, reconnector.GetCallCount(), "should have exactly 1 call after initial connect") + + // We'll use channels to coordinate the test execution: + // 1. Start all goroutines but have them wait + // 2. Release the first one and wait for it to block + // 3. Release the others while the first is still blocked + + const numConcurrent = 3 + startSignals := make([]chan struct{}, numConcurrent) + startedSignals := make([]chan struct{}, numConcurrent) + for i := range startSignals { + startSignals[i] = make(chan struct{}) + startedSignals[i] = make(chan struct{}) + } + + errors := make([]error, numConcurrent) + var wg sync.WaitGroup + + // Start all goroutines + for i := 0; i < numConcurrent; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + // Wait for the signal to start + <-startSignals[idx] + // Signal that we're about to call ForceReconnect + close(startedSignals[idx]) + errors[idx] = bp.ForceReconnect() + }(i) + } + + // Start the first ForceReconnect and wait for it to block + close(startSignals[0]) + <-startedSignals[0] + + // Wait for the first reconnect to actually start and block + testutil.RequireReceive(testCtx, t, blockedChan) + + // Now start all the other ForceReconnect calls + // They should all join the same singleflight operation + for i := 1; i < numConcurrent; i++ { + close(startSignals[i]) + } + + // Wait for all additional goroutines to have started their calls + for i := 1; i < numConcurrent; i++ { + <-startedSignals[i] + } + + // At this point, one reconnect has started and is blocked, + // and all other goroutines have called ForceReconnect and should be + // waiting on the same singleflight operation. + // Due to singleflight, only one reconnect should have been attempted. + require.Equal(t, 2, reconnector.GetCallCount(), "should have exactly 2 calls: initial connect + 1 reconnect due to singleflight") + + // Release the blocking reconnect function + close(blockChan) + + // Wait for all ForceReconnect calls to complete + wg.Wait() + + // All calls should succeed (they share the same result from singleflight) + for i, err := range errors { + require.NoError(t, err, "ForceReconnect %d should succeed", i, err) + } + + // Final verification: call count should still be exactly 2 + require.Equal(t, 2, reconnector.GetCallCount(), "final call count should be exactly 2: initial connect + 1 singleflight reconnect") +} + +func TestBackedPipe_SingleReconnectionOnMultipleErrors(t *testing.T) { + t.Parallel() + + ctx := context.Background() + testCtx := testutil.Context(t, testutil.WaitShort) + + // Create connections for initial connect and reconnection + conn1 := newMockConnection() + conn2 := newMockConnection() + reconnector, signalChan := mockReconnectFunc(conn1, conn2) + + bp := backedpipe.NewBackedPipe(ctx, reconnector) + defer bp.Close() + + // Initial connect + err := bp.Connect() + require.NoError(t, err) + require.True(t, bp.Connected()) + require.Equal(t, 1, reconnector.GetCallCount()) + + // Write some initial data to establish the connection + _, err = bp.Write([]byte("initial data")) + require.NoError(t, err) + + // Set up both read and write errors on the connection + conn1.SetReadError(xerrors.New("read connection lost")) + conn1.SetWriteError(xerrors.New("write connection lost")) + + // Trigger write error (this will trigger reconnection) + go func() { + _, _ = bp.Write([]byte("trigger write error")) + }() + + // Wait for reconnection to start + testutil.RequireReceive(testCtx, t, signalChan) + + // Wait for reconnection to complete + require.Eventually(t, func() bool { + return bp.Connected() + }, testutil.WaitShort, testutil.IntervalFast, "should reconnect after write error") + + // Verify that only one reconnection occurred + require.Equal(t, 2, reconnector.GetCallCount(), "should have exactly 2 calls: initial connect + 1 reconnection") + require.True(t, bp.Connected(), "should be connected after reconnection") +} + +func TestBackedPipe_ForceReconnectWhenDisconnected(t *testing.T) { + t.Parallel() + + ctx := context.Background() + conn := newMockConnection() + reconnector, _ := mockReconnectFunc(conn) + + bp := backedpipe.NewBackedPipe(ctx, reconnector) + defer bp.Close() + + // Don't connect initially, just force reconnect + err := bp.ForceReconnect() + require.NoError(t, err) + require.True(t, bp.Connected()) + require.Equal(t, 1, reconnector.GetCallCount()) + + // Verify we can write and read + _, err = bp.Write([]byte("test")) + require.NoError(t, err) + require.Equal(t, "test", conn.ReadString()) + + conn.WriteString("response") + buf := make([]byte, 10) + n, err := bp.Read(buf) + require.NoError(t, err) + require.Equal(t, 8, n) + require.Equal(t, "response", string(buf[:n])) +} + +func TestBackedPipe_EOFTriggersReconnection(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + // Create connections where we can control when EOF occurs + conn1 := newMockConnection() + conn2 := newMockConnection() + conn2.WriteString("newdata") // Pre-populate conn2 with data + + // Make conn1 return EOF after reading "world" + hasReadData := false + conn1.readFunc = func(p []byte) (int, error) { + // Don't lock here - the Read method already holds the lock + + // First time: return "world" + if !hasReadData && conn1.readBuffer.Len() > 0 { + n, _ := conn1.readBuffer.Read(p) + hasReadData = true + return n, nil + } + // After that: return EOF + return 0, io.EOF + } + conn1.WriteString("world") + + reconnector := &eofTestReconnector{ + conn1: conn1, + conn2: conn2, + } + + bp := backedpipe.NewBackedPipe(ctx, reconnector) + defer bp.Close() + + // Initial connect + err := bp.Connect() + require.NoError(t, err) + require.Equal(t, 1, reconnector.GetCallCount()) + + // Write some data + _, err = bp.Write([]byte("hello")) + require.NoError(t, err) + + buf := make([]byte, 10) + + // First read should succeed + n, err := bp.Read(buf) + require.NoError(t, err) + require.Equal(t, 5, n) + require.Equal(t, "world", string(buf[:n])) + + // Next read will encounter EOF and should trigger reconnection + // After reconnection, it should read from conn2 + n, err = bp.Read(buf) + require.NoError(t, err) + require.Equal(t, 7, n) + require.Equal(t, "newdata", string(buf[:n])) + + // Verify reconnection happened + require.Equal(t, 2, reconnector.GetCallCount()) + + // Verify the pipe is still connected and functional + require.True(t, bp.Connected()) + + // Further writes should go to the new connection + _, err = bp.Write([]byte("aftereof")) + require.NoError(t, err) + require.Equal(t, "aftereof", conn2.ReadString()) +} + +func BenchmarkBackedPipe_Write(b *testing.B) { + ctx := context.Background() + conn := newMockConnection() + reconnectFn, _ := mockReconnectFunc(conn) + + bp := backedpipe.NewBackedPipe(ctx, reconnectFn) + bp.Connect() + b.Cleanup(func() { + _ = bp.Close() + }) + + data := make([]byte, 1024) // 1KB writes + + b.ResetTimer() + for i := 0; i < b.N; i++ { + bp.Write(data) + } +} + +func BenchmarkBackedPipe_Read(b *testing.B) { + ctx := context.Background() + conn := newMockConnection() + reconnectFn, _ := mockReconnectFunc(conn) + + bp := backedpipe.NewBackedPipe(ctx, reconnectFn) + bp.Connect() + b.Cleanup(func() { + _ = bp.Close() + }) + + buf := make([]byte, 1024) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + // Fill connection with fresh data for each iteration + conn.WriteString(string(buf)) + bp.Read(buf) + } +} diff --git a/agent/immortalstreams/backedpipe/backed_reader.go b/agent/immortalstreams/backedpipe/backed_reader.go new file mode 100644 index 0000000000..a8e24ad446 --- /dev/null +++ b/agent/immortalstreams/backedpipe/backed_reader.go @@ -0,0 +1,166 @@ +package backedpipe + +import ( + "io" + "sync" +) + +// BackedReader wraps an unreliable io.Reader and makes it resilient to disconnections. +// It tracks sequence numbers for all bytes read and can handle reconnection, +// blocking reads when disconnected instead of erroring. +type BackedReader struct { + mu sync.Mutex + cond *sync.Cond + reader io.Reader + sequenceNum uint64 + closed bool + + // Error channel for generation-aware error reporting + errorEventChan chan<- ErrorEvent + + // Current connection generation for error reporting + currentGen uint64 +} + +// NewBackedReader creates a new BackedReader with generation-aware error reporting. +// The reader is initially disconnected and must be connected using Reconnect before +// reads will succeed. The errorEventChan will receive ErrorEvent structs containing +// error details, component info, and connection generation. +func NewBackedReader(errorEventChan chan<- ErrorEvent) *BackedReader { + if errorEventChan == nil { + panic("error event channel cannot be nil") + } + br := &BackedReader{ + errorEventChan: errorEventChan, + } + br.cond = sync.NewCond(&br.mu) + return br +} + +// Read implements io.Reader. It blocks when disconnected until either: +// 1. A reconnection is established +// 2. The reader is closed +// +// When connected, it reads from the underlying reader and updates sequence numbers. +// Connection failures are automatically detected and reported to the higher layer via callback. +func (br *BackedReader) Read(p []byte) (int, error) { + br.mu.Lock() + defer br.mu.Unlock() + + for { + // Step 1: Wait until we have a reader or are closed + for br.reader == nil && !br.closed { + br.cond.Wait() + } + + if br.closed { + return 0, io.EOF + } + + // Step 2: Perform the read while holding the mutex + // This ensures proper synchronization with Reconnect and Close operations + n, err := br.reader.Read(p) + br.sequenceNum += uint64(n) // #nosec G115 -- n is always >= 0 per io.Reader contract + + if err == nil { + return n, nil + } + + // Mark reader as disconnected so future reads will wait for reconnection + br.reader = nil + + // Notify parent of error with generation information + select { + case br.errorEventChan <- ErrorEvent{ + Err: err, + Component: "reader", + Generation: br.currentGen, + }: + default: + // Channel is full, drop the error. + // This is not a problem, because we set the reader to nil + // and block until reconnected so no new errors will be sent + // until pipe processes the error and reconnects. + } + + // If we got some data before the error, return it now + if n > 0 { + return n, nil + } + } +} + +// Reconnect coordinates reconnection using channels for better synchronization. +// The seqNum channel is used to send the current sequence number to the caller. +// The newR channel is used to receive the new reader from the caller. +// This allows for better coordination during the reconnection process. +func (br *BackedReader) Reconnect(seqNum chan<- uint64, newR <-chan io.Reader) { + // Grab the lock + br.mu.Lock() + defer br.mu.Unlock() + + if br.closed { + // Close the channel to indicate closed state + close(seqNum) + return + } + + // Get the sequence number to send to the other side via seqNum channel + seqNum <- br.sequenceNum + close(seqNum) + + // Wait for the reconnect to complete, via newR channel, and give us a new io.Reader + newReader := <-newR + + // If reconnection fails while we are starting it, the caller sends nil on newR + if newReader == nil { + // Reconnection failed, keep current state + return + } + + // Reconnection successful + br.reader = newReader + + // Notify any waiting reads via the cond + br.cond.Broadcast() +} + +// Close the reader and wake up any blocked reads. +// After closing, all Read calls will return io.EOF. +func (br *BackedReader) Close() error { + br.mu.Lock() + defer br.mu.Unlock() + + if br.closed { + return nil + } + + br.closed = true + br.reader = nil + + // Wake up any blocked reads + br.cond.Broadcast() + + return nil +} + +// SequenceNum returns the current sequence number (total bytes read). +func (br *BackedReader) SequenceNum() uint64 { + br.mu.Lock() + defer br.mu.Unlock() + return br.sequenceNum +} + +// Connected returns whether the reader is currently connected. +func (br *BackedReader) Connected() bool { + br.mu.Lock() + defer br.mu.Unlock() + return br.reader != nil +} + +// SetGeneration sets the current connection generation for error reporting. +func (br *BackedReader) SetGeneration(generation uint64) { + br.mu.Lock() + defer br.mu.Unlock() + br.currentGen = generation +} diff --git a/agent/immortalstreams/backedpipe/backed_reader_test.go b/agent/immortalstreams/backedpipe/backed_reader_test.go new file mode 100644 index 0000000000..a1a8de1590 --- /dev/null +++ b/agent/immortalstreams/backedpipe/backed_reader_test.go @@ -0,0 +1,603 @@ +package backedpipe_test + +import ( + "context" + "io" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/agent/immortalstreams/backedpipe" + "github.com/coder/coder/v2/testutil" +) + +// mockReader implements io.Reader with controllable behavior for testing +type mockReader struct { + mu sync.Mutex + data []byte + pos int + err error + readFunc func([]byte) (int, error) +} + +func newMockReader(data string) *mockReader { + return &mockReader{data: []byte(data)} +} + +func (mr *mockReader) Read(p []byte) (int, error) { + mr.mu.Lock() + defer mr.mu.Unlock() + + if mr.readFunc != nil { + return mr.readFunc(p) + } + + if mr.err != nil { + return 0, mr.err + } + + if mr.pos >= len(mr.data) { + return 0, io.EOF + } + + n := copy(p, mr.data[mr.pos:]) + mr.pos += n + return n, nil +} + +func (mr *mockReader) setError(err error) { + mr.mu.Lock() + defer mr.mu.Unlock() + mr.err = err +} + +func TestBackedReader_NewBackedReader(t *testing.T) { + t.Parallel() + + errChan := make(chan backedpipe.ErrorEvent, 1) + br := backedpipe.NewBackedReader(errChan) + require.NotNil(t, br) + require.Equal(t, uint64(0), br.SequenceNum()) + require.False(t, br.Connected()) +} + +func TestBackedReader_BasicReadOperation(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + + errChan := make(chan backedpipe.ErrorEvent, 1) + br := backedpipe.NewBackedReader(errChan) + reader := newMockReader("hello world") + + // Connect the reader + seqNum := make(chan uint64, 1) + newR := make(chan io.Reader, 1) + + go br.Reconnect(seqNum, newR) + + // Get sequence number from reader + seq := testutil.RequireReceive(ctx, t, seqNum) + require.Equal(t, uint64(0), seq) + + // Send new reader + testutil.RequireSend(ctx, t, newR, io.Reader(reader)) + + // Read data + buf := make([]byte, 5) + n, err := br.Read(buf) + require.NoError(t, err) + require.Equal(t, 5, n) + require.Equal(t, "hello", string(buf)) + require.Equal(t, uint64(5), br.SequenceNum()) + + // Read more data + n, err = br.Read(buf) + require.NoError(t, err) + require.Equal(t, 5, n) + require.Equal(t, " worl", string(buf)) + require.Equal(t, uint64(10), br.SequenceNum()) +} + +func TestBackedReader_ReadBlocksWhenDisconnected(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + + errChan := make(chan backedpipe.ErrorEvent, 1) + br := backedpipe.NewBackedReader(errChan) + + // Start a read operation that should block + readDone := make(chan struct{}) + var readErr error + var readBuf []byte + var readN int + + go func() { + defer close(readDone) + buf := make([]byte, 10) + readN, readErr = br.Read(buf) + readBuf = buf[:readN] + }() + + // Ensure the read is actually blocked by verifying it hasn't completed + // and that the reader is not connected + select { + case <-readDone: + t.Fatal("Read should be blocked when disconnected") + default: + // Read is still blocked, which is what we want + } + require.False(t, br.Connected(), "Reader should not be connected") + + // Connect and the read should unblock + reader := newMockReader("test") + seqNum := make(chan uint64, 1) + newR := make(chan io.Reader, 1) + + go br.Reconnect(seqNum, newR) + + // Get sequence number and send new reader + testutil.RequireReceive(ctx, t, seqNum) + testutil.RequireSend(ctx, t, newR, io.Reader(reader)) + + // Wait for read to complete + testutil.TryReceive(ctx, t, readDone) + require.NoError(t, readErr) + require.Equal(t, "test", string(readBuf)) +} + +func TestBackedReader_ReconnectionAfterFailure(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + + errChan := make(chan backedpipe.ErrorEvent, 1) + br := backedpipe.NewBackedReader(errChan) + reader1 := newMockReader("first") + + // Initial connection + seqNum := make(chan uint64, 1) + newR := make(chan io.Reader, 1) + + go br.Reconnect(seqNum, newR) + + // Get sequence number and send new reader + testutil.RequireReceive(ctx, t, seqNum) + testutil.RequireSend(ctx, t, newR, io.Reader(reader1)) + + // Read some data + buf := make([]byte, 5) + n, err := br.Read(buf) + require.NoError(t, err) + require.Equal(t, "first", string(buf[:n])) + require.Equal(t, uint64(5), br.SequenceNum()) + + // Simulate connection failure + reader1.setError(xerrors.New("connection lost")) + + // Start a read that will block due to connection failure + readDone := make(chan error, 1) + go func() { + _, err := br.Read(buf) + readDone <- err + }() + + // Wait for the error to be reported via error channel + receivedErrorEvent := testutil.RequireReceive(ctx, t, errChan) + require.Error(t, receivedErrorEvent.Err) + require.Equal(t, "reader", receivedErrorEvent.Component) + require.Contains(t, receivedErrorEvent.Err.Error(), "connection lost") + + // Verify read is still blocked + select { + case err := <-readDone: + t.Fatalf("Read should still be blocked, but completed with: %v", err) + default: + // Good, still blocked + } + + // Verify disconnection + require.False(t, br.Connected()) + + // Reconnect with new reader + reader2 := newMockReader("second") + seqNum2 := make(chan uint64, 1) + newR2 := make(chan io.Reader, 1) + + go br.Reconnect(seqNum2, newR2) + + // Get sequence number and send new reader + seq := testutil.RequireReceive(ctx, t, seqNum2) + require.Equal(t, uint64(5), seq) // Should return current sequence number + testutil.RequireSend(ctx, t, newR2, io.Reader(reader2)) + + // Wait for read to unblock and succeed with new data + readErr := testutil.RequireReceive(ctx, t, readDone) + require.NoError(t, readErr) // Should succeed with new reader + require.True(t, br.Connected()) +} + +func TestBackedReader_Close(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + + errChan := make(chan backedpipe.ErrorEvent, 1) + br := backedpipe.NewBackedReader(errChan) + reader := newMockReader("test") + + // Connect + seqNum := make(chan uint64, 1) + newR := make(chan io.Reader, 1) + + go br.Reconnect(seqNum, newR) + + // Get sequence number and send new reader + testutil.RequireReceive(ctx, t, seqNum) + testutil.RequireSend(ctx, t, newR, io.Reader(reader)) + + // First, read all available data + buf := make([]byte, 10) + n, err := br.Read(buf) + require.NoError(t, err) + require.Equal(t, 4, n) // "test" is 4 bytes + + // Close the reader before EOF triggers reconnection + err = br.Close() + require.NoError(t, err) + + // After close, reads should return EOF + n, err = br.Read(buf) + require.Equal(t, 0, n) + require.Equal(t, io.EOF, err) + + // Subsequent reads should return EOF + _, err = br.Read(buf) + require.Equal(t, io.EOF, err) +} + +func TestBackedReader_CloseIdempotent(t *testing.T) { + t.Parallel() + + errChan := make(chan backedpipe.ErrorEvent, 1) + br := backedpipe.NewBackedReader(errChan) + + err := br.Close() + require.NoError(t, err) + + // Second close should be no-op + err = br.Close() + require.NoError(t, err) +} + +func TestBackedReader_ReconnectAfterClose(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + + errChan := make(chan backedpipe.ErrorEvent, 1) + br := backedpipe.NewBackedReader(errChan) + + err := br.Close() + require.NoError(t, err) + + seqNum := make(chan uint64, 1) + newR := make(chan io.Reader, 1) + + go br.Reconnect(seqNum, newR) + + // Should get 0 sequence number for closed reader + seq := testutil.TryReceive(ctx, t, seqNum) + require.Equal(t, uint64(0), seq) +} + +// Helper function to reconnect a reader using channels +func reconnectReader(ctx context.Context, t testing.TB, br *backedpipe.BackedReader, reader io.Reader) { + seqNum := make(chan uint64, 1) + newR := make(chan io.Reader, 1) + + go br.Reconnect(seqNum, newR) + + // Get sequence number and send new reader + testutil.RequireReceive(ctx, t, seqNum) + testutil.RequireSend(ctx, t, newR, reader) +} + +func TestBackedReader_SequenceNumberTracking(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + + errChan := make(chan backedpipe.ErrorEvent, 1) + br := backedpipe.NewBackedReader(errChan) + reader := newMockReader("0123456789") + + reconnectReader(ctx, t, br, reader) + + // Read in chunks and verify sequence number + buf := make([]byte, 3) + + n, err := br.Read(buf) + require.NoError(t, err) + require.Equal(t, 3, n) + require.Equal(t, uint64(3), br.SequenceNum()) + + n, err = br.Read(buf) + require.NoError(t, err) + require.Equal(t, 3, n) + require.Equal(t, uint64(6), br.SequenceNum()) + + n, err = br.Read(buf) + require.NoError(t, err) + require.Equal(t, 3, n) + require.Equal(t, uint64(9), br.SequenceNum()) +} + +func TestBackedReader_EOFHandling(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + + errChan := make(chan backedpipe.ErrorEvent, 1) + br := backedpipe.NewBackedReader(errChan) + reader := newMockReader("test") + + reconnectReader(ctx, t, br, reader) + + // Read all data + buf := make([]byte, 10) + n, err := br.Read(buf) + require.NoError(t, err) + require.Equal(t, 4, n) + require.Equal(t, "test", string(buf[:n])) + + // Next read should encounter EOF, which triggers disconnection + // The read should block waiting for reconnection + readDone := make(chan struct{}) + var readErr error + var readN int + + go func() { + defer close(readDone) + readN, readErr = br.Read(buf) + }() + + // Wait for EOF to be reported via error channel + receivedErrorEvent := testutil.RequireReceive(ctx, t, errChan) + require.Equal(t, io.EOF, receivedErrorEvent.Err) + require.Equal(t, "reader", receivedErrorEvent.Component) + + // Reader should be disconnected after EOF + require.False(t, br.Connected()) + + // Read should still be blocked + select { + case <-readDone: + t.Fatal("Read should be blocked waiting for reconnection after EOF") + default: + // Good, still blocked + } + + // Reconnect with new data + reader2 := newMockReader("more") + reconnectReader(ctx, t, br, reader2) + + // Wait for the blocked read to complete with new data + testutil.TryReceive(ctx, t, readDone) + require.NoError(t, readErr) + require.Equal(t, 4, readN) + require.Equal(t, "more", string(buf[:readN])) +} + +func BenchmarkBackedReader_Read(b *testing.B) { + errChan := make(chan backedpipe.ErrorEvent, 1) + br := backedpipe.NewBackedReader(errChan) + buf := make([]byte, 1024) + + // Create a reader that never returns EOF by cycling through data + reader := &mockReader{ + readFunc: func(p []byte) (int, error) { + // Fill buffer with 'x' characters - never EOF + for i := range p { + p[i] = 'x' + } + return len(p), nil + }, + } + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) + defer cancel() + reconnectReader(ctx, b, br, reader) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + br.Read(buf) + } +} + +func TestBackedReader_PartialReads(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + + errChan := make(chan backedpipe.ErrorEvent, 1) + br := backedpipe.NewBackedReader(errChan) + + // Create a reader that returns partial reads + reader := &mockReader{ + readFunc: func(p []byte) (int, error) { + // Always return just 1 byte at a time + if len(p) == 0 { + return 0, nil + } + p[0] = 'A' + return 1, nil + }, + } + + reconnectReader(ctx, t, br, reader) + + // Read multiple times + buf := make([]byte, 10) + for i := 0; i < 5; i++ { + n, err := br.Read(buf) + require.NoError(t, err) + require.Equal(t, 1, n) + require.Equal(t, byte('A'), buf[0]) + } + + require.Equal(t, uint64(5), br.SequenceNum()) +} + +func TestBackedReader_CloseWhileBlockedOnUnderlyingReader(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + + errChan := make(chan backedpipe.ErrorEvent, 1) + br := backedpipe.NewBackedReader(errChan) + + // Create a reader that blocks on Read calls but can be unblocked + readStarted := make(chan struct{}, 1) + readUnblocked := make(chan struct{}) + blockingReader := &mockReader{ + readFunc: func(p []byte) (int, error) { + select { + case readStarted <- struct{}{}: + default: + } + <-readUnblocked // Block until signaled + // After unblocking, return an error to simulate connection failure + return 0, xerrors.New("connection interrupted") + }, + } + + // Connect the blocking reader + seqNum := make(chan uint64, 1) + newR := make(chan io.Reader, 1) + + go br.Reconnect(seqNum, newR) + + // Get sequence number and send blocking reader + testutil.RequireReceive(ctx, t, seqNum) + testutil.RequireSend(ctx, t, newR, io.Reader(blockingReader)) + + // Start a read that will block on the underlying reader + readDone := make(chan struct{}) + var readErr error + var readN int + + go func() { + defer close(readDone) + buf := make([]byte, 10) + readN, readErr = br.Read(buf) + }() + + // Wait for the read to start and block on the underlying reader + testutil.RequireReceive(ctx, t, readStarted) + + // Verify read is blocked by checking that it hasn't completed + // and ensuring we have adequate time for it to reach the blocking state + require.Eventually(t, func() bool { + select { + case <-readDone: + t.Fatal("Read should be blocked on underlying reader") + return false + default: + // Good, still blocked + return true + } + }, testutil.WaitShort, testutil.IntervalMedium) + + // Start Close() in a goroutine since it will block until the underlying read completes + closeDone := make(chan error, 1) + go func() { + closeDone <- br.Close() + }() + + // Verify Close() is also blocked waiting for the underlying read + select { + case <-closeDone: + t.Fatal("Close should be blocked until underlying read completes") + case <-time.After(10 * time.Millisecond): + // Good, Close is blocked + } + + // Unblock the underlying reader, which will cause both the read and close to complete + close(readUnblocked) + + // Wait for both the read and close to complete + testutil.TryReceive(ctx, t, readDone) + closeErr := testutil.RequireReceive(ctx, t, closeDone) + require.NoError(t, closeErr) + + // The read should return EOF because Close() was called while it was blocked, + // even though the underlying reader returned an error + require.Equal(t, 0, readN) + require.Equal(t, io.EOF, readErr) + + // Subsequent reads should return EOF since the reader is now closed + buf := make([]byte, 10) + n, err := br.Read(buf) + require.Equal(t, 0, n) + require.Equal(t, io.EOF, err) +} + +func TestBackedReader_CloseWhileBlockedWaitingForReconnect(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + + errChan := make(chan backedpipe.ErrorEvent, 1) + br := backedpipe.NewBackedReader(errChan) + reader1 := newMockReader("initial") + + // Initial connection + seqNum := make(chan uint64, 1) + newR := make(chan io.Reader, 1) + + go br.Reconnect(seqNum, newR) + + // Get sequence number and send initial reader + testutil.RequireReceive(ctx, t, seqNum) + testutil.RequireSend(ctx, t, newR, io.Reader(reader1)) + + // Read initial data + buf := make([]byte, 10) + n, err := br.Read(buf) + require.NoError(t, err) + require.Equal(t, "initial", string(buf[:n])) + + // Simulate connection failure + reader1.setError(xerrors.New("connection lost")) + + // Start a read that will block waiting for reconnection + readDone := make(chan struct{}) + var readErr error + var readN int + + go func() { + defer close(readDone) + readN, readErr = br.Read(buf) + }() + + // Wait for the error to be reported (indicating disconnection) + receivedErrorEvent := testutil.RequireReceive(ctx, t, errChan) + require.Error(t, receivedErrorEvent.Err) + require.Equal(t, "reader", receivedErrorEvent.Component) + require.Contains(t, receivedErrorEvent.Err.Error(), "connection lost") + + // Verify read is blocked waiting for reconnection + select { + case <-readDone: + t.Fatal("Read should be blocked waiting for reconnection") + default: + // Good, still blocked + } + + // Verify reader is disconnected + require.False(t, br.Connected()) + + // Close the BackedReader while read is blocked waiting for reconnection + err = br.Close() + require.NoError(t, err) + + // The read should unblock and return EOF + testutil.TryReceive(ctx, t, readDone) + require.Equal(t, 0, readN) + require.Equal(t, io.EOF, readErr) +} diff --git a/agent/immortalstreams/backedpipe/backed_writer.go b/agent/immortalstreams/backedpipe/backed_writer.go new file mode 100644 index 0000000000..e4093e48f2 --- /dev/null +++ b/agent/immortalstreams/backedpipe/backed_writer.go @@ -0,0 +1,243 @@ +package backedpipe + +import ( + "io" + "os" + "sync" + + "golang.org/x/xerrors" +) + +var ( + ErrWriterClosed = xerrors.New("cannot reconnect closed writer") + ErrNilWriter = xerrors.New("new writer cannot be nil") + ErrFutureSequence = xerrors.New("cannot replay from future sequence") + ErrReplayDataUnavailable = xerrors.New("failed to read replay data") + ErrReplayFailed = xerrors.New("replay failed") + ErrPartialReplay = xerrors.New("partial replay") +) + +// BackedWriter wraps an unreliable io.Writer and makes it resilient to disconnections. +// It maintains a ring buffer of recent writes for replay during reconnection. +type BackedWriter struct { + mu sync.Mutex + cond *sync.Cond + writer io.Writer + buffer *ringBuffer + sequenceNum uint64 // total bytes written + closed bool + + // Error channel for generation-aware error reporting + errorEventChan chan<- ErrorEvent + + // Current connection generation for error reporting + currentGen uint64 +} + +// NewBackedWriter creates a new BackedWriter with generation-aware error reporting. +// The writer is initially disconnected and will block writes until connected. +// The errorEventChan will receive ErrorEvent structs containing error details, +// component info, and connection generation. Capacity must be > 0. +func NewBackedWriter(capacity int, errorEventChan chan<- ErrorEvent) *BackedWriter { + if capacity <= 0 { + panic("backed writer capacity must be > 0") + } + if errorEventChan == nil { + panic("error event channel cannot be nil") + } + bw := &BackedWriter{ + buffer: newRingBuffer(capacity), + errorEventChan: errorEventChan, + } + bw.cond = sync.NewCond(&bw.mu) + return bw +} + +// blockUntilConnectedOrClosed blocks until either a writer is available or the BackedWriter is closed. +// Returns os.ErrClosed if closed while waiting, nil if connected. You must hold the mutex to call this. +func (bw *BackedWriter) blockUntilConnectedOrClosed() error { + for bw.writer == nil && !bw.closed { + bw.cond.Wait() + } + if bw.closed { + return os.ErrClosed + } + return nil +} + +// Write implements io.Writer. +// When connected, it writes to both the ring buffer (to preserve data in case we need to replay it) +// and the underlying writer. +// If the underlying write fails, the writer is marked as disconnected and the write blocks +// until reconnection occurs. +func (bw *BackedWriter) Write(p []byte) (int, error) { + if len(p) == 0 { + return 0, nil + } + + bw.mu.Lock() + defer bw.mu.Unlock() + + // Block until connected + if err := bw.blockUntilConnectedOrClosed(); err != nil { + return 0, err + } + + // Write to buffer + bw.buffer.Write(p) + bw.sequenceNum += uint64(len(p)) + + // Try to write to underlying writer + n, err := bw.writer.Write(p) + if err == nil && n != len(p) { + err = io.ErrShortWrite + } + + if err != nil { + // Connection failed or partial write, mark as disconnected + bw.writer = nil + + // Notify parent of error with generation information + select { + case bw.errorEventChan <- ErrorEvent{ + Err: err, + Component: "writer", + Generation: bw.currentGen, + }: + default: + // Channel is full, drop the error. + // This is not a problem, because we set the writer to nil + // and block until reconnected so no new errors will be sent + // until pipe processes the error and reconnects. + } + + // Block until reconnected - reconnection will replay this data + if err := bw.blockUntilConnectedOrClosed(); err != nil { + return 0, err + } + + // Don't retry - reconnection replay handled it + return len(p), nil + } + + // Write succeeded + return len(p), nil +} + +// Reconnect replaces the current writer with a new one and replays data from the specified +// sequence number. If the requested sequence number is no longer in the buffer, +// returns an error indicating data loss. +// +// IMPORTANT: You must close the current writer, if any, before calling this method. +// Otherwise, if a Write operation is currently blocked in the underlying writer's +// Write method, this method will deadlock waiting for the mutex that Write holds. +func (bw *BackedWriter) Reconnect(replayFromSeq uint64, newWriter io.Writer) error { + bw.mu.Lock() + defer bw.mu.Unlock() + + if bw.closed { + return ErrWriterClosed + } + + if newWriter == nil { + return ErrNilWriter + } + + // Check if we can replay from the requested sequence number + if replayFromSeq > bw.sequenceNum { + return ErrFutureSequence + } + + // Calculate how many bytes we need to replay + replayBytes := bw.sequenceNum - replayFromSeq + + var replayData []byte + if replayBytes > 0 { + // Get the last replayBytes from buffer + // If the buffer doesn't have enough data (some was evicted), + // ReadLast will return an error + var err error + // Safe conversion: The check above (replayFromSeq > bw.sequenceNum) ensures + // replayBytes = bw.sequenceNum - replayFromSeq is always <= bw.sequenceNum. + // Since sequence numbers are much smaller than maxInt, the uint64->int conversion is safe. + //nolint:gosec // Safe conversion: replayBytes <= sequenceNum, which is much less than maxInt + replayData, err = bw.buffer.ReadLast(int(replayBytes)) + if err != nil { + return ErrReplayDataUnavailable + } + } + + // Clear the current writer first in case replay fails + bw.writer = nil + + // Replay data if needed. We keep the mutex held during replay to ensure + // no concurrent operations can interfere with the reconnection process. + if len(replayData) > 0 { + n, err := newWriter.Write(replayData) + if err != nil { + // Reconnect failed, writer remains nil + return ErrReplayFailed + } + + if n != len(replayData) { + // Reconnect failed, writer remains nil + return ErrPartialReplay + } + } + + // Set new writer only after successful replay. This ensures no concurrent + // writes can interfere with the replay operation. + bw.writer = newWriter + + // Wake up any operations waiting for connection + bw.cond.Broadcast() + + return nil +} + +// Close closes the writer and prevents further writes. +// After closing, all Write calls will return os.ErrClosed. +// This code keeps the Close() signature consistent with io.Closer, +// but it never actually returns an error. +// +// IMPORTANT: You must close the current underlying writer, if any, before calling +// this method. Otherwise, if a Write operation is currently blocked in the +// underlying writer's Write method, this method will deadlock waiting for the +// mutex that Write holds. +func (bw *BackedWriter) Close() error { + bw.mu.Lock() + defer bw.mu.Unlock() + + if bw.closed { + return nil + } + + bw.closed = true + bw.writer = nil + + // Wake up any blocked operations + bw.cond.Broadcast() + + return nil +} + +// SequenceNum returns the current sequence number (total bytes written). +func (bw *BackedWriter) SequenceNum() uint64 { + bw.mu.Lock() + defer bw.mu.Unlock() + return bw.sequenceNum +} + +// Connected returns whether the writer is currently connected. +func (bw *BackedWriter) Connected() bool { + bw.mu.Lock() + defer bw.mu.Unlock() + return bw.writer != nil +} + +// SetGeneration sets the current connection generation for error reporting. +func (bw *BackedWriter) SetGeneration(generation uint64) { + bw.mu.Lock() + defer bw.mu.Unlock() + bw.currentGen = generation +} diff --git a/agent/immortalstreams/backedpipe/backed_writer_test.go b/agent/immortalstreams/backedpipe/backed_writer_test.go new file mode 100644 index 0000000000..a1a77b36bc --- /dev/null +++ b/agent/immortalstreams/backedpipe/backed_writer_test.go @@ -0,0 +1,996 @@ +package backedpipe_test + +import ( + "bytes" + "os" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/agent/immortalstreams/backedpipe" + "github.com/coder/coder/v2/testutil" +) + +// mockWriter implements io.Writer with controllable behavior for testing +type mockWriter struct { + mu sync.Mutex + buffer bytes.Buffer + err error + writeFunc func([]byte) (int, error) + writeCalls int +} + +func newMockWriter() *mockWriter { + return &mockWriter{} +} + +// newBackedWriterForTest creates a BackedWriter with a small buffer for testing eviction behavior +func newBackedWriterForTest(bufferSize int) *backedpipe.BackedWriter { + errChan := make(chan backedpipe.ErrorEvent, 1) + return backedpipe.NewBackedWriter(bufferSize, errChan) +} + +func (mw *mockWriter) Write(p []byte) (int, error) { + mw.mu.Lock() + defer mw.mu.Unlock() + + mw.writeCalls++ + + if mw.writeFunc != nil { + return mw.writeFunc(p) + } + + if mw.err != nil { + return 0, mw.err + } + + return mw.buffer.Write(p) +} + +func (mw *mockWriter) Len() int { + mw.mu.Lock() + defer mw.mu.Unlock() + return mw.buffer.Len() +} + +func (mw *mockWriter) Reset() { + mw.mu.Lock() + defer mw.mu.Unlock() + mw.buffer.Reset() + mw.writeCalls = 0 + mw.err = nil + mw.writeFunc = nil +} + +func (mw *mockWriter) setError(err error) { + mw.mu.Lock() + defer mw.mu.Unlock() + mw.err = err +} + +func TestBackedWriter_NewBackedWriter(t *testing.T) { + t.Parallel() + + errChan := make(chan backedpipe.ErrorEvent, 1) + bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errChan) + require.NotNil(t, bw) + require.Equal(t, uint64(0), bw.SequenceNum()) + require.False(t, bw.Connected()) +} + +func TestBackedWriter_WriteBlocksWhenDisconnected(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + + errChan := make(chan backedpipe.ErrorEvent, 1) + bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errChan) + + // Write should block when disconnected + writeComplete := make(chan struct{}) + var writeErr error + var n int + + go func() { + defer close(writeComplete) + n, writeErr = bw.Write([]byte("hello")) + }() + + // Verify write is blocked + select { + case <-writeComplete: + t.Fatal("Write should have blocked when disconnected") + case <-time.After(50 * time.Millisecond): + // Expected - write is blocked + } + + // Connect and verify write completes + writer := newMockWriter() + err := bw.Reconnect(0, writer) + require.NoError(t, err) + + // Write should now complete + testutil.TryReceive(ctx, t, writeComplete) + + require.NoError(t, writeErr) + require.Equal(t, 5, n) + require.Equal(t, uint64(5), bw.SequenceNum()) + require.Equal(t, []byte("hello"), writer.buffer.Bytes()) +} + +func TestBackedWriter_WriteToUnderlyingWhenConnected(t *testing.T) { + t.Parallel() + + errChan := make(chan backedpipe.ErrorEvent, 1) + bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errChan) + writer := newMockWriter() + + // Connect + err := bw.Reconnect(0, writer) + require.NoError(t, err) + require.True(t, bw.Connected()) + + // Write should go to both buffer and underlying writer + n, err := bw.Write([]byte("hello")) + require.NoError(t, err) + require.Equal(t, 5, n) + + // Data should be buffered + require.Equal(t, uint64(5), bw.SequenceNum()) + + // Check underlying writer + require.Equal(t, []byte("hello"), writer.buffer.Bytes()) +} + +func TestBackedWriter_BlockOnWriteFailure(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + + errChan := make(chan backedpipe.ErrorEvent, 1) + bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errChan) + writer := newMockWriter() + + // Connect + err := bw.Reconnect(0, writer) + require.NoError(t, err) + + // Cause write to fail + writer.setError(xerrors.New("write failed")) + + // Write should block when underlying writer fails, not succeed immediately + writeComplete := make(chan struct{}) + var writeErr error + var n int + + go func() { + defer close(writeComplete) + n, writeErr = bw.Write([]byte("hello")) + }() + + // Verify write is blocked + select { + case <-writeComplete: + t.Fatal("Write should have blocked when underlying writer fails") + case <-time.After(50 * time.Millisecond): + // Expected - write is blocked + } + + // Should be disconnected + require.False(t, bw.Connected()) + + // Error should be sent to error channel + select { + case receivedErrorEvent := <-errChan: + require.Contains(t, receivedErrorEvent.Err.Error(), "write failed") + require.Equal(t, "writer", receivedErrorEvent.Component) + default: + t.Fatal("Expected error to be sent to error channel") + } + + // Reconnect with working writer and verify write completes + writer2 := newMockWriter() + err = bw.Reconnect(0, writer2) // Replay from beginning + require.NoError(t, err) + + // Write should now complete + testutil.TryReceive(ctx, t, writeComplete) + + require.NoError(t, writeErr) + require.Equal(t, 5, n) + require.Equal(t, uint64(5), bw.SequenceNum()) + require.Equal(t, []byte("hello"), writer2.buffer.Bytes()) +} + +func TestBackedWriter_ReplayOnReconnect(t *testing.T) { + t.Parallel() + + errChan := make(chan backedpipe.ErrorEvent, 1) + bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errChan) + + // Connect initially to write some data + writer1 := newMockWriter() + err := bw.Reconnect(0, writer1) + require.NoError(t, err) + + // Write some data while connected + _, err = bw.Write([]byte("hello")) + require.NoError(t, err) + _, err = bw.Write([]byte(" world")) + require.NoError(t, err) + + require.Equal(t, uint64(11), bw.SequenceNum()) + + // Disconnect by causing a write failure + writer1.setError(xerrors.New("connection lost")) + + // Write should block when underlying writer fails + writeComplete := make(chan struct{}) + var writeErr error + var n int + + go func() { + defer close(writeComplete) + n, writeErr = bw.Write([]byte("test")) + }() + + // Verify write is blocked + select { + case <-writeComplete: + t.Fatal("Write should have blocked when underlying writer fails") + case <-time.After(50 * time.Millisecond): + // Expected - write is blocked + } + + require.False(t, bw.Connected()) + + // Reconnect with new writer and request replay from beginning + writer2 := newMockWriter() + err = bw.Reconnect(0, writer2) + require.NoError(t, err) + + // Write should now complete + select { + case <-writeComplete: + // Expected - write completed + case <-time.After(100 * time.Millisecond): + t.Fatal("Write should have completed after reconnection") + } + + require.NoError(t, writeErr) + require.Equal(t, 4, n) + + // Should have replayed all data including the failed write that was buffered + require.Equal(t, []byte("hello worldtest"), writer2.buffer.Bytes()) + + // Write new data should go to both + _, err = bw.Write([]byte("!")) + require.NoError(t, err) + require.Equal(t, []byte("hello worldtest!"), writer2.buffer.Bytes()) +} + +func TestBackedWriter_PartialReplay(t *testing.T) { + t.Parallel() + + errChan := make(chan backedpipe.ErrorEvent, 1) + bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errChan) + + // Connect initially to write some data + writer1 := newMockWriter() + err := bw.Reconnect(0, writer1) + require.NoError(t, err) + + // Write some data + _, err = bw.Write([]byte("hello")) + require.NoError(t, err) + _, err = bw.Write([]byte(" world")) + require.NoError(t, err) + _, err = bw.Write([]byte("!")) + require.NoError(t, err) + + // Reconnect with new writer and request replay from middle + writer2 := newMockWriter() + err = bw.Reconnect(5, writer2) // From " world!" + require.NoError(t, err) + + // Should have replayed only the requested portion + require.Equal(t, []byte(" world!"), writer2.buffer.Bytes()) +} + +func TestBackedWriter_ReplayFromFutureSequence(t *testing.T) { + t.Parallel() + + errChan := make(chan backedpipe.ErrorEvent, 1) + bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errChan) + + // Connect initially to write some data + writer1 := newMockWriter() + err := bw.Reconnect(0, writer1) + require.NoError(t, err) + + _, err = bw.Write([]byte("hello")) + require.NoError(t, err) + + writer2 := newMockWriter() + err = bw.Reconnect(10, writer2) // Future sequence + require.Error(t, err) + require.ErrorIs(t, err, backedpipe.ErrFutureSequence) +} + +func TestBackedWriter_ReplayDataLoss(t *testing.T) { + t.Parallel() + + bw := newBackedWriterForTest(10) // Small buffer for testing + + // Connect initially to write some data + writer1 := newMockWriter() + err := bw.Reconnect(0, writer1) + require.NoError(t, err) + + // Fill buffer beyond capacity to cause eviction + _, err = bw.Write([]byte("0123456789")) // Fills buffer exactly + require.NoError(t, err) + _, err = bw.Write([]byte("abcdef")) // Should evict "012345" + require.NoError(t, err) + + writer2 := newMockWriter() + err = bw.Reconnect(0, writer2) // Try to replay from evicted data + // With the new error handling, this should fail because we can't read all the data + require.Error(t, err) + require.ErrorIs(t, err, backedpipe.ErrReplayDataUnavailable) +} + +func TestBackedWriter_BufferEviction(t *testing.T) { + t.Parallel() + + bw := newBackedWriterForTest(5) // Very small buffer for testing + + // Connect initially + writer := newMockWriter() + err := bw.Reconnect(0, writer) + require.NoError(t, err) + + // Write data that will cause eviction + n, err := bw.Write([]byte("abcde")) + require.NoError(t, err) + require.Equal(t, 5, n) + + // Write more to cause eviction + n, err = bw.Write([]byte("fg")) + require.NoError(t, err) + require.Equal(t, 2, n) + + // Verify that the buffer contains only the latest data after eviction + // Total sequence number should be 7 (5 + 2) + require.Equal(t, uint64(7), bw.SequenceNum()) + + // Try to reconnect from the beginning - this should fail because + // the early data was evicted from the buffer + writer2 := newMockWriter() + err = bw.Reconnect(0, writer2) + require.Error(t, err) + require.ErrorIs(t, err, backedpipe.ErrReplayDataUnavailable) + + // However, reconnecting from a sequence that's still in the buffer should work + // The buffer should contain the last 5 bytes: "cdefg" + writer3 := newMockWriter() + err = bw.Reconnect(2, writer3) // From sequence 2, should replay "cdefg" + require.NoError(t, err) + require.Equal(t, []byte("cdefg"), writer3.buffer.Bytes()) + require.True(t, bw.Connected()) +} + +func TestBackedWriter_Close(t *testing.T) { + t.Parallel() + + errChan := make(chan backedpipe.ErrorEvent, 1) + bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errChan) + writer := newMockWriter() + + bw.Reconnect(0, writer) + + err := bw.Close() + require.NoError(t, err) + + // Writes after close should fail + _, err = bw.Write([]byte("test")) + require.Equal(t, os.ErrClosed, err) + + // Reconnect after close should fail + err = bw.Reconnect(0, newMockWriter()) + require.Error(t, err) + require.ErrorIs(t, err, backedpipe.ErrWriterClosed) +} + +func TestBackedWriter_CloseIdempotent(t *testing.T) { + t.Parallel() + + errChan := make(chan backedpipe.ErrorEvent, 1) + bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errChan) + + err := bw.Close() + require.NoError(t, err) + + // Second close should be no-op + err = bw.Close() + require.NoError(t, err) +} + +func TestBackedWriter_ReconnectDuringReplay(t *testing.T) { + t.Parallel() + + errChan := make(chan backedpipe.ErrorEvent, 1) + bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errChan) + + // Connect initially to write some data + writer1 := newMockWriter() + err := bw.Reconnect(0, writer1) + require.NoError(t, err) + + _, err = bw.Write([]byte("hello world")) + require.NoError(t, err) + + // Create a writer that fails during replay + writer2 := &mockWriter{ + err: backedpipe.ErrReplayFailed, + } + + err = bw.Reconnect(0, writer2) + require.Error(t, err) + require.ErrorIs(t, err, backedpipe.ErrReplayFailed) + require.False(t, bw.Connected()) +} + +func TestBackedWriter_BlockOnPartialWrite(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + + errChan := make(chan backedpipe.ErrorEvent, 1) + bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errChan) + + // Create writer that does partial writes + writer := &mockWriter{ + writeFunc: func(p []byte) (int, error) { + if len(p) > 3 { + return 3, nil // Only write first 3 bytes + } + return len(p), nil + }, + } + + bw.Reconnect(0, writer) + + // Write should block due to partial write + writeComplete := make(chan struct{}) + var writeErr error + var n int + + go func() { + defer close(writeComplete) + n, writeErr = bw.Write([]byte("hello")) + }() + + // Verify write is blocked + select { + case <-writeComplete: + t.Fatal("Write should have blocked when underlying writer does partial write") + case <-time.After(50 * time.Millisecond): + // Expected - write is blocked + } + + // Should be disconnected + require.False(t, bw.Connected()) + + // Error should be sent to error channel + select { + case receivedErrorEvent := <-errChan: + require.Contains(t, receivedErrorEvent.Err.Error(), "short write") + require.Equal(t, "writer", receivedErrorEvent.Component) + default: + t.Fatal("Expected error to be sent to error channel") + } + + // Reconnect with working writer and verify write completes + writer2 := newMockWriter() + err := bw.Reconnect(0, writer2) // Replay from beginning + require.NoError(t, err) + + // Write should now complete + testutil.TryReceive(ctx, t, writeComplete) + + require.NoError(t, writeErr) + require.Equal(t, 5, n) + require.Equal(t, uint64(5), bw.SequenceNum()) + require.Equal(t, []byte("hello"), writer2.buffer.Bytes()) +} + +func TestBackedWriter_WriteUnblocksOnReconnect(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + + errChan := make(chan backedpipe.ErrorEvent, 1) + bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errChan) + + // Start a single write that should block + writeResult := make(chan error, 1) + go func() { + _, err := bw.Write([]byte("test")) + writeResult <- err + }() + + // Verify write is blocked + select { + case <-writeResult: + t.Fatal("Write should have blocked when disconnected") + case <-time.After(50 * time.Millisecond): + // Expected - write is blocked + } + + // Connect and verify write completes + writer := newMockWriter() + err := bw.Reconnect(0, writer) + require.NoError(t, err) + + // Write should now complete + err = testutil.RequireReceive(ctx, t, writeResult) + require.NoError(t, err) + + // Write should have been written to the underlying writer + require.Equal(t, "test", writer.buffer.String()) +} + +func TestBackedWriter_CloseUnblocksWaitingWrites(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + + errChan := make(chan backedpipe.ErrorEvent, 1) + bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errChan) + + // Start a write that should block + writeComplete := make(chan error, 1) + go func() { + _, err := bw.Write([]byte("test")) + writeComplete <- err + }() + + // Verify write is blocked + select { + case <-writeComplete: + t.Fatal("Write should have blocked when disconnected") + case <-time.After(50 * time.Millisecond): + // Expected - write is blocked + } + + // Close the writer + err := bw.Close() + require.NoError(t, err) + + // Write should now complete with error + err = testutil.RequireReceive(ctx, t, writeComplete) + require.Equal(t, os.ErrClosed, err) +} + +func TestBackedWriter_WriteBlocksAfterDisconnection(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + + errChan := make(chan backedpipe.ErrorEvent, 1) + bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errChan) + writer := newMockWriter() + + // Connect initially + err := bw.Reconnect(0, writer) + require.NoError(t, err) + + // Write should succeed when connected + _, err = bw.Write([]byte("hello")) + require.NoError(t, err) + + // Cause disconnection - the write should now block instead of returning an error + writer.setError(xerrors.New("connection lost")) + + // This write should block + writeComplete := make(chan error, 1) + go func() { + _, err := bw.Write([]byte("world")) + writeComplete <- err + }() + + // Verify write is blocked + select { + case <-writeComplete: + t.Fatal("Write should have blocked after disconnection") + case <-time.After(50 * time.Millisecond): + // Expected - write is blocked + } + + // Should be disconnected + require.False(t, bw.Connected()) + + // Reconnect and verify write completes + writer2 := newMockWriter() + err = bw.Reconnect(5, writer2) // Replay from after "hello" + require.NoError(t, err) + + err = testutil.RequireReceive(ctx, t, writeComplete) + require.NoError(t, err) + + // Check that only "world" was written during replay (not duplicated) + require.Equal(t, []byte("world"), writer2.buffer.Bytes()) // Only "world" since we replayed from sequence 5 +} + +func TestBackedWriter_ConcurrentWriteAndClose(t *testing.T) { + t.Parallel() + + errChan := make(chan backedpipe.ErrorEvent, 1) + bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errChan) + + // Don't connect initially - this will cause writes to block in blockUntilConnectedOrClosed() + + writeStarted := make(chan struct{}, 1) + + // Start a write operation that will block waiting for connection + writeComplete := make(chan struct{}) + var writeErr error + var n int + + go func() { + defer close(writeComplete) + // Signal that we're about to start the write + writeStarted <- struct{}{} + // This write will block in blockUntilConnectedOrClosed() since no writer is connected + n, writeErr = bw.Write([]byte("hello")) + }() + + // Wait for write goroutine to start + ctx := testutil.Context(t, testutil.WaitShort) + testutil.RequireReceive(ctx, t, writeStarted) + + // Ensure the write is actually blocked by repeatedly checking that: + // 1. The write hasn't completed yet + // 2. The writer is still not connected + // We use require.Eventually to give it a fair chance to reach the blocking state + require.Eventually(t, func() bool { + select { + case <-writeComplete: + t.Fatal("Write should be blocked when no writer is connected") + return false + default: + // Write is still blocked, which is what we want + return !bw.Connected() + } + }, testutil.WaitShort, testutil.IntervalMedium) + + // Close the writer while the write is blocked waiting for connection + closeErr := bw.Close() + require.NoError(t, closeErr) + + // Wait for write to complete + select { + case <-writeComplete: + // Good, write completed + case <-ctx.Done(): + t.Fatal("Write did not complete in time") + } + + // The write should have failed with os.ErrClosed because Close() was called + // while it was waiting for connection + require.ErrorIs(t, writeErr, os.ErrClosed) + require.Equal(t, 0, n) + + // Subsequent writes should also fail + n, err := bw.Write([]byte("world")) + require.Equal(t, 0, n) + require.ErrorIs(t, err, os.ErrClosed) +} + +func TestBackedWriter_ConcurrentWriteAndReconnect(t *testing.T) { + t.Parallel() + + errChan := make(chan backedpipe.ErrorEvent, 1) + bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errChan) + + // Initial connection + writer1 := newMockWriter() + err := bw.Reconnect(0, writer1) + require.NoError(t, err) + + // Write some initial data + _, err = bw.Write([]byte("initial")) + require.NoError(t, err) + + // Start reconnection which will block new writes + replayStarted := make(chan struct{}, 1) // Buffered to prevent race condition + replayCanComplete := make(chan struct{}) + writer2 := &mockWriter{ + writeFunc: func(p []byte) (int, error) { + // Signal that replay has started + select { + case replayStarted <- struct{}{}: + default: + // Signal already sent, which is fine + } + // Wait for test to allow replay to complete + <-replayCanComplete + return len(p), nil + }, + } + + // Start the reconnection in a goroutine so we can control timing + reconnectComplete := make(chan error, 1) + go func() { + reconnectComplete <- bw.Reconnect(0, writer2) + }() + + ctx := testutil.Context(t, testutil.WaitShort) + // Wait for replay to start + testutil.RequireReceive(ctx, t, replayStarted) + + // Now start a write operation that will be blocked by the ongoing reconnect + writeStarted := make(chan struct{}, 1) + writeComplete := make(chan struct{}) + var writeErr error + var n int + + go func() { + defer close(writeComplete) + // Signal that we're about to start the write + writeStarted <- struct{}{} + // This write should be blocked during reconnect + n, writeErr = bw.Write([]byte("blocked")) + }() + + // Wait for write to start + testutil.RequireReceive(ctx, t, writeStarted) + + // Use a small timeout to ensure the write goroutine has a chance to get blocked + // on the mutex before we check if it's still blocked + writeCheckTimer := time.NewTimer(testutil.IntervalFast) + defer writeCheckTimer.Stop() + + select { + case <-writeComplete: + t.Fatal("Write should be blocked during reconnect") + case <-writeCheckTimer.C: + // Write is still blocked after a reasonable wait + } + + // Allow replay to complete, which will allow reconnect to finish + close(replayCanComplete) + + // Wait for reconnection to complete + select { + case reconnectErr := <-reconnectComplete: + require.NoError(t, reconnectErr) + case <-ctx.Done(): + t.Fatal("Reconnect did not complete in time") + } + + // Wait for write to complete + <-writeComplete + + // Write should succeed after reconnection completes + require.NoError(t, writeErr) + require.Equal(t, 7, n) // "blocked" is 7 bytes + + // Verify the writer is connected + require.True(t, bw.Connected()) +} + +func TestBackedWriter_ConcurrentReconnectAndClose(t *testing.T) { + t.Parallel() + + errChan := make(chan backedpipe.ErrorEvent, 1) + bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errChan) + + // Initial connection and write some data + writer1 := newMockWriter() + err := bw.Reconnect(0, writer1) + require.NoError(t, err) + _, err = bw.Write([]byte("test data")) + require.NoError(t, err) + + // Start reconnection with slow replay + reconnectStarted := make(chan struct{}, 1) + replayCanComplete := make(chan struct{}) + reconnectComplete := make(chan struct{}) + var reconnectErr error + + go func() { + defer close(reconnectComplete) + writer2 := &mockWriter{ + writeFunc: func(p []byte) (int, error) { + // Signal that replay has started + select { + case reconnectStarted <- struct{}{}: + default: + } + // Wait for test to allow replay to complete + <-replayCanComplete + return len(p), nil + }, + } + reconnectErr = bw.Reconnect(0, writer2) + }() + + // Wait for reconnection to start + ctx := testutil.Context(t, testutil.WaitShort) + testutil.RequireReceive(ctx, t, reconnectStarted) + + // Start Close() in a separate goroutine since it will block until Reconnect() completes + closeStarted := make(chan struct{}, 1) + closeComplete := make(chan error, 1) + go func() { + closeStarted <- struct{}{} // Signal that Close() is starting + closeComplete <- bw.Close() + }() + + // Wait for Close() to start, then give it a moment to attempt to acquire the mutex + testutil.RequireReceive(ctx, t, closeStarted) + closeCheckTimer := time.NewTimer(testutil.IntervalFast) + defer closeCheckTimer.Stop() + + select { + case <-closeComplete: + t.Fatal("Close should be blocked during reconnect") + case <-closeCheckTimer.C: + // Good, Close is still blocked after a reasonable wait + } + + // Allow replay to complete so reconnection can finish + close(replayCanComplete) + + // Wait for reconnect to complete + select { + case <-reconnectComplete: + // Good, reconnect completed + case <-ctx.Done(): + t.Fatal("Reconnect did not complete in time") + } + + // Wait for close to complete + select { + case closeErr := <-closeComplete: + require.NoError(t, closeErr) + case <-ctx.Done(): + t.Fatal("Close did not complete in time") + } + + // With mutex held during replay, Close() waits for Reconnect() to finish. + // So Reconnect() should succeed, then Close() runs and closes the writer. + require.NoError(t, reconnectErr) + + // Verify writer is closed (Close() ran after Reconnect() completed) + require.False(t, bw.Connected()) +} + +func TestBackedWriter_MultipleWritesDuringReconnect(t *testing.T) { + t.Parallel() + + errChan := make(chan backedpipe.ErrorEvent, 1) + bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errChan) + + // Initial connection + writer1 := newMockWriter() + err := bw.Reconnect(0, writer1) + require.NoError(t, err) + + // Write some initial data + _, err = bw.Write([]byte("initial")) + require.NoError(t, err) + + // Start multiple write operations + numWriters := 5 + var wg sync.WaitGroup + writeResults := make([]error, numWriters) + writesStarted := make(chan struct{}, numWriters) + + for i := 0; i < numWriters; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + // Signal that this write is starting + writesStarted <- struct{}{} + data := []byte{byte('A' + id)} + _, writeResults[id] = bw.Write(data) + }(i) + } + + // Wait for all writes to start + ctx := testutil.Context(t, testutil.WaitLong) + for i := 0; i < numWriters; i++ { + testutil.RequireReceive(ctx, t, writesStarted) + } + + // Use a timer to ensure all write goroutines have had a chance to start executing + // and potentially get blocked on the mutex before we start the reconnection + writesReadyTimer := time.NewTimer(testutil.IntervalFast) + defer writesReadyTimer.Stop() + <-writesReadyTimer.C + + // Start reconnection with controlled replay + replayStarted := make(chan struct{}) + replayCanComplete := make(chan struct{}) + writer2 := &mockWriter{ + writeFunc: func(p []byte) (int, error) { + // Signal that replay has started + select { + case replayStarted <- struct{}{}: + default: + } + // Wait for test to allow replay to complete + <-replayCanComplete + return len(p), nil + }, + } + + // Start reconnection in a goroutine so we can control timing + reconnectComplete := make(chan error, 1) + go func() { + reconnectComplete <- bw.Reconnect(0, writer2) + }() + + // Wait for replay to start + testutil.RequireReceive(ctx, t, replayStarted) + + // Allow replay to complete + close(replayCanComplete) + + // Wait for reconnection to complete + select { + case reconnectErr := <-reconnectComplete: + require.NoError(t, reconnectErr) + case <-ctx.Done(): + t.Fatal("Reconnect did not complete in time") + } + + // Wait for all writes to complete + wg.Wait() + + // All writes should succeed + for i, err := range writeResults { + require.NoError(t, err, "Write %d should succeed", i) + } + + // Verify the writer is connected + require.True(t, bw.Connected()) +} + +func BenchmarkBackedWriter_Write(b *testing.B) { + errChan := make(chan backedpipe.ErrorEvent, 1) + bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errChan) // 64KB buffer + writer := newMockWriter() + bw.Reconnect(0, writer) + + data := bytes.Repeat([]byte("x"), 1024) // 1KB writes + + b.ResetTimer() + for i := 0; i < b.N; i++ { + bw.Write(data) + } +} + +func BenchmarkBackedWriter_Reconnect(b *testing.B) { + errChan := make(chan backedpipe.ErrorEvent, 1) + bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errChan) + + // Connect initially to fill buffer with data + initialWriter := newMockWriter() + err := bw.Reconnect(0, initialWriter) + if err != nil { + b.Fatal(err) + } + + // Fill buffer with data + data := bytes.Repeat([]byte("x"), 1024) + for i := 0; i < 32; i++ { + bw.Write(data) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + writer := newMockWriter() + bw.Reconnect(0, writer) + } +} diff --git a/agent/immortalstreams/backedpipe/ring_buffer.go b/agent/immortalstreams/backedpipe/ring_buffer.go new file mode 100644 index 0000000000..91fde569af --- /dev/null +++ b/agent/immortalstreams/backedpipe/ring_buffer.go @@ -0,0 +1,129 @@ +package backedpipe + +import "golang.org/x/xerrors" + +// ringBuffer implements an efficient circular buffer with a fixed-size allocation. +// This implementation is not thread-safe and relies on external synchronization. +type ringBuffer struct { + buffer []byte + start int // index of first valid byte + end int // index of last valid byte (-1 when empty) +} + +// newRingBuffer creates a new ring buffer with the specified capacity. +// Capacity must be > 0. +func newRingBuffer(capacity int) *ringBuffer { + if capacity <= 0 { + panic("ring buffer capacity must be > 0") + } + return &ringBuffer{ + buffer: make([]byte, capacity), + end: -1, // -1 indicates empty buffer + } +} + +// Size returns the current number of bytes in the buffer. +func (rb *ringBuffer) Size() int { + if rb.end == -1 { + return 0 // Buffer is empty + } + if rb.start <= rb.end { + return rb.end - rb.start + 1 + } + // Buffer wraps around + return len(rb.buffer) - rb.start + rb.end + 1 +} + +// Write writes data to the ring buffer. If the buffer would overflow, +// it evicts the oldest data to make room for new data. +func (rb *ringBuffer) Write(data []byte) { + if len(data) == 0 { + return + } + + capacity := len(rb.buffer) + + // If data is larger than capacity, only keep the last capacity bytes + if len(data) > capacity { + data = data[len(data)-capacity:] + // Clear buffer and write new data + rb.start = 0 + rb.end = -1 // Will be set properly below + } + + // Calculate how much we need to evict to fit new data + spaceNeeded := len(data) + availableSpace := capacity - rb.Size() + + if spaceNeeded > availableSpace { + bytesToEvict := spaceNeeded - availableSpace + rb.evict(bytesToEvict) + } + + // Buffer has data, write after current end + writePos := (rb.end + 1) % capacity + if writePos+len(data) <= capacity { + // No wrap needed - single copy + copy(rb.buffer[writePos:], data) + rb.end = (rb.end + len(data)) % capacity + } else { + // Need to wrap around - two copies + firstChunk := capacity - writePos + copy(rb.buffer[writePos:], data[:firstChunk]) + copy(rb.buffer[0:], data[firstChunk:]) + rb.end = len(data) - firstChunk - 1 + } +} + +// evict removes the specified number of bytes from the beginning of the buffer. +func (rb *ringBuffer) evict(count int) { + if count >= rb.Size() { + // Evict everything + rb.start = 0 + rb.end = -1 + return + } + + rb.start = (rb.start + count) % len(rb.buffer) + // Buffer remains non-empty after partial eviction +} + +// ReadLast returns the last n bytes from the buffer. +// If n is greater than the available data, returns an error. +// If n is negative, returns an error. +func (rb *ringBuffer) ReadLast(n int) ([]byte, error) { + if n < 0 { + return nil, xerrors.New("cannot read negative number of bytes") + } + + if n == 0 { + return nil, nil + } + + size := rb.Size() + + // If requested more than available, return error + if n > size { + return nil, xerrors.Errorf("requested %d bytes but only %d available", n, size) + } + + result := make([]byte, n) + capacity := len(rb.buffer) + + // Calculate where to start reading from (n bytes before the end) + startOffset := size - n + actualStart := (rb.start + startOffset) % capacity + + // Copy the last n bytes + if actualStart+n <= capacity { + // No wrap needed + copy(result, rb.buffer[actualStart:actualStart+n]) + } else { + // Need to wrap around + firstChunk := capacity - actualStart + copy(result[0:firstChunk], rb.buffer[actualStart:capacity]) + copy(result[firstChunk:], rb.buffer[0:n-firstChunk]) + } + + return result, nil +} diff --git a/agent/immortalstreams/backedpipe/ring_buffer_internal_test.go b/agent/immortalstreams/backedpipe/ring_buffer_internal_test.go new file mode 100644 index 0000000000..fee2b00328 --- /dev/null +++ b/agent/immortalstreams/backedpipe/ring_buffer_internal_test.go @@ -0,0 +1,261 @@ +package backedpipe + +import ( + "bytes" + "os" + "runtime" + "testing" + + "github.com/stretchr/testify/require" + "go.uber.org/goleak" + + "github.com/coder/coder/v2/testutil" +) + +func TestMain(m *testing.M) { + if runtime.GOOS == "windows" { + // Don't run goleak on windows tests, they're super flaky right now. + // See: https://github.com/coder/coder/issues/8954 + os.Exit(m.Run()) + } + goleak.VerifyTestMain(m, testutil.GoleakOptions...) +} + +func TestRingBuffer_NewRingBuffer(t *testing.T) { + t.Parallel() + + rb := newRingBuffer(100) + // Test that we can write and read from the buffer + rb.Write([]byte("test")) + + data, err := rb.ReadLast(4) + require.NoError(t, err) + require.Equal(t, []byte("test"), data) +} + +func TestRingBuffer_WriteAndRead(t *testing.T) { + t.Parallel() + + rb := newRingBuffer(10) + + // Write some data + rb.Write([]byte("hello")) + + // Read last 4 bytes + data, err := rb.ReadLast(4) + require.NoError(t, err) + require.Equal(t, "ello", string(data)) + + // Write more data + rb.Write([]byte("world")) + + // Read last 5 bytes + data, err = rb.ReadLast(5) + require.NoError(t, err) + require.Equal(t, "world", string(data)) + + // Read last 3 bytes + data, err = rb.ReadLast(3) + require.NoError(t, err) + require.Equal(t, "rld", string(data)) + + // Read more than available (should be 10 bytes total) + _, err = rb.ReadLast(15) + require.Error(t, err) + require.Contains(t, err.Error(), "requested 15 bytes but only") +} + +func TestRingBuffer_OverflowEviction(t *testing.T) { + t.Parallel() + + rb := newRingBuffer(5) + + // Fill buffer + rb.Write([]byte("abcde")) + + // Overflow should evict oldest data + rb.Write([]byte("fg")) + + // Should now contain "cdefg" + data, err := rb.ReadLast(5) + require.NoError(t, err) + require.Equal(t, []byte("cdefg"), data) +} + +func TestRingBuffer_LargeWrite(t *testing.T) { + t.Parallel() + + rb := newRingBuffer(5) + + // Write data larger than capacity + rb.Write([]byte("abcdefghij")) + + // Should contain last 5 bytes + data, err := rb.ReadLast(5) + require.NoError(t, err) + require.Equal(t, []byte("fghij"), data) +} + +func TestRingBuffer_WrapAround(t *testing.T) { + t.Parallel() + + rb := newRingBuffer(5) + + // Fill buffer + rb.Write([]byte("abcde")) + + // Write more to cause wrap-around + rb.Write([]byte("fgh")) + + // Should contain "defgh" + data, err := rb.ReadLast(5) + require.NoError(t, err) + require.Equal(t, []byte("defgh"), data) + + // Test reading last 3 bytes after wrap + data, err = rb.ReadLast(3) + require.NoError(t, err) + require.Equal(t, []byte("fgh"), data) +} + +func TestRingBuffer_ReadLastEdgeCases(t *testing.T) { + t.Parallel() + + rb := newRingBuffer(3) + + // Write some data (5 bytes to a 3-byte buffer, so only last 3 bytes remain) + rb.Write([]byte("hello")) + + // Test reading negative count + data, err := rb.ReadLast(-1) + require.Error(t, err) + require.Contains(t, err.Error(), "cannot read negative number of bytes") + require.Nil(t, data) + + // Test reading zero bytes + data, err = rb.ReadLast(0) + require.NoError(t, err) + require.Nil(t, data) + + // Test reading more than available (buffer has 3 bytes, try to read 10) + _, err = rb.ReadLast(10) + require.Error(t, err) + require.Contains(t, err.Error(), "requested 10 bytes but only 3 available") + + // Test reading exact amount available + data, err = rb.ReadLast(3) + require.NoError(t, err) + require.Equal(t, []byte("llo"), data) +} + +func TestRingBuffer_EmptyWrite(t *testing.T) { + t.Parallel() + + rb := newRingBuffer(10) + + // Write empty data + rb.Write([]byte{}) + + // Buffer should still be empty + _, err := rb.ReadLast(5) + require.Error(t, err) + require.Contains(t, err.Error(), "requested 5 bytes but only 0 available") +} + +func TestRingBuffer_MultipleWrites(t *testing.T) { + t.Parallel() + + rb := newRingBuffer(10) + + // Write data in chunks + rb.Write([]byte("ab")) + rb.Write([]byte("cd")) + rb.Write([]byte("ef")) + + data, err := rb.ReadLast(6) + require.NoError(t, err) + require.Equal(t, []byte("abcdef"), data) + + // Test partial reads + data, err = rb.ReadLast(4) + require.NoError(t, err) + require.Equal(t, []byte("cdef"), data) + + data, err = rb.ReadLast(2) + require.NoError(t, err) + require.Equal(t, []byte("ef"), data) +} + +func TestRingBuffer_EdgeCaseEviction(t *testing.T) { + t.Parallel() + + rb := newRingBuffer(3) + + // Write data that will cause eviction + rb.Write([]byte("abc")) + + // Write more to cause eviction + rb.Write([]byte("d")) + + // Should now contain "bcd" + data, err := rb.ReadLast(3) + require.NoError(t, err) + require.Equal(t, []byte("bcd"), data) +} + +func TestRingBuffer_ComplexWrapAroundScenario(t *testing.T) { + t.Parallel() + + rb := newRingBuffer(8) + + // Fill buffer + rb.Write([]byte("12345678")) + + // Evict some and add more to create complex wrap scenario + rb.Write([]byte("abcd")) + data, err := rb.ReadLast(8) + require.NoError(t, err) + require.Equal(t, []byte("5678abcd"), data) + + // Add more + rb.Write([]byte("xyz")) + data, err = rb.ReadLast(8) + require.NoError(t, err) + require.Equal(t, []byte("8abcdxyz"), data) + + // Test reading various amounts from the end + data, err = rb.ReadLast(7) + require.NoError(t, err) + require.Equal(t, []byte("abcdxyz"), data) + + data, err = rb.ReadLast(4) + require.NoError(t, err) + require.Equal(t, []byte("dxyz"), data) +} + +// Benchmark tests for performance validation +func BenchmarkRingBuffer_Write(b *testing.B) { + rb := newRingBuffer(64 * 1024 * 1024) // 64MB for benchmarks + data := bytes.Repeat([]byte("x"), 1024) // 1KB writes + + b.ResetTimer() + for i := 0; i < b.N; i++ { + rb.Write(data) + } +} + +func BenchmarkRingBuffer_ReadLast(b *testing.B) { + rb := newRingBuffer(64 * 1024 * 1024) // 64MB for benchmarks + // Fill buffer with test data + for i := 0; i < 64; i++ { + rb.Write(bytes.Repeat([]byte("x"), 1024)) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := rb.ReadLast((i % 100) + 1) + if err != nil { + b.Fatal(err) + } + } +}