mirror of
https://github.com/coder/coder.git
synced 2026-06-03 04:58:23 +00:00
7b903cad73
## Problem Centralized requests recorded *the first available key from the pool at `CreateInterceptor` time* as `credential_hint`, so the interception could be persisted in the database with a hint that didn't match the key that actually served the request. The fix consists in storing, at end-of-interception, the hint of the key that succeeded, or the last attempted key if all keys are unavailable. ## Changes - Add `Key.Hint()` and update `credential_hint` on every failover attempt so it reflects the actually-used key. - Stop pre-populating `credential_hint` at `CreateInterceptor`. Centralized starts empty and is updated by the key failover loop. - Persist the final hint via `RecordInterceptionEnded`; SQL updates `credential_hint` only when `credential_kind = 'centralized'` so BYOK keeps its start-time value. - Log the actually-used hint on interception end/failure; start log uses a `<keypool-pending>` placeholder for centralized. > [!NOTE] > Initially generated by Claude Opus 4.7, modified and reviewed by @ssncferreira
687 lines
24 KiB
Go
687 lines
24 KiB
Go
package aibridgedserver
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"encoding/json"
|
|
"net/url"
|
|
"slices"
|
|
"strings"
|
|
"sync"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/hashicorp/go-multierror"
|
|
"golang.org/x/xerrors"
|
|
"google.golang.org/protobuf/types/known/anypb"
|
|
"google.golang.org/protobuf/types/known/structpb"
|
|
|
|
"cdr.dev/slog/v3"
|
|
"github.com/coder/coder/v2/coderd/aibridged"
|
|
"github.com/coder/coder/v2/coderd/aibridged/proto"
|
|
"github.com/coder/coder/v2/coderd/aiseats"
|
|
"github.com/coder/coder/v2/coderd/apikey"
|
|
"github.com/coder/coder/v2/coderd/database"
|
|
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
|
"github.com/coder/coder/v2/coderd/database/dbtime"
|
|
"github.com/coder/coder/v2/coderd/externalauth"
|
|
"github.com/coder/coder/v2/coderd/httpmw"
|
|
codermcp "github.com/coder/coder/v2/coderd/mcp"
|
|
"github.com/coder/coder/v2/codersdk"
|
|
)
|
|
|
|
var (
|
|
ErrExpiredOrInvalidOAuthToken = xerrors.New("expired or invalid OAuth2 token")
|
|
ErrNoMCPConfigFound = xerrors.New("no MCP config found")
|
|
|
|
// These errors are returned by IsAuthorized. Since they're just returned as
|
|
// a generic dRPC error, it's difficult to tell them apart without string
|
|
// matching.
|
|
// TODO: return these errors to the client in a more structured/comparable
|
|
// way.
|
|
ErrInvalidKey = xerrors.New("invalid key")
|
|
ErrUnknownKey = xerrors.New("unknown key")
|
|
ErrExpired = xerrors.New("expired")
|
|
ErrUnknownUser = xerrors.New("unknown user")
|
|
ErrDeletedUser = xerrors.New("deleted user")
|
|
ErrSystemUser = xerrors.New("system user")
|
|
ErrAmbiguousAuth = xerrors.New("both key and key_id set; exactly one required")
|
|
|
|
ErrNoExternalAuthLinkFound = xerrors.New("no external auth link found")
|
|
)
|
|
|
|
const (
|
|
InterceptionLogMarker = "interception log"
|
|
MetadataUserAgentKey = "request_user_agent"
|
|
)
|
|
|
|
var _ aibridged.DRPCServer = &Server{}
|
|
|
|
type store interface {
|
|
// Recorder-related queries.
|
|
InsertAIBridgeInterception(ctx context.Context, arg database.InsertAIBridgeInterceptionParams) (database.AIBridgeInterception, error)
|
|
InsertAIBridgeTokenUsage(ctx context.Context, arg database.InsertAIBridgeTokenUsageParams) (database.AIBridgeTokenUsage, error)
|
|
InsertAIBridgeUserPrompt(ctx context.Context, arg database.InsertAIBridgeUserPromptParams) (database.AIBridgeUserPrompt, error)
|
|
InsertAIBridgeToolUsage(ctx context.Context, arg database.InsertAIBridgeToolUsageParams) (database.AIBridgeToolUsage, error)
|
|
InsertAIBridgeModelThought(ctx context.Context, arg database.InsertAIBridgeModelThoughtParams) (database.AIBridgeModelThought, error)
|
|
UpdateAIBridgeInterceptionEnded(ctx context.Context, intcID database.UpdateAIBridgeInterceptionEndedParams) (database.AIBridgeInterception, error)
|
|
GetAIBridgeInterceptionLineageByToolCallID(ctx context.Context, toolCallID string) (database.GetAIBridgeInterceptionLineageByToolCallIDRow, error)
|
|
|
|
// MCPConfigurator-related queries.
|
|
GetExternalAuthLinksByUserID(ctx context.Context, userID uuid.UUID) ([]database.ExternalAuthLink, error)
|
|
|
|
// Authorizer-related queries.
|
|
GetAPIKeyByID(ctx context.Context, id string) (database.APIKey, error)
|
|
GetUserByID(ctx context.Context, id uuid.UUID) (database.User, error)
|
|
}
|
|
|
|
type Server struct {
|
|
// lifecycleCtx must be tied to the API server's lifecycle
|
|
// as when the API server shuts down, we want to cancel any
|
|
// long-running operations.
|
|
lifecycleCtx context.Context
|
|
store store
|
|
logger slog.Logger
|
|
externalAuthConfigs map[string]*externalauth.Config
|
|
|
|
coderMCPConfig *proto.MCPServerConfig // may be nil if not available
|
|
structuredLogging bool
|
|
aiSeatTracker aiseats.SeatTracker
|
|
}
|
|
|
|
func NewServer(lifecycleCtx context.Context, store store, logger slog.Logger, accessURL string,
|
|
bridgeCfg codersdk.AIBridgeConfig, externalAuthConfigs []*externalauth.Config, experiments codersdk.Experiments,
|
|
aiSeatTracker aiseats.SeatTracker,
|
|
) (*Server, error) {
|
|
eac := make(map[string]*externalauth.Config, len(externalAuthConfigs))
|
|
|
|
for _, cfg := range externalAuthConfigs {
|
|
// Only External Auth configs which are configured with an MCP URL are relevant to aibridged.
|
|
if cfg.MCPURL == "" {
|
|
continue
|
|
}
|
|
eac[cfg.ID] = cfg
|
|
}
|
|
|
|
srv := &Server{
|
|
lifecycleCtx: lifecycleCtx,
|
|
store: store,
|
|
logger: logger,
|
|
externalAuthConfigs: eac,
|
|
structuredLogging: bridgeCfg.StructuredLogging.Value(),
|
|
aiSeatTracker: aiSeatTracker,
|
|
}
|
|
|
|
if bridgeCfg.InjectCoderMCPTools {
|
|
logger.Warn(lifecycleCtx, "inject MCP tools option is deprecated and will be removed in a future release")
|
|
coderMCPConfig, err := getCoderMCPServerConfig(experiments, accessURL)
|
|
if err != nil {
|
|
logger.Warn(lifecycleCtx, "failed to retrieve coder MCP server config, Coder MCP will not be available", slog.Error(err))
|
|
}
|
|
srv.coderMCPConfig = coderMCPConfig
|
|
}
|
|
|
|
return srv, nil
|
|
}
|
|
|
|
func (s *Server) RecordInterception(ctx context.Context, in *proto.RecordInterceptionRequest) (*proto.RecordInterceptionResponse, error) {
|
|
//nolint:gocritic // AIBridged has specific authz rules.
|
|
ctx = dbauthz.AsAIBridged(ctx)
|
|
|
|
intcID, err := uuid.Parse(in.GetId())
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("invalid interception ID %q: %w", in.GetId(), err)
|
|
}
|
|
initID, err := uuid.Parse(in.GetInitiatorId())
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("invalid initiator ID %q: %w", in.GetInitiatorId(), err)
|
|
}
|
|
if in.ApiKeyId == "" {
|
|
return nil, xerrors.Errorf("empty API key ID")
|
|
}
|
|
|
|
metadata := metadataToMap(in.GetMetadata())
|
|
|
|
if in.UserAgent != "" {
|
|
if _, ok := metadata[MetadataUserAgentKey]; ok {
|
|
s.logger.Warn(ctx, "interception metadata contains user agent key, will be overwritten")
|
|
}
|
|
metadata[MetadataUserAgentKey] = in.UserAgent
|
|
}
|
|
|
|
// Look up the interception lineage using the correlating tool call ID.
|
|
parentID, rootID := s.findInterceptionLineage(ctx, in.GetCorrelatingToolCallId())
|
|
|
|
if s.structuredLogging {
|
|
s.logger.Info(ctx, InterceptionLogMarker,
|
|
slog.F("record_type", "interception_start"),
|
|
slog.F("interception_id", intcID.String()),
|
|
slog.F("initiator_id", initID.String()),
|
|
slog.F("api_key_id", in.ApiKeyId),
|
|
slog.F("provider", in.Provider),
|
|
slog.F("model", in.Model),
|
|
slog.F("client", in.Client),
|
|
slog.F("client_session_id", in.GetClientSessionId()),
|
|
slog.F("started_at", in.StartedAt.AsTime()),
|
|
slog.F("metadata", metadata),
|
|
slog.F("correlating_tool_call_id", in.GetCorrelatingToolCallId()),
|
|
slog.F("thread_parent_id", parentID),
|
|
slog.F("thread_root_id", rootID),
|
|
)
|
|
}
|
|
|
|
out, err := json.Marshal(metadata)
|
|
if err != nil {
|
|
s.logger.Warn(ctx, "failed to marshal aibridge metadata from proto to JSON", slog.F("metadata", in), slog.Error(err))
|
|
}
|
|
|
|
providerName := strings.TrimSpace(in.ProviderName)
|
|
if providerName == "" {
|
|
providerName = in.Provider
|
|
}
|
|
|
|
_, err = s.store.InsertAIBridgeInterception(ctx, database.InsertAIBridgeInterceptionParams{
|
|
ID: intcID,
|
|
APIKeyID: sql.NullString{String: in.ApiKeyId, Valid: true},
|
|
Client: sql.NullString{String: in.Client, Valid: in.Client != ""},
|
|
ClientSessionID: sql.NullString{String: in.GetClientSessionId(), Valid: in.GetClientSessionId() != ""},
|
|
InitiatorID: initID,
|
|
Provider: in.Provider,
|
|
ProviderName: providerName,
|
|
Model: in.Model,
|
|
Metadata: out,
|
|
StartedAt: in.StartedAt.AsTime(),
|
|
ThreadParentInterceptionID: uuid.NullUUID{UUID: parentID, Valid: parentID != uuid.Nil},
|
|
ThreadRootInterceptionID: uuid.NullUUID{UUID: rootID, Valid: rootID != uuid.Nil},
|
|
CredentialKind: credentialKindOrDefault(in.CredentialKind),
|
|
CredentialHint: in.CredentialHint,
|
|
})
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("start interception: %w", err)
|
|
}
|
|
|
|
reason := aiseats.ReasonAIBridge("provider=" + in.Provider + ", model=" + in.Model)
|
|
s.aiSeatTracker.RecordUsage(ctx, initID, reason)
|
|
return &proto.RecordInterceptionResponse{}, nil
|
|
}
|
|
|
|
func (s *Server) RecordInterceptionEnded(ctx context.Context, in *proto.RecordInterceptionEndedRequest) (*proto.RecordInterceptionEndedResponse, error) {
|
|
//nolint:gocritic // AIBridged has specific authz rules.
|
|
ctx = dbauthz.AsAIBridged(ctx)
|
|
|
|
intcID, err := uuid.Parse(in.GetId())
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("invalid interception ID %q: %w", in.GetId(), err)
|
|
}
|
|
|
|
if s.structuredLogging {
|
|
s.logger.Info(ctx, InterceptionLogMarker,
|
|
slog.F("record_type", "interception_end"),
|
|
slog.F("interception_id", intcID.String()),
|
|
slog.F("ended_at", in.EndedAt.AsTime()),
|
|
)
|
|
}
|
|
|
|
_, err = s.store.UpdateAIBridgeInterceptionEnded(ctx, database.UpdateAIBridgeInterceptionEndedParams{
|
|
ID: intcID,
|
|
EndedAt: in.EndedAt.AsTime(),
|
|
CredentialHint: in.CredentialHint,
|
|
})
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("end interception: %w", err)
|
|
}
|
|
|
|
return &proto.RecordInterceptionEndedResponse{}, nil
|
|
}
|
|
|
|
func (s *Server) RecordTokenUsage(ctx context.Context, in *proto.RecordTokenUsageRequest) (*proto.RecordTokenUsageResponse, error) {
|
|
//nolint:gocritic // AIBridged has specific authz rules.
|
|
ctx = dbauthz.AsAIBridged(ctx)
|
|
|
|
intcID, err := uuid.Parse(in.GetInterceptionId())
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("failed to parse interception_id %q: %w", in.GetInterceptionId(), err)
|
|
}
|
|
|
|
metadata := metadataToMap(in.GetMetadata())
|
|
|
|
if s.structuredLogging {
|
|
s.logger.Info(ctx, InterceptionLogMarker,
|
|
slog.F("record_type", "token_usage"),
|
|
slog.F("interception_id", intcID.String()),
|
|
slog.F("msg_id", in.GetMsgId()),
|
|
slog.F("input_tokens", in.GetInputTokens()),
|
|
slog.F("output_tokens", in.GetOutputTokens()),
|
|
slog.F("cache_read_input_tokens", in.GetCacheReadInputTokens()),
|
|
slog.F("cache_write_input_tokens", in.GetCacheWriteInputTokens()),
|
|
slog.F("created_at", in.GetCreatedAt().AsTime()),
|
|
slog.F("metadata", metadata),
|
|
)
|
|
}
|
|
|
|
out, err := json.Marshal(metadata)
|
|
if err != nil {
|
|
s.logger.Warn(ctx, "failed to marshal aibridge metadata from proto to JSON", slog.F("metadata", in), slog.Error(err))
|
|
}
|
|
|
|
_, err = s.store.InsertAIBridgeTokenUsage(ctx, database.InsertAIBridgeTokenUsageParams{
|
|
ID: uuid.New(),
|
|
InterceptionID: intcID,
|
|
ProviderResponseID: in.GetMsgId(),
|
|
InputTokens: in.GetInputTokens(),
|
|
OutputTokens: in.GetOutputTokens(),
|
|
CacheReadInputTokens: in.GetCacheReadInputTokens(),
|
|
CacheWriteInputTokens: in.GetCacheWriteInputTokens(),
|
|
Metadata: out,
|
|
CreatedAt: in.GetCreatedAt().AsTime(),
|
|
})
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("insert token usage: %w", err)
|
|
}
|
|
|
|
return &proto.RecordTokenUsageResponse{}, nil
|
|
}
|
|
|
|
func (s *Server) RecordPromptUsage(ctx context.Context, in *proto.RecordPromptUsageRequest) (*proto.RecordPromptUsageResponse, error) {
|
|
//nolint:gocritic // AIBridged has specific authz rules.
|
|
ctx = dbauthz.AsAIBridged(ctx)
|
|
|
|
intcID, err := uuid.Parse(in.GetInterceptionId())
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("failed to parse interception_id %q: %w", in.GetInterceptionId(), err)
|
|
}
|
|
|
|
metadata := metadataToMap(in.GetMetadata())
|
|
|
|
if s.structuredLogging {
|
|
s.logger.Info(ctx, InterceptionLogMarker,
|
|
slog.F("record_type", "prompt_usage"),
|
|
slog.F("interception_id", intcID.String()),
|
|
slog.F("msg_id", in.GetMsgId()),
|
|
slog.F("prompt", in.GetPrompt()),
|
|
slog.F("created_at", in.GetCreatedAt().AsTime()),
|
|
slog.F("metadata", metadata),
|
|
)
|
|
}
|
|
|
|
out, err := json.Marshal(metadata)
|
|
if err != nil {
|
|
s.logger.Warn(ctx, "failed to marshal aibridge metadata from proto to JSON", slog.F("metadata", in), slog.Error(err))
|
|
}
|
|
|
|
_, err = s.store.InsertAIBridgeUserPrompt(ctx, database.InsertAIBridgeUserPromptParams{
|
|
ID: uuid.New(),
|
|
InterceptionID: intcID,
|
|
ProviderResponseID: in.GetMsgId(),
|
|
Prompt: in.GetPrompt(),
|
|
Metadata: out,
|
|
CreatedAt: in.GetCreatedAt().AsTime(),
|
|
})
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("insert user prompt: %w", err)
|
|
}
|
|
|
|
return &proto.RecordPromptUsageResponse{}, nil
|
|
}
|
|
|
|
func (s *Server) RecordToolUsage(ctx context.Context, in *proto.RecordToolUsageRequest) (*proto.RecordToolUsageResponse, error) {
|
|
//nolint:gocritic // AIBridged has specific authz rules.
|
|
ctx = dbauthz.AsAIBridged(ctx)
|
|
|
|
intcID, err := uuid.Parse(in.GetInterceptionId())
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("failed to parse interception_id %q: %w", in.GetInterceptionId(), err)
|
|
}
|
|
|
|
metadata := metadataToMap(in.GetMetadata())
|
|
|
|
if s.structuredLogging {
|
|
s.logger.Info(ctx, InterceptionLogMarker,
|
|
slog.F("record_type", "tool_usage"),
|
|
slog.F("interception_id", intcID.String()),
|
|
slog.F("msg_id", in.GetMsgId()),
|
|
slog.F("tool_call_id", in.GetToolCallId()),
|
|
slog.F("tool", in.GetTool()),
|
|
slog.F("input", in.GetInput()),
|
|
slog.F("server_url", in.GetServerUrl()),
|
|
slog.F("injected", in.GetInjected()),
|
|
slog.F("invocation_error", in.GetInvocationError()),
|
|
slog.F("created_at", in.GetCreatedAt().AsTime()),
|
|
slog.F("metadata", metadata),
|
|
)
|
|
}
|
|
|
|
out, err := json.Marshal(metadata)
|
|
if err != nil {
|
|
s.logger.Warn(ctx, "failed to marshal aibridge metadata from proto to JSON", slog.F("metadata", in), slog.Error(err))
|
|
}
|
|
|
|
_, err = s.store.InsertAIBridgeToolUsage(ctx, database.InsertAIBridgeToolUsageParams{
|
|
ID: uuid.New(),
|
|
InterceptionID: intcID,
|
|
ProviderResponseID: in.GetMsgId(),
|
|
ProviderToolCallID: sql.NullString{String: in.GetToolCallId(), Valid: in.GetToolCallId() != ""},
|
|
ServerUrl: sql.NullString{String: in.GetServerUrl(), Valid: in.ServerUrl != nil},
|
|
Tool: in.GetTool(),
|
|
Input: in.GetInput(),
|
|
Injected: in.GetInjected(),
|
|
InvocationError: sql.NullString{String: in.GetInvocationError(), Valid: in.InvocationError != nil},
|
|
Metadata: out,
|
|
CreatedAt: in.GetCreatedAt().AsTime(),
|
|
})
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("insert tool usage: %w", err)
|
|
}
|
|
|
|
return &proto.RecordToolUsageResponse{}, nil
|
|
}
|
|
|
|
func (s *Server) RecordModelThought(ctx context.Context, in *proto.RecordModelThoughtRequest) (*proto.RecordModelThoughtResponse, error) {
|
|
//nolint:gocritic // AIBridged has specific authz rules.
|
|
ctx = dbauthz.AsAIBridged(ctx)
|
|
|
|
intcID, err := uuid.Parse(in.GetInterceptionId())
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("failed to parse interception_id %q: %w", in.GetInterceptionId(), err)
|
|
}
|
|
|
|
metadata := metadataToMap(in.GetMetadata())
|
|
|
|
if s.structuredLogging {
|
|
s.logger.Info(ctx, InterceptionLogMarker,
|
|
slog.F("record_type", "model_thought"),
|
|
slog.F("interception_id", intcID.String()),
|
|
slog.F("content", in.GetContent()),
|
|
slog.F("created_at", in.GetCreatedAt().AsTime()),
|
|
slog.F("metadata", metadata),
|
|
)
|
|
}
|
|
|
|
out, err := json.Marshal(metadata)
|
|
if err != nil {
|
|
s.logger.Warn(ctx, "failed to marshal aibridge metadata from proto to JSON", slog.F("metadata", in), slog.Error(err))
|
|
}
|
|
|
|
_, err = s.store.InsertAIBridgeModelThought(ctx, database.InsertAIBridgeModelThoughtParams{
|
|
InterceptionID: intcID,
|
|
Content: in.GetContent(),
|
|
Metadata: out,
|
|
CreatedAt: in.GetCreatedAt().AsTime(),
|
|
})
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("insert model thought: %w", err)
|
|
}
|
|
|
|
return &proto.RecordModelThoughtResponse{}, nil
|
|
}
|
|
|
|
// findInterceptionLineage looks up the parent interception and the root
|
|
// of the thread by finding which interception recorded a tool usage with
|
|
// the given tool call ID. Returns (parentID, rootID); both will be
|
|
// uuid.Nil if no match is found or the tool call ID is empty.
|
|
func (s *Server) findInterceptionLineage(ctx context.Context, toolCallID string) (parent uuid.UUID, root uuid.UUID) {
|
|
if toolCallID == "" {
|
|
return uuid.Nil, uuid.Nil
|
|
}
|
|
|
|
lineage, err := s.store.GetAIBridgeInterceptionLineageByToolCallID(ctx, toolCallID)
|
|
if err != nil {
|
|
s.logger.Warn(ctx, "failed to retrieve interception lineage",
|
|
slog.Error(err), slog.F("tool_call_id", toolCallID))
|
|
return uuid.Nil, uuid.Nil
|
|
}
|
|
|
|
return lineage.ThreadParentID, lineage.ThreadRootID
|
|
}
|
|
|
|
func (s *Server) GetMCPServerConfigs(_ context.Context, _ *proto.GetMCPServerConfigsRequest) (*proto.GetMCPServerConfigsResponse, error) {
|
|
cfgs := make([]*proto.MCPServerConfig, 0, len(s.externalAuthConfigs))
|
|
for _, eac := range s.externalAuthConfigs {
|
|
var allowlist, denylist string
|
|
if eac.MCPToolAllowRegex != nil {
|
|
allowlist = eac.MCPToolAllowRegex.String()
|
|
}
|
|
if eac.MCPToolDenyRegex != nil {
|
|
denylist = eac.MCPToolDenyRegex.String()
|
|
}
|
|
|
|
cfgs = append(cfgs, &proto.MCPServerConfig{
|
|
Id: eac.ID,
|
|
Url: eac.MCPURL,
|
|
ToolAllowRegex: allowlist,
|
|
ToolDenyRegex: denylist,
|
|
})
|
|
}
|
|
|
|
return &proto.GetMCPServerConfigsResponse{
|
|
CoderMcpConfig: s.coderMCPConfig, // it's fine if this is nil
|
|
ExternalAuthMcpConfigs: cfgs,
|
|
}, nil
|
|
}
|
|
|
|
func (s *Server) GetMCPServerAccessTokensBatch(ctx context.Context, in *proto.GetMCPServerAccessTokensBatchRequest) (*proto.GetMCPServerAccessTokensBatchResponse, error) {
|
|
if len(in.GetMcpServerConfigIds()) == 0 {
|
|
return &proto.GetMCPServerAccessTokensBatchResponse{}, nil
|
|
}
|
|
|
|
userID, err := uuid.Parse(in.GetUserId())
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("parse user_id: %w", err)
|
|
}
|
|
|
|
//nolint:gocritic // AIBridged has specific authz rules.
|
|
ctx = dbauthz.AsAIBridged(ctx)
|
|
links, err := s.store.GetExternalAuthLinksByUserID(ctx, userID)
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("fetch external auth links: %w", err)
|
|
}
|
|
|
|
if len(links) == 0 {
|
|
return &proto.GetMCPServerAccessTokensBatchResponse{}, nil
|
|
}
|
|
|
|
// Ensure unique to prevent unnecessary effort.
|
|
ids := in.GetMcpServerConfigIds()
|
|
slices.Sort(ids)
|
|
ids = slices.Compact(ids)
|
|
|
|
var (
|
|
wg sync.WaitGroup
|
|
errs error
|
|
|
|
mu sync.Mutex
|
|
tokens = make(map[string]string, len(ids))
|
|
tokenErrs = make(map[string]string)
|
|
)
|
|
|
|
externalAuthLoop:
|
|
for _, id := range ids {
|
|
eac, ok := s.externalAuthConfigs[id]
|
|
if !ok {
|
|
mu.Lock()
|
|
s.logger.Warn(ctx, "no MCP server config found by given ID", slog.F("id", id))
|
|
tokenErrs[id] = ErrNoMCPConfigFound.Error()
|
|
mu.Unlock()
|
|
continue
|
|
}
|
|
|
|
for _, link := range links {
|
|
if link.ProviderID != eac.ID {
|
|
continue
|
|
}
|
|
|
|
// Validate all configured External Auth links concurrently.
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
|
|
// TODO: timeout.
|
|
valid, _, validateErr := eac.ValidateToken(ctx, link.OAuthToken())
|
|
mu.Lock()
|
|
defer mu.Unlock()
|
|
if !valid {
|
|
// TODO: attempt refresh.
|
|
s.logger.Warn(ctx, "invalid/expired access token, cannot auto-configure MCP", slog.F("provider", link.ProviderID), slog.Error(validateErr))
|
|
tokenErrs[id] = ErrExpiredOrInvalidOAuthToken.Error()
|
|
return
|
|
}
|
|
|
|
if validateErr != nil {
|
|
errs = multierror.Append(errs, validateErr)
|
|
tokenErrs[id] = validateErr.Error()
|
|
} else {
|
|
tokens[id] = link.OAuthAccessToken
|
|
}
|
|
}()
|
|
|
|
continue externalAuthLoop
|
|
}
|
|
|
|
// No link found for this external auth config, so include a generic
|
|
// error.
|
|
mu.Lock()
|
|
tokenErrs[id] = ErrNoExternalAuthLinkFound.Error()
|
|
mu.Unlock()
|
|
}
|
|
|
|
wg.Wait()
|
|
return &proto.GetMCPServerAccessTokensBatchResponse{
|
|
AccessTokens: tokens,
|
|
Errors: tokenErrs,
|
|
}, errs
|
|
}
|
|
|
|
// IsAuthorized validates a given Coder API key and returns the user ID to which it belongs (if valid).
|
|
//
|
|
// SECURITY: when in.KeyId is set (the "delegated" path), this method trusts the
|
|
// caller's claim of identity and skips the key-secret check. This is safe only
|
|
// because the DRPCServer is reachable solely via the in-process
|
|
// [aibridged.MemTransportPipe]; the handler itself cannot tell whether it was
|
|
// invoked over the in-memory pipe or a network socket. If this RPC is ever
|
|
// exposed over a network boundary, any caller who knows a valid 10-char key ID
|
|
// (which is not secret) could authenticate as the key's owner without the
|
|
// secret. Do not bind this DRPCServer to a network listener.
|
|
//
|
|
// NOTE: this should really be using the code from [httpmw.ExtractAPIKey]. That function not only validates the key
|
|
// but handles many other cases like updating last used, expiry, etc. This code does not currently use it for
|
|
// a few reasons:
|
|
//
|
|
// 1. [httpmw.ExtractAPIKey] relies on keys being given in specific headers [httpmw.APITokenFromRequest] which AI
|
|
// bridge requests will not conform to.
|
|
// 2. The code mixes many different concerns, and handles HTTP responses too, which is undesirable here.
|
|
// 3. The core logic would need to be extracted, but that will surely be a complex & time-consuming distraction right now.
|
|
// 4. Once we have an Early Access release of AI Bridge, we need to return to this.
|
|
//
|
|
// TODO: replace with logic from [httpmw.ExtractAPIKey].
|
|
func (s *Server) IsAuthorized(ctx context.Context, in *proto.IsAuthorizedRequest) (*proto.IsAuthorizedResponse, error) {
|
|
//nolint:gocritic // AIBridged has specific authz rules.
|
|
ctx = dbauthz.AsAIBridged(ctx)
|
|
|
|
var (
|
|
keyID string
|
|
keySecret string
|
|
// delegated requests skip the secret check: the caller never
|
|
// has the secret. Trust is established at the in-process
|
|
// transport boundary, not in this RPC.
|
|
delegated bool
|
|
)
|
|
switch {
|
|
case in.GetKey() != "" && in.GetKeyId() != "":
|
|
return nil, ErrAmbiguousAuth
|
|
case in.GetKeyId() != "":
|
|
keyID = in.GetKeyId()
|
|
delegated = true
|
|
default:
|
|
var err error
|
|
keyID, keySecret, err = httpmw.SplitAPIToken(in.GetKey())
|
|
if err != nil {
|
|
return nil, ErrInvalidKey
|
|
}
|
|
}
|
|
|
|
// Key exists.
|
|
key, err := s.store.GetAPIKeyByID(ctx, keyID)
|
|
if err != nil {
|
|
s.logger.Warn(ctx, "failed to retrieve API key by id", slog.F("key_id", keyID), slog.Error(err))
|
|
return nil, ErrUnknownKey
|
|
}
|
|
|
|
// Key has not expired.
|
|
now := dbtime.Now()
|
|
if key.ExpiresAt.Before(now) {
|
|
return nil, ErrExpired
|
|
}
|
|
|
|
// Key secret matches (skipped for delegated callers).
|
|
if !delegated && !apikey.ValidateHash(key.HashedSecret, keySecret) {
|
|
return nil, ErrInvalidKey
|
|
}
|
|
|
|
// User exists.
|
|
user, err := s.store.GetUserByID(ctx, key.UserID)
|
|
if err != nil {
|
|
s.logger.Warn(ctx, "failed to retrieve API key user", slog.F("key_id", keyID), slog.F("user_id", key.UserID), slog.Error(err))
|
|
return nil, ErrUnknownUser
|
|
}
|
|
|
|
// User is not deleted or a system user.
|
|
if user.Deleted {
|
|
return nil, ErrDeletedUser
|
|
}
|
|
if user.IsSystem {
|
|
return nil, ErrSystemUser
|
|
}
|
|
|
|
return &proto.IsAuthorizedResponse{
|
|
OwnerId: key.UserID.String(),
|
|
ApiKeyId: key.ID,
|
|
Username: user.Username,
|
|
}, nil
|
|
}
|
|
|
|
// Deprecated: Injected MCP in AI Bridge is deprecated and will be removed in a future release.
|
|
func getCoderMCPServerConfig(experiments codersdk.Experiments, accessURL string) (*proto.MCPServerConfig, error) {
|
|
// Both the MCP & OAuth2 experiments are currently required in order to use our
|
|
// internal MCP server.
|
|
if !experiments.Enabled(codersdk.ExperimentMCPServerHTTP) {
|
|
return nil, xerrors.Errorf("%q experiment not enabled", codersdk.ExperimentMCPServerHTTP)
|
|
}
|
|
if !experiments.Enabled(codersdk.ExperimentOAuth2) {
|
|
return nil, xerrors.Errorf("%q experiment not enabled", codersdk.ExperimentOAuth2)
|
|
}
|
|
|
|
u, err := url.JoinPath(accessURL, codermcp.MCPEndpoint)
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("build MCP URL with %q: %w", accessURL, err)
|
|
}
|
|
|
|
return &proto.MCPServerConfig{
|
|
Id: aibridged.InternalMCPServerID,
|
|
Url: u,
|
|
}, nil
|
|
}
|
|
|
|
// credentialKindOrDefault converts the proto credential kind string to
|
|
// the database enum, defaulting to "centralized" when the value is
|
|
// empty or not a valid enum member.
|
|
func credentialKindOrDefault(kind string) database.CredentialKind {
|
|
ck := database.CredentialKind(kind)
|
|
if !ck.Valid() {
|
|
return database.CredentialKindCentralized
|
|
}
|
|
return ck
|
|
}
|
|
|
|
func metadataToMap(in map[string]*anypb.Any) map[string]any {
|
|
meta := make(map[string]any, len(in))
|
|
for k, v := range in {
|
|
if v == nil {
|
|
continue
|
|
}
|
|
var sv structpb.Value
|
|
if err := v.UnmarshalTo(&sv); err == nil {
|
|
meta[k] = sv.AsInterface()
|
|
}
|
|
}
|
|
return meta
|
|
}
|