mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
fix: sanitize workspace agent logs before insert (#24028)
Workspace agent logs could still fail after the earlier invalid UTF-8 fix because NUL bytes are valid Go/protobuf strings but are rejected by Postgres text columns. The legacy HTTP log upload path also bypassed the old sanitization entirely, and both server insert paths computed logs_length from the unsanitized input. Add a shared log-output sanitizer in agentsdk, use it in the protobuf conversion path and both server-side insert paths, and compute OutputLength from the sanitized string so overflow accounting matches what is actually stored. This keeps the old invalid UTF-8 behavior while also handling embedded NUL bytes consistently across DRPC and HTTP log ingestion. Refs [#23292 ](https://github.com/coder/coder/issues/23292) Refs [#13433 ](https://github.com/coder/coder/issues/13433)
This commit is contained in:
committed by
GitHub
parent
7caef4987f
commit
f4240bb8c1
@@ -77,8 +77,9 @@ func (a *LogsAPI) BatchCreateLogs(ctx context.Context, req *agentproto.BatchCrea
|
||||
level := make([]database.LogLevel, 0)
|
||||
outputLength := 0
|
||||
for _, logEntry := range req.Logs {
|
||||
output = append(output, logEntry.Output)
|
||||
outputLength += len(logEntry.Output)
|
||||
sanitizedOutput := agentsdk.SanitizeLogOutput(logEntry.Output)
|
||||
output = append(output, sanitizedOutput)
|
||||
outputLength += len(sanitizedOutput)
|
||||
|
||||
var dbLevel database.LogLevel
|
||||
switch logEntry.Level {
|
||||
|
||||
@@ -139,6 +139,59 @@ func TestBatchCreateLogs(t *testing.T) {
|
||||
require.True(t, publishWorkspaceAgentLogsUpdateCalled)
|
||||
})
|
||||
|
||||
t.Run("SanitizesOutput", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dbM := dbmock.NewMockStore(gomock.NewController(t))
|
||||
now := dbtime.Now()
|
||||
api := &agentapi.LogsAPI{
|
||||
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
Database: dbM,
|
||||
Log: testutil.Logger(t),
|
||||
TimeNowFn: func() time.Time {
|
||||
return now
|
||||
},
|
||||
}
|
||||
|
||||
rawOutput := "before\x00middle\xc3\x28after"
|
||||
sanitizedOutput := agentsdk.SanitizeLogOutput(rawOutput)
|
||||
expectedOutputLength := int32(len(sanitizedOutput)) //nolint:gosec // Test-controlled string length is small.
|
||||
req := &agentproto.BatchCreateLogsRequest{
|
||||
LogSourceId: logSource.ID[:],
|
||||
Logs: []*agentproto.Log{
|
||||
{
|
||||
CreatedAt: timestamppb.New(now),
|
||||
Level: agentproto.Log_WARN,
|
||||
Output: rawOutput,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
dbM.EXPECT().InsertWorkspaceAgentLogs(gomock.Any(), database.InsertWorkspaceAgentLogsParams{
|
||||
AgentID: agent.ID,
|
||||
LogSourceID: logSource.ID,
|
||||
CreatedAt: now,
|
||||
Output: []string{sanitizedOutput},
|
||||
Level: []database.LogLevel{database.LogLevelWarn},
|
||||
OutputLength: expectedOutputLength,
|
||||
}).Return([]database.WorkspaceAgentLog{
|
||||
{
|
||||
AgentID: agent.ID,
|
||||
CreatedAt: now,
|
||||
ID: 1,
|
||||
Output: sanitizedOutput,
|
||||
Level: database.LogLevelWarn,
|
||||
LogSourceID: logSource.ID,
|
||||
},
|
||||
}, nil)
|
||||
|
||||
resp, err := api.BatchCreateLogs(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, &agentproto.BatchCreateLogsResponse{}, resp)
|
||||
})
|
||||
|
||||
t.Run("NoWorkspacePublishIfNotFirstLogs", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -181,8 +181,9 @@ func (api *API) patchWorkspaceAgentLogs(rw http.ResponseWriter, r *http.Request)
|
||||
level := make([]database.LogLevel, 0)
|
||||
outputLength := 0
|
||||
for _, logEntry := range req.Logs {
|
||||
output = append(output, logEntry.Output)
|
||||
outputLength += len(logEntry.Output)
|
||||
sanitizedOutput := agentsdk.SanitizeLogOutput(logEntry.Output)
|
||||
output = append(output, sanitizedOutput)
|
||||
outputLength += len(sanitizedOutput)
|
||||
if logEntry.Level == "" {
|
||||
// Default to "info" to support older agents that didn't have the level field.
|
||||
logEntry.Level = codersdk.LogLevelInfo
|
||||
|
||||
@@ -260,6 +260,50 @@ func TestWorkspaceAgentLogs(t *testing.T) {
|
||||
require.Equal(t, "testing", logChunk[0].Output)
|
||||
require.Equal(t, "testing2", logChunk[1].Output)
|
||||
})
|
||||
t.Run("SanitizesNulBytesAndTracksSanitizedLength", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
client, db := coderdtest.NewWithDatabase(t, nil)
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
OrganizationID: user.OrganizationID,
|
||||
OwnerID: user.UserID,
|
||||
}).WithAgent().Do()
|
||||
|
||||
rawOutput := "before\x00after"
|
||||
sanitizedOutput := agentsdk.SanitizeLogOutput(rawOutput)
|
||||
agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(r.AgentToken))
|
||||
err := agentClient.PatchLogs(ctx, agentsdk.PatchLogs{
|
||||
Logs: []agentsdk.Log{
|
||||
{
|
||||
CreatedAt: dbtime.Now(),
|
||||
Output: rawOutput,
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
agent, err := db.GetWorkspaceAgentByID(dbauthz.AsSystemRestricted(ctx), r.Agents[0].ID)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, len(sanitizedOutput), agent.LogsLength)
|
||||
|
||||
workspace, err := client.Workspace(ctx, r.Workspace.ID)
|
||||
require.NoError(t, err)
|
||||
logs, closer, err := client.WorkspaceAgentLogsAfter(ctx, workspace.LatestBuild.Resources[0].Agents[0].ID, 0, true)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
_ = closer.Close()
|
||||
}()
|
||||
|
||||
var logChunk []codersdk.WorkspaceAgentLog
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case logChunk = <-logs:
|
||||
}
|
||||
require.NoError(t, ctx.Err())
|
||||
require.Len(t, logChunk, 1)
|
||||
require.Equal(t, sanitizedOutput, logChunk[0].Output)
|
||||
})
|
||||
t.Run("Close logs on outdated build", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
@@ -376,7 +376,7 @@ func ProtoFromLog(log Log) (*proto.Log, error) {
|
||||
}
|
||||
return &proto.Log{
|
||||
CreatedAt: timestamppb.New(log.CreatedAt),
|
||||
Output: strings.ToValidUTF8(log.Output, "❌"),
|
||||
Output: SanitizeLogOutput(log.Output),
|
||||
Level: proto.Log_Level(lvl),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -229,7 +229,7 @@ func TestLogSender_SkipHugeLog(t *testing.T) {
|
||||
require.ErrorIs(t, err, context.Canceled)
|
||||
}
|
||||
|
||||
func TestLogSender_InvalidUTF8(t *testing.T) {
|
||||
func TestLogSender_SanitizeOutput(t *testing.T) {
|
||||
t.Parallel()
|
||||
testCtx := testutil.Context(t, testutil.WaitShort)
|
||||
ctx, cancel := context.WithCancel(testCtx)
|
||||
@@ -243,7 +243,7 @@ func TestLogSender_InvalidUTF8(t *testing.T) {
|
||||
uut.Enqueue(ls1,
|
||||
Log{
|
||||
CreatedAt: t0,
|
||||
Output: "test log 0, src 1\xc3\x28",
|
||||
Output: "test log 0, src 1\x00\xc3\x28",
|
||||
Level: codersdk.LogLevelInfo,
|
||||
},
|
||||
Log{
|
||||
@@ -260,10 +260,10 @@ func TestLogSender_InvalidUTF8(t *testing.T) {
|
||||
|
||||
req := testutil.TryReceive(ctx, t, fDest.reqs)
|
||||
require.NotNil(t, req)
|
||||
require.Len(t, req.Logs, 2, "it should sanitize invalid UTF-8, but still send")
|
||||
// the 0xc3, 0x28 is an invalid 2-byte sequence in UTF-8. The sanitizer replaces 0xc3 with ❌, and then
|
||||
// interprets 0x28 as a 1-byte sequence "("
|
||||
require.Equal(t, "test log 0, src 1❌(", req.Logs[0].GetOutput())
|
||||
require.Len(t, req.Logs, 2, "it should sanitize invalid output, but still send")
|
||||
// The sanitizer replaces the NUL byte and invalid UTF-8 with ❌ while
|
||||
// preserving the valid "(" byte that follows 0xc3.
|
||||
require.Equal(t, "test log 0, src 1❌❌(", req.Logs[0].GetOutput())
|
||||
require.Equal(t, proto.Log_INFO, req.Logs[0].GetLevel())
|
||||
require.Equal(t, "test log 1, src 1", req.Logs[1].GetOutput())
|
||||
require.Equal(t, proto.Log_INFO, req.Logs[1].GetLevel())
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
package agentsdk
|
||||
|
||||
import "strings"
|
||||
|
||||
// SanitizeLogOutput replaces invalid UTF-8 and NUL characters in log output.
|
||||
// Invalid UTF-8 cannot be transported in protobuf string fields, and PostgreSQL
|
||||
// rejects NUL bytes in text columns.
|
||||
func SanitizeLogOutput(s string) string {
|
||||
s = strings.ToValidUTF8(s, "❌")
|
||||
return strings.ReplaceAll(s, "\x00", "❌")
|
||||
}
|
||||
@@ -17,6 +17,54 @@ import (
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestSanitizeLogOutput(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
in string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "valid",
|
||||
in: "hello world",
|
||||
want: "hello world",
|
||||
},
|
||||
{
|
||||
name: "invalid utf8",
|
||||
in: "test log\xc3\x28",
|
||||
want: "test log❌(",
|
||||
},
|
||||
{
|
||||
name: "nul byte",
|
||||
in: "before\x00after",
|
||||
want: "before❌after",
|
||||
},
|
||||
{
|
||||
name: "invalid utf8 and nul byte",
|
||||
in: "before\x00middle\xc3\x28after",
|
||||
want: "before❌middle❌(after",
|
||||
},
|
||||
{
|
||||
name: "nul byte at edges",
|
||||
in: "\x00middle\x00",
|
||||
want: "❌middle❌",
|
||||
},
|
||||
{
|
||||
name: "invalid utf8 at edges",
|
||||
in: "\xc3middle\xc3",
|
||||
want: "❌middle❌",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
require.Equal(t, tt.want, agentsdk.SanitizeLogOutput(tt.in))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartupLogsWriter_Write(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user