mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
chore: add agent-connection-watch for workspaces (#24507)
<!-- If you have used AI to produce some or all of this PR, please ensure you have read our [AI Contribution guidelines](https://coder.com/docs/about/contributing/AI_CONTRIBUTING) before submitting. --> relates to GRU-18 Adds basic implementation for Workspace Agent Connection Watch and tests. Missing are handling of logs.
This commit is contained in:
@@ -1293,6 +1293,7 @@ coderd/apidoc/.gen: \
|
||||
$(wildcard enterprise/coderd/*.go) \
|
||||
$(wildcard codersdk/*.go) \
|
||||
$(wildcard enterprise/wsproxy/wsproxysdk/*.go) \
|
||||
$(wildcard coderd/workspaceconnwatcher/*.go) \
|
||||
$(DB_GEN_FILES) \
|
||||
coderd/rbac/object_gen.go \
|
||||
.swaggo \
|
||||
|
||||
Generated
+122
@@ -12783,6 +12783,41 @@ const docTemplate = `{
|
||||
]
|
||||
}
|
||||
},
|
||||
"/api/v2/workspaces/{workspace}/agent-connection-watch": {
|
||||
"get": {
|
||||
"produces": [
|
||||
"application/json"
|
||||
],
|
||||
"tags": [
|
||||
"Workspaces"
|
||||
],
|
||||
"summary": "Workspace Agent Connection Watch",
|
||||
"operationId": "workspace-agent-connection-watch",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"format": "uuid",
|
||||
"description": "Workspace ID",
|
||||
"name": "workspace",
|
||||
"in": "path",
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"101": {
|
||||
"description": "Switching Protocols",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/workspacesdk.ConnectionWatchEvent"
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"/api/v2/workspaces/{workspace}/autostart": {
|
||||
"put": {
|
||||
"consumes": [
|
||||
@@ -27964,6 +27999,93 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"workspacesdk.AgentUpdate": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"lifecycle": {
|
||||
"$ref": "#/definitions/codersdk.WorkspaceAgentLifecycle"
|
||||
}
|
||||
}
|
||||
},
|
||||
"workspacesdk.BuildUpdate": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"job_status": {
|
||||
"$ref": "#/definitions/codersdk.ProvisionerJobStatus"
|
||||
},
|
||||
"transition": {
|
||||
"$ref": "#/definitions/codersdk.WorkspaceTransition"
|
||||
}
|
||||
}
|
||||
},
|
||||
"workspacesdk.ConnectionWatchEvent": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"agent_update": {
|
||||
"$ref": "#/definitions/workspacesdk.AgentUpdate"
|
||||
},
|
||||
"build_update": {
|
||||
"$ref": "#/definitions/workspacesdk.BuildUpdate"
|
||||
},
|
||||
"error": {
|
||||
"$ref": "#/definitions/workspacesdk.WatchError"
|
||||
}
|
||||
}
|
||||
},
|
||||
"workspacesdk.WatchError": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"code": {
|
||||
"$ref": "#/definitions/workspacesdk.WatchErrorCode"
|
||||
},
|
||||
"details": {
|
||||
"type": "string"
|
||||
},
|
||||
"message": {
|
||||
"type": "string"
|
||||
},
|
||||
"retryable": {
|
||||
"type": "boolean"
|
||||
}
|
||||
}
|
||||
},
|
||||
"workspacesdk.WatchErrorCode": {
|
||||
"type": "integer",
|
||||
"enum": [
|
||||
0,
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
4,
|
||||
5,
|
||||
6
|
||||
],
|
||||
"x-enum-comments": {
|
||||
"_": "Ensure that zero value is not a valid code"
|
||||
},
|
||||
"x-enum-descriptions": [
|
||||
"Ensure that zero value is not a valid code",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
""
|
||||
],
|
||||
"x-enum-varnames": [
|
||||
"_",
|
||||
"WatchErrorTooManyAgents",
|
||||
"WatchErrorNameNotFound",
|
||||
"WatchErrorNoAgents",
|
||||
"WatchErrorServerShutdown",
|
||||
"WatchErrorDatabase",
|
||||
"WatchErrorInternal"
|
||||
]
|
||||
},
|
||||
"wsproxysdk.CryptoKeysResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
||||
Generated
+110
@@ -11343,6 +11343,37 @@
|
||||
]
|
||||
}
|
||||
},
|
||||
"/api/v2/workspaces/{workspace}/agent-connection-watch": {
|
||||
"get": {
|
||||
"produces": ["application/json"],
|
||||
"tags": ["Workspaces"],
|
||||
"summary": "Workspace Agent Connection Watch",
|
||||
"operationId": "workspace-agent-connection-watch",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"format": "uuid",
|
||||
"description": "Workspace ID",
|
||||
"name": "workspace",
|
||||
"in": "path",
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"101": {
|
||||
"description": "Switching Protocols",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/workspacesdk.ConnectionWatchEvent"
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"/api/v2/workspaces/{workspace}/autostart": {
|
||||
"put": {
|
||||
"consumes": ["application/json"],
|
||||
@@ -25823,6 +25854,85 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"workspacesdk.AgentUpdate": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"lifecycle": {
|
||||
"$ref": "#/definitions/codersdk.WorkspaceAgentLifecycle"
|
||||
}
|
||||
}
|
||||
},
|
||||
"workspacesdk.BuildUpdate": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"job_status": {
|
||||
"$ref": "#/definitions/codersdk.ProvisionerJobStatus"
|
||||
},
|
||||
"transition": {
|
||||
"$ref": "#/definitions/codersdk.WorkspaceTransition"
|
||||
}
|
||||
}
|
||||
},
|
||||
"workspacesdk.ConnectionWatchEvent": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"agent_update": {
|
||||
"$ref": "#/definitions/workspacesdk.AgentUpdate"
|
||||
},
|
||||
"build_update": {
|
||||
"$ref": "#/definitions/workspacesdk.BuildUpdate"
|
||||
},
|
||||
"error": {
|
||||
"$ref": "#/definitions/workspacesdk.WatchError"
|
||||
}
|
||||
}
|
||||
},
|
||||
"workspacesdk.WatchError": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"code": {
|
||||
"$ref": "#/definitions/workspacesdk.WatchErrorCode"
|
||||
},
|
||||
"details": {
|
||||
"type": "string"
|
||||
},
|
||||
"message": {
|
||||
"type": "string"
|
||||
},
|
||||
"retryable": {
|
||||
"type": "boolean"
|
||||
}
|
||||
}
|
||||
},
|
||||
"workspacesdk.WatchErrorCode": {
|
||||
"type": "integer",
|
||||
"enum": [0, 1, 2, 3, 4, 5, 6],
|
||||
"x-enum-comments": {
|
||||
"_": "Ensure that zero value is not a valid code"
|
||||
},
|
||||
"x-enum-descriptions": [
|
||||
"Ensure that zero value is not a valid code",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
""
|
||||
],
|
||||
"x-enum-varnames": [
|
||||
"_",
|
||||
"WatchErrorTooManyAgents",
|
||||
"WatchErrorNameNotFound",
|
||||
"WatchErrorNoAgents",
|
||||
"WatchErrorServerShutdown",
|
||||
"WatchErrorDatabase",
|
||||
"WatchErrorInternal"
|
||||
]
|
||||
},
|
||||
"wsproxysdk.CryptoKeysResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
||||
@@ -92,6 +92,7 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/webpush"
|
||||
"github.com/coder/coder/v2/coderd/workspaceapps"
|
||||
"github.com/coder/coder/v2/coderd/workspaceapps/appurl"
|
||||
"github.com/coder/coder/v2/coderd/workspaceconnwatcher"
|
||||
"github.com/coder/coder/v2/coderd/workspacestats"
|
||||
"github.com/coder/coder/v2/coderd/wsbuilder"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd"
|
||||
@@ -923,6 +924,8 @@ func New(options *Options) *API {
|
||||
APIKeyEncryptionKeycache: options.AppEncryptionKeyCache,
|
||||
})
|
||||
|
||||
api.workspaceAgentConnWatcher = workspaceconnwatcher.New(api.ctx, options.Logger, options.Pubsub, options.Database)
|
||||
|
||||
apiKeyMiddleware := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
|
||||
DB: options.Database,
|
||||
ActivateDormantUser: ActivateDormantUser(options.Logger, &api.Auditor, options.Database),
|
||||
@@ -1820,6 +1823,7 @@ func New(options *Options) *API {
|
||||
r.Patch("/", api.patchWorkspaceACL)
|
||||
r.Delete("/", api.deleteWorkspaceACL)
|
||||
})
|
||||
r.Get("/agent-connection-watch", api.workspaceAgentConnWatcher.WorkspaceAgentConnectionWatch)
|
||||
})
|
||||
})
|
||||
r.Route("/workspacebuilds/{workspacebuild}", func(r chi.Router) {
|
||||
@@ -2238,6 +2242,8 @@ type API struct {
|
||||
// profile collection (via /debug/profile) can run at a time. The CPU
|
||||
// profiler is process-global, so concurrent collections would fail.
|
||||
ProfileCollecting atomic.Bool
|
||||
|
||||
workspaceAgentConnWatcher *workspaceconnwatcher.Watcher
|
||||
}
|
||||
|
||||
// Close waits for all WebSocket connections to drain before returning.
|
||||
@@ -2301,6 +2307,7 @@ func (api *API) Close() error {
|
||||
_ = api.AppSigningKeyCache.Close()
|
||||
_ = api.AppEncryptionKeyCache.Close()
|
||||
_ = api.UpdatesProvider.Close()
|
||||
api.workspaceAgentConnWatcher.Close()
|
||||
|
||||
if current := api.PrebuildsReconciler.Load(); current != nil {
|
||||
ctx, giveUp := context.WithTimeoutCause(context.Background(), time.Second*30, xerrors.New("gave up waiting for reconciler to stop before shutdown"))
|
||||
|
||||
@@ -0,0 +1,28 @@
|
||||
package coderdtest
|
||||
|
||||
import (
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"go.uber.org/mock/gomock"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/database/dbmock"
|
||||
"github.com/coder/coder/v2/coderd/rbac"
|
||||
)
|
||||
|
||||
func MockedDatabaseWithAuthz(t testing.TB, logger slog.Logger) (*gomock.Controller, *dbmock.MockStore, database.Store, rbac.Authorizer) {
|
||||
ctrl := gomock.NewController(t)
|
||||
mDB := dbmock.NewMockStore(ctrl)
|
||||
auth := rbac.NewStrictCachingAuthorizer(prometheus.NewRegistry())
|
||||
accessControlStore := &atomic.Pointer[dbauthz.AccessControlStore]{}
|
||||
var acs dbauthz.AccessControlStore = dbauthz.AGPLTemplateAccessControlStore{}
|
||||
accessControlStore.Store(&acs)
|
||||
// dbauthz will call Wrappers() to check for wrapped databases
|
||||
mDB.EXPECT().Wrappers().Return([]string{}).AnyTimes()
|
||||
authDB := dbauthz.New(mDB, auth, logger, accessControlStore)
|
||||
return ctrl, mDB, authDB, auth
|
||||
}
|
||||
@@ -0,0 +1,31 @@
|
||||
package coderdtest
|
||||
|
||||
import (
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/rbac"
|
||||
"github.com/coder/coder/v2/coderd/rbac/rolestore"
|
||||
)
|
||||
|
||||
func MemberSubject(userID, orgID uuid.UUID) rbac.Subject {
|
||||
memberRole, err := rbac.RoleByName(rbac.RoleMember())
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
orgMember, err := rolestore.TestingGetSystemRole(
|
||||
rbac.RoleOrgMember(),
|
||||
orgID,
|
||||
rbac.OrgSettings{ShareableWorkspaceOwners: rbac.ShareableWorkspaceOwnersNone},
|
||||
)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return rbac.Subject{
|
||||
FriendlyName: "coderdtest-member",
|
||||
Email: "member@coderd.test",
|
||||
Type: rbac.SubjectTypeUser,
|
||||
ID: userID.String(),
|
||||
Roles: rbac.Roles{memberRole, orgMember},
|
||||
Scope: rbac.ScopeAll,
|
||||
}.WithCachedASTValue()
|
||||
}
|
||||
@@ -16,7 +16,7 @@ import (
|
||||
func TestEndpointsDocumented(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
swaggerComments, err := coderdtest.ParseSwaggerComments("..")
|
||||
swaggerComments, err := coderdtest.ParseSwaggerComments("..", "../workspaceconnwatcher")
|
||||
require.NoError(t, err, "can't parse swagger comments")
|
||||
require.NotEmpty(t, swaggerComments, "swagger comments must be present")
|
||||
|
||||
|
||||
@@ -54,3 +54,7 @@ func ExtractWorkspaceParam(db database.Store) func(http.Handler) http.Handler {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func WithWorkspaceParam(ctx context.Context, workspace database.Workspace) context.Context {
|
||||
return context.WithValue(ctx, workspaceParamContextKey{}, workspace)
|
||||
}
|
||||
|
||||
@@ -170,6 +170,25 @@ var systemRoles = map[string]permissionsFunc{
|
||||
rbac.RoleOrgServiceAccount(): rbac.OrgServiceAccountPermissions,
|
||||
}
|
||||
|
||||
func TestingGetSystemRole(name string, orgID uuid.UUID, settings rbac.OrgSettings) (rbac.Role, error) {
|
||||
f, ok := systemRoles[name]
|
||||
if !ok {
|
||||
return rbac.Role{}, xerrors.Errorf("role %q not found", name)
|
||||
}
|
||||
perms := f(settings)
|
||||
return rbac.Role{
|
||||
Identifier: rbac.RoleIdentifier{Name: name, OrganizationID: orgID},
|
||||
DisplayName: "",
|
||||
Site: nil,
|
||||
ByOrgID: map[string]rbac.OrgPermissions{
|
||||
orgID.String(): {
|
||||
Org: perms.Org,
|
||||
Member: perms.Member,
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// permissionsFunc produces the desired permissions for a system role
|
||||
// given organization settings.
|
||||
type permissionsFunc func(rbac.OrgSettings) rbac.OrgRolePermissions
|
||||
|
||||
@@ -0,0 +1,333 @@
|
||||
package workspaceconnwatcher
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/database/pubsub"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/coderd/httpmw"
|
||||
"github.com/coder/coder/v2/coderd/wspubsub"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
"github.com/coder/coder/v2/codersdk/wsjson"
|
||||
"github.com/coder/websocket"
|
||||
)
|
||||
|
||||
type Watcher struct {
|
||||
logger slog.Logger
|
||||
sub pubsub.Subscriber
|
||||
db database.Store
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
mu sync.Mutex
|
||||
wg sync.WaitGroup
|
||||
closed bool
|
||||
}
|
||||
|
||||
type event struct {
|
||||
sync bool
|
||||
wsEvent *wspubsub.WorkspaceEvent
|
||||
}
|
||||
|
||||
func New(ctx context.Context, logger slog.Logger, sub pubsub.Subscriber, db database.Store) *Watcher {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
w := &Watcher{
|
||||
logger: logger.Named("wsconnwatcher"),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
sub: sub,
|
||||
db: db,
|
||||
}
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
w.Close()
|
||||
}()
|
||||
return w
|
||||
}
|
||||
|
||||
// @Summary Workspace Agent Connection Watch
|
||||
// @ID workspace-agent-connection-watch
|
||||
// @Security CoderSessionToken
|
||||
// @Produce json
|
||||
// @Tags Workspaces
|
||||
// @Param workspace path string true "Workspace ID" format(uuid)
|
||||
// @Success 101 {object} workspacesdk.ConnectionWatchEvent
|
||||
// @Router /api/v2/workspaces/{workspace}/agent-connection-watch [get]
|
||||
func (w *Watcher) WorkspaceAgentConnectionWatch(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
workspace := httpmw.WorkspaceParam(r)
|
||||
agentName := r.URL.Query().Get("agent_name")
|
||||
|
||||
filteredEvents := make(chan event, 1)
|
||||
filteredEvents <- event{sync: true} // init sync
|
||||
cancelWorkspaceSubscribe, err := w.sub.SubscribeWithErr(wspubsub.WorkspaceEventChannel(workspace.OwnerID),
|
||||
wspubsub.HandleWorkspaceEvent(
|
||||
func(ctx context.Context, payload wspubsub.WorkspaceEvent, err error) {
|
||||
if err != nil {
|
||||
// subscription error, resync
|
||||
select {
|
||||
case filteredEvents <- event{sync: true}:
|
||||
case <-ctx.Done():
|
||||
}
|
||||
return
|
||||
}
|
||||
if payload.WorkspaceID != workspace.ID {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case filteredEvents <- event{wsEvent: &payload}:
|
||||
case <-ctx.Done():
|
||||
}
|
||||
}))
|
||||
if err != nil {
|
||||
w.logger.Error(ctx, "failed to subscribe to workspace events",
|
||||
slog.Error(err), slog.F("owner_id", workspace.OwnerID))
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error setting up workspace event subscription",
|
||||
// Don't include the error in case it leaks infra details about the pubsub
|
||||
})
|
||||
return
|
||||
}
|
||||
defer cancelWorkspaceSubscribe()
|
||||
|
||||
closed := false
|
||||
w.mu.Lock()
|
||||
closed = w.closed
|
||||
if !closed {
|
||||
w.wg.Add(1)
|
||||
}
|
||||
w.mu.Unlock()
|
||||
if closed {
|
||||
w.logger.Debug(ctx, "server is closed, writing error")
|
||||
httpapi.Write(ctx, rw, http.StatusServiceUnavailable, codersdk.Response{
|
||||
Message: "Server instance is shutting down",
|
||||
})
|
||||
return
|
||||
}
|
||||
defer w.wg.Done()
|
||||
|
||||
conn, err := websocket.Accept(rw, r, nil)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Failed to accept WebSocket.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// CloseRead starts a goroutine to read and discard messages from the client,
|
||||
// including Pong messages sent in response to our Ping heartbeats.
|
||||
_ = conn.CloseRead(ctx)
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
go httpapi.HeartbeatClose(ctx, w.logger, cancel, conn)
|
||||
defer cancel()
|
||||
|
||||
u := &updater{
|
||||
db: w.db,
|
||||
watcherCtx: w.ctx,
|
||||
connCtx: ctx,
|
||||
conn: conn,
|
||||
workspaceID: workspace.ID,
|
||||
events: filteredEvents,
|
||||
agentName: agentName,
|
||||
}
|
||||
u.run()
|
||||
}
|
||||
|
||||
func (w *Watcher) Close() {
|
||||
w.mu.Lock()
|
||||
w.closed = true
|
||||
w.mu.Unlock()
|
||||
|
||||
w.cancel()
|
||||
w.wg.Wait()
|
||||
}
|
||||
|
||||
type updater struct {
|
||||
db database.Store
|
||||
watcherCtx context.Context
|
||||
connCtx context.Context
|
||||
conn *websocket.Conn
|
||||
enc *wsjson.Encoder[workspacesdk.ConnectionWatchEvent]
|
||||
workspaceID uuid.UUID
|
||||
events <-chan event
|
||||
agentName string
|
||||
|
||||
lastBuild database.GetLatestWorkspaceBuildWithStatusByWorkspaceIDRow
|
||||
}
|
||||
|
||||
func (u *updater) run() {
|
||||
u.enc = wsjson.NewEncoder[workspacesdk.ConnectionWatchEvent](u.conn, websocket.MessageText)
|
||||
defer func() {
|
||||
// this is a no-op if we have already closed for some other reason.
|
||||
_ = u.enc.Close(websocket.StatusNormalClosure)
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-u.watcherCtx.Done():
|
||||
u.errorThenClose(workspacesdk.WatchError{
|
||||
Code: workspacesdk.WatchErrorServerShutdown,
|
||||
Retryable: true,
|
||||
Message: "server is shutting down",
|
||||
})
|
||||
return
|
||||
case <-u.connCtx.Done():
|
||||
return
|
||||
case e := <-u.events:
|
||||
if e.sync {
|
||||
// zero this out so we'll send a full update
|
||||
u.lastBuild = database.GetLatestWorkspaceBuildWithStatusByWorkspaceIDRow{}
|
||||
if !u.buildUpdate() {
|
||||
return
|
||||
}
|
||||
}
|
||||
if e.wsEvent != nil {
|
||||
switch e.wsEvent.Kind {
|
||||
case wspubsub.WorkspaceEventKindStateChange:
|
||||
if !u.buildUpdate() {
|
||||
return
|
||||
}
|
||||
case wspubsub.WorkspaceEventKindAgentLifecycleUpdate:
|
||||
if !u.maybeSendAgentUpdate() {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (u *updater) buildUpdate() bool {
|
||||
build, err := u.db.GetLatestWorkspaceBuildWithStatusByWorkspaceID(u.connCtx, u.workspaceID)
|
||||
if err != nil {
|
||||
retryable := true
|
||||
details := err.Error()
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
// There is no build (unlikely), or the workspace was deleted. In both cases, retrying won't help.
|
||||
retryable = false
|
||||
}
|
||||
if dbauthz.IsNotAuthorizedError(err) {
|
||||
retryable = false
|
||||
details = "unauthorized" // security: don't leak internal authz details
|
||||
}
|
||||
u.errorThenClose(workspacesdk.WatchError{
|
||||
Code: workspacesdk.WatchErrorDatabase,
|
||||
Retryable: retryable,
|
||||
Message: "failed to fetch latest workspace build",
|
||||
Details: details,
|
||||
})
|
||||
return false
|
||||
}
|
||||
|
||||
if build.BuildNumber != u.lastBuild.BuildNumber ||
|
||||
build.JobStatus != u.lastBuild.JobStatus ||
|
||||
build.Transition != u.lastBuild.Transition {
|
||||
u.lastBuild = build
|
||||
err = u.enc.Encode(workspacesdk.ConnectionWatchEvent{BuildUpdate: &workspacesdk.BuildUpdate{
|
||||
Transition: codersdk.WorkspaceTransition(build.Transition),
|
||||
JobStatus: codersdk.ProvisionerJobStatus(build.JobStatus),
|
||||
}})
|
||||
if err != nil {
|
||||
// probably this is just that the connection is closed, but in case there is some actual JSON serialization
|
||||
// error, send a close frame.
|
||||
_ = u.conn.Close(websocket.StatusInternalError, "failed to encode build update")
|
||||
return false
|
||||
}
|
||||
return u.maybeSendAgentUpdate()
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (u *updater) maybeSendAgentUpdate() (ok bool) {
|
||||
if u.lastBuild.Transition != database.WorkspaceTransitionStart ||
|
||||
u.lastBuild.JobStatus != database.ProvisionerJobStatusSucceeded {
|
||||
// only send agent updates for successfully started workspaces
|
||||
return true
|
||||
}
|
||||
|
||||
agents, err := u.db.GetWorkspaceAgentsByWorkspaceAndBuildNumber(u.connCtx,
|
||||
database.GetWorkspaceAgentsByWorkspaceAndBuildNumberParams{
|
||||
WorkspaceID: u.workspaceID,
|
||||
BuildNumber: u.lastBuild.BuildNumber,
|
||||
})
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
details := err.Error()
|
||||
retryable := true
|
||||
if dbauthz.IsNotAuthorizedError(err) {
|
||||
retryable = false
|
||||
details = "unauthorized"
|
||||
}
|
||||
u.errorThenClose(workspacesdk.WatchError{
|
||||
Code: workspacesdk.WatchErrorDatabase,
|
||||
Retryable: retryable,
|
||||
Message: "failed to fetch workspace agents",
|
||||
Details: details,
|
||||
})
|
||||
return false
|
||||
}
|
||||
if len(agents) == 0 {
|
||||
u.errorThenClose(workspacesdk.WatchError{
|
||||
Code: workspacesdk.WatchErrorNoAgents,
|
||||
Retryable: false,
|
||||
Message: "no agents found for workspace",
|
||||
})
|
||||
return false
|
||||
}
|
||||
if len(agents) > 1 && u.agentName == "" {
|
||||
u.errorThenClose(workspacesdk.WatchError{
|
||||
Code: workspacesdk.WatchErrorTooManyAgents,
|
||||
Retryable: false,
|
||||
Message: "more than one agent on workspace and target not specified",
|
||||
})
|
||||
return false
|
||||
}
|
||||
var agent database.WorkspaceAgent
|
||||
if u.agentName == "" {
|
||||
agent = agents[0]
|
||||
} else {
|
||||
for _, a := range agents {
|
||||
if a.Name == u.agentName {
|
||||
agent = a
|
||||
break
|
||||
}
|
||||
}
|
||||
if agent.ID == uuid.Nil {
|
||||
u.errorThenClose(workspacesdk.WatchError{
|
||||
Code: workspacesdk.WatchErrorNameNotFound,
|
||||
Retryable: false,
|
||||
Message: "target agent not found by name",
|
||||
})
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
err = u.enc.Encode(workspacesdk.ConnectionWatchEvent{AgentUpdate: &workspacesdk.AgentUpdate{
|
||||
Lifecycle: codersdk.WorkspaceAgentLifecycle(agent.LifecycleState),
|
||||
ID: agent.ID,
|
||||
}})
|
||||
if err != nil {
|
||||
// probably this is just that the connection is closed, but in case there is some actual JSON serialization
|
||||
// error, send a close frame.
|
||||
_ = u.conn.Close(websocket.StatusInternalError, "failed to encode agent update")
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (u *updater) errorThenClose(err workspacesdk.WatchError) {
|
||||
_ = u.enc.Encode(workspacesdk.ConnectionWatchEvent{Error: &err})
|
||||
// ignore encoding errors above because in any case, we are going to close the connection.
|
||||
_ = u.conn.Close(websocket.StatusNormalClosure, "error")
|
||||
}
|
||||
@@ -0,0 +1,474 @@
|
||||
package workspaceconnwatcher_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/database/dbmock"
|
||||
"github.com/coder/coder/v2/coderd/database/pubsub"
|
||||
"github.com/coder/coder/v2/coderd/httpmw"
|
||||
"github.com/coder/coder/v2/coderd/workspaceconnwatcher"
|
||||
"github.com/coder/coder/v2/coderd/wspubsub"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
"github.com/coder/coder/v2/codersdk/wsjson"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
"github.com/coder/websocket"
|
||||
)
|
||||
|
||||
var (
|
||||
workspaceID = uuid.UUID{1}
|
||||
userID = uuid.UUID{2}
|
||||
orgID = uuid.UUID{3}
|
||||
agentID = uuid.UUID{4}
|
||||
)
|
||||
|
||||
type harness struct {
|
||||
db *dbmock.MockStore
|
||||
watcher *workspaceconnwatcher.Watcher
|
||||
pub pubsub.Publisher
|
||||
logger slog.Logger
|
||||
|
||||
// Initialized, but overridable before Dial()
|
||||
workspace database.Workspace
|
||||
userID, orgID uuid.UUID
|
||||
}
|
||||
|
||||
func newHarness(ctx context.Context, t *testing.T, logger slog.Logger) *harness {
|
||||
h := &harness{
|
||||
workspace: database.Workspace{
|
||||
ID: workspaceID,
|
||||
OrganizationID: orgID,
|
||||
OwnerID: userID,
|
||||
},
|
||||
orgID: orgID,
|
||||
userID: userID,
|
||||
logger: logger,
|
||||
}
|
||||
ps := pubsub.NewInMemory()
|
||||
h.pub = ps
|
||||
|
||||
var authzDB database.Store
|
||||
_, h.db, authzDB, _ = coderdtest.MockedDatabaseWithAuthz(t, logger)
|
||||
h.watcher = workspaceconnwatcher.New(ctx, logger.Named("watcher"), ps, authzDB)
|
||||
t.Cleanup(h.watcher.Close)
|
||||
return h
|
||||
}
|
||||
|
||||
func (h *harness) Dial(ctx context.Context, url string) (*wsjson.Decoder[workspacesdk.ConnectionWatchEvent], error) {
|
||||
rt := testutil.InMemWebsocketRoundTripper{
|
||||
Handler: http.HandlerFunc(h.watcher.WorkspaceAgentConnectionWatch),
|
||||
CtxMutator: func(ctx context.Context) context.Context {
|
||||
ctx = httpmw.WithWorkspaceParam(ctx, h.workspace)
|
||||
ctx = dbauthz.As(ctx, coderdtest.MemberSubject(userID, orgID))
|
||||
return ctx
|
||||
},
|
||||
Logger: h.logger.Named("roundtripper"),
|
||||
}
|
||||
// nolint: bodyclose
|
||||
clientSock, resp, err := websocket.Dial(ctx, url, &websocket.DialOptions{
|
||||
HTTPClient: &http.Client{Transport: rt},
|
||||
})
|
||||
if err != nil {
|
||||
if resp.StatusCode != http.StatusSwitchingProtocols {
|
||||
return nil, codersdk.ReadBodyAsError(resp)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dec := wsjson.NewDecoder[workspacesdk.ConnectionWatchEvent](
|
||||
clientSock, websocket.MessageText, h.logger.Named("decoder"))
|
||||
return dec, nil
|
||||
}
|
||||
|
||||
func TestWatcher_Agents(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
agents []database.WorkspaceAgent
|
||||
agentDBError error
|
||||
url string
|
||||
expectedAgentUpdate *workspacesdk.AgentUpdate
|
||||
expectedErrorCode workspacesdk.WatchErrorCode
|
||||
expectedErrorRetryable bool
|
||||
}{
|
||||
{
|
||||
name: "noNameSingleAgent",
|
||||
agents: []database.WorkspaceAgent{
|
||||
{
|
||||
Name: "test",
|
||||
ID: agentID,
|
||||
LifecycleState: database.WorkspaceAgentLifecycleStateCreated,
|
||||
},
|
||||
},
|
||||
url: "wss://local.test/",
|
||||
expectedAgentUpdate: &workspacesdk.AgentUpdate{
|
||||
Lifecycle: codersdk.WorkspaceAgentLifecycleCreated,
|
||||
ID: agentID,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "noNameMultiAgent",
|
||||
agents: []database.WorkspaceAgent{
|
||||
{
|
||||
Name: "agent0",
|
||||
ID: agentID,
|
||||
LifecycleState: database.WorkspaceAgentLifecycleStateCreated,
|
||||
},
|
||||
{
|
||||
Name: "agent1",
|
||||
ID: uuid.UUID{77},
|
||||
LifecycleState: database.WorkspaceAgentLifecycleStateCreated,
|
||||
},
|
||||
},
|
||||
url: "wss://local.test/",
|
||||
expectedErrorCode: workspacesdk.WatchErrorTooManyAgents,
|
||||
expectedErrorRetryable: false,
|
||||
},
|
||||
{
|
||||
name: "namedAgentMultiAgent",
|
||||
agents: []database.WorkspaceAgent{
|
||||
{
|
||||
Name: "agent0",
|
||||
ID: agentID,
|
||||
LifecycleState: database.WorkspaceAgentLifecycleStateCreated,
|
||||
},
|
||||
{
|
||||
Name: "agent1",
|
||||
ID: uuid.UUID{77},
|
||||
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
|
||||
},
|
||||
},
|
||||
url: "wss://local.test/?agent_name=agent0",
|
||||
expectedAgentUpdate: &workspacesdk.AgentUpdate{
|
||||
Lifecycle: codersdk.WorkspaceAgentLifecycleCreated,
|
||||
ID: agentID,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "namedAgentNonexistent",
|
||||
agents: []database.WorkspaceAgent{
|
||||
{
|
||||
Name: "agent0",
|
||||
ID: agentID,
|
||||
LifecycleState: database.WorkspaceAgentLifecycleStateCreated,
|
||||
},
|
||||
{
|
||||
Name: "agent1",
|
||||
ID: uuid.UUID{77},
|
||||
LifecycleState: database.WorkspaceAgentLifecycleStateCreated,
|
||||
},
|
||||
},
|
||||
url: "wss://local.test/?agent_name=agent2",
|
||||
expectedErrorCode: workspacesdk.WatchErrorNameNotFound,
|
||||
expectedErrorRetryable: false,
|
||||
},
|
||||
{
|
||||
name: "dbError",
|
||||
agentDBError: xerrors.New("a bad thing happened"),
|
||||
url: "wss://local.test/",
|
||||
expectedErrorCode: workspacesdk.WatchErrorDatabase,
|
||||
expectedErrorRetryable: true,
|
||||
},
|
||||
{
|
||||
name: "unauthorized",
|
||||
agentDBError: dbauthz.NotAuthorizedError{Err: xerrors.New("not allowed")},
|
||||
url: "wss://local.test/",
|
||||
expectedErrorCode: workspacesdk.WatchErrorDatabase,
|
||||
expectedErrorRetryable: false,
|
||||
},
|
||||
{
|
||||
name: "noAgents",
|
||||
agents: []database.WorkspaceAgent{},
|
||||
url: "wss://local.test/",
|
||||
expectedErrorCode: workspacesdk.WatchErrorNoAgents,
|
||||
expectedErrorRetryable: false,
|
||||
},
|
||||
}
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
logger := testutil.Logger(t)
|
||||
h := newHarness(ctx, t, logger)
|
||||
|
||||
h.db.EXPECT().GetLatestWorkspaceBuildWithStatusByWorkspaceID(gomock.Any(), h.workspace.ID).
|
||||
Times(1).
|
||||
Return(database.GetLatestWorkspaceBuildWithStatusByWorkspaceIDRow{
|
||||
Transition: database.WorkspaceTransitionStart,
|
||||
BuildNumber: 1,
|
||||
JobStatus: database.ProvisionerJobStatusSucceeded,
|
||||
WorkspaceTable: database.WorkspaceTable{
|
||||
ID: h.workspace.ID,
|
||||
OwnerID: userID,
|
||||
OrganizationID: orgID,
|
||||
},
|
||||
}, nil)
|
||||
// RBAC check for agent query
|
||||
h.db.EXPECT().GetWorkspaceByID(gomock.Any(), h.workspace.ID).
|
||||
Times(1).
|
||||
Return(h.workspace, nil)
|
||||
h.db.EXPECT().GetWorkspaceAgentsByWorkspaceAndBuildNumber(
|
||||
gomock.Any(),
|
||||
database.GetWorkspaceAgentsByWorkspaceAndBuildNumberParams{
|
||||
WorkspaceID: h.workspace.ID,
|
||||
BuildNumber: 1,
|
||||
}).
|
||||
Times(1).
|
||||
Return(tc.agents, tc.agentDBError)
|
||||
|
||||
dec, err := h.Dial(ctx, tc.url)
|
||||
require.NoError(t, err)
|
||||
defer dec.Close()
|
||||
events := dec.Chan()
|
||||
e0 := testutil.RequireReceive(ctx, t, events)
|
||||
require.Equal(t, workspacesdk.ConnectionWatchEvent{
|
||||
BuildUpdate: &workspacesdk.BuildUpdate{
|
||||
Transition: codersdk.WorkspaceTransitionStart,
|
||||
JobStatus: codersdk.ProvisionerJobSucceeded,
|
||||
},
|
||||
}, e0)
|
||||
|
||||
e1 := testutil.RequireReceive(ctx, t, events)
|
||||
if tc.expectedAgentUpdate != nil {
|
||||
require.Equal(t, workspacesdk.ConnectionWatchEvent{AgentUpdate: tc.expectedAgentUpdate}, e1)
|
||||
} else {
|
||||
require.NotNil(t, e1.Error)
|
||||
require.Equal(t, tc.expectedErrorRetryable, e1.Error.Retryable)
|
||||
require.Equal(t, tc.expectedErrorCode, e1.Error.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWatcher_LostAccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
logger := testutil.Logger(t)
|
||||
h := newHarness(ctx, t, logger)
|
||||
|
||||
h.db.EXPECT().GetLatestWorkspaceBuildWithStatusByWorkspaceID(gomock.Any(), h.workspace.ID).
|
||||
Times(1).
|
||||
Return(database.GetLatestWorkspaceBuildWithStatusByWorkspaceIDRow{
|
||||
Transition: database.WorkspaceTransitionStart,
|
||||
BuildNumber: 1,
|
||||
JobStatus: database.ProvisionerJobStatusSucceeded,
|
||||
WorkspaceTable: database.WorkspaceTable{
|
||||
ID: h.workspace.ID,
|
||||
OwnerID: uuid.UUID{99}, // workspace gets a new owner, e.g.
|
||||
OrganizationID: orgID,
|
||||
},
|
||||
}, nil)
|
||||
|
||||
dec, err := h.Dial(ctx, "wss://local.test/")
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
err := dec.Close()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
events := dec.Chan()
|
||||
e0 := testutil.RequireReceive(ctx, t, events)
|
||||
require.NotNil(t, e0.Error)
|
||||
require.Equal(t, workspacesdk.WatchErrorDatabase, e0.Error.Code)
|
||||
require.False(t, e0.Error.Retryable)
|
||||
require.Equal(t, "unauthorized", e0.Error.Details, "should not leak internal auth details")
|
||||
}
|
||||
|
||||
func TestWatcher_PublishChanges(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
logger := testutil.Logger(t)
|
||||
h := newHarness(ctx, t, logger)
|
||||
|
||||
// Initial build update, job is running.
|
||||
build0 := h.db.EXPECT().GetLatestWorkspaceBuildWithStatusByWorkspaceID(gomock.Any(), h.workspace.ID).
|
||||
Times(1).
|
||||
Return(database.GetLatestWorkspaceBuildWithStatusByWorkspaceIDRow{
|
||||
Transition: database.WorkspaceTransitionStart,
|
||||
BuildNumber: 1,
|
||||
JobStatus: database.ProvisionerJobStatusRunning,
|
||||
WorkspaceTable: database.WorkspaceTable{
|
||||
ID: h.workspace.ID,
|
||||
OwnerID: userID,
|
||||
OrganizationID: orgID,
|
||||
},
|
||||
}, nil)
|
||||
|
||||
dec, err := h.Dial(ctx, "wss://local.test/")
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
err := dec.Close()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
events := dec.Chan()
|
||||
|
||||
e0 := testutil.RequireReceive(ctx, t, events)
|
||||
require.Equal(t, workspacesdk.ConnectionWatchEvent{
|
||||
BuildUpdate: &workspacesdk.BuildUpdate{
|
||||
Transition: codersdk.WorkspaceTransitionStart,
|
||||
JobStatus: codersdk.ProvisionerJobRunning,
|
||||
},
|
||||
}, e0)
|
||||
|
||||
// Since job is still running, we don't immediately query for agents. Next we set up the db queries and send in an
|
||||
// update over the pubsub to kick a new query.
|
||||
build1 := h.db.EXPECT().GetLatestWorkspaceBuildWithStatusByWorkspaceID(gomock.Any(), h.workspace.ID).
|
||||
After(build0).
|
||||
Times(1).
|
||||
Return(database.GetLatestWorkspaceBuildWithStatusByWorkspaceIDRow{
|
||||
Transition: database.WorkspaceTransitionStart,
|
||||
BuildNumber: 1,
|
||||
JobStatus: database.ProvisionerJobStatusSucceeded,
|
||||
WorkspaceTable: database.WorkspaceTable{
|
||||
ID: h.workspace.ID,
|
||||
OwnerID: userID,
|
||||
OrganizationID: orgID,
|
||||
},
|
||||
}, nil)
|
||||
// RBAC check for agent query
|
||||
h.db.EXPECT().GetWorkspaceByID(gomock.Any(), h.workspace.ID).
|
||||
After(build1).
|
||||
Times(2). // these queries are identical between the initial and the update below
|
||||
Return(h.workspace, nil)
|
||||
agent0 := h.db.EXPECT().GetWorkspaceAgentsByWorkspaceAndBuildNumber(
|
||||
gomock.Any(),
|
||||
database.GetWorkspaceAgentsByWorkspaceAndBuildNumberParams{
|
||||
WorkspaceID: h.workspace.ID,
|
||||
BuildNumber: 1,
|
||||
}).
|
||||
After(build1).
|
||||
Times(1).
|
||||
Return([]database.WorkspaceAgent{
|
||||
{
|
||||
Name: "test",
|
||||
ID: agentID,
|
||||
LifecycleState: database.WorkspaceAgentLifecycleStateCreated,
|
||||
},
|
||||
}, nil)
|
||||
changeMsg := wspubsub.WorkspaceEvent{
|
||||
Kind: wspubsub.WorkspaceEventKindStateChange,
|
||||
WorkspaceID: h.workspace.ID,
|
||||
}
|
||||
changeBytes, err := json.Marshal(changeMsg)
|
||||
require.NoError(t, err)
|
||||
err = h.pub.Publish(wspubsub.WorkspaceEventChannel(h.workspace.OwnerID), changeBytes)
|
||||
require.NoError(t, err)
|
||||
|
||||
e1 := testutil.RequireReceive(ctx, t, events)
|
||||
require.Equal(t, workspacesdk.ConnectionWatchEvent{
|
||||
BuildUpdate: &workspacesdk.BuildUpdate{
|
||||
Transition: codersdk.WorkspaceTransitionStart,
|
||||
JobStatus: codersdk.ProvisionerJobSucceeded,
|
||||
},
|
||||
}, e1)
|
||||
e2 := testutil.RequireReceive(ctx, t, events)
|
||||
require.Equal(t, workspacesdk.ConnectionWatchEvent{AgentUpdate: &workspacesdk.AgentUpdate{
|
||||
ID: agentID,
|
||||
Lifecycle: codersdk.WorkspaceAgentLifecycleCreated,
|
||||
}}, e2)
|
||||
|
||||
// Finally, send in a change event for the agent. But first, program the mock for the expected query.
|
||||
h.db.EXPECT().GetWorkspaceAgentsByWorkspaceAndBuildNumber(
|
||||
gomock.Any(),
|
||||
database.GetWorkspaceAgentsByWorkspaceAndBuildNumberParams{
|
||||
WorkspaceID: h.workspace.ID,
|
||||
BuildNumber: 1,
|
||||
}).
|
||||
After(agent0).
|
||||
Times(1).
|
||||
Return([]database.WorkspaceAgent{
|
||||
{
|
||||
Name: "test",
|
||||
ID: agentID,
|
||||
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
|
||||
},
|
||||
}, nil)
|
||||
changeMsg = wspubsub.WorkspaceEvent{
|
||||
Kind: wspubsub.WorkspaceEventKindAgentLifecycleUpdate,
|
||||
WorkspaceID: h.workspace.ID,
|
||||
AgentID: &agentID,
|
||||
}
|
||||
changeBytes, err = json.Marshal(changeMsg)
|
||||
require.NoError(t, err)
|
||||
err = h.pub.Publish(wspubsub.WorkspaceEventChannel(h.workspace.OwnerID), changeBytes)
|
||||
require.NoError(t, err)
|
||||
|
||||
e3 := testutil.RequireReceive(ctx, t, events)
|
||||
require.Equal(t, workspacesdk.ConnectionWatchEvent{AgentUpdate: &workspacesdk.AgentUpdate{
|
||||
ID: agentID,
|
||||
Lifecycle: codersdk.WorkspaceAgentLifecycleReady,
|
||||
}}, e3)
|
||||
}
|
||||
|
||||
func TestWatcher_ClosedBeforeDial(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
logger := testutil.Logger(t)
|
||||
h := newHarness(ctx, t, logger)
|
||||
h.watcher.Close()
|
||||
_, err := h.Dial(ctx, "wss://local.test/")
|
||||
var sdkError *codersdk.Error
|
||||
require.True(t, errors.As(err, &sdkError))
|
||||
require.Equal(t, http.StatusServiceUnavailable, sdkError.StatusCode())
|
||||
}
|
||||
|
||||
func TestWatcher_ClosedAfterDial(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
logger := testutil.Logger(t)
|
||||
h := newHarness(ctx, t, logger)
|
||||
|
||||
h.db.EXPECT().GetLatestWorkspaceBuildWithStatusByWorkspaceID(gomock.Any(), h.workspace.ID).
|
||||
Times(1).
|
||||
Return(database.GetLatestWorkspaceBuildWithStatusByWorkspaceIDRow{
|
||||
Transition: database.WorkspaceTransitionStop,
|
||||
BuildNumber: 1,
|
||||
JobStatus: database.ProvisionerJobStatusSucceeded,
|
||||
WorkspaceTable: database.WorkspaceTable{
|
||||
ID: h.workspace.ID,
|
||||
OwnerID: userID,
|
||||
OrganizationID: orgID,
|
||||
},
|
||||
}, nil)
|
||||
|
||||
dec, err := h.Dial(ctx, "wss://local.test/")
|
||||
require.NoError(t, err)
|
||||
events := dec.Chan()
|
||||
_ = testutil.RequireReceive(ctx, t, events)
|
||||
|
||||
closed := make(chan struct{})
|
||||
go func() {
|
||||
defer close(closed)
|
||||
h.watcher.Close()
|
||||
}()
|
||||
|
||||
e := testutil.RequireReceive(ctx, t, events)
|
||||
require.NotNil(t, e.Error)
|
||||
require.Equal(t, workspacesdk.WatchErrorServerShutdown, e.Error.Code)
|
||||
require.True(t, e.Error.Retryable)
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Fatal("context timed out")
|
||||
case _, ok := <-events:
|
||||
require.False(t, ok, "socket not closed")
|
||||
}
|
||||
testutil.TryReceive(ctx, t, closed)
|
||||
}
|
||||
@@ -0,0 +1,100 @@
|
||||
package tunneler_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
"github.com/coder/coder/v2/agent/agenttest"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbfake"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk/tunneler"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
// TestTunneler_Integration is an integration test using coderdtest. It should be removed when we integrate the Tunneler
|
||||
// into coder ssh and those integration test cover this functionality.
|
||||
func TestTunneler_Integration(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client, store := coderdtest.NewWithDatabase(t, nil)
|
||||
logger := testutil.Logger(t)
|
||||
client.SetLogger(logger.Named("client"))
|
||||
first := coderdtest.CreateFirstUser(t, client)
|
||||
userClient, user := coderdtest.CreateAnotherUserMutators(t, client, first.OrganizationID, nil, func(r *codersdk.CreateUserRequestWithOrgs) {
|
||||
r.Username = "myuser"
|
||||
})
|
||||
userClient.SetLogger(logger.Named("userclient"))
|
||||
r := dbfake.WorkspaceBuild(t, store, database.WorkspaceTable{
|
||||
Name: "myworkspace",
|
||||
OrganizationID: first.OrganizationID,
|
||||
OwnerID: user.ID,
|
||||
}).WithAgent().Do()
|
||||
wsSDKClient := workspacesdk.New(userClient)
|
||||
logs := &bytes.Buffer{}
|
||||
|
||||
app := &sshApplication{
|
||||
t: t,
|
||||
ctx: ctx,
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
|
||||
tun := tunneler.NewTunneler(wsSDKClient, tunneler.Config{
|
||||
WorkspaceID: r.Workspace.ID,
|
||||
App: app,
|
||||
WorkspaceStarter: nil,
|
||||
AgentName: "",
|
||||
LogWriter: logs,
|
||||
DebugLogger: logger.Named("tunneler"),
|
||||
})
|
||||
|
||||
testAgent := agenttest.New(t, client.URL, r.AgentToken)
|
||||
defer testAgent.Close()
|
||||
|
||||
testutil.TryReceive(ctx, t, app.done)
|
||||
require.Equal(t, app.result, "foo\n")
|
||||
|
||||
err := tun.GracefulShutdown(ctx)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
type sshApplication struct {
|
||||
t *testing.T
|
||||
ctx context.Context
|
||||
client *ssh.Client
|
||||
done chan struct{}
|
||||
result string
|
||||
}
|
||||
|
||||
func (s *sshApplication) Close() error {
|
||||
return s.client.Close()
|
||||
}
|
||||
|
||||
func (s *sshApplication) Start(conn workspacesdk.AgentConn) error {
|
||||
var err error
|
||||
s.client, err = conn.SSHClient(s.ctx)
|
||||
if err != nil {
|
||||
s.t.Error(err)
|
||||
return err
|
||||
}
|
||||
go func() {
|
||||
defer close(s.done)
|
||||
sess, err := s.client.NewSession()
|
||||
if err != nil {
|
||||
s.t.Error("failed to create session", err)
|
||||
}
|
||||
defer sess.Close()
|
||||
out, err := sess.Output("echo foo")
|
||||
if err != nil {
|
||||
s.t.Error("failed to echo", err)
|
||||
}
|
||||
s.result = string(out)
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
"github.com/coder/coder/v2/codersdk/wsjson"
|
||||
)
|
||||
|
||||
type state int
|
||||
@@ -33,6 +34,11 @@ type WorkspaceStarter interface {
|
||||
|
||||
type Client interface {
|
||||
DialAgent(dialCtx context.Context, agentID uuid.UUID, options *workspacesdk.DialAgentOptions) (workspacesdk.AgentConn, error)
|
||||
WorkspaceAgentConnectionWatch(
|
||||
dialCtx context.Context, workspaceID uuid.UUID, agentName string,
|
||||
) (
|
||||
dec *wsjson.Decoder[workspacesdk.ConnectionWatchEvent], err error,
|
||||
)
|
||||
}
|
||||
|
||||
const (
|
||||
@@ -127,9 +133,9 @@ type Config struct {
|
||||
// ordering.
|
||||
type tunnelerEvent struct {
|
||||
shutdownSignal *shutdownSignal
|
||||
buildUpdate *buildUpdate
|
||||
buildUpdate *workspacesdk.BuildUpdate
|
||||
provisionerJobLog *codersdk.ProvisionerJobLog
|
||||
agentUpdate *agentUpdate
|
||||
agentUpdate *workspacesdk.AgentUpdate
|
||||
agentLog *codersdk.WorkspaceAgentLog
|
||||
appUpdate *networkedApplicationUpdate
|
||||
tailnetUpdate *tailnetUpdate
|
||||
@@ -137,16 +143,6 @@ type tunnelerEvent struct {
|
||||
|
||||
type shutdownSignal struct{}
|
||||
|
||||
type buildUpdate struct {
|
||||
transition codersdk.WorkspaceTransition
|
||||
jobStatus codersdk.ProvisionerJobStatus
|
||||
}
|
||||
|
||||
type agentUpdate struct {
|
||||
lifecycle codersdk.WorkspaceAgentLifecycle
|
||||
id uuid.UUID
|
||||
}
|
||||
|
||||
type networkedApplicationUpdate struct {
|
||||
// up is true if the application is up. False if it is down.
|
||||
up bool
|
||||
@@ -174,10 +170,64 @@ func NewTunneler(client Client, config Config) *Tunneler {
|
||||
return t
|
||||
}
|
||||
|
||||
func (t *Tunneler) GracefulShutdown(ctx context.Context) error {
|
||||
select {
|
||||
case t.events <- tunnelerEvent{shutdownSignal: &shutdownSignal{}}:
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-t.ctx.Done():
|
||||
}
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
t.wg.Wait()
|
||||
}()
|
||||
select {
|
||||
case <-done:
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tunneler) start() {
|
||||
defer t.wg.Done()
|
||||
// here we would subscribe to updates.
|
||||
// t.client.AgentConnectionWatch(t.config.WorkspaceID, t.config.AgentName)
|
||||
d, err := t.client.WorkspaceAgentConnectionWatch(t.ctx, t.config.WorkspaceID, t.config.AgentName)
|
||||
// TODO: handle retries
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer d.Close()
|
||||
c := d.Chan()
|
||||
for {
|
||||
select {
|
||||
case <-t.ctx.Done():
|
||||
return
|
||||
case event, ok := <-c:
|
||||
if !ok {
|
||||
t.config.DebugLogger.Error(t.ctx, "watch closed")
|
||||
}
|
||||
if event.Error != nil {
|
||||
t.config.DebugLogger.Error(t.ctx, "workspace agent connection watch error", slog.Error(event.Error))
|
||||
}
|
||||
if !ok || event.Error != nil {
|
||||
// TODO: handle retries
|
||||
select {
|
||||
case t.events <- tunnelerEvent{shutdownSignal: &shutdownSignal{}}:
|
||||
case <-t.ctx.Done():
|
||||
}
|
||||
return
|
||||
}
|
||||
select {
|
||||
case <-t.ctx.Done():
|
||||
return
|
||||
case t.events <- tunnelerEvent{
|
||||
buildUpdate: event.BuildUpdate,
|
||||
agentUpdate: event.AgentUpdate,
|
||||
}:
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tunneler) eventLoop() {
|
||||
@@ -235,13 +285,13 @@ func (t *Tunneler) handleSignal() {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tunneler) handleBuildUpdate(update *buildUpdate) {
|
||||
func (t *Tunneler) handleBuildUpdate(update *workspacesdk.BuildUpdate) {
|
||||
if t.state == shutdownTailnet || t.state == shutdownApplication || t.state == exit {
|
||||
return // no-op
|
||||
}
|
||||
|
||||
var canMakeProgress, jobUnhealthy bool
|
||||
switch update.jobStatus {
|
||||
switch update.JobStatus {
|
||||
case codersdk.ProvisionerJobPending, codersdk.ProvisionerJobRunning:
|
||||
canMakeProgress = true
|
||||
case codersdk.ProvisionerJobSucceeded:
|
||||
@@ -249,21 +299,21 @@ func (t *Tunneler) handleBuildUpdate(update *buildUpdate) {
|
||||
jobUnhealthy = true
|
||||
}
|
||||
|
||||
if update.transition == codersdk.WorkspaceTransitionDelete {
|
||||
t.config.DebugLogger.Info(t.ctx, "workspace is being deleted", slog.F("job_status", update.jobStatus))
|
||||
if update.Transition == codersdk.WorkspaceTransitionDelete {
|
||||
t.config.DebugLogger.Info(t.ctx, "workspace is being deleted", slog.F("job_status", update.JobStatus))
|
||||
// treat same as signal
|
||||
t.handleSignal()
|
||||
return
|
||||
}
|
||||
if jobUnhealthy {
|
||||
t.config.DebugLogger.Info(t.ctx, "build job is in unhealthy state", slog.F("job_status", update.jobStatus))
|
||||
t.config.DebugLogger.Info(t.ctx, "build job is in unhealthy state", slog.F("job_status", update.JobStatus))
|
||||
// treat same as signal
|
||||
t.handleSignal()
|
||||
return
|
||||
}
|
||||
|
||||
if update.transition == codersdk.WorkspaceTransitionStart && canMakeProgress {
|
||||
t.config.DebugLogger.Debug(t.ctx, "workspace is starting", slog.F("job_status", update.jobStatus))
|
||||
if update.Transition == codersdk.WorkspaceTransitionStart && canMakeProgress {
|
||||
t.config.DebugLogger.Debug(t.ctx, "workspace is starting", slog.F("job_status", update.JobStatus))
|
||||
switch t.state {
|
||||
// new build after we have already connected
|
||||
case establishTailnet: // we are starting the tailnet
|
||||
@@ -279,8 +329,8 @@ func (t *Tunneler) handleBuildUpdate(update *buildUpdate) {
|
||||
}
|
||||
return
|
||||
}
|
||||
if update.transition == codersdk.WorkspaceTransitionStart && update.jobStatus == codersdk.ProvisionerJobSucceeded {
|
||||
t.config.DebugLogger.Debug(t.ctx, "workspace is started", slog.F("job_status", update.jobStatus))
|
||||
if update.Transition == codersdk.WorkspaceTransitionStart && update.JobStatus == codersdk.ProvisionerJobSucceeded {
|
||||
t.config.DebugLogger.Debug(t.ctx, "workspace is started", slog.F("job_status", update.JobStatus))
|
||||
switch t.state {
|
||||
case establishTailnet, applicationUp, tailnetUp:
|
||||
// no-op. Later agent updates will tell us whether the tailnet connection is current.
|
||||
@@ -290,7 +340,7 @@ func (t *Tunneler) handleBuildUpdate(update *buildUpdate) {
|
||||
return
|
||||
}
|
||||
|
||||
if update.transition == codersdk.WorkspaceTransitionStop {
|
||||
if update.Transition == codersdk.WorkspaceTransitionStop {
|
||||
// these cases take effect regardless of whether the transition is complete or not
|
||||
switch t.state {
|
||||
// all 3 of these mean a new build after we have already started connecting
|
||||
@@ -312,7 +362,7 @@ func (t *Tunneler) handleBuildUpdate(update *buildUpdate) {
|
||||
t.state = exit
|
||||
return
|
||||
}
|
||||
if update.jobStatus == codersdk.ProvisionerJobSucceeded {
|
||||
if update.JobStatus == codersdk.ProvisionerJobSucceeded {
|
||||
switch t.state {
|
||||
case stateInit, waitToStart, waitForAgent:
|
||||
t.wg.Add(1)
|
||||
@@ -335,29 +385,29 @@ func (t *Tunneler) handleBuildUpdate(update *buildUpdate) {
|
||||
}
|
||||
// unhittable
|
||||
t.config.DebugLogger.Critical(t.ctx, "unhandled build update",
|
||||
slog.F("job_status", update.jobStatus), slog.F("transition", update.transition), slog.F("state", t.state))
|
||||
slog.F("job_status", update.JobStatus), slog.F("transition", update.Transition), slog.F("state", t.state))
|
||||
}
|
||||
|
||||
func (*Tunneler) handleProvisionerJobLog(*codersdk.ProvisionerJobLog) {
|
||||
}
|
||||
|
||||
func (t *Tunneler) handleAgentUpdate(update *agentUpdate) {
|
||||
func (t *Tunneler) handleAgentUpdate(update *workspacesdk.AgentUpdate) {
|
||||
t.config.DebugLogger.Debug(t.ctx, "handling agent update",
|
||||
slog.F("state", t.state),
|
||||
slog.F("lifecycle", update.lifecycle),
|
||||
slog.F("agent_id", update.id))
|
||||
slog.F("lifecycle", update.Lifecycle),
|
||||
slog.F("agent_id", update.ID))
|
||||
if t.state != waitForAgent {
|
||||
return
|
||||
}
|
||||
doConnect := func() {
|
||||
t.wg.Add(1)
|
||||
t.state = establishTailnet
|
||||
go t.connectTailnet(update.id)
|
||||
go t.connectTailnet(update.ID)
|
||||
}
|
||||
// consequence of ignoring updates if we are not waiting for the agent is that we MUST receive
|
||||
// the start build succeeded update BEFORE we get the Agent connected / ready update. We should keep this
|
||||
// in mind when implementing the watch in Coderd.
|
||||
switch update.lifecycle {
|
||||
switch update.Lifecycle {
|
||||
case codersdk.WorkspaceAgentLifecycleReady:
|
||||
doConnect()
|
||||
return
|
||||
@@ -376,7 +426,7 @@ func (t *Tunneler) handleAgentUpdate(update *agentUpdate) {
|
||||
default:
|
||||
// unhittable, unless new states are added. We structure this with the switch and all cases covered to ensure
|
||||
// we cover all cases.
|
||||
t.config.DebugLogger.Critical(t.ctx, "unhandled agent update", slog.F("lifecycle", update.lifecycle))
|
||||
t.config.DebugLogger.Critical(t.ctx, "unhandled agent update", slog.F("lifecycle", update.Lifecycle))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk/agentconnmock"
|
||||
"github.com/coder/coder/v2/codersdk/wsjson"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
@@ -28,7 +29,7 @@ func TestHandleBuildUpdate_Coverage(t *testing.T) {
|
||||
t.Run(fmt.Sprintf("%d_%s_%s_%t_%t", s, trans, jobStatus, noAutostart, noWaitForScripts), func(t *testing.T) {
|
||||
t.Parallel()
|
||||
coverUpdate(t, workspaceID, noAutostart, noWaitForScripts, s, func(uut *Tunneler) {
|
||||
uut.handleBuildUpdate(&buildUpdate{transition: trans, jobStatus: jobStatus})
|
||||
uut.handleBuildUpdate(&workspacesdk.BuildUpdate{Transition: trans, JobStatus: jobStatus})
|
||||
})
|
||||
})
|
||||
}
|
||||
@@ -105,31 +106,31 @@ func TestBuildUpdatesStoppedWorkspace(t *testing.T) {
|
||||
state: stateInit,
|
||||
}
|
||||
|
||||
uut.handleBuildUpdate(&buildUpdate{transition: codersdk.WorkspaceTransitionStop, jobStatus: codersdk.ProvisionerJobPending})
|
||||
uut.handleBuildUpdate(&workspacesdk.BuildUpdate{Transition: codersdk.WorkspaceTransitionStop, JobStatus: codersdk.ProvisionerJobPending})
|
||||
require.Equal(t, waitToStart, uut.state)
|
||||
waitForGoroutines(testCtx, t, uut)
|
||||
require.False(t, fWorkspaceStarter.started)
|
||||
|
||||
uut.handleBuildUpdate(&buildUpdate{transition: codersdk.WorkspaceTransitionStop, jobStatus: codersdk.ProvisionerJobRunning})
|
||||
uut.handleBuildUpdate(&workspacesdk.BuildUpdate{Transition: codersdk.WorkspaceTransitionStop, JobStatus: codersdk.ProvisionerJobRunning})
|
||||
require.Equal(t, waitToStart, uut.state)
|
||||
waitForGoroutines(testCtx, t, uut)
|
||||
require.False(t, fWorkspaceStarter.started)
|
||||
|
||||
// when stop job succeeds, we start the workspace
|
||||
uut.handleBuildUpdate(&buildUpdate{transition: codersdk.WorkspaceTransitionStop, jobStatus: codersdk.ProvisionerJobSucceeded})
|
||||
uut.handleBuildUpdate(&workspacesdk.BuildUpdate{Transition: codersdk.WorkspaceTransitionStop, JobStatus: codersdk.ProvisionerJobSucceeded})
|
||||
require.Equal(t, waitForWorkspaceStarted, uut.state)
|
||||
waitForGoroutines(testCtx, t, uut)
|
||||
require.True(t, fWorkspaceStarter.started)
|
||||
|
||||
uut.handleBuildUpdate(&buildUpdate{transition: codersdk.WorkspaceTransitionStart, jobStatus: codersdk.ProvisionerJobPending})
|
||||
uut.handleBuildUpdate(&workspacesdk.BuildUpdate{Transition: codersdk.WorkspaceTransitionStart, JobStatus: codersdk.ProvisionerJobPending})
|
||||
require.Equal(t, waitForWorkspaceStarted, uut.state)
|
||||
waitForGoroutines(testCtx, t, uut)
|
||||
|
||||
uut.handleBuildUpdate(&buildUpdate{transition: codersdk.WorkspaceTransitionStart, jobStatus: codersdk.ProvisionerJobRunning})
|
||||
uut.handleBuildUpdate(&workspacesdk.BuildUpdate{Transition: codersdk.WorkspaceTransitionStart, JobStatus: codersdk.ProvisionerJobRunning})
|
||||
require.Equal(t, waitForWorkspaceStarted, uut.state)
|
||||
waitForGoroutines(testCtx, t, uut)
|
||||
|
||||
uut.handleBuildUpdate(&buildUpdate{transition: codersdk.WorkspaceTransitionStart, jobStatus: codersdk.ProvisionerJobSucceeded})
|
||||
uut.handleBuildUpdate(&workspacesdk.BuildUpdate{Transition: codersdk.WorkspaceTransitionStart, JobStatus: codersdk.ProvisionerJobSucceeded})
|
||||
require.Equal(t, waitForAgent, uut.state)
|
||||
waitForGoroutines(testCtx, t, uut)
|
||||
}
|
||||
@@ -157,7 +158,7 @@ func TestBuildUpdatesNewBuildWhileWaiting(t *testing.T) {
|
||||
}
|
||||
|
||||
// New build comes in while we are waiting for the agent to start. We roll back to waiting for the workspace to start.
|
||||
uut.handleBuildUpdate(&buildUpdate{transition: codersdk.WorkspaceTransitionStart, jobStatus: codersdk.ProvisionerJobRunning})
|
||||
uut.handleBuildUpdate(&workspacesdk.BuildUpdate{Transition: codersdk.WorkspaceTransitionStart, JobStatus: codersdk.ProvisionerJobRunning})
|
||||
require.Equal(t, waitForWorkspaceStarted, uut.state)
|
||||
waitForGoroutines(testCtx, t, uut)
|
||||
require.False(t, fWorkspaceStarter.started)
|
||||
@@ -193,12 +194,12 @@ func TestBuildUpdatesBadJobs(t *testing.T) {
|
||||
state: stateInit,
|
||||
}
|
||||
|
||||
uut.handleBuildUpdate(&buildUpdate{transition: codersdk.WorkspaceTransitionStart, jobStatus: codersdk.ProvisionerJobRunning})
|
||||
uut.handleBuildUpdate(&workspacesdk.BuildUpdate{Transition: codersdk.WorkspaceTransitionStart, JobStatus: codersdk.ProvisionerJobRunning})
|
||||
require.Equal(t, waitForWorkspaceStarted, uut.state)
|
||||
waitForGoroutines(testCtx, t, uut)
|
||||
require.False(t, fWorkspaceStarter.started)
|
||||
|
||||
uut.handleBuildUpdate(&buildUpdate{transition: codersdk.WorkspaceTransitionStop, jobStatus: jobStatus})
|
||||
uut.handleBuildUpdate(&workspacesdk.BuildUpdate{Transition: codersdk.WorkspaceTransitionStop, JobStatus: jobStatus})
|
||||
require.Equal(t, exit, uut.state)
|
||||
waitForGoroutines(testCtx, t, uut)
|
||||
require.False(t, fWorkspaceStarter.started)
|
||||
@@ -233,7 +234,7 @@ func TestBuildUpdatesNoAutostart(t *testing.T) {
|
||||
}
|
||||
|
||||
// when stop job succeeds, we exit because autostart is disabled
|
||||
uut.handleBuildUpdate(&buildUpdate{transition: codersdk.WorkspaceTransitionStop, jobStatus: codersdk.ProvisionerJobSucceeded})
|
||||
uut.handleBuildUpdate(&workspacesdk.BuildUpdate{Transition: codersdk.WorkspaceTransitionStop, JobStatus: codersdk.ProvisionerJobSucceeded})
|
||||
require.Equal(t, exit, uut.state)
|
||||
waitForGoroutines(testCtx, t, uut)
|
||||
require.False(t, fWorkspaceStarter.started)
|
||||
@@ -254,7 +255,7 @@ func TestAgentUpdate_Coverage(t *testing.T) {
|
||||
t.Run(fmt.Sprintf("%d_%s_%t_%t", s, lifecycle, noAutostart, noWaitForScripts), func(t *testing.T) {
|
||||
t.Parallel()
|
||||
coverUpdate(t, workspaceID, noAutostart, noWaitForScripts, s, func(uut *Tunneler) {
|
||||
uut.handleAgentUpdate(&agentUpdate{lifecycle: lifecycle, id: agentID})
|
||||
uut.handleAgentUpdate(&workspacesdk.AgentUpdate{Lifecycle: lifecycle, ID: agentID})
|
||||
})
|
||||
})
|
||||
}
|
||||
@@ -288,7 +289,7 @@ func TestAgentUpdateReady(t *testing.T) {
|
||||
client: fClient,
|
||||
}
|
||||
|
||||
uut.handleAgentUpdate(&agentUpdate{lifecycle: codersdk.WorkspaceAgentLifecycleReady, id: agentID})
|
||||
uut.handleAgentUpdate(&workspacesdk.AgentUpdate{Lifecycle: codersdk.WorkspaceAgentLifecycleReady, ID: agentID})
|
||||
require.Equal(t, establishTailnet, uut.state)
|
||||
event := testutil.RequireReceive(testCtx, t, uut.events)
|
||||
require.NotNil(t, event.tailnetUpdate)
|
||||
@@ -323,7 +324,7 @@ func TestAgentUpdateNoWait(t *testing.T) {
|
||||
client: fClient,
|
||||
}
|
||||
|
||||
uut.handleAgentUpdate(&agentUpdate{lifecycle: codersdk.WorkspaceAgentLifecycleStarting, id: agentID})
|
||||
uut.handleAgentUpdate(&workspacesdk.AgentUpdate{Lifecycle: codersdk.WorkspaceAgentLifecycleStarting, ID: agentID})
|
||||
require.Equal(t, establishTailnet, uut.state)
|
||||
event := testutil.RequireReceive(testCtx, t, uut.events)
|
||||
require.NotNil(t, event.tailnetUpdate)
|
||||
@@ -526,27 +527,27 @@ func TestTunneler_EventLoop_Signal(t *testing.T) {
|
||||
go uut.eventLoop()
|
||||
|
||||
testutil.RequireSend(testCtx, t, uut.events, tunnelerEvent{
|
||||
buildUpdate: &buildUpdate{
|
||||
transition: codersdk.WorkspaceTransitionStart,
|
||||
jobStatus: codersdk.ProvisionerJobPending,
|
||||
buildUpdate: &workspacesdk.BuildUpdate{
|
||||
Transition: codersdk.WorkspaceTransitionStart,
|
||||
JobStatus: codersdk.ProvisionerJobPending,
|
||||
},
|
||||
})
|
||||
testutil.RequireSend(testCtx, t, uut.events, tunnelerEvent{
|
||||
buildUpdate: &buildUpdate{
|
||||
transition: codersdk.WorkspaceTransitionStart,
|
||||
jobStatus: codersdk.ProvisionerJobRunning,
|
||||
buildUpdate: &workspacesdk.BuildUpdate{
|
||||
Transition: codersdk.WorkspaceTransitionStart,
|
||||
JobStatus: codersdk.ProvisionerJobRunning,
|
||||
},
|
||||
})
|
||||
testutil.RequireSend(testCtx, t, uut.events, tunnelerEvent{
|
||||
buildUpdate: &buildUpdate{
|
||||
transition: codersdk.WorkspaceTransitionStart,
|
||||
jobStatus: codersdk.ProvisionerJobSucceeded,
|
||||
buildUpdate: &workspacesdk.BuildUpdate{
|
||||
Transition: codersdk.WorkspaceTransitionStart,
|
||||
JobStatus: codersdk.ProvisionerJobSucceeded,
|
||||
},
|
||||
})
|
||||
testutil.RequireSend(testCtx, t, uut.events, tunnelerEvent{
|
||||
agentUpdate: &agentUpdate{
|
||||
lifecycle: codersdk.WorkspaceAgentLifecycleReady,
|
||||
id: agentID,
|
||||
agentUpdate: &workspacesdk.AgentUpdate{
|
||||
Lifecycle: codersdk.WorkspaceAgentLifecycleReady,
|
||||
ID: agentID,
|
||||
},
|
||||
})
|
||||
|
||||
@@ -658,6 +659,11 @@ type fakeClient struct {
|
||||
dialed bool
|
||||
}
|
||||
|
||||
func (*fakeClient) WorkspaceAgentConnectionWatch(context.Context, uuid.UUID, string) (dec *wsjson.Decoder[workspacesdk.ConnectionWatchEvent], err error) {
|
||||
// TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (f *fakeClient) DialAgent(
|
||||
_ context.Context, id uuid.UUID, _ *workspacesdk.DialAgentOptions,
|
||||
) (
|
||||
|
||||
@@ -0,0 +1,86 @@
|
||||
package workspacesdk
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/wsjson"
|
||||
"github.com/coder/websocket"
|
||||
)
|
||||
|
||||
type WatchErrorCode int
|
||||
|
||||
const (
|
||||
_ WatchErrorCode = iota // Ensure that zero value is not a valid code
|
||||
WatchErrorTooManyAgents
|
||||
WatchErrorNameNotFound
|
||||
WatchErrorNoAgents
|
||||
WatchErrorServerShutdown
|
||||
WatchErrorDatabase
|
||||
WatchErrorInternal
|
||||
)
|
||||
|
||||
type ConnectionWatchEvent struct {
|
||||
Error *WatchError `json:"error"`
|
||||
BuildUpdate *BuildUpdate `json:"build_update,omitempty"`
|
||||
AgentUpdate *AgentUpdate `json:"agent_update,omitempty"`
|
||||
}
|
||||
|
||||
type WatchError struct {
|
||||
Code WatchErrorCode `json:"code"`
|
||||
Retryable bool `json:"retryable"`
|
||||
Message string `json:"message"`
|
||||
Details string `json:"details,omitempty"`
|
||||
}
|
||||
|
||||
func (e *WatchError) Error() string {
|
||||
if e.Details != "" {
|
||||
return fmt.Sprintf("%s: %s", e.Message, e.Details)
|
||||
}
|
||||
return e.Message
|
||||
}
|
||||
|
||||
type BuildUpdate struct {
|
||||
Transition codersdk.WorkspaceTransition `json:"transition"`
|
||||
JobStatus codersdk.ProvisionerJobStatus `json:"job_status"`
|
||||
}
|
||||
|
||||
type AgentUpdate struct {
|
||||
Lifecycle codersdk.WorkspaceAgentLifecycle `json:"lifecycle"`
|
||||
ID uuid.UUID `json:"id" format:"uuid"`
|
||||
}
|
||||
|
||||
func (c *Client) WorkspaceAgentConnectionWatch(
|
||||
dialCtx context.Context, workspaceID uuid.UUID, agentName string,
|
||||
) (
|
||||
dec *wsjson.Decoder[ConnectionWatchEvent], err error,
|
||||
) {
|
||||
wsOptions := &websocket.DialOptions{
|
||||
HTTPClient: c.client.HTTPClient,
|
||||
// Need to disable compression to avoid a data-race.
|
||||
CompressionMode: websocket.CompressionDisabled,
|
||||
}
|
||||
c.client.SessionTokenProvider.SetDialOption(wsOptions)
|
||||
|
||||
watchURL, err := c.client.URL.Parse(fmt.Sprintf("/api/v2/workspaces/%s/agent-connection-watch", workspaceID))
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("parse url: %w", err)
|
||||
}
|
||||
if agentName != "" {
|
||||
q := watchURL.Query()
|
||||
q.Set("agent_name", agentName)
|
||||
watchURL.RawQuery = q.Encode()
|
||||
}
|
||||
|
||||
// nolint:bodyclose
|
||||
conn, res, err := websocket.Dial(dialCtx, watchURL.String(), wsOptions)
|
||||
if err != nil {
|
||||
bodyErr := codersdk.ReadBodyAsError(res)
|
||||
return nil, bodyErr
|
||||
}
|
||||
return wsjson.NewDecoder[ConnectionWatchEvent](conn, websocket.MessageText, c.client.Logger()), nil
|
||||
}
|
||||
Generated
+95
@@ -18616,6 +18616,101 @@ None
|
||||
| `disable_direct_connections` | boolean | false | | |
|
||||
| `hostname_suffix` | string | false | | |
|
||||
|
||||
## workspacesdk.AgentUpdate
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08",
|
||||
"lifecycle": "created"
|
||||
}
|
||||
```
|
||||
|
||||
### Properties
|
||||
|
||||
| Name | Type | Required | Restrictions | Description |
|
||||
|-------------|----------------------------------------------------------------------|----------|--------------|-------------|
|
||||
| `id` | string | false | | |
|
||||
| `lifecycle` | [codersdk.WorkspaceAgentLifecycle](#codersdkworkspaceagentlifecycle) | false | | |
|
||||
|
||||
## workspacesdk.BuildUpdate
|
||||
|
||||
```json
|
||||
{
|
||||
"job_status": "pending",
|
||||
"transition": "start"
|
||||
}
|
||||
```
|
||||
|
||||
### Properties
|
||||
|
||||
| Name | Type | Required | Restrictions | Description |
|
||||
|--------------|----------------------------------------------------------------|----------|--------------|-------------|
|
||||
| `job_status` | [codersdk.ProvisionerJobStatus](#codersdkprovisionerjobstatus) | false | | |
|
||||
| `transition` | [codersdk.WorkspaceTransition](#codersdkworkspacetransition) | false | | |
|
||||
|
||||
## workspacesdk.ConnectionWatchEvent
|
||||
|
||||
```json
|
||||
{
|
||||
"agent_update": {
|
||||
"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08",
|
||||
"lifecycle": "created"
|
||||
},
|
||||
"build_update": {
|
||||
"job_status": "pending",
|
||||
"transition": "start"
|
||||
},
|
||||
"error": {
|
||||
"code": 0,
|
||||
"details": "string",
|
||||
"message": "string",
|
||||
"retryable": true
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Properties
|
||||
|
||||
| Name | Type | Required | Restrictions | Description |
|
||||
|----------------|------------------------------------------------------|----------|--------------|-------------|
|
||||
| `agent_update` | [workspacesdk.AgentUpdate](#workspacesdkagentupdate) | false | | |
|
||||
| `build_update` | [workspacesdk.BuildUpdate](#workspacesdkbuildupdate) | false | | |
|
||||
| `error` | [workspacesdk.WatchError](#workspacesdkwatcherror) | false | | |
|
||||
|
||||
## workspacesdk.WatchError
|
||||
|
||||
```json
|
||||
{
|
||||
"code": 0,
|
||||
"details": "string",
|
||||
"message": "string",
|
||||
"retryable": true
|
||||
}
|
||||
```
|
||||
|
||||
### Properties
|
||||
|
||||
| Name | Type | Required | Restrictions | Description |
|
||||
|-------------|------------------------------------------------------------|----------|--------------|-------------|
|
||||
| `code` | [workspacesdk.WatchErrorCode](#workspacesdkwatcherrorcode) | false | | |
|
||||
| `details` | string | false | | |
|
||||
| `message` | string | false | | |
|
||||
| `retryable` | boolean | false | | |
|
||||
|
||||
## workspacesdk.WatchErrorCode
|
||||
|
||||
```json
|
||||
0
|
||||
```
|
||||
|
||||
### Properties
|
||||
|
||||
#### Enumerated Values
|
||||
|
||||
| Value(s) |
|
||||
|-----------------------------------|
|
||||
| `0`, `1`, `2`, `3`, `4`, `5`, `6` |
|
||||
|
||||
## wsproxysdk.CryptoKeysResponse
|
||||
|
||||
```json
|
||||
|
||||
Generated
+50
@@ -1817,6 +1817,56 @@ curl -X PATCH http://coder-server:8080/api/v2/workspaces/{workspace}/acl \
|
||||
|
||||
To perform this operation, you must be authenticated. [Learn more](authentication.md).
|
||||
|
||||
## Workspace Agent Connection Watch
|
||||
|
||||
### Code samples
|
||||
|
||||
```shell
|
||||
# Example request using curl
|
||||
curl -X GET http://coder-server:8080/api/v2/workspaces/{workspace}/agent-connection-watch \
|
||||
-H 'Accept: application/json' \
|
||||
-H 'Coder-Session-Token: API_KEY'
|
||||
```
|
||||
|
||||
`GET /api/v2/workspaces/{workspace}/agent-connection-watch`
|
||||
|
||||
### Parameters
|
||||
|
||||
| Name | In | Type | Required | Description |
|
||||
|-------------|------|--------------|----------|--------------|
|
||||
| `workspace` | path | string(uuid) | true | Workspace ID |
|
||||
|
||||
### Example responses
|
||||
|
||||
> 101 Response
|
||||
|
||||
```json
|
||||
{
|
||||
"agent_update": {
|
||||
"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08",
|
||||
"lifecycle": "created"
|
||||
},
|
||||
"build_update": {
|
||||
"job_status": "pending",
|
||||
"transition": "start"
|
||||
},
|
||||
"error": {
|
||||
"code": 0,
|
||||
"details": "string",
|
||||
"message": "string",
|
||||
"retryable": true
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Responses
|
||||
|
||||
| Status | Meaning | Description | Schema |
|
||||
|--------|--------------------------------------------------------------------------|---------------------|----------------------------------------------------------------------------------|
|
||||
| 101 | [Switching Protocols](https://tools.ietf.org/html/rfc7231#section-6.2.2) | Switching Protocols | [workspacesdk.ConnectionWatchEvent](schemas.md#workspacesdkconnectionwatchevent) |
|
||||
|
||||
To perform this operation, you must be authenticated. [Learn more](authentication.md).
|
||||
|
||||
## Update workspace autostart schedule by ID
|
||||
|
||||
### Code samples
|
||||
|
||||
@@ -12,7 +12,8 @@ import (
|
||||
func TestEnterpriseEndpointsDocumented(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
swaggerComments, err := coderdtest.ParseSwaggerComments("..", "../../../coderd")
|
||||
swaggerComments, err := coderdtest.ParseSwaggerComments(
|
||||
"..", "../../../coderd", "../../../coderd/workspaceconnwatcher")
|
||||
require.NoError(t, err, "can't parse swagger comments")
|
||||
require.NotEmpty(t, swaggerComments, "swagger comments must be present")
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ func main() {
|
||||
}
|
||||
|
||||
err := gen.New().Build(&gen.Config{
|
||||
SearchDir: "./coderd,./codersdk,./enterprise/coderd,./enterprise/wsproxy/wsproxysdk",
|
||||
SearchDir: "./coderd,./coderd/workspaceconnwatcher,./codersdk,./enterprise/coderd,./enterprise/wsproxy/wsproxysdk",
|
||||
MainAPIFile: "coderd.go",
|
||||
OutputDir: outputDir,
|
||||
OutputTypes: []string{"go", "json"},
|
||||
|
||||
@@ -0,0 +1,101 @@
|
||||
package testutil
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
)
|
||||
|
||||
// InMemWebsocketRoundTripper allows you to "dial" an HTTP handler that sets up a websocket using only in-memory
|
||||
// primitives. No TCP or OS networking needed. CtxMutator gives you explicit control over the context the handler sees.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// rt := testutil.InMemWebsocketRoundTripper{
|
||||
// Handler: MyHandler,
|
||||
// CtxMutator: func(ctx context.Context) context.Context {
|
||||
// ctx = httpmw.WithWorkspaceParam(ctx, ws)
|
||||
// ctx = dbauthz.As(ctx, coderdtest.MemberSubject(userID, orgID))
|
||||
// return ctx
|
||||
// },
|
||||
// Logger: logger.Named("roundtripper"),
|
||||
// }
|
||||
// clientSock, _, err := websocket.Dial(ctx, "wss://local.test/", &websocket.DialOptions{
|
||||
// HTTPClient: &http.Client{Transport: rt},
|
||||
// })
|
||||
type InMemWebsocketRoundTripper struct {
|
||||
Logger slog.Logger
|
||||
Handler http.Handler
|
||||
CtxMutator func(ctx context.Context) context.Context
|
||||
}
|
||||
|
||||
func (i InMemWebsocketRoundTripper) RoundTrip(request *http.Request) (*http.Response, error) {
|
||||
i.Logger.Debug(context.Background(), "round trip start")
|
||||
defer i.Logger.Debug(context.Background(), "round trip end")
|
||||
newCtx := i.CtxMutator(request.Context())
|
||||
request = request.WithContext(newCtx)
|
||||
serverP, clientP := net.Pipe()
|
||||
var _ io.ReadWriteCloser = clientP // compile time check that response body is OK for websocket
|
||||
response := &http.Response{
|
||||
Header: make(http.Header),
|
||||
Body: clientP,
|
||||
}
|
||||
rw := newInMemWebsocketResponseWriter(response, serverP)
|
||||
go func() {
|
||||
i.Handler.ServeHTTP(rw, request)
|
||||
if !rw.hijacked {
|
||||
i.Logger.Debug(context.Background(), "closing connection after handler did not hijack")
|
||||
// If the handler didn't hijack the connection, we should close it when the handler finishes.
|
||||
// This prevents a 3s delay in websocket.Dial() reading the non-upgraded response.
|
||||
_ = serverP.Close()
|
||||
}
|
||||
}()
|
||||
select {
|
||||
case <-newCtx.Done():
|
||||
return nil, newCtx.Err()
|
||||
case <-rw.gotHeaders:
|
||||
return response, nil
|
||||
}
|
||||
}
|
||||
|
||||
func newInMemWebsocketResponseWriter(resp *http.Response, conn net.Conn) *inMemWebsocketResponseWriter {
|
||||
r := bufio.NewReader(conn)
|
||||
w := bufio.NewWriter(conn)
|
||||
return &inMemWebsocketResponseWriter{
|
||||
r: resp,
|
||||
b: bufio.NewReadWriter(r, w),
|
||||
gotHeaders: make(chan struct{}),
|
||||
conn: conn,
|
||||
}
|
||||
}
|
||||
|
||||
type inMemWebsocketResponseWriter struct {
|
||||
r *http.Response
|
||||
b *bufio.ReadWriter
|
||||
gotHeaders chan struct{}
|
||||
hijacked bool
|
||||
conn net.Conn
|
||||
}
|
||||
|
||||
func (rw *inMemWebsocketResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
rw.hijacked = true
|
||||
return rw.conn, rw.b, nil
|
||||
}
|
||||
|
||||
func (rw *inMemWebsocketResponseWriter) Header() http.Header {
|
||||
return rw.r.Header
|
||||
}
|
||||
|
||||
func (rw *inMemWebsocketResponseWriter) Write([]byte) (int, error) {
|
||||
n, err := rw.b.Write([]byte{})
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (rw *inMemWebsocketResponseWriter) WriteHeader(statusCode int) {
|
||||
rw.r.StatusCode = statusCode
|
||||
close(rw.gotHeaders)
|
||||
}
|
||||
Reference in New Issue
Block a user