Files
coder/coderd/x/chatd/configcache_internal_test.go
T
Ethan c650aabbef chore: standardize on *_internal_test.go for white-box tests (#25601)
My agent added `//nolint:testpackage` to a test file on one of my PRs.
Again. This PR cleans it up across the entire repo and updates the
in-repo conventions so future agents stop doing it.

The repo already has a precedent for white-box tests that need to touch
unexported symbols: `*_internal_test.go` (145+ existing files). The
`testpackage` linter's default `skip-regexp` exempts that filename
suffix, so the `//nolint:testpackage` directive is unnecessary in every
case where someone reached for it. This PR renames 51 such files to
`*_internal_test.go` via `git mv` so blame and history follow, and
strips the dead directive from 2 files that were already correctly named
(`coderd/oauth2provider/authorize_internal_test.go`,
`coderd/x/chatd/advisor_internal_test.go`).

`.claude/docs/TESTING.md` now documents the rule explicitly under *Test
Package Naming*, which is imported into the root `AGENTS.md` via
`@.claude/docs/TESTING.md`. The rule: prefer `package foo_test`; if you
need internal access, rename the file to `*_internal_test.go` rather
than adding a nolint directive.
2026-05-22 20:24:38 +10:00

1204 lines
36 KiB
Go

package chatd
import (
"context"
"database/sql"
"fmt"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/testutil"
"github.com/coder/quartz"
)
type stubChatConfigStore struct {
database.Store
getAIProviders func(context.Context) ([]database.AIProvider, error)
getChatModelConfigByID func(context.Context, uuid.UUID) (database.ChatModelConfig, error)
getDefaultChatModelConfig func(context.Context) (database.ChatModelConfig, error)
getUserChatCustomPrompt func(context.Context, uuid.UUID) (string, error)
getChatAdvisorConfig func(context.Context) (string, error)
enabledProvidersCalls atomic.Int32
modelConfigByIDCalls atomic.Int32
defaultModelConfigCall atomic.Int32
userPromptCalls atomic.Int32
advisorConfigCalls atomic.Int32
}
func (s *stubChatConfigStore) GetAIProviders(ctx context.Context, _ database.GetAIProvidersParams) ([]database.AIProvider, error) {
s.enabledProvidersCalls.Add(1)
if s.getAIProviders == nil {
panic("unexpected GetAIProviders call")
}
return s.getAIProviders(ctx)
}
func (s *stubChatConfigStore) GetChatModelConfigByID(ctx context.Context, id uuid.UUID) (database.ChatModelConfig, error) {
s.modelConfigByIDCalls.Add(1)
if s.getChatModelConfigByID == nil {
panic("unexpected GetChatModelConfigByID call")
}
return s.getChatModelConfigByID(ctx, id)
}
func (s *stubChatConfigStore) GetDefaultChatModelConfig(ctx context.Context) (database.ChatModelConfig, error) {
s.defaultModelConfigCall.Add(1)
if s.getDefaultChatModelConfig == nil {
panic("unexpected GetDefaultChatModelConfig call")
}
return s.getDefaultChatModelConfig(ctx)
}
func (s *stubChatConfigStore) GetUserChatCustomPrompt(ctx context.Context, userID uuid.UUID) (string, error) {
s.userPromptCalls.Add(1)
if s.getUserChatCustomPrompt == nil {
panic("unexpected GetUserChatCustomPrompt call")
}
return s.getUserChatCustomPrompt(ctx, userID)
}
func (s *stubChatConfigStore) GetChatAdvisorConfig(ctx context.Context) (string, error) {
s.advisorConfigCalls.Add(1)
if s.getChatAdvisorConfig == nil {
panic("unexpected GetChatAdvisorConfig call")
}
return s.getChatAdvisorConfig(ctx)
}
func TestConfigCache_EnabledProviders_CacheHit(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
clock := quartz.NewMock(t)
providers := []database.AIProvider{testAIProvider("provider-a")}
store := &stubChatConfigStore{
getAIProviders: func(context.Context) ([]database.AIProvider, error) {
return providers, nil
},
}
cache := newChatConfigCache(ctx, store, clock)
first, err := cache.EnabledProviders(ctx)
require.NoError(t, err)
second, err := cache.EnabledProviders(ctx)
require.NoError(t, err)
require.Equal(t, providers, first)
require.Equal(t, providers, second)
require.Equal(t, int32(1), store.enabledProvidersCalls.Load())
}
func TestConfigCache_EnabledProviders_TTLExpiry(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
clock := quartz.NewMock(t)
store := &stubChatConfigStore{}
store.getAIProviders = func(context.Context) ([]database.AIProvider, error) {
call := store.enabledProvidersCalls.Load()
return []database.AIProvider{testAIProvider(fmt.Sprintf("provider-%d", call))}, nil
}
cache := newChatConfigCache(ctx, store, clock)
first, err := cache.EnabledProviders(ctx)
require.NoError(t, err)
clock.Advance(chatConfigProvidersTTL).MustWait(ctx)
second, err := cache.EnabledProviders(ctx)
require.NoError(t, err)
require.NotEqual(t, first, second)
require.Equal(t, int32(2), store.enabledProvidersCalls.Load())
}
func TestConfigCache_EnabledProviders_Invalidation(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
clock := quartz.NewMock(t)
store := &stubChatConfigStore{}
store.getAIProviders = func(context.Context) ([]database.AIProvider, error) {
call := store.enabledProvidersCalls.Load()
return []database.AIProvider{testAIProvider(fmt.Sprintf("provider-%d", call))}, nil
}
cache := newChatConfigCache(ctx, store, clock)
first, err := cache.EnabledProviders(ctx)
require.NoError(t, err)
cache.InvalidateProviders()
second, err := cache.EnabledProviders(ctx)
require.NoError(t, err)
require.NotEqual(t, first, second)
require.Equal(t, int32(2), store.enabledProvidersCalls.Load())
}
func TestConfigCache_ModelConfigByID_CacheHit(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
clock := quartz.NewMock(t)
configID := uuid.New()
config := testChatModelConfig(configID, "model-a")
store := &stubChatConfigStore{
getChatModelConfigByID: func(context.Context, uuid.UUID) (database.ChatModelConfig, error) {
return config, nil
},
}
cache := newChatConfigCache(ctx, store, clock)
first, err := cache.ModelConfigByID(ctx, configID)
require.NoError(t, err)
second, err := cache.ModelConfigByID(ctx, configID)
require.NoError(t, err)
require.Equal(t, config, first)
require.Equal(t, config, second)
require.Equal(t, int32(1), store.modelConfigByIDCalls.Load())
}
func TestConfigCache_ModelConfigByID_ClonesOptionsForCache(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
clock := quartz.NewMock(t)
configID := uuid.New()
const options = `{"temperature":0.1}`
config := testChatModelConfig(configID, "model-a")
config.Options = []byte(options)
store := &stubChatConfigStore{
getChatModelConfigByID: func(context.Context, uuid.UUID) (database.ChatModelConfig, error) {
return config, nil
},
}
cache := newChatConfigCache(ctx, store, clock)
// First call populates cache via singleflight.
first, err := cache.ModelConfigByID(ctx, configID)
require.NoError(t, err)
first.Options[0] = 'x' // mutate singleflight return
// Second call is a cache hit.
second, err := cache.ModelConfigByID(ctx, configID)
require.NoError(t, err)
require.Equal(t, options, string(second.Options))
second.Options[0] = 'y' // mutate cache-hit return
// Third call is another cache hit — must be unaffected.
third, err := cache.ModelConfigByID(ctx, configID)
require.NoError(t, err)
require.Equal(t, options, string(third.Options))
}
func TestConfigCache_ModelConfigByID_NotFound(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
clock := quartz.NewMock(t)
configID := uuid.New()
store := &stubChatConfigStore{
getChatModelConfigByID: func(context.Context, uuid.UUID) (database.ChatModelConfig, error) {
return database.ChatModelConfig{}, sql.ErrNoRows
},
}
cache := newChatConfigCache(ctx, store, clock)
_, err := cache.ModelConfigByID(ctx, configID)
require.ErrorIs(t, err, sql.ErrNoRows)
_, err = cache.ModelConfigByID(ctx, configID)
require.ErrorIs(t, err, sql.ErrNoRows)
require.Equal(t, int32(2), store.modelConfigByIDCalls.Load())
_, ok := cache.modelConfigs[configID]
require.False(t, ok)
}
func TestConfigCache_InvalidateModelConfig_CascadesToDefault(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
clock := quartz.NewMock(t)
configID := uuid.New()
config := testChatModelConfig(configID, "model-a")
store := &stubChatConfigStore{}
store.getChatModelConfigByID = func(context.Context, uuid.UUID) (database.ChatModelConfig, error) {
return config, nil
}
store.getDefaultChatModelConfig = func(context.Context) (database.ChatModelConfig, error) {
call := store.defaultModelConfigCall.Load()
return testChatModelConfig(uuid.New(), fmt.Sprintf("default-model-%d", call)), nil
}
cache := newChatConfigCache(ctx, store, clock)
_, err := cache.ModelConfigByID(ctx, configID)
require.NoError(t, err)
firstDefault, err := cache.DefaultModelConfig(ctx)
require.NoError(t, err)
cache.InvalidateModelConfig(configID)
require.Nil(t, cache.defaultModelConfig)
secondDefault, err := cache.DefaultModelConfig(ctx)
require.NoError(t, err)
require.NotEqual(t, firstDefault, secondDefault)
require.Equal(t, int32(2), store.defaultModelConfigCall.Load())
}
func TestConfigCache_UserPrompt_NegativeCaching(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
clock := quartz.NewMock(t)
userID := uuid.New()
store := &stubChatConfigStore{
getUserChatCustomPrompt: func(context.Context, uuid.UUID) (string, error) {
return "", sql.ErrNoRows
},
}
cache := newChatConfigCache(ctx, store, clock)
first, err := cache.UserPrompt(ctx, userID)
require.NoError(t, err)
second, err := cache.UserPrompt(ctx, userID)
require.NoError(t, err)
require.Empty(t, first)
require.Empty(t, second)
require.Equal(t, int32(1), store.userPromptCalls.Load())
}
func TestConfigCache_UserPrompt_ExpiredEntryRefetches(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
clock := quartz.NewMock(t)
userID := uuid.New()
store := &stubChatConfigStore{}
store.getUserChatCustomPrompt = func(context.Context, uuid.UUID) (string, error) {
call := store.userPromptCalls.Load()
return fmt.Sprintf("prompt-%d", call), nil
}
cache := newChatConfigCache(ctx, store, clock)
cache.userPrompts.Set(userID, "stale", -time.Second)
first, err := cache.UserPrompt(ctx, userID)
require.NoError(t, err)
second, err := cache.UserPrompt(ctx, userID)
require.NoError(t, err)
require.Equal(t, "prompt-1", first)
require.Equal(t, first, second)
require.Equal(t, int32(1), store.userPromptCalls.Load())
}
func TestConfigCache_InvalidateUserPrompt(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
clock := quartz.NewMock(t)
userID := uuid.New()
store := &stubChatConfigStore{}
store.getUserChatCustomPrompt = func(context.Context, uuid.UUID) (string, error) {
call := store.userPromptCalls.Load()
return fmt.Sprintf("prompt-%d", call), nil
}
cache := newChatConfigCache(ctx, store, clock)
first, err := cache.UserPrompt(ctx, userID)
require.NoError(t, err)
cache.InvalidateUserPrompt(userID)
second, err := cache.UserPrompt(ctx, userID)
require.NoError(t, err)
require.NotEqual(t, first, second)
require.Equal(t, int32(2), store.userPromptCalls.Load())
}
func TestConfigCache_InvalidateUserPrompt_BlocksStaleInFlightPrompt(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
clock := quartz.NewMock(t)
userID := uuid.New()
const stalePrompt = "stale prompt"
const freshPrompt = "fresh prompt"
firstStarted := make(chan struct{})
secondStarted := make(chan struct{})
releaseFirst := make(chan struct{})
releaseSecond := make(chan struct{})
store := &stubChatConfigStore{}
store.getUserChatCustomPrompt = func(context.Context, uuid.UUID) (string, error) {
switch call := store.userPromptCalls.Load(); call {
case 1:
close(firstStarted)
<-releaseFirst
return stalePrompt, nil
case 2:
close(secondStarted)
<-releaseSecond
return freshPrompt, nil
default:
return "", xerrors.Errorf("unexpected user prompt call %d", call)
}
}
cache := newChatConfigCache(ctx, store, clock)
type result struct {
prompt string
err error
}
firstResult := make(chan result, 1)
go func() {
prompt, err := cache.UserPrompt(ctx, userID)
firstResult <- result{prompt: prompt, err: err}
}()
waitForSignal(t, firstStarted)
cache.InvalidateUserPrompt(userID)
secondResult := make(chan result, 1)
go func() {
prompt, err := cache.UserPrompt(ctx, userID)
secondResult <- result{prompt: prompt, err: err}
}()
waitForSignal(t, secondStarted)
close(releaseFirst)
first := <-firstResult
require.NoError(t, first.err)
require.Equal(t, stalePrompt, first.prompt)
_, _, ok := cache.userPrompts.Get(userID)
require.False(t, ok)
close(releaseSecond)
second := <-secondResult
require.NoError(t, second.err)
require.Equal(t, freshPrompt, second.prompt)
require.Equal(t, int32(2), store.userPromptCalls.Load())
third, err := cache.UserPrompt(ctx, userID)
require.NoError(t, err)
require.Equal(t, freshPrompt, third)
require.Equal(t, int32(2), store.userPromptCalls.Load())
}
func TestConfigCache_Singleflight(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
clock := quartz.NewMock(t)
providers := []database.AIProvider{testAIProvider("provider-a")}
fetchStarted := make(chan struct{})
releaseFetch := make(chan struct{})
var startedOnce sync.Once
store := &stubChatConfigStore{}
store.getAIProviders = func(context.Context) ([]database.AIProvider, error) {
startedOnce.Do(func() { close(fetchStarted) })
<-releaseFetch
return providers, nil
}
cache := newChatConfigCache(ctx, store, clock)
const callers = 8
results := make([][]database.AIProvider, callers)
errs := make([]error, callers)
var wg sync.WaitGroup
start := make(chan struct{})
for i := 0; i < callers; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
<-start
results[i], errs[i] = cache.EnabledProviders(ctx)
}(i)
}
close(start)
waitForSignal(t, fetchStarted)
close(releaseFetch)
wg.Wait()
for i := 0; i < callers; i++ {
require.NoError(t, errs[i])
require.Equal(t, providers, results[i])
}
require.Equal(t, int32(1), store.enabledProvidersCalls.Load())
}
func TestConfigCache_GenerationPreventsStaleWrite(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
clock := quartz.NewMock(t)
firstProviders := []database.AIProvider{testAIProvider("provider-a")}
secondProviders := []database.AIProvider{testAIProvider("provider-b")}
fetchStarted := make(chan struct{})
releaseFetch := make(chan struct{})
var startedOnce sync.Once
store := &stubChatConfigStore{}
store.getAIProviders = func(context.Context) ([]database.AIProvider, error) {
call := store.enabledProvidersCalls.Load()
if call == 1 {
startedOnce.Do(func() { close(fetchStarted) })
<-releaseFetch
return firstProviders, nil
}
return secondProviders, nil
}
cache := newChatConfigCache(ctx, store, clock)
resultCh := make(chan []database.AIProvider, 1)
errCh := make(chan error, 1)
go func() {
providers, err := cache.EnabledProviders(ctx)
if err != nil {
errCh <- err
return
}
resultCh <- providers
}()
waitForSignal(t, fetchStarted)
cache.InvalidateProviders()
close(releaseFetch)
select {
case err := <-errCh:
require.NoError(t, err)
case providers := <-resultCh:
require.Equal(t, firstProviders, providers)
case <-time.After(testutil.WaitShort):
t.Fatal("timed out waiting for in-flight fetch")
}
require.Nil(t, cache.providers)
second, err := cache.EnabledProviders(ctx)
require.NoError(t, err)
require.Equal(t, secondProviders, second)
require.Equal(t, int32(2), store.enabledProvidersCalls.Load())
}
func TestConfigCache_InvalidateProviders_BlocksStaleInFlightProviders(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
clock := quartz.NewMock(t)
staleProviders := []database.AIProvider{testAIProvider("provider-stale")}
freshProviders := []database.AIProvider{testAIProvider("provider-fresh")}
firstStarted := make(chan struct{})
secondStarted := make(chan struct{})
releaseFirst := make(chan struct{})
releaseSecond := make(chan struct{})
store := &stubChatConfigStore{}
store.getAIProviders = func(context.Context) ([]database.AIProvider, error) {
switch call := store.enabledProvidersCalls.Load(); call {
case 1:
close(firstStarted)
<-releaseFirst
return staleProviders, nil
case 2:
close(secondStarted)
<-releaseSecond
return freshProviders, nil
default:
return nil, xerrors.Errorf("unexpected provider call %d", call)
}
}
cache := newChatConfigCache(ctx, store, clock)
type result struct {
providers []database.AIProvider
err error
}
firstResult := make(chan result, 1)
go func() {
providers, err := cache.EnabledProviders(ctx)
firstResult <- result{providers: providers, err: err}
}()
waitForSignal(t, firstStarted)
cache.InvalidateProviders()
secondResult := make(chan result, 1)
go func() {
providers, err := cache.EnabledProviders(ctx)
secondResult <- result{providers: providers, err: err}
}()
waitForSignal(t, secondStarted)
close(releaseFirst)
first := <-firstResult
require.NoError(t, first.err)
require.Equal(t, staleProviders, first.providers)
require.Nil(t, cache.providers)
close(releaseSecond)
second := <-secondResult
require.NoError(t, second.err)
require.Equal(t, freshProviders, second.providers)
require.Equal(t, int32(2), store.enabledProvidersCalls.Load())
third, err := cache.EnabledProviders(ctx)
require.NoError(t, err)
require.Equal(t, freshProviders, third)
require.Equal(t, int32(2), store.enabledProvidersCalls.Load())
}
func TestConfigCache_InvalidateProviders_CascadesToModelConfigs(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
clock := quartz.NewMock(t)
configID := uuid.New()
store := &stubChatConfigStore{}
store.getChatModelConfigByID = func(context.Context, uuid.UUID) (database.ChatModelConfig, error) {
call := store.modelConfigByIDCalls.Load()
return testChatModelConfig(configID, fmt.Sprintf("model-%d", call)), nil
}
cache := newChatConfigCache(ctx, store, clock)
first, err := cache.ModelConfigByID(ctx, configID)
require.NoError(t, err)
cache.InvalidateProviders()
second, err := cache.ModelConfigByID(ctx, configID)
require.NoError(t, err)
require.NotEqual(t, first, second)
require.Equal(t, int32(2), store.modelConfigByIDCalls.Load())
}
func TestConfigCache_InvalidateProviders_CascadesToDefaultModelConfig(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
clock := quartz.NewMock(t)
store := &stubChatConfigStore{}
store.getDefaultChatModelConfig = func(context.Context) (database.ChatModelConfig, error) {
call := store.defaultModelConfigCall.Load()
return testChatModelConfig(uuid.New(), fmt.Sprintf("default-model-%d", call)), nil
}
cache := newChatConfigCache(ctx, store, clock)
first, err := cache.DefaultModelConfig(ctx)
require.NoError(t, err)
cache.InvalidateProviders()
second, err := cache.DefaultModelConfig(ctx)
require.NoError(t, err)
require.NotEqual(t, first, second)
require.Equal(t, int32(2), store.defaultModelConfigCall.Load())
}
func TestConfigCache_InvalidateProviders_BlocksStaleInFlightModelConfig(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
clock := quartz.NewMock(t)
configID := uuid.New()
staleConfig := testChatModelConfig(configID, "stale-model")
freshConfig := testChatModelConfig(configID, "fresh-model")
firstStarted := make(chan struct{})
secondStarted := make(chan struct{})
releaseFirst := make(chan struct{})
releaseSecond := make(chan struct{})
store := &stubChatConfigStore{}
store.getChatModelConfigByID = func(context.Context, uuid.UUID) (database.ChatModelConfig, error) {
switch call := store.modelConfigByIDCalls.Load(); call {
case 1:
close(firstStarted)
<-releaseFirst
return staleConfig, nil
case 2:
close(secondStarted)
<-releaseSecond
return freshConfig, nil
default:
return database.ChatModelConfig{}, xerrors.Errorf("unexpected model config call %d", call)
}
}
cache := newChatConfigCache(ctx, store, clock)
type result struct {
config database.ChatModelConfig
err error
}
firstResult := make(chan result, 1)
go func() {
config, err := cache.ModelConfigByID(ctx, configID)
firstResult <- result{config: config, err: err}
}()
waitForSignal(t, firstStarted)
cache.InvalidateProviders()
secondResult := make(chan result, 1)
go func() {
config, err := cache.ModelConfigByID(ctx, configID)
secondResult <- result{config: config, err: err}
}()
waitForSignal(t, secondStarted)
close(releaseFirst)
first := <-firstResult
require.NoError(t, first.err)
require.Equal(t, staleConfig, first.config)
_, ok := cache.modelConfigs[configID]
require.False(t, ok)
close(releaseSecond)
second := <-secondResult
require.NoError(t, second.err)
require.Equal(t, freshConfig, second.config)
require.Equal(t, int32(2), store.modelConfigByIDCalls.Load())
third, err := cache.ModelConfigByID(ctx, configID)
require.NoError(t, err)
require.Equal(t, freshConfig, third)
require.Equal(t, int32(2), store.modelConfigByIDCalls.Load())
}
func testAIProvider(name string) database.AIProvider {
return database.AIProvider{
ID: uuid.New(),
Type: database.AIProviderType(name),
Name: name,
DisplayName: sql.NullString{String: name, Valid: true},
Enabled: true,
CreatedAt: time.Unix(0, 0).UTC(),
UpdatedAt: time.Unix(0, 0).UTC(),
}
}
func testChatModelConfig(id uuid.UUID, model string) database.ChatModelConfig {
return database.ChatModelConfig{
ID: id,
Provider: "openai",
Model: model,
DisplayName: model,
Enabled: true,
CreatedAt: time.Unix(0, 0).UTC(),
UpdatedAt: time.Unix(0, 0).UTC(),
ContextLimit: 128000,
CompressionThreshold: 64000,
}
}
func waitForSignal(t *testing.T, ch <-chan struct{}) {
t.Helper()
select {
case <-ch:
case <-time.After(testutil.WaitShort):
t.Fatal("timed out waiting for signal")
}
}
// TestConfigCache_CallerCancellation verifies the DoChan-based
// cancellation semantics across all four cache methods:
// - A canceled caller returns immediately without waiting for the
// shared fill to complete.
// - One canceled waiter does not poison other coalesced waiters.
// - Server context cancellation propagates through the fill.
func TestConfigCache_CallerCancellation(t *testing.T) {
t.Parallel()
type cacheMethod struct {
name string
// setupBlocked configures the store to block on release.
// The started channel is closed when the fill enters the
// store. The release channel unblocks the store.
setupBlocked func(store *stubChatConfigStore, started, release chan struct{})
// setupCtxSensitive configures the store to block until
// its context is canceled (for server-shutdown testing).
setupCtxSensitive func(store *stubChatConfigStore, started chan struct{})
// call invokes the cache method under test.
call func(ctx context.Context, cache *chatConfigCache) error
// storeCalls returns the number of underlying store calls.
storeCalls func(store *stubChatConfigStore) int32
}
configID := uuid.New()
userID := uuid.New()
methods := []cacheMethod{
{
name: "EnabledProviders",
setupBlocked: func(store *stubChatConfigStore, started, release chan struct{}) {
var once sync.Once
store.getAIProviders = func(ctx context.Context) ([]database.AIProvider, error) {
once.Do(func() { close(started) })
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-release:
return []database.AIProvider{testAIProvider("p")}, nil
}
}
},
setupCtxSensitive: func(store *stubChatConfigStore, started chan struct{}) {
var once sync.Once
store.getAIProviders = func(ctx context.Context) ([]database.AIProvider, error) {
once.Do(func() { close(started) })
<-ctx.Done()
return nil, ctx.Err()
}
},
call: func(ctx context.Context, cache *chatConfigCache) error {
_, err := cache.EnabledProviders(ctx)
return err
},
storeCalls: func(store *stubChatConfigStore) int32 {
return store.enabledProvidersCalls.Load()
},
},
{
name: "ModelConfigByID",
setupBlocked: func(store *stubChatConfigStore, started, release chan struct{}) {
var once sync.Once
store.getChatModelConfigByID = func(ctx context.Context, id uuid.UUID) (database.ChatModelConfig, error) {
once.Do(func() { close(started) })
select {
case <-ctx.Done():
return database.ChatModelConfig{}, ctx.Err()
case <-release:
return testChatModelConfig(id, "model"), nil
}
}
},
setupCtxSensitive: func(store *stubChatConfigStore, started chan struct{}) {
var once sync.Once
store.getChatModelConfigByID = func(ctx context.Context, _ uuid.UUID) (database.ChatModelConfig, error) {
once.Do(func() { close(started) })
<-ctx.Done()
return database.ChatModelConfig{}, ctx.Err()
}
},
call: func(ctx context.Context, cache *chatConfigCache) error {
_, err := cache.ModelConfigByID(ctx, configID)
return err
},
storeCalls: func(store *stubChatConfigStore) int32 {
return store.modelConfigByIDCalls.Load()
},
},
{
name: "DefaultModelConfig",
setupBlocked: func(store *stubChatConfigStore, started, release chan struct{}) {
var once sync.Once
store.getDefaultChatModelConfig = func(ctx context.Context) (database.ChatModelConfig, error) {
once.Do(func() { close(started) })
select {
case <-ctx.Done():
return database.ChatModelConfig{}, ctx.Err()
case <-release:
return testChatModelConfig(uuid.New(), "default"), nil
}
}
},
setupCtxSensitive: func(store *stubChatConfigStore, started chan struct{}) {
var once sync.Once
store.getDefaultChatModelConfig = func(ctx context.Context) (database.ChatModelConfig, error) {
once.Do(func() { close(started) })
<-ctx.Done()
return database.ChatModelConfig{}, ctx.Err()
}
},
call: func(ctx context.Context, cache *chatConfigCache) error {
_, err := cache.DefaultModelConfig(ctx)
return err
},
storeCalls: func(store *stubChatConfigStore) int32 {
return store.defaultModelConfigCall.Load()
},
},
{
name: "UserPrompt",
setupBlocked: func(store *stubChatConfigStore, started, release chan struct{}) {
var once sync.Once
store.getUserChatCustomPrompt = func(ctx context.Context, _ uuid.UUID) (string, error) {
once.Do(func() { close(started) })
select {
case <-ctx.Done():
return "", ctx.Err()
case <-release:
return "custom prompt", nil
}
}
},
setupCtxSensitive: func(store *stubChatConfigStore, started chan struct{}) {
var once sync.Once
store.getUserChatCustomPrompt = func(ctx context.Context, _ uuid.UUID) (string, error) {
once.Do(func() { close(started) })
<-ctx.Done()
return "", ctx.Err()
}
},
call: func(ctx context.Context, cache *chatConfigCache) error {
_, err := cache.UserPrompt(ctx, userID)
return err
},
storeCalls: func(store *stubChatConfigStore) int32 {
return store.userPromptCalls.Load()
},
},
}
// Test A: A canceled caller stops waiting immediately; the
// shared fill still completes and populates the cache.
t.Run("CanceledCallerStopsWaiting", func(t *testing.T) {
t.Parallel()
for _, m := range methods {
t.Run(m.name, func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
clock := quartz.NewMock(t)
store := &stubChatConfigStore{}
started := make(chan struct{})
release := make(chan struct{})
m.setupBlocked(store, started, release)
cache := newChatConfigCache(ctx, store, clock)
callerCtx, callerCancel := context.WithCancel(ctx)
errCh := make(chan error, 1)
go func() {
errCh <- m.call(callerCtx, cache)
}()
// Wait for the fill to enter the store, then
// cancel the caller's context.
waitForSignal(t, started)
callerCancel()
select {
case err := <-errCh:
require.ErrorIs(t, err, context.Canceled)
case <-time.After(testutil.WaitShort):
t.Fatal("canceled caller did not return promptly")
}
// Release the store so the fill can complete.
close(release)
// A fresh call must succeed — either a cache
// hit or by joining the still-in-flight fill.
// Only one store call should have occurred.
require.NoError(t, m.call(ctx, cache))
require.Equal(t, int32(1), m.storeCalls(store))
})
}
})
// Test B: One canceled waiter does not poison other coalesced
// waiters sharing the same singleflight entry.
t.Run("CanceledWaiterDoesNotPoisonOthers", func(t *testing.T) {
t.Parallel()
for _, m := range methods {
t.Run(m.name, func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
clock := quartz.NewMock(t)
store := &stubChatConfigStore{}
started := make(chan struct{})
release := make(chan struct{})
m.setupBlocked(store, started, release)
cache := newChatConfigCache(ctx, store, clock)
cancelCtx, cancel := context.WithCancel(ctx)
cancelErrCh := make(chan error, 1)
survivorErrCh := make(chan error, 1)
go func() {
cancelErrCh <- m.call(cancelCtx, cache)
}()
go func() {
survivorErrCh <- m.call(ctx, cache)
}()
waitForSignal(t, started)
cancel()
select {
case err := <-cancelErrCh:
require.ErrorIs(t, err, context.Canceled)
case <-time.After(testutil.WaitShort):
t.Fatal("canceled caller did not return promptly")
}
// Release the store; the surviving waiter
// must receive the successful result.
close(release)
select {
case err := <-survivorErrCh:
require.NoError(t, err)
case <-time.After(testutil.WaitShort):
t.Fatal("survivor caller did not return")
}
require.Equal(t, int32(1), m.storeCalls(store))
})
}
})
// Test C: Server context cancellation propagates through the
// fill, ensuring graceful shutdown behavior is preserved.
t.Run("ServerCancellation", func(t *testing.T) {
t.Parallel()
for _, m := range methods {
t.Run(m.name, func(t *testing.T) {
t.Parallel()
clock := quartz.NewMock(t)
store := &stubChatConfigStore{}
started := make(chan struct{})
m.setupCtxSensitive(store, started)
serverCtx, serverCancel := context.WithCancel(context.Background())
defer serverCancel()
cache := newChatConfigCache(serverCtx, store, clock)
callerCtx := testutil.Context(t, testutil.WaitMedium)
errCh := make(chan error, 1)
go func() {
errCh <- m.call(callerCtx, cache)
}()
waitForSignal(t, started)
serverCancel()
select {
case err := <-errCh:
require.ErrorIs(t, err, context.Canceled)
case <-time.After(testutil.WaitShort):
t.Fatal("caller did not return after server cancel")
}
})
}
})
}
func TestConfigCache_AdvisorConfig_CacheHit(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
clock := quartz.NewMock(t)
const raw = `{"enabled":true,"max_uses_per_run":3,"max_output_tokens":16384}`
store := &stubChatConfigStore{
getChatAdvisorConfig: func(context.Context) (string, error) {
return raw, nil
},
}
cache := newChatConfigCache(ctx, store, clock)
first, err := cache.AdvisorConfig(ctx)
require.NoError(t, err)
second, err := cache.AdvisorConfig(ctx)
require.NoError(t, err)
require.True(t, first.Enabled)
require.Equal(t, 3, first.MaxUsesPerRun)
require.Equal(t, int64(16384), first.MaxOutputTokens)
require.Equal(t, first, second)
require.Equal(t, int32(1), store.advisorConfigCalls.Load(),
"second lookup must be served from cache")
}
func TestConfigCache_AdvisorConfig_TTLExpiry(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
clock := quartz.NewMock(t)
store := &stubChatConfigStore{}
store.getChatAdvisorConfig = func(context.Context) (string, error) {
call := store.advisorConfigCalls.Load()
return fmt.Sprintf(`{"max_uses_per_run":%d}`, call), nil
}
cache := newChatConfigCache(ctx, store, clock)
first, err := cache.AdvisorConfig(ctx)
require.NoError(t, err)
clock.Advance(chatConfigAdvisorConfigTTL).MustWait(ctx)
second, err := cache.AdvisorConfig(ctx)
require.NoError(t, err)
require.NotEqual(t, first.MaxUsesPerRun, second.MaxUsesPerRun,
"TTL expiry must trigger a refetch")
require.Equal(t, int32(2), store.advisorConfigCalls.Load())
}
func TestConfigCache_AdvisorConfig_DBErrorNotCached(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
clock := quartz.NewMock(t)
expected := xerrors.New("boom")
store := &stubChatConfigStore{
getChatAdvisorConfig: func(context.Context) (string, error) {
return "", expected
},
}
cache := newChatConfigCache(ctx, store, clock)
_, err := cache.AdvisorConfig(ctx)
require.ErrorIs(t, err, expected)
_, err = cache.AdvisorConfig(ctx)
require.ErrorIs(t, err, expected)
require.Equal(t, int32(2), store.advisorConfigCalls.Load(),
"errors must not populate the cache; every call retries")
}
func TestConfigCache_AdvisorConfig_InvalidJSONNotCached(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
clock := quartz.NewMock(t)
store := &stubChatConfigStore{
getChatAdvisorConfig: func(context.Context) (string, error) {
return "not valid json", nil
},
}
cache := newChatConfigCache(ctx, store, clock)
_, err := cache.AdvisorConfig(ctx)
require.Error(t, err, "malformed JSON must surface as an error")
_, err = cache.AdvisorConfig(ctx)
require.Error(t, err)
require.Equal(t, int32(2), store.advisorConfigCalls.Load(),
"parse errors must not populate the cache; every call retries")
}
func TestConfigCache_AdvisorConfig_EmptyJSONYieldsZeroValue(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
clock := quartz.NewMock(t)
// GetChatAdvisorConfig returns "{}" when the site-config row is
// absent. That must unmarshal to a zero-value AdvisorConfig rather
// than a parse error.
store := &stubChatConfigStore{
getChatAdvisorConfig: func(context.Context) (string, error) {
return "{}", nil
},
}
cache := newChatConfigCache(ctx, store, clock)
cfg, err := cache.AdvisorConfig(ctx)
require.NoError(t, err)
require.Equal(t, codersdk.AdvisorConfig{}, cfg)
}
// Guards the pubsub-driven invalidation path. Without this, an admin
// writing PUT /api/experimental/chats/config/advisor could keep every
// replica serving stale enabled/model/limits for up to
// chatConfigAdvisorConfigTTL, which defeats the subscriber in chatd.go.
func TestConfigCache_InvalidateAdvisorConfig(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
clock := quartz.NewMock(t)
store := &stubChatConfigStore{}
store.getChatAdvisorConfig = func(context.Context) (string, error) {
call := store.advisorConfigCalls.Load()
return fmt.Sprintf(`{"max_uses_per_run":%d}`, call), nil
}
cache := newChatConfigCache(ctx, store, clock)
first, err := cache.AdvisorConfig(ctx)
require.NoError(t, err)
cache.InvalidateAdvisorConfig()
second, err := cache.AdvisorConfig(ctx)
require.NoError(t, err)
require.NotEqual(t, first.MaxUsesPerRun, second.MaxUsesPerRun,
"invalidation must force a refetch without waiting for TTL expiry")
require.Equal(t, int32(2), store.advisorConfigCalls.Load())
}
// Guards against the invalidation-during-singleflight race. A stale
// in-flight fill started before InvalidateAdvisorConfig must not
// re-cache its pre-update value, which would defeat the pubsub
// invalidation path for up to chatConfigAdvisorConfigTTL.
func TestConfigCache_InvalidateAdvisorConfig_BlocksStaleInFlight(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
clock := quartz.NewMock(t)
staleConfig := `{"max_uses_per_run":1}`
freshConfig := `{"max_uses_per_run":2}`
firstStarted := make(chan struct{})
secondStarted := make(chan struct{})
releaseFirst := make(chan struct{})
releaseSecond := make(chan struct{})
store := &stubChatConfigStore{}
store.getChatAdvisorConfig = func(context.Context) (string, error) {
switch call := store.advisorConfigCalls.Load(); call {
case 1:
close(firstStarted)
<-releaseFirst
return staleConfig, nil
case 2:
close(secondStarted)
<-releaseSecond
return freshConfig, nil
default:
return "", xerrors.Errorf("unexpected advisor config call %d", call)
}
}
cache := newChatConfigCache(ctx, store, clock)
type result struct {
config codersdk.AdvisorConfig
err error
}
firstResult := make(chan result, 1)
go func() {
config, err := cache.AdvisorConfig(ctx)
firstResult <- result{config: config, err: err}
}()
waitForSignal(t, firstStarted)
cache.InvalidateAdvisorConfig()
secondResult := make(chan result, 1)
go func() {
config, err := cache.AdvisorConfig(ctx)
secondResult <- result{config: config, err: err}
}()
waitForSignal(t, secondStarted)
close(releaseFirst)
first := <-firstResult
require.NoError(t, first.err)
require.EqualValues(t, 1, first.config.MaxUsesPerRun)
require.Nil(t, cache.advisorConfig,
"stale fill must not re-cache after invalidation")
close(releaseSecond)
second := <-secondResult
require.NoError(t, second.err)
require.EqualValues(t, 2, second.config.MaxUsesPerRun)
require.Equal(t, int32(2), store.advisorConfigCalls.Load())
third, err := cache.AdvisorConfig(ctx)
require.NoError(t, err)
require.EqualValues(t, 2, third.MaxUsesPerRun)
require.Equal(t, int32(2), store.advisorConfigCalls.Load())
}