Files
coder/coderd/httpapi/websocket_internal_test.go
T
Cian Johnston da2fa082bb fix(coderd/httpapi): CloseRead on test conns to ensure pings pong (#25184)
The `websocketPair` test helper was not calling `CloseRead` on either
side of the connection. Without `CloseRead`, the websocket library does
not process control frames (ping/pong), so the heartbeat tests were
passing only because no pings had yet failed, not because pings were
actually succeeding.

Add `CloseRead` on both the client and server connections so that pong
frames are delivered in response to pings.

Split out from #25012.

> 🤖 Generated with [Coder Agents](https://coder.com)
2026-05-14 13:54:59 +01:00

186 lines
5.3 KiB
Go

package httpapi
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/testutil"
"github.com/coder/quartz"
"github.com/coder/websocket"
)
// websocketPair sets up an httptest server with a websocket endpoint and
// returns the server-side conn. The server handler stays alive until ctx
// is done.
func websocketPair(ctx context.Context, t *testing.T) *websocket.Conn {
t.Helper()
serverConnCh := make(chan *websocket.Conn, 1)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := websocket.Accept(w, r, nil)
if err != nil {
return
}
serverConnCh <- conn
// Keep the handler alive so the HTTP server doesn't close
// the connection from under us.
<-ctx.Done()
}))
t.Cleanup(srv.Close)
//nolint:bodyclose
clientConn, _, err := websocket.Dial(ctx, srv.URL, nil)
require.NoError(t, err)
_ = clientConn.CloseRead(ctx) // Needed to handle pings/pongs.
t.Cleanup(func() {
_ = clientConn.Close(websocket.StatusNormalClosure, "test cleanup")
})
select {
case sc := <-serverConnCh:
_ = sc.CloseRead(ctx) // Needed to handle pings/pongs.
return sc
case <-ctx.Done():
t.Fatal("timed out waiting for server websocket accept")
return nil
}
}
func TestHeartbeatClose(t *testing.T) {
t.Parallel()
t.Run("ServerSideClose", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
sink := testutil.NewFakeSink(t)
logger := sink.Logger()
mClock := quartz.NewMock(t)
// Trap ticker creation so we can synchronize startup.
trap := mClock.Trap().NewTicker("HeartbeatClose")
defer trap.Close()
serverConn := websocketPair(ctx, t)
exitCalled := make(chan struct{})
go heartbeatCloseWith(ctx, logger, func() {
close(exitCalled)
}, serverConn, mClock, time.Second)
// Wait for the ticker to be created, then release.
trap.MustWait(ctx).MustRelease(ctx)
// Close the server-side connection before the tick fires.
// The next ping will get net.ErrClosed.
_ = serverConn.Close(websocket.StatusGoingAway, "simulated teardown")
// Advance clock to trigger the tick.
mClock.Advance(time.Second).MustWait(ctx)
// Wait for heartbeatClose to call exit.
select {
case <-exitCalled:
case <-ctx.Done():
t.Fatal("timed out waiting for heartbeatClose to call exit")
}
// A closed connection is a normal shutdown condition. The
// error should be logged at Debug, not Error.
errorEntries := sink.Entries(func(e slog.SinkEntry) bool { return e.Level == slog.LevelError })
assert.Empty(t, errorEntries,
"closed connection should not produce error-level logs, got: %+v", errorEntries)
debugEntries := sink.Entries(func(e slog.SinkEntry) bool { return e.Level == slog.LevelDebug })
assert.NotEmpty(t, debugEntries,
"expected a debug-level log entry for the closed connection")
})
t.Run("ContextCanceled", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
sink := testutil.NewFakeSink(t)
logger := sink.Logger()
mClock := quartz.NewMock(t)
trap := mClock.Trap().NewTicker("HeartbeatClose")
defer trap.Close()
serverCtx, serverCancel := context.WithCancel(ctx)
serverConn := websocketPair(ctx, t)
done := make(chan struct{})
go func() {
defer close(done)
heartbeatCloseWith(serverCtx, logger, func() {
t.Error("exit should not be called on context cancel")
}, serverConn, mClock, time.Second)
}()
trap.MustWait(ctx).MustRelease(ctx)
// Cancel the context. HeartbeatClose should return via
// the <-ctx.Done() branch without calling exit.
serverCancel()
select {
case <-done:
case <-ctx.Done():
t.Fatal("timed out waiting for heartbeatClose to return")
}
errorEntries := sink.Entries(func(e slog.SinkEntry) bool { return e.Level == slog.LevelError })
assert.Empty(t, errorEntries,
"context cancellation should not produce error-level logs, got: %+v", errorEntries)
})
t.Run("PingSucceeds", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
sink := testutil.NewFakeSink(t)
logger := sink.Logger()
mClock := quartz.NewMock(t)
trap := mClock.Trap().NewTicker("HeartbeatClose")
defer trap.Close()
serverConn := websocketPair(ctx, t)
exitCalled := make(chan struct{}, 1)
go heartbeatCloseWith(ctx, logger, func() {
exitCalled <- struct{}{}
}, serverConn, mClock, time.Second)
trap.MustWait(ctx).MustRelease(ctx)
// Fire several ticks — pings should succeed each time.
for range 3 {
mClock.Advance(time.Second).MustWait(ctx)
// Give the ping round-trip time to complete.
// If exit were called, we'd catch it.
select {
case <-exitCalled:
t.Fatal("exit should not be called when pings succeed")
default:
}
}
// No logs should be emitted during normal operation.
errorEntries := sink.Entries(func(e slog.SinkEntry) bool { return e.Level == slog.LevelError })
assert.Empty(t, errorEntries,
"successful pings should not produce error-level logs, got: %+v", errorEntries)
debugEntries := sink.Entries(func(e slog.SinkEntry) bool { return e.Level == slog.LevelDebug })
assert.Empty(t, debugEntries,
"successful pings should not produce debug-level logs, got: %+v", debugEntries)
})
}