diff --git a/coderd/agentapi/logs.go b/coderd/agentapi/logs.go index 9561db7651..34826ef867 100644 --- a/coderd/agentapi/logs.go +++ b/coderd/agentapi/logs.go @@ -77,8 +77,9 @@ func (a *LogsAPI) BatchCreateLogs(ctx context.Context, req *agentproto.BatchCrea level := make([]database.LogLevel, 0) outputLength := 0 for _, logEntry := range req.Logs { - output = append(output, logEntry.Output) - outputLength += len(logEntry.Output) + sanitizedOutput := agentsdk.SanitizeLogOutput(logEntry.Output) + output = append(output, sanitizedOutput) + outputLength += len(sanitizedOutput) var dbLevel database.LogLevel switch logEntry.Level { diff --git a/coderd/agentapi/logs_test.go b/coderd/agentapi/logs_test.go index 9828f0ce47..08ee1bc9a7 100644 --- a/coderd/agentapi/logs_test.go +++ b/coderd/agentapi/logs_test.go @@ -139,6 +139,59 @@ func TestBatchCreateLogs(t *testing.T) { require.True(t, publishWorkspaceAgentLogsUpdateCalled) }) + t.Run("SanitizesOutput", func(t *testing.T) { + t.Parallel() + + dbM := dbmock.NewMockStore(gomock.NewController(t)) + now := dbtime.Now() + api := &agentapi.LogsAPI{ + AgentFn: func(context.Context) (database.WorkspaceAgent, error) { + return agent, nil + }, + Database: dbM, + Log: testutil.Logger(t), + TimeNowFn: func() time.Time { + return now + }, + } + + rawOutput := "before\x00middle\xc3\x28after" + sanitizedOutput := agentsdk.SanitizeLogOutput(rawOutput) + expectedOutputLength := int32(len(sanitizedOutput)) //nolint:gosec // Test-controlled string length is small. + req := &agentproto.BatchCreateLogsRequest{ + LogSourceId: logSource.ID[:], + Logs: []*agentproto.Log{ + { + CreatedAt: timestamppb.New(now), + Level: agentproto.Log_WARN, + Output: rawOutput, + }, + }, + } + + dbM.EXPECT().InsertWorkspaceAgentLogs(gomock.Any(), database.InsertWorkspaceAgentLogsParams{ + AgentID: agent.ID, + LogSourceID: logSource.ID, + CreatedAt: now, + Output: []string{sanitizedOutput}, + Level: []database.LogLevel{database.LogLevelWarn}, + OutputLength: expectedOutputLength, + }).Return([]database.WorkspaceAgentLog{ + { + AgentID: agent.ID, + CreatedAt: now, + ID: 1, + Output: sanitizedOutput, + Level: database.LogLevelWarn, + LogSourceID: logSource.ID, + }, + }, nil) + + resp, err := api.BatchCreateLogs(context.Background(), req) + require.NoError(t, err) + require.Equal(t, &agentproto.BatchCreateLogsResponse{}, resp) + }) + t.Run("NoWorkspacePublishIfNotFirstLogs", func(t *testing.T) { t.Parallel() diff --git a/coderd/workspaceagents.go b/coderd/workspaceagents.go index 12b8a98251..def90f23d2 100644 --- a/coderd/workspaceagents.go +++ b/coderd/workspaceagents.go @@ -181,8 +181,9 @@ func (api *API) patchWorkspaceAgentLogs(rw http.ResponseWriter, r *http.Request) level := make([]database.LogLevel, 0) outputLength := 0 for _, logEntry := range req.Logs { - output = append(output, logEntry.Output) - outputLength += len(logEntry.Output) + sanitizedOutput := agentsdk.SanitizeLogOutput(logEntry.Output) + output = append(output, sanitizedOutput) + outputLength += len(sanitizedOutput) if logEntry.Level == "" { // Default to "info" to support older agents that didn't have the level field. logEntry.Level = codersdk.LogLevelInfo diff --git a/coderd/workspaceagents_test.go b/coderd/workspaceagents_test.go index f47f1a39d2..cb2f9167f4 100644 --- a/coderd/workspaceagents_test.go +++ b/coderd/workspaceagents_test.go @@ -260,6 +260,50 @@ func TestWorkspaceAgentLogs(t *testing.T) { require.Equal(t, "testing", logChunk[0].Output) require.Equal(t, "testing2", logChunk[1].Output) }) + t.Run("SanitizesNulBytesAndTracksSanitizedLength", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitMedium) + client, db := coderdtest.NewWithDatabase(t, nil) + user := coderdtest.CreateFirstUser(t, client) + r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{ + OrganizationID: user.OrganizationID, + OwnerID: user.UserID, + }).WithAgent().Do() + + rawOutput := "before\x00after" + sanitizedOutput := agentsdk.SanitizeLogOutput(rawOutput) + agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(r.AgentToken)) + err := agentClient.PatchLogs(ctx, agentsdk.PatchLogs{ + Logs: []agentsdk.Log{ + { + CreatedAt: dbtime.Now(), + Output: rawOutput, + }, + }, + }) + require.NoError(t, err) + + agent, err := db.GetWorkspaceAgentByID(dbauthz.AsSystemRestricted(ctx), r.Agents[0].ID) + require.NoError(t, err) + require.EqualValues(t, len(sanitizedOutput), agent.LogsLength) + + workspace, err := client.Workspace(ctx, r.Workspace.ID) + require.NoError(t, err) + logs, closer, err := client.WorkspaceAgentLogsAfter(ctx, workspace.LatestBuild.Resources[0].Agents[0].ID, 0, true) + require.NoError(t, err) + defer func() { + _ = closer.Close() + }() + + var logChunk []codersdk.WorkspaceAgentLog + select { + case <-ctx.Done(): + case logChunk = <-logs: + } + require.NoError(t, ctx.Err()) + require.Len(t, logChunk, 1) + require.Equal(t, sanitizedOutput, logChunk[0].Output) + }) t.Run("Close logs on outdated build", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitMedium) diff --git a/codersdk/agentsdk/convert.go b/codersdk/agentsdk/convert.go index 470e141e3a..ca36a6eba1 100644 --- a/codersdk/agentsdk/convert.go +++ b/codersdk/agentsdk/convert.go @@ -376,7 +376,7 @@ func ProtoFromLog(log Log) (*proto.Log, error) { } return &proto.Log{ CreatedAt: timestamppb.New(log.CreatedAt), - Output: strings.ToValidUTF8(log.Output, "❌"), + Output: SanitizeLogOutput(log.Output), Level: proto.Log_Level(lvl), }, nil } diff --git a/codersdk/agentsdk/logs_internal_test.go b/codersdk/agentsdk/logs_internal_test.go index a8e4210239..e4524ed53b 100644 --- a/codersdk/agentsdk/logs_internal_test.go +++ b/codersdk/agentsdk/logs_internal_test.go @@ -229,7 +229,7 @@ func TestLogSender_SkipHugeLog(t *testing.T) { require.ErrorIs(t, err, context.Canceled) } -func TestLogSender_InvalidUTF8(t *testing.T) { +func TestLogSender_SanitizeOutput(t *testing.T) { t.Parallel() testCtx := testutil.Context(t, testutil.WaitShort) ctx, cancel := context.WithCancel(testCtx) @@ -243,7 +243,7 @@ func TestLogSender_InvalidUTF8(t *testing.T) { uut.Enqueue(ls1, Log{ CreatedAt: t0, - Output: "test log 0, src 1\xc3\x28", + Output: "test log 0, src 1\x00\xc3\x28", Level: codersdk.LogLevelInfo, }, Log{ @@ -260,10 +260,10 @@ func TestLogSender_InvalidUTF8(t *testing.T) { req := testutil.TryReceive(ctx, t, fDest.reqs) require.NotNil(t, req) - require.Len(t, req.Logs, 2, "it should sanitize invalid UTF-8, but still send") - // the 0xc3, 0x28 is an invalid 2-byte sequence in UTF-8. The sanitizer replaces 0xc3 with ❌, and then - // interprets 0x28 as a 1-byte sequence "(" - require.Equal(t, "test log 0, src 1❌(", req.Logs[0].GetOutput()) + require.Len(t, req.Logs, 2, "it should sanitize invalid output, but still send") + // The sanitizer replaces the NUL byte and invalid UTF-8 with ❌ while + // preserving the valid "(" byte that follows 0xc3. + require.Equal(t, "test log 0, src 1❌❌(", req.Logs[0].GetOutput()) require.Equal(t, proto.Log_INFO, req.Logs[0].GetLevel()) require.Equal(t, "test log 1, src 1", req.Logs[1].GetOutput()) require.Equal(t, proto.Log_INFO, req.Logs[1].GetLevel()) diff --git a/codersdk/agentsdk/logs_sanitize.go b/codersdk/agentsdk/logs_sanitize.go new file mode 100644 index 0000000000..ef5a34df5b --- /dev/null +++ b/codersdk/agentsdk/logs_sanitize.go @@ -0,0 +1,11 @@ +package agentsdk + +import "strings" + +// SanitizeLogOutput replaces invalid UTF-8 and NUL characters in log output. +// Invalid UTF-8 cannot be transported in protobuf string fields, and PostgreSQL +// rejects NUL bytes in text columns. +func SanitizeLogOutput(s string) string { + s = strings.ToValidUTF8(s, "❌") + return strings.ReplaceAll(s, "\x00", "❌") +} diff --git a/codersdk/agentsdk/logs_test.go b/codersdk/agentsdk/logs_test.go index 05e4bc574e..56347466d3 100644 --- a/codersdk/agentsdk/logs_test.go +++ b/codersdk/agentsdk/logs_test.go @@ -17,6 +17,54 @@ import ( "github.com/coder/coder/v2/testutil" ) +func TestSanitizeLogOutput(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + in string + want string + }{ + { + name: "valid", + in: "hello world", + want: "hello world", + }, + { + name: "invalid utf8", + in: "test log\xc3\x28", + want: "test log❌(", + }, + { + name: "nul byte", + in: "before\x00after", + want: "before❌after", + }, + { + name: "invalid utf8 and nul byte", + in: "before\x00middle\xc3\x28after", + want: "before❌middle❌(after", + }, + { + name: "nul byte at edges", + in: "\x00middle\x00", + want: "❌middle❌", + }, + { + name: "invalid utf8 at edges", + in: "\xc3middle\xc3", + want: "❌middle❌", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + require.Equal(t, tt.want, agentsdk.SanitizeLogOutput(tt.in)) + }) + } +} + func TestStartupLogsWriter_Write(t *testing.T) { t.Parallel()