chore: add agent-connection-watch for workspaces (#24507)

<!--

If you have used AI to produce some or all of this PR, please ensure you have read our [AI Contribution guidelines](https://coder.com/docs/about/contributing/AI_CONTRIBUTING) before submitting.

-->

relates to GRU-18  
  
Adds basic implementation for Workspace Agent Connection Watch and tests.  
  
Missing are handling of logs.
This commit is contained in:
Spike Curtis
2026-05-20 13:09:11 -04:00
committed by GitHub
parent 05e47b9c0f
commit 8dc4d76890
20 changed files with 1679 additions and 61 deletions
+1
View File
@@ -1293,6 +1293,7 @@ coderd/apidoc/.gen: \
$(wildcard enterprise/coderd/*.go) \
$(wildcard codersdk/*.go) \
$(wildcard enterprise/wsproxy/wsproxysdk/*.go) \
$(wildcard coderd/workspaceconnwatcher/*.go) \
$(DB_GEN_FILES) \
coderd/rbac/object_gen.go \
.swaggo \
+122
View File
@@ -12783,6 +12783,41 @@ const docTemplate = `{
]
}
},
"/api/v2/workspaces/{workspace}/agent-connection-watch": {
"get": {
"produces": [
"application/json"
],
"tags": [
"Workspaces"
],
"summary": "Workspace Agent Connection Watch",
"operationId": "workspace-agent-connection-watch",
"parameters": [
{
"type": "string",
"format": "uuid",
"description": "Workspace ID",
"name": "workspace",
"in": "path",
"required": true
}
],
"responses": {
"101": {
"description": "Switching Protocols",
"schema": {
"$ref": "#/definitions/workspacesdk.ConnectionWatchEvent"
}
}
},
"security": [
{
"CoderSessionToken": []
}
]
}
},
"/api/v2/workspaces/{workspace}/autostart": {
"put": {
"consumes": [
@@ -27964,6 +27999,93 @@ const docTemplate = `{
}
}
},
"workspacesdk.AgentUpdate": {
"type": "object",
"properties": {
"id": {
"type": "string",
"format": "uuid"
},
"lifecycle": {
"$ref": "#/definitions/codersdk.WorkspaceAgentLifecycle"
}
}
},
"workspacesdk.BuildUpdate": {
"type": "object",
"properties": {
"job_status": {
"$ref": "#/definitions/codersdk.ProvisionerJobStatus"
},
"transition": {
"$ref": "#/definitions/codersdk.WorkspaceTransition"
}
}
},
"workspacesdk.ConnectionWatchEvent": {
"type": "object",
"properties": {
"agent_update": {
"$ref": "#/definitions/workspacesdk.AgentUpdate"
},
"build_update": {
"$ref": "#/definitions/workspacesdk.BuildUpdate"
},
"error": {
"$ref": "#/definitions/workspacesdk.WatchError"
}
}
},
"workspacesdk.WatchError": {
"type": "object",
"properties": {
"code": {
"$ref": "#/definitions/workspacesdk.WatchErrorCode"
},
"details": {
"type": "string"
},
"message": {
"type": "string"
},
"retryable": {
"type": "boolean"
}
}
},
"workspacesdk.WatchErrorCode": {
"type": "integer",
"enum": [
0,
1,
2,
3,
4,
5,
6
],
"x-enum-comments": {
"_": "Ensure that zero value is not a valid code"
},
"x-enum-descriptions": [
"Ensure that zero value is not a valid code",
"",
"",
"",
"",
"",
""
],
"x-enum-varnames": [
"_",
"WatchErrorTooManyAgents",
"WatchErrorNameNotFound",
"WatchErrorNoAgents",
"WatchErrorServerShutdown",
"WatchErrorDatabase",
"WatchErrorInternal"
]
},
"wsproxysdk.CryptoKeysResponse": {
"type": "object",
"properties": {
+110
View File
@@ -11343,6 +11343,37 @@
]
}
},
"/api/v2/workspaces/{workspace}/agent-connection-watch": {
"get": {
"produces": ["application/json"],
"tags": ["Workspaces"],
"summary": "Workspace Agent Connection Watch",
"operationId": "workspace-agent-connection-watch",
"parameters": [
{
"type": "string",
"format": "uuid",
"description": "Workspace ID",
"name": "workspace",
"in": "path",
"required": true
}
],
"responses": {
"101": {
"description": "Switching Protocols",
"schema": {
"$ref": "#/definitions/workspacesdk.ConnectionWatchEvent"
}
}
},
"security": [
{
"CoderSessionToken": []
}
]
}
},
"/api/v2/workspaces/{workspace}/autostart": {
"put": {
"consumes": ["application/json"],
@@ -25823,6 +25854,85 @@
}
}
},
"workspacesdk.AgentUpdate": {
"type": "object",
"properties": {
"id": {
"type": "string",
"format": "uuid"
},
"lifecycle": {
"$ref": "#/definitions/codersdk.WorkspaceAgentLifecycle"
}
}
},
"workspacesdk.BuildUpdate": {
"type": "object",
"properties": {
"job_status": {
"$ref": "#/definitions/codersdk.ProvisionerJobStatus"
},
"transition": {
"$ref": "#/definitions/codersdk.WorkspaceTransition"
}
}
},
"workspacesdk.ConnectionWatchEvent": {
"type": "object",
"properties": {
"agent_update": {
"$ref": "#/definitions/workspacesdk.AgentUpdate"
},
"build_update": {
"$ref": "#/definitions/workspacesdk.BuildUpdate"
},
"error": {
"$ref": "#/definitions/workspacesdk.WatchError"
}
}
},
"workspacesdk.WatchError": {
"type": "object",
"properties": {
"code": {
"$ref": "#/definitions/workspacesdk.WatchErrorCode"
},
"details": {
"type": "string"
},
"message": {
"type": "string"
},
"retryable": {
"type": "boolean"
}
}
},
"workspacesdk.WatchErrorCode": {
"type": "integer",
"enum": [0, 1, 2, 3, 4, 5, 6],
"x-enum-comments": {
"_": "Ensure that zero value is not a valid code"
},
"x-enum-descriptions": [
"Ensure that zero value is not a valid code",
"",
"",
"",
"",
"",
""
],
"x-enum-varnames": [
"_",
"WatchErrorTooManyAgents",
"WatchErrorNameNotFound",
"WatchErrorNoAgents",
"WatchErrorServerShutdown",
"WatchErrorDatabase",
"WatchErrorInternal"
]
},
"wsproxysdk.CryptoKeysResponse": {
"type": "object",
"properties": {
+7
View File
@@ -92,6 +92,7 @@ import (
"github.com/coder/coder/v2/coderd/webpush"
"github.com/coder/coder/v2/coderd/workspaceapps"
"github.com/coder/coder/v2/coderd/workspaceapps/appurl"
"github.com/coder/coder/v2/coderd/workspaceconnwatcher"
"github.com/coder/coder/v2/coderd/workspacestats"
"github.com/coder/coder/v2/coderd/wsbuilder"
"github.com/coder/coder/v2/coderd/x/chatd"
@@ -923,6 +924,8 @@ func New(options *Options) *API {
APIKeyEncryptionKeycache: options.AppEncryptionKeyCache,
})
api.workspaceAgentConnWatcher = workspaceconnwatcher.New(api.ctx, options.Logger, options.Pubsub, options.Database)
apiKeyMiddleware := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
DB: options.Database,
ActivateDormantUser: ActivateDormantUser(options.Logger, &api.Auditor, options.Database),
@@ -1820,6 +1823,7 @@ func New(options *Options) *API {
r.Patch("/", api.patchWorkspaceACL)
r.Delete("/", api.deleteWorkspaceACL)
})
r.Get("/agent-connection-watch", api.workspaceAgentConnWatcher.WorkspaceAgentConnectionWatch)
})
})
r.Route("/workspacebuilds/{workspacebuild}", func(r chi.Router) {
@@ -2238,6 +2242,8 @@ type API struct {
// profile collection (via /debug/profile) can run at a time. The CPU
// profiler is process-global, so concurrent collections would fail.
ProfileCollecting atomic.Bool
workspaceAgentConnWatcher *workspaceconnwatcher.Watcher
}
// Close waits for all WebSocket connections to drain before returning.
@@ -2301,6 +2307,7 @@ func (api *API) Close() error {
_ = api.AppSigningKeyCache.Close()
_ = api.AppEncryptionKeyCache.Close()
_ = api.UpdatesProvider.Close()
api.workspaceAgentConnWatcher.Close()
if current := api.PrebuildsReconciler.Load(); current != nil {
ctx, giveUp := context.WithTimeoutCause(context.Background(), time.Second*30, xerrors.New("gave up waiting for reconciler to stop before shutdown"))
+28
View File
@@ -0,0 +1,28 @@
package coderdtest
import (
"sync/atomic"
"testing"
"github.com/prometheus/client_golang/prometheus"
"go.uber.org/mock/gomock"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbmock"
"github.com/coder/coder/v2/coderd/rbac"
)
func MockedDatabaseWithAuthz(t testing.TB, logger slog.Logger) (*gomock.Controller, *dbmock.MockStore, database.Store, rbac.Authorizer) {
ctrl := gomock.NewController(t)
mDB := dbmock.NewMockStore(ctrl)
auth := rbac.NewStrictCachingAuthorizer(prometheus.NewRegistry())
accessControlStore := &atomic.Pointer[dbauthz.AccessControlStore]{}
var acs dbauthz.AccessControlStore = dbauthz.AGPLTemplateAccessControlStore{}
accessControlStore.Store(&acs)
// dbauthz will call Wrappers() to check for wrapped databases
mDB.EXPECT().Wrappers().Return([]string{}).AnyTimes()
authDB := dbauthz.New(mDB, auth, logger, accessControlStore)
return ctrl, mDB, authDB, auth
}
+31
View File
@@ -0,0 +1,31 @@
package coderdtest
import (
"github.com/google/uuid"
"github.com/coder/coder/v2/coderd/rbac"
"github.com/coder/coder/v2/coderd/rbac/rolestore"
)
func MemberSubject(userID, orgID uuid.UUID) rbac.Subject {
memberRole, err := rbac.RoleByName(rbac.RoleMember())
if err != nil {
panic(err)
}
orgMember, err := rolestore.TestingGetSystemRole(
rbac.RoleOrgMember(),
orgID,
rbac.OrgSettings{ShareableWorkspaceOwners: rbac.ShareableWorkspaceOwnersNone},
)
if err != nil {
panic(err)
}
return rbac.Subject{
FriendlyName: "coderdtest-member",
Email: "member@coderd.test",
Type: rbac.SubjectTypeUser,
ID: userID.String(),
Roles: rbac.Roles{memberRole, orgMember},
Scope: rbac.ScopeAll,
}.WithCachedASTValue()
}
+1 -1
View File
@@ -16,7 +16,7 @@ import (
func TestEndpointsDocumented(t *testing.T) {
t.Parallel()
swaggerComments, err := coderdtest.ParseSwaggerComments("..")
swaggerComments, err := coderdtest.ParseSwaggerComments("..", "../workspaceconnwatcher")
require.NoError(t, err, "can't parse swagger comments")
require.NotEmpty(t, swaggerComments, "swagger comments must be present")
+4
View File
@@ -54,3 +54,7 @@ func ExtractWorkspaceParam(db database.Store) func(http.Handler) http.Handler {
})
}
}
func WithWorkspaceParam(ctx context.Context, workspace database.Workspace) context.Context {
return context.WithValue(ctx, workspaceParamContextKey{}, workspace)
}
+19
View File
@@ -170,6 +170,25 @@ var systemRoles = map[string]permissionsFunc{
rbac.RoleOrgServiceAccount(): rbac.OrgServiceAccountPermissions,
}
func TestingGetSystemRole(name string, orgID uuid.UUID, settings rbac.OrgSettings) (rbac.Role, error) {
f, ok := systemRoles[name]
if !ok {
return rbac.Role{}, xerrors.Errorf("role %q not found", name)
}
perms := f(settings)
return rbac.Role{
Identifier: rbac.RoleIdentifier{Name: name, OrganizationID: orgID},
DisplayName: "",
Site: nil,
ByOrgID: map[string]rbac.OrgPermissions{
orgID.String(): {
Org: perms.Org,
Member: perms.Member,
},
},
}, nil
}
// permissionsFunc produces the desired permissions for a system role
// given organization settings.
type permissionsFunc func(rbac.OrgSettings) rbac.OrgRolePermissions
+333
View File
@@ -0,0 +1,333 @@
package workspaceconnwatcher
import (
"context"
"database/sql"
"errors"
"net/http"
"sync"
"github.com/google/uuid"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/pubsub"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/coderd/httpmw"
"github.com/coder/coder/v2/coderd/wspubsub"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
"github.com/coder/coder/v2/codersdk/wsjson"
"github.com/coder/websocket"
)
type Watcher struct {
logger slog.Logger
sub pubsub.Subscriber
db database.Store
ctx context.Context
cancel context.CancelFunc
mu sync.Mutex
wg sync.WaitGroup
closed bool
}
type event struct {
sync bool
wsEvent *wspubsub.WorkspaceEvent
}
func New(ctx context.Context, logger slog.Logger, sub pubsub.Subscriber, db database.Store) *Watcher {
ctx, cancel := context.WithCancel(ctx)
w := &Watcher{
logger: logger.Named("wsconnwatcher"),
ctx: ctx,
cancel: cancel,
sub: sub,
db: db,
}
go func() {
<-ctx.Done()
w.Close()
}()
return w
}
// @Summary Workspace Agent Connection Watch
// @ID workspace-agent-connection-watch
// @Security CoderSessionToken
// @Produce json
// @Tags Workspaces
// @Param workspace path string true "Workspace ID" format(uuid)
// @Success 101 {object} workspacesdk.ConnectionWatchEvent
// @Router /api/v2/workspaces/{workspace}/agent-connection-watch [get]
func (w *Watcher) WorkspaceAgentConnectionWatch(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
workspace := httpmw.WorkspaceParam(r)
agentName := r.URL.Query().Get("agent_name")
filteredEvents := make(chan event, 1)
filteredEvents <- event{sync: true} // init sync
cancelWorkspaceSubscribe, err := w.sub.SubscribeWithErr(wspubsub.WorkspaceEventChannel(workspace.OwnerID),
wspubsub.HandleWorkspaceEvent(
func(ctx context.Context, payload wspubsub.WorkspaceEvent, err error) {
if err != nil {
// subscription error, resync
select {
case filteredEvents <- event{sync: true}:
case <-ctx.Done():
}
return
}
if payload.WorkspaceID != workspace.ID {
return
}
select {
case filteredEvents <- event{wsEvent: &payload}:
case <-ctx.Done():
}
}))
if err != nil {
w.logger.Error(ctx, "failed to subscribe to workspace events",
slog.Error(err), slog.F("owner_id", workspace.OwnerID))
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error setting up workspace event subscription",
// Don't include the error in case it leaks infra details about the pubsub
})
return
}
defer cancelWorkspaceSubscribe()
closed := false
w.mu.Lock()
closed = w.closed
if !closed {
w.wg.Add(1)
}
w.mu.Unlock()
if closed {
w.logger.Debug(ctx, "server is closed, writing error")
httpapi.Write(ctx, rw, http.StatusServiceUnavailable, codersdk.Response{
Message: "Server instance is shutting down",
})
return
}
defer w.wg.Done()
conn, err := websocket.Accept(rw, r, nil)
if err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Failed to accept WebSocket.",
Detail: err.Error(),
})
return
}
// CloseRead starts a goroutine to read and discard messages from the client,
// including Pong messages sent in response to our Ping heartbeats.
_ = conn.CloseRead(ctx)
ctx, cancel := context.WithCancel(ctx)
go httpapi.HeartbeatClose(ctx, w.logger, cancel, conn)
defer cancel()
u := &updater{
db: w.db,
watcherCtx: w.ctx,
connCtx: ctx,
conn: conn,
workspaceID: workspace.ID,
events: filteredEvents,
agentName: agentName,
}
u.run()
}
func (w *Watcher) Close() {
w.mu.Lock()
w.closed = true
w.mu.Unlock()
w.cancel()
w.wg.Wait()
}
type updater struct {
db database.Store
watcherCtx context.Context
connCtx context.Context
conn *websocket.Conn
enc *wsjson.Encoder[workspacesdk.ConnectionWatchEvent]
workspaceID uuid.UUID
events <-chan event
agentName string
lastBuild database.GetLatestWorkspaceBuildWithStatusByWorkspaceIDRow
}
func (u *updater) run() {
u.enc = wsjson.NewEncoder[workspacesdk.ConnectionWatchEvent](u.conn, websocket.MessageText)
defer func() {
// this is a no-op if we have already closed for some other reason.
_ = u.enc.Close(websocket.StatusNormalClosure)
}()
for {
select {
case <-u.watcherCtx.Done():
u.errorThenClose(workspacesdk.WatchError{
Code: workspacesdk.WatchErrorServerShutdown,
Retryable: true,
Message: "server is shutting down",
})
return
case <-u.connCtx.Done():
return
case e := <-u.events:
if e.sync {
// zero this out so we'll send a full update
u.lastBuild = database.GetLatestWorkspaceBuildWithStatusByWorkspaceIDRow{}
if !u.buildUpdate() {
return
}
}
if e.wsEvent != nil {
switch e.wsEvent.Kind {
case wspubsub.WorkspaceEventKindStateChange:
if !u.buildUpdate() {
return
}
case wspubsub.WorkspaceEventKindAgentLifecycleUpdate:
if !u.maybeSendAgentUpdate() {
return
}
}
}
}
}
}
func (u *updater) buildUpdate() bool {
build, err := u.db.GetLatestWorkspaceBuildWithStatusByWorkspaceID(u.connCtx, u.workspaceID)
if err != nil {
retryable := true
details := err.Error()
if errors.Is(err, sql.ErrNoRows) {
// There is no build (unlikely), or the workspace was deleted. In both cases, retrying won't help.
retryable = false
}
if dbauthz.IsNotAuthorizedError(err) {
retryable = false
details = "unauthorized" // security: don't leak internal authz details
}
u.errorThenClose(workspacesdk.WatchError{
Code: workspacesdk.WatchErrorDatabase,
Retryable: retryable,
Message: "failed to fetch latest workspace build",
Details: details,
})
return false
}
if build.BuildNumber != u.lastBuild.BuildNumber ||
build.JobStatus != u.lastBuild.JobStatus ||
build.Transition != u.lastBuild.Transition {
u.lastBuild = build
err = u.enc.Encode(workspacesdk.ConnectionWatchEvent{BuildUpdate: &workspacesdk.BuildUpdate{
Transition: codersdk.WorkspaceTransition(build.Transition),
JobStatus: codersdk.ProvisionerJobStatus(build.JobStatus),
}})
if err != nil {
// probably this is just that the connection is closed, but in case there is some actual JSON serialization
// error, send a close frame.
_ = u.conn.Close(websocket.StatusInternalError, "failed to encode build update")
return false
}
return u.maybeSendAgentUpdate()
}
return true
}
func (u *updater) maybeSendAgentUpdate() (ok bool) {
if u.lastBuild.Transition != database.WorkspaceTransitionStart ||
u.lastBuild.JobStatus != database.ProvisionerJobStatusSucceeded {
// only send agent updates for successfully started workspaces
return true
}
agents, err := u.db.GetWorkspaceAgentsByWorkspaceAndBuildNumber(u.connCtx,
database.GetWorkspaceAgentsByWorkspaceAndBuildNumberParams{
WorkspaceID: u.workspaceID,
BuildNumber: u.lastBuild.BuildNumber,
})
if err != nil && !errors.Is(err, sql.ErrNoRows) {
details := err.Error()
retryable := true
if dbauthz.IsNotAuthorizedError(err) {
retryable = false
details = "unauthorized"
}
u.errorThenClose(workspacesdk.WatchError{
Code: workspacesdk.WatchErrorDatabase,
Retryable: retryable,
Message: "failed to fetch workspace agents",
Details: details,
})
return false
}
if len(agents) == 0 {
u.errorThenClose(workspacesdk.WatchError{
Code: workspacesdk.WatchErrorNoAgents,
Retryable: false,
Message: "no agents found for workspace",
})
return false
}
if len(agents) > 1 && u.agentName == "" {
u.errorThenClose(workspacesdk.WatchError{
Code: workspacesdk.WatchErrorTooManyAgents,
Retryable: false,
Message: "more than one agent on workspace and target not specified",
})
return false
}
var agent database.WorkspaceAgent
if u.agentName == "" {
agent = agents[0]
} else {
for _, a := range agents {
if a.Name == u.agentName {
agent = a
break
}
}
if agent.ID == uuid.Nil {
u.errorThenClose(workspacesdk.WatchError{
Code: workspacesdk.WatchErrorNameNotFound,
Retryable: false,
Message: "target agent not found by name",
})
return false
}
}
err = u.enc.Encode(workspacesdk.ConnectionWatchEvent{AgentUpdate: &workspacesdk.AgentUpdate{
Lifecycle: codersdk.WorkspaceAgentLifecycle(agent.LifecycleState),
ID: agent.ID,
}})
if err != nil {
// probably this is just that the connection is closed, but in case there is some actual JSON serialization
// error, send a close frame.
_ = u.conn.Close(websocket.StatusInternalError, "failed to encode agent update")
return false
}
return true
}
func (u *updater) errorThenClose(err workspacesdk.WatchError) {
_ = u.enc.Encode(workspacesdk.ConnectionWatchEvent{Error: &err})
// ignore encoding errors above because in any case, we are going to close the connection.
_ = u.conn.Close(websocket.StatusNormalClosure, "error")
}
+474
View File
@@ -0,0 +1,474 @@
package workspaceconnwatcher_test
import (
"context"
"encoding/json"
"errors"
"net/http"
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbmock"
"github.com/coder/coder/v2/coderd/database/pubsub"
"github.com/coder/coder/v2/coderd/httpmw"
"github.com/coder/coder/v2/coderd/workspaceconnwatcher"
"github.com/coder/coder/v2/coderd/wspubsub"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
"github.com/coder/coder/v2/codersdk/wsjson"
"github.com/coder/coder/v2/testutil"
"github.com/coder/websocket"
)
var (
workspaceID = uuid.UUID{1}
userID = uuid.UUID{2}
orgID = uuid.UUID{3}
agentID = uuid.UUID{4}
)
type harness struct {
db *dbmock.MockStore
watcher *workspaceconnwatcher.Watcher
pub pubsub.Publisher
logger slog.Logger
// Initialized, but overridable before Dial()
workspace database.Workspace
userID, orgID uuid.UUID
}
func newHarness(ctx context.Context, t *testing.T, logger slog.Logger) *harness {
h := &harness{
workspace: database.Workspace{
ID: workspaceID,
OrganizationID: orgID,
OwnerID: userID,
},
orgID: orgID,
userID: userID,
logger: logger,
}
ps := pubsub.NewInMemory()
h.pub = ps
var authzDB database.Store
_, h.db, authzDB, _ = coderdtest.MockedDatabaseWithAuthz(t, logger)
h.watcher = workspaceconnwatcher.New(ctx, logger.Named("watcher"), ps, authzDB)
t.Cleanup(h.watcher.Close)
return h
}
func (h *harness) Dial(ctx context.Context, url string) (*wsjson.Decoder[workspacesdk.ConnectionWatchEvent], error) {
rt := testutil.InMemWebsocketRoundTripper{
Handler: http.HandlerFunc(h.watcher.WorkspaceAgentConnectionWatch),
CtxMutator: func(ctx context.Context) context.Context {
ctx = httpmw.WithWorkspaceParam(ctx, h.workspace)
ctx = dbauthz.As(ctx, coderdtest.MemberSubject(userID, orgID))
return ctx
},
Logger: h.logger.Named("roundtripper"),
}
// nolint: bodyclose
clientSock, resp, err := websocket.Dial(ctx, url, &websocket.DialOptions{
HTTPClient: &http.Client{Transport: rt},
})
if err != nil {
if resp.StatusCode != http.StatusSwitchingProtocols {
return nil, codersdk.ReadBodyAsError(resp)
}
return nil, err
}
dec := wsjson.NewDecoder[workspacesdk.ConnectionWatchEvent](
clientSock, websocket.MessageText, h.logger.Named("decoder"))
return dec, nil
}
func TestWatcher_Agents(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
agents []database.WorkspaceAgent
agentDBError error
url string
expectedAgentUpdate *workspacesdk.AgentUpdate
expectedErrorCode workspacesdk.WatchErrorCode
expectedErrorRetryable bool
}{
{
name: "noNameSingleAgent",
agents: []database.WorkspaceAgent{
{
Name: "test",
ID: agentID,
LifecycleState: database.WorkspaceAgentLifecycleStateCreated,
},
},
url: "wss://local.test/",
expectedAgentUpdate: &workspacesdk.AgentUpdate{
Lifecycle: codersdk.WorkspaceAgentLifecycleCreated,
ID: agentID,
},
},
{
name: "noNameMultiAgent",
agents: []database.WorkspaceAgent{
{
Name: "agent0",
ID: agentID,
LifecycleState: database.WorkspaceAgentLifecycleStateCreated,
},
{
Name: "agent1",
ID: uuid.UUID{77},
LifecycleState: database.WorkspaceAgentLifecycleStateCreated,
},
},
url: "wss://local.test/",
expectedErrorCode: workspacesdk.WatchErrorTooManyAgents,
expectedErrorRetryable: false,
},
{
name: "namedAgentMultiAgent",
agents: []database.WorkspaceAgent{
{
Name: "agent0",
ID: agentID,
LifecycleState: database.WorkspaceAgentLifecycleStateCreated,
},
{
Name: "agent1",
ID: uuid.UUID{77},
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
},
},
url: "wss://local.test/?agent_name=agent0",
expectedAgentUpdate: &workspacesdk.AgentUpdate{
Lifecycle: codersdk.WorkspaceAgentLifecycleCreated,
ID: agentID,
},
},
{
name: "namedAgentNonexistent",
agents: []database.WorkspaceAgent{
{
Name: "agent0",
ID: agentID,
LifecycleState: database.WorkspaceAgentLifecycleStateCreated,
},
{
Name: "agent1",
ID: uuid.UUID{77},
LifecycleState: database.WorkspaceAgentLifecycleStateCreated,
},
},
url: "wss://local.test/?agent_name=agent2",
expectedErrorCode: workspacesdk.WatchErrorNameNotFound,
expectedErrorRetryable: false,
},
{
name: "dbError",
agentDBError: xerrors.New("a bad thing happened"),
url: "wss://local.test/",
expectedErrorCode: workspacesdk.WatchErrorDatabase,
expectedErrorRetryable: true,
},
{
name: "unauthorized",
agentDBError: dbauthz.NotAuthorizedError{Err: xerrors.New("not allowed")},
url: "wss://local.test/",
expectedErrorCode: workspacesdk.WatchErrorDatabase,
expectedErrorRetryable: false,
},
{
name: "noAgents",
agents: []database.WorkspaceAgent{},
url: "wss://local.test/",
expectedErrorCode: workspacesdk.WatchErrorNoAgents,
expectedErrorRetryable: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := testutil.Logger(t)
h := newHarness(ctx, t, logger)
h.db.EXPECT().GetLatestWorkspaceBuildWithStatusByWorkspaceID(gomock.Any(), h.workspace.ID).
Times(1).
Return(database.GetLatestWorkspaceBuildWithStatusByWorkspaceIDRow{
Transition: database.WorkspaceTransitionStart,
BuildNumber: 1,
JobStatus: database.ProvisionerJobStatusSucceeded,
WorkspaceTable: database.WorkspaceTable{
ID: h.workspace.ID,
OwnerID: userID,
OrganizationID: orgID,
},
}, nil)
// RBAC check for agent query
h.db.EXPECT().GetWorkspaceByID(gomock.Any(), h.workspace.ID).
Times(1).
Return(h.workspace, nil)
h.db.EXPECT().GetWorkspaceAgentsByWorkspaceAndBuildNumber(
gomock.Any(),
database.GetWorkspaceAgentsByWorkspaceAndBuildNumberParams{
WorkspaceID: h.workspace.ID,
BuildNumber: 1,
}).
Times(1).
Return(tc.agents, tc.agentDBError)
dec, err := h.Dial(ctx, tc.url)
require.NoError(t, err)
defer dec.Close()
events := dec.Chan()
e0 := testutil.RequireReceive(ctx, t, events)
require.Equal(t, workspacesdk.ConnectionWatchEvent{
BuildUpdate: &workspacesdk.BuildUpdate{
Transition: codersdk.WorkspaceTransitionStart,
JobStatus: codersdk.ProvisionerJobSucceeded,
},
}, e0)
e1 := testutil.RequireReceive(ctx, t, events)
if tc.expectedAgentUpdate != nil {
require.Equal(t, workspacesdk.ConnectionWatchEvent{AgentUpdate: tc.expectedAgentUpdate}, e1)
} else {
require.NotNil(t, e1.Error)
require.Equal(t, tc.expectedErrorRetryable, e1.Error.Retryable)
require.Equal(t, tc.expectedErrorCode, e1.Error.Code)
}
})
}
}
func TestWatcher_LostAccess(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := testutil.Logger(t)
h := newHarness(ctx, t, logger)
h.db.EXPECT().GetLatestWorkspaceBuildWithStatusByWorkspaceID(gomock.Any(), h.workspace.ID).
Times(1).
Return(database.GetLatestWorkspaceBuildWithStatusByWorkspaceIDRow{
Transition: database.WorkspaceTransitionStart,
BuildNumber: 1,
JobStatus: database.ProvisionerJobStatusSucceeded,
WorkspaceTable: database.WorkspaceTable{
ID: h.workspace.ID,
OwnerID: uuid.UUID{99}, // workspace gets a new owner, e.g.
OrganizationID: orgID,
},
}, nil)
dec, err := h.Dial(ctx, "wss://local.test/")
require.NoError(t, err)
defer func() {
err := dec.Close()
require.NoError(t, err)
}()
events := dec.Chan()
e0 := testutil.RequireReceive(ctx, t, events)
require.NotNil(t, e0.Error)
require.Equal(t, workspacesdk.WatchErrorDatabase, e0.Error.Code)
require.False(t, e0.Error.Retryable)
require.Equal(t, "unauthorized", e0.Error.Details, "should not leak internal auth details")
}
func TestWatcher_PublishChanges(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := testutil.Logger(t)
h := newHarness(ctx, t, logger)
// Initial build update, job is running.
build0 := h.db.EXPECT().GetLatestWorkspaceBuildWithStatusByWorkspaceID(gomock.Any(), h.workspace.ID).
Times(1).
Return(database.GetLatestWorkspaceBuildWithStatusByWorkspaceIDRow{
Transition: database.WorkspaceTransitionStart,
BuildNumber: 1,
JobStatus: database.ProvisionerJobStatusRunning,
WorkspaceTable: database.WorkspaceTable{
ID: h.workspace.ID,
OwnerID: userID,
OrganizationID: orgID,
},
}, nil)
dec, err := h.Dial(ctx, "wss://local.test/")
require.NoError(t, err)
defer func() {
err := dec.Close()
require.NoError(t, err)
}()
events := dec.Chan()
e0 := testutil.RequireReceive(ctx, t, events)
require.Equal(t, workspacesdk.ConnectionWatchEvent{
BuildUpdate: &workspacesdk.BuildUpdate{
Transition: codersdk.WorkspaceTransitionStart,
JobStatus: codersdk.ProvisionerJobRunning,
},
}, e0)
// Since job is still running, we don't immediately query for agents. Next we set up the db queries and send in an
// update over the pubsub to kick a new query.
build1 := h.db.EXPECT().GetLatestWorkspaceBuildWithStatusByWorkspaceID(gomock.Any(), h.workspace.ID).
After(build0).
Times(1).
Return(database.GetLatestWorkspaceBuildWithStatusByWorkspaceIDRow{
Transition: database.WorkspaceTransitionStart,
BuildNumber: 1,
JobStatus: database.ProvisionerJobStatusSucceeded,
WorkspaceTable: database.WorkspaceTable{
ID: h.workspace.ID,
OwnerID: userID,
OrganizationID: orgID,
},
}, nil)
// RBAC check for agent query
h.db.EXPECT().GetWorkspaceByID(gomock.Any(), h.workspace.ID).
After(build1).
Times(2). // these queries are identical between the initial and the update below
Return(h.workspace, nil)
agent0 := h.db.EXPECT().GetWorkspaceAgentsByWorkspaceAndBuildNumber(
gomock.Any(),
database.GetWorkspaceAgentsByWorkspaceAndBuildNumberParams{
WorkspaceID: h.workspace.ID,
BuildNumber: 1,
}).
After(build1).
Times(1).
Return([]database.WorkspaceAgent{
{
Name: "test",
ID: agentID,
LifecycleState: database.WorkspaceAgentLifecycleStateCreated,
},
}, nil)
changeMsg := wspubsub.WorkspaceEvent{
Kind: wspubsub.WorkspaceEventKindStateChange,
WorkspaceID: h.workspace.ID,
}
changeBytes, err := json.Marshal(changeMsg)
require.NoError(t, err)
err = h.pub.Publish(wspubsub.WorkspaceEventChannel(h.workspace.OwnerID), changeBytes)
require.NoError(t, err)
e1 := testutil.RequireReceive(ctx, t, events)
require.Equal(t, workspacesdk.ConnectionWatchEvent{
BuildUpdate: &workspacesdk.BuildUpdate{
Transition: codersdk.WorkspaceTransitionStart,
JobStatus: codersdk.ProvisionerJobSucceeded,
},
}, e1)
e2 := testutil.RequireReceive(ctx, t, events)
require.Equal(t, workspacesdk.ConnectionWatchEvent{AgentUpdate: &workspacesdk.AgentUpdate{
ID: agentID,
Lifecycle: codersdk.WorkspaceAgentLifecycleCreated,
}}, e2)
// Finally, send in a change event for the agent. But first, program the mock for the expected query.
h.db.EXPECT().GetWorkspaceAgentsByWorkspaceAndBuildNumber(
gomock.Any(),
database.GetWorkspaceAgentsByWorkspaceAndBuildNumberParams{
WorkspaceID: h.workspace.ID,
BuildNumber: 1,
}).
After(agent0).
Times(1).
Return([]database.WorkspaceAgent{
{
Name: "test",
ID: agentID,
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
},
}, nil)
changeMsg = wspubsub.WorkspaceEvent{
Kind: wspubsub.WorkspaceEventKindAgentLifecycleUpdate,
WorkspaceID: h.workspace.ID,
AgentID: &agentID,
}
changeBytes, err = json.Marshal(changeMsg)
require.NoError(t, err)
err = h.pub.Publish(wspubsub.WorkspaceEventChannel(h.workspace.OwnerID), changeBytes)
require.NoError(t, err)
e3 := testutil.RequireReceive(ctx, t, events)
require.Equal(t, workspacesdk.ConnectionWatchEvent{AgentUpdate: &workspacesdk.AgentUpdate{
ID: agentID,
Lifecycle: codersdk.WorkspaceAgentLifecycleReady,
}}, e3)
}
func TestWatcher_ClosedBeforeDial(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := testutil.Logger(t)
h := newHarness(ctx, t, logger)
h.watcher.Close()
_, err := h.Dial(ctx, "wss://local.test/")
var sdkError *codersdk.Error
require.True(t, errors.As(err, &sdkError))
require.Equal(t, http.StatusServiceUnavailable, sdkError.StatusCode())
}
func TestWatcher_ClosedAfterDial(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := testutil.Logger(t)
h := newHarness(ctx, t, logger)
h.db.EXPECT().GetLatestWorkspaceBuildWithStatusByWorkspaceID(gomock.Any(), h.workspace.ID).
Times(1).
Return(database.GetLatestWorkspaceBuildWithStatusByWorkspaceIDRow{
Transition: database.WorkspaceTransitionStop,
BuildNumber: 1,
JobStatus: database.ProvisionerJobStatusSucceeded,
WorkspaceTable: database.WorkspaceTable{
ID: h.workspace.ID,
OwnerID: userID,
OrganizationID: orgID,
},
}, nil)
dec, err := h.Dial(ctx, "wss://local.test/")
require.NoError(t, err)
events := dec.Chan()
_ = testutil.RequireReceive(ctx, t, events)
closed := make(chan struct{})
go func() {
defer close(closed)
h.watcher.Close()
}()
e := testutil.RequireReceive(ctx, t, events)
require.NotNil(t, e.Error)
require.Equal(t, workspacesdk.WatchErrorServerShutdown, e.Error.Code)
require.True(t, e.Error.Retryable)
select {
case <-ctx.Done():
t.Fatal("context timed out")
case _, ok := <-events:
require.False(t, ok, "socket not closed")
}
testutil.TryReceive(ctx, t, closed)
}
@@ -0,0 +1,100 @@
package tunneler_test
import (
"bytes"
"context"
"testing"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/ssh"
"github.com/coder/coder/v2/agent/agenttest"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbfake"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
"github.com/coder/coder/v2/codersdk/workspacesdk/tunneler"
"github.com/coder/coder/v2/testutil"
)
// TestTunneler_Integration is an integration test using coderdtest. It should be removed when we integrate the Tunneler
// into coder ssh and those integration test cover this functionality.
func TestTunneler_Integration(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client, store := coderdtest.NewWithDatabase(t, nil)
logger := testutil.Logger(t)
client.SetLogger(logger.Named("client"))
first := coderdtest.CreateFirstUser(t, client)
userClient, user := coderdtest.CreateAnotherUserMutators(t, client, first.OrganizationID, nil, func(r *codersdk.CreateUserRequestWithOrgs) {
r.Username = "myuser"
})
userClient.SetLogger(logger.Named("userclient"))
r := dbfake.WorkspaceBuild(t, store, database.WorkspaceTable{
Name: "myworkspace",
OrganizationID: first.OrganizationID,
OwnerID: user.ID,
}).WithAgent().Do()
wsSDKClient := workspacesdk.New(userClient)
logs := &bytes.Buffer{}
app := &sshApplication{
t: t,
ctx: ctx,
done: make(chan struct{}),
}
tun := tunneler.NewTunneler(wsSDKClient, tunneler.Config{
WorkspaceID: r.Workspace.ID,
App: app,
WorkspaceStarter: nil,
AgentName: "",
LogWriter: logs,
DebugLogger: logger.Named("tunneler"),
})
testAgent := agenttest.New(t, client.URL, r.AgentToken)
defer testAgent.Close()
testutil.TryReceive(ctx, t, app.done)
require.Equal(t, app.result, "foo\n")
err := tun.GracefulShutdown(ctx)
require.NoError(t, err)
}
type sshApplication struct {
t *testing.T
ctx context.Context
client *ssh.Client
done chan struct{}
result string
}
func (s *sshApplication) Close() error {
return s.client.Close()
}
func (s *sshApplication) Start(conn workspacesdk.AgentConn) error {
var err error
s.client, err = conn.SSHClient(s.ctx)
if err != nil {
s.t.Error(err)
return err
}
go func() {
defer close(s.done)
sess, err := s.client.NewSession()
if err != nil {
s.t.Error("failed to create session", err)
}
defer sess.Close()
out, err := sess.Output("echo foo")
if err != nil {
s.t.Error("failed to echo", err)
}
s.result = string(out)
}()
return nil
}
+82 -32
View File
@@ -11,6 +11,7 @@ import (
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
"github.com/coder/coder/v2/codersdk/wsjson"
)
type state int
@@ -33,6 +34,11 @@ type WorkspaceStarter interface {
type Client interface {
DialAgent(dialCtx context.Context, agentID uuid.UUID, options *workspacesdk.DialAgentOptions) (workspacesdk.AgentConn, error)
WorkspaceAgentConnectionWatch(
dialCtx context.Context, workspaceID uuid.UUID, agentName string,
) (
dec *wsjson.Decoder[workspacesdk.ConnectionWatchEvent], err error,
)
}
const (
@@ -127,9 +133,9 @@ type Config struct {
// ordering.
type tunnelerEvent struct {
shutdownSignal *shutdownSignal
buildUpdate *buildUpdate
buildUpdate *workspacesdk.BuildUpdate
provisionerJobLog *codersdk.ProvisionerJobLog
agentUpdate *agentUpdate
agentUpdate *workspacesdk.AgentUpdate
agentLog *codersdk.WorkspaceAgentLog
appUpdate *networkedApplicationUpdate
tailnetUpdate *tailnetUpdate
@@ -137,16 +143,6 @@ type tunnelerEvent struct {
type shutdownSignal struct{}
type buildUpdate struct {
transition codersdk.WorkspaceTransition
jobStatus codersdk.ProvisionerJobStatus
}
type agentUpdate struct {
lifecycle codersdk.WorkspaceAgentLifecycle
id uuid.UUID
}
type networkedApplicationUpdate struct {
// up is true if the application is up. False if it is down.
up bool
@@ -174,10 +170,64 @@ func NewTunneler(client Client, config Config) *Tunneler {
return t
}
func (t *Tunneler) GracefulShutdown(ctx context.Context) error {
select {
case t.events <- tunnelerEvent{shutdownSignal: &shutdownSignal{}}:
case <-ctx.Done():
return ctx.Err()
case <-t.ctx.Done():
}
done := make(chan struct{})
go func() {
defer close(done)
t.wg.Wait()
}()
select {
case <-done:
return nil
case <-ctx.Done():
return ctx.Err()
}
}
func (t *Tunneler) start() {
defer t.wg.Done()
// here we would subscribe to updates.
// t.client.AgentConnectionWatch(t.config.WorkspaceID, t.config.AgentName)
d, err := t.client.WorkspaceAgentConnectionWatch(t.ctx, t.config.WorkspaceID, t.config.AgentName)
// TODO: handle retries
if err != nil {
return
}
defer d.Close()
c := d.Chan()
for {
select {
case <-t.ctx.Done():
return
case event, ok := <-c:
if !ok {
t.config.DebugLogger.Error(t.ctx, "watch closed")
}
if event.Error != nil {
t.config.DebugLogger.Error(t.ctx, "workspace agent connection watch error", slog.Error(event.Error))
}
if !ok || event.Error != nil {
// TODO: handle retries
select {
case t.events <- tunnelerEvent{shutdownSignal: &shutdownSignal{}}:
case <-t.ctx.Done():
}
return
}
select {
case <-t.ctx.Done():
return
case t.events <- tunnelerEvent{
buildUpdate: event.BuildUpdate,
agentUpdate: event.AgentUpdate,
}:
}
}
}
}
func (t *Tunneler) eventLoop() {
@@ -235,13 +285,13 @@ func (t *Tunneler) handleSignal() {
}
}
func (t *Tunneler) handleBuildUpdate(update *buildUpdate) {
func (t *Tunneler) handleBuildUpdate(update *workspacesdk.BuildUpdate) {
if t.state == shutdownTailnet || t.state == shutdownApplication || t.state == exit {
return // no-op
}
var canMakeProgress, jobUnhealthy bool
switch update.jobStatus {
switch update.JobStatus {
case codersdk.ProvisionerJobPending, codersdk.ProvisionerJobRunning:
canMakeProgress = true
case codersdk.ProvisionerJobSucceeded:
@@ -249,21 +299,21 @@ func (t *Tunneler) handleBuildUpdate(update *buildUpdate) {
jobUnhealthy = true
}
if update.transition == codersdk.WorkspaceTransitionDelete {
t.config.DebugLogger.Info(t.ctx, "workspace is being deleted", slog.F("job_status", update.jobStatus))
if update.Transition == codersdk.WorkspaceTransitionDelete {
t.config.DebugLogger.Info(t.ctx, "workspace is being deleted", slog.F("job_status", update.JobStatus))
// treat same as signal
t.handleSignal()
return
}
if jobUnhealthy {
t.config.DebugLogger.Info(t.ctx, "build job is in unhealthy state", slog.F("job_status", update.jobStatus))
t.config.DebugLogger.Info(t.ctx, "build job is in unhealthy state", slog.F("job_status", update.JobStatus))
// treat same as signal
t.handleSignal()
return
}
if update.transition == codersdk.WorkspaceTransitionStart && canMakeProgress {
t.config.DebugLogger.Debug(t.ctx, "workspace is starting", slog.F("job_status", update.jobStatus))
if update.Transition == codersdk.WorkspaceTransitionStart && canMakeProgress {
t.config.DebugLogger.Debug(t.ctx, "workspace is starting", slog.F("job_status", update.JobStatus))
switch t.state {
// new build after we have already connected
case establishTailnet: // we are starting the tailnet
@@ -279,8 +329,8 @@ func (t *Tunneler) handleBuildUpdate(update *buildUpdate) {
}
return
}
if update.transition == codersdk.WorkspaceTransitionStart && update.jobStatus == codersdk.ProvisionerJobSucceeded {
t.config.DebugLogger.Debug(t.ctx, "workspace is started", slog.F("job_status", update.jobStatus))
if update.Transition == codersdk.WorkspaceTransitionStart && update.JobStatus == codersdk.ProvisionerJobSucceeded {
t.config.DebugLogger.Debug(t.ctx, "workspace is started", slog.F("job_status", update.JobStatus))
switch t.state {
case establishTailnet, applicationUp, tailnetUp:
// no-op. Later agent updates will tell us whether the tailnet connection is current.
@@ -290,7 +340,7 @@ func (t *Tunneler) handleBuildUpdate(update *buildUpdate) {
return
}
if update.transition == codersdk.WorkspaceTransitionStop {
if update.Transition == codersdk.WorkspaceTransitionStop {
// these cases take effect regardless of whether the transition is complete or not
switch t.state {
// all 3 of these mean a new build after we have already started connecting
@@ -312,7 +362,7 @@ func (t *Tunneler) handleBuildUpdate(update *buildUpdate) {
t.state = exit
return
}
if update.jobStatus == codersdk.ProvisionerJobSucceeded {
if update.JobStatus == codersdk.ProvisionerJobSucceeded {
switch t.state {
case stateInit, waitToStart, waitForAgent:
t.wg.Add(1)
@@ -335,29 +385,29 @@ func (t *Tunneler) handleBuildUpdate(update *buildUpdate) {
}
// unhittable
t.config.DebugLogger.Critical(t.ctx, "unhandled build update",
slog.F("job_status", update.jobStatus), slog.F("transition", update.transition), slog.F("state", t.state))
slog.F("job_status", update.JobStatus), slog.F("transition", update.Transition), slog.F("state", t.state))
}
func (*Tunneler) handleProvisionerJobLog(*codersdk.ProvisionerJobLog) {
}
func (t *Tunneler) handleAgentUpdate(update *agentUpdate) {
func (t *Tunneler) handleAgentUpdate(update *workspacesdk.AgentUpdate) {
t.config.DebugLogger.Debug(t.ctx, "handling agent update",
slog.F("state", t.state),
slog.F("lifecycle", update.lifecycle),
slog.F("agent_id", update.id))
slog.F("lifecycle", update.Lifecycle),
slog.F("agent_id", update.ID))
if t.state != waitForAgent {
return
}
doConnect := func() {
t.wg.Add(1)
t.state = establishTailnet
go t.connectTailnet(update.id)
go t.connectTailnet(update.ID)
}
// consequence of ignoring updates if we are not waiting for the agent is that we MUST receive
// the start build succeeded update BEFORE we get the Agent connected / ready update. We should keep this
// in mind when implementing the watch in Coderd.
switch update.lifecycle {
switch update.Lifecycle {
case codersdk.WorkspaceAgentLifecycleReady:
doConnect()
return
@@ -376,7 +426,7 @@ func (t *Tunneler) handleAgentUpdate(update *agentUpdate) {
default:
// unhittable, unless new states are added. We structure this with the switch and all cases covered to ensure
// we cover all cases.
t.config.DebugLogger.Critical(t.ctx, "unhandled agent update", slog.F("lifecycle", update.lifecycle))
t.config.DebugLogger.Critical(t.ctx, "unhandled agent update", slog.F("lifecycle", update.Lifecycle))
}
}
@@ -12,6 +12,7 @@ import (
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
"github.com/coder/coder/v2/codersdk/workspacesdk/agentconnmock"
"github.com/coder/coder/v2/codersdk/wsjson"
"github.com/coder/coder/v2/testutil"
)
@@ -28,7 +29,7 @@ func TestHandleBuildUpdate_Coverage(t *testing.T) {
t.Run(fmt.Sprintf("%d_%s_%s_%t_%t", s, trans, jobStatus, noAutostart, noWaitForScripts), func(t *testing.T) {
t.Parallel()
coverUpdate(t, workspaceID, noAutostart, noWaitForScripts, s, func(uut *Tunneler) {
uut.handleBuildUpdate(&buildUpdate{transition: trans, jobStatus: jobStatus})
uut.handleBuildUpdate(&workspacesdk.BuildUpdate{Transition: trans, JobStatus: jobStatus})
})
})
}
@@ -105,31 +106,31 @@ func TestBuildUpdatesStoppedWorkspace(t *testing.T) {
state: stateInit,
}
uut.handleBuildUpdate(&buildUpdate{transition: codersdk.WorkspaceTransitionStop, jobStatus: codersdk.ProvisionerJobPending})
uut.handleBuildUpdate(&workspacesdk.BuildUpdate{Transition: codersdk.WorkspaceTransitionStop, JobStatus: codersdk.ProvisionerJobPending})
require.Equal(t, waitToStart, uut.state)
waitForGoroutines(testCtx, t, uut)
require.False(t, fWorkspaceStarter.started)
uut.handleBuildUpdate(&buildUpdate{transition: codersdk.WorkspaceTransitionStop, jobStatus: codersdk.ProvisionerJobRunning})
uut.handleBuildUpdate(&workspacesdk.BuildUpdate{Transition: codersdk.WorkspaceTransitionStop, JobStatus: codersdk.ProvisionerJobRunning})
require.Equal(t, waitToStart, uut.state)
waitForGoroutines(testCtx, t, uut)
require.False(t, fWorkspaceStarter.started)
// when stop job succeeds, we start the workspace
uut.handleBuildUpdate(&buildUpdate{transition: codersdk.WorkspaceTransitionStop, jobStatus: codersdk.ProvisionerJobSucceeded})
uut.handleBuildUpdate(&workspacesdk.BuildUpdate{Transition: codersdk.WorkspaceTransitionStop, JobStatus: codersdk.ProvisionerJobSucceeded})
require.Equal(t, waitForWorkspaceStarted, uut.state)
waitForGoroutines(testCtx, t, uut)
require.True(t, fWorkspaceStarter.started)
uut.handleBuildUpdate(&buildUpdate{transition: codersdk.WorkspaceTransitionStart, jobStatus: codersdk.ProvisionerJobPending})
uut.handleBuildUpdate(&workspacesdk.BuildUpdate{Transition: codersdk.WorkspaceTransitionStart, JobStatus: codersdk.ProvisionerJobPending})
require.Equal(t, waitForWorkspaceStarted, uut.state)
waitForGoroutines(testCtx, t, uut)
uut.handleBuildUpdate(&buildUpdate{transition: codersdk.WorkspaceTransitionStart, jobStatus: codersdk.ProvisionerJobRunning})
uut.handleBuildUpdate(&workspacesdk.BuildUpdate{Transition: codersdk.WorkspaceTransitionStart, JobStatus: codersdk.ProvisionerJobRunning})
require.Equal(t, waitForWorkspaceStarted, uut.state)
waitForGoroutines(testCtx, t, uut)
uut.handleBuildUpdate(&buildUpdate{transition: codersdk.WorkspaceTransitionStart, jobStatus: codersdk.ProvisionerJobSucceeded})
uut.handleBuildUpdate(&workspacesdk.BuildUpdate{Transition: codersdk.WorkspaceTransitionStart, JobStatus: codersdk.ProvisionerJobSucceeded})
require.Equal(t, waitForAgent, uut.state)
waitForGoroutines(testCtx, t, uut)
}
@@ -157,7 +158,7 @@ func TestBuildUpdatesNewBuildWhileWaiting(t *testing.T) {
}
// New build comes in while we are waiting for the agent to start. We roll back to waiting for the workspace to start.
uut.handleBuildUpdate(&buildUpdate{transition: codersdk.WorkspaceTransitionStart, jobStatus: codersdk.ProvisionerJobRunning})
uut.handleBuildUpdate(&workspacesdk.BuildUpdate{Transition: codersdk.WorkspaceTransitionStart, JobStatus: codersdk.ProvisionerJobRunning})
require.Equal(t, waitForWorkspaceStarted, uut.state)
waitForGoroutines(testCtx, t, uut)
require.False(t, fWorkspaceStarter.started)
@@ -193,12 +194,12 @@ func TestBuildUpdatesBadJobs(t *testing.T) {
state: stateInit,
}
uut.handleBuildUpdate(&buildUpdate{transition: codersdk.WorkspaceTransitionStart, jobStatus: codersdk.ProvisionerJobRunning})
uut.handleBuildUpdate(&workspacesdk.BuildUpdate{Transition: codersdk.WorkspaceTransitionStart, JobStatus: codersdk.ProvisionerJobRunning})
require.Equal(t, waitForWorkspaceStarted, uut.state)
waitForGoroutines(testCtx, t, uut)
require.False(t, fWorkspaceStarter.started)
uut.handleBuildUpdate(&buildUpdate{transition: codersdk.WorkspaceTransitionStop, jobStatus: jobStatus})
uut.handleBuildUpdate(&workspacesdk.BuildUpdate{Transition: codersdk.WorkspaceTransitionStop, JobStatus: jobStatus})
require.Equal(t, exit, uut.state)
waitForGoroutines(testCtx, t, uut)
require.False(t, fWorkspaceStarter.started)
@@ -233,7 +234,7 @@ func TestBuildUpdatesNoAutostart(t *testing.T) {
}
// when stop job succeeds, we exit because autostart is disabled
uut.handleBuildUpdate(&buildUpdate{transition: codersdk.WorkspaceTransitionStop, jobStatus: codersdk.ProvisionerJobSucceeded})
uut.handleBuildUpdate(&workspacesdk.BuildUpdate{Transition: codersdk.WorkspaceTransitionStop, JobStatus: codersdk.ProvisionerJobSucceeded})
require.Equal(t, exit, uut.state)
waitForGoroutines(testCtx, t, uut)
require.False(t, fWorkspaceStarter.started)
@@ -254,7 +255,7 @@ func TestAgentUpdate_Coverage(t *testing.T) {
t.Run(fmt.Sprintf("%d_%s_%t_%t", s, lifecycle, noAutostart, noWaitForScripts), func(t *testing.T) {
t.Parallel()
coverUpdate(t, workspaceID, noAutostart, noWaitForScripts, s, func(uut *Tunneler) {
uut.handleAgentUpdate(&agentUpdate{lifecycle: lifecycle, id: agentID})
uut.handleAgentUpdate(&workspacesdk.AgentUpdate{Lifecycle: lifecycle, ID: agentID})
})
})
}
@@ -288,7 +289,7 @@ func TestAgentUpdateReady(t *testing.T) {
client: fClient,
}
uut.handleAgentUpdate(&agentUpdate{lifecycle: codersdk.WorkspaceAgentLifecycleReady, id: agentID})
uut.handleAgentUpdate(&workspacesdk.AgentUpdate{Lifecycle: codersdk.WorkspaceAgentLifecycleReady, ID: agentID})
require.Equal(t, establishTailnet, uut.state)
event := testutil.RequireReceive(testCtx, t, uut.events)
require.NotNil(t, event.tailnetUpdate)
@@ -323,7 +324,7 @@ func TestAgentUpdateNoWait(t *testing.T) {
client: fClient,
}
uut.handleAgentUpdate(&agentUpdate{lifecycle: codersdk.WorkspaceAgentLifecycleStarting, id: agentID})
uut.handleAgentUpdate(&workspacesdk.AgentUpdate{Lifecycle: codersdk.WorkspaceAgentLifecycleStarting, ID: agentID})
require.Equal(t, establishTailnet, uut.state)
event := testutil.RequireReceive(testCtx, t, uut.events)
require.NotNil(t, event.tailnetUpdate)
@@ -526,27 +527,27 @@ func TestTunneler_EventLoop_Signal(t *testing.T) {
go uut.eventLoop()
testutil.RequireSend(testCtx, t, uut.events, tunnelerEvent{
buildUpdate: &buildUpdate{
transition: codersdk.WorkspaceTransitionStart,
jobStatus: codersdk.ProvisionerJobPending,
buildUpdate: &workspacesdk.BuildUpdate{
Transition: codersdk.WorkspaceTransitionStart,
JobStatus: codersdk.ProvisionerJobPending,
},
})
testutil.RequireSend(testCtx, t, uut.events, tunnelerEvent{
buildUpdate: &buildUpdate{
transition: codersdk.WorkspaceTransitionStart,
jobStatus: codersdk.ProvisionerJobRunning,
buildUpdate: &workspacesdk.BuildUpdate{
Transition: codersdk.WorkspaceTransitionStart,
JobStatus: codersdk.ProvisionerJobRunning,
},
})
testutil.RequireSend(testCtx, t, uut.events, tunnelerEvent{
buildUpdate: &buildUpdate{
transition: codersdk.WorkspaceTransitionStart,
jobStatus: codersdk.ProvisionerJobSucceeded,
buildUpdate: &workspacesdk.BuildUpdate{
Transition: codersdk.WorkspaceTransitionStart,
JobStatus: codersdk.ProvisionerJobSucceeded,
},
})
testutil.RequireSend(testCtx, t, uut.events, tunnelerEvent{
agentUpdate: &agentUpdate{
lifecycle: codersdk.WorkspaceAgentLifecycleReady,
id: agentID,
agentUpdate: &workspacesdk.AgentUpdate{
Lifecycle: codersdk.WorkspaceAgentLifecycleReady,
ID: agentID,
},
})
@@ -658,6 +659,11 @@ type fakeClient struct {
dialed bool
}
func (*fakeClient) WorkspaceAgentConnectionWatch(context.Context, uuid.UUID, string) (dec *wsjson.Decoder[workspacesdk.ConnectionWatchEvent], err error) {
// TODO implement me
panic("implement me")
}
func (f *fakeClient) DialAgent(
_ context.Context, id uuid.UUID, _ *workspacesdk.DialAgentOptions,
) (
@@ -0,0 +1,86 @@
package workspacesdk
import (
"context"
"fmt"
"github.com/google/uuid"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/wsjson"
"github.com/coder/websocket"
)
type WatchErrorCode int
const (
_ WatchErrorCode = iota // Ensure that zero value is not a valid code
WatchErrorTooManyAgents
WatchErrorNameNotFound
WatchErrorNoAgents
WatchErrorServerShutdown
WatchErrorDatabase
WatchErrorInternal
)
type ConnectionWatchEvent struct {
Error *WatchError `json:"error"`
BuildUpdate *BuildUpdate `json:"build_update,omitempty"`
AgentUpdate *AgentUpdate `json:"agent_update,omitempty"`
}
type WatchError struct {
Code WatchErrorCode `json:"code"`
Retryable bool `json:"retryable"`
Message string `json:"message"`
Details string `json:"details,omitempty"`
}
func (e *WatchError) Error() string {
if e.Details != "" {
return fmt.Sprintf("%s: %s", e.Message, e.Details)
}
return e.Message
}
type BuildUpdate struct {
Transition codersdk.WorkspaceTransition `json:"transition"`
JobStatus codersdk.ProvisionerJobStatus `json:"job_status"`
}
type AgentUpdate struct {
Lifecycle codersdk.WorkspaceAgentLifecycle `json:"lifecycle"`
ID uuid.UUID `json:"id" format:"uuid"`
}
func (c *Client) WorkspaceAgentConnectionWatch(
dialCtx context.Context, workspaceID uuid.UUID, agentName string,
) (
dec *wsjson.Decoder[ConnectionWatchEvent], err error,
) {
wsOptions := &websocket.DialOptions{
HTTPClient: c.client.HTTPClient,
// Need to disable compression to avoid a data-race.
CompressionMode: websocket.CompressionDisabled,
}
c.client.SessionTokenProvider.SetDialOption(wsOptions)
watchURL, err := c.client.URL.Parse(fmt.Sprintf("/api/v2/workspaces/%s/agent-connection-watch", workspaceID))
if err != nil {
return nil, xerrors.Errorf("parse url: %w", err)
}
if agentName != "" {
q := watchURL.Query()
q.Set("agent_name", agentName)
watchURL.RawQuery = q.Encode()
}
// nolint:bodyclose
conn, res, err := websocket.Dial(dialCtx, watchURL.String(), wsOptions)
if err != nil {
bodyErr := codersdk.ReadBodyAsError(res)
return nil, bodyErr
}
return wsjson.NewDecoder[ConnectionWatchEvent](conn, websocket.MessageText, c.client.Logger()), nil
}
+95
View File
@@ -18616,6 +18616,101 @@ None
| `disable_direct_connections` | boolean | false | | |
| `hostname_suffix` | string | false | | |
## workspacesdk.AgentUpdate
```json
{
"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08",
"lifecycle": "created"
}
```
### Properties
| Name | Type | Required | Restrictions | Description |
|-------------|----------------------------------------------------------------------|----------|--------------|-------------|
| `id` | string | false | | |
| `lifecycle` | [codersdk.WorkspaceAgentLifecycle](#codersdkworkspaceagentlifecycle) | false | | |
## workspacesdk.BuildUpdate
```json
{
"job_status": "pending",
"transition": "start"
}
```
### Properties
| Name | Type | Required | Restrictions | Description |
|--------------|----------------------------------------------------------------|----------|--------------|-------------|
| `job_status` | [codersdk.ProvisionerJobStatus](#codersdkprovisionerjobstatus) | false | | |
| `transition` | [codersdk.WorkspaceTransition](#codersdkworkspacetransition) | false | | |
## workspacesdk.ConnectionWatchEvent
```json
{
"agent_update": {
"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08",
"lifecycle": "created"
},
"build_update": {
"job_status": "pending",
"transition": "start"
},
"error": {
"code": 0,
"details": "string",
"message": "string",
"retryable": true
}
}
```
### Properties
| Name | Type | Required | Restrictions | Description |
|----------------|------------------------------------------------------|----------|--------------|-------------|
| `agent_update` | [workspacesdk.AgentUpdate](#workspacesdkagentupdate) | false | | |
| `build_update` | [workspacesdk.BuildUpdate](#workspacesdkbuildupdate) | false | | |
| `error` | [workspacesdk.WatchError](#workspacesdkwatcherror) | false | | |
## workspacesdk.WatchError
```json
{
"code": 0,
"details": "string",
"message": "string",
"retryable": true
}
```
### Properties
| Name | Type | Required | Restrictions | Description |
|-------------|------------------------------------------------------------|----------|--------------|-------------|
| `code` | [workspacesdk.WatchErrorCode](#workspacesdkwatcherrorcode) | false | | |
| `details` | string | false | | |
| `message` | string | false | | |
| `retryable` | boolean | false | | |
## workspacesdk.WatchErrorCode
```json
0
```
### Properties
#### Enumerated Values
| Value(s) |
|-----------------------------------|
| `0`, `1`, `2`, `3`, `4`, `5`, `6` |
## wsproxysdk.CryptoKeysResponse
```json
+50
View File
@@ -1817,6 +1817,56 @@ curl -X PATCH http://coder-server:8080/api/v2/workspaces/{workspace}/acl \
To perform this operation, you must be authenticated. [Learn more](authentication.md).
## Workspace Agent Connection Watch
### Code samples
```shell
# Example request using curl
curl -X GET http://coder-server:8080/api/v2/workspaces/{workspace}/agent-connection-watch \
-H 'Accept: application/json' \
-H 'Coder-Session-Token: API_KEY'
```
`GET /api/v2/workspaces/{workspace}/agent-connection-watch`
### Parameters
| Name | In | Type | Required | Description |
|-------------|------|--------------|----------|--------------|
| `workspace` | path | string(uuid) | true | Workspace ID |
### Example responses
> 101 Response
```json
{
"agent_update": {
"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08",
"lifecycle": "created"
},
"build_update": {
"job_status": "pending",
"transition": "start"
},
"error": {
"code": 0,
"details": "string",
"message": "string",
"retryable": true
}
}
```
### Responses
| Status | Meaning | Description | Schema |
|--------|--------------------------------------------------------------------------|---------------------|----------------------------------------------------------------------------------|
| 101 | [Switching Protocols](https://tools.ietf.org/html/rfc7231#section-6.2.2) | Switching Protocols | [workspacesdk.ConnectionWatchEvent](schemas.md#workspacesdkconnectionwatchevent) |
To perform this operation, you must be authenticated. [Learn more](authentication.md).
## Update workspace autostart schedule by ID
### Code samples
@@ -12,7 +12,8 @@ import (
func TestEnterpriseEndpointsDocumented(t *testing.T) {
t.Parallel()
swaggerComments, err := coderdtest.ParseSwaggerComments("..", "../../../coderd")
swaggerComments, err := coderdtest.ParseSwaggerComments(
"..", "../../../coderd", "../../../coderd/workspaceconnwatcher")
require.NoError(t, err, "can't parse swagger comments")
require.NotEmpty(t, swaggerComments, "swagger comments must be present")
+1 -1
View File
@@ -22,7 +22,7 @@ func main() {
}
err := gen.New().Build(&gen.Config{
SearchDir: "./coderd,./codersdk,./enterprise/coderd,./enterprise/wsproxy/wsproxysdk",
SearchDir: "./coderd,./coderd/workspaceconnwatcher,./codersdk,./enterprise/coderd,./enterprise/wsproxy/wsproxysdk",
MainAPIFile: "coderd.go",
OutputDir: outputDir,
OutputTypes: []string{"go", "json"},
+101
View File
@@ -0,0 +1,101 @@
package testutil
import (
"bufio"
"context"
"io"
"net"
"net/http"
"cdr.dev/slog/v3"
)
// InMemWebsocketRoundTripper allows you to "dial" an HTTP handler that sets up a websocket using only in-memory
// primitives. No TCP or OS networking needed. CtxMutator gives you explicit control over the context the handler sees.
//
// Example:
//
// rt := testutil.InMemWebsocketRoundTripper{
// Handler: MyHandler,
// CtxMutator: func(ctx context.Context) context.Context {
// ctx = httpmw.WithWorkspaceParam(ctx, ws)
// ctx = dbauthz.As(ctx, coderdtest.MemberSubject(userID, orgID))
// return ctx
// },
// Logger: logger.Named("roundtripper"),
// }
// clientSock, _, err := websocket.Dial(ctx, "wss://local.test/", &websocket.DialOptions{
// HTTPClient: &http.Client{Transport: rt},
// })
type InMemWebsocketRoundTripper struct {
Logger slog.Logger
Handler http.Handler
CtxMutator func(ctx context.Context) context.Context
}
func (i InMemWebsocketRoundTripper) RoundTrip(request *http.Request) (*http.Response, error) {
i.Logger.Debug(context.Background(), "round trip start")
defer i.Logger.Debug(context.Background(), "round trip end")
newCtx := i.CtxMutator(request.Context())
request = request.WithContext(newCtx)
serverP, clientP := net.Pipe()
var _ io.ReadWriteCloser = clientP // compile time check that response body is OK for websocket
response := &http.Response{
Header: make(http.Header),
Body: clientP,
}
rw := newInMemWebsocketResponseWriter(response, serverP)
go func() {
i.Handler.ServeHTTP(rw, request)
if !rw.hijacked {
i.Logger.Debug(context.Background(), "closing connection after handler did not hijack")
// If the handler didn't hijack the connection, we should close it when the handler finishes.
// This prevents a 3s delay in websocket.Dial() reading the non-upgraded response.
_ = serverP.Close()
}
}()
select {
case <-newCtx.Done():
return nil, newCtx.Err()
case <-rw.gotHeaders:
return response, nil
}
}
func newInMemWebsocketResponseWriter(resp *http.Response, conn net.Conn) *inMemWebsocketResponseWriter {
r := bufio.NewReader(conn)
w := bufio.NewWriter(conn)
return &inMemWebsocketResponseWriter{
r: resp,
b: bufio.NewReadWriter(r, w),
gotHeaders: make(chan struct{}),
conn: conn,
}
}
type inMemWebsocketResponseWriter struct {
r *http.Response
b *bufio.ReadWriter
gotHeaders chan struct{}
hijacked bool
conn net.Conn
}
func (rw *inMemWebsocketResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
rw.hijacked = true
return rw.conn, rw.b, nil
}
func (rw *inMemWebsocketResponseWriter) Header() http.Header {
return rw.r.Header
}
func (rw *inMemWebsocketResponseWriter) Write([]byte) (int, error) {
n, err := rw.b.Write([]byte{})
return n, err
}
func (rw *inMemWebsocketResponseWriter) WriteHeader(statusCode int) {
rw.r.StatusCode = statusCode
close(rw.gotHeaders)
}