mirror of
https://github.com/coder/coder.git
synced 2026-06-07 06:58:17 +00:00
4b5ec8a9a4
## Summary Adds a new `diff_status_change` event kind to the `/chats/watch` pubsub stream so the sidebar can update diff status (PR created, files changed, branch info) without a full page reload. ### Problem When a chat's diff status changes (e.g. PR created via GitHub, git branch pushed), the sidebar didn't update because: 1. The backend `publishChatPubsubEvent` didn't include diff status data 2. The frontend watch handler only merged `status`, `title`, and `updated_at` from events ### Solution A **notify-only** approach: a new `ChatEventKindDiffStatusChange` event kind tells the frontend "diff status changed for chat X" — the frontend then invalidates the relevant React Query cache entries to re-fetch. ### Backend changes - **`coderd/pubsub/chatevent.go`**: New `ChatEventKindDiffStatusChange = "diff_status_change"` constant - **`coderd/chatd/chatd.go`**: New `PublishDiffStatusChange(ctx, chatID)` method on `Server` - **`coderd/chats.go`**: New `publishChatDiffStatusEvent` helper. Published from: - `refreshWorkspaceChatDiffStatuses` — after each chat's diff status is refreshed via GitHub API - `storeChatGitRef` — after persisting git branch/origin info from workspace agent ### Frontend changes - **`AgentsPage.tsx`**: Handle `diff_status_change` event by invalidating `chatDiffStatusKey` and `chatDiffContentsKey` queries - **`ChatContext.ts`**: Remove redundant diff status invalidation that fired on every chat status change (the new event kind handles this properly)
3152 lines
84 KiB
Go
3152 lines
84 KiB
Go
package coderd
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"regexp"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"charm.land/fantasy"
|
|
"github.com/go-chi/chi/v5"
|
|
"github.com/google/uuid"
|
|
"golang.org/x/xerrors"
|
|
|
|
"cdr.dev/slog/v3"
|
|
"github.com/coder/coder/v2/coderd/audit"
|
|
"github.com/coder/coder/v2/coderd/chatd"
|
|
"github.com/coder/coder/v2/coderd/chatd/chatprovider"
|
|
"github.com/coder/coder/v2/coderd/database"
|
|
"github.com/coder/coder/v2/coderd/database/db2sdk"
|
|
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
|
"github.com/coder/coder/v2/coderd/externalauth"
|
|
"github.com/coder/coder/v2/coderd/httpapi"
|
|
"github.com/coder/coder/v2/coderd/httpapi/httperror"
|
|
"github.com/coder/coder/v2/coderd/httpmw"
|
|
"github.com/coder/coder/v2/coderd/pubsub"
|
|
"github.com/coder/coder/v2/coderd/rbac"
|
|
"github.com/coder/coder/v2/coderd/rbac/policy"
|
|
"github.com/coder/coder/v2/coderd/tracing"
|
|
"github.com/coder/coder/v2/codersdk"
|
|
)
|
|
|
|
const (
|
|
chatDiffStatusTTL = 120 * time.Second
|
|
chatDiffBackgroundRefreshTimeout = 20 * time.Second
|
|
githubAPIBaseURL = "https://api.github.com"
|
|
chatStreamBatchSize = 256
|
|
|
|
chatContextLimitModelConfigKey = "context_limit"
|
|
chatContextCompressionThresholdModelConfigKey = "context_compression_threshold"
|
|
defaultChatContextCompressionThreshold = int32(70)
|
|
minChatContextCompressionThreshold = int32(0)
|
|
maxChatContextCompressionThreshold = int32(100)
|
|
)
|
|
|
|
// chatDiffRefreshBackoffSchedule defines the delays between successive
|
|
// background diff refresh attempts. The trigger fires when the agent
|
|
// obtains a GitHub token, which is typically right before a git push
|
|
// or PR creation. The backoff gives progressively more time for the
|
|
// push and any PR workflow to complete before querying the GitHub API.
|
|
var chatDiffRefreshBackoffSchedule = []time.Duration{
|
|
1 * time.Second,
|
|
3 * time.Second,
|
|
5 * time.Second,
|
|
10 * time.Second,
|
|
20 * time.Second,
|
|
}
|
|
|
|
// chatGitRef holds the branch and remote origin reported by the
|
|
// workspace agent during a git operation.
|
|
type chatGitRef struct {
|
|
Branch string
|
|
RemoteOrigin string
|
|
}
|
|
|
|
var (
|
|
githubPullRequestPathPattern = regexp.MustCompile(
|
|
`^https://github\.com/([A-Za-z0-9_.-]+)/([A-Za-z0-9_.-]+)/pull/([0-9]+)(?:[/?#].*)?$`,
|
|
)
|
|
githubRepositoryHTTPSPattern = regexp.MustCompile(
|
|
`^https://github\.com/([A-Za-z0-9_.-]+)/([A-Za-z0-9_.-]+?)(?:\.git)?/?$`,
|
|
)
|
|
githubRepositorySSHPathPattern = regexp.MustCompile(
|
|
`^(?:ssh://)?git@github\.com[:/]([A-Za-z0-9_.-]+)/([A-Za-z0-9_.-]+?)(?:\.git)?/?$`,
|
|
)
|
|
)
|
|
|
|
type githubPullRequestRef struct {
|
|
Owner string
|
|
Repo string
|
|
Number int
|
|
}
|
|
|
|
type githubPullRequestStatus struct {
|
|
PullRequestState string
|
|
ChangesRequested bool
|
|
Additions int32
|
|
Deletions int32
|
|
ChangedFiles int32
|
|
}
|
|
|
|
type chatRepositoryRef struct {
|
|
Provider string
|
|
RemoteOrigin string
|
|
Branch string
|
|
Owner string
|
|
Repo string
|
|
}
|
|
|
|
type chatDiffReference struct {
|
|
PullRequestURL string
|
|
RepositoryRef *chatRepositoryRef
|
|
}
|
|
|
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
|
func (api *API) watchChats(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
apiKey := httpmw.APIKey(r)
|
|
|
|
sendEvent, senderClosed, err := httpapi.OneWayWebSocketEventSender(api.Logger)(rw, r)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to open chat watch stream.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
defer func() {
|
|
<-senderClosed
|
|
}()
|
|
|
|
cancelSubscribe, err := api.Pubsub.SubscribeWithErr(pubsub.ChatEventChannel(apiKey.UserID),
|
|
pubsub.HandleChatEvent(
|
|
func(ctx context.Context, payload pubsub.ChatEvent, err error) {
|
|
if err != nil {
|
|
api.Logger.Error(ctx, "chat event subscription error", slog.Error(err))
|
|
return
|
|
}
|
|
_ = sendEvent(codersdk.ServerSentEvent{
|
|
Type: codersdk.ServerSentEventTypeData,
|
|
Data: payload,
|
|
})
|
|
},
|
|
))
|
|
if err != nil {
|
|
_ = sendEvent(codersdk.ServerSentEvent{
|
|
Type: codersdk.ServerSentEventTypeError,
|
|
Data: codersdk.Response{
|
|
Message: "Internal error subscribing to chat events.",
|
|
Detail: err.Error(),
|
|
},
|
|
})
|
|
return
|
|
}
|
|
defer cancelSubscribe()
|
|
|
|
// Send initial ping to signal the connection is ready.
|
|
_ = sendEvent(codersdk.ServerSentEvent{
|
|
Type: codersdk.ServerSentEventTypePing,
|
|
})
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case <-senderClosed:
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
|
func (api *API) listChats(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
apiKey := httpmw.APIKey(r)
|
|
|
|
chats, err := api.Database.GetChatsByOwnerID(ctx, apiKey.UserID)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to list chats.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
diffStatusesByChatID, err := api.getChatDiffStatusesByChatID(ctx, chats)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to list chats.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
httpapi.Write(ctx, rw, http.StatusOK, convertChats(chats, diffStatusesByChatID))
|
|
}
|
|
|
|
func (api *API) getChatDiffStatusesByChatID(
|
|
ctx context.Context,
|
|
chats []database.Chat,
|
|
) (map[uuid.UUID]database.ChatDiffStatus, error) {
|
|
if len(chats) == 0 {
|
|
return map[uuid.UUID]database.ChatDiffStatus{}, nil
|
|
}
|
|
|
|
chatIDs := make([]uuid.UUID, 0, len(chats))
|
|
for _, chat := range chats {
|
|
chatIDs = append(chatIDs, chat.ID)
|
|
}
|
|
|
|
statuses, err := api.Database.GetChatDiffStatusesByChatIDs(ctx, chatIDs)
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("get chat diff statuses: %w", err)
|
|
}
|
|
|
|
statusesByChatID := make(map[uuid.UUID]database.ChatDiffStatus, len(statuses))
|
|
for _, status := range statuses {
|
|
statusesByChatID[status.ChatID] = status
|
|
}
|
|
return statusesByChatID, nil
|
|
}
|
|
|
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
|
func (api *API) postChats(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
apiKey := httpmw.APIKey(r)
|
|
|
|
var req codersdk.CreateChatRequest
|
|
if !httpapi.Read(ctx, rw, r, &req) {
|
|
return
|
|
}
|
|
|
|
contentBlocks, titleSource, inputError := createChatInputFromRequest(req)
|
|
if inputError != nil {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, *inputError)
|
|
return
|
|
}
|
|
|
|
workspaceSelection, validationStatus, validationError := api.validateCreateChatWorkspaceSelection(ctx, req)
|
|
if validationError != nil {
|
|
httpapi.Write(ctx, rw, validationStatus, *validationError)
|
|
return
|
|
}
|
|
|
|
title := chatTitleFromMessage(titleSource)
|
|
|
|
if api.chatDaemon == nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Chat processor is unavailable.",
|
|
Detail: "Chat processor is not configured.",
|
|
})
|
|
return
|
|
}
|
|
|
|
modelConfigID, modelConfigStatus, modelConfigError := api.resolveCreateChatModelConfigID(ctx, req)
|
|
if modelConfigError != nil {
|
|
httpapi.Write(ctx, rw, modelConfigStatus, *modelConfigError)
|
|
return
|
|
}
|
|
|
|
chat, err := api.chatDaemon.CreateChat(ctx, chatd.CreateOptions{
|
|
OwnerID: apiKey.UserID,
|
|
WorkspaceID: workspaceSelection.WorkspaceID,
|
|
WorkspaceAgentID: workspaceSelection.WorkspaceAgentID,
|
|
Title: title,
|
|
ModelConfigID: modelConfigID,
|
|
SystemPrompt: defaultChatSystemPrompt(),
|
|
InitialUserContent: contentBlocks,
|
|
})
|
|
if err != nil {
|
|
if database.IsForeignKeyViolation(
|
|
err,
|
|
database.ForeignKeyChatsLastModelConfigID,
|
|
database.ForeignKeyChatMessagesModelConfigID,
|
|
) {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid model config ID.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to create chat.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
httpapi.Write(ctx, rw, http.StatusCreated, convertChat(chat, nil))
|
|
}
|
|
|
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
|
func (api *API) listChatModels(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
//nolint:gocritic // System context required to read enabled chat models.
|
|
systemCtx := dbauthz.AsSystemRestricted(ctx)
|
|
|
|
if api.chatDaemon == nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Chat processor is unavailable.",
|
|
Detail: "Chat processor is not configured.",
|
|
})
|
|
return
|
|
}
|
|
|
|
enabledProviders, err := api.Database.GetEnabledChatProviders(
|
|
systemCtx,
|
|
)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to load chat model configuration.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
enabledModels, err := api.Database.GetEnabledChatModelConfigs(
|
|
systemCtx,
|
|
)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to load chat model configuration.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
configuredProviders := make(
|
|
[]chatprovider.ConfiguredProvider, 0, len(enabledProviders),
|
|
)
|
|
for _, provider := range enabledProviders {
|
|
configuredProviders = append(
|
|
configuredProviders, chatprovider.ConfiguredProvider{
|
|
Provider: provider.Provider,
|
|
APIKey: provider.APIKey,
|
|
BaseURL: provider.BaseUrl,
|
|
},
|
|
)
|
|
}
|
|
configuredModels := make(
|
|
[]chatprovider.ConfiguredModel, 0, len(enabledModels),
|
|
)
|
|
for _, model := range enabledModels {
|
|
configuredModels = append(configuredModels, chatprovider.ConfiguredModel{
|
|
Provider: model.Provider,
|
|
Model: model.Model,
|
|
DisplayName: model.DisplayName,
|
|
})
|
|
}
|
|
|
|
keys := chatprovider.MergeProviderAPIKeys(
|
|
chatProviderAPIKeysFromDeploymentValues(api.DeploymentValues),
|
|
configuredProviders,
|
|
)
|
|
catalog := chatprovider.NewModelCatalog(keys)
|
|
var response codersdk.ChatModelsResponse
|
|
if configured, ok := catalog.ListConfiguredModels(
|
|
configuredProviders, configuredModels,
|
|
); ok {
|
|
response = configured
|
|
} else {
|
|
response = catalog.ListConfiguredProviderAvailability(configuredProviders)
|
|
}
|
|
|
|
httpapi.Write(ctx, rw, http.StatusOK, response)
|
|
}
|
|
|
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
|
//
|
|
//nolint:revive // HTTP handler writes to ResponseWriter.
|
|
func (api *API) getChat(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
chat := httpmw.ChatParam(r)
|
|
chatID := chat.ID
|
|
|
|
messages, err := api.Database.GetChatMessagesByChatID(ctx, chatID)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to get chat messages.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
queuedMessages, err := api.Database.GetChatQueuedMessages(ctx, chatID)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to get queued messages.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
httpapi.Write(ctx, rw, http.StatusOK, codersdk.ChatWithMessages{
|
|
Chat: convertChat(chat, nil),
|
|
Messages: convertChatMessages(messages),
|
|
QueuedMessages: convertChatQueuedMessages(queuedMessages),
|
|
})
|
|
}
|
|
|
|
// @Summary Archive a chat
|
|
// @ID archive-chat
|
|
// @Tags Chats
|
|
// @Success 204
|
|
// @Router /chats/{chat}/archive [post]
|
|
func (api *API) archiveChat(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
chat := httpmw.ChatParam(r)
|
|
|
|
if chat.Archived {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Chat is already archived.",
|
|
})
|
|
return
|
|
}
|
|
|
|
var err error
|
|
if api.chatDaemon != nil {
|
|
err = api.chatDaemon.ArchiveChat(ctx, chat.ID)
|
|
} else {
|
|
err = archiveChatTree(ctx, api.Database, chat.ID)
|
|
}
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to archive chat.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
rw.WriteHeader(http.StatusNoContent)
|
|
}
|
|
|
|
// @Summary Unarchive a chat
|
|
// @ID unarchive-chat
|
|
// @Tags Chats
|
|
// @Success 204
|
|
// @Router /chats/{chat}/unarchive [post]
|
|
func (api *API) unarchiveChat(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
chat := httpmw.ChatParam(r)
|
|
|
|
if !chat.Archived {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Chat is not archived.",
|
|
})
|
|
return
|
|
}
|
|
|
|
err := api.Database.UnarchiveChatByID(ctx, chat.ID)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to unarchive chat.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
rw.WriteHeader(http.StatusNoContent)
|
|
}
|
|
|
|
func archiveChatTree(ctx context.Context, store database.Store, chatID uuid.UUID) error {
|
|
children, err := store.ListChildChatsByParentID(ctx, chatID)
|
|
if err != nil {
|
|
return xerrors.Errorf("list child chats: %w", err)
|
|
}
|
|
for _, child := range children {
|
|
if err := archiveChatTree(ctx, store, child.ID); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return store.ArchiveChatByID(ctx, chatID)
|
|
}
|
|
|
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
|
func (api *API) postChatMessages(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
chat := httpmw.ChatParam(r)
|
|
chatID := chat.ID
|
|
|
|
if api.chatDaemon == nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Chat processor is unavailable.",
|
|
Detail: "Chat processor is not configured.",
|
|
})
|
|
return
|
|
}
|
|
|
|
var req codersdk.CreateChatMessageRequest
|
|
if !httpapi.Read(ctx, rw, r, &req) {
|
|
return
|
|
}
|
|
|
|
contentBlocks, _, inputError := createChatInputFromParts(req.Content, "content")
|
|
if inputError != nil {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: inputError.Message,
|
|
Detail: inputError.Detail,
|
|
})
|
|
return
|
|
}
|
|
|
|
sendResult, sendErr := api.chatDaemon.SendMessage(
|
|
ctx,
|
|
chatd.SendMessageOptions{
|
|
ChatID: chatID,
|
|
Content: contentBlocks,
|
|
ModelConfigID: req.ModelConfigID,
|
|
BusyBehavior: chatd.SendMessageBusyBehaviorQueue,
|
|
},
|
|
)
|
|
if sendErr != nil {
|
|
if xerrors.Is(sendErr, chatd.ErrMessageQueueFull) {
|
|
httpapi.Write(ctx, rw, http.StatusTooManyRequests, codersdk.Response{
|
|
Message: "Message queue is full.",
|
|
Detail: fmt.Sprintf("Maximum %d messages can be queued.", chatd.MaxQueueSize),
|
|
})
|
|
return
|
|
}
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to create chat message.",
|
|
Detail: sendErr.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
response := codersdk.CreateChatMessageResponse{Queued: sendResult.Queued}
|
|
if sendResult.Queued {
|
|
if sendResult.QueuedMessage != nil {
|
|
response.QueuedMessage = convertChatQueuedMessagePtr(*sendResult.QueuedMessage)
|
|
}
|
|
} else {
|
|
message := convertChatMessage(sendResult.Message)
|
|
response.Message = &message
|
|
}
|
|
|
|
httpapi.Write(ctx, rw, http.StatusOK, response)
|
|
}
|
|
|
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
|
func (api *API) patchChatMessage(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
chat := httpmw.ChatParam(r)
|
|
|
|
if api.chatDaemon == nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Chat processor is unavailable.",
|
|
Detail: "Chat processor is not configured.",
|
|
})
|
|
return
|
|
}
|
|
|
|
messageIDStr := chi.URLParam(r, "message")
|
|
messageID, err := strconv.ParseInt(messageIDStr, 10, 64)
|
|
if err != nil || messageID <= 0 {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid chat message ID.",
|
|
Detail: "Message ID must be a positive integer.",
|
|
})
|
|
return
|
|
}
|
|
|
|
var req codersdk.EditChatMessageRequest
|
|
if !httpapi.Read(ctx, rw, r, &req) {
|
|
return
|
|
}
|
|
|
|
contentBlocks, _, inputError := createChatInputFromParts(req.Content, "content")
|
|
if inputError != nil {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: inputError.Message,
|
|
Detail: inputError.Detail,
|
|
})
|
|
return
|
|
}
|
|
|
|
editResult, editErr := api.chatDaemon.EditMessage(ctx, chatd.EditMessageOptions{
|
|
ChatID: chat.ID,
|
|
EditedMessageID: messageID,
|
|
Content: contentBlocks,
|
|
})
|
|
if editErr != nil {
|
|
switch {
|
|
case xerrors.Is(editErr, chatd.ErrEditedMessageNotFound):
|
|
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
|
|
Message: "Chat message not found.",
|
|
Detail: "Message does not belong to this chat.",
|
|
})
|
|
case xerrors.Is(editErr, chatd.ErrEditedMessageNotUser):
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Only user messages can be edited.",
|
|
})
|
|
default:
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to edit chat message.",
|
|
Detail: editErr.Error(),
|
|
})
|
|
}
|
|
return
|
|
}
|
|
|
|
message := convertChatMessage(editResult.Message)
|
|
httpapi.Write(ctx, rw, http.StatusOK, message)
|
|
}
|
|
|
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
|
func (api *API) deleteChatQueuedMessage(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
chat := httpmw.ChatParam(r)
|
|
chatID := chat.ID
|
|
|
|
queuedMessageIDStr := chi.URLParam(r, "queuedMessage")
|
|
queuedMessageID, err := strconv.ParseInt(queuedMessageIDStr, 10, 64)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid queued message ID.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
if api.chatDaemon != nil {
|
|
err = api.chatDaemon.DeleteQueued(ctx, chatID, queuedMessageID)
|
|
} else {
|
|
err = api.Database.DeleteChatQueuedMessage(ctx, database.DeleteChatQueuedMessageParams{
|
|
ID: queuedMessageID,
|
|
ChatID: chatID,
|
|
})
|
|
}
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to delete queued message.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
rw.WriteHeader(http.StatusNoContent)
|
|
}
|
|
|
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
|
func (api *API) promoteChatQueuedMessage(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
chat := httpmw.ChatParam(r)
|
|
chatID := chat.ID
|
|
|
|
queuedMessageIDStr := chi.URLParam(r, "queuedMessage")
|
|
queuedMessageID, err := strconv.ParseInt(queuedMessageIDStr, 10, 64)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid queued message ID.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
if api.chatDaemon == nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Chat processor is unavailable.",
|
|
Detail: "Chat processor is not configured.",
|
|
})
|
|
return
|
|
}
|
|
|
|
promoteResult, txErr := api.chatDaemon.PromoteQueued(ctx, chatd.PromoteQueuedOptions{
|
|
ChatID: chatID,
|
|
QueuedMessageID: queuedMessageID,
|
|
})
|
|
|
|
if txErr != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to promote queued message.",
|
|
Detail: txErr.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
httpapi.Write(ctx, rw, http.StatusOK, convertChatMessage(promoteResult.PromotedMessage))
|
|
}
|
|
|
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
|
func (api *API) streamChat(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
chat := httpmw.ChatParam(r)
|
|
chatID := chat.ID
|
|
|
|
if api.chatDaemon == nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Chat streaming is not available.",
|
|
Detail: "Chat processor is not configured.",
|
|
})
|
|
return
|
|
}
|
|
|
|
sendEvent, senderClosed, err := httpapi.OneWayWebSocketEventSender(api.Logger)(rw, r)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to open chat stream.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
defer func() {
|
|
<-senderClosed
|
|
}()
|
|
|
|
snapshot, events, cancel, ok := api.chatDaemon.Subscribe(ctx, chatID, r.Header)
|
|
if !ok {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Chat streaming is not available.",
|
|
Detail: "Chat stream state is not configured.",
|
|
})
|
|
return
|
|
}
|
|
defer cancel()
|
|
|
|
sendChatStreamBatch := func(batch []codersdk.ChatStreamEvent) error {
|
|
if len(batch) == 0 {
|
|
return nil
|
|
}
|
|
return sendEvent(codersdk.ServerSentEvent{
|
|
Type: codersdk.ServerSentEventTypeData,
|
|
Data: batch,
|
|
})
|
|
}
|
|
|
|
drainChatStreamBatch := func(
|
|
first codersdk.ChatStreamEvent,
|
|
maxBatchSize int,
|
|
) ([]codersdk.ChatStreamEvent, bool) {
|
|
batch := []codersdk.ChatStreamEvent{first}
|
|
if maxBatchSize <= 1 {
|
|
return batch, false
|
|
}
|
|
|
|
for len(batch) < maxBatchSize {
|
|
select {
|
|
case event, ok := <-events:
|
|
if !ok {
|
|
return batch, true
|
|
}
|
|
batch = append(batch, event)
|
|
default:
|
|
return batch, false
|
|
}
|
|
}
|
|
|
|
return batch, false
|
|
}
|
|
|
|
for start := 0; start < len(snapshot); start += chatStreamBatchSize {
|
|
end := start + chatStreamBatchSize
|
|
if end > len(snapshot) {
|
|
end = len(snapshot)
|
|
}
|
|
if err := sendChatStreamBatch(snapshot[start:end]); err != nil {
|
|
api.Logger.Debug(ctx, "failed to send chat stream snapshot", slog.Error(err))
|
|
return
|
|
}
|
|
}
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case <-senderClosed:
|
|
return
|
|
case firstEvent, ok := <-events:
|
|
if !ok {
|
|
return
|
|
}
|
|
batch, streamClosed := drainChatStreamBatch(
|
|
firstEvent,
|
|
chatStreamBatchSize,
|
|
)
|
|
if err := sendChatStreamBatch(batch); err != nil {
|
|
api.Logger.Debug(ctx, "failed to send chat stream event", slog.Error(err))
|
|
return
|
|
}
|
|
if streamClosed {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
|
func (api *API) interruptChat(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
chat := httpmw.ChatParam(r)
|
|
chatID := chat.ID
|
|
|
|
if api.chatDaemon != nil {
|
|
chat = api.chatDaemon.InterruptChat(ctx, chat)
|
|
} else {
|
|
updatedChat, updateErr := api.Database.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
|
ID: chatID,
|
|
Status: database.ChatStatusWaiting,
|
|
WorkerID: uuid.NullUUID{},
|
|
StartedAt: sql.NullTime{},
|
|
HeartbeatAt: sql.NullTime{},
|
|
})
|
|
if updateErr != nil {
|
|
api.Logger.Error(ctx, "failed to mark chat as waiting",
|
|
slog.F("chat_id", chatID), slog.Error(updateErr))
|
|
} else {
|
|
chat = updatedChat
|
|
}
|
|
}
|
|
|
|
httpapi.Write(ctx, rw, http.StatusOK, convertChat(chat, nil))
|
|
}
|
|
|
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
|
//
|
|
//nolint:revive // HTTP handler writes to ResponseWriter.
|
|
func (api *API) getChatDiffStatus(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
chat := httpmw.ChatParam(r)
|
|
chatID := chat.ID
|
|
|
|
status, err := api.resolveChatDiffStatus(ctx, chat)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to get chat diff status.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
httpapi.Write(ctx, rw, http.StatusOK, convertChatDiffStatus(chatID, status))
|
|
}
|
|
|
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
|
//
|
|
//nolint:revive // HTTP handler writes to ResponseWriter.
|
|
func (api *API) getChatDiffContents(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
chat := httpmw.ChatParam(r)
|
|
|
|
diff, err := api.resolveChatDiffContents(ctx, chat)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to get chat diff.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
httpapi.Write(ctx, rw, http.StatusOK, diff)
|
|
}
|
|
|
|
// chatCreateWorkspace provides workspace creation for the chat
|
|
// processor. RBAC authorization uses context-based checks via
|
|
// dbauthz.As rather than fake *http.Request objects.
|
|
func (api *API) chatCreateWorkspace(
|
|
ctx context.Context,
|
|
ownerID uuid.UUID,
|
|
req codersdk.CreateWorkspaceRequest,
|
|
) (codersdk.Workspace, error) {
|
|
actor, _, err := httpmw.UserRBACSubject(ctx, api.Database, ownerID, rbac.ScopeAll)
|
|
if err != nil {
|
|
return codersdk.Workspace{}, xerrors.Errorf("load user authorization: %w", err)
|
|
}
|
|
ctx = dbauthz.As(ctx, actor)
|
|
|
|
ownerUser, err := api.Database.GetUserByID(ctx, ownerID)
|
|
if err != nil {
|
|
return codersdk.Workspace{}, xerrors.Errorf("get workspace owner: %w", err)
|
|
}
|
|
owner := workspaceOwner{
|
|
ID: ownerUser.ID,
|
|
Username: ownerUser.Username,
|
|
AvatarURL: ownerUser.AvatarURL,
|
|
}
|
|
|
|
auditor := api.Auditor.Load()
|
|
if auditor == nil {
|
|
return codersdk.Workspace{}, xerrors.New("auditor is not configured")
|
|
}
|
|
|
|
// The audit system requires a ResponseWriter to capture the
|
|
// HTTP status code. Since this is a programmatic call, we use
|
|
// a recorder. The audit entry still captures the owner, action,
|
|
// and resource correctly.
|
|
rw := httptest.NewRecorder()
|
|
sw := &tracing.StatusWriter{ResponseWriter: rw}
|
|
|
|
// Build a minimal synthetic request so the audit commit
|
|
// closure can extract a request ID and user agent. The RBAC
|
|
// subject is already on the context via dbauthz.As above.
|
|
auditReq, err := http.NewRequestWithContext(
|
|
httpmw.WithRequestID(ctx, uuid.New()),
|
|
http.MethodPost,
|
|
"http://localhost/internal/chat/workspace",
|
|
nil,
|
|
)
|
|
if err != nil {
|
|
return codersdk.Workspace{}, xerrors.Errorf("create audit request: %w", err)
|
|
}
|
|
|
|
aReq, commitAudit := audit.InitRequest[database.WorkspaceTable](sw, &audit.RequestParams{
|
|
Audit: *auditor,
|
|
Log: api.Logger,
|
|
Request: auditReq,
|
|
Action: database.AuditActionCreate,
|
|
AdditionalFields: audit.AdditionalFields{
|
|
WorkspaceOwner: owner.Username,
|
|
},
|
|
})
|
|
aReq.UserID = ownerID
|
|
defer commitAudit()
|
|
|
|
workspace, err := createWorkspace(ctx, aReq, ownerID, api, owner, req, nil)
|
|
if err != nil {
|
|
sw.WriteHeader(chatWorkspaceAuditStatus(err))
|
|
return codersdk.Workspace{}, err
|
|
}
|
|
|
|
sw.WriteHeader(http.StatusCreated)
|
|
return workspace, nil
|
|
}
|
|
|
|
func chatWorkspaceAuditStatus(err error) int {
|
|
if responder, ok := httperror.IsResponder(err); ok {
|
|
status, _ := responder.Response()
|
|
return status
|
|
}
|
|
return http.StatusInternalServerError
|
|
}
|
|
|
|
func (api *API) resolveChatDiffStatus(
|
|
ctx context.Context,
|
|
chat database.Chat,
|
|
) (*database.ChatDiffStatus, error) {
|
|
return api.resolveChatDiffStatusWithOptions(ctx, chat, false)
|
|
}
|
|
|
|
//nolint:revive // Boolean forces cache refresh bypass.
|
|
func (api *API) resolveChatDiffStatusWithOptions(
|
|
ctx context.Context,
|
|
chat database.Chat,
|
|
forceRefresh bool,
|
|
) (*database.ChatDiffStatus, error) {
|
|
status, found, err := api.getCachedChatDiffStatus(ctx, chat.ID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
now := time.Now().UTC()
|
|
|
|
reference, err := api.resolveChatDiffReference(ctx, chat, found, status)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if reference.PullRequestURL != "" {
|
|
if !found || !strings.EqualFold(strings.TrimSpace(status.Url.String), reference.PullRequestURL) {
|
|
status, err = api.upsertChatDiffStatusReference(ctx, chat.ID, reference.PullRequestURL, now.Add(-time.Second))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
found = true
|
|
}
|
|
}
|
|
|
|
if !found {
|
|
return nil, nil //nolint:nilnil // Callers handle nil status explicitly.
|
|
}
|
|
if reference.PullRequestURL == "" {
|
|
return &status, nil
|
|
}
|
|
if !shouldRefreshChatDiffStatus(status, now, forceRefresh) {
|
|
return &status, nil
|
|
}
|
|
|
|
refreshed, err := api.refreshChatDiffStatus(
|
|
ctx,
|
|
chat.OwnerID,
|
|
chat.ID,
|
|
reference.PullRequestURL,
|
|
)
|
|
if err == nil {
|
|
return &refreshed, nil
|
|
}
|
|
|
|
api.Logger.Warn(ctx, "failed to refresh chat diff status",
|
|
slog.F("chat_id", chat.ID),
|
|
slog.F("pull_request_url", reference.PullRequestURL),
|
|
slog.Error(err),
|
|
)
|
|
|
|
backoffStatus, backoffErr := api.upsertChatDiffStatusReference(ctx, chat.ID, reference.PullRequestURL, now.Add(chatDiffStatusTTL))
|
|
if backoffErr != nil {
|
|
api.Logger.Warn(ctx, "failed to extend chat diff status stale timestamp",
|
|
slog.F("chat_id", chat.ID),
|
|
slog.Error(backoffErr),
|
|
)
|
|
return &status, nil
|
|
}
|
|
|
|
return &backoffStatus, nil
|
|
}
|
|
|
|
//nolint:revive // Boolean forces cache refresh bypass.
|
|
func shouldRefreshChatDiffStatus(status database.ChatDiffStatus, now time.Time, forceRefresh bool) bool {
|
|
if forceRefresh {
|
|
return true
|
|
}
|
|
return chatDiffStatusIsStale(status, now)
|
|
}
|
|
|
|
func (api *API) triggerWorkspaceChatDiffStatusRefresh(workspace database.Workspace, gitRef chatGitRef) {
|
|
if workspace.ID == uuid.Nil || workspace.OwnerID == uuid.Nil {
|
|
return
|
|
}
|
|
|
|
go func(workspaceID, workspaceOwnerID uuid.UUID, gitRef chatGitRef) {
|
|
ctx := api.ctx
|
|
if ctx == nil {
|
|
ctx = context.Background()
|
|
}
|
|
//nolint:gocritic // Background goroutine for diff status refresh has no user context.
|
|
ctx = dbauthz.AsSystemRestricted(ctx)
|
|
|
|
// Always store the git ref so the data is persisted even
|
|
// before a PR exists. The frontend can show branch info
|
|
// and the refresh loop can resolve a PR later.
|
|
api.storeChatGitRef(ctx, workspaceID, workspaceOwnerID, gitRef)
|
|
|
|
for _, delay := range chatDiffRefreshBackoffSchedule {
|
|
t := api.Clock.NewTimer(delay, "chat_diff_refresh")
|
|
select {
|
|
case <-ctx.Done():
|
|
t.Stop()
|
|
return
|
|
case <-t.C:
|
|
}
|
|
|
|
// Refresh and publish status on every iteration.
|
|
// Stop the loop once a PR is discovered — there's
|
|
// nothing more to wait for after that.
|
|
if api.refreshWorkspaceChatDiffStatuses(ctx, workspaceID, workspaceOwnerID) {
|
|
return
|
|
}
|
|
}
|
|
}(workspace.ID, workspace.OwnerID, gitRef)
|
|
}
|
|
|
|
// storeChatGitRef persists the git branch and remote origin reported
|
|
// by the workspace agent on all chats associated with the workspace.
|
|
func (api *API) storeChatGitRef(ctx context.Context, workspaceID, workspaceOwnerID uuid.UUID, gitRef chatGitRef) {
|
|
chats, err := api.Database.GetChatsByOwnerID(ctx, workspaceOwnerID)
|
|
if err != nil {
|
|
api.Logger.Warn(ctx, "failed to list chats for git ref storage",
|
|
slog.F("workspace_id", workspaceID),
|
|
slog.Error(err),
|
|
)
|
|
return
|
|
}
|
|
|
|
for _, chat := range filterChatsByWorkspaceID(chats, workspaceID) {
|
|
_, err := api.Database.UpsertChatDiffStatusReference(ctx, database.UpsertChatDiffStatusReferenceParams{
|
|
ChatID: chat.ID,
|
|
GitBranch: gitRef.Branch,
|
|
GitRemoteOrigin: gitRef.RemoteOrigin,
|
|
StaleAt: time.Now().UTC().Add(-time.Second),
|
|
Url: sql.NullString{},
|
|
})
|
|
if err != nil {
|
|
api.Logger.Warn(ctx, "failed to store git ref on chat diff status",
|
|
slog.F("chat_id", chat.ID),
|
|
slog.F("workspace_id", workspaceID),
|
|
slog.Error(err),
|
|
)
|
|
continue
|
|
}
|
|
api.publishChatDiffStatusEvent(ctx, chat.ID)
|
|
}
|
|
}
|
|
|
|
// refreshWorkspaceChatDiffStatuses refreshes the diff status for all
|
|
// chats associated with the given workspace. It returns true when
|
|
// every chat has a PR URL resolved, signaling that the caller can
|
|
// stop polling.
|
|
func (api *API) refreshWorkspaceChatDiffStatuses(ctx context.Context, workspaceID, workspaceOwnerID uuid.UUID) bool {
|
|
chats, err := api.Database.GetChatsByOwnerID(ctx, workspaceOwnerID)
|
|
if err != nil {
|
|
api.Logger.Warn(ctx, "failed to list workspace owner chats for diff refresh",
|
|
slog.F("workspace_id", workspaceID),
|
|
slog.F("workspace_owner_id", workspaceOwnerID),
|
|
slog.Error(err),
|
|
)
|
|
return false
|
|
}
|
|
|
|
filtered := filterChatsByWorkspaceID(chats, workspaceID)
|
|
if len(filtered) == 0 {
|
|
return false
|
|
}
|
|
|
|
allHavePR := true
|
|
for _, chat := range filtered {
|
|
refreshCtx, cancel := context.WithTimeout(ctx, chatDiffBackgroundRefreshTimeout)
|
|
status, err := api.resolveChatDiffStatusWithOptions(refreshCtx, chat, true)
|
|
cancel()
|
|
if err != nil {
|
|
api.Logger.Warn(ctx, "failed to refresh chat diff status after workspace external auth",
|
|
slog.F("workspace_id", workspaceID),
|
|
slog.F("chat_id", chat.ID),
|
|
slog.Error(err),
|
|
)
|
|
allHavePR = false
|
|
} else if status == nil || !status.Url.Valid || strings.TrimSpace(status.Url.String) == "" {
|
|
allHavePR = false
|
|
}
|
|
|
|
api.publishChatStatusEvent(ctx, chat.ID)
|
|
api.publishChatDiffStatusEvent(ctx, chat.ID)
|
|
}
|
|
|
|
return allHavePR
|
|
}
|
|
|
|
func filterChatsByWorkspaceID(chats []database.Chat, workspaceID uuid.UUID) []database.Chat {
|
|
filteredChats := make([]database.Chat, 0, len(chats))
|
|
for _, chat := range chats {
|
|
if !chat.WorkspaceID.Valid || chat.WorkspaceID.UUID != workspaceID {
|
|
continue
|
|
}
|
|
filteredChats = append(filteredChats, chat)
|
|
}
|
|
return filteredChats
|
|
}
|
|
|
|
func (api *API) publishChatStatusEvent(ctx context.Context, chatID uuid.UUID) {
|
|
if api.chatDaemon == nil {
|
|
return
|
|
}
|
|
|
|
if err := api.chatDaemon.RefreshStatus(ctx, chatID); err != nil {
|
|
api.Logger.Debug(ctx, "failed to refresh published chat status",
|
|
slog.F("chat_id", chatID),
|
|
slog.Error(err),
|
|
)
|
|
}
|
|
}
|
|
|
|
func (api *API) publishChatDiffStatusEvent(ctx context.Context, chatID uuid.UUID) {
|
|
if api.chatDaemon == nil {
|
|
return
|
|
}
|
|
|
|
if err := api.chatDaemon.PublishDiffStatusChange(ctx, chatID); err != nil {
|
|
api.Logger.Debug(ctx, "failed to publish chat diff status change",
|
|
slog.F("chat_id", chatID),
|
|
slog.Error(err),
|
|
)
|
|
}
|
|
}
|
|
|
|
func (api *API) resolveChatDiffContents(
|
|
ctx context.Context,
|
|
chat database.Chat,
|
|
) (codersdk.ChatDiffContents, error) {
|
|
result := codersdk.ChatDiffContents{ChatID: chat.ID}
|
|
|
|
status, found, err := api.getCachedChatDiffStatus(ctx, chat.ID)
|
|
if err != nil {
|
|
return result, err
|
|
}
|
|
|
|
reference, err := api.resolveChatDiffReference(ctx, chat, found, status)
|
|
if err != nil {
|
|
return result, err
|
|
}
|
|
|
|
if reference.RepositoryRef != nil {
|
|
provider := strings.TrimSpace(reference.RepositoryRef.Provider)
|
|
if provider != "" {
|
|
result.Provider = &provider
|
|
}
|
|
|
|
origin := strings.TrimSpace(reference.RepositoryRef.RemoteOrigin)
|
|
if origin != "" {
|
|
result.RemoteOrigin = &origin
|
|
}
|
|
|
|
branch := strings.TrimSpace(reference.RepositoryRef.Branch)
|
|
if branch != "" {
|
|
result.Branch = &branch
|
|
}
|
|
}
|
|
|
|
if reference.PullRequestURL != "" {
|
|
pullRequestURL := strings.TrimSpace(reference.PullRequestURL)
|
|
result.PullRequestURL = &pullRequestURL
|
|
if !found || !strings.EqualFold(strings.TrimSpace(status.Url.String), pullRequestURL) {
|
|
_, err := api.upsertChatDiffStatusReference(ctx, chat.ID, pullRequestURL, time.Now().UTC().Add(-time.Second))
|
|
if err != nil {
|
|
return result, err
|
|
}
|
|
}
|
|
}
|
|
|
|
if reference.RepositoryRef == nil {
|
|
return result, nil
|
|
}
|
|
if !strings.EqualFold(reference.RepositoryRef.Provider, string(codersdk.EnhancedExternalAuthProviderGitHub)) {
|
|
return result, nil
|
|
}
|
|
|
|
token := api.resolveChatGitHubAccessToken(ctx, chat.OwnerID)
|
|
|
|
if reference.PullRequestURL != "" {
|
|
diff, err := api.fetchGitHubPullRequestDiff(ctx, reference.PullRequestURL, token)
|
|
if err != nil {
|
|
return result, err
|
|
}
|
|
result.Diff = diff
|
|
return result, nil
|
|
}
|
|
|
|
diff, err := api.fetchGitHubCompareDiff(ctx, *reference.RepositoryRef, token)
|
|
if err != nil {
|
|
return result, err
|
|
}
|
|
result.Diff = diff
|
|
return result, nil
|
|
}
|
|
|
|
// resolveChatDiffReference builds the diff reference from the cached
|
|
// status stored in the database. The git branch and remote origin are
|
|
// populated by the workspace agent during git operations (via the
|
|
// gitaskpass flow), so no SSH into the workspace is needed here.
|
|
//
|
|
//nolint:revive // Boolean indicates whether diff status was found.
|
|
func (api *API) resolveChatDiffReference(
|
|
ctx context.Context,
|
|
chat database.Chat,
|
|
found bool,
|
|
status database.ChatDiffStatus,
|
|
) (chatDiffReference, error) {
|
|
reference := chatDiffReference{}
|
|
if !found {
|
|
return reference, nil
|
|
}
|
|
|
|
reference.PullRequestURL = strings.TrimSpace(status.Url.String)
|
|
|
|
// Build the repository ref from the stored git branch/origin
|
|
// that the agent reported.
|
|
reference.RepositoryRef = api.buildChatRepositoryRefFromStatus(status)
|
|
|
|
// If we have a repo ref with a branch, try to resolve the
|
|
// current open PR. This picks up new PRs after the previous
|
|
// one was closed.
|
|
if reference.RepositoryRef != nil &&
|
|
strings.EqualFold(reference.RepositoryRef.Provider, string(codersdk.EnhancedExternalAuthProviderGitHub)) {
|
|
pullRequestURL, lookupErr := api.resolveGitHubPullRequestURLFromRepositoryRef(ctx, chat.OwnerID, *reference.RepositoryRef)
|
|
if lookupErr != nil {
|
|
api.Logger.Debug(ctx, "failed to resolve pull request from repository reference",
|
|
slog.F("chat_id", chat.ID),
|
|
slog.F("provider", reference.RepositoryRef.Provider),
|
|
slog.F("remote_origin", reference.RepositoryRef.RemoteOrigin),
|
|
slog.F("branch", reference.RepositoryRef.Branch),
|
|
slog.Error(lookupErr),
|
|
)
|
|
} else if pullRequestURL != "" {
|
|
reference.PullRequestURL = pullRequestURL
|
|
}
|
|
}
|
|
|
|
reference.PullRequestURL = normalizeGitHubPullRequestURL(reference.PullRequestURL)
|
|
|
|
// If we have a PR URL but no repo ref (e.g. the agent hasn't
|
|
// reported branch/origin yet), derive a partial ref from the
|
|
// PR URL so the caller can still show provider/owner/repo.
|
|
if reference.RepositoryRef == nil && reference.PullRequestURL != "" {
|
|
if parsed, ok := parseGitHubPullRequestURL(reference.PullRequestURL); ok {
|
|
reference.RepositoryRef = &chatRepositoryRef{
|
|
Provider: string(codersdk.EnhancedExternalAuthProviderGitHub),
|
|
RemoteOrigin: fmt.Sprintf("https://github.com/%s/%s", parsed.Owner, parsed.Repo),
|
|
Owner: parsed.Owner,
|
|
Repo: parsed.Repo,
|
|
}
|
|
}
|
|
}
|
|
|
|
return reference, nil
|
|
}
|
|
|
|
// buildChatRepositoryRefFromStatus constructs a chatRepositoryRef
|
|
// from the git branch and remote origin stored in the cached status.
|
|
// Returns nil if no ref data is available.
|
|
func (api *API) buildChatRepositoryRefFromStatus(status database.ChatDiffStatus) *chatRepositoryRef {
|
|
branch := strings.TrimSpace(status.GitBranch)
|
|
origin := strings.TrimSpace(status.GitRemoteOrigin)
|
|
if branch == "" || origin == "" {
|
|
return nil
|
|
}
|
|
|
|
repoRef := &chatRepositoryRef{
|
|
Provider: strings.TrimSpace(api.resolveExternalAuthProviderType(origin)),
|
|
RemoteOrigin: origin,
|
|
Branch: branch,
|
|
}
|
|
|
|
if owner, repo, normalizedOrigin, ok := parseGitHubRepositoryOrigin(repoRef.RemoteOrigin); ok {
|
|
if repoRef.Provider == "" {
|
|
repoRef.Provider = string(codersdk.EnhancedExternalAuthProviderGitHub)
|
|
}
|
|
repoRef.RemoteOrigin = normalizedOrigin
|
|
repoRef.Owner = owner
|
|
repoRef.Repo = repo
|
|
}
|
|
|
|
if repoRef.Provider == "" {
|
|
return nil
|
|
}
|
|
|
|
return repoRef
|
|
}
|
|
|
|
func (api *API) upsertChatDiffStatusReference(
|
|
ctx context.Context,
|
|
chatID uuid.UUID,
|
|
pullRequestURL string,
|
|
staleAt time.Time,
|
|
) (database.ChatDiffStatus, error) {
|
|
status, err := api.Database.UpsertChatDiffStatusReference(
|
|
ctx,
|
|
database.UpsertChatDiffStatusReferenceParams{
|
|
ChatID: chatID,
|
|
Url: sql.NullString{
|
|
String: pullRequestURL,
|
|
Valid: strings.TrimSpace(pullRequestURL) != "",
|
|
},
|
|
// Empty strings preserve existing values via the
|
|
// CASE expression in the SQL query.
|
|
GitBranch: "",
|
|
GitRemoteOrigin: "",
|
|
StaleAt: staleAt,
|
|
},
|
|
)
|
|
if err != nil {
|
|
return database.ChatDiffStatus{}, xerrors.Errorf("upsert chat diff status reference: %w", err)
|
|
}
|
|
return status, nil
|
|
}
|
|
|
|
func (api *API) getCachedChatDiffStatus(
|
|
ctx context.Context,
|
|
chatID uuid.UUID,
|
|
) (database.ChatDiffStatus, bool, error) {
|
|
status, err := api.Database.GetChatDiffStatusByChatID(ctx, chatID)
|
|
if err == nil {
|
|
return status, true, nil
|
|
}
|
|
if xerrors.Is(err, sql.ErrNoRows) {
|
|
return database.ChatDiffStatus{}, false, nil
|
|
}
|
|
return database.ChatDiffStatus{}, false, xerrors.Errorf(
|
|
"get chat diff status: %w",
|
|
err,
|
|
)
|
|
}
|
|
|
|
func (api *API) resolveExternalAuthProviderType(match string) string {
|
|
match = strings.TrimSpace(match)
|
|
if match == "" {
|
|
return ""
|
|
}
|
|
|
|
for _, extAuth := range api.ExternalAuthConfigs {
|
|
if extAuth.Regex == nil || !extAuth.Regex.MatchString(match) {
|
|
continue
|
|
}
|
|
return strings.ToLower(strings.TrimSpace(extAuth.Type))
|
|
}
|
|
|
|
return ""
|
|
}
|
|
|
|
func parseGitHubRepositoryOrigin(raw string) (owner string, repo string, normalizedOrigin string, ok bool) {
|
|
raw = strings.TrimSpace(raw)
|
|
if raw == "" {
|
|
return "", "", "", false
|
|
}
|
|
|
|
matches := githubRepositoryHTTPSPattern.FindStringSubmatch(raw)
|
|
if len(matches) != 3 {
|
|
matches = githubRepositorySSHPathPattern.FindStringSubmatch(raw)
|
|
}
|
|
if len(matches) != 3 {
|
|
return "", "", "", false
|
|
}
|
|
|
|
owner = strings.TrimSpace(matches[1])
|
|
repo = strings.TrimSpace(matches[2])
|
|
repo = strings.TrimSuffix(repo, ".git")
|
|
if owner == "" || repo == "" {
|
|
return "", "", "", false
|
|
}
|
|
|
|
return owner, repo, fmt.Sprintf("https://github.com/%s/%s", owner, repo), true
|
|
}
|
|
|
|
func buildGitHubBranchURL(owner string, repo string, branch string) string {
|
|
owner = strings.TrimSpace(owner)
|
|
repo = strings.TrimSpace(repo)
|
|
branch = strings.TrimSpace(branch)
|
|
if owner == "" || repo == "" || branch == "" {
|
|
return ""
|
|
}
|
|
|
|
return fmt.Sprintf(
|
|
"https://github.com/%s/%s/tree/%s",
|
|
owner,
|
|
repo,
|
|
url.PathEscape(branch),
|
|
)
|
|
}
|
|
|
|
func chatDiffStatusIsStale(status database.ChatDiffStatus, now time.Time) bool {
|
|
if !status.RefreshedAt.Valid {
|
|
return true
|
|
}
|
|
return !status.StaleAt.After(now)
|
|
}
|
|
|
|
func (api *API) refreshChatDiffStatus(
|
|
ctx context.Context,
|
|
chatOwnerID uuid.UUID,
|
|
chatID uuid.UUID,
|
|
pullRequestURL string,
|
|
) (database.ChatDiffStatus, error) {
|
|
status, err := api.fetchGitHubPullRequestStatus(
|
|
ctx,
|
|
pullRequestURL,
|
|
api.resolveChatGitHubAccessToken(ctx, chatOwnerID),
|
|
)
|
|
if err != nil {
|
|
return database.ChatDiffStatus{}, err
|
|
}
|
|
|
|
refreshedAt := time.Now().UTC()
|
|
refreshedStatus, err := api.Database.UpsertChatDiffStatus(
|
|
ctx,
|
|
database.UpsertChatDiffStatusParams{
|
|
ChatID: chatID,
|
|
Url: sql.NullString{String: pullRequestURL, Valid: true},
|
|
PullRequestState: sql.NullString{
|
|
String: status.PullRequestState,
|
|
Valid: status.PullRequestState != "",
|
|
},
|
|
ChangesRequested: status.ChangesRequested,
|
|
Additions: status.Additions,
|
|
Deletions: status.Deletions,
|
|
ChangedFiles: status.ChangedFiles,
|
|
RefreshedAt: refreshedAt,
|
|
StaleAt: refreshedAt.Add(chatDiffStatusTTL),
|
|
},
|
|
)
|
|
if err != nil {
|
|
return database.ChatDiffStatus{}, xerrors.Errorf("upsert chat diff status: %w", err)
|
|
}
|
|
return refreshedStatus, nil
|
|
}
|
|
|
|
func (api *API) resolveChatGitHubAccessToken(
|
|
ctx context.Context,
|
|
userID uuid.UUID,
|
|
) string {
|
|
// Build a map of provider ID -> config so we can refresh tokens
|
|
// using the same code path as provisionerdserver.
|
|
ghConfigs := make(map[string]*externalauth.Config)
|
|
providerIDs := []string{"github"}
|
|
for _, config := range api.ExternalAuthConfigs {
|
|
if !strings.EqualFold(
|
|
config.Type,
|
|
string(codersdk.EnhancedExternalAuthProviderGitHub),
|
|
) {
|
|
continue
|
|
}
|
|
providerIDs = append(providerIDs, config.ID)
|
|
ghConfigs[config.ID] = config
|
|
}
|
|
|
|
seen := map[string]struct{}{}
|
|
for _, providerID := range providerIDs {
|
|
if _, ok := seen[providerID]; ok {
|
|
continue
|
|
}
|
|
seen[providerID] = struct{}{}
|
|
|
|
link, err := api.Database.GetExternalAuthLink(
|
|
ctx,
|
|
database.GetExternalAuthLinkParams{
|
|
ProviderID: providerID,
|
|
UserID: userID,
|
|
},
|
|
)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
|
|
// Refresh the token if there is a matching config, mirroring
|
|
// the same code path used by provisionerdserver when handing
|
|
// tokens to provisioners.
|
|
if cfg, ok := ghConfigs[providerID]; ok {
|
|
refreshed, refreshErr := cfg.RefreshToken(ctx, api.Database, link)
|
|
if refreshErr != nil {
|
|
api.Logger.Debug(ctx, "failed to refresh external auth token for chat diff",
|
|
slog.F("provider_id", providerID),
|
|
slog.F("user_id", userID),
|
|
slog.Error(refreshErr),
|
|
)
|
|
// Fall through — the existing token may still work
|
|
// (e.g. GitHub tokens with no expiry).
|
|
} else {
|
|
link = refreshed
|
|
}
|
|
}
|
|
|
|
token := strings.TrimSpace(link.OAuthAccessToken)
|
|
if token != "" {
|
|
return token
|
|
}
|
|
}
|
|
|
|
return ""
|
|
}
|
|
|
|
func (api *API) resolveGitHubPullRequestURLFromRepositoryRef(
|
|
ctx context.Context,
|
|
userID uuid.UUID,
|
|
repositoryRef chatRepositoryRef,
|
|
) (string, error) {
|
|
if repositoryRef.Owner == "" || repositoryRef.Repo == "" || repositoryRef.Branch == "" {
|
|
return "", nil
|
|
}
|
|
|
|
query := url.Values{}
|
|
query.Set("state", "open")
|
|
query.Set("head", fmt.Sprintf("%s:%s", repositoryRef.Owner, repositoryRef.Branch))
|
|
query.Set("sort", "updated")
|
|
query.Set("direction", "desc")
|
|
query.Set("per_page", "1")
|
|
|
|
requestURL := fmt.Sprintf(
|
|
"%s/repos/%s/%s/pulls?%s",
|
|
githubAPIBaseURL,
|
|
repositoryRef.Owner,
|
|
repositoryRef.Repo,
|
|
query.Encode(),
|
|
)
|
|
|
|
var pulls []struct {
|
|
HTMLURL string `json:"html_url"`
|
|
}
|
|
|
|
token := api.resolveChatGitHubAccessToken(ctx, userID)
|
|
if err := api.decodeGitHubJSON(ctx, requestURL, token, &pulls); err != nil {
|
|
return "", err
|
|
}
|
|
if len(pulls) == 0 {
|
|
return "", nil
|
|
}
|
|
|
|
return normalizeGitHubPullRequestURL(pulls[0].HTMLURL), nil
|
|
}
|
|
|
|
func (api *API) fetchGitHubPullRequestDiff(
|
|
ctx context.Context,
|
|
pullRequestURL string,
|
|
token string,
|
|
) (string, error) {
|
|
ref, ok := parseGitHubPullRequestURL(pullRequestURL)
|
|
if !ok {
|
|
return "", xerrors.Errorf("invalid GitHub pull request URL %q", pullRequestURL)
|
|
}
|
|
|
|
requestURL := fmt.Sprintf(
|
|
"%s/repos/%s/%s/pulls/%d",
|
|
githubAPIBaseURL,
|
|
ref.Owner,
|
|
ref.Repo,
|
|
ref.Number,
|
|
)
|
|
|
|
return api.fetchGitHubDiff(ctx, requestURL, token)
|
|
}
|
|
|
|
func (api *API) fetchGitHubCompareDiff(
|
|
ctx context.Context,
|
|
repositoryRef chatRepositoryRef,
|
|
token string,
|
|
) (string, error) {
|
|
if repositoryRef.Owner == "" || repositoryRef.Repo == "" || repositoryRef.Branch == "" {
|
|
return "", nil
|
|
}
|
|
|
|
var repository struct {
|
|
DefaultBranch string `json:"default_branch"`
|
|
}
|
|
|
|
repositoryURL := fmt.Sprintf(
|
|
"%s/repos/%s/%s",
|
|
githubAPIBaseURL,
|
|
repositoryRef.Owner,
|
|
repositoryRef.Repo,
|
|
)
|
|
if err := api.decodeGitHubJSON(ctx, repositoryURL, token, &repository); err != nil {
|
|
return "", err
|
|
}
|
|
defaultBranch := strings.TrimSpace(repository.DefaultBranch)
|
|
if defaultBranch == "" {
|
|
return "", xerrors.New("github repository default branch is empty")
|
|
}
|
|
|
|
requestURL := fmt.Sprintf(
|
|
"%s/repos/%s/%s/compare/%s...%s",
|
|
githubAPIBaseURL,
|
|
repositoryRef.Owner,
|
|
repositoryRef.Repo,
|
|
url.PathEscape(defaultBranch),
|
|
url.PathEscape(repositoryRef.Branch),
|
|
)
|
|
|
|
return api.fetchGitHubDiff(ctx, requestURL, token)
|
|
}
|
|
|
|
func (api *API) fetchGitHubDiff(
|
|
ctx context.Context,
|
|
requestURL string,
|
|
token string,
|
|
) (string, error) {
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL, nil)
|
|
if err != nil {
|
|
return "", xerrors.Errorf("create github diff request: %w", err)
|
|
}
|
|
req.Header.Set("Accept", "application/vnd.github.diff")
|
|
req.Header.Set("X-GitHub-Api-Version", "2022-11-28")
|
|
req.Header.Set("User-Agent", "coder-chat-diff")
|
|
if token != "" {
|
|
req.Header.Set("Authorization", "Bearer "+token)
|
|
}
|
|
|
|
httpClient := api.HTTPClient
|
|
if httpClient == nil {
|
|
httpClient = http.DefaultClient
|
|
}
|
|
|
|
resp, err := httpClient.Do(req)
|
|
if err != nil {
|
|
return "", xerrors.Errorf("execute github diff request: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
body, readErr := io.ReadAll(io.LimitReader(resp.Body, 8192))
|
|
if readErr != nil {
|
|
return "", xerrors.Errorf("github diff request failed with status %d", resp.StatusCode)
|
|
}
|
|
return "", xerrors.Errorf(
|
|
"github diff request failed with status %d: %s",
|
|
resp.StatusCode,
|
|
strings.TrimSpace(string(body)),
|
|
)
|
|
}
|
|
|
|
diff, err := io.ReadAll(io.LimitReader(resp.Body, 4<<20))
|
|
if err != nil {
|
|
return "", xerrors.Errorf("read github diff response: %w", err)
|
|
}
|
|
return string(diff), nil
|
|
}
|
|
|
|
func (api *API) fetchGitHubPullRequestStatus(
|
|
ctx context.Context,
|
|
pullRequestURL string,
|
|
token string,
|
|
) (githubPullRequestStatus, error) {
|
|
ref, ok := parseGitHubPullRequestURL(pullRequestURL)
|
|
if !ok {
|
|
return githubPullRequestStatus{}, xerrors.Errorf(
|
|
"invalid GitHub pull request URL %q",
|
|
pullRequestURL,
|
|
)
|
|
}
|
|
|
|
pullEndpoint := fmt.Sprintf(
|
|
"%s/repos/%s/%s/pulls/%d",
|
|
githubAPIBaseURL,
|
|
ref.Owner,
|
|
ref.Repo,
|
|
ref.Number,
|
|
)
|
|
|
|
var pull struct {
|
|
State string `json:"state"`
|
|
Additions int32 `json:"additions"`
|
|
Deletions int32 `json:"deletions"`
|
|
ChangedFiles int32 `json:"changed_files"`
|
|
}
|
|
if err := api.decodeGitHubJSON(ctx, pullEndpoint, token, &pull); err != nil {
|
|
return githubPullRequestStatus{}, err
|
|
}
|
|
|
|
var reviews []struct {
|
|
ID int64 `json:"id"`
|
|
State string `json:"state"`
|
|
User struct {
|
|
Login string `json:"login"`
|
|
} `json:"user"`
|
|
}
|
|
if err := api.decodeGitHubJSON(
|
|
ctx,
|
|
pullEndpoint+"/reviews?per_page=100",
|
|
token,
|
|
&reviews,
|
|
); err != nil {
|
|
return githubPullRequestStatus{}, err
|
|
}
|
|
|
|
return githubPullRequestStatus{
|
|
PullRequestState: strings.ToLower(strings.TrimSpace(pull.State)),
|
|
ChangesRequested: hasOutstandingGitHubChangesRequested(reviews),
|
|
Additions: pull.Additions,
|
|
Deletions: pull.Deletions,
|
|
ChangedFiles: pull.ChangedFiles,
|
|
}, nil
|
|
}
|
|
|
|
func (api *API) decodeGitHubJSON(
|
|
ctx context.Context,
|
|
requestURL string,
|
|
token string,
|
|
dest any,
|
|
) error {
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL, nil)
|
|
if err != nil {
|
|
return xerrors.Errorf("create github request: %w", err)
|
|
}
|
|
req.Header.Set("Accept", "application/vnd.github+json")
|
|
req.Header.Set("X-GitHub-Api-Version", "2022-11-28")
|
|
req.Header.Set("User-Agent", "coder-chat-diff-status")
|
|
if token != "" {
|
|
req.Header.Set("Authorization", "Bearer "+token)
|
|
}
|
|
|
|
httpClient := api.HTTPClient
|
|
if httpClient == nil {
|
|
httpClient = http.DefaultClient
|
|
}
|
|
|
|
resp, err := httpClient.Do(req)
|
|
if err != nil {
|
|
return xerrors.Errorf("execute github request: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
body, readErr := io.ReadAll(io.LimitReader(resp.Body, 8192))
|
|
if readErr != nil {
|
|
return xerrors.Errorf(
|
|
"github request failed with status %d",
|
|
resp.StatusCode,
|
|
)
|
|
}
|
|
return xerrors.Errorf(
|
|
"github request failed with status %d: %s",
|
|
resp.StatusCode,
|
|
strings.TrimSpace(string(body)),
|
|
)
|
|
}
|
|
|
|
if err := json.NewDecoder(resp.Body).Decode(dest); err != nil {
|
|
return xerrors.Errorf("decode github response: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func hasOutstandingGitHubChangesRequested(
|
|
reviews []struct {
|
|
ID int64 `json:"id"`
|
|
State string `json:"state"`
|
|
User struct {
|
|
Login string `json:"login"`
|
|
} `json:"user"`
|
|
},
|
|
) bool {
|
|
type reviewerState struct {
|
|
reviewID int64
|
|
state string
|
|
}
|
|
|
|
statesByReviewer := make(map[string]reviewerState)
|
|
for _, review := range reviews {
|
|
login := strings.ToLower(strings.TrimSpace(review.User.Login))
|
|
if login == "" {
|
|
continue
|
|
}
|
|
|
|
state := strings.ToUpper(strings.TrimSpace(review.State))
|
|
switch state {
|
|
case "CHANGES_REQUESTED", "APPROVED", "DISMISSED":
|
|
default:
|
|
continue
|
|
}
|
|
|
|
current, exists := statesByReviewer[login]
|
|
if exists && current.reviewID > review.ID {
|
|
continue
|
|
}
|
|
statesByReviewer[login] = reviewerState{
|
|
reviewID: review.ID,
|
|
state: state,
|
|
}
|
|
}
|
|
|
|
for _, state := range statesByReviewer {
|
|
if state.state == "CHANGES_REQUESTED" {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func normalizeGitHubPullRequestURL(raw string) string {
|
|
ref, ok := parseGitHubPullRequestURL(strings.TrimRight(
|
|
strings.TrimSpace(raw),
|
|
"),.;",
|
|
))
|
|
if !ok {
|
|
return ""
|
|
}
|
|
return fmt.Sprintf("https://github.com/%s/%s/pull/%d", ref.Owner, ref.Repo, ref.Number)
|
|
}
|
|
|
|
func parseGitHubPullRequestURL(raw string) (githubPullRequestRef, bool) {
|
|
matches := githubPullRequestPathPattern.FindStringSubmatch(strings.TrimSpace(raw))
|
|
if len(matches) != 4 {
|
|
return githubPullRequestRef{}, false
|
|
}
|
|
|
|
number, err := strconv.Atoi(matches[3])
|
|
if err != nil {
|
|
return githubPullRequestRef{}, false
|
|
}
|
|
|
|
return githubPullRequestRef{
|
|
Owner: matches[1],
|
|
Repo: matches[2],
|
|
Number: number,
|
|
}, true
|
|
}
|
|
|
|
type createChatWorkspaceSelection struct {
|
|
WorkspaceID uuid.NullUUID
|
|
WorkspaceAgentID uuid.NullUUID
|
|
}
|
|
|
|
func (api *API) validateCreateChatWorkspaceSelection(
|
|
ctx context.Context,
|
|
req codersdk.CreateChatRequest,
|
|
) (
|
|
createChatWorkspaceSelection,
|
|
int,
|
|
*codersdk.Response,
|
|
) {
|
|
selection := createChatWorkspaceSelection{}
|
|
if req.WorkspaceID == nil {
|
|
return selection, 0, nil
|
|
}
|
|
|
|
workspace, err := api.Database.GetWorkspaceByID(ctx, *req.WorkspaceID)
|
|
if err != nil {
|
|
if httpapi.Is404Error(err) {
|
|
return selection, http.StatusBadRequest, &codersdk.Response{
|
|
Message: "Workspace not found or you do not have access to this resource",
|
|
}
|
|
}
|
|
return selection, http.StatusInternalServerError, &codersdk.Response{
|
|
Message: "Failed to get workspace.",
|
|
Detail: err.Error(),
|
|
}
|
|
}
|
|
selection.WorkspaceID = uuid.NullUUID{
|
|
UUID: workspace.ID,
|
|
Valid: true,
|
|
}
|
|
|
|
workspaceAgents, err := api.Database.GetWorkspaceAgentsInLatestBuildByWorkspaceID(
|
|
ctx,
|
|
workspace.ID,
|
|
)
|
|
if err != nil {
|
|
return selection, http.StatusInternalServerError, &codersdk.Response{
|
|
Message: "Failed to get workspace agents.",
|
|
Detail: err.Error(),
|
|
}
|
|
}
|
|
if len(workspaceAgents) > 0 {
|
|
selection.WorkspaceAgentID = uuid.NullUUID{
|
|
UUID: workspaceAgents[0].ID,
|
|
Valid: true,
|
|
}
|
|
}
|
|
|
|
return selection, 0, nil
|
|
}
|
|
|
|
func (api *API) resolveCreateChatModelConfigID(
|
|
ctx context.Context,
|
|
req codersdk.CreateChatRequest,
|
|
) (uuid.UUID, int, *codersdk.Response) {
|
|
if req.ModelConfigID != nil {
|
|
if *req.ModelConfigID == uuid.Nil {
|
|
return uuid.Nil, http.StatusBadRequest, &codersdk.Response{
|
|
Message: "Invalid model config ID.",
|
|
}
|
|
}
|
|
return *req.ModelConfigID, 0, nil
|
|
}
|
|
|
|
defaultModelConfig, err := api.Database.GetDefaultChatModelConfig(ctx)
|
|
if err != nil {
|
|
if xerrors.Is(err, sql.ErrNoRows) {
|
|
return uuid.Nil, http.StatusBadRequest, &codersdk.Response{
|
|
Message: "No default chat model config is configured.",
|
|
}
|
|
}
|
|
return uuid.Nil, http.StatusInternalServerError, &codersdk.Response{
|
|
Message: "Failed to resolve chat model config.",
|
|
Detail: err.Error(),
|
|
}
|
|
}
|
|
|
|
return defaultModelConfig.ID, 0, nil
|
|
}
|
|
|
|
func normalizeChatCompressionThreshold(
|
|
requested *int32,
|
|
fallback int32,
|
|
) (int32, error) {
|
|
threshold := fallback
|
|
if requested != nil {
|
|
threshold = *requested
|
|
}
|
|
|
|
if threshold < minChatContextCompressionThreshold ||
|
|
threshold > maxChatContextCompressionThreshold {
|
|
return 0, xerrors.Errorf(
|
|
"context_compression_threshold must be between %d and %d",
|
|
minChatContextCompressionThreshold,
|
|
maxChatContextCompressionThreshold,
|
|
)
|
|
}
|
|
|
|
return threshold, nil
|
|
}
|
|
|
|
func defaultChatSystemPrompt() string {
|
|
return chatd.DefaultSystemPrompt
|
|
}
|
|
|
|
func createChatInputFromRequest(req codersdk.CreateChatRequest) (
|
|
[]fantasy.Content,
|
|
string,
|
|
*codersdk.Response,
|
|
) {
|
|
return createChatInputFromParts(req.Content, "content")
|
|
}
|
|
|
|
func createChatInputFromParts(
|
|
parts []codersdk.ChatInputPart,
|
|
fieldName string,
|
|
) ([]fantasy.Content, string, *codersdk.Response) {
|
|
if len(parts) == 0 {
|
|
return nil, "", &codersdk.Response{
|
|
Message: "Content is required.",
|
|
Detail: "Content cannot be empty.",
|
|
}
|
|
}
|
|
|
|
content := make([]fantasy.Content, 0, len(parts))
|
|
textParts := make([]string, 0, len(parts))
|
|
for i, part := range parts {
|
|
switch strings.ToLower(strings.TrimSpace(string(part.Type))) {
|
|
case string(codersdk.ChatInputPartTypeText):
|
|
text := strings.TrimSpace(part.Text)
|
|
if text == "" {
|
|
return nil, "", &codersdk.Response{
|
|
Message: "Invalid input part.",
|
|
Detail: fmt.Sprintf("%s[%d].text cannot be empty.", fieldName, i),
|
|
}
|
|
}
|
|
content = append(content, fantasy.TextContent{Text: text})
|
|
textParts = append(textParts, text)
|
|
default:
|
|
return nil, "", &codersdk.Response{
|
|
Message: "Invalid input part.",
|
|
Detail: fmt.Sprintf(
|
|
"%s[%d].type %q is not supported.",
|
|
fieldName,
|
|
i,
|
|
part.Type,
|
|
),
|
|
}
|
|
}
|
|
}
|
|
|
|
titleSource := strings.TrimSpace(strings.Join(textParts, " "))
|
|
if titleSource == "" {
|
|
return nil, "", &codersdk.Response{
|
|
Message: "Content is required.",
|
|
Detail: "Content must include at least one text part.",
|
|
}
|
|
}
|
|
return content, titleSource, nil
|
|
}
|
|
|
|
func chatTitleFromMessage(message string) string {
|
|
const maxWords = 6
|
|
const maxRunes = 80
|
|
words := strings.Fields(message)
|
|
if len(words) == 0 {
|
|
return "New Chat"
|
|
}
|
|
truncated := false
|
|
if len(words) > maxWords {
|
|
words = words[:maxWords]
|
|
truncated = true
|
|
}
|
|
title := strings.Join(words, " ")
|
|
if truncated {
|
|
title += "…"
|
|
}
|
|
return truncateRunes(title, maxRunes)
|
|
}
|
|
|
|
func truncateRunes(value string, maxLen int) string {
|
|
if maxLen <= 0 {
|
|
return ""
|
|
}
|
|
|
|
runes := []rune(value)
|
|
if len(runes) <= maxLen {
|
|
return value
|
|
}
|
|
|
|
return string(runes[:maxLen])
|
|
}
|
|
|
|
func convertChat(c database.Chat, diffStatus *database.ChatDiffStatus) codersdk.Chat {
|
|
chat := codersdk.Chat{
|
|
ID: c.ID,
|
|
OwnerID: c.OwnerID,
|
|
LastModelConfigID: c.LastModelConfigID,
|
|
Title: c.Title,
|
|
Status: codersdk.ChatStatus(c.Status),
|
|
CreatedAt: c.CreatedAt,
|
|
UpdatedAt: c.UpdatedAt,
|
|
}
|
|
if c.ParentChatID.Valid {
|
|
parentChatID := c.ParentChatID.UUID
|
|
chat.ParentChatID = &parentChatID
|
|
}
|
|
switch {
|
|
case c.RootChatID.Valid:
|
|
rootChatID := c.RootChatID.UUID
|
|
chat.RootChatID = &rootChatID
|
|
case c.ParentChatID.Valid:
|
|
rootChatID := c.ParentChatID.UUID
|
|
chat.RootChatID = &rootChatID
|
|
default:
|
|
rootChatID := c.ID
|
|
chat.RootChatID = &rootChatID
|
|
}
|
|
if c.WorkspaceID.Valid {
|
|
chat.WorkspaceID = &c.WorkspaceID.UUID
|
|
}
|
|
if c.WorkspaceAgentID.Valid {
|
|
chat.WorkspaceAgentID = &c.WorkspaceAgentID.UUID
|
|
}
|
|
if diffStatus != nil {
|
|
convertedDiffStatus := convertChatDiffStatus(c.ID, diffStatus)
|
|
chat.DiffStatus = &convertedDiffStatus
|
|
}
|
|
return chat
|
|
}
|
|
|
|
func convertChats(chats []database.Chat, diffStatusesByChatID map[uuid.UUID]database.ChatDiffStatus) []codersdk.Chat {
|
|
result := make([]codersdk.Chat, len(chats))
|
|
for i, c := range chats {
|
|
diffStatus, ok := diffStatusesByChatID[c.ID]
|
|
if ok {
|
|
result[i] = convertChat(c, &diffStatus)
|
|
continue
|
|
}
|
|
|
|
result[i] = convertChat(c, nil)
|
|
if diffStatusesByChatID != nil {
|
|
emptyDiffStatus := convertChatDiffStatus(c.ID, nil)
|
|
result[i].DiffStatus = &emptyDiffStatus
|
|
}
|
|
}
|
|
return result
|
|
}
|
|
|
|
func convertChatQueuedMessage(m database.ChatQueuedMessage) codersdk.ChatQueuedMessage {
|
|
return db2sdk.ChatQueuedMessage(m)
|
|
}
|
|
|
|
func convertChatQueuedMessagePtr(m database.ChatQueuedMessage) *codersdk.ChatQueuedMessage {
|
|
qm := convertChatQueuedMessage(m)
|
|
return &qm
|
|
}
|
|
|
|
func convertChatQueuedMessages(msgs []database.ChatQueuedMessage) []codersdk.ChatQueuedMessage {
|
|
result := make([]codersdk.ChatQueuedMessage, 0, len(msgs))
|
|
for _, m := range msgs {
|
|
result = append(result, convertChatQueuedMessage(m))
|
|
}
|
|
return result
|
|
}
|
|
|
|
func convertChatMessage(m database.ChatMessage) codersdk.ChatMessage {
|
|
return db2sdk.ChatMessage(m)
|
|
}
|
|
|
|
func convertChatMessages(messages []database.ChatMessage) []codersdk.ChatMessage {
|
|
result := make([]codersdk.ChatMessage, 0, len(messages))
|
|
for _, m := range messages {
|
|
result = append(result, convertChatMessage(m))
|
|
}
|
|
return result
|
|
}
|
|
|
|
func convertChatDiffStatus(chatID uuid.UUID, status *database.ChatDiffStatus) codersdk.ChatDiffStatus {
|
|
result := codersdk.ChatDiffStatus{
|
|
ChatID: chatID,
|
|
}
|
|
if status == nil {
|
|
return result
|
|
}
|
|
|
|
result.ChatID = status.ChatID
|
|
if status.Url.Valid {
|
|
u := strings.TrimSpace(status.Url.String)
|
|
if u != "" {
|
|
result.URL = &u
|
|
}
|
|
}
|
|
if result.URL == nil {
|
|
owner, repo, _, ok := parseGitHubRepositoryOrigin(status.GitRemoteOrigin)
|
|
if ok {
|
|
branchURL := buildGitHubBranchURL(owner, repo, status.GitBranch)
|
|
if branchURL != "" {
|
|
result.URL = &branchURL
|
|
}
|
|
}
|
|
}
|
|
if status.PullRequestState.Valid {
|
|
pullRequestState := strings.TrimSpace(status.PullRequestState.String)
|
|
if pullRequestState != "" {
|
|
result.PullRequestState = &pullRequestState
|
|
}
|
|
}
|
|
result.ChangesRequested = status.ChangesRequested
|
|
result.Additions = status.Additions
|
|
result.Deletions = status.Deletions
|
|
result.ChangedFiles = status.ChangedFiles
|
|
if status.RefreshedAt.Valid {
|
|
refreshedAt := status.RefreshedAt.Time
|
|
result.RefreshedAt = &refreshedAt
|
|
}
|
|
staleAt := status.StaleAt
|
|
result.StaleAt = &staleAt
|
|
|
|
return result
|
|
}
|
|
|
|
func (api *API) listChatProviders(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
//nolint:gocritic // System context required to read enabled chat providers.
|
|
systemCtx := dbauthz.AsSystemRestricted(ctx)
|
|
if !api.Authorize(r, policy.ActionRead, rbac.ResourceDeploymentConfig) {
|
|
httpapi.Forbidden(rw)
|
|
return
|
|
}
|
|
|
|
providers, err := api.Database.GetChatProviders(ctx)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to list chat providers.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
providersByName := make(map[string]database.ChatProvider, len(providers))
|
|
configuredProviders := make([]chatprovider.ConfiguredProvider, 0, len(providers))
|
|
for _, provider := range providers {
|
|
normalizedProvider := normalizeChatProvider(provider.Provider)
|
|
if normalizedProvider == "" {
|
|
continue
|
|
}
|
|
provider.Provider = normalizedProvider
|
|
providersByName[normalizedProvider] = provider
|
|
configuredProviders = append(configuredProviders, chatprovider.ConfiguredProvider{
|
|
Provider: normalizedProvider,
|
|
APIKey: provider.APIKey,
|
|
BaseURL: provider.BaseUrl,
|
|
})
|
|
}
|
|
if api.chatDaemon == nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Chat processor is unavailable.",
|
|
Detail: "Chat processor is not configured.",
|
|
})
|
|
return
|
|
}
|
|
|
|
enabledProviders, err := api.Database.GetEnabledChatProviders(
|
|
systemCtx,
|
|
)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to resolve provider API keys.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
enabledConfiguredProviders := make(
|
|
[]chatprovider.ConfiguredProvider, 0, len(enabledProviders),
|
|
)
|
|
for _, provider := range enabledProviders {
|
|
enabledConfiguredProviders = append(
|
|
enabledConfiguredProviders, chatprovider.ConfiguredProvider{
|
|
Provider: provider.Provider,
|
|
APIKey: provider.APIKey,
|
|
BaseURL: provider.BaseUrl,
|
|
},
|
|
)
|
|
}
|
|
|
|
effectiveKeys := chatprovider.MergeProviderAPIKeys(
|
|
chatProviderAPIKeysFromDeploymentValues(api.DeploymentValues),
|
|
enabledConfiguredProviders,
|
|
)
|
|
effectiveKeys = chatprovider.MergeProviderAPIKeys(
|
|
effectiveKeys, configuredProviders,
|
|
)
|
|
|
|
supportedProviders := chatprovider.SupportedProviders()
|
|
resp := make([]codersdk.ChatProviderConfig, 0, len(supportedProviders))
|
|
for _, provider := range supportedProviders {
|
|
configured, ok := providersByName[provider]
|
|
if ok {
|
|
resp = append(
|
|
resp,
|
|
convertChatProviderConfig(
|
|
configured,
|
|
effectiveKeys.APIKey(provider) != "",
|
|
codersdk.ChatProviderConfigSourceDatabase,
|
|
),
|
|
)
|
|
continue
|
|
}
|
|
|
|
source := codersdk.ChatProviderConfigSourceSupported
|
|
hasAPIKey := effectiveKeys.APIKey(provider) != ""
|
|
enabled := false
|
|
if chatprovider.IsEnvPresetProvider(provider) && hasAPIKey {
|
|
source = codersdk.ChatProviderConfigSourceEnvPreset
|
|
enabled = true
|
|
}
|
|
|
|
resp = append(resp, codersdk.ChatProviderConfig{
|
|
ID: uuid.Nil,
|
|
Provider: provider,
|
|
DisplayName: chatprovider.ProviderDisplayName(provider),
|
|
Enabled: enabled,
|
|
HasAPIKey: hasAPIKey,
|
|
BaseURL: effectiveKeys.BaseURL(provider),
|
|
Source: source,
|
|
})
|
|
}
|
|
|
|
httpapi.Write(ctx, rw, http.StatusOK, resp)
|
|
}
|
|
|
|
func (api *API) createChatProvider(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
apiKey := httpmw.APIKey(r)
|
|
if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) {
|
|
httpapi.Forbidden(rw)
|
|
return
|
|
}
|
|
|
|
var req codersdk.CreateChatProviderConfigRequest
|
|
if !httpapi.Read(ctx, rw, r, &req) {
|
|
return
|
|
}
|
|
|
|
provider := normalizeChatProvider(req.Provider)
|
|
if provider == "" {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid provider.",
|
|
Detail: chatProviderValidationDetail(),
|
|
})
|
|
return
|
|
}
|
|
|
|
enabled := true
|
|
if req.Enabled != nil {
|
|
enabled = *req.Enabled
|
|
}
|
|
baseURL, err := normalizeChatProviderBaseURL(req.BaseURL)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid provider base URL.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
inserted, err := api.Database.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
|
Provider: provider,
|
|
DisplayName: strings.TrimSpace(req.DisplayName),
|
|
APIKey: strings.TrimSpace(req.APIKey),
|
|
BaseUrl: baseURL,
|
|
ApiKeyKeyID: sql.NullString{},
|
|
CreatedBy: uuid.NullUUID{UUID: apiKey.UserID, Valid: apiKey.UserID != uuid.Nil},
|
|
Enabled: enabled,
|
|
})
|
|
if err != nil {
|
|
switch {
|
|
case database.IsUniqueViolation(err):
|
|
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{
|
|
Message: "Chat provider already exists.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
case database.IsCheckViolation(err):
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid provider.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
default:
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to create chat provider.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
}
|
|
|
|
httpapi.Write(
|
|
ctx,
|
|
rw,
|
|
http.StatusCreated,
|
|
convertChatProviderConfig(
|
|
inserted,
|
|
api.hasEffectiveProviderAPIKey(ctx, inserted),
|
|
codersdk.ChatProviderConfigSourceDatabase,
|
|
),
|
|
)
|
|
}
|
|
|
|
func (api *API) updateChatProvider(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) {
|
|
httpapi.Forbidden(rw)
|
|
return
|
|
}
|
|
|
|
providerID, ok := parseChatProviderID(rw, r)
|
|
if !ok {
|
|
return
|
|
}
|
|
|
|
existing, err := api.Database.GetChatProviderByID(ctx, providerID)
|
|
if err != nil {
|
|
if httpapi.Is404Error(err) {
|
|
httpapi.ResourceNotFound(rw)
|
|
return
|
|
}
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to get chat provider.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
var req codersdk.UpdateChatProviderConfigRequest
|
|
if !httpapi.Read(ctx, rw, r, &req) {
|
|
return
|
|
}
|
|
|
|
displayName := existing.DisplayName
|
|
if trimmed := strings.TrimSpace(req.DisplayName); trimmed != "" {
|
|
displayName = trimmed
|
|
}
|
|
|
|
enabled := existing.Enabled
|
|
if req.Enabled != nil {
|
|
enabled = *req.Enabled
|
|
}
|
|
|
|
apiKey := existing.APIKey
|
|
apiKeyKeyID := existing.ApiKeyKeyID
|
|
if req.APIKey != nil {
|
|
apiKey = strings.TrimSpace(*req.APIKey)
|
|
apiKeyKeyID = sql.NullString{}
|
|
}
|
|
baseURL := existing.BaseUrl
|
|
if req.BaseURL != nil {
|
|
baseURL, err = normalizeChatProviderBaseURL(*req.BaseURL)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid provider base URL.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
}
|
|
|
|
updated, err := api.Database.UpdateChatProvider(ctx, database.UpdateChatProviderParams{
|
|
DisplayName: displayName,
|
|
APIKey: apiKey,
|
|
BaseUrl: baseURL,
|
|
ApiKeyKeyID: apiKeyKeyID,
|
|
Enabled: enabled,
|
|
ID: existing.ID,
|
|
})
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to update chat provider.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
httpapi.Write(
|
|
ctx,
|
|
rw,
|
|
http.StatusOK,
|
|
convertChatProviderConfig(
|
|
updated,
|
|
api.hasEffectiveProviderAPIKey(ctx, updated),
|
|
codersdk.ChatProviderConfigSourceDatabase,
|
|
),
|
|
)
|
|
}
|
|
|
|
func (api *API) deleteChatProvider(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) {
|
|
httpapi.Forbidden(rw)
|
|
return
|
|
}
|
|
|
|
providerID, ok := parseChatProviderID(rw, r)
|
|
if !ok {
|
|
return
|
|
}
|
|
|
|
if _, err := api.Database.GetChatProviderByID(ctx, providerID); err != nil {
|
|
if httpapi.Is404Error(err) {
|
|
httpapi.ResourceNotFound(rw)
|
|
return
|
|
}
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to get chat provider.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
if err := api.Database.DeleteChatProviderByID(ctx, providerID); err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to delete chat provider.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
rw.WriteHeader(http.StatusNoContent)
|
|
}
|
|
|
|
func (api *API) listChatModelConfigs(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
|
|
// Admin users can see all model configs (including disabled ones)
|
|
// for management purposes. Non-admin users see only enabled
|
|
// configs, which is sufficient for using the chat feature.
|
|
isAdmin := api.Authorize(r, policy.ActionRead, rbac.ResourceDeploymentConfig)
|
|
|
|
var configs []database.ChatModelConfig
|
|
var err error
|
|
if isAdmin {
|
|
configs, err = api.Database.GetChatModelConfigs(ctx)
|
|
} else {
|
|
//nolint:gocritic // All authenticated users need to read enabled model configs to use the chat feature.
|
|
configs, err = api.Database.GetEnabledChatModelConfigs(dbauthz.AsSystemRestricted(ctx))
|
|
}
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to list chat model configs.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
resp := make([]codersdk.ChatModelConfig, 0, len(configs))
|
|
for _, config := range configs {
|
|
resp = append(resp, convertChatModelConfig(config))
|
|
}
|
|
|
|
httpapi.Write(ctx, rw, http.StatusOK, resp)
|
|
}
|
|
|
|
func (api *API) createChatModelConfig(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
apiKey := httpmw.APIKey(r)
|
|
if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) {
|
|
httpapi.Forbidden(rw)
|
|
return
|
|
}
|
|
|
|
var req codersdk.CreateChatModelConfigRequest
|
|
if !httpapi.Read(ctx, rw, r, &req) {
|
|
return
|
|
}
|
|
|
|
provider := normalizeChatProvider(req.Provider)
|
|
if provider == "" {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid provider.",
|
|
Detail: chatProviderValidationDetail(),
|
|
})
|
|
return
|
|
}
|
|
|
|
model := strings.TrimSpace(req.Model)
|
|
if model == "" {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Model is required.",
|
|
})
|
|
return
|
|
}
|
|
|
|
enabled := true
|
|
if req.Enabled != nil {
|
|
enabled = *req.Enabled
|
|
}
|
|
isDefault := false
|
|
if req.IsDefault != nil {
|
|
isDefault = *req.IsDefault
|
|
}
|
|
|
|
if req.ContextLimit == nil || *req.ContextLimit <= 0 {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Context limit is required.",
|
|
Detail: "context_limit must be greater than zero.",
|
|
})
|
|
return
|
|
}
|
|
contextLimit := *req.ContextLimit
|
|
|
|
compressionThreshold, thresholdErr := normalizeChatCompressionThreshold(
|
|
req.CompressionThreshold,
|
|
defaultChatContextCompressionThreshold,
|
|
)
|
|
if thresholdErr != nil {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid compression threshold.",
|
|
Detail: thresholdErr.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
modelConfigRaw, modelConfigErr := marshalChatModelCallConfig(req.ModelConfig)
|
|
if modelConfigErr != nil {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid model config.",
|
|
Detail: modelConfigErr.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
insertParams := database.InsertChatModelConfigParams{
|
|
Provider: provider,
|
|
Model: model,
|
|
DisplayName: strings.TrimSpace(req.DisplayName),
|
|
Enabled: enabled,
|
|
IsDefault: isDefault,
|
|
ContextLimit: contextLimit,
|
|
CompressionThreshold: compressionThreshold,
|
|
Options: modelConfigRaw,
|
|
CreatedBy: uuid.NullUUID{UUID: apiKey.UserID, Valid: apiKey.UserID != uuid.Nil},
|
|
UpdatedBy: uuid.NullUUID{UUID: apiKey.UserID, Valid: apiKey.UserID != uuid.Nil},
|
|
}
|
|
|
|
var inserted database.ChatModelConfig
|
|
err := api.Database.InTx(func(tx database.Store) error {
|
|
insertAsDefault := isDefault
|
|
if !insertAsDefault {
|
|
_, err := tx.GetDefaultChatModelConfig(ctx)
|
|
switch {
|
|
case err == nil:
|
|
// A default already exists.
|
|
case xerrors.Is(err, sql.ErrNoRows):
|
|
insertAsDefault = true
|
|
default:
|
|
return xerrors.Errorf("get default model config: %w", err)
|
|
}
|
|
}
|
|
|
|
if insertAsDefault {
|
|
if err := tx.UnsetDefaultChatModelConfigs(ctx); err != nil {
|
|
return xerrors.Errorf("unset default model configs: %w", err)
|
|
}
|
|
}
|
|
insertParams.IsDefault = insertAsDefault
|
|
|
|
config, err := tx.InsertChatModelConfig(ctx, insertParams)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
inserted = config
|
|
|
|
if err := ensureDefaultChatModelConfig(ctx, tx); err != nil {
|
|
return err
|
|
}
|
|
|
|
refreshedConfig, err := tx.GetChatModelConfigByID(ctx, inserted.ID)
|
|
if err != nil {
|
|
return xerrors.Errorf("refresh inserted chat model config: %w", err)
|
|
}
|
|
inserted = refreshedConfig
|
|
return nil
|
|
}, nil)
|
|
if err != nil {
|
|
switch {
|
|
case database.IsUniqueViolation(err):
|
|
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{
|
|
Message: "Chat model config already exists.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
case database.IsForeignKeyViolation(err):
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Chat provider is not configured.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
default:
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to create chat model config.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
}
|
|
|
|
httpapi.Write(ctx, rw, http.StatusCreated, convertChatModelConfig(inserted))
|
|
}
|
|
|
|
func (api *API) updateChatModelConfig(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
apiKey := httpmw.APIKey(r)
|
|
if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) {
|
|
httpapi.Forbidden(rw)
|
|
return
|
|
}
|
|
|
|
modelConfigID, ok := parseChatModelConfigID(rw, r)
|
|
if !ok {
|
|
return
|
|
}
|
|
|
|
existing, err := api.Database.GetChatModelConfigByID(ctx, modelConfigID)
|
|
if err != nil {
|
|
if httpapi.Is404Error(err) {
|
|
httpapi.ResourceNotFound(rw)
|
|
return
|
|
}
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to get chat model config.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
var req codersdk.UpdateChatModelConfigRequest
|
|
if !httpapi.Read(ctx, rw, r, &req) {
|
|
return
|
|
}
|
|
|
|
provider := existing.Provider
|
|
if strings.TrimSpace(req.Provider) != "" {
|
|
provider = normalizeChatProvider(req.Provider)
|
|
if provider == "" {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid provider.",
|
|
Detail: chatProviderValidationDetail(),
|
|
})
|
|
return
|
|
}
|
|
}
|
|
|
|
model := existing.Model
|
|
if trimmed := strings.TrimSpace(req.Model); trimmed != "" {
|
|
model = trimmed
|
|
}
|
|
|
|
displayName := existing.DisplayName
|
|
if trimmed := strings.TrimSpace(req.DisplayName); trimmed != "" {
|
|
displayName = trimmed
|
|
}
|
|
|
|
enabled := existing.Enabled
|
|
if req.Enabled != nil {
|
|
enabled = *req.Enabled
|
|
}
|
|
isDefault := existing.IsDefault
|
|
if req.IsDefault != nil {
|
|
isDefault = *req.IsDefault
|
|
}
|
|
|
|
contextLimit := existing.ContextLimit
|
|
if req.ContextLimit != nil {
|
|
if *req.ContextLimit <= 0 {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Context limit must be greater than zero.",
|
|
})
|
|
return
|
|
}
|
|
contextLimit = *req.ContextLimit
|
|
}
|
|
|
|
compressionThreshold, thresholdErr := normalizeChatCompressionThreshold(
|
|
req.CompressionThreshold,
|
|
existing.CompressionThreshold,
|
|
)
|
|
if thresholdErr != nil {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid compression threshold.",
|
|
Detail: thresholdErr.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
modelConfigRaw := existing.Options
|
|
if req.ModelConfig != nil {
|
|
encodedModelConfig, modelConfigErr := marshalChatModelCallConfig(req.ModelConfig)
|
|
if modelConfigErr != nil {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid model config.",
|
|
Detail: modelConfigErr.Error(),
|
|
})
|
|
return
|
|
}
|
|
modelConfigRaw = encodedModelConfig
|
|
}
|
|
|
|
updateParams := database.UpdateChatModelConfigParams{
|
|
Provider: provider,
|
|
Model: model,
|
|
DisplayName: displayName,
|
|
Enabled: enabled,
|
|
IsDefault: isDefault,
|
|
ContextLimit: contextLimit,
|
|
CompressionThreshold: compressionThreshold,
|
|
Options: modelConfigRaw,
|
|
UpdatedBy: uuid.NullUUID{UUID: apiKey.UserID, Valid: apiKey.UserID != uuid.Nil},
|
|
ID: existing.ID,
|
|
}
|
|
|
|
var updated database.ChatModelConfig
|
|
err = api.Database.InTx(func(tx database.Store) error {
|
|
setAsDefault := updateParams.IsDefault && !existing.IsDefault
|
|
if setAsDefault {
|
|
if err := tx.UnsetDefaultChatModelConfigs(ctx); err != nil {
|
|
return xerrors.Errorf("unset default model configs: %w", err)
|
|
}
|
|
}
|
|
|
|
_, err := tx.UpdateChatModelConfig(ctx, updateParams)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
excludeConfigID := uuid.Nil
|
|
if existing.IsDefault && req.IsDefault != nil && !*req.IsDefault {
|
|
excludeConfigID = existing.ID
|
|
}
|
|
|
|
if err := ensureDefaultChatModelConfig(
|
|
ctx,
|
|
tx,
|
|
excludeConfigID,
|
|
); err != nil {
|
|
return err
|
|
}
|
|
|
|
refreshedConfig, err := tx.GetChatModelConfigByID(ctx, existing.ID)
|
|
if err != nil {
|
|
return xerrors.Errorf("refresh updated chat model config: %w", err)
|
|
}
|
|
updated = refreshedConfig
|
|
return nil
|
|
}, nil)
|
|
if err != nil {
|
|
switch {
|
|
case database.IsUniqueViolation(err):
|
|
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{
|
|
Message: "Chat model config already exists.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
case database.IsForeignKeyViolation(err):
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Chat provider is not configured.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
default:
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to update chat model config.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
}
|
|
|
|
httpapi.Write(ctx, rw, http.StatusOK, convertChatModelConfig(updated))
|
|
}
|
|
|
|
func (api *API) deleteChatModelConfig(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) {
|
|
httpapi.Forbidden(rw)
|
|
return
|
|
}
|
|
|
|
modelConfigID, ok := parseChatModelConfigID(rw, r)
|
|
if !ok {
|
|
return
|
|
}
|
|
|
|
if _, err := api.Database.GetChatModelConfigByID(ctx, modelConfigID); err != nil {
|
|
if httpapi.Is404Error(err) {
|
|
httpapi.ResourceNotFound(rw)
|
|
return
|
|
}
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to get chat model config.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
if err := api.Database.InTx(func(tx database.Store) error {
|
|
if err := tx.DeleteChatModelConfigByID(ctx, modelConfigID); err != nil {
|
|
return err
|
|
}
|
|
return ensureDefaultChatModelConfig(ctx, tx)
|
|
}, nil); err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to delete chat model config.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
rw.WriteHeader(http.StatusNoContent)
|
|
}
|
|
|
|
func ensureDefaultChatModelConfig(
|
|
ctx context.Context,
|
|
tx database.Store,
|
|
excludedConfigIDs ...uuid.UUID,
|
|
) error {
|
|
_, err := tx.GetDefaultChatModelConfig(ctx)
|
|
switch {
|
|
case err == nil:
|
|
return nil
|
|
case !xerrors.Is(err, sql.ErrNoRows):
|
|
return xerrors.Errorf("get default model config: %w", err)
|
|
}
|
|
|
|
modelConfigs, err := tx.GetChatModelConfigs(ctx)
|
|
if err != nil {
|
|
return xerrors.Errorf("list chat model configs: %w", err)
|
|
}
|
|
if len(modelConfigs) == 0 {
|
|
return nil
|
|
}
|
|
|
|
candidateConfig := modelConfigs[0]
|
|
excluded := make(map[uuid.UUID]struct{}, len(excludedConfigIDs))
|
|
for _, configID := range excludedConfigIDs {
|
|
if configID == uuid.Nil {
|
|
continue
|
|
}
|
|
excluded[configID] = struct{}{}
|
|
}
|
|
for _, config := range modelConfigs {
|
|
if _, skip := excluded[config.ID]; skip {
|
|
continue
|
|
}
|
|
candidateConfig = config
|
|
break
|
|
}
|
|
|
|
if err := tx.UnsetDefaultChatModelConfigs(ctx); err != nil {
|
|
return xerrors.Errorf("unset default model configs: %w", err)
|
|
}
|
|
|
|
params := chatModelConfigToUpdateParams(candidateConfig)
|
|
params.IsDefault = true
|
|
if _, err := tx.UpdateChatModelConfig(ctx, params); err != nil {
|
|
return xerrors.Errorf("set default model config: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func chatModelConfigToUpdateParams(
|
|
config database.ChatModelConfig,
|
|
) database.UpdateChatModelConfigParams {
|
|
return database.UpdateChatModelConfigParams{
|
|
Provider: config.Provider,
|
|
Model: config.Model,
|
|
DisplayName: config.DisplayName,
|
|
Enabled: config.Enabled,
|
|
IsDefault: config.IsDefault,
|
|
ContextLimit: config.ContextLimit,
|
|
CompressionThreshold: config.CompressionThreshold,
|
|
Options: config.Options,
|
|
UpdatedBy: uuid.NullUUID{},
|
|
ID: config.ID,
|
|
}
|
|
}
|
|
|
|
func parseChatProviderID(rw http.ResponseWriter, r *http.Request) (uuid.UUID, bool) {
|
|
providerID, err := uuid.Parse(chi.URLParam(r, "providerConfig"))
|
|
if err != nil {
|
|
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid chat provider ID.",
|
|
Detail: err.Error(),
|
|
})
|
|
return uuid.Nil, false
|
|
}
|
|
return providerID, true
|
|
}
|
|
|
|
func parseChatModelConfigID(rw http.ResponseWriter, r *http.Request) (uuid.UUID, bool) {
|
|
modelConfigID, err := uuid.Parse(chi.URLParam(r, "modelConfig"))
|
|
if err != nil {
|
|
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid chat model config ID.",
|
|
Detail: err.Error(),
|
|
})
|
|
return uuid.Nil, false
|
|
}
|
|
return modelConfigID, true
|
|
}
|
|
|
|
func convertChatProviderConfig(
|
|
provider database.ChatProvider,
|
|
hasAPIKey bool,
|
|
source codersdk.ChatProviderConfigSource,
|
|
) codersdk.ChatProviderConfig {
|
|
displayName := strings.TrimSpace(provider.DisplayName)
|
|
if displayName == "" {
|
|
displayName = chatprovider.ProviderDisplayName(provider.Provider)
|
|
}
|
|
|
|
return codersdk.ChatProviderConfig{
|
|
ID: provider.ID,
|
|
Provider: provider.Provider,
|
|
DisplayName: displayName,
|
|
Enabled: provider.Enabled,
|
|
HasAPIKey: hasAPIKey,
|
|
BaseURL: strings.TrimSpace(provider.BaseUrl),
|
|
Source: source,
|
|
CreatedAt: provider.CreatedAt,
|
|
UpdatedAt: provider.UpdatedAt,
|
|
}
|
|
}
|
|
|
|
func convertChatModelConfig(config database.ChatModelConfig) codersdk.ChatModelConfig {
|
|
return codersdk.ChatModelConfig{
|
|
ID: config.ID,
|
|
Provider: config.Provider,
|
|
Model: config.Model,
|
|
DisplayName: config.DisplayName,
|
|
Enabled: config.Enabled,
|
|
IsDefault: config.IsDefault,
|
|
ContextLimit: config.ContextLimit,
|
|
CompressionThreshold: config.CompressionThreshold,
|
|
ModelConfig: unmarshalChatModelCallConfig(config.Options),
|
|
CreatedAt: config.CreatedAt,
|
|
UpdatedAt: config.UpdatedAt,
|
|
}
|
|
}
|
|
|
|
func marshalChatModelCallConfig(
|
|
modelConfig *codersdk.ChatModelCallConfig,
|
|
) (json.RawMessage, error) {
|
|
if modelConfig == nil {
|
|
return json.RawMessage("{}"), nil
|
|
}
|
|
|
|
encoded, err := json.Marshal(modelConfig)
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("encode model config: %w", err)
|
|
}
|
|
return encoded, nil
|
|
}
|
|
|
|
func unmarshalChatModelCallConfig(
|
|
raw json.RawMessage,
|
|
) *codersdk.ChatModelCallConfig {
|
|
if len(raw) == 0 {
|
|
return nil
|
|
}
|
|
|
|
decoded := &codersdk.ChatModelCallConfig{}
|
|
if err := json.Unmarshal(raw, decoded); err != nil {
|
|
return nil
|
|
}
|
|
if isZeroChatModelCallConfig(decoded) {
|
|
return nil
|
|
}
|
|
return decoded
|
|
}
|
|
|
|
func isZeroChatModelCallConfig(config *codersdk.ChatModelCallConfig) bool {
|
|
if config == nil {
|
|
return true
|
|
}
|
|
|
|
return config.MaxOutputTokens == nil &&
|
|
config.Temperature == nil &&
|
|
config.TopP == nil &&
|
|
config.TopK == nil &&
|
|
config.PresencePenalty == nil &&
|
|
config.FrequencyPenalty == nil &&
|
|
isZeroChatModelProviderOptions(config.ProviderOptions)
|
|
}
|
|
|
|
func isZeroChatModelProviderOptions(options *codersdk.ChatModelProviderOptions) bool {
|
|
if options == nil {
|
|
return true
|
|
}
|
|
|
|
return options.OpenAI == nil &&
|
|
options.Anthropic == nil &&
|
|
options.Google == nil &&
|
|
options.OpenAICompat == nil &&
|
|
options.OpenRouter == nil &&
|
|
options.Vercel == nil
|
|
}
|
|
|
|
func normalizeChatProvider(provider string) string {
|
|
return chatprovider.NormalizeProvider(provider)
|
|
}
|
|
|
|
func normalizeChatProviderBaseURL(raw string) (string, error) {
|
|
trimmed := strings.TrimSpace(raw)
|
|
if trimmed == "" {
|
|
return "", nil
|
|
}
|
|
|
|
parsed, err := url.Parse(trimmed)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
if parsed.Scheme == "" || parsed.Host == "" {
|
|
return "", xerrors.New("Base URL must be an absolute URL with scheme and host.")
|
|
}
|
|
if parsed.Scheme != "http" && parsed.Scheme != "https" {
|
|
return "", xerrors.New("Base URL scheme must be http or https.")
|
|
}
|
|
return parsed.String(), nil
|
|
}
|
|
|
|
func chatProviderValidationDetail() string {
|
|
return "Provider must be one of: " + strings.Join(chatprovider.SupportedProviders(), ", ") + "."
|
|
}
|
|
|
|
func chatProviderAPIKeysFromDeploymentValues(
|
|
deploymentValues *codersdk.DeploymentValues,
|
|
) chatprovider.ProviderAPIKeys {
|
|
_ = deploymentValues
|
|
// For now, we'll just manage configs in the UI.
|
|
// We should probably not be reusing the AI bridge configs anyways.
|
|
return chatprovider.ProviderAPIKeys{
|
|
// OpenAI: deploymentValues.AI.BridgeConfig.OpenAI.Key.Value(),
|
|
// Anthropic: deploymentValues.AI.BridgeConfig.Anthropic.Key.Value(),
|
|
// BaseURLByProvider: map[string]string{
|
|
// "openai": deploymentValues.AI.BridgeConfig.OpenAI.BaseURL.Value(),
|
|
// "anthropic": deploymentValues.AI.BridgeConfig.Anthropic.BaseURL.Value(),
|
|
// },
|
|
}
|
|
}
|
|
|
|
func (api *API) hasEffectiveProviderAPIKey(ctx context.Context, provider database.ChatProvider) bool {
|
|
if strings.TrimSpace(provider.APIKey) != "" {
|
|
return true
|
|
}
|
|
if api.chatDaemon == nil {
|
|
return false
|
|
}
|
|
//nolint:gocritic // System context required to read enabled chat providers.
|
|
systemCtx := dbauthz.AsSystemRestricted(ctx)
|
|
|
|
enabledProviders, err := api.Database.GetEnabledChatProviders(
|
|
systemCtx,
|
|
)
|
|
if err != nil {
|
|
api.Logger.Warn(ctx, "failed to resolve provider API keys",
|
|
slog.F("provider", provider.Provider),
|
|
slog.Error(err),
|
|
)
|
|
return false
|
|
}
|
|
|
|
enabledConfiguredProviders := make(
|
|
[]chatprovider.ConfiguredProvider, 0, len(enabledProviders),
|
|
)
|
|
for _, configured := range enabledProviders {
|
|
enabledConfiguredProviders = append(
|
|
enabledConfiguredProviders, chatprovider.ConfiguredProvider{
|
|
Provider: configured.Provider,
|
|
APIKey: configured.APIKey,
|
|
BaseURL: configured.BaseUrl,
|
|
},
|
|
)
|
|
}
|
|
|
|
effectiveKeys := chatprovider.MergeProviderAPIKeys(
|
|
chatProviderAPIKeysFromDeploymentValues(api.DeploymentValues),
|
|
enabledConfiguredProviders,
|
|
)
|
|
return effectiveKeys.APIKey(provider.Provider) != ""
|
|
}
|