diff --git a/coderd/taskname/taskname.go b/coderd/taskname/taskname.go index fe0f9d575d..3351a288cf 100644 --- a/coderd/taskname/taskname.go +++ b/coderd/taskname/taskname.go @@ -96,6 +96,27 @@ var ( ErrNoNameGenerated = xerrors.New("no task name generated") ) +// extractJSON strips optional markdown code fences (```json or +// ```) that LLMs sometimes wrap around JSON output, returning +// only the inner JSON string. Only well-formed fences with a +// newline after the opening backticks are stripped; malformed +// fences are left untouched so that json.Unmarshal fails +// cleanly and the caller can fall back to other strategies. +func extractJSON(s string) string { + s = strings.TrimSpace(s) + if strings.HasPrefix(s, "```") { + // Only strip when there is a newline separating the + // fence line from the body. Without one we cannot + // reliably tell the fence from the content. + if idx := strings.Index(s, "\n"); idx != -1 { + s = s[idx+1:] + s = strings.TrimSuffix(s, "```") + s = strings.TrimSpace(s) + } + } + return s +} + type TaskName struct { Name string `json:"task_name"` DisplayName string `json:"display_name"` @@ -188,7 +209,7 @@ func generateFromPrompt(prompt string) (TaskName, error) { // generateFromAnthropic uses Claude (Anthropic) to generate semantic task and display names from a user prompt. // It sends the prompt to Claude with a structured system prompt requesting JSON output containing both names. // Returns an error if the API call fails, the response is invalid, or Claude returns an "unnamed" placeholder. -func generateFromAnthropic(ctx context.Context, prompt string, apiKey string, model anthropic.Model) (TaskName, error) { +func generateFromAnthropic(ctx context.Context, prompt string, apiKey string, model anthropic.Model, opts ...anthropicoption.RequestOption) (TaskName, error) { anthropicModel := model if anthropicModel == "" { anthropicModel = defaultModel @@ -216,6 +237,7 @@ func generateFromAnthropic(ctx context.Context, prompt string, apiKey string, mo anthropicOptions := anthropic.DefaultClientOptions() anthropicOptions = append(anthropicOptions, anthropicoption.WithAPIKey(apiKey)) + anthropicOptions = append(anthropicOptions, opts...) anthropicClient := anthropic.NewClient(anthropicOptions...) stream, err := anthropicDataStream(ctx, anthropicClient, anthropicModel, conversation) @@ -234,9 +256,11 @@ func generateFromAnthropic(ctx context.Context, prompt string, apiKey string, mo return TaskName{}, ErrNoNameGenerated } - // Parse the JSON response + // Parse the JSON response. LLMs sometimes wrap JSON in + // markdown code fences (```json ... ```), so we strip + // those before unmarshalling. var taskNameResponse TaskName - if err := json.Unmarshal([]byte(acc.Messages()[0].Content), &taskNameResponse); err != nil { + if err := json.Unmarshal([]byte(extractJSON(acc.Messages()[0].Content)), &taskNameResponse); err != nil { return TaskName{}, xerrors.Errorf("failed to parse anthropic response: %w", err) } diff --git a/coderd/taskname/taskname_internal_test.go b/coderd/taskname/taskname_internal_test.go index 4613123250..eff0b30de6 100644 --- a/coderd/taskname/taskname_internal_test.go +++ b/coderd/taskname/taskname_internal_test.go @@ -1,10 +1,15 @@ package taskname import ( + "encoding/json" "fmt" + "net/http" + "net/http/httptest" "strings" "testing" + "github.com/anthropics/anthropic-sdk-go" + anthropicoption "github.com/anthropics/anthropic-sdk-go/option" "github.com/stretchr/testify/require" "github.com/coder/coder/v2/codersdk" @@ -113,6 +118,151 @@ func TestGenerateFromPrompt(t *testing.T) { } } +func TestExtractJSON(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + expected string + }{ + { + name: "BareJSON", + input: `{"display_name": "Fix bug", "task_name": "fix-bug"}`, + expected: `{"display_name": "Fix bug", "task_name": "fix-bug"}`, + }, + { + name: "FencedJSON", + input: "```json\n{\"display_name\": \"Fix bug\", \"task_name\": \"fix-bug\"}\n```", + expected: `{"display_name": "Fix bug", "task_name": "fix-bug"}`, + }, + { + name: "FencedNoLanguage", + input: "```\n{\"display_name\": \"Fix bug\", \"task_name\": \"fix-bug\"}\n```", + expected: `{"display_name": "Fix bug", "task_name": "fix-bug"}`, + }, + { + name: "FencedWithSurroundingWhitespace", + input: " \n```json\n{\"display_name\": \"Fix bug\", \"task_name\": \"fix-bug\"}\n```\n ", + expected: `{"display_name": "Fix bug", "task_name": "fix-bug"}`, + }, + { + name: "BareJSONWithWhitespace", + input: " \n{\"display_name\": \"Fix bug\", \"task_name\": \"fix-bug\"}\n ", + expected: `{"display_name": "Fix bug", "task_name": "fix-bug"}`, + }, + { + name: "FencedMultilineJSON", + input: "```json\n{\n \"display_name\": \"Fix bug\",\n \"task_name\": \"fix-bug\"\n}\n```", + expected: "{\n \"display_name\": \"Fix bug\",\n \"task_name\": \"fix-bug\"\n}", + }, + { + name: "FencedNoNewlinePassthrough", + input: "```json{\"display_name\": \"Fix bug\", \"task_name\": \"fix-bug\"}```", + expected: "```json{\"display_name\": \"Fix bug\", \"task_name\": \"fix-bug\"}```", + }, + { + name: "NonJSONFencedContent", + input: "```foo: {}, bar: {}```", + expected: "```foo: {}, bar: {}```", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + got := extractJSON(tc.input) + require.Equal(t, tc.expected, got) + }) + } +} + +// fakeAnthropicSSE builds a minimal Anthropic Messages SSE stream +// whose sole text content is the provided string. +func fakeAnthropicSSE(t *testing.T, text string) string { + t.Helper() + + // Use json.Marshal to produce a correctly escaped JSON + // string value, then strip the surrounding quotes. + escapedBytes, err := json.Marshal(text) + require.NoError(t, err) + escaped := string(escapedBytes[1 : len(escapedBytes)-1]) + + return fmt.Sprintf(`event: message_start +data: {"type":"message_start","message":{"id":"msg_test","type":"message","role":"assistant","model":"claude-haiku-4-5-20241022","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":10,"output_tokens":1}}} + +event: content_block_start +data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}} + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"%s"}} + +event: content_block_stop +data: {"type":"content_block_stop","index":0} + +event: message_delta +data: {"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"output_tokens":20}} + +event: message_stop +data: {"type":"message_stop"} +`, escaped) +} + +func TestGenerateFromAnthropicMock(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + responseText string + expectedDisplayName string + expectedNamePrefix string + }{ + { + name: "BareJSON", + responseText: `{"display_name": "Fix bug", "task_name": "fix-bug"}`, + expectedDisplayName: "Fix bug", + expectedNamePrefix: "fix-bug-", + }, + { + name: "FencedJSON", + responseText: "```json\n{\"display_name\": \"Debug auth\", \"task_name\": \"debug-auth\"}\n```", + expectedDisplayName: "Debug auth", + expectedNamePrefix: "debug-auth-", + }, + { + name: "FencedNoLanguage", + responseText: "```\n{\"display_name\": \"Setup CI\", \"task_name\": \"setup-ci\"}\n```", + expectedDisplayName: "Setup CI", + expectedNamePrefix: "setup-ci-", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + _, _ = w.Write([]byte(fakeAnthropicSSE(t, tc.responseText))) + })) + t.Cleanup(srv.Close) + + ctx := testutil.Context(t, testutil.WaitShort) + + taskName, err := generateFromAnthropic( + ctx, "test prompt", "fake-key", + anthropic.ModelClaudeHaiku4_5, + anthropicoption.WithBaseURL(srv.URL), + ) + require.NoError(t, err) + require.NoError(t, codersdk.NameValid(taskName.Name)) + require.True(t, strings.HasPrefix(taskName.Name, tc.expectedNamePrefix), + "expected name %q to have prefix %q", taskName.Name, tc.expectedNamePrefix) + require.Equal(t, tc.expectedDisplayName, taskName.DisplayName) + }) + } +} + func TestGenerateFromAnthropic(t *testing.T) { t.Parallel()