mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
c933ddcffd
## Problem The Admin → Agents → System Prompt textarea saved only to the browser's `localStorage`. The value was never sent to the backend, never stored in the database, and never injected into chats. Entering text, clicking Save, and refreshing the page showed no changes — the prompt was effectively a no-op. ## Root Cause Three disconnected layers: 1. **Frontend** wrote to `localStorage`, never called an API. 2. **`handleCreateChat`** never read `savedSystemPrompt`. 3. **Backend** hardcoded `chatd.DefaultSystemPrompt` on every chat creation — no field in `CreateChatRequest` accepted a custom prompt. ## Changes ### Database - Added `GetChatSystemPrompt` / `UpsertChatSystemPrompt` queries on the existing `site_configs` table (no migration needed). ### API - `GET /api/experimental/chats/system-prompt` — returns the configured prompt (any authenticated user). - `PUT /api/experimental/chats/system-prompt` — sets the prompt (admin-only, `rbac: deployment_config update`). - Input validation: max 32 KiB prompt length. ### Backend - `resolvedChatSystemPrompt(ctx)` checks for a custom prompt in the DB, falls back to `chatd.DefaultSystemPrompt` when empty/unset. - Logs a warning on DB errors instead of silently swallowing them. - Replaced the hardcoded `defaultChatSystemPrompt()` call in chat creation. ### Frontend - Replaced `localStorage` read/write with React Query `useQuery`/`useMutation` backed by the new endpoints. - Fixed `useEffect` draft sync to avoid clobbering in-progress user edits on refetch. - Added `try/catch` error handling on save (draft stays dirty for retry). - Save button disabled during mutation (`isSavingSystemPrompt`). - Query key follows kebab-case convention (`chat-system-prompt`). ### UX - Added hint: "When empty, the built-in default prompt is used." ### Tests - `TestChatSystemPrompt`: GET returns empty when unset, admin can set, non-admin gets 403. - dbauthz `TestMethodTestSuite` coverage for both new querier methods.
3755 lines
103 KiB
Go
3755 lines
103 KiB
Go
package coderd
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"context"
|
|
"database/sql"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"mime"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"regexp"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"charm.land/fantasy"
|
|
"github.com/go-chi/chi/v5"
|
|
"github.com/google/uuid"
|
|
"golang.org/x/xerrors"
|
|
|
|
"cdr.dev/slog/v3"
|
|
"github.com/coder/coder/v2/coderd/audit"
|
|
"github.com/coder/coder/v2/coderd/chatd"
|
|
"github.com/coder/coder/v2/coderd/chatd/chatprovider"
|
|
"github.com/coder/coder/v2/coderd/database"
|
|
"github.com/coder/coder/v2/coderd/database/db2sdk"
|
|
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
|
"github.com/coder/coder/v2/coderd/externalauth"
|
|
"github.com/coder/coder/v2/coderd/httpapi"
|
|
"github.com/coder/coder/v2/coderd/httpapi/httperror"
|
|
"github.com/coder/coder/v2/coderd/httpmw"
|
|
"github.com/coder/coder/v2/coderd/pubsub"
|
|
"github.com/coder/coder/v2/coderd/rbac"
|
|
"github.com/coder/coder/v2/coderd/rbac/policy"
|
|
"github.com/coder/coder/v2/coderd/tracing"
|
|
"github.com/coder/coder/v2/codersdk"
|
|
"github.com/coder/coder/v2/codersdk/wsjson"
|
|
"github.com/coder/websocket"
|
|
)
|
|
|
|
const (
|
|
chatDiffStatusTTL = 120 * time.Second
|
|
chatDiffBackgroundRefreshTimeout = 20 * time.Second
|
|
githubAPIBaseURL = "https://api.github.com"
|
|
chatStreamBatchSize = 256
|
|
|
|
chatContextLimitModelConfigKey = "context_limit"
|
|
chatContextCompressionThresholdModelConfigKey = "context_compression_threshold"
|
|
defaultChatContextCompressionThreshold = int32(70)
|
|
minChatContextCompressionThreshold = int32(0)
|
|
maxChatContextCompressionThreshold = int32(100)
|
|
maxSystemPromptLenBytes = 131072 // 128 KiB
|
|
)
|
|
|
|
// chatDiffRefreshBackoffSchedule defines the delays between successive
|
|
// background diff refresh attempts. The trigger fires when the agent
|
|
// obtains a GitHub token, which is typically right before a git push
|
|
// or PR creation. The backoff gives progressively more time for the
|
|
// push and any PR workflow to complete before querying the GitHub API.
|
|
var chatDiffRefreshBackoffSchedule = []time.Duration{
|
|
1 * time.Second,
|
|
3 * time.Second,
|
|
5 * time.Second,
|
|
10 * time.Second,
|
|
20 * time.Second,
|
|
}
|
|
|
|
// chatGitRef holds the branch and remote origin reported by the
|
|
// workspace agent during a git operation.
|
|
type chatGitRef struct {
|
|
Branch string
|
|
RemoteOrigin string
|
|
}
|
|
|
|
var (
|
|
githubPullRequestPathPattern = regexp.MustCompile(
|
|
`^https://github\.com/([A-Za-z0-9_.-]+)/([A-Za-z0-9_.-]+)/pull/([0-9]+)(?:[/?#].*)?$`,
|
|
)
|
|
githubRepositoryHTTPSPattern = regexp.MustCompile(
|
|
`^https://github\.com/([A-Za-z0-9_.-]+)/([A-Za-z0-9_.-]+?)(?:\.git)?/?$`,
|
|
)
|
|
githubRepositorySSHPathPattern = regexp.MustCompile(
|
|
`^(?:ssh://)?git@github\.com[:/]([A-Za-z0-9_.-]+)/([A-Za-z0-9_.-]+?)(?:\.git)?/?$`,
|
|
)
|
|
)
|
|
|
|
type githubPullRequestRef struct {
|
|
Owner string
|
|
Repo string
|
|
Number int
|
|
}
|
|
|
|
type githubPullRequestStatus struct {
|
|
PullRequestState string
|
|
ChangesRequested bool
|
|
Additions int32
|
|
Deletions int32
|
|
ChangedFiles int32
|
|
}
|
|
|
|
type chatRepositoryRef struct {
|
|
Provider string
|
|
RemoteOrigin string
|
|
Branch string
|
|
Owner string
|
|
Repo string
|
|
}
|
|
|
|
type chatDiffReference struct {
|
|
PullRequestURL string
|
|
RepositoryRef *chatRepositoryRef
|
|
}
|
|
|
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
|
func (api *API) watchChats(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
apiKey := httpmw.APIKey(r)
|
|
|
|
sendEvent, senderClosed, err := httpapi.OneWayWebSocketEventSender(api.Logger)(rw, r)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to open chat watch stream.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
defer func() {
|
|
<-senderClosed
|
|
}()
|
|
|
|
cancelSubscribe, err := api.Pubsub.SubscribeWithErr(pubsub.ChatEventChannel(apiKey.UserID),
|
|
pubsub.HandleChatEvent(
|
|
func(ctx context.Context, payload pubsub.ChatEvent, err error) {
|
|
if err != nil {
|
|
api.Logger.Error(ctx, "chat event subscription error", slog.Error(err))
|
|
return
|
|
}
|
|
_ = sendEvent(codersdk.ServerSentEvent{
|
|
Type: codersdk.ServerSentEventTypeData,
|
|
Data: payload,
|
|
})
|
|
},
|
|
))
|
|
if err != nil {
|
|
_ = sendEvent(codersdk.ServerSentEvent{
|
|
Type: codersdk.ServerSentEventTypeError,
|
|
Data: codersdk.Response{
|
|
Message: "Internal error subscribing to chat events.",
|
|
Detail: err.Error(),
|
|
},
|
|
})
|
|
return
|
|
}
|
|
defer cancelSubscribe()
|
|
|
|
// Send initial ping to signal the connection is ready.
|
|
_ = sendEvent(codersdk.ServerSentEvent{
|
|
Type: codersdk.ServerSentEventTypePing,
|
|
})
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case <-senderClosed:
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
|
func (api *API) listChats(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
apiKey := httpmw.APIKey(r)
|
|
|
|
params := database.GetChatsByOwnerIDParams{
|
|
OwnerID: apiKey.UserID,
|
|
}
|
|
if v := r.URL.Query().Get("archived"); v != "" {
|
|
b, err := strconv.ParseBool(v)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid query parameter.",
|
|
Validations: []codersdk.ValidationError{
|
|
{Field: "archived", Detail: "Must be a valid boolean"},
|
|
},
|
|
})
|
|
return
|
|
}
|
|
params.Archived = sql.NullBool{Bool: b, Valid: true}
|
|
}
|
|
|
|
chats, err := api.Database.GetChatsByOwnerID(ctx, params)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to list chats.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
diffStatusesByChatID, err := api.getChatDiffStatusesByChatID(ctx, chats)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to list chats.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
httpapi.Write(ctx, rw, http.StatusOK, convertChats(chats, diffStatusesByChatID))
|
|
}
|
|
|
|
func (api *API) getChatDiffStatusesByChatID(
|
|
ctx context.Context,
|
|
chats []database.Chat,
|
|
) (map[uuid.UUID]database.ChatDiffStatus, error) {
|
|
if len(chats) == 0 {
|
|
return map[uuid.UUID]database.ChatDiffStatus{}, nil
|
|
}
|
|
|
|
chatIDs := make([]uuid.UUID, 0, len(chats))
|
|
for _, chat := range chats {
|
|
chatIDs = append(chatIDs, chat.ID)
|
|
}
|
|
|
|
statuses, err := api.Database.GetChatDiffStatusesByChatIDs(ctx, chatIDs)
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("get chat diff statuses: %w", err)
|
|
}
|
|
|
|
statusesByChatID := make(map[uuid.UUID]database.ChatDiffStatus, len(statuses))
|
|
for _, status := range statuses {
|
|
statusesByChatID[status.ChatID] = status
|
|
}
|
|
return statusesByChatID, nil
|
|
}
|
|
|
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
|
func (api *API) postChats(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
apiKey := httpmw.APIKey(r)
|
|
|
|
var req codersdk.CreateChatRequest
|
|
if !httpapi.Read(ctx, rw, r, &req) {
|
|
return
|
|
}
|
|
|
|
contentBlocks, contentFileIDs, titleSource, inputError := createChatInputFromRequest(ctx, api.Database, req)
|
|
if inputError != nil {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, *inputError)
|
|
return
|
|
}
|
|
|
|
workspaceSelection, validationStatus, validationError := api.validateCreateChatWorkspaceSelection(ctx, req)
|
|
if validationError != nil {
|
|
httpapi.Write(ctx, rw, validationStatus, *validationError)
|
|
return
|
|
}
|
|
|
|
title := chatTitleFromMessage(titleSource)
|
|
|
|
if api.chatDaemon == nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Chat processor is unavailable.",
|
|
Detail: "Chat processor is not configured.",
|
|
})
|
|
return
|
|
}
|
|
|
|
modelConfigID, modelConfigStatus, modelConfigError := api.resolveCreateChatModelConfigID(ctx, req)
|
|
if modelConfigError != nil {
|
|
httpapi.Write(ctx, rw, modelConfigStatus, *modelConfigError)
|
|
return
|
|
}
|
|
|
|
chat, err := api.chatDaemon.CreateChat(ctx, chatd.CreateOptions{
|
|
OwnerID: apiKey.UserID,
|
|
WorkspaceID: workspaceSelection.WorkspaceID,
|
|
Title: title,
|
|
ModelConfigID: modelConfigID,
|
|
SystemPrompt: api.resolvedChatSystemPrompt(ctx),
|
|
InitialUserContent: contentBlocks,
|
|
ContentFileIDs: contentFileIDs,
|
|
})
|
|
if err != nil {
|
|
if database.IsForeignKeyViolation(
|
|
err,
|
|
database.ForeignKeyChatsLastModelConfigID,
|
|
database.ForeignKeyChatMessagesModelConfigID,
|
|
) {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid model config ID.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to create chat.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
httpapi.Write(ctx, rw, http.StatusCreated, convertChat(chat, nil))
|
|
}
|
|
|
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
|
func (api *API) listChatModels(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
//nolint:gocritic // System context required to read enabled chat models.
|
|
systemCtx := dbauthz.AsSystemRestricted(ctx)
|
|
|
|
if api.chatDaemon == nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Chat processor is unavailable.",
|
|
Detail: "Chat processor is not configured.",
|
|
})
|
|
return
|
|
}
|
|
|
|
enabledProviders, err := api.Database.GetEnabledChatProviders(
|
|
systemCtx,
|
|
)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to load chat model configuration.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
enabledModels, err := api.Database.GetEnabledChatModelConfigs(
|
|
systemCtx,
|
|
)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to load chat model configuration.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
configuredProviders := make(
|
|
[]chatprovider.ConfiguredProvider, 0, len(enabledProviders),
|
|
)
|
|
for _, provider := range enabledProviders {
|
|
configuredProviders = append(
|
|
configuredProviders, chatprovider.ConfiguredProvider{
|
|
Provider: provider.Provider,
|
|
APIKey: provider.APIKey,
|
|
BaseURL: provider.BaseUrl,
|
|
},
|
|
)
|
|
}
|
|
configuredModels := make(
|
|
[]chatprovider.ConfiguredModel, 0, len(enabledModels),
|
|
)
|
|
for _, model := range enabledModels {
|
|
configuredModels = append(configuredModels, chatprovider.ConfiguredModel{
|
|
Provider: model.Provider,
|
|
Model: model.Model,
|
|
DisplayName: model.DisplayName,
|
|
})
|
|
}
|
|
|
|
keys := chatprovider.MergeProviderAPIKeys(
|
|
chatProviderAPIKeysFromDeploymentValues(api.DeploymentValues),
|
|
configuredProviders,
|
|
)
|
|
catalog := chatprovider.NewModelCatalog(keys)
|
|
var response codersdk.ChatModelsResponse
|
|
if configured, ok := catalog.ListConfiguredModels(
|
|
configuredProviders, configuredModels,
|
|
); ok {
|
|
response = configured
|
|
} else {
|
|
response = catalog.ListConfiguredProviderAvailability(configuredProviders)
|
|
}
|
|
|
|
httpapi.Write(ctx, rw, http.StatusOK, response)
|
|
}
|
|
|
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
|
//
|
|
//nolint:revive // HTTP handler writes to ResponseWriter.
|
|
func (api *API) getChat(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
chat := httpmw.ChatParam(r)
|
|
chatID := chat.ID
|
|
|
|
messages, err := api.Database.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
|
ChatID: chatID,
|
|
AfterID: 0,
|
|
})
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to get chat messages.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
queuedMessages, err := api.Database.GetChatQueuedMessages(ctx, chatID)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to get queued messages.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
httpapi.Write(ctx, rw, http.StatusOK, codersdk.ChatWithMessages{
|
|
Chat: convertChat(chat, nil),
|
|
Messages: convertChatMessages(messages),
|
|
QueuedMessages: convertChatQueuedMessages(queuedMessages),
|
|
})
|
|
}
|
|
|
|
// @Summary Watch git changes for a chat.
|
|
// @ID watch-chat-git
|
|
// @Security CoderSessionToken
|
|
// @Tags Chats
|
|
// @Param chat path string true "Chat ID" format(uuid)
|
|
// @Success 101
|
|
// @Router /chats/{chat}/git/watch [get]
|
|
//
|
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
|
//
|
|
//nolint:revive // HTTP handler writes to ResponseWriter.
|
|
func (api *API) watchChatGit(rw http.ResponseWriter, r *http.Request) {
|
|
var (
|
|
ctx = r.Context()
|
|
chat = httpmw.ChatParam(r)
|
|
logger = api.Logger.Named("chat_git_watcher").With(slog.F("chat_id", chat.ID))
|
|
)
|
|
|
|
if !chat.WorkspaceID.Valid {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Chat has no workspace to watch.",
|
|
})
|
|
return
|
|
}
|
|
|
|
agents, err := api.Database.GetWorkspaceAgentsInLatestBuildByWorkspaceID(ctx, chat.WorkspaceID.UUID)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Internal error fetching workspace agents.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
if len(agents) == 0 {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Chat workspace has no agents.",
|
|
})
|
|
return
|
|
}
|
|
|
|
apiAgent, err := db2sdk.WorkspaceAgent(
|
|
api.DERPMap(),
|
|
*api.TailnetCoordinator.Load(),
|
|
agents[0],
|
|
nil,
|
|
nil,
|
|
nil,
|
|
api.AgentInactiveDisconnectTimeout,
|
|
api.DeploymentValues.AgentFallbackTroubleshootingURL.String(),
|
|
)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Internal error reading workspace agent.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
if apiAgent.Status != codersdk.WorkspaceAgentConnected {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: fmt.Sprintf("Agent state is %q, it must be in the %q state.", apiAgent.Status, codersdk.WorkspaceAgentConnected),
|
|
})
|
|
return
|
|
}
|
|
|
|
dialCtx, dialCancel := context.WithTimeout(ctx, 30*time.Second)
|
|
defer dialCancel()
|
|
|
|
agentConn, release, err := api.agentProvider.AgentConn(dialCtx, agents[0].ID)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Internal error dialing workspace agent.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
defer release()
|
|
|
|
agentStream, err := agentConn.WatchGit(ctx, logger, chat.ID)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Internal error watching agent's git state.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
defer agentStream.Close(websocket.StatusGoingAway)
|
|
|
|
clientConn, err := websocket.Accept(rw, r, &websocket.AcceptOptions{
|
|
CompressionMode: websocket.CompressionNoContextTakeover,
|
|
})
|
|
if err != nil {
|
|
logger.Error(ctx, "failed to accept websocket", slog.Error(err))
|
|
return
|
|
}
|
|
|
|
clientStream := wsjson.NewStream[
|
|
codersdk.WorkspaceAgentGitClientMessage,
|
|
codersdk.WorkspaceAgentGitServerMessage,
|
|
](clientConn, websocket.MessageText, websocket.MessageText, logger)
|
|
|
|
ctx, cancel := context.WithCancel(r.Context())
|
|
defer cancel()
|
|
|
|
go httpapi.HeartbeatClose(ctx, logger, cancel, clientConn)
|
|
|
|
// Proxy agent → client.
|
|
agentCh := agentStream.Chan()
|
|
var wg sync.WaitGroup
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
for {
|
|
select {
|
|
case <-api.ctx.Done():
|
|
return
|
|
case <-ctx.Done():
|
|
return
|
|
case msg, ok := <-agentCh:
|
|
if !ok {
|
|
cancel()
|
|
return
|
|
}
|
|
if err := clientStream.Send(msg); err != nil {
|
|
logger.Debug(ctx, "failed to forward agent message to client", slog.Error(err))
|
|
cancel()
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}()
|
|
|
|
// Proxy client → agent.
|
|
clientCh := clientStream.Chan()
|
|
proxyLoop:
|
|
for {
|
|
select {
|
|
case <-api.ctx.Done():
|
|
break proxyLoop
|
|
case <-ctx.Done():
|
|
break proxyLoop
|
|
case msg, ok := <-clientCh:
|
|
if !ok {
|
|
break proxyLoop
|
|
}
|
|
if err := agentStream.Send(msg); err != nil {
|
|
logger.Debug(ctx, "failed to forward client message to agent", slog.Error(err))
|
|
break proxyLoop
|
|
}
|
|
}
|
|
}
|
|
|
|
cancel()
|
|
wg.Wait()
|
|
_ = clientStream.Close(websocket.StatusGoingAway)
|
|
}
|
|
|
|
// @Summary Archive a chat
|
|
// @ID archive-chat
|
|
// @Tags Chats
|
|
// @Success 204
|
|
// @Router /chats/{chat}/archive [post]
|
|
func (api *API) archiveChat(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
chat := httpmw.ChatParam(r)
|
|
|
|
if chat.Archived {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Chat is already archived.",
|
|
})
|
|
return
|
|
}
|
|
|
|
var err error
|
|
// Use chatDaemon when available so it can notify
|
|
// active subscribers. Fall back to direct DB for the
|
|
// simple archive flag — no streaming state is involved.
|
|
if api.chatDaemon != nil {
|
|
err = api.chatDaemon.ArchiveChat(ctx, chat.ID)
|
|
} else {
|
|
err = api.Database.ArchiveChatByID(ctx, chat.ID)
|
|
}
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to archive chat.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
rw.WriteHeader(http.StatusNoContent)
|
|
}
|
|
|
|
// @Summary Unarchive a chat
|
|
// @ID unarchive-chat
|
|
// @Tags Chats
|
|
// @Success 204
|
|
// @Router /chats/{chat}/unarchive [post]
|
|
func (api *API) unarchiveChat(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
chat := httpmw.ChatParam(r)
|
|
|
|
if !chat.Archived {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Chat is not archived.",
|
|
})
|
|
return
|
|
}
|
|
|
|
var err error
|
|
// Use chatDaemon when available so it can notify
|
|
// active subscribers. Fall back to direct DB for the
|
|
// simple unarchive flag — no streaming state is involved.
|
|
if api.chatDaemon != nil {
|
|
err = api.chatDaemon.UnarchiveChat(ctx, chat.ID)
|
|
} else {
|
|
err = api.Database.UnarchiveChatByID(ctx, chat.ID)
|
|
}
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to unarchive chat.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
rw.WriteHeader(http.StatusNoContent)
|
|
}
|
|
|
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
|
func (api *API) postChatMessages(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
chat := httpmw.ChatParam(r)
|
|
chatID := chat.ID
|
|
|
|
if api.chatDaemon == nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Chat processor is unavailable.",
|
|
Detail: "Chat processor is not configured.",
|
|
})
|
|
return
|
|
}
|
|
|
|
var req codersdk.CreateChatMessageRequest
|
|
if !httpapi.Read(ctx, rw, r, &req) {
|
|
return
|
|
}
|
|
|
|
contentBlocks, contentFileIDs, _, inputError := createChatInputFromParts(ctx, api.Database, req.Content, "content")
|
|
if inputError != nil {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: inputError.Message,
|
|
Detail: inputError.Detail,
|
|
})
|
|
return
|
|
}
|
|
|
|
sendResult, sendErr := api.chatDaemon.SendMessage(
|
|
ctx,
|
|
chatd.SendMessageOptions{
|
|
ChatID: chatID,
|
|
Content: contentBlocks,
|
|
ContentFileIDs: contentFileIDs,
|
|
ModelConfigID: req.ModelConfigID,
|
|
BusyBehavior: chatd.SendMessageBusyBehaviorQueue,
|
|
},
|
|
)
|
|
if sendErr != nil {
|
|
if xerrors.Is(sendErr, chatd.ErrMessageQueueFull) {
|
|
httpapi.Write(ctx, rw, http.StatusTooManyRequests, codersdk.Response{
|
|
Message: "Message queue is full.",
|
|
Detail: fmt.Sprintf("Maximum %d messages can be queued.", chatd.MaxQueueSize),
|
|
})
|
|
return
|
|
}
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to create chat message.",
|
|
Detail: sendErr.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
response := codersdk.CreateChatMessageResponse{Queued: sendResult.Queued}
|
|
if sendResult.Queued {
|
|
if sendResult.QueuedMessage != nil {
|
|
response.QueuedMessage = convertChatQueuedMessagePtr(*sendResult.QueuedMessage)
|
|
}
|
|
} else {
|
|
message := convertChatMessage(sendResult.Message)
|
|
response.Message = &message
|
|
}
|
|
|
|
httpapi.Write(ctx, rw, http.StatusOK, response)
|
|
}
|
|
|
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
|
func (api *API) patchChatMessage(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
chat := httpmw.ChatParam(r)
|
|
|
|
if api.chatDaemon == nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Chat processor is unavailable.",
|
|
Detail: "Chat processor is not configured.",
|
|
})
|
|
return
|
|
}
|
|
|
|
messageIDStr := chi.URLParam(r, "message")
|
|
messageID, err := strconv.ParseInt(messageIDStr, 10, 64)
|
|
if err != nil || messageID <= 0 {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid chat message ID.",
|
|
Detail: "Message ID must be a positive integer.",
|
|
})
|
|
return
|
|
}
|
|
|
|
var req codersdk.EditChatMessageRequest
|
|
if !httpapi.Read(ctx, rw, r, &req) {
|
|
return
|
|
}
|
|
|
|
contentBlocks, contentFileIDs, _, inputError := createChatInputFromParts(ctx, api.Database, req.Content, "content")
|
|
if inputError != nil {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: inputError.Message,
|
|
Detail: inputError.Detail,
|
|
})
|
|
return
|
|
}
|
|
|
|
editResult, editErr := api.chatDaemon.EditMessage(ctx, chatd.EditMessageOptions{
|
|
ChatID: chat.ID,
|
|
EditedMessageID: messageID,
|
|
Content: contentBlocks,
|
|
ContentFileIDs: contentFileIDs,
|
|
})
|
|
if editErr != nil {
|
|
switch {
|
|
case xerrors.Is(editErr, chatd.ErrEditedMessageNotFound):
|
|
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
|
|
Message: "Chat message not found.",
|
|
Detail: "Message does not belong to this chat.",
|
|
})
|
|
case xerrors.Is(editErr, chatd.ErrEditedMessageNotUser):
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Only user messages can be edited.",
|
|
})
|
|
default:
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to edit chat message.",
|
|
Detail: editErr.Error(),
|
|
})
|
|
}
|
|
return
|
|
}
|
|
|
|
message := convertChatMessage(editResult.Message)
|
|
httpapi.Write(ctx, rw, http.StatusOK, message)
|
|
}
|
|
|
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
|
func (api *API) deleteChatQueuedMessage(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
chat := httpmw.ChatParam(r)
|
|
chatID := chat.ID
|
|
|
|
queuedMessageIDStr := chi.URLParam(r, "queuedMessage")
|
|
queuedMessageID, err := strconv.ParseInt(queuedMessageIDStr, 10, 64)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid queued message ID.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
if api.chatDaemon != nil {
|
|
err = api.chatDaemon.DeleteQueued(ctx, chatID, queuedMessageID)
|
|
} else {
|
|
err = api.Database.DeleteChatQueuedMessage(ctx, database.DeleteChatQueuedMessageParams{
|
|
ID: queuedMessageID,
|
|
ChatID: chatID,
|
|
})
|
|
}
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to delete queued message.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
rw.WriteHeader(http.StatusNoContent)
|
|
}
|
|
|
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
|
func (api *API) promoteChatQueuedMessage(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
chat := httpmw.ChatParam(r)
|
|
chatID := chat.ID
|
|
|
|
queuedMessageIDStr := chi.URLParam(r, "queuedMessage")
|
|
queuedMessageID, err := strconv.ParseInt(queuedMessageIDStr, 10, 64)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid queued message ID.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
if api.chatDaemon == nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Chat processor is unavailable.",
|
|
Detail: "Chat processor is not configured.",
|
|
})
|
|
return
|
|
}
|
|
|
|
promoteResult, txErr := api.chatDaemon.PromoteQueued(ctx, chatd.PromoteQueuedOptions{
|
|
ChatID: chatID,
|
|
QueuedMessageID: queuedMessageID,
|
|
})
|
|
|
|
if txErr != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to promote queued message.",
|
|
Detail: txErr.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
httpapi.Write(ctx, rw, http.StatusOK, convertChatMessage(promoteResult.PromotedMessage))
|
|
}
|
|
|
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
|
func (api *API) streamChat(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
chat := httpmw.ChatParam(r)
|
|
chatID := chat.ID
|
|
|
|
if api.chatDaemon == nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Chat streaming is not available.",
|
|
Detail: "Chat processor is not configured.",
|
|
})
|
|
return
|
|
}
|
|
|
|
var afterMessageID int64
|
|
if v := r.URL.Query().Get("after_id"); v != "" {
|
|
var err error
|
|
afterMessageID, err = strconv.ParseInt(v, 10, 64)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid after_id parameter.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
}
|
|
|
|
sendEvent, senderClosed, err := httpapi.OneWayWebSocketEventSender(api.Logger)(rw, r)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to open chat stream.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
snapshot, events, cancel, ok := api.chatDaemon.Subscribe(ctx, chatID, r.Header, afterMessageID)
|
|
if !ok {
|
|
_ = sendEvent(codersdk.ServerSentEvent{
|
|
Type: codersdk.ServerSentEventTypeError,
|
|
Data: codersdk.Response{
|
|
Message: "Chat streaming is not available.",
|
|
Detail: "Chat stream state is not configured.",
|
|
},
|
|
})
|
|
// Ensure the WebSocket is closed so senderClosed
|
|
// completes and the handler can return.
|
|
<-senderClosed
|
|
return
|
|
}
|
|
defer func() {
|
|
<-senderClosed
|
|
}()
|
|
defer cancel()
|
|
|
|
sendChatStreamBatch := func(batch []codersdk.ChatStreamEvent) error {
|
|
if len(batch) == 0 {
|
|
return nil
|
|
}
|
|
return sendEvent(codersdk.ServerSentEvent{
|
|
Type: codersdk.ServerSentEventTypeData,
|
|
Data: batch,
|
|
})
|
|
}
|
|
|
|
drainChatStreamBatch := func(
|
|
first codersdk.ChatStreamEvent,
|
|
maxBatchSize int,
|
|
) ([]codersdk.ChatStreamEvent, bool) {
|
|
batch := []codersdk.ChatStreamEvent{first}
|
|
if maxBatchSize <= 1 {
|
|
return batch, false
|
|
}
|
|
|
|
for len(batch) < maxBatchSize {
|
|
select {
|
|
case event, ok := <-events:
|
|
if !ok {
|
|
return batch, true
|
|
}
|
|
batch = append(batch, event)
|
|
default:
|
|
return batch, false
|
|
}
|
|
}
|
|
|
|
return batch, false
|
|
}
|
|
|
|
for start := 0; start < len(snapshot); start += chatStreamBatchSize {
|
|
end := start + chatStreamBatchSize
|
|
if end > len(snapshot) {
|
|
end = len(snapshot)
|
|
}
|
|
if err := sendChatStreamBatch(snapshot[start:end]); err != nil {
|
|
api.Logger.Debug(ctx, "failed to send chat stream snapshot", slog.Error(err))
|
|
return
|
|
}
|
|
}
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case <-senderClosed:
|
|
return
|
|
case firstEvent, ok := <-events:
|
|
if !ok {
|
|
return
|
|
}
|
|
batch, streamClosed := drainChatStreamBatch(
|
|
firstEvent,
|
|
chatStreamBatchSize,
|
|
)
|
|
if err := sendChatStreamBatch(batch); err != nil {
|
|
api.Logger.Debug(ctx, "failed to send chat stream event", slog.Error(err))
|
|
return
|
|
}
|
|
if streamClosed {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
|
func (api *API) interruptChat(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
chat := httpmw.ChatParam(r)
|
|
chatID := chat.ID
|
|
|
|
if api.chatDaemon != nil {
|
|
chat = api.chatDaemon.InterruptChat(ctx, chat)
|
|
} else {
|
|
updatedChat, updateErr := api.Database.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
|
ID: chatID,
|
|
Status: database.ChatStatusWaiting,
|
|
WorkerID: uuid.NullUUID{},
|
|
StartedAt: sql.NullTime{},
|
|
HeartbeatAt: sql.NullTime{},
|
|
LastError: sql.NullString{},
|
|
})
|
|
if updateErr != nil {
|
|
api.Logger.Error(ctx, "failed to mark chat as waiting",
|
|
slog.F("chat_id", chatID), slog.Error(updateErr))
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to interrupt chat.",
|
|
Detail: updateErr.Error(),
|
|
})
|
|
return
|
|
}
|
|
chat = updatedChat
|
|
}
|
|
|
|
httpapi.Write(ctx, rw, http.StatusOK, convertChat(chat, nil))
|
|
}
|
|
|
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
|
//
|
|
//nolint:revive // HTTP handler writes to ResponseWriter.
|
|
func (api *API) getChatDiffStatus(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
chat := httpmw.ChatParam(r)
|
|
chatID := chat.ID
|
|
|
|
status, err := api.resolveChatDiffStatus(ctx, chat)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to get chat diff status.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
httpapi.Write(ctx, rw, http.StatusOK, convertChatDiffStatus(chatID, status))
|
|
}
|
|
|
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
|
//
|
|
//nolint:revive // HTTP handler writes to ResponseWriter.
|
|
func (api *API) getChatDiffContents(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
chat := httpmw.ChatParam(r)
|
|
|
|
diff, err := api.resolveChatDiffContents(ctx, chat)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to get chat diff.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
httpapi.Write(ctx, rw, http.StatusOK, diff)
|
|
}
|
|
|
|
// chatCreateWorkspace provides workspace creation for the chat
|
|
// processor. RBAC authorization uses context-based checks via
|
|
// dbauthz.As rather than fake *http.Request objects.
|
|
func (api *API) chatCreateWorkspace(
|
|
ctx context.Context,
|
|
ownerID uuid.UUID,
|
|
req codersdk.CreateWorkspaceRequest,
|
|
) (codersdk.Workspace, error) {
|
|
actor, _, err := httpmw.UserRBACSubject(ctx, api.Database, ownerID, rbac.ScopeAll)
|
|
if err != nil {
|
|
return codersdk.Workspace{}, xerrors.Errorf("load user authorization: %w", err)
|
|
}
|
|
ctx = dbauthz.As(ctx, actor)
|
|
|
|
ownerUser, err := api.Database.GetUserByID(ctx, ownerID)
|
|
if err != nil {
|
|
return codersdk.Workspace{}, xerrors.Errorf("get workspace owner: %w", err)
|
|
}
|
|
owner := workspaceOwner{
|
|
ID: ownerUser.ID,
|
|
Username: ownerUser.Username,
|
|
AvatarURL: ownerUser.AvatarURL,
|
|
}
|
|
|
|
auditor := api.Auditor.Load()
|
|
if auditor == nil {
|
|
return codersdk.Workspace{}, xerrors.New("auditor is not configured")
|
|
}
|
|
|
|
// The audit system requires a ResponseWriter to capture the
|
|
// HTTP status code. Since this is a programmatic call, we use
|
|
// a recorder. The audit entry still captures the owner, action,
|
|
// and resource correctly.
|
|
rw := httptest.NewRecorder()
|
|
sw := &tracing.StatusWriter{ResponseWriter: rw}
|
|
|
|
// Build a minimal synthetic request so the audit commit
|
|
// closure can extract a request ID and user agent. The RBAC
|
|
// subject is already on the context via dbauthz.As above.
|
|
auditReq, err := http.NewRequestWithContext(
|
|
httpmw.WithRequestID(ctx, uuid.New()),
|
|
http.MethodPost,
|
|
"http://localhost/internal/chat/workspace",
|
|
nil,
|
|
)
|
|
if err != nil {
|
|
return codersdk.Workspace{}, xerrors.Errorf("create audit request: %w", err)
|
|
}
|
|
|
|
aReq, commitAudit := audit.InitRequest[database.WorkspaceTable](sw, &audit.RequestParams{
|
|
Audit: *auditor,
|
|
Log: api.Logger,
|
|
Request: auditReq,
|
|
Action: database.AuditActionCreate,
|
|
AdditionalFields: audit.AdditionalFields{
|
|
WorkspaceOwner: owner.Username,
|
|
},
|
|
})
|
|
aReq.UserID = ownerID
|
|
defer commitAudit()
|
|
|
|
workspace, err := createWorkspace(ctx, aReq, ownerID, api, owner, req, nil)
|
|
if err != nil {
|
|
sw.WriteHeader(chatWorkspaceAuditStatus(err))
|
|
return codersdk.Workspace{}, err
|
|
}
|
|
|
|
sw.WriteHeader(http.StatusCreated)
|
|
return workspace, nil
|
|
}
|
|
|
|
// chatStartWorkspace starts a stopped workspace by creating a new
|
|
// build with the "start" transition. It mirrors chatCreateWorkspace
|
|
// but for the start path.
|
|
func (api *API) chatStartWorkspace(
|
|
ctx context.Context,
|
|
ownerID uuid.UUID,
|
|
workspaceID uuid.UUID,
|
|
req codersdk.CreateWorkspaceBuildRequest,
|
|
) (codersdk.WorkspaceBuild, error) {
|
|
actor, _, err := httpmw.UserRBACSubject(ctx, api.Database, ownerID, rbac.ScopeAll)
|
|
if err != nil {
|
|
return codersdk.WorkspaceBuild{}, xerrors.Errorf("load user authorization: %w", err)
|
|
}
|
|
ctx = dbauthz.As(ctx, actor)
|
|
|
|
workspace, err := api.Database.GetWorkspaceByID(ctx, workspaceID)
|
|
if err != nil {
|
|
return codersdk.WorkspaceBuild{}, xerrors.Errorf("get workspace: %w", err)
|
|
}
|
|
|
|
// Build a synthetic API key so postWorkspaceBuildsInternal can
|
|
// record the correct initiator.
|
|
syntheticKey := database.APIKey{
|
|
UserID: ownerID,
|
|
}
|
|
|
|
apiBuild, err := api.postWorkspaceBuildsInternal(
|
|
ctx,
|
|
syntheticKey,
|
|
workspace,
|
|
req,
|
|
func(action policy.Action, object rbac.Objecter) bool {
|
|
// Authorization is handled by dbauthz on the context.
|
|
authErr := api.HTTPAuth.Authorizer.Authorize(ctx, actor, action, object.RBACObject())
|
|
return authErr == nil
|
|
},
|
|
audit.WorkspaceBuildBaggage{},
|
|
)
|
|
if err != nil {
|
|
return codersdk.WorkspaceBuild{}, xerrors.Errorf("create workspace build: %w", err)
|
|
}
|
|
|
|
return apiBuild, nil
|
|
}
|
|
|
|
func chatWorkspaceAuditStatus(err error) int {
|
|
if responder, ok := httperror.IsResponder(err); ok {
|
|
status, _ := responder.Response()
|
|
return status
|
|
}
|
|
return http.StatusInternalServerError
|
|
}
|
|
|
|
func (api *API) resolveChatDiffStatus(
|
|
ctx context.Context,
|
|
chat database.Chat,
|
|
) (*database.ChatDiffStatus, error) {
|
|
return api.resolveChatDiffStatusWithOptions(ctx, chat, false)
|
|
}
|
|
|
|
//nolint:revive // Boolean forces cache refresh bypass.
|
|
func (api *API) resolveChatDiffStatusWithOptions(
|
|
ctx context.Context,
|
|
chat database.Chat,
|
|
forceRefresh bool,
|
|
) (*database.ChatDiffStatus, error) {
|
|
status, found, err := api.getCachedChatDiffStatus(ctx, chat.ID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
now := time.Now().UTC()
|
|
|
|
reference, err := api.resolveChatDiffReference(ctx, chat, found, status)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if reference.PullRequestURL != "" {
|
|
if !found || !strings.EqualFold(strings.TrimSpace(status.Url.String), reference.PullRequestURL) {
|
|
status, err = api.upsertChatDiffStatusReference(ctx, chat.ID, reference.PullRequestURL, now.Add(-time.Second))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
found = true
|
|
}
|
|
}
|
|
|
|
if !found {
|
|
return nil, nil //nolint:nilnil // Callers handle nil status explicitly.
|
|
}
|
|
if reference.PullRequestURL == "" {
|
|
return &status, nil
|
|
}
|
|
if !shouldRefreshChatDiffStatus(status, now, forceRefresh) {
|
|
return &status, nil
|
|
}
|
|
|
|
refreshed, err := api.refreshChatDiffStatus(
|
|
ctx,
|
|
chat.OwnerID,
|
|
chat.ID,
|
|
reference.PullRequestURL,
|
|
)
|
|
if err == nil {
|
|
return &refreshed, nil
|
|
}
|
|
|
|
api.Logger.Warn(ctx, "failed to refresh chat diff status",
|
|
slog.F("chat_id", chat.ID),
|
|
slog.F("pull_request_url", reference.PullRequestURL),
|
|
slog.Error(err),
|
|
)
|
|
|
|
backoffStatus, backoffErr := api.upsertChatDiffStatusReference(ctx, chat.ID, reference.PullRequestURL, now.Add(chatDiffStatusTTL))
|
|
if backoffErr != nil {
|
|
api.Logger.Warn(ctx, "failed to extend chat diff status stale timestamp",
|
|
slog.F("chat_id", chat.ID),
|
|
slog.Error(backoffErr),
|
|
)
|
|
return &status, nil
|
|
}
|
|
|
|
return &backoffStatus, nil
|
|
}
|
|
|
|
//nolint:revive // Boolean forces cache refresh bypass.
|
|
func shouldRefreshChatDiffStatus(status database.ChatDiffStatus, now time.Time, forceRefresh bool) bool {
|
|
if forceRefresh {
|
|
return true
|
|
}
|
|
return chatDiffStatusIsStale(status, now)
|
|
}
|
|
|
|
func (api *API) triggerWorkspaceChatDiffStatusRefresh(workspace database.Workspace, chatID uuid.NullUUID, gitRef chatGitRef) {
|
|
if workspace.ID == uuid.Nil || workspace.OwnerID == uuid.Nil {
|
|
return
|
|
}
|
|
|
|
go func(workspaceID, workspaceOwnerID uuid.UUID, chatID uuid.NullUUID, gitRef chatGitRef) {
|
|
ctx := api.ctx
|
|
if ctx == nil {
|
|
ctx = context.Background()
|
|
}
|
|
//nolint:gocritic // Background goroutine for diff status refresh has no user context.
|
|
ctx = dbauthz.AsSystemRestricted(ctx)
|
|
|
|
// Always store the git ref so the data is persisted even
|
|
// before a PR exists. The frontend can show branch info
|
|
// and the refresh loop can resolve a PR later.
|
|
api.storeChatGitRef(ctx, workspaceID, workspaceOwnerID, chatID, gitRef)
|
|
|
|
for _, delay := range chatDiffRefreshBackoffSchedule {
|
|
t := api.Clock.NewTimer(delay, "chat_diff_refresh")
|
|
select {
|
|
case <-ctx.Done():
|
|
t.Stop()
|
|
return
|
|
case <-t.C:
|
|
}
|
|
|
|
// Refresh and publish status on every iteration.
|
|
// Stop the loop once a PR is discovered — there's
|
|
// nothing more to wait for after that.
|
|
if api.refreshWorkspaceChatDiffStatuses(ctx, workspaceID, workspaceOwnerID, chatID) {
|
|
return
|
|
}
|
|
}
|
|
}(workspace.ID, workspace.OwnerID, chatID, gitRef)
|
|
}
|
|
|
|
// storeChatGitRef persists the git branch and remote origin reported
|
|
// by the workspace agent on the chat that initiated the git operation.
|
|
// When chatID is set, only that specific chat is updated; otherwise all
|
|
// chats associated with the workspace are updated (legacy fallback).
|
|
func (api *API) storeChatGitRef(ctx context.Context, workspaceID, workspaceOwnerID uuid.UUID, chatID uuid.NullUUID, gitRef chatGitRef) {
|
|
var chatsToUpdate []database.Chat
|
|
|
|
if chatID.Valid {
|
|
chat, err := api.Database.GetChatByID(ctx, chatID.UUID)
|
|
if err != nil {
|
|
api.Logger.Warn(ctx, "failed to get chat for git ref storage",
|
|
slog.F("chat_id", chatID.UUID),
|
|
slog.F("workspace_id", workspaceID),
|
|
slog.Error(err),
|
|
)
|
|
return
|
|
}
|
|
chatsToUpdate = []database.Chat{chat}
|
|
} else {
|
|
chats, err := api.Database.GetChatsByOwnerID(ctx, database.GetChatsByOwnerIDParams{
|
|
OwnerID: workspaceOwnerID,
|
|
})
|
|
if err != nil {
|
|
api.Logger.Warn(ctx, "failed to list chats for git ref storage",
|
|
slog.F("workspace_id", workspaceID),
|
|
slog.Error(err),
|
|
)
|
|
return
|
|
}
|
|
chatsToUpdate = filterChatsByWorkspaceID(chats, workspaceID)
|
|
}
|
|
|
|
for _, chat := range chatsToUpdate {
|
|
_, err := api.Database.UpsertChatDiffStatusReference(ctx, database.UpsertChatDiffStatusReferenceParams{
|
|
ChatID: chat.ID,
|
|
GitBranch: gitRef.Branch,
|
|
GitRemoteOrigin: gitRef.RemoteOrigin,
|
|
StaleAt: time.Now().UTC().Add(-time.Second),
|
|
Url: sql.NullString{},
|
|
})
|
|
if err != nil {
|
|
api.Logger.Warn(ctx, "failed to store git ref on chat diff status",
|
|
slog.F("chat_id", chat.ID),
|
|
slog.F("workspace_id", workspaceID),
|
|
slog.Error(err),
|
|
)
|
|
continue
|
|
}
|
|
api.publishChatDiffStatusEvent(ctx, chat.ID)
|
|
}
|
|
}
|
|
|
|
// refreshWorkspaceChatDiffStatuses refreshes the diff status for chats
|
|
// associated with the given workspace. When chatID is set, only that
|
|
// specific chat is refreshed; otherwise all chats for the workspace
|
|
// are refreshed (legacy fallback). It returns true when every
|
|
// refreshed chat has a PR URL resolved, signaling that the caller
|
|
// can stop polling.
|
|
func (api *API) refreshWorkspaceChatDiffStatuses(ctx context.Context, workspaceID, workspaceOwnerID uuid.UUID, chatID uuid.NullUUID) bool {
|
|
var filtered []database.Chat
|
|
|
|
if chatID.Valid {
|
|
chat, err := api.Database.GetChatByID(ctx, chatID.UUID)
|
|
if err != nil {
|
|
api.Logger.Warn(ctx, "failed to get chat for diff refresh",
|
|
slog.F("chat_id", chatID.UUID),
|
|
slog.F("workspace_id", workspaceID),
|
|
slog.Error(err),
|
|
)
|
|
return false
|
|
}
|
|
filtered = []database.Chat{chat}
|
|
} else {
|
|
chats, err := api.Database.GetChatsByOwnerID(ctx, database.GetChatsByOwnerIDParams{
|
|
OwnerID: workspaceOwnerID,
|
|
})
|
|
if err != nil {
|
|
api.Logger.Warn(ctx, "failed to list workspace owner chats for diff refresh",
|
|
slog.F("workspace_id", workspaceID),
|
|
slog.F("workspace_owner_id", workspaceOwnerID),
|
|
slog.Error(err),
|
|
)
|
|
return false
|
|
}
|
|
filtered = filterChatsByWorkspaceID(chats, workspaceID)
|
|
}
|
|
if len(filtered) == 0 {
|
|
return false
|
|
}
|
|
|
|
allHavePR := true
|
|
for _, chat := range filtered {
|
|
refreshCtx, cancel := context.WithTimeout(ctx, chatDiffBackgroundRefreshTimeout)
|
|
status, err := api.resolveChatDiffStatusWithOptions(refreshCtx, chat, true)
|
|
cancel()
|
|
if err != nil {
|
|
api.Logger.Warn(ctx, "failed to refresh chat diff status after workspace external auth",
|
|
slog.F("workspace_id", workspaceID),
|
|
slog.F("chat_id", chat.ID),
|
|
slog.Error(err),
|
|
)
|
|
allHavePR = false
|
|
} else if status == nil || !status.Url.Valid || strings.TrimSpace(status.Url.String) == "" {
|
|
allHavePR = false
|
|
}
|
|
|
|
api.publishChatStatusEvent(ctx, chat.ID)
|
|
api.publishChatDiffStatusEvent(ctx, chat.ID)
|
|
}
|
|
|
|
return allHavePR
|
|
}
|
|
|
|
func filterChatsByWorkspaceID(chats []database.Chat, workspaceID uuid.UUID) []database.Chat {
|
|
filteredChats := make([]database.Chat, 0, len(chats))
|
|
for _, chat := range chats {
|
|
if !chat.WorkspaceID.Valid || chat.WorkspaceID.UUID != workspaceID {
|
|
continue
|
|
}
|
|
filteredChats = append(filteredChats, chat)
|
|
}
|
|
return filteredChats
|
|
}
|
|
|
|
func (api *API) publishChatStatusEvent(ctx context.Context, chatID uuid.UUID) {
|
|
if api.chatDaemon == nil {
|
|
return
|
|
}
|
|
|
|
if err := api.chatDaemon.RefreshStatus(ctx, chatID); err != nil {
|
|
api.Logger.Debug(ctx, "failed to refresh published chat status",
|
|
slog.F("chat_id", chatID),
|
|
slog.Error(err),
|
|
)
|
|
}
|
|
}
|
|
|
|
func (api *API) publishChatDiffStatusEvent(ctx context.Context, chatID uuid.UUID) {
|
|
if api.chatDaemon == nil {
|
|
return
|
|
}
|
|
|
|
if err := api.chatDaemon.PublishDiffStatusChange(ctx, chatID); err != nil {
|
|
api.Logger.Debug(ctx, "failed to publish chat diff status change",
|
|
slog.F("chat_id", chatID),
|
|
slog.Error(err),
|
|
)
|
|
}
|
|
}
|
|
|
|
func (api *API) resolveChatDiffContents(
|
|
ctx context.Context,
|
|
chat database.Chat,
|
|
) (codersdk.ChatDiffContents, error) {
|
|
result := codersdk.ChatDiffContents{ChatID: chat.ID}
|
|
|
|
status, found, err := api.getCachedChatDiffStatus(ctx, chat.ID)
|
|
if err != nil {
|
|
return result, err
|
|
}
|
|
|
|
reference, err := api.resolveChatDiffReference(ctx, chat, found, status)
|
|
if err != nil {
|
|
return result, err
|
|
}
|
|
|
|
if reference.RepositoryRef != nil {
|
|
provider := strings.TrimSpace(reference.RepositoryRef.Provider)
|
|
if provider != "" {
|
|
result.Provider = &provider
|
|
}
|
|
|
|
origin := strings.TrimSpace(reference.RepositoryRef.RemoteOrigin)
|
|
if origin != "" {
|
|
result.RemoteOrigin = &origin
|
|
}
|
|
|
|
branch := strings.TrimSpace(reference.RepositoryRef.Branch)
|
|
if branch != "" {
|
|
result.Branch = &branch
|
|
}
|
|
}
|
|
|
|
if reference.PullRequestURL != "" {
|
|
pullRequestURL := strings.TrimSpace(reference.PullRequestURL)
|
|
result.PullRequestURL = &pullRequestURL
|
|
if !found || !strings.EqualFold(strings.TrimSpace(status.Url.String), pullRequestURL) {
|
|
_, err := api.upsertChatDiffStatusReference(ctx, chat.ID, pullRequestURL, time.Now().UTC().Add(-time.Second))
|
|
if err != nil {
|
|
return result, err
|
|
}
|
|
}
|
|
}
|
|
|
|
if reference.RepositoryRef == nil {
|
|
return result, nil
|
|
}
|
|
if !strings.EqualFold(reference.RepositoryRef.Provider, string(codersdk.EnhancedExternalAuthProviderGitHub)) {
|
|
return result, nil
|
|
}
|
|
|
|
token := api.resolveChatGitHubAccessToken(ctx, chat.OwnerID)
|
|
|
|
if reference.PullRequestURL != "" {
|
|
diff, err := api.fetchGitHubPullRequestDiff(ctx, reference.PullRequestURL, token)
|
|
if err != nil {
|
|
return result, err
|
|
}
|
|
result.Diff = diff
|
|
return result, nil
|
|
}
|
|
|
|
diff, err := api.fetchGitHubCompareDiff(ctx, *reference.RepositoryRef, token)
|
|
if err != nil {
|
|
return result, err
|
|
}
|
|
result.Diff = diff
|
|
return result, nil
|
|
}
|
|
|
|
// resolveChatDiffReference builds the diff reference from the cached
|
|
// status stored in the database. The git branch and remote origin are
|
|
// populated by the workspace agent during git operations (via the
|
|
// gitaskpass flow), so no SSH into the workspace is needed here.
|
|
//
|
|
//nolint:revive // Boolean indicates whether diff status was found.
|
|
func (api *API) resolveChatDiffReference(
|
|
ctx context.Context,
|
|
chat database.Chat,
|
|
found bool,
|
|
status database.ChatDiffStatus,
|
|
) (chatDiffReference, error) {
|
|
reference := chatDiffReference{}
|
|
if !found {
|
|
return reference, nil
|
|
}
|
|
|
|
reference.PullRequestURL = strings.TrimSpace(status.Url.String)
|
|
|
|
// Build the repository ref from the stored git branch/origin
|
|
// that the agent reported.
|
|
reference.RepositoryRef = api.buildChatRepositoryRefFromStatus(status)
|
|
|
|
// If we have a repo ref with a branch, try to resolve the
|
|
// current open PR. This picks up new PRs after the previous
|
|
// one was closed.
|
|
if reference.RepositoryRef != nil &&
|
|
strings.EqualFold(reference.RepositoryRef.Provider, string(codersdk.EnhancedExternalAuthProviderGitHub)) {
|
|
pullRequestURL, lookupErr := api.resolveGitHubPullRequestURLFromRepositoryRef(ctx, chat.OwnerID, *reference.RepositoryRef)
|
|
if lookupErr != nil {
|
|
api.Logger.Debug(ctx, "failed to resolve pull request from repository reference",
|
|
slog.F("chat_id", chat.ID),
|
|
slog.F("provider", reference.RepositoryRef.Provider),
|
|
slog.F("remote_origin", reference.RepositoryRef.RemoteOrigin),
|
|
slog.F("branch", reference.RepositoryRef.Branch),
|
|
slog.Error(lookupErr),
|
|
)
|
|
} else if pullRequestURL != "" {
|
|
reference.PullRequestURL = pullRequestURL
|
|
}
|
|
}
|
|
|
|
reference.PullRequestURL = normalizeGitHubPullRequestURL(reference.PullRequestURL)
|
|
|
|
// If we have a PR URL but no repo ref (e.g. the agent hasn't
|
|
// reported branch/origin yet), derive a partial ref from the
|
|
// PR URL so the caller can still show provider/owner/repo.
|
|
if reference.RepositoryRef == nil && reference.PullRequestURL != "" {
|
|
if parsed, ok := parseGitHubPullRequestURL(reference.PullRequestURL); ok {
|
|
reference.RepositoryRef = &chatRepositoryRef{
|
|
Provider: string(codersdk.EnhancedExternalAuthProviderGitHub),
|
|
RemoteOrigin: fmt.Sprintf("https://github.com/%s/%s", parsed.Owner, parsed.Repo),
|
|
Owner: parsed.Owner,
|
|
Repo: parsed.Repo,
|
|
}
|
|
}
|
|
}
|
|
|
|
return reference, nil
|
|
}
|
|
|
|
// buildChatRepositoryRefFromStatus constructs a chatRepositoryRef
|
|
// from the git branch and remote origin stored in the cached status.
|
|
// Returns nil if no ref data is available.
|
|
func (api *API) buildChatRepositoryRefFromStatus(status database.ChatDiffStatus) *chatRepositoryRef {
|
|
branch := strings.TrimSpace(status.GitBranch)
|
|
origin := strings.TrimSpace(status.GitRemoteOrigin)
|
|
if branch == "" || origin == "" {
|
|
return nil
|
|
}
|
|
|
|
repoRef := &chatRepositoryRef{
|
|
Provider: strings.TrimSpace(api.resolveExternalAuthProviderType(origin)),
|
|
RemoteOrigin: origin,
|
|
Branch: branch,
|
|
}
|
|
|
|
if owner, repo, normalizedOrigin, ok := parseGitHubRepositoryOrigin(repoRef.RemoteOrigin); ok {
|
|
if repoRef.Provider == "" {
|
|
repoRef.Provider = string(codersdk.EnhancedExternalAuthProviderGitHub)
|
|
}
|
|
repoRef.RemoteOrigin = normalizedOrigin
|
|
repoRef.Owner = owner
|
|
repoRef.Repo = repo
|
|
}
|
|
|
|
if repoRef.Provider == "" {
|
|
return nil
|
|
}
|
|
|
|
return repoRef
|
|
}
|
|
|
|
func (api *API) upsertChatDiffStatusReference(
|
|
ctx context.Context,
|
|
chatID uuid.UUID,
|
|
pullRequestURL string,
|
|
staleAt time.Time,
|
|
) (database.ChatDiffStatus, error) {
|
|
status, err := api.Database.UpsertChatDiffStatusReference(
|
|
ctx,
|
|
database.UpsertChatDiffStatusReferenceParams{
|
|
ChatID: chatID,
|
|
Url: sql.NullString{
|
|
String: pullRequestURL,
|
|
Valid: strings.TrimSpace(pullRequestURL) != "",
|
|
},
|
|
// Empty strings preserve existing values via the
|
|
// CASE expression in the SQL query.
|
|
GitBranch: "",
|
|
GitRemoteOrigin: "",
|
|
StaleAt: staleAt,
|
|
},
|
|
)
|
|
if err != nil {
|
|
return database.ChatDiffStatus{}, xerrors.Errorf("upsert chat diff status reference: %w", err)
|
|
}
|
|
return status, nil
|
|
}
|
|
|
|
func (api *API) getCachedChatDiffStatus(
|
|
ctx context.Context,
|
|
chatID uuid.UUID,
|
|
) (database.ChatDiffStatus, bool, error) {
|
|
status, err := api.Database.GetChatDiffStatusByChatID(ctx, chatID)
|
|
if err == nil {
|
|
return status, true, nil
|
|
}
|
|
if xerrors.Is(err, sql.ErrNoRows) {
|
|
return database.ChatDiffStatus{}, false, nil
|
|
}
|
|
return database.ChatDiffStatus{}, false, xerrors.Errorf(
|
|
"get chat diff status: %w",
|
|
err,
|
|
)
|
|
}
|
|
|
|
func (api *API) resolveExternalAuthProviderType(match string) string {
|
|
match = strings.TrimSpace(match)
|
|
if match == "" {
|
|
return ""
|
|
}
|
|
|
|
for _, extAuth := range api.ExternalAuthConfigs {
|
|
if extAuth.Regex == nil || !extAuth.Regex.MatchString(match) {
|
|
continue
|
|
}
|
|
return strings.ToLower(strings.TrimSpace(extAuth.Type))
|
|
}
|
|
|
|
return ""
|
|
}
|
|
|
|
func parseGitHubRepositoryOrigin(raw string) (owner string, repo string, normalizedOrigin string, ok bool) {
|
|
raw = strings.TrimSpace(raw)
|
|
if raw == "" {
|
|
return "", "", "", false
|
|
}
|
|
|
|
matches := githubRepositoryHTTPSPattern.FindStringSubmatch(raw)
|
|
if len(matches) != 3 {
|
|
matches = githubRepositorySSHPathPattern.FindStringSubmatch(raw)
|
|
}
|
|
if len(matches) != 3 {
|
|
return "", "", "", false
|
|
}
|
|
|
|
owner = strings.TrimSpace(matches[1])
|
|
repo = strings.TrimSpace(matches[2])
|
|
repo = strings.TrimSuffix(repo, ".git")
|
|
if owner == "" || repo == "" {
|
|
return "", "", "", false
|
|
}
|
|
|
|
return owner, repo, fmt.Sprintf("https://github.com/%s/%s", owner, repo), true
|
|
}
|
|
|
|
func buildGitHubBranchURL(owner string, repo string, branch string) string {
|
|
owner = strings.TrimSpace(owner)
|
|
repo = strings.TrimSpace(repo)
|
|
branch = strings.TrimSpace(branch)
|
|
if owner == "" || repo == "" || branch == "" {
|
|
return ""
|
|
}
|
|
|
|
return fmt.Sprintf(
|
|
"https://github.com/%s/%s/tree/%s",
|
|
owner,
|
|
repo,
|
|
url.PathEscape(branch),
|
|
)
|
|
}
|
|
|
|
func chatDiffStatusIsStale(status database.ChatDiffStatus, now time.Time) bool {
|
|
if !status.RefreshedAt.Valid {
|
|
return true
|
|
}
|
|
return !status.StaleAt.After(now)
|
|
}
|
|
|
|
func (api *API) refreshChatDiffStatus(
|
|
ctx context.Context,
|
|
chatOwnerID uuid.UUID,
|
|
chatID uuid.UUID,
|
|
pullRequestURL string,
|
|
) (database.ChatDiffStatus, error) {
|
|
status, err := api.fetchGitHubPullRequestStatus(
|
|
ctx,
|
|
pullRequestURL,
|
|
api.resolveChatGitHubAccessToken(ctx, chatOwnerID),
|
|
)
|
|
if err != nil {
|
|
return database.ChatDiffStatus{}, err
|
|
}
|
|
|
|
refreshedAt := time.Now().UTC()
|
|
refreshedStatus, err := api.Database.UpsertChatDiffStatus(
|
|
ctx,
|
|
database.UpsertChatDiffStatusParams{
|
|
ChatID: chatID,
|
|
Url: sql.NullString{String: pullRequestURL, Valid: true},
|
|
PullRequestState: sql.NullString{
|
|
String: status.PullRequestState,
|
|
Valid: status.PullRequestState != "",
|
|
},
|
|
ChangesRequested: status.ChangesRequested,
|
|
Additions: status.Additions,
|
|
Deletions: status.Deletions,
|
|
ChangedFiles: status.ChangedFiles,
|
|
RefreshedAt: refreshedAt,
|
|
StaleAt: refreshedAt.Add(chatDiffStatusTTL),
|
|
},
|
|
)
|
|
if err != nil {
|
|
return database.ChatDiffStatus{}, xerrors.Errorf("upsert chat diff status: %w", err)
|
|
}
|
|
return refreshedStatus, nil
|
|
}
|
|
|
|
func (api *API) resolveChatGitHubAccessToken(
|
|
ctx context.Context,
|
|
userID uuid.UUID,
|
|
) string {
|
|
// Build a map of provider ID -> config so we can refresh tokens
|
|
// using the same code path as provisionerdserver.
|
|
ghConfigs := make(map[string]*externalauth.Config)
|
|
providerIDs := []string{"github"}
|
|
for _, config := range api.ExternalAuthConfigs {
|
|
if !strings.EqualFold(
|
|
config.Type,
|
|
string(codersdk.EnhancedExternalAuthProviderGitHub),
|
|
) {
|
|
continue
|
|
}
|
|
providerIDs = append(providerIDs, config.ID)
|
|
ghConfigs[config.ID] = config
|
|
}
|
|
|
|
seen := map[string]struct{}{}
|
|
for _, providerID := range providerIDs {
|
|
if _, ok := seen[providerID]; ok {
|
|
continue
|
|
}
|
|
seen[providerID] = struct{}{}
|
|
|
|
link, err := api.Database.GetExternalAuthLink(
|
|
ctx,
|
|
database.GetExternalAuthLinkParams{
|
|
ProviderID: providerID,
|
|
UserID: userID,
|
|
},
|
|
)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
|
|
// Refresh the token if there is a matching config, mirroring
|
|
// the same code path used by provisionerdserver when handing
|
|
// tokens to provisioners.
|
|
if cfg, ok := ghConfigs[providerID]; ok {
|
|
refreshed, refreshErr := cfg.RefreshToken(ctx, api.Database, link)
|
|
if refreshErr != nil {
|
|
api.Logger.Debug(ctx, "failed to refresh external auth token for chat diff",
|
|
slog.F("provider_id", providerID),
|
|
slog.F("user_id", userID),
|
|
slog.Error(refreshErr),
|
|
)
|
|
// Fall through — the existing token may still work
|
|
// (e.g. GitHub tokens with no expiry).
|
|
} else {
|
|
link = refreshed
|
|
}
|
|
}
|
|
|
|
token := strings.TrimSpace(link.OAuthAccessToken)
|
|
if token != "" {
|
|
return token
|
|
}
|
|
}
|
|
|
|
return ""
|
|
}
|
|
|
|
func (api *API) resolveGitHubPullRequestURLFromRepositoryRef(
|
|
ctx context.Context,
|
|
userID uuid.UUID,
|
|
repositoryRef chatRepositoryRef,
|
|
) (string, error) {
|
|
if repositoryRef.Owner == "" || repositoryRef.Repo == "" || repositoryRef.Branch == "" {
|
|
return "", nil
|
|
}
|
|
|
|
query := url.Values{}
|
|
query.Set("state", "open")
|
|
query.Set("head", fmt.Sprintf("%s:%s", repositoryRef.Owner, repositoryRef.Branch))
|
|
query.Set("sort", "updated")
|
|
query.Set("direction", "desc")
|
|
query.Set("per_page", "1")
|
|
|
|
requestURL := fmt.Sprintf(
|
|
"%s/repos/%s/%s/pulls?%s",
|
|
githubAPIBaseURL,
|
|
repositoryRef.Owner,
|
|
repositoryRef.Repo,
|
|
query.Encode(),
|
|
)
|
|
|
|
var pulls []struct {
|
|
HTMLURL string `json:"html_url"`
|
|
}
|
|
|
|
token := api.resolveChatGitHubAccessToken(ctx, userID)
|
|
if err := api.decodeGitHubJSON(ctx, requestURL, token, &pulls); err != nil {
|
|
return "", err
|
|
}
|
|
if len(pulls) == 0 {
|
|
return "", nil
|
|
}
|
|
|
|
return normalizeGitHubPullRequestURL(pulls[0].HTMLURL), nil
|
|
}
|
|
|
|
func (api *API) fetchGitHubPullRequestDiff(
|
|
ctx context.Context,
|
|
pullRequestURL string,
|
|
token string,
|
|
) (string, error) {
|
|
ref, ok := parseGitHubPullRequestURL(pullRequestURL)
|
|
if !ok {
|
|
return "", xerrors.Errorf("invalid GitHub pull request URL %q", pullRequestURL)
|
|
}
|
|
|
|
requestURL := fmt.Sprintf(
|
|
"%s/repos/%s/%s/pulls/%d",
|
|
githubAPIBaseURL,
|
|
ref.Owner,
|
|
ref.Repo,
|
|
ref.Number,
|
|
)
|
|
|
|
return api.fetchGitHubDiff(ctx, requestURL, token)
|
|
}
|
|
|
|
func (api *API) fetchGitHubCompareDiff(
|
|
ctx context.Context,
|
|
repositoryRef chatRepositoryRef,
|
|
token string,
|
|
) (string, error) {
|
|
if repositoryRef.Owner == "" || repositoryRef.Repo == "" || repositoryRef.Branch == "" {
|
|
return "", nil
|
|
}
|
|
|
|
var repository struct {
|
|
DefaultBranch string `json:"default_branch"`
|
|
}
|
|
|
|
repositoryURL := fmt.Sprintf(
|
|
"%s/repos/%s/%s",
|
|
githubAPIBaseURL,
|
|
repositoryRef.Owner,
|
|
repositoryRef.Repo,
|
|
)
|
|
if err := api.decodeGitHubJSON(ctx, repositoryURL, token, &repository); err != nil {
|
|
return "", err
|
|
}
|
|
defaultBranch := strings.TrimSpace(repository.DefaultBranch)
|
|
if defaultBranch == "" {
|
|
return "", xerrors.New("github repository default branch is empty")
|
|
}
|
|
|
|
requestURL := fmt.Sprintf(
|
|
"%s/repos/%s/%s/compare/%s...%s",
|
|
githubAPIBaseURL,
|
|
repositoryRef.Owner,
|
|
repositoryRef.Repo,
|
|
url.PathEscape(defaultBranch),
|
|
url.PathEscape(repositoryRef.Branch),
|
|
)
|
|
|
|
return api.fetchGitHubDiff(ctx, requestURL, token)
|
|
}
|
|
|
|
func (api *API) fetchGitHubDiff(
|
|
ctx context.Context,
|
|
requestURL string,
|
|
token string,
|
|
) (string, error) {
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL, nil)
|
|
if err != nil {
|
|
return "", xerrors.Errorf("create github diff request: %w", err)
|
|
}
|
|
req.Header.Set("Accept", "application/vnd.github.diff")
|
|
req.Header.Set("X-GitHub-Api-Version", "2022-11-28")
|
|
req.Header.Set("User-Agent", "coder-chat-diff")
|
|
if token != "" {
|
|
req.Header.Set("Authorization", "Bearer "+token)
|
|
}
|
|
|
|
httpClient := api.HTTPClient
|
|
if httpClient == nil {
|
|
httpClient = http.DefaultClient
|
|
}
|
|
|
|
resp, err := httpClient.Do(req)
|
|
if err != nil {
|
|
return "", xerrors.Errorf("execute github diff request: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
body, readErr := io.ReadAll(io.LimitReader(resp.Body, 8192))
|
|
if readErr != nil {
|
|
return "", xerrors.Errorf("github diff request failed with status %d", resp.StatusCode)
|
|
}
|
|
return "", xerrors.Errorf(
|
|
"github diff request failed with status %d: %s",
|
|
resp.StatusCode,
|
|
strings.TrimSpace(string(body)),
|
|
)
|
|
}
|
|
|
|
diff, err := io.ReadAll(io.LimitReader(resp.Body, 4<<20))
|
|
if err != nil {
|
|
return "", xerrors.Errorf("read github diff response: %w", err)
|
|
}
|
|
return string(diff), nil
|
|
}
|
|
|
|
func (api *API) fetchGitHubPullRequestStatus(
|
|
ctx context.Context,
|
|
pullRequestURL string,
|
|
token string,
|
|
) (githubPullRequestStatus, error) {
|
|
ref, ok := parseGitHubPullRequestURL(pullRequestURL)
|
|
if !ok {
|
|
return githubPullRequestStatus{}, xerrors.Errorf(
|
|
"invalid GitHub pull request URL %q",
|
|
pullRequestURL,
|
|
)
|
|
}
|
|
|
|
pullEndpoint := fmt.Sprintf(
|
|
"%s/repos/%s/%s/pulls/%d",
|
|
githubAPIBaseURL,
|
|
ref.Owner,
|
|
ref.Repo,
|
|
ref.Number,
|
|
)
|
|
|
|
var pull struct {
|
|
State string `json:"state"`
|
|
Additions int32 `json:"additions"`
|
|
Deletions int32 `json:"deletions"`
|
|
ChangedFiles int32 `json:"changed_files"`
|
|
}
|
|
if err := api.decodeGitHubJSON(ctx, pullEndpoint, token, &pull); err != nil {
|
|
return githubPullRequestStatus{}, err
|
|
}
|
|
|
|
var reviews []struct {
|
|
ID int64 `json:"id"`
|
|
State string `json:"state"`
|
|
User struct {
|
|
Login string `json:"login"`
|
|
} `json:"user"`
|
|
}
|
|
if err := api.decodeGitHubJSON(
|
|
ctx,
|
|
pullEndpoint+"/reviews?per_page=100",
|
|
token,
|
|
&reviews,
|
|
); err != nil {
|
|
return githubPullRequestStatus{}, err
|
|
}
|
|
|
|
return githubPullRequestStatus{
|
|
PullRequestState: strings.ToLower(strings.TrimSpace(pull.State)),
|
|
ChangesRequested: hasOutstandingGitHubChangesRequested(reviews),
|
|
Additions: pull.Additions,
|
|
Deletions: pull.Deletions,
|
|
ChangedFiles: pull.ChangedFiles,
|
|
}, nil
|
|
}
|
|
|
|
func (api *API) decodeGitHubJSON(
|
|
ctx context.Context,
|
|
requestURL string,
|
|
token string,
|
|
dest any,
|
|
) error {
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL, nil)
|
|
if err != nil {
|
|
return xerrors.Errorf("create github request: %w", err)
|
|
}
|
|
req.Header.Set("Accept", "application/vnd.github+json")
|
|
req.Header.Set("X-GitHub-Api-Version", "2022-11-28")
|
|
req.Header.Set("User-Agent", "coder-chat-diff-status")
|
|
if token != "" {
|
|
req.Header.Set("Authorization", "Bearer "+token)
|
|
}
|
|
|
|
httpClient := api.HTTPClient
|
|
if httpClient == nil {
|
|
httpClient = http.DefaultClient
|
|
}
|
|
|
|
resp, err := httpClient.Do(req)
|
|
if err != nil {
|
|
return xerrors.Errorf("execute github request: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
body, readErr := io.ReadAll(io.LimitReader(resp.Body, 8192))
|
|
if readErr != nil {
|
|
return xerrors.Errorf(
|
|
"github request failed with status %d",
|
|
resp.StatusCode,
|
|
)
|
|
}
|
|
return xerrors.Errorf(
|
|
"github request failed with status %d: %s",
|
|
resp.StatusCode,
|
|
strings.TrimSpace(string(body)),
|
|
)
|
|
}
|
|
|
|
if err := json.NewDecoder(resp.Body).Decode(dest); err != nil {
|
|
return xerrors.Errorf("decode github response: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func hasOutstandingGitHubChangesRequested(
|
|
reviews []struct {
|
|
ID int64 `json:"id"`
|
|
State string `json:"state"`
|
|
User struct {
|
|
Login string `json:"login"`
|
|
} `json:"user"`
|
|
},
|
|
) bool {
|
|
type reviewerState struct {
|
|
reviewID int64
|
|
state string
|
|
}
|
|
|
|
statesByReviewer := make(map[string]reviewerState)
|
|
for _, review := range reviews {
|
|
login := strings.ToLower(strings.TrimSpace(review.User.Login))
|
|
if login == "" {
|
|
continue
|
|
}
|
|
|
|
state := strings.ToUpper(strings.TrimSpace(review.State))
|
|
switch state {
|
|
case "CHANGES_REQUESTED", "APPROVED", "DISMISSED":
|
|
default:
|
|
continue
|
|
}
|
|
|
|
current, exists := statesByReviewer[login]
|
|
if exists && current.reviewID > review.ID {
|
|
continue
|
|
}
|
|
statesByReviewer[login] = reviewerState{
|
|
reviewID: review.ID,
|
|
state: state,
|
|
}
|
|
}
|
|
|
|
for _, state := range statesByReviewer {
|
|
if state.state == "CHANGES_REQUESTED" {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func normalizeGitHubPullRequestURL(raw string) string {
|
|
ref, ok := parseGitHubPullRequestURL(strings.TrimRight(
|
|
strings.TrimSpace(raw),
|
|
"),.;",
|
|
))
|
|
if !ok {
|
|
return ""
|
|
}
|
|
return fmt.Sprintf("https://github.com/%s/%s/pull/%d", ref.Owner, ref.Repo, ref.Number)
|
|
}
|
|
|
|
func parseGitHubPullRequestURL(raw string) (githubPullRequestRef, bool) {
|
|
matches := githubPullRequestPathPattern.FindStringSubmatch(strings.TrimSpace(raw))
|
|
if len(matches) != 4 {
|
|
return githubPullRequestRef{}, false
|
|
}
|
|
|
|
number, err := strconv.Atoi(matches[3])
|
|
if err != nil {
|
|
return githubPullRequestRef{}, false
|
|
}
|
|
|
|
return githubPullRequestRef{
|
|
Owner: matches[1],
|
|
Repo: matches[2],
|
|
Number: number,
|
|
}, true
|
|
}
|
|
|
|
type createChatWorkspaceSelection struct {
|
|
WorkspaceID uuid.NullUUID
|
|
}
|
|
|
|
func (api *API) validateCreateChatWorkspaceSelection(
|
|
ctx context.Context,
|
|
req codersdk.CreateChatRequest,
|
|
) (
|
|
createChatWorkspaceSelection,
|
|
int,
|
|
*codersdk.Response,
|
|
) {
|
|
selection := createChatWorkspaceSelection{}
|
|
if req.WorkspaceID == nil {
|
|
return selection, 0, nil
|
|
}
|
|
|
|
workspace, err := api.Database.GetWorkspaceByID(ctx, *req.WorkspaceID)
|
|
if err != nil {
|
|
if httpapi.Is404Error(err) {
|
|
return selection, http.StatusBadRequest, &codersdk.Response{
|
|
Message: "Workspace not found or you do not have access to this resource",
|
|
}
|
|
}
|
|
return selection, http.StatusInternalServerError, &codersdk.Response{
|
|
Message: "Failed to get workspace.",
|
|
Detail: err.Error(),
|
|
}
|
|
}
|
|
selection.WorkspaceID = uuid.NullUUID{
|
|
UUID: workspace.ID,
|
|
Valid: true,
|
|
}
|
|
|
|
return selection, 0, nil
|
|
}
|
|
|
|
func (api *API) resolveCreateChatModelConfigID(
|
|
ctx context.Context,
|
|
req codersdk.CreateChatRequest,
|
|
) (uuid.UUID, int, *codersdk.Response) {
|
|
if req.ModelConfigID != nil {
|
|
if *req.ModelConfigID == uuid.Nil {
|
|
return uuid.Nil, http.StatusBadRequest, &codersdk.Response{
|
|
Message: "Invalid model config ID.",
|
|
}
|
|
}
|
|
return *req.ModelConfigID, 0, nil
|
|
}
|
|
|
|
defaultModelConfig, err := api.Database.GetDefaultChatModelConfig(ctx)
|
|
if err != nil {
|
|
if xerrors.Is(err, sql.ErrNoRows) {
|
|
return uuid.Nil, http.StatusBadRequest, &codersdk.Response{
|
|
Message: "No default chat model config is configured.",
|
|
}
|
|
}
|
|
return uuid.Nil, http.StatusInternalServerError, &codersdk.Response{
|
|
Message: "Failed to resolve chat model config.",
|
|
Detail: err.Error(),
|
|
}
|
|
}
|
|
|
|
return defaultModelConfig.ID, 0, nil
|
|
}
|
|
|
|
func normalizeChatCompressionThreshold(
|
|
requested *int32,
|
|
fallback int32,
|
|
) (int32, error) {
|
|
threshold := fallback
|
|
if requested != nil {
|
|
threshold = *requested
|
|
}
|
|
|
|
if threshold < minChatContextCompressionThreshold ||
|
|
threshold > maxChatContextCompressionThreshold {
|
|
return 0, xerrors.Errorf(
|
|
"context_compression_threshold must be between %d and %d",
|
|
minChatContextCompressionThreshold,
|
|
maxChatContextCompressionThreshold,
|
|
)
|
|
}
|
|
|
|
return threshold, nil
|
|
}
|
|
|
|
const (
|
|
// maxChatFileSize is the maximum size of a chat file upload (10 MB).
|
|
maxChatFileSize = 10 << 20
|
|
// maxChatFileName is the maximum length of an uploaded file name.
|
|
maxChatFileName = 255
|
|
)
|
|
|
|
// allowedChatFileMIMETypes lists the content types accepted for chat
|
|
// file uploads. SVG is explicitly excluded because it can contain scripts.
|
|
var allowedChatFileMIMETypes = map[string]bool{
|
|
"image/png": true,
|
|
"image/jpeg": true,
|
|
"image/gif": true,
|
|
"image/webp": true,
|
|
"image/svg+xml": false, // SVG can contain scripts.
|
|
}
|
|
|
|
var (
|
|
webpMagicRIFF = []byte("RIFF")
|
|
webpMagicWEBP = []byte("WEBP")
|
|
)
|
|
|
|
// detectChatFileType detects the MIME type of the given data.
|
|
// It extends http.DetectContentType with support for WebP, which
|
|
// Go's standard sniffer does not recognize.
|
|
func detectChatFileType(data []byte) string {
|
|
if len(data) >= 12 &&
|
|
bytes.Equal(data[0:4], webpMagicRIFF) &&
|
|
bytes.Equal(data[8:12], webpMagicWEBP) {
|
|
return "image/webp"
|
|
}
|
|
return http.DetectContentType(data)
|
|
}
|
|
|
|
//nolint:revive // get-return: revive assumes get* must be a getter, but this is an HTTP handler.
|
|
func (api *API) getChatSystemPrompt(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
prompt, err := api.Database.GetChatSystemPrompt(ctx)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Internal error fetching chat system prompt.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
httpapi.Write(ctx, rw, http.StatusOK, codersdk.ChatSystemPromptResponse{
|
|
SystemPrompt: prompt,
|
|
})
|
|
}
|
|
|
|
func (api *API) putChatSystemPrompt(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
var req codersdk.UpdateChatSystemPromptRequest
|
|
if !httpapi.Read(ctx, rw, r, &req) {
|
|
return
|
|
}
|
|
trimmedPrompt := strings.TrimSpace(req.SystemPrompt)
|
|
// 128 KiB is generous for a system prompt while still
|
|
// preventing abuse or accidental pastes of large content.
|
|
if len(trimmedPrompt) > maxSystemPromptLenBytes {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "System prompt exceeds maximum length.",
|
|
Detail: fmt.Sprintf("Maximum length is %d bytes, got %d.", maxSystemPromptLenBytes, len(trimmedPrompt)),
|
|
})
|
|
return
|
|
}
|
|
err := api.Database.UpsertChatSystemPrompt(ctx, trimmedPrompt)
|
|
if httpapi.Is404Error(err) { // also catches authz error
|
|
httpapi.ResourceNotFound(rw)
|
|
return
|
|
} else if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Internal error updating chat system prompt.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
rw.WriteHeader(http.StatusNoContent)
|
|
}
|
|
|
|
func (api *API) resolvedChatSystemPrompt(ctx context.Context) string {
|
|
custom, err := api.Database.GetChatSystemPrompt(ctx)
|
|
if err != nil {
|
|
// Log but don't fail chat creation — fall back to the
|
|
// built-in default so the user isn't blocked.
|
|
api.Logger.Error(ctx, "failed to fetch custom chat system prompt, using default", slog.Error(err))
|
|
return chatd.DefaultSystemPrompt
|
|
}
|
|
if strings.TrimSpace(custom) != "" {
|
|
return custom
|
|
}
|
|
return chatd.DefaultSystemPrompt
|
|
}
|
|
|
|
// @Summary Upload a chat file
|
|
// @ID upload-chat-file
|
|
// @Security CoderSessionToken
|
|
// @Accept application/octet-stream
|
|
// @Produce json
|
|
// @Tags Chats
|
|
// @Param Content-Type header string true "Content-Type must be an image type (image/png, image/jpeg, image/gif, image/webp)"
|
|
// @Param organization query string true "Organization ID" format(uuid)
|
|
// @Success 201 {object} codersdk.UploadChatFileResponse
|
|
// @Failure 400 {object} codersdk.Response
|
|
// @Failure 401 {object} codersdk.Response
|
|
// @Failure 413 {object} codersdk.Response
|
|
// @Failure 500 {object} codersdk.Response
|
|
// @Router /chats/files [post]
|
|
func (api *API) postChatFile(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
apiKey := httpmw.APIKey(r)
|
|
|
|
if !api.Authorize(r, policy.ActionCreate, rbac.ResourceChat.WithOwner(apiKey.UserID.String())) {
|
|
httpapi.Forbidden(rw)
|
|
return
|
|
}
|
|
|
|
orgIDStr := r.URL.Query().Get("organization")
|
|
if orgIDStr == "" {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Missing organization query parameter.",
|
|
})
|
|
return
|
|
}
|
|
orgID, err := uuid.Parse(orgIDStr)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid organization ID.",
|
|
})
|
|
return
|
|
}
|
|
|
|
contentType := r.Header.Get("Content-Type")
|
|
if contentType == "" {
|
|
contentType = "application/octet-stream"
|
|
}
|
|
// Strip parameters (e.g. "image/png; charset=utf-8" → "image/png")
|
|
// so the allowlist check matches the base media type.
|
|
if mediaType, _, err := mime.ParseMediaType(contentType); err == nil {
|
|
contentType = mediaType
|
|
}
|
|
|
|
if allowed, ok := allowedChatFileMIMETypes[contentType]; !ok || !allowed {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Unsupported file type.",
|
|
Detail: "Allowed types: image/png, image/jpeg, image/gif, image/webp.",
|
|
})
|
|
return
|
|
}
|
|
|
|
r.Body = http.MaxBytesReader(rw, r.Body, maxChatFileSize)
|
|
br := bufio.NewReader(r.Body)
|
|
|
|
// Peek at the leading bytes to sniff the real content type
|
|
// before reading the entire body.
|
|
peek, peekErr := br.Peek(512)
|
|
if peekErr != nil && !errors.Is(peekErr, io.EOF) && !errors.Is(peekErr, bufio.ErrBufferFull) {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Failed to read file from request.",
|
|
Detail: peekErr.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
// Verify the actual content matches a safe image type so that
|
|
// a client cannot spoof Content-Type to serve active content.
|
|
detected := detectChatFileType(peek)
|
|
if allowed, ok := allowedChatFileMIMETypes[detected]; !ok || !allowed {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Unsupported file type.",
|
|
Detail: "Allowed types: image/png, image/jpeg, image/gif, image/webp.",
|
|
})
|
|
return
|
|
}
|
|
|
|
// Read the full body now that we know the type is valid.
|
|
data, err := io.ReadAll(br)
|
|
if err != nil {
|
|
var maxBytesErr *http.MaxBytesError
|
|
if errors.As(err, &maxBytesErr) {
|
|
httpapi.Write(ctx, rw, http.StatusRequestEntityTooLarge, codersdk.Response{
|
|
Message: "File too large.",
|
|
Detail: fmt.Sprintf("Maximum file size is %d bytes.", maxChatFileSize),
|
|
})
|
|
return
|
|
}
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Failed to read file from request.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
// Extract filename from Content-Disposition header if provided.
|
|
var filename string
|
|
if cd := r.Header.Get("Content-Disposition"); cd != "" {
|
|
if _, params, err := mime.ParseMediaType(cd); err == nil {
|
|
filename = params["filename"]
|
|
if len(filename) > maxChatFileName {
|
|
// Truncate at rune boundary to avoid splitting
|
|
// multi-byte UTF-8 characters.
|
|
var truncated []byte
|
|
for _, r := range filename {
|
|
encoded := []byte(string(r))
|
|
if len(truncated)+len(encoded) > maxChatFileName {
|
|
break
|
|
}
|
|
truncated = append(truncated, encoded...)
|
|
}
|
|
filename = string(truncated)
|
|
}
|
|
}
|
|
}
|
|
|
|
chatFile, err := api.Database.InsertChatFile(ctx, database.InsertChatFileParams{
|
|
OwnerID: apiKey.UserID,
|
|
OrganizationID: orgID,
|
|
Name: filename,
|
|
Mimetype: detected,
|
|
Data: data,
|
|
})
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to save chat file.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
httpapi.Write(ctx, rw, http.StatusCreated, codersdk.UploadChatFileResponse{
|
|
ID: chatFile.ID,
|
|
})
|
|
}
|
|
|
|
// @Summary Get a chat file
|
|
// @ID get-chat-file
|
|
// @Security CoderSessionToken
|
|
// @Tags Chats
|
|
// @Param file path string true "File ID" format(uuid)
|
|
// @Success 200
|
|
// @Failure 400 {object} codersdk.Response
|
|
// @Failure 401 {object} codersdk.Response
|
|
// @Failure 404 {object} codersdk.Response
|
|
// @Failure 500 {object} codersdk.Response
|
|
// @Router /chats/files/{file} [get]
|
|
func (api *API) chatFileByID(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
|
|
fileIDStr := chi.URLParam(r, "file")
|
|
fileID, err := uuid.Parse(fileIDStr)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid file ID.",
|
|
})
|
|
return
|
|
}
|
|
|
|
chatFile, err := api.Database.GetChatFileByID(ctx, fileID)
|
|
if err != nil {
|
|
if httpapi.Is404Error(err) {
|
|
httpapi.ResourceNotFound(rw)
|
|
return
|
|
}
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to get chat file.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
rw.Header().Set("Content-Type", chatFile.Mimetype)
|
|
if chatFile.Name != "" {
|
|
rw.Header().Set("Content-Disposition", mime.FormatMediaType("inline", map[string]string{"filename": chatFile.Name}))
|
|
} else {
|
|
rw.Header().Set("Content-Disposition", "inline")
|
|
}
|
|
rw.Header().Set("Cache-Control", "private, max-age=31536000, immutable")
|
|
rw.Header().Set("Content-Length", strconv.Itoa(len(chatFile.Data)))
|
|
rw.WriteHeader(http.StatusOK)
|
|
_, _ = rw.Write(chatFile.Data)
|
|
}
|
|
|
|
func createChatInputFromRequest(ctx context.Context, db database.Store, req codersdk.CreateChatRequest) (
|
|
[]fantasy.Content,
|
|
map[int]uuid.UUID,
|
|
string,
|
|
*codersdk.Response,
|
|
) {
|
|
return createChatInputFromParts(ctx, db, req.Content, "content")
|
|
}
|
|
|
|
func createChatInputFromParts(
|
|
ctx context.Context,
|
|
db database.Store,
|
|
parts []codersdk.ChatInputPart,
|
|
fieldName string,
|
|
) ([]fantasy.Content, map[int]uuid.UUID, string, *codersdk.Response) {
|
|
if len(parts) == 0 {
|
|
return nil, nil, "", &codersdk.Response{
|
|
Message: "Content is required.",
|
|
Detail: "Content cannot be empty.",
|
|
}
|
|
}
|
|
|
|
content := make([]fantasy.Content, 0, len(parts))
|
|
fileIDs := make(map[int]uuid.UUID)
|
|
textParts := make([]string, 0, len(parts))
|
|
for i, part := range parts {
|
|
switch strings.ToLower(strings.TrimSpace(string(part.Type))) {
|
|
case string(codersdk.ChatInputPartTypeText):
|
|
text := strings.TrimSpace(part.Text)
|
|
if text == "" {
|
|
return nil, nil, "", &codersdk.Response{
|
|
Message: "Invalid input part.",
|
|
Detail: fmt.Sprintf("%s[%d].text cannot be empty.", fieldName, i),
|
|
}
|
|
}
|
|
content = append(content, fantasy.TextContent{Text: text})
|
|
textParts = append(textParts, text)
|
|
case string(codersdk.ChatInputPartTypeFile):
|
|
if part.FileID == uuid.Nil {
|
|
return nil, nil, "", &codersdk.Response{
|
|
Message: "Invalid input part.",
|
|
Detail: fmt.Sprintf("%s[%d].file_id is required for file parts.", fieldName, i),
|
|
}
|
|
}
|
|
// Validate that the file exists and get its media type.
|
|
// File data is not loaded here; it's resolved at LLM
|
|
// dispatch time via chatFileResolver.
|
|
chatFile, err := db.GetChatFileByID(ctx, part.FileID)
|
|
if err != nil {
|
|
if httpapi.Is404Error(err) {
|
|
return nil, nil, "", &codersdk.Response{
|
|
Message: "Invalid input part.",
|
|
Detail: fmt.Sprintf("%s[%d].file_id references a file that does not exist.", fieldName, i),
|
|
}
|
|
}
|
|
return nil, nil, "", &codersdk.Response{
|
|
Message: "Internal error.",
|
|
Detail: fmt.Sprintf("Failed to retrieve file for %s[%d].", fieldName, i),
|
|
}
|
|
}
|
|
content = append(content, fantasy.FileContent{
|
|
MediaType: chatFile.Mimetype,
|
|
})
|
|
fileIDs[len(content)-1] = part.FileID
|
|
case string(codersdk.ChatInputPartTypeFileReference):
|
|
if part.FileName == "" {
|
|
return nil, nil, "", &codersdk.Response{
|
|
Message: "Invalid input part.",
|
|
Detail: fmt.Sprintf("%s[%d].file_name cannot be empty for file-reference.", fieldName, i),
|
|
}
|
|
}
|
|
lineRange := fmt.Sprintf("%d", part.StartLine)
|
|
if part.StartLine != part.EndLine {
|
|
lineRange = fmt.Sprintf("%d-%d", part.StartLine, part.EndLine)
|
|
}
|
|
var sb strings.Builder
|
|
_, _ = fmt.Fprintf(&sb, "[file-reference] %s:%s", part.FileName, lineRange)
|
|
if strings.TrimSpace(part.Content) != "" {
|
|
_, _ = fmt.Fprintf(&sb, "\n```%s\n%s\n```", part.FileName, strings.TrimSpace(part.Content))
|
|
}
|
|
text := sb.String()
|
|
content = append(content, fantasy.TextContent{Text: text})
|
|
textParts = append(textParts, text)
|
|
default:
|
|
return nil, nil, "", &codersdk.Response{
|
|
Message: "Invalid input part.",
|
|
Detail: fmt.Sprintf(
|
|
"%s[%d].type %q is not supported.",
|
|
fieldName,
|
|
i,
|
|
part.Type,
|
|
),
|
|
}
|
|
}
|
|
}
|
|
|
|
// Allow file-only messages. The titleSource may be empty
|
|
// when only file parts are provided, callers handle this.
|
|
if len(content) == 0 {
|
|
return nil, nil, "", &codersdk.Response{
|
|
Message: "Content is required.",
|
|
Detail: fmt.Sprintf("%s must include at least one text or file part.", fieldName),
|
|
}
|
|
}
|
|
titleSource := strings.TrimSpace(strings.Join(textParts, " "))
|
|
return content, fileIDs, titleSource, nil
|
|
}
|
|
|
|
func chatTitleFromMessage(message string) string {
|
|
const maxWords = 6
|
|
const maxRunes = 80
|
|
words := strings.Fields(message)
|
|
if len(words) == 0 {
|
|
return "New Chat"
|
|
}
|
|
truncated := false
|
|
if len(words) > maxWords {
|
|
words = words[:maxWords]
|
|
truncated = true
|
|
}
|
|
title := strings.Join(words, " ")
|
|
if truncated {
|
|
title += "…"
|
|
}
|
|
return truncateRunes(title, maxRunes)
|
|
}
|
|
|
|
func truncateRunes(value string, maxLen int) string {
|
|
if maxLen <= 0 {
|
|
return ""
|
|
}
|
|
|
|
runes := []rune(value)
|
|
if len(runes) <= maxLen {
|
|
return value
|
|
}
|
|
|
|
return string(runes[:maxLen])
|
|
}
|
|
|
|
func convertChat(c database.Chat, diffStatus *database.ChatDiffStatus) codersdk.Chat {
|
|
chat := codersdk.Chat{
|
|
ID: c.ID,
|
|
OwnerID: c.OwnerID,
|
|
LastModelConfigID: c.LastModelConfigID,
|
|
Title: c.Title,
|
|
Status: codersdk.ChatStatus(c.Status),
|
|
Archived: c.Archived,
|
|
CreatedAt: c.CreatedAt,
|
|
UpdatedAt: c.UpdatedAt,
|
|
}
|
|
if c.LastError.Valid {
|
|
chat.LastError = &c.LastError.String
|
|
}
|
|
if c.ParentChatID.Valid {
|
|
parentChatID := c.ParentChatID.UUID
|
|
chat.ParentChatID = &parentChatID
|
|
}
|
|
switch {
|
|
case c.RootChatID.Valid:
|
|
rootChatID := c.RootChatID.UUID
|
|
chat.RootChatID = &rootChatID
|
|
case c.ParentChatID.Valid:
|
|
rootChatID := c.ParentChatID.UUID
|
|
chat.RootChatID = &rootChatID
|
|
default:
|
|
rootChatID := c.ID
|
|
chat.RootChatID = &rootChatID
|
|
}
|
|
if c.WorkspaceID.Valid {
|
|
chat.WorkspaceID = &c.WorkspaceID.UUID
|
|
}
|
|
if diffStatus != nil {
|
|
convertedDiffStatus := convertChatDiffStatus(c.ID, diffStatus)
|
|
chat.DiffStatus = &convertedDiffStatus
|
|
}
|
|
return chat
|
|
}
|
|
|
|
func convertChats(chats []database.Chat, diffStatusesByChatID map[uuid.UUID]database.ChatDiffStatus) []codersdk.Chat {
|
|
result := make([]codersdk.Chat, len(chats))
|
|
for i, c := range chats {
|
|
diffStatus, ok := diffStatusesByChatID[c.ID]
|
|
if ok {
|
|
result[i] = convertChat(c, &diffStatus)
|
|
continue
|
|
}
|
|
|
|
result[i] = convertChat(c, nil)
|
|
if diffStatusesByChatID != nil {
|
|
emptyDiffStatus := convertChatDiffStatus(c.ID, nil)
|
|
result[i].DiffStatus = &emptyDiffStatus
|
|
}
|
|
}
|
|
return result
|
|
}
|
|
|
|
func convertChatQueuedMessage(m database.ChatQueuedMessage) codersdk.ChatQueuedMessage {
|
|
return db2sdk.ChatQueuedMessage(m)
|
|
}
|
|
|
|
func convertChatQueuedMessagePtr(m database.ChatQueuedMessage) *codersdk.ChatQueuedMessage {
|
|
qm := convertChatQueuedMessage(m)
|
|
return &qm
|
|
}
|
|
|
|
func convertChatQueuedMessages(msgs []database.ChatQueuedMessage) []codersdk.ChatQueuedMessage {
|
|
result := make([]codersdk.ChatQueuedMessage, 0, len(msgs))
|
|
for _, m := range msgs {
|
|
result = append(result, convertChatQueuedMessage(m))
|
|
}
|
|
return result
|
|
}
|
|
|
|
func convertChatMessage(m database.ChatMessage) codersdk.ChatMessage {
|
|
return db2sdk.ChatMessage(m)
|
|
}
|
|
|
|
func convertChatMessages(messages []database.ChatMessage) []codersdk.ChatMessage {
|
|
result := make([]codersdk.ChatMessage, 0, len(messages))
|
|
for _, m := range messages {
|
|
result = append(result, convertChatMessage(m))
|
|
}
|
|
return result
|
|
}
|
|
|
|
func convertChatDiffStatus(chatID uuid.UUID, status *database.ChatDiffStatus) codersdk.ChatDiffStatus {
|
|
result := codersdk.ChatDiffStatus{
|
|
ChatID: chatID,
|
|
}
|
|
if status == nil {
|
|
return result
|
|
}
|
|
|
|
result.ChatID = status.ChatID
|
|
if status.Url.Valid {
|
|
u := strings.TrimSpace(status.Url.String)
|
|
if u != "" {
|
|
result.URL = &u
|
|
}
|
|
}
|
|
if result.URL == nil {
|
|
owner, repo, _, ok := parseGitHubRepositoryOrigin(status.GitRemoteOrigin)
|
|
if ok {
|
|
branchURL := buildGitHubBranchURL(owner, repo, status.GitBranch)
|
|
if branchURL != "" {
|
|
result.URL = &branchURL
|
|
}
|
|
}
|
|
}
|
|
if status.PullRequestState.Valid {
|
|
pullRequestState := strings.TrimSpace(status.PullRequestState.String)
|
|
if pullRequestState != "" {
|
|
result.PullRequestState = &pullRequestState
|
|
}
|
|
}
|
|
result.ChangesRequested = status.ChangesRequested
|
|
result.Additions = status.Additions
|
|
result.Deletions = status.Deletions
|
|
result.ChangedFiles = status.ChangedFiles
|
|
if status.RefreshedAt.Valid {
|
|
refreshedAt := status.RefreshedAt.Time
|
|
result.RefreshedAt = &refreshedAt
|
|
}
|
|
staleAt := status.StaleAt
|
|
result.StaleAt = &staleAt
|
|
|
|
return result
|
|
}
|
|
|
|
func (api *API) listChatProviders(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
//nolint:gocritic // System context required to read enabled chat providers.
|
|
systemCtx := dbauthz.AsSystemRestricted(ctx)
|
|
if !api.Authorize(r, policy.ActionRead, rbac.ResourceDeploymentConfig) {
|
|
httpapi.Forbidden(rw)
|
|
return
|
|
}
|
|
|
|
providers, err := api.Database.GetChatProviders(ctx)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to list chat providers.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
providersByName := make(map[string]database.ChatProvider, len(providers))
|
|
configuredProviders := make([]chatprovider.ConfiguredProvider, 0, len(providers))
|
|
for _, provider := range providers {
|
|
normalizedProvider := normalizeChatProvider(provider.Provider)
|
|
if normalizedProvider == "" {
|
|
continue
|
|
}
|
|
provider.Provider = normalizedProvider
|
|
providersByName[normalizedProvider] = provider
|
|
configuredProviders = append(configuredProviders, chatprovider.ConfiguredProvider{
|
|
Provider: normalizedProvider,
|
|
APIKey: provider.APIKey,
|
|
BaseURL: provider.BaseUrl,
|
|
})
|
|
}
|
|
if api.chatDaemon == nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Chat processor is unavailable.",
|
|
Detail: "Chat processor is not configured.",
|
|
})
|
|
return
|
|
}
|
|
|
|
enabledProviders, err := api.Database.GetEnabledChatProviders(
|
|
systemCtx,
|
|
)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to resolve provider API keys.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
enabledConfiguredProviders := make(
|
|
[]chatprovider.ConfiguredProvider, 0, len(enabledProviders),
|
|
)
|
|
for _, provider := range enabledProviders {
|
|
enabledConfiguredProviders = append(
|
|
enabledConfiguredProviders, chatprovider.ConfiguredProvider{
|
|
Provider: provider.Provider,
|
|
APIKey: provider.APIKey,
|
|
BaseURL: provider.BaseUrl,
|
|
},
|
|
)
|
|
}
|
|
|
|
effectiveKeys := chatprovider.MergeProviderAPIKeys(
|
|
chatProviderAPIKeysFromDeploymentValues(api.DeploymentValues),
|
|
enabledConfiguredProviders,
|
|
)
|
|
effectiveKeys = chatprovider.MergeProviderAPIKeys(
|
|
effectiveKeys, configuredProviders,
|
|
)
|
|
|
|
supportedProviders := chatprovider.SupportedProviders()
|
|
resp := make([]codersdk.ChatProviderConfig, 0, len(supportedProviders))
|
|
for _, provider := range supportedProviders {
|
|
configured, ok := providersByName[provider]
|
|
if ok {
|
|
resp = append(
|
|
resp,
|
|
convertChatProviderConfig(
|
|
configured,
|
|
effectiveKeys.APIKey(provider) != "",
|
|
codersdk.ChatProviderConfigSourceDatabase,
|
|
),
|
|
)
|
|
continue
|
|
}
|
|
|
|
source := codersdk.ChatProviderConfigSourceSupported
|
|
hasAPIKey := effectiveKeys.APIKey(provider) != ""
|
|
enabled := false
|
|
if chatprovider.IsEnvPresetProvider(provider) && hasAPIKey {
|
|
source = codersdk.ChatProviderConfigSourceEnvPreset
|
|
enabled = true
|
|
}
|
|
|
|
resp = append(resp, codersdk.ChatProviderConfig{
|
|
ID: uuid.Nil,
|
|
Provider: provider,
|
|
DisplayName: chatprovider.ProviderDisplayName(provider),
|
|
Enabled: enabled,
|
|
HasAPIKey: hasAPIKey,
|
|
BaseURL: effectiveKeys.BaseURL(provider),
|
|
Source: source,
|
|
})
|
|
}
|
|
|
|
httpapi.Write(ctx, rw, http.StatusOK, resp)
|
|
}
|
|
|
|
func (api *API) createChatProvider(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
apiKey := httpmw.APIKey(r)
|
|
if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) {
|
|
httpapi.Forbidden(rw)
|
|
return
|
|
}
|
|
|
|
var req codersdk.CreateChatProviderConfigRequest
|
|
if !httpapi.Read(ctx, rw, r, &req) {
|
|
return
|
|
}
|
|
|
|
provider := normalizeChatProvider(req.Provider)
|
|
if provider == "" {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid provider.",
|
|
Detail: chatProviderValidationDetail(),
|
|
})
|
|
return
|
|
}
|
|
|
|
enabled := true
|
|
if req.Enabled != nil {
|
|
enabled = *req.Enabled
|
|
}
|
|
baseURL, err := normalizeChatProviderBaseURL(req.BaseURL)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid provider base URL.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
inserted, err := api.Database.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
|
Provider: provider,
|
|
DisplayName: strings.TrimSpace(req.DisplayName),
|
|
APIKey: strings.TrimSpace(req.APIKey),
|
|
BaseUrl: baseURL,
|
|
ApiKeyKeyID: sql.NullString{},
|
|
CreatedBy: uuid.NullUUID{UUID: apiKey.UserID, Valid: apiKey.UserID != uuid.Nil},
|
|
Enabled: enabled,
|
|
})
|
|
if err != nil {
|
|
switch {
|
|
case database.IsUniqueViolation(err):
|
|
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{
|
|
Message: "Chat provider already exists.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
case database.IsCheckViolation(err):
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid provider.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
default:
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to create chat provider.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
}
|
|
|
|
httpapi.Write(
|
|
ctx,
|
|
rw,
|
|
http.StatusCreated,
|
|
convertChatProviderConfig(
|
|
inserted,
|
|
api.hasEffectiveProviderAPIKey(ctx, inserted),
|
|
codersdk.ChatProviderConfigSourceDatabase,
|
|
),
|
|
)
|
|
}
|
|
|
|
func (api *API) updateChatProvider(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) {
|
|
httpapi.Forbidden(rw)
|
|
return
|
|
}
|
|
|
|
providerID, ok := parseChatProviderID(rw, r)
|
|
if !ok {
|
|
return
|
|
}
|
|
|
|
existing, err := api.Database.GetChatProviderByID(ctx, providerID)
|
|
if err != nil {
|
|
if httpapi.Is404Error(err) {
|
|
httpapi.ResourceNotFound(rw)
|
|
return
|
|
}
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to get chat provider.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
var req codersdk.UpdateChatProviderConfigRequest
|
|
if !httpapi.Read(ctx, rw, r, &req) {
|
|
return
|
|
}
|
|
|
|
displayName := existing.DisplayName
|
|
if trimmed := strings.TrimSpace(req.DisplayName); trimmed != "" {
|
|
displayName = trimmed
|
|
}
|
|
|
|
enabled := existing.Enabled
|
|
if req.Enabled != nil {
|
|
enabled = *req.Enabled
|
|
}
|
|
|
|
apiKey := existing.APIKey
|
|
apiKeyKeyID := existing.ApiKeyKeyID
|
|
if req.APIKey != nil {
|
|
apiKey = strings.TrimSpace(*req.APIKey)
|
|
apiKeyKeyID = sql.NullString{}
|
|
}
|
|
baseURL := existing.BaseUrl
|
|
if req.BaseURL != nil {
|
|
baseURL, err = normalizeChatProviderBaseURL(*req.BaseURL)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid provider base URL.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
}
|
|
|
|
updated, err := api.Database.UpdateChatProvider(ctx, database.UpdateChatProviderParams{
|
|
DisplayName: displayName,
|
|
APIKey: apiKey,
|
|
BaseUrl: baseURL,
|
|
ApiKeyKeyID: apiKeyKeyID,
|
|
Enabled: enabled,
|
|
ID: existing.ID,
|
|
})
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to update chat provider.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
httpapi.Write(
|
|
ctx,
|
|
rw,
|
|
http.StatusOK,
|
|
convertChatProviderConfig(
|
|
updated,
|
|
api.hasEffectiveProviderAPIKey(ctx, updated),
|
|
codersdk.ChatProviderConfigSourceDatabase,
|
|
),
|
|
)
|
|
}
|
|
|
|
func (api *API) deleteChatProvider(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) {
|
|
httpapi.Forbidden(rw)
|
|
return
|
|
}
|
|
|
|
providerID, ok := parseChatProviderID(rw, r)
|
|
if !ok {
|
|
return
|
|
}
|
|
|
|
if _, err := api.Database.GetChatProviderByID(ctx, providerID); err != nil {
|
|
if httpapi.Is404Error(err) {
|
|
httpapi.ResourceNotFound(rw)
|
|
return
|
|
}
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to get chat provider.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
if err := api.Database.DeleteChatProviderByID(ctx, providerID); err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to delete chat provider.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
rw.WriteHeader(http.StatusNoContent)
|
|
}
|
|
|
|
func (api *API) listChatModelConfigs(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
|
|
// Admin users can see all model configs (including disabled ones)
|
|
// for management purposes. Non-admin users see only enabled
|
|
// configs, which is sufficient for using the chat feature.
|
|
isAdmin := api.Authorize(r, policy.ActionRead, rbac.ResourceDeploymentConfig)
|
|
|
|
var configs []database.ChatModelConfig
|
|
var err error
|
|
if isAdmin {
|
|
configs, err = api.Database.GetChatModelConfigs(ctx)
|
|
} else {
|
|
//nolint:gocritic // All authenticated users need to read enabled model configs to use the chat feature.
|
|
configs, err = api.Database.GetEnabledChatModelConfigs(dbauthz.AsSystemRestricted(ctx))
|
|
}
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to list chat model configs.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
resp := make([]codersdk.ChatModelConfig, 0, len(configs))
|
|
for _, config := range configs {
|
|
resp = append(resp, convertChatModelConfig(config))
|
|
}
|
|
|
|
httpapi.Write(ctx, rw, http.StatusOK, resp)
|
|
}
|
|
|
|
func (api *API) createChatModelConfig(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
apiKey := httpmw.APIKey(r)
|
|
if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) {
|
|
httpapi.Forbidden(rw)
|
|
return
|
|
}
|
|
|
|
var req codersdk.CreateChatModelConfigRequest
|
|
if !httpapi.Read(ctx, rw, r, &req) {
|
|
return
|
|
}
|
|
|
|
provider := normalizeChatProvider(req.Provider)
|
|
if provider == "" {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid provider.",
|
|
Detail: chatProviderValidationDetail(),
|
|
})
|
|
return
|
|
}
|
|
|
|
model := strings.TrimSpace(req.Model)
|
|
if model == "" {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Model is required.",
|
|
})
|
|
return
|
|
}
|
|
|
|
enabled := true
|
|
if req.Enabled != nil {
|
|
enabled = *req.Enabled
|
|
}
|
|
isDefault := false
|
|
if req.IsDefault != nil {
|
|
isDefault = *req.IsDefault
|
|
}
|
|
|
|
if req.ContextLimit == nil || *req.ContextLimit <= 0 {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Context limit is required.",
|
|
Detail: "context_limit must be greater than zero.",
|
|
})
|
|
return
|
|
}
|
|
contextLimit := *req.ContextLimit
|
|
|
|
compressionThreshold, thresholdErr := normalizeChatCompressionThreshold(
|
|
req.CompressionThreshold,
|
|
defaultChatContextCompressionThreshold,
|
|
)
|
|
if thresholdErr != nil {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid compression threshold.",
|
|
Detail: thresholdErr.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
modelConfigRaw, modelConfigErr := marshalChatModelCallConfig(req.ModelConfig)
|
|
if modelConfigErr != nil {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid model config.",
|
|
Detail: modelConfigErr.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
insertParams := database.InsertChatModelConfigParams{
|
|
Provider: provider,
|
|
Model: model,
|
|
DisplayName: strings.TrimSpace(req.DisplayName),
|
|
Enabled: enabled,
|
|
IsDefault: isDefault,
|
|
ContextLimit: contextLimit,
|
|
CompressionThreshold: compressionThreshold,
|
|
Options: modelConfigRaw,
|
|
CreatedBy: uuid.NullUUID{UUID: apiKey.UserID, Valid: apiKey.UserID != uuid.Nil},
|
|
UpdatedBy: uuid.NullUUID{UUID: apiKey.UserID, Valid: apiKey.UserID != uuid.Nil},
|
|
}
|
|
|
|
var inserted database.ChatModelConfig
|
|
err := api.Database.InTx(func(tx database.Store) error {
|
|
insertAsDefault := isDefault
|
|
if !insertAsDefault {
|
|
_, err := tx.GetDefaultChatModelConfig(ctx)
|
|
switch {
|
|
case err == nil:
|
|
// A default already exists.
|
|
case xerrors.Is(err, sql.ErrNoRows):
|
|
insertAsDefault = true
|
|
default:
|
|
return xerrors.Errorf("get default model config: %w", err)
|
|
}
|
|
}
|
|
|
|
if insertAsDefault {
|
|
if err := tx.UnsetDefaultChatModelConfigs(ctx); err != nil {
|
|
return xerrors.Errorf("unset default model configs: %w", err)
|
|
}
|
|
}
|
|
insertParams.IsDefault = insertAsDefault
|
|
|
|
config, err := tx.InsertChatModelConfig(ctx, insertParams)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
inserted = config
|
|
|
|
if err := ensureDefaultChatModelConfig(ctx, tx); err != nil {
|
|
return err
|
|
}
|
|
|
|
refreshedConfig, err := tx.GetChatModelConfigByID(ctx, inserted.ID)
|
|
if err != nil {
|
|
return xerrors.Errorf("refresh inserted chat model config: %w", err)
|
|
}
|
|
inserted = refreshedConfig
|
|
return nil
|
|
}, nil)
|
|
if err != nil {
|
|
switch {
|
|
case database.IsUniqueViolation(err):
|
|
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{
|
|
Message: "Chat model config already exists.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
case database.IsForeignKeyViolation(err):
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Chat provider is not configured.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
default:
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to create chat model config.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
}
|
|
|
|
httpapi.Write(ctx, rw, http.StatusCreated, convertChatModelConfig(inserted))
|
|
}
|
|
|
|
func (api *API) updateChatModelConfig(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
apiKey := httpmw.APIKey(r)
|
|
if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) {
|
|
httpapi.Forbidden(rw)
|
|
return
|
|
}
|
|
|
|
modelConfigID, ok := parseChatModelConfigID(rw, r)
|
|
if !ok {
|
|
return
|
|
}
|
|
|
|
existing, err := api.Database.GetChatModelConfigByID(ctx, modelConfigID)
|
|
if err != nil {
|
|
if httpapi.Is404Error(err) {
|
|
httpapi.ResourceNotFound(rw)
|
|
return
|
|
}
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to get chat model config.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
var req codersdk.UpdateChatModelConfigRequest
|
|
if !httpapi.Read(ctx, rw, r, &req) {
|
|
return
|
|
}
|
|
|
|
provider := existing.Provider
|
|
if strings.TrimSpace(req.Provider) != "" {
|
|
provider = normalizeChatProvider(req.Provider)
|
|
if provider == "" {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid provider.",
|
|
Detail: chatProviderValidationDetail(),
|
|
})
|
|
return
|
|
}
|
|
}
|
|
|
|
model := existing.Model
|
|
if trimmed := strings.TrimSpace(req.Model); trimmed != "" {
|
|
model = trimmed
|
|
}
|
|
|
|
displayName := existing.DisplayName
|
|
if trimmed := strings.TrimSpace(req.DisplayName); trimmed != "" {
|
|
displayName = trimmed
|
|
}
|
|
|
|
enabled := existing.Enabled
|
|
if req.Enabled != nil {
|
|
enabled = *req.Enabled
|
|
}
|
|
isDefault := existing.IsDefault
|
|
if req.IsDefault != nil {
|
|
isDefault = *req.IsDefault
|
|
}
|
|
|
|
contextLimit := existing.ContextLimit
|
|
if req.ContextLimit != nil {
|
|
if *req.ContextLimit <= 0 {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Context limit must be greater than zero.",
|
|
})
|
|
return
|
|
}
|
|
contextLimit = *req.ContextLimit
|
|
}
|
|
|
|
compressionThreshold, thresholdErr := normalizeChatCompressionThreshold(
|
|
req.CompressionThreshold,
|
|
existing.CompressionThreshold,
|
|
)
|
|
if thresholdErr != nil {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid compression threshold.",
|
|
Detail: thresholdErr.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
modelConfigRaw := existing.Options
|
|
if req.ModelConfig != nil {
|
|
encodedModelConfig, modelConfigErr := marshalChatModelCallConfig(req.ModelConfig)
|
|
if modelConfigErr != nil {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid model config.",
|
|
Detail: modelConfigErr.Error(),
|
|
})
|
|
return
|
|
}
|
|
modelConfigRaw = encodedModelConfig
|
|
}
|
|
|
|
updateParams := database.UpdateChatModelConfigParams{
|
|
Provider: provider,
|
|
Model: model,
|
|
DisplayName: displayName,
|
|
Enabled: enabled,
|
|
IsDefault: isDefault,
|
|
ContextLimit: contextLimit,
|
|
CompressionThreshold: compressionThreshold,
|
|
Options: modelConfigRaw,
|
|
UpdatedBy: uuid.NullUUID{UUID: apiKey.UserID, Valid: apiKey.UserID != uuid.Nil},
|
|
ID: existing.ID,
|
|
}
|
|
|
|
var updated database.ChatModelConfig
|
|
err = api.Database.InTx(func(tx database.Store) error {
|
|
setAsDefault := updateParams.IsDefault && !existing.IsDefault
|
|
if setAsDefault {
|
|
if err := tx.UnsetDefaultChatModelConfigs(ctx); err != nil {
|
|
return xerrors.Errorf("unset default model configs: %w", err)
|
|
}
|
|
}
|
|
|
|
_, err := tx.UpdateChatModelConfig(ctx, updateParams)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
excludeConfigID := uuid.Nil
|
|
if existing.IsDefault && req.IsDefault != nil && !*req.IsDefault {
|
|
excludeConfigID = existing.ID
|
|
}
|
|
|
|
if err := ensureDefaultChatModelConfig(
|
|
ctx,
|
|
tx,
|
|
excludeConfigID,
|
|
); err != nil {
|
|
return err
|
|
}
|
|
|
|
refreshedConfig, err := tx.GetChatModelConfigByID(ctx, existing.ID)
|
|
if err != nil {
|
|
return xerrors.Errorf("refresh updated chat model config: %w", err)
|
|
}
|
|
updated = refreshedConfig
|
|
return nil
|
|
}, nil)
|
|
if err != nil {
|
|
switch {
|
|
case database.IsUniqueViolation(err):
|
|
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{
|
|
Message: "Chat model config already exists.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
case database.IsForeignKeyViolation(err):
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Chat provider is not configured.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
default:
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to update chat model config.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
}
|
|
|
|
httpapi.Write(ctx, rw, http.StatusOK, convertChatModelConfig(updated))
|
|
}
|
|
|
|
func (api *API) deleteChatModelConfig(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) {
|
|
httpapi.Forbidden(rw)
|
|
return
|
|
}
|
|
|
|
modelConfigID, ok := parseChatModelConfigID(rw, r)
|
|
if !ok {
|
|
return
|
|
}
|
|
|
|
if _, err := api.Database.GetChatModelConfigByID(ctx, modelConfigID); err != nil {
|
|
if httpapi.Is404Error(err) {
|
|
httpapi.ResourceNotFound(rw)
|
|
return
|
|
}
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to get chat model config.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
if err := api.Database.InTx(func(tx database.Store) error {
|
|
if err := tx.DeleteChatModelConfigByID(ctx, modelConfigID); err != nil {
|
|
return err
|
|
}
|
|
return ensureDefaultChatModelConfig(ctx, tx)
|
|
}, nil); err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to delete chat model config.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
rw.WriteHeader(http.StatusNoContent)
|
|
}
|
|
|
|
func ensureDefaultChatModelConfig(
|
|
ctx context.Context,
|
|
tx database.Store,
|
|
excludedConfigIDs ...uuid.UUID,
|
|
) error {
|
|
_, err := tx.GetDefaultChatModelConfig(ctx)
|
|
switch {
|
|
case err == nil:
|
|
return nil
|
|
case !xerrors.Is(err, sql.ErrNoRows):
|
|
return xerrors.Errorf("get default model config: %w", err)
|
|
}
|
|
|
|
modelConfigs, err := tx.GetChatModelConfigs(ctx)
|
|
if err != nil {
|
|
return xerrors.Errorf("list chat model configs: %w", err)
|
|
}
|
|
if len(modelConfigs) == 0 {
|
|
return nil
|
|
}
|
|
|
|
candidateConfig := modelConfigs[0]
|
|
excluded := make(map[uuid.UUID]struct{}, len(excludedConfigIDs))
|
|
for _, configID := range excludedConfigIDs {
|
|
if configID == uuid.Nil {
|
|
continue
|
|
}
|
|
excluded[configID] = struct{}{}
|
|
}
|
|
for _, config := range modelConfigs {
|
|
if _, skip := excluded[config.ID]; skip {
|
|
continue
|
|
}
|
|
candidateConfig = config
|
|
break
|
|
}
|
|
|
|
if err := tx.UnsetDefaultChatModelConfigs(ctx); err != nil {
|
|
return xerrors.Errorf("unset default model configs: %w", err)
|
|
}
|
|
|
|
params := chatModelConfigToUpdateParams(candidateConfig)
|
|
params.IsDefault = true
|
|
if _, err := tx.UpdateChatModelConfig(ctx, params); err != nil {
|
|
return xerrors.Errorf("set default model config: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func chatModelConfigToUpdateParams(
|
|
config database.ChatModelConfig,
|
|
) database.UpdateChatModelConfigParams {
|
|
return database.UpdateChatModelConfigParams{
|
|
Provider: config.Provider,
|
|
Model: config.Model,
|
|
DisplayName: config.DisplayName,
|
|
Enabled: config.Enabled,
|
|
IsDefault: config.IsDefault,
|
|
ContextLimit: config.ContextLimit,
|
|
CompressionThreshold: config.CompressionThreshold,
|
|
Options: config.Options,
|
|
UpdatedBy: uuid.NullUUID{},
|
|
ID: config.ID,
|
|
}
|
|
}
|
|
|
|
func parseChatProviderID(rw http.ResponseWriter, r *http.Request) (uuid.UUID, bool) {
|
|
providerID, err := uuid.Parse(chi.URLParam(r, "providerConfig"))
|
|
if err != nil {
|
|
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid chat provider ID.",
|
|
Detail: err.Error(),
|
|
})
|
|
return uuid.Nil, false
|
|
}
|
|
return providerID, true
|
|
}
|
|
|
|
func parseChatModelConfigID(rw http.ResponseWriter, r *http.Request) (uuid.UUID, bool) {
|
|
modelConfigID, err := uuid.Parse(chi.URLParam(r, "modelConfig"))
|
|
if err != nil {
|
|
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid chat model config ID.",
|
|
Detail: err.Error(),
|
|
})
|
|
return uuid.Nil, false
|
|
}
|
|
return modelConfigID, true
|
|
}
|
|
|
|
func convertChatProviderConfig(
|
|
provider database.ChatProvider,
|
|
hasAPIKey bool,
|
|
source codersdk.ChatProviderConfigSource,
|
|
) codersdk.ChatProviderConfig {
|
|
displayName := strings.TrimSpace(provider.DisplayName)
|
|
if displayName == "" {
|
|
displayName = chatprovider.ProviderDisplayName(provider.Provider)
|
|
}
|
|
|
|
return codersdk.ChatProviderConfig{
|
|
ID: provider.ID,
|
|
Provider: provider.Provider,
|
|
DisplayName: displayName,
|
|
Enabled: provider.Enabled,
|
|
HasAPIKey: hasAPIKey,
|
|
BaseURL: strings.TrimSpace(provider.BaseUrl),
|
|
Source: source,
|
|
CreatedAt: provider.CreatedAt,
|
|
UpdatedAt: provider.UpdatedAt,
|
|
}
|
|
}
|
|
|
|
func convertChatModelConfig(config database.ChatModelConfig) codersdk.ChatModelConfig {
|
|
return codersdk.ChatModelConfig{
|
|
ID: config.ID,
|
|
Provider: config.Provider,
|
|
Model: config.Model,
|
|
DisplayName: config.DisplayName,
|
|
Enabled: config.Enabled,
|
|
IsDefault: config.IsDefault,
|
|
ContextLimit: config.ContextLimit,
|
|
CompressionThreshold: config.CompressionThreshold,
|
|
ModelConfig: unmarshalChatModelCallConfig(config.Options),
|
|
CreatedAt: config.CreatedAt,
|
|
UpdatedAt: config.UpdatedAt,
|
|
}
|
|
}
|
|
|
|
func marshalChatModelCallConfig(
|
|
modelConfig *codersdk.ChatModelCallConfig,
|
|
) (json.RawMessage, error) {
|
|
if modelConfig == nil {
|
|
return json.RawMessage("{}"), nil
|
|
}
|
|
|
|
encoded, err := json.Marshal(modelConfig)
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("encode model config: %w", err)
|
|
}
|
|
return encoded, nil
|
|
}
|
|
|
|
func unmarshalChatModelCallConfig(
|
|
raw json.RawMessage,
|
|
) *codersdk.ChatModelCallConfig {
|
|
if len(raw) == 0 {
|
|
return nil
|
|
}
|
|
|
|
decoded := &codersdk.ChatModelCallConfig{}
|
|
if err := json.Unmarshal(raw, decoded); err != nil {
|
|
return nil
|
|
}
|
|
if isZeroChatModelCallConfig(decoded) {
|
|
return nil
|
|
}
|
|
return decoded
|
|
}
|
|
|
|
func isZeroChatModelCallConfig(config *codersdk.ChatModelCallConfig) bool {
|
|
if config == nil {
|
|
return true
|
|
}
|
|
|
|
return config.MaxOutputTokens == nil &&
|
|
config.Temperature == nil &&
|
|
config.TopP == nil &&
|
|
config.TopK == nil &&
|
|
config.PresencePenalty == nil &&
|
|
config.FrequencyPenalty == nil &&
|
|
isZeroChatModelProviderOptions(config.ProviderOptions)
|
|
}
|
|
|
|
func isZeroChatModelProviderOptions(options *codersdk.ChatModelProviderOptions) bool {
|
|
if options == nil {
|
|
return true
|
|
}
|
|
|
|
return options.OpenAI == nil &&
|
|
options.Anthropic == nil &&
|
|
options.Google == nil &&
|
|
options.OpenAICompat == nil &&
|
|
options.OpenRouter == nil &&
|
|
options.Vercel == nil
|
|
}
|
|
|
|
func normalizeChatProvider(provider string) string {
|
|
return chatprovider.NormalizeProvider(provider)
|
|
}
|
|
|
|
func normalizeChatProviderBaseURL(raw string) (string, error) {
|
|
trimmed := strings.TrimSpace(raw)
|
|
if trimmed == "" {
|
|
return "", nil
|
|
}
|
|
|
|
parsed, err := url.Parse(trimmed)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
if parsed.Scheme == "" || parsed.Host == "" {
|
|
return "", xerrors.New("Base URL must be an absolute URL with scheme and host.")
|
|
}
|
|
if parsed.Scheme != "http" && parsed.Scheme != "https" {
|
|
return "", xerrors.New("Base URL scheme must be http or https.")
|
|
}
|
|
return parsed.String(), nil
|
|
}
|
|
|
|
func chatProviderValidationDetail() string {
|
|
return "Provider must be one of: " + strings.Join(chatprovider.SupportedProviders(), ", ") + "."
|
|
}
|
|
|
|
func chatProviderAPIKeysFromDeploymentValues(
|
|
deploymentValues *codersdk.DeploymentValues,
|
|
) chatprovider.ProviderAPIKeys {
|
|
_ = deploymentValues
|
|
// For now, we'll just manage configs in the UI.
|
|
// We should probably not be reusing the AI bridge configs anyways.
|
|
return chatprovider.ProviderAPIKeys{
|
|
// OpenAI: deploymentValues.AI.BridgeConfig.OpenAI.Key.Value(),
|
|
// Anthropic: deploymentValues.AI.BridgeConfig.Anthropic.Key.Value(),
|
|
// BaseURLByProvider: map[string]string{
|
|
// "openai": deploymentValues.AI.BridgeConfig.OpenAI.BaseURL.Value(),
|
|
// "anthropic": deploymentValues.AI.BridgeConfig.Anthropic.BaseURL.Value(),
|
|
// },
|
|
}
|
|
}
|
|
|
|
func (api *API) hasEffectiveProviderAPIKey(ctx context.Context, provider database.ChatProvider) bool {
|
|
if strings.TrimSpace(provider.APIKey) != "" {
|
|
return true
|
|
}
|
|
if api.chatDaemon == nil {
|
|
return false
|
|
}
|
|
//nolint:gocritic // System context required to read enabled chat providers.
|
|
systemCtx := dbauthz.AsSystemRestricted(ctx)
|
|
|
|
enabledProviders, err := api.Database.GetEnabledChatProviders(
|
|
systemCtx,
|
|
)
|
|
if err != nil {
|
|
api.Logger.Warn(ctx, "failed to resolve provider API keys",
|
|
slog.F("provider", provider.Provider),
|
|
slog.Error(err),
|
|
)
|
|
return false
|
|
}
|
|
|
|
enabledConfiguredProviders := make(
|
|
[]chatprovider.ConfiguredProvider, 0, len(enabledProviders),
|
|
)
|
|
for _, configured := range enabledProviders {
|
|
enabledConfiguredProviders = append(
|
|
enabledConfiguredProviders, chatprovider.ConfiguredProvider{
|
|
Provider: configured.Provider,
|
|
APIKey: configured.APIKey,
|
|
BaseURL: configured.BaseUrl,
|
|
},
|
|
)
|
|
}
|
|
|
|
effectiveKeys := chatprovider.MergeProviderAPIKeys(
|
|
chatProviderAPIKeysFromDeploymentValues(api.DeploymentValues),
|
|
enabledConfiguredProviders,
|
|
)
|
|
return effectiveKeys.APIKey(provider.Provider) != ""
|
|
}
|