fix: sanitize workspace agent logs before insert (#24028)

Workspace agent logs could still fail after the earlier invalid UTF-8
fix because NUL bytes are valid Go/protobuf strings but are rejected by
Postgres text columns. The legacy HTTP log upload path also bypassed the
old sanitization entirely, and both server insert paths computed
logs_length from the unsanitized input.

Add a shared log-output sanitizer in agentsdk, use it in the protobuf
conversion path and both server-side insert paths, and compute
OutputLength from the sanitized string so overflow accounting matches
what is actually stored. This keeps the old invalid UTF-8 behavior while
also handling embedded NUL bytes consistently across DRPC and HTTP log
ingestion.

Refs [#23292 ](https://github.com/coder/coder/issues/23292)
Refs [#13433 ](https://github.com/coder/coder/issues/13433)
This commit is contained in:
dylanhuff-at-coder
2026-04-08 16:29:38 -07:00
committed by GitHub
parent 7caef4987f
commit f4240bb8c1
8 changed files with 169 additions and 11 deletions
+3 -2
View File
@@ -77,8 +77,9 @@ func (a *LogsAPI) BatchCreateLogs(ctx context.Context, req *agentproto.BatchCrea
level := make([]database.LogLevel, 0) level := make([]database.LogLevel, 0)
outputLength := 0 outputLength := 0
for _, logEntry := range req.Logs { for _, logEntry := range req.Logs {
output = append(output, logEntry.Output) sanitizedOutput := agentsdk.SanitizeLogOutput(logEntry.Output)
outputLength += len(logEntry.Output) output = append(output, sanitizedOutput)
outputLength += len(sanitizedOutput)
var dbLevel database.LogLevel var dbLevel database.LogLevel
switch logEntry.Level { switch logEntry.Level {
+53
View File
@@ -139,6 +139,59 @@ func TestBatchCreateLogs(t *testing.T) {
require.True(t, publishWorkspaceAgentLogsUpdateCalled) require.True(t, publishWorkspaceAgentLogsUpdateCalled)
}) })
t.Run("SanitizesOutput", func(t *testing.T) {
t.Parallel()
dbM := dbmock.NewMockStore(gomock.NewController(t))
now := dbtime.Now()
api := &agentapi.LogsAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
Database: dbM,
Log: testutil.Logger(t),
TimeNowFn: func() time.Time {
return now
},
}
rawOutput := "before\x00middle\xc3\x28after"
sanitizedOutput := agentsdk.SanitizeLogOutput(rawOutput)
expectedOutputLength := int32(len(sanitizedOutput)) //nolint:gosec // Test-controlled string length is small.
req := &agentproto.BatchCreateLogsRequest{
LogSourceId: logSource.ID[:],
Logs: []*agentproto.Log{
{
CreatedAt: timestamppb.New(now),
Level: agentproto.Log_WARN,
Output: rawOutput,
},
},
}
dbM.EXPECT().InsertWorkspaceAgentLogs(gomock.Any(), database.InsertWorkspaceAgentLogsParams{
AgentID: agent.ID,
LogSourceID: logSource.ID,
CreatedAt: now,
Output: []string{sanitizedOutput},
Level: []database.LogLevel{database.LogLevelWarn},
OutputLength: expectedOutputLength,
}).Return([]database.WorkspaceAgentLog{
{
AgentID: agent.ID,
CreatedAt: now,
ID: 1,
Output: sanitizedOutput,
Level: database.LogLevelWarn,
LogSourceID: logSource.ID,
},
}, nil)
resp, err := api.BatchCreateLogs(context.Background(), req)
require.NoError(t, err)
require.Equal(t, &agentproto.BatchCreateLogsResponse{}, resp)
})
t.Run("NoWorkspacePublishIfNotFirstLogs", func(t *testing.T) { t.Run("NoWorkspacePublishIfNotFirstLogs", func(t *testing.T) {
t.Parallel() t.Parallel()
+3 -2
View File
@@ -181,8 +181,9 @@ func (api *API) patchWorkspaceAgentLogs(rw http.ResponseWriter, r *http.Request)
level := make([]database.LogLevel, 0) level := make([]database.LogLevel, 0)
outputLength := 0 outputLength := 0
for _, logEntry := range req.Logs { for _, logEntry := range req.Logs {
output = append(output, logEntry.Output) sanitizedOutput := agentsdk.SanitizeLogOutput(logEntry.Output)
outputLength += len(logEntry.Output) output = append(output, sanitizedOutput)
outputLength += len(sanitizedOutput)
if logEntry.Level == "" { if logEntry.Level == "" {
// Default to "info" to support older agents that didn't have the level field. // Default to "info" to support older agents that didn't have the level field.
logEntry.Level = codersdk.LogLevelInfo logEntry.Level = codersdk.LogLevelInfo
+44
View File
@@ -260,6 +260,50 @@ func TestWorkspaceAgentLogs(t *testing.T) {
require.Equal(t, "testing", logChunk[0].Output) require.Equal(t, "testing", logChunk[0].Output)
require.Equal(t, "testing2", logChunk[1].Output) require.Equal(t, "testing2", logChunk[1].Output)
}) })
t.Run("SanitizesNulBytesAndTracksSanitizedLength", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
client, db := coderdtest.NewWithDatabase(t, nil)
user := coderdtest.CreateFirstUser(t, client)
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: user.OrganizationID,
OwnerID: user.UserID,
}).WithAgent().Do()
rawOutput := "before\x00after"
sanitizedOutput := agentsdk.SanitizeLogOutput(rawOutput)
agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(r.AgentToken))
err := agentClient.PatchLogs(ctx, agentsdk.PatchLogs{
Logs: []agentsdk.Log{
{
CreatedAt: dbtime.Now(),
Output: rawOutput,
},
},
})
require.NoError(t, err)
agent, err := db.GetWorkspaceAgentByID(dbauthz.AsSystemRestricted(ctx), r.Agents[0].ID)
require.NoError(t, err)
require.EqualValues(t, len(sanitizedOutput), agent.LogsLength)
workspace, err := client.Workspace(ctx, r.Workspace.ID)
require.NoError(t, err)
logs, closer, err := client.WorkspaceAgentLogsAfter(ctx, workspace.LatestBuild.Resources[0].Agents[0].ID, 0, true)
require.NoError(t, err)
defer func() {
_ = closer.Close()
}()
var logChunk []codersdk.WorkspaceAgentLog
select {
case <-ctx.Done():
case logChunk = <-logs:
}
require.NoError(t, ctx.Err())
require.Len(t, logChunk, 1)
require.Equal(t, sanitizedOutput, logChunk[0].Output)
})
t.Run("Close logs on outdated build", func(t *testing.T) { t.Run("Close logs on outdated build", func(t *testing.T) {
t.Parallel() t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium) ctx := testutil.Context(t, testutil.WaitMedium)
+1 -1
View File
@@ -376,7 +376,7 @@ func ProtoFromLog(log Log) (*proto.Log, error) {
} }
return &proto.Log{ return &proto.Log{
CreatedAt: timestamppb.New(log.CreatedAt), CreatedAt: timestamppb.New(log.CreatedAt),
Output: strings.ToValidUTF8(log.Output, "❌"), Output: SanitizeLogOutput(log.Output),
Level: proto.Log_Level(lvl), Level: proto.Log_Level(lvl),
}, nil }, nil
} }
+6 -6
View File
@@ -229,7 +229,7 @@ func TestLogSender_SkipHugeLog(t *testing.T) {
require.ErrorIs(t, err, context.Canceled) require.ErrorIs(t, err, context.Canceled)
} }
func TestLogSender_InvalidUTF8(t *testing.T) { func TestLogSender_SanitizeOutput(t *testing.T) {
t.Parallel() t.Parallel()
testCtx := testutil.Context(t, testutil.WaitShort) testCtx := testutil.Context(t, testutil.WaitShort)
ctx, cancel := context.WithCancel(testCtx) ctx, cancel := context.WithCancel(testCtx)
@@ -243,7 +243,7 @@ func TestLogSender_InvalidUTF8(t *testing.T) {
uut.Enqueue(ls1, uut.Enqueue(ls1,
Log{ Log{
CreatedAt: t0, CreatedAt: t0,
Output: "test log 0, src 1\xc3\x28", Output: "test log 0, src 1\x00\xc3\x28",
Level: codersdk.LogLevelInfo, Level: codersdk.LogLevelInfo,
}, },
Log{ Log{
@@ -260,10 +260,10 @@ func TestLogSender_InvalidUTF8(t *testing.T) {
req := testutil.TryReceive(ctx, t, fDest.reqs) req := testutil.TryReceive(ctx, t, fDest.reqs)
require.NotNil(t, req) require.NotNil(t, req)
require.Len(t, req.Logs, 2, "it should sanitize invalid UTF-8, but still send") require.Len(t, req.Logs, 2, "it should sanitize invalid output, but still send")
// the 0xc3, 0x28 is an invalid 2-byte sequence in UTF-8. The sanitizer replaces 0xc3 with ❌, and then // The sanitizer replaces the NUL byte and invalid UTF-8 with ❌ while
// interprets 0x28 as a 1-byte sequence "(" // preserving the valid "(" byte that follows 0xc3.
require.Equal(t, "test log 0, src 1❌(", req.Logs[0].GetOutput()) require.Equal(t, "test log 0, src 1❌(", req.Logs[0].GetOutput())
require.Equal(t, proto.Log_INFO, req.Logs[0].GetLevel()) require.Equal(t, proto.Log_INFO, req.Logs[0].GetLevel())
require.Equal(t, "test log 1, src 1", req.Logs[1].GetOutput()) require.Equal(t, "test log 1, src 1", req.Logs[1].GetOutput())
require.Equal(t, proto.Log_INFO, req.Logs[1].GetLevel()) require.Equal(t, proto.Log_INFO, req.Logs[1].GetLevel())
+11
View File
@@ -0,0 +1,11 @@
package agentsdk
import "strings"
// SanitizeLogOutput replaces invalid UTF-8 and NUL characters in log output.
// Invalid UTF-8 cannot be transported in protobuf string fields, and PostgreSQL
// rejects NUL bytes in text columns.
func SanitizeLogOutput(s string) string {
s = strings.ToValidUTF8(s, "❌")
return strings.ReplaceAll(s, "\x00", "❌")
}
+48
View File
@@ -17,6 +17,54 @@ import (
"github.com/coder/coder/v2/testutil" "github.com/coder/coder/v2/testutil"
) )
func TestSanitizeLogOutput(t *testing.T) {
t.Parallel()
tests := []struct {
name string
in string
want string
}{
{
name: "valid",
in: "hello world",
want: "hello world",
},
{
name: "invalid utf8",
in: "test log\xc3\x28",
want: "test log❌(",
},
{
name: "nul byte",
in: "before\x00after",
want: "before❌after",
},
{
name: "invalid utf8 and nul byte",
in: "before\x00middle\xc3\x28after",
want: "before❌middle❌(after",
},
{
name: "nul byte at edges",
in: "\x00middle\x00",
want: "❌middle❌",
},
{
name: "invalid utf8 at edges",
in: "\xc3middle\xc3",
want: "❌middle❌",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
require.Equal(t, tt.want, agentsdk.SanitizeLogOutput(tt.in))
})
}
}
func TestStartupLogsWriter_Write(t *testing.T) { func TestStartupLogsWriter_Write(t *testing.T) {
t.Parallel() t.Parallel()