mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
fix: sanitize workspace agent logs before insert (#24028)
Workspace agent logs could still fail after the earlier invalid UTF-8 fix because NUL bytes are valid Go/protobuf strings but are rejected by Postgres text columns. The legacy HTTP log upload path also bypassed the old sanitization entirely, and both server insert paths computed logs_length from the unsanitized input. Add a shared log-output sanitizer in agentsdk, use it in the protobuf conversion path and both server-side insert paths, and compute OutputLength from the sanitized string so overflow accounting matches what is actually stored. This keeps the old invalid UTF-8 behavior while also handling embedded NUL bytes consistently across DRPC and HTTP log ingestion. Refs [#23292 ](https://github.com/coder/coder/issues/23292) Refs [#13433 ](https://github.com/coder/coder/issues/13433)
This commit is contained in:
committed by
GitHub
parent
7caef4987f
commit
f4240bb8c1
@@ -77,8 +77,9 @@ func (a *LogsAPI) BatchCreateLogs(ctx context.Context, req *agentproto.BatchCrea
|
|||||||
level := make([]database.LogLevel, 0)
|
level := make([]database.LogLevel, 0)
|
||||||
outputLength := 0
|
outputLength := 0
|
||||||
for _, logEntry := range req.Logs {
|
for _, logEntry := range req.Logs {
|
||||||
output = append(output, logEntry.Output)
|
sanitizedOutput := agentsdk.SanitizeLogOutput(logEntry.Output)
|
||||||
outputLength += len(logEntry.Output)
|
output = append(output, sanitizedOutput)
|
||||||
|
outputLength += len(sanitizedOutput)
|
||||||
|
|
||||||
var dbLevel database.LogLevel
|
var dbLevel database.LogLevel
|
||||||
switch logEntry.Level {
|
switch logEntry.Level {
|
||||||
|
|||||||
@@ -139,6 +139,59 @@ func TestBatchCreateLogs(t *testing.T) {
|
|||||||
require.True(t, publishWorkspaceAgentLogsUpdateCalled)
|
require.True(t, publishWorkspaceAgentLogsUpdateCalled)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("SanitizesOutput", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
dbM := dbmock.NewMockStore(gomock.NewController(t))
|
||||||
|
now := dbtime.Now()
|
||||||
|
api := &agentapi.LogsAPI{
|
||||||
|
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
||||||
|
return agent, nil
|
||||||
|
},
|
||||||
|
Database: dbM,
|
||||||
|
Log: testutil.Logger(t),
|
||||||
|
TimeNowFn: func() time.Time {
|
||||||
|
return now
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
rawOutput := "before\x00middle\xc3\x28after"
|
||||||
|
sanitizedOutput := agentsdk.SanitizeLogOutput(rawOutput)
|
||||||
|
expectedOutputLength := int32(len(sanitizedOutput)) //nolint:gosec // Test-controlled string length is small.
|
||||||
|
req := &agentproto.BatchCreateLogsRequest{
|
||||||
|
LogSourceId: logSource.ID[:],
|
||||||
|
Logs: []*agentproto.Log{
|
||||||
|
{
|
||||||
|
CreatedAt: timestamppb.New(now),
|
||||||
|
Level: agentproto.Log_WARN,
|
||||||
|
Output: rawOutput,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
dbM.EXPECT().InsertWorkspaceAgentLogs(gomock.Any(), database.InsertWorkspaceAgentLogsParams{
|
||||||
|
AgentID: agent.ID,
|
||||||
|
LogSourceID: logSource.ID,
|
||||||
|
CreatedAt: now,
|
||||||
|
Output: []string{sanitizedOutput},
|
||||||
|
Level: []database.LogLevel{database.LogLevelWarn},
|
||||||
|
OutputLength: expectedOutputLength,
|
||||||
|
}).Return([]database.WorkspaceAgentLog{
|
||||||
|
{
|
||||||
|
AgentID: agent.ID,
|
||||||
|
CreatedAt: now,
|
||||||
|
ID: 1,
|
||||||
|
Output: sanitizedOutput,
|
||||||
|
Level: database.LogLevelWarn,
|
||||||
|
LogSourceID: logSource.ID,
|
||||||
|
},
|
||||||
|
}, nil)
|
||||||
|
|
||||||
|
resp, err := api.BatchCreateLogs(context.Background(), req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, &agentproto.BatchCreateLogsResponse{}, resp)
|
||||||
|
})
|
||||||
|
|
||||||
t.Run("NoWorkspacePublishIfNotFirstLogs", func(t *testing.T) {
|
t.Run("NoWorkspacePublishIfNotFirstLogs", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
|||||||
@@ -181,8 +181,9 @@ func (api *API) patchWorkspaceAgentLogs(rw http.ResponseWriter, r *http.Request)
|
|||||||
level := make([]database.LogLevel, 0)
|
level := make([]database.LogLevel, 0)
|
||||||
outputLength := 0
|
outputLength := 0
|
||||||
for _, logEntry := range req.Logs {
|
for _, logEntry := range req.Logs {
|
||||||
output = append(output, logEntry.Output)
|
sanitizedOutput := agentsdk.SanitizeLogOutput(logEntry.Output)
|
||||||
outputLength += len(logEntry.Output)
|
output = append(output, sanitizedOutput)
|
||||||
|
outputLength += len(sanitizedOutput)
|
||||||
if logEntry.Level == "" {
|
if logEntry.Level == "" {
|
||||||
// Default to "info" to support older agents that didn't have the level field.
|
// Default to "info" to support older agents that didn't have the level field.
|
||||||
logEntry.Level = codersdk.LogLevelInfo
|
logEntry.Level = codersdk.LogLevelInfo
|
||||||
|
|||||||
@@ -260,6 +260,50 @@ func TestWorkspaceAgentLogs(t *testing.T) {
|
|||||||
require.Equal(t, "testing", logChunk[0].Output)
|
require.Equal(t, "testing", logChunk[0].Output)
|
||||||
require.Equal(t, "testing2", logChunk[1].Output)
|
require.Equal(t, "testing2", logChunk[1].Output)
|
||||||
})
|
})
|
||||||
|
t.Run("SanitizesNulBytesAndTracksSanitizedLength", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||||
|
client, db := coderdtest.NewWithDatabase(t, nil)
|
||||||
|
user := coderdtest.CreateFirstUser(t, client)
|
||||||
|
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||||
|
OrganizationID: user.OrganizationID,
|
||||||
|
OwnerID: user.UserID,
|
||||||
|
}).WithAgent().Do()
|
||||||
|
|
||||||
|
rawOutput := "before\x00after"
|
||||||
|
sanitizedOutput := agentsdk.SanitizeLogOutput(rawOutput)
|
||||||
|
agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(r.AgentToken))
|
||||||
|
err := agentClient.PatchLogs(ctx, agentsdk.PatchLogs{
|
||||||
|
Logs: []agentsdk.Log{
|
||||||
|
{
|
||||||
|
CreatedAt: dbtime.Now(),
|
||||||
|
Output: rawOutput,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
agent, err := db.GetWorkspaceAgentByID(dbauthz.AsSystemRestricted(ctx), r.Agents[0].ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.EqualValues(t, len(sanitizedOutput), agent.LogsLength)
|
||||||
|
|
||||||
|
workspace, err := client.Workspace(ctx, r.Workspace.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
logs, closer, err := client.WorkspaceAgentLogsAfter(ctx, workspace.LatestBuild.Resources[0].Agents[0].ID, 0, true)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
_ = closer.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
var logChunk []codersdk.WorkspaceAgentLog
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
case logChunk = <-logs:
|
||||||
|
}
|
||||||
|
require.NoError(t, ctx.Err())
|
||||||
|
require.Len(t, logChunk, 1)
|
||||||
|
require.Equal(t, sanitizedOutput, logChunk[0].Output)
|
||||||
|
})
|
||||||
t.Run("Close logs on outdated build", func(t *testing.T) {
|
t.Run("Close logs on outdated build", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||||
|
|||||||
@@ -376,7 +376,7 @@ func ProtoFromLog(log Log) (*proto.Log, error) {
|
|||||||
}
|
}
|
||||||
return &proto.Log{
|
return &proto.Log{
|
||||||
CreatedAt: timestamppb.New(log.CreatedAt),
|
CreatedAt: timestamppb.New(log.CreatedAt),
|
||||||
Output: strings.ToValidUTF8(log.Output, "❌"),
|
Output: SanitizeLogOutput(log.Output),
|
||||||
Level: proto.Log_Level(lvl),
|
Level: proto.Log_Level(lvl),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -229,7 +229,7 @@ func TestLogSender_SkipHugeLog(t *testing.T) {
|
|||||||
require.ErrorIs(t, err, context.Canceled)
|
require.ErrorIs(t, err, context.Canceled)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestLogSender_InvalidUTF8(t *testing.T) {
|
func TestLogSender_SanitizeOutput(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
testCtx := testutil.Context(t, testutil.WaitShort)
|
testCtx := testutil.Context(t, testutil.WaitShort)
|
||||||
ctx, cancel := context.WithCancel(testCtx)
|
ctx, cancel := context.WithCancel(testCtx)
|
||||||
@@ -243,7 +243,7 @@ func TestLogSender_InvalidUTF8(t *testing.T) {
|
|||||||
uut.Enqueue(ls1,
|
uut.Enqueue(ls1,
|
||||||
Log{
|
Log{
|
||||||
CreatedAt: t0,
|
CreatedAt: t0,
|
||||||
Output: "test log 0, src 1\xc3\x28",
|
Output: "test log 0, src 1\x00\xc3\x28",
|
||||||
Level: codersdk.LogLevelInfo,
|
Level: codersdk.LogLevelInfo,
|
||||||
},
|
},
|
||||||
Log{
|
Log{
|
||||||
@@ -260,10 +260,10 @@ func TestLogSender_InvalidUTF8(t *testing.T) {
|
|||||||
|
|
||||||
req := testutil.TryReceive(ctx, t, fDest.reqs)
|
req := testutil.TryReceive(ctx, t, fDest.reqs)
|
||||||
require.NotNil(t, req)
|
require.NotNil(t, req)
|
||||||
require.Len(t, req.Logs, 2, "it should sanitize invalid UTF-8, but still send")
|
require.Len(t, req.Logs, 2, "it should sanitize invalid output, but still send")
|
||||||
// the 0xc3, 0x28 is an invalid 2-byte sequence in UTF-8. The sanitizer replaces 0xc3 with ❌, and then
|
// The sanitizer replaces the NUL byte and invalid UTF-8 with ❌ while
|
||||||
// interprets 0x28 as a 1-byte sequence "("
|
// preserving the valid "(" byte that follows 0xc3.
|
||||||
require.Equal(t, "test log 0, src 1❌(", req.Logs[0].GetOutput())
|
require.Equal(t, "test log 0, src 1❌❌(", req.Logs[0].GetOutput())
|
||||||
require.Equal(t, proto.Log_INFO, req.Logs[0].GetLevel())
|
require.Equal(t, proto.Log_INFO, req.Logs[0].GetLevel())
|
||||||
require.Equal(t, "test log 1, src 1", req.Logs[1].GetOutput())
|
require.Equal(t, "test log 1, src 1", req.Logs[1].GetOutput())
|
||||||
require.Equal(t, proto.Log_INFO, req.Logs[1].GetLevel())
|
require.Equal(t, proto.Log_INFO, req.Logs[1].GetLevel())
|
||||||
|
|||||||
@@ -0,0 +1,11 @@
|
|||||||
|
package agentsdk
|
||||||
|
|
||||||
|
import "strings"
|
||||||
|
|
||||||
|
// SanitizeLogOutput replaces invalid UTF-8 and NUL characters in log output.
|
||||||
|
// Invalid UTF-8 cannot be transported in protobuf string fields, and PostgreSQL
|
||||||
|
// rejects NUL bytes in text columns.
|
||||||
|
func SanitizeLogOutput(s string) string {
|
||||||
|
s = strings.ToValidUTF8(s, "❌")
|
||||||
|
return strings.ReplaceAll(s, "\x00", "❌")
|
||||||
|
}
|
||||||
@@ -17,6 +17,54 @@ import (
|
|||||||
"github.com/coder/coder/v2/testutil"
|
"github.com/coder/coder/v2/testutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func TestSanitizeLogOutput(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
in string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid",
|
||||||
|
in: "hello world",
|
||||||
|
want: "hello world",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid utf8",
|
||||||
|
in: "test log\xc3\x28",
|
||||||
|
want: "test log❌(",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nul byte",
|
||||||
|
in: "before\x00after",
|
||||||
|
want: "before❌after",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid utf8 and nul byte",
|
||||||
|
in: "before\x00middle\xc3\x28after",
|
||||||
|
want: "before❌middle❌(after",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nul byte at edges",
|
||||||
|
in: "\x00middle\x00",
|
||||||
|
want: "❌middle❌",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid utf8 at edges",
|
||||||
|
in: "\xc3middle\xc3",
|
||||||
|
want: "❌middle❌",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
require.Equal(t, tt.want, agentsdk.SanitizeLogOutput(tt.in))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestStartupLogsWriter_Write(t *testing.T) {
|
func TestStartupLogsWriter_Write(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user