chore: add support for one-way websockets to backend (#16853)

Closes https://github.com/coder/coder/issues/16775

## Changes made
- Added `OneWayWebSocket` function that establishes WebSocket
connections that don't allow client-to-server communication
- Added tests for the new function
- Updated API endpoints to make new WS-based endpoints, and mark
previous SSE-based endpoints as deprecated
- Updated existing SSE handlers to use the same core logic as the new WS
handlers

## Notes
- Frontend changes handled via #16855
This commit is contained in:
Michael Smith
2025-03-28 17:13:20 -04:00
committed by GitHub
parent d3050a7e77
commit 9bc727e977
21 changed files with 1720 additions and 190 deletions
+97
View File
@@ -8618,6 +8618,7 @@ const docTemplate = `{
],
"summary": "Watch for workspace agent metadata updates",
"operationId": "watch-for-workspace-agent-metadata-updates",
"deprecated": true,
"parameters": [
{
"type": "string",
@@ -8638,6 +8639,44 @@ const docTemplate = `{
}
}
},
"/workspaceagents/{workspaceagent}/watch-metadata-ws": {
"get": {
"security": [
{
"CoderSessionToken": []
}
],
"produces": [
"application/json"
],
"tags": [
"Agents"
],
"summary": "Watch for workspace agent metadata updates via WebSockets",
"operationId": "watch-for-workspace-agent-metadata-updates-via-websockets",
"parameters": [
{
"type": "string",
"format": "uuid",
"description": "Workspace agent ID",
"name": "workspaceagent",
"in": "path",
"required": true
}
],
"responses": {
"200": {
"description": "OK",
"schema": {
"$ref": "#/definitions/codersdk.ServerSentEvent"
}
}
},
"x-apidocgen": {
"skip": true
}
}
},
"/workspacebuilds/{workspacebuild}": {
"get": {
"security": [
@@ -10049,6 +10088,7 @@ const docTemplate = `{
],
"summary": "Watch workspace by ID",
"operationId": "watch-workspace-by-id",
"deprecated": true,
"parameters": [
{
"type": "string",
@@ -10068,6 +10108,41 @@ const docTemplate = `{
}
}
}
},
"/workspaces/{workspace}/watch-ws": {
"get": {
"security": [
{
"CoderSessionToken": []
}
],
"produces": [
"application/json"
],
"tags": [
"Workspaces"
],
"summary": "Watch workspace by ID via WebSockets",
"operationId": "watch-workspace-by-id-via-websockets",
"parameters": [
{
"type": "string",
"format": "uuid",
"description": "Workspace ID",
"name": "workspace",
"in": "path",
"required": true
}
],
"responses": {
"200": {
"description": "OK",
"schema": {
"$ref": "#/definitions/codersdk.ServerSentEvent"
}
}
}
}
}
},
"definitions": {
@@ -14621,6 +14696,28 @@ const docTemplate = `{
}
}
},
"codersdk.ServerSentEvent": {
"type": "object",
"properties": {
"data": {},
"type": {
"$ref": "#/definitions/codersdk.ServerSentEventType"
}
}
},
"codersdk.ServerSentEventType": {
"type": "string",
"enum": [
"ping",
"data",
"error"
],
"x-enum-varnames": [
"ServerSentEventTypePing",
"ServerSentEventTypeData",
"ServerSentEventTypeError"
]
},
"codersdk.SessionCountDeploymentStats": {
"type": "object",
"properties": {
+85
View File
@@ -7627,6 +7627,7 @@
"tags": ["Agents"],
"summary": "Watch for workspace agent metadata updates",
"operationId": "watch-for-workspace-agent-metadata-updates",
"deprecated": true,
"parameters": [
{
"type": "string",
@@ -7647,6 +7648,40 @@
}
}
},
"/workspaceagents/{workspaceagent}/watch-metadata-ws": {
"get": {
"security": [
{
"CoderSessionToken": []
}
],
"produces": ["application/json"],
"tags": ["Agents"],
"summary": "Watch for workspace agent metadata updates via WebSockets",
"operationId": "watch-for-workspace-agent-metadata-updates-via-websockets",
"parameters": [
{
"type": "string",
"format": "uuid",
"description": "Workspace agent ID",
"name": "workspaceagent",
"in": "path",
"required": true
}
],
"responses": {
"200": {
"description": "OK",
"schema": {
"$ref": "#/definitions/codersdk.ServerSentEvent"
}
}
},
"x-apidocgen": {
"skip": true
}
}
},
"/workspacebuilds/{workspacebuild}": {
"get": {
"security": [
@@ -8900,6 +8935,7 @@
"tags": ["Workspaces"],
"summary": "Watch workspace by ID",
"operationId": "watch-workspace-by-id",
"deprecated": true,
"parameters": [
{
"type": "string",
@@ -8919,6 +8955,37 @@
}
}
}
},
"/workspaces/{workspace}/watch-ws": {
"get": {
"security": [
{
"CoderSessionToken": []
}
],
"produces": ["application/json"],
"tags": ["Workspaces"],
"summary": "Watch workspace by ID via WebSockets",
"operationId": "watch-workspace-by-id-via-websockets",
"parameters": [
{
"type": "string",
"format": "uuid",
"description": "Workspace ID",
"name": "workspace",
"in": "path",
"required": true
}
],
"responses": {
"200": {
"description": "OK",
"schema": {
"$ref": "#/definitions/codersdk.ServerSentEvent"
}
}
}
}
}
},
"definitions": {
@@ -13265,6 +13332,24 @@
}
}
},
"codersdk.ServerSentEvent": {
"type": "object",
"properties": {
"data": {},
"type": {
"$ref": "#/definitions/codersdk.ServerSentEventType"
}
}
},
"codersdk.ServerSentEventType": {
"type": "string",
"enum": ["ping", "data", "error"],
"x-enum-varnames": [
"ServerSentEventTypePing",
"ServerSentEventTypeData",
"ServerSentEventTypeError"
]
},
"codersdk.SessionCountDeploymentStats": {
"type": "object",
"properties": {
+4 -2
View File
@@ -1248,7 +1248,8 @@ func New(options *Options) *API {
httpmw.ExtractWorkspaceParam(options.Database),
)
r.Get("/", api.workspaceAgent)
r.Get("/watch-metadata", api.watchWorkspaceAgentMetadata)
r.Get("/watch-metadata", api.watchWorkspaceAgentMetadataSSE)
r.Get("/watch-metadata-ws", api.watchWorkspaceAgentMetadataWS)
r.Get("/startup-logs", api.workspaceAgentLogsDeprecated)
r.Get("/logs", api.workspaceAgentLogs)
r.Get("/listening-ports", api.workspaceAgentListeningPorts)
@@ -1280,7 +1281,8 @@ func New(options *Options) *API {
r.Route("/ttl", func(r chi.Router) {
r.Put("/", api.putWorkspaceTTL)
})
r.Get("/watch", api.watchWorkspace)
r.Get("/watch", api.watchWorkspaceSSE)
r.Get("/watch-ws", api.watchWorkspaceWS)
r.Put("/extend", api.putExtendWorkspace)
r.Post("/usage", api.postWorkspaceUsage)
r.Put("/dormant", api.putWorkspaceDormant)
+119 -17
View File
@@ -16,6 +16,9 @@ import (
"github.com/go-playground/validator/v10"
"golang.org/x/xerrors"
"github.com/coder/websocket"
"github.com/coder/websocket/wsjson"
"github.com/coder/coder/v2/coderd/httpapi/httpapiconstraints"
"github.com/coder/coder/v2/coderd/tracing"
"github.com/coder/coder/v2/codersdk"
@@ -282,7 +285,25 @@ func WebsocketCloseSprintf(format string, vars ...any) string {
return msg
}
func ServerSentEventSender(rw http.ResponseWriter, r *http.Request) (sendEvent func(ctx context.Context, sse codersdk.ServerSentEvent) error, closed chan struct{}, err error) {
type EventSender func(rw http.ResponseWriter, r *http.Request) (
sendEvent func(sse codersdk.ServerSentEvent) error,
done <-chan struct{},
err error,
)
// ServerSentEventSender establishes a Server-Sent Event connection and allows
// the consumer to send messages to the client.
//
// The function returned allows you to send a single message to the client,
// while the channel lets you listen for when the connection closes.
//
// As much as possible, this function should be avoided in favor of using the
// OneWayWebSocket function. See OneWayWebSocket for more context.
func ServerSentEventSender(rw http.ResponseWriter, r *http.Request) (
func(sse codersdk.ServerSentEvent) error,
<-chan struct{},
error,
) {
h := rw.Header()
h.Set("Content-Type", "text/event-stream")
h.Set("Cache-Control", "no-cache")
@@ -294,7 +315,8 @@ func ServerSentEventSender(rw http.ResponseWriter, r *http.Request) (sendEvent f
panic("http.ResponseWriter is not http.Flusher")
}
closed = make(chan struct{})
ctx := r.Context()
closed := make(chan struct{})
type sseEvent struct {
payload []byte
errC chan error
@@ -304,16 +326,13 @@ func ServerSentEventSender(rw http.ResponseWriter, r *http.Request) (sendEvent f
// Synchronized handling of events (no guarantee of order).
go func() {
defer close(closed)
// Send a heartbeat every 15 seconds to avoid the connection being killed.
ticker := time.NewTicker(time.Second * 15)
ticker := time.NewTicker(HeartbeatInterval)
defer ticker.Stop()
for {
var event sseEvent
select {
case <-r.Context().Done():
case <-ctx.Done():
return
case event = <-eventC:
case <-ticker.C:
@@ -333,21 +352,21 @@ func ServerSentEventSender(rw http.ResponseWriter, r *http.Request) (sendEvent f
}
}()
sendEvent = func(ctx context.Context, sse codersdk.ServerSentEvent) error {
sendEvent := func(newEvent codersdk.ServerSentEvent) error {
buf := &bytes.Buffer{}
enc := json.NewEncoder(buf)
_, err := buf.WriteString(fmt.Sprintf("event: %s\n", sse.Type))
_, err := buf.WriteString(fmt.Sprintf("event: %s\n", newEvent.Type))
if err != nil {
return err
}
if sse.Data != nil {
if newEvent.Data != nil {
_, err = buf.WriteString("data: ")
if err != nil {
return err
}
err = enc.Encode(sse.Data)
enc := json.NewEncoder(buf)
err = enc.Encode(newEvent.Data)
if err != nil {
return err
}
@@ -364,8 +383,6 @@ func ServerSentEventSender(rw http.ResponseWriter, r *http.Request) (sendEvent f
}
select {
case <-r.Context().Done():
return r.Context().Err()
case <-ctx.Done():
return ctx.Err()
case <-closed:
@@ -375,8 +392,6 @@ func ServerSentEventSender(rw http.ResponseWriter, r *http.Request) (sendEvent f
// for early exit. We don't check closed here because it
// can't happen while processing the event.
select {
case <-r.Context().Done():
return r.Context().Err()
case <-ctx.Done():
return ctx.Err()
case err := <-event.errC:
@@ -387,3 +402,90 @@ func ServerSentEventSender(rw http.ResponseWriter, r *http.Request) (sendEvent f
return sendEvent, closed, nil
}
// OneWayWebSocketEventSender establishes a new WebSocket connection that
// enforces one-way communication from the server to the client.
//
// The function returned allows you to send a single message to the client,
// while the channel lets you listen for when the connection closes.
//
// We must use an approach like this instead of Server-Sent Events for the
// browser, because on HTTP/1.1 connections, browsers are locked to no more than
// six HTTP connections for a domain total, across all tabs. If a user were to
// open a workspace in multiple tabs, the entire UI can start to lock up.
// WebSockets have no such limitation, no matter what HTTP protocol was used to
// establish the connection.
func OneWayWebSocketEventSender(rw http.ResponseWriter, r *http.Request) (
func(event codersdk.ServerSentEvent) error,
<-chan struct{},
error,
) {
ctx, cancel := context.WithCancel(r.Context())
r = r.WithContext(ctx)
socket, err := websocket.Accept(rw, r, nil)
if err != nil {
cancel()
return nil, nil, xerrors.Errorf("cannot establish connection: %w", err)
}
go Heartbeat(ctx, socket)
eventC := make(chan codersdk.ServerSentEvent)
socketErrC := make(chan websocket.CloseError, 1)
closed := make(chan struct{})
go func() {
defer cancel()
defer close(closed)
for {
select {
case event := <-eventC:
writeCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
err := wsjson.Write(writeCtx, socket, event)
cancel()
if err == nil {
continue
}
_ = socket.Close(websocket.StatusInternalError, "Unable to send newest message")
case err := <-socketErrC:
_ = socket.Close(err.Code, err.Reason)
case <-ctx.Done():
_ = socket.Close(websocket.StatusNormalClosure, "Connection closed")
}
return
}
}()
// We have some tools in the UI code to help enforce one-way WebSocket
// connections, but there's still the possibility that the client could send
// a message when it's not supposed to. If that happens, the client likely
// forgot to use those tools, and communication probably can't be trusted.
// Better to just close the socket and force the UI to fix its mess
go func() {
_, _, err := socket.Read(ctx)
if errors.Is(err, context.Canceled) {
return
}
if err != nil {
socketErrC <- websocket.CloseError{
Code: websocket.StatusInternalError,
Reason: "Unable to process invalid message from client",
}
return
}
socketErrC <- websocket.CloseError{
Code: websocket.StatusProtocolError,
Reason: "Clients cannot send messages for one-way WebSockets",
}
}()
sendEvent := func(event codersdk.ServerSentEvent) error {
select {
case eventC <- event:
case <-ctx.Done():
return ctx.Err()
}
return nil
}
return sendEvent, closed, nil
}
+438
View File
@@ -1,14 +1,18 @@
package httpapi_test
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -16,6 +20,7 @@ import (
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/testutil"
)
func TestInternalServerError(t *testing.T) {
@@ -155,3 +160,436 @@ func TestWebsocketCloseMsg(t *testing.T) {
assert.Equal(t, len(trunc), 123)
})
}
// Our WebSocket library accepts any arbitrary ResponseWriter at the type level,
// but the writer must also implement http.Hijacker for long-lived connections.
type mockOneWaySocketWriter struct {
serverRecorder *httptest.ResponseRecorder
serverConn net.Conn
clientConn net.Conn
serverReadWriter *bufio.ReadWriter
testContext *testing.T
}
func (m mockOneWaySocketWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return m.serverConn, m.serverReadWriter, nil
}
func (m mockOneWaySocketWriter) Flush() {
err := m.serverReadWriter.Flush()
require.NoError(m.testContext, err)
}
func (m mockOneWaySocketWriter) Header() http.Header {
return m.serverRecorder.Header()
}
func (m mockOneWaySocketWriter) Write(b []byte) (int, error) {
return m.serverReadWriter.Write(b)
}
func (m mockOneWaySocketWriter) WriteHeader(code int) {
m.serverRecorder.WriteHeader(code)
}
type mockEventSenderWrite func(b []byte) (int, error)
func (w mockEventSenderWrite) Write(b []byte) (int, error) {
return w(b)
}
func TestOneWayWebSocketEventSender(t *testing.T) {
t.Parallel()
newBaseRequest := func(ctx context.Context) *http.Request {
url := "ws://www.fake-website.com/logs"
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
require.NoError(t, err)
h := req.Header
h.Add("Connection", "Upgrade")
h.Add("Upgrade", "websocket")
h.Add("Sec-WebSocket-Version", "13")
h.Add("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==") // Just need any string
return req
}
newOneWayWriter := func(t *testing.T) mockOneWaySocketWriter {
mockServer, mockClient := net.Pipe()
recorder := httptest.NewRecorder()
var write mockEventSenderWrite = func(b []byte) (int, error) {
serverCount, err := mockServer.Write(b)
if err != nil {
return 0, err
}
recorderCount, err := recorder.Write(b)
if err != nil {
return 0, err
}
return min(serverCount, recorderCount), nil
}
return mockOneWaySocketWriter{
testContext: t,
serverConn: mockServer,
clientConn: mockClient,
serverRecorder: recorder,
serverReadWriter: bufio.NewReadWriter(
bufio.NewReader(mockServer),
bufio.NewWriter(write),
),
}
}
t.Run("Produces error if the socket connection could not be established", func(t *testing.T) {
t.Parallel()
incorrectProtocols := []struct {
major int
minor int
proto string
}{
{0, 9, "HTTP/0.9"},
{1, 0, "HTTP/1.0"},
}
for _, p := range incorrectProtocols {
ctx := testutil.Context(t, testutil.WaitShort)
req := newBaseRequest(ctx)
req.ProtoMajor = p.major
req.ProtoMinor = p.minor
req.Proto = p.proto
writer := newOneWayWriter(t)
_, _, err := httpapi.OneWayWebSocketEventSender(writer, req)
require.ErrorContains(t, err, p.proto)
}
})
t.Run("Returned callback can publish new event to WebSocket connection", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
req := newBaseRequest(ctx)
writer := newOneWayWriter(t)
send, _, err := httpapi.OneWayWebSocketEventSender(writer, req)
require.NoError(t, err)
serverPayload := codersdk.ServerSentEvent{
Type: codersdk.ServerSentEventTypeData,
Data: "Blah",
}
err = send(serverPayload)
require.NoError(t, err)
// The client connection will receive a little bit of additional data on
// top of the main payload. Have to make sure check has tolerance for
// extra data being present
serverBytes, err := json.Marshal(serverPayload)
require.NoError(t, err)
clientBytes, err := io.ReadAll(writer.clientConn)
require.NoError(t, err)
require.True(t, bytes.Contains(clientBytes, serverBytes))
})
t.Run("Signals to outside consumer when socket has been closed", func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithCancel(testutil.Context(t, testutil.WaitShort))
req := newBaseRequest(ctx)
writer := newOneWayWriter(t)
_, done, err := httpapi.OneWayWebSocketEventSender(writer, req)
require.NoError(t, err)
successC := make(chan bool)
ticker := time.NewTicker(testutil.WaitShort)
go func() {
select {
case <-done:
successC <- true
case <-ticker.C:
successC <- false
}
}()
cancel()
require.True(t, <-successC)
})
t.Run("Socket will immediately close if client sends any message", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
req := newBaseRequest(ctx)
writer := newOneWayWriter(t)
_, done, err := httpapi.OneWayWebSocketEventSender(writer, req)
require.NoError(t, err)
successC := make(chan bool)
ticker := time.NewTicker(testutil.WaitShort)
go func() {
select {
case <-done:
successC <- true
case <-ticker.C:
successC <- false
}
}()
type JunkClientEvent struct {
Value string
}
b, err := json.Marshal(JunkClientEvent{"Hi :)"})
require.NoError(t, err)
_, err = writer.clientConn.Write(b)
require.NoError(t, err)
require.True(t, <-successC)
})
t.Run("Renders the socket inert if the request context cancels", func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithCancel(testutil.Context(t, testutil.WaitShort))
req := newBaseRequest(ctx)
writer := newOneWayWriter(t)
send, done, err := httpapi.OneWayWebSocketEventSender(writer, req)
require.NoError(t, err)
successC := make(chan bool)
ticker := time.NewTicker(testutil.WaitShort)
go func() {
select {
case <-done:
successC <- true
case <-ticker.C:
successC <- false
}
}()
cancel()
require.True(t, <-successC)
err = send(codersdk.ServerSentEvent{
Type: codersdk.ServerSentEventTypeData,
Data: "Didn't realize you were closed - sorry! I'll try coming back tomorrow.",
})
require.Equal(t, err, ctx.Err())
_, open := <-done
require.False(t, open)
_, err = writer.serverConn.Write([]byte{})
require.Equal(t, err, io.ErrClosedPipe)
_, err = writer.clientConn.Read([]byte{})
require.Equal(t, err, io.EOF)
})
t.Run("Sends a heartbeat to the socket on a fixed internal of time to keep connections alive", func(t *testing.T) {
t.Parallel()
// Need add at least three heartbeats for something to be reliably
// counted as an interval, but also need some wiggle room
heartbeatCount := 3
hbDuration := time.Duration(heartbeatCount) * httpapi.HeartbeatInterval
timeout := hbDuration + (5 * time.Second)
ctx := testutil.Context(t, timeout)
req := newBaseRequest(ctx)
writer := newOneWayWriter(t)
_, _, err := httpapi.OneWayWebSocketEventSender(writer, req)
require.NoError(t, err)
type Result struct {
Err error
Success bool
}
resultC := make(chan Result)
go func() {
err := writer.
clientConn.
SetReadDeadline(time.Now().Add(timeout))
if err != nil {
resultC <- Result{err, false}
return
}
for range heartbeatCount {
pingBuffer := make([]byte, 1)
pingSize, err := writer.clientConn.Read(pingBuffer)
if err != nil || pingSize != 1 {
resultC <- Result{err, false}
return
}
}
resultC <- Result{nil, true}
}()
result := <-resultC
require.NoError(t, result.Err)
require.True(t, result.Success)
})
}
// ServerSentEventSender accepts any arbitrary ResponseWriter at the type level,
// but the writer must also implement http.Flusher for long-lived connections
type mockServerSentWriter struct {
serverRecorder *httptest.ResponseRecorder
serverConn net.Conn
clientConn net.Conn
buffer *bytes.Buffer
testContext *testing.T
}
func (m mockServerSentWriter) Flush() {
b := m.buffer.Bytes()
_, err := m.serverConn.Write(b)
require.NoError(m.testContext, err)
m.buffer.Reset()
// Must close server connection to indicate EOF for any reads from the
// client connection; otherwise reads block forever. This is a testing
// limitation compared to the one-way websockets, since we have no way to
// frame the data and auto-indicate EOF for each message
err = m.serverConn.Close()
require.NoError(m.testContext, err)
}
func (m mockServerSentWriter) Header() http.Header {
return m.serverRecorder.Header()
}
func (m mockServerSentWriter) Write(b []byte) (int, error) {
return m.buffer.Write(b)
}
func (m mockServerSentWriter) WriteHeader(code int) {
m.serverRecorder.WriteHeader(code)
}
func TestServerSentEventSender(t *testing.T) {
t.Parallel()
newBaseRequest := func(ctx context.Context) *http.Request {
url := "ws://www.fake-website.com/logs"
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
require.NoError(t, err)
return req
}
newServerSentWriter := func(t *testing.T) mockServerSentWriter {
mockServer, mockClient := net.Pipe()
return mockServerSentWriter{
testContext: t,
serverRecorder: httptest.NewRecorder(),
clientConn: mockClient,
serverConn: mockServer,
buffer: &bytes.Buffer{},
}
}
t.Run("Mutates response headers to support SSE connections", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
req := newBaseRequest(ctx)
writer := newServerSentWriter(t)
_, _, err := httpapi.ServerSentEventSender(writer, req)
require.NoError(t, err)
h := writer.Header()
require.Equal(t, h.Get("Content-Type"), "text/event-stream")
require.Equal(t, h.Get("Cache-Control"), "no-cache")
require.Equal(t, h.Get("Connection"), "keep-alive")
require.Equal(t, h.Get("X-Accel-Buffering"), "no")
})
t.Run("Returned callback can publish new event to SSE connection", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
req := newBaseRequest(ctx)
writer := newServerSentWriter(t)
send, _, err := httpapi.ServerSentEventSender(writer, req)
require.NoError(t, err)
serverPayload := codersdk.ServerSentEvent{
Type: codersdk.ServerSentEventTypeData,
Data: "Blah",
}
err = send(serverPayload)
require.NoError(t, err)
clientBytes, err := io.ReadAll(writer.clientConn)
require.NoError(t, err)
require.Equal(
t,
string(clientBytes),
"event: data\ndata: \"Blah\"\n\n",
)
})
t.Run("Signals to outside consumer when connection has been closed", func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithCancel(testutil.Context(t, testutil.WaitShort))
req := newBaseRequest(ctx)
writer := newServerSentWriter(t)
_, done, err := httpapi.ServerSentEventSender(writer, req)
require.NoError(t, err)
successC := make(chan bool)
ticker := time.NewTicker(testutil.WaitShort)
go func() {
select {
case <-done:
successC <- true
case <-ticker.C:
successC <- false
}
}()
cancel()
require.True(t, <-successC)
})
t.Run("Sends a heartbeat to the client on a fixed internal of time to keep connections alive", func(t *testing.T) {
t.Parallel()
// Need add at least three heartbeats for something to be reliably
// counted as an interval, but also need some wiggle room
heartbeatCount := 3
hbDuration := time.Duration(heartbeatCount) * httpapi.HeartbeatInterval
timeout := hbDuration + (5 * time.Second)
ctx := testutil.Context(t, timeout)
req := newBaseRequest(ctx)
writer := newServerSentWriter(t)
_, _, err := httpapi.ServerSentEventSender(writer, req)
require.NoError(t, err)
type Result struct {
Err error
Success bool
}
resultC := make(chan Result)
go func() {
err := writer.
clientConn.
SetReadDeadline(time.Now().Add(timeout))
if err != nil {
resultC <- Result{err, false}
return
}
for range heartbeatCount {
pingBuffer := make([]byte, 1)
pingSize, err := writer.clientConn.Read(pingBuffer)
if err != nil || pingSize != 1 {
resultC <- Result{err, false}
return
}
}
resultC <- Result{nil, true}
}()
result := <-resultC
require.NoError(t, result.Err)
require.True(t, result.Success)
})
}
+5 -4
View File
@@ -11,11 +11,13 @@ import (
"github.com/coder/websocket"
)
const HeartbeatInterval time.Duration = 15 * time.Second
// Heartbeat loops to ping a WebSocket to keep it alive.
// Default idle connection timeouts are typically 60 seconds.
// See: https://docs.aws.amazon.com/elasticloadbalancing/latest/application/application-load-balancers.html#connection-idle-timeout
func Heartbeat(ctx context.Context, conn *websocket.Conn) {
ticker := time.NewTicker(15 * time.Second)
ticker := time.NewTicker(HeartbeatInterval)
defer ticker.Stop()
for {
select {
@@ -33,8 +35,7 @@ func Heartbeat(ctx context.Context, conn *websocket.Conn) {
// Heartbeat loops to ping a WebSocket to keep it alive. It calls `exit` on ping
// failure.
func HeartbeatClose(ctx context.Context, logger slog.Logger, exit func(), conn *websocket.Conn) {
interval := 15 * time.Second
ticker := time.NewTicker(interval)
ticker := time.NewTicker(HeartbeatInterval)
defer ticker.Stop()
for {
@@ -43,7 +44,7 @@ func HeartbeatClose(ctx context.Context, logger slog.Logger, exit func(), conn *
return
case <-ticker.C:
}
err := pingWithTimeout(ctx, conn, interval)
err := pingWithTimeout(ctx, conn, HeartbeatInterval)
if err != nil {
// context.DeadlineExceeded is expected when the client disconnects without sending a close frame
if !errors.Is(err, context.DeadlineExceeded) {
+28 -6
View File
@@ -1098,7 +1098,29 @@ func convertScripts(dbScripts []database.WorkspaceAgentScript) []codersdk.Worksp
// @Param workspaceagent path string true "Workspace agent ID" format(uuid)
// @Router /workspaceagents/{workspaceagent}/watch-metadata [get]
// @x-apidocgen {"skip": true}
func (api *API) watchWorkspaceAgentMetadata(rw http.ResponseWriter, r *http.Request) {
// @Deprecated Use /workspaceagents/{workspaceagent}/watch-metadata-ws instead
func (api *API) watchWorkspaceAgentMetadataSSE(rw http.ResponseWriter, r *http.Request) {
api.watchWorkspaceAgentMetadata(rw, r, httpapi.ServerSentEventSender)
}
// @Summary Watch for workspace agent metadata updates via WebSockets
// @ID watch-for-workspace-agent-metadata-updates-via-websockets
// @Security CoderSessionToken
// @Produce json
// @Tags Agents
// @Success 200 {object} codersdk.ServerSentEvent
// @Param workspaceagent path string true "Workspace agent ID" format(uuid)
// @Router /workspaceagents/{workspaceagent}/watch-metadata-ws [get]
// @x-apidocgen {"skip": true}
func (api *API) watchWorkspaceAgentMetadataWS(rw http.ResponseWriter, r *http.Request) {
api.watchWorkspaceAgentMetadata(rw, r, httpapi.OneWayWebSocketEventSender)
}
func (api *API) watchWorkspaceAgentMetadata(
rw http.ResponseWriter,
r *http.Request,
connect httpapi.EventSender,
) {
// Allow us to interrupt watch via cancel.
ctx, cancel := context.WithCancel(r.Context())
defer cancel()
@@ -1163,7 +1185,7 @@ func (api *API) watchWorkspaceAgentMetadata(rw http.ResponseWriter, r *http.Requ
//nolint:ineffassign // Release memory.
initialMD = nil
sseSendEvent, sseSenderClosed, err := httpapi.ServerSentEventSender(rw, r)
sendEvent, senderClosed, err := connect(rw, r)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error setting up server-sent events.",
@@ -1174,14 +1196,14 @@ func (api *API) watchWorkspaceAgentMetadata(rw http.ResponseWriter, r *http.Requ
// Prevent handler from returning until the sender is closed.
defer func() {
cancel()
<-sseSenderClosed
<-senderClosed
}()
// Synchronize cancellation from SSE -> context, this lets us simplify the
// cancellation logic.
go func() {
select {
case <-ctx.Done():
case <-sseSenderClosed:
case <-senderClosed:
cancel()
}
}()
@@ -1193,7 +1215,7 @@ func (api *API) watchWorkspaceAgentMetadata(rw http.ResponseWriter, r *http.Requ
log.Debug(ctx, "sending metadata", "num", len(values))
_ = sseSendEvent(ctx, codersdk.ServerSentEvent{
_ = sendEvent(codersdk.ServerSentEvent{
Type: codersdk.ServerSentEventTypeData,
Data: convertWorkspaceAgentMetadata(values),
})
@@ -1225,7 +1247,7 @@ func (api *API) watchWorkspaceAgentMetadata(rw http.ResponseWriter, r *http.Requ
if err != nil {
if !database.IsQueryCanceledError(err) {
log.Error(ctx, "failed to get metadata", slog.Error(err))
_ = sseSendEvent(ctx, codersdk.ServerSentEvent{
_ = sendEvent(codersdk.ServerSentEvent{
Type: codersdk.ServerSentEventTypeError,
Data: codersdk.Response{
Message: "Failed to get metadata.",
+31 -10
View File
@@ -1719,12 +1719,33 @@ func (api *API) resolveAutostart(rw http.ResponseWriter, r *http.Request) {
// @Param workspace path string true "Workspace ID" format(uuid)
// @Success 200 {object} codersdk.Response
// @Router /workspaces/{workspace}/watch [get]
func (api *API) watchWorkspace(rw http.ResponseWriter, r *http.Request) {
// @Deprecated Use /workspaces/{workspace}/watch-ws instead
func (api *API) watchWorkspaceSSE(rw http.ResponseWriter, r *http.Request) {
api.watchWorkspace(rw, r, httpapi.ServerSentEventSender)
}
// @Summary Watch workspace by ID via WebSockets
// @ID watch-workspace-by-id-via-websockets
// @Security CoderSessionToken
// @Produce json
// @Tags Workspaces
// @Param workspace path string true "Workspace ID" format(uuid)
// @Success 200 {object} codersdk.ServerSentEvent
// @Router /workspaces/{workspace}/watch-ws [get]
func (api *API) watchWorkspaceWS(rw http.ResponseWriter, r *http.Request) {
api.watchWorkspace(rw, r, httpapi.OneWayWebSocketEventSender)
}
func (api *API) watchWorkspace(
rw http.ResponseWriter,
r *http.Request,
connect httpapi.EventSender,
) {
ctx := r.Context()
workspace := httpmw.WorkspaceParam(r)
apiKey := httpmw.APIKey(r)
sendEvent, senderClosed, err := httpapi.ServerSentEventSender(rw, r)
sendEvent, senderClosed, err := connect(rw, r)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error setting up server-sent events.",
@@ -1740,7 +1761,7 @@ func (api *API) watchWorkspace(rw http.ResponseWriter, r *http.Request) {
sendUpdate := func(_ context.Context, _ []byte) {
workspace, err := api.Database.GetWorkspaceByID(ctx, workspace.ID)
if err != nil {
_ = sendEvent(ctx, codersdk.ServerSentEvent{
_ = sendEvent(codersdk.ServerSentEvent{
Type: codersdk.ServerSentEventTypeError,
Data: codersdk.Response{
Message: "Internal error fetching workspace.",
@@ -1752,7 +1773,7 @@ func (api *API) watchWorkspace(rw http.ResponseWriter, r *http.Request) {
data, err := api.workspaceData(ctx, []database.Workspace{workspace})
if err != nil {
_ = sendEvent(ctx, codersdk.ServerSentEvent{
_ = sendEvent(codersdk.ServerSentEvent{
Type: codersdk.ServerSentEventTypeError,
Data: codersdk.Response{
Message: "Internal error fetching workspace data.",
@@ -1762,7 +1783,7 @@ func (api *API) watchWorkspace(rw http.ResponseWriter, r *http.Request) {
return
}
if len(data.templates) == 0 {
_ = sendEvent(ctx, codersdk.ServerSentEvent{
_ = sendEvent(codersdk.ServerSentEvent{
Type: codersdk.ServerSentEventTypeError,
Data: codersdk.Response{
Message: "Forbidden reading template of selected workspace.",
@@ -1779,7 +1800,7 @@ func (api *API) watchWorkspace(rw http.ResponseWriter, r *http.Request) {
api.Options.AllowWorkspaceRenames,
)
if err != nil {
_ = sendEvent(ctx, codersdk.ServerSentEvent{
_ = sendEvent(codersdk.ServerSentEvent{
Type: codersdk.ServerSentEventTypeError,
Data: codersdk.Response{
Message: "Internal error converting workspace.",
@@ -1787,7 +1808,7 @@ func (api *API) watchWorkspace(rw http.ResponseWriter, r *http.Request) {
},
})
}
_ = sendEvent(ctx, codersdk.ServerSentEvent{
_ = sendEvent(codersdk.ServerSentEvent{
Type: codersdk.ServerSentEventTypeData,
Data: w,
})
@@ -1805,7 +1826,7 @@ func (api *API) watchWorkspace(rw http.ResponseWriter, r *http.Request) {
sendUpdate(ctx, nil)
}))
if err != nil {
_ = sendEvent(ctx, codersdk.ServerSentEvent{
_ = sendEvent(codersdk.ServerSentEvent{
Type: codersdk.ServerSentEventTypeError,
Data: codersdk.Response{
Message: "Internal error subscribing to workspace events.",
@@ -1819,7 +1840,7 @@ func (api *API) watchWorkspace(rw http.ResponseWriter, r *http.Request) {
// This is required to show whether the workspace is up-to-date.
cancelTemplateSubscribe, err := api.Pubsub.Subscribe(watchTemplateChannel(workspace.TemplateID), sendUpdate)
if err != nil {
_ = sendEvent(ctx, codersdk.ServerSentEvent{
_ = sendEvent(codersdk.ServerSentEvent{
Type: codersdk.ServerSentEventTypeError,
Data: codersdk.Response{
Message: "Internal error subscribing to template events.",
@@ -1832,7 +1853,7 @@ func (api *API) watchWorkspace(rw http.ResponseWriter, r *http.Request) {
// An initial ping signals to the request that the server is now ready
// and the client can begin servicing a channel with data.
_ = sendEvent(ctx, codersdk.ServerSentEvent{
_ = sendEvent(codersdk.ServerSentEvent{
Type: codersdk.ServerSentEventTypePing,
})
// Send updated workspace info after connection is established. This avoids
+32
View File
@@ -5735,6 +5735,38 @@ Git clone makes use of this by parsing the URL from: 'Username for "https://gith
| `ssh_config_options` | object | false | | |
| » `[any property]` | string | false | | |
## codersdk.ServerSentEvent
```json
{
"data": null,
"type": "ping"
}
```
### Properties
| Name | Type | Required | Restrictions | Description |
|--------|--------------------------------------------------------------|----------|--------------|-------------|
| `data` | any | false | | |
| `type` | [codersdk.ServerSentEventType](#codersdkserversenteventtype) | false | | |
## codersdk.ServerSentEventType
```json
"ping"
```
### Properties
#### Enumerated Values
| Value |
|---------|
| `ping` |
| `data` |
| `error` |
## codersdk.SessionCountDeploymentStats
```json
+38
View File
@@ -1979,3 +1979,41 @@ curl -X GET http://coder-server:8080/api/v2/workspaces/{workspace}/watch \
| 200 | [OK](https://tools.ietf.org/html/rfc7231#section-6.3.1) | OK | [codersdk.Response](schemas.md#codersdkresponse) |
To perform this operation, you must be authenticated. [Learn more](authentication.md).
## Watch workspace by ID via WebSockets
### Code samples
```shell
# Example request using curl
curl -X GET http://coder-server:8080/api/v2/workspaces/{workspace}/watch-ws \
-H 'Accept: application/json' \
-H 'Coder-Session-Token: API_KEY'
```
`GET /workspaces/{workspace}/watch-ws`
### Parameters
| Name | In | Type | Required | Description |
|-------------|------|--------------|----------|--------------|
| `workspace` | path | string(uuid) | true | Workspace ID |
### Example responses
> 200 Response
```json
{
"data": null,
"type": "ping"
}
```
### Responses
| Status | Meaning | Description | Schema |
|--------|---------------------------------------------------------|-------------|----------------------------------------------------------------|
| 200 | [OK](https://tools.ietf.org/html/rfc7231#section-6.3.1) | OK | [codersdk.ServerSentEvent](schemas.md#codersdkserversentevent) |
To perform this operation, you must be authenticated. [Learn more](authentication.md).
-1
View File
@@ -166,7 +166,6 @@
"@vitejs/plugin-react": "4.3.4",
"autoprefixer": "10.4.20",
"chromatic": "11.25.2",
"eventsourcemock": "2.0.0",
"express": "4.21.2",
"jest": "29.7.0",
"jest-canvas-mock": "2.5.2",
-8
View File
@@ -403,9 +403,6 @@ importers:
chromatic:
specifier: 11.25.2
version: 11.25.2
eventsourcemock:
specifier: 2.0.0
version: 2.0.0
express:
specifier: 4.21.2
version: 4.21.2
@@ -3796,9 +3793,6 @@ packages:
eventemitter3@4.0.7:
resolution: {integrity: sha512-8guHBZCwKnFhYdHr2ysuRWErTwhoN2X8XELRlrRwpmfeY2jjuUN4taQMsULKUVo1K4DvZl+0pgfyoysHxvmvEw==, tarball: https://registry.npmjs.org/eventemitter3/-/eventemitter3-4.0.7.tgz}
eventsourcemock@2.0.0:
resolution: {integrity: sha512-tSmJnuE+h6A8/hLRg0usf1yL+Q8w01RQtmg0Uzgoxk/HIPZrIUeAr/A4es/8h1wNsoG8RdiESNQLTKiNwbSC3Q==, tarball: https://registry.npmjs.org/eventsourcemock/-/eventsourcemock-2.0.0.tgz}
execa@5.1.1:
resolution: {integrity: sha512-8uSpZZocAZRBAPIEINJj3Lo9HyGitllczc27Eh5YYojjMFMn8yHMDMaUHE2Jqfq05D/wucwI4JGURyXt1vchyg==, tarball: https://registry.npmjs.org/execa/-/execa-5.1.1.tgz}
engines: {node: '>=10'}
@@ -10017,8 +10011,6 @@ snapshots:
eventemitter3@4.0.7: {}
eventsourcemock@2.0.0: {}
execa@5.1.1:
dependencies:
cross-spawn: 7.0.6
-1
View File
@@ -1 +0,0 @@
declare module "eventsourcemock";
+28 -48
View File
@@ -22,9 +22,10 @@
import globalAxios, { type AxiosInstance, isAxiosError } from "axios";
import type dayjs from "dayjs";
import userAgentParser from "ua-parser-js";
import { OneWayWebSocket } from "utils/OneWayWebSocket";
import { delay } from "../utils/delay";
import * as TypesGen from "./typesGenerated";
import type { PostWorkspaceUsageRequest } from "./typesGenerated";
import * as TypesGen from "./typesGenerated";
const getMissingParameters = (
oldBuildParameters: TypesGen.WorkspaceBuildParameter[],
@@ -101,61 +102,40 @@ const getMissingParameters = (
};
/**
*
* @param agentId
* @returns An EventSource that emits agent metadata event objects
* (ServerSentEvent)
* @returns {OneWayWebSocket} A OneWayWebSocket that emits Server-Sent Events.
*/
export const watchAgentMetadata = (agentId: string): EventSource => {
return new EventSource(
`${location.protocol}//${location.host}/api/v2/workspaceagents/${agentId}/watch-metadata`,
{ withCredentials: true },
);
export const watchAgentMetadata = (
agentId: string,
): OneWayWebSocket<TypesGen.ServerSentEvent> => {
return new OneWayWebSocket({
apiRoute: `/api/v2/workspaceagents/${agentId}/watch-metadata-ws`,
});
};
/**
* @returns {EventSource} An EventSource that emits workspace event objects
* (ServerSentEvent)
* @returns {OneWayWebSocket} A OneWayWebSocket that emits Server-Sent Events.
*/
export const watchWorkspace = (workspaceId: string): EventSource => {
return new EventSource(
`${location.protocol}//${location.host}/api/v2/workspaces/${workspaceId}/watch`,
{ withCredentials: true },
);
export const watchWorkspace = (
workspaceId: string,
): OneWayWebSocket<TypesGen.ServerSentEvent> => {
return new OneWayWebSocket({
apiRoute: `/api/v2/workspaces/${workspaceId}/watch-ws`,
});
};
type WatchInboxNotificationsParams = {
type WatchInboxNotificationsParams = Readonly<{
read_status?: "read" | "unread" | "all";
};
}>;
export const watchInboxNotifications = (
onNewNotification: (res: TypesGen.GetInboxNotificationResponse) => void,
export function watchInboxNotifications(
params?: WatchInboxNotificationsParams,
) => {
const searchParams = new URLSearchParams(params);
const socket = createWebSocket(
"/api/v2/notifications/inbox/watch",
searchParams,
);
socket.addEventListener("message", (event) => {
try {
const res = JSON.parse(
event.data,
) as TypesGen.GetInboxNotificationResponse;
onNewNotification(res);
} catch (error) {
console.warn("Error parsing inbox notification: ", error);
}
): OneWayWebSocket<TypesGen.GetInboxNotificationResponse> {
return new OneWayWebSocket({
apiRoute: "/api/v2/notifications/inbox/watch",
searchParams: params,
});
socket.addEventListener("error", (event) => {
console.warn("Watch inbox notifications error: ", event);
socket.close();
});
return socket;
};
}
export const getURLWithSearchParams = (
basePath: string,
@@ -1125,7 +1105,7 @@ class ApiMethods {
};
getWorkspaceByOwnerAndName = async (
username = "me",
username: string,
workspaceName: string,
params?: TypesGen.WorkspaceOptions,
): Promise<TypesGen.Workspace> => {
@@ -1138,7 +1118,7 @@ class ApiMethods {
};
getWorkspaceBuildByNumber = async (
username = "me",
username: string,
workspaceName: string,
buildNumber: number,
): Promise<TypesGen.WorkspaceBuild> => {
@@ -1324,7 +1304,7 @@ class ApiMethods {
};
createWorkspace = async (
userId = "me",
userId: string,
workspace: TypesGen.CreateWorkspaceRequest,
): Promise<TypesGen.Workspace> => {
const response = await this.axios.post<TypesGen.Workspace>(
@@ -2542,7 +2522,7 @@ function createWebSocket(
) {
const protocol = location.protocol === "https:" ? "wss:" : "ws:";
const socket = new WebSocket(
`${protocol}//${location.host}${path}?${params.toString()}`,
`${protocol}//${location.host}${path}?${params}`,
);
socket.binaryType = "blob";
return socket;
@@ -61,21 +61,31 @@ export const NotificationsInbox: FC<NotificationsInboxProps> = ({
);
useEffect(() => {
const socket = watchInboxNotifications(
(res) => {
updateNotificationsCache((prev) => {
return {
unread_count: res.unread_count,
notifications: [res.notification, ...prev.notifications],
};
});
},
{ read_status: "unread" },
);
const socket = watchInboxNotifications({ read_status: "unread" });
return () => {
socket.addEventListener("message", (e) => {
if (e.parseError) {
console.warn("Error parsing inbox notification: ", e.parseError);
return;
}
const msg = e.parsedMessage;
updateNotificationsCache((current) => {
return {
unread_count: msg.unread_count,
notifications: [msg.notification, ...current.notifications],
};
});
});
socket.addEventListener("error", () => {
displayError(
"Unable to retrieve latest inbox notifications. Please try refreshing the browser.",
);
socket.close();
};
});
return () => socket.close();
}, [updateNotificationsCache]);
const {
+63 -28
View File
@@ -3,9 +3,11 @@ import Skeleton from "@mui/material/Skeleton";
import Tooltip from "@mui/material/Tooltip";
import { watchAgentMetadata } from "api/api";
import type {
ServerSentEvent,
WorkspaceAgent,
WorkspaceAgentMetadata,
} from "api/typesGenerated";
import { displayError } from "components/GlobalSnackbar/utils";
import { Stack } from "components/Stack/Stack";
import dayjs from "dayjs";
import {
@@ -17,6 +19,7 @@ import {
useState,
} from "react";
import { MONOSPACE_FONT_FAMILY } from "theme/constants";
import type { OneWayWebSocket } from "utils/OneWayWebSocket";
type ItemStatus = "stale" | "valid" | "loading";
@@ -42,50 +45,82 @@ interface AgentMetadataProps {
storybookMetadata?: WorkspaceAgentMetadata[];
}
const maxSocketErrorRetryCount = 3;
export const AgentMetadata: FC<AgentMetadataProps> = ({
agent,
storybookMetadata,
}) => {
const [metadata, setMetadata] = useState<
WorkspaceAgentMetadata[] | undefined
>(undefined);
const [activeMetadata, setActiveMetadata] = useState(storybookMetadata);
useEffect(() => {
// This is an unfortunate pitfall with this component's testing setup,
// but even though we use the value of storybookMetadata as the initial
// value of the activeMetadata, we cannot put activeMetadata itself into
// the dependency array. If we did, we would destroy and rebuild each
// connection every single time a new message comes in from the socket,
// because the socket has to be wired up to the state setter
if (storybookMetadata !== undefined) {
setMetadata(storybookMetadata);
return;
}
let timeout: ReturnType<typeof setTimeout> | undefined = undefined;
let timeoutId: number | undefined = undefined;
let activeSocket: OneWayWebSocket<ServerSentEvent> | null = null;
let retries = 0;
const connect = (): (() => void) => {
const source = watchAgentMetadata(agent.id);
const createNewConnection = () => {
const socket = watchAgentMetadata(agent.id);
activeSocket = socket;
source.onerror = (e) => {
console.error("received error in watch stream", e);
setMetadata(undefined);
source.close();
socket.addEventListener("error", () => {
setActiveMetadata(undefined);
window.clearTimeout(timeoutId);
timeout = setTimeout(() => {
connect();
}, 3000);
};
// The error event is supposed to fire when an error happens
// with the connection itself, which implies that the connection
// would auto-close. Couldn't find a definitive answer on MDN,
// though, so closing it manually just to be safe
socket.close();
activeSocket = null;
source.addEventListener("data", (e) => {
const data = JSON.parse(e.data);
setMetadata(data);
});
return () => {
if (timeout !== undefined) {
clearTimeout(timeout);
retries++;
if (retries >= maxSocketErrorRetryCount) {
displayError(
"Unexpected disconnect while watching Metadata changes. Please try refreshing the page.",
);
return;
}
source.close();
};
displayError(
"Unexpected disconnect while watching Metadata changes. Creating new connection...",
);
timeoutId = window.setTimeout(() => {
createNewConnection();
}, 3_000);
});
socket.addEventListener("message", (e) => {
if (e.parseError) {
displayError(
"Unable to process newest response from server. Please try refreshing the page.",
);
return;
}
const msg = e.parsedMessage;
if (msg.type === "data") {
setActiveMetadata(msg.data as WorkspaceAgentMetadata[]);
}
});
};
createNewConnection();
return () => {
window.clearTimeout(timeoutId);
activeSocket?.close();
};
return connect();
}, [agent.id, storybookMetadata]);
if (metadata === undefined) {
if (activeMetadata === undefined) {
return (
<section css={styles.root}>
<AgentMetadataSkeleton />
@@ -93,7 +128,7 @@ export const AgentMetadata: FC<AgentMetadataProps> = ({
);
}
return <AgentMetadataView metadata={metadata} />;
return <AgentMetadataView metadata={activeMetadata} />;
};
export const AgentMetadataSkeleton: FC = () => {
@@ -1,46 +1,38 @@
import { watchBuildLogsByTemplateVersionId } from "api/api";
import type { ProvisionerJobLog, TemplateVersion } from "api/typesGenerated";
import { useEffectEvent } from "hooks/hookPolyfills";
import { useEffect, useState } from "react";
export const useWatchVersionLogs = (
templateVersion: TemplateVersion | undefined,
options?: { onDone: () => Promise<unknown> },
) => {
const [logs, setLogs] = useState<ProvisionerJobLog[] | undefined>();
const [logs, setLogs] = useState<ProvisionerJobLog[]>();
const templateVersionId = templateVersion?.id;
const templateVersionStatus = templateVersion?.job.status;
const [cachedVersionId, setCachedVersionId] = useState(templateVersionId);
if (cachedVersionId !== templateVersionId) {
setCachedVersionId(templateVersionId);
setLogs([]);
}
// biome-ignore lint/correctness/useExhaustiveDependencies: consider refactoring
const stableOnDone = useEffectEvent(() => options?.onDone());
const status = templateVersion?.job.status;
const canWatch = status === "running" || status === "pending";
useEffect(() => {
setLogs(undefined);
}, [templateVersionId]);
useEffect(() => {
if (!templateVersionId || !templateVersionStatus) {
return;
}
if (
templateVersionStatus !== "running" &&
templateVersionStatus !== "pending"
) {
if (!templateVersionId || !canWatch) {
return;
}
const socket = watchBuildLogsByTemplateVersionId(templateVersionId, {
onMessage: (log) => {
setLogs((logs) => (logs ? [...logs, log] : [log]));
},
onDone: options?.onDone,
onError: (error) => {
console.error(error);
onError: (error) => console.error(error),
onDone: stableOnDone,
onMessage: (newLog) => {
setLogs((current) => [...(current ?? []), newLog]);
},
});
return () => {
socket.close();
};
}, [options?.onDone, templateVersionId, templateVersionStatus]);
return () => socket.close();
}, [stableOnDone, canWatch, templateVersionId]);
return logs;
};
@@ -2,7 +2,7 @@ import { screen, waitFor, within } from "@testing-library/react";
import userEvent from "@testing-library/user-event";
import * as apiModule from "api/api";
import type { TemplateVersionParameter, Workspace } from "api/typesGenerated";
import EventSourceMock from "eventsourcemock";
import MockServerSocket from "jest-websocket-mock";
import {
DashboardContext,
type DashboardProvider,
@@ -84,23 +84,11 @@ const testButton = async (
const user = userEvent.setup();
await user.click(button);
expect(actionMock).toBeCalled();
expect(actionMock).toHaveBeenCalled();
};
let originalEventSource: typeof window.EventSource;
beforeAll(() => {
originalEventSource = window.EventSource;
// mocking out EventSource for SSE
window.EventSource = EventSourceMock;
});
beforeEach(() => {
jest.resetAllMocks();
});
afterAll(() => {
window.EventSource = originalEventSource;
afterEach(() => {
MockServerSocket.clean();
});
describe("WorkspacePage", () => {
+18 -11
View File
@@ -5,6 +5,7 @@ import { workspaceBuildsKey } from "api/queries/workspaceBuilds";
import { workspaceByOwnerAndName } from "api/queries/workspaces";
import type { Workspace } from "api/typesGenerated";
import { ErrorAlert } from "components/Alert/ErrorAlert";
import { displayError } from "components/GlobalSnackbar/utils";
import { Loader } from "components/Loader/Loader";
import { Margins } from "components/Margins/Margins";
import { useEffectEvent } from "hooks/hookPolyfills";
@@ -82,20 +83,26 @@ export const WorkspacePage: FC = () => {
return;
}
const eventSource = watchWorkspace(workspaceId);
const socket = watchWorkspace(workspaceId);
socket.addEventListener("message", (event) => {
if (event.parseError) {
displayError(
"Unable to process latest data from the server. Please try refreshing the page.",
);
return;
}
eventSource.addEventListener("data", async (event) => {
const newWorkspaceData = JSON.parse(event.data) as Workspace;
await updateWorkspaceData(newWorkspaceData);
if (event.parsedMessage.type === "data") {
updateWorkspaceData(event.parsedMessage.data as Workspace);
}
});
socket.addEventListener("error", () => {
displayError(
"Unable to get workspace changes. Connection has been closed.",
);
});
eventSource.addEventListener("error", (event) => {
console.error("Error on getting workspace changes.", event);
});
return () => {
eventSource.close();
};
return () => socket.close();
}, [updateWorkspaceData, workspaceId]);
// Page statuses
+492
View File
@@ -0,0 +1,492 @@
/**
* @file Sets up unit tests for OneWayWebSocket.
*
* 2025-03-18 - Really wanted to define these as integration tests with MSW, but
* getting it set up correctly for Jest and JSDOM got a little screwy. That can
* be revisited in the future, but in the meantime, we're assuming that the base
* WebSocket class doesn't have any bugs, and can safely be mocked out.
*/
import {
type OneWayMessageEvent,
OneWayWebSocket,
type WebSocketEventType,
} from "./OneWayWebSocket";
type MockPublisher = Readonly<{
publishMessage: (event: MessageEvent<string>) => void;
publishError: (event: ErrorEvent) => void;
publishClose: (event: CloseEvent) => void;
publishOpen: (event: Event) => void;
}>;
function createMockWebSocket(
url: string,
protocols?: string | string[],
): readonly [WebSocket, MockPublisher] {
type EventMap = {
message: MessageEvent<string>;
error: ErrorEvent;
close: CloseEvent;
open: Event;
};
type CallbackStore = {
[K in keyof EventMap]: ((event: EventMap[K]) => void)[];
};
let activeProtocol: string;
if (Array.isArray(protocols)) {
activeProtocol = protocols[0] ?? "";
} else if (typeof protocols === "string") {
activeProtocol = protocols;
} else {
activeProtocol = "";
}
let closed = false;
const store: CallbackStore = {
message: [],
error: [],
close: [],
open: [],
};
const mockSocket: WebSocket = {
CONNECTING: 0,
OPEN: 1,
CLOSING: 2,
CLOSED: 3,
url,
protocol: activeProtocol,
readyState: 1,
binaryType: "blob",
bufferedAmount: 0,
extensions: "",
onclose: null,
onerror: null,
onmessage: null,
onopen: null,
send: jest.fn(),
dispatchEvent: jest.fn(),
addEventListener: <E extends WebSocketEventType>(
eventType: E,
callback: WebSocketEventMap[E],
) => {
if (closed) {
return;
}
const subscribers = store[eventType];
const cb = callback as unknown as CallbackStore[E][0];
if (!subscribers.includes(cb)) {
subscribers.push(cb);
}
},
removeEventListener: <E extends WebSocketEventType>(
eventType: E,
callback: WebSocketEventMap[E],
) => {
if (closed) {
return;
}
const subscribers = store[eventType];
const cb = callback as unknown as CallbackStore[E][0];
if (subscribers.includes(cb)) {
const updated = store[eventType].filter((c) => c !== cb);
store[eventType] = updated as unknown as CallbackStore[E];
}
},
close: () => {
closed = true;
},
};
const publisher: MockPublisher = {
publishOpen: (event) => {
if (closed) {
return;
}
for (const sub of store.open) {
sub(event);
}
},
publishError: (event) => {
if (closed) {
return;
}
for (const sub of store.error) {
sub(event);
}
},
publishMessage: (event) => {
if (closed) {
return;
}
for (const sub of store.message) {
sub(event);
}
},
publishClose: (event) => {
if (closed) {
return;
}
for (const sub of store.close) {
sub(event);
}
},
};
return [mockSocket, publisher] as const;
}
describe(OneWayWebSocket.name, () => {
const dummyRoute = "/api/v2/blah";
it("Errors out if API route does not start with '/api/v2/'", () => {
const testRoutes: string[] = ["blah", "", "/", "/api", "/api/v225"];
for (const r of testRoutes) {
expect(() => {
new OneWayWebSocket({
apiRoute: r,
websocketInit: (url, protocols) => {
const [socket] = createMockWebSocket(url, protocols);
return socket;
},
});
}).toThrow(Error);
}
});
it("Lets a consumer add an event listener of each type", () => {
let publisher!: MockPublisher;
const oneWay = new OneWayWebSocket({
apiRoute: dummyRoute,
websocketInit: (url, protocols) => {
const [socket, pub] = createMockWebSocket(url, protocols);
publisher = pub;
return socket;
},
});
const onOpen = jest.fn();
const onClose = jest.fn();
const onError = jest.fn();
const onMessage = jest.fn();
oneWay.addEventListener("open", onOpen);
oneWay.addEventListener("close", onClose);
oneWay.addEventListener("error", onError);
oneWay.addEventListener("message", onMessage);
publisher.publishOpen(new Event("open"));
publisher.publishClose(new CloseEvent("close"));
publisher.publishError(
new ErrorEvent("error", {
error: new Error("Whoops - connection broke"),
}),
);
publisher.publishMessage(
new MessageEvent("message", {
data: "null",
}),
);
expect(onOpen).toHaveBeenCalledTimes(1);
expect(onClose).toHaveBeenCalledTimes(1);
expect(onError).toHaveBeenCalledTimes(1);
expect(onMessage).toHaveBeenCalledTimes(1);
});
it("Lets a consumer remove an event listener of each type", () => {
let publisher!: MockPublisher;
const oneWay = new OneWayWebSocket({
apiRoute: dummyRoute,
websocketInit: (url, protocols) => {
const [socket, pub] = createMockWebSocket(url, protocols);
publisher = pub;
return socket;
},
});
const onOpen = jest.fn();
const onClose = jest.fn();
const onError = jest.fn();
const onMessage = jest.fn();
oneWay.addEventListener("open", onOpen);
oneWay.addEventListener("close", onClose);
oneWay.addEventListener("error", onError);
oneWay.addEventListener("message", onMessage);
oneWay.removeEventListener("open", onOpen);
oneWay.removeEventListener("close", onClose);
oneWay.removeEventListener("error", onError);
oneWay.removeEventListener("message", onMessage);
publisher.publishOpen(new Event("open"));
publisher.publishClose(new CloseEvent("close"));
publisher.publishError(
new ErrorEvent("error", {
error: new Error("Whoops - connection broke"),
}),
);
publisher.publishMessage(
new MessageEvent("message", {
data: "null",
}),
);
expect(onOpen).toHaveBeenCalledTimes(0);
expect(onClose).toHaveBeenCalledTimes(0);
expect(onError).toHaveBeenCalledTimes(0);
expect(onMessage).toHaveBeenCalledTimes(0);
});
it("Only calls each callback once if callback is added multiple times", () => {
let publisher!: MockPublisher;
const oneWay = new OneWayWebSocket({
apiRoute: dummyRoute,
websocketInit: (url, protocols) => {
const [socket, pub] = createMockWebSocket(url, protocols);
publisher = pub;
return socket;
},
});
const onOpen = jest.fn();
const onClose = jest.fn();
const onError = jest.fn();
const onMessage = jest.fn();
for (let i = 0; i < 10; i++) {
oneWay.addEventListener("open", onOpen);
oneWay.addEventListener("close", onClose);
oneWay.addEventListener("error", onError);
oneWay.addEventListener("message", onMessage);
}
publisher.publishOpen(new Event("open"));
publisher.publishClose(new CloseEvent("close"));
publisher.publishError(
new ErrorEvent("error", {
error: new Error("Whoops - connection broke"),
}),
);
publisher.publishMessage(
new MessageEvent("message", {
data: "null",
}),
);
expect(onOpen).toHaveBeenCalledTimes(1);
expect(onClose).toHaveBeenCalledTimes(1);
expect(onError).toHaveBeenCalledTimes(1);
expect(onMessage).toHaveBeenCalledTimes(1);
});
it("Lets consumers register multiple callbacks for each event type", () => {
let publisher!: MockPublisher;
const oneWay = new OneWayWebSocket({
apiRoute: dummyRoute,
websocketInit: (url, protocols) => {
const [socket, pub] = createMockWebSocket(url, protocols);
publisher = pub;
return socket;
},
});
const onOpen1 = jest.fn();
const onClose1 = jest.fn();
const onError1 = jest.fn();
const onMessage1 = jest.fn();
oneWay.addEventListener("open", onOpen1);
oneWay.addEventListener("close", onClose1);
oneWay.addEventListener("error", onError1);
oneWay.addEventListener("message", onMessage1);
const onOpen2 = jest.fn();
const onClose2 = jest.fn();
const onError2 = jest.fn();
const onMessage2 = jest.fn();
oneWay.addEventListener("open", onOpen2);
oneWay.addEventListener("close", onClose2);
oneWay.addEventListener("error", onError2);
oneWay.addEventListener("message", onMessage2);
publisher.publishOpen(new Event("open"));
publisher.publishClose(new CloseEvent("close"));
publisher.publishError(
new ErrorEvent("error", {
error: new Error("Whoops - connection broke"),
}),
);
publisher.publishMessage(
new MessageEvent("message", {
data: "null",
}),
);
expect(onOpen1).toHaveBeenCalledTimes(1);
expect(onClose1).toHaveBeenCalledTimes(1);
expect(onError1).toHaveBeenCalledTimes(1);
expect(onMessage1).toHaveBeenCalledTimes(1);
expect(onOpen2).toHaveBeenCalledTimes(1);
expect(onClose2).toHaveBeenCalledTimes(1);
expect(onError2).toHaveBeenCalledTimes(1);
expect(onMessage2).toHaveBeenCalledTimes(1);
});
it("Computes the socket protocol based on the browser location protocol", () => {
const oneWay1 = new OneWayWebSocket({
apiRoute: dummyRoute,
websocketInit: (url, protocols) => {
const [socket] = createMockWebSocket(url, protocols);
return socket;
},
location: {
protocol: "https:",
host: "www.cool.com",
},
});
const oneWay2 = new OneWayWebSocket({
apiRoute: dummyRoute,
websocketInit: (url, protocols) => {
const [socket] = createMockWebSocket(url, protocols);
return socket;
},
location: {
protocol: "http:",
host: "www.cool.com",
},
});
expect(oneWay1.url).toMatch(/^wss:\/\//);
expect(oneWay2.url).toMatch(/^ws:\/\//);
});
it("Gives consumers pre-parsed versions of message events", () => {
let publisher!: MockPublisher;
const oneWay = new OneWayWebSocket({
apiRoute: dummyRoute,
websocketInit: (url, protocols) => {
const [socket, pub] = createMockWebSocket(url, protocols);
publisher = pub;
return socket;
},
});
const onMessage = jest.fn();
oneWay.addEventListener("message", onMessage);
const payload = {
value: 5,
cool: "yes",
};
const event = new MessageEvent("message", {
data: JSON.stringify(payload),
});
publisher.publishMessage(event);
expect(onMessage).toHaveBeenCalledWith({
sourceEvent: event,
parsedMessage: payload,
parseError: undefined,
});
});
it("Exposes parsing error if message payload could not be parsed as JSON", () => {
let publisher!: MockPublisher;
const oneWay = new OneWayWebSocket({
apiRoute: dummyRoute,
websocketInit: (url, protocols) => {
const [socket, pub] = createMockWebSocket(url, protocols);
publisher = pub;
return socket;
},
});
const onMessage = jest.fn();
oneWay.addEventListener("message", onMessage);
const payload = "definitely not valid JSON";
const event = new MessageEvent("message", {
data: payload,
});
publisher.publishMessage(event);
const arg: OneWayMessageEvent<never> = onMessage.mock.lastCall[0];
expect(arg.sourceEvent).toEqual(event);
expect(arg.parsedMessage).toEqual(undefined);
expect(arg.parseError).toBeInstanceOf(Error);
});
it("Passes all search param values through Websocket URL", () => {
const input1: Record<string, string> = {
cool: "yeah",
yeah: "cool",
blah: "5",
};
const oneWay1 = new OneWayWebSocket({
apiRoute: dummyRoute,
websocketInit: (url, protocols) => {
const [socket] = createMockWebSocket(url, protocols);
return socket;
},
searchParams: input1,
location: {
protocol: "https:",
host: "www.blah.com",
},
});
let [base, params] = oneWay1.url.split("?");
expect(base).toBe("wss://www.blah.com/api/v2/blah");
for (const [key, value] of Object.entries(input1)) {
expect(params).toContain(`${key}=${value}`);
}
const input2 = new URLSearchParams(input1);
const oneWay2 = new OneWayWebSocket({
apiRoute: dummyRoute,
websocketInit: (url, protocols) => {
const [socket] = createMockWebSocket(url, protocols);
return socket;
},
searchParams: input2,
location: {
protocol: "https:",
host: "www.blah.com",
},
});
[base, params] = oneWay2.url.split("?");
expect(base).toBe("wss://www.blah.com/api/v2/blah");
for (const [key, value] of Object.entries(input2)) {
expect(params).toContain(`${key}=${value}`);
}
const oneWay3 = new OneWayWebSocket({
apiRoute: dummyRoute,
websocketInit: (url, protocols) => {
const [socket] = createMockWebSocket(url, protocols);
return socket;
},
searchParams: undefined,
location: {
protocol: "https:",
host: "www.blah.com",
},
});
[base, params] = oneWay3.url.split("?");
expect(base).toBe("wss://www.blah.com/api/v2/blah");
expect(params).toBe(undefined);
});
});
+198
View File
@@ -0,0 +1,198 @@
/**
* @file A wrapper over WebSockets that (1) enforces one-way communication, and
* (2) supports automatically parsing JSON messages as they come in.
*
* This should ALWAYS be favored in favor of using Server-Sent Events and the
* built-in EventSource class for doing one-way communication. SSEs have a hard
* limitation on HTTP/1.1 and below where there is a maximum number of 6 ports
* that can ever be used for a domain (sometimes less depending on the browser).
* Not only is this limit shared with short-lived REST requests, but it also
* applies across tabs and windows. So if a user opens Coder in multiple tabs,
* there is a very real possibility that parts of the app will start to lock up
* without it being clear why.
*
* WebSockets do not have this limitation, even on HTTP/1.1 all modern
* browsers implement at least some degree of multiplexing for them.
*/
// Not bothering with trying to borrow methods from the base WebSocket type
// because it's already a mess of inheritance and generics, and we're going to
// have to add a few more
export type WebSocketEventType = "close" | "error" | "message" | "open";
export type OneWayMessageEvent<TData> = Readonly<
| {
sourceEvent: MessageEvent<string>;
parsedMessage: TData;
parseError: undefined;
}
| {
sourceEvent: MessageEvent<string>;
parsedMessage: undefined;
parseError: Error;
}
>;
type OneWayEventPayloadMap<TData> = {
close: CloseEvent;
error: Event;
message: OneWayMessageEvent<TData>;
open: Event;
};
type WebSocketMessageCallback = (payload: MessageEvent<string>) => void;
type OneWayEventCallback<TData, TEvent extends WebSocketEventType> = (
payload: OneWayEventPayloadMap<TData>[TEvent],
) => void;
interface OneWayWebSocketApi<TData> {
get url(): string;
addEventListener: <TEvent extends WebSocketEventType>(
eventType: TEvent,
callback: OneWayEventCallback<TData, TEvent>,
) => void;
removeEventListener: <TEvent extends WebSocketEventType>(
eventType: TEvent,
callback: OneWayEventCallback<TData, TEvent>,
) => void;
close: (closeCode?: number, reason?: string) => void;
}
type OneWayWebSocketInit = Readonly<{
apiRoute: string;
serverProtocols?: string | string[];
searchParams?: Record<string, string> | URLSearchParams;
binaryType?: BinaryType;
websocketInit?: (url: string, protocols?: string | string[]) => WebSocket;
location?: Readonly<{
protocol: string;
host: string;
}>;
}>;
function defaultInit(url: string, protocols?: string | string[]): WebSocket {
return new WebSocket(url, protocols);
}
export class OneWayWebSocket<TData = unknown>
implements OneWayWebSocketApi<TData>
{
readonly #socket: WebSocket;
readonly #messageCallbackWrappers = new Map<
OneWayEventCallback<TData, "message">,
WebSocketMessageCallback
>();
constructor(init: OneWayWebSocketInit) {
const {
apiRoute,
searchParams,
serverProtocols,
binaryType = "blob",
location = window.location,
websocketInit = defaultInit,
} = init;
if (!apiRoute.startsWith("/api/v2/")) {
throw new Error(`API route '${apiRoute}' does not begin with a slash`);
}
const formattedParams =
searchParams instanceof URLSearchParams
? searchParams
: new URLSearchParams(searchParams);
const paramsString = formattedParams.toString();
const paramsSuffix = paramsString ? `?${paramsString}` : "";
const wsProtocol = location.protocol === "https:" ? "wss:" : "ws:";
const url = `${wsProtocol}//${location.host}${apiRoute}${paramsSuffix}`;
this.#socket = websocketInit(url, serverProtocols);
this.#socket.binaryType = binaryType;
}
get url(): string {
return this.#socket.url;
}
addEventListener<TEvent extends WebSocketEventType>(
event: TEvent,
callback: OneWayEventCallback<TData, TEvent>,
): void {
// Not happy about all the type assertions, but there are some nasty
// type contravariance issues if you try to resolve the function types
// properly. This is actually the lesser of two evils
const looseCallback = callback as OneWayEventCallback<
TData,
WebSocketEventType
>;
if (this.#messageCallbackWrappers.has(looseCallback)) {
return;
}
if (event !== "message") {
this.#socket.addEventListener(event, looseCallback);
return;
}
const wrapped = (event: MessageEvent<string>): void => {
const messageCallback = looseCallback as OneWayEventCallback<
TData,
"message"
>;
try {
const message = JSON.parse(event.data) as TData;
messageCallback({
sourceEvent: event,
parseError: undefined,
parsedMessage: message,
});
} catch (err) {
messageCallback({
sourceEvent: event,
parseError: err as Error,
parsedMessage: undefined,
});
}
};
this.#socket.addEventListener(event as "message", wrapped);
this.#messageCallbackWrappers.set(looseCallback, wrapped);
}
removeEventListener<TEvent extends WebSocketEventType>(
event: TEvent,
callback: OneWayEventCallback<TData, TEvent>,
): void {
const looseCallback = callback as OneWayEventCallback<
TData,
WebSocketEventType
>;
if (event !== "message") {
this.#socket.removeEventListener(event, looseCallback);
return;
}
if (!this.#messageCallbackWrappers.has(looseCallback)) {
return;
}
const wrapper = this.#messageCallbackWrappers.get(looseCallback);
if (wrapper === undefined) {
throw new Error(
`Cannot unregister callback for event ${event}. This is likely an issue with the browser itself.`,
);
}
this.#socket.removeEventListener(event as "message", wrapper);
this.#messageCallbackWrappers.delete(looseCallback);
}
close(closeCode?: number, reason?: string): void {
this.#socket.close(closeCode, reason);
}
}