Files
coder/coderd/x/chatd/dialvalidation_test.go
T
Ethan 61e31ec5cc perf(coderd/x/chatd): persist workspace agent binding across chat turns (#23274)
## Summary

This change removes the steady-state "resolve the latest workspace
agent" query from chat execution.

Instead of asking the database for the latest build's agent on every
turn, a chat now persists the workspace/build/agent binding it actually
uses and reuses that binding across subsequent turns. The common path
becomes "load the bound agent by ID and dial it", with fallback paths to
repair the binding when it is missing, stale, or intentionally changed.

## What changes

- add `workspace_id`, `build_id`, and `agent_id` binding fields to
`chats`
- expose those fields through the chat API / SDK so the execution
context is explicit
- load the persisted binding first in chatd, instead of always resolving
the latest build's agent
- persist a refreshed binding when chatd has to re-resolve the workspace
agent
- keep child / subagent chats on the same bound workspace context by
inheriting the parent binding
- leave `build_id` / `agent_id` unset for flows like `create_workspace`,
then bind them lazily on the next agent-backed turn

## Runtime behavior

The binding is treated as an optimistic cache of the agent a chat should
use:

- if the bound agent still exists and dials successfully, we use it
without a latest-build lookup
- if the bound agent is missing or no longer reachable, chatd
re-resolves against the latest build and persists the new binding
- if a workspace mutation changes the chat's target workspace, the
binding is updated as part of that mutation

To avoid reintroducing a hot-path query, dialing uses lazy validation:

- start dialing the cached agent immediately
- only validate against the latest build if the dial is still pending
after a short delay
- if validation finds a different agent, cancel the stale dial, switch
to the current agent, and persist the repaired binding

## Result

The hot path stops issuing
`GetWorkspaceAgentsInLatestBuildByWorkspaceID` for every user message,
which is the source of the DB pressure this PR is addressing. At the
same time, chats still converge to the correct workspace agent when the
binding becomes stale due to rebuilds or explicit workspace changes.
2026-03-26 17:22:38 +11:00

564 lines
15 KiB
Go

package chatd //nolint:testpackage // Uses internal symbols.
import (
"context"
"sync/atomic"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/codersdk/workspacesdk"
"github.com/coder/coder/v2/codersdk/workspacesdk/agentconnmock"
"github.com/coder/coder/v2/testutil"
)
func TestDialWithLazyValidation_FastDial(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
agentID := uuid.New()
workspaceID := uuid.New()
conn := agentconnmock.NewMockAgentConn(ctrl)
var releaseCalls atomic.Int32
var validateCalls atomic.Int32
result, err := dialWithLazyValidation(
context.Background(),
agentID,
workspaceID,
func(_ context.Context, id uuid.UUID) (workspacesdk.AgentConn, func(), error) {
if id != agentID {
return nil, nil, xerrors.Errorf("unexpected agent ID %q", id)
}
return conn, func() {
releaseCalls.Add(1)
}, nil
},
func(_ context.Context, id uuid.UUID) (uuid.UUID, error) {
validateCalls.Add(1)
return uuid.Nil, xerrors.Errorf("unexpected workspace ID %q", id)
},
time.Minute,
)
require.NoError(t, err)
require.Same(t, conn, result.Conn)
require.Equal(t, agentID, result.AgentID)
require.False(t, result.WasSwitched)
require.EqualValues(t, 0, validateCalls.Load())
if result.Release != nil {
result.Release()
}
require.EqualValues(t, 1, releaseCalls.Load())
}
func TestDialWithLazyValidation_SlowDialSameAgent(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
agentID := uuid.New()
workspaceID := uuid.New()
conn := agentconnmock.NewMockAgentConn(ctrl)
unblockDial := make(chan struct{})
var releaseCalls atomic.Int32
var validateCalls atomic.Int32
result, err := dialWithLazyValidation(
context.Background(),
agentID,
workspaceID,
func(ctx context.Context, id uuid.UUID) (workspacesdk.AgentConn, func(), error) {
if id != agentID {
return nil, nil, xerrors.Errorf("unexpected agent ID %q", id)
}
select {
case <-unblockDial:
return conn, func() {
releaseCalls.Add(1)
}, nil
case <-ctx.Done():
return nil, nil, ctx.Err()
}
},
func(_ context.Context, id uuid.UUID) (uuid.UUID, error) {
if id != workspaceID {
return uuid.Nil, xerrors.Errorf("unexpected workspace ID %q", id)
}
validateCalls.Add(1)
close(unblockDial)
return agentID, nil
},
0,
)
require.NoError(t, err)
require.Same(t, conn, result.Conn)
require.Equal(t, agentID, result.AgentID)
require.False(t, result.WasSwitched)
require.EqualValues(t, 1, validateCalls.Load())
if result.Release != nil {
result.Release()
}
require.EqualValues(t, 1, releaseCalls.Load())
}
func TestDialWithLazyValidation_SlowDialStaleAgent(t *testing.T) {
t.Parallel()
t.Run("LateSuccessReleasesStaleConn", func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
staleAgentID := uuid.New()
currentAgentID := uuid.New()
workspaceID := uuid.New()
staleConn := agentconnmock.NewMockAgentConn(ctrl)
currentConn := agentconnmock.NewMockAgentConn(ctrl)
var dialCalls atomic.Int32
var validateCalls atomic.Int32
var staleReleaseCalls atomic.Int32
var currentReleaseCalls atomic.Int32
result, err := dialWithLazyValidation(
context.Background(),
staleAgentID,
workspaceID,
func(ctx context.Context, id uuid.UUID) (workspacesdk.AgentConn, func(), error) {
dialCalls.Add(1)
switch id {
case staleAgentID:
<-ctx.Done()
return staleConn, func() {
staleReleaseCalls.Add(1)
}, nil
case currentAgentID:
return currentConn, func() {
currentReleaseCalls.Add(1)
}, nil
default:
return nil, nil, xerrors.Errorf("unexpected agent ID %q", id)
}
},
func(_ context.Context, id uuid.UUID) (uuid.UUID, error) {
if id != workspaceID {
return uuid.Nil, xerrors.Errorf("unexpected workspace ID %q", id)
}
validateCalls.Add(1)
return currentAgentID, nil
},
0,
)
require.NoError(t, err)
require.Same(t, currentConn, result.Conn)
require.Equal(t, currentAgentID, result.AgentID)
require.True(t, result.WasSwitched)
require.Eventually(t, func() bool {
return dialCalls.Load() == 2
}, testutil.WaitShort, testutil.IntervalFast)
require.EqualValues(t, 1, validateCalls.Load())
require.Eventually(t, func() bool {
return staleReleaseCalls.Load() == 1
}, testutil.WaitShort, testutil.IntervalFast)
if result.Release != nil {
result.Release()
}
require.EqualValues(t, 1, currentReleaseCalls.Load())
})
t.Run("CanceledFailureDoesNotReleaseStaleConn", func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
staleAgentID := uuid.New()
currentAgentID := uuid.New()
workspaceID := uuid.New()
currentConn := agentconnmock.NewMockAgentConn(ctrl)
var dialCalls atomic.Int32
var validateCalls atomic.Int32
var staleReleaseCalls atomic.Int32
var currentReleaseCalls atomic.Int32
result, err := dialWithLazyValidation(
context.Background(),
staleAgentID,
workspaceID,
func(ctx context.Context, id uuid.UUID) (workspacesdk.AgentConn, func(), error) {
dialCalls.Add(1)
switch id {
case staleAgentID:
<-ctx.Done()
return nil, func() {
staleReleaseCalls.Add(1)
}, ctx.Err()
case currentAgentID:
return currentConn, func() {
currentReleaseCalls.Add(1)
}, nil
default:
return nil, nil, xerrors.Errorf("unexpected agent ID %q", id)
}
},
func(_ context.Context, id uuid.UUID) (uuid.UUID, error) {
if id != workspaceID {
return uuid.Nil, xerrors.Errorf("unexpected workspace ID %q", id)
}
validateCalls.Add(1)
return currentAgentID, nil
},
0,
)
require.NoError(t, err)
require.Same(t, currentConn, result.Conn)
require.Equal(t, currentAgentID, result.AgentID)
require.True(t, result.WasSwitched)
require.Eventually(t, func() bool {
return dialCalls.Load() == 2
}, testutil.WaitShort, testutil.IntervalFast)
require.EqualValues(t, 1, validateCalls.Load())
require.EqualValues(t, 0, staleReleaseCalls.Load())
if result.Release != nil {
result.Release()
}
require.EqualValues(t, 1, currentReleaseCalls.Load())
})
t.Run("SwitchDoesNotBlock", func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
staleAgentID := uuid.New()
currentAgentID := uuid.New()
workspaceID := uuid.New()
staleConn := agentconnmock.NewMockAgentConn(ctrl)
currentConn := agentconnmock.NewMockAgentConn(ctrl)
staleDialStarted := make(chan struct{})
allowStaleReturn := make(chan struct{})
var dialCalls atomic.Int32
var validateCalls atomic.Int32
var staleReleaseCalls atomic.Int32
var currentReleaseCalls atomic.Int32
var staleReturnReleased atomic.Bool
releaseStaleReturn := func() {
if staleReturnReleased.CompareAndSwap(false, true) {
close(allowStaleReturn)
}
}
defer releaseStaleReturn()
resultCh := make(chan DialResult, 1)
errCh := make(chan error, 1)
go func() {
result, err := dialWithLazyValidation(
context.Background(),
staleAgentID,
workspaceID,
func(_ context.Context, id uuid.UUID) (workspacesdk.AgentConn, func(), error) {
dialCalls.Add(1)
switch id {
case staleAgentID:
close(staleDialStarted)
<-allowStaleReturn
return staleConn, func() {
staleReleaseCalls.Add(1)
}, nil
case currentAgentID:
return currentConn, func() {
currentReleaseCalls.Add(1)
}, nil
default:
return nil, nil, xerrors.Errorf("unexpected agent ID %q", id)
}
},
func(_ context.Context, id uuid.UUID) (uuid.UUID, error) {
if id != workspaceID {
return uuid.Nil, xerrors.Errorf("unexpected workspace ID %q", id)
}
<-staleDialStarted
validateCalls.Add(1)
return currentAgentID, nil
},
0,
)
if err != nil {
errCh <- err
return
}
resultCh <- result
}()
var result DialResult
select {
case err := <-errCh:
require.NoError(t, err)
case result = <-resultCh:
require.Same(t, currentConn, result.Conn)
require.Equal(t, currentAgentID, result.AgentID)
require.True(t, result.WasSwitched)
releaseStaleReturn()
case <-time.After(testutil.WaitShort):
t.Fatal("dialWithLazyValidation blocked on stale dial cleanup")
}
require.EqualValues(t, 2, dialCalls.Load())
require.EqualValues(t, 1, validateCalls.Load())
require.Eventually(t, func() bool {
return staleReleaseCalls.Load() == 1
}, testutil.WaitShort, testutil.IntervalFast)
if result.Release != nil {
result.Release()
}
require.EqualValues(t, 1, currentReleaseCalls.Load())
})
}
func TestDialWithLazyValidation_FastFailure(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
staleAgentID := uuid.New()
currentAgentID := uuid.New()
workspaceID := uuid.New()
currentConn := agentconnmock.NewMockAgentConn(ctrl)
var dialCalls atomic.Int32
var validateCalls atomic.Int32
var currentReleaseCalls atomic.Int32
result, err := dialWithLazyValidation(
context.Background(),
staleAgentID,
workspaceID,
func(_ context.Context, id uuid.UUID) (workspacesdk.AgentConn, func(), error) {
switch dialCalls.Add(1) {
case 1:
if id != staleAgentID {
return nil, nil, xerrors.Errorf("unexpected agent ID %q", id)
}
return nil, nil, xerrors.New("dial failed")
case 2:
if id != currentAgentID {
return nil, nil, xerrors.Errorf("unexpected agent ID %q", id)
}
return currentConn, func() {
currentReleaseCalls.Add(1)
}, nil
default:
return nil, nil, xerrors.New("unexpected dial call")
}
},
func(_ context.Context, id uuid.UUID) (uuid.UUID, error) {
if id != workspaceID {
return uuid.Nil, xerrors.Errorf("unexpected workspace ID %q", id)
}
validateCalls.Add(1)
return currentAgentID, nil
},
time.Minute,
)
require.NoError(t, err)
require.Same(t, currentConn, result.Conn)
require.Equal(t, currentAgentID, result.AgentID)
require.True(t, result.WasSwitched)
require.EqualValues(t, 2, dialCalls.Load())
require.EqualValues(t, 1, validateCalls.Load())
if result.Release != nil {
result.Release()
}
require.EqualValues(t, 1, currentReleaseCalls.Load())
}
func TestDialWithLazyValidation_FastFailureSameAgent(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
agentID := uuid.New()
workspaceID := uuid.New()
conn := agentconnmock.NewMockAgentConn(ctrl)
var dialCalls atomic.Int32
var releaseCalls atomic.Int32
var validateCalls atomic.Int32
result, err := dialWithLazyValidation(
context.Background(),
agentID,
workspaceID,
func(_ context.Context, id uuid.UUID) (workspacesdk.AgentConn, func(), error) {
if id != agentID {
return nil, nil, xerrors.Errorf("unexpected agent ID %q", id)
}
switch dialCalls.Add(1) {
case 1:
return nil, nil, xerrors.New("dial failed")
case 2:
return conn, func() {
releaseCalls.Add(1)
}, nil
default:
return nil, nil, xerrors.New("unexpected dial call")
}
},
func(_ context.Context, id uuid.UUID) (uuid.UUID, error) {
if id != workspaceID {
return uuid.Nil, xerrors.Errorf("unexpected workspace ID %q", id)
}
validateCalls.Add(1)
return agentID, nil
},
time.Minute,
)
require.NoError(t, err)
require.Same(t, conn, result.Conn)
require.Equal(t, agentID, result.AgentID)
require.False(t, result.WasSwitched)
require.EqualValues(t, 2, dialCalls.Load())
require.EqualValues(t, 1, validateCalls.Load())
if result.Release != nil {
result.Release()
}
require.EqualValues(t, 1, releaseCalls.Load())
}
func TestDialWithLazyValidation_FastFailureSameAgentRetryFails(t *testing.T) {
t.Parallel()
agentID := uuid.New()
workspaceID := uuid.New()
var dialCalls atomic.Int32
var validateCalls atomic.Int32
_, err := dialWithLazyValidation(
context.Background(),
agentID,
workspaceID,
func(_ context.Context, id uuid.UUID) (workspacesdk.AgentConn, func(), error) {
if id != agentID {
return nil, nil, xerrors.Errorf("unexpected agent ID %q", id)
}
switch dialCalls.Add(1) {
case 1:
return nil, nil, xerrors.New("dial failed")
case 2:
return nil, nil, xerrors.New("retry failed")
default:
return nil, nil, xerrors.New("unexpected dial call")
}
},
func(_ context.Context, id uuid.UUID) (uuid.UUID, error) {
if id != workspaceID {
return uuid.Nil, xerrors.Errorf("unexpected workspace ID %q", id)
}
validateCalls.Add(1)
return agentID, nil
},
time.Minute,
)
require.EqualError(t, err, "dial with lazy validation: retry failed")
require.EqualValues(t, 2, dialCalls.Load())
require.EqualValues(t, 1, validateCalls.Load())
}
func TestDialWithLazyValidation_ValidationError(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
agentID := uuid.New()
workspaceID := uuid.New()
conn := agentconnmock.NewMockAgentConn(ctrl)
unblockDial := make(chan struct{})
var releaseCalls atomic.Int32
var validateCalls atomic.Int32
result, err := dialWithLazyValidation(
context.Background(),
agentID,
workspaceID,
func(ctx context.Context, id uuid.UUID) (workspacesdk.AgentConn, func(), error) {
if id != agentID {
return nil, nil, xerrors.Errorf("unexpected agent ID %q", id)
}
select {
case <-unblockDial:
return conn, func() {
releaseCalls.Add(1)
}, nil
case <-ctx.Done():
return nil, nil, ctx.Err()
}
},
func(_ context.Context, id uuid.UUID) (uuid.UUID, error) {
if id != workspaceID {
return uuid.Nil, xerrors.Errorf("unexpected workspace ID %q", id)
}
validateCalls.Add(1)
// Validation fails — code should fall back to waiting
// for the original dial.
close(unblockDial)
return uuid.Nil, xerrors.New("db connection reset")
},
0,
)
require.NoError(t, err)
require.Same(t, conn, result.Conn)
require.Equal(t, agentID, result.AgentID)
require.False(t, result.WasSwitched)
require.EqualValues(t, 1, validateCalls.Load())
if result.Release != nil {
result.Release()
}
require.EqualValues(t, 1, releaseCalls.Load())
}
func TestDialWithLazyValidation_ContextCanceled(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
agentID := uuid.New()
workspaceID := uuid.New()
var validateCalls atomic.Int32
_, err := dialWithLazyValidation(
ctx,
agentID,
workspaceID,
func(ctx context.Context, id uuid.UUID) (workspacesdk.AgentConn, func(), error) {
if id != agentID {
return nil, nil, xerrors.Errorf("unexpected agent ID %q", id)
}
<-ctx.Done()
return nil, nil, ctx.Err()
},
func(_ context.Context, id uuid.UUID) (uuid.UUID, error) {
if id != workspaceID {
return uuid.Nil, xerrors.Errorf("unexpected workspace ID %q", id)
}
validateCalls.Add(1)
cancel()
return agentID, nil
},
0,
)
require.ErrorIs(t, err, context.Canceled)
require.EqualValues(t, 1, validateCalls.Load())
}