feat: add endpoint to list aibridge interceptions (#19929)

Co-authored-by: Dean Sheather <dean@deansheather.com>
This commit is contained in:
Paweł Banaszewski
2025-09-26 16:20:33 +02:00
committed by GitHub
parent d70e26d2e3
commit 0a6ba5d51a
29 changed files with 2646 additions and 144 deletions
+197
View File
@@ -85,6 +85,51 @@ const docTemplate = `{
}
}
},
"/api/experimental/aibridge/interceptions": {
"get": {
"security": [
{
"CoderSessionToken": []
}
],
"produces": [
"application/json"
],
"tags": [
"AIBridge"
],
"summary": "List AIBridge interceptions",
"operationId": "list-aibridge-interceptions",
"parameters": [
{
"type": "string",
"description": "Search query in the format ` + "`" + `key:value` + "`" + `. Available keys are: initiator, provider, model, started_after, started_before.",
"name": "q",
"in": "query"
},
{
"type": "integer",
"description": "Page limit",
"name": "limit",
"in": "query"
},
{
"type": "string",
"description": "Cursor pagination after ID",
"name": "after_id",
"in": "query"
}
],
"responses": {
"200": {
"description": "OK",
"schema": {
"$ref": "#/definitions/codersdk.AIBridgeListInterceptionsResponse"
}
}
}
}
},
"/appearance": {
"get": {
"security": [
@@ -11226,6 +11271,62 @@ const docTemplate = `{
}
}
},
"codersdk.AIBridgeInterception": {
"type": "object",
"properties": {
"id": {
"type": "string",
"format": "uuid"
},
"initiator_id": {
"type": "string",
"format": "uuid"
},
"metadata": {
"type": "object",
"additionalProperties": {}
},
"model": {
"type": "string"
},
"provider": {
"type": "string"
},
"started_at": {
"type": "string",
"format": "date-time"
},
"token_usages": {
"type": "array",
"items": {
"$ref": "#/definitions/codersdk.AIBridgeTokenUsage"
}
},
"tool_usages": {
"type": "array",
"items": {
"$ref": "#/definitions/codersdk.AIBridgeToolUsage"
}
},
"user_prompts": {
"type": "array",
"items": {
"$ref": "#/definitions/codersdk.AIBridgeUserPrompt"
}
}
}
},
"codersdk.AIBridgeListInterceptionsResponse": {
"type": "object",
"properties": {
"results": {
"type": "array",
"items": {
"$ref": "#/definitions/codersdk.AIBridgeInterception"
}
}
}
},
"codersdk.AIBridgeOpenAIConfig": {
"type": "object",
"properties": {
@@ -11237,6 +11338,102 @@ const docTemplate = `{
}
}
},
"codersdk.AIBridgeTokenUsage": {
"type": "object",
"properties": {
"created_at": {
"type": "string",
"format": "date-time"
},
"id": {
"type": "string",
"format": "uuid"
},
"input_tokens": {
"type": "integer"
},
"interception_id": {
"type": "string",
"format": "uuid"
},
"metadata": {
"type": "object",
"additionalProperties": {}
},
"output_tokens": {
"type": "integer"
},
"provider_response_id": {
"type": "string"
}
}
},
"codersdk.AIBridgeToolUsage": {
"type": "object",
"properties": {
"created_at": {
"type": "string",
"format": "date-time"
},
"id": {
"type": "string",
"format": "uuid"
},
"injected": {
"type": "boolean"
},
"input": {
"type": "string"
},
"interception_id": {
"type": "string",
"format": "uuid"
},
"invocation_error": {
"type": "string"
},
"metadata": {
"type": "object",
"additionalProperties": {}
},
"provider_response_id": {
"type": "string"
},
"server_url": {
"type": "string"
},
"tool": {
"type": "string"
}
}
},
"codersdk.AIBridgeUserPrompt": {
"type": "object",
"properties": {
"created_at": {
"type": "string",
"format": "date-time"
},
"id": {
"type": "string",
"format": "uuid"
},
"interception_id": {
"type": "string",
"format": "uuid"
},
"metadata": {
"type": "object",
"additionalProperties": {}
},
"prompt": {
"type": "string"
},
"provider_response_id": {
"type": "string"
}
}
},
"codersdk.AIConfig": {
"type": "object",
"properties": {
+193
View File
@@ -65,6 +65,47 @@
}
}
},
"/api/experimental/aibridge/interceptions": {
"get": {
"security": [
{
"CoderSessionToken": []
}
],
"produces": ["application/json"],
"tags": ["AIBridge"],
"summary": "List AIBridge interceptions",
"operationId": "list-aibridge-interceptions",
"parameters": [
{
"type": "string",
"description": "Search query in the format `key:value`. Available keys are: initiator, provider, model, started_after, started_before.",
"name": "q",
"in": "query"
},
{
"type": "integer",
"description": "Page limit",
"name": "limit",
"in": "query"
},
{
"type": "string",
"description": "Cursor pagination after ID",
"name": "after_id",
"in": "query"
}
],
"responses": {
"200": {
"description": "OK",
"schema": {
"$ref": "#/definitions/codersdk.AIBridgeListInterceptionsResponse"
}
}
}
}
},
"/appearance": {
"get": {
"security": [
@@ -9954,6 +9995,62 @@
}
}
},
"codersdk.AIBridgeInterception": {
"type": "object",
"properties": {
"id": {
"type": "string",
"format": "uuid"
},
"initiator_id": {
"type": "string",
"format": "uuid"
},
"metadata": {
"type": "object",
"additionalProperties": {}
},
"model": {
"type": "string"
},
"provider": {
"type": "string"
},
"started_at": {
"type": "string",
"format": "date-time"
},
"token_usages": {
"type": "array",
"items": {
"$ref": "#/definitions/codersdk.AIBridgeTokenUsage"
}
},
"tool_usages": {
"type": "array",
"items": {
"$ref": "#/definitions/codersdk.AIBridgeToolUsage"
}
},
"user_prompts": {
"type": "array",
"items": {
"$ref": "#/definitions/codersdk.AIBridgeUserPrompt"
}
}
}
},
"codersdk.AIBridgeListInterceptionsResponse": {
"type": "object",
"properties": {
"results": {
"type": "array",
"items": {
"$ref": "#/definitions/codersdk.AIBridgeInterception"
}
}
}
},
"codersdk.AIBridgeOpenAIConfig": {
"type": "object",
"properties": {
@@ -9965,6 +10062,102 @@
}
}
},
"codersdk.AIBridgeTokenUsage": {
"type": "object",
"properties": {
"created_at": {
"type": "string",
"format": "date-time"
},
"id": {
"type": "string",
"format": "uuid"
},
"input_tokens": {
"type": "integer"
},
"interception_id": {
"type": "string",
"format": "uuid"
},
"metadata": {
"type": "object",
"additionalProperties": {}
},
"output_tokens": {
"type": "integer"
},
"provider_response_id": {
"type": "string"
}
}
},
"codersdk.AIBridgeToolUsage": {
"type": "object",
"properties": {
"created_at": {
"type": "string",
"format": "date-time"
},
"id": {
"type": "string",
"format": "uuid"
},
"injected": {
"type": "boolean"
},
"input": {
"type": "string"
},
"interception_id": {
"type": "string",
"format": "uuid"
},
"invocation_error": {
"type": "string"
},
"metadata": {
"type": "object",
"additionalProperties": {}
},
"provider_response_id": {
"type": "string"
},
"server_url": {
"type": "string"
},
"tool": {
"type": "string"
}
}
},
"codersdk.AIBridgeUserPrompt": {
"type": "object",
"properties": {
"created_at": {
"type": "string",
"format": "date-time"
},
"id": {
"type": "string",
"format": "uuid"
},
"interception_id": {
"type": "string",
"format": "uuid"
},
"metadata": {
"type": "object",
"additionalProperties": {}
},
"prompt": {
"type": "string"
},
"provider_response_id": {
"type": "string"
}
}
},
"codersdk.AIConfig": {
"type": "object",
"properties": {
+31 -26
View File
@@ -1000,38 +1000,41 @@ func New(options *Options) *API {
// Experimental routes are not guaranteed to be stable and may change at any time.
r.Route("/api/experimental", func(r chi.Router) {
api.ExperimentalHandler = r
r.NotFound(func(rw http.ResponseWriter, _ *http.Request) { httpapi.RouteNotFound(rw) })
// Only this group should be subject to apiKeyMiddleware; aibridged will mount its own
// router and handles key validation in a different fashion.
// See enterprise/x/aibridged/http.go.
r.Group(func(r chi.Router) {
r.Use(
// Specific routes can specify different limits, but every rate
// limit must be configurable by the admin.
apiRateLimiter,
httpmw.ReportCLITelemetry(api.Logger, options.Telemetry),
)
r.Route("/aitasks", func(r chi.Router) {
r.Use(apiKeyMiddleware)
r.Get("/prompts", api.aiTasksPrompts)
})
r.Route("/tasks", func(r chi.Router) {
r.Use(apiKeyMiddleware)
r.Route("/aitasks", func(r chi.Router) {
r.Get("/prompts", api.aiTasksPrompts)
})
r.Route("/tasks", func(r chi.Router) {
r.Use(apiRateLimiter)
r.Get("/", api.tasksList)
r.Get("/", api.tasksList)
r.Route("/{user}", func(r chi.Router) {
r.Use(httpmw.ExtractOrganizationMembersParam(options.Database, api.HTTPAuth.Authorize))
r.Get("/{id}", api.taskGet)
r.Delete("/{id}", api.taskDelete)
r.Post("/{id}/send", api.taskSend)
r.Get("/{id}/logs", api.taskLogs)
r.Post("/", api.tasksCreate)
})
})
r.Route("/mcp", func(r chi.Router) {
r.Use(
httpmw.RequireExperimentWithDevBypass(api.Experiments, codersdk.ExperimentOAuth2, codersdk.ExperimentMCPServerHTTP),
)
// MCP HTTP transport endpoint with mandatory authentication
r.Mount("/http", api.mcpHTTPHandler())
r.Route("/{user}", func(r chi.Router) {
r.Use(httpmw.ExtractOrganizationMembersParam(options.Database, api.HTTPAuth.Authorize))
r.Get("/{id}", api.taskGet)
r.Delete("/{id}", api.taskDelete)
r.Post("/{id}/send", api.taskSend)
r.Get("/{id}/logs", api.taskLogs)
r.Post("/", api.tasksCreate)
})
})
r.Route("/mcp", func(r chi.Router) {
r.Use(
apiKeyMiddleware,
httpmw.RequireExperimentWithDevBypass(api.Experiments, codersdk.ExperimentOAuth2, codersdk.ExperimentMCPServerHTTP),
)
// MCP HTTP transport endpoint with mandatory authentication
r.Mount("/http", api.mcpHTTPHandler())
})
})
r.Route("/api/v2", func(r chi.Router) {
@@ -1727,6 +1730,8 @@ type API struct {
// APIHandler serves "/api/v2"
APIHandler chi.Router
// ExperimentalHandler serves "/api/experimental"
ExperimentalHandler chi.Router
// RootHandler serves "/"
RootHandler chi.Router
+82
View File
@@ -13,6 +13,7 @@ import (
"github.com/google/uuid"
"github.com/hashicorp/hcl/v2"
"github.com/sqlc-dev/pqtype"
"golang.org/x/xerrors"
"tailscale.com/tailcfg"
@@ -918,3 +919,84 @@ func PreviewParameterValidation(v *previewtypes.ParameterValidation) codersdk.Pr
Monotonic: v.Monotonic,
}
}
func AIBridgeInterception(interception database.AIBridgeInterception, tokenUsages []database.AIBridgeTokenUsage, userPrompts []database.AIBridgeUserPrompt, toolUsages []database.AIBridgeToolUsage) codersdk.AIBridgeInterception {
sdkTokenUsages := List(tokenUsages, AIBridgeTokenUsage)
sort.Slice(sdkTokenUsages, func(i, j int) bool {
// created_at ASC
return sdkTokenUsages[i].CreatedAt.Before(sdkTokenUsages[j].CreatedAt)
})
sdkUserPrompts := List(userPrompts, AIBridgeUserPrompt)
sort.Slice(sdkUserPrompts, func(i, j int) bool {
// created_at ASC
return sdkUserPrompts[i].CreatedAt.Before(sdkUserPrompts[j].CreatedAt)
})
sdkToolUsages := List(toolUsages, AIBridgeToolUsage)
sort.Slice(sdkToolUsages, func(i, j int) bool {
// created_at ASC
return sdkToolUsages[i].CreatedAt.Before(sdkToolUsages[j].CreatedAt)
})
return codersdk.AIBridgeInterception{
ID: interception.ID,
InitiatorID: interception.InitiatorID,
Provider: interception.Provider,
Model: interception.Model,
Metadata: jsonOrEmptyMap(interception.Metadata),
StartedAt: interception.StartedAt,
TokenUsages: sdkTokenUsages,
UserPrompts: sdkUserPrompts,
ToolUsages: sdkToolUsages,
}
}
func AIBridgeTokenUsage(usage database.AIBridgeTokenUsage) codersdk.AIBridgeTokenUsage {
return codersdk.AIBridgeTokenUsage{
ID: usage.ID,
InterceptionID: usage.InterceptionID,
ProviderResponseID: usage.ProviderResponseID,
InputTokens: usage.InputTokens,
OutputTokens: usage.OutputTokens,
Metadata: jsonOrEmptyMap(usage.Metadata),
CreatedAt: usage.CreatedAt,
}
}
func AIBridgeUserPrompt(prompt database.AIBridgeUserPrompt) codersdk.AIBridgeUserPrompt {
return codersdk.AIBridgeUserPrompt{
ID: prompt.ID,
InterceptionID: prompt.InterceptionID,
ProviderResponseID: prompt.ProviderResponseID,
Prompt: prompt.Prompt,
Metadata: jsonOrEmptyMap(prompt.Metadata),
CreatedAt: prompt.CreatedAt,
}
}
func AIBridgeToolUsage(usage database.AIBridgeToolUsage) codersdk.AIBridgeToolUsage {
return codersdk.AIBridgeToolUsage{
ID: usage.ID,
InterceptionID: usage.InterceptionID,
ProviderResponseID: usage.ProviderResponseID,
ServerURL: usage.ServerUrl.String,
Tool: usage.Tool,
Input: usage.Input,
Injected: usage.Injected,
InvocationError: usage.InvocationError.String,
Metadata: jsonOrEmptyMap(usage.Metadata),
CreatedAt: usage.CreatedAt,
}
}
func jsonOrEmptyMap(rawMessage pqtype.NullRawMessage) map[string]any {
var m map[string]any
if !rawMessage.Valid {
return m
}
err := json.Unmarshal(rawMessage.RawMessage, &m)
if err != nil {
// Don't reuse m
return map[string]any{}
}
return m
}
+51 -6
View File
@@ -3842,26 +3842,26 @@ func (q *querier) InsertAIBridgeInterception(ctx context.Context, arg database.I
return insert(q.log, q.auth, rbac.ResourceAibridgeInterception.WithOwner(arg.InitiatorID.String()), q.db.InsertAIBridgeInterception)(ctx, arg)
}
func (q *querier) InsertAIBridgeTokenUsage(ctx context.Context, arg database.InsertAIBridgeTokenUsageParams) error {
func (q *querier) InsertAIBridgeTokenUsage(ctx context.Context, arg database.InsertAIBridgeTokenUsageParams) (database.AIBridgeTokenUsage, error) {
// All aibridge_token_usages records belong to the initiator of their associated interception.
if err := q.authorizeAIBridgeInterceptionAction(ctx, policy.ActionUpdate, arg.InterceptionID); err != nil {
return err
return database.AIBridgeTokenUsage{}, err
}
return q.db.InsertAIBridgeTokenUsage(ctx, arg)
}
func (q *querier) InsertAIBridgeToolUsage(ctx context.Context, arg database.InsertAIBridgeToolUsageParams) error {
func (q *querier) InsertAIBridgeToolUsage(ctx context.Context, arg database.InsertAIBridgeToolUsageParams) (database.AIBridgeToolUsage, error) {
// All aibridge_tool_usages records belong to the initiator of their associated interception.
if err := q.authorizeAIBridgeInterceptionAction(ctx, policy.ActionUpdate, arg.InterceptionID); err != nil {
return err
return database.AIBridgeToolUsage{}, err
}
return q.db.InsertAIBridgeToolUsage(ctx, arg)
}
func (q *querier) InsertAIBridgeUserPrompt(ctx context.Context, arg database.InsertAIBridgeUserPromptParams) error {
func (q *querier) InsertAIBridgeUserPrompt(ctx context.Context, arg database.InsertAIBridgeUserPromptParams) (database.AIBridgeUserPrompt, error) {
// All aibridge_user_prompts records belong to the initiator of their associated interception.
if err := q.authorizeAIBridgeInterceptionAction(ctx, policy.ActionUpdate, arg.InterceptionID); err != nil {
return err
return database.AIBridgeUserPrompt{}, err
}
return q.db.InsertAIBridgeUserPrompt(ctx, arg)
}
@@ -4409,6 +4409,44 @@ func (q *querier) InsertWorkspaceResourceMetadata(ctx context.Context, arg datab
return q.db.InsertWorkspaceResourceMetadata(ctx, arg)
}
func (q *querier) ListAIBridgeInterceptions(ctx context.Context, arg database.ListAIBridgeInterceptionsParams) ([]database.AIBridgeInterception, error) {
prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceAibridgeInterception.Type)
if err != nil {
return nil, xerrors.Errorf("(dev error) prepare sql filter: %w", err)
}
return q.db.ListAuthorizedAIBridgeInterceptions(ctx, arg, prep)
}
func (q *querier) ListAIBridgeTokenUsagesByInterceptionIDs(ctx context.Context, interceptionIDs []uuid.UUID) ([]database.AIBridgeTokenUsage, error) {
// This function is a system function until we implement a join for aibridge interceptions.
// Matches the behavior of the workspaces listing endpoint.
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
return nil, err
}
return q.db.ListAIBridgeTokenUsagesByInterceptionIDs(ctx, interceptionIDs)
}
func (q *querier) ListAIBridgeToolUsagesByInterceptionIDs(ctx context.Context, interceptionIDs []uuid.UUID) ([]database.AIBridgeToolUsage, error) {
// This function is a system function until we implement a join for aibridge interceptions.
// Matches the behavior of the workspaces listing endpoint.
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
return nil, err
}
return q.db.ListAIBridgeToolUsagesByInterceptionIDs(ctx, interceptionIDs)
}
func (q *querier) ListAIBridgeUserPromptsByInterceptionIDs(ctx context.Context, interceptionIDs []uuid.UUID) ([]database.AIBridgeUserPrompt, error) {
// This function is a system function until we implement a join for aibridge interceptions.
// Matches the behavior of the workspaces listing endpoint.
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
return nil, err
}
return q.db.ListAIBridgeUserPromptsByInterceptionIDs(ctx, interceptionIDs)
}
func (q *querier) ListProvisionerKeysByOrganization(ctx context.Context, organizationID uuid.UUID) ([]database.ProvisionerKey, error) {
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.ListProvisionerKeysByOrganization)(ctx, organizationID)
}
@@ -5761,3 +5799,10 @@ func (q *querier) GetAuthorizedConnectionLogsOffset(ctx context.Context, arg dat
func (q *querier) CountAuthorizedConnectionLogs(ctx context.Context, arg database.CountConnectionLogsParams, _ rbac.PreparedAuthorized) (int64, error) {
return q.CountConnectionLogs(ctx, arg)
}
func (q *querier) ListAuthorizedAIBridgeInterceptions(ctx context.Context, arg database.ListAIBridgeInterceptionsParams, _ rbac.PreparedAuthorized) ([]database.AIBridgeInterception, error) {
// TODO: Delete this function, all ListAIBridgeInterceptions should be authorized. For now just call ListAIBridgeInterceptions on the authz querier.
// This cannot be deleted for now because it's included in the
// database.Store interface, so dbauthz needs to implement it.
return q.ListAIBridgeInterceptions(ctx, arg)
}
+65 -27
View File
@@ -4334,13 +4334,6 @@ func TestInsertAPIKey_AsPrebuildsUser(t *testing.T) {
}
func (s *MethodTestSuite) TestAIBridge() {
s.Run("GetAIBridgeInterceptionByID", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
intID := uuid.UUID{2}
intc := testutil.Fake(s.T(), faker, database.AIBridgeInterception{ID: intID})
db.EXPECT().GetAIBridgeInterceptionByID(gomock.Any(), intID).Return(intc, nil).AnyTimes()
check.Args(intID).Asserts(intc, policy.ActionRead).Returns(intc)
}))
s.Run("InsertAIBridgeInterception", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
initID := uuid.UUID{3}
user := testutil.Fake(s.T(), faker, database.User{ID: initID})
@@ -4360,30 +4353,43 @@ func (s *MethodTestSuite) TestAIBridge() {
s.Run("InsertAIBridgeTokenUsage", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
intID := uuid.UUID{2}
intc := testutil.Fake(s.T(), faker, database.AIBridgeInterception{ID: intID})
params := database.InsertAIBridgeTokenUsageParams{InterceptionID: intc.ID}
db.EXPECT().GetAIBridgeInterceptionByID(gomock.Any(), intID).Return(intc, nil).AnyTimes() // Validation.
db.EXPECT().InsertAIBridgeTokenUsage(gomock.Any(), params).Return(nil).AnyTimes()
check.Args(params).Asserts(intc, policy.ActionUpdate)
}))
s.Run("InsertAIBridgeToolUsage", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
intID := uuid.UUID{2}
intc := testutil.Fake(s.T(), faker, database.AIBridgeInterception{ID: intID})
params := database.InsertAIBridgeToolUsageParams{InterceptionID: intc.ID}
db.EXPECT().GetAIBridgeInterceptionByID(gomock.Any(), intID).Return(intc, nil).AnyTimes() // Validation.
db.EXPECT().InsertAIBridgeToolUsage(gomock.Any(), params).Return(nil).AnyTimes()
params := database.InsertAIBridgeTokenUsageParams{InterceptionID: intc.ID}
expected := testutil.Fake(s.T(), faker, database.AIBridgeTokenUsage{InterceptionID: intc.ID})
db.EXPECT().InsertAIBridgeTokenUsage(gomock.Any(), params).Return(expected, nil).AnyTimes()
check.Args(params).Asserts(intc, policy.ActionUpdate)
}))
s.Run("InsertAIBridgeUserPrompt", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
intID := uuid.UUID{2}
intc := testutil.Fake(s.T(), faker, database.AIBridgeInterception{ID: intID})
params := database.InsertAIBridgeUserPromptParams{InterceptionID: intc.ID}
db.EXPECT().GetAIBridgeInterceptionByID(gomock.Any(), intID).Return(intc, nil).AnyTimes() // Validation.
db.EXPECT().InsertAIBridgeUserPrompt(gomock.Any(), params).Return(nil).AnyTimes()
params := database.InsertAIBridgeUserPromptParams{InterceptionID: intc.ID}
expected := testutil.Fake(s.T(), faker, database.AIBridgeUserPrompt{InterceptionID: intc.ID})
db.EXPECT().InsertAIBridgeUserPrompt(gomock.Any(), params).Return(expected, nil).AnyTimes()
check.Args(params).Asserts(intc, policy.ActionUpdate)
}))
s.Run("InsertAIBridgeToolUsage", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
intID := uuid.UUID{2}
intc := testutil.Fake(s.T(), faker, database.AIBridgeInterception{ID: intID})
db.EXPECT().GetAIBridgeInterceptionByID(gomock.Any(), intID).Return(intc, nil).AnyTimes() // Validation.
params := database.InsertAIBridgeToolUsageParams{InterceptionID: intc.ID}
expected := testutil.Fake(s.T(), faker, database.AIBridgeToolUsage{InterceptionID: intc.ID})
db.EXPECT().InsertAIBridgeToolUsage(gomock.Any(), params).Return(expected, nil).AnyTimes()
check.Args(params).Asserts(intc, policy.ActionUpdate)
}))
s.Run("GetAIBridgeInterceptionByID", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
intID := uuid.UUID{2}
intc := testutil.Fake(s.T(), faker, database.AIBridgeInterception{ID: intID})
db.EXPECT().GetAIBridgeInterceptionByID(gomock.Any(), intID).Return(intc, nil).AnyTimes()
check.Args(intID).Asserts(intc, policy.ActionRead).Returns(intc)
}))
s.Run("GetAIBridgeInterceptions", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
a := testutil.Fake(s.T(), faker, database.AIBridgeInterception{})
b := testutil.Fake(s.T(), faker, database.AIBridgeInterception{})
@@ -4401,6 +4407,16 @@ func (s *MethodTestSuite) TestAIBridge() {
check.Args(intID).Asserts(intc, policy.ActionRead).Returns(toks)
}))
s.Run("GetAIBridgeUserPromptsByInterceptionID", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
intID := uuid.UUID{2}
intc := testutil.Fake(s.T(), faker, database.AIBridgeInterception{ID: intID})
pr := testutil.Fake(s.T(), faker, database.AIBridgeUserPrompt{InterceptionID: intID})
prs := []database.AIBridgeUserPrompt{pr}
db.EXPECT().GetAIBridgeInterceptionByID(gomock.Any(), intID).Return(intc, nil).AnyTimes() // Validation.
db.EXPECT().GetAIBridgeUserPromptsByInterceptionID(gomock.Any(), intID).Return(prs, nil).AnyTimes()
check.Args(intID).Asserts(intc, policy.ActionRead).Returns(prs)
}))
s.Run("GetAIBridgeToolUsagesByInterceptionID", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
intID := uuid.UUID{2}
intc := testutil.Fake(s.T(), faker, database.AIBridgeInterception{ID: intID})
@@ -4411,13 +4427,35 @@ func (s *MethodTestSuite) TestAIBridge() {
check.Args(intID).Asserts(intc, policy.ActionRead).Returns(tools)
}))
s.Run("GetAIBridgeUserPromptsByInterceptionID", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
intID := uuid.UUID{2}
intc := testutil.Fake(s.T(), faker, database.AIBridgeInterception{ID: intID})
pr := testutil.Fake(s.T(), faker, database.AIBridgeUserPrompt{InterceptionID: intID})
prs := []database.AIBridgeUserPrompt{pr}
db.EXPECT().GetAIBridgeInterceptionByID(gomock.Any(), intID).Return(intc, nil).AnyTimes() // Validation.
db.EXPECT().GetAIBridgeUserPromptsByInterceptionID(gomock.Any(), intID).Return(prs, nil).AnyTimes()
check.Args(intID).Asserts(intc, policy.ActionRead).Returns(prs)
s.Run("ListAIBridgeInterceptions", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
params := database.ListAIBridgeInterceptionsParams{}
db.EXPECT().ListAuthorizedAIBridgeInterceptions(gomock.Any(), params, gomock.Any()).Return([]database.AIBridgeInterception{}, nil).AnyTimes()
// No asserts here because SQLFilter.
check.Args(params).Asserts()
}))
s.Run("ListAuthorizedAIBridgeInterceptions", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
params := database.ListAIBridgeInterceptionsParams{}
db.EXPECT().ListAuthorizedAIBridgeInterceptions(gomock.Any(), params, gomock.Any()).Return([]database.AIBridgeInterception{}, nil).AnyTimes()
// No asserts here because SQLFilter.
check.Args(params, emptyPreparedAuthorized{}).Asserts()
}))
s.Run("ListAIBridgeTokenUsagesByInterceptionIDs", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
ids := []uuid.UUID{{1}}
db.EXPECT().ListAIBridgeTokenUsagesByInterceptionIDs(gomock.Any(), ids).Return([]database.AIBridgeTokenUsage{}, nil).AnyTimes()
check.Args(ids).Asserts(rbac.ResourceSystem, policy.ActionRead).Returns([]database.AIBridgeTokenUsage{})
}))
s.Run("ListAIBridgeUserPromptsByInterceptionIDs", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
ids := []uuid.UUID{{1}}
db.EXPECT().ListAIBridgeUserPromptsByInterceptionIDs(gomock.Any(), ids).Return([]database.AIBridgeUserPrompt{}, nil).AnyTimes()
check.Args(ids).Asserts(rbac.ResourceSystem, policy.ActionRead).Returns([]database.AIBridgeUserPrompt{})
}))
s.Run("ListAIBridgeToolUsagesByInterceptionIDs", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
ids := []uuid.UUID{{1}}
db.EXPECT().ListAIBridgeToolUsagesByInterceptionIDs(gomock.Any(), ids).Return([]database.AIBridgeToolUsage{}, nil).AnyTimes()
check.Args(ids).Asserts(rbac.ResourceSystem, policy.ActionRead).Returns([]database.AIBridgeToolUsage{})
}))
}
+65
View File
@@ -1471,6 +1471,71 @@ func ClaimPrebuild(
return claimedWorkspace
}
func AIBridgeInterception(t testing.TB, db database.Store, seed database.InsertAIBridgeInterceptionParams) database.AIBridgeInterception {
interception, err := db.InsertAIBridgeInterception(genCtx, database.InsertAIBridgeInterceptionParams{
ID: takeFirst(seed.ID, uuid.New()),
InitiatorID: takeFirst(seed.InitiatorID, uuid.New()),
Provider: takeFirst(seed.Provider, "provider"),
Model: takeFirst(seed.Model, "model"),
Metadata: takeFirstSlice(seed.Metadata, json.RawMessage("{}")),
StartedAt: takeFirst(seed.StartedAt, dbtime.Now()),
})
require.NoError(t, err, "insert aibridge interception")
return interception
}
func AIBridgeTokenUsage(t testing.TB, db database.Store, seed database.InsertAIBridgeTokenUsageParams) database.AIBridgeTokenUsage {
usage, err := db.InsertAIBridgeTokenUsage(genCtx, database.InsertAIBridgeTokenUsageParams{
ID: takeFirst(seed.ID, uuid.New()),
InterceptionID: takeFirst(seed.InterceptionID, uuid.New()),
ProviderResponseID: takeFirst(seed.ProviderResponseID, "provider_response_id"),
InputTokens: takeFirst(seed.InputTokens, 100),
OutputTokens: takeFirst(seed.OutputTokens, 100),
Metadata: takeFirstSlice(seed.Metadata, json.RawMessage("{}")),
CreatedAt: takeFirst(seed.CreatedAt, dbtime.Now()),
})
require.NoError(t, err, "insert aibridge token usage")
return usage
}
func AIBridgeUserPrompt(t testing.TB, db database.Store, seed database.InsertAIBridgeUserPromptParams) database.AIBridgeUserPrompt {
prompt, err := db.InsertAIBridgeUserPrompt(genCtx, database.InsertAIBridgeUserPromptParams{
ID: takeFirst(seed.ID, uuid.New()),
InterceptionID: takeFirst(seed.InterceptionID, uuid.New()),
ProviderResponseID: takeFirst(seed.ProviderResponseID, "provider_response_id"),
Prompt: takeFirst(seed.Prompt, "prompt"),
Metadata: takeFirstSlice(seed.Metadata, json.RawMessage("{}")),
CreatedAt: takeFirst(seed.CreatedAt, dbtime.Now()),
})
require.NoError(t, err, "insert aibridge user prompt")
return prompt
}
func AIBridgeToolUsage(t testing.TB, db database.Store, seed database.InsertAIBridgeToolUsageParams) database.AIBridgeToolUsage {
serverURL := sql.NullString{}
if seed.ServerUrl.Valid {
serverURL = seed.ServerUrl
}
invocationError := sql.NullString{}
if seed.InvocationError.Valid {
invocationError = seed.InvocationError
}
toolUsage, err := db.InsertAIBridgeToolUsage(genCtx, database.InsertAIBridgeToolUsageParams{
ID: takeFirst(seed.ID, uuid.New()),
InterceptionID: takeFirst(seed.InterceptionID, uuid.New()),
ProviderResponseID: takeFirst(seed.ProviderResponseID, "provider_response_id"),
Tool: takeFirst(seed.Tool, "tool"),
ServerUrl: serverURL,
Input: takeFirst(seed.Input, "input"),
Injected: takeFirst(seed.Injected, false),
InvocationError: invocationError,
Metadata: takeFirstSlice(seed.Metadata, json.RawMessage("{}")),
CreatedAt: takeFirst(seed.CreatedAt, dbtime.Now()),
})
require.NoError(t, err, "insert aibridge tool usage")
return toolUsage
}
func provisionerJobTiming(t testing.TB, db database.Store, seed database.ProvisionerJobTiming) database.ProvisionerJobTiming {
timing, err := db.InsertProvisionerJobTimings(genCtx, database.InsertProvisionerJobTimingsParams{
JobID: takeFirst(seed.JobID, uuid.New()),
+44 -9
View File
@@ -2210,25 +2210,25 @@ func (m queryMetricsStore) InsertAIBridgeInterception(ctx context.Context, arg d
return r0, r1
}
func (m queryMetricsStore) InsertAIBridgeTokenUsage(ctx context.Context, arg database.InsertAIBridgeTokenUsageParams) error {
func (m queryMetricsStore) InsertAIBridgeTokenUsage(ctx context.Context, arg database.InsertAIBridgeTokenUsageParams) (database.AIBridgeTokenUsage, error) {
start := time.Now()
r0 := m.s.InsertAIBridgeTokenUsage(ctx, arg)
r0, r1 := m.s.InsertAIBridgeTokenUsage(ctx, arg)
m.queryLatencies.WithLabelValues("InsertAIBridgeTokenUsage").Observe(time.Since(start).Seconds())
return r0
return r0, r1
}
func (m queryMetricsStore) InsertAIBridgeToolUsage(ctx context.Context, arg database.InsertAIBridgeToolUsageParams) error {
func (m queryMetricsStore) InsertAIBridgeToolUsage(ctx context.Context, arg database.InsertAIBridgeToolUsageParams) (database.AIBridgeToolUsage, error) {
start := time.Now()
r0 := m.s.InsertAIBridgeToolUsage(ctx, arg)
r0, r1 := m.s.InsertAIBridgeToolUsage(ctx, arg)
m.queryLatencies.WithLabelValues("InsertAIBridgeToolUsage").Observe(time.Since(start).Seconds())
return r0
return r0, r1
}
func (m queryMetricsStore) InsertAIBridgeUserPrompt(ctx context.Context, arg database.InsertAIBridgeUserPromptParams) error {
func (m queryMetricsStore) InsertAIBridgeUserPrompt(ctx context.Context, arg database.InsertAIBridgeUserPromptParams) (database.AIBridgeUserPrompt, error) {
start := time.Now()
r0 := m.s.InsertAIBridgeUserPrompt(ctx, arg)
r0, r1 := m.s.InsertAIBridgeUserPrompt(ctx, arg)
m.queryLatencies.WithLabelValues("InsertAIBridgeUserPrompt").Observe(time.Since(start).Seconds())
return r0
return r0, r1
}
func (m queryMetricsStore) InsertAPIKey(ctx context.Context, arg database.InsertAPIKeyParams) (database.APIKey, error) {
@@ -2665,6 +2665,34 @@ func (m queryMetricsStore) InsertWorkspaceResourceMetadata(ctx context.Context,
return metadata, err
}
func (m queryMetricsStore) ListAIBridgeInterceptions(ctx context.Context, arg database.ListAIBridgeInterceptionsParams) ([]database.AIBridgeInterception, error) {
start := time.Now()
r0, r1 := m.s.ListAIBridgeInterceptions(ctx, arg)
m.queryLatencies.WithLabelValues("ListAIBridgeInterceptions").Observe(time.Since(start).Seconds())
return r0, r1
}
func (m queryMetricsStore) ListAIBridgeTokenUsagesByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]database.AIBridgeTokenUsage, error) {
start := time.Now()
r0, r1 := m.s.ListAIBridgeTokenUsagesByInterceptionIDs(ctx, interceptionIds)
m.queryLatencies.WithLabelValues("ListAIBridgeTokenUsagesByInterceptionIDs").Observe(time.Since(start).Seconds())
return r0, r1
}
func (m queryMetricsStore) ListAIBridgeToolUsagesByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]database.AIBridgeToolUsage, error) {
start := time.Now()
r0, r1 := m.s.ListAIBridgeToolUsagesByInterceptionIDs(ctx, interceptionIds)
m.queryLatencies.WithLabelValues("ListAIBridgeToolUsagesByInterceptionIDs").Observe(time.Since(start).Seconds())
return r0, r1
}
func (m queryMetricsStore) ListAIBridgeUserPromptsByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]database.AIBridgeUserPrompt, error) {
start := time.Now()
r0, r1 := m.s.ListAIBridgeUserPromptsByInterceptionIDs(ctx, interceptionIds)
m.queryLatencies.WithLabelValues("ListAIBridgeUserPromptsByInterceptionIDs").Observe(time.Since(start).Seconds())
return r0, r1
}
func (m queryMetricsStore) ListProvisionerKeysByOrganization(ctx context.Context, organizationID uuid.UUID) ([]database.ProvisionerKey, error) {
start := time.Now()
r0, r1 := m.s.ListProvisionerKeysByOrganization(ctx, organizationID)
@@ -3630,3 +3658,10 @@ func (m queryMetricsStore) CountAuthorizedConnectionLogs(ctx context.Context, ar
m.queryLatencies.WithLabelValues("CountAuthorizedConnectionLogs").Observe(time.Since(start).Seconds())
return r0, r1
}
func (m queryMetricsStore) ListAuthorizedAIBridgeInterceptions(ctx context.Context, arg database.ListAIBridgeInterceptionsParams, prepared rbac.PreparedAuthorized) ([]database.AIBridgeInterception, error) {
start := time.Now()
r0, r1 := m.s.ListAuthorizedAIBridgeInterceptions(ctx, arg, prepared)
m.queryLatencies.WithLabelValues("ListAuthorizedAIBridgeInterceptions").Observe(time.Since(start).Seconds())
return r0, r1
}
+87 -9
View File
@@ -4724,11 +4724,12 @@ func (mr *MockStoreMockRecorder) InsertAIBridgeInterception(ctx, arg any) *gomoc
}
// InsertAIBridgeTokenUsage mocks base method.
func (m *MockStore) InsertAIBridgeTokenUsage(ctx context.Context, arg database.InsertAIBridgeTokenUsageParams) error {
func (m *MockStore) InsertAIBridgeTokenUsage(ctx context.Context, arg database.InsertAIBridgeTokenUsageParams) (database.AIBridgeTokenUsage, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "InsertAIBridgeTokenUsage", ctx, arg)
ret0, _ := ret[0].(error)
return ret0
ret0, _ := ret[0].(database.AIBridgeTokenUsage)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// InsertAIBridgeTokenUsage indicates an expected call of InsertAIBridgeTokenUsage.
@@ -4738,11 +4739,12 @@ func (mr *MockStoreMockRecorder) InsertAIBridgeTokenUsage(ctx, arg any) *gomock.
}
// InsertAIBridgeToolUsage mocks base method.
func (m *MockStore) InsertAIBridgeToolUsage(ctx context.Context, arg database.InsertAIBridgeToolUsageParams) error {
func (m *MockStore) InsertAIBridgeToolUsage(ctx context.Context, arg database.InsertAIBridgeToolUsageParams) (database.AIBridgeToolUsage, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "InsertAIBridgeToolUsage", ctx, arg)
ret0, _ := ret[0].(error)
return ret0
ret0, _ := ret[0].(database.AIBridgeToolUsage)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// InsertAIBridgeToolUsage indicates an expected call of InsertAIBridgeToolUsage.
@@ -4752,11 +4754,12 @@ func (mr *MockStoreMockRecorder) InsertAIBridgeToolUsage(ctx, arg any) *gomock.C
}
// InsertAIBridgeUserPrompt mocks base method.
func (m *MockStore) InsertAIBridgeUserPrompt(ctx context.Context, arg database.InsertAIBridgeUserPromptParams) error {
func (m *MockStore) InsertAIBridgeUserPrompt(ctx context.Context, arg database.InsertAIBridgeUserPromptParams) (database.AIBridgeUserPrompt, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "InsertAIBridgeUserPrompt", ctx, arg)
ret0, _ := ret[0].(error)
return ret0
ret0, _ := ret[0].(database.AIBridgeUserPrompt)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// InsertAIBridgeUserPrompt indicates an expected call of InsertAIBridgeUserPrompt.
@@ -5680,6 +5683,81 @@ func (mr *MockStoreMockRecorder) InsertWorkspaceResourceMetadata(ctx, arg any) *
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertWorkspaceResourceMetadata", reflect.TypeOf((*MockStore)(nil).InsertWorkspaceResourceMetadata), ctx, arg)
}
// ListAIBridgeInterceptions mocks base method.
func (m *MockStore) ListAIBridgeInterceptions(ctx context.Context, arg database.ListAIBridgeInterceptionsParams) ([]database.AIBridgeInterception, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ListAIBridgeInterceptions", ctx, arg)
ret0, _ := ret[0].([]database.AIBridgeInterception)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ListAIBridgeInterceptions indicates an expected call of ListAIBridgeInterceptions.
func (mr *MockStoreMockRecorder) ListAIBridgeInterceptions(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAIBridgeInterceptions", reflect.TypeOf((*MockStore)(nil).ListAIBridgeInterceptions), ctx, arg)
}
// ListAIBridgeTokenUsagesByInterceptionIDs mocks base method.
func (m *MockStore) ListAIBridgeTokenUsagesByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]database.AIBridgeTokenUsage, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ListAIBridgeTokenUsagesByInterceptionIDs", ctx, interceptionIds)
ret0, _ := ret[0].([]database.AIBridgeTokenUsage)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ListAIBridgeTokenUsagesByInterceptionIDs indicates an expected call of ListAIBridgeTokenUsagesByInterceptionIDs.
func (mr *MockStoreMockRecorder) ListAIBridgeTokenUsagesByInterceptionIDs(ctx, interceptionIds any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAIBridgeTokenUsagesByInterceptionIDs", reflect.TypeOf((*MockStore)(nil).ListAIBridgeTokenUsagesByInterceptionIDs), ctx, interceptionIds)
}
// ListAIBridgeToolUsagesByInterceptionIDs mocks base method.
func (m *MockStore) ListAIBridgeToolUsagesByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]database.AIBridgeToolUsage, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ListAIBridgeToolUsagesByInterceptionIDs", ctx, interceptionIds)
ret0, _ := ret[0].([]database.AIBridgeToolUsage)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ListAIBridgeToolUsagesByInterceptionIDs indicates an expected call of ListAIBridgeToolUsagesByInterceptionIDs.
func (mr *MockStoreMockRecorder) ListAIBridgeToolUsagesByInterceptionIDs(ctx, interceptionIds any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAIBridgeToolUsagesByInterceptionIDs", reflect.TypeOf((*MockStore)(nil).ListAIBridgeToolUsagesByInterceptionIDs), ctx, interceptionIds)
}
// ListAIBridgeUserPromptsByInterceptionIDs mocks base method.
func (m *MockStore) ListAIBridgeUserPromptsByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]database.AIBridgeUserPrompt, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ListAIBridgeUserPromptsByInterceptionIDs", ctx, interceptionIds)
ret0, _ := ret[0].([]database.AIBridgeUserPrompt)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ListAIBridgeUserPromptsByInterceptionIDs indicates an expected call of ListAIBridgeUserPromptsByInterceptionIDs.
func (mr *MockStoreMockRecorder) ListAIBridgeUserPromptsByInterceptionIDs(ctx, interceptionIds any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAIBridgeUserPromptsByInterceptionIDs", reflect.TypeOf((*MockStore)(nil).ListAIBridgeUserPromptsByInterceptionIDs), ctx, interceptionIds)
}
// ListAuthorizedAIBridgeInterceptions mocks base method.
func (m *MockStore) ListAuthorizedAIBridgeInterceptions(ctx context.Context, arg database.ListAIBridgeInterceptionsParams, prepared rbac.PreparedAuthorized) ([]database.AIBridgeInterception, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ListAuthorizedAIBridgeInterceptions", ctx, arg, prepared)
ret0, _ := ret[0].([]database.AIBridgeInterception)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ListAuthorizedAIBridgeInterceptions indicates an expected call of ListAuthorizedAIBridgeInterceptions.
func (mr *MockStoreMockRecorder) ListAuthorizedAIBridgeInterceptions(ctx, arg, prepared any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAuthorizedAIBridgeInterceptions", reflect.TypeOf((*MockStore)(nil).ListAuthorizedAIBridgeInterceptions), ctx, arg, prepared)
}
// ListProvisionerKeysByOrganization mocks base method.
func (m *MockStore) ListProvisionerKeysByOrganization(ctx context.Context, organizationID uuid.UUID) ([]database.ProvisionerKey, error) {
m.ctrl.T.Helper()
+6
View File
@@ -3112,6 +3112,12 @@ CREATE INDEX idx_agent_stats_user_id ON workspace_agent_stats USING btree (user_
CREATE INDEX idx_aibridge_interceptions_initiator_id ON aibridge_interceptions USING btree (initiator_id);
CREATE INDEX idx_aibridge_interceptions_model ON aibridge_interceptions USING btree (model);
CREATE INDEX idx_aibridge_interceptions_provider ON aibridge_interceptions USING btree (provider);
CREATE INDEX idx_aibridge_interceptions_started_id_desc ON aibridge_interceptions USING btree (started_at DESC, id DESC);
CREATE INDEX idx_aibridge_token_usages_interception_id ON aibridge_token_usages USING btree (interception_id);
CREATE INDEX idx_aibridge_token_usages_provider_response_id ON aibridge_token_usages USING btree (provider_response_id);
@@ -0,0 +1,5 @@
DROP INDEX IF EXISTS idx_aibridge_interceptions_started_id_desc;
DROP INDEX IF EXISTS idx_aibridge_interceptions_provider;
DROP INDEX IF EXISTS idx_aibridge_interceptions_model;
@@ -0,0 +1,9 @@
-- This is used for consistent cursor pagination.
CREATE INDEX IF NOT EXISTS idx_aibridge_interceptions_started_id_desc
ON aibridge_interceptions (started_at DESC, id DESC);
CREATE INDEX IF NOT EXISTS idx_aibridge_interceptions_provider
ON aibridge_interceptions (provider);
CREATE INDEX IF NOT EXISTS idx_aibridge_interceptions_model
ON aibridge_interceptions (model);
+55
View File
@@ -51,6 +51,7 @@ type customQuerier interface {
userQuerier
auditLogQuerier
connectionLogQuerier
aibridgeQuerier
}
type templateQuerier interface {
@@ -761,6 +762,60 @@ func (q *sqlQuerier) CountAuthorizedConnectionLogs(ctx context.Context, arg Coun
return count, nil
}
type aibridgeQuerier interface {
ListAuthorizedAIBridgeInterceptions(ctx context.Context, arg ListAIBridgeInterceptionsParams, prepared rbac.PreparedAuthorized) ([]AIBridgeInterception, error)
}
func (q *sqlQuerier) ListAuthorizedAIBridgeInterceptions(ctx context.Context, arg ListAIBridgeInterceptionsParams, prepared rbac.PreparedAuthorized) ([]AIBridgeInterception, error) {
authorizedFilter, err := prepared.CompileToSQL(ctx, regosql.ConvertConfig{
VariableConverter: regosql.AIBridgeInterceptionConverter(),
})
if err != nil {
return nil, xerrors.Errorf("compile authorized filter: %w", err)
}
filtered, err := insertAuthorizedFilter(listAIBridgeInterceptions, fmt.Sprintf(" AND %s", authorizedFilter))
if err != nil {
return nil, xerrors.Errorf("insert authorized filter: %w", err)
}
query := fmt.Sprintf("-- name: ListAuthorizedAIBridgeInterceptions :many\n%s", filtered)
rows, err := q.db.QueryContext(ctx, query,
arg.StartedAfter,
arg.StartedBefore,
arg.InitiatorID,
arg.Provider,
arg.Model,
arg.AfterID,
arg.Limit,
)
if err != nil {
return nil, err
}
defer rows.Close()
var items []AIBridgeInterception
for rows.Next() {
var i AIBridgeInterception
if err := rows.Scan(
&i.ID,
&i.InitiatorID,
&i.Provider,
&i.Model,
&i.StartedAt,
&i.Metadata,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
func insertAuthorizedFilter(query string, replaceWith string) (string, error) {
if !strings.Contains(query, authorizedQueryPlaceholder) {
return "", xerrors.Errorf("query does not contain authorized replace string, this is not an authorized query")
+7 -3
View File
@@ -508,9 +508,9 @@ type sqlcQuerier interface {
GetWorkspacesByTemplateID(ctx context.Context, templateID uuid.UUID) ([]WorkspaceTable, error)
GetWorkspacesEligibleForTransition(ctx context.Context, now time.Time) ([]GetWorkspacesEligibleForTransitionRow, error)
InsertAIBridgeInterception(ctx context.Context, arg InsertAIBridgeInterceptionParams) (AIBridgeInterception, error)
InsertAIBridgeTokenUsage(ctx context.Context, arg InsertAIBridgeTokenUsageParams) error
InsertAIBridgeToolUsage(ctx context.Context, arg InsertAIBridgeToolUsageParams) error
InsertAIBridgeUserPrompt(ctx context.Context, arg InsertAIBridgeUserPromptParams) error
InsertAIBridgeTokenUsage(ctx context.Context, arg InsertAIBridgeTokenUsageParams) (AIBridgeTokenUsage, error)
InsertAIBridgeToolUsage(ctx context.Context, arg InsertAIBridgeToolUsageParams) (AIBridgeToolUsage, error)
InsertAIBridgeUserPrompt(ctx context.Context, arg InsertAIBridgeUserPromptParams) (AIBridgeUserPrompt, error)
InsertAPIKey(ctx context.Context, arg InsertAPIKeyParams) (APIKey, error)
// We use the organization_id as the id
// for simplicity since all users is
@@ -585,6 +585,10 @@ type sqlcQuerier interface {
InsertWorkspaceProxy(ctx context.Context, arg InsertWorkspaceProxyParams) (WorkspaceProxy, error)
InsertWorkspaceResource(ctx context.Context, arg InsertWorkspaceResourceParams) (WorkspaceResource, error)
InsertWorkspaceResourceMetadata(ctx context.Context, arg InsertWorkspaceResourceMetadataParams) ([]WorkspaceResourceMetadatum, error)
ListAIBridgeInterceptions(ctx context.Context, arg ListAIBridgeInterceptionsParams) ([]AIBridgeInterception, error)
ListAIBridgeTokenUsagesByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]AIBridgeTokenUsage, error)
ListAIBridgeToolUsagesByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]AIBridgeToolUsage, error)
ListAIBridgeUserPromptsByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]AIBridgeUserPrompt, error)
ListProvisionerKeysByOrganization(ctx context.Context, organizationID uuid.UUID) ([]ProvisionerKey, error)
ListProvisionerKeysByOrganizationExcludeReserved(ctx context.Context, organizationID uuid.UUID) ([]ProvisionerKey, error)
ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]UserSecret, error)
+320 -19
View File
@@ -112,7 +112,12 @@ func (q *sqlQuerier) ActivityBumpWorkspace(ctx context.Context, arg ActivityBump
}
const getAIBridgeInterceptionByID = `-- name: GetAIBridgeInterceptionByID :one
SELECT id, initiator_id, provider, model, started_at, metadata FROM aibridge_interceptions WHERE id = $1::uuid
SELECT
id, initiator_id, provider, model, started_at, metadata
FROM
aibridge_interceptions
WHERE
id = $1::uuid
`
func (q *sqlQuerier) GetAIBridgeInterceptionByID(ctx context.Context, id uuid.UUID) (AIBridgeInterception, error) {
@@ -130,7 +135,10 @@ func (q *sqlQuerier) GetAIBridgeInterceptionByID(ctx context.Context, id uuid.UU
}
const getAIBridgeInterceptions = `-- name: GetAIBridgeInterceptions :many
SELECT id, initiator_id, provider, model, started_at, metadata FROM aibridge_interceptions
SELECT
id, initiator_id, provider, model, started_at, metadata
FROM
aibridge_interceptions
`
func (q *sqlQuerier) GetAIBridgeInterceptions(ctx context.Context) ([]AIBridgeInterception, error) {
@@ -164,7 +172,13 @@ func (q *sqlQuerier) GetAIBridgeInterceptions(ctx context.Context) ([]AIBridgeIn
}
const getAIBridgeTokenUsagesByInterceptionID = `-- name: GetAIBridgeTokenUsagesByInterceptionID :many
SELECT id, interception_id, provider_response_id, input_tokens, output_tokens, metadata, created_at FROM aibridge_token_usages WHERE interception_id = $1::uuid
SELECT
id, interception_id, provider_response_id, input_tokens, output_tokens, metadata, created_at
FROM
aibridge_token_usages WHERE interception_id = $1::uuid
ORDER BY
created_at ASC,
id ASC
`
func (q *sqlQuerier) GetAIBridgeTokenUsagesByInterceptionID(ctx context.Context, interceptionID uuid.UUID) ([]AIBridgeTokenUsage, error) {
@@ -199,7 +213,15 @@ func (q *sqlQuerier) GetAIBridgeTokenUsagesByInterceptionID(ctx context.Context,
}
const getAIBridgeToolUsagesByInterceptionID = `-- name: GetAIBridgeToolUsagesByInterceptionID :many
SELECT id, interception_id, provider_response_id, server_url, tool, input, injected, invocation_error, metadata, created_at FROM aibridge_tool_usages WHERE interception_id = $1::uuid
SELECT
id, interception_id, provider_response_id, server_url, tool, input, injected, invocation_error, metadata, created_at
FROM
aibridge_tool_usages
WHERE
interception_id = $1::uuid
ORDER BY
created_at ASC,
id ASC
`
func (q *sqlQuerier) GetAIBridgeToolUsagesByInterceptionID(ctx context.Context, interceptionID uuid.UUID) ([]AIBridgeToolUsage, error) {
@@ -237,7 +259,15 @@ func (q *sqlQuerier) GetAIBridgeToolUsagesByInterceptionID(ctx context.Context,
}
const getAIBridgeUserPromptsByInterceptionID = `-- name: GetAIBridgeUserPromptsByInterceptionID :many
SELECT id, interception_id, provider_response_id, prompt, metadata, created_at FROM aibridge_user_prompts WHERE interception_id = $1::uuid
SELECT
id, interception_id, provider_response_id, prompt, metadata, created_at
FROM
aibridge_user_prompts
WHERE
interception_id = $1::uuid
ORDER BY
created_at ASC,
id ASC
`
func (q *sqlQuerier) GetAIBridgeUserPromptsByInterceptionID(ctx context.Context, interceptionID uuid.UUID) ([]AIBridgeUserPrompt, error) {
@@ -271,8 +301,11 @@ func (q *sqlQuerier) GetAIBridgeUserPromptsByInterceptionID(ctx context.Context,
}
const insertAIBridgeInterception = `-- name: InsertAIBridgeInterception :one
INSERT INTO aibridge_interceptions (id, initiator_id, provider, model, metadata, started_at)
VALUES ($1::uuid, $2::uuid, $3, $4, COALESCE($5::jsonb, '{}'::jsonb), $6)
INSERT INTO aibridge_interceptions (
id, initiator_id, provider, model, metadata, started_at
) VALUES (
$1, $2, $3, $4, COALESCE($5::jsonb, '{}'::jsonb), $6
)
RETURNING id, initiator_id, provider, model, started_at, metadata
`
@@ -306,12 +339,13 @@ func (q *sqlQuerier) InsertAIBridgeInterception(ctx context.Context, arg InsertA
return i, err
}
const insertAIBridgeTokenUsage = `-- name: InsertAIBridgeTokenUsage :exec
const insertAIBridgeTokenUsage = `-- name: InsertAIBridgeTokenUsage :one
INSERT INTO aibridge_token_usages (
id, interception_id, provider_response_id, input_tokens, output_tokens, metadata, created_at
) VALUES (
$1, $2, $3, $4, $5, COALESCE($6::jsonb, '{}'::jsonb), $7
)
RETURNING id, interception_id, provider_response_id, input_tokens, output_tokens, metadata, created_at
`
type InsertAIBridgeTokenUsageParams struct {
@@ -324,8 +358,8 @@ type InsertAIBridgeTokenUsageParams struct {
CreatedAt time.Time `db:"created_at" json:"created_at"`
}
func (q *sqlQuerier) InsertAIBridgeTokenUsage(ctx context.Context, arg InsertAIBridgeTokenUsageParams) error {
_, err := q.db.ExecContext(ctx, insertAIBridgeTokenUsage,
func (q *sqlQuerier) InsertAIBridgeTokenUsage(ctx context.Context, arg InsertAIBridgeTokenUsageParams) (AIBridgeTokenUsage, error) {
row := q.db.QueryRowContext(ctx, insertAIBridgeTokenUsage,
arg.ID,
arg.InterceptionID,
arg.ProviderResponseID,
@@ -334,15 +368,26 @@ func (q *sqlQuerier) InsertAIBridgeTokenUsage(ctx context.Context, arg InsertAIB
arg.Metadata,
arg.CreatedAt,
)
return err
var i AIBridgeTokenUsage
err := row.Scan(
&i.ID,
&i.InterceptionID,
&i.ProviderResponseID,
&i.InputTokens,
&i.OutputTokens,
&i.Metadata,
&i.CreatedAt,
)
return i, err
}
const insertAIBridgeToolUsage = `-- name: InsertAIBridgeToolUsage :exec
const insertAIBridgeToolUsage = `-- name: InsertAIBridgeToolUsage :one
INSERT INTO aibridge_tool_usages (
id, interception_id, provider_response_id, tool, server_url, input, injected, invocation_error, metadata, created_at
) VALUES (
$1, $2, $3, $4, $5, $6, $7, $8, COALESCE($9::jsonb, '{}'::jsonb), $10
)
RETURNING id, interception_id, provider_response_id, server_url, tool, input, injected, invocation_error, metadata, created_at
`
type InsertAIBridgeToolUsageParams struct {
@@ -358,8 +403,8 @@ type InsertAIBridgeToolUsageParams struct {
CreatedAt time.Time `db:"created_at" json:"created_at"`
}
func (q *sqlQuerier) InsertAIBridgeToolUsage(ctx context.Context, arg InsertAIBridgeToolUsageParams) error {
_, err := q.db.ExecContext(ctx, insertAIBridgeToolUsage,
func (q *sqlQuerier) InsertAIBridgeToolUsage(ctx context.Context, arg InsertAIBridgeToolUsageParams) (AIBridgeToolUsage, error) {
row := q.db.QueryRowContext(ctx, insertAIBridgeToolUsage,
arg.ID,
arg.InterceptionID,
arg.ProviderResponseID,
@@ -371,15 +416,29 @@ func (q *sqlQuerier) InsertAIBridgeToolUsage(ctx context.Context, arg InsertAIBr
arg.Metadata,
arg.CreatedAt,
)
return err
var i AIBridgeToolUsage
err := row.Scan(
&i.ID,
&i.InterceptionID,
&i.ProviderResponseID,
&i.ServerUrl,
&i.Tool,
&i.Input,
&i.Injected,
&i.InvocationError,
&i.Metadata,
&i.CreatedAt,
)
return i, err
}
const insertAIBridgeUserPrompt = `-- name: InsertAIBridgeUserPrompt :exec
const insertAIBridgeUserPrompt = `-- name: InsertAIBridgeUserPrompt :one
INSERT INTO aibridge_user_prompts (
id, interception_id, provider_response_id, prompt, metadata, created_at
) VALUES (
$1, $2, $3, $4, COALESCE($5::jsonb, '{}'::jsonb), $6
)
RETURNING id, interception_id, provider_response_id, prompt, metadata, created_at
`
type InsertAIBridgeUserPromptParams struct {
@@ -391,8 +450,8 @@ type InsertAIBridgeUserPromptParams struct {
CreatedAt time.Time `db:"created_at" json:"created_at"`
}
func (q *sqlQuerier) InsertAIBridgeUserPrompt(ctx context.Context, arg InsertAIBridgeUserPromptParams) error {
_, err := q.db.ExecContext(ctx, insertAIBridgeUserPrompt,
func (q *sqlQuerier) InsertAIBridgeUserPrompt(ctx context.Context, arg InsertAIBridgeUserPromptParams) (AIBridgeUserPrompt, error) {
row := q.db.QueryRowContext(ctx, insertAIBridgeUserPrompt,
arg.ID,
arg.InterceptionID,
arg.ProviderResponseID,
@@ -400,7 +459,249 @@ func (q *sqlQuerier) InsertAIBridgeUserPrompt(ctx context.Context, arg InsertAIB
arg.Metadata,
arg.CreatedAt,
)
return err
var i AIBridgeUserPrompt
err := row.Scan(
&i.ID,
&i.InterceptionID,
&i.ProviderResponseID,
&i.Prompt,
&i.Metadata,
&i.CreatedAt,
)
return i, err
}
const listAIBridgeInterceptions = `-- name: ListAIBridgeInterceptions :many
SELECT
id, initiator_id, provider, model, started_at, metadata
FROM
aibridge_interceptions
WHERE
-- Filter by time frame
CASE
WHEN $1::timestamptz != '0001-01-01 00:00:00+00'::timestamptz THEN aibridge_interceptions.started_at >= $1::timestamptz
ELSE true
END
AND CASE
WHEN $2::timestamptz != '0001-01-01 00:00:00+00'::timestamptz THEN aibridge_interceptions.started_at <= $2::timestamptz
ELSE true
END
-- Filter initiator_id
AND CASE
WHEN $3::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN aibridge_interceptions.initiator_id = $3::uuid
ELSE true
END
-- Filter provider
AND CASE
WHEN $4::text != '' THEN aibridge_interceptions.provider = $4::text
ELSE true
END
-- Filter model
AND CASE
WHEN $5::text != '' THEN aibridge_interceptions.model = $5::text
ELSE true
END
-- Cursor pagination
AND CASE
WHEN $6::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN (
-- The pagination cursor is the last ID of the previous page.
-- The query is ordered by the started_at field, so select all
-- rows before the cursor and before the after_id UUID.
-- This uses a less than operator because we're sorting DESC. The
-- "after_id" terminology comes from our pagination parser in
-- coderd.
(aibridge_interceptions.started_at, aibridge_interceptions.id) < (
(SELECT started_at FROM aibridge_interceptions WHERE id = $6),
$6::uuid
)
)
ELSE true
END
-- Authorize Filter clause will be injected below in ListAuthorizedAIBridgeInterceptions
-- @authorize_filter
ORDER BY
aibridge_interceptions.started_at DESC,
aibridge_interceptions.id DESC
LIMIT COALESCE(NULLIF($7::integer, 0), 100)
`
type ListAIBridgeInterceptionsParams struct {
StartedAfter time.Time `db:"started_after" json:"started_after"`
StartedBefore time.Time `db:"started_before" json:"started_before"`
InitiatorID uuid.UUID `db:"initiator_id" json:"initiator_id"`
Provider string `db:"provider" json:"provider"`
Model string `db:"model" json:"model"`
AfterID uuid.UUID `db:"after_id" json:"after_id"`
Limit int32 `db:"limit_" json:"limit_"`
}
func (q *sqlQuerier) ListAIBridgeInterceptions(ctx context.Context, arg ListAIBridgeInterceptionsParams) ([]AIBridgeInterception, error) {
rows, err := q.db.QueryContext(ctx, listAIBridgeInterceptions,
arg.StartedAfter,
arg.StartedBefore,
arg.InitiatorID,
arg.Provider,
arg.Model,
arg.AfterID,
arg.Limit,
)
if err != nil {
return nil, err
}
defer rows.Close()
var items []AIBridgeInterception
for rows.Next() {
var i AIBridgeInterception
if err := rows.Scan(
&i.ID,
&i.InitiatorID,
&i.Provider,
&i.Model,
&i.StartedAt,
&i.Metadata,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const listAIBridgeTokenUsagesByInterceptionIDs = `-- name: ListAIBridgeTokenUsagesByInterceptionIDs :many
SELECT
id, interception_id, provider_response_id, input_tokens, output_tokens, metadata, created_at
FROM
aibridge_token_usages
WHERE
interception_id = ANY($1::uuid[])
ORDER BY
created_at ASC,
id ASC
`
func (q *sqlQuerier) ListAIBridgeTokenUsagesByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]AIBridgeTokenUsage, error) {
rows, err := q.db.QueryContext(ctx, listAIBridgeTokenUsagesByInterceptionIDs, pq.Array(interceptionIds))
if err != nil {
return nil, err
}
defer rows.Close()
var items []AIBridgeTokenUsage
for rows.Next() {
var i AIBridgeTokenUsage
if err := rows.Scan(
&i.ID,
&i.InterceptionID,
&i.ProviderResponseID,
&i.InputTokens,
&i.OutputTokens,
&i.Metadata,
&i.CreatedAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const listAIBridgeToolUsagesByInterceptionIDs = `-- name: ListAIBridgeToolUsagesByInterceptionIDs :many
SELECT
id, interception_id, provider_response_id, server_url, tool, input, injected, invocation_error, metadata, created_at
FROM
aibridge_tool_usages
WHERE
interception_id = ANY($1::uuid[])
ORDER BY
created_at ASC,
id ASC
`
func (q *sqlQuerier) ListAIBridgeToolUsagesByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]AIBridgeToolUsage, error) {
rows, err := q.db.QueryContext(ctx, listAIBridgeToolUsagesByInterceptionIDs, pq.Array(interceptionIds))
if err != nil {
return nil, err
}
defer rows.Close()
var items []AIBridgeToolUsage
for rows.Next() {
var i AIBridgeToolUsage
if err := rows.Scan(
&i.ID,
&i.InterceptionID,
&i.ProviderResponseID,
&i.ServerUrl,
&i.Tool,
&i.Input,
&i.Injected,
&i.InvocationError,
&i.Metadata,
&i.CreatedAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const listAIBridgeUserPromptsByInterceptionIDs = `-- name: ListAIBridgeUserPromptsByInterceptionIDs :many
SELECT
id, interception_id, provider_response_id, prompt, metadata, created_at
FROM
aibridge_user_prompts
WHERE
interception_id = ANY($1::uuid[])
ORDER BY
created_at ASC,
id ASC
`
func (q *sqlQuerier) ListAIBridgeUserPromptsByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]AIBridgeUserPrompt, error) {
rows, err := q.db.QueryContext(ctx, listAIBridgeUserPromptsByInterceptionIDs, pq.Array(interceptionIds))
if err != nil {
return nil, err
}
defer rows.Close()
var items []AIBridgeUserPrompt
for rows.Next() {
var i AIBridgeUserPrompt
if err := rows.Scan(
&i.ID,
&i.InterceptionID,
&i.ProviderResponseID,
&i.Prompt,
&i.Metadata,
&i.CreatedAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const deleteAPIKeyByID = `-- name: DeleteAPIKeyByID :exec
+136 -13
View File
@@ -1,40 +1,163 @@
-- name: InsertAIBridgeInterception :one
INSERT INTO aibridge_interceptions (id, initiator_id, provider, model, metadata, started_at)
VALUES (@id::uuid, @initiator_id::uuid, @provider, @model, COALESCE(@metadata::jsonb, '{}'::jsonb), @started_at)
INSERT INTO aibridge_interceptions (
id, initiator_id, provider, model, metadata, started_at
) VALUES (
@id, @initiator_id, @provider, @model, COALESCE(@metadata::jsonb, '{}'::jsonb), @started_at
)
RETURNING *;
-- name: InsertAIBridgeTokenUsage :exec
-- name: InsertAIBridgeTokenUsage :one
INSERT INTO aibridge_token_usages (
id, interception_id, provider_response_id, input_tokens, output_tokens, metadata, created_at
) VALUES (
@id, @interception_id, @provider_response_id, @input_tokens, @output_tokens, COALESCE(@metadata::jsonb, '{}'::jsonb), @created_at
);
)
RETURNING *;
-- name: InsertAIBridgeUserPrompt :exec
-- name: InsertAIBridgeUserPrompt :one
INSERT INTO aibridge_user_prompts (
id, interception_id, provider_response_id, prompt, metadata, created_at
) VALUES (
@id, @interception_id, @provider_response_id, @prompt, COALESCE(@metadata::jsonb, '{}'::jsonb), @created_at
);
)
RETURNING *;
-- name: InsertAIBridgeToolUsage :exec
-- name: InsertAIBridgeToolUsage :one
INSERT INTO aibridge_tool_usages (
id, interception_id, provider_response_id, tool, server_url, input, injected, invocation_error, metadata, created_at
) VALUES (
@id, @interception_id, @provider_response_id, @tool, @server_url, @input, @injected, @invocation_error, COALESCE(@metadata::jsonb, '{}'::jsonb), @created_at
);
)
RETURNING *;
-- name: GetAIBridgeInterceptionByID :one
SELECT * FROM aibridge_interceptions WHERE id = @id::uuid;
SELECT
*
FROM
aibridge_interceptions
WHERE
id = @id::uuid;
-- name: GetAIBridgeInterceptions :many
SELECT * FROM aibridge_interceptions;
SELECT
*
FROM
aibridge_interceptions;
-- name: GetAIBridgeTokenUsagesByInterceptionID :many
SELECT * FROM aibridge_token_usages WHERE interception_id = @interception_id::uuid;
SELECT
*
FROM
aibridge_token_usages WHERE interception_id = @interception_id::uuid
ORDER BY
created_at ASC,
id ASC;
-- name: GetAIBridgeUserPromptsByInterceptionID :many
SELECT * FROM aibridge_user_prompts WHERE interception_id = @interception_id::uuid;
SELECT
*
FROM
aibridge_user_prompts
WHERE
interception_id = @interception_id::uuid
ORDER BY
created_at ASC,
id ASC;
-- name: GetAIBridgeToolUsagesByInterceptionID :many
SELECT * FROM aibridge_tool_usages WHERE interception_id = @interception_id::uuid;
SELECT
*
FROM
aibridge_tool_usages
WHERE
interception_id = @interception_id::uuid
ORDER BY
created_at ASC,
id ASC;
-- name: ListAIBridgeInterceptions :many
SELECT
*
FROM
aibridge_interceptions
WHERE
-- Filter by time frame
CASE
WHEN @started_after::timestamptz != '0001-01-01 00:00:00+00'::timestamptz THEN aibridge_interceptions.started_at >= @started_after::timestamptz
ELSE true
END
AND CASE
WHEN @started_before::timestamptz != '0001-01-01 00:00:00+00'::timestamptz THEN aibridge_interceptions.started_at <= @started_before::timestamptz
ELSE true
END
-- Filter initiator_id
AND CASE
WHEN @initiator_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN aibridge_interceptions.initiator_id = @initiator_id::uuid
ELSE true
END
-- Filter provider
AND CASE
WHEN @provider::text != '' THEN aibridge_interceptions.provider = @provider::text
ELSE true
END
-- Filter model
AND CASE
WHEN @model::text != '' THEN aibridge_interceptions.model = @model::text
ELSE true
END
-- Cursor pagination
AND CASE
WHEN @after_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN (
-- The pagination cursor is the last ID of the previous page.
-- The query is ordered by the started_at field, so select all
-- rows before the cursor and before the after_id UUID.
-- This uses a less than operator because we're sorting DESC. The
-- "after_id" terminology comes from our pagination parser in
-- coderd.
(aibridge_interceptions.started_at, aibridge_interceptions.id) < (
(SELECT started_at FROM aibridge_interceptions WHERE id = @after_id),
@after_id::uuid
)
)
ELSE true
END
-- Authorize Filter clause will be injected below in ListAuthorizedAIBridgeInterceptions
-- @authorize_filter
ORDER BY
aibridge_interceptions.started_at DESC,
aibridge_interceptions.id DESC
LIMIT COALESCE(NULLIF(@limit_::integer, 0), 100)
;
-- name: ListAIBridgeTokenUsagesByInterceptionIDs :many
SELECT
*
FROM
aibridge_token_usages
WHERE
interception_id = ANY(@interception_ids::uuid[])
ORDER BY
created_at ASC,
id ASC;
-- name: ListAIBridgeUserPromptsByInterceptionIDs :many
SELECT
*
FROM
aibridge_user_prompts
WHERE
interception_id = ANY(@interception_ids::uuid[])
ORDER BY
created_at ASC,
id ASC;
-- name: ListAIBridgeToolUsagesByInterceptionIDs :many
SELECT
*
FROM
aibridge_tool_usages
WHERE
interception_id = ANY(@interception_ids::uuid[])
ORDER BY
created_at ASC,
id ASC;
+16 -1
View File
@@ -54,7 +54,7 @@ func AuditLogConverter() *sqltypes.VariableConverter {
matcher := sqltypes.NewVariableConverter().RegisterMatcher(
resourceIDMatcher(),
sqltypes.StringVarMatcher("COALESCE(audit_logs.organization_id :: text, '')", []string{"input", "object", "org_owner"}),
// Aduit logs have no user owner, only owner by an organization.
// Audit logs have no user owner, only owner by an organization.
sqltypes.AlwaysFalse(userOwnerMatcher()),
)
matcher.RegisterMatcher(
@@ -78,6 +78,21 @@ func ConnectionLogConverter() *sqltypes.VariableConverter {
return matcher
}
func AIBridgeInterceptionConverter() *sqltypes.VariableConverter {
matcher := sqltypes.NewVariableConverter().RegisterMatcher(
resourceIDMatcher(),
// AIBridge interceptions are not tied to any organization.
sqltypes.StringVarMatcher("''", []string{"input", "object", "org_owner"}),
sqltypes.StringVarMatcher("initiator_id :: text", []string{"input", "object", "owner"}),
)
matcher.RegisterMatcher(
// No ACLs on the aibridge interception type
sqltypes.AlwaysFalse(groupACLMatcher(matcher)),
sqltypes.AlwaysFalse(userACLMatcher(matcher)),
)
return matcher
}
func UserConverter() *sqltypes.VariableConverter {
matcher := sqltypes.NewVariableConverter().RegisterMatcher(
resourceIDMatcher(),
+2
View File
@@ -327,6 +327,8 @@ func ReloadBuiltinRoles(opts *RoleOptions) {
// Allow auditors to query deployment stats and insights.
ResourceDeploymentStats.Type: {policy.ActionRead},
ResourceDeploymentConfig.Type: {policy.ActionRead},
// Allow auditors to query aibridge interceptions.
ResourceAibridgeInterception.Type: {policy.ActionRead},
}),
Org: map[string][]Permission{},
User: []Permission{},
+46 -2
View File
@@ -226,7 +226,8 @@ func Workspaces(ctx context.Context, db database.Store, query string, page coder
filter.HasExternalAgent = parser.NullableBoolean(values, sql.NullBool{}, "has_external_agent")
filter.OrganizationID = parseOrganization(ctx, db, parser, values, "organization")
filter.Shared = parser.NullableBoolean(values, sql.NullBool{}, "shared")
filter.SharedWithUserID = parseUser(ctx, db, parser, values, "shared_with_user")
// TODO: support "me" by passing in the actorID
filter.SharedWithUserID = parseUser(ctx, db, parser, values, "shared_with_user", uuid.Nil)
filter.SharedWithGroupID = parseGroup(ctx, db, parser, values, "shared_with_group")
type paramMatch struct {
@@ -304,6 +305,46 @@ func Templates(ctx context.Context, db database.Store, actorID uuid.UUID, query
return filter, parser.Errors
}
func AIBridgeInterceptions(ctx context.Context, db database.Store, query string, page codersdk.Pagination, actorID uuid.UUID) (database.ListAIBridgeInterceptionsParams, []codersdk.ValidationError) {
// nolint:exhaustruct // Empty values just means "don't filter by that field".
filter := database.ListAIBridgeInterceptionsParams{
AfterID: page.AfterID,
// #nosec G115 - Safe conversion for pagination limit which is expected to be within int32 range
Limit: int32(page.Limit),
}
if query == "" {
return filter, nil
}
values, errors := searchTerms(query, func(term string, values url.Values) error {
// Default to the initiating user
values.Add("user", term)
return nil
})
if len(errors) > 0 {
return filter, errors
}
parser := httpapi.NewQueryParamParser()
filter.InitiatorID = parseUser(ctx, db, parser, values, "initiator", actorID)
filter.Provider = parser.String(values, "", "provider")
filter.Model = parser.String(values, "", "model")
// Time must be between started_after and started_before.
filter.StartedAfter = parser.Time3339Nano(values, time.Time{}, "started_after")
filter.StartedBefore = parser.Time3339Nano(values, time.Time{}, "started_before")
if !filter.StartedBefore.IsZero() && !filter.StartedAfter.IsZero() && !filter.StartedBefore.After(filter.StartedAfter) {
parser.Errors = append(parser.Errors, codersdk.ValidationError{
Field: "started_before",
Detail: `Query param "started_before" has invalid value: "started_before" must be after "started_after" if set`,
})
}
parser.ErrorExcessParams(values)
return filter, parser.Errors
}
func searchTerms(query string, defaultKey func(term string, values url.Values) error) (url.Values, []codersdk.ValidationError) {
searchValues := make(url.Values)
@@ -365,11 +406,14 @@ func parseOrganization(ctx context.Context, db database.Store, parser *httpapi.Q
})
}
func parseUser(ctx context.Context, db database.Store, parser *httpapi.QueryParamParser, vals url.Values, queryParam string) uuid.UUID {
func parseUser(ctx context.Context, db database.Store, parser *httpapi.QueryParamParser, vals url.Values, queryParam string, actorID uuid.UUID) uuid.UUID {
return httpapi.ParseCustom(parser, vals, uuid.Nil, queryParam, func(v string) (uuid.UUID, error) {
if v == "" {
return uuid.Nil, nil
}
if v == codersdk.Me && actorID != uuid.Nil {
return actorID, nil
}
userID, err := uuid.Parse(v)
if err == nil {
return userID, nil
+125
View File
@@ -0,0 +1,125 @@
package codersdk
import (
"context"
"encoding/json"
"fmt"
"net/http"
"strings"
"time"
"github.com/google/uuid"
)
type AIBridgeInterception struct {
ID uuid.UUID `json:"id" format:"uuid"`
InitiatorID uuid.UUID `json:"initiator_id" format:"uuid"`
Provider string `json:"provider"`
Model string `json:"model"`
Metadata map[string]any `json:"metadata"`
StartedAt time.Time `json:"started_at" format:"date-time"`
TokenUsages []AIBridgeTokenUsage `json:"token_usages"`
UserPrompts []AIBridgeUserPrompt `json:"user_prompts"`
ToolUsages []AIBridgeToolUsage `json:"tool_usages"`
}
type AIBridgeTokenUsage struct {
ID uuid.UUID `json:"id" format:"uuid"`
InterceptionID uuid.UUID `json:"interception_id" format:"uuid"`
ProviderResponseID string `json:"provider_response_id"`
InputTokens int64 `json:"input_tokens"`
OutputTokens int64 `json:"output_tokens"`
Metadata map[string]any `json:"metadata"`
CreatedAt time.Time `json:"created_at" format:"date-time"`
}
type AIBridgeUserPrompt struct {
ID uuid.UUID `json:"id" format:"uuid"`
InterceptionID uuid.UUID `json:"interception_id" format:"uuid"`
ProviderResponseID string `json:"provider_response_id"`
Prompt string `json:"prompt"`
Metadata map[string]any `json:"metadata"`
CreatedAt time.Time `json:"created_at" format:"date-time"`
}
type AIBridgeToolUsage struct {
ID uuid.UUID `json:"id" format:"uuid"`
InterceptionID uuid.UUID `json:"interception_id" format:"uuid"`
ProviderResponseID string `json:"provider_response_id"`
ServerURL string `json:"server_url"`
Tool string `json:"tool"`
Input string `json:"input"`
Injected bool `json:"injected"`
InvocationError string `json:"invocation_error"`
Metadata map[string]any `json:"metadata"`
CreatedAt time.Time `json:"created_at" format:"date-time"`
}
type AIBridgeListInterceptionsResponse struct {
Results []AIBridgeInterception `json:"results"`
}
// @typescript-ignore AIBridgeListInterceptionsFilter
type AIBridgeListInterceptionsFilter struct {
// Limit defaults to 100, max is 1000.
// Offset based pagination is not supported for AIBridge interceptions. Use
// cursor pagination instead with after_id.
Pagination Pagination `json:"pagination,omitempty"`
// Initiator is a user ID, username, or "me".
Initiator string `json:"initiator,omitempty"`
StartedBefore time.Time `json:"started_before,omitempty" format:"date-time"`
StartedAfter time.Time `json:"started_after,omitempty" format:"date-time"`
Provider string `json:"provider,omitempty"`
Model string `json:"model,omitempty"`
FilterQuery string `json:"q,omitempty"`
}
// asRequestOption returns a function that can be used in (*Client).Request.
// It modifies the request query parameters.
func (f AIBridgeListInterceptionsFilter) asRequestOption() RequestOption {
return func(r *http.Request) {
var params []string
// Make sure all user input is quoted to ensure it's parsed as a single
// string.
if f.Initiator != "" {
params = append(params, fmt.Sprintf("initiator:%q", f.Initiator))
}
if !f.StartedBefore.IsZero() {
params = append(params, fmt.Sprintf("started_before:%q", f.StartedBefore.Format(time.RFC3339Nano)))
}
if !f.StartedAfter.IsZero() {
params = append(params, fmt.Sprintf("started_after:%q", f.StartedAfter.Format(time.RFC3339Nano)))
}
if f.Provider != "" {
params = append(params, fmt.Sprintf("provider:%q", f.Provider))
}
if f.Model != "" {
params = append(params, fmt.Sprintf("model:%q", f.Model))
}
if f.FilterQuery != "" {
// If custom stuff is added, just add it on here.
params = append(params, f.FilterQuery)
}
q := r.URL.Query()
q.Set("q", strings.Join(params, " "))
r.URL.RawQuery = q.Encode()
}
}
// AIBridgeListInterceptions returns AIBridge interceptions with the given
// filter.
func (c *ExperimentalClient) AIBridgeListInterceptions(ctx context.Context, filter AIBridgeListInterceptionsFilter) (AIBridgeListInterceptionsResponse, error) {
res, err := c.Request(ctx, http.MethodGet, "/api/experimental/aibridge/interceptions", nil, filter.asRequestOption(), filter.Pagination.asRequestOption(), filter.Pagination.asRequestOption())
if err != nil {
return AIBridgeListInterceptionsResponse{}, err
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
return AIBridgeListInterceptionsResponse{}, ReadBodyAsError(res)
}
var resp AIBridgeListInterceptionsResponse
return resp, json.NewDecoder(res.Body).Decode(&resp)
}
+96
View File
@@ -0,0 +1,96 @@
# AIBridge
## List AIBridge interceptions
### Code samples
```shell
# Example request using curl
curl -X GET http://coder-server:8080/api/v2/api/experimental/aibridge/interceptions \
-H 'Accept: application/json' \
-H 'Coder-Session-Token: API_KEY'
```
`GET /api/experimental/aibridge/interceptions`
### Parameters
| Name | In | Type | Required | Description |
|------------|-------|---------|----------|------------------------------------------------------------------------------------------------------------------------|
| `q` | query | string | false | Search query in the format `key:value`. Available keys are: initiator, provider, model, started_after, started_before. |
| `limit` | query | integer | false | Page limit |
| `after_id` | query | string | false | Cursor pagination after ID |
### Example responses
> 200 Response
```json
{
"results": [
{
"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08",
"initiator_id": "06588898-9a84-4b35-ba8f-f9cbd64946f3",
"metadata": {
"property1": null,
"property2": null
},
"model": "string",
"provider": "string",
"started_at": "2019-08-24T14:15:22Z",
"token_usages": [
{
"created_at": "2019-08-24T14:15:22Z",
"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08",
"input_tokens": 0,
"interception_id": "34d9b688-63ad-46f4-88b5-665c1e7f7824",
"metadata": {
"property1": null,
"property2": null
},
"output_tokens": 0,
"provider_response_id": "string"
}
],
"tool_usages": [
{
"created_at": "2019-08-24T14:15:22Z",
"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08",
"injected": true,
"input": "string",
"interception_id": "34d9b688-63ad-46f4-88b5-665c1e7f7824",
"invocation_error": "string",
"metadata": {
"property1": null,
"property2": null
},
"provider_response_id": "string",
"server_url": "string",
"tool": "string"
}
],
"user_prompts": [
{
"created_at": "2019-08-24T14:15:22Z",
"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08",
"interception_id": "34d9b688-63ad-46f4-88b5-665c1e7f7824",
"metadata": {
"property1": null,
"property2": null
},
"prompt": "string",
"provider_response_id": "string"
}
]
}
]
}
```
### Responses
| Status | Meaning | Description | Schema |
|--------|---------------------------------------------------------|-------------|----------------------------------------------------------------------------------------------------|
| 200 | [OK](https://tools.ietf.org/html/rfc7231#section-6.3.1) | OK | [codersdk.AIBridgeListInterceptionsResponse](schemas.md#codersdkaibridgelistinterceptionsresponse) |
To perform this operation, you must be authenticated. [Learn more](authentication.md).
+239
View File
@@ -375,6 +375,151 @@
| `enabled` | boolean | false | | |
| `openai` | [codersdk.AIBridgeOpenAIConfig](#codersdkaibridgeopenaiconfig) | false | | |
## codersdk.AIBridgeInterception
```json
{
"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08",
"initiator_id": "06588898-9a84-4b35-ba8f-f9cbd64946f3",
"metadata": {
"property1": null,
"property2": null
},
"model": "string",
"provider": "string",
"started_at": "2019-08-24T14:15:22Z",
"token_usages": [
{
"created_at": "2019-08-24T14:15:22Z",
"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08",
"input_tokens": 0,
"interception_id": "34d9b688-63ad-46f4-88b5-665c1e7f7824",
"metadata": {
"property1": null,
"property2": null
},
"output_tokens": 0,
"provider_response_id": "string"
}
],
"tool_usages": [
{
"created_at": "2019-08-24T14:15:22Z",
"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08",
"injected": true,
"input": "string",
"interception_id": "34d9b688-63ad-46f4-88b5-665c1e7f7824",
"invocation_error": "string",
"metadata": {
"property1": null,
"property2": null
},
"provider_response_id": "string",
"server_url": "string",
"tool": "string"
}
],
"user_prompts": [
{
"created_at": "2019-08-24T14:15:22Z",
"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08",
"interception_id": "34d9b688-63ad-46f4-88b5-665c1e7f7824",
"metadata": {
"property1": null,
"property2": null
},
"prompt": "string",
"provider_response_id": "string"
}
]
}
```
### Properties
| Name | Type | Required | Restrictions | Description |
|--------------------|---------------------------------------------------------------------|----------|--------------|-------------|
| `id` | string | false | | |
| `initiator_id` | string | false | | |
| `metadata` | object | false | | |
| » `[any property]` | any | false | | |
| `model` | string | false | | |
| `provider` | string | false | | |
| `started_at` | string | false | | |
| `token_usages` | array of [codersdk.AIBridgeTokenUsage](#codersdkaibridgetokenusage) | false | | |
| `tool_usages` | array of [codersdk.AIBridgeToolUsage](#codersdkaibridgetoolusage) | false | | |
| `user_prompts` | array of [codersdk.AIBridgeUserPrompt](#codersdkaibridgeuserprompt) | false | | |
## codersdk.AIBridgeListInterceptionsResponse
```json
{
"results": [
{
"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08",
"initiator_id": "06588898-9a84-4b35-ba8f-f9cbd64946f3",
"metadata": {
"property1": null,
"property2": null
},
"model": "string",
"provider": "string",
"started_at": "2019-08-24T14:15:22Z",
"token_usages": [
{
"created_at": "2019-08-24T14:15:22Z",
"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08",
"input_tokens": 0,
"interception_id": "34d9b688-63ad-46f4-88b5-665c1e7f7824",
"metadata": {
"property1": null,
"property2": null
},
"output_tokens": 0,
"provider_response_id": "string"
}
],
"tool_usages": [
{
"created_at": "2019-08-24T14:15:22Z",
"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08",
"injected": true,
"input": "string",
"interception_id": "34d9b688-63ad-46f4-88b5-665c1e7f7824",
"invocation_error": "string",
"metadata": {
"property1": null,
"property2": null
},
"provider_response_id": "string",
"server_url": "string",
"tool": "string"
}
],
"user_prompts": [
{
"created_at": "2019-08-24T14:15:22Z",
"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08",
"interception_id": "34d9b688-63ad-46f4-88b5-665c1e7f7824",
"metadata": {
"property1": null,
"property2": null
},
"prompt": "string",
"provider_response_id": "string"
}
]
}
]
}
```
### Properties
| Name | Type | Required | Restrictions | Description |
|-----------|-------------------------------------------------------------------------|----------|--------------|-------------|
| `results` | array of [codersdk.AIBridgeInterception](#codersdkaibridgeinterception) | false | | |
## codersdk.AIBridgeOpenAIConfig
```json
@@ -391,6 +536,100 @@
| `base_url` | string | false | | |
| `key` | string | false | | |
## codersdk.AIBridgeTokenUsage
```json
{
"created_at": "2019-08-24T14:15:22Z",
"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08",
"input_tokens": 0,
"interception_id": "34d9b688-63ad-46f4-88b5-665c1e7f7824",
"metadata": {
"property1": null,
"property2": null
},
"output_tokens": 0,
"provider_response_id": "string"
}
```
### Properties
| Name | Type | Required | Restrictions | Description |
|------------------------|---------|----------|--------------|-------------|
| `created_at` | string | false | | |
| `id` | string | false | | |
| `input_tokens` | integer | false | | |
| `interception_id` | string | false | | |
| `metadata` | object | false | | |
| » `[any property]` | any | false | | |
| `output_tokens` | integer | false | | |
| `provider_response_id` | string | false | | |
## codersdk.AIBridgeToolUsage
```json
{
"created_at": "2019-08-24T14:15:22Z",
"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08",
"injected": true,
"input": "string",
"interception_id": "34d9b688-63ad-46f4-88b5-665c1e7f7824",
"invocation_error": "string",
"metadata": {
"property1": null,
"property2": null
},
"provider_response_id": "string",
"server_url": "string",
"tool": "string"
}
```
### Properties
| Name | Type | Required | Restrictions | Description |
|------------------------|---------|----------|--------------|-------------|
| `created_at` | string | false | | |
| `id` | string | false | | |
| `injected` | boolean | false | | |
| `input` | string | false | | |
| `interception_id` | string | false | | |
| `invocation_error` | string | false | | |
| `metadata` | object | false | | |
| » `[any property]` | any | false | | |
| `provider_response_id` | string | false | | |
| `server_url` | string | false | | |
| `tool` | string | false | | |
## codersdk.AIBridgeUserPrompt
```json
{
"created_at": "2019-08-24T14:15:22Z",
"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08",
"interception_id": "34d9b688-63ad-46f4-88b5-665c1e7f7824",
"metadata": {
"property1": null,
"property2": null
},
"prompt": "string",
"provider_response_id": "string"
}
```
### Properties
| Name | Type | Required | Restrictions | Description |
|------------------------|--------|----------|--------------|-------------|
| `created_at` | string | false | | |
| `id` | string | false | | |
| `interception_id` | string | false | | |
| `metadata` | object | false | | |
| » `[any property]` | any | false | | |
| `prompt` | string | false | | |
| `provider_response_id` | string | false | | |
## codersdk.AIConfig
```json
+159
View File
@@ -0,0 +1,159 @@
package coderd
import (
"context"
"fmt"
"net/http"
"github.com/google/uuid"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/db2sdk"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/coderd/httpmw"
"github.com/coder/coder/v2/coderd/searchquery"
"github.com/coder/coder/v2/codersdk"
)
const (
maxListInterceptionsLimit = 1000
defaultListInterceptionsLimit = 100
)
// aiBridgeListInterceptions returns all AIBridge interceptions a user can read.
// Optional filters with query params
//
// @Summary List AIBridge interceptions
// @ID list-aibridge-interceptions
// @Security CoderSessionToken
// @Produce json
// @Tags AIBridge
// @Param q query string false "Search query in the format `key:value`. Available keys are: initiator, provider, model, started_after, started_before."
// @Param limit query int false "Page limit"
// @Param after_id query string false "Cursor pagination after ID"
// @Success 200 {object} codersdk.AIBridgeListInterceptionsResponse
// @Router /api/experimental/aibridge/interceptions [get]
func (api *API) aiBridgeListInterceptions(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
apiKey := httpmw.APIKey(r)
page, ok := coderd.ParsePagination(rw, r)
if !ok {
return
}
if page.Offset != 0 {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Offset pagination is not supported.",
Detail: "Offset pagination is not supported for AIBridge interceptions. Use cursor pagination instead with after_id..",
})
return
}
if page.Limit == 0 {
page.Limit = defaultListInterceptionsLimit
}
if page.Limit > maxListInterceptionsLimit || page.Limit < 1 {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid pagination limit value.",
Detail: fmt.Sprintf("Pagination limit must be in range (0, %d]", maxListInterceptionsLimit),
})
return
}
queryStr := r.URL.Query().Get("q")
filter, errs := searchquery.AIBridgeInterceptions(ctx, api.Database, queryStr, page, apiKey.UserID)
if len(errs) > 0 {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid workspace search query.",
Validations: errs,
})
return
}
var rows []database.AIBridgeInterception
err := api.Database.InTx(func(db database.Store) error {
// Ensure the after_id interception exists and is visible to the user.
if page.AfterID != uuid.Nil {
_, err := db.GetAIBridgeInterceptionByID(ctx, page.AfterID)
if err != nil {
return xerrors.Errorf("get aibridge interception by id %s for cursor pagination: %w", page.AfterID, err)
}
}
var err error
// This only returns authorized interceptions (when using dbauthz).
rows, err = db.ListAIBridgeInterceptions(ctx, filter)
if err != nil {
return xerrors.Errorf("list aibridge interceptions: %w", err)
}
return nil
}, nil)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error getting AIBridge interceptions.",
Detail: err.Error(),
})
return
}
// This fetches the other rows associated with the interceptions.
items, err := populatedAndConvertAIBridgeInterceptions(ctx, api.Database, rows)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error converting database rows to API response.",
Detail: err.Error(),
})
return
}
httpapi.Write(ctx, rw, http.StatusOK, codersdk.AIBridgeListInterceptionsResponse{
Results: items,
})
}
func populatedAndConvertAIBridgeInterceptions(ctx context.Context, db database.Store, dbInterceptions []database.AIBridgeInterception) ([]codersdk.AIBridgeInterception, error) {
ids := make([]uuid.UUID, len(dbInterceptions))
for i, row := range dbInterceptions {
ids[i] = row.ID
}
//nolint:gocritic // This is a system function until we implement a join for aibridge interceptions. AIBridge interception subresources use the same authorization call as their parent.
tokenUsagesRows, err := db.ListAIBridgeTokenUsagesByInterceptionIDs(dbauthz.AsSystemRestricted(ctx), ids)
if err != nil {
return nil, xerrors.Errorf("get linked aibridge token usages from database: %w", err)
}
tokenUsagesMap := make(map[uuid.UUID][]database.AIBridgeTokenUsage, len(dbInterceptions))
for _, row := range tokenUsagesRows {
tokenUsagesMap[row.InterceptionID] = append(tokenUsagesMap[row.InterceptionID], row)
}
//nolint:gocritic // This is a system function until we implement a join for aibridge interceptions. AIBridge interception subresources use the same authorization call as their parent.
userPromptRows, err := db.ListAIBridgeUserPromptsByInterceptionIDs(dbauthz.AsSystemRestricted(ctx), ids)
if err != nil {
return nil, xerrors.Errorf("get linked aibridge user prompts from database: %w", err)
}
userPromptsMap := make(map[uuid.UUID][]database.AIBridgeUserPrompt, len(dbInterceptions))
for _, row := range userPromptRows {
userPromptsMap[row.InterceptionID] = append(userPromptsMap[row.InterceptionID], row)
}
//nolint:gocritic // This is a system function until we implement a join for aibridge interceptions. AIBridge interception subresources use the same authorization call as their parent.
toolUsagesRows, err := db.ListAIBridgeToolUsagesByInterceptionIDs(dbauthz.AsSystemRestricted(ctx), ids)
if err != nil {
return nil, xerrors.Errorf("get linked aibridge tool usages from database: %w", err)
}
toolUsagesMap := make(map[uuid.UUID][]database.AIBridgeToolUsage, len(dbInterceptions))
for _, row := range toolUsagesRows {
toolUsagesMap[row.InterceptionID] = append(toolUsagesMap[row.InterceptionID], row)
}
items := make([]codersdk.AIBridgeInterception, len(dbInterceptions))
for i, row := range dbInterceptions {
items[i] = db2sdk.AIBridgeInterception(row, tokenUsagesMap[row.ID], userPromptsMap[row.ID], toolUsagesMap[row.ID])
}
return items, nil
}
+478
View File
@@ -0,0 +1,478 @@
package coderd_test
import (
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/db2sdk"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/cryptorand"
"github.com/coder/coder/v2/enterprise/coderd/coderdenttest"
"github.com/coder/coder/v2/testutil"
)
func TestAIBridgeListInterceptions(t *testing.T) {
t.Parallel()
t.Run("EmptyDB", func(t *testing.T) {
t.Parallel()
dv := coderdtest.DeploymentValues(t)
dv.Experiments = []string{string(codersdk.ExperimentAIBridge)}
client, _ := coderdenttest.New(t, &coderdenttest.Options{
Options: &coderdtest.Options{
DeploymentValues: dv,
},
})
experimentalClient := codersdk.NewExperimentalClient(client)
ctx := testutil.Context(t, testutil.WaitLong)
res, err := experimentalClient.AIBridgeListInterceptions(ctx, codersdk.AIBridgeListInterceptionsFilter{})
require.NoError(t, err)
require.Empty(t, res.Results)
})
t.Run("OK", func(t *testing.T) {
t.Parallel()
dv := coderdtest.DeploymentValues(t)
dv.Experiments = []string{string(codersdk.ExperimentAIBridge)}
client, db, _ := coderdenttest.NewWithDatabase(t, &coderdenttest.Options{
Options: &coderdtest.Options{
DeploymentValues: dv,
},
})
experimentalClient := codersdk.NewExperimentalClient(client)
ctx := testutil.Context(t, testutil.WaitLong)
// Insert a bunch of test data.
now := dbtime.Now()
i1 := dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{
StartedAt: now.Add(-time.Hour),
})
i1tok1 := dbgen.AIBridgeTokenUsage(t, db, database.InsertAIBridgeTokenUsageParams{
InterceptionID: i1.ID,
CreatedAt: now,
})
i1tok2 := dbgen.AIBridgeTokenUsage(t, db, database.InsertAIBridgeTokenUsageParams{
InterceptionID: i1.ID,
CreatedAt: now.Add(-time.Minute),
})
i1up1 := dbgen.AIBridgeUserPrompt(t, db, database.InsertAIBridgeUserPromptParams{
InterceptionID: i1.ID,
CreatedAt: now,
})
i1up2 := dbgen.AIBridgeUserPrompt(t, db, database.InsertAIBridgeUserPromptParams{
InterceptionID: i1.ID,
CreatedAt: now.Add(-time.Minute),
})
i1tool1 := dbgen.AIBridgeToolUsage(t, db, database.InsertAIBridgeToolUsageParams{
InterceptionID: i1.ID,
CreatedAt: now,
})
i1tool2 := dbgen.AIBridgeToolUsage(t, db, database.InsertAIBridgeToolUsageParams{
InterceptionID: i1.ID,
CreatedAt: now.Add(-time.Minute),
})
i2 := dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{
StartedAt: now,
})
// Convert to SDK types for response comparison.
// You may notice that the ordering of the inner arrays are ASC, this is
// intentional.
i1SDK := db2sdk.AIBridgeInterception(i1, []database.AIBridgeTokenUsage{i1tok2, i1tok1}, []database.AIBridgeUserPrompt{i1up2, i1up1}, []database.AIBridgeToolUsage{i1tool2, i1tool1})
i2SDK := db2sdk.AIBridgeInterception(i2, nil, nil, nil)
res, err := experimentalClient.AIBridgeListInterceptions(ctx, codersdk.AIBridgeListInterceptionsFilter{})
require.NoError(t, err)
require.Len(t, res.Results, 2)
require.Equal(t, i2SDK.ID, res.Results[0].ID)
require.Equal(t, i1SDK.ID, res.Results[1].ID)
// Normalize timestamps in the response so we can compare the whole
// thing easily.
res.Results[0].StartedAt = i2SDK.StartedAt
res.Results[1].StartedAt = i1SDK.StartedAt
require.Len(t, res.Results[1].TokenUsages, 2)
require.Equal(t, i1SDK.TokenUsages[0].ID, res.Results[1].TokenUsages[0].ID)
require.Equal(t, i1SDK.TokenUsages[1].ID, res.Results[1].TokenUsages[1].ID)
res.Results[1].TokenUsages[0].CreatedAt = i1SDK.TokenUsages[0].CreatedAt
res.Results[1].TokenUsages[1].CreatedAt = i1SDK.TokenUsages[1].CreatedAt
require.Len(t, res.Results[1].UserPrompts, 2)
require.Equal(t, i1SDK.UserPrompts[0].ID, res.Results[1].UserPrompts[0].ID)
require.Equal(t, i1SDK.UserPrompts[1].ID, res.Results[1].UserPrompts[1].ID)
res.Results[1].UserPrompts[0].CreatedAt = i1SDK.UserPrompts[0].CreatedAt
res.Results[1].UserPrompts[1].CreatedAt = i1SDK.UserPrompts[1].CreatedAt
require.Len(t, res.Results[1].ToolUsages, 2)
require.Equal(t, i1SDK.ToolUsages[0].ID, res.Results[1].ToolUsages[0].ID)
require.Equal(t, i1SDK.ToolUsages[1].ID, res.Results[1].ToolUsages[1].ID)
res.Results[1].ToolUsages[0].CreatedAt = i1SDK.ToolUsages[0].CreatedAt
res.Results[1].ToolUsages[1].CreatedAt = i1SDK.ToolUsages[1].CreatedAt
require.Equal(t, []codersdk.AIBridgeInterception{i2SDK, i1SDK}, res.Results)
})
t.Run("Pagination", func(t *testing.T) {
t.Parallel()
dv := coderdtest.DeploymentValues(t)
dv.Experiments = []string{string(codersdk.ExperimentAIBridge)}
client, db, _ := coderdenttest.NewWithDatabase(t, &coderdenttest.Options{
Options: &coderdtest.Options{
DeploymentValues: dv,
},
})
experimentalClient := codersdk.NewExperimentalClient(client)
ctx := testutil.Context(t, testutil.WaitLong)
allInterceptionIDs := make([]uuid.UUID, 0, 20)
// Create 10 interceptions with the same started_at time. The returned
// order for these should still be deterministic.
now := dbtime.Now()
for i := range 10 {
interception := dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{
ID: uuid.UUID{byte(i)},
StartedAt: now,
})
allInterceptionIDs = append(allInterceptionIDs, interception.ID)
}
// Create 10 interceptions with a random started_at time.
for i := range 10 {
randomOffset, err := cryptorand.Intn(10000)
require.NoError(t, err)
randomOffsetDur := time.Duration(randomOffset) * time.Second
interception := dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{
ID: uuid.UUID{byte(i + 10)},
StartedAt: now.Add(randomOffsetDur),
})
allInterceptionIDs = append(allInterceptionIDs, interception.ID)
}
// Get all interceptions one by one from the API using cursor
// pagination.
getAllInterceptionsOneByOne := func() []uuid.UUID {
interceptionIDs := []uuid.UUID{}
for {
afterID := uuid.Nil
if len(interceptionIDs) > 0 {
afterID = interceptionIDs[len(interceptionIDs)-1]
}
res, err := experimentalClient.AIBridgeListInterceptions(ctx, codersdk.AIBridgeListInterceptionsFilter{
Pagination: codersdk.Pagination{
AfterID: afterID,
Limit: 1,
},
})
require.NoError(t, err)
if len(res.Results) == 0 {
break
}
require.Len(t, res.Results, 1)
interceptionIDs = append(interceptionIDs, res.Results[0].ID)
}
return interceptionIDs
}
// First attempt: get all interceptions one by one.
gotInterceptionIDs1 := getAllInterceptionsOneByOne()
// We should have all of the interceptions returned:
require.ElementsMatch(t, allInterceptionIDs, gotInterceptionIDs1)
// Second attempt: get all interceptions one by one again.
gotInterceptionIDs2 := getAllInterceptionsOneByOne()
// They should be returned in the exact same order.
require.Equal(t, gotInterceptionIDs1, gotInterceptionIDs2)
// Try to get an invalid limit.
res, err := experimentalClient.AIBridgeListInterceptions(ctx, codersdk.AIBridgeListInterceptionsFilter{
Pagination: codersdk.Pagination{
Limit: 1001,
},
})
var sdkErr *codersdk.Error
require.ErrorAs(t, err, &sdkErr)
require.Contains(t, sdkErr.Message, "Invalid pagination limit value.")
require.Empty(t, res.Results)
})
t.Run("Authorized", func(t *testing.T) {
t.Parallel()
dv := coderdtest.DeploymentValues(t)
dv.Experiments = []string{string(codersdk.ExperimentAIBridge)}
adminClient, db, firstUser := coderdenttest.NewWithDatabase(t, &coderdenttest.Options{
Options: &coderdtest.Options{
DeploymentValues: dv,
},
})
adminExperimentalClient := codersdk.NewExperimentalClient(adminClient)
ctx := testutil.Context(t, testutil.WaitLong)
secondUserClient, secondUser := coderdtest.CreateAnotherUser(t, adminClient, firstUser.OrganizationID)
secondUserExperimentalClient := codersdk.NewExperimentalClient(secondUserClient)
now := dbtime.Now()
i1 := dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{
InitiatorID: firstUser.UserID,
StartedAt: now,
})
i2 := dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{
InitiatorID: secondUser.ID,
StartedAt: now.Add(-time.Hour),
})
// Admin can see all interceptions.
res, err := adminExperimentalClient.AIBridgeListInterceptions(ctx, codersdk.AIBridgeListInterceptionsFilter{})
require.NoError(t, err)
require.Len(t, res.Results, 2)
require.Equal(t, i1.ID, res.Results[0].ID)
require.Equal(t, i2.ID, res.Results[1].ID)
// Second user can only see their own interceptions.
res, err = secondUserExperimentalClient.AIBridgeListInterceptions(ctx, codersdk.AIBridgeListInterceptionsFilter{})
require.NoError(t, err)
require.Len(t, res.Results, 1)
require.Equal(t, i2.ID, res.Results[0].ID)
})
t.Run("Filter", func(t *testing.T) {
t.Parallel()
dv := coderdtest.DeploymentValues(t)
dv.Experiments = []string{string(codersdk.ExperimentAIBridge)}
client, db, firstUser := coderdenttest.NewWithDatabase(t, &coderdenttest.Options{
Options: &coderdtest.Options{
DeploymentValues: dv,
},
})
experimentalClient := codersdk.NewExperimentalClient(client)
_, secondUser := coderdtest.CreateAnotherUser(t, client, firstUser.OrganizationID)
// Insert a bunch of test data with varying filterable fields.
now := dbtime.Now()
i1 := dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{
ID: uuid.MustParse("00000000-0000-0000-0000-000000000001"),
InitiatorID: firstUser.UserID,
Provider: "one",
Model: "one",
StartedAt: now,
})
i2 := dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{
ID: uuid.MustParse("00000000-0000-0000-0000-000000000002"),
InitiatorID: firstUser.UserID,
Provider: "two",
Model: "two",
StartedAt: now.Add(-time.Hour),
})
i3 := dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{
ID: uuid.MustParse("00000000-0000-0000-0000-000000000003"),
InitiatorID: secondUser.ID,
Provider: "three",
Model: "three",
StartedAt: now.Add(-2 * time.Hour),
})
// Convert to SDK types for response comparison. We don't care about the
// inner arrays for this test.
i1SDK := db2sdk.AIBridgeInterception(i1, nil, nil, nil)
i2SDK := db2sdk.AIBridgeInterception(i2, nil, nil, nil)
i3SDK := db2sdk.AIBridgeInterception(i3, nil, nil, nil)
cases := []struct {
name string
filter codersdk.AIBridgeListInterceptionsFilter
want []codersdk.AIBridgeInterception
}{
{
name: "NoFilter",
filter: codersdk.AIBridgeListInterceptionsFilter{},
want: []codersdk.AIBridgeInterception{i1SDK, i2SDK, i3SDK},
},
{
name: "Initiator/NoMatch",
filter: codersdk.AIBridgeListInterceptionsFilter{Initiator: uuid.New().String()},
want: []codersdk.AIBridgeInterception{},
},
{
name: "Initiator/Me",
filter: codersdk.AIBridgeListInterceptionsFilter{Initiator: codersdk.Me},
want: []codersdk.AIBridgeInterception{i1SDK, i2SDK},
},
{
name: "Initiator/UserID",
filter: codersdk.AIBridgeListInterceptionsFilter{Initiator: secondUser.ID.String()},
want: []codersdk.AIBridgeInterception{i3SDK},
},
{
name: "Initiator/Username",
filter: codersdk.AIBridgeListInterceptionsFilter{Initiator: secondUser.Username},
want: []codersdk.AIBridgeInterception{i3SDK},
},
{
name: "Provider/NoMatch",
filter: codersdk.AIBridgeListInterceptionsFilter{Provider: "nonsense"},
want: []codersdk.AIBridgeInterception{},
},
{
name: "Provider/OK",
filter: codersdk.AIBridgeListInterceptionsFilter{Provider: "two"},
want: []codersdk.AIBridgeInterception{i2SDK},
},
{
name: "Model/NoMatch",
filter: codersdk.AIBridgeListInterceptionsFilter{Model: "nonsense"},
want: []codersdk.AIBridgeInterception{},
},
{
name: "Model/OK",
filter: codersdk.AIBridgeListInterceptionsFilter{Model: "three"},
want: []codersdk.AIBridgeInterception{i3SDK},
},
{
name: "StartedAfter/NoMatch",
filter: codersdk.AIBridgeListInterceptionsFilter{
StartedAfter: i1.StartedAt.Add(10 * time.Minute),
},
want: []codersdk.AIBridgeInterception{},
},
{
name: "StartedAfter/OK",
filter: codersdk.AIBridgeListInterceptionsFilter{
StartedAfter: i2.StartedAt.Add(-10 * time.Minute),
},
want: []codersdk.AIBridgeInterception{i1SDK, i2SDK},
},
{
name: "StartedBefore/NoMatch",
filter: codersdk.AIBridgeListInterceptionsFilter{
StartedBefore: i3.StartedAt.Add(-10 * time.Minute),
},
want: []codersdk.AIBridgeInterception{},
},
{
name: "StartedBefore/OK",
filter: codersdk.AIBridgeListInterceptionsFilter{
StartedBefore: i3.StartedAt.Add(10 * time.Minute),
},
want: []codersdk.AIBridgeInterception{i3SDK},
},
{
name: "BothBeforeAndAfter/NoMatch",
filter: codersdk.AIBridgeListInterceptionsFilter{
StartedAfter: i1.StartedAt.Add(10 * time.Minute),
StartedBefore: i1.StartedAt.Add(20 * time.Minute),
},
want: []codersdk.AIBridgeInterception{},
},
{
name: "BothBeforeAndAfter/OK",
filter: codersdk.AIBridgeListInterceptionsFilter{
StartedAfter: i2.StartedAt.Add(-10 * time.Minute),
StartedBefore: i2.StartedAt.Add(10 * time.Minute),
},
want: []codersdk.AIBridgeInterception{i2SDK},
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
res, err := experimentalClient.AIBridgeListInterceptions(ctx, tc.filter)
require.NoError(t, err)
// We just compare UUID strings for the sake of this test.
wantIDs := make([]string, len(tc.want))
for i, r := range tc.want {
wantIDs[i] = r.ID.String()
}
gotIDs := make([]string, len(res.Results))
for i, r := range res.Results {
gotIDs[i] = r.ID.String()
}
require.Equal(t, wantIDs, gotIDs)
})
}
})
t.Run("FilterErrors", func(t *testing.T) {
t.Parallel()
dv := coderdtest.DeploymentValues(t)
dv.Experiments = []string{string(codersdk.ExperimentAIBridge)}
client, _ := coderdenttest.New(t, &coderdenttest.Options{
Options: &coderdtest.Options{
DeploymentValues: dv,
},
})
experimentalClient := codersdk.NewExperimentalClient(client)
// No need to insert any test data, we're just testing the filter
// errors.
cases := []struct {
name string
q string
want []codersdk.ValidationError
}{
{
name: "UnknownUsername",
q: "initiator:unknown",
want: []codersdk.ValidationError{
{
Field: "initiator",
Detail: `Query param "initiator" has invalid value: user "unknown" either does not exist, or you are unauthorized to view them`,
},
},
},
{
name: "InvalidStartedAfter",
q: "started_after:invalid",
want: []codersdk.ValidationError{
{
Field: "started_after",
Detail: `Query param "started_after" must be a valid date format (2006-01-02T15:04:05.999999999Z07:00): parsing time "INVALID" as "2006-01-02T15:04:05.999999999Z07:00": cannot parse "INVALID" as "2006"`,
},
},
},
{
name: "InvalidStartedBefore",
q: "started_before:invalid",
want: []codersdk.ValidationError{
{
Field: "started_before",
Detail: `Query param "started_before" must be a valid date format (2006-01-02T15:04:05.999999999Z07:00): parsing time "INVALID" as "2006-01-02T15:04:05.999999999Z07:00": cannot parse "INVALID" as "2006"`,
},
},
},
{
name: "InvalidBeforeAfterRange",
// Before MUST be after After if both are set
q: `started_after:"2025-01-01T00:00:00Z" started_before:"2024-01-01T00:00:00Z"`,
want: []codersdk.ValidationError{
{
Field: "started_before",
Detail: `Query param "started_before" has invalid value: "started_before" must be after "started_after" if set`,
},
},
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
res, err := experimentalClient.AIBridgeListInterceptions(ctx, codersdk.AIBridgeListInterceptionsFilter{
FilterQuery: tc.q,
})
var sdkErr *codersdk.Error
require.ErrorAs(t, err, &sdkErr)
require.Equal(t, tc.want, sdkErr.Validations)
require.Empty(t, res.Results)
})
}
})
}
+2 -17
View File
@@ -6,16 +6,13 @@ import (
"io"
"net/http"
"github.com/go-chi/chi/v5"
"golang.org/x/xerrors"
"storj.io/drpc/drpcmux"
"storj.io/drpc/drpcserver"
"cdr.dev/slog"
"github.com/coder/coder/v2/coderd/httpmw"
"github.com/coder/coder/v2/coderd/tracing"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/drpcsdk"
"github.com/coder/coder/v2/enterprise/x/aibridged"
aibridgedproto "github.com/coder/coder/v2/enterprise/x/aibridged/proto"
@@ -25,24 +22,12 @@ import (
// RegisterInMemoryAIBridgedHTTPHandler mounts [aibridged.Server]'s HTTP router onto
// [API]'s router, so that requests to aibridged will be relayed from Coder's API server
// to the in-memory aibridged.
func (api *API) RegisterInMemoryAIBridgedHTTPHandler(srv *aibridged.Server) {
func (api *API) RegisterInMemoryAIBridgedHTTPHandler(srv http.Handler) {
if srv == nil {
panic("aibridged cannot be nil")
}
if api.AGPL.RootHandler == nil {
panic("api.RootHandler cannot be nil")
}
aibridgeEndpoint := "/api/experimental/aibridge"
r := chi.NewRouter()
r.Group(func(r chi.Router) {
r.Use(httpmw.RequireExperiment(api.AGPL.Experiments, codersdk.ExperimentAIBridge))
r.HandleFunc("/*", http.StripPrefix(aibridgeEndpoint, srv).ServeHTTP)
})
api.AGPL.RootHandler.Mount(aibridgeEndpoint, r)
api.aibridgedHandler = srv
}
// CreateInMemoryAIBridgeServer creates a [aibridged.DRPCServer] and returns a
+26
View File
@@ -226,6 +226,30 @@ func New(ctx context.Context, options *Options) (_ *API, err error) {
return api.refreshEntitlements(ctx)
}
api.AGPL.ExperimentalHandler.Group(func(r chi.Router) {
r.Route("/aibridge", func(r chi.Router) {
r.Use(
httpmw.RequireExperimentWithDevBypass(api.AGPL.Experiments, codersdk.ExperimentAIBridge),
)
r.Group(func(r chi.Router) {
r.Use(apiKeyMiddleware)
r.Get("/interceptions", api.aiBridgeListInterceptions)
})
// This is a bit funky but since aibridge only exposes a HTTP
// handler, this is how it has to be.
r.HandleFunc("/*", func(rw http.ResponseWriter, r *http.Request) {
if api.aibridgedHandler == nil {
httpapi.Write(r.Context(), rw, http.StatusNotFound, codersdk.Response{
Message: "aibridged handler not mounted",
})
return
}
api.aibridgedHandler.ServeHTTP(rw, r)
})
})
})
api.AGPL.APIHandler.Group(func(r chi.Router) {
r.Get("/entitlements", api.serveEntitlements)
// /regions overrides the AGPL /regions endpoint
@@ -677,6 +701,8 @@ type API struct {
licenseMetricsCollector *license.MetricsCollector
tailnetService *tailnet.ClientService
aibridgedHandler http.Handler
}
// writeEntitlementWarningsHeader writes the entitlement warnings to the response header
@@ -53,9 +53,9 @@ var _ aibridged.DRPCServer = &Server{}
type store interface {
// Recorder-related queries.
InsertAIBridgeInterception(ctx context.Context, arg database.InsertAIBridgeInterceptionParams) (database.AIBridgeInterception, error)
InsertAIBridgeTokenUsage(ctx context.Context, arg database.InsertAIBridgeTokenUsageParams) error
InsertAIBridgeUserPrompt(ctx context.Context, arg database.InsertAIBridgeUserPromptParams) error
InsertAIBridgeToolUsage(ctx context.Context, arg database.InsertAIBridgeToolUsageParams) error
InsertAIBridgeTokenUsage(ctx context.Context, arg database.InsertAIBridgeTokenUsageParams) (database.AIBridgeTokenUsage, error)
InsertAIBridgeUserPrompt(ctx context.Context, arg database.InsertAIBridgeUserPromptParams) (database.AIBridgeUserPrompt, error)
InsertAIBridgeToolUsage(ctx context.Context, arg database.InsertAIBridgeToolUsageParams) (database.AIBridgeToolUsage, error)
// MCPConfigurator-related queries.
GetExternalAuthLinksByUserID(ctx context.Context, userID uuid.UUID) ([]database.ExternalAuthLink, error)
@@ -139,7 +139,7 @@ func (s *Server) RecordTokenUsage(ctx context.Context, in *proto.RecordTokenUsag
return nil, xerrors.Errorf("failed to parse interception_id %q: %w", in.GetInterceptionId(), err)
}
err = s.store.InsertAIBridgeTokenUsage(ctx, database.InsertAIBridgeTokenUsageParams{
_, err = s.store.InsertAIBridgeTokenUsage(ctx, database.InsertAIBridgeTokenUsageParams{
ID: uuid.New(),
InterceptionID: intcID,
ProviderResponseID: in.GetMsgId(),
@@ -163,7 +163,7 @@ func (s *Server) RecordPromptUsage(ctx context.Context, in *proto.RecordPromptUs
return nil, xerrors.Errorf("failed to parse interception_id %q: %w", in.GetInterceptionId(), err)
}
err = s.store.InsertAIBridgeUserPrompt(ctx, database.InsertAIBridgeUserPromptParams{
_, err = s.store.InsertAIBridgeUserPrompt(ctx, database.InsertAIBridgeUserPromptParams{
ID: uuid.New(),
InterceptionID: intcID,
ProviderResponseID: in.GetMsgId(),
@@ -186,7 +186,7 @@ func (s *Server) RecordToolUsage(ctx context.Context, in *proto.RecordToolUsageR
return nil, xerrors.Errorf("failed to parse interception_id %q: %w", in.GetInterceptionId(), err)
}
err = s.store.InsertAIBridgeToolUsage(ctx, database.InsertAIBridgeToolUsageParams{
_, err = s.store.InsertAIBridgeToolUsage(ctx, database.InsertAIBridgeToolUsageParams{
ID: uuid.New(),
InterceptionID: intcID,
ProviderResponseID: in.GetMsgId(),
@@ -464,7 +464,18 @@ func TestRecordTokenUsage(t *testing.T) {
return false
}
return true
})).Return(nil)
})).Return(database.AIBridgeTokenUsage{
ID: uuid.New(),
InterceptionID: interceptionID,
ProviderResponseID: req.GetMsgId(),
InputTokens: req.GetInputTokens(),
OutputTokens: req.GetOutputTokens(),
Metadata: pqtype.NullRawMessage{
RawMessage: json.RawMessage(metadataJSON),
Valid: true,
},
CreatedAt: req.GetCreatedAt().AsTime(),
}, nil)
},
},
{
@@ -488,7 +499,7 @@ func TestRecordTokenUsage(t *testing.T) {
CreatedAt: timestamppb.Now(),
},
setupMocks: func(t *testing.T, db *dbmock.MockStore, req *proto.RecordTokenUsageRequest) {
db.EXPECT().InsertAIBridgeTokenUsage(gomock.Any(), gomock.Any()).Return(sql.ErrConnDone)
db.EXPECT().InsertAIBridgeTokenUsage(gomock.Any(), gomock.Any()).Return(database.AIBridgeTokenUsage{}, sql.ErrConnDone)
},
expectedErr: "insert token usage",
},
@@ -534,7 +545,17 @@ func TestRecordPromptUsage(t *testing.T) {
return false
}
return true
})).Return(nil)
})).Return(database.AIBridgeUserPrompt{
ID: uuid.New(),
InterceptionID: interceptionID,
ProviderResponseID: req.GetMsgId(),
Prompt: req.GetPrompt(),
Metadata: pqtype.NullRawMessage{
RawMessage: json.RawMessage(metadataJSON),
Valid: true,
},
CreatedAt: req.GetCreatedAt().AsTime(),
}, nil)
},
},
{
@@ -556,7 +577,7 @@ func TestRecordPromptUsage(t *testing.T) {
CreatedAt: timestamppb.Now(),
},
setupMocks: func(t *testing.T, db *dbmock.MockStore, req *proto.RecordPromptUsageRequest) {
db.EXPECT().InsertAIBridgeUserPrompt(gomock.Any(), gomock.Any()).Return(sql.ErrConnDone)
db.EXPECT().InsertAIBridgeUserPrompt(gomock.Any(), gomock.Any()).Return(database.AIBridgeUserPrompt{}, sql.ErrConnDone)
},
expectedErr: "insert user prompt",
},
@@ -622,7 +643,21 @@ func TestRecordToolUsage(t *testing.T) {
return false
}
return true
})).Return(nil)
})).Return(database.AIBridgeToolUsage{
ID: uuid.New(),
InterceptionID: interceptionID,
ProviderResponseID: req.GetMsgId(),
Tool: req.GetTool(),
ServerUrl: dbServerURL,
Input: req.GetInput(),
Injected: req.GetInjected(),
InvocationError: dbInvocationError,
Metadata: pqtype.NullRawMessage{
RawMessage: json.RawMessage(metadataJSON),
Valid: true,
},
CreatedAt: req.GetCreatedAt().AsTime(),
}, nil)
},
},
{
@@ -646,7 +681,7 @@ func TestRecordToolUsage(t *testing.T) {
CreatedAt: timestamppb.Now(),
},
setupMocks: func(t *testing.T, db *dbmock.MockStore, req *proto.RecordToolUsageRequest) {
db.EXPECT().InsertAIBridgeToolUsage(gomock.Any(), gomock.Any()).Return(sql.ErrConnDone)
db.EXPECT().InsertAIBridgeToolUsage(gomock.Any(), gomock.Any()).Return(database.AIBridgeToolUsage{}, sql.ErrConnDone)
},
expectedErr: "insert tool usage",
},
+57
View File
@@ -19,12 +19,69 @@ export interface AIBridgeConfig {
readonly anthropic: AIBridgeAnthropicConfig;
}
// From codersdk/aibridge.go
export interface AIBridgeInterception {
readonly id: string;
readonly initiator_id: string;
readonly provider: string;
readonly model: string;
// empty interface{} type, falling back to unknown
readonly metadata: Record<string, unknown>;
readonly started_at: string;
readonly token_usages: readonly AIBridgeTokenUsage[];
readonly user_prompts: readonly AIBridgeUserPrompt[];
readonly tool_usages: readonly AIBridgeToolUsage[];
}
// From codersdk/aibridge.go
export interface AIBridgeListInterceptionsResponse {
readonly results: readonly AIBridgeInterception[];
}
// From codersdk/deployment.go
export interface AIBridgeOpenAIConfig {
readonly base_url: string;
readonly key: string;
}
// From codersdk/aibridge.go
export interface AIBridgeTokenUsage {
readonly id: string;
readonly interception_id: string;
readonly provider_response_id: string;
readonly input_tokens: number;
readonly output_tokens: number;
// empty interface{} type, falling back to unknown
readonly metadata: Record<string, unknown>;
readonly created_at: string;
}
// From codersdk/aibridge.go
export interface AIBridgeToolUsage {
readonly id: string;
readonly interception_id: string;
readonly provider_response_id: string;
readonly server_url: string;
readonly tool: string;
readonly input: string;
readonly injected: boolean;
readonly invocation_error: string;
// empty interface{} type, falling back to unknown
readonly metadata: Record<string, unknown>;
readonly created_at: string;
}
// From codersdk/aibridge.go
export interface AIBridgeUserPrompt {
readonly id: string;
readonly interception_id: string;
readonly provider_response_id: string;
readonly prompt: string;
// empty interface{} type, falling back to unknown
readonly metadata: Record<string, unknown>;
readonly created_at: string;
}
// From codersdk/deployment.go
export interface AIConfig {
readonly bridge?: AIBridgeConfig;