Files
coder/coderd/x/chatd/configcache.go
2026-05-22 09:50:01 +02:00

520 lines
14 KiB
Go

package chatd
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"slices"
"sync"
"time"
"github.com/ammario/tlru"
"github.com/google/uuid"
"tailscale.com/util/singleflight"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/quartz"
)
const (
chatConfigProvidersTTL = 10 * time.Second
chatConfigModelConfigTTL = 10 * time.Second
chatConfigUserPromptTTL = 5 * time.Second
chatConfigAdvisorConfigTTL = 10 * time.Second
// Bound user-prompt cache cardinality so one-shot users do not
// accumulate forever in long-lived chatd processes.
chatConfigUserPromptEntryLimit = 64 * 1024
)
type cachedProviders struct {
providers []database.AIProvider
expiresAt time.Time
}
type cachedAdvisorConfig struct {
config codersdk.AdvisorConfig
expiresAt time.Time
}
type cachedModelConfig struct {
config database.ChatModelConfig
expiresAt time.Time
}
type modelConfigSnapshot struct {
epoch uint64
generation uint64
}
// cloneModelConfig returns a shallow copy of cfg with Options
// deep-cloned so the cache owns its own backing array.
func cloneModelConfig(cfg database.ChatModelConfig) database.ChatModelConfig {
cfg.Options = slices.Clone(cfg.Options)
return cfg
}
type chatConfigCache struct {
db database.Store
clock quartz.Clock
// ctx is the server-scoped context used for all DB fills.
// Cache fills run inside singleflight.Do where one caller
// becomes the leader for all coalesced waiters. Using a
// per-request context would mean the leader's cancellation
// (timeout, user disconnect) fans the error to every waiter.
// Storing the server context here makes that impossible by
// construction — callers cannot pass a request context into
// the shared fill path.
ctx context.Context
mu sync.RWMutex
// Providers (singleton).
providers *cachedProviders
providerGeneration uint64
providerFetches singleflight.Group[string, []database.AIProvider]
// Model configs (keyed by ID).
modelTopologyEpoch uint64
modelConfigs map[uuid.UUID]cachedModelConfig
modelConfigFetches singleflight.Group[string, database.ChatModelConfig]
// Default model config (singleton).
defaultModelConfig *cachedModelConfig
defaultModelConfigGeneration uint64
defaultModelConfigFetches singleflight.Group[string, database.ChatModelConfig]
// User custom prompts (keyed by user ID).
userPromptEpoch uint64
userPrompts *tlru.Cache[uuid.UUID, string]
userPromptFetches singleflight.Group[string, string]
// Advisor configuration (singleton).
advisorConfig *cachedAdvisorConfig
advisorConfigGeneration uint64
advisorConfigFetches singleflight.Group[string, codersdk.AdvisorConfig]
}
func newChatConfigCache(ctx context.Context, db database.Store, clock quartz.Clock) *chatConfigCache {
return &chatConfigCache{
db: db,
clock: clock,
ctx: ctx,
modelConfigs: make(map[uuid.UUID]cachedModelConfig),
userPrompts: tlru.New[uuid.UUID](
tlru.ConstantCost[string],
chatConfigUserPromptEntryLimit,
),
}
}
// singleflightDoChan wraps a singleflight group's DoChan method,
// allowing the caller to abandon the wait if their context is
// canceled while the shared fill continues running to completion.
// This separates two lifetimes: the fill runs under the server-scoped
// context, while each caller waits under its own request-scoped context.
func singleflightDoChan[K comparable, V any](
ctx context.Context,
group *singleflight.Group[K, V],
key K,
fn func() (V, error),
) (V, error) {
ch := group.DoChan(key, fn)
select {
case <-ctx.Done():
var zero V
return zero, ctx.Err()
case res := <-ch:
return res.Val, res.Err
}
}
func (c *chatConfigCache) EnabledProviders(ctx context.Context) ([]database.AIProvider, error) {
if providers, ok := c.cachedProviders(); ok {
return providers, nil
}
generation := c.providersGeneration()
providers, err := singleflightDoChan(
ctx,
&c.providerFetches,
fmt.Sprintf("%d:providers", generation),
func() ([]database.AIProvider, error) {
if cached, ok := c.cachedProviders(); ok {
return cached, nil
}
fetched, err := c.db.GetAIProviders(c.ctx, database.GetAIProvidersParams{})
if err != nil {
return nil, err
}
c.storeProviders(generation, fetched)
return slices.Clone(fetched), nil
},
)
if err != nil {
return nil, err
}
return slices.Clone(providers), nil
}
func (c *chatConfigCache) cachedProviders() ([]database.AIProvider, bool) {
c.mu.RLock()
entry := c.providers
c.mu.RUnlock()
if entry == nil {
return nil, false
}
if c.clock.Now().Before(entry.expiresAt) {
return slices.Clone(entry.providers), true
}
c.mu.Lock()
if current := c.providers; current != nil && !c.clock.Now().Before(current.expiresAt) {
c.providers = nil
}
c.mu.Unlock()
return nil, false
}
func (c *chatConfigCache) providersGeneration() uint64 {
c.mu.RLock()
generation := c.providerGeneration
c.mu.RUnlock()
return generation
}
func (c *chatConfigCache) storeProviders(generation uint64, providers []database.AIProvider) {
c.mu.Lock()
defer c.mu.Unlock()
if c.providerGeneration != generation {
return
}
c.providers = &cachedProviders{
providers: slices.Clone(providers),
expiresAt: c.clock.Now().Add(chatConfigProvidersTTL),
}
}
func (c *chatConfigCache) InvalidateProviders() {
c.mu.Lock()
c.providers = nil
c.providerGeneration++
// Provider topology changed — model selections depend on
// provider existence, so flush all model-config state.
clear(c.modelConfigs)
c.modelTopologyEpoch++
c.defaultModelConfig = nil
c.defaultModelConfigGeneration++
c.mu.Unlock()
}
func (c *chatConfigCache) ModelConfigByID(ctx context.Context, id uuid.UUID) (database.ChatModelConfig, error) {
if config, ok := c.cachedModelConfig(id); ok {
return config, nil
}
snap := c.modelConfigSnapshot()
config, err := singleflightDoChan(ctx, &c.modelConfigFetches, fmt.Sprintf("%d:%s", snap.epoch, id), func() (database.ChatModelConfig, error) {
if cached, ok := c.cachedModelConfig(id); ok {
return cached, nil
}
fetched, err := c.db.GetChatModelConfigByID(c.ctx, id)
if err != nil {
return database.ChatModelConfig{}, err
}
c.storeModelConfig(snap, fetched)
return cloneModelConfig(fetched), nil
})
if err != nil {
return database.ChatModelConfig{}, err
}
return config, nil
}
func (c *chatConfigCache) cachedModelConfig(id uuid.UUID) (database.ChatModelConfig, bool) {
c.mu.RLock()
entry, ok := c.modelConfigs[id]
c.mu.RUnlock()
if !ok {
return database.ChatModelConfig{}, false
}
if c.clock.Now().Before(entry.expiresAt) {
return cloneModelConfig(entry.config), true
}
c.mu.Lock()
if current, ok := c.modelConfigs[id]; ok && !c.clock.Now().Before(current.expiresAt) {
delete(c.modelConfigs, id)
}
c.mu.Unlock()
return database.ChatModelConfig{}, false
}
func (c *chatConfigCache) modelConfigSnapshot() modelConfigSnapshot {
c.mu.RLock()
snap := modelConfigSnapshot{epoch: c.modelTopologyEpoch}
c.mu.RUnlock()
return snap
}
func (c *chatConfigCache) storeModelConfig(snap modelConfigSnapshot, config database.ChatModelConfig) {
c.mu.Lock()
defer c.mu.Unlock()
if c.modelTopologyEpoch != snap.epoch {
return
}
c.modelConfigs[config.ID] = cachedModelConfig{
config: cloneModelConfig(config),
expiresAt: c.clock.Now().Add(chatConfigModelConfigTTL),
}
}
func (c *chatConfigCache) DefaultModelConfig(ctx context.Context) (database.ChatModelConfig, error) {
if config, ok := c.cachedDefaultModelConfig(); ok {
return config, nil
}
snap := c.defaultModelConfigSnapshot()
config, err := singleflightDoChan(ctx, &c.defaultModelConfigFetches, fmt.Sprintf("%d:default", snap.epoch), func() (database.ChatModelConfig, error) {
if cached, ok := c.cachedDefaultModelConfig(); ok {
return cached, nil
}
fetched, err := c.db.GetDefaultChatModelConfig(c.ctx)
if err != nil {
return database.ChatModelConfig{}, err
}
c.storeDefaultModelConfig(snap, fetched)
return cloneModelConfig(fetched), nil
})
if err != nil {
return database.ChatModelConfig{}, err
}
return config, nil
}
func (c *chatConfigCache) cachedDefaultModelConfig() (database.ChatModelConfig, bool) {
c.mu.RLock()
entry := c.defaultModelConfig
c.mu.RUnlock()
if entry == nil {
return database.ChatModelConfig{}, false
}
if c.clock.Now().Before(entry.expiresAt) {
return cloneModelConfig(entry.config), true
}
c.mu.Lock()
if current := c.defaultModelConfig; current != nil && !c.clock.Now().Before(current.expiresAt) {
c.defaultModelConfig = nil
}
c.mu.Unlock()
return database.ChatModelConfig{}, false
}
func (c *chatConfigCache) defaultModelConfigSnapshot() modelConfigSnapshot {
c.mu.RLock()
snap := modelConfigSnapshot{
epoch: c.modelTopologyEpoch,
generation: c.defaultModelConfigGeneration,
}
c.mu.RUnlock()
return snap
}
func (c *chatConfigCache) storeDefaultModelConfig(snap modelConfigSnapshot, config database.ChatModelConfig) {
c.mu.Lock()
defer c.mu.Unlock()
if c.modelTopologyEpoch != snap.epoch {
return
}
if c.defaultModelConfigGeneration != snap.generation {
return
}
c.defaultModelConfig = &cachedModelConfig{
config: cloneModelConfig(config),
expiresAt: c.clock.Now().Add(chatConfigModelConfigTTL),
}
}
func (c *chatConfigCache) UserPrompt(ctx context.Context, userID uuid.UUID) (string, error) {
if prompt, ok := c.cachedUserPrompt(userID); ok {
return prompt, nil
}
epoch := c.currentUserPromptEpoch()
prompt, err := singleflightDoChan(ctx, &c.userPromptFetches, fmt.Sprintf("%d:%s", epoch, userID), func() (string, error) {
if cached, ok := c.cachedUserPrompt(userID); ok {
return cached, nil
}
fetched, err := c.db.GetUserChatCustomPrompt(c.ctx, userID)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
c.storeUserPrompt(epoch, userID, "")
return "", nil
}
return "", err
}
c.storeUserPrompt(epoch, userID, fetched)
return fetched, nil
})
if err != nil {
return "", err
}
return prompt, nil
}
func (c *chatConfigCache) cachedUserPrompt(userID uuid.UUID) (string, bool) {
prompt, _, ok := c.userPrompts.Get(userID)
if !ok {
return "", false
}
return prompt, true
}
func (c *chatConfigCache) currentUserPromptEpoch() uint64 {
c.mu.RLock()
epoch := c.userPromptEpoch
c.mu.RUnlock()
return epoch
}
func (c *chatConfigCache) storeUserPrompt(epoch uint64, userID uuid.UUID, prompt string) {
c.mu.Lock()
defer c.mu.Unlock()
if c.userPromptEpoch != epoch {
return
}
c.userPrompts.Set(userID, prompt, chatConfigUserPromptTTL)
}
func (c *chatConfigCache) InvalidateModelConfig(id uuid.UUID) {
c.mu.Lock()
delete(c.modelConfigs, id)
c.modelTopologyEpoch++
c.defaultModelConfig = nil
c.defaultModelConfigGeneration++
c.mu.Unlock()
}
func (c *chatConfigCache) InvalidateUserPrompt(userID uuid.UUID) {
c.mu.Lock()
c.userPrompts.Delete(userID)
c.userPromptEpoch++
c.mu.Unlock()
}
// InvalidateAdvisorConfig drops the cached advisor configuration so the
// next AdvisorConfig call re-fetches from the database. Called from the
// ChatConfigEvent subscriber after an admin writes
// PUT /api/experimental/chats/config/advisor; without this the cache
// could serve stale enabled/model/limits for up to
// chatConfigAdvisorConfigTTL. Bumping the generation counter also
// discards any in-flight fill started before the invalidation, so a
// stale DB read cannot re-cache the pre-update value.
func (c *chatConfigCache) InvalidateAdvisorConfig() {
c.mu.Lock()
c.advisorConfig = nil
c.advisorConfigGeneration++
c.mu.Unlock()
}
// AdvisorConfig returns the deployment-wide advisor configuration. The
// underlying site-config row changes on the order of hours or days, so
// this cache saves a per-turn DB round trip on chats that reference the
// advisor. Parse errors and lookup errors are surfaced to the caller;
// callers that prefer silent fallback handle that at the call site.
func (c *chatConfigCache) AdvisorConfig(ctx context.Context) (codersdk.AdvisorConfig, error) {
if config, ok := c.cachedAdvisorConfig(); ok {
return config, nil
}
generation := c.advisorConfigGenerationSnapshot()
config, err := singleflightDoChan(
ctx,
&c.advisorConfigFetches,
fmt.Sprintf("%d:advisor", generation),
func() (codersdk.AdvisorConfig, error) {
if cached, ok := c.cachedAdvisorConfig(); ok {
return cached, nil
}
raw, err := c.db.GetChatAdvisorConfig(c.ctx)
if err != nil {
return codersdk.AdvisorConfig{}, err
}
var cfg codersdk.AdvisorConfig
if err := json.Unmarshal([]byte(raw), &cfg); err != nil {
return codersdk.AdvisorConfig{}, err
}
c.storeAdvisorConfig(generation, cfg)
return cfg, nil
},
)
if err != nil {
return codersdk.AdvisorConfig{}, err
}
return config, nil
}
func (c *chatConfigCache) cachedAdvisorConfig() (codersdk.AdvisorConfig, bool) {
c.mu.RLock()
entry := c.advisorConfig
c.mu.RUnlock()
if entry == nil {
return codersdk.AdvisorConfig{}, false
}
if c.clock.Now().Before(entry.expiresAt) {
return entry.config, true
}
c.mu.Lock()
if current := c.advisorConfig; current != nil && !c.clock.Now().Before(current.expiresAt) {
c.advisorConfig = nil
}
c.mu.Unlock()
return codersdk.AdvisorConfig{}, false
}
func (c *chatConfigCache) advisorConfigGenerationSnapshot() uint64 {
c.mu.RLock()
generation := c.advisorConfigGeneration
c.mu.RUnlock()
return generation
}
func (c *chatConfigCache) storeAdvisorConfig(generation uint64, config codersdk.AdvisorConfig) {
c.mu.Lock()
defer c.mu.Unlock()
if c.advisorConfigGeneration != generation {
return
}
c.advisorConfig = &cachedAdvisorConfig{
config: config,
expiresAt: c.clock.Now().Add(chatConfigAdvisorConfigTTL),
}
}