mirror of
https://github.com/coder/coder.git
synced 2026-06-03 13:08:25 +00:00
da2fa082bb
The `websocketPair` test helper was not calling `CloseRead` on either side of the connection. Without `CloseRead`, the websocket library does not process control frames (ping/pong), so the heartbeat tests were passing only because no pings had yet failed, not because pings were actually succeeding. Add `CloseRead` on both the client and server connections so that pong frames are delivered in response to pings. Split out from #25012. > 🤖 Generated with [Coder Agents](https://coder.com)
186 lines
5.3 KiB
Go
186 lines
5.3 KiB
Go
package httpapi
|
|
|
|
import (
|
|
"context"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
|
|
"cdr.dev/slog/v3"
|
|
"github.com/coder/coder/v2/testutil"
|
|
"github.com/coder/quartz"
|
|
"github.com/coder/websocket"
|
|
)
|
|
|
|
// websocketPair sets up an httptest server with a websocket endpoint and
|
|
// returns the server-side conn. The server handler stays alive until ctx
|
|
// is done.
|
|
func websocketPair(ctx context.Context, t *testing.T) *websocket.Conn {
|
|
t.Helper()
|
|
serverConnCh := make(chan *websocket.Conn, 1)
|
|
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
conn, err := websocket.Accept(w, r, nil)
|
|
if err != nil {
|
|
return
|
|
}
|
|
serverConnCh <- conn
|
|
// Keep the handler alive so the HTTP server doesn't close
|
|
// the connection from under us.
|
|
<-ctx.Done()
|
|
}))
|
|
t.Cleanup(srv.Close)
|
|
|
|
//nolint:bodyclose
|
|
clientConn, _, err := websocket.Dial(ctx, srv.URL, nil)
|
|
require.NoError(t, err)
|
|
_ = clientConn.CloseRead(ctx) // Needed to handle pings/pongs.
|
|
t.Cleanup(func() {
|
|
_ = clientConn.Close(websocket.StatusNormalClosure, "test cleanup")
|
|
})
|
|
|
|
select {
|
|
case sc := <-serverConnCh:
|
|
_ = sc.CloseRead(ctx) // Needed to handle pings/pongs.
|
|
return sc
|
|
case <-ctx.Done():
|
|
t.Fatal("timed out waiting for server websocket accept")
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func TestHeartbeatClose(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
t.Run("ServerSideClose", func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
sink := testutil.NewFakeSink(t)
|
|
logger := sink.Logger()
|
|
mClock := quartz.NewMock(t)
|
|
|
|
// Trap ticker creation so we can synchronize startup.
|
|
trap := mClock.Trap().NewTicker("HeartbeatClose")
|
|
defer trap.Close()
|
|
|
|
serverConn := websocketPair(ctx, t)
|
|
exitCalled := make(chan struct{})
|
|
|
|
go heartbeatCloseWith(ctx, logger, func() {
|
|
close(exitCalled)
|
|
}, serverConn, mClock, time.Second)
|
|
|
|
// Wait for the ticker to be created, then release.
|
|
trap.MustWait(ctx).MustRelease(ctx)
|
|
|
|
// Close the server-side connection before the tick fires.
|
|
// The next ping will get net.ErrClosed.
|
|
_ = serverConn.Close(websocket.StatusGoingAway, "simulated teardown")
|
|
|
|
// Advance clock to trigger the tick.
|
|
mClock.Advance(time.Second).MustWait(ctx)
|
|
|
|
// Wait for heartbeatClose to call exit.
|
|
select {
|
|
case <-exitCalled:
|
|
case <-ctx.Done():
|
|
t.Fatal("timed out waiting for heartbeatClose to call exit")
|
|
}
|
|
|
|
// A closed connection is a normal shutdown condition. The
|
|
// error should be logged at Debug, not Error.
|
|
errorEntries := sink.Entries(func(e slog.SinkEntry) bool { return e.Level == slog.LevelError })
|
|
assert.Empty(t, errorEntries,
|
|
"closed connection should not produce error-level logs, got: %+v", errorEntries)
|
|
debugEntries := sink.Entries(func(e slog.SinkEntry) bool { return e.Level == slog.LevelDebug })
|
|
assert.NotEmpty(t, debugEntries,
|
|
"expected a debug-level log entry for the closed connection")
|
|
})
|
|
|
|
t.Run("ContextCanceled", func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
sink := testutil.NewFakeSink(t)
|
|
logger := sink.Logger()
|
|
mClock := quartz.NewMock(t)
|
|
|
|
trap := mClock.Trap().NewTicker("HeartbeatClose")
|
|
defer trap.Close()
|
|
|
|
serverCtx, serverCancel := context.WithCancel(ctx)
|
|
serverConn := websocketPair(ctx, t)
|
|
done := make(chan struct{})
|
|
|
|
go func() {
|
|
defer close(done)
|
|
heartbeatCloseWith(serverCtx, logger, func() {
|
|
t.Error("exit should not be called on context cancel")
|
|
}, serverConn, mClock, time.Second)
|
|
}()
|
|
|
|
trap.MustWait(ctx).MustRelease(ctx)
|
|
|
|
// Cancel the context. HeartbeatClose should return via
|
|
// the <-ctx.Done() branch without calling exit.
|
|
serverCancel()
|
|
|
|
select {
|
|
case <-done:
|
|
case <-ctx.Done():
|
|
t.Fatal("timed out waiting for heartbeatClose to return")
|
|
}
|
|
|
|
errorEntries := sink.Entries(func(e slog.SinkEntry) bool { return e.Level == slog.LevelError })
|
|
assert.Empty(t, errorEntries,
|
|
"context cancellation should not produce error-level logs, got: %+v", errorEntries)
|
|
})
|
|
|
|
t.Run("PingSucceeds", func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
sink := testutil.NewFakeSink(t)
|
|
logger := sink.Logger()
|
|
mClock := quartz.NewMock(t)
|
|
|
|
trap := mClock.Trap().NewTicker("HeartbeatClose")
|
|
defer trap.Close()
|
|
|
|
serverConn := websocketPair(ctx, t)
|
|
exitCalled := make(chan struct{}, 1)
|
|
|
|
go heartbeatCloseWith(ctx, logger, func() {
|
|
exitCalled <- struct{}{}
|
|
}, serverConn, mClock, time.Second)
|
|
|
|
trap.MustWait(ctx).MustRelease(ctx)
|
|
|
|
// Fire several ticks — pings should succeed each time.
|
|
for range 3 {
|
|
mClock.Advance(time.Second).MustWait(ctx)
|
|
|
|
// Give the ping round-trip time to complete.
|
|
// If exit were called, we'd catch it.
|
|
select {
|
|
case <-exitCalled:
|
|
t.Fatal("exit should not be called when pings succeed")
|
|
default:
|
|
}
|
|
}
|
|
|
|
// No logs should be emitted during normal operation.
|
|
errorEntries := sink.Entries(func(e slog.SinkEntry) bool { return e.Level == slog.LevelError })
|
|
assert.Empty(t, errorEntries,
|
|
"successful pings should not produce error-level logs, got: %+v", errorEntries)
|
|
debugEntries := sink.Entries(func(e slog.SinkEntry) bool { return e.Level == slog.LevelDebug })
|
|
assert.Empty(t, debugEntries,
|
|
"successful pings should not produce debug-level logs, got: %+v", debugEntries)
|
|
})
|
|
}
|