mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
89d470c112
The watchChats handler called SubscribeWithErr after websocket.Accept, creating a window where clients could trigger events before the subscription was active. Move the subscription before the accept so events accumulate in the pubsub internal queue and drain naturally once the encoder is ready. Fixes: https://linear.app/codercom/issue/CODAGT-480
7919 lines
246 KiB
Go
7919 lines
246 KiB
Go
package coderd
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"math"
|
|
"mime"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"slices"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
"unicode/utf8"
|
|
|
|
"github.com/go-chi/chi/v5"
|
|
"github.com/google/uuid"
|
|
"github.com/shopspring/decimal"
|
|
"github.com/sqlc-dev/pqtype"
|
|
"golang.org/x/sync/errgroup"
|
|
"golang.org/x/xerrors"
|
|
|
|
"cdr.dev/slog/v3"
|
|
"github.com/coder/coder/v2/agent/agentssh"
|
|
"github.com/coder/coder/v2/coderd/audit"
|
|
"github.com/coder/coder/v2/coderd/database"
|
|
"github.com/coder/coder/v2/coderd/database/db2sdk"
|
|
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
|
dbpubsub "github.com/coder/coder/v2/coderd/database/pubsub"
|
|
"github.com/coder/coder/v2/coderd/dynamicparameters"
|
|
"github.com/coder/coder/v2/coderd/externalauth"
|
|
"github.com/coder/coder/v2/coderd/externalauth/gitprovider"
|
|
"github.com/coder/coder/v2/coderd/httpapi"
|
|
"github.com/coder/coder/v2/coderd/httpapi/httperror"
|
|
"github.com/coder/coder/v2/coderd/httpmw"
|
|
"github.com/coder/coder/v2/coderd/pubsub"
|
|
"github.com/coder/coder/v2/coderd/rbac"
|
|
"github.com/coder/coder/v2/coderd/rbac/policy"
|
|
"github.com/coder/coder/v2/coderd/searchquery"
|
|
"github.com/coder/coder/v2/coderd/tracing"
|
|
"github.com/coder/coder/v2/coderd/util/ptr"
|
|
"github.com/coder/coder/v2/coderd/util/xjson"
|
|
"github.com/coder/coder/v2/coderd/workspaceapps"
|
|
"github.com/coder/coder/v2/coderd/wsbuilder"
|
|
"github.com/coder/coder/v2/coderd/x/chatd"
|
|
"github.com/coder/coder/v2/coderd/x/chatd/chatprovider"
|
|
"github.com/coder/coder/v2/coderd/x/chatd/chattool"
|
|
"github.com/coder/coder/v2/coderd/x/chatfiles"
|
|
"github.com/coder/coder/v2/coderd/x/gitsync"
|
|
"github.com/coder/coder/v2/codersdk"
|
|
"github.com/coder/coder/v2/codersdk/wsjson"
|
|
"github.com/coder/websocket"
|
|
)
|
|
|
|
const (
|
|
chatStreamBatchSize = 256
|
|
|
|
chatContextLimitModelConfigKey = "context_limit"
|
|
chatContextCompressionThresholdModelConfigKey = "context_compression_threshold"
|
|
defaultChatContextCompressionThreshold = int32(70)
|
|
minChatContextCompressionThreshold = int32(0)
|
|
maxChatContextCompressionThreshold = int32(100)
|
|
maxSystemPromptLenBytes = 131072 // 128 KiB
|
|
)
|
|
|
|
// chatGitRef holds the branch, remote origin, and optional chat
|
|
// ID reported by the workspace agent during a git operation.
|
|
type chatGitRef struct {
|
|
Branch string
|
|
RemoteOrigin string
|
|
ChatID uuid.UUID
|
|
}
|
|
|
|
type chatRepositoryRef struct {
|
|
Provider string
|
|
RemoteOrigin string
|
|
Branch string
|
|
Owner string
|
|
Repo string
|
|
}
|
|
|
|
type chatDiffReference struct {
|
|
PullRequestURL string
|
|
RepositoryRef *chatRepositoryRef
|
|
}
|
|
|
|
func writeChatUsageLimitExceeded(
|
|
ctx context.Context,
|
|
rw http.ResponseWriter,
|
|
limitErr *chatd.UsageLimitExceededError,
|
|
) {
|
|
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.ChatUsageLimitExceededResponse{
|
|
Response: codersdk.Response{
|
|
Message: "Chat usage limit exceeded.",
|
|
},
|
|
SpentMicros: limitErr.ConsumedMicros,
|
|
LimitMicros: limitErr.LimitMicros,
|
|
ResetsAt: limitErr.PeriodEnd,
|
|
})
|
|
}
|
|
|
|
func maybeWriteLimitErr(ctx context.Context, rw http.ResponseWriter, err error) bool {
|
|
var limitErr *chatd.UsageLimitExceededError
|
|
if errors.As(err, &limitErr) {
|
|
writeChatUsageLimitExceeded(ctx, rw, limitErr)
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
|
|
func publishChatTitleChange(logger slog.Logger, ps dbpubsub.Pubsub, chat database.Chat) {
|
|
if ps == nil {
|
|
return
|
|
}
|
|
event := codersdk.ChatWatchEvent{
|
|
Kind: codersdk.ChatWatchEventKindTitleChange,
|
|
Chat: db2sdk.Chat(chat, nil, nil),
|
|
}
|
|
payload, err := json.Marshal(event)
|
|
if err != nil {
|
|
logger.Error(context.Background(), "failed to marshal chat title change event",
|
|
slog.F("chat_id", chat.ID),
|
|
slog.Error(err),
|
|
)
|
|
return
|
|
}
|
|
if err := ps.Publish(pubsub.ChatWatchEventChannel(chat.OwnerID), payload); err != nil {
|
|
logger.Error(context.Background(), "failed to publish chat title change event",
|
|
slog.F("chat_id", chat.ID),
|
|
slog.Error(err),
|
|
)
|
|
}
|
|
}
|
|
|
|
func publishChatConfigEvent(logger slog.Logger, ps dbpubsub.Pubsub, kind pubsub.ChatConfigEventKind, entityID uuid.UUID) {
|
|
payload, err := json.Marshal(pubsub.ChatConfigEvent{
|
|
Kind: kind,
|
|
EntityID: entityID,
|
|
})
|
|
if err != nil {
|
|
logger.Error(context.Background(), "failed to marshal chat config event",
|
|
slog.F("kind", kind),
|
|
slog.F("entity_id", entityID),
|
|
slog.Error(err),
|
|
)
|
|
return
|
|
}
|
|
if err := ps.Publish(pubsub.ChatConfigEventChannel, payload); err != nil {
|
|
logger.Error(context.Background(), "failed to publish chat config event",
|
|
slog.F("kind", kind),
|
|
slog.F("entity_id", entityID),
|
|
slog.Error(err),
|
|
)
|
|
}
|
|
}
|
|
|
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
|
//
|
|
// @Summary Watch chat events for a user via WebSockets
|
|
// @ID watch-chat-events-for-a-user-via-websockets
|
|
// @Security CoderSessionToken
|
|
// @Tags Chats
|
|
// @Produce json
|
|
// @Success 200 {object} codersdk.ChatWatchEvent
|
|
// @Router /api/experimental/chats/watch [get]
|
|
// @Description Experimental: this endpoint is subject to change.
|
|
func (api *API) watchChats(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
apiKey := httpmw.APIKey(r)
|
|
logger := api.Logger.Named("chat_watcher")
|
|
|
|
// Subscribe before accepting the websocket so the subscription
|
|
// is active when the client's Dial returns.
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
defer cancel()
|
|
|
|
var (
|
|
encoder *json.Encoder
|
|
encoderReady = make(chan struct{})
|
|
// Capture before WebsocketNetConn reassigns ctx (data race).
|
|
ctxDone = ctx.Done()
|
|
)
|
|
|
|
cancelSubscribe, err := api.Pubsub.SubscribeWithErr(pubsub.ChatWatchEventChannel(apiKey.UserID),
|
|
pubsub.HandleChatWatchEvent(
|
|
func(cbCtx context.Context, payload codersdk.ChatWatchEvent, err error) {
|
|
if err != nil {
|
|
logger.Error(cbCtx, "chat watch event subscription error", slog.Error(err))
|
|
return
|
|
}
|
|
select {
|
|
case <-encoderReady:
|
|
case <-ctxDone:
|
|
return
|
|
case <-cbCtx.Done():
|
|
return
|
|
}
|
|
|
|
// encoderReady may close with encoder still nil on error paths.
|
|
if encoder == nil {
|
|
return
|
|
}
|
|
// The encoder is only written from the pubsub delivery
|
|
// goroutine, which processes messages serially. Do not
|
|
// add a second write path without synchronization.
|
|
if err := encoder.Encode(payload); err != nil {
|
|
logger.Debug(cbCtx, "failed to send chat watch event", slog.Error(err))
|
|
cancel()
|
|
return
|
|
}
|
|
},
|
|
))
|
|
if err != nil {
|
|
close(encoderReady)
|
|
logger.Error(ctx, "failed to subscribe to chat watch events", slog.Error(err))
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to subscribe to chat events.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
defer cancelSubscribe()
|
|
|
|
conn, err := websocket.Accept(rw, r, nil)
|
|
if err != nil {
|
|
close(encoderReady)
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to open chat watch stream.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
_ = conn.CloseRead(context.Background())
|
|
|
|
ctx, wsNetConn := codersdk.WebsocketNetConn(ctx, conn, websocket.MessageText)
|
|
defer wsNetConn.Close()
|
|
|
|
go httpapi.HeartbeatClose(ctx, logger, cancel, conn)
|
|
|
|
encoder = json.NewEncoder(wsNetConn)
|
|
close(encoderReady)
|
|
|
|
<-ctx.Done()
|
|
}
|
|
|
|
// EXPERIMENTAL: chatsByWorkspace returns a mapping of workspace ID to
|
|
// the latest non-archived chat ID for each requested workspace.
|
|
// The query returns all matching chats and RBAC post-filters them;
|
|
// the handler then picks the latest per workspace in Go. This avoids
|
|
// the DISTINCT ON + post-filter bug where the sole candidate is
|
|
// silently dropped when the caller can't read it.
|
|
//
|
|
// TODO:
|
|
// 1. move aggregation to a SQL view with proper in-query authz so we
|
|
// can return a single row per workspace without this two-pass approach.
|
|
// 2. Restore the below router annotation and un-skip docs gen
|
|
// <at>Router /api/experimental/chats/by-workspace [post]
|
|
//
|
|
// @Summary Get latest chats by workspace IDs
|
|
// @ID get-latest-chats-by-workspace-ids
|
|
// @Security CoderSessionToken
|
|
// @Tags Chats
|
|
// @Accept json
|
|
// @Produce json
|
|
// @Success 200
|
|
// @x-apidocgen {"skip": true}
|
|
func (api *API) chatsByWorkspace(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
|
|
idsParam := r.URL.Query().Get("workspace_ids")
|
|
if idsParam == "" {
|
|
httpapi.Write(ctx, rw, http.StatusOK, map[uuid.UUID]uuid.UUID{})
|
|
return
|
|
}
|
|
|
|
raw := strings.Split(idsParam, ",")
|
|
|
|
// maxWorkspaceIDs is coupled to DEFAULT_RECORDS_PER_PAGE (25) in
|
|
// site/src/components/PaginationWidget/utils.ts.
|
|
// If the page size changes, this limit should too.
|
|
const maxWorkspaceIDs = 25
|
|
if len(raw) > maxWorkspaceIDs {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: fmt.Sprintf("Too many workspace IDs, maximum is %d.", maxWorkspaceIDs),
|
|
})
|
|
return
|
|
}
|
|
|
|
workspaceIDs := make([]uuid.UUID, 0, len(raw))
|
|
for _, s := range raw {
|
|
id, err := uuid.Parse(strings.TrimSpace(s))
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: fmt.Sprintf("Invalid workspace ID %q: %s", s, err),
|
|
})
|
|
return
|
|
}
|
|
workspaceIDs = append(workspaceIDs, id)
|
|
}
|
|
|
|
chats, err := api.Database.GetChatsByWorkspaceIDs(ctx, workspaceIDs)
|
|
if httpapi.Is404Error(err) {
|
|
httpapi.ResourceNotFound(rw)
|
|
return
|
|
} else if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to get chats by workspace.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
// The SQL orders by (workspace_id, updated_at DESC), so the first
|
|
// chat seen per workspace after RBAC filtering is the latest
|
|
// readable one.
|
|
result := make(map[uuid.UUID]uuid.UUID, len(chats))
|
|
for _, chat := range chats {
|
|
if chat.WorkspaceID.Valid {
|
|
if _, exists := result[chat.WorkspaceID.UUID]; !exists {
|
|
result[chat.WorkspaceID.UUID] = chat.ID
|
|
}
|
|
}
|
|
}
|
|
|
|
httpapi.Write(ctx, rw, http.StatusOK, result)
|
|
}
|
|
|
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
|
//
|
|
// @Summary List chats
|
|
// @ID list-chats
|
|
// @Security CoderSessionToken
|
|
// @Tags Chats
|
|
// @Produce json
|
|
// @Param q query string false "Search query. Supports title:<substring> (case-insensitive, quote multi-word values), archived:bool, has_unread:bool, pr_status:<draft\|open\|merged\|closed> as repeated or comma-separated values, diff_url:<url> (quote values containing colons), pr:<number> (exact PR number match), repo:<owner/repo> (case-insensitive substring match against git remote origin or URL), pr_title:<text> (case-insensitive PR title substring). Bare terms are not supported; use title:<value> for title filtering."
|
|
// @Param label query string false "Filter by label as key:value. Repeat for multiple (AND logic)."
|
|
// @Success 200 {array} codersdk.Chat
|
|
// @Router /api/experimental/chats [get]
|
|
// @Description Experimental: this endpoint is subject to change.
|
|
func (api *API) listChats(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
apiKey := httpmw.APIKey(r)
|
|
|
|
paginationParams, ok := ParsePagination(rw, r)
|
|
if !ok {
|
|
return
|
|
}
|
|
|
|
queryStr := r.URL.Query().Get("q")
|
|
searchParams, errs := searchquery.Chats(queryStr)
|
|
if len(errs) > 0 {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid chat search query.",
|
|
Validations: errs,
|
|
})
|
|
return
|
|
}
|
|
|
|
var labelFilter pqtype.NullRawMessage
|
|
if labelParams := r.URL.Query()["label"]; len(labelParams) > 0 {
|
|
labelMap := make(map[string]string, len(labelParams))
|
|
for _, lp := range labelParams {
|
|
key, value, ok := strings.Cut(lp, ":")
|
|
if !ok || key == "" || value == "" {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: fmt.Sprintf("Invalid label filter: %q (expected format key:value, both must be non-empty)", lp),
|
|
})
|
|
return
|
|
}
|
|
labelMap[key] = value
|
|
}
|
|
labelsJSON, err := json.Marshal(labelMap)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to marshal label filter.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
labelFilter = pqtype.NullRawMessage{
|
|
RawMessage: labelsJSON,
|
|
Valid: true,
|
|
}
|
|
}
|
|
|
|
params := database.GetChatsParams{
|
|
OwnedOnly: true,
|
|
ViewerID: apiKey.UserID,
|
|
Archived: searchParams.Archived,
|
|
AfterID: paginationParams.AfterID,
|
|
LabelFilter: labelFilter,
|
|
DiffURL: searchParams.DiffURL,
|
|
TitleQuery: searchParams.TitleQuery,
|
|
HasUnread: searchParams.HasUnread,
|
|
PullRequestStatuses: searchParams.PullRequestStatuses,
|
|
PrNumber: searchParams.PrNumber,
|
|
RepoQuery: searchParams.RepoQuery,
|
|
PrTitleQuery: searchParams.PrTitleQuery,
|
|
// #nosec G115 - Pagination offsets are small and fit in int32
|
|
OffsetOpt: int32(paginationParams.Offset),
|
|
// #nosec G115 - Pagination limits are small and fit in int32
|
|
LimitOpt: int32(paginationParams.Limit),
|
|
}
|
|
|
|
chatRows, err := api.Database.GetChats(ctx, params)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to list chats.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
// Collect root chat IDs so we can fetch their children.
|
|
rootIDs := make([]uuid.UUID, len(chatRows))
|
|
for i, row := range chatRows {
|
|
rootIDs[i] = row.Chat.ID
|
|
}
|
|
|
|
// Embed children matching the caller's archive filter so
|
|
// sidebar views don't surface state-mismatched rows.
|
|
var childRows []database.GetChildChatsByParentIDsRow
|
|
if len(rootIDs) > 0 {
|
|
childRows, err = api.Database.GetChildChatsByParentIDs(ctx, database.GetChildChatsByParentIDsParams{
|
|
ParentIds: rootIDs,
|
|
Archived: searchParams.Archived,
|
|
})
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to list child chats.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
}
|
|
|
|
// Collect all chat objects (root + child) for diff status lookup.
|
|
allChats := make([]database.Chat, 0, len(chatRows)+len(childRows))
|
|
for _, row := range chatRows {
|
|
allChats = append(allChats, row.Chat)
|
|
}
|
|
for _, row := range childRows {
|
|
allChats = append(allChats, row.Chat)
|
|
}
|
|
|
|
diffStatusesByChatID, err := api.getChatDiffStatusesByChatID(ctx, allChats)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to list chats.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
httpapi.Write(ctx, rw, http.StatusOK, db2sdk.ChatRowsWithChildren(chatRows, childRows, diffStatusesByChatID))
|
|
}
|
|
|
|
func (api *API) getChatDiffStatusesByChatID(
|
|
ctx context.Context,
|
|
chats []database.Chat,
|
|
) (map[uuid.UUID]database.ChatDiffStatus, error) {
|
|
if len(chats) == 0 {
|
|
return map[uuid.UUID]database.ChatDiffStatus{}, nil
|
|
}
|
|
|
|
chatIDs := make([]uuid.UUID, 0, len(chats))
|
|
for _, chat := range chats {
|
|
chatIDs = append(chatIDs, chat.ID)
|
|
}
|
|
|
|
statuses, err := api.Database.GetChatDiffStatusesByChatIDs(ctx, chatIDs)
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("get chat diff statuses: %w", err)
|
|
}
|
|
|
|
statusesByChatID := make(map[uuid.UUID]database.ChatDiffStatus, len(statuses))
|
|
for _, status := range statuses {
|
|
statusesByChatID[status.ChatID] = status
|
|
}
|
|
return statusesByChatID, nil
|
|
}
|
|
|
|
func planModeToNullChatPlanMode(mode codersdk.ChatPlanMode) database.NullChatPlanMode {
|
|
if mode == "" {
|
|
return database.NullChatPlanMode{}
|
|
}
|
|
return database.NullChatPlanMode{
|
|
ChatPlanMode: database.ChatPlanMode(mode),
|
|
Valid: true,
|
|
}
|
|
}
|
|
|
|
func validateChatPlanMode(mode codersdk.ChatPlanMode) bool {
|
|
switch mode {
|
|
case "", codersdk.ChatPlanModePlan:
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
func parseChatModelOverride(raw string) (*uuid.UUID, error) {
|
|
trimmed := strings.TrimSpace(raw)
|
|
if trimmed == "" {
|
|
//nolint:nilnil // Empty site-config value means the override is unset.
|
|
return nil, nil
|
|
}
|
|
modelConfigID, err := uuid.Parse(trimmed)
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("parse chat model override: %w", err)
|
|
}
|
|
return &modelConfigID, nil
|
|
}
|
|
|
|
func formatChatModelOverride(id *uuid.UUID) string {
|
|
if id == nil {
|
|
return ""
|
|
}
|
|
return id.String()
|
|
}
|
|
|
|
func lookupEnabledChatModelConfigByID(
|
|
ctx context.Context,
|
|
db database.Store,
|
|
id uuid.UUID,
|
|
) (database.ChatModelConfig, error) {
|
|
//nolint:gocritic // Validation lookup uses AsChatd to check model
|
|
// availability independently of the caller's read permissions.
|
|
return db.GetEnabledChatModelConfigByID(dbauthz.AsChatd(ctx), id)
|
|
}
|
|
|
|
func validateChatModelOverrideID(
|
|
ctx context.Context,
|
|
db database.Store,
|
|
id *uuid.UUID,
|
|
) (int, *codersdk.Response) {
|
|
if id == nil {
|
|
return 0, nil
|
|
}
|
|
if *id == uuid.Nil {
|
|
return http.StatusBadRequest, &codersdk.Response{
|
|
Message: "Invalid model_config_id.",
|
|
}
|
|
}
|
|
_, err := lookupEnabledChatModelConfigByID(ctx, db, *id)
|
|
if err == nil {
|
|
return 0, nil
|
|
}
|
|
if xerrors.Is(err, sql.ErrNoRows) {
|
|
return http.StatusBadRequest, &codersdk.Response{
|
|
Message: "Invalid model_config_id.",
|
|
}
|
|
}
|
|
return http.StatusInternalServerError, &codersdk.Response{
|
|
Message: "Internal error validating model config override.",
|
|
Detail: err.Error(),
|
|
}
|
|
}
|
|
|
|
func (api *API) getChatModelOverrideConfig(
|
|
ctx context.Context,
|
|
settingName string,
|
|
getter func(context.Context) (string, error),
|
|
) (*uuid.UUID, bool, error) {
|
|
raw, err := getter(ctx)
|
|
if err != nil {
|
|
return nil, false, xerrors.Errorf("get %s model override: %w", settingName, err)
|
|
}
|
|
id, err := parseChatModelOverride(raw)
|
|
if err != nil {
|
|
// Degrade malformed values to unset so the admin settings page
|
|
// remains accessible and the bad value can be cleared.
|
|
api.Logger.Warn(
|
|
ctx,
|
|
"malformed model override in site config, treating as unset",
|
|
slog.F("setting", settingName),
|
|
slog.F("raw_value", raw),
|
|
slog.Error(err),
|
|
)
|
|
return nil, true, nil
|
|
}
|
|
return id, false, nil
|
|
}
|
|
|
|
func parseChatModelOverrideContext(raw string) (codersdk.ChatModelOverrideContext, error) {
|
|
overrideContext := codersdk.ChatModelOverrideContext(raw)
|
|
if overrideContext.Valid() {
|
|
return overrideContext, nil
|
|
}
|
|
return "", xerrors.Errorf("unknown chat model override context %q", raw)
|
|
}
|
|
|
|
type chatModelOverrideSiteConfig struct {
|
|
label string
|
|
getter func(context.Context) (string, error)
|
|
upsert func(context.Context, string) error
|
|
}
|
|
|
|
func (api *API) chatModelOverrideSiteConfig(
|
|
overrideContext codersdk.ChatModelOverrideContext,
|
|
) (chatModelOverrideSiteConfig, error) {
|
|
switch overrideContext {
|
|
case codersdk.ChatModelOverrideContextGeneral:
|
|
return chatModelOverrideSiteConfig{
|
|
label: "general",
|
|
getter: api.Database.GetChatGeneralModelOverride,
|
|
upsert: api.Database.UpsertChatGeneralModelOverride,
|
|
}, nil
|
|
case codersdk.ChatModelOverrideContextExplore:
|
|
return chatModelOverrideSiteConfig{
|
|
label: "explore",
|
|
getter: api.Database.GetChatExploreModelOverride,
|
|
upsert: api.Database.UpsertChatExploreModelOverride,
|
|
}, nil
|
|
case codersdk.ChatModelOverrideContextTitleGeneration:
|
|
return chatModelOverrideSiteConfig{
|
|
label: "title generation",
|
|
getter: api.Database.GetChatTitleGenerationModelOverride,
|
|
upsert: api.Database.UpsertChatTitleGenerationModelOverride,
|
|
}, nil
|
|
default:
|
|
return chatModelOverrideSiteConfig{}, xerrors.Errorf(
|
|
"unknown chat model override context %q",
|
|
overrideContext,
|
|
)
|
|
}
|
|
}
|
|
|
|
func (api *API) readChatModelOverrideConfig(
|
|
ctx context.Context,
|
|
overrideContext codersdk.ChatModelOverrideContext,
|
|
) (*uuid.UUID, bool, string, error) {
|
|
siteConfig, err := api.chatModelOverrideSiteConfig(overrideContext)
|
|
if err != nil {
|
|
return nil, false, "", err
|
|
}
|
|
id, isMalformed, err := api.getChatModelOverrideConfig(ctx, siteConfig.label, siteConfig.getter)
|
|
return id, isMalformed, siteConfig.label, err
|
|
}
|
|
|
|
func (api *API) upsertChatModelOverrideConfig(
|
|
ctx context.Context,
|
|
overrideContext codersdk.ChatModelOverrideContext,
|
|
modelConfigID *uuid.UUID,
|
|
) (string, error) {
|
|
siteConfig, err := api.chatModelOverrideSiteConfig(overrideContext)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return siteConfig.label, siteConfig.upsert(ctx, formatChatModelOverride(modelConfigID))
|
|
}
|
|
|
|
var chatPersonalModelOverrideContexts = []codersdk.ChatPersonalModelOverrideContext{
|
|
codersdk.ChatPersonalModelOverrideContextRoot,
|
|
codersdk.ChatPersonalModelOverrideContextGeneral,
|
|
codersdk.ChatPersonalModelOverrideContextExplore,
|
|
}
|
|
|
|
func parseChatPersonalModelOverrideContext(raw string) (codersdk.ChatPersonalModelOverrideContext, bool) {
|
|
c := codersdk.ChatPersonalModelOverrideContext(raw)
|
|
return c, slices.Contains(chatPersonalModelOverrideContexts, c)
|
|
}
|
|
|
|
func chatPersonalModelOverrideContextsJoined() string {
|
|
values := make([]string, 0, len(chatPersonalModelOverrideContexts))
|
|
for _, overrideContext := range chatPersonalModelOverrideContexts {
|
|
values = append(values, string(overrideContext))
|
|
}
|
|
return strings.Join(values, ", ")
|
|
}
|
|
|
|
func defaultChatPersonalModelOverrideMode(
|
|
overrideContext codersdk.ChatPersonalModelOverrideContext,
|
|
) codersdk.ChatPersonalModelOverrideMode {
|
|
if overrideContext == codersdk.ChatPersonalModelOverrideContextRoot {
|
|
return codersdk.ChatPersonalModelOverrideModeChatDefault
|
|
}
|
|
return codersdk.ChatPersonalModelOverrideModeDeploymentDefault
|
|
}
|
|
|
|
func parseChatPersonalModelOverrideValue(
|
|
raw string,
|
|
overrideContext codersdk.ChatPersonalModelOverrideContext,
|
|
) chatd.ParsedChatPersonalModelOverride {
|
|
defaultMode := defaultChatPersonalModelOverrideMode(overrideContext)
|
|
parsed := chatd.ParseChatPersonalModelOverride(raw, defaultMode)
|
|
if overrideContext == codersdk.ChatPersonalModelOverrideContextRoot &&
|
|
parsed.Mode == codersdk.ChatPersonalModelOverrideModeDeploymentDefault {
|
|
return chatd.ParsedChatPersonalModelOverride{
|
|
Mode: defaultMode,
|
|
Malformed: true,
|
|
}
|
|
}
|
|
return parsed
|
|
}
|
|
|
|
func formatChatPersonalModelOverrideValue(
|
|
mode codersdk.ChatPersonalModelOverrideMode,
|
|
modelConfigID string,
|
|
) string {
|
|
if mode == codersdk.ChatPersonalModelOverrideModeModel {
|
|
return string(mode) + ":" + strings.TrimSpace(modelConfigID)
|
|
}
|
|
return string(mode)
|
|
}
|
|
|
|
func chatPersonalModelOverrideResponse(
|
|
overrideContext codersdk.ChatPersonalModelOverrideContext,
|
|
raw string,
|
|
isSet bool,
|
|
) codersdk.ChatPersonalModelOverride {
|
|
parsed := parseChatPersonalModelOverrideValue(raw, overrideContext)
|
|
modelConfigID := ""
|
|
if parsed.Mode == codersdk.ChatPersonalModelOverrideModeModel {
|
|
modelConfigID = parsed.ModelConfigID.String()
|
|
}
|
|
return codersdk.ChatPersonalModelOverride{
|
|
Context: overrideContext,
|
|
Mode: parsed.Mode,
|
|
ModelConfigID: modelConfigID,
|
|
IsSet: isSet,
|
|
IsMalformed: parsed.Malformed,
|
|
}
|
|
}
|
|
|
|
func (api *API) chatPersonalModelOverrideDeploymentDefaultResponse(
|
|
ctx context.Context,
|
|
overrideContext codersdk.ChatModelOverrideContext,
|
|
) (codersdk.ChatModelOverrideResponse, error) {
|
|
// The deployment defaults are global chat configuration, not user-owned
|
|
// resources. Users may read these values here because the personal settings
|
|
// UI must explain what deployment_default resolves to.
|
|
//nolint:gocritic // System context is required to read deployment config.
|
|
modelConfigID, isMalformed, _, err := api.readChatModelOverrideConfig(
|
|
dbauthz.AsSystemRestricted(ctx),
|
|
overrideContext,
|
|
)
|
|
if err != nil {
|
|
return codersdk.ChatModelOverrideResponse{}, err
|
|
}
|
|
return codersdk.ChatModelOverrideResponse{
|
|
Context: overrideContext,
|
|
ModelConfigID: formatChatModelOverride(modelConfigID),
|
|
IsMalformed: isMalformed,
|
|
}, nil
|
|
}
|
|
|
|
func (api *API) chatPersonalModelOverrideDeploymentDefaults(
|
|
ctx context.Context,
|
|
) (codersdk.ChatPersonalModelOverrideDeploymentDefaults, error) {
|
|
general, err := api.chatPersonalModelOverrideDeploymentDefaultResponse(
|
|
ctx,
|
|
codersdk.ChatModelOverrideContextGeneral,
|
|
)
|
|
if err != nil {
|
|
return codersdk.ChatPersonalModelOverrideDeploymentDefaults{}, err
|
|
}
|
|
explore, err := api.chatPersonalModelOverrideDeploymentDefaultResponse(
|
|
ctx,
|
|
codersdk.ChatModelOverrideContextExplore,
|
|
)
|
|
if err != nil {
|
|
return codersdk.ChatPersonalModelOverrideDeploymentDefaults{}, err
|
|
}
|
|
return codersdk.ChatPersonalModelOverrideDeploymentDefaults{
|
|
General: general,
|
|
Explore: explore,
|
|
}, nil
|
|
}
|
|
|
|
type userChatModelAvailability struct {
|
|
configuredProviders []chatprovider.ConfiguredProvider
|
|
configuredModels []chatprovider.ConfiguredModel
|
|
enabledModels []database.ChatModelConfig
|
|
providerStatus map[string]chatprovider.ProviderAvailability
|
|
providerStatusByID map[uuid.UUID]chatprovider.ProviderAvailability
|
|
enabledProviderNames map[string]struct{}
|
|
enabledProviderIDs map[uuid.UUID]struct{}
|
|
}
|
|
|
|
// chatModelConfigUnavailableReason reports why a model config cannot be used.
|
|
// The empty value means the model config is available. Callers must check the
|
|
// error returned by userCanUseChatModelConfig before interpreting this value.
|
|
type chatModelConfigUnavailableReason string
|
|
|
|
const (
|
|
chatModelConfigAvailable chatModelConfigUnavailableReason = ""
|
|
chatModelConfigUnavailableModelNotFoundOrDisabled chatModelConfigUnavailableReason = "model_not_found_or_disabled"
|
|
chatModelConfigUnavailableProviderDisabled chatModelConfigUnavailableReason = "provider_disabled"
|
|
chatModelConfigUnavailableCredentialsMissing chatModelConfigUnavailableReason = "credentials_missing"
|
|
)
|
|
|
|
// getUserChatProviderAvailability returns the enabled chat providers and models
|
|
// the user can access. Deployment-level configuration is read as chatd, while
|
|
// user key lookups still use the caller's authorization context.
|
|
func (api *API) getUserChatProviderAvailability(
|
|
ctx context.Context,
|
|
userID uuid.UUID,
|
|
) (userChatModelAvailability, error) {
|
|
//nolint:gocritic // Chatd context is required to read enabled chat config.
|
|
chatdCtx := dbauthz.AsChatd(ctx)
|
|
enabledProviders, err := api.Database.GetAIProviders(chatdCtx, database.GetAIProvidersParams{})
|
|
if err != nil {
|
|
return userChatModelAvailability{}, err
|
|
}
|
|
enabledModels, err := api.Database.GetEnabledChatModelConfigs(chatdCtx)
|
|
if err != nil {
|
|
return userChatModelAvailability{}, err
|
|
}
|
|
|
|
configuredProviders, err := api.configuredProvidersFromAIProviders(chatdCtx, enabledProviders)
|
|
if err != nil {
|
|
return userChatModelAvailability{}, err
|
|
}
|
|
availability := userChatModelAvailability{
|
|
configuredProviders: configuredProviders,
|
|
configuredModels: make([]chatprovider.ConfiguredModel, 0, len(enabledModels)),
|
|
enabledModels: enabledModels,
|
|
enabledProviderNames: make(map[string]struct{}, len(enabledProviders)),
|
|
enabledProviderIDs: make(map[uuid.UUID]struct{}, len(enabledProviders)),
|
|
providerStatusByID: make(map[uuid.UUID]chatprovider.ProviderAvailability, len(enabledProviders)),
|
|
}
|
|
for _, configuredProvider := range configuredProviders {
|
|
normalizedProvider := chatprovider.NormalizeProvider(configuredProvider.Provider)
|
|
if normalizedProvider != "" {
|
|
availability.enabledProviderNames[normalizedProvider] = struct{}{}
|
|
}
|
|
if configuredProvider.ProviderID != uuid.Nil {
|
|
availability.enabledProviderIDs[configuredProvider.ProviderID] = struct{}{}
|
|
}
|
|
}
|
|
userKeys := []chatprovider.UserProviderKey{}
|
|
if api.DeploymentValues.AI.BridgeConfig.AllowBYOK.Value() {
|
|
userKeyRows, err := api.Database.GetUserAIProviderKeysByUserID(ctx, userID)
|
|
if err != nil {
|
|
return userChatModelAvailability{}, err
|
|
}
|
|
userKeys = make([]chatprovider.UserProviderKey, 0, len(userKeyRows))
|
|
for _, userKey := range userKeyRows {
|
|
userKeys = append(userKeys, chatprovider.UserProviderKey{
|
|
ChatProviderID: userKey.AIProviderID,
|
|
APIKey: userKey.APIKey,
|
|
})
|
|
}
|
|
}
|
|
|
|
fallbackKeys := ChatProviderAPIKeysFromDeploymentValues(api.DeploymentValues)
|
|
mergeProviderStatus := func(
|
|
statuses map[string]chatprovider.ProviderAvailability,
|
|
normalizedProvider string,
|
|
status chatprovider.ProviderAvailability,
|
|
) {
|
|
current, ok := statuses[normalizedProvider]
|
|
if !ok || (!current.Available && status.Available) {
|
|
statuses[normalizedProvider] = status
|
|
}
|
|
}
|
|
|
|
providerStatusByType := make(map[string]chatprovider.ProviderAvailability, len(availability.configuredProviders))
|
|
for _, configuredProvider := range availability.configuredProviders {
|
|
normalizedProvider := chatprovider.NormalizeProvider(configuredProvider.Provider)
|
|
if normalizedProvider == "" {
|
|
continue
|
|
}
|
|
_, providerStatus := chatprovider.ResolveUserProviderKeys(
|
|
fallbackKeys,
|
|
[]chatprovider.ConfiguredProvider{configuredProvider},
|
|
userKeys,
|
|
)
|
|
status, ok := providerStatus[normalizedProvider]
|
|
if !ok {
|
|
continue
|
|
}
|
|
if configuredProvider.ProviderID != uuid.Nil {
|
|
availability.providerStatusByID[configuredProvider.ProviderID] = status
|
|
}
|
|
mergeProviderStatus(providerStatusByType, normalizedProvider, status)
|
|
}
|
|
|
|
modelStatusByType := make(map[string]chatprovider.ProviderAvailability, len(enabledModels))
|
|
for _, model := range enabledModels {
|
|
normalizedProvider := chatprovider.NormalizeProvider(model.Provider)
|
|
if normalizedProvider == "" {
|
|
continue
|
|
}
|
|
if model.AIProviderID.Valid {
|
|
status, ok := availability.providerStatusByID[model.AIProviderID.UUID]
|
|
if ok {
|
|
mergeProviderStatus(modelStatusByType, normalizedProvider, status)
|
|
}
|
|
continue
|
|
}
|
|
if status, ok := providerStatusByType[normalizedProvider]; ok {
|
|
mergeProviderStatus(modelStatusByType, normalizedProvider, status)
|
|
}
|
|
}
|
|
availability.providerStatus = providerStatusByType
|
|
for provider, status := range modelStatusByType {
|
|
availability.providerStatus[provider] = status
|
|
}
|
|
|
|
for _, model := range enabledModels {
|
|
normalizedProvider := chatprovider.NormalizeProvider(model.Provider)
|
|
if model.AIProviderID.Valid {
|
|
status, ok := availability.providerStatusByID[model.AIProviderID.UUID]
|
|
if !ok {
|
|
continue
|
|
}
|
|
if aggregateStatus, ok := availability.providerStatus[normalizedProvider]; ok && aggregateStatus.Available && !status.Available {
|
|
continue
|
|
}
|
|
}
|
|
availability.configuredModels = append(availability.configuredModels, chatprovider.ConfiguredModel{
|
|
Provider: model.Provider,
|
|
Model: model.Model,
|
|
DisplayName: model.DisplayName,
|
|
})
|
|
}
|
|
return availability, nil
|
|
}
|
|
|
|
// userCanUseChatModelConfig returns chatModelConfigAvailable when the user can
|
|
// use the model config. If err is non-nil, callers must ignore the returned
|
|
// reason because it may be the zero-value availability sentinel.
|
|
func (api *API) userCanUseChatModelConfig(
|
|
ctx context.Context,
|
|
userID uuid.UUID,
|
|
modelConfigID uuid.UUID,
|
|
) (chatModelConfigUnavailableReason, error) {
|
|
if modelConfigID == uuid.Nil {
|
|
return chatModelConfigUnavailableModelNotFoundOrDisabled, nil
|
|
}
|
|
//nolint:gocritic // Non-admin users need deployment config validation.
|
|
model, err := api.Database.GetChatModelConfigByID(
|
|
dbauthz.AsSystemRestricted(ctx),
|
|
modelConfigID,
|
|
)
|
|
if err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) || httpapi.Is404Error(err) {
|
|
return chatModelConfigUnavailableModelNotFoundOrDisabled, nil
|
|
}
|
|
return chatModelConfigAvailable, err
|
|
}
|
|
if !model.Enabled {
|
|
return chatModelConfigUnavailableModelNotFoundOrDisabled, nil
|
|
}
|
|
|
|
availability, err := api.getUserChatProviderAvailability(ctx, userID)
|
|
if err != nil {
|
|
return chatModelConfigAvailable, err
|
|
}
|
|
if model.AIProviderID.Valid {
|
|
providerID := model.AIProviderID.UUID
|
|
if _, ok := availability.enabledProviderIDs[providerID]; !ok {
|
|
return chatModelConfigUnavailableProviderDisabled, nil
|
|
}
|
|
providerStatus, ok := availability.providerStatusByID[providerID]
|
|
if !ok {
|
|
return chatModelConfigUnavailableProviderDisabled, nil
|
|
}
|
|
if !providerStatus.Available {
|
|
return chatModelConfigUnavailableCredentialsMissing, nil
|
|
}
|
|
return chatModelConfigAvailable, nil
|
|
}
|
|
provider, _, err := chatprovider.ResolveModelWithProviderHint(model.Model, model.Provider)
|
|
if err != nil {
|
|
return chatModelConfigUnavailableProviderDisabled, nil
|
|
}
|
|
if _, ok := availability.enabledProviderNames[provider]; !ok {
|
|
return chatModelConfigUnavailableProviderDisabled, nil
|
|
}
|
|
providerStatus, ok := availability.providerStatus[provider]
|
|
if !ok {
|
|
return chatModelConfigUnavailableProviderDisabled, nil
|
|
}
|
|
if !providerStatus.Available {
|
|
return chatModelConfigUnavailableCredentialsMissing, nil
|
|
}
|
|
return chatModelConfigAvailable, nil
|
|
}
|
|
|
|
func (api *API) validateUserChatModelConfigAvailable(
|
|
ctx context.Context,
|
|
userID uuid.UUID,
|
|
modelConfigID uuid.UUID,
|
|
) (int, *codersdk.Response) {
|
|
reason, err := api.userCanUseChatModelConfig(ctx, userID, modelConfigID)
|
|
if err != nil {
|
|
return http.StatusInternalServerError, &codersdk.Response{
|
|
Message: "Internal error validating model config override.",
|
|
Detail: err.Error(),
|
|
}
|
|
}
|
|
switch reason {
|
|
case chatModelConfigAvailable:
|
|
return 0, nil
|
|
case chatModelConfigUnavailableModelNotFoundOrDisabled:
|
|
return http.StatusBadRequest, &codersdk.Response{
|
|
Message: "Invalid model_config_id: model config not found or disabled.",
|
|
}
|
|
case chatModelConfigUnavailableCredentialsMissing:
|
|
return http.StatusBadRequest, &codersdk.Response{
|
|
Message: "Invalid model_config_id: provider credentials unavailable for this model.",
|
|
}
|
|
case chatModelConfigUnavailableProviderDisabled:
|
|
return http.StatusBadRequest, &codersdk.Response{
|
|
Message: "Invalid model_config_id: provider is not enabled for this model.",
|
|
}
|
|
default:
|
|
api.Logger.Warn(ctx,
|
|
"unknown chat model config availability reason",
|
|
slog.F("user_id", userID),
|
|
slog.F("model_config_id", modelConfigID),
|
|
slog.F("reason", reason),
|
|
)
|
|
return http.StatusBadRequest, &codersdk.Response{
|
|
Message: "Invalid model_config_id.",
|
|
}
|
|
}
|
|
}
|
|
|
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
|
//
|
|
// @Summary Create chat
|
|
// @ID create-chat
|
|
// @Security CoderSessionToken
|
|
// @Tags Chats
|
|
// @Accept json
|
|
// @Produce json
|
|
// @Param request body codersdk.CreateChatRequest true "Create chat request"
|
|
// @Success 201 {object} codersdk.Chat
|
|
// @Router /api/experimental/chats [post]
|
|
// @Description Experimental: this endpoint is subject to change.
|
|
func (api *API) postChats(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
apiKey := httpmw.APIKey(r)
|
|
|
|
// Cap the raw request body to prevent excessive memory use
|
|
// from large dynamic tool schemas.
|
|
r.Body = http.MaxBytesReader(rw, r.Body, int64(2*maxSystemPromptLenBytes))
|
|
|
|
var req codersdk.CreateChatRequest
|
|
if !httpapi.Read(ctx, rw, r, &req) {
|
|
return
|
|
}
|
|
|
|
aReq, commitAudit := audit.InitRequest[database.Chat](rw, &audit.RequestParams{
|
|
Audit: *api.Auditor.Load(),
|
|
Log: api.Logger,
|
|
Request: r,
|
|
Action: database.AuditActionCreate,
|
|
OrganizationID: req.OrganizationID,
|
|
})
|
|
defer commitAudit()
|
|
|
|
// Validate organization membership.
|
|
if req.OrganizationID == uuid.Nil {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "organization_id is required.",
|
|
})
|
|
return
|
|
}
|
|
isMember, err := httpmw.UserAuthorization(ctx).HasOrganizationMembership(req.OrganizationID)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to validate organization membership.",
|
|
Detail: xerrors.Errorf("check organization membership: %w", err).Error(),
|
|
})
|
|
return
|
|
}
|
|
if !isMember {
|
|
httpapi.Write(ctx, rw, http.StatusForbidden, codersdk.Response{
|
|
Message: "You are not a member of the specified organization.",
|
|
})
|
|
return
|
|
}
|
|
// NOTE: This authorize check is intentionally placed after request
|
|
// parsing because we need req.OrganizationID to scope the RBAC check
|
|
// to the correct org. The request body is bounded by MaxBytesReader
|
|
// above, limiting the cost of parsing before rejection.
|
|
if !api.Authorize(r, policy.ActionCreate, rbac.ResourceChat.WithOwner(apiKey.UserID.String()).InOrg(req.OrganizationID)) {
|
|
httpapi.Forbidden(rw)
|
|
return
|
|
}
|
|
|
|
// Validate per-chat system prompt length.
|
|
const maxSystemPromptLen = 10000
|
|
if len(req.SystemPrompt) > maxSystemPromptLen {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "System prompt exceeds maximum length.",
|
|
Detail: fmt.Sprintf("System prompt must be at most %d characters, got %d.", maxSystemPromptLen, len(req.SystemPrompt)),
|
|
})
|
|
return
|
|
}
|
|
contentBlocks, titleSource, fileIDs, inputError := createChatInputFromRequest(ctx, api.Database, req)
|
|
if inputError != nil {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, *inputError)
|
|
return
|
|
}
|
|
|
|
workspaceSelection, validationStatus, validationError := api.validateCreateChatWorkspaceSelection(ctx, r, req)
|
|
if validationError != nil {
|
|
httpapi.Write(ctx, rw, validationStatus, *validationError)
|
|
return
|
|
}
|
|
|
|
title := chatTitleFromMessage(titleSource)
|
|
|
|
if api.chatDaemon == nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Chat processor is unavailable.",
|
|
Detail: "Chat processor is not configured.",
|
|
})
|
|
return
|
|
}
|
|
|
|
modelConfigID, modelConfigStatus, modelConfigError := api.resolveCreateChatModelConfigID(ctx, apiKey.UserID, req)
|
|
if modelConfigError != nil {
|
|
httpapi.Write(ctx, rw, modelConfigStatus, *modelConfigError)
|
|
return
|
|
}
|
|
|
|
if !validateChatPlanMode(req.PlanMode) {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid plan_mode value.",
|
|
})
|
|
return
|
|
}
|
|
|
|
// Validate MCP server IDs exist.
|
|
if len(req.MCPServerIDs) > 0 {
|
|
//nolint:gocritic // Need to validate MCP server IDs exist.
|
|
existingConfigs, err := api.Database.GetMCPServerConfigsByIDs(dbauthz.AsSystemRestricted(ctx), req.MCPServerIDs)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to validate MCP server IDs.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
if len(existingConfigs) != len(req.MCPServerIDs) {
|
|
found := make(map[uuid.UUID]struct{}, len(existingConfigs))
|
|
for _, c := range existingConfigs {
|
|
found[c.ID] = struct{}{}
|
|
}
|
|
var missing []string
|
|
for _, id := range req.MCPServerIDs {
|
|
if _, ok := found[id]; !ok {
|
|
missing = append(missing, id.String())
|
|
}
|
|
}
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "One or more MCP server IDs are invalid.",
|
|
Detail: fmt.Sprintf("Invalid IDs: %s", strings.Join(missing, ", ")),
|
|
})
|
|
return
|
|
}
|
|
}
|
|
|
|
mcpServerIDs := req.MCPServerIDs
|
|
if mcpServerIDs == nil {
|
|
mcpServerIDs = []uuid.UUID{}
|
|
}
|
|
|
|
labels := req.Labels
|
|
if labels == nil {
|
|
labels = map[string]string{}
|
|
}
|
|
if errs := httpapi.ValidateChatLabels(labels); len(errs) > 0 {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid labels.",
|
|
Validations: errs,
|
|
})
|
|
return
|
|
}
|
|
|
|
if len(req.UnsafeDynamicTools) > 250 {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Too many dynamic tools.",
|
|
Detail: "Maximum 250 dynamic tools per chat.",
|
|
})
|
|
return
|
|
}
|
|
|
|
// Validate that dynamic tool names are non-empty and unique
|
|
// within the list. Name collision with built-in tools is
|
|
// checked at chatloop time when the full tool set is known.
|
|
if len(req.UnsafeDynamicTools) > 0 {
|
|
seenNames := make(map[string]struct{}, len(req.UnsafeDynamicTools))
|
|
for _, dt := range req.UnsafeDynamicTools {
|
|
if dt.Name == "" {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Dynamic tool name must not be empty.",
|
|
})
|
|
return
|
|
}
|
|
if _, exists := seenNames[dt.Name]; exists {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Duplicate dynamic tool name.",
|
|
Detail: fmt.Sprintf("Tool %q appears more than once.", dt.Name),
|
|
})
|
|
return
|
|
}
|
|
seenNames[dt.Name] = struct{}{}
|
|
}
|
|
}
|
|
|
|
var dynamicToolsJSON json.RawMessage
|
|
if len(req.UnsafeDynamicTools) > 0 {
|
|
var err error
|
|
dynamicToolsJSON, err = json.Marshal(req.UnsafeDynamicTools)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to marshal dynamic tools.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
}
|
|
|
|
clientType := database.ChatClientTypeApi
|
|
if req.ClientType != "" {
|
|
clientType = database.ChatClientType(req.ClientType)
|
|
if !clientType.Valid() {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid client_type.",
|
|
Detail: fmt.Sprintf("got %q, want one of %v", req.ClientType, database.AllChatClientTypeValues()),
|
|
})
|
|
return
|
|
}
|
|
}
|
|
|
|
chat, err := api.chatDaemon.CreateChat(ctx, chatd.CreateOptions{
|
|
OrganizationID: req.OrganizationID,
|
|
OwnerID: apiKey.UserID,
|
|
WorkspaceID: workspaceSelection.WorkspaceID,
|
|
Title: title,
|
|
ModelConfigID: modelConfigID,
|
|
PlanMode: planModeToNullChatPlanMode(req.PlanMode),
|
|
ClientType: clientType,
|
|
SystemPrompt: req.SystemPrompt,
|
|
InitialUserContent: contentBlocks,
|
|
MCPServerIDs: mcpServerIDs,
|
|
Labels: labels,
|
|
DynamicTools: dynamicToolsJSON,
|
|
// IMPORTANT: users can only create root chats at the time of writing.
|
|
ParentChatID: uuid.NullUUID{},
|
|
})
|
|
if err != nil {
|
|
if maybeWriteLimitErr(ctx, rw, err) {
|
|
return
|
|
}
|
|
if database.IsForeignKeyViolation(
|
|
err,
|
|
database.ForeignKeyChatsLastModelConfigID,
|
|
database.ForeignKeyChatMessagesModelConfigID,
|
|
) {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid model config ID.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
if dbauthz.IsNotAuthorizedError(err) {
|
|
httpapi.Forbidden(rw)
|
|
return
|
|
}
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to create chat.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
aReq.New = chat
|
|
|
|
if chat.ParentChatID.Valid {
|
|
// Should not be possible. If we get here, something is very wrong. Bail.
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Developer error: ParentChatID got set somehow in api.postChats. This should never happen.",
|
|
})
|
|
return
|
|
}
|
|
|
|
// Link any user-uploaded files referenced in the initial
|
|
// message to this newly created chat (best-effort; cap
|
|
// enforced in SQL).
|
|
unlinked, capExceeded := api.linkFilesToChat(ctx, chat.ID, fileIDs)
|
|
|
|
// Re-read the chat so the response reflects the authoritative
|
|
// database state (file links are deduped in the join table).
|
|
chat, err = api.Database.GetChatByID(ctx, chat.ID)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to read back chat after creation.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
aReq.New = chat
|
|
|
|
chatFiles := api.fetchChatFileMetadata(ctx, chat.ID)
|
|
response := db2sdk.Chat(chat, nil, chatFiles)
|
|
if len(unlinked) > 0 {
|
|
if capExceeded {
|
|
response.Warnings = append(response.Warnings, fileLinkCapWarning(len(unlinked)))
|
|
} else {
|
|
response.Warnings = append(response.Warnings, fileLinkErrorWarning(len(unlinked)))
|
|
}
|
|
}
|
|
httpapi.Write(ctx, rw, http.StatusCreated, response)
|
|
}
|
|
|
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
|
//
|
|
// @Summary List chat models
|
|
// @ID list-chat-models
|
|
// @Security CoderSessionToken
|
|
// @Tags Chats
|
|
// @Produce json
|
|
// @Success 200 {object} codersdk.ChatModelsResponse
|
|
// @Router /api/experimental/chats/models [get]
|
|
// @Description Experimental: this endpoint is subject to change.
|
|
func (api *API) listChatModels(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
apiKey := httpmw.APIKey(r)
|
|
if api.chatDaemon == nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Chat processor is unavailable.",
|
|
Detail: "Chat processor is not configured.",
|
|
})
|
|
return
|
|
}
|
|
|
|
availability, err := api.getUserChatProviderAvailability(ctx, apiKey.UserID)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to load chat model configuration.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
catalog := chatprovider.NewModelCatalog()
|
|
var response codersdk.ChatModelsResponse
|
|
if configured, ok := catalog.ListConfiguredModels(
|
|
availability.configuredProviders,
|
|
availability.configuredModels,
|
|
availability.providerStatus,
|
|
availability.enabledProviderNames,
|
|
); ok {
|
|
response = configured
|
|
} else {
|
|
response = catalog.ListConfiguredProviderAvailability(
|
|
availability.providerStatus,
|
|
availability.enabledProviderNames,
|
|
)
|
|
}
|
|
|
|
httpapi.Write(ctx, rw, http.StatusOK, response)
|
|
}
|
|
|
|
func (api *API) chatCostSummary(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
apiKey := httpmw.APIKey(r)
|
|
|
|
// Default date range: last 30 days.
|
|
now := time.Now()
|
|
defaultStart := now.AddDate(0, 0, -30)
|
|
|
|
qp := r.URL.Query()
|
|
p := httpapi.NewQueryParamParser()
|
|
startDate := p.Time(qp, defaultStart, "start_date", time.RFC3339)
|
|
endDate := p.Time(qp, now, "end_date", time.RFC3339)
|
|
p.ErrorExcessParams(qp)
|
|
if len(p.Errors) > 0 {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid query parameters.",
|
|
Validations: p.Errors,
|
|
})
|
|
return
|
|
}
|
|
|
|
targetUser := httpmw.UserParam(r)
|
|
if targetUser.ID != apiKey.UserID && !api.Authorize(r, policy.ActionRead, rbac.ResourceChat.WithOwner(targetUser.ID.String())) {
|
|
httpapi.Forbidden(rw)
|
|
return
|
|
}
|
|
|
|
summary, err := api.Database.GetChatCostSummary(ctx, database.GetChatCostSummaryParams{
|
|
OwnerID: targetUser.ID,
|
|
StartDate: startDate,
|
|
EndDate: endDate,
|
|
})
|
|
if err != nil {
|
|
if dbauthz.IsNotAuthorizedError(err) {
|
|
httpapi.Forbidden(rw)
|
|
return
|
|
}
|
|
httpapi.InternalServerError(rw, err)
|
|
return
|
|
}
|
|
|
|
byModel, err := api.Database.GetChatCostPerModel(ctx, database.GetChatCostPerModelParams{
|
|
OwnerID: targetUser.ID,
|
|
StartDate: startDate,
|
|
EndDate: endDate,
|
|
})
|
|
if err != nil {
|
|
if dbauthz.IsNotAuthorizedError(err) {
|
|
httpapi.Forbidden(rw)
|
|
return
|
|
}
|
|
httpapi.InternalServerError(rw, err)
|
|
return
|
|
}
|
|
|
|
byChat, err := api.Database.GetChatCostPerChat(ctx, database.GetChatCostPerChatParams{
|
|
OwnerID: targetUser.ID,
|
|
StartDate: startDate,
|
|
EndDate: endDate,
|
|
})
|
|
if err != nil {
|
|
if dbauthz.IsNotAuthorizedError(err) {
|
|
httpapi.Forbidden(rw)
|
|
return
|
|
}
|
|
httpapi.InternalServerError(rw, err)
|
|
return
|
|
}
|
|
|
|
modelBreakdowns := make([]codersdk.ChatCostModelBreakdown, 0, len(byModel))
|
|
for _, model := range byModel {
|
|
modelBreakdowns = append(modelBreakdowns, convertChatCostModelBreakdown(model))
|
|
}
|
|
|
|
chatBreakdowns := make([]codersdk.ChatCostChatBreakdown, 0, len(byChat))
|
|
for _, chat := range byChat {
|
|
chatBreakdowns = append(chatBreakdowns, convertChatCostChatBreakdown(chat))
|
|
}
|
|
|
|
// TODO(CODAGT-161): pass real organization ID
|
|
// when the HTTP endpoint supports org-scoped queries.
|
|
usageStatus, err := chatd.ResolveUsageLimitStatus(ctx, api.Database, targetUser.ID, uuid.NullUUID{}, time.Now())
|
|
if err != nil {
|
|
api.Logger.Warn(ctx, "failed to resolve usage limit status", slog.Error(err))
|
|
}
|
|
|
|
response := codersdk.ChatCostSummary{
|
|
StartDate: startDate,
|
|
EndDate: endDate,
|
|
TotalCostMicros: summary.TotalCostMicros,
|
|
PricedMessageCount: summary.PricedMessageCount,
|
|
UnpricedMessageCount: summary.UnpricedMessageCount,
|
|
TotalInputTokens: summary.TotalInputTokens,
|
|
TotalOutputTokens: summary.TotalOutputTokens,
|
|
TotalCacheReadTokens: summary.TotalCacheReadTokens,
|
|
TotalCacheCreationTokens: summary.TotalCacheCreationTokens,
|
|
TotalRuntimeMs: summary.TotalRuntimeMs,
|
|
ByModel: modelBreakdowns,
|
|
ByChat: chatBreakdowns,
|
|
}
|
|
if usageStatus != nil {
|
|
response.UsageLimit = usageStatus
|
|
}
|
|
|
|
httpapi.Write(ctx, rw, http.StatusOK, response)
|
|
}
|
|
|
|
func (api *API) chatCostUsers(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
if !api.Authorize(r, policy.ActionRead, rbac.ResourceChat) {
|
|
httpapi.Forbidden(rw)
|
|
return
|
|
}
|
|
|
|
now := time.Now()
|
|
defaultStart := now.AddDate(0, 0, -30)
|
|
|
|
qp := r.URL.Query()
|
|
p := httpapi.NewQueryParamParser()
|
|
startDate := p.Time(qp, defaultStart, "start_date", time.RFC3339)
|
|
endDate := p.Time(qp, now, "end_date", time.RFC3339)
|
|
username := strings.TrimSpace(p.String(qp, "", "username"))
|
|
limit := p.Int(qp, 10, "limit")
|
|
offset := p.Int(qp, 0, "offset")
|
|
p.ErrorExcessParams(qp)
|
|
if len(p.Errors) > 0 {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid query parameters.",
|
|
Validations: p.Errors,
|
|
})
|
|
return
|
|
}
|
|
if limit <= 0 {
|
|
limit = 10
|
|
}
|
|
if offset < 0 || offset > math.MaxInt32 || limit > math.MaxInt32 {
|
|
validations := make([]codersdk.ValidationError, 0, 2)
|
|
if offset < 0 {
|
|
validations = append(validations, codersdk.ValidationError{
|
|
Field: "offset",
|
|
Detail: "Must be greater than or equal to 0.",
|
|
})
|
|
}
|
|
if offset > math.MaxInt32 {
|
|
validations = append(validations, codersdk.ValidationError{
|
|
Field: "offset",
|
|
Detail: fmt.Sprintf("Must be less than or equal to %d.", math.MaxInt32),
|
|
})
|
|
}
|
|
if limit > math.MaxInt32 {
|
|
validations = append(validations, codersdk.ValidationError{
|
|
Field: "limit",
|
|
Detail: fmt.Sprintf("Must be less than or equal to %d.", math.MaxInt32),
|
|
})
|
|
}
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid query parameters.",
|
|
Validations: validations,
|
|
})
|
|
return
|
|
}
|
|
|
|
users, err := api.Database.GetChatCostPerUser(ctx, database.GetChatCostPerUserParams{
|
|
StartDate: startDate,
|
|
EndDate: endDate,
|
|
Username: username,
|
|
// #nosec G115 - Pagination limits are validated to fit in int32 above.
|
|
PageLimit: int32(limit),
|
|
// #nosec G115 - Pagination offsets are validated to fit in int32 above.
|
|
PageOffset: int32(offset),
|
|
})
|
|
if err != nil {
|
|
httpapi.InternalServerError(rw, err)
|
|
return
|
|
}
|
|
|
|
rollups := make([]codersdk.ChatCostUserRollup, 0, len(users))
|
|
count := int64(0)
|
|
for _, user := range users {
|
|
count = user.TotalCount
|
|
rollups = append(rollups, convertChatCostUserRollup(user))
|
|
}
|
|
|
|
if len(users) == 0 && offset > 0 {
|
|
countUsers, countErr := api.Database.GetChatCostPerUser(ctx, database.GetChatCostPerUserParams{
|
|
StartDate: startDate,
|
|
EndDate: endDate,
|
|
Username: username,
|
|
PageLimit: 1,
|
|
PageOffset: 0,
|
|
})
|
|
if countErr != nil {
|
|
httpapi.InternalServerError(rw, countErr)
|
|
return
|
|
}
|
|
if len(countUsers) > 0 {
|
|
count = countUsers[0].TotalCount
|
|
}
|
|
}
|
|
|
|
httpapi.Write(ctx, rw, http.StatusOK, codersdk.ChatCostUsersResponse{
|
|
StartDate: startDate,
|
|
EndDate: endDate,
|
|
Count: count,
|
|
Users: rollups,
|
|
})
|
|
}
|
|
|
|
// @Summary Get chat usage limit config
|
|
// @x-apidocgen {"skip": true}
|
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
|
//
|
|
//nolint:revive // HTTP handler writes to ResponseWriter.
|
|
func (api *API) getChatUsageLimitConfig(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
|
|
if !api.Authorize(r, policy.ActionRead, rbac.ResourceDeploymentConfig) {
|
|
httpapi.Forbidden(rw)
|
|
return
|
|
}
|
|
|
|
config, configErr := api.Database.GetChatUsageLimitConfig(ctx)
|
|
if configErr != nil && !errors.Is(configErr, sql.ErrNoRows) {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to get chat usage limit config.",
|
|
Detail: configErr.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
overrideRows, err := api.Database.ListChatUsageLimitOverrides(ctx)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to list chat usage limit overrides.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
groupOverrides, err := api.Database.ListChatUsageLimitGroupOverrides(ctx)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to list group usage limit overrides.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
unpricedModelCount, err := api.Database.CountEnabledModelsWithoutPricing(ctx)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to count unpriced chat models.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
response := codersdk.ChatUsageLimitConfigResponse{
|
|
ChatUsageLimitConfig: codersdk.ChatUsageLimitConfig{},
|
|
UnpricedModelCount: unpricedModelCount,
|
|
Overrides: make([]codersdk.ChatUsageLimitOverride, 0, len(overrideRows)),
|
|
GroupOverrides: make([]codersdk.ChatUsageLimitGroupOverride, 0, len(groupOverrides)),
|
|
}
|
|
if configErr == nil {
|
|
response.Period = codersdk.ChatUsageLimitPeriod(config.Period)
|
|
response.UpdatedAt = config.UpdatedAt
|
|
if config.Enabled {
|
|
response.SpendLimitMicros = ptr.Ref(config.DefaultLimitMicros)
|
|
}
|
|
}
|
|
|
|
for _, row := range overrideRows {
|
|
response.Overrides = append(response.Overrides, codersdk.ChatUsageLimitOverride{
|
|
UserID: row.UserID,
|
|
Username: row.Username,
|
|
Name: row.Name,
|
|
AvatarURL: row.AvatarURL,
|
|
SpendLimitMicros: nullInt64Ptr(row.SpendLimitMicros),
|
|
})
|
|
}
|
|
|
|
for _, glo := range groupOverrides {
|
|
response.GroupOverrides = append(response.GroupOverrides, codersdk.ChatUsageLimitGroupOverride{
|
|
GroupID: glo.GroupID,
|
|
GroupName: glo.GroupName,
|
|
GroupDisplayName: glo.GroupDisplayName,
|
|
GroupAvatarURL: glo.GroupAvatarUrl,
|
|
MemberCount: glo.MemberCount,
|
|
SpendLimitMicros: nullInt64Ptr(glo.SpendLimitMicros),
|
|
})
|
|
}
|
|
httpapi.Write(ctx, rw, http.StatusOK, response)
|
|
}
|
|
|
|
// @Summary Update chat usage limit config
|
|
// @x-apidocgen {"skip": true}
|
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
|
func (api *API) updateChatUsageLimitConfig(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) {
|
|
httpapi.Forbidden(rw)
|
|
return
|
|
}
|
|
|
|
var req codersdk.ChatUsageLimitConfig
|
|
if !httpapi.Read(ctx, rw, r, &req) {
|
|
return
|
|
}
|
|
|
|
params := database.UpsertChatUsageLimitConfigParams{
|
|
Enabled: false,
|
|
DefaultLimitMicros: 0,
|
|
Period: "",
|
|
}
|
|
if req.SpendLimitMicros == nil {
|
|
if req.Period != "" && !req.Period.Valid() {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid chat usage limit period.",
|
|
Detail: "Period must be one of: day, week, month.",
|
|
})
|
|
return
|
|
}
|
|
|
|
params.Enabled = false
|
|
params.DefaultLimitMicros = 0
|
|
params.Period = string(req.Period)
|
|
if params.Period == "" {
|
|
params.Period = string(codersdk.ChatUsageLimitPeriodMonth)
|
|
}
|
|
} else {
|
|
if *req.SpendLimitMicros <= 0 {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid chat usage limit spend limit.",
|
|
Detail: "Spend limit must be greater than 0.",
|
|
})
|
|
return
|
|
}
|
|
if !req.Period.Valid() {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid chat usage limit period.",
|
|
Detail: "Period must be one of: day, week, month.",
|
|
})
|
|
return
|
|
}
|
|
|
|
params.Enabled = true
|
|
params.DefaultLimitMicros = *req.SpendLimitMicros
|
|
params.Period = string(req.Period)
|
|
}
|
|
|
|
config, err := api.Database.UpsertChatUsageLimitConfig(ctx, params)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to update chat usage limit config.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
response := codersdk.ChatUsageLimitConfig{
|
|
Period: codersdk.ChatUsageLimitPeriod(config.Period),
|
|
UpdatedAt: config.UpdatedAt,
|
|
}
|
|
if config.Enabled {
|
|
response.SpendLimitMicros = ptr.Ref(config.DefaultLimitMicros)
|
|
}
|
|
|
|
httpapi.Write(ctx, rw, http.StatusOK, response)
|
|
}
|
|
|
|
// @Summary Get my chat usage limit status
|
|
// @x-apidocgen {"skip": true}
|
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
|
//
|
|
// getMyChatUsageLimitStatus returns the current usage-limit status for the
|
|
// authenticated user. No additional RBAC check is required because the
|
|
// endpoint always operates on the requesting user's own data via
|
|
// httpmw.APIKey(r).UserID.
|
|
//
|
|
//nolint:revive // HTTP handler writes to ResponseWriter.
|
|
func (api *API) getMyChatUsageLimitStatus(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
// TODO(CODAGT-161): pass real organization ID
|
|
// when the HTTP endpoint supports org-scoped queries.
|
|
status, err := chatd.ResolveUsageLimitStatus(ctx, api.Database, httpmw.APIKey(r).UserID, uuid.NullUUID{}, time.Now())
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to get chat usage limit status.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
if status == nil {
|
|
httpapi.Write(ctx, rw, http.StatusOK, codersdk.ChatUsageLimitStatus{IsLimited: false})
|
|
return
|
|
}
|
|
|
|
httpapi.Write(ctx, rw, http.StatusOK, status)
|
|
}
|
|
|
|
// @Summary Upsert chat usage limit override
|
|
// @x-apidocgen {"skip": true}
|
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
|
func (api *API) upsertChatUsageLimitOverride(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) {
|
|
httpapi.Forbidden(rw)
|
|
return
|
|
}
|
|
|
|
userID, ok := parseChatUsageLimitUserID(rw, r)
|
|
if !ok {
|
|
return
|
|
}
|
|
|
|
var req codersdk.UpsertChatUsageLimitOverrideRequest
|
|
if !httpapi.Read(ctx, rw, r, &req) {
|
|
return
|
|
}
|
|
if req.SpendLimitMicros <= 0 {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid chat usage limit override.",
|
|
Detail: "Spend limit must be greater than 0.",
|
|
})
|
|
return
|
|
}
|
|
|
|
user, err := api.Database.GetUserByID(ctx, userID)
|
|
if err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
|
|
Message: "User not found.",
|
|
})
|
|
return
|
|
}
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to look up chat usage limit user.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
_, err = api.Database.UpsertChatUsageLimitUserOverride(ctx, database.UpsertChatUsageLimitUserOverrideParams{
|
|
UserID: userID,
|
|
SpendLimitMicros: req.SpendLimitMicros,
|
|
})
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to upsert chat usage limit override.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
httpapi.Write(ctx, rw, http.StatusOK, codersdk.ChatUsageLimitOverride{
|
|
UserID: user.ID,
|
|
Username: user.Username,
|
|
Name: user.Name,
|
|
AvatarURL: user.AvatarURL,
|
|
SpendLimitMicros: nullInt64Ptr(sql.NullInt64{Int64: req.SpendLimitMicros, Valid: true}),
|
|
})
|
|
}
|
|
|
|
// @Summary Delete chat usage limit override
|
|
// @x-apidocgen {"skip": true}
|
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
|
func (api *API) deleteChatUsageLimitOverride(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) {
|
|
httpapi.Forbidden(rw)
|
|
return
|
|
}
|
|
|
|
userID, ok := parseChatUsageLimitUserID(rw, r)
|
|
if !ok {
|
|
return
|
|
}
|
|
|
|
if _, err := api.Database.GetUserByID(ctx, userID); err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
writeChatUsageLimitUserNotFound(ctx, rw)
|
|
return
|
|
}
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to look up chat usage limit user.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
if _, err := api.Database.GetChatUsageLimitUserOverride(ctx, userID); err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
writeChatUsageLimitOverrideNotFound(ctx, rw)
|
|
return
|
|
}
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to look up chat usage limit override.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
if err := api.Database.DeleteChatUsageLimitUserOverride(ctx, userID); err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to delete chat usage limit override.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
rw.WriteHeader(http.StatusNoContent)
|
|
}
|
|
|
|
// @Summary Upsert chat usage limit group override
|
|
// @x-apidocgen {"skip": true}
|
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
|
func (api *API) upsertChatUsageLimitGroupOverride(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) {
|
|
httpapi.Forbidden(rw)
|
|
return
|
|
}
|
|
|
|
groupIDStr := chi.URLParam(r, "group")
|
|
groupID, err := uuid.Parse(groupIDStr)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid group ID.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
var req codersdk.UpdateChatUsageLimitGroupOverrideRequest
|
|
if !httpapi.Read(ctx, rw, r, &req) {
|
|
return
|
|
}
|
|
|
|
if req.SpendLimitMicros <= 0 {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid chat usage limit group override.",
|
|
Detail: "Spend limit (in microdollars) must be greater than 0.",
|
|
})
|
|
return
|
|
}
|
|
|
|
group, err := api.Database.GetGroupByID(ctx, groupID)
|
|
if err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
|
|
Message: "Group not found.",
|
|
})
|
|
return
|
|
}
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to look up group details.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
_, err = api.Database.UpsertChatUsageLimitGroupOverride(ctx, database.UpsertChatUsageLimitGroupOverrideParams{
|
|
GroupID: groupID,
|
|
SpendLimitMicros: req.SpendLimitMicros,
|
|
})
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to upsert group usage limit override.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
memberCount, err := api.Database.GetGroupMembersCountByGroupID(ctx, database.GetGroupMembersCountByGroupIDParams{
|
|
GroupID: groupID,
|
|
IncludeSystem: false,
|
|
})
|
|
if err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
writeChatUsageLimitGroupNotFound(ctx, rw)
|
|
return
|
|
}
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to fetch group member count.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
httpapi.Write(ctx, rw, http.StatusOK, codersdk.ChatUsageLimitGroupOverride{
|
|
GroupID: group.ID,
|
|
GroupName: group.Name,
|
|
GroupDisplayName: group.DisplayName,
|
|
GroupAvatarURL: group.AvatarURL,
|
|
MemberCount: memberCount,
|
|
SpendLimitMicros: nullInt64Ptr(sql.NullInt64{Int64: req.SpendLimitMicros, Valid: true}),
|
|
})
|
|
}
|
|
|
|
// @Summary Delete chat usage limit group override
|
|
// @x-apidocgen {"skip": true}
|
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
|
func (api *API) deleteChatUsageLimitGroupOverride(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) {
|
|
httpapi.Forbidden(rw)
|
|
return
|
|
}
|
|
|
|
groupIDStr := chi.URLParam(r, "group")
|
|
groupID, err := uuid.Parse(groupIDStr)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid group ID.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
if _, err := api.Database.GetGroupByID(ctx, groupID); err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
writeChatUsageLimitGroupNotFound(ctx, rw)
|
|
return
|
|
}
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to look up group details.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
if _, err := api.Database.GetChatUsageLimitGroupOverride(ctx, groupID); err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
writeChatUsageLimitGroupOverrideNotFound(ctx, rw)
|
|
return
|
|
}
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to look up group usage limit override.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
if err := api.Database.DeleteChatUsageLimitGroupOverride(ctx, groupID); err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to delete group usage limit override.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
rw.WriteHeader(http.StatusNoContent)
|
|
}
|
|
|
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
|
//
|
|
// @Summary Get chat by ID
|
|
// @ID get-chat-by-id
|
|
// @Security CoderSessionToken
|
|
// @Tags Chats
|
|
// @Produce json
|
|
// @Param chat path string true "Chat ID" format(uuid)
|
|
// @Success 200 {object} codersdk.Chat
|
|
// @Router /api/experimental/chats/{chat} [get]
|
|
// @Description Experimental: this endpoint is subject to change.
|
|
//
|
|
//nolint:revive // HTTP handler writes to ResponseWriter.
|
|
func (api *API) getChat(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
chat := httpmw.ChatParam(r)
|
|
|
|
// Use the cached diff status from the database rather than
|
|
// resolving it inline. Inline resolution calls out to the
|
|
// git provider API (e.g. GitHub) on every request which
|
|
// blocks the response for 200-800ms. The background gitsync
|
|
// worker keeps the cached status fresh.
|
|
var diffStatus *database.ChatDiffStatus
|
|
status, err := api.Database.GetChatDiffStatusByChatID(ctx, chat.ID)
|
|
switch {
|
|
case err == nil:
|
|
diffStatus = &status
|
|
case !xerrors.Is(err, sql.ErrNoRows):
|
|
api.Logger.Error(ctx, "failed to get cached chat diff status",
|
|
slog.F("chat_id", chat.ID),
|
|
slog.Error(err),
|
|
)
|
|
}
|
|
|
|
// Hydrate file metadata for all files linked to this chat.
|
|
chatFiles := api.fetchChatFileMetadata(ctx, chat.ID)
|
|
|
|
sdkChat := db2sdk.Chat(chat, diffStatus, chatFiles)
|
|
|
|
// For root chats, embed children so callers get a complete
|
|
// tree in a single response.
|
|
if !chat.ParentChatID.Valid {
|
|
// Embed children matching the parent's archive state.
|
|
childRows, err := api.Database.GetChildChatsByParentIDs(ctx, database.GetChildChatsByParentIDsParams{
|
|
ParentIds: []uuid.UUID{chat.ID},
|
|
Archived: sql.NullBool{Bool: chat.Archived, Valid: true},
|
|
})
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to fetch child chats.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
// Look up diff statuses for children.
|
|
childChats := make([]database.Chat, len(childRows))
|
|
for i, row := range childRows {
|
|
childChats[i] = row.Chat
|
|
}
|
|
childDiffStatuses, err := api.getChatDiffStatusesByChatID(ctx, childChats)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to fetch child chat diff statuses.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
sdkChat.Children = db2sdk.ChildChatRows(childRows, childDiffStatuses)
|
|
}
|
|
|
|
httpapi.Write(ctx, rw, http.StatusOK, sdkChat)
|
|
}
|
|
|
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
|
//
|
|
// @Summary List chat messages
|
|
// @ID list-chat-messages
|
|
// @Security CoderSessionToken
|
|
// @Tags Chats
|
|
// @Produce json
|
|
// @Param chat path string true "Chat ID" format(uuid)
|
|
// @Param before_id query int false "Return messages with id < before_id"
|
|
// @Param after_id query int false "Return messages with id > after_id"
|
|
// @Param limit query int false "Page size, 1 to 200. Defaults to 50."
|
|
// @Success 200 {object} codersdk.ChatMessagesResponse
|
|
// @Router /api/experimental/chats/{chat}/messages [get]
|
|
// @Description Experimental: this endpoint is subject to change.
|
|
//
|
|
//nolint:revive // HTTP handler writes to ResponseWriter.
|
|
func (api *API) getChatMessages(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
chat := httpmw.ChatParam(r)
|
|
chatID := chat.ID
|
|
|
|
// Parse optional cursor-based pagination parameters.
|
|
queryParams := r.URL.Query()
|
|
parser := httpapi.NewQueryParamParser()
|
|
beforeID := parser.PositiveInt64(queryParams, 0, "before_id")
|
|
afterID := parser.PositiveInt64(queryParams, 0, "after_id")
|
|
limit := parser.PositiveInt32(queryParams, 50, "limit")
|
|
if len(parser.Errors) > 0 {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Query parameters have invalid values.",
|
|
Validations: parser.Errors,
|
|
})
|
|
return
|
|
}
|
|
if limit < 1 || limit > 200 {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid limit parameter (1-200).",
|
|
})
|
|
return
|
|
}
|
|
// Reject transposed or equal cursors so an empty open range is loud,
|
|
// not silently indistinguishable from "no messages in this range."
|
|
if beforeID > 0 && afterID > 0 && afterID >= beforeID {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "after_id must be less than before_id.",
|
|
})
|
|
return
|
|
}
|
|
|
|
// Polling with only after_id uses ASC so the cursor advances
|
|
// monotonically; a DESC limit would drop rows when a burst larger
|
|
// than `limit` lands between polls. Fetch limit+1 in both paths to
|
|
// detect whether more pages exist.
|
|
var messages []database.ChatMessage
|
|
var err error
|
|
switch {
|
|
case afterID > 0 && beforeID == 0:
|
|
messages, err = api.Database.GetChatMessagesByChatIDAscPaginated(ctx, database.GetChatMessagesByChatIDAscPaginatedParams{
|
|
ChatID: chatID,
|
|
AfterID: afterID,
|
|
LimitVal: limit + 1,
|
|
})
|
|
default:
|
|
messages, err = api.Database.GetChatMessagesByChatIDDescPaginated(ctx, database.GetChatMessagesByChatIDDescPaginatedParams{
|
|
ChatID: chatID,
|
|
BeforeID: beforeID,
|
|
AfterID: afterID,
|
|
LimitVal: limit + 1,
|
|
})
|
|
}
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to get chat messages.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
hasMore := len(messages) > int(limit)
|
|
if hasMore {
|
|
messages = messages[:limit]
|
|
}
|
|
|
|
// Queued messages are only meaningful for the initial top-of-history
|
|
// load. Suppress them whenever any cursor is set so polling callers do
|
|
// not receive the snapshot on every page fetch.
|
|
var queuedMessages []database.ChatQueuedMessage
|
|
if beforeID == 0 && afterID == 0 {
|
|
queuedMessages, err = api.Database.GetChatQueuedMessages(ctx, chatID)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to get queued messages.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
}
|
|
|
|
httpapi.Write(ctx, rw, http.StatusOK, codersdk.ChatMessagesResponse{
|
|
Messages: convertChatMessages(messages),
|
|
QueuedMessages: convertChatQueuedMessages(queuedMessages),
|
|
HasMore: hasMore,
|
|
})
|
|
}
|
|
|
|
// @Summary List chat user prompts
|
|
// @ID list-chat-user-prompts
|
|
// @Security CoderSessionToken
|
|
// @Tags Chats
|
|
// @Produce json
|
|
// @Param chat path string true "Chat ID" format(uuid)
|
|
// @Param limit query int false "Page size, 0 to 2000. 0 (the default) means the server-side default of 500."
|
|
// @Success 200 {object} codersdk.ChatPromptsResponse
|
|
// @Router /api/experimental/chats/{chat}/prompts [get]
|
|
// @Description Experimental: this endpoint is subject to change.
|
|
// @Description
|
|
// @Description Returns the user-authored prompts in a chat, newest first,
|
|
// @Description with each prompt's text parts concatenated in the order they
|
|
// @Description were authored. Used by the composer to power the up/down
|
|
// @Description arrow prompt-history cycle without paging through every
|
|
// @Description message in the chat.
|
|
//
|
|
//nolint:revive // HTTP handler writes to ResponseWriter.
|
|
func (api *API) getChatUserPrompts(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
chat := httpmw.ChatParam(r)
|
|
chatID := chat.ID
|
|
|
|
queryParams := r.URL.Query()
|
|
parser := httpapi.NewQueryParamParser()
|
|
// Default 0 sentinel; the SQL query treats 0 as "use the built-in
|
|
// default of 500" via COALESCE(NULLIF(@limit_val, 0), 500). The
|
|
// SDK guards opts.Limit > 0 so callers using the typed client only
|
|
// reach here with an explicit value; raw HTTP callers can omit the
|
|
// parameter (or pass 0) to opt into the default.
|
|
limit := parser.PositiveInt32(queryParams, 0, "limit")
|
|
if len(parser.Errors) > 0 {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Query parameters have invalid values.",
|
|
Validations: parser.Errors,
|
|
})
|
|
return
|
|
}
|
|
// PositiveInt32 already rejects negatives via parser.Errors above,
|
|
// so we only need to cap the upper bound here.
|
|
if limit > 2000 {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid limit parameter (0-2000).",
|
|
})
|
|
return
|
|
}
|
|
|
|
rows, err := api.Database.GetChatUserPromptsByChatID(ctx, database.GetChatUserPromptsByChatIDParams{
|
|
ChatID: chatID,
|
|
LimitVal: limit,
|
|
})
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to get chat user prompts.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
prompts := make([]codersdk.ChatPrompt, 0, len(rows))
|
|
for _, row := range rows {
|
|
prompts = append(prompts, codersdk.ChatPrompt{
|
|
ID: row.ID,
|
|
Text: row.Text,
|
|
})
|
|
}
|
|
|
|
httpapi.Write(ctx, rw, http.StatusOK, codersdk.ChatPromptsResponse{
|
|
Prompts: prompts,
|
|
})
|
|
}
|
|
|
|
// authorizeChatWorkspaceExec enforces the workspace-level permissions
|
|
// shared by the chat stream endpoints that proxy a live websocket into
|
|
// the workspace agent (currently /stream/git and /stream/desktop).
|
|
//
|
|
// The chat row only authorizes the chat owner, so callers also need
|
|
// exec-level access (ApplicationConnect or SSH) to the bound workspace.
|
|
// The chat owner's workspace permissions may have been revoked after
|
|
// the chat was bound; skipping this check enabled CODAGT-184.
|
|
//
|
|
// On any failure the response is written and ok=false is returned.
|
|
//
|
|
//nolint:revive // HTTP handler writes to ResponseWriter.
|
|
func (api *API) authorizeChatWorkspaceExec(
|
|
rw http.ResponseWriter,
|
|
r *http.Request,
|
|
chat database.Chat,
|
|
noWorkspaceMessage string,
|
|
) (database.Workspace, bool) {
|
|
ctx := r.Context()
|
|
|
|
if !chat.WorkspaceID.Valid {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: noWorkspaceMessage,
|
|
})
|
|
return database.Workspace{}, false
|
|
}
|
|
|
|
workspace, err := api.Database.GetWorkspaceByID(ctx, chat.WorkspaceID.UUID)
|
|
if httpapi.Is404Error(err) {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: codersdk.ChatGitWatchWorkspaceNotFoundMessage,
|
|
})
|
|
return database.Workspace{}, false
|
|
}
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Internal error fetching chat workspace.",
|
|
Detail: err.Error(),
|
|
})
|
|
return database.Workspace{}, false
|
|
}
|
|
|
|
if !api.Authorize(r, policy.ActionApplicationConnect, workspace) &&
|
|
!api.Authorize(r, policy.ActionSSH, workspace) {
|
|
httpapi.Forbidden(rw)
|
|
return database.Workspace{}, false
|
|
}
|
|
|
|
return workspace, true
|
|
}
|
|
|
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
|
//
|
|
// @Summary Watch chat workspace git state via WebSockets
|
|
// @ID watch-chat-workspace-git-state-via-websockets
|
|
// @Security CoderSessionToken
|
|
// @Tags Chats
|
|
// @Produce json
|
|
// @Param chat path string true "Chat ID" format(uuid)
|
|
// @Success 200 {object} codersdk.WorkspaceAgentGitServerMessage
|
|
// @Router /api/experimental/chats/{chat}/stream/git [get]
|
|
// @Description Experimental: this endpoint is subject to change.
|
|
//
|
|
//nolint:revive // HTTP handler writes to ResponseWriter.
|
|
func (api *API) watchChatGit(rw http.ResponseWriter, r *http.Request) {
|
|
var (
|
|
ctx = r.Context()
|
|
chat = httpmw.ChatParam(r)
|
|
logger = api.Logger.Named("chat_git_watcher").With(slog.F("chat_id", chat.ID))
|
|
)
|
|
|
|
if _, ok := api.authorizeChatWorkspaceExec(rw, r, chat, codersdk.ChatGitWatchNoWorkspaceMessage); !ok {
|
|
return
|
|
}
|
|
|
|
agents, err := api.Database.GetWorkspaceAgentsInLatestBuildByWorkspaceID(ctx, chat.WorkspaceID.UUID)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Internal error fetching workspace agents.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
if len(agents) == 0 {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: codersdk.ChatGitWatchWorkspaceNoAgentsMessage,
|
|
})
|
|
return
|
|
}
|
|
|
|
apiAgent, err := db2sdk.WorkspaceAgent(
|
|
api.DERPMap(),
|
|
*api.TailnetCoordinator.Load(),
|
|
agents[0],
|
|
nil,
|
|
nil,
|
|
nil,
|
|
api.AgentInactiveDisconnectTimeout,
|
|
api.DeploymentValues.AgentFallbackTroubleshootingURL.String(),
|
|
)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Internal error reading workspace agent.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
if apiAgent.Status != codersdk.WorkspaceAgentConnected {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: codersdk.ChatGitWatchAgentStateMessage(apiAgent.Status),
|
|
})
|
|
return
|
|
}
|
|
|
|
dialCtx, dialCancel := context.WithTimeout(ctx, 30*time.Second)
|
|
defer dialCancel()
|
|
|
|
agentConn, release, err := api.agentProvider.AgentConn(dialCtx, agents[0].ID)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Internal error dialing workspace agent.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
defer release()
|
|
|
|
agentStream, err := agentConn.WatchGit(ctx, logger, chat.ID)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Internal error watching agent's git state.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
defer agentStream.Close(websocket.StatusGoingAway)
|
|
|
|
clientConn, err := websocket.Accept(rw, r, &websocket.AcceptOptions{
|
|
CompressionMode: websocket.CompressionNoContextTakeover,
|
|
})
|
|
if err != nil {
|
|
logger.Error(ctx, "failed to accept websocket", slog.Error(err))
|
|
return
|
|
}
|
|
|
|
clientStream := wsjson.NewStream[
|
|
codersdk.WorkspaceAgentGitClientMessage,
|
|
codersdk.WorkspaceAgentGitServerMessage,
|
|
](clientConn, websocket.MessageText, websocket.MessageText, logger)
|
|
|
|
ctx, cancel := context.WithCancel(r.Context())
|
|
defer cancel()
|
|
|
|
go httpapi.HeartbeatClose(ctx, logger, cancel, clientConn)
|
|
|
|
// Proxy agent → client.
|
|
agentCh := agentStream.Chan()
|
|
var wg sync.WaitGroup
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
for {
|
|
select {
|
|
case <-api.ctx.Done():
|
|
return
|
|
case <-ctx.Done():
|
|
return
|
|
case msg, ok := <-agentCh:
|
|
if !ok {
|
|
cancel()
|
|
return
|
|
}
|
|
if err := clientStream.Send(msg); err != nil {
|
|
logger.Debug(ctx, "failed to forward agent message to client", slog.Error(err))
|
|
cancel()
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}()
|
|
|
|
// Proxy client → agent.
|
|
clientCh := clientStream.Chan()
|
|
proxyLoop:
|
|
for {
|
|
select {
|
|
case <-api.ctx.Done():
|
|
break proxyLoop
|
|
case <-ctx.Done():
|
|
break proxyLoop
|
|
case msg, ok := <-clientCh:
|
|
if !ok {
|
|
break proxyLoop
|
|
}
|
|
if err := agentStream.Send(msg); err != nil {
|
|
logger.Debug(ctx, "failed to forward client message to agent", slog.Error(err))
|
|
break proxyLoop
|
|
}
|
|
}
|
|
}
|
|
|
|
cancel()
|
|
wg.Wait()
|
|
_ = clientStream.Close(websocket.StatusGoingAway)
|
|
}
|
|
|
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
|
//
|
|
// @Summary Connect to chat workspace desktop via WebSockets
|
|
// @ID connect-to-chat-workspace-desktop-via-websockets
|
|
// @Security CoderSessionToken
|
|
// @Tags Chats
|
|
// @Produce application/octet-stream
|
|
// @Param chat path string true "Chat ID" format(uuid)
|
|
// @Success 101
|
|
// @Router /api/experimental/chats/{chat}/stream/desktop [get]
|
|
// @Description Raw binary WebSocket stream of the chat workspace desktop.
|
|
// @Description Experimental: this endpoint is subject to change.
|
|
//
|
|
//nolint:revive // HTTP handler writes to ResponseWriter.
|
|
func (api *API) watchChatDesktop(rw http.ResponseWriter, r *http.Request) {
|
|
var (
|
|
ctx = r.Context()
|
|
chat = httpmw.ChatParam(r)
|
|
logger = api.Logger.Named("chat_desktop").With(slog.F("chat_id", chat.ID))
|
|
)
|
|
|
|
if _, ok := api.authorizeChatWorkspaceExec(rw, r, chat, "Chat has no workspace."); !ok {
|
|
return
|
|
}
|
|
|
|
agents, err := api.Database.GetWorkspaceAgentsInLatestBuildByWorkspaceID(ctx, chat.WorkspaceID.UUID)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Internal error fetching workspace agents.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
if len(agents) == 0 {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Chat workspace has no agents.",
|
|
})
|
|
return
|
|
}
|
|
|
|
apiAgent, err := db2sdk.WorkspaceAgent(
|
|
api.DERPMap(),
|
|
*api.TailnetCoordinator.Load(),
|
|
agents[0],
|
|
nil,
|
|
nil,
|
|
nil,
|
|
api.AgentInactiveDisconnectTimeout,
|
|
api.DeploymentValues.AgentFallbackTroubleshootingURL.String(),
|
|
)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Internal error reading workspace agent.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
if apiAgent.Status != codersdk.WorkspaceAgentConnected {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: fmt.Sprintf("Agent state is %q, must be connected.", apiAgent.Status),
|
|
})
|
|
return
|
|
}
|
|
|
|
dialCtx, dialCancel := context.WithTimeout(ctx, 30*time.Second)
|
|
defer dialCancel()
|
|
|
|
agentConn, release, err := api.agentProvider.AgentConn(dialCtx, agents[0].ID)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to dial workspace agent.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
defer release()
|
|
|
|
desktopConn, err := agentConn.ConnectDesktopVNC(ctx)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to connect to agent desktop.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
defer desktopConn.Close()
|
|
|
|
conn, err := websocket.Accept(rw, r, &websocket.AcceptOptions{
|
|
CompressionMode: websocket.CompressionDisabled,
|
|
})
|
|
if err != nil {
|
|
logger.Error(ctx, "failed to accept websocket", slog.Error(err))
|
|
return
|
|
}
|
|
|
|
// No read limit — RFB framebuffer updates can be large.
|
|
conn.SetReadLimit(-1)
|
|
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
defer cancel()
|
|
|
|
ctx, wsNetConn := workspaceapps.WebsocketNetConn(ctx, conn, websocket.MessageBinary)
|
|
defer wsNetConn.Close()
|
|
|
|
go httpapi.HeartbeatClose(ctx, logger, cancel, conn)
|
|
|
|
agentssh.Bicopy(ctx, wsNetConn, desktopConn)
|
|
logger.Debug(ctx, "desktop Bicopy finished")
|
|
}
|
|
|
|
func (api *API) applyChatTitleUpdate(
|
|
ctx context.Context,
|
|
rw http.ResponseWriter,
|
|
chat database.Chat,
|
|
rawTitle string,
|
|
) (database.Chat, bool) {
|
|
trimmedTitle := strings.TrimSpace(rawTitle)
|
|
if trimmedTitle == "" {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Title cannot be empty.",
|
|
})
|
|
return chat, true
|
|
}
|
|
const maxChatTitleRunes = 200
|
|
if utf8.RuneCountInString(trimmedTitle) > maxChatTitleRunes {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: fmt.Sprintf("Title must be at most %d characters.", maxChatTitleRunes),
|
|
})
|
|
return chat, true
|
|
}
|
|
if trimmedTitle == chat.Title {
|
|
return chat, false
|
|
}
|
|
|
|
var (
|
|
updatedChat database.Chat
|
|
wrote bool
|
|
err error
|
|
)
|
|
if api.chatDaemon != nil {
|
|
updatedChat, wrote, err = api.chatDaemon.RenameChatTitle(ctx, chat, trimmedTitle)
|
|
} else {
|
|
err = api.Database.InTx(func(tx database.Store) error {
|
|
currentChat, txErr := tx.GetChatByID(ctx, chat.ID)
|
|
if txErr != nil {
|
|
return txErr
|
|
}
|
|
if trimmedTitle == currentChat.Title {
|
|
updatedChat = currentChat
|
|
wrote = false
|
|
return nil
|
|
}
|
|
updatedChat, txErr = tx.UpdateChatTitleByID(ctx, database.UpdateChatTitleByIDParams{
|
|
ID: chat.ID,
|
|
Title: trimmedTitle,
|
|
})
|
|
if txErr != nil {
|
|
return txErr
|
|
}
|
|
wrote = true
|
|
return nil
|
|
}, nil)
|
|
}
|
|
if err != nil {
|
|
if errors.Is(err, chatd.ErrManualTitleRegenerationInProgress) {
|
|
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{
|
|
Message: "Title regeneration already in progress for this chat.",
|
|
})
|
|
return chat, true
|
|
}
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
httpapi.ResourceNotFound(rw)
|
|
return chat, true
|
|
}
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to update chat title.",
|
|
Detail: err.Error(),
|
|
})
|
|
return chat, true
|
|
}
|
|
if wrote {
|
|
if api.chatDaemon != nil {
|
|
api.chatDaemon.PublishTitleChange(updatedChat)
|
|
} else {
|
|
publishChatTitleChange(api.Logger, api.Pubsub, updatedChat)
|
|
}
|
|
}
|
|
return updatedChat, false
|
|
}
|
|
|
|
// patchChat updates a chat resource. Supports updating labels,
|
|
// workspace binding, archiving, pinning, and pinned-chat ordering.
|
|
//
|
|
// @Summary Update chat
|
|
// @ID update-chat
|
|
// @Security CoderSessionToken
|
|
// @Tags Chats
|
|
// @Accept json
|
|
// @Param chat path string true "Chat ID" format(uuid)
|
|
// @Param request body codersdk.UpdateChatRequest true "Update chat request"
|
|
// @Success 204
|
|
// @Router /api/experimental/chats/{chat} [patch]
|
|
// @Description Experimental: this endpoint is subject to change.
|
|
func (api *API) patchChat(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
chat := httpmw.ChatParam(r)
|
|
|
|
if !api.Authorize(r, policy.ActionUpdate, chat.RBACObject()) {
|
|
httpapi.ResourceNotFound(rw)
|
|
return
|
|
}
|
|
|
|
aReq, commitAudit := audit.InitRequest[database.Chat](rw, &audit.RequestParams{
|
|
Audit: *api.Auditor.Load(),
|
|
Log: api.Logger,
|
|
Request: r,
|
|
Action: database.AuditActionWrite,
|
|
})
|
|
defer commitAudit()
|
|
aReq.Old = chat
|
|
aReq.UpdateOrganizationID(chat.OrganizationID)
|
|
|
|
var req codersdk.UpdateChatRequest
|
|
if !httpapi.Read(ctx, rw, r, &req) {
|
|
return
|
|
}
|
|
|
|
var planModeUpdate *database.NullChatPlanMode
|
|
if req.PlanMode != nil {
|
|
if !validateChatPlanMode(*req.PlanMode) {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid plan_mode value.",
|
|
})
|
|
return
|
|
}
|
|
resolvedPlanMode := planModeToNullChatPlanMode(*req.PlanMode)
|
|
planModeUpdate = &resolvedPlanMode
|
|
}
|
|
|
|
if req.Title != nil {
|
|
updatedChat, handled := api.applyChatTitleUpdate(ctx, rw, chat, *req.Title)
|
|
if handled {
|
|
return
|
|
}
|
|
chat = updatedChat
|
|
}
|
|
if req.Labels != nil {
|
|
if errs := httpapi.ValidateChatLabels(*req.Labels); len(errs) > 0 {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid labels.",
|
|
Validations: errs,
|
|
})
|
|
return
|
|
}
|
|
labelsJSON, err := json.Marshal(*req.Labels)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to marshal labels.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
updatedChat, err := api.Database.UpdateChatLabelsByID(ctx, database.UpdateChatLabelsByIDParams{
|
|
ID: chat.ID,
|
|
Labels: labelsJSON,
|
|
})
|
|
if err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
httpapi.ResourceNotFound(rw)
|
|
return
|
|
}
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to update chat labels.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
chat = updatedChat
|
|
}
|
|
|
|
if req.Archived != nil {
|
|
archived := *req.Archived
|
|
if archived == chat.Archived {
|
|
state := "archived"
|
|
if !archived {
|
|
state = "not archived"
|
|
}
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: fmt.Sprintf("Chat is already %s.", state),
|
|
})
|
|
return
|
|
}
|
|
|
|
// Archive invariant is one-way: parent archived implies
|
|
// child archived. Parent archive/unarchive cascade via
|
|
// root_chat_id; individual child archive is permitted;
|
|
// child unarchive while the parent is archived is rejected
|
|
// (enforced atomically in chatd.Server.UnarchiveChat).
|
|
if chat.ParentChatID.Valid && !archived {
|
|
if done := api.writeChildUnarchiveGuard(ctx, rw, chat); done {
|
|
return
|
|
}
|
|
}
|
|
var err error
|
|
// Use chatDaemon when available so it can interrupt active
|
|
// processing before broadcasting archive state. Fall back to
|
|
// direct DB when no daemon is running.
|
|
if archived {
|
|
if api.chatDaemon != nil {
|
|
err = api.chatDaemon.ArchiveChat(ctx, chat)
|
|
} else {
|
|
_, err = api.Database.ArchiveChatByID(ctx, chat.ID)
|
|
}
|
|
} else {
|
|
if api.chatDaemon != nil {
|
|
err = api.chatDaemon.UnarchiveChat(ctx, chat)
|
|
} else {
|
|
_, err = api.Database.UnarchiveChatByID(ctx, chat.ID)
|
|
}
|
|
}
|
|
if err != nil {
|
|
if errors.Is(err, chatd.ErrChildUnarchiveParentArchived) {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Cannot unarchive a child chat while its parent is archived. Unarchive the parent chat to cascade.",
|
|
})
|
|
return
|
|
}
|
|
action := "archive"
|
|
if !archived {
|
|
action = "unarchive"
|
|
}
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: fmt.Sprintf("Failed to %s chat.", action),
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
}
|
|
|
|
if req.PinOrder != nil {
|
|
pinOrder := *req.PinOrder
|
|
if pinOrder < 0 {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Pin order must be non-negative.",
|
|
})
|
|
return
|
|
}
|
|
|
|
if pinOrder > 0 && chat.Archived {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Cannot pin an archived chat.",
|
|
})
|
|
return
|
|
}
|
|
|
|
if pinOrder > 0 && chat.ParentChatID.Valid {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Cannot pin a child chat.",
|
|
})
|
|
return
|
|
}
|
|
|
|
// The behavior depends on current pin state:
|
|
// - pinOrder == 0: unpin.
|
|
// - pinOrder > 0 && already pinned: reorder (shift
|
|
// neighbors, clamp to [1, count]).
|
|
// - pinOrder > 0 && not pinned: append to end. The
|
|
// requested value is intentionally ignored; the
|
|
// SQL ORDER BY sorts pinned chats first so they
|
|
// appear on page 1 of the paginated sidebar.
|
|
var err error
|
|
errMsg := "Failed to pin chat."
|
|
switch {
|
|
case pinOrder == 0:
|
|
errMsg = "Failed to unpin chat."
|
|
err = api.Database.UnpinChatByID(ctx, chat.ID)
|
|
case chat.PinOrder > 0:
|
|
errMsg = "Failed to reorder pinned chat."
|
|
err = api.Database.UpdateChatPinOrder(ctx, database.UpdateChatPinOrderParams{
|
|
ID: chat.ID,
|
|
PinOrder: pinOrder,
|
|
})
|
|
default:
|
|
err = api.Database.PinChatByID(ctx, chat.ID)
|
|
}
|
|
if err != nil {
|
|
switch {
|
|
case database.IsCheckViolation(err, database.CheckChatsPinOrderParentCheck):
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Cannot pin a child chat.",
|
|
})
|
|
case database.IsCheckViolation(err, database.CheckChatsPinOrderArchivedCheck):
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Cannot pin an archived chat.",
|
|
})
|
|
default:
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: errMsg,
|
|
Detail: err.Error(),
|
|
})
|
|
}
|
|
return
|
|
}
|
|
}
|
|
|
|
if req.WorkspaceID != nil {
|
|
workspaceID := uuid.NullUUID{}
|
|
workspace := database.Workspace{}
|
|
if *req.WorkspaceID != uuid.Nil {
|
|
var status int
|
|
var resp *codersdk.Response
|
|
workspaceID, workspace, status, resp = api.validateChatWorkspaceSelection(ctx, r, req.WorkspaceID)
|
|
if resp != nil {
|
|
httpapi.Write(ctx, rw, status, *resp)
|
|
return
|
|
}
|
|
if workspace.OrganizationID != chat.OrganizationID {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Workspace does not belong to this chat's organization.",
|
|
})
|
|
return
|
|
}
|
|
}
|
|
|
|
updatedChat, err := api.Database.UpdateChatWorkspaceBinding(ctx, database.UpdateChatWorkspaceBindingParams{
|
|
ID: chat.ID,
|
|
WorkspaceID: workspaceID,
|
|
BuildID: uuid.NullUUID{},
|
|
AgentID: uuid.NullUUID{},
|
|
})
|
|
if err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
httpapi.ResourceNotFound(rw)
|
|
return
|
|
}
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to update chat workspace binding.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
chat = updatedChat
|
|
}
|
|
|
|
if planModeUpdate != nil {
|
|
updatedChat, err := api.Database.UpdateChatPlanModeByID(ctx, database.UpdateChatPlanModeByIDParams{
|
|
PlanMode: *planModeUpdate,
|
|
ID: chat.ID,
|
|
})
|
|
if err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
httpapi.ResourceNotFound(rw)
|
|
return
|
|
}
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to update chat plan mode.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
chat = updatedChat
|
|
}
|
|
|
|
if refreshed, err := api.Database.GetChatByID(ctx, chat.ID); err == nil {
|
|
aReq.New = refreshed
|
|
} else {
|
|
aReq.New = chat // fallback
|
|
api.Logger.Error(ctx, "failed to refresh chat for audit", slog.F("chat_id", chat.ID), slog.Error(err))
|
|
}
|
|
|
|
rw.WriteHeader(http.StatusNoContent)
|
|
}
|
|
|
|
// writeChildUnarchiveGuard returns a 400 early when a child unarchive
|
|
// request obviously races an archived parent. The durable invariant
|
|
// is enforced atomically in chatd.Server.UnarchiveChat; this guard
|
|
// just surfaces the error before we take any locks.
|
|
//
|
|
// Returns true when a response has been written.
|
|
func (api *API) writeChildUnarchiveGuard(
|
|
ctx context.Context,
|
|
rw http.ResponseWriter,
|
|
chat database.Chat,
|
|
) bool {
|
|
parent, err := api.Database.GetChatByID(ctx, chat.ParentChatID.UUID)
|
|
if err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
httpapi.ResourceNotFound(rw)
|
|
return true
|
|
}
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to load parent chat.",
|
|
Detail: err.Error(),
|
|
})
|
|
return true
|
|
}
|
|
if parent.Archived {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Cannot unarchive a child chat while its parent is archived. Unarchive the parent chat to cascade.",
|
|
})
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
|
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
|
//
|
|
// @Summary Send chat message
|
|
// @ID send-chat-message
|
|
// @Security CoderSessionToken
|
|
// @Tags Chats
|
|
// @Accept json
|
|
// @Produce json
|
|
// @Param chat path string true "Chat ID" format(uuid)
|
|
// @Param request body codersdk.CreateChatMessageRequest true "Create chat message request"
|
|
// @Success 200 {object} codersdk.CreateChatMessageResponse
|
|
// @Router /api/experimental/chats/{chat}/messages [post]
|
|
// @Description Experimental: this endpoint is subject to change.
|
|
func (api *API) postChatMessages(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
apiKey := httpmw.APIKey(r)
|
|
chat := httpmw.ChatParam(r)
|
|
chatID := chat.ID
|
|
|
|
// Sending a message triggers LLM inference, requiring update
|
|
// permission on the org-scoped chat resource.
|
|
if !api.Authorize(r, policy.ActionUpdate, chat.RBACObject()) {
|
|
httpapi.ResourceNotFound(rw)
|
|
return
|
|
}
|
|
|
|
// Only the chat owner may send messages. Org admins pass the
|
|
// RBAC check above (org-level ActionUpdate), but chat
|
|
// processing forwards the *owner's* credentials (OIDC tokens,
|
|
// provider API keys) to external services. Allowing a
|
|
// non-owner to trigger processing would leak the owner's
|
|
// tokens to MCP servers the caller controls.
|
|
if apiKey.UserID != chat.OwnerID {
|
|
httpapi.Write(ctx, rw, http.StatusForbidden, codersdk.Response{
|
|
Message: "Only the chat owner may send messages.",
|
|
})
|
|
return
|
|
}
|
|
|
|
if chat.Archived {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Cannot send messages to an archived chat.",
|
|
})
|
|
return
|
|
}
|
|
|
|
if api.chatDaemon == nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Chat processor is unavailable.",
|
|
Detail: "Chat processor is not configured.",
|
|
})
|
|
return
|
|
}
|
|
|
|
var req codersdk.CreateChatMessageRequest
|
|
if !httpapi.Read(ctx, rw, r, &req) {
|
|
return
|
|
}
|
|
|
|
contentBlocks, _, fileIDs, inputError := createChatInputFromParts(ctx, api.Database, req.Content, "content")
|
|
if inputError != nil {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: inputError.Message,
|
|
Detail: inputError.Detail,
|
|
})
|
|
return
|
|
}
|
|
|
|
// Validate MCP server IDs exist.
|
|
if req.MCPServerIDs != nil && len(*req.MCPServerIDs) > 0 {
|
|
//nolint:gocritic // Need to validate MCP server IDs exist.
|
|
existingConfigs, err := api.Database.GetMCPServerConfigsByIDs(dbauthz.AsSystemRestricted(ctx), *req.MCPServerIDs)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to validate MCP server IDs.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
if len(existingConfigs) != len(*req.MCPServerIDs) {
|
|
found := make(map[uuid.UUID]struct{}, len(existingConfigs))
|
|
for _, c := range existingConfigs {
|
|
found[c.ID] = struct{}{}
|
|
}
|
|
var missing []string
|
|
for _, id := range *req.MCPServerIDs {
|
|
if _, ok := found[id]; !ok {
|
|
missing = append(missing, id.String())
|
|
}
|
|
}
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "One or more MCP server IDs are invalid.",
|
|
Detail: fmt.Sprintf("Invalid IDs: %s", strings.Join(missing, ", ")),
|
|
})
|
|
return
|
|
}
|
|
}
|
|
|
|
if req.PlanMode != nil {
|
|
if !validateChatPlanMode(*req.PlanMode) {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid plan_mode value.",
|
|
})
|
|
return
|
|
}
|
|
}
|
|
|
|
var sendPlanMode *database.NullChatPlanMode
|
|
if req.PlanMode != nil {
|
|
resolvedPlanMode := planModeToNullChatPlanMode(*req.PlanMode)
|
|
sendPlanMode = &resolvedPlanMode
|
|
}
|
|
|
|
busyBehavior := chatd.SendMessageBusyBehaviorQueue
|
|
switch req.BusyBehavior {
|
|
case codersdk.ChatBusyBehaviorInterrupt:
|
|
busyBehavior = chatd.SendMessageBusyBehaviorInterrupt
|
|
case codersdk.ChatBusyBehaviorQueue, "":
|
|
// Default to queue.
|
|
default:
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid busy_behavior value.",
|
|
Detail: `Must be "queue" or "interrupt".`,
|
|
})
|
|
return
|
|
}
|
|
|
|
modelConfigID := uuid.Nil
|
|
if req.ModelConfigID != nil {
|
|
modelConfigID = *req.ModelConfigID
|
|
}
|
|
|
|
sendResult, sendErr := api.chatDaemon.SendMessage(
|
|
ctx,
|
|
chatd.SendMessageOptions{
|
|
ChatID: chatID,
|
|
CreatedBy: apiKey.UserID,
|
|
Content: contentBlocks,
|
|
ModelConfigID: modelConfigID,
|
|
BusyBehavior: busyBehavior,
|
|
PlanMode: sendPlanMode,
|
|
MCPServerIDs: req.MCPServerIDs,
|
|
},
|
|
)
|
|
if sendErr != nil {
|
|
if maybeWriteLimitErr(ctx, rw, sendErr) {
|
|
return
|
|
}
|
|
if xerrors.Is(sendErr, chatd.ErrChatArchived) {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Cannot send messages to an archived chat.",
|
|
})
|
|
return
|
|
}
|
|
if xerrors.Is(sendErr, chatd.ErrMessageQueueFull) {
|
|
httpapi.Write(ctx, rw, http.StatusTooManyRequests, codersdk.Response{
|
|
Message: "Message queue is full.",
|
|
Detail: fmt.Sprintf("Maximum %d messages can be queued.", chatd.MaxQueueSize),
|
|
})
|
|
return
|
|
}
|
|
if xerrors.Is(sendErr, chatd.ErrInvalidModelConfigID) {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid model config ID.",
|
|
})
|
|
return
|
|
}
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to create chat message.",
|
|
Detail: sendErr.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
// Link any user-uploaded files referenced in this message
|
|
// to the chat (best-effort; cap enforced in SQL).
|
|
unlinked, capExceeded := api.linkFilesToChat(ctx, chatID, fileIDs)
|
|
response := codersdk.CreateChatMessageResponse{Queued: sendResult.Queued}
|
|
if sendResult.Queued {
|
|
if sendResult.QueuedMessage != nil {
|
|
response.QueuedMessage = convertChatQueuedMessagePtr(*sendResult.QueuedMessage)
|
|
}
|
|
} else {
|
|
message := convertChatMessage(sendResult.Message)
|
|
response.Message = &message
|
|
}
|
|
if len(unlinked) > 0 {
|
|
if capExceeded {
|
|
response.Warnings = append(response.Warnings, fileLinkCapWarning(len(unlinked)))
|
|
} else {
|
|
response.Warnings = append(response.Warnings, fileLinkErrorWarning(len(unlinked)))
|
|
}
|
|
}
|
|
|
|
httpapi.Write(ctx, rw, http.StatusOK, response)
|
|
}
|
|
|
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
|
//
|
|
// @Summary Edit chat message
|
|
// @ID edit-chat-message
|
|
// @Security CoderSessionToken
|
|
// @Tags Chats
|
|
// @Accept json
|
|
// @Produce json
|
|
// @Param chat path string true "Chat ID" format(uuid)
|
|
// @Param message path int true "Message ID"
|
|
// @Param request body codersdk.EditChatMessageRequest true "Edit chat message request"
|
|
// @Success 200 {object} codersdk.EditChatMessageResponse
|
|
// @Router /api/experimental/chats/{chat}/messages/{message} [patch]
|
|
// @Description Experimental: this endpoint is subject to change.
|
|
func (api *API) patchChatMessage(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
apiKey := httpmw.APIKey(r)
|
|
chat := httpmw.ChatParam(r)
|
|
|
|
if !api.Authorize(r, policy.ActionUpdate, chat.RBACObject()) {
|
|
httpapi.ResourceNotFound(rw)
|
|
return
|
|
}
|
|
|
|
// Only the chat owner may edit messages. See postChatMessages
|
|
// for the security rationale.
|
|
if apiKey.UserID != chat.OwnerID {
|
|
httpapi.Write(ctx, rw, http.StatusForbidden, codersdk.Response{
|
|
Message: "Only the chat owner may edit messages.",
|
|
})
|
|
return
|
|
}
|
|
|
|
if chat.Archived {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Cannot edit messages in an archived chat.",
|
|
})
|
|
return
|
|
}
|
|
|
|
if api.chatDaemon == nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Chat processor is unavailable.",
|
|
Detail: "Chat processor is not configured.",
|
|
})
|
|
return
|
|
}
|
|
|
|
messageIDStr := chi.URLParam(r, "message")
|
|
messageID, err := strconv.ParseInt(messageIDStr, 10, 64)
|
|
if err != nil || messageID <= 0 {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid chat message ID.",
|
|
Detail: "Message ID must be a positive integer.",
|
|
})
|
|
return
|
|
}
|
|
|
|
var req codersdk.EditChatMessageRequest
|
|
if !httpapi.Read(ctx, rw, r, &req) {
|
|
return
|
|
}
|
|
|
|
contentBlocks, _, fileIDs, inputError := createChatInputFromParts(ctx, api.Database, req.Content, "content")
|
|
if inputError != nil {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: inputError.Message,
|
|
Detail: inputError.Detail,
|
|
})
|
|
return
|
|
}
|
|
|
|
editModelConfigID := uuid.Nil
|
|
if req.ModelConfigID != nil {
|
|
editModelConfigID = *req.ModelConfigID
|
|
}
|
|
|
|
editResult, editErr := api.chatDaemon.EditMessage(ctx, chatd.EditMessageOptions{
|
|
ChatID: chat.ID,
|
|
CreatedBy: apiKey.UserID,
|
|
EditedMessageID: messageID,
|
|
Content: contentBlocks,
|
|
ModelConfigID: editModelConfigID,
|
|
})
|
|
if editErr != nil {
|
|
if maybeWriteLimitErr(ctx, rw, editErr) {
|
|
return
|
|
}
|
|
|
|
switch {
|
|
case xerrors.Is(editErr, chatd.ErrChatArchived):
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Cannot edit messages in an archived chat.",
|
|
})
|
|
case xerrors.Is(editErr, chatd.ErrEditedMessageNotFound):
|
|
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
|
|
Message: "Chat message not found.",
|
|
Detail: "Message does not belong to this chat.",
|
|
})
|
|
case xerrors.Is(editErr, chatd.ErrEditedMessageNotUser):
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Only user messages can be edited.",
|
|
})
|
|
case xerrors.Is(editErr, chatd.ErrInvalidModelConfigID):
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid model config ID.",
|
|
})
|
|
default:
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to edit chat message.",
|
|
Detail: editErr.Error(),
|
|
})
|
|
}
|
|
return
|
|
}
|
|
|
|
// Link any user-uploaded files referenced in the edited
|
|
// message to the chat (best-effort; cap enforced in SQL).
|
|
unlinked, capExceeded := api.linkFilesToChat(ctx, chat.ID, fileIDs)
|
|
response := codersdk.EditChatMessageResponse{
|
|
Message: convertChatMessage(editResult.Message),
|
|
}
|
|
if len(unlinked) > 0 {
|
|
if capExceeded {
|
|
response.Warnings = append(response.Warnings, fileLinkCapWarning(len(unlinked)))
|
|
} else {
|
|
response.Warnings = append(response.Warnings, fileLinkErrorWarning(len(unlinked)))
|
|
}
|
|
}
|
|
httpapi.Write(ctx, rw, http.StatusOK, response)
|
|
}
|
|
|
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
|
func (api *API) deleteChatQueuedMessage(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
chat := httpmw.ChatParam(r)
|
|
chatID := chat.ID
|
|
|
|
if !api.Authorize(r, policy.ActionUpdate, chat.RBACObject()) {
|
|
httpapi.ResourceNotFound(rw)
|
|
return
|
|
}
|
|
|
|
queuedMessageIDStr := chi.URLParam(r, "queuedMessage")
|
|
queuedMessageID, err := strconv.ParseInt(queuedMessageIDStr, 10, 64)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid queued message ID.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
if api.chatDaemon != nil {
|
|
err = api.chatDaemon.DeleteQueued(ctx, chatID, queuedMessageID)
|
|
} else {
|
|
err = api.Database.DeleteChatQueuedMessage(ctx, database.DeleteChatQueuedMessageParams{
|
|
ID: queuedMessageID,
|
|
ChatID: chatID,
|
|
})
|
|
}
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to delete queued message.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
rw.WriteHeader(http.StatusNoContent)
|
|
}
|
|
|
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
|
func (api *API) promoteChatQueuedMessage(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
apiKey := httpmw.APIKey(r)
|
|
chat := httpmw.ChatParam(r)
|
|
chatID := chat.ID
|
|
|
|
// Promoting a queued message triggers LLM inference,
|
|
// requiring update permission on the org-scoped chat resource.
|
|
if !api.Authorize(r, policy.ActionUpdate, chat.RBACObject()) {
|
|
httpapi.ResourceNotFound(rw)
|
|
return
|
|
}
|
|
|
|
// Only the chat owner may promote messages. See
|
|
// postChatMessages for the security rationale.
|
|
if apiKey.UserID != chat.OwnerID {
|
|
httpapi.Write(ctx, rw, http.StatusForbidden, codersdk.Response{
|
|
Message: "Only the chat owner may promote queued messages.",
|
|
})
|
|
return
|
|
}
|
|
|
|
if chat.Archived {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Cannot promote queued messages in an archived chat.",
|
|
})
|
|
return
|
|
}
|
|
|
|
queuedMessageIDStr := chi.URLParam(r, "queuedMessage")
|
|
queuedMessageID, err := strconv.ParseInt(queuedMessageIDStr, 10, 64)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid queued message ID.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
if api.chatDaemon == nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Chat processor is unavailable.",
|
|
Detail: "Chat processor is not configured.",
|
|
})
|
|
return
|
|
}
|
|
|
|
_, txErr := api.chatDaemon.PromoteQueued(ctx, chatd.PromoteQueuedOptions{
|
|
ChatID: chatID,
|
|
CreatedBy: apiKey.UserID,
|
|
QueuedMessageID: queuedMessageID,
|
|
})
|
|
|
|
if txErr != nil {
|
|
if maybeWriteLimitErr(ctx, rw, txErr) {
|
|
return
|
|
}
|
|
if xerrors.Is(txErr, chatd.ErrChatArchived) {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Cannot promote queued messages in an archived chat.",
|
|
})
|
|
return
|
|
}
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to promote queued message.",
|
|
Detail: txErr.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
httpapi.Write(ctx, rw, http.StatusAccepted, codersdk.Response{
|
|
Message: "Queued message promotion accepted.",
|
|
})
|
|
}
|
|
|
|
// markChatAsRead updates the last read message ID for a chat to the
|
|
// latest message, so subsequent unread checks treat all current
|
|
// messages as seen. This is called on stream connect and disconnect
|
|
// to avoid per-message API calls during active streaming.
|
|
func (api *API) markChatAsRead(ctx context.Context, chatID uuid.UUID) {
|
|
lastMsg, err := api.Database.GetLastChatMessageByRole(ctx, database.GetLastChatMessageByRoleParams{
|
|
ChatID: chatID,
|
|
Role: database.ChatMessageRoleAssistant,
|
|
})
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
// No assistant messages yet, nothing to mark as read.
|
|
return
|
|
}
|
|
if err != nil {
|
|
api.Logger.Warn(ctx, "failed to get last assistant message for read marker",
|
|
slog.F("chat_id", chatID),
|
|
slog.Error(err),
|
|
)
|
|
return
|
|
}
|
|
|
|
err = api.Database.UpdateChatLastReadMessageID(ctx, database.UpdateChatLastReadMessageIDParams{
|
|
ID: chatID,
|
|
LastReadMessageID: lastMsg.ID,
|
|
})
|
|
if err != nil {
|
|
api.Logger.Warn(ctx, "failed to update chat last read message ID",
|
|
slog.F("chat_id", chatID),
|
|
slog.Error(err),
|
|
)
|
|
}
|
|
}
|
|
|
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
|
//
|
|
// @Summary Stream chat events via WebSockets
|
|
// @ID stream-chat-events-via-websockets
|
|
// @Security CoderSessionToken
|
|
// @Tags Chats
|
|
// @Produce json
|
|
// @Param chat path string true "Chat ID" format(uuid)
|
|
// @Success 200 {object} codersdk.ChatStreamEvent
|
|
// @Router /api/experimental/chats/{chat}/stream [get]
|
|
// @Description Experimental: this endpoint is subject to change.
|
|
func (api *API) streamChat(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
chat := httpmw.ChatParam(r)
|
|
chatID := chat.ID
|
|
logger := api.Logger.Named("chat_streamer").With(slog.F("chat_id", chatID))
|
|
|
|
if api.chatDaemon == nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Chat streaming is not available.",
|
|
Detail: "Chat processor is not configured.",
|
|
})
|
|
return
|
|
}
|
|
|
|
var afterMessageID int64
|
|
if v := r.URL.Query().Get("after_id"); v != "" {
|
|
var err error
|
|
afterMessageID, err = strconv.ParseInt(v, 10, 64)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid after_id parameter.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
}
|
|
|
|
// Subscribe before accepting the WebSocket so that failures
|
|
// can still be reported as normal HTTP errors.
|
|
snapshot, events, cancelSub, ok := api.chatDaemon.SubscribeAuthorized(ctx, chat, r.Header, afterMessageID)
|
|
// Subscribe only fails today when the receiver is nil, which
|
|
// the chatDaemon == nil guard above already catches. This is
|
|
// defensive against future Subscribe failure modes.
|
|
if !ok {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Chat streaming is not available.",
|
|
Detail: "Chat stream state is not configured.",
|
|
})
|
|
return
|
|
}
|
|
defer cancelSub()
|
|
|
|
conn, err := websocket.Accept(rw, r, nil)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to open chat stream.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
defer cancel()
|
|
|
|
_ = conn.CloseRead(context.Background())
|
|
|
|
ctx, wsNetConn := codersdk.WebsocketNetConn(ctx, conn, websocket.MessageText)
|
|
defer wsNetConn.Close()
|
|
|
|
go httpapi.HeartbeatClose(ctx, logger, cancel, conn)
|
|
|
|
// The last_read_message_id field is owner-scoped. Shared readers
|
|
// intentionally lack chat update permission, so their streams must not
|
|
// update it.
|
|
if chat.OwnerID == httpmw.APIKey(r).UserID {
|
|
api.markChatAsRead(ctx, chatID)
|
|
defer api.markChatAsRead(context.WithoutCancel(ctx), chatID)
|
|
}
|
|
|
|
encoder := json.NewEncoder(wsNetConn)
|
|
|
|
sendChatStreamBatch := func(batch []codersdk.ChatStreamEvent) error {
|
|
if len(batch) == 0 {
|
|
return nil
|
|
}
|
|
return encoder.Encode(batch)
|
|
}
|
|
|
|
drainChatStreamBatch := func(
|
|
first codersdk.ChatStreamEvent,
|
|
maxBatchSize int,
|
|
) ([]codersdk.ChatStreamEvent, bool) {
|
|
batch := []codersdk.ChatStreamEvent{first}
|
|
if maxBatchSize <= 1 {
|
|
return batch, false
|
|
}
|
|
|
|
for len(batch) < maxBatchSize {
|
|
select {
|
|
case event, ok := <-events:
|
|
if !ok {
|
|
return batch, true
|
|
}
|
|
batch = append(batch, event)
|
|
default:
|
|
return batch, false
|
|
}
|
|
}
|
|
|
|
return batch, false
|
|
}
|
|
|
|
for start := 0; start < len(snapshot); start += chatStreamBatchSize {
|
|
end := start + chatStreamBatchSize
|
|
if end > len(snapshot) {
|
|
end = len(snapshot)
|
|
}
|
|
if err := sendChatStreamBatch(snapshot[start:end]); err != nil {
|
|
logger.Debug(ctx, "failed to send chat stream snapshot", slog.Error(err))
|
|
return
|
|
}
|
|
}
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case firstEvent, ok := <-events:
|
|
if !ok {
|
|
return
|
|
}
|
|
batch, streamClosed := drainChatStreamBatch(
|
|
firstEvent,
|
|
chatStreamBatchSize,
|
|
)
|
|
if err := sendChatStreamBatch(batch); err != nil {
|
|
logger.Debug(ctx, "failed to send chat stream event", slog.Error(err))
|
|
return
|
|
}
|
|
if streamClosed {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
|
//
|
|
// @Summary Interrupt chat
|
|
// @ID interrupt-chat
|
|
// @Security CoderSessionToken
|
|
// @Tags Chats
|
|
// @Param chat path string true "Chat ID" format(uuid)
|
|
// @Produce json
|
|
// @Success 200 {object} codersdk.Chat
|
|
// @Router /api/experimental/chats/{chat}/interrupt [post]
|
|
// @Description Experimental: this endpoint is subject to change.
|
|
func (api *API) interruptChat(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
chat := httpmw.ChatParam(r)
|
|
chatID := chat.ID
|
|
logger := api.Logger.Named("chat_interrupt").With(slog.F("chat_id", chatID))
|
|
|
|
if !api.Authorize(r, policy.ActionUpdate, chat.RBACObject()) {
|
|
httpapi.ResourceNotFound(rw)
|
|
return
|
|
}
|
|
|
|
if api.chatDaemon != nil {
|
|
chat = api.chatDaemon.InterruptChat(ctx, chat)
|
|
} else {
|
|
updatedChat, updateErr := api.Database.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
|
ID: chatID,
|
|
Status: database.ChatStatusWaiting,
|
|
WorkerID: uuid.NullUUID{},
|
|
StartedAt: sql.NullTime{},
|
|
HeartbeatAt: sql.NullTime{},
|
|
LastError: pqtype.NullRawMessage{},
|
|
})
|
|
if updateErr != nil {
|
|
logger.Error(ctx, "failed to mark chat as waiting", slog.Error(updateErr))
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to interrupt chat.",
|
|
Detail: updateErr.Error(),
|
|
})
|
|
return
|
|
}
|
|
chat = updatedChat
|
|
}
|
|
|
|
httpapi.Write(ctx, rw, http.StatusOK, db2sdk.Chat(chat, nil, nil))
|
|
}
|
|
|
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
|
//
|
|
// @Summary Regenerate chat title
|
|
// @ID regenerate-chat-title
|
|
// @Security CoderSessionToken
|
|
// @Tags Chats
|
|
// @Produce json
|
|
// @Param chat path string true "Chat ID" format(uuid)
|
|
// @Success 200 {object} codersdk.Chat
|
|
// @Router /api/experimental/chats/{chat}/title/regenerate [post]
|
|
// @Description Experimental: this endpoint is subject to change.
|
|
//
|
|
//nolint:revive // HTTP handler writes to ResponseWriter.
|
|
func (api *API) regenerateChatTitle(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
apiKey := httpmw.APIKey(r)
|
|
chat := httpmw.ChatParam(r)
|
|
|
|
if !api.Authorize(r, policy.ActionUpdate, chat.RBACObject()) {
|
|
httpapi.ResourceNotFound(rw)
|
|
return
|
|
}
|
|
|
|
// Only the chat owner may regenerate titles. See
|
|
// postChatMessages for the security rationale.
|
|
if apiKey.UserID != chat.OwnerID {
|
|
httpapi.Write(ctx, rw, http.StatusForbidden, codersdk.Response{
|
|
Message: "Only the chat owner may regenerate the title.",
|
|
})
|
|
return
|
|
}
|
|
|
|
if api.chatDaemon == nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Chat processor is unavailable.",
|
|
Detail: "Chat processor is not configured.",
|
|
})
|
|
return
|
|
}
|
|
|
|
updatedChat, err := api.chatDaemon.RegenerateChatTitle(ctx, chat)
|
|
if err != nil {
|
|
if errors.Is(err, chatd.ErrManualTitleRegenerationInProgress) {
|
|
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{
|
|
Message: "Title regeneration already in progress for this chat.",
|
|
})
|
|
return
|
|
}
|
|
if maybeWriteLimitErr(ctx, rw, err) {
|
|
return
|
|
}
|
|
if httpapi.Is404Error(err) {
|
|
httpapi.ResourceNotFound(rw)
|
|
return
|
|
}
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to regenerate chat title.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
httpapi.Write(ctx, rw, http.StatusOK, db2sdk.Chat(updatedChat, nil, nil))
|
|
}
|
|
|
|
//nolint:revive // HTTP handler writes to ResponseWriter.
|
|
func (api *API) proposeChatTitle(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
apiKey := httpmw.APIKey(r)
|
|
chat := httpmw.ChatParam(r)
|
|
|
|
if !api.Authorize(r, policy.ActionUpdate, chat.RBACObject()) {
|
|
httpapi.ResourceNotFound(rw)
|
|
return
|
|
}
|
|
|
|
// Only the chat owner may propose titles. See
|
|
// postChatMessages for the security rationale.
|
|
if apiKey.UserID != chat.OwnerID {
|
|
httpapi.Write(ctx, rw, http.StatusForbidden, codersdk.Response{
|
|
Message: "Only the chat owner may propose a title.",
|
|
})
|
|
return
|
|
}
|
|
|
|
if api.chatDaemon == nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Chat processor is unavailable.",
|
|
Detail: "Chat processor is not configured.",
|
|
})
|
|
return
|
|
}
|
|
|
|
title, err := api.chatDaemon.ProposeChatTitle(ctx, chat)
|
|
if err != nil {
|
|
if errors.Is(err, chatd.ErrManualTitleRegenerationInProgress) {
|
|
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{
|
|
Message: "Title regeneration already in progress for this chat.",
|
|
})
|
|
return
|
|
}
|
|
if maybeWriteLimitErr(ctx, rw, err) {
|
|
return
|
|
}
|
|
if httpapi.Is404Error(err) {
|
|
httpapi.ResourceNotFound(rw)
|
|
return
|
|
}
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to generate chat title.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
httpapi.Write(ctx, rw, http.StatusOK, codersdk.ProposeChatTitleResponse{Title: title})
|
|
}
|
|
|
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
|
//
|
|
// @Summary Get chat diff contents
|
|
// @ID get-chat-diff-contents
|
|
// @Security CoderSessionToken
|
|
// @Tags Chats
|
|
// @Produce json
|
|
// @Param chat path string true "Chat ID" format(uuid)
|
|
// @Success 200 {object} codersdk.ChatDiffContents
|
|
// @Router /api/experimental/chats/{chat}/diff [get]
|
|
// @Description Experimental: this endpoint is subject to change.
|
|
//
|
|
//nolint:revive // HTTP handler writes to ResponseWriter.
|
|
func (api *API) getChatDiffContents(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
chat := httpmw.ChatParam(r)
|
|
|
|
diff, err := api.resolveChatDiffContents(ctx, chat)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to get chat diff.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
httpapi.Write(ctx, rw, http.StatusOK, diff)
|
|
}
|
|
|
|
// chatCreateWorkspace provides workspace creation for the chat
|
|
// processor. RBAC authorization uses context-based checks via
|
|
// dbauthz.As rather than fake *http.Request objects.
|
|
func (api *API) chatCreateWorkspace(
|
|
ctx context.Context,
|
|
ownerID uuid.UUID,
|
|
req codersdk.CreateWorkspaceRequest,
|
|
) (codersdk.Workspace, error) {
|
|
actor, _, err := httpmw.UserRBACSubject(ctx, api.Database, ownerID, rbac.ScopeAll)
|
|
if err != nil {
|
|
return codersdk.Workspace{}, xerrors.Errorf("load user authorization: %w", err)
|
|
}
|
|
ctx = dbauthz.As(ctx, actor)
|
|
|
|
ownerUser, err := api.Database.GetUserByID(ctx, ownerID)
|
|
if err != nil {
|
|
return codersdk.Workspace{}, xerrors.Errorf("get workspace owner: %w", err)
|
|
}
|
|
owner := workspaceOwner{
|
|
ID: ownerUser.ID,
|
|
Username: ownerUser.Username,
|
|
AvatarURL: ownerUser.AvatarURL,
|
|
}
|
|
|
|
auditor := api.Auditor.Load()
|
|
if auditor == nil {
|
|
return codersdk.Workspace{}, xerrors.New("auditor is not configured")
|
|
}
|
|
|
|
// The audit system requires a ResponseWriter to capture the
|
|
// HTTP status code. Since this is a programmatic call, we use
|
|
// a recorder. The audit entry still captures the owner, action,
|
|
// and resource correctly.
|
|
rw := httptest.NewRecorder()
|
|
sw := &tracing.StatusWriter{ResponseWriter: rw}
|
|
|
|
// Build a minimal synthetic request so the audit commit
|
|
// closure can extract a request ID and user agent. The RBAC
|
|
// subject is already on the context via dbauthz.As above.
|
|
auditReq, err := http.NewRequestWithContext(
|
|
httpmw.WithRequestID(ctx, uuid.New()),
|
|
http.MethodPost,
|
|
"http://localhost/internal/chat/workspace",
|
|
nil,
|
|
)
|
|
if err != nil {
|
|
return codersdk.Workspace{}, xerrors.Errorf("create audit request: %w", err)
|
|
}
|
|
|
|
aReq, commitAudit := audit.InitRequest[database.WorkspaceTable](sw, &audit.RequestParams{
|
|
Audit: *auditor,
|
|
Log: api.Logger,
|
|
Request: auditReq,
|
|
Action: database.AuditActionCreate,
|
|
AdditionalFields: audit.AdditionalFields{
|
|
WorkspaceOwner: owner.Username,
|
|
},
|
|
})
|
|
aReq.UserID = ownerID
|
|
defer commitAudit()
|
|
|
|
workspace, err := createWorkspace(ctx, aReq, ownerID, api, owner, req, nil)
|
|
if err != nil {
|
|
sw.WriteHeader(chatWorkspaceAuditStatus(err))
|
|
return codersdk.Workspace{}, err
|
|
}
|
|
|
|
sw.WriteHeader(http.StatusCreated)
|
|
return workspace, nil
|
|
}
|
|
|
|
// chatStartWorkspace starts a stopped workspace by creating a new
|
|
// build with the "start" transition. It mirrors chatCreateWorkspace
|
|
// but for the start path.
|
|
//
|
|
// Aliased as ChatStartWorkspace in coderd/export_test.go so external
|
|
// tests in the coderd_test package can drive the auto-update path
|
|
// end-to-end. The proper fix is to extract the request building into
|
|
// a pure function; tracked in CODAGT-292.
|
|
func (api *API) chatStartWorkspace(
|
|
ctx context.Context,
|
|
ownerID uuid.UUID,
|
|
workspaceID uuid.UUID,
|
|
req codersdk.CreateWorkspaceBuildRequest,
|
|
) (codersdk.WorkspaceBuild, error) {
|
|
actor, _, err := httpmw.UserRBACSubject(ctx, api.Database, ownerID, rbac.ScopeAll)
|
|
if err != nil {
|
|
return codersdk.WorkspaceBuild{}, xerrors.Errorf("load user authorization: %w", err)
|
|
}
|
|
ctx = dbauthz.As(ctx, actor)
|
|
|
|
workspace, err := api.Database.GetWorkspaceByID(ctx, workspaceID)
|
|
if err != nil {
|
|
return codersdk.WorkspaceBuild{}, xerrors.Errorf("get workspace: %w", err)
|
|
}
|
|
|
|
updatedToActiveVersion := false
|
|
if req.Transition == codersdk.WorkspaceTransitionStart {
|
|
template, err := api.Database.GetTemplateByID(ctx, workspace.TemplateID)
|
|
if err != nil {
|
|
return codersdk.WorkspaceBuild{}, xerrors.Errorf("get template: %w", err)
|
|
}
|
|
|
|
templateAccessControl := (*(api.AccessControlStore.Load())).GetTemplateAccessControl(template)
|
|
if templateAccessControl.RequireActiveVersion {
|
|
latestBuild, err := api.Database.GetLatestWorkspaceBuildByWorkspaceID(ctx, workspace.ID)
|
|
if err != nil {
|
|
return codersdk.WorkspaceBuild{}, xerrors.Errorf("get latest workspace build: %w", err)
|
|
}
|
|
|
|
updatedToActiveVersion = latestBuild.TemplateVersionID != template.ActiveVersionID
|
|
req.TemplateVersionID = template.ActiveVersionID
|
|
}
|
|
}
|
|
|
|
// Build a synthetic API key so postWorkspaceBuildsInternal can
|
|
// record the correct initiator.
|
|
syntheticKey := database.APIKey{
|
|
UserID: ownerID,
|
|
}
|
|
|
|
apiBuild, err := api.postWorkspaceBuildsInternal(
|
|
ctx,
|
|
syntheticKey,
|
|
workspace,
|
|
req,
|
|
func(action policy.Action, object rbac.Objecter) bool {
|
|
// Authorization is handled by dbauthz on the context.
|
|
authErr := api.HTTPAuth.Authorizer.Authorize(ctx, actor, action, object.RBACObject())
|
|
return authErr == nil
|
|
},
|
|
audit.WorkspaceBuildBaggage{},
|
|
)
|
|
if err != nil {
|
|
if updatedToActiveVersion && isChatStartWorkspaceManualUpdateRequiredError(err) {
|
|
const retryInstructions = "The workspace needs the template's active version before it can start. Use read_template with this workspace's template_id to inspect the active version's required parameters, then retry start_workspace with a parameters object that supplies any missing or changed values. If the correct value for a parameter is not obvious from its description or defaults, ask the user rather than guessing."
|
|
if responder, ok := httperror.IsResponder(err); ok {
|
|
status, resp := responder.Response()
|
|
resp = rewriteChatStartWorkspaceManualUpdateResponse(resp, err.Error(), retryInstructions)
|
|
return codersdk.WorkspaceBuild{}, httperror.NewResponseError(status, resp)
|
|
}
|
|
return codersdk.WorkspaceBuild{}, httperror.NewResponseError(http.StatusBadRequest, codersdk.Response{
|
|
Message: retryInstructions,
|
|
Detail: err.Error(),
|
|
})
|
|
}
|
|
return codersdk.WorkspaceBuild{}, xerrors.Errorf("create workspace build: %w", err)
|
|
}
|
|
|
|
return apiBuild, nil
|
|
}
|
|
|
|
// chatStopWorkspace stops a workspace by creating a new build with the
|
|
// "stop" transition. It mirrors chatStartWorkspace, without start-only
|
|
// active-version behavior.
|
|
func (api *API) chatStopWorkspace(
|
|
ctx context.Context,
|
|
ownerID uuid.UUID,
|
|
workspaceID uuid.UUID,
|
|
req codersdk.CreateWorkspaceBuildRequest,
|
|
) (codersdk.WorkspaceBuild, error) {
|
|
actor, _, err := httpmw.UserRBACSubject(ctx, api.Database, ownerID, rbac.ScopeAll)
|
|
if err != nil {
|
|
return codersdk.WorkspaceBuild{}, xerrors.Errorf("load user authorization: %w", err)
|
|
}
|
|
ctx = dbauthz.As(ctx, actor)
|
|
|
|
workspace, err := api.Database.GetWorkspaceByID(ctx, workspaceID)
|
|
if err != nil {
|
|
return codersdk.WorkspaceBuild{}, xerrors.Errorf("get workspace: %w", err)
|
|
}
|
|
|
|
req.Transition = codersdk.WorkspaceTransitionStop
|
|
|
|
// Build a synthetic API key so postWorkspaceBuildsInternal can
|
|
// record the correct initiator.
|
|
syntheticKey := database.APIKey{
|
|
UserID: ownerID,
|
|
}
|
|
|
|
apiBuild, err := api.postWorkspaceBuildsInternal(
|
|
ctx,
|
|
syntheticKey,
|
|
workspace,
|
|
req,
|
|
func(action policy.Action, object rbac.Objecter) bool {
|
|
// Authorization is handled by dbauthz on the context.
|
|
authErr := api.HTTPAuth.Authorizer.Authorize(ctx, actor, action, object.RBACObject())
|
|
return authErr == nil
|
|
},
|
|
audit.WorkspaceBuildBaggage{},
|
|
)
|
|
if err != nil {
|
|
return codersdk.WorkspaceBuild{}, xerrors.Errorf("create workspace build: %w", err)
|
|
}
|
|
|
|
return apiBuild, nil
|
|
}
|
|
|
|
func rewriteChatStartWorkspaceManualUpdateResponse(resp codersdk.Response, fallbackDetail string, retryInstructions string) codersdk.Response {
|
|
originalMessage := resp.Message
|
|
resp.Message = retryInstructions
|
|
if len(resp.Validations) == 0 && originalMessage != "" {
|
|
if resp.Detail == "" {
|
|
resp.Detail = originalMessage
|
|
} else {
|
|
resp.Detail = originalMessage + ": " + resp.Detail
|
|
}
|
|
} else if resp.Detail == "" {
|
|
resp.Detail = fallbackDetail
|
|
}
|
|
return resp
|
|
}
|
|
|
|
func isChatStartWorkspaceManualUpdateRequiredError(err error) bool {
|
|
var diagnosticErr *dynamicparameters.DiagnosticError
|
|
if errors.As(err, &diagnosticErr) {
|
|
return true
|
|
}
|
|
|
|
return errors.Is(err, wsbuilder.ErrParameterValidation)
|
|
}
|
|
|
|
func chatWorkspaceAuditStatus(err error) int {
|
|
if responder, ok := httperror.IsResponder(err); ok {
|
|
status, _ := responder.Response()
|
|
return status
|
|
}
|
|
return http.StatusInternalServerError
|
|
}
|
|
|
|
func (api *API) resolveChatDiffContents(
|
|
ctx context.Context,
|
|
chat database.Chat,
|
|
) (codersdk.ChatDiffContents, error) {
|
|
result := codersdk.ChatDiffContents{ChatID: chat.ID}
|
|
|
|
status, found, err := api.getCachedChatDiffStatus(ctx, chat.ID)
|
|
if err != nil {
|
|
return result, err
|
|
}
|
|
|
|
reference, err := api.resolveChatDiffReference(ctx, chat, found, status)
|
|
if err != nil {
|
|
return result, err
|
|
}
|
|
|
|
if reference.RepositoryRef != nil {
|
|
provider := strings.TrimSpace(reference.RepositoryRef.Provider)
|
|
if provider != "" {
|
|
result.Provider = &provider
|
|
}
|
|
|
|
origin := strings.TrimSpace(reference.RepositoryRef.RemoteOrigin)
|
|
if origin != "" {
|
|
result.RemoteOrigin = &origin
|
|
}
|
|
|
|
branch := strings.TrimSpace(reference.RepositoryRef.Branch)
|
|
if branch != "" {
|
|
result.Branch = &branch
|
|
}
|
|
}
|
|
|
|
if reference.PullRequestURL != "" {
|
|
pullRequestURL := strings.TrimSpace(reference.PullRequestURL)
|
|
result.PullRequestURL = &pullRequestURL
|
|
if !found || !strings.EqualFold(strings.TrimSpace(status.Url.String), pullRequestURL) {
|
|
_, err := api.upsertChatDiffStatusReference(ctx, chat.ID, pullRequestURL, time.Now().UTC().Add(-time.Second))
|
|
if err != nil {
|
|
return result, err
|
|
}
|
|
}
|
|
}
|
|
|
|
if reference.RepositoryRef == nil {
|
|
return result, nil
|
|
}
|
|
|
|
gp := api.resolveGitProvider(reference.RepositoryRef.RemoteOrigin)
|
|
if gp == nil {
|
|
return result, nil
|
|
}
|
|
|
|
token, err := api.resolveChatGitAccessToken(ctx, chat.OwnerID, reference.RepositoryRef.RemoteOrigin)
|
|
if err != nil {
|
|
return result, xerrors.Errorf("resolve git access token: %w", err)
|
|
} else if token == nil {
|
|
return result, xerrors.New("nil git access token")
|
|
}
|
|
|
|
if reference.PullRequestURL != "" {
|
|
ref, ok := gp.ParsePullRequestURL(reference.PullRequestURL)
|
|
if !ok {
|
|
return result, xerrors.Errorf("invalid pull request URL %q", reference.PullRequestURL)
|
|
}
|
|
diff, err := gp.FetchPullRequestDiff(ctx, *token, ref)
|
|
if err != nil {
|
|
return result, err
|
|
}
|
|
result.Diff = diff
|
|
return result, nil
|
|
}
|
|
diff, err := gp.FetchBranchDiff(ctx, *token, gitprovider.BranchRef{
|
|
Owner: reference.RepositoryRef.Owner,
|
|
Repo: reference.RepositoryRef.Repo,
|
|
Branch: reference.RepositoryRef.Branch,
|
|
})
|
|
if err != nil {
|
|
return result, err
|
|
}
|
|
result.Diff = diff
|
|
return result, nil
|
|
}
|
|
|
|
// resolveChatDiffReference builds the diff reference from the cached
|
|
// status stored in the database. The git branch and remote origin are
|
|
// populated by the workspace agent during git operations (via the
|
|
// gitaskpass flow), so no SSH into the workspace is needed here.
|
|
//
|
|
//nolint:revive // Boolean indicates whether diff status was found.
|
|
func (api *API) resolveChatDiffReference(
|
|
ctx context.Context,
|
|
chat database.Chat,
|
|
found bool,
|
|
status database.ChatDiffStatus,
|
|
) (chatDiffReference, error) {
|
|
reference := chatDiffReference{}
|
|
if !found {
|
|
return reference, nil
|
|
}
|
|
|
|
reference.PullRequestURL = strings.TrimSpace(status.Url.String)
|
|
|
|
// Build the repository ref from the stored git branch/origin
|
|
// that the agent reported.
|
|
reference.RepositoryRef = api.buildChatRepositoryRefFromStatus(status)
|
|
|
|
// If we have a repo ref with a branch, try to resolve the
|
|
// current open PR. This picks up new PRs after the previous
|
|
// one was closed.
|
|
if reference.RepositoryRef != nil && reference.RepositoryRef.Owner != "" {
|
|
gp := api.resolveGitProvider(reference.RepositoryRef.RemoteOrigin)
|
|
if gp != nil {
|
|
token, err := api.resolveChatGitAccessToken(ctx, chat.OwnerID, reference.RepositoryRef.RemoteOrigin)
|
|
if token == nil || errors.Is(err, gitsync.ErrNoTokenAvailable) {
|
|
// No token available yet.
|
|
return reference, nil
|
|
} else if err != nil {
|
|
return chatDiffReference{}, xerrors.Errorf("resolve git access token: %w", err)
|
|
}
|
|
prRef, lookupErr := gp.ResolveBranchPullRequest(ctx, *token, gitprovider.BranchRef{
|
|
Owner: reference.RepositoryRef.Owner,
|
|
Repo: reference.RepositoryRef.Repo,
|
|
Branch: reference.RepositoryRef.Branch,
|
|
})
|
|
if lookupErr != nil {
|
|
api.Logger.Debug(ctx, "failed to resolve pull request from repository reference",
|
|
slog.F("chat_id", chat.ID),
|
|
slog.F("provider", reference.RepositoryRef.Provider),
|
|
slog.F("remote_origin", reference.RepositoryRef.RemoteOrigin),
|
|
slog.F("branch", reference.RepositoryRef.Branch),
|
|
slog.Error(lookupErr),
|
|
)
|
|
} else if prRef != nil {
|
|
reference.PullRequestURL = gp.BuildPullRequestURL(*prRef)
|
|
}
|
|
reference.PullRequestURL = gp.NormalizePullRequestURL(reference.PullRequestURL)
|
|
}
|
|
}
|
|
|
|
// If we have a PR URL but no repo ref (e.g. the agent hasn't
|
|
// reported branch/origin yet), derive a partial ref from the
|
|
// PR URL so the caller can still show provider/owner/repo.
|
|
if reference.RepositoryRef == nil && reference.PullRequestURL != "" {
|
|
for _, extAuth := range api.ExternalAuthConfigs {
|
|
gp := extAuth.Git(api.HTTPClient)
|
|
if gp == nil {
|
|
continue
|
|
}
|
|
if parsed, ok := gp.ParsePullRequestURL(reference.PullRequestURL); ok {
|
|
reference.RepositoryRef = &chatRepositoryRef{
|
|
Provider: strings.ToLower(extAuth.Type),
|
|
Owner: parsed.Owner,
|
|
Repo: parsed.Repo,
|
|
RemoteOrigin: gp.BuildRepositoryURL(parsed.Owner, parsed.Repo),
|
|
}
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
return reference, nil
|
|
}
|
|
|
|
// buildChatRepositoryRefFromStatus constructs a chatRepositoryRef
|
|
// from the git branch and remote origin stored in the cached status.
|
|
// Returns nil if no ref data is available.
|
|
func (api *API) buildChatRepositoryRefFromStatus(status database.ChatDiffStatus) *chatRepositoryRef {
|
|
branch := strings.TrimSpace(status.GitBranch)
|
|
origin := strings.TrimSpace(status.GitRemoteOrigin)
|
|
if branch == "" || origin == "" {
|
|
return nil
|
|
}
|
|
|
|
providerType, gp := api.resolveExternalAuth(origin)
|
|
repoRef := &chatRepositoryRef{
|
|
Provider: providerType,
|
|
RemoteOrigin: origin,
|
|
Branch: branch,
|
|
}
|
|
if gp != nil {
|
|
if owner, repo, normalizedOrigin, ok := gp.ParseRepositoryOrigin(repoRef.RemoteOrigin); ok {
|
|
repoRef.RemoteOrigin = normalizedOrigin
|
|
repoRef.Owner = owner
|
|
repoRef.Repo = repo
|
|
}
|
|
}
|
|
|
|
if repoRef.Provider == "" {
|
|
return nil
|
|
}
|
|
|
|
return repoRef
|
|
}
|
|
|
|
func (api *API) upsertChatDiffStatusReference(
|
|
ctx context.Context,
|
|
chatID uuid.UUID,
|
|
pullRequestURL string,
|
|
staleAt time.Time,
|
|
) (database.ChatDiffStatus, error) {
|
|
status, err := api.Database.UpsertChatDiffStatusReference(
|
|
ctx,
|
|
database.UpsertChatDiffStatusReferenceParams{
|
|
ChatID: chatID,
|
|
Url: sql.NullString{
|
|
String: pullRequestURL,
|
|
Valid: strings.TrimSpace(pullRequestURL) != "",
|
|
},
|
|
// Empty strings preserve existing values via the
|
|
// CASE expression in the SQL query.
|
|
GitBranch: "",
|
|
GitRemoteOrigin: "",
|
|
StaleAt: staleAt,
|
|
},
|
|
)
|
|
if err != nil {
|
|
return database.ChatDiffStatus{}, xerrors.Errorf("upsert chat diff status reference: %w", err)
|
|
}
|
|
return status, nil
|
|
}
|
|
|
|
func (api *API) getCachedChatDiffStatus(
|
|
ctx context.Context,
|
|
chatID uuid.UUID,
|
|
) (database.ChatDiffStatus, bool, error) {
|
|
status, err := api.Database.GetChatDiffStatusByChatID(ctx, chatID)
|
|
if err == nil {
|
|
return status, true, nil
|
|
}
|
|
if xerrors.Is(err, sql.ErrNoRows) {
|
|
return database.ChatDiffStatus{}, false, nil
|
|
}
|
|
return database.ChatDiffStatus{}, false, xerrors.Errorf(
|
|
"get chat diff status: %w",
|
|
err,
|
|
)
|
|
}
|
|
|
|
// resolveExternalAuth finds the external auth config matching the
|
|
// given remote origin URL and returns both the provider type string
|
|
// (e.g. "github") and the gitprovider.Provider. Returns ("", nil)
|
|
// if no matching config is found.
|
|
func (api *API) resolveExternalAuth(origin string) (providerType string, gp gitprovider.Provider) {
|
|
origin = strings.TrimSpace(origin)
|
|
if origin == "" {
|
|
return "", nil
|
|
}
|
|
for _, extAuth := range api.ExternalAuthConfigs {
|
|
if extAuth.Regex == nil || !extAuth.Regex.MatchString(origin) {
|
|
continue
|
|
}
|
|
return strings.ToLower(strings.TrimSpace(extAuth.Type)),
|
|
extAuth.Git(api.HTTPClient)
|
|
}
|
|
return "", nil
|
|
}
|
|
|
|
// resolveGitProvider finds the external auth config matching the
|
|
// given remote origin URL and returns its git provider. Returns
|
|
// nil if no matching git provider is configured.
|
|
func (api *API) resolveGitProvider(origin string) gitprovider.Provider {
|
|
_, gp := api.resolveExternalAuth(origin)
|
|
return gp
|
|
}
|
|
|
|
func (api *API) resolveChatGitAccessToken(
|
|
ctx context.Context,
|
|
userID uuid.UUID,
|
|
origin string,
|
|
) (*string, error) {
|
|
origin = strings.TrimSpace(origin)
|
|
|
|
// If we have an origin, find the specific matching config first.
|
|
// This ensures multi-provider setups (github.com + GHE) get the
|
|
// correct token.
|
|
if origin != "" {
|
|
for _, config := range api.ExternalAuthConfigs {
|
|
if config.Regex == nil || !config.Regex.MatchString(origin) {
|
|
continue
|
|
}
|
|
//nolint:gocritic // System access needed to read external auth
|
|
// links when called from the gitsync worker (chatd context).
|
|
link, err := api.Database.GetExternalAuthLink(dbauthz.AsSystemRestricted(ctx),
|
|
database.GetExternalAuthLinkParams{
|
|
ProviderID: config.ID,
|
|
UserID: userID,
|
|
},
|
|
)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
//nolint:gocritic // System context carried through for token refresh.
|
|
refreshed, refreshErr := config.RefreshToken(dbauthz.AsSystemRestricted(ctx), api.Database, link)
|
|
if refreshErr == nil {
|
|
link = refreshed
|
|
}
|
|
token := strings.TrimSpace(link.OAuthAccessToken)
|
|
if token != "" {
|
|
return ptr.Ref(token), nil
|
|
}
|
|
}
|
|
}
|
|
|
|
// Fallback: iterate all external auth configs.
|
|
// Used when origin is empty (inline refresh from HTTP handler)
|
|
// or when the origin-specific lookup above failed.
|
|
configs := make(map[string]*externalauth.Config)
|
|
providerIDs := []string{}
|
|
for _, config := range api.ExternalAuthConfigs {
|
|
providerIDs = append(providerIDs, config.ID)
|
|
configs[config.ID] = config
|
|
}
|
|
|
|
seen := map[string]struct{}{}
|
|
for _, providerID := range providerIDs {
|
|
if _, ok := seen[providerID]; ok {
|
|
continue
|
|
}
|
|
seen[providerID] = struct{}{}
|
|
|
|
//nolint:gocritic // System access needed to read external auth
|
|
// links when called from the gitsync worker (chatd context).
|
|
link, err := api.Database.GetExternalAuthLink(
|
|
dbauthz.AsSystemRestricted(ctx),
|
|
database.GetExternalAuthLinkParams{
|
|
ProviderID: providerID,
|
|
UserID: userID,
|
|
},
|
|
)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
|
|
// Refresh the token if there is a matching config, mirroring
|
|
// the same code path used by provisionerdserver when handing
|
|
// tokens to provisioners.
|
|
if cfg, ok := configs[providerID]; ok {
|
|
//nolint:gocritic // System context carried through for token refresh.
|
|
refreshed, refreshErr := cfg.RefreshToken(dbauthz.AsSystemRestricted(ctx), api.Database, link)
|
|
if refreshErr != nil {
|
|
api.Logger.Debug(ctx, "failed to refresh external auth token for chat diff",
|
|
slog.F("provider_id", providerID),
|
|
slog.F("user_id", userID),
|
|
slog.Error(refreshErr),
|
|
)
|
|
// Fall through — the existing token may still work
|
|
// (e.g. GitHub tokens with no expiry).
|
|
} else {
|
|
link = refreshed
|
|
}
|
|
}
|
|
|
|
token := strings.TrimSpace(link.OAuthAccessToken)
|
|
if token != "" {
|
|
return ptr.Ref(token), nil
|
|
}
|
|
}
|
|
|
|
return nil, gitsync.ErrNoTokenAvailable
|
|
}
|
|
|
|
type createChatWorkspaceSelection struct {
|
|
WorkspaceID uuid.NullUUID
|
|
}
|
|
|
|
func (api *API) validateChatWorkspaceSelection(
|
|
ctx context.Context,
|
|
r *http.Request,
|
|
workspaceID *uuid.UUID,
|
|
) (
|
|
uuid.NullUUID,
|
|
database.Workspace,
|
|
int,
|
|
*codersdk.Response,
|
|
) {
|
|
if workspaceID == nil {
|
|
return uuid.NullUUID{}, database.Workspace{}, 0, nil
|
|
}
|
|
|
|
workspace, err := api.Database.GetWorkspaceByID(ctx, *workspaceID)
|
|
if err != nil {
|
|
if httpapi.Is404Error(err) {
|
|
return uuid.NullUUID{}, database.Workspace{}, http.StatusBadRequest, &codersdk.Response{
|
|
Message: "Workspace not found or you do not have access to this resource",
|
|
}
|
|
}
|
|
return uuid.NullUUID{}, database.Workspace{}, http.StatusInternalServerError, &codersdk.Response{
|
|
Message: "Failed to get workspace.",
|
|
Detail: err.Error(),
|
|
}
|
|
}
|
|
|
|
selection := uuid.NullUUID{
|
|
UUID: workspace.ID,
|
|
Valid: true,
|
|
}
|
|
if !api.Authorize(r, policy.ActionSSH, workspace) {
|
|
return uuid.NullUUID{}, database.Workspace{}, http.StatusBadRequest, &codersdk.Response{
|
|
Message: "Workspace not found or you do not have access to this resource",
|
|
}
|
|
}
|
|
|
|
return selection, workspace, 0, nil
|
|
}
|
|
|
|
func (api *API) validateCreateChatWorkspaceSelection(
|
|
ctx context.Context,
|
|
r *http.Request,
|
|
req codersdk.CreateChatRequest,
|
|
) (
|
|
createChatWorkspaceSelection,
|
|
int,
|
|
*codersdk.Response,
|
|
) {
|
|
selection := createChatWorkspaceSelection{}
|
|
workspaceID, workspace, status, resp := api.validateChatWorkspaceSelection(ctx, r, req.WorkspaceID)
|
|
if resp != nil {
|
|
return selection, status, resp
|
|
}
|
|
selection.WorkspaceID = workspaceID
|
|
if !workspaceID.Valid {
|
|
return selection, 0, nil
|
|
}
|
|
if workspace.OrganizationID != req.OrganizationID {
|
|
return selection, http.StatusBadRequest, &codersdk.Response{
|
|
Message: "Workspace does not belong to the specified organization.",
|
|
}
|
|
}
|
|
|
|
return selection, 0, nil
|
|
}
|
|
|
|
func (api *API) resolveCreateChatModelConfigID(
|
|
ctx context.Context,
|
|
userID uuid.UUID,
|
|
req codersdk.CreateChatRequest,
|
|
) (uuid.UUID, int, *codersdk.Response) {
|
|
if req.ModelConfigID != nil {
|
|
if *req.ModelConfigID == uuid.Nil {
|
|
return uuid.Nil, http.StatusBadRequest, &codersdk.Response{
|
|
Message: "Invalid model config ID.",
|
|
}
|
|
}
|
|
return *req.ModelConfigID, 0, nil
|
|
}
|
|
|
|
personalOverridesEnabled, err := api.Database.GetChatPersonalModelOverridesEnabled(ctx)
|
|
if err != nil {
|
|
return uuid.Nil, http.StatusInternalServerError, &codersdk.Response{
|
|
Message: "Failed to resolve chat model config.",
|
|
Detail: err.Error(),
|
|
}
|
|
}
|
|
if !personalOverridesEnabled {
|
|
return api.defaultCreateChatModelConfigID(ctx)
|
|
}
|
|
|
|
raw, err := api.Database.GetUserChatPersonalModelOverride(ctx, database.GetUserChatPersonalModelOverrideParams{
|
|
UserID: userID,
|
|
Key: chatd.ChatPersonalModelOverrideKey(codersdk.ChatPersonalModelOverrideContextRoot),
|
|
})
|
|
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
|
return uuid.Nil, http.StatusInternalServerError, &codersdk.Response{
|
|
Message: "Failed to resolve chat model config.",
|
|
Detail: err.Error(),
|
|
}
|
|
}
|
|
if err == nil {
|
|
parsed := parseChatPersonalModelOverrideValue(
|
|
raw,
|
|
codersdk.ChatPersonalModelOverrideContextRoot,
|
|
)
|
|
if parsed.Malformed {
|
|
api.Logger.Debug(
|
|
ctx,
|
|
"unsupported personal root model override mode, using default model",
|
|
slog.F("user_id", userID),
|
|
slog.F("raw_value", raw),
|
|
)
|
|
}
|
|
switch parsed.Mode {
|
|
case codersdk.ChatPersonalModelOverrideModeChatDefault:
|
|
// For root context, chat_default and the defensive default
|
|
// case both fall through to the deployment default model below.
|
|
case codersdk.ChatPersonalModelOverrideModeModel:
|
|
reason, err := api.userCanUseChatModelConfig(
|
|
ctx,
|
|
userID,
|
|
parsed.ModelConfigID,
|
|
)
|
|
if err != nil {
|
|
return uuid.Nil, http.StatusInternalServerError, &codersdk.Response{
|
|
Message: "Failed to resolve chat model config.",
|
|
Detail: err.Error(),
|
|
}
|
|
}
|
|
if reason == chatModelConfigAvailable {
|
|
return parsed.ModelConfigID, 0, nil
|
|
}
|
|
api.Logger.Debug(
|
|
ctx,
|
|
"personal root model override is unavailable, using default model",
|
|
slog.F("user_id", userID),
|
|
slog.F("model_config_id", parsed.ModelConfigID),
|
|
slog.F("reason", reason),
|
|
)
|
|
default:
|
|
api.Logger.Warn(
|
|
ctx,
|
|
"unsupported personal root model override mode, using default model",
|
|
slog.F("user_id", userID),
|
|
slog.F("mode", parsed.Mode),
|
|
)
|
|
}
|
|
}
|
|
|
|
return api.defaultCreateChatModelConfigID(ctx)
|
|
}
|
|
|
|
func (api *API) defaultCreateChatModelConfigID(
|
|
ctx context.Context,
|
|
) (uuid.UUID, int, *codersdk.Response) {
|
|
defaultModelConfig, err := api.Database.GetDefaultChatModelConfig(ctx)
|
|
if err != nil {
|
|
if xerrors.Is(err, sql.ErrNoRows) {
|
|
return uuid.Nil, http.StatusBadRequest, &codersdk.Response{
|
|
Message: "No default chat model config is configured.",
|
|
}
|
|
}
|
|
return uuid.Nil, http.StatusInternalServerError, &codersdk.Response{
|
|
Message: "Failed to resolve chat model config.",
|
|
Detail: err.Error(),
|
|
}
|
|
}
|
|
|
|
return defaultModelConfig.ID, 0, nil
|
|
}
|
|
|
|
func normalizeChatCompressionThreshold(
|
|
requested *int32,
|
|
fallback int32,
|
|
) (int32, error) {
|
|
threshold := fallback
|
|
if requested != nil {
|
|
threshold = *requested
|
|
}
|
|
|
|
if threshold < minChatContextCompressionThreshold ||
|
|
threshold > maxChatContextCompressionThreshold {
|
|
return 0, xerrors.Errorf(
|
|
"context_compression_threshold must be between %d and %d",
|
|
minChatContextCompressionThreshold,
|
|
maxChatContextCompressionThreshold,
|
|
)
|
|
}
|
|
|
|
return threshold, nil
|
|
}
|
|
|
|
func parseCompactionThresholdKey(key string) (uuid.UUID, error) {
|
|
if !strings.HasPrefix(key, codersdk.ChatCompactionThresholdKeyPrefix) {
|
|
return uuid.Nil, xerrors.Errorf("invalid compaction threshold key: %q", key)
|
|
}
|
|
id, err := uuid.Parse(key[len(codersdk.ChatCompactionThresholdKeyPrefix):])
|
|
if err != nil {
|
|
return uuid.Nil, xerrors.Errorf("invalid model config ID in key %q: %w", key, err)
|
|
}
|
|
return id, nil
|
|
}
|
|
|
|
//nolint:revive // get-return: revive assumes get* must be a getter, but this is an HTTP handler.
|
|
func (api *API) getChatSystemPrompt(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) {
|
|
httpapi.ResourceNotFound(rw)
|
|
return
|
|
}
|
|
config, err := api.Database.GetChatSystemPromptConfig(ctx)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Internal error fetching chat system prompt configuration.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
httpapi.Write(ctx, rw, http.StatusOK, codersdk.ChatSystemPromptResponse{
|
|
SystemPrompt: config.ChatSystemPrompt,
|
|
IncludeDefaultSystemPrompt: config.IncludeDefaultSystemPrompt,
|
|
DefaultSystemPrompt: chatd.DefaultSystemPrompt,
|
|
})
|
|
}
|
|
|
|
func (api *API) putChatSystemPrompt(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) {
|
|
httpapi.Forbidden(rw)
|
|
return
|
|
}
|
|
// Cap the raw request body to prevent excessive memory use from
|
|
// payloads padded with invisible characters that sanitize away.
|
|
r.Body = http.MaxBytesReader(rw, r.Body, int64(2*maxSystemPromptLenBytes))
|
|
var req codersdk.UpdateChatSystemPromptRequest
|
|
if !httpapi.Read(ctx, rw, r, &req) {
|
|
return
|
|
}
|
|
sanitizedPrompt := chatd.SanitizePromptText(req.SystemPrompt)
|
|
// 128 KiB is generous for a system prompt while still
|
|
// preventing abuse or accidental pastes of large content.
|
|
if len(sanitizedPrompt) > maxSystemPromptLenBytes {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "System prompt exceeds maximum length.",
|
|
Detail: fmt.Sprintf("Maximum length is %d bytes, got %d.", maxSystemPromptLenBytes, len(sanitizedPrompt)),
|
|
})
|
|
return
|
|
}
|
|
err := api.Database.InTx(func(tx database.Store) error {
|
|
if err := tx.UpsertChatSystemPrompt(ctx, sanitizedPrompt); err != nil {
|
|
return err
|
|
}
|
|
// Only update the include-default flag when the caller explicitly
|
|
// provides it. Omitting the field preserves whatever is currently
|
|
// stored (or the schema-level default for new deployments),
|
|
// avoiding a backward-compatibility regression for older clients
|
|
// that only send system_prompt.
|
|
if req.IncludeDefaultSystemPrompt != nil {
|
|
return tx.UpsertChatIncludeDefaultSystemPrompt(ctx, *req.IncludeDefaultSystemPrompt)
|
|
}
|
|
return nil
|
|
}, nil)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Internal error updating chat system prompt configuration.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
rw.WriteHeader(http.StatusNoContent)
|
|
}
|
|
|
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
|
//
|
|
//nolint:revive // get-return: revive assumes get* must be a getter, but this is an HTTP handler.
|
|
func (api *API) getChatPlanModeInstructions(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) {
|
|
httpapi.ResourceNotFound(rw)
|
|
return
|
|
}
|
|
|
|
instructions, err := api.Database.GetChatPlanModeInstructions(ctx)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Internal error fetching plan mode instructions.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
httpapi.Write(ctx, rw, http.StatusOK, codersdk.ChatPlanModeInstructionsResponse{
|
|
PlanModeInstructions: instructions,
|
|
})
|
|
}
|
|
|
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
|
func (api *API) putChatPlanModeInstructions(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) {
|
|
httpapi.Forbidden(rw)
|
|
return
|
|
}
|
|
|
|
// Cap the raw request body to prevent excessive memory use from
|
|
// payloads padded with invisible characters that sanitize away.
|
|
r.Body = http.MaxBytesReader(rw, r.Body, int64(2*maxSystemPromptLenBytes))
|
|
|
|
var req codersdk.UpdateChatPlanModeInstructionsRequest
|
|
if !httpapi.Read(ctx, rw, r, &req) {
|
|
return
|
|
}
|
|
|
|
sanitizedInstructions := chatd.SanitizePromptText(req.PlanModeInstructions)
|
|
if len(sanitizedInstructions) > maxSystemPromptLenBytes {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Plan mode instructions exceed maximum length.",
|
|
Detail: fmt.Sprintf("Maximum length is %d bytes, got %d.", maxSystemPromptLenBytes, len(sanitizedInstructions)),
|
|
})
|
|
return
|
|
}
|
|
|
|
if err := api.Database.UpsertChatPlanModeInstructions(ctx, sanitizedInstructions); err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Internal error updating plan mode instructions.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
rw.WriteHeader(http.StatusNoContent)
|
|
}
|
|
|
|
func readChatModelOverrideContext(
|
|
rw http.ResponseWriter,
|
|
r *http.Request,
|
|
) (codersdk.ChatModelOverrideContext, bool) {
|
|
ctx := r.Context()
|
|
rawContext := chi.URLParam(r, "context")
|
|
overrideContext, err := parseChatModelOverrideContext(rawContext)
|
|
if err == nil {
|
|
return overrideContext, true
|
|
}
|
|
validContextValues := make(
|
|
[]string,
|
|
0,
|
|
len(codersdk.AllChatModelOverrideContexts()),
|
|
)
|
|
for _, overrideContext := range codersdk.AllChatModelOverrideContexts() {
|
|
validContextValues = append(validContextValues, string(overrideContext))
|
|
}
|
|
validContexts := strings.Join(validContextValues, ", ")
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid chat model override context.",
|
|
Detail: fmt.Sprintf(
|
|
"Expected one of %s. Got %q.",
|
|
validContexts,
|
|
rawContext,
|
|
),
|
|
})
|
|
return "", false
|
|
}
|
|
|
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
|
//
|
|
//nolint:revive // get-return: revive assumes get* must be a getter, but this is an HTTP handler.
|
|
func (api *API) getChatModelOverride(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
if !api.Authorize(r, policy.ActionRead, rbac.ResourceDeploymentConfig) {
|
|
httpapi.ResourceNotFound(rw)
|
|
return
|
|
}
|
|
overrideContext, ok := readChatModelOverrideContext(rw, r)
|
|
if !ok {
|
|
return
|
|
}
|
|
|
|
modelConfigID, isMalformed, label, err := api.readChatModelOverrideConfig(ctx, overrideContext)
|
|
if err != nil {
|
|
if label == "" {
|
|
label = string(overrideContext)
|
|
}
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: fmt.Sprintf("Internal error fetching %s model override.", label),
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
resp := codersdk.ChatModelOverrideResponse{
|
|
Context: overrideContext,
|
|
ModelConfigID: formatChatModelOverride(modelConfigID),
|
|
IsMalformed: isMalformed,
|
|
}
|
|
|
|
httpapi.Write(ctx, rw, http.StatusOK, resp)
|
|
}
|
|
|
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
|
func (api *API) putChatModelOverride(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) {
|
|
httpapi.Forbidden(rw)
|
|
return
|
|
}
|
|
overrideContext, ok := readChatModelOverrideContext(rw, r)
|
|
if !ok {
|
|
return
|
|
}
|
|
|
|
var req codersdk.UpdateChatModelOverrideRequest
|
|
if !httpapi.Read(ctx, rw, r, &req) {
|
|
return
|
|
}
|
|
|
|
modelConfigID, err := parseChatModelOverride(req.ModelConfigID)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid model_config_id.",
|
|
Detail: fmt.Sprintf("Value %q is not a valid UUID.", req.ModelConfigID),
|
|
})
|
|
return
|
|
}
|
|
|
|
status, resp := validateChatModelOverrideID(ctx, api.Database, modelConfigID)
|
|
if resp != nil {
|
|
httpapi.Write(ctx, rw, status, *resp)
|
|
return
|
|
}
|
|
|
|
label, err := api.upsertChatModelOverrideConfig(ctx, overrideContext, modelConfigID)
|
|
if err != nil {
|
|
if label == "" {
|
|
label = string(overrideContext)
|
|
}
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: fmt.Sprintf("Internal error updating %s model override.", label),
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
rw.WriteHeader(http.StatusNoContent)
|
|
}
|
|
|
|
func readChatPersonalModelOverrideContext(
|
|
rw http.ResponseWriter,
|
|
r *http.Request,
|
|
) (codersdk.ChatPersonalModelOverrideContext, bool) {
|
|
ctx := r.Context()
|
|
rawContext := chi.URLParam(r, "context")
|
|
overrideContext, ok := parseChatPersonalModelOverrideContext(rawContext)
|
|
if ok {
|
|
return overrideContext, true
|
|
}
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid chat personal model override context.",
|
|
Detail: fmt.Sprintf(
|
|
"Expected one of %s. Got %q.",
|
|
chatPersonalModelOverrideContextsJoined(),
|
|
rawContext,
|
|
),
|
|
})
|
|
return "", false
|
|
}
|
|
|
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
|
//
|
|
//nolint:revive // get-return: revive assumes get* must be a getter, but this is an HTTP handler.
|
|
func (api *API) getChatPersonalModelOverridesAdminSettings(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
if !api.Authorize(r, policy.ActionRead, rbac.ResourceDeploymentConfig) {
|
|
httpapi.ResourceNotFound(rw)
|
|
return
|
|
}
|
|
|
|
enabled, err := api.Database.GetChatPersonalModelOverridesEnabled(ctx)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Internal error fetching personal model override setting.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
httpapi.Write(ctx, rw, http.StatusOK, codersdk.ChatPersonalModelOverridesAdminSettings{
|
|
AllowUsers: enabled,
|
|
})
|
|
}
|
|
|
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
|
func (api *API) putChatPersonalModelOverridesAdminSettings(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) {
|
|
httpapi.Forbidden(rw)
|
|
return
|
|
}
|
|
|
|
var req codersdk.UpdateChatPersonalModelOverridesAdminSettingsRequest
|
|
if !httpapi.Read(ctx, rw, r, &req) {
|
|
return
|
|
}
|
|
if err := api.Database.UpsertChatPersonalModelOverridesEnabled(ctx, req.AllowUsers); err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Internal error updating personal model override setting.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
rw.WriteHeader(http.StatusNoContent)
|
|
}
|
|
|
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
|
//
|
|
//nolint:revive // get-return: revive assumes get* must be a getter, but this is an HTTP handler.
|
|
func (api *API) getUserChatPersonalModelOverrides(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
apiKey := httpmw.APIKey(r)
|
|
|
|
enabled, err := api.Database.GetChatPersonalModelOverridesEnabled(ctx)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Internal error fetching personal model override setting.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
rows, err := api.Database.ListUserChatPersonalModelOverrides(ctx, apiKey.UserID)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Internal error fetching user personal model overrides.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
values := make(map[codersdk.ChatPersonalModelOverrideContext]string, len(rows))
|
|
for _, row := range rows {
|
|
rawContext, ok := strings.CutPrefix(row.Key, chatd.ChatPersonalModelOverrideKeyPrefix)
|
|
if !ok {
|
|
continue
|
|
}
|
|
overrideContext, ok := parseChatPersonalModelOverrideContext(rawContext)
|
|
if !ok {
|
|
continue
|
|
}
|
|
values[overrideContext] = row.Value
|
|
}
|
|
|
|
deploymentDefaults, err := api.chatPersonalModelOverrideDeploymentDefaults(ctx)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Internal error fetching deployment model defaults.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
response := codersdk.UserChatPersonalModelOverridesResponse{
|
|
Enabled: enabled,
|
|
DeploymentDefaults: deploymentDefaults,
|
|
}
|
|
for _, overrideContext := range chatPersonalModelOverrideContexts {
|
|
raw, isSet := values[overrideContext]
|
|
override := chatPersonalModelOverrideResponse(overrideContext, raw, isSet)
|
|
switch overrideContext {
|
|
case codersdk.ChatPersonalModelOverrideContextRoot:
|
|
response.Root = override
|
|
case codersdk.ChatPersonalModelOverrideContextGeneral:
|
|
response.General = override
|
|
case codersdk.ChatPersonalModelOverrideContextExplore:
|
|
response.Explore = override
|
|
}
|
|
}
|
|
httpapi.Write(ctx, rw, http.StatusOK, response)
|
|
}
|
|
|
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
|
func (api *API) putUserChatPersonalModelOverride(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
apiKey := httpmw.APIKey(r)
|
|
|
|
enabled, err := api.Database.GetChatPersonalModelOverridesEnabled(ctx)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Internal error fetching personal model override setting.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
if !enabled {
|
|
httpapi.Write(ctx, rw, http.StatusForbidden, codersdk.Response{
|
|
Message: "An administrator has not enabled user personal model overrides.",
|
|
})
|
|
return
|
|
}
|
|
|
|
overrideContext, ok := readChatPersonalModelOverrideContext(rw, r)
|
|
if !ok {
|
|
return
|
|
}
|
|
|
|
var req codersdk.UpdateUserChatPersonalModelOverrideRequest
|
|
if !httpapi.Read(ctx, rw, r, &req) {
|
|
return
|
|
}
|
|
|
|
modelConfigID := ""
|
|
rawModelConfigID := strings.TrimSpace(req.ModelConfigID)
|
|
switch req.Mode {
|
|
case codersdk.ChatPersonalModelOverrideModeChatDefault:
|
|
if rawModelConfigID != "" {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "model_config_id must be empty unless mode is model.",
|
|
})
|
|
return
|
|
}
|
|
case codersdk.ChatPersonalModelOverrideModeDeploymentDefault:
|
|
if overrideContext == codersdk.ChatPersonalModelOverrideContextRoot {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "deployment_default is not supported for root personal model overrides.",
|
|
})
|
|
return
|
|
}
|
|
if rawModelConfigID != "" {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "model_config_id must be empty unless mode is model.",
|
|
})
|
|
return
|
|
}
|
|
case codersdk.ChatPersonalModelOverrideModeModel:
|
|
if rawModelConfigID == "" {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "model_config_id is required when mode is model.",
|
|
})
|
|
return
|
|
}
|
|
parsedModelConfigID, err := uuid.Parse(rawModelConfigID)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid model_config_id.",
|
|
Detail: fmt.Sprintf("Value %q is not a valid UUID.", req.ModelConfigID),
|
|
})
|
|
return
|
|
}
|
|
if parsedModelConfigID == uuid.Nil {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid model_config_id.",
|
|
})
|
|
return
|
|
}
|
|
status, resp := api.validateUserChatModelConfigAvailable(ctx, apiKey.UserID, parsedModelConfigID)
|
|
if resp != nil {
|
|
httpapi.Write(ctx, rw, status, *resp)
|
|
return
|
|
}
|
|
modelConfigID = parsedModelConfigID.String()
|
|
default:
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid personal model override mode.",
|
|
})
|
|
return
|
|
}
|
|
|
|
if err := api.Database.UpsertUserChatPersonalModelOverride(ctx, database.UpsertUserChatPersonalModelOverrideParams{
|
|
UserID: apiKey.UserID,
|
|
Key: chatd.ChatPersonalModelOverrideKey(overrideContext),
|
|
Value: formatChatPersonalModelOverrideValue(req.Mode, modelConfigID),
|
|
}); err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Internal error updating user personal model override.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
rw.WriteHeader(http.StatusNoContent)
|
|
}
|
|
|
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
|
//
|
|
//nolint:revive // get-return: revive assumes get* must be a getter, but this is an HTTP handler.
|
|
func (api *API) getChatDesktopEnabled(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
enabled, err := api.Database.GetChatDesktopEnabled(ctx)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Internal error fetching desktop setting.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
httpapi.Write(ctx, rw, http.StatusOK, codersdk.ChatDesktopEnabledResponse{
|
|
EnableDesktop: enabled,
|
|
})
|
|
}
|
|
|
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
|
func (api *API) putChatDesktopEnabled(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) {
|
|
httpapi.Forbidden(rw)
|
|
return
|
|
}
|
|
|
|
var req codersdk.UpdateChatDesktopEnabledRequest
|
|
if !httpapi.Read(ctx, rw, r, &req) {
|
|
return
|
|
}
|
|
if err := api.Database.UpsertChatDesktopEnabled(ctx, req.EnableDesktop); httpapi.Is404Error(err) {
|
|
httpapi.ResourceNotFound(rw)
|
|
return
|
|
} else if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Internal error updating desktop setting.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
rw.WriteHeader(http.StatusNoContent)
|
|
}
|
|
|
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
|
//
|
|
//nolint:revive // get-return: revive assumes get* must be a getter, but this is an HTTP handler.
|
|
func (api *API) getChatComputerUseProvider(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
provider, err := api.Database.GetChatComputerUseProvider(ctx)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Internal error fetching computer use provider.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
httpapi.Write(ctx, rw, http.StatusOK, codersdk.ChatComputerUseProviderResponse{
|
|
Provider: chattool.DefaultComputerUseProvider(provider),
|
|
})
|
|
}
|
|
|
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
|
func (api *API) putChatComputerUseProvider(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) {
|
|
httpapi.Forbidden(rw)
|
|
return
|
|
}
|
|
|
|
var req codersdk.UpdateChatComputerUseProviderRequest
|
|
if !httpapi.Read(ctx, rw, r, &req) {
|
|
return
|
|
}
|
|
if !chattool.IsSupportedComputerUseProvider(req.Provider) {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid computer use provider.",
|
|
Detail: fmt.Sprintf(
|
|
"Expected one of: %s. Got %q.",
|
|
strings.Join(chattool.SupportedComputerUseProviders(), ", "),
|
|
req.Provider,
|
|
),
|
|
})
|
|
return
|
|
}
|
|
|
|
if err := api.Database.UpsertChatComputerUseProvider(ctx, req.Provider); err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Internal error updating computer use provider.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
rw.WriteHeader(http.StatusNoContent)
|
|
}
|
|
|
|
func (api *API) deploymentChatDebugLoggingEnabled() bool {
|
|
return api.DeploymentValues != nil && api.DeploymentValues.AI.Chat.DebugLoggingEnabled.Value()
|
|
}
|
|
|
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
|
//
|
|
//nolint:revive // get-return: revive assumes get* must be a getter, but this is an HTTP handler.
|
|
func (api *API) getChatDebugLogging(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
if !api.Authorize(r, policy.ActionRead, rbac.ResourceDeploymentConfig) {
|
|
httpapi.ResourceNotFound(rw)
|
|
return
|
|
}
|
|
|
|
allowUsers, err := api.Database.GetChatDebugLoggingAllowUsers(ctx)
|
|
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Internal error fetching chat debug logging setting.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
httpapi.Write(ctx, rw, http.StatusOK, codersdk.ChatDebugLoggingAdminSettings{
|
|
AllowUsers: err == nil && allowUsers,
|
|
ForcedByDeployment: api.deploymentChatDebugLoggingEnabled(),
|
|
})
|
|
}
|
|
|
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
|
func (api *API) putChatDebugLogging(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) {
|
|
httpapi.Forbidden(rw)
|
|
return
|
|
}
|
|
|
|
var req codersdk.UpdateChatDebugLoggingAllowUsersRequest
|
|
if !httpapi.Read(ctx, rw, r, &req) {
|
|
return
|
|
}
|
|
if err := api.Database.UpsertChatDebugLoggingAllowUsers(ctx, req.AllowUsers); err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Internal error updating chat debug logging setting.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
rw.WriteHeader(http.StatusNoContent)
|
|
}
|
|
|
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
|
//
|
|
//nolint:revive // get-return: revive assumes get* must be a getter, but this is an HTTP handler.
|
|
func (api *API) getUserChatDebugLogging(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
apiKey := httpmw.APIKey(r)
|
|
|
|
forcedByDeployment := api.deploymentChatDebugLoggingEnabled()
|
|
allowUsers := false
|
|
if !forcedByDeployment {
|
|
enabled, err := api.Database.GetChatDebugLoggingAllowUsers(ctx)
|
|
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Internal error fetching chat debug logging setting.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
allowUsers = err == nil && enabled
|
|
}
|
|
|
|
debugEnabled := forcedByDeployment
|
|
if allowUsers {
|
|
enabled, err := api.Database.GetUserChatDebugLoggingEnabled(ctx, apiKey.UserID)
|
|
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Internal error fetching user chat debug logging setting.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
debugEnabled = err == nil && enabled
|
|
}
|
|
|
|
httpapi.Write(ctx, rw, http.StatusOK, codersdk.UserChatDebugLoggingSettings{
|
|
DebugLoggingEnabled: debugEnabled,
|
|
UserToggleAllowed: !forcedByDeployment && allowUsers,
|
|
ForcedByDeployment: forcedByDeployment,
|
|
})
|
|
}
|
|
|
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
|
func (api *API) putUserChatDebugLogging(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
apiKey := httpmw.APIKey(r)
|
|
if api.deploymentChatDebugLoggingEnabled() {
|
|
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{
|
|
Message: "Chat debug logging is already forced on by deployment configuration.",
|
|
})
|
|
return
|
|
}
|
|
|
|
allowUsers, err := api.Database.GetChatDebugLoggingAllowUsers(ctx)
|
|
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Internal error fetching chat debug logging setting.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
if err != nil || !allowUsers {
|
|
httpapi.Write(ctx, rw, http.StatusForbidden, codersdk.Response{
|
|
Message: "An administrator has not enabled user-controlled chat debug logging.",
|
|
})
|
|
return
|
|
}
|
|
|
|
var req codersdk.UpdateUserChatDebugLoggingRequest
|
|
if !httpapi.Read(ctx, rw, r, &req) {
|
|
return
|
|
}
|
|
if err := api.Database.UpsertUserChatDebugLoggingEnabled(ctx, database.UpsertUserChatDebugLoggingEnabledParams{
|
|
UserID: apiKey.UserID,
|
|
DebugLoggingEnabled: req.DebugLoggingEnabled,
|
|
}); err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Internal error updating user chat debug logging setting.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
rw.WriteHeader(http.StatusNoContent)
|
|
}
|
|
|
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
|
//
|
|
//nolint:revive // get-return: revive assumes get* must be a getter, but this is an HTTP handler.
|
|
func (api *API) getChatAdvisorConfig(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
raw, err := api.Database.GetChatAdvisorConfig(ctx)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Internal error fetching advisor configuration.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
var resp codersdk.AdvisorConfig
|
|
if err := json.Unmarshal([]byte(raw), &resp); err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Stored advisor configuration is invalid.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
resp.MaxUsesPerRun = max(resp.MaxUsesPerRun, 0)
|
|
resp.MaxOutputTokens = max(resp.MaxOutputTokens, 0)
|
|
|
|
httpapi.Write(ctx, rw, http.StatusOK, resp)
|
|
}
|
|
|
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
|
func (api *API) putChatAdvisorConfig(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) {
|
|
httpapi.Forbidden(rw)
|
|
return
|
|
}
|
|
|
|
var req codersdk.UpdateAdvisorConfigRequest
|
|
if !httpapi.Read(ctx, rw, r, &req) {
|
|
return
|
|
}
|
|
if req.MaxUsesPerRun < 0 {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: fmt.Sprintf("max_uses_per_run %d must be non-negative.", req.MaxUsesPerRun),
|
|
})
|
|
return
|
|
}
|
|
if req.MaxOutputTokens < 0 {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: fmt.Sprintf("max_output_tokens %d must be non-negative.", req.MaxOutputTokens),
|
|
})
|
|
return
|
|
}
|
|
if req.ModelConfigID != uuid.Nil {
|
|
// Use system context because GetChatModelConfigByID requires
|
|
// deployment-config read access, which can be broader than the
|
|
// handler's explicit update check. The lookup only validates that
|
|
// the referenced model exists before persisting deployment config.
|
|
//nolint:gocritic // This admin-authorized validation lookup intentionally bypasses read authz.
|
|
if _, err := api.Database.GetChatModelConfigByID(dbauthz.AsSystemRestricted(ctx), req.ModelConfigID); err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) || httpapi.Is404Error(err) {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: fmt.Sprintf("model_config_id %q does not match any existing model config.", req.ModelConfigID),
|
|
})
|
|
return
|
|
}
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Internal error validating advisor model config.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
}
|
|
|
|
raw, err := json.Marshal(req)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Internal error encoding advisor configuration.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
if err := api.Database.UpsertChatAdvisorConfig(ctx, string(raw)); err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Internal error updating advisor configuration.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
publishChatConfigEvent(api.Logger, api.Pubsub, pubsub.ChatConfigEventAdvisorConfig, uuid.Nil)
|
|
|
|
rw.WriteHeader(http.StatusNoContent)
|
|
}
|
|
|
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
|
//
|
|
//nolint:revive // get-return: revive assumes get* must be a getter, but this is an HTTP handler.
|
|
func (api *API) getChatWorkspaceTTL(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
raw, err := api.Database.GetChatWorkspaceTTL(ctx)
|
|
if err != nil {
|
|
if httpapi.Is404Error(err) {
|
|
httpapi.ResourceNotFound(rw)
|
|
return
|
|
}
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Internal error fetching workspace TTL setting.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
// Validate/default the stored value so callers always receive a
|
|
// well-formed duration string.
|
|
d, err := codersdk.ParseChatWorkspaceTTL(raw)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Stored workspace TTL is invalid.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
httpapi.Write(ctx, rw, http.StatusOK, codersdk.ChatWorkspaceTTLResponse{
|
|
WorkspaceTTLMillis: d.Milliseconds(),
|
|
})
|
|
}
|
|
|
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
|
func (api *API) putChatWorkspaceTTL(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) {
|
|
httpapi.Forbidden(rw)
|
|
return
|
|
}
|
|
|
|
var req codersdk.UpdateChatWorkspaceTTLRequest
|
|
if !httpapi.Read(ctx, rw, r, &req) {
|
|
return
|
|
}
|
|
|
|
// Validate before converting to avoid int64 overflow in the
|
|
// multiplication by time.Millisecond.
|
|
if req.WorkspaceTTLMillis < 0 {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Workspace TTL must be non-negative.",
|
|
})
|
|
return
|
|
}
|
|
|
|
// Convert milliseconds to duration.
|
|
d := time.Duration(req.WorkspaceTTLMillis) * time.Millisecond
|
|
|
|
// Technically a duplication of validWorkspaceTTL but this is not scoped to templates.
|
|
if d > 0 && d < ttlMinimum {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Workspace TTL must not be less than 1 minute.",
|
|
})
|
|
return
|
|
}
|
|
if d > ttlMaximum {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Workspace TTL must not exceed 30 days.",
|
|
})
|
|
return
|
|
}
|
|
|
|
// Store the canonicalized duration string.
|
|
if err := api.Database.UpsertChatWorkspaceTTL(ctx, d.String()); httpapi.Is404Error(err) {
|
|
httpapi.ResourceNotFound(rw)
|
|
return
|
|
} else if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Internal error updating workspace TTL setting.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
rw.WriteHeader(http.StatusNoContent)
|
|
}
|
|
|
|
// @Summary Get chat retention days
|
|
// @ID get-chat-retention-days
|
|
// @Security CoderSessionToken
|
|
// @Tags Chats
|
|
// @Produce json
|
|
// @Success 200 {object} codersdk.ChatRetentionDaysResponse
|
|
// @Router /api/experimental/chats/config/retention-days [get]
|
|
// @x-apidocgen {"skip": true}
|
|
//
|
|
//nolint:revive // get-return: revive assumes get* must be a getter, but this is an HTTP handler.
|
|
func (api *API) getChatRetentionDays(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
retentionDays, err := api.Database.GetChatRetentionDays(ctx)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to get chat retention days.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
httpapi.Write(ctx, rw, http.StatusOK, codersdk.ChatRetentionDaysResponse{
|
|
RetentionDays: retentionDays,
|
|
})
|
|
}
|
|
|
|
// Keep in sync with retentionDaysMaximum in
|
|
// site/src/pages/AgentsPage/AgentSettingsBehaviorPageView.tsx.
|
|
const retentionDaysMaximum = 3650 // ~10 years
|
|
|
|
// @Summary Update chat retention days
|
|
// @ID update-chat-retention-days
|
|
// @Security CoderSessionToken
|
|
// @Tags Chats
|
|
// @Accept json
|
|
// @Param request body codersdk.UpdateChatRetentionDaysRequest true "Request body"
|
|
// @Success 204
|
|
// @Router /api/experimental/chats/config/retention-days [put]
|
|
// @x-apidocgen {"skip": true}
|
|
func (api *API) putChatRetentionDays(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) {
|
|
httpapi.Forbidden(rw)
|
|
return
|
|
}
|
|
var req codersdk.UpdateChatRetentionDaysRequest
|
|
if !httpapi.Read(ctx, rw, r, &req) {
|
|
return
|
|
}
|
|
if req.RetentionDays < 0 || req.RetentionDays > retentionDaysMaximum {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: fmt.Sprintf("Retention days must be between 0 and %d.", retentionDaysMaximum),
|
|
})
|
|
return
|
|
}
|
|
if err := api.Database.UpsertChatRetentionDays(ctx, req.RetentionDays); err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to update chat retention days.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
rw.WriteHeader(http.StatusNoContent)
|
|
}
|
|
|
|
// getChatDebugRetentionDays returns the deployment-wide chat debug run
|
|
// retention window. Any authenticated user can read it; writes require admin.
|
|
//
|
|
//nolint:revive // get-return: revive assumes get* must be a getter, but this is an HTTP handler.
|
|
func (api *API) getChatDebugRetentionDays(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
retentionDays, err := api.Database.GetChatDebugRetentionDays(ctx, codersdk.DefaultChatDebugRetentionDays)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to get chat debug retention days.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
httpapi.Write(ctx, rw, http.StatusOK, codersdk.ChatDebugRetentionDaysResponse{
|
|
DebugRetentionDays: retentionDays,
|
|
})
|
|
}
|
|
|
|
// Keep in sync with the validation schema in
|
|
// site/src/pages/AgentsPage/components/DebugRetentionSettings.tsx.
|
|
const chatDebugRetentionDaysMaximum = 3650 // ~10 years
|
|
|
|
// putChatDebugRetentionDays updates the deployment-wide chat debug run
|
|
// retention window. Admin-only.
|
|
func (api *API) putChatDebugRetentionDays(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) {
|
|
httpapi.Forbidden(rw)
|
|
return
|
|
}
|
|
var req codersdk.UpdateChatDebugRetentionDaysRequest
|
|
if !httpapi.Read(ctx, rw, r, &req) {
|
|
return
|
|
}
|
|
if req.DebugRetentionDays < 0 || req.DebugRetentionDays > chatDebugRetentionDaysMaximum {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: fmt.Sprintf("Chat debug retention days must be between 0 and %d.", chatDebugRetentionDaysMaximum),
|
|
})
|
|
return
|
|
}
|
|
if err := api.Database.UpsertChatDebugRetentionDays(ctx, req.DebugRetentionDays); err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to update chat debug retention days.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
rw.WriteHeader(http.StatusNoContent)
|
|
}
|
|
|
|
// getChatAutoArchiveDays returns the deployment-wide auto-archive
|
|
// window. Any authenticated user can read it (same as retention
|
|
// days); writes require admin.
|
|
//
|
|
//nolint:revive // get-return: revive assumes get* must be a getter, but this is an HTTP handler.
|
|
func (api *API) getChatAutoArchiveDays(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
autoArchiveDays, err := api.Database.GetChatAutoArchiveDays(ctx, codersdk.DefaultChatAutoArchiveDays)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to get chat auto-archive days.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
httpapi.Write(ctx, rw, http.StatusOK, codersdk.ChatAutoArchiveDaysResponse{
|
|
AutoArchiveDays: autoArchiveDays,
|
|
})
|
|
}
|
|
|
|
// Upper bound for the auto-archive window. Keep in sync with
|
|
// the validation schema in site/src/pages/AgentsPage/components/AutoArchiveSettings.tsx.
|
|
const autoArchiveDaysMaximum = 3650 // ~10 years
|
|
|
|
// putChatAutoArchiveDays updates the deployment-wide auto-archive
|
|
// window. Admin-only; documented in docs/ai-coder/agents/chats-api.md.
|
|
func (api *API) putChatAutoArchiveDays(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) {
|
|
httpapi.Forbidden(rw)
|
|
return
|
|
}
|
|
var req codersdk.UpdateChatAutoArchiveDaysRequest
|
|
if !httpapi.Read(ctx, rw, r, &req) {
|
|
return
|
|
}
|
|
if req.AutoArchiveDays < 0 || req.AutoArchiveDays > autoArchiveDaysMaximum {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: fmt.Sprintf("Auto-archive days must be between 0 and %d.", autoArchiveDaysMaximum),
|
|
})
|
|
return
|
|
}
|
|
if err := api.Database.UpsertChatAutoArchiveDays(ctx, req.AutoArchiveDays); err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to update chat auto-archive days.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
rw.WriteHeader(http.StatusNoContent)
|
|
}
|
|
|
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
|
//
|
|
//nolint:revive // get-return: revive assumes get* must be a getter, but this is an HTTP handler.
|
|
func (api *API) getChatTemplateAllowlist(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
if !api.Authorize(r, policy.ActionRead, rbac.ResourceDeploymentConfig) {
|
|
httpapi.ResourceNotFound(rw)
|
|
return
|
|
}
|
|
raw, err := api.Database.GetChatTemplateAllowlist(ctx)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Internal error fetching chat template allowlist.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
parsed, parseErr := xjson.ParseUUIDList(raw)
|
|
if parseErr != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Stored template allowlist is corrupt.",
|
|
Detail: parseErr.Error(),
|
|
})
|
|
return
|
|
}
|
|
ids := make([]string, len(parsed))
|
|
for i, id := range parsed {
|
|
ids[i] = id.String()
|
|
}
|
|
resp := codersdk.ChatTemplateAllowlist{
|
|
TemplateIDs: ids,
|
|
}
|
|
httpapi.Write(ctx, rw, http.StatusOK, resp)
|
|
}
|
|
|
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
|
func (api *API) putChatTemplateAllowlist(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) {
|
|
httpapi.ResourceNotFound(rw)
|
|
return
|
|
}
|
|
|
|
var req codersdk.ChatTemplateAllowlist
|
|
if !httpapi.Read(ctx, rw, r, &req) {
|
|
return
|
|
}
|
|
|
|
// Validate all entries are valid UUIDs and deduplicate.
|
|
seen := make(map[string]struct{}, len(req.TemplateIDs))
|
|
deduped := make([]string, 0, len(req.TemplateIDs))
|
|
for _, id := range req.TemplateIDs {
|
|
parsed, err := uuid.Parse(id)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid template ID in allowlist.",
|
|
Detail: fmt.Sprintf("%q is not a valid UUID.", id),
|
|
})
|
|
return
|
|
}
|
|
// Canonicalize to lowercase so deduplication is
|
|
// case-insensitive and stored values are consistent.
|
|
canonical := parsed.String()
|
|
if _, ok := seen[canonical]; !ok {
|
|
seen[canonical] = struct{}{}
|
|
deduped = append(deduped, canonical)
|
|
}
|
|
}
|
|
|
|
// Convert to UUIDs for the database query.
|
|
parsedUUIDs := make([]uuid.UUID, len(deduped))
|
|
for i, s := range deduped {
|
|
// Already validated above, safe to ignore error.
|
|
parsedUUIDs[i], _ = uuid.Parse(s)
|
|
}
|
|
|
|
raw, err := json.Marshal(deduped)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Internal error encoding template allowlist.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
err = api.Database.InTx(func(tx database.Store) error {
|
|
// Verify all IDs refer to existing, non-deprecated templates
|
|
// in a single query.
|
|
if len(parsedUUIDs) > 0 {
|
|
found, err := tx.GetTemplatesWithFilter(ctx, database.GetTemplatesWithFilterParams{
|
|
IDs: parsedUUIDs,
|
|
Deprecated: sql.NullBool{
|
|
Bool: false,
|
|
Valid: true,
|
|
},
|
|
})
|
|
if err != nil {
|
|
return xerrors.Errorf("fetch templates: %w", err)
|
|
}
|
|
if len(found) != len(parsedUUIDs) {
|
|
foundSet := make(map[uuid.UUID]struct{}, len(found))
|
|
for _, t := range found {
|
|
foundSet[t.ID] = struct{}{}
|
|
}
|
|
var missing []string
|
|
for _, id := range parsedUUIDs {
|
|
if _, ok := foundSet[id]; !ok {
|
|
missing = append(missing, id.String())
|
|
}
|
|
}
|
|
return xerrors.Errorf("templates not found or deprecated: %s", strings.Join(missing, ", "))
|
|
}
|
|
}
|
|
return tx.UpsertChatTemplateAllowlist(ctx, string(raw))
|
|
}, nil)
|
|
if err != nil {
|
|
// If the error mentions "not found or deprecated", it's a
|
|
// validation failure, not an internal error.
|
|
if strings.Contains(err.Error(), "not found or deprecated") {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "One or more templates not found or deprecated.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Internal error updating chat template allowlist.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
rw.WriteHeader(http.StatusNoContent)
|
|
}
|
|
|
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
|
//
|
|
//nolint:revive // get-return: revive assumes get* must be a getter, but this is an HTTP handler.
|
|
func (api *API) getUserChatCustomPrompt(rw http.ResponseWriter, r *http.Request) {
|
|
var (
|
|
ctx = r.Context()
|
|
apiKey = httpmw.APIKey(r)
|
|
)
|
|
|
|
customPrompt, err := api.Database.GetUserChatCustomPrompt(ctx, apiKey.UserID)
|
|
if err != nil {
|
|
if !errors.Is(err, sql.ErrNoRows) {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Error reading user chat custom prompt.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
customPrompt = ""
|
|
}
|
|
|
|
httpapi.Write(ctx, rw, http.StatusOK, codersdk.UserChatCustomPrompt{
|
|
CustomPrompt: customPrompt,
|
|
})
|
|
}
|
|
|
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
|
func (api *API) putUserChatCustomPrompt(rw http.ResponseWriter, r *http.Request) {
|
|
var (
|
|
ctx = r.Context()
|
|
apiKey = httpmw.APIKey(r)
|
|
)
|
|
// Cap the raw request body to prevent excessive memory use from
|
|
// payloads padded with invisible characters that sanitize away.
|
|
r.Body = http.MaxBytesReader(rw, r.Body, int64(2*maxSystemPromptLenBytes))
|
|
|
|
var params codersdk.UserChatCustomPrompt
|
|
if !httpapi.Read(ctx, rw, r, ¶ms) {
|
|
return
|
|
}
|
|
|
|
sanitizedPrompt := chatd.SanitizePromptText(params.CustomPrompt)
|
|
// Apply the same 128 KiB limit as the deployment system prompt.
|
|
if len(sanitizedPrompt) > maxSystemPromptLenBytes {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Custom prompt exceeds maximum length.",
|
|
Detail: fmt.Sprintf("Maximum length is %d bytes, got %d.", maxSystemPromptLenBytes, len(sanitizedPrompt)),
|
|
})
|
|
return
|
|
}
|
|
|
|
updatedConfig, err := api.Database.UpdateUserChatCustomPrompt(ctx, database.UpdateUserChatCustomPromptParams{
|
|
UserID: apiKey.UserID,
|
|
ChatCustomPrompt: sanitizedPrompt,
|
|
})
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Error updating user chat custom prompt.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
publishChatConfigEvent(api.Logger, api.Pubsub, pubsub.ChatConfigEventUserPrompt, apiKey.UserID)
|
|
|
|
httpapi.Write(ctx, rw, http.StatusOK, codersdk.UserChatCustomPrompt{
|
|
CustomPrompt: updatedConfig.Value,
|
|
})
|
|
}
|
|
|
|
// @Summary Get user chat compaction thresholds
|
|
// @x-apidocgen {"skip": true}
|
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
|
//
|
|
//nolint:revive // get-return: revive assumes get* must be a getter, but this is an HTTP handler.
|
|
func (api *API) getUserChatCompactionThresholds(rw http.ResponseWriter, r *http.Request) {
|
|
var (
|
|
ctx = r.Context()
|
|
apiKey = httpmw.APIKey(r)
|
|
)
|
|
|
|
rows, err := api.Database.ListUserChatCompactionThresholds(ctx, apiKey.UserID)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Error listing user chat compaction thresholds.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
resp := codersdk.UserChatCompactionThresholds{
|
|
Thresholds: make([]codersdk.UserChatCompactionThreshold, 0, len(rows)),
|
|
}
|
|
for _, row := range rows {
|
|
modelConfigID, err := parseCompactionThresholdKey(row.Key)
|
|
if err != nil {
|
|
api.Logger.Warn(ctx, "skipping malformed user chat compaction threshold key",
|
|
slog.F("key", row.Key),
|
|
slog.F("value", row.Value),
|
|
slog.Error(err),
|
|
)
|
|
continue
|
|
}
|
|
|
|
thresholdPercent, err := strconv.ParseInt(row.Value, 10, 32)
|
|
if err != nil {
|
|
api.Logger.Warn(ctx, "skipping malformed user chat compaction threshold value",
|
|
slog.F("key", row.Key),
|
|
slog.F("value", row.Value),
|
|
slog.Error(err),
|
|
)
|
|
continue
|
|
}
|
|
if thresholdPercent < int64(minChatContextCompressionThreshold) ||
|
|
thresholdPercent > int64(maxChatContextCompressionThreshold) {
|
|
api.Logger.Warn(ctx, "skipping out-of-range user chat compaction threshold",
|
|
slog.F("key", row.Key),
|
|
slog.F("value", row.Value),
|
|
)
|
|
continue
|
|
}
|
|
|
|
resp.Thresholds = append(resp.Thresholds, codersdk.UserChatCompactionThreshold{
|
|
ModelConfigID: modelConfigID,
|
|
ThresholdPercent: int32(thresholdPercent),
|
|
})
|
|
}
|
|
|
|
httpapi.Write(ctx, rw, http.StatusOK, resp)
|
|
}
|
|
|
|
// @Summary Set user chat compaction threshold for a model config
|
|
// @x-apidocgen {"skip": true}
|
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
|
func (api *API) putUserChatCompactionThreshold(rw http.ResponseWriter, r *http.Request) {
|
|
var (
|
|
ctx = r.Context()
|
|
apiKey = httpmw.APIKey(r)
|
|
)
|
|
|
|
modelConfigID, ok := parseChatModelConfigID(rw, r)
|
|
if !ok {
|
|
return
|
|
}
|
|
|
|
var req codersdk.UpdateUserChatCompactionThresholdRequest
|
|
if !httpapi.Read(ctx, rw, r, &req) {
|
|
return
|
|
}
|
|
if req.ThresholdPercent < minChatContextCompressionThreshold ||
|
|
req.ThresholdPercent > maxChatContextCompressionThreshold {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "threshold_percent is out of range.",
|
|
Detail: fmt.Sprintf(
|
|
"threshold_percent must be between %d and %d, got %d.",
|
|
minChatContextCompressionThreshold,
|
|
maxChatContextCompressionThreshold,
|
|
req.ThresholdPercent,
|
|
),
|
|
})
|
|
return
|
|
}
|
|
|
|
// Use system context because GetChatModelConfigByID requires
|
|
// deployment-config read access, which non-admin users lack.
|
|
// The user is only checking if the model exists and is enabled
|
|
// before writing their own personal preference.
|
|
//nolint:gocritic // Non-admin users need this lookup to save their own setting.
|
|
modelConfig, err := api.Database.GetChatModelConfigByID(dbauthz.AsSystemRestricted(ctx), modelConfigID)
|
|
if err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) || httpapi.Is404Error(err) {
|
|
httpapi.ResourceNotFound(rw)
|
|
return
|
|
}
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to get chat model config.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
if !modelConfig.Enabled {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Model config is disabled.",
|
|
})
|
|
return
|
|
}
|
|
|
|
_, err = api.Database.UpdateUserChatCompactionThreshold(ctx, database.UpdateUserChatCompactionThresholdParams{
|
|
UserID: apiKey.UserID,
|
|
Key: codersdk.CompactionThresholdKey(modelConfigID),
|
|
ThresholdPercent: req.ThresholdPercent,
|
|
})
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Error updating user chat compaction threshold.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
httpapi.Write(ctx, rw, http.StatusOK, codersdk.UserChatCompactionThreshold{
|
|
ModelConfigID: modelConfigID,
|
|
ThresholdPercent: req.ThresholdPercent,
|
|
})
|
|
}
|
|
|
|
// @Summary Delete user chat compaction threshold for a model config
|
|
// @x-apidocgen {"skip": true}
|
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
|
func (api *API) deleteUserChatCompactionThreshold(rw http.ResponseWriter, r *http.Request) {
|
|
var (
|
|
ctx = r.Context()
|
|
apiKey = httpmw.APIKey(r)
|
|
)
|
|
|
|
modelConfigID, ok := parseChatModelConfigID(rw, r)
|
|
if !ok {
|
|
return
|
|
}
|
|
|
|
if err := api.Database.DeleteUserChatCompactionThreshold(ctx, database.DeleteUserChatCompactionThresholdParams{
|
|
UserID: apiKey.UserID,
|
|
Key: codersdk.CompactionThresholdKey(modelConfigID),
|
|
}); err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Error deleting user chat compaction threshold.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
rw.WriteHeader(http.StatusNoContent)
|
|
}
|
|
|
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
|
//
|
|
// @Summary Upload chat file
|
|
// @ID upload-chat-file
|
|
// @Security CoderSessionToken
|
|
// @Tags Chats
|
|
// @Accept image/png,image/jpeg,image/gif,image/webp,text/plain,text/markdown,text/csv,application/json,application/pdf
|
|
// @Produce json
|
|
// @Param organization query string true "Organization ID" format(uuid)
|
|
// @Success 201 {object} codersdk.UploadChatFileResponse
|
|
// @Router /api/experimental/chats/files [post]
|
|
// @Description Experimental: this endpoint is subject to change.
|
|
func (api *API) postChatFile(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
apiKey := httpmw.APIKey(r)
|
|
|
|
orgIDStr := r.URL.Query().Get("organization")
|
|
if orgIDStr == "" {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Missing organization query parameter.",
|
|
})
|
|
return
|
|
}
|
|
orgID, err := uuid.Parse(orgIDStr)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid organization ID.",
|
|
})
|
|
return
|
|
}
|
|
// NOTE: This authorize check is intentionally placed after query
|
|
// parameter parsing because we need orgID to scope the RBAC check
|
|
// to the correct org.
|
|
if !api.Authorize(r, policy.ActionCreate, rbac.ResourceChat.WithOwner(apiKey.UserID.String()).InOrg(orgID)) {
|
|
httpapi.Forbidden(rw)
|
|
return
|
|
}
|
|
|
|
contentType := r.Header.Get("Content-Type")
|
|
if contentType == "" {
|
|
contentType = "application/octet-stream"
|
|
}
|
|
// Strip parameters (e.g. "image/png; charset=utf-8" → "image/png")
|
|
// so the allowlist check matches the base media type.
|
|
if mediaType, _, err := mime.ParseMediaType(contentType); err == nil {
|
|
contentType = mediaType
|
|
}
|
|
// application/octet-stream means the client could not classify the file
|
|
// ahead of time, so we defer to byte classification below.
|
|
if contentType != "application/octet-stream" && !chatfiles.IsAllowedStoredMediaType(contentType) {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Unsupported file type.",
|
|
Detail: fmt.Sprintf("Allowed types: %s.", chatfiles.AllowedStoredMediaTypesString()),
|
|
})
|
|
return
|
|
}
|
|
|
|
// Extract filename from Content-Disposition header if provided.
|
|
var filename string
|
|
if cd := r.Header.Get("Content-Disposition"); cd != "" {
|
|
if _, params, err := mime.ParseMediaType(cd); err == nil {
|
|
filename = params["filename"]
|
|
}
|
|
}
|
|
|
|
r.Body = http.MaxBytesReader(rw, r.Body, codersdk.MaxChatFileSizeBytes)
|
|
data, err := io.ReadAll(r.Body)
|
|
if err != nil {
|
|
var maxBytesErr *http.MaxBytesError
|
|
if errors.As(err, &maxBytesErr) {
|
|
httpapi.Write(ctx, rw, http.StatusRequestEntityTooLarge, codersdk.Response{
|
|
Message: "File too large.",
|
|
Detail: fmt.Sprintf("Maximum file size is %d bytes.", codersdk.MaxChatFileSizeBytes),
|
|
})
|
|
return
|
|
}
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Failed to read file from request.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
// Verify the actual content matches an allowed file type so that
|
|
// a client cannot spoof Content-Type to serve active content.
|
|
filename, detected, err := chatfiles.PrepareStoredFile(filename, filename, data)
|
|
if err != nil {
|
|
switch {
|
|
case errors.Is(err, chatfiles.ErrStoredFileNameRequired):
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Filename is required.",
|
|
Detail: "Provide a filename in the Content-Disposition header.",
|
|
})
|
|
case errors.Is(err, chatfiles.ErrUnsupportedStoredFileType):
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Unsupported file type.",
|
|
Detail: fmt.Sprintf("Allowed types: %s.", chatfiles.AllowedStoredMediaTypesString()),
|
|
})
|
|
default:
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid file.",
|
|
Detail: err.Error(),
|
|
})
|
|
}
|
|
return
|
|
}
|
|
// The compatibility check below is security-critical: it keeps exact
|
|
// media-type matching by default while allowing application/
|
|
// octet-stream uploads to defer to byte classification, and letting
|
|
// text/plain refine to safe text subtypes such as JSON, CSV, and
|
|
// Markdown. Combined with the X-Content-Type-Options: nosniff header
|
|
// applied globally, this still prevents clients from smuggling binary
|
|
// or active content under a safer declared Content-Type.
|
|
if !chatfiles.IsCompatibleUploadMediaType(contentType, detected) {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "File content type does not match Content-Type header.",
|
|
Detail: fmt.Sprintf("Header declared %q but file content was detected as %q.", contentType, detected),
|
|
})
|
|
return
|
|
}
|
|
chatFile, err := api.Database.InsertChatFile(ctx, database.InsertChatFileParams{
|
|
OwnerID: apiKey.UserID,
|
|
OrganizationID: orgID,
|
|
Name: filename,
|
|
Mimetype: detected,
|
|
Data: data,
|
|
})
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to save chat file.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
httpapi.Write(ctx, rw, http.StatusCreated, codersdk.UploadChatFileResponse{
|
|
ID: chatFile.ID,
|
|
})
|
|
}
|
|
|
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
|
//
|
|
// @Summary Get chat file
|
|
// @ID get-chat-file
|
|
// @Security CoderSessionToken
|
|
// @Tags Chats
|
|
// @Produce image/png,image/jpeg,image/gif,image/webp,text/plain,text/markdown,text/csv,application/json,application/pdf
|
|
// @Param file path string true "File ID" format(uuid)
|
|
// @Success 200
|
|
// @Router /api/experimental/chats/files/{file} [get]
|
|
// @Description Experimental: this endpoint is subject to change.
|
|
func (api *API) chatFileByID(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
|
|
fileIDStr := chi.URLParam(r, "file")
|
|
fileID, err := uuid.Parse(fileIDStr)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid file ID.",
|
|
})
|
|
return
|
|
}
|
|
|
|
chatFile, err := api.Database.GetChatFileByID(ctx, fileID)
|
|
if err != nil {
|
|
if httpapi.Is404Error(err) {
|
|
httpapi.ResourceNotFound(rw)
|
|
return
|
|
}
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to get chat file.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
rw.Header().Set("Content-Type", chatFile.Mimetype)
|
|
disposition := "attachment"
|
|
if chatfiles.IsInlineRenderableStoredMediaType(chatFile.Mimetype) {
|
|
disposition = "inline"
|
|
}
|
|
if chatFile.Name != "" {
|
|
rw.Header().Set("Content-Disposition", mime.FormatMediaType(disposition, map[string]string{"filename": chatFile.Name}))
|
|
} else {
|
|
rw.Header().Set("Content-Disposition", disposition)
|
|
}
|
|
rw.Header().Set("Cache-Control", "private, max-age=31536000, immutable")
|
|
rw.Header().Set("Content-Length", strconv.Itoa(len(chatFile.Data)))
|
|
rw.WriteHeader(http.StatusOK)
|
|
if _, err := rw.Write(chatFile.Data); err != nil {
|
|
api.Logger.Debug(ctx, "failed to write chat file response", slog.Error(err))
|
|
}
|
|
}
|
|
|
|
func createChatInputFromRequest(ctx context.Context, db database.Store, req codersdk.CreateChatRequest) (
|
|
[]codersdk.ChatMessagePart,
|
|
string,
|
|
[]uuid.UUID,
|
|
*codersdk.Response,
|
|
) {
|
|
return createChatInputFromParts(ctx, db, req.Content, "content")
|
|
}
|
|
|
|
func createChatInputFromParts(
|
|
ctx context.Context,
|
|
db database.Store,
|
|
parts []codersdk.ChatInputPart,
|
|
fieldName string,
|
|
) ([]codersdk.ChatMessagePart, string, []uuid.UUID, *codersdk.Response) {
|
|
if len(parts) == 0 {
|
|
return nil, "", nil, &codersdk.Response{
|
|
Message: "Content is required.",
|
|
Detail: "Content cannot be empty.",
|
|
}
|
|
}
|
|
|
|
var fileIDs []uuid.UUID
|
|
content := make([]codersdk.ChatMessagePart, 0, len(parts))
|
|
textParts := make([]string, 0, len(parts))
|
|
for i, part := range parts {
|
|
switch strings.ToLower(strings.TrimSpace(string(part.Type))) {
|
|
case string(codersdk.ChatInputPartTypeText):
|
|
text := strings.TrimSpace(part.Text)
|
|
if text == "" {
|
|
return nil, "", nil, &codersdk.Response{
|
|
Message: "Invalid input part.",
|
|
Detail: fmt.Sprintf("%s[%d].text cannot be empty.", fieldName, i),
|
|
}
|
|
}
|
|
content = append(content, codersdk.ChatMessageText(text))
|
|
textParts = append(textParts, text)
|
|
case string(codersdk.ChatInputPartTypeFile):
|
|
if part.FileID == uuid.Nil {
|
|
return nil, "", nil, &codersdk.Response{
|
|
Message: "Invalid input part.",
|
|
Detail: fmt.Sprintf("%s[%d].file_id is required for file parts.", fieldName, i),
|
|
}
|
|
}
|
|
// Validate that the file exists and get its media type.
|
|
// File data is not loaded here; it's resolved at LLM
|
|
// dispatch time via chatFileResolver.
|
|
chatFile, err := db.GetChatFileByID(ctx, part.FileID)
|
|
if err != nil {
|
|
if httpapi.Is404Error(err) {
|
|
return nil, "", nil, &codersdk.Response{
|
|
Message: "Invalid input part.",
|
|
Detail: fmt.Sprintf("%s[%d].file_id references a file that does not exist.", fieldName, i),
|
|
}
|
|
}
|
|
return nil, "", nil, &codersdk.Response{
|
|
Message: "Internal error.",
|
|
Detail: fmt.Sprintf("Failed to retrieve file for %s[%d].", fieldName, i),
|
|
}
|
|
}
|
|
content = append(content, codersdk.ChatMessageFile(part.FileID, chatFile.Mimetype, chatFile.Name))
|
|
fileIDs = append(fileIDs, part.FileID)
|
|
// file-reference parts carry inline code snippets, not uploaded
|
|
// files. They have no FileID and are excluded from file tracking.
|
|
case string(codersdk.ChatInputPartTypeFileReference):
|
|
if part.FileName == "" {
|
|
return nil, "", nil, &codersdk.Response{
|
|
Message: "Invalid input part.",
|
|
Detail: fmt.Sprintf("%s[%d].file_name cannot be empty for file-reference.", fieldName, i),
|
|
}
|
|
}
|
|
content = append(content, codersdk.ChatMessageFileReference(part.FileName, part.StartLine, part.EndLine, part.Content))
|
|
// Build text representation for title generation.
|
|
lineRange := fmt.Sprintf("%d", part.StartLine)
|
|
if part.StartLine != part.EndLine {
|
|
lineRange = fmt.Sprintf("%d-%d", part.StartLine, part.EndLine)
|
|
}
|
|
var sb strings.Builder
|
|
_, _ = fmt.Fprintf(&sb, "[file-reference] %s:%s", part.FileName, lineRange)
|
|
if strings.TrimSpace(part.Content) != "" {
|
|
_, _ = fmt.Fprintf(&sb, "\n```%s\n%s\n```", part.FileName, strings.TrimSpace(part.Content))
|
|
}
|
|
textParts = append(textParts, sb.String())
|
|
default:
|
|
return nil, "", nil, &codersdk.Response{
|
|
Message: "Invalid input part.",
|
|
Detail: fmt.Sprintf(
|
|
"%s[%d].type %q is not supported.",
|
|
fieldName,
|
|
i,
|
|
part.Type,
|
|
),
|
|
}
|
|
}
|
|
}
|
|
|
|
// Allow file-only messages. The titleSource may be empty
|
|
// when only file parts are provided, callers handle this.
|
|
if len(content) == 0 {
|
|
return nil, "", nil, &codersdk.Response{
|
|
Message: "Content is required.",
|
|
Detail: fmt.Sprintf("%s must include at least one text or file part.", fieldName),
|
|
}
|
|
}
|
|
titleSource := strings.TrimSpace(strings.Join(textParts, " "))
|
|
return content, titleSource, fileIDs, nil
|
|
}
|
|
|
|
func chatTitleFromMessage(message string) string {
|
|
const maxWords = 6
|
|
const maxRunes = 80
|
|
words := strings.Fields(message)
|
|
if len(words) == 0 {
|
|
return "New Chat"
|
|
}
|
|
truncated := false
|
|
if len(words) > maxWords {
|
|
words = words[:maxWords]
|
|
truncated = true
|
|
}
|
|
title := strings.Join(words, " ")
|
|
if truncated {
|
|
title += "…"
|
|
}
|
|
return truncateRunes(title, maxRunes)
|
|
}
|
|
|
|
func truncateRunes(value string, maxLen int) string {
|
|
if maxLen <= 0 {
|
|
return ""
|
|
}
|
|
|
|
runes := []rune(value)
|
|
if len(runes) <= maxLen {
|
|
return value
|
|
}
|
|
|
|
return string(runes[:maxLen])
|
|
}
|
|
|
|
// linkFilesToChat inserts file-link rows into the chat_file_links
|
|
// join table. Cap enforcement and dedup are handled atomically in
|
|
// SQL. On success returns (nil, false). On failure returns the full
|
|
// input fileIDs slice — linking is all-or-nothing because the
|
|
// SQL operates on the batch atomically. capExceeded indicates
|
|
// whether the failure was due to the cap being exceeded (true)
|
|
// or a database error (false).
|
|
// Failures are logged but never block the caller.
|
|
func (api *API) linkFilesToChat(ctx context.Context, chatID uuid.UUID, fileIDs []uuid.UUID) (unlinked []uuid.UUID, capExceeded bool) {
|
|
if len(fileIDs) == 0 {
|
|
return nil, false
|
|
}
|
|
rejected, err := api.Database.LinkChatFiles(ctx, database.LinkChatFilesParams{
|
|
ChatID: chatID,
|
|
MaxFileLinks: int32(codersdk.MaxChatFileIDs),
|
|
FileIds: fileIDs,
|
|
})
|
|
if err != nil {
|
|
api.Logger.Error(ctx, "failed to link files to chat",
|
|
slog.F("chat_id", chatID),
|
|
slog.F("file_ids", fileIDs),
|
|
slog.Error(err),
|
|
)
|
|
return fileIDs, false
|
|
}
|
|
if rejected > 0 {
|
|
api.Logger.Warn(ctx, "file cap reached, files not linked",
|
|
slog.F("chat_id", chatID),
|
|
slog.F("file_ids", fileIDs),
|
|
slog.F("max_file_links", codersdk.MaxChatFileIDs),
|
|
)
|
|
return fileIDs, true
|
|
}
|
|
return nil, false
|
|
}
|
|
|
|
// fileLinkCapWarning builds a user-facing warning when a batch
|
|
// of file IDs was atomically rejected because the resulting
|
|
// array would exceed the per-chat file cap.
|
|
func fileLinkCapWarning(count int) string {
|
|
return fmt.Sprintf("file linking skipped: batch of %d file(s) would exceed limit of %d", count, codersdk.MaxChatFileIDs)
|
|
}
|
|
|
|
// fileLinkErrorWarning builds a user-facing warning when a
|
|
// database error prevented linking files to a chat.
|
|
func fileLinkErrorWarning(count int) string {
|
|
return fmt.Sprintf("%d file(s) could not be linked due to a server error", count)
|
|
}
|
|
|
|
// fetchChatFileMetadata returns metadata for all files linked to
|
|
// the given chat. Errors are logged and result in a nil return
|
|
// (callers treat file metadata as best-effort).
|
|
func (api *API) fetchChatFileMetadata(ctx context.Context, chatID uuid.UUID) []database.GetChatFileMetadataByChatIDRow {
|
|
rows, err := api.Database.GetChatFileMetadataByChatID(ctx, chatID)
|
|
if err != nil {
|
|
api.Logger.Error(ctx, "failed to fetch chat file metadata",
|
|
slog.F("chat_id", chatID),
|
|
slog.Error(err),
|
|
)
|
|
return nil
|
|
}
|
|
return rows
|
|
}
|
|
|
|
func convertChatCostModelBreakdown(model database.GetChatCostPerModelRow) codersdk.ChatCostModelBreakdown {
|
|
displayName := strings.TrimSpace(model.DisplayName)
|
|
if displayName == "" {
|
|
displayName = model.Model
|
|
}
|
|
return codersdk.ChatCostModelBreakdown{
|
|
ModelConfigID: model.ModelConfigID,
|
|
DisplayName: displayName,
|
|
Provider: model.Provider,
|
|
Model: model.Model,
|
|
TotalCostMicros: model.TotalCostMicros,
|
|
MessageCount: model.MessageCount,
|
|
TotalInputTokens: model.TotalInputTokens,
|
|
TotalOutputTokens: model.TotalOutputTokens,
|
|
TotalCacheReadTokens: model.TotalCacheReadTokens,
|
|
TotalCacheCreationTokens: model.TotalCacheCreationTokens,
|
|
TotalRuntimeMs: model.TotalRuntimeMs,
|
|
}
|
|
}
|
|
|
|
func convertChatCostChatBreakdown(chat database.GetChatCostPerChatRow) codersdk.ChatCostChatBreakdown {
|
|
return codersdk.ChatCostChatBreakdown{
|
|
RootChatID: chat.RootChatID,
|
|
ChatTitle: chat.ChatTitle,
|
|
TotalCostMicros: chat.TotalCostMicros,
|
|
MessageCount: chat.MessageCount,
|
|
TotalInputTokens: chat.TotalInputTokens,
|
|
TotalOutputTokens: chat.TotalOutputTokens,
|
|
TotalCacheReadTokens: chat.TotalCacheReadTokens,
|
|
TotalCacheCreationTokens: chat.TotalCacheCreationTokens,
|
|
TotalRuntimeMs: chat.TotalRuntimeMs,
|
|
}
|
|
}
|
|
|
|
func convertChatCostUserRollup(user database.GetChatCostPerUserRow) codersdk.ChatCostUserRollup {
|
|
return codersdk.ChatCostUserRollup{
|
|
UserID: user.UserID,
|
|
Username: user.Username,
|
|
Name: user.Name,
|
|
AvatarURL: user.AvatarURL,
|
|
TotalCostMicros: user.TotalCostMicros,
|
|
MessageCount: user.MessageCount,
|
|
ChatCount: user.ChatCount,
|
|
TotalInputTokens: user.TotalInputTokens,
|
|
TotalOutputTokens: user.TotalOutputTokens,
|
|
TotalCacheReadTokens: user.TotalCacheReadTokens,
|
|
TotalCacheCreationTokens: user.TotalCacheCreationTokens,
|
|
TotalRuntimeMs: user.TotalRuntimeMs,
|
|
}
|
|
}
|
|
|
|
func convertChatQueuedMessage(m database.ChatQueuedMessage) codersdk.ChatQueuedMessage {
|
|
return db2sdk.ChatQueuedMessage(m)
|
|
}
|
|
|
|
func convertChatQueuedMessagePtr(m database.ChatQueuedMessage) *codersdk.ChatQueuedMessage {
|
|
qm := convertChatQueuedMessage(m)
|
|
return &qm
|
|
}
|
|
|
|
func convertChatQueuedMessages(msgs []database.ChatQueuedMessage) []codersdk.ChatQueuedMessage {
|
|
result := make([]codersdk.ChatQueuedMessage, 0, len(msgs))
|
|
for _, m := range msgs {
|
|
result = append(result, convertChatQueuedMessage(m))
|
|
}
|
|
return result
|
|
}
|
|
|
|
func convertChatMessage(m database.ChatMessage) codersdk.ChatMessage {
|
|
return db2sdk.ChatMessage(m)
|
|
}
|
|
|
|
func convertChatMessages(messages []database.ChatMessage) []codersdk.ChatMessage {
|
|
result := make([]codersdk.ChatMessage, 0, len(messages))
|
|
for _, m := range messages {
|
|
result = append(result, convertChatMessage(m))
|
|
}
|
|
return result
|
|
}
|
|
|
|
func parseUserAIProviderID(r *http.Request) (uuid.UUID, error) {
|
|
return uuid.Parse(chi.URLParam(r, "aiProvider"))
|
|
}
|
|
|
|
func convertAIProviderSummary(provider database.AIProvider) codersdk.AIProviderSummary {
|
|
displayName := provider.Name
|
|
if provider.DisplayName.Valid && provider.DisplayName.String != "" {
|
|
displayName = provider.DisplayName.String
|
|
}
|
|
return codersdk.AIProviderSummary{
|
|
ID: provider.ID,
|
|
Type: codersdk.AIProviderType(provider.Type),
|
|
Name: provider.Name,
|
|
DisplayName: displayName,
|
|
Enabled: provider.Enabled,
|
|
Deleted: provider.Deleted,
|
|
}
|
|
}
|
|
|
|
func (api *API) listUserAIProviderKeyConfigs(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
targetUser := httpmw.UserParam(r)
|
|
//nolint:gocritic // Users can list limited provider metadata to manage their own AI provider keys.
|
|
metadataCtx := dbauthz.AsAIProviderMetadataReader(ctx)
|
|
providers, err := api.Database.GetAIProviders(metadataCtx, database.GetAIProvidersParams{IncludeDisabled: true})
|
|
if err != nil {
|
|
api.Logger.Error(ctx, "failed to list user AI provider configs", slog.Error(err), slog.F("user_id", targetUser.ID))
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{Message: "Failed to list AI providers."})
|
|
return
|
|
}
|
|
keys, err := api.Database.GetUserAIProviderKeysByUserID(ctx, targetUser.ID)
|
|
if err != nil {
|
|
api.Logger.Error(ctx, "failed to list user AI provider keys", slog.Error(err), slog.F("user_id", targetUser.ID))
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{Message: "Failed to list user AI provider keys."})
|
|
return
|
|
}
|
|
|
|
keysByProviderID := make(map[uuid.UUID]struct{}, len(keys))
|
|
for _, key := range keys {
|
|
keysByProviderID[key.AIProviderID] = struct{}{}
|
|
}
|
|
|
|
visibleProviders := make([]database.AIProvider, 0, len(providers))
|
|
visibleProviderIDs := make([]uuid.UUID, 0, len(providers))
|
|
for _, provider := range providers {
|
|
_, hasUserKey := keysByProviderID[provider.ID]
|
|
if !provider.Enabled && !hasUserKey {
|
|
continue
|
|
}
|
|
visibleProviders = append(visibleProviders, provider)
|
|
visibleProviderIDs = append(visibleProviderIDs, provider.ID)
|
|
}
|
|
|
|
providerKeysByProviderID := make(map[uuid.UUID]struct{}, len(visibleProviderIDs))
|
|
if len(visibleProviderIDs) > 0 {
|
|
providerKeyIDs, err := api.Database.GetAIProviderKeyPresence(metadataCtx, visibleProviderIDs)
|
|
if err != nil {
|
|
api.Logger.Error(ctx, "failed to list AI provider key presence", slog.Error(err), slog.F("user_id", targetUser.ID))
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{Message: "Failed to list AI provider keys."})
|
|
return
|
|
}
|
|
for _, providerID := range providerKeyIDs {
|
|
providerKeysByProviderID[providerID] = struct{}{}
|
|
}
|
|
}
|
|
|
|
byokEnabled := api.DeploymentValues.AI.BridgeConfig.AllowBYOK.Value()
|
|
configs := make([]codersdk.UserAIProviderKeyConfig, 0, len(visibleProviders))
|
|
for _, provider := range visibleProviders {
|
|
_, hasUserKey := keysByProviderID[provider.ID]
|
|
_, hasProviderKey := providerKeysByProviderID[provider.ID]
|
|
configs = append(configs, codersdk.UserAIProviderKeyConfig{
|
|
Provider: convertAIProviderSummary(provider),
|
|
HasUserAPIKey: hasUserKey,
|
|
HasProviderAPIKey: hasProviderKey,
|
|
BYOKEnabled: byokEnabled,
|
|
})
|
|
}
|
|
httpapi.Write(ctx, rw, http.StatusOK, configs)
|
|
}
|
|
|
|
func (api *API) upsertUserAIProviderKey(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
if !api.DeploymentValues.AI.BridgeConfig.AllowBYOK.Value() {
|
|
httpapi.Write(ctx, rw, http.StatusForbidden, codersdk.Response{Message: "BYOK is disabled."})
|
|
return
|
|
}
|
|
targetUser := httpmw.UserParam(r)
|
|
providerID, err := parseUserAIProviderID(r)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{Message: "Invalid AI provider ID."})
|
|
return
|
|
}
|
|
//nolint:gocritic // Users can attach their own key to an enabled provider without AI provider admin permissions.
|
|
metadataCtx := dbauthz.AsAIProviderMetadataReader(ctx)
|
|
provider, err := api.Database.GetAIProviderByID(metadataCtx, providerID)
|
|
if err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{Message: "AI provider not found."})
|
|
return
|
|
}
|
|
api.Logger.Error(ctx, "failed to get AI provider", slog.Error(err), slog.F("ai_provider_id", providerID))
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{Message: "Failed to get AI provider."})
|
|
return
|
|
}
|
|
if !provider.Enabled {
|
|
httpapi.Write(ctx, rw, http.StatusPreconditionFailed, codersdk.Response{Message: "AI provider is disabled."})
|
|
return
|
|
}
|
|
var req codersdk.CreateUserAIProviderKeyRequest
|
|
if !httpapi.Read(ctx, rw, r, &req) {
|
|
return
|
|
}
|
|
if err := validateChatProviderAPIKeySize(req.APIKey); err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "API key too large.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
if req.APIKey == "" {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{Message: "API key is required."})
|
|
return
|
|
}
|
|
if strings.TrimSpace(req.APIKey) != req.APIKey {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{Message: "API key must not contain leading or trailing whitespace."})
|
|
return
|
|
}
|
|
providerKeys, err := api.Database.GetAIProviderKeyPresence(metadataCtx, []uuid.UUID{providerID})
|
|
if err != nil {
|
|
api.Logger.Error(ctx, "failed to list AI provider key presence", slog.Error(err), slog.F("ai_provider_id", providerID))
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{Message: "Failed to list AI provider keys."})
|
|
return
|
|
}
|
|
now := api.Clock.Now()
|
|
_, err = api.Database.UpsertUserAIProviderKey(ctx, database.UpsertUserAIProviderKeyParams{
|
|
ID: uuid.New(),
|
|
UserID: targetUser.ID,
|
|
AIProviderID: providerID,
|
|
APIKey: req.APIKey,
|
|
ApiKeyKeyID: sql.NullString{},
|
|
CreatedAt: now,
|
|
UpdatedAt: now,
|
|
})
|
|
if err != nil {
|
|
api.Logger.Error(ctx, "failed to update user AI provider key", slog.Error(err), slog.F("user_id", targetUser.ID), slog.F("ai_provider_id", providerID))
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{Message: "Failed to update user AI provider key."})
|
|
return
|
|
}
|
|
httpapi.Write(ctx, rw, http.StatusOK, codersdk.UserAIProviderKeyConfig{
|
|
Provider: convertAIProviderSummary(provider),
|
|
HasUserAPIKey: true,
|
|
HasProviderAPIKey: len(providerKeys) > 0,
|
|
BYOKEnabled: true,
|
|
})
|
|
}
|
|
|
|
func (api *API) deleteUserAIProviderKey(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
targetUser := httpmw.UserParam(r)
|
|
providerID, err := parseUserAIProviderID(r)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{Message: "Invalid AI provider ID."})
|
|
return
|
|
}
|
|
if err := api.Database.DeleteUserAIProviderKey(ctx, database.DeleteUserAIProviderKeyParams{UserID: targetUser.ID, AIProviderID: providerID}); err != nil {
|
|
api.Logger.Error(ctx, "failed to delete user AI provider key", slog.Error(err), slog.F("user_id", targetUser.ID), slog.F("ai_provider_id", providerID))
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{Message: "Failed to delete user AI provider key."})
|
|
return
|
|
}
|
|
httpapi.Write(ctx, rw, http.StatusNoContent, nil)
|
|
}
|
|
|
|
func (api *API) configuredProvidersFromAIProviders(ctx context.Context, providers []database.AIProvider) ([]chatprovider.ConfiguredProvider, error) {
|
|
if len(providers) == 0 {
|
|
return nil, nil
|
|
}
|
|
providerIDs := make([]uuid.UUID, 0, len(providers))
|
|
for _, provider := range providers {
|
|
providerIDs = append(providerIDs, provider.ID)
|
|
}
|
|
keys, err := api.Database.GetAIProviderKeysByProviderIDs(ctx, providerIDs)
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("get AI provider keys: %w", err)
|
|
}
|
|
keysByProviderID := make(map[uuid.UUID][]database.AIProviderKey, len(providers))
|
|
for _, key := range keys {
|
|
keysByProviderID[key.ProviderID] = append(keysByProviderID[key.ProviderID], key)
|
|
}
|
|
configuredProviders := make([]chatprovider.ConfiguredProvider, 0, len(providers))
|
|
for _, provider := range providers {
|
|
configuredProviders = append(configuredProviders, api.configuredProviderFromAIProviderKeys(provider, keysByProviderID[provider.ID]))
|
|
}
|
|
return configuredProviders, nil
|
|
}
|
|
|
|
func (api *API) configuredProviderFromAIProviderKeys(provider database.AIProvider, keys []database.AIProviderKey) chatprovider.ConfiguredProvider {
|
|
apiKey := ""
|
|
for _, key := range keys {
|
|
if key.APIKey != "" {
|
|
apiKey = key.APIKey
|
|
break
|
|
}
|
|
}
|
|
return chatprovider.ConfiguredProvider{
|
|
ProviderID: provider.ID,
|
|
Provider: string(provider.Type),
|
|
APIKey: apiKey,
|
|
BaseURL: provider.BaseUrl,
|
|
CentralAPIKeyEnabled: true,
|
|
AllowUserAPIKey: api.DeploymentValues.AI.BridgeConfig.AllowBYOK.Value(),
|
|
AllowCentralAPIKeyFallback: true,
|
|
}
|
|
}
|
|
|
|
func writeLegacyChatProviderGone(rw http.ResponseWriter, r *http.Request) {
|
|
httpapi.Write(r.Context(), rw, http.StatusGone, codersdk.Response{
|
|
Message: "Legacy chat provider APIs were removed. Use AI provider APIs instead.",
|
|
Detail: "See https://coder.com/docs/ai-coder/agents/models#providers for AI provider configuration.",
|
|
})
|
|
}
|
|
|
|
func (*API) listChatProviders(rw http.ResponseWriter, r *http.Request) {
|
|
writeLegacyChatProviderGone(rw, r)
|
|
}
|
|
|
|
func (*API) createChatProvider(rw http.ResponseWriter, r *http.Request) {
|
|
writeLegacyChatProviderGone(rw, r)
|
|
}
|
|
|
|
func (*API) updateChatProvider(rw http.ResponseWriter, r *http.Request) {
|
|
writeLegacyChatProviderGone(rw, r)
|
|
}
|
|
|
|
func (*API) deleteChatProvider(rw http.ResponseWriter, r *http.Request) {
|
|
writeLegacyChatProviderGone(rw, r)
|
|
}
|
|
|
|
func (*API) listUserChatProviderConfigs(rw http.ResponseWriter, r *http.Request) {
|
|
writeLegacyChatProviderGone(rw, r)
|
|
}
|
|
|
|
func (*API) upsertUserChatProviderKey(rw http.ResponseWriter, r *http.Request) {
|
|
writeLegacyChatProviderGone(rw, r)
|
|
}
|
|
|
|
func (*API) deleteUserChatProviderKey(rw http.ResponseWriter, r *http.Request) {
|
|
writeLegacyChatProviderGone(rw, r)
|
|
}
|
|
|
|
func (api *API) listChatModelConfigs(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
|
|
// Admin users can see all model configs (including disabled ones)
|
|
// for management purposes. Non-admin users see only enabled
|
|
// configs, which is sufficient for using the chat feature.
|
|
isAdmin := api.Authorize(r, policy.ActionRead, rbac.ResourceDeploymentConfig)
|
|
|
|
var configs []database.ChatModelConfig
|
|
var err error
|
|
if isAdmin {
|
|
configs, err = api.Database.GetChatModelConfigs(ctx)
|
|
} else {
|
|
//nolint:gocritic // All authenticated users need to read enabled model configs to use the chat feature.
|
|
configs, err = api.Database.GetEnabledChatModelConfigs(dbauthz.AsChatd(ctx))
|
|
}
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to list chat model configs.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
resp := make([]codersdk.ChatModelConfig, 0, len(configs))
|
|
for _, config := range configs {
|
|
resp = append(resp, convertChatModelConfig(config))
|
|
}
|
|
|
|
httpapi.Write(ctx, rw, http.StatusOK, resp)
|
|
}
|
|
|
|
func (api *API) createChatModelConfig(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
apiKey := httpmw.APIKey(r)
|
|
if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) {
|
|
httpapi.Forbidden(rw)
|
|
return
|
|
}
|
|
|
|
var req codersdk.CreateChatModelConfigRequest
|
|
if !httpapi.Read(ctx, rw, r, &req) {
|
|
return
|
|
}
|
|
|
|
if req.AIProviderID == nil {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{Message: "AI provider ID is required."})
|
|
return
|
|
}
|
|
//nolint:gocritic // The route already authorized chat model config updates.
|
|
aiProvider, err := api.Database.GetAIProviderByID(dbauthz.AsChatd(ctx), *req.AIProviderID)
|
|
if err != nil {
|
|
if httpapi.Is404Error(err) {
|
|
httpapi.Write(ctx, rw, http.StatusPreconditionFailed, codersdk.Response{Message: "AI provider is not configured."})
|
|
return
|
|
}
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to get AI provider.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
if !aiProvider.Enabled {
|
|
httpapi.Write(ctx, rw, http.StatusPreconditionFailed, codersdk.Response{Message: "AI provider is disabled."})
|
|
return
|
|
}
|
|
provider := string(aiProvider.Type)
|
|
aiProviderID := uuid.NullUUID{UUID: aiProvider.ID, Valid: true}
|
|
|
|
model := strings.TrimSpace(req.Model)
|
|
if model == "" {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Model is required.",
|
|
})
|
|
return
|
|
}
|
|
|
|
enabled := true
|
|
if req.Enabled != nil {
|
|
enabled = *req.Enabled
|
|
}
|
|
isDefault := false
|
|
if req.IsDefault != nil {
|
|
isDefault = *req.IsDefault
|
|
}
|
|
|
|
if req.ContextLimit == nil || *req.ContextLimit <= 0 {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Context limit is required.",
|
|
Detail: "context_limit must be greater than zero.",
|
|
})
|
|
return
|
|
}
|
|
contextLimit := *req.ContextLimit
|
|
|
|
compressionThreshold, thresholdErr := normalizeChatCompressionThreshold(
|
|
req.CompressionThreshold,
|
|
defaultChatContextCompressionThreshold,
|
|
)
|
|
if thresholdErr != nil {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid compression threshold.",
|
|
Detail: thresholdErr.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
modelConfigRaw, modelConfigErr := marshalChatModelCallConfig(req.ModelConfig)
|
|
if modelConfigErr != nil {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid model config.",
|
|
Detail: modelConfigErr.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
insertParams := database.InsertChatModelConfigParams{
|
|
Provider: provider,
|
|
Model: model,
|
|
DisplayName: strings.TrimSpace(req.DisplayName),
|
|
Enabled: enabled,
|
|
IsDefault: isDefault,
|
|
ContextLimit: contextLimit,
|
|
CompressionThreshold: compressionThreshold,
|
|
Options: modelConfigRaw,
|
|
AIProviderID: aiProviderID,
|
|
CreatedBy: uuid.NullUUID{UUID: apiKey.UserID, Valid: apiKey.UserID != uuid.Nil},
|
|
UpdatedBy: uuid.NullUUID{UUID: apiKey.UserID, Valid: apiKey.UserID != uuid.Nil},
|
|
}
|
|
|
|
var inserted database.ChatModelConfig
|
|
err = api.Database.InTx(func(tx database.Store) error {
|
|
//nolint:gocritic // The route already authorized chat model config updates.
|
|
lockedAIProvider, err := tx.GetAIProviderByIDForReferenceLock(dbauthz.AsChatd(ctx), insertParams.AIProviderID.UUID)
|
|
if err != nil {
|
|
if xerrors.Is(err, sql.ErrNoRows) {
|
|
return errChatProviderNotConfigured
|
|
}
|
|
return xerrors.Errorf("get AI provider for update: %w", err)
|
|
}
|
|
if !lockedAIProvider.Enabled {
|
|
return errChatProviderNotConfigured
|
|
}
|
|
insertParams.Provider = string(lockedAIProvider.Type)
|
|
|
|
insertAsDefault := isDefault
|
|
if !insertAsDefault {
|
|
_, err := tx.GetDefaultChatModelConfig(ctx)
|
|
switch {
|
|
case err == nil:
|
|
// A default already exists.
|
|
case xerrors.Is(err, sql.ErrNoRows):
|
|
insertAsDefault = true
|
|
default:
|
|
return xerrors.Errorf("get default model config: %w", err)
|
|
}
|
|
}
|
|
|
|
if insertAsDefault {
|
|
if err := tx.UnsetDefaultChatModelConfigs(ctx); err != nil {
|
|
return xerrors.Errorf("unset default model configs: %w", err)
|
|
}
|
|
}
|
|
insertParams.IsDefault = insertAsDefault
|
|
|
|
config, err := tx.InsertChatModelConfig(ctx, insertParams)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
inserted = config
|
|
|
|
if err := ensureDefaultChatModelConfig(ctx, tx); err != nil {
|
|
return err
|
|
}
|
|
|
|
refreshedConfig, err := tx.GetChatModelConfigByID(ctx, inserted.ID)
|
|
if err != nil {
|
|
return xerrors.Errorf("refresh inserted chat model config: %w", err)
|
|
}
|
|
inserted = refreshedConfig
|
|
return nil
|
|
}, nil)
|
|
if err != nil {
|
|
switch {
|
|
case database.IsUniqueViolation(err):
|
|
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{
|
|
Message: "Chat model config already exists.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
case xerrors.Is(err, errChatProviderNotConfigured):
|
|
httpapi.Write(ctx, rw, http.StatusPreconditionFailed, codersdk.Response{
|
|
Message: "Chat provider is not configured.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
default:
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to create chat model config.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
}
|
|
|
|
publishChatConfigEvent(api.Logger, api.Pubsub, pubsub.ChatConfigEventModelConfig, inserted.ID)
|
|
|
|
httpapi.Write(ctx, rw, http.StatusCreated, convertChatModelConfig(inserted))
|
|
}
|
|
|
|
func (api *API) updateChatModelConfig(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
apiKey := httpmw.APIKey(r)
|
|
if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) {
|
|
httpapi.Forbidden(rw)
|
|
return
|
|
}
|
|
|
|
modelConfigID, ok := parseChatModelConfigID(rw, r)
|
|
if !ok {
|
|
return
|
|
}
|
|
|
|
existing, err := api.Database.GetChatModelConfigByID(ctx, modelConfigID)
|
|
if err != nil {
|
|
if httpapi.Is404Error(err) {
|
|
httpapi.ResourceNotFound(rw)
|
|
return
|
|
}
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to get chat model config.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
var req codersdk.UpdateChatModelConfigRequest
|
|
if !httpapi.Read(ctx, rw, r, &req) {
|
|
return
|
|
}
|
|
|
|
if strings.TrimSpace(req.Provider) != "" && req.AIProviderID == nil {
|
|
requestedProvider := chatprovider.NormalizeProvider(req.Provider)
|
|
if requestedProvider == "" {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{Message: "Invalid provider."})
|
|
return
|
|
}
|
|
if requestedProvider != existing.Provider {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{Message: "AI provider ID is required when updating provider."})
|
|
return
|
|
}
|
|
}
|
|
|
|
provider := existing.Provider
|
|
aiProviderID := existing.AIProviderID
|
|
if req.AIProviderID != nil {
|
|
//nolint:gocritic // The route already authorized chat model config updates.
|
|
aiProvider, err := api.Database.GetAIProviderByID(dbauthz.AsChatd(ctx), *req.AIProviderID)
|
|
if err != nil {
|
|
if httpapi.Is404Error(err) {
|
|
httpapi.Write(ctx, rw, http.StatusPreconditionFailed, codersdk.Response{Message: "AI provider is not configured."})
|
|
return
|
|
}
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to get AI provider.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
if !aiProvider.Enabled {
|
|
httpapi.Write(ctx, rw, http.StatusPreconditionFailed, codersdk.Response{Message: "AI provider is disabled."})
|
|
return
|
|
}
|
|
provider = string(aiProvider.Type)
|
|
aiProviderID = uuid.NullUUID{UUID: aiProvider.ID, Valid: true}
|
|
}
|
|
|
|
model := existing.Model
|
|
if trimmed := strings.TrimSpace(req.Model); trimmed != "" {
|
|
model = trimmed
|
|
}
|
|
|
|
displayName := existing.DisplayName
|
|
if trimmed := strings.TrimSpace(req.DisplayName); trimmed != "" {
|
|
displayName = trimmed
|
|
}
|
|
|
|
enabled := existing.Enabled
|
|
if req.Enabled != nil {
|
|
enabled = *req.Enabled
|
|
}
|
|
isDefault := existing.IsDefault
|
|
if req.IsDefault != nil {
|
|
isDefault = *req.IsDefault
|
|
}
|
|
|
|
contextLimit := existing.ContextLimit
|
|
if req.ContextLimit != nil {
|
|
if *req.ContextLimit <= 0 {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Context limit must be greater than zero.",
|
|
})
|
|
return
|
|
}
|
|
contextLimit = *req.ContextLimit
|
|
}
|
|
|
|
compressionThreshold, thresholdErr := normalizeChatCompressionThreshold(
|
|
req.CompressionThreshold,
|
|
existing.CompressionThreshold,
|
|
)
|
|
if thresholdErr != nil {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid compression threshold.",
|
|
Detail: thresholdErr.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
modelConfigRaw := existing.Options
|
|
if req.ModelConfig != nil {
|
|
encodedModelConfig, modelConfigErr := marshalChatModelCallConfig(req.ModelConfig)
|
|
if modelConfigErr != nil {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid model config.",
|
|
Detail: modelConfigErr.Error(),
|
|
})
|
|
return
|
|
}
|
|
modelConfigRaw = encodedModelConfig
|
|
}
|
|
|
|
updateParams := database.UpdateChatModelConfigParams{
|
|
Provider: provider,
|
|
Model: model,
|
|
DisplayName: displayName,
|
|
Enabled: enabled,
|
|
IsDefault: isDefault,
|
|
ContextLimit: contextLimit,
|
|
CompressionThreshold: compressionThreshold,
|
|
Options: modelConfigRaw,
|
|
AIProviderID: aiProviderID,
|
|
UpdatedBy: uuid.NullUUID{UUID: apiKey.UserID, Valid: apiKey.UserID != uuid.Nil},
|
|
ID: existing.ID,
|
|
}
|
|
|
|
var updated database.ChatModelConfig
|
|
err = api.Database.InTx(func(tx database.Store) error {
|
|
if updateParams.AIProviderID.Valid && req.AIProviderID != nil {
|
|
//nolint:gocritic // The route already authorized chat model config updates.
|
|
aiProvider, err := tx.GetAIProviderByIDForReferenceLock(dbauthz.AsChatd(ctx), updateParams.AIProviderID.UUID)
|
|
if err != nil {
|
|
if xerrors.Is(err, sql.ErrNoRows) {
|
|
return errChatProviderNotConfigured
|
|
}
|
|
return xerrors.Errorf("get AI provider for update: %w", err)
|
|
}
|
|
if !aiProvider.Enabled {
|
|
return errChatProviderNotConfigured
|
|
}
|
|
updateParams.Provider = string(aiProvider.Type)
|
|
}
|
|
|
|
setAsDefault := updateParams.IsDefault && !existing.IsDefault
|
|
if setAsDefault {
|
|
if err := tx.UnsetDefaultChatModelConfigs(ctx); err != nil {
|
|
return xerrors.Errorf("unset default model configs: %w", err)
|
|
}
|
|
}
|
|
|
|
_, err := tx.UpdateChatModelConfig(ctx, updateParams)
|
|
if err != nil {
|
|
if xerrors.Is(err, sql.ErrNoRows) {
|
|
return errChatModelConfigNotFound
|
|
}
|
|
return err
|
|
}
|
|
|
|
excludeConfigID := uuid.Nil
|
|
if existing.IsDefault && req.IsDefault != nil && !*req.IsDefault {
|
|
excludeConfigID = existing.ID
|
|
}
|
|
|
|
if err := ensureDefaultChatModelConfig(
|
|
ctx,
|
|
tx,
|
|
excludeConfigID,
|
|
); err != nil {
|
|
return err
|
|
}
|
|
|
|
refreshedConfig, err := tx.GetChatModelConfigByID(ctx, existing.ID)
|
|
if err != nil {
|
|
if xerrors.Is(err, sql.ErrNoRows) {
|
|
// Do not wrap with %w. The outer handler maps target misses to 404.
|
|
return xerrors.Errorf("refresh updated chat model config: %v", err)
|
|
}
|
|
return xerrors.Errorf("refresh updated chat model config: %w", err)
|
|
}
|
|
updated = refreshedConfig
|
|
return nil
|
|
}, nil)
|
|
if err != nil {
|
|
switch {
|
|
case database.IsUniqueViolation(err):
|
|
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{
|
|
Message: "Chat model config already exists.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
case xerrors.Is(err, errChatProviderNotConfigured):
|
|
httpapi.Write(ctx, rw, http.StatusPreconditionFailed, codersdk.Response{
|
|
Message: "Chat provider is not configured.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
case xerrors.Is(err, errChatModelConfigNotFound):
|
|
httpapi.ResourceNotFound(rw)
|
|
return
|
|
default:
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to update chat model config.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
}
|
|
|
|
publishChatConfigEvent(api.Logger, api.Pubsub, pubsub.ChatConfigEventModelConfig, updated.ID)
|
|
|
|
httpapi.Write(ctx, rw, http.StatusOK, convertChatModelConfig(updated))
|
|
}
|
|
|
|
func (api *API) deleteChatModelConfig(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) {
|
|
httpapi.Forbidden(rw)
|
|
return
|
|
}
|
|
|
|
modelConfigID, ok := parseChatModelConfigID(rw, r)
|
|
if !ok {
|
|
return
|
|
}
|
|
|
|
if _, err := api.Database.GetChatModelConfigByID(ctx, modelConfigID); err != nil {
|
|
if httpapi.Is404Error(err) {
|
|
httpapi.ResourceNotFound(rw)
|
|
return
|
|
}
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to get chat model config.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
if err := api.Database.InTx(func(tx database.Store) error {
|
|
if err := tx.DeleteChatModelConfigByID(ctx, modelConfigID); err != nil {
|
|
return err
|
|
}
|
|
return ensureDefaultChatModelConfig(ctx, tx)
|
|
}, nil); err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Failed to delete chat model config.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
publishChatConfigEvent(api.Logger, api.Pubsub, pubsub.ChatConfigEventModelConfig, modelConfigID)
|
|
|
|
rw.WriteHeader(http.StatusNoContent)
|
|
}
|
|
|
|
func ensureDefaultChatModelConfig(
|
|
ctx context.Context,
|
|
tx database.Store,
|
|
excludedConfigIDs ...uuid.UUID,
|
|
) error {
|
|
_, err := tx.GetDefaultChatModelConfig(ctx)
|
|
switch {
|
|
case err == nil:
|
|
return nil
|
|
case !xerrors.Is(err, sql.ErrNoRows):
|
|
return xerrors.Errorf("get default model config: %w", err)
|
|
}
|
|
|
|
modelConfigs, err := tx.GetChatModelConfigs(ctx)
|
|
if err != nil {
|
|
return xerrors.Errorf("list chat model configs: %w", err)
|
|
}
|
|
if len(modelConfigs) == 0 {
|
|
return nil
|
|
}
|
|
|
|
candidateConfig := modelConfigs[0]
|
|
excluded := make(map[uuid.UUID]struct{}, len(excludedConfigIDs))
|
|
for _, configID := range excludedConfigIDs {
|
|
if configID == uuid.Nil {
|
|
continue
|
|
}
|
|
excluded[configID] = struct{}{}
|
|
}
|
|
for _, config := range modelConfigs {
|
|
if _, skip := excluded[config.ID]; skip {
|
|
continue
|
|
}
|
|
candidateConfig = config
|
|
break
|
|
}
|
|
|
|
if err := tx.UnsetDefaultChatModelConfigs(ctx); err != nil {
|
|
return xerrors.Errorf("unset default model configs: %w", err)
|
|
}
|
|
|
|
params := chatModelConfigToUpdateParams(candidateConfig)
|
|
params.IsDefault = true
|
|
if _, err := tx.UpdateChatModelConfig(ctx, params); err != nil {
|
|
if xerrors.Is(err, sql.ErrNoRows) {
|
|
// Do not wrap with %w. Callers map target misses to 404, but a
|
|
// default-candidate race is an internal retryable failure.
|
|
return xerrors.Errorf("set default model config: %v", err)
|
|
}
|
|
return xerrors.Errorf("set default model config: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func chatModelConfigToUpdateParams(
|
|
config database.ChatModelConfig,
|
|
) database.UpdateChatModelConfigParams {
|
|
return database.UpdateChatModelConfigParams{
|
|
Provider: config.Provider,
|
|
Model: config.Model,
|
|
DisplayName: config.DisplayName,
|
|
Enabled: config.Enabled,
|
|
IsDefault: config.IsDefault,
|
|
ContextLimit: config.ContextLimit,
|
|
CompressionThreshold: config.CompressionThreshold,
|
|
Options: config.Options,
|
|
AIProviderID: config.AIProviderID,
|
|
UpdatedBy: uuid.NullUUID{},
|
|
ID: config.ID,
|
|
}
|
|
}
|
|
|
|
func nullInt64Ptr(n sql.NullInt64) *int64 {
|
|
if !n.Valid {
|
|
return nil
|
|
}
|
|
return &n.Int64
|
|
}
|
|
|
|
func writeChatUsageLimitUserNotFound(ctx context.Context, rw http.ResponseWriter) {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "User not found.",
|
|
})
|
|
}
|
|
|
|
func writeChatUsageLimitOverrideNotFound(ctx context.Context, rw http.ResponseWriter) {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Chat usage limit override not found.",
|
|
})
|
|
}
|
|
|
|
func writeChatUsageLimitGroupOverrideNotFound(ctx context.Context, rw http.ResponseWriter) {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Chat usage limit group override not found.",
|
|
})
|
|
}
|
|
|
|
func writeChatUsageLimitGroupNotFound(ctx context.Context, rw http.ResponseWriter) {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Group not found.",
|
|
})
|
|
}
|
|
|
|
func parseChatUsageLimitUserID(rw http.ResponseWriter, r *http.Request) (uuid.UUID, bool) {
|
|
userID, err := uuid.Parse(chi.URLParam(r, "user"))
|
|
if err != nil {
|
|
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid chat usage limit user ID.",
|
|
Detail: err.Error(),
|
|
})
|
|
return uuid.Nil, false
|
|
}
|
|
return userID, true
|
|
}
|
|
|
|
func parseChatModelConfigID(rw http.ResponseWriter, r *http.Request) (uuid.UUID, bool) {
|
|
modelConfigID, err := uuid.Parse(chi.URLParam(r, "modelConfig"))
|
|
if err != nil {
|
|
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid chat model config ID.",
|
|
Detail: err.Error(),
|
|
})
|
|
return uuid.Nil, false
|
|
}
|
|
return modelConfigID, true
|
|
}
|
|
|
|
func convertChatModelConfig(config database.ChatModelConfig) codersdk.ChatModelConfig {
|
|
var aiProviderID *uuid.UUID
|
|
if config.AIProviderID.Valid {
|
|
aiProviderID = &config.AIProviderID.UUID
|
|
}
|
|
return codersdk.ChatModelConfig{
|
|
ID: config.ID,
|
|
Provider: config.Provider,
|
|
AIProviderID: aiProviderID,
|
|
Model: config.Model,
|
|
DisplayName: config.DisplayName,
|
|
Enabled: config.Enabled,
|
|
IsDefault: config.IsDefault,
|
|
ContextLimit: config.ContextLimit,
|
|
CompressionThreshold: config.CompressionThreshold,
|
|
ModelConfig: unmarshalChatModelCallConfig(config.Options),
|
|
CreatedAt: config.CreatedAt,
|
|
UpdatedAt: config.UpdatedAt,
|
|
}
|
|
}
|
|
|
|
func marshalChatModelCallConfig(
|
|
modelConfig *codersdk.ChatModelCallConfig,
|
|
) (json.RawMessage, error) {
|
|
if modelConfig == nil {
|
|
return json.RawMessage("{}"), nil
|
|
}
|
|
|
|
if err := validateChatModelCallConfig(modelConfig); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
encoded, err := json.Marshal(modelConfig)
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("encode model config: %w", err)
|
|
}
|
|
return encoded, nil
|
|
}
|
|
|
|
func validateChatModelCallConfig(modelConfig *codersdk.ChatModelCallConfig) error {
|
|
if modelConfig == nil {
|
|
return nil
|
|
}
|
|
|
|
costConfig := codersdk.ModelCostConfig{}
|
|
if modelConfig.Cost != nil {
|
|
costConfig = *modelConfig.Cost
|
|
}
|
|
|
|
pricingFields := []struct {
|
|
name string
|
|
value *decimal.Decimal
|
|
}{
|
|
{name: "cost.input_price_per_million_tokens", value: costConfig.InputPricePerMillionTokens},
|
|
{name: "cost.output_price_per_million_tokens", value: costConfig.OutputPricePerMillionTokens},
|
|
{name: "cost.cache_read_price_per_million_tokens", value: costConfig.CacheReadPricePerMillionTokens},
|
|
{name: "cost.cache_write_price_per_million_tokens", value: costConfig.CacheWritePricePerMillionTokens},
|
|
}
|
|
for _, field := range pricingFields {
|
|
if err := validateNonNegativeDecimalField(field.name, field.value); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func validateNonNegativeDecimalField(name string, value *decimal.Decimal) error {
|
|
if value == nil {
|
|
return nil
|
|
}
|
|
if value.IsNegative() {
|
|
return xerrors.Errorf("%s must be greater than or equal to zero", name)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func unmarshalChatModelCallConfig(
|
|
raw json.RawMessage,
|
|
) *codersdk.ChatModelCallConfig {
|
|
if len(raw) == 0 {
|
|
return nil
|
|
}
|
|
|
|
decoded := &codersdk.ChatModelCallConfig{}
|
|
if err := json.Unmarshal(raw, decoded); err != nil {
|
|
return nil
|
|
}
|
|
if isZeroChatModelCallConfig(decoded) {
|
|
return nil
|
|
}
|
|
return decoded
|
|
}
|
|
|
|
func isZeroChatModelCallConfig(config *codersdk.ChatModelCallConfig) bool {
|
|
if config == nil {
|
|
return true
|
|
}
|
|
|
|
return config.MaxOutputTokens == nil &&
|
|
config.Temperature == nil &&
|
|
config.TopP == nil &&
|
|
config.TopK == nil &&
|
|
config.PresencePenalty == nil &&
|
|
config.FrequencyPenalty == nil &&
|
|
isZeroModelCostConfig(config.Cost) &&
|
|
isZeroChatModelProviderOptions(config.ProviderOptions)
|
|
}
|
|
|
|
func isZeroModelCostConfig(cost *codersdk.ModelCostConfig) bool {
|
|
if cost == nil {
|
|
return true
|
|
}
|
|
|
|
return cost.InputPricePerMillionTokens == nil &&
|
|
cost.OutputPricePerMillionTokens == nil &&
|
|
cost.CacheReadPricePerMillionTokens == nil &&
|
|
cost.CacheWritePricePerMillionTokens == nil
|
|
}
|
|
|
|
func isZeroChatModelProviderOptions(options *codersdk.ChatModelProviderOptions) bool {
|
|
if options == nil {
|
|
return true
|
|
}
|
|
|
|
return options.OpenAI == nil &&
|
|
options.Anthropic == nil &&
|
|
options.Google == nil &&
|
|
options.OpenAICompat == nil &&
|
|
options.OpenRouter == nil &&
|
|
options.Vercel == nil
|
|
}
|
|
|
|
const maxChatProviderAPIKeySize = 10240 // 10 KB
|
|
|
|
func validateChatProviderAPIKeySize(apiKey string) error {
|
|
if len(apiKey) > maxChatProviderAPIKeySize {
|
|
return xerrors.Errorf("API key exceeds maximum size of 10 KB (%d bytes)", maxChatProviderAPIKeySize)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
var (
|
|
errChatModelConfigNotFound = xerrors.New("chat model config not found")
|
|
errChatProviderNotConfigured = xerrors.New("chat provider is not configured")
|
|
)
|
|
|
|
// ChatProviderAPIKeysFromDeploymentValues returns deployment-backed chat
|
|
// provider API keys.
|
|
func ChatProviderAPIKeysFromDeploymentValues(
|
|
_ *codersdk.DeploymentValues,
|
|
) chatprovider.ProviderAPIKeys {
|
|
// AI bridge deployment config is intentionally not reused for chat
|
|
// provider credentials. Bridge keys serve the AI task subsystem and
|
|
// should not silently broaden into chat execution paths.
|
|
return chatprovider.ProviderAPIKeys{}
|
|
}
|
|
|
|
// @Summary Get PR insights
|
|
// @ID get-pr-insights
|
|
// @Security CoderSessionToken
|
|
// @Tags Chats
|
|
// @Produce json
|
|
// @Param start_date query string true "Start date (RFC3339)"
|
|
// @Param end_date query string true "End date (RFC3339)"
|
|
// @Success 200 {object} codersdk.PRInsightsResponse
|
|
// @Router /api/experimental/chats/insights/pull-requests [get]
|
|
// @x-apidocgen {"skip": true}
|
|
func (api *API) prInsights(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
|
|
// Admin-only endpoint.
|
|
if !api.Authorize(r, policy.ActionRead, rbac.ResourceDeploymentConfig) {
|
|
httpapi.Forbidden(rw)
|
|
return
|
|
}
|
|
|
|
// Parse date range.
|
|
now := time.Now()
|
|
defaultStart := now.AddDate(0, 0, -30)
|
|
|
|
qp := r.URL.Query()
|
|
p := httpapi.NewQueryParamParser()
|
|
startDate := p.Time(qp, defaultStart, "start_date", time.RFC3339)
|
|
endDate := p.Time(qp, now, "end_date", time.RFC3339)
|
|
p.ErrorExcessParams(qp)
|
|
if len(p.Errors) > 0 {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid query parameters.",
|
|
Validations: p.Errors,
|
|
})
|
|
return
|
|
}
|
|
|
|
// Calculate previous period of equal length for trend comparison.
|
|
duration := endDate.Sub(startDate)
|
|
prevStart := startDate.Add(-duration)
|
|
|
|
// No owner filter — admin sees all data.
|
|
ownerID := uuid.NullUUID{}
|
|
|
|
// Run all queries in parallel.
|
|
var (
|
|
currentSummary database.GetPRInsightsSummaryRow
|
|
previousSummary database.GetPRInsightsSummaryRow
|
|
timeSeries []database.GetPRInsightsTimeSeriesRow
|
|
byModel []database.GetPRInsightsPerModelRow
|
|
recentPRs []database.GetPRInsightsPullRequestsRow
|
|
)
|
|
|
|
eg, egCtx := errgroup.WithContext(ctx)
|
|
eg.SetLimit(5)
|
|
|
|
eg.Go(func() error {
|
|
var err error
|
|
currentSummary, err = api.Database.GetPRInsightsSummary(egCtx, database.GetPRInsightsSummaryParams{
|
|
StartDate: startDate,
|
|
EndDate: endDate,
|
|
OwnerID: ownerID,
|
|
})
|
|
return err
|
|
})
|
|
|
|
eg.Go(func() error {
|
|
var err error
|
|
previousSummary, err = api.Database.GetPRInsightsSummary(egCtx, database.GetPRInsightsSummaryParams{
|
|
StartDate: prevStart,
|
|
EndDate: startDate,
|
|
OwnerID: ownerID,
|
|
})
|
|
return err
|
|
})
|
|
|
|
eg.Go(func() error {
|
|
var err error
|
|
timeSeries, err = api.Database.GetPRInsightsTimeSeries(egCtx, database.GetPRInsightsTimeSeriesParams{
|
|
StartDate: startDate,
|
|
EndDate: endDate,
|
|
OwnerID: ownerID,
|
|
})
|
|
return err
|
|
})
|
|
|
|
eg.Go(func() error {
|
|
var err error
|
|
byModel, err = api.Database.GetPRInsightsPerModel(egCtx, database.GetPRInsightsPerModelParams{
|
|
StartDate: startDate,
|
|
EndDate: endDate,
|
|
OwnerID: ownerID,
|
|
})
|
|
return err
|
|
})
|
|
|
|
eg.Go(func() error {
|
|
var err error
|
|
recentPRs, err = api.Database.GetPRInsightsPullRequests(egCtx, database.GetPRInsightsPullRequestsParams{
|
|
StartDate: startDate,
|
|
EndDate: endDate,
|
|
OwnerID: ownerID,
|
|
})
|
|
return err
|
|
})
|
|
|
|
if err := eg.Wait(); err != nil {
|
|
httpapi.InternalServerError(rw, err)
|
|
return
|
|
}
|
|
|
|
// Build summary with computed fields.
|
|
summary := codersdk.PRInsightsSummary{
|
|
TotalPRsCreated: currentSummary.TotalPrsCreated,
|
|
TotalPRsMerged: currentSummary.TotalPrsMerged,
|
|
TotalAdditions: currentSummary.TotalAdditions,
|
|
TotalDeletions: currentSummary.TotalDeletions,
|
|
TotalCostMicros: currentSummary.TotalCostMicros,
|
|
PrevTotalPRsCreated: previousSummary.TotalPrsCreated,
|
|
PrevTotalPRsMerged: previousSummary.TotalPrsMerged,
|
|
}
|
|
if summary.TotalPRsCreated > 0 {
|
|
summary.MergeRate = float64(summary.TotalPRsMerged) / float64(summary.TotalPRsCreated)
|
|
}
|
|
if summary.TotalPRsMerged > 0 {
|
|
summary.CostPerMergedPRMicros = currentSummary.MergedCostMicros / summary.TotalPRsMerged
|
|
}
|
|
if summary.PrevTotalPRsCreated > 0 {
|
|
summary.PrevMergeRate = float64(summary.PrevTotalPRsMerged) / float64(summary.PrevTotalPRsCreated)
|
|
}
|
|
if summary.PrevTotalPRsMerged > 0 {
|
|
summary.PrevCostPerMergedPRMicros = previousSummary.MergedCostMicros / summary.PrevTotalPRsMerged
|
|
}
|
|
|
|
// Convert time series.
|
|
tsEntries := make([]codersdk.PRInsightsTimeSeriesEntry, 0, len(timeSeries))
|
|
for _, ts := range timeSeries {
|
|
tsEntries = append(tsEntries, codersdk.PRInsightsTimeSeriesEntry{
|
|
Date: ts.Date,
|
|
PRsCreated: ts.PrsCreated,
|
|
PRsMerged: ts.PrsMerged,
|
|
PRsClosed: ts.PrsClosed,
|
|
})
|
|
}
|
|
|
|
// Convert model breakdown.
|
|
modelEntries := make([]codersdk.PRInsightsModelBreakdown, 0, len(byModel))
|
|
for _, m := range byModel {
|
|
entry := codersdk.PRInsightsModelBreakdown{
|
|
ModelConfigID: m.ModelConfigID.UUID,
|
|
DisplayName: m.DisplayName,
|
|
Provider: m.Provider,
|
|
TotalPRs: m.TotalPrs,
|
|
MergedPRs: m.MergedPrs,
|
|
TotalAdditions: m.TotalAdditions,
|
|
TotalDeletions: m.TotalDeletions,
|
|
TotalCostMicros: m.TotalCostMicros,
|
|
}
|
|
if entry.TotalPRs > 0 {
|
|
entry.MergeRate = float64(entry.MergedPRs) / float64(entry.TotalPRs)
|
|
}
|
|
if entry.MergedPRs > 0 {
|
|
entry.CostPerMergedPRMicros = m.MergedCostMicros / entry.MergedPRs
|
|
}
|
|
modelEntries = append(modelEntries, entry)
|
|
}
|
|
|
|
// Convert recent PRs.
|
|
prEntries := make([]codersdk.PRInsightsPullRequest, 0, len(recentPRs))
|
|
for _, pr := range recentPRs {
|
|
entry := codersdk.PRInsightsPullRequest{
|
|
ChatID: pr.ChatID,
|
|
PRTitle: pr.PrTitle,
|
|
Draft: pr.Draft,
|
|
Additions: pr.Additions,
|
|
Deletions: pr.Deletions,
|
|
ChangedFiles: pr.ChangedFiles,
|
|
ChangesRequested: pr.ChangesRequested,
|
|
BaseBranch: pr.BaseBranch,
|
|
ModelDisplayName: pr.ModelDisplayName,
|
|
CostMicros: pr.CostMicros,
|
|
CreatedAt: pr.CreatedAt,
|
|
}
|
|
if pr.PrUrl.Valid {
|
|
entry.PRURL = &pr.PrUrl.String
|
|
}
|
|
if pr.PrNumber.Valid {
|
|
entry.PRNumber = &pr.PrNumber.Int32
|
|
}
|
|
if pr.State.Valid {
|
|
entry.State = pr.State.String
|
|
}
|
|
if pr.Commits.Valid {
|
|
entry.Commits = &pr.Commits.Int32
|
|
}
|
|
if pr.Approved.Valid {
|
|
entry.Approved = &pr.Approved.Bool
|
|
}
|
|
if pr.ReviewerCount.Valid {
|
|
entry.ReviewerCount = &pr.ReviewerCount.Int32
|
|
}
|
|
if pr.AuthorLogin.Valid {
|
|
entry.AuthorLogin = &pr.AuthorLogin.String
|
|
}
|
|
if pr.AuthorAvatarUrl.Valid {
|
|
entry.AuthorAvatarURL = &pr.AuthorAvatarUrl.String
|
|
}
|
|
prEntries = append(prEntries, entry)
|
|
}
|
|
|
|
httpapi.Write(ctx, rw, http.StatusOK, codersdk.PRInsightsResponse{
|
|
Summary: summary,
|
|
TimeSeries: tsEntries,
|
|
ByModel: modelEntries,
|
|
PullRequests: prEntries,
|
|
})
|
|
}
|
|
|
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
|
//
|
|
//nolint:revive // HTTP handler writes to ResponseWriter.
|
|
func (api *API) postChatToolResults(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
chat := httpmw.ChatParam(r)
|
|
apiKey := httpmw.APIKey(r)
|
|
|
|
// Submitting tool results resumes LLM inference,
|
|
// requiring update permission on the org-scoped chat resource.
|
|
if !api.Authorize(r, policy.ActionUpdate, chat.RBACObject()) {
|
|
httpapi.ResourceNotFound(rw)
|
|
return
|
|
}
|
|
|
|
// Only the chat owner may submit tool results. See
|
|
// postChatMessages for the security rationale.
|
|
if apiKey.UserID != chat.OwnerID {
|
|
httpapi.Write(ctx, rw, http.StatusForbidden, codersdk.Response{
|
|
Message: "Only the chat owner may submit tool results.",
|
|
})
|
|
return
|
|
}
|
|
|
|
if chat.Archived {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Cannot submit tool results to an archived chat.",
|
|
})
|
|
return
|
|
}
|
|
|
|
// Cap the raw request body to prevent excessive memory use.
|
|
r.Body = http.MaxBytesReader(rw, r.Body, int64(2*maxSystemPromptLenBytes))
|
|
var req codersdk.SubmitToolResultsRequest
|
|
|
|
if !httpapi.Read(ctx, rw, r, &req) {
|
|
return
|
|
}
|
|
|
|
if len(req.Results) == 0 {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "At least one tool result is required.",
|
|
})
|
|
return
|
|
}
|
|
|
|
// Fast-path check outside the transaction. The authoritative
|
|
// check happens inside SubmitToolResults under a row lock.
|
|
if chat.Status != database.ChatStatusRequiresAction {
|
|
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{
|
|
Message: "Chat is not waiting for tool results.",
|
|
Detail: fmt.Sprintf("Chat status is %q, expected %q.", chat.Status, database.ChatStatusRequiresAction),
|
|
})
|
|
return
|
|
}
|
|
|
|
var dynamicTools json.RawMessage
|
|
if chat.DynamicTools.Valid {
|
|
dynamicTools = chat.DynamicTools.RawMessage
|
|
}
|
|
|
|
err := api.chatDaemon.SubmitToolResults(ctx, chatd.SubmitToolResultsOptions{
|
|
ChatID: chat.ID,
|
|
UserID: apiKey.UserID,
|
|
ModelConfigID: chat.LastModelConfigID,
|
|
Results: req.Results,
|
|
DynamicTools: dynamicTools,
|
|
})
|
|
if err != nil {
|
|
var validationErr *chatd.ToolResultValidationError
|
|
var conflictErr *chatd.ToolResultStatusConflictError
|
|
switch {
|
|
case xerrors.Is(err, chatd.ErrChatArchived):
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Cannot submit tool results to an archived chat.",
|
|
})
|
|
case errors.As(err, &conflictErr):
|
|
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{
|
|
Message: "Chat is not waiting for tool results.",
|
|
Detail: err.Error(),
|
|
})
|
|
case errors.As(err, &validationErr):
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: validationErr.Message,
|
|
Detail: validationErr.Detail,
|
|
})
|
|
default:
|
|
api.Logger.Error(ctx, "tool results submission failed",
|
|
slog.F("chat_id", chat.ID),
|
|
slog.Error(err),
|
|
)
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Internal error submitting tool results.",
|
|
})
|
|
}
|
|
return
|
|
}
|
|
|
|
rw.WriteHeader(http.StatusNoContent)
|
|
}
|
|
|
|
// getChatDebugRuns returns a list of debug run summaries for a chat.
|
|
// EXPERIMENTAL
|
|
//
|
|
//nolint:revive // get-return: revive assumes get* must be a getter, but this is an HTTP handler.
|
|
func (api *API) getChatDebugRuns(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
chat := httpmw.ChatParam(r)
|
|
|
|
const maxDebugRuns = 100
|
|
runs, err := api.Database.GetChatDebugRunsByChatID(ctx, database.GetChatDebugRunsByChatIDParams{
|
|
ChatID: chat.ID,
|
|
LimitVal: maxDebugRuns,
|
|
})
|
|
if err != nil {
|
|
// The chat may have been deleted or access revoked between
|
|
// middleware extraction and this query (dbauthz re-authorizes
|
|
// on read). Surface those races as 404 to match the rest of
|
|
// this API and avoid leaking backend details.
|
|
if httpapi.Is404Error(err) {
|
|
httpapi.ResourceNotFound(rw)
|
|
return
|
|
}
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Internal error fetching debug runs.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
summaries := make([]codersdk.ChatDebugRunSummary, 0, len(runs))
|
|
for _, run := range runs {
|
|
summaries = append(summaries, db2sdk.ChatDebugRunSummary(run))
|
|
}
|
|
httpapi.Write(ctx, rw, http.StatusOK, summaries)
|
|
}
|
|
|
|
// getChatDebugRun returns a single debug run with its steps.
|
|
// EXPERIMENTAL
|
|
//
|
|
//nolint:revive // get-return: revive assumes get* must be a getter, but this is an HTTP handler.
|
|
func (api *API) getChatDebugRun(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
chat := httpmw.ChatParam(r)
|
|
|
|
runIDStr := chi.URLParam(r, "debugRun")
|
|
runID, err := uuid.Parse(runIDStr)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Invalid debug run ID.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
run, err := api.Database.GetChatDebugRunByID(ctx, runID)
|
|
if err != nil {
|
|
// Treat both not-found and authorization failures as 404 to
|
|
// avoid leaking the existence of runs the caller cannot access.
|
|
if httpapi.Is404Error(err) {
|
|
httpapi.ResourceNotFound(rw)
|
|
return
|
|
}
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Internal error fetching debug run.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
// Verify the run belongs to this chat.
|
|
if run.ChatID != chat.ID {
|
|
httpapi.ResourceNotFound(rw)
|
|
return
|
|
}
|
|
|
|
steps, err := api.Database.GetChatDebugStepsByRunID(ctx, run.ID)
|
|
if err != nil {
|
|
// The run may have been deleted or access may have changed
|
|
// between the two queries. Treat not-found/authz errors as
|
|
// 404 for consistency with the run lookup above.
|
|
if httpapi.Is404Error(err) {
|
|
httpapi.ResourceNotFound(rw)
|
|
return
|
|
}
|
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
|
Message: "Internal error fetching debug steps.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
httpapi.Write(ctx, rw, http.StatusOK, db2sdk.ChatDebugRunDetail(run, steps))
|
|
}
|