Files
coder/coderd/workspaceagents_test.go
T
Atif Ali 584c61acb5 fix: mark connecting agents as unhealthy instead of healthy (#24044)
## Problem

Workspaces showed as "Healthy" immediately after creation while the
agent was still downloading, starting, or connecting. If the agent never
connected, the workspace stayed "Healthy" for the entire connection
timeout (~120s), then abruptly flipped to "Unhealthy".

## Root cause

In `db2sdk.WorkspaceAgent`, the health switch had no case for
`WorkspaceAgentConnecting`. Agents in `connecting` status with a
non-`off` lifecycle (e.g. `created` after a fresh build) fell through to
the `default` case and were marked `Healthy = true`.

## Fix

Add an explicit case for `WorkspaceAgentConnecting` that sets `Healthy =
false` with reason `"agent has not yet connected"`. The case is placed
after the existing `!connected + off` case (which correctly catches
stopped agents as "not running") and before the `timeout`/`disconnected`
cases.

```
Status        + Lifecycle       → Health reason
──────────────────────────────────────────────────────
any !connected + off           → "agent is not running"
connecting    + created/starting → "agent has not yet connected"  ← NEW
timeout       + any            → "agent is taking too long to connect"
disconnected  + any            → "agent has lost connection"
connected     + start_error    → "agent startup script exited with an error"
connected     + shutting_down  → "agent is shutting down"
connected     + ready/starting → healthy
```

The frontend already handles this case — `getAgentHealthIssue()` returns
"Workspace agent is still connecting" with `severity: "info"` for
unhealthy workspaces with connecting agents.

## Test changes

- **Healthy test**: now actually connects the agent via `agenttest.New`
before asserting health (previously passed due to the bug).
- **New Connecting test**: verifies a never-connected agent is correctly
marked unhealthy.
- **Mixed health test**: connects a1 and waits for the mixed state
(`a1.Healthy && !workspace.Healthy`) to avoid a race where both agents
are initially connecting.
- **Sub-agent excluded test**: connects the parent agent and waits for
it to be healthy before creating the sub-agent.
- **TestWorkspaceAgent/Connect**: flipped assertion to `Health.Healthy
== false` for a `dbfake` agent that never connects.

<details>
<summary>Review notes</summary>

### Known follow-up

The `healthy:false` workspace search filter maps to `[disconnected,
timeout]` and does not include `connecting`. This is a pre-existing gap
that is now more consequential — a workspace unhealthy solely due to a
connecting agent won't appear in `healthy:false` results. Worth a
follow-up issue.

### Deep review findings addressed

| Finding | Severity | Status |
|---------|----------|--------|
| Mixed health test race (all 3 reviewers) | P2 | Fixed — tightened
`Eventually` condition |
| `TestWorkspaceAgent/Connect` assertion break | P1 | Fixed — flipped
assertion |
| CLI renders red for connecting agents | Obs | Acknowledged — design
trade-off, accurate but visually strong for transient state |
| Switch case ordering overlap | Obs | Documented with inline comment |

</details>

> 🤖 This PR was created with the help of Coder Agents, and needs a human
review. 🧑💻
2026-04-09 13:21:28 +05:00

3543 lines
115 KiB
Go

package coderd_test
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"io"
"maps"
"net/http"
"os"
"path/filepath"
"slices"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/go-jose/go-jose/v4/jwt"
"github.com/google/go-cmp/cmp"
"github.com/google/uuid"
"github.com/ory/dockertest/v3"
"github.com/ory/dockertest/v3/docker"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"golang.org/x/xerrors"
"google.golang.org/protobuf/types/known/timestamppb"
"tailscale.com/tailcfg"
"cdr.dev/slog/v3"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/agent"
"github.com/coder/coder/v2/agent/agentcontainers"
"github.com/coder/coder/v2/agent/agentcontainers/acmock"
"github.com/coder/coder/v2/agent/agentcontainers/watcher"
"github.com/coder/coder/v2/agent/agenttest"
agentproto "github.com/coder/coder/v2/agent/proto"
"github.com/coder/coder/v2/coderd/agentapi/metadatabatcher"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/coderdtest/oidctest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/db2sdk"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbfake"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/coderd/database/pubsub"
"github.com/coder/coder/v2/coderd/externalauth"
"github.com/coder/coder/v2/coderd/jwtutils"
"github.com/coder/coder/v2/coderd/prebuilds"
"github.com/coder/coder/v2/coderd/rbac"
"github.com/coder/coder/v2/coderd/telemetry"
"github.com/coder/coder/v2/coderd/util/ptr"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/agentsdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
"github.com/coder/coder/v2/provisioner/echo"
"github.com/coder/coder/v2/provisionersdk/proto"
"github.com/coder/coder/v2/tailnet"
tailnetproto "github.com/coder/coder/v2/tailnet/proto"
"github.com/coder/coder/v2/tailnet/tailnettest"
"github.com/coder/coder/v2/testutil"
"github.com/coder/quartz"
"github.com/coder/websocket"
)
func TestWorkspaceAgent(t *testing.T) {
t.Parallel()
t.Run("Connect", func(t *testing.T) {
t.Parallel()
client, db := coderdtest.NewWithDatabase(t, nil)
user := coderdtest.CreateFirstUser(t, client)
tmpDir := t.TempDir()
anotherClient, anotherUser := coderdtest.CreateAnotherUser(t, client, user.OrganizationID)
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: user.OrganizationID,
OwnerID: anotherUser.ID,
}).WithAgent(func(agents []*proto.Agent) []*proto.Agent {
agents[0].Directory = tmpDir
return agents
}).Do()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
workspace, err := anotherClient.Workspace(ctx, r.Workspace.ID)
require.NoError(t, err)
require.Equal(t, tmpDir, workspace.LatestBuild.Resources[0].Agents[0].Directory)
_, err = anotherClient.WorkspaceAgent(ctx, workspace.LatestBuild.Resources[0].Agents[0].ID)
require.NoError(t, err)
require.False(t, workspace.LatestBuild.Resources[0].Agents[0].Health.Healthy)
})
t.Run("HasFallbackTroubleshootingURL", func(t *testing.T) {
t.Parallel()
client, db := coderdtest.NewWithDatabase(t, nil)
user := coderdtest.CreateFirstUser(t, client)
tmpDir := t.TempDir()
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: user.OrganizationID,
OwnerID: user.UserID,
}).WithAgent(func(agents []*proto.Agent) []*proto.Agent {
agents[0].Directory = tmpDir
return agents
}).Do()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
defer cancel()
workspace, err := client.Workspace(ctx, r.Workspace.ID)
require.NoError(t, err)
require.NotEmpty(t, workspace.LatestBuild.Resources[0].Agents[0].TroubleshootingURL)
t.Log(workspace.LatestBuild.Resources[0].Agents[0].TroubleshootingURL)
})
t.Run("Timeout", func(t *testing.T) {
t.Parallel()
// timeouts can cause error logs to be dropped on shutdown
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
client, db := coderdtest.NewWithDatabase(t, &coderdtest.Options{
Logger: &logger,
})
user := coderdtest.CreateFirstUser(t, client)
tmpDir := t.TempDir()
wantTroubleshootingURL := "https://example.com/troubleshoot"
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: user.OrganizationID,
OwnerID: user.UserID,
}).WithAgent(func(agents []*proto.Agent) []*proto.Agent {
agents[0].Directory = tmpDir
agents[0].ConnectionTimeoutSeconds = 1
agents[0].TroubleshootingUrl = wantTroubleshootingURL
return agents
}).Do()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
defer cancel()
var err error
var workspace codersdk.Workspace
testutil.Eventually(ctx, t, func(ctx context.Context) (done bool) {
workspace, err = client.Workspace(ctx, r.Workspace.ID)
if !assert.NoError(t, err) {
return false
}
return workspace.LatestBuild.Resources[0].Agents[0].Status == codersdk.WorkspaceAgentTimeout
}, testutil.IntervalMedium, "agent status timeout")
require.Equal(t, wantTroubleshootingURL, workspace.LatestBuild.Resources[0].Agents[0].TroubleshootingURL)
require.False(t, workspace.LatestBuild.Resources[0].Agents[0].Health.Healthy)
require.NotEmpty(t, workspace.LatestBuild.Resources[0].Agents[0].Health.Reason)
})
t.Run("DisplayApps", func(t *testing.T) {
t.Parallel()
client, db := coderdtest.NewWithDatabase(t, nil)
user := coderdtest.CreateFirstUser(t, client)
tmpDir := t.TempDir()
apps := &proto.DisplayApps{
Vscode: true,
VscodeInsiders: true,
WebTerminal: true,
PortForwardingHelper: true,
SshHelper: true,
}
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: user.OrganizationID,
OwnerID: user.UserID,
}).WithAgent(func(agents []*proto.Agent) []*proto.Agent {
agents[0].Directory = tmpDir
agents[0].DisplayApps = apps
return agents
}).Do()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
workspace, err := client.Workspace(ctx, r.Workspace.ID)
require.NoError(t, err)
agent, err := client.WorkspaceAgent(ctx, workspace.LatestBuild.Resources[0].Agents[0].ID)
require.NoError(t, err)
expectedApps := []codersdk.DisplayApp{
codersdk.DisplayAppPortForward,
codersdk.DisplayAppSSH,
codersdk.DisplayAppVSCodeDesktop,
codersdk.DisplayAppVSCodeInsiders,
codersdk.DisplayAppWebTerminal,
}
require.ElementsMatch(t, expectedApps, agent.DisplayApps)
// Flips all the apps to false.
apps.PortForwardingHelper = false
apps.Vscode = false
apps.VscodeInsiders = false
apps.SshHelper = false
apps.WebTerminal = false
// Creating another workspace is easier
r = dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: user.OrganizationID,
OwnerID: user.UserID,
}).WithAgent(func(agents []*proto.Agent) []*proto.Agent {
agents[0].Directory = tmpDir
agents[0].DisplayApps = apps
return agents
}).Do()
workspace, err = client.Workspace(ctx, r.Workspace.ID)
require.NoError(t, err)
agent, err = client.WorkspaceAgent(ctx, workspace.LatestBuild.Resources[0].Agents[0].ID)
require.NoError(t, err)
require.Len(t, agent.DisplayApps, 0)
})
}
func TestWorkspaceAgentLogs(t *testing.T) {
t.Parallel()
t.Run("Success", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
client, db := coderdtest.NewWithDatabase(t, nil)
user := coderdtest.CreateFirstUser(t, client)
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: user.OrganizationID,
OwnerID: user.UserID,
}).WithAgent().Do()
agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(r.AgentToken))
err := agentClient.PatchLogs(ctx, agentsdk.PatchLogs{
Logs: []agentsdk.Log{
{
CreatedAt: dbtime.Now(),
Output: "testing",
},
{
CreatedAt: dbtime.Now(),
Output: "testing2",
},
},
})
require.NoError(t, err)
workspace, err := client.Workspace(ctx, r.Workspace.ID)
require.NoError(t, err)
logs, closer, err := client.WorkspaceAgentLogsAfter(ctx, workspace.LatestBuild.Resources[0].Agents[0].ID, 0, true)
require.NoError(t, err)
defer func() {
_ = closer.Close()
}()
var logChunk []codersdk.WorkspaceAgentLog
select {
case <-ctx.Done():
case logChunk = <-logs:
}
require.NoError(t, ctx.Err())
require.Len(t, logChunk, 2) // No EOF.
require.Equal(t, "testing", logChunk[0].Output)
require.Equal(t, "testing2", logChunk[1].Output)
})
t.Run("SanitizesNulBytesAndTracksSanitizedLength", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
client, db := coderdtest.NewWithDatabase(t, nil)
user := coderdtest.CreateFirstUser(t, client)
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: user.OrganizationID,
OwnerID: user.UserID,
}).WithAgent().Do()
rawOutput := "before\x00after"
sanitizedOutput := agentsdk.SanitizeLogOutput(rawOutput)
agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(r.AgentToken))
err := agentClient.PatchLogs(ctx, agentsdk.PatchLogs{
Logs: []agentsdk.Log{
{
CreatedAt: dbtime.Now(),
Output: rawOutput,
},
},
})
require.NoError(t, err)
agent, err := db.GetWorkspaceAgentByID(dbauthz.AsSystemRestricted(ctx), r.Agents[0].ID)
require.NoError(t, err)
require.EqualValues(t, len(sanitizedOutput), agent.LogsLength)
workspace, err := client.Workspace(ctx, r.Workspace.ID)
require.NoError(t, err)
logs, closer, err := client.WorkspaceAgentLogsAfter(ctx, workspace.LatestBuild.Resources[0].Agents[0].ID, 0, true)
require.NoError(t, err)
defer func() {
_ = closer.Close()
}()
var logChunk []codersdk.WorkspaceAgentLog
select {
case <-ctx.Done():
case logChunk = <-logs:
}
require.NoError(t, ctx.Err())
require.Len(t, logChunk, 1)
require.Equal(t, sanitizedOutput, logChunk[0].Output)
})
t.Run("Close logs on outdated build", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
client, db := coderdtest.NewWithDatabase(t, nil)
user := coderdtest.CreateFirstUser(t, client)
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: user.OrganizationID,
OwnerID: user.UserID,
}).WithAgent().Do()
agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(r.AgentToken))
err := agentClient.PatchLogs(ctx, agentsdk.PatchLogs{
Logs: []agentsdk.Log{
{
CreatedAt: dbtime.Now(),
Output: "testing",
},
},
})
require.NoError(t, err)
workspace, err := client.Workspace(ctx, r.Workspace.ID)
require.NoError(t, err)
logs, closer, err := client.WorkspaceAgentLogsAfter(ctx, workspace.LatestBuild.Resources[0].Agents[0].ID, 0, true)
require.NoError(t, err)
defer func() {
_ = closer.Close()
}()
select {
case <-ctx.Done():
require.FailNow(t, "context done while waiting for first log")
case <-logs:
}
_ = coderdtest.CreateWorkspaceBuild(t, client, workspace, database.WorkspaceTransitionStart)
select {
case <-ctx.Done():
require.FailNow(t, "context done while waiting for logs close")
case <-logs:
}
})
t.Run("PublishesOnOverflow", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
client, db := coderdtest.NewWithDatabase(t, nil)
user := coderdtest.CreateFirstUser(t, client)
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: user.OrganizationID,
OwnerID: user.UserID,
}).WithAgent().Do()
updates, err := client.WatchWorkspace(ctx, r.Workspace.ID)
require.NoError(t, err)
agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(r.AgentToken))
err = agentClient.PatchLogs(ctx, agentsdk.PatchLogs{
Logs: []agentsdk.Log{{
CreatedAt: dbtime.Now(),
Output: strings.Repeat("a", (1<<20)+1),
}},
})
var apiError *codersdk.Error
require.ErrorAs(t, err, &apiError)
require.Equal(t, http.StatusRequestEntityTooLarge, apiError.StatusCode())
// It's possible we have multiple updates queued, but that's alright, we just
// wait for the one where it overflows.
for {
var update codersdk.Workspace
select {
case <-ctx.Done():
t.FailNow()
case update = <-updates:
}
if update.LatestBuild.Resources[0].Agents[0].LogsOverflowed {
break
}
}
})
}
func TestWorkspaceAgentLogsFormat(t *testing.T) {
t.Parallel()
client, db := coderdtest.NewWithDatabase(t, nil)
user := coderdtest.CreateFirstUser(t, client)
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: user.OrganizationID,
OwnerID: user.UserID,
}).WithAgent().Do()
workspaceAgent := r.Agents[0]
logSource := dbgen.WorkspaceAgentLogSource(t, db, database.WorkspaceAgentLogSource{
WorkspaceAgentID: workspaceAgent.ID,
DisplayName: "startup_script",
})
agentLog := dbgen.WorkspaceAgentLog(t, db, database.WorkspaceAgentLog{
AgentID: workspaceAgent.ID,
LogSourceID: logSource.ID,
Output: "test log output",
Level: database.LogLevelInfo,
})
tests := []struct {
name string
queryParams string
expectedStatus int
expectedContentType string
checkBody func(string)
}{
{
name: "JSON",
queryParams: "",
expectedStatus: http.StatusOK,
expectedContentType: "application/json",
checkBody: func(body string) {
assert.NotEmpty(t, body)
},
},
{
name: "Text",
queryParams: "?format=text",
expectedStatus: http.StatusOK,
expectedContentType: "text/plain",
checkBody: func(body string) {
expected := db2sdk.WorkspaceAgentLog(agentLog).Text(workspaceAgent.Name, logSource.DisplayName)
assert.Contains(t, body, expected)
},
},
{
name: "InvalidFormat",
queryParams: "?format=invalid",
expectedStatus: http.StatusBadRequest,
checkBody: func(body string) {
assert.Contains(t, body, "Invalid format")
},
},
{
name: "TextWithFollowFails",
queryParams: "?format=text&follow",
expectedStatus: http.StatusBadRequest,
checkBody: func(body string) {
assert.Contains(t, body, "not supported with follow mode")
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
urlPath := fmt.Sprintf("/api/v2/workspaceagents/%s/logs%s", workspaceAgent.ID, tt.queryParams)
res, err := client.Request(ctx, http.MethodGet, urlPath, nil)
require.NoError(t, err)
defer res.Body.Close()
require.Equal(t, tt.expectedStatus, res.StatusCode)
if tt.expectedContentType != "" {
require.Contains(t, res.Header.Get("Content-Type"), tt.expectedContentType)
}
if assert.NotNil(t, tt.checkBody) {
body, err := io.ReadAll(res.Body)
require.NoError(t, err)
tt.checkBody(string(body))
}
})
}
}
func TestWorkspaceAgentAppStatus(t *testing.T) {
t.Parallel()
client, db := coderdtest.NewWithDatabase(t, nil)
user := coderdtest.CreateFirstUser(t, client)
client, user2 := coderdtest.CreateAnotherUser(t, client, user.OrganizationID)
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: user.OrganizationID,
OwnerID: user2.ID,
}).WithAgent(func(a []*proto.Agent) []*proto.Agent {
a[0].Apps = []*proto.App{
{
Slug: "vscode",
},
}
return a
}).Do()
agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(r.AgentToken))
t.Run("Success", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
err := agentClient.PatchAppStatus(ctx, agentsdk.PatchAppStatus{
AppSlug: "vscode",
Message: "testing",
URI: "https://example.com",
State: codersdk.WorkspaceAppStatusStateComplete,
// Ensure deprecated fields are ignored.
Icon: "https://example.com/icon.png",
NeedsUserAttention: true,
})
require.NoError(t, err)
workspace, err := client.Workspace(ctx, r.Workspace.ID)
require.NoError(t, err)
agent, err := client.WorkspaceAgent(ctx, workspace.LatestBuild.Resources[0].Agents[0].ID)
require.NoError(t, err)
require.Len(t, agent.Apps[0].Statuses, 1)
// Deprecated fields should be ignored.
require.Empty(t, agent.Apps[0].Statuses[0].Icon)
require.False(t, agent.Apps[0].Statuses[0].NeedsUserAttention)
})
t.Run("FailUnknownApp", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
err := agentClient.PatchAppStatus(ctx, agentsdk.PatchAppStatus{
AppSlug: "unknown",
Message: "testing",
URI: "https://example.com",
State: codersdk.WorkspaceAppStatusStateComplete,
})
require.ErrorContains(t, err, "No app found with slug")
var sdkErr *codersdk.Error
require.ErrorAs(t, err, &sdkErr)
require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
})
t.Run("FailUnknownState", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
err := agentClient.PatchAppStatus(ctx, agentsdk.PatchAppStatus{
AppSlug: "vscode",
Message: "testing",
URI: "https://example.com",
State: "unknown",
})
require.ErrorContains(t, err, "Invalid state")
var sdkErr *codersdk.Error
require.ErrorAs(t, err, &sdkErr)
require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
})
t.Run("FailTooLong", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
err := agentClient.PatchAppStatus(ctx, agentsdk.PatchAppStatus{
AppSlug: "vscode",
Message: strings.Repeat("a", 161),
URI: "https://example.com",
State: codersdk.WorkspaceAppStatusStateComplete,
})
require.ErrorContains(t, err, "Message is too long")
var sdkErr *codersdk.Error
require.ErrorAs(t, err, &sdkErr)
require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
})
}
func TestWorkspaceAgentAppStatus_ActivityBump(t *testing.T) {
t.Parallel()
client, db := coderdtest.NewWithDatabase(t, nil)
user := coderdtest.CreateFirstUser(t, client)
tests := []struct {
name string
prevState *codersdk.WorkspaceAppStatusState // nil means no previous state
newState codersdk.WorkspaceAppStatusState
shouldBump bool
}{
{
name: "FirstStatusBumps",
prevState: nil,
newState: codersdk.WorkspaceAppStatusStateWorking,
shouldBump: true,
},
{
name: "WorkingToIdleBumps",
prevState: ptr.Ref(codersdk.WorkspaceAppStatusStateWorking),
newState: codersdk.WorkspaceAppStatusStateIdle,
shouldBump: true,
},
{
name: "WorkingToCompleteBumps",
prevState: ptr.Ref(codersdk.WorkspaceAppStatusStateWorking),
newState: codersdk.WorkspaceAppStatusStateComplete,
shouldBump: true,
},
{
name: "CompleteToIdleNoBump",
prevState: ptr.Ref(codersdk.WorkspaceAppStatusStateComplete),
newState: codersdk.WorkspaceAppStatusStateIdle,
shouldBump: false,
},
{
name: "CompleteToCompleteNoBump",
prevState: ptr.Ref(codersdk.WorkspaceAppStatusStateComplete),
newState: codersdk.WorkspaceAppStatusStateComplete,
shouldBump: false,
},
{
name: "FailureToIdleNoBump",
prevState: ptr.Ref(codersdk.WorkspaceAppStatusStateFailure),
newState: codersdk.WorkspaceAppStatusStateIdle,
shouldBump: false,
},
{
name: "FailureToFailureNoBump",
prevState: ptr.Ref(codersdk.WorkspaceAppStatusStateFailure),
newState: codersdk.WorkspaceAppStatusStateFailure,
shouldBump: false,
},
{
name: "CompleteToWorkingBumps",
prevState: ptr.Ref(codersdk.WorkspaceAppStatusStateComplete),
newState: codersdk.WorkspaceAppStatusStateWorking,
shouldBump: true,
},
{
name: "FailureToCompleteNoBump",
prevState: ptr.Ref(codersdk.WorkspaceAppStatusStateFailure),
newState: codersdk.WorkspaceAppStatusStateComplete,
shouldBump: false,
},
{
name: "WorkingToFailureBumps",
prevState: ptr.Ref(codersdk.WorkspaceAppStatusStateWorking),
newState: codersdk.WorkspaceAppStatusStateFailure,
shouldBump: true,
},
{
name: "IdleToIdleNoBump",
prevState: ptr.Ref(codersdk.WorkspaceAppStatusStateIdle),
newState: codersdk.WorkspaceAppStatusStateIdle,
shouldBump: false,
},
{
name: "IdleToWorkingBumps",
prevState: ptr.Ref(codersdk.WorkspaceAppStatusStateIdle),
newState: codersdk.WorkspaceAppStatusStateWorking,
shouldBump: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
// Create workspace with agent and app.
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: user.OrganizationID,
OwnerID: user.UserID,
}).WithAgent(func(a []*proto.Agent) []*proto.Agent {
a[0].Apps = []*proto.App{{Slug: "test-app"}}
return a
}).Do()
ctx := testutil.Context(t, testutil.WaitMedium)
// Configure template with activity_bump to enable deadline bumping.
_, err := client.UpdateTemplateMeta(ctx, r.Template.ID, codersdk.UpdateTemplateMeta{
ActivityBumpMillis: time.Hour.Milliseconds(),
})
require.NoError(t, err)
// Set the workspace build deadline to the past to ensure the 5%
// threshold is met for activity bumping.
pastDeadline := dbtime.Now().Add(-30 * time.Minute)
err = db.UpdateWorkspaceBuildDeadlineByID(dbauthz.AsSystemRestricted(ctx), database.UpdateWorkspaceBuildDeadlineByIDParams{
ID: r.Build.ID,
UpdatedAt: dbtime.Now(),
Deadline: pastDeadline,
MaxDeadline: time.Time{},
})
require.NoError(t, err)
agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(r.AgentToken))
// If there's a previous state, report it first.
if tt.prevState != nil {
err := agentClient.PatchAppStatus(ctx, agentsdk.PatchAppStatus{
AppSlug: "test-app",
State: *tt.prevState,
Message: "previous state",
})
require.NoError(t, err)
// Reset deadline to past again to meet 5% threshold for next bump.
err = db.UpdateWorkspaceBuildDeadlineByID(dbauthz.AsSystemRestricted(ctx), database.UpdateWorkspaceBuildDeadlineByIDParams{
ID: r.Build.ID,
UpdatedAt: dbtime.Now(),
Deadline: pastDeadline,
MaxDeadline: time.Time{},
})
require.NoError(t, err)
}
// Get the deadline before the new status report.
beforeBuild, err := db.GetWorkspaceBuildByID(dbauthz.AsSystemRestricted(ctx), r.Build.ID)
require.NoError(t, err)
beforeDeadline := beforeBuild.Deadline
// Report the new state.
err = agentClient.PatchAppStatus(ctx, agentsdk.PatchAppStatus{
AppSlug: "test-app",
State: tt.newState,
Message: "new state",
})
require.NoError(t, err)
// Check if deadline changed.
afterBuild, err := db.GetWorkspaceBuildByID(dbauthz.AsSystemRestricted(ctx), r.Build.ID)
require.NoError(t, err)
afterDeadline := afterBuild.Deadline
didBump := afterDeadline.After(beforeDeadline)
if tt.shouldBump {
require.True(t, didBump, "wanted deadline to bump but it didn't")
} else {
require.False(t, didBump, "wanted deadline not to bump but it did")
}
})
}
}
func TestWorkspaceAgentConnectRPC(t *testing.T) {
t.Parallel()
t.Run("Connect", func(t *testing.T) {
t.Parallel()
for _, tc := range []struct {
name string
apiKeyScope rbac.ScopeName
}{
{
name: "empty (backwards compat)",
apiKeyScope: "",
},
{
name: "all",
apiKeyScope: rbac.ScopeAll,
},
{
name: "no_user_data",
apiKeyScope: rbac.ScopeNoUserData,
},
{
name: "application_connect",
apiKeyScope: rbac.ScopeApplicationConnect,
},
} {
t.Run(tc.name, func(t *testing.T) {
client, db := coderdtest.NewWithDatabase(t, nil)
user := coderdtest.CreateFirstUser(t, client)
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: user.OrganizationID,
OwnerID: user.UserID,
}).WithAgent(func(agents []*proto.Agent) []*proto.Agent {
for _, agent := range agents {
agent.ApiKeyScope = string(tc.apiKeyScope)
}
return agents
}).Do()
_ = agenttest.New(t, client.URL, r.AgentToken)
resources := coderdtest.NewWorkspaceAgentWaiter(t, client, r.Workspace.ID).AgentNames([]string{}).Wait()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
conn, err := workspacesdk.New(client).
DialAgent(ctx, resources[0].Agents[0].ID, nil)
require.NoError(t, err)
defer func() {
_ = conn.Close()
}()
conn.AwaitReachable(ctx)
})
}
})
t.Run("FailNonLatestBuild", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, &coderdtest.Options{
IncludeProvisionerDaemon: true,
})
user := coderdtest.CreateFirstUser(t, client)
authToken := uuid.NewString()
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
Parse: echo.ParseComplete,
ProvisionPlan: echo.PlanComplete,
ProvisionGraph: echo.ProvisionGraphWithAgent(authToken),
})
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
workspace := coderdtest.CreateWorkspace(t, client, template.ID)
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
version = coderdtest.UpdateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
Parse: echo.ParseComplete,
ProvisionPlan: echo.PlanComplete,
ProvisionGraph: []*proto.Response{{
Type: &proto.Response_Graph{
Graph: &proto.GraphComplete{
Resources: []*proto.Resource{{
Name: "example",
Type: "aws_instance",
Agents: []*proto.Agent{{
Id: uuid.NewString(),
Name: "dev",
Auth: &proto.Agent_Token{
Token: uuid.NewString(),
},
}},
}},
},
},
}},
}, template.ID)
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
stopBuild, err := client.CreateWorkspaceBuild(ctx, workspace.ID, codersdk.CreateWorkspaceBuildRequest{
TemplateVersionID: version.ID,
Transition: codersdk.WorkspaceTransitionStop,
})
require.NoError(t, err)
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, stopBuild.ID)
agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(authToken))
_, err = agentClient.ConnectRPC(ctx)
require.Error(t, err)
var sdkErr *codersdk.Error
require.ErrorAs(t, err, &sdkErr)
require.Equal(t, http.StatusUnauthorized, sdkErr.StatusCode())
})
t.Run("FailDeleted", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client, db := coderdtest.NewWithDatabase(t, nil)
user := coderdtest.CreateFirstUser(t, client)
// Given: a workspace exists
seed := database.WorkspaceTable{OrganizationID: user.OrganizationID, OwnerID: user.UserID}
wsb := dbfake.WorkspaceBuild(t, db, seed).WithAgent().Do()
// When: the workspace is marked as soft-deleted
err := db.UpdateWorkspaceDeletedByID(
dbauthz.AsProvisionerd(ctx),
database.UpdateWorkspaceDeletedByIDParams{ID: wsb.Workspace.ID, Deleted: true},
)
require.NoError(t, err)
// Then: the agent token should no longer be valid
agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken((wsb.AgentToken)))
_, err = agentClient.ConnectRPC(ctx)
require.Error(t, err)
var sdkErr *codersdk.Error
require.ErrorAs(t, err, &sdkErr)
// Then: we should get a 401 Unauthorized response
require.Equal(t, http.StatusUnauthorized, sdkErr.StatusCode())
})
}
func TestWorkspaceAgentTailnet(t *testing.T) {
t.Parallel()
client, db := coderdtest.NewWithDatabase(t, nil)
user := coderdtest.CreateFirstUser(t, client)
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: user.OrganizationID,
OwnerID: user.UserID,
}).WithAgent().Do()
_ = agenttest.New(t, client.URL, r.AgentToken)
resources := coderdtest.AwaitWorkspaceAgents(t, client, r.Workspace.ID)
conn, err := func() (workspacesdk.AgentConn, error) {
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel() // Connection should remain open even if the dial context is canceled.
return workspacesdk.New(client).
DialAgent(ctx, resources[0].Agents[0].ID, &workspacesdk.DialAgentOptions{
Logger: testutil.Logger(t).Named("client"),
})
}()
require.NoError(t, err)
defer conn.Close()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
sshClient, err := conn.SSHClient(ctx)
require.NoError(t, err)
session, err := sshClient.NewSession()
require.NoError(t, err)
output, err := session.CombinedOutput("echo test")
require.NoError(t, err)
_ = session.Close()
_ = sshClient.Close()
_ = conn.Close()
require.Equal(t, "test", strings.TrimSpace(string(output)))
}
func TestWorkspaceAgentClientCoordinate_BadVersion(t *testing.T) {
t.Parallel()
client, db := coderdtest.NewWithDatabase(t, nil)
user := coderdtest.CreateFirstUser(t, client)
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: user.OrganizationID,
OwnerID: user.UserID,
}).WithAgent().Do()
ctx := testutil.Context(t, testutil.WaitShort)
agentToken, err := uuid.Parse(r.AgentToken)
require.NoError(t, err)
ao, err := db.GetAuthenticatedWorkspaceAgentAndBuildByAuthToken(dbauthz.AsSystemRestricted(ctx), agentToken)
require.NoError(t, err)
//nolint: bodyclose // closed by ReadBodyAsError
resp, err := client.Request(ctx, http.MethodGet,
fmt.Sprintf("api/v2/workspaceagents/%s/coordinate", ao.WorkspaceAgent.ID),
nil,
codersdk.WithQueryParam("version", "99.99"))
require.NoError(t, err)
require.Equal(t, http.StatusBadRequest, resp.StatusCode)
err = codersdk.ReadBodyAsError(resp)
var sdkErr *codersdk.Error
require.ErrorAs(t, err, &sdkErr)
require.Equal(t, "Unknown or unsupported API version", sdkErr.Message)
require.Len(t, sdkErr.Validations, 1)
require.Equal(t, "version", sdkErr.Validations[0].Field)
}
type resumeTokenRecordingProvider struct {
tailnet.ResumeTokenProvider
t testing.TB
generateCalls chan uuid.UUID
verifyCalls chan string
}
var _ tailnet.ResumeTokenProvider = &resumeTokenRecordingProvider{}
func newResumeTokenRecordingProvider(t testing.TB, underlying tailnet.ResumeTokenProvider) *resumeTokenRecordingProvider {
return &resumeTokenRecordingProvider{
ResumeTokenProvider: underlying,
t: t,
generateCalls: make(chan uuid.UUID, 1),
verifyCalls: make(chan string, 1),
}
}
func (r *resumeTokenRecordingProvider) GenerateResumeToken(ctx context.Context, peerID uuid.UUID) (*tailnetproto.RefreshResumeTokenResponse, error) {
select {
case r.generateCalls <- peerID:
return r.ResumeTokenProvider.GenerateResumeToken(ctx, peerID)
default:
r.t.Error("generateCalls full")
return nil, xerrors.New("generateCalls full")
}
}
func (r *resumeTokenRecordingProvider) VerifyResumeToken(ctx context.Context, token string) (uuid.UUID, error) {
select {
case r.verifyCalls <- token:
return r.ResumeTokenProvider.VerifyResumeToken(ctx, token)
default:
r.t.Error("verifyCalls full")
return uuid.Nil, xerrors.New("verifyCalls full")
}
}
func TestWorkspaceAgentClientCoordinate_ResumeToken(t *testing.T) {
t.Parallel()
t.Run("OK", func(t *testing.T) {
t.Parallel()
logger := testutil.Logger(t)
clock := quartz.NewMock(t)
resumeTokenSigningKey, err := tailnet.GenerateResumeTokenSigningKey()
mgr := jwtutils.StaticKey{
ID: uuid.New().String(),
Key: resumeTokenSigningKey[:],
}
require.NoError(t, err)
resumeTokenProvider := newResumeTokenRecordingProvider(
t,
tailnet.NewResumeTokenKeyProvider(mgr, clock, time.Hour),
)
client, closer, api := coderdtest.NewWithAPI(t, &coderdtest.Options{
Coordinator: tailnet.NewCoordinator(logger),
CoordinatorResumeTokenProvider: resumeTokenProvider,
})
defer closer.Close()
user := coderdtest.CreateFirstUser(t, client)
// Create a workspace with an agent. No need to connect it since clients can
// still connect to the coordinator while the agent isn't connected.
r := dbfake.WorkspaceBuild(t, api.Database, database.WorkspaceTable{
OrganizationID: user.OrganizationID,
OwnerID: user.UserID,
}).WithAgent().Do()
agentTokenUUID, err := uuid.Parse(r.AgentToken)
require.NoError(t, err)
ctx := testutil.Context(t, testutil.WaitLong)
agentAndBuild, err := api.Database.GetAuthenticatedWorkspaceAgentAndBuildByAuthToken(dbauthz.AsSystemRestricted(ctx), agentTokenUUID)
require.NoError(t, err)
// Connect with no resume token, and ensure that the peer ID is set to a
// random value.
originalResumeToken, err := connectToCoordinatorAndFetchResumeToken(ctx, logger, client, agentAndBuild.WorkspaceAgent.ID, "")
require.NoError(t, err)
originalPeerID := testutil.TryReceive(ctx, t, resumeTokenProvider.generateCalls)
require.NotEqual(t, originalPeerID, uuid.Nil)
// Connect with a valid resume token, and ensure that the peer ID is set to
// the stored value.
clock.Advance(time.Second)
newResumeToken, err := connectToCoordinatorAndFetchResumeToken(ctx, logger, client, agentAndBuild.WorkspaceAgent.ID, originalResumeToken)
require.NoError(t, err)
verifiedToken := testutil.TryReceive(ctx, t, resumeTokenProvider.verifyCalls)
require.Equal(t, originalResumeToken, verifiedToken)
newPeerID := testutil.TryReceive(ctx, t, resumeTokenProvider.generateCalls)
require.Equal(t, originalPeerID, newPeerID)
require.NotEqual(t, originalResumeToken, newResumeToken)
// Connect with an invalid resume token, and ensure that the request is
// rejected.
clock.Advance(time.Second)
_, err = connectToCoordinatorAndFetchResumeToken(ctx, logger, client, agentAndBuild.WorkspaceAgent.ID, "invalid")
require.Error(t, err)
var sdkErr *codersdk.Error
require.ErrorAs(t, err, &sdkErr)
require.Equal(t, http.StatusUnauthorized, sdkErr.StatusCode())
require.Len(t, sdkErr.Validations, 1)
require.Equal(t, "resume_token", sdkErr.Validations[0].Field)
verifiedToken = testutil.TryReceive(ctx, t, resumeTokenProvider.verifyCalls)
require.Equal(t, "invalid", verifiedToken)
select {
case <-resumeTokenProvider.generateCalls:
t.Fatal("unexpected peer ID in channel")
default:
}
})
t.Run("BadJWT", func(t *testing.T) {
t.Parallel()
logger := testutil.Logger(t)
clock := quartz.NewMock(t)
resumeTokenSigningKey, err := tailnet.GenerateResumeTokenSigningKey()
mgr := jwtutils.StaticKey{
ID: uuid.New().String(),
Key: resumeTokenSigningKey[:],
}
require.NoError(t, err)
resumeTokenProvider := newResumeTokenRecordingProvider(
t,
tailnet.NewResumeTokenKeyProvider(mgr, clock, time.Hour),
)
client, closer, api := coderdtest.NewWithAPI(t, &coderdtest.Options{
Coordinator: tailnet.NewCoordinator(logger),
CoordinatorResumeTokenProvider: resumeTokenProvider,
})
defer closer.Close()
user := coderdtest.CreateFirstUser(t, client)
// Create a workspace with an agent. No need to connect it since clients can
// still connect to the coordinator while the agent isn't connected.
r := dbfake.WorkspaceBuild(t, api.Database, database.WorkspaceTable{
OrganizationID: user.OrganizationID,
OwnerID: user.UserID,
}).WithAgent().Do()
agentTokenUUID, err := uuid.Parse(r.AgentToken)
require.NoError(t, err)
ctx := testutil.Context(t, testutil.WaitLong)
agentAndBuild, err := api.Database.GetAuthenticatedWorkspaceAgentAndBuildByAuthToken(dbauthz.AsSystemRestricted(ctx), agentTokenUUID)
require.NoError(t, err)
// Connect with no resume token, and ensure that the peer ID is set to a
// random value.
originalResumeToken, err := connectToCoordinatorAndFetchResumeToken(ctx, logger, client, agentAndBuild.WorkspaceAgent.ID, "")
require.NoError(t, err)
originalPeerID := testutil.TryReceive(ctx, t, resumeTokenProvider.generateCalls)
require.NotEqual(t, originalPeerID, uuid.Nil)
// Connect with an outdated token, and ensure that the peer ID is set to a
// random value. We don't want to fail requests just because
// a user got unlucky during a deployment upgrade.
outdatedToken := generateBadJWT(t, jwtutils.RegisteredClaims{
Subject: originalPeerID.String(),
Expiry: jwt.NewNumericDate(clock.Now().Add(time.Minute)),
})
clock.Advance(time.Second)
newResumeToken, err := connectToCoordinatorAndFetchResumeToken(ctx, logger, client, agentAndBuild.WorkspaceAgent.ID, outdatedToken)
require.NoError(t, err)
verifiedToken := testutil.TryReceive(ctx, t, resumeTokenProvider.verifyCalls)
require.Equal(t, outdatedToken, verifiedToken)
newPeerID := testutil.TryReceive(ctx, t, resumeTokenProvider.generateCalls)
require.NotEqual(t, originalPeerID, newPeerID)
require.NotEqual(t, originalResumeToken, newResumeToken)
})
}
// connectToCoordinatorAndFetchResumeToken connects to the tailnet coordinator
// with a given resume token. It returns an error if the connection is rejected.
// If the connection is accepted, it is immediately closed and no error is
// returned.
func connectToCoordinatorAndFetchResumeToken(ctx context.Context, logger slog.Logger, sdkClient *codersdk.Client, agentID uuid.UUID, resumeToken string) (string, error) {
u, err := sdkClient.URL.Parse(fmt.Sprintf("/api/v2/workspaceagents/%s/coordinate", agentID))
if err != nil {
return "", xerrors.Errorf("parse URL: %w", err)
}
q := u.Query()
q.Set("version", "2.0")
if resumeToken != "" {
q.Set("resume_token", resumeToken)
}
u.RawQuery = q.Encode()
//nolint:bodyclose
wsConn, resp, err := websocket.Dial(ctx, u.String(), &websocket.DialOptions{
HTTPHeader: http.Header{
"Coder-Session-Token": []string{sdkClient.SessionToken()},
},
})
if err != nil {
if resp.StatusCode != http.StatusSwitchingProtocols {
err = codersdk.ReadBodyAsError(resp)
}
return "", xerrors.Errorf("websocket dial: %w", err)
}
defer wsConn.Close(websocket.StatusNormalClosure, "done")
// Send a request to the server to ensure that we're plumbed all the way
// through.
rpcClient, err := tailnet.NewDRPCClient(
websocket.NetConn(ctx, wsConn, websocket.MessageBinary),
logger,
)
if err != nil {
return "", xerrors.Errorf("new dRPC client: %w", err)
}
// Fetch a resume token.
newResumeToken, err := rpcClient.RefreshResumeToken(ctx, &tailnetproto.RefreshResumeTokenRequest{})
if err != nil {
return "", xerrors.Errorf("fetch resume token: %w", err)
}
return newResumeToken.Token, nil
}
func TestWorkspaceAgentTailnetDirectDisabled(t *testing.T) {
t.Parallel()
dv := coderdtest.DeploymentValues(t)
err := dv.DERP.Config.BlockDirect.Set("true")
require.NoError(t, err)
require.True(t, dv.DERP.Config.BlockDirect.Value())
client, db := coderdtest.NewWithDatabase(t, &coderdtest.Options{
DeploymentValues: dv,
})
user := coderdtest.CreateFirstUser(t, client)
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: user.OrganizationID,
OwnerID: user.UserID,
}).WithAgent().Do()
ctx := testutil.Context(t, testutil.WaitLong)
// Verify that the manifest has DisableDirectConnections set to true.
agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(r.AgentToken))
rpc, err := agentClient.ConnectRPC(ctx)
require.NoError(t, err)
defer func() {
cErr := rpc.Close()
require.NoError(t, cErr)
}()
aAPI := agentproto.NewDRPCAgentClient(rpc)
manifest := requireGetManifest(ctx, t, aAPI)
require.True(t, manifest.DisableDirectConnections)
_ = agenttest.New(t, client.URL, r.AgentToken)
resources := coderdtest.AwaitWorkspaceAgents(t, client, r.Workspace.ID)
agentID := resources[0].Agents[0].ID
// Verify that the connection data has no STUN ports and
// DisableDirectConnections set to true.
res, err := client.Request(ctx, http.MethodGet, fmt.Sprintf("/api/v2/workspaceagents/%s/connection", agentID), nil)
require.NoError(t, err)
defer res.Body.Close()
require.Equal(t, http.StatusOK, res.StatusCode)
var connInfo workspacesdk.AgentConnectionInfo
err = json.NewDecoder(res.Body).Decode(&connInfo)
require.NoError(t, err)
require.True(t, connInfo.DisableDirectConnections)
for _, region := range connInfo.DERPMap.Regions {
t.Logf("region %s (%v)", region.RegionCode, region.EmbeddedRelay)
for _, node := range region.Nodes {
t.Logf(" node %s (stun %d)", node.Name, node.STUNPort)
require.EqualValues(t, -1, node.STUNPort)
// tailnet.NewDERPMap() will create nodes with "stun" in the name,
// but not if direct is disabled.
require.NotContains(t, node.Name, "stun")
require.False(t, node.STUNOnly)
}
}
conn, err := workspacesdk.New(client).
DialAgent(ctx, resources[0].Agents[0].ID, &workspacesdk.DialAgentOptions{
Logger: testutil.Logger(t).Named("client"),
})
require.NoError(t, err)
defer conn.Close()
require.True(t, conn.AwaitReachable(ctx))
_, p2p, _, err := conn.Ping(ctx)
require.NoError(t, err)
require.False(t, p2p)
}
type fakeListeningPortsGetter struct {
sync.Mutex
ports []codersdk.WorkspaceAgentListeningPort
}
func (g *fakeListeningPortsGetter) GetListeningPorts() ([]codersdk.WorkspaceAgentListeningPort, error) {
g.Lock()
defer g.Unlock()
return slices.Clone(g.ports), nil
}
func (g *fakeListeningPortsGetter) setPorts(ports ...codersdk.WorkspaceAgentListeningPort) {
g.Lock()
defer g.Unlock()
g.ports = slices.Clone(ports)
}
func TestWorkspaceAgentListeningPorts(t *testing.T) {
t.Parallel()
testPort := codersdk.WorkspaceAgentListeningPort{
Network: "tcp",
ProcessName: "test-app",
Port: 44762,
}
filteredPort := codersdk.WorkspaceAgentListeningPort{
Network: "tcp",
ProcessName: "postgres",
Port: 5432,
}
setup := func(t *testing.T, apps []*proto.App, dv *codersdk.DeploymentValues) (*codersdk.Client, uuid.UUID, *fakeListeningPortsGetter) {
t.Helper()
client, db := coderdtest.NewWithDatabase(t, &coderdtest.Options{
DeploymentValues: dv,
})
fLPG := &fakeListeningPortsGetter{}
user := coderdtest.CreateFirstUser(t, client)
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: user.OrganizationID,
OwnerID: user.UserID,
}).WithAgent(func(agents []*proto.Agent) []*proto.Agent {
agents[0].Apps = apps
return agents
}).Do()
_ = agenttest.New(t, client.URL, r.AgentToken, func(o *agent.Options) {
o.ListeningPortsGetter = fLPG
})
resources := coderdtest.NewWorkspaceAgentWaiter(t, client, r.Workspace.ID).Wait()
// #nosec G115 - Safe conversion as TCP port numbers are within uint16 range (0-65535)
return client, resources[0].Agents[0].ID, fLPG
}
for _, tc := range []struct {
name string
setDV func(t *testing.T, dv *codersdk.DeploymentValues)
}{
{
name: "Mainline",
setDV: func(*testing.T, *codersdk.DeploymentValues) {},
},
{
name: "BlockDirect",
setDV: func(t *testing.T, dv *codersdk.DeploymentValues) {
err := dv.DERP.Config.BlockDirect.Set("true")
require.NoError(t, err)
require.True(t, dv.DERP.Config.BlockDirect.Value())
},
},
} {
t.Run("OK_"+tc.name, func(t *testing.T) {
t.Parallel()
dv := coderdtest.DeploymentValues(t)
tc.setDV(t, dv)
client, agentID, fLPG := setup(t, nil, dv)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
fLPG.setPorts(testPort)
// List ports and ensure that the port we expect to see is there.
res, err := client.WorkspaceAgentListeningPorts(ctx, agentID)
require.NoError(t, err)
require.Equal(t, []codersdk.WorkspaceAgentListeningPort{testPort}, res.Ports)
// Remove the port and check that the port is no longer in the response.
fLPG.setPorts()
res, err = client.WorkspaceAgentListeningPorts(ctx, agentID)
require.NoError(t, err)
require.Empty(t, res.Ports)
})
}
t.Run("Filter", func(t *testing.T) {
t.Parallel()
app := &proto.App{
Slug: testPort.ProcessName,
Url: fmt.Sprintf("http://localhost:%d", testPort.Port),
}
client, agentID, fLPG := setup(t, []*proto.App{app}, nil)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
fLPG.setPorts(testPort, filteredPort)
res, err := client.WorkspaceAgentListeningPorts(ctx, agentID)
require.NoError(t, err)
require.Empty(t, res.Ports)
})
}
func TestWorkspaceAgentContainers(t *testing.T) {
t.Parallel()
// This test will not normally run in CI, but is kept here as a semi-manual
// test for local development. Run it as follows:
// CODER_TEST_USE_DOCKER=1 go test -run TestWorkspaceAgentContainers/Docker ./coderd
t.Run("Docker", func(t *testing.T) {
t.Parallel()
if ctud, ok := os.LookupEnv("CODER_TEST_USE_DOCKER"); !ok || ctud != "1" {
t.Skip("Set CODER_TEST_USE_DOCKER=1 to run this test")
}
pool, err := dockertest.NewPool("")
require.NoError(t, err, "Could not connect to docker")
testLabels := map[string]string{
"com.coder.test": uuid.New().String(),
"com.coder.empty": "",
}
ct, err := pool.RunWithOptions(&dockertest.RunOptions{
Repository: "busybox",
Tag: "latest",
Cmd: []string{"sleep", "infinity"},
Labels: testLabels,
}, func(config *docker.HostConfig) {
config.AutoRemove = true
config.RestartPolicy = docker.RestartPolicy{Name: "no"}
})
require.NoError(t, err, "Could not start test docker container")
t.Cleanup(func() {
assert.NoError(t, pool.Purge(ct), "Could not purge resource %q", ct.Container.Name)
})
// Start another container which we will expect to ignore.
ct2, err := pool.RunWithOptions(&dockertest.RunOptions{
Repository: "busybox",
Tag: "latest",
Cmd: []string{"sleep", "infinity"},
Labels: map[string]string{
"com.coder.test": "ignoreme",
"com.coder.empty": "",
},
}, func(config *docker.HostConfig) {
config.AutoRemove = true
config.RestartPolicy = docker.RestartPolicy{Name: "no"}
})
require.NoError(t, err, "Could not start second test docker container")
t.Cleanup(func() {
assert.NoError(t, pool.Purge(ct2), "Could not purge resource %q", ct2.Container.Name)
})
client, db := coderdtest.NewWithDatabase(t, &coderdtest.Options{})
user := coderdtest.CreateFirstUser(t, client)
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: user.OrganizationID,
OwnerID: user.UserID,
}).WithAgent(func(agents []*proto.Agent) []*proto.Agent {
return agents
}).Do()
_ = agenttest.New(t, client.URL, r.AgentToken, func(o *agent.Options) {
o.Devcontainers = true
o.DevcontainerAPIOptions = append(o.DevcontainerAPIOptions,
agentcontainers.WithContainerLabelIncludeFilter("this.label.does.not.exist.ignore.devcontainers", "true"),
)
})
resources := coderdtest.NewWorkspaceAgentWaiter(t, client, r.Workspace.ID).Wait()
require.Len(t, resources, 1, "expected one resource")
require.Len(t, resources[0].Agents, 1, "expected one agent")
agentID := resources[0].Agents[0].ID
ctx := testutil.Context(t, testutil.WaitLong)
// If we filter by testLabels, we should only get one container back.
res, err := client.WorkspaceAgentListContainers(ctx, agentID, testLabels)
require.NoError(t, err, "failed to list containers filtered by test label")
require.Len(t, res.Containers, 1, "expected exactly one container")
assert.Equal(t, ct.Container.ID, res.Containers[0].ID, "expected container ID to match")
assert.Equal(t, "busybox:latest", res.Containers[0].Image, "expected container image to match")
assert.Equal(t, ct.Container.Config.Labels, res.Containers[0].Labels, "expected container labels to match")
assert.Equal(t, strings.TrimPrefix(ct.Container.Name, "/"), res.Containers[0].FriendlyName, "expected container name to match")
assert.True(t, res.Containers[0].Running, "expected container to be running")
assert.Equal(t, "running", res.Containers[0].Status, "expected container status to be running")
// List all containers and ensure we get at least both (there may be more).
res, err = client.WorkspaceAgentListContainers(ctx, agentID, nil)
require.NoError(t, err, "failed to list all containers")
require.NotEmpty(t, res.Containers, "expected to find containers")
var found []string
for _, c := range res.Containers {
found = append(found, c.ID)
}
require.Contains(t, found, ct.Container.ID, "expected to find first container without label filter")
require.Contains(t, found, ct2.Container.ID, "expected to find first container without label filter")
})
// This test will normally run in CI. It uses a mock implementation of
// agentcontainers.Lister instead of introducing a hard dependency on Docker.
t.Run("Mock", func(t *testing.T) {
t.Parallel()
// begin test fixtures
testLabels := map[string]string{
"com.coder.test": uuid.New().String(),
}
testResponse := codersdk.WorkspaceAgentListContainersResponse{
Containers: []codersdk.WorkspaceAgentContainer{
{
ID: uuid.NewString(),
CreatedAt: dbtime.Now(),
FriendlyName: testutil.GetRandomName(t),
Image: "busybox:latest",
Labels: testLabels,
Running: true,
Status: "running",
Ports: []codersdk.WorkspaceAgentContainerPort{
{
Network: "tcp",
Port: 80,
HostIP: "0.0.0.0",
HostPort: 8000,
},
},
Volumes: map[string]string{
"/host": "/container",
},
},
},
}
// end test fixtures
for _, tc := range []struct {
name string
setupMock func(*acmock.MockContainerCLI) (codersdk.WorkspaceAgentListContainersResponse, error)
}{
{
name: "test response",
setupMock: func(mcl *acmock.MockContainerCLI) (codersdk.WorkspaceAgentListContainersResponse, error) {
mcl.EXPECT().List(gomock.Any()).Return(testResponse, nil).AnyTimes()
return testResponse, nil
},
},
{
name: "error response",
setupMock: func(mcl *acmock.MockContainerCLI) (codersdk.WorkspaceAgentListContainersResponse, error) {
mcl.EXPECT().List(gomock.Any()).Return(codersdk.WorkspaceAgentListContainersResponse{}, assert.AnError).AnyTimes()
return codersdk.WorkspaceAgentListContainersResponse{}, assert.AnError
},
},
} {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
mcl := acmock.NewMockContainerCLI(ctrl)
expected, expectedErr := tc.setupMock(mcl)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
client, db := coderdtest.NewWithDatabase(t, &coderdtest.Options{
Logger: &logger,
})
user := coderdtest.CreateFirstUser(t, client)
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: user.OrganizationID,
OwnerID: user.UserID,
}).WithAgent(func(agents []*proto.Agent) []*proto.Agent {
return agents
}).Do()
_ = agenttest.New(t, client.URL, r.AgentToken, func(o *agent.Options) {
o.Logger = logger.Named("agent")
o.Devcontainers = true
o.DevcontainerAPIOptions = append(o.DevcontainerAPIOptions,
agentcontainers.WithContainerCLI(mcl),
agentcontainers.WithContainerLabelIncludeFilter("this.label.does.not.exist.ignore.devcontainers", "true"),
)
})
resources := coderdtest.NewWorkspaceAgentWaiter(t, client, r.Workspace.ID).Wait()
require.Len(t, resources, 1, "expected one resource")
require.Len(t, resources[0].Agents, 1, "expected one agent")
agentID := resources[0].Agents[0].ID
ctx := testutil.Context(t, testutil.WaitLong)
// List containers and ensure we get the expected mocked response.
res, err := client.WorkspaceAgentListContainers(ctx, agentID, nil)
if expectedErr != nil {
require.Contains(t, err.Error(), expectedErr.Error(), "unexpected error")
require.Empty(t, res, "expected empty response")
} else {
require.NoError(t, err, "failed to list all containers")
if diff := cmp.Diff(expected, res); diff != "" {
t.Fatalf("unexpected response (-want +got):\n%s", diff)
}
}
})
}
})
}
func TestWatchWorkspaceAgentDevcontainers(t *testing.T) {
t.Parallel()
t.Run("OK", func(t *testing.T) {
t.Parallel()
var (
ctx = testutil.Context(t, testutil.WaitLong)
logger = slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
mClock = quartz.NewMock(t)
updaterTickerTrap = mClock.Trap().TickerFunc("updaterLoop")
mCtrl = gomock.NewController(t)
mCCLI = acmock.NewMockContainerCLI(mCtrl)
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{Logger: &logger})
user = coderdtest.CreateFirstUser(t, client)
r = dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: user.OrganizationID,
OwnerID: user.UserID,
}).WithAgent(func(agents []*proto.Agent) []*proto.Agent {
return agents
}).Do()
fakeContainer1 = codersdk.WorkspaceAgentContainer{
ID: "container1",
CreatedAt: dbtime.Now(),
FriendlyName: "container1",
Image: "busybox:latest",
Labels: map[string]string{
agentcontainers.DevcontainerLocalFolderLabel: "/home/coder/project1",
agentcontainers.DevcontainerConfigFileLabel: "/home/coder/project1/.devcontainer/devcontainer.json",
},
Running: true,
Status: "running",
}
fakeContainer2 = codersdk.WorkspaceAgentContainer{
ID: "container1",
CreatedAt: dbtime.Now(),
FriendlyName: "container2",
Image: "busybox:latest",
Labels: map[string]string{
agentcontainers.DevcontainerLocalFolderLabel: "/home/coder/project2",
agentcontainers.DevcontainerConfigFileLabel: "/home/coder/project2/.devcontainer/devcontainer.json",
},
Running: true,
Status: "running",
}
)
stages := []struct {
containers []codersdk.WorkspaceAgentContainer
expected codersdk.WorkspaceAgentListContainersResponse
}{
{
containers: []codersdk.WorkspaceAgentContainer{fakeContainer1},
expected: codersdk.WorkspaceAgentListContainersResponse{
Containers: []codersdk.WorkspaceAgentContainer{fakeContainer1},
Devcontainers: []codersdk.WorkspaceAgentDevcontainer{
{
Name: "project1",
WorkspaceFolder: fakeContainer1.Labels[agentcontainers.DevcontainerLocalFolderLabel],
ConfigPath: fakeContainer1.Labels[agentcontainers.DevcontainerConfigFileLabel],
Status: "running",
Container: &fakeContainer1,
},
},
},
},
{
containers: []codersdk.WorkspaceAgentContainer{fakeContainer1, fakeContainer2},
expected: codersdk.WorkspaceAgentListContainersResponse{
Containers: []codersdk.WorkspaceAgentContainer{fakeContainer1, fakeContainer2},
Devcontainers: []codersdk.WorkspaceAgentDevcontainer{
{
Name: "project1",
WorkspaceFolder: fakeContainer1.Labels[agentcontainers.DevcontainerLocalFolderLabel],
ConfigPath: fakeContainer1.Labels[agentcontainers.DevcontainerConfigFileLabel],
Status: "running",
Container: &fakeContainer1,
},
{
Name: "project2",
WorkspaceFolder: fakeContainer2.Labels[agentcontainers.DevcontainerLocalFolderLabel],
ConfigPath: fakeContainer2.Labels[agentcontainers.DevcontainerConfigFileLabel],
Status: "running",
Container: &fakeContainer2,
},
},
},
},
{
containers: []codersdk.WorkspaceAgentContainer{fakeContainer2},
expected: codersdk.WorkspaceAgentListContainersResponse{
Containers: []codersdk.WorkspaceAgentContainer{fakeContainer2},
Devcontainers: []codersdk.WorkspaceAgentDevcontainer{
{
Name: "",
WorkspaceFolder: fakeContainer1.Labels[agentcontainers.DevcontainerLocalFolderLabel],
ConfigPath: fakeContainer1.Labels[agentcontainers.DevcontainerConfigFileLabel],
Status: "stopped",
Container: nil,
},
{
Name: "project2",
WorkspaceFolder: fakeContainer2.Labels[agentcontainers.DevcontainerLocalFolderLabel],
ConfigPath: fakeContainer2.Labels[agentcontainers.DevcontainerConfigFileLabel],
Status: "running",
Container: &fakeContainer2,
},
},
},
},
}
// Set up initial state for immediate send on connection
mCCLI.EXPECT().List(gomock.Any()).Return(codersdk.WorkspaceAgentListContainersResponse{Containers: stages[0].containers}, nil)
mCCLI.EXPECT().DetectArchitecture(gomock.Any(), gomock.Any()).Return("<none>", nil).AnyTimes()
_ = agenttest.New(t, client.URL, r.AgentToken, func(o *agent.Options) {
o.Logger = logger.Named("agent")
o.Devcontainers = true
o.DevcontainerAPIOptions = []agentcontainers.Option{
agentcontainers.WithClock(mClock),
agentcontainers.WithContainerCLI(mCCLI),
agentcontainers.WithWatcher(watcher.NewNoop()),
}
})
resources := coderdtest.NewWorkspaceAgentWaiter(t, client, r.Workspace.ID).Wait()
require.Len(t, resources, 1, "expected one resource")
require.Len(t, resources[0].Agents, 1, "expected one agent")
agentID := resources[0].Agents[0].ID
updaterTickerTrap.MustWait(ctx).MustRelease(ctx)
defer updaterTickerTrap.Close()
containers, closer, err := client.WatchWorkspaceAgentContainers(ctx, agentID)
require.NoError(t, err)
defer func() {
closer.Close()
}()
// Read initial state sent immediately on connection
var got codersdk.WorkspaceAgentListContainersResponse
select {
case <-ctx.Done():
case got = <-containers:
}
require.NoError(t, ctx.Err())
require.Equal(t, stages[0].expected.Containers, got.Containers)
require.Len(t, got.Devcontainers, len(stages[0].expected.Devcontainers))
for j, expectedDev := range stages[0].expected.Devcontainers {
gotDev := got.Devcontainers[j]
require.Equal(t, expectedDev.Name, gotDev.Name)
require.Equal(t, expectedDev.WorkspaceFolder, gotDev.WorkspaceFolder)
require.Equal(t, expectedDev.ConfigPath, gotDev.ConfigPath)
require.Equal(t, expectedDev.Status, gotDev.Status)
require.Equal(t, expectedDev.Container, gotDev.Container)
}
// Process remaining stages through updater loop
for i, stage := range stages[1:] {
mCCLI.EXPECT().List(gomock.Any()).Return(codersdk.WorkspaceAgentListContainersResponse{Containers: stage.containers}, nil)
_, aw := mClock.AdvanceNext()
aw.MustWait(ctx)
var got codersdk.WorkspaceAgentListContainersResponse
select {
case <-ctx.Done():
case got = <-containers:
}
require.NoError(t, ctx.Err())
require.Equal(t, stages[i+1].expected.Containers, got.Containers)
require.Len(t, got.Devcontainers, len(stages[i+1].expected.Devcontainers))
for j, expectedDev := range stages[i+1].expected.Devcontainers {
gotDev := got.Devcontainers[j]
require.Equal(t, expectedDev.Name, gotDev.Name)
require.Equal(t, expectedDev.WorkspaceFolder, gotDev.WorkspaceFolder)
require.Equal(t, expectedDev.ConfigPath, gotDev.ConfigPath)
require.Equal(t, expectedDev.Status, gotDev.Status)
require.Equal(t, expectedDev.Container, gotDev.Container)
}
}
})
}
func TestWorkspaceAgentRecreateDevcontainer(t *testing.T) {
t.Parallel()
t.Run("Mock", func(t *testing.T) {
t.Parallel()
var (
workspaceFolder = t.TempDir()
configFile = filepath.Join(workspaceFolder, ".devcontainer", "devcontainer.json")
devcontainerID = uuid.New()
// Create a container that would be associated with the devcontainer
devContainer = codersdk.WorkspaceAgentContainer{
ID: uuid.NewString(),
CreatedAt: dbtime.Now(),
FriendlyName: testutil.GetRandomName(t),
Image: "busybox:latest",
Labels: map[string]string{
agentcontainers.DevcontainerLocalFolderLabel: workspaceFolder,
agentcontainers.DevcontainerConfigFileLabel: configFile,
},
Running: true,
Status: "running",
}
devcontainer = codersdk.WorkspaceAgentDevcontainer{
ID: devcontainerID,
Name: "test-devcontainer",
WorkspaceFolder: workspaceFolder,
ConfigPath: configFile,
Status: codersdk.WorkspaceAgentDevcontainerStatusRunning,
Container: &devContainer,
}
)
for _, tc := range []struct {
name string
devcontainerID string
devcontainers []codersdk.WorkspaceAgentDevcontainer
containers []codersdk.WorkspaceAgentContainer
expectRecreate bool
expectErrorCode int
}{
{
name: "Recreate",
devcontainerID: devcontainerID.String(),
devcontainers: []codersdk.WorkspaceAgentDevcontainer{devcontainer},
containers: []codersdk.WorkspaceAgentContainer{devContainer},
expectRecreate: true,
},
{
name: "Devcontainer does not exist",
devcontainerID: uuid.NewString(),
devcontainers: nil,
containers: []codersdk.WorkspaceAgentContainer{},
expectErrorCode: http.StatusNotFound,
},
} {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
var (
ctx = testutil.Context(t, testutil.WaitLong)
mCtrl = gomock.NewController(t)
mCCLI = acmock.NewMockContainerCLI(mCtrl)
mDCCLI = acmock.NewMockDevcontainerCLI(mCtrl)
logger = slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
Logger: &logger,
})
user = coderdtest.CreateFirstUser(t, client)
r = dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: user.OrganizationID,
OwnerID: user.UserID,
}).WithAgent(func(agents []*proto.Agent) []*proto.Agent {
return agents
}).Do()
)
mCCLI.EXPECT().List(gomock.Any()).Return(codersdk.WorkspaceAgentListContainersResponse{
Containers: tc.containers,
}, nil).AnyTimes()
var upCalled chan struct{}
if tc.expectRecreate {
upCalled = make(chan struct{})
// DetectArchitecture always returns "<none>" for this test to disable agent injection.
mCCLI.EXPECT().DetectArchitecture(gomock.Any(), devContainer.ID).Return("<none>", nil).AnyTimes()
mDCCLI.EXPECT().ReadConfig(gomock.Any(), workspaceFolder, configFile, gomock.Any()).Return(agentcontainers.DevcontainerConfig{}, nil).AnyTimes()
mDCCLI.EXPECT().Up(gomock.Any(), workspaceFolder, configFile, gomock.Any()).
DoAndReturn(func(_ context.Context, _, _ string, _ ...agentcontainers.DevcontainerCLIUpOptions) (string, error) {
close(upCalled)
return "someid", nil
}).Times(1)
}
devcontainerAPIOptions := []agentcontainers.Option{
agentcontainers.WithContainerCLI(mCCLI),
agentcontainers.WithDevcontainerCLI(mDCCLI),
agentcontainers.WithWatcher(watcher.NewNoop()),
}
if tc.devcontainers != nil {
devcontainerAPIOptions = append(devcontainerAPIOptions,
agentcontainers.WithDevcontainers(tc.devcontainers, nil))
}
_ = agenttest.New(t, client.URL, r.AgentToken, func(o *agent.Options) {
o.Logger = logger.Named("agent")
o.Devcontainers = true
o.DevcontainerAPIOptions = devcontainerAPIOptions
})
resources := coderdtest.NewWorkspaceAgentWaiter(t, client, r.Workspace.ID).Wait()
require.Len(t, resources, 1, "expected one resource")
require.Len(t, resources[0].Agents, 1, "expected one agent")
agentID := resources[0].Agents[0].ID
_, err := client.WorkspaceAgentRecreateDevcontainer(ctx, agentID, tc.devcontainerID)
if tc.expectErrorCode > 0 {
cerr, ok := codersdk.AsError(err)
require.True(t, ok, "expected error to be a coder error")
assert.Equal(t, tc.expectErrorCode, cerr.StatusCode())
} else {
require.NoError(t, err, "failed to recreate devcontainer")
testutil.TryReceive(ctx, t, upCalled)
}
})
}
})
}
func TestWorkspaceAgentDeleteDevcontainer(t *testing.T) {
t.Parallel()
const (
workspaceFolder = "/home/coder/coder"
)
configFile := filepath.Join(workspaceFolder, ".devcontainer", "devcontainer.json")
setupDevcontainerMocks := func(t *testing.T) (
*gomock.Controller,
*acmock.MockContainerCLI,
*acmock.MockDevcontainerCLI,
codersdk.WorkspaceAgentContainer,
codersdk.WorkspaceAgentDevcontainer,
[]agentcontainers.Option,
) {
devcontainerID := uuid.New()
devContainer := codersdk.WorkspaceAgentContainer{
ID: uuid.NewString(),
CreatedAt: dbtime.Now(),
FriendlyName: testutil.GetRandomName(t),
Image: "busybox:latest",
Labels: map[string]string{
agentcontainers.DevcontainerLocalFolderLabel: workspaceFolder,
agentcontainers.DevcontainerConfigFileLabel: configFile,
},
Running: true,
Status: "running",
}
devcontainer := codersdk.WorkspaceAgentDevcontainer{
ID: devcontainerID,
Name: "test-devcontainer",
WorkspaceFolder: workspaceFolder,
ConfigPath: configFile,
Status: codersdk.WorkspaceAgentDevcontainerStatusRunning,
Container: &devContainer,
}
mCtrl := gomock.NewController(t)
mCCLI := acmock.NewMockContainerCLI(mCtrl)
mDCCLI := acmock.NewMockDevcontainerCLI(mCtrl)
mCCLI.EXPECT().List(gomock.Any()).Return(codersdk.WorkspaceAgentListContainersResponse{
Containers: []codersdk.WorkspaceAgentContainer{devContainer},
}, nil).AnyTimes()
mCCLI.EXPECT().DetectArchitecture(gomock.Any(), devContainer.ID).Return("<none>", nil).AnyTimes()
mDCCLI.EXPECT().ReadConfig(gomock.Any(), workspaceFolder, configFile, gomock.Any()).Return(agentcontainers.DevcontainerConfig{}, nil).AnyTimes()
devcontainerAPIOptions := []agentcontainers.Option{
agentcontainers.WithContainerCLI(mCCLI),
agentcontainers.WithDevcontainerCLI(mDCCLI),
agentcontainers.WithWatcher(watcher.NewNoop()),
agentcontainers.WithDevcontainers([]codersdk.WorkspaceAgentDevcontainer{devcontainer}, nil),
}
return mCtrl, mCCLI, mDCCLI, devContainer, devcontainer, devcontainerAPIOptions
}
tests := []struct {
name string
startAgent bool
useAnotherUser bool
expectError bool
expectedStatus int
}{
{
name: "OK",
startAgent: true,
useAnotherUser: false,
expectError: false,
},
{
name: "Forbidden",
startAgent: true,
useAnotherUser: true,
expectError: true,
expectedStatus: http.StatusNotFound,
},
{
name: "AgentNotConnected",
startAgent: false,
useAnotherUser: false,
expectError: true,
expectedStatus: http.StatusBadRequest,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
client, db := coderdtest.NewWithDatabase(t, &coderdtest.Options{
Logger: &logger,
})
user := coderdtest.CreateFirstUser(t, client)
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: user.OrganizationID,
OwnerID: user.UserID,
}).WithAgent(func(agents []*proto.Agent) []*proto.Agent {
return agents
}).Do()
_, mCCLI, _, devContainer, devcontainer, devcontainerAPIOptions := setupDevcontainerMocks(t)
var agentID uuid.UUID
if tc.startAgent {
_ = agenttest.New(t, client.URL, r.AgentToken, func(o *agent.Options) {
o.Logger = logger.Named("agent")
o.Devcontainers = true
o.DevcontainerAPIOptions = devcontainerAPIOptions
})
resources := coderdtest.NewWorkspaceAgentWaiter(t, client, r.Workspace.ID).Wait()
require.Len(t, resources, 1, "expected one resource")
require.Len(t, resources[0].Agents, 1, "expected one agent")
agentID = resources[0].Agents[0].ID
if !tc.expectError {
// Set up expectations for Stop and Remove when expecting success.
mCCLI.EXPECT().Stop(gomock.Any(), devContainer.ID).Return(nil).Times(1)
mCCLI.EXPECT().Remove(gomock.Any(), devContainer.ID).Return(nil).Times(1)
}
} else {
// When not starting an agent, get the agent ID from the workspace resources.
ws, err := client.Workspace(ctx, r.Workspace.ID)
require.NoError(t, err, "failed to get workspace")
require.Len(t, ws.LatestBuild.Resources, 1, "expected one resource")
require.Len(t, ws.LatestBuild.Resources[0].Agents, 1, "expected one agent")
agentID = ws.LatestBuild.Resources[0].Agents[0].ID
}
testClient := client
if tc.useAnotherUser {
testClient, _ = coderdtest.CreateAnotherUser(t, client, user.OrganizationID)
}
err := testClient.WorkspaceAgentDeleteDevcontainer(ctx, agentID, devcontainer.ID.String())
if tc.expectError {
require.Error(t, err)
var sdkErr *codersdk.Error
require.ErrorAs(t, err, &sdkErr)
require.Equal(t, tc.expectedStatus, sdkErr.StatusCode())
} else {
require.NoError(t, err, "failed to delete devcontainer")
}
})
}
}
func TestWorkspaceAgentAppHealth(t *testing.T) {
t.Parallel()
client, db := coderdtest.NewWithDatabase(t, nil)
user := coderdtest.CreateFirstUser(t, client)
apps := []*proto.App{
{
Slug: "code-server",
Command: "some-command",
Url: "http://localhost:3000",
Icon: "/code.svg",
},
{
Slug: "code-server-2",
DisplayName: "code-server-2",
Command: "some-command",
Url: "http://localhost:3000",
Icon: "/code.svg",
Healthcheck: &proto.Healthcheck{
Url: "http://localhost:3000",
Interval: 5,
Threshold: 6,
},
},
}
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: user.OrganizationID,
OwnerID: user.UserID,
}).WithAgent(func(agents []*proto.Agent) []*proto.Agent {
agents[0].Apps = apps
return agents
}).Do()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(r.AgentToken))
conn, err := agentClient.ConnectRPC(ctx)
require.NoError(t, err)
defer func() {
cErr := conn.Close()
require.NoError(t, cErr)
}()
aAPI := agentproto.NewDRPCAgentClient(conn)
manifest := requireGetManifest(ctx, t, aAPI)
require.EqualValues(t, codersdk.WorkspaceAppHealthDisabled, manifest.Apps[0].Health)
require.EqualValues(t, codersdk.WorkspaceAppHealthInitializing, manifest.Apps[1].Health)
// empty
_, err = aAPI.BatchUpdateAppHealths(ctx, &agentproto.BatchUpdateAppHealthRequest{})
require.NoError(t, err)
// healthcheck disabled
_, err = aAPI.BatchUpdateAppHealths(ctx, &agentproto.BatchUpdateAppHealthRequest{
Updates: []*agentproto.BatchUpdateAppHealthRequest_HealthUpdate{
{
Id: manifest.Apps[0].ID[:],
Health: agentproto.AppHealth_INITIALIZING,
},
},
})
require.Error(t, err)
// invalid value
_, err = aAPI.BatchUpdateAppHealths(ctx, &agentproto.BatchUpdateAppHealthRequest{
Updates: []*agentproto.BatchUpdateAppHealthRequest_HealthUpdate{
{
Id: manifest.Apps[1].ID[:],
Health: 99,
},
},
})
require.Error(t, err)
// update to healthy
_, err = aAPI.BatchUpdateAppHealths(ctx, &agentproto.BatchUpdateAppHealthRequest{
Updates: []*agentproto.BatchUpdateAppHealthRequest_HealthUpdate{
{
Id: manifest.Apps[1].ID[:],
Health: agentproto.AppHealth_HEALTHY,
},
},
})
require.NoError(t, err)
manifest = requireGetManifest(ctx, t, aAPI)
require.EqualValues(t, codersdk.WorkspaceAppHealthHealthy, manifest.Apps[1].Health)
// update to unhealthy
_, err = aAPI.BatchUpdateAppHealths(ctx, &agentproto.BatchUpdateAppHealthRequest{
Updates: []*agentproto.BatchUpdateAppHealthRequest_HealthUpdate{
{
Id: manifest.Apps[1].ID[:],
Health: agentproto.AppHealth_UNHEALTHY,
},
},
})
require.NoError(t, err)
manifest = requireGetManifest(ctx, t, aAPI)
require.EqualValues(t, codersdk.WorkspaceAppHealthUnhealthy, manifest.Apps[1].Health)
}
func TestWorkspaceAgentPostLogSource(t *testing.T) {
t.Parallel()
t.Run("OK", func(t *testing.T) {
t.Parallel()
client, db := coderdtest.NewWithDatabase(t, nil)
user := coderdtest.CreateFirstUser(t, client)
ctx := testutil.Context(t, testutil.WaitShort)
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: user.OrganizationID,
OwnerID: user.UserID,
}).WithAgent().Do()
agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(r.AgentToken))
req := agentsdk.PostLogSourceRequest{
ID: uuid.New(),
DisplayName: "colin logs",
Icon: "/emojis/1f42e.png",
}
res, err := agentClient.PostLogSource(ctx, req)
require.NoError(t, err)
assert.Equal(t, req.ID, res.ID)
assert.Equal(t, req.DisplayName, res.DisplayName)
assert.Equal(t, req.Icon, res.Icon)
assert.NotZero(t, res.WorkspaceAgentID)
assert.NotZero(t, res.CreatedAt)
// should be idempotent
res, err = agentClient.PostLogSource(ctx, req)
require.NoError(t, err)
assert.Equal(t, req.ID, res.ID)
assert.Equal(t, req.DisplayName, res.DisplayName)
assert.Equal(t, req.Icon, res.Icon)
assert.NotZero(t, res.WorkspaceAgentID)
assert.NotZero(t, res.CreatedAt)
})
}
func TestWorkspaceAgent_LifecycleState(t *testing.T) {
t.Parallel()
t.Run("Set", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client, db := coderdtest.NewWithDatabase(t, nil)
user := coderdtest.CreateFirstUser(t, client)
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: user.OrganizationID,
OwnerID: user.UserID,
}).WithAgent().Do()
workspace, err := client.Workspace(context.Background(), r.Workspace.ID)
require.NoError(t, err)
for _, res := range workspace.LatestBuild.Resources {
for _, a := range res.Agents {
require.Equal(t, codersdk.WorkspaceAgentLifecycleCreated, a.LifecycleState)
}
}
ac := agentsdk.New(client.URL, agentsdk.WithFixedToken(r.AgentToken))
conn, err := ac.ConnectRPC(ctx)
require.NoError(t, err)
defer func() {
cErr := conn.Close()
require.NoError(t, cErr)
}()
agentAPI := agentproto.NewDRPCAgentClient(conn)
tests := []struct {
state codersdk.WorkspaceAgentLifecycle
wantErr bool
}{
{codersdk.WorkspaceAgentLifecycleCreated, false},
{codersdk.WorkspaceAgentLifecycleStarting, false},
{codersdk.WorkspaceAgentLifecycleStartTimeout, false},
{codersdk.WorkspaceAgentLifecycleStartError, false},
{codersdk.WorkspaceAgentLifecycleReady, false},
{codersdk.WorkspaceAgentLifecycleShuttingDown, false},
{codersdk.WorkspaceAgentLifecycleShutdownTimeout, false},
{codersdk.WorkspaceAgentLifecycleShutdownError, false},
{codersdk.WorkspaceAgentLifecycleOff, false},
{codersdk.WorkspaceAgentLifecycle("nonexistent_state"), true},
{codersdk.WorkspaceAgentLifecycle(""), true},
}
//nolint:paralleltest // No race between setting the state and getting the workspace.
for _, tt := range tests {
t.Run(string(tt.state), func(t *testing.T) {
state, err := agentsdk.ProtoFromLifecycleState(tt.state)
if tt.wantErr {
require.Error(t, err)
return
}
_, err = agentAPI.UpdateLifecycle(ctx, &agentproto.UpdateLifecycleRequest{
Lifecycle: &agentproto.Lifecycle{
State: state,
ChangedAt: timestamppb.Now(),
},
})
require.NoError(t, err, "post lifecycle state %q", tt.state)
workspace, err = client.Workspace(ctx, workspace.ID)
require.NoError(t, err, "get workspace")
for _, res := range workspace.LatestBuild.Resources {
for _, agent := range res.Agents {
require.Equal(t, tt.state, agent.LifecycleState)
}
}
})
}
})
}
func TestWorkspaceAgent_Metadata(t *testing.T) {
t.Parallel()
client, db := coderdtest.NewWithDatabase(t, &coderdtest.Options{
MetadataBatcherOptions: []metadatabatcher.Option{
metadatabatcher.WithInterval(100 * time.Millisecond),
},
})
user := coderdtest.CreateFirstUser(t, client)
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: user.OrganizationID,
OwnerID: user.UserID,
}).WithAgent(func(agents []*proto.Agent) []*proto.Agent {
agents[0].Metadata = []*proto.Agent_Metadata{
{
DisplayName: "First Meta",
Key: "foo1",
Script: "echo hi",
Interval: 10,
Timeout: 3,
},
{
DisplayName: "Second Meta",
Key: "foo2",
Script: "echo howdy",
Interval: 10,
Timeout: 3,
},
{
DisplayName: "TooLong",
Key: "foo3",
Script: "echo howdy",
Interval: 10,
Timeout: 3,
},
}
return agents
}).Do()
workspace, err := client.Workspace(context.Background(), r.Workspace.ID)
require.NoError(t, err)
for _, res := range workspace.LatestBuild.Resources {
for _, a := range res.Agents {
require.Equal(t, codersdk.WorkspaceAgentLifecycleCreated, a.LifecycleState)
}
}
agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(r.AgentToken))
ctx := testutil.Context(t, testutil.WaitMedium)
conn, err := agentClient.ConnectRPC(ctx)
require.NoError(t, err)
defer func() {
cErr := conn.Close()
require.NoError(t, cErr)
}()
aAPI := agentproto.NewDRPCAgentClient(conn)
manifest := requireGetManifest(ctx, t, aAPI)
// Verify manifest API response.
require.Equal(t, workspace.ID, manifest.WorkspaceID)
require.Equal(t, workspace.OwnerName, manifest.OwnerName)
require.Equal(t, "First Meta", manifest.Metadata[0].DisplayName)
require.Equal(t, "foo1", manifest.Metadata[0].Key)
require.Equal(t, "echo hi", manifest.Metadata[0].Script)
require.EqualValues(t, 10, manifest.Metadata[0].Interval)
require.EqualValues(t, 3, manifest.Metadata[0].Timeout)
post := func(ctx context.Context, key string, mr codersdk.WorkspaceAgentMetadataResult) {
_, err := aAPI.BatchUpdateMetadata(ctx, &agentproto.BatchUpdateMetadataRequest{
Metadata: []*agentproto.Metadata{
{
Key: key,
Result: agentsdk.ProtoFromMetadataResult(mr),
},
},
})
require.NoError(t, err, "post metadata: %s, %#v", key, mr)
}
workspace, err = client.Workspace(ctx, workspace.ID)
require.NoError(t, err, "get workspace")
agentID := workspace.LatestBuild.Resources[0].Agents[0].ID
var update []codersdk.WorkspaceAgentMetadata
wantMetadata1 := codersdk.WorkspaceAgentMetadataResult{
CollectedAt: time.Now(),
Value: "bar",
}
// Setup is complete, reset the context.
ctx = testutil.Context(t, testutil.WaitMedium)
// Initial post must come before the Watch is established.
post(ctx, "foo1", wantMetadata1)
updates, errors := client.WatchWorkspaceAgentMetadata(ctx, agentID)
recvUpdate := func() []codersdk.WorkspaceAgentMetadata {
select {
case <-ctx.Done():
t.Fatalf("context done: %v", ctx.Err())
case err := <-errors:
t.Fatalf("error watching metadata: %v", err)
case update := <-updates:
return update
}
return nil
}
check := func(want codersdk.WorkspaceAgentMetadataResult, got codersdk.WorkspaceAgentMetadata, retry bool) {
// We can't trust the order of the updates due to timers and debounces,
// so let's check a few times more.
for i := 0; retry && i < 2 && (want.Value != got.Result.Value || want.Error != got.Result.Error); i++ {
update = recvUpdate()
for _, m := range update {
if m.Description.Key == got.Description.Key {
got = m
break
}
}
}
ok1 := assert.Equal(t, want.Value, got.Result.Value)
ok2 := assert.Equal(t, want.Error, got.Result.Error)
if !ok1 || !ok2 {
require.FailNow(t, "check failed")
}
}
update = recvUpdate()
require.Len(t, update, 3)
check(wantMetadata1, update[0], true)
// The second metadata result is not yet posted.
require.Zero(t, update[1].Result.CollectedAt)
wantMetadata2 := wantMetadata1
post(ctx, "foo2", wantMetadata2)
update = recvUpdate()
require.Len(t, update, 3)
check(wantMetadata1, update[0], true)
check(wantMetadata2, update[1], true)
wantMetadata1.Error = "error"
post(ctx, "foo1", wantMetadata1)
update = recvUpdate()
require.Len(t, update, 3)
check(wantMetadata1, update[0], true)
const maxValueLen = 2048
tooLongValueMetadata := wantMetadata1
tooLongValueMetadata.Value = strings.Repeat("a", maxValueLen*2)
tooLongValueMetadata.Error = ""
tooLongValueMetadata.CollectedAt = time.Now()
post(ctx, "foo3", tooLongValueMetadata)
got := recvUpdate()[2]
for i := 0; i < 2 && len(got.Result.Value) != maxValueLen; i++ {
got = recvUpdate()[2]
}
require.Len(t, got.Result.Value, maxValueLen)
require.NotEmpty(t, got.Result.Error)
unknownKeyMetadata := wantMetadata1
post(ctx, "unknown", unknownKeyMetadata)
}
func TestWorkspaceAgent_Metadata_DisplayOrder(t *testing.T) {
t.Parallel()
client, db := coderdtest.NewWithDatabase(t, nil)
user := coderdtest.CreateFirstUser(t, client)
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: user.OrganizationID,
OwnerID: user.UserID,
}).WithAgent(func(agents []*proto.Agent) []*proto.Agent {
agents[0].Metadata = []*proto.Agent_Metadata{
{
DisplayName: "First Meta",
Key: "foo1",
Script: "echo hi",
Interval: 10,
Timeout: 3,
Order: 2,
},
{
DisplayName: "Second Meta",
Key: "foo2",
Script: "echo howdy",
Interval: 10,
Timeout: 3,
Order: 1,
},
{
DisplayName: "Third Meta",
Key: "foo3",
Script: "echo howdy",
Interval: 10,
Timeout: 3,
Order: 2,
},
{
DisplayName: "Fourth Meta",
Key: "foo4",
Script: "echo howdy",
Interval: 10,
Timeout: 3,
Order: 3,
},
}
return agents
}).Do()
workspace, err := client.Workspace(context.Background(), r.Workspace.ID)
require.NoError(t, err)
for _, res := range workspace.LatestBuild.Resources {
for _, a := range res.Agents {
require.Equal(t, codersdk.WorkspaceAgentLifecycleCreated, a.LifecycleState)
}
}
ctx := testutil.Context(t, testutil.WaitMedium)
workspace, err = client.Workspace(ctx, workspace.ID)
require.NoError(t, err, "get workspace")
agentID := workspace.LatestBuild.Resources[0].Agents[0].ID
var update []codersdk.WorkspaceAgentMetadata
// Setup is complete, reset the context.
ctx = testutil.Context(t, testutil.WaitMedium)
updates, errors := client.WatchWorkspaceAgentMetadata(ctx, agentID)
recvUpdate := func() []codersdk.WorkspaceAgentMetadata {
select {
case <-ctx.Done():
t.Fatalf("context done: %v", ctx.Err())
case err := <-errors:
t.Fatalf("error watching metadata: %v", err)
case update := <-updates:
return update
}
return nil
}
update = recvUpdate()
require.Len(t, update, 4)
require.Equal(t, "Second Meta", update[0].Description.DisplayName)
require.Equal(t, "First Meta", update[1].Description.DisplayName)
require.Equal(t, "Third Meta", update[2].Description.DisplayName)
require.Equal(t, "Fourth Meta", update[3].Description.DisplayName)
}
type testWAMErrorStore struct {
database.Store
err atomic.Pointer[error]
}
func (s *testWAMErrorStore) GetWorkspaceAgentMetadata(ctx context.Context, arg database.GetWorkspaceAgentMetadataParams) ([]database.WorkspaceAgentMetadatum, error) {
err := s.err.Load()
if err != nil {
return nil, *err
}
return s.Store.GetWorkspaceAgentMetadata(ctx, arg)
}
func TestWorkspaceAgent_Metadata_CatchMemoryLeak(t *testing.T) {
t.Parallel()
store, psub := dbtestutil.NewDB(t)
db := &testWAMErrorStore{Store: store}
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Named("coderd").Leveled(slog.LevelDebug)
client := coderdtest.New(t, &coderdtest.Options{
Database: db,
Pubsub: psub,
IncludeProvisionerDaemon: true,
Logger: &logger,
})
user := coderdtest.CreateFirstUser(t, client)
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: user.OrganizationID,
OwnerID: user.UserID,
}).WithAgent(func(agents []*proto.Agent) []*proto.Agent {
agents[0].Metadata = []*proto.Agent_Metadata{
{
DisplayName: "First Meta",
Key: "foo1",
Script: "echo hi",
Interval: 10,
Timeout: 3,
},
{
DisplayName: "Second Meta",
Key: "foo2",
Script: "echo bye",
Interval: 10,
Timeout: 3,
},
}
return agents
}).Do()
workspace, err := client.Workspace(context.Background(), r.Workspace.ID)
require.NoError(t, err)
for _, res := range workspace.LatestBuild.Resources {
for _, a := range res.Agents {
require.Equal(t, codersdk.WorkspaceAgentLifecycleCreated, a.LifecycleState)
}
}
agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(r.AgentToken))
ctx := testutil.Context(t, testutil.WaitSuperLong)
conn, err := agentClient.ConnectRPC(ctx)
require.NoError(t, err)
defer func() {
cErr := conn.Close()
require.NoError(t, cErr)
}()
aAPI := agentproto.NewDRPCAgentClient(conn)
manifest := requireGetManifest(ctx, t, aAPI)
post := func(ctx context.Context, key, value string) error {
_, err := aAPI.BatchUpdateMetadata(ctx, &agentproto.BatchUpdateMetadataRequest{
Metadata: []*agentproto.Metadata{
{
Key: key,
Result: agentsdk.ProtoFromMetadataResult(codersdk.WorkspaceAgentMetadataResult{
CollectedAt: time.Now(),
Value: value,
}),
},
},
})
return err
}
workspace, err = client.Workspace(ctx, workspace.ID)
require.NoError(t, err, "get workspace")
// Start the SSE connection.
metadata, errors := client.WatchWorkspaceAgentMetadata(ctx, manifest.AgentID)
// Discard the output, pretending to be a client consuming it.
wantErr := xerrors.New("test error")
metadataDone := testutil.Go(t, func() {
for {
select {
case <-ctx.Done():
return
case _, ok := <-metadata:
if !ok {
return
}
case err := <-errors:
if err != nil && !strings.Contains(err.Error(), wantErr.Error()) {
assert.NoError(t, err, "watch metadata")
}
return
}
}
})
postDone := testutil.Go(t, func() {
for {
select {
case <-metadataDone:
return
default:
}
// We need to send two separate metadata updates to trigger the
// memory leak. foo2 will cause the number of foo1 to be doubled, etc.
err := post(ctx, "foo1", "hi")
if err != nil {
assert.NoError(t, err, "post metadata foo1")
return
}
err = post(ctx, "foo2", "bye")
if err != nil {
assert.NoError(t, err, "post metadata foo1")
return
}
}
})
// In a previously faulty implementation, this database error will trigger
// a close of the goroutine that consumes metadata updates for refreshing
// the metadata sent over SSE. As it was, the exit of the consumer was not
// detected as a trigger to close down the connection.
//
// Further, there was a memory leak in the pubsub subscription that cause
// ballooning of memory (almost double in size every received metadata).
//
// This db error should trigger a close of the SSE connection in the fixed
// implementation. The memory leak should not happen in either case, but
// testing it is not straightforward.
db.err.Store(&wantErr)
testutil.TryReceive(ctx, t, metadataDone)
testutil.TryReceive(ctx, t, postDone)
}
func TestWorkspaceAgent_Startup(t *testing.T) {
t.Parallel()
t.Run("OK", func(t *testing.T) {
t.Parallel()
client, db := coderdtest.NewWithDatabase(t, nil)
user := coderdtest.CreateFirstUser(t, client)
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: user.OrganizationID,
OwnerID: user.UserID,
}).WithAgent().Do()
agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(r.AgentToken))
ctx := testutil.Context(t, testutil.WaitMedium)
var (
expectedVersion = "v1.2.3"
expectedDir = "/home/coder"
expectedSubsystems = []codersdk.AgentSubsystem{
codersdk.AgentSubsystemEnvbox,
codersdk.AgentSubsystemExectrace,
}
)
err := postStartup(ctx, t, agentClient, &agentproto.Startup{
Version: expectedVersion,
ExpandedDirectory: expectedDir,
Subsystems: []agentproto.Startup_Subsystem{
// Not sorted.
agentproto.Startup_EXECTRACE,
agentproto.Startup_ENVBOX,
},
})
require.NoError(t, err)
workspace, err := client.Workspace(ctx, r.Workspace.ID)
require.NoError(t, err)
wsagent, err := client.WorkspaceAgent(ctx, workspace.LatestBuild.Resources[0].Agents[0].ID)
require.NoError(t, err)
require.Equal(t, expectedVersion, wsagent.Version)
require.Equal(t, expectedDir, wsagent.ExpandedDirectory)
// Sorted
require.Equal(t, expectedSubsystems, wsagent.Subsystems)
require.Equal(t, agentproto.CurrentVersion.String(), wsagent.APIVersion)
})
t.Run("InvalidSemver", func(t *testing.T) {
t.Parallel()
client, db := coderdtest.NewWithDatabase(t, nil)
user := coderdtest.CreateFirstUser(t, client)
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: user.OrganizationID,
OwnerID: user.UserID,
}).WithAgent().Do()
agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(r.AgentToken))
ctx := testutil.Context(t, testutil.WaitMedium)
err := postStartup(ctx, t, agentClient, &agentproto.Startup{
Version: "1.2.3",
})
require.ErrorContains(t, err, "invalid agent semver version")
})
}
// TestWorkspaceAgent_UpdatedDERP runs a real coderd server, with a real agent
// and a real client, and updates the DERP map live to ensure connections still
// work.
func TestWorkspaceAgent_UpdatedDERP(t *testing.T) {
t.Parallel()
logger := testutil.Logger(t)
dv := coderdtest.DeploymentValues(t)
err := dv.DERP.Config.BlockDirect.Set("true")
require.NoError(t, err)
client, closer, api := coderdtest.NewWithAPI(t, &coderdtest.Options{
DeploymentValues: dv,
})
defer closer.Close()
user := coderdtest.CreateFirstUser(t, client)
// Change the DERP mapper to our custom one.
var currentDerpMap atomic.Pointer[tailcfg.DERPMap]
originalDerpMap, _ := tailnettest.RunDERPAndSTUN(t)
currentDerpMap.Store(originalDerpMap)
derpMapFn := func(_ *tailcfg.DERPMap) *tailcfg.DERPMap {
return currentDerpMap.Load().Clone()
}
api.DERPMapper.Store(&derpMapFn)
// Start workspace a workspace agent.
r := dbfake.WorkspaceBuild(t, api.Database, database.WorkspaceTable{
OrganizationID: user.OrganizationID,
OwnerID: user.UserID,
}).WithAgent().Do()
agentCloser := agenttest.New(t, client.URL, r.AgentToken)
resources := coderdtest.AwaitWorkspaceAgents(t, client, r.Workspace.ID)
agentID := resources[0].Agents[0].ID
// Connect from a client.
conn1, err := func() (workspacesdk.AgentConn, error) {
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel() // Connection should remain open even if the dial context is canceled.
return workspacesdk.New(client).
DialAgent(ctx, agentID, &workspacesdk.DialAgentOptions{
Logger: logger.Named("client1"),
})
}()
require.NoError(t, err)
defer conn1.Close()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
ok := conn1.AwaitReachable(ctx)
require.True(t, ok)
// Change the DERP map and change the region ID.
newDerpMap, _ := tailnettest.RunDERPAndSTUN(t)
require.NotNil(t, newDerpMap)
newDerpMap.Regions[2] = newDerpMap.Regions[1]
delete(newDerpMap.Regions, 1)
newDerpMap.Regions[2].RegionID = 2
for _, node := range newDerpMap.Regions[2].Nodes {
node.RegionID = 2
}
currentDerpMap.Store(newDerpMap)
// Wait for the agent's DERP map to be updated.
require.Eventually(t, func() bool {
conn := agentCloser.TailnetConn()
if conn == nil {
return false
}
regionIDs := conn.DERPMap().RegionIDs()
return len(regionIDs) == 1 && regionIDs[0] == 2 && conn.Node().PreferredDERP == 2
}, testutil.WaitLong, testutil.IntervalFast)
// Wait for the DERP map to be updated on the existing client.
require.Eventually(t, func() bool {
regionIDs := conn1.TailnetConn().DERPMap().RegionIDs()
return len(regionIDs) == 1 && regionIDs[0] == 2
}, testutil.WaitLong, testutil.IntervalFast)
// The first client should still be able to reach the agent.
ok = conn1.AwaitReachable(ctx)
require.True(t, ok)
// Connect from a second client.
conn2, err := workspacesdk.New(client).
DialAgent(ctx, agentID, &workspacesdk.DialAgentOptions{
Logger: logger.Named("client2"),
})
require.NoError(t, err)
defer conn2.Close()
ok = conn2.AwaitReachable(ctx)
require.True(t, ok)
require.Equal(t, []int{2}, conn2.TailnetConn().DERPMap().RegionIDs())
}
func TestWorkspaceAgentExternalAuthListen(t *testing.T) {
t.Parallel()
// ValidateURLSpam acts as a workspace calling GIT_ASK_PASS which
// will wait until the external auth token is valid. The issue is we spam
// the validate endpoint with requests until the token is valid. We do this
// even if the token has not changed. We are calling validate with the
// same inputs expecting a different result (insanity?). To reduce our
// api rate limit usage, we should do nothing if the inputs have not
// changed.
//
// Note that an expired oauth token is already skipped, so this really
// only covers the case of a revoked token.
t.Run("ValidateURLSpam", func(t *testing.T) {
t.Parallel()
const providerID = "fake-idp"
// Count all the times we call validate
var validateCalls atomic.Int32
fake := oidctest.NewFakeIDP(t, oidctest.WithServing(), oidctest.WithMiddlewares(func(handler http.Handler) http.Handler {
return http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Count all the validate calls
if strings.Contains(r.URL.Path, "/external-auth-validate/") {
validateCalls.Add(1)
}
handler.ServeHTTP(w, r)
}))
}))
ticks := make(chan time.Time)
// setup
ownerClient, db := coderdtest.NewWithDatabase(t, &coderdtest.Options{
NewTicker: func(duration time.Duration) (<-chan time.Time, func()) {
return ticks, func() {}
},
ExternalAuthConfigs: []*externalauth.Config{
fake.ExternalAuthConfig(t, providerID, nil, func(cfg *externalauth.Config) {
cfg.Type = codersdk.EnhancedExternalAuthProviderGitLab.String()
}),
},
})
first := coderdtest.CreateFirstUser(t, ownerClient)
tmpDir := t.TempDir()
client, user := coderdtest.CreateAnotherUser(t, ownerClient, first.OrganizationID)
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: first.OrganizationID,
OwnerID: user.ID,
}).WithAgent(func(agents []*proto.Agent) []*proto.Agent {
agents[0].Directory = tmpDir
return agents
}).Do()
agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(r.AgentToken))
// We need to include an invalid oauth token that is not expired.
dbgen.ExternalAuthLink(t, db, database.ExternalAuthLink{
ProviderID: providerID,
UserID: user.ID,
CreatedAt: dbtime.Now(),
UpdatedAt: dbtime.Now(),
OAuthAccessToken: "invalid",
OAuthRefreshToken: "bad",
OAuthExpiry: dbtime.Now().Add(time.Hour),
})
ctx, cancel := context.WithCancel(testutil.Context(t, testutil.WaitShort))
go func() {
// The request that will block and fire off validate calls.
_, err := agentClient.ExternalAuth(ctx, agentsdk.ExternalAuthRequest{
ID: providerID,
Match: "",
Listen: true,
})
assert.Error(t, err, "this should fail")
}()
// Send off 10 ticks to cause 10 validate calls
for i := 0; i < 10; i++ {
ticks <- time.Now()
}
cancel()
// We expect only 1. One from the initial "Refresh" attempt, and the
// other should be skipped.
// In a failed test, you will likely see 9, as the last one
// gets canceled.
require.EqualValues(t, 1, validateCalls.Load(), "validate calls duplicated on same token")
})
}
func TestOwnedWorkspacesCoordinate(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
logger := testutil.Logger(t)
firstClient, _, api := coderdtest.NewWithAPI(t, &coderdtest.Options{
Coordinator: tailnet.NewCoordinator(logger),
})
firstUser := coderdtest.CreateFirstUser(t, firstClient)
member, memberUser := coderdtest.CreateAnotherUser(t, firstClient, firstUser.OrganizationID, rbac.RoleTemplateAdmin())
// Create a workspace with an agent
firstWorkspace := buildWorkspaceWithAgent(t, member, firstUser.OrganizationID, memberUser.ID, api.Database, api.Pubsub)
u, err := member.URL.Parse("/api/v2/tailnet")
require.NoError(t, err)
q := u.Query()
q.Set("version", "2.0")
u.RawQuery = q.Encode()
//nolint:bodyclose // websocket package closes this for you
wsConn, resp, err := websocket.Dial(ctx, u.String(), &websocket.DialOptions{
HTTPHeader: http.Header{
"Coder-Session-Token": []string{member.SessionToken()},
},
})
if err != nil {
if resp != nil && resp.StatusCode != http.StatusSwitchingProtocols {
err = codersdk.ReadBodyAsError(resp)
}
require.NoError(t, err)
}
defer wsConn.Close(websocket.StatusNormalClosure, "done")
rpcClient, err := tailnet.NewDRPCClient(
websocket.NetConn(ctx, wsConn, websocket.MessageBinary),
logger,
)
require.NoError(t, err)
stream, err := rpcClient.WorkspaceUpdates(ctx, &tailnetproto.WorkspaceUpdatesRequest{
WorkspaceOwnerId: tailnet.UUIDToByteSlice(memberUser.ID),
})
require.NoError(t, err)
// First update will contain the existing workspace and agent
update, err := stream.Recv()
require.NoError(t, err)
require.Len(t, update.UpsertedWorkspaces, 1)
require.EqualValues(t, update.UpsertedWorkspaces[0].Id, firstWorkspace.ID)
require.Len(t, update.UpsertedAgents, 1)
require.EqualValues(t, update.UpsertedAgents[0].WorkspaceId, firstWorkspace.ID)
require.Len(t, update.DeletedWorkspaces, 0)
require.Len(t, update.DeletedAgents, 0)
// Build a second workspace
secondWorkspace := buildWorkspaceWithAgent(t, member, firstUser.OrganizationID, memberUser.ID, api.Database, api.Pubsub)
// Wait for the second workspace to be running with an agent
expectedState := map[uuid.UUID]workspace{
secondWorkspace.ID: {
Status: tailnetproto.Workspace_RUNNING,
NumAgents: 1,
},
}
waitForUpdates(t, ctx, stream, map[uuid.UUID]workspace{}, expectedState)
// Wait for the workspace and agent to be deleted
secondWorkspace.Deleted = true
dbfake.WorkspaceBuild(t, api.Database, secondWorkspace).
Seed(database.WorkspaceBuild{
Transition: database.WorkspaceTransitionDelete,
BuildNumber: 2,
}).Do()
waitForUpdates(t, ctx, stream, expectedState, map[uuid.UUID]workspace{
secondWorkspace.ID: {
Status: tailnetproto.Workspace_DELETED,
NumAgents: 0,
},
})
}
func TestUserTailnetTelemetry(t *testing.T) {
t.Parallel()
telemetryData := &codersdk.CoderDesktopTelemetry{
DeviceOS: "Windows",
DeviceID: "device001",
CoderDesktopVersion: "0.22.1",
}
fullHeader, err := json.Marshal(telemetryData)
require.NoError(t, err)
testCases := []struct {
name string
headers map[string]string
// only used for DeviceID, DeviceOS, CoderDesktopVersion
expected telemetry.UserTailnetConnection
}{
{
name: "no header",
headers: map[string]string{},
expected: telemetry.UserTailnetConnection{},
},
{
name: "full header",
headers: map[string]string{
codersdk.CoderDesktopTelemetryHeader: string(fullHeader),
},
expected: telemetry.UserTailnetConnection{
DeviceOS: ptr.Ref("Windows"),
DeviceID: ptr.Ref("device001"),
CoderDesktopVersion: ptr.Ref("0.22.1"),
},
},
{
name: "empty header",
headers: map[string]string{
codersdk.CoderDesktopTelemetryHeader: "",
},
expected: telemetry.UserTailnetConnection{},
},
{
name: "invalid header",
headers: map[string]string{
codersdk.CoderDesktopTelemetryHeader: "{\"device_os",
},
expected: telemetry.UserTailnetConnection{},
},
}
// nolint: paralleltest // no longer need to reinitialize loop vars in go 1.22
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
logger := testutil.Logger(t)
fTelemetry := newFakeTelemetryReporter(ctx, t, 200)
fTelemetry.enabled = false
firstClient := coderdtest.New(t, &coderdtest.Options{
Logger: &logger,
TelemetryReporter: fTelemetry,
})
firstUser := coderdtest.CreateFirstUser(t, firstClient)
member, memberUser := coderdtest.CreateAnotherUser(t, firstClient, firstUser.OrganizationID, rbac.RoleTemplateAdmin())
headers := http.Header{
"Coder-Session-Token": []string{member.SessionToken()},
}
for k, v := range tc.headers {
headers.Add(k, v)
}
// enable telemetry now that user is created.
fTelemetry.enabled = true
u, err := member.URL.Parse("/api/v2/tailnet")
require.NoError(t, err)
q := u.Query()
q.Set("version", "2.0")
u.RawQuery = q.Encode()
predialTime := dbtime.Now()
//nolint:bodyclose // websocket package closes this for you
wsConn, resp, err := websocket.Dial(ctx, u.String(), &websocket.DialOptions{
HTTPHeader: headers,
})
if err != nil {
if resp != nil && resp.StatusCode != http.StatusSwitchingProtocols {
err = codersdk.ReadBodyAsError(resp)
}
require.NoError(t, err)
}
defer wsConn.Close(websocket.StatusNormalClosure, "done")
// Check telemetry
snapshot := testutil.TryReceive(ctx, t, fTelemetry.snapshots)
require.Len(t, snapshot.UserTailnetConnections, 1)
telemetryConnection := snapshot.UserTailnetConnections[0]
require.Equal(t, memberUser.ID.String(), telemetryConnection.UserID)
require.GreaterOrEqual(t, telemetryConnection.ConnectedAt, predialTime)
require.LessOrEqual(t, telemetryConnection.ConnectedAt, dbtime.Now())
require.NotEmpty(t, telemetryConnection.PeerID)
requireEqualOrBothNil(t, telemetryConnection.DeviceID, tc.expected.DeviceID)
requireEqualOrBothNil(t, telemetryConnection.DeviceOS, tc.expected.DeviceOS)
requireEqualOrBothNil(t, telemetryConnection.CoderDesktopVersion, tc.expected.CoderDesktopVersion)
beforeDisconnectTime := dbtime.Now()
err = wsConn.Close(websocket.StatusNormalClosure, "done")
require.NoError(t, err)
snapshot = testutil.TryReceive(ctx, t, fTelemetry.snapshots)
require.Len(t, snapshot.UserTailnetConnections, 1)
telemetryDisconnection := snapshot.UserTailnetConnections[0]
require.Equal(t, memberUser.ID.String(), telemetryDisconnection.UserID)
require.Equal(t, telemetryConnection.ConnectedAt, telemetryDisconnection.ConnectedAt)
require.Equal(t, telemetryConnection.UserID, telemetryDisconnection.UserID)
require.Equal(t, telemetryConnection.PeerID, telemetryDisconnection.PeerID)
require.NotNil(t, telemetryDisconnection.DisconnectedAt)
require.GreaterOrEqual(t, *telemetryDisconnection.DisconnectedAt, beforeDisconnectTime)
require.LessOrEqual(t, *telemetryDisconnection.DisconnectedAt, dbtime.Now())
requireEqualOrBothNil(t, telemetryConnection.DeviceID, tc.expected.DeviceID)
requireEqualOrBothNil(t, telemetryConnection.DeviceOS, tc.expected.DeviceOS)
requireEqualOrBothNil(t, telemetryConnection.CoderDesktopVersion, tc.expected.CoderDesktopVersion)
})
}
}
func buildWorkspaceWithAgent(
t *testing.T,
client *codersdk.Client,
orgID uuid.UUID,
ownerID uuid.UUID,
db database.Store,
ps pubsub.Pubsub,
) database.WorkspaceTable {
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: orgID,
OwnerID: ownerID,
}).WithAgent().Pubsub(ps).Do()
_ = agenttest.New(t, client.URL, r.AgentToken)
coderdtest.NewWorkspaceAgentWaiter(t, client, r.Workspace.ID).Wait()
return r.Workspace
}
func requireGetManifest(ctx context.Context, t testing.TB, aAPI agentproto.DRPCAgentClient) agentsdk.Manifest {
mp, err := aAPI.GetManifest(ctx, &agentproto.GetManifestRequest{})
require.NoError(t, err)
manifest, err := agentsdk.ManifestFromProto(mp)
require.NoError(t, err)
return manifest
}
func postStartup(ctx context.Context, t testing.TB, client agent.Client, startup *agentproto.Startup) error {
aAPI, _, err := client.ConnectRPC28(ctx)
require.NoError(t, err)
defer func() {
cErr := aAPI.DRPCConn().Close()
require.NoError(t, cErr)
}()
_, err = aAPI.UpdateStartup(ctx, &agentproto.UpdateStartupRequest{Startup: startup})
return err
}
type workspace struct {
Status tailnetproto.Workspace_Status
NumAgents int
}
func waitForUpdates(
t *testing.T,
//nolint:revive // t takes precedence
ctx context.Context,
stream tailnetproto.DRPCTailnet_WorkspaceUpdatesClient,
currentState map[uuid.UUID]workspace,
expectedState map[uuid.UUID]workspace,
) {
t.Helper()
errCh := make(chan error, 1)
go func() {
for {
select {
case <-ctx.Done():
errCh <- ctx.Err()
return
default:
}
update, err := stream.Recv()
if err != nil {
errCh <- err
return
}
for _, ws := range update.UpsertedWorkspaces {
id, err := uuid.FromBytes(ws.Id)
if err != nil {
errCh <- err
return
}
currentState[id] = workspace{
Status: ws.Status,
NumAgents: currentState[id].NumAgents,
}
}
for _, ws := range update.DeletedWorkspaces {
id, err := uuid.FromBytes(ws.Id)
if err != nil {
errCh <- err
return
}
currentState[id] = workspace{
Status: tailnetproto.Workspace_DELETED,
NumAgents: currentState[id].NumAgents,
}
}
for _, a := range update.UpsertedAgents {
id, err := uuid.FromBytes(a.WorkspaceId)
if err != nil {
errCh <- err
return
}
currentState[id] = workspace{
Status: currentState[id].Status,
NumAgents: currentState[id].NumAgents + 1,
}
}
for _, a := range update.DeletedAgents {
id, err := uuid.FromBytes(a.WorkspaceId)
if err != nil {
errCh <- err
return
}
currentState[id] = workspace{
Status: currentState[id].Status,
NumAgents: currentState[id].NumAgents - 1,
}
}
if maps.Equal(currentState, expectedState) {
errCh <- nil
return
}
}
}()
select {
case err := <-errCh:
if err != nil {
t.Fatal(err)
}
case <-ctx.Done():
t.Fatal("Timeout waiting for desired state", currentState)
}
}
// fakeTelemetryReporter is a fake implementation of telemetry.Reporter
// that sends snapshots on a buffered channel, useful for testing.
type fakeTelemetryReporter struct {
enabled bool
snapshots chan *telemetry.Snapshot
t testing.TB
ctx context.Context
}
// newFakeTelemetryReporter creates a new fakeTelemetryReporter with a buffered channel.
// The buffer size determines how many snapshots can be reported before blocking.
func newFakeTelemetryReporter(ctx context.Context, t testing.TB, bufferSize int) *fakeTelemetryReporter {
return &fakeTelemetryReporter{
enabled: true,
snapshots: make(chan *telemetry.Snapshot, bufferSize),
ctx: ctx,
t: t,
}
}
// Report implements the telemetry.Reporter interface by sending the snapshot
// to the snapshots channel.
func (f *fakeTelemetryReporter) Report(snapshot *telemetry.Snapshot) {
if !f.enabled {
return
}
select {
case f.snapshots <- snapshot:
// Successfully sent
case <-f.ctx.Done():
f.t.Error("context closed while writing snapshot")
}
}
// Enabled implements the telemetry.Reporter interface.
func (f *fakeTelemetryReporter) Enabled() bool {
return f.enabled
}
// Close implements the telemetry.Reporter interface.
func (*fakeTelemetryReporter) Close() {}
func requireEqualOrBothNil[T any](t testing.TB, a, b *T) {
t.Helper()
if a != nil && b != nil {
require.Equal(t, *a, *b)
return
}
require.Equal(t, a, b)
}
func TestAgentConnectionInfo(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
dv := coderdtest.DeploymentValues(t)
dv.WorkspaceHostnameSuffix = "yallah"
dv.DERP.Config.BlockDirect = true
dv.DERP.Config.ForceWebSockets = true
client, db := coderdtest.NewWithDatabase(t, &coderdtest.Options{DeploymentValues: dv})
user := coderdtest.CreateFirstUser(t, client)
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: user.OrganizationID,
OwnerID: user.UserID,
}).WithAgent().Do()
info, err := workspacesdk.New(client).AgentConnectionInfoGeneric(ctx)
require.NoError(t, err)
require.Equal(t, "yallah", info.HostnameSuffix)
require.True(t, info.DisableDirectConnections)
require.True(t, info.DERPForceWebSockets)
ws, err := client.Workspace(ctx, r.Workspace.ID)
require.NoError(t, err)
agnt := ws.LatestBuild.Resources[0].Agents[0]
info, err = workspacesdk.New(client).AgentConnectionInfo(ctx, agnt.ID)
require.NoError(t, err)
require.Equal(t, "yallah", info.HostnameSuffix)
require.True(t, info.DisableDirectConnections)
require.True(t, info.DERPForceWebSockets)
}
func TestReinit(t *testing.T) {
t.Parallel()
// Helper to create the prebuilds system user's workspace (an
// unclaimed prebuild) and return the build result. The first
// build's InitiatorID defaults to PrebuildsSystemUserID via
// dbfake.
setupPrebuildWorkspace := func(t *testing.T, db database.Store, orgID uuid.UUID) dbfake.WorkspaceResponse {
t.Helper()
return dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: orgID,
OwnerID: database.PrebuildsSystemUserID,
}).WithAgent().Do()
}
// Helper to simulate claiming a prebuild: change the workspace
// owner to the real user and create a second (claim) build.
claimPrebuild := func(t *testing.T, db database.Store, sqlDB *sql.DB, ws database.WorkspaceTable, claimerID uuid.UUID, templateVersionID uuid.UUID, complete bool) dbfake.WorkspaceResponse {
t.Helper()
// Change the workspace owner to the claiming user.
_, err := sqlDB.Exec("UPDATE workspaces SET owner_id = $1 WHERE id = $2", claimerID, ws.ID)
require.NoError(t, err)
// Update the in-memory workspace to reflect the new owner
// so that dbfake uses it for the second build.
ws.OwnerID = claimerID
builder := dbfake.WorkspaceBuild(t, db, ws).
Seed(database.WorkspaceBuild{
TemplateVersionID: templateVersionID,
BuildNumber: 2,
InitiatorID: claimerID,
Transition: database.WorkspaceTransitionStart,
}).
WithAgent()
if !complete {
builder = builder.Starting()
}
return builder.Do()
}
t.Run("unclaimed prebuild receives reinit via pubsub", func(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
pubsubSpy := pubsubReinitSpy{
Pubsub: ps,
triedToSubscribe: make(chan string),
}
client := coderdtest.New(t, &coderdtest.Options{
Database: db,
Pubsub: &pubsubSpy,
})
user := coderdtest.CreateFirstUser(t, client)
r := setupPrebuildWorkspace(t, db, user.OrganizationID)
pubsubSpy.Lock()
pubsubSpy.expectedEvent = agentsdk.PrebuildClaimedChannel(r.Workspace.ID)
pubsubSpy.Unlock()
agentCtx := testutil.Context(t, testutil.WaitShort)
agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(r.AgentToken))
agentReinitializedCh := make(chan *agentsdk.ReinitializationEvent)
go func() {
reinitEvent, err := agentClient.WaitForReinit(agentCtx)
assert.NoError(t, err)
agentReinitializedCh <- reinitEvent
}()
// We need to subscribe before we publish, lest we miss the
// event.
ctx := testutil.Context(t, testutil.WaitShort)
testutil.TryReceive(ctx, t, pubsubSpy.triedToSubscribe)
// Now that we're subscribed, publish the event.
err := prebuilds.NewPubsubWorkspaceClaimPublisher(ps).PublishWorkspaceClaim(agentsdk.ReinitializationEvent{
WorkspaceID: r.Workspace.ID,
Reason: agentsdk.ReinitializeReasonPrebuildClaimed,
})
require.NoError(t, err)
ctx = testutil.Context(t, testutil.WaitShort)
reinitEvent := testutil.TryReceive(ctx, t, agentReinitializedCh)
require.NotNil(t, reinitEvent)
require.Equal(t, r.Workspace.ID, reinitEvent.WorkspaceID)
})
// Verifies the durable claim check: when an agent reconnects
// after missing the pubsub event, the handler detects that the
// workspace was originally a prebuild (first build initiated by
// PrebuildsSystemUserID), is now claimed (owner changed), and
// the claim build completed, so it sends a one-shot reinit
// event immediately.
t.Run("claimed prebuild receives one-shot reinit on reconnect", func(t *testing.T) {
t.Parallel()
db, ps, sqlDB := dbtestutil.NewDBWithSQLDB(t)
client := coderdtest.New(t, &coderdtest.Options{
Database: db,
Pubsub: ps,
})
user := coderdtest.CreateFirstUser(t, client)
// Create an unclaimed prebuild (build 1, completed).
r := setupPrebuildWorkspace(t, db, user.OrganizationID)
// Claim it: change owner + create build 2 (completed).
claimR := claimPrebuild(t, db, sqlDB, r.Workspace, user.UserID, r.TemplateVersion.ID, true)
agentCtx := testutil.Context(t, testutil.WaitShort)
agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(claimR.AgentToken))
agentReinitializedCh := make(chan *agentsdk.ReinitializationEvent)
go func() {
reinitEvent, err := agentClient.WaitForReinit(agentCtx)
assert.NoError(t, err)
agentReinitializedCh <- reinitEvent
}()
// The agent should receive a reinit event immediately from
// the durable claim check — no pubsub publish needed.
ctx := testutil.Context(t, testutil.WaitShort)
reinitEvent := testutil.TryReceive(ctx, t, agentReinitializedCh)
require.NotNil(t, reinitEvent)
require.Equal(t, r.Workspace.ID, reinitEvent.WorkspaceID)
require.Equal(t, agentsdk.ReinitializeReasonPrebuildClaimed, reinitEvent.Reason)
require.Equal(t, user.UserID, reinitEvent.OwnerID)
})
// Verifies that when the claim build completed with an error,
// the handler returns 409 so the agent treats it as terminal
// and stops retrying (WaitForReinitLoop exits on any 409).
t.Run("failed claim build returns terminal 409", func(t *testing.T) {
t.Parallel()
db, ps, sqlDB := dbtestutil.NewDBWithSQLDB(t)
client := coderdtest.New(t, &coderdtest.Options{
Database: db,
Pubsub: ps,
})
user := coderdtest.CreateFirstUser(t, client)
// Create an unclaimed prebuild (build 1, completed).
r := setupPrebuildWorkspace(t, db, user.OrganizationID)
// Claim it: create build 2 as completed (so agent rows
// exist and the token is valid for auth).
claimR := claimPrebuild(t, db, sqlDB, r.Workspace, user.UserID, r.TemplateVersion.ID, true)
// Simulate a claim build failure: set an error on the
// provisioner job. This models the case where terraform
// apply partially succeeded (creating resources/agents)
// but ultimately errored.
_, err := sqlDB.Exec(
"UPDATE provisioner_jobs SET error = 'simulated claim failure' WHERE id = $1",
claimR.Build.JobID,
)
require.NoError(t, err)
agentCtx := testutil.Context(t, testutil.WaitShort)
agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(claimR.AgentToken))
_, err = agentClient.WaitForReinit(agentCtx)
require.Error(t, err)
var sdkErr *codersdk.Error
require.ErrorAs(t, err, &sdkErr)
require.Equal(t, http.StatusConflict, sdkErr.StatusCode())
})
// Verifies that a regular workspace (never a prebuild) gets a
// 409 Conflict response, causing the agent's reinit loop to
// close the channel gracefully.
t.Run("regular workspace gets 409", func(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
client := coderdtest.New(t, &coderdtest.Options{
Database: db,
Pubsub: ps,
})
user := coderdtest.CreateFirstUser(t, client)
// Create a regular workspace (not a prebuild). The first
// build's initiator will be the user, not the prebuilds
// system user.
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: user.OrganizationID,
OwnerID: user.UserID,
}).WithAgent().Do()
agentCtx := testutil.Context(t, testutil.WaitShort)
agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(r.AgentToken))
// WaitForReinit should return an error wrapping a 409.
_, err := agentClient.WaitForReinit(agentCtx)
require.Error(t, err)
var sdkErr *codersdk.Error
require.ErrorAs(t, err, &sdkErr)
require.Equal(t, http.StatusConflict, sdkErr.StatusCode())
})
}
type pubsubReinitSpy struct {
pubsub.Pubsub
sync.Mutex
triedToSubscribe chan string
expectedEvent string
}
func (p *pubsubReinitSpy) Subscribe(event string, listener pubsub.Listener) (cancel func(), err error) {
cancel, err = p.Pubsub.Subscribe(event, listener)
p.Lock()
if p.expectedEvent != "" && event == p.expectedEvent {
close(p.triedToSubscribe)
}
p.Unlock()
return cancel, err
}