diff --git a/coderd/httpapi/websocket.go b/coderd/httpapi/websocket.go index b0c7e456b5..c483cf1834 100644 --- a/coderd/httpapi/websocket.go +++ b/coderd/httpapi/websocket.go @@ -3,20 +3,26 @@ package httpapi import ( "context" "errors" + "net" "time" "golang.org/x/xerrors" "cdr.dev/slog/v3" + "github.com/coder/quartz" "github.com/coder/websocket" ) const HeartbeatInterval time.Duration = 15 * time.Second -// HeartbeatClose loops to ping a WebSocket to keep it alive. It calls `exit` on ping -// failure. +// HeartbeatClose loops to ping a WebSocket to keep it alive. +// It calls `exit` on ping failure. func HeartbeatClose(ctx context.Context, logger slog.Logger, exit func(), conn *websocket.Conn) { - ticker := time.NewTicker(HeartbeatInterval) + heartbeatCloseWith(ctx, logger, exit, conn, quartz.NewReal(), HeartbeatInterval) +} + +func heartbeatCloseWith(ctx context.Context, logger slog.Logger, exit func(), conn *websocket.Conn, clk quartz.Clock, interval time.Duration) { + ticker := clk.NewTicker(interval, "HeartbeatClose") defer ticker.Stop() for { @@ -25,11 +31,21 @@ func HeartbeatClose(ctx context.Context, logger slog.Logger, exit func(), conn * return case <-ticker.C: } - err := pingWithTimeout(ctx, conn, HeartbeatInterval) + err := pingWithTimeout(ctx, conn, interval) if err != nil { - // context.DeadlineExceeded is expected when the client disconnects without sending a close frame. - // context.Canceled is expected when the request context is canceled. - if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) { + // These errors are all expected during normal connection + // teardown and should not be logged at error level: + // - context.DeadlineExceeded: client disconnected + // without sending a close frame. + // - context.Canceled: request context was canceled. + // - net.ErrClosed: connection was already closed by + // another goroutine (e.g. handler returned). + // - websocket.CloseError: a close frame was + // received or sent. + if errors.Is(err, context.DeadlineExceeded) || + errors.Is(err, context.Canceled) || + errors.Is(err, net.ErrClosed) || + websocket.CloseStatus(err) != -1 { logger.Debug(ctx, "heartbeat ping stopped", slog.Error(err)) } else { logger.Error(ctx, "failed to heartbeat ping", slog.Error(err)) diff --git a/coderd/httpapi/websocket_internal_test.go b/coderd/httpapi/websocket_internal_test.go new file mode 100644 index 0000000000..09d18aa38e --- /dev/null +++ b/coderd/httpapi/websocket_internal_test.go @@ -0,0 +1,210 @@ +package httpapi + +import ( + "context" + "net/http" + "net/http/httptest" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "cdr.dev/slog/v3" + "github.com/coder/coder/v2/testutil" + "github.com/coder/quartz" + "github.com/coder/websocket" +) + +// logSink captures log entries so tests can assert on log levels. +type logSink struct { + mu sync.Mutex + entries []slog.SinkEntry +} + +func (s *logSink) LogEntry(_ context.Context, e slog.SinkEntry) { + s.mu.Lock() + defer s.mu.Unlock() + s.entries = append(s.entries, e) +} + +func (*logSink) Sync() {} + +func (s *logSink) entriesAtLevel(level slog.Level) []slog.SinkEntry { + s.mu.Lock() + defer s.mu.Unlock() + var result []slog.SinkEntry + for _, e := range s.entries { + if e.Level == level { + result = append(result, e) + } + } + return result +} + +// websocketPair sets up an httptest server with a websocket endpoint and +// returns the server-side conn. The server handler stays alive until ctx +// is done. +func websocketPair(ctx context.Context, t *testing.T) *websocket.Conn { + t.Helper() + serverConnCh := make(chan *websocket.Conn, 1) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := websocket.Accept(w, r, nil) + if err != nil { + return + } + serverConnCh <- conn + // Keep the handler alive so the HTTP server doesn't close + // the connection from under us. + <-ctx.Done() + })) + t.Cleanup(srv.Close) + + //nolint:bodyclose + clientConn, _, err := websocket.Dial(ctx, srv.URL, nil) + require.NoError(t, err) + t.Cleanup(func() { + _ = clientConn.Close(websocket.StatusNormalClosure, "test cleanup") + }) + + select { + case sc := <-serverConnCh: + return sc + case <-ctx.Done(): + t.Fatal("timed out waiting for server websocket accept") + return nil + } +} + +func TestHeartbeatClose(t *testing.T) { + t.Parallel() + + t.Run("ServerSideClose", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + + sink := &logSink{} + logger := slog.Make(sink).Leveled(slog.LevelDebug) + mClock := quartz.NewMock(t) + + // Trap ticker creation so we can synchronize startup. + trap := mClock.Trap().NewTicker("HeartbeatClose") + defer trap.Close() + + serverConn := websocketPair(ctx, t) + exitCalled := make(chan struct{}) + + go heartbeatCloseWith(ctx, logger, func() { + close(exitCalled) + }, serverConn, mClock, time.Second) + + // Wait for the ticker to be created, then release. + trap.MustWait(ctx).MustRelease(ctx) + + // Close the server-side connection before the tick fires. + // The next ping will get net.ErrClosed. + _ = serverConn.Close(websocket.StatusGoingAway, "simulated teardown") + + // Advance clock to trigger the tick. + mClock.Advance(time.Second).MustWait(ctx) + + // Wait for heartbeatClose to call exit. + select { + case <-exitCalled: + case <-ctx.Done(): + t.Fatal("timed out waiting for heartbeatClose to call exit") + } + + // A closed connection is a normal shutdown condition. The + // error should be logged at Debug, not Error. + errorEntries := sink.entriesAtLevel(slog.LevelError) + assert.Empty(t, errorEntries, + "closed connection should not produce error-level logs, got: %+v", errorEntries) + debugEntries := sink.entriesAtLevel(slog.LevelDebug) + assert.NotEmpty(t, debugEntries, + "expected a debug-level log entry for the closed connection") + }) + + t.Run("ContextCanceled", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + + sink := &logSink{} + logger := slog.Make(sink).Leveled(slog.LevelDebug) + mClock := quartz.NewMock(t) + + trap := mClock.Trap().NewTicker("HeartbeatClose") + defer trap.Close() + + serverCtx, serverCancel := context.WithCancel(ctx) + serverConn := websocketPair(ctx, t) + done := make(chan struct{}) + + go func() { + defer close(done) + heartbeatCloseWith(serverCtx, logger, func() { + t.Error("exit should not be called on context cancel") + }, serverConn, mClock, time.Second) + }() + + trap.MustWait(ctx).MustRelease(ctx) + + // Cancel the context. HeartbeatClose should return via + // the <-ctx.Done() branch without calling exit. + serverCancel() + + select { + case <-done: + case <-ctx.Done(): + t.Fatal("timed out waiting for heartbeatClose to return") + } + + errorEntries := sink.entriesAtLevel(slog.LevelError) + assert.Empty(t, errorEntries, + "context cancellation should not produce error-level logs, got: %+v", errorEntries) + }) + + t.Run("PingSucceeds", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + + sink := &logSink{} + logger := slog.Make(sink).Leveled(slog.LevelDebug) + mClock := quartz.NewMock(t) + + trap := mClock.Trap().NewTicker("HeartbeatClose") + defer trap.Close() + + serverConn := websocketPair(ctx, t) + exitCalled := make(chan struct{}, 1) + + go heartbeatCloseWith(ctx, logger, func() { + exitCalled <- struct{}{} + }, serverConn, mClock, time.Second) + + trap.MustWait(ctx).MustRelease(ctx) + + // Fire several ticks — pings should succeed each time. + for range 3 { + mClock.Advance(time.Second).MustWait(ctx) + + // Give the ping round-trip time to complete. + // If exit were called, we'd catch it. + select { + case <-exitCalled: + t.Fatal("exit should not be called when pings succeed") + default: + } + } + + // No logs should be emitted during normal operation. + errorEntries := sink.entriesAtLevel(slog.LevelError) + assert.Empty(t, errorEntries, + "successful pings should not produce error-level logs, got: %+v", errorEntries) + debugEntries := sink.entriesAtLevel(slog.LevelDebug) + assert.Empty(t, debugEntries, + "successful pings should not produce debug-level logs, got: %+v", debugEntries) + }) +}