mirror of
https://github.com/coder/coder.git
synced 2026-06-05 05:58:20 +00:00
61e31ec5cc
## Summary This change removes the steady-state "resolve the latest workspace agent" query from chat execution. Instead of asking the database for the latest build's agent on every turn, a chat now persists the workspace/build/agent binding it actually uses and reuses that binding across subsequent turns. The common path becomes "load the bound agent by ID and dial it", with fallback paths to repair the binding when it is missing, stale, or intentionally changed. ## What changes - add `workspace_id`, `build_id`, and `agent_id` binding fields to `chats` - expose those fields through the chat API / SDK so the execution context is explicit - load the persisted binding first in chatd, instead of always resolving the latest build's agent - persist a refreshed binding when chatd has to re-resolve the workspace agent - keep child / subagent chats on the same bound workspace context by inheriting the parent binding - leave `build_id` / `agent_id` unset for flows like `create_workspace`, then bind them lazily on the next agent-backed turn ## Runtime behavior The binding is treated as an optimistic cache of the agent a chat should use: - if the bound agent still exists and dials successfully, we use it without a latest-build lookup - if the bound agent is missing or no longer reachable, chatd re-resolves against the latest build and persists the new binding - if a workspace mutation changes the chat's target workspace, the binding is updated as part of that mutation To avoid reintroducing a hot-path query, dialing uses lazy validation: - start dialing the cached agent immediately - only validate against the latest build if the dial is still pending after a short delay - if validation finds a different agent, cancel the stale dial, switch to the current agent, and persist the repaired binding ## Result The hot path stops issuing `GetWorkspaceAgentsInLatestBuildByWorkspaceID` for every user message, which is the source of the DB pressure this PR is addressing. At the same time, chats still converge to the correct workspace agent when the binding becomes stale due to rebuilds or explicit workspace changes.
564 lines
15 KiB
Go
564 lines
15 KiB
Go
package chatd //nolint:testpackage // Uses internal symbols.
|
|
|
|
import (
|
|
"context"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/stretchr/testify/require"
|
|
"go.uber.org/mock/gomock"
|
|
"golang.org/x/xerrors"
|
|
|
|
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
|
"github.com/coder/coder/v2/codersdk/workspacesdk/agentconnmock"
|
|
"github.com/coder/coder/v2/testutil"
|
|
)
|
|
|
|
func TestDialWithLazyValidation_FastDial(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
agentID := uuid.New()
|
|
workspaceID := uuid.New()
|
|
conn := agentconnmock.NewMockAgentConn(ctrl)
|
|
|
|
var releaseCalls atomic.Int32
|
|
var validateCalls atomic.Int32
|
|
|
|
result, err := dialWithLazyValidation(
|
|
context.Background(),
|
|
agentID,
|
|
workspaceID,
|
|
func(_ context.Context, id uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
|
if id != agentID {
|
|
return nil, nil, xerrors.Errorf("unexpected agent ID %q", id)
|
|
}
|
|
return conn, func() {
|
|
releaseCalls.Add(1)
|
|
}, nil
|
|
},
|
|
func(_ context.Context, id uuid.UUID) (uuid.UUID, error) {
|
|
validateCalls.Add(1)
|
|
return uuid.Nil, xerrors.Errorf("unexpected workspace ID %q", id)
|
|
},
|
|
time.Minute,
|
|
)
|
|
require.NoError(t, err)
|
|
require.Same(t, conn, result.Conn)
|
|
require.Equal(t, agentID, result.AgentID)
|
|
require.False(t, result.WasSwitched)
|
|
require.EqualValues(t, 0, validateCalls.Load())
|
|
|
|
if result.Release != nil {
|
|
result.Release()
|
|
}
|
|
require.EqualValues(t, 1, releaseCalls.Load())
|
|
}
|
|
|
|
func TestDialWithLazyValidation_SlowDialSameAgent(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
agentID := uuid.New()
|
|
workspaceID := uuid.New()
|
|
conn := agentconnmock.NewMockAgentConn(ctrl)
|
|
unblockDial := make(chan struct{})
|
|
|
|
var releaseCalls atomic.Int32
|
|
var validateCalls atomic.Int32
|
|
|
|
result, err := dialWithLazyValidation(
|
|
context.Background(),
|
|
agentID,
|
|
workspaceID,
|
|
func(ctx context.Context, id uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
|
if id != agentID {
|
|
return nil, nil, xerrors.Errorf("unexpected agent ID %q", id)
|
|
}
|
|
select {
|
|
case <-unblockDial:
|
|
return conn, func() {
|
|
releaseCalls.Add(1)
|
|
}, nil
|
|
case <-ctx.Done():
|
|
return nil, nil, ctx.Err()
|
|
}
|
|
},
|
|
func(_ context.Context, id uuid.UUID) (uuid.UUID, error) {
|
|
if id != workspaceID {
|
|
return uuid.Nil, xerrors.Errorf("unexpected workspace ID %q", id)
|
|
}
|
|
validateCalls.Add(1)
|
|
close(unblockDial)
|
|
return agentID, nil
|
|
},
|
|
0,
|
|
)
|
|
require.NoError(t, err)
|
|
require.Same(t, conn, result.Conn)
|
|
require.Equal(t, agentID, result.AgentID)
|
|
require.False(t, result.WasSwitched)
|
|
require.EqualValues(t, 1, validateCalls.Load())
|
|
|
|
if result.Release != nil {
|
|
result.Release()
|
|
}
|
|
require.EqualValues(t, 1, releaseCalls.Load())
|
|
}
|
|
|
|
func TestDialWithLazyValidation_SlowDialStaleAgent(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
t.Run("LateSuccessReleasesStaleConn", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
staleAgentID := uuid.New()
|
|
currentAgentID := uuid.New()
|
|
workspaceID := uuid.New()
|
|
staleConn := agentconnmock.NewMockAgentConn(ctrl)
|
|
currentConn := agentconnmock.NewMockAgentConn(ctrl)
|
|
|
|
var dialCalls atomic.Int32
|
|
var validateCalls atomic.Int32
|
|
var staleReleaseCalls atomic.Int32
|
|
var currentReleaseCalls atomic.Int32
|
|
|
|
result, err := dialWithLazyValidation(
|
|
context.Background(),
|
|
staleAgentID,
|
|
workspaceID,
|
|
func(ctx context.Context, id uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
|
dialCalls.Add(1)
|
|
switch id {
|
|
case staleAgentID:
|
|
<-ctx.Done()
|
|
return staleConn, func() {
|
|
staleReleaseCalls.Add(1)
|
|
}, nil
|
|
case currentAgentID:
|
|
return currentConn, func() {
|
|
currentReleaseCalls.Add(1)
|
|
}, nil
|
|
default:
|
|
return nil, nil, xerrors.Errorf("unexpected agent ID %q", id)
|
|
}
|
|
},
|
|
func(_ context.Context, id uuid.UUID) (uuid.UUID, error) {
|
|
if id != workspaceID {
|
|
return uuid.Nil, xerrors.Errorf("unexpected workspace ID %q", id)
|
|
}
|
|
validateCalls.Add(1)
|
|
return currentAgentID, nil
|
|
},
|
|
0,
|
|
)
|
|
require.NoError(t, err)
|
|
require.Same(t, currentConn, result.Conn)
|
|
require.Equal(t, currentAgentID, result.AgentID)
|
|
require.True(t, result.WasSwitched)
|
|
require.Eventually(t, func() bool {
|
|
return dialCalls.Load() == 2
|
|
}, testutil.WaitShort, testutil.IntervalFast)
|
|
require.EqualValues(t, 1, validateCalls.Load())
|
|
require.Eventually(t, func() bool {
|
|
return staleReleaseCalls.Load() == 1
|
|
}, testutil.WaitShort, testutil.IntervalFast)
|
|
|
|
if result.Release != nil {
|
|
result.Release()
|
|
}
|
|
require.EqualValues(t, 1, currentReleaseCalls.Load())
|
|
})
|
|
|
|
t.Run("CanceledFailureDoesNotReleaseStaleConn", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
staleAgentID := uuid.New()
|
|
currentAgentID := uuid.New()
|
|
workspaceID := uuid.New()
|
|
currentConn := agentconnmock.NewMockAgentConn(ctrl)
|
|
|
|
var dialCalls atomic.Int32
|
|
var validateCalls atomic.Int32
|
|
var staleReleaseCalls atomic.Int32
|
|
var currentReleaseCalls atomic.Int32
|
|
|
|
result, err := dialWithLazyValidation(
|
|
context.Background(),
|
|
staleAgentID,
|
|
workspaceID,
|
|
func(ctx context.Context, id uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
|
dialCalls.Add(1)
|
|
switch id {
|
|
case staleAgentID:
|
|
<-ctx.Done()
|
|
return nil, func() {
|
|
staleReleaseCalls.Add(1)
|
|
}, ctx.Err()
|
|
case currentAgentID:
|
|
return currentConn, func() {
|
|
currentReleaseCalls.Add(1)
|
|
}, nil
|
|
default:
|
|
return nil, nil, xerrors.Errorf("unexpected agent ID %q", id)
|
|
}
|
|
},
|
|
func(_ context.Context, id uuid.UUID) (uuid.UUID, error) {
|
|
if id != workspaceID {
|
|
return uuid.Nil, xerrors.Errorf("unexpected workspace ID %q", id)
|
|
}
|
|
validateCalls.Add(1)
|
|
return currentAgentID, nil
|
|
},
|
|
0,
|
|
)
|
|
require.NoError(t, err)
|
|
require.Same(t, currentConn, result.Conn)
|
|
require.Equal(t, currentAgentID, result.AgentID)
|
|
require.True(t, result.WasSwitched)
|
|
require.Eventually(t, func() bool {
|
|
return dialCalls.Load() == 2
|
|
}, testutil.WaitShort, testutil.IntervalFast)
|
|
require.EqualValues(t, 1, validateCalls.Load())
|
|
require.EqualValues(t, 0, staleReleaseCalls.Load())
|
|
|
|
if result.Release != nil {
|
|
result.Release()
|
|
}
|
|
require.EqualValues(t, 1, currentReleaseCalls.Load())
|
|
})
|
|
|
|
t.Run("SwitchDoesNotBlock", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
staleAgentID := uuid.New()
|
|
currentAgentID := uuid.New()
|
|
workspaceID := uuid.New()
|
|
staleConn := agentconnmock.NewMockAgentConn(ctrl)
|
|
currentConn := agentconnmock.NewMockAgentConn(ctrl)
|
|
staleDialStarted := make(chan struct{})
|
|
allowStaleReturn := make(chan struct{})
|
|
|
|
var dialCalls atomic.Int32
|
|
var validateCalls atomic.Int32
|
|
var staleReleaseCalls atomic.Int32
|
|
var currentReleaseCalls atomic.Int32
|
|
var staleReturnReleased atomic.Bool
|
|
releaseStaleReturn := func() {
|
|
if staleReturnReleased.CompareAndSwap(false, true) {
|
|
close(allowStaleReturn)
|
|
}
|
|
}
|
|
defer releaseStaleReturn()
|
|
|
|
resultCh := make(chan DialResult, 1)
|
|
errCh := make(chan error, 1)
|
|
go func() {
|
|
result, err := dialWithLazyValidation(
|
|
context.Background(),
|
|
staleAgentID,
|
|
workspaceID,
|
|
func(_ context.Context, id uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
|
dialCalls.Add(1)
|
|
switch id {
|
|
case staleAgentID:
|
|
close(staleDialStarted)
|
|
<-allowStaleReturn
|
|
return staleConn, func() {
|
|
staleReleaseCalls.Add(1)
|
|
}, nil
|
|
case currentAgentID:
|
|
return currentConn, func() {
|
|
currentReleaseCalls.Add(1)
|
|
}, nil
|
|
default:
|
|
return nil, nil, xerrors.Errorf("unexpected agent ID %q", id)
|
|
}
|
|
},
|
|
func(_ context.Context, id uuid.UUID) (uuid.UUID, error) {
|
|
if id != workspaceID {
|
|
return uuid.Nil, xerrors.Errorf("unexpected workspace ID %q", id)
|
|
}
|
|
<-staleDialStarted
|
|
validateCalls.Add(1)
|
|
return currentAgentID, nil
|
|
},
|
|
0,
|
|
)
|
|
if err != nil {
|
|
errCh <- err
|
|
return
|
|
}
|
|
resultCh <- result
|
|
}()
|
|
|
|
var result DialResult
|
|
select {
|
|
case err := <-errCh:
|
|
require.NoError(t, err)
|
|
case result = <-resultCh:
|
|
require.Same(t, currentConn, result.Conn)
|
|
require.Equal(t, currentAgentID, result.AgentID)
|
|
require.True(t, result.WasSwitched)
|
|
releaseStaleReturn()
|
|
case <-time.After(testutil.WaitShort):
|
|
t.Fatal("dialWithLazyValidation blocked on stale dial cleanup")
|
|
}
|
|
|
|
require.EqualValues(t, 2, dialCalls.Load())
|
|
require.EqualValues(t, 1, validateCalls.Load())
|
|
require.Eventually(t, func() bool {
|
|
return staleReleaseCalls.Load() == 1
|
|
}, testutil.WaitShort, testutil.IntervalFast)
|
|
|
|
if result.Release != nil {
|
|
result.Release()
|
|
}
|
|
require.EqualValues(t, 1, currentReleaseCalls.Load())
|
|
})
|
|
}
|
|
|
|
func TestDialWithLazyValidation_FastFailure(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
staleAgentID := uuid.New()
|
|
currentAgentID := uuid.New()
|
|
workspaceID := uuid.New()
|
|
currentConn := agentconnmock.NewMockAgentConn(ctrl)
|
|
|
|
var dialCalls atomic.Int32
|
|
var validateCalls atomic.Int32
|
|
var currentReleaseCalls atomic.Int32
|
|
|
|
result, err := dialWithLazyValidation(
|
|
context.Background(),
|
|
staleAgentID,
|
|
workspaceID,
|
|
func(_ context.Context, id uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
|
switch dialCalls.Add(1) {
|
|
case 1:
|
|
if id != staleAgentID {
|
|
return nil, nil, xerrors.Errorf("unexpected agent ID %q", id)
|
|
}
|
|
return nil, nil, xerrors.New("dial failed")
|
|
case 2:
|
|
if id != currentAgentID {
|
|
return nil, nil, xerrors.Errorf("unexpected agent ID %q", id)
|
|
}
|
|
return currentConn, func() {
|
|
currentReleaseCalls.Add(1)
|
|
}, nil
|
|
default:
|
|
return nil, nil, xerrors.New("unexpected dial call")
|
|
}
|
|
},
|
|
func(_ context.Context, id uuid.UUID) (uuid.UUID, error) {
|
|
if id != workspaceID {
|
|
return uuid.Nil, xerrors.Errorf("unexpected workspace ID %q", id)
|
|
}
|
|
validateCalls.Add(1)
|
|
return currentAgentID, nil
|
|
},
|
|
time.Minute,
|
|
)
|
|
require.NoError(t, err)
|
|
require.Same(t, currentConn, result.Conn)
|
|
require.Equal(t, currentAgentID, result.AgentID)
|
|
require.True(t, result.WasSwitched)
|
|
require.EqualValues(t, 2, dialCalls.Load())
|
|
require.EqualValues(t, 1, validateCalls.Load())
|
|
|
|
if result.Release != nil {
|
|
result.Release()
|
|
}
|
|
require.EqualValues(t, 1, currentReleaseCalls.Load())
|
|
}
|
|
|
|
func TestDialWithLazyValidation_FastFailureSameAgent(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
agentID := uuid.New()
|
|
workspaceID := uuid.New()
|
|
conn := agentconnmock.NewMockAgentConn(ctrl)
|
|
|
|
var dialCalls atomic.Int32
|
|
var releaseCalls atomic.Int32
|
|
var validateCalls atomic.Int32
|
|
|
|
result, err := dialWithLazyValidation(
|
|
context.Background(),
|
|
agentID,
|
|
workspaceID,
|
|
func(_ context.Context, id uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
|
if id != agentID {
|
|
return nil, nil, xerrors.Errorf("unexpected agent ID %q", id)
|
|
}
|
|
switch dialCalls.Add(1) {
|
|
case 1:
|
|
return nil, nil, xerrors.New("dial failed")
|
|
case 2:
|
|
return conn, func() {
|
|
releaseCalls.Add(1)
|
|
}, nil
|
|
default:
|
|
return nil, nil, xerrors.New("unexpected dial call")
|
|
}
|
|
},
|
|
func(_ context.Context, id uuid.UUID) (uuid.UUID, error) {
|
|
if id != workspaceID {
|
|
return uuid.Nil, xerrors.Errorf("unexpected workspace ID %q", id)
|
|
}
|
|
validateCalls.Add(1)
|
|
return agentID, nil
|
|
},
|
|
time.Minute,
|
|
)
|
|
require.NoError(t, err)
|
|
require.Same(t, conn, result.Conn)
|
|
require.Equal(t, agentID, result.AgentID)
|
|
require.False(t, result.WasSwitched)
|
|
require.EqualValues(t, 2, dialCalls.Load())
|
|
require.EqualValues(t, 1, validateCalls.Load())
|
|
|
|
if result.Release != nil {
|
|
result.Release()
|
|
}
|
|
require.EqualValues(t, 1, releaseCalls.Load())
|
|
}
|
|
|
|
func TestDialWithLazyValidation_FastFailureSameAgentRetryFails(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
agentID := uuid.New()
|
|
workspaceID := uuid.New()
|
|
|
|
var dialCalls atomic.Int32
|
|
var validateCalls atomic.Int32
|
|
|
|
_, err := dialWithLazyValidation(
|
|
context.Background(),
|
|
agentID,
|
|
workspaceID,
|
|
func(_ context.Context, id uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
|
if id != agentID {
|
|
return nil, nil, xerrors.Errorf("unexpected agent ID %q", id)
|
|
}
|
|
switch dialCalls.Add(1) {
|
|
case 1:
|
|
return nil, nil, xerrors.New("dial failed")
|
|
case 2:
|
|
return nil, nil, xerrors.New("retry failed")
|
|
default:
|
|
return nil, nil, xerrors.New("unexpected dial call")
|
|
}
|
|
},
|
|
func(_ context.Context, id uuid.UUID) (uuid.UUID, error) {
|
|
if id != workspaceID {
|
|
return uuid.Nil, xerrors.Errorf("unexpected workspace ID %q", id)
|
|
}
|
|
validateCalls.Add(1)
|
|
return agentID, nil
|
|
},
|
|
time.Minute,
|
|
)
|
|
require.EqualError(t, err, "dial with lazy validation: retry failed")
|
|
require.EqualValues(t, 2, dialCalls.Load())
|
|
require.EqualValues(t, 1, validateCalls.Load())
|
|
}
|
|
|
|
func TestDialWithLazyValidation_ValidationError(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
agentID := uuid.New()
|
|
workspaceID := uuid.New()
|
|
conn := agentconnmock.NewMockAgentConn(ctrl)
|
|
unblockDial := make(chan struct{})
|
|
|
|
var releaseCalls atomic.Int32
|
|
var validateCalls atomic.Int32
|
|
|
|
result, err := dialWithLazyValidation(
|
|
context.Background(),
|
|
agentID,
|
|
workspaceID,
|
|
func(ctx context.Context, id uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
|
if id != agentID {
|
|
return nil, nil, xerrors.Errorf("unexpected agent ID %q", id)
|
|
}
|
|
select {
|
|
case <-unblockDial:
|
|
return conn, func() {
|
|
releaseCalls.Add(1)
|
|
}, nil
|
|
case <-ctx.Done():
|
|
return nil, nil, ctx.Err()
|
|
}
|
|
},
|
|
func(_ context.Context, id uuid.UUID) (uuid.UUID, error) {
|
|
if id != workspaceID {
|
|
return uuid.Nil, xerrors.Errorf("unexpected workspace ID %q", id)
|
|
}
|
|
validateCalls.Add(1)
|
|
// Validation fails — code should fall back to waiting
|
|
// for the original dial.
|
|
close(unblockDial)
|
|
return uuid.Nil, xerrors.New("db connection reset")
|
|
},
|
|
0,
|
|
)
|
|
require.NoError(t, err)
|
|
require.Same(t, conn, result.Conn)
|
|
require.Equal(t, agentID, result.AgentID)
|
|
require.False(t, result.WasSwitched)
|
|
require.EqualValues(t, 1, validateCalls.Load())
|
|
|
|
if result.Release != nil {
|
|
result.Release()
|
|
}
|
|
require.EqualValues(t, 1, releaseCalls.Load())
|
|
}
|
|
|
|
func TestDialWithLazyValidation_ContextCanceled(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
|
|
agentID := uuid.New()
|
|
workspaceID := uuid.New()
|
|
|
|
var validateCalls atomic.Int32
|
|
|
|
_, err := dialWithLazyValidation(
|
|
ctx,
|
|
agentID,
|
|
workspaceID,
|
|
func(ctx context.Context, id uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
|
if id != agentID {
|
|
return nil, nil, xerrors.Errorf("unexpected agent ID %q", id)
|
|
}
|
|
<-ctx.Done()
|
|
return nil, nil, ctx.Err()
|
|
},
|
|
func(_ context.Context, id uuid.UUID) (uuid.UUID, error) {
|
|
if id != workspaceID {
|
|
return uuid.Nil, xerrors.Errorf("unexpected workspace ID %q", id)
|
|
}
|
|
validateCalls.Add(1)
|
|
cancel()
|
|
return agentID, nil
|
|
},
|
|
0,
|
|
)
|
|
require.ErrorIs(t, err, context.Canceled)
|
|
require.EqualValues(t, 1, validateCalls.Load())
|
|
}
|