chore!: route connection logs to new table (#18340)

### Breaking Change (changelog note):
> User connections to workspaces, and the opening of workspace apps or ports will no longer create entries in the audit log. Those events will now be included in the 'Connection Log'.
Please see the 'Connection Log' page in the dashboard, and the Connection Log [documentation](https://coder.com/docs/admin/monitoring/connection-logs) for details. Those with permission to view the Audit Log will also be able to view the Connection Log. The new Connection Log has the same licensing restrictions as the Audit Log, and requires a Premium Coder deployment.

### Context

This is the first PR of a few for moving connection events out of the audit log, and into a new database table and web UI page called the 'Connection Log'.

This PR:
- Creates the new table
- Adds and tests queries for inserting and reading, including reading with an RBAC filter.
- Implements the corresponding RBAC changes, such that anyone who can view the audit log can read from the table
- Implements, under the enterprise package, a `ConnectionLogger` abstraction to replace the `Auditor` abstraction for these logs. (No-op'd in AGPL, like the `Auditor`)
- Routes SSH connection and Workspace App events into the new `ConnectionLogger`
- Updates all existing tests to check the values of the `ConnectionLogger` instead of the `Auditor`.

Future PRs:
- Add filtering to the query
- Add an enterprise endpoint to query the new table
- Write a query to delete old events from the audit log, call it from dbpurge.
- Implement a table in the Web UI for viewing connection logs.


> [!NOTE]
> The PRs in this stack obviously won't be (completely) atomic. Whilst they'll each pass CI, the stack is designed to be merged all at once. I'm splitting them up for the sake of those reviewing, and so changes can be reviewed as early as possible.  Despite this, it's really hard to make this PR any smaller than it already is. I'll be keeping it in draft until it's actually ready to merge.
This commit is contained in:
Ethan
2025-07-15 14:36:06 +10:00
committed by GitHub
parent 43b0bb7f61
commit 08e17a07fc
54 changed files with 2199 additions and 493 deletions
+8 -8
View File
@@ -19,7 +19,7 @@ import (
agentproto "github.com/coder/coder/v2/agent/proto"
"github.com/coder/coder/v2/coderd/agentapi/resourcesmonitor"
"github.com/coder/coder/v2/coderd/appearance"
"github.com/coder/coder/v2/coderd/audit"
"github.com/coder/coder/v2/coderd/connectionlog"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/pubsub"
"github.com/coder/coder/v2/coderd/externalauth"
@@ -50,7 +50,7 @@ type API struct {
*ResourcesMonitoringAPI
*LogsAPI
*ScriptsAPI
*AuditAPI
*ConnLogAPI
*SubAgentAPI
*tailnet.DRPCService
@@ -71,7 +71,7 @@ type Options struct {
Database database.Store
NotificationsEnqueuer notifications.Enqueuer
Pubsub pubsub.Pubsub
Auditor *atomic.Pointer[audit.Auditor]
ConnectionLogger *atomic.Pointer[connectionlog.ConnectionLogger]
DerpMapFn func() *tailcfg.DERPMap
TailnetCoordinator *atomic.Pointer[tailnet.Coordinator]
StatsReporter *workspacestats.Reporter
@@ -180,11 +180,11 @@ func New(opts Options) *API {
Database: opts.Database,
}
api.AuditAPI = &AuditAPI{
AgentFn: api.agent,
Auditor: opts.Auditor,
Database: opts.Database,
Log: opts.Log,
api.ConnLogAPI = &ConnLogAPI{
AgentFn: api.agent,
ConnectionLogger: opts.ConnectionLogger,
Database: opts.Database,
Log: opts.Log,
}
api.DRPCService = &tailnet.DRPCService{
-105
View File
@@ -1,105 +0,0 @@
package agentapi
import (
"context"
"encoding/json"
"strconv"
"sync/atomic"
"github.com/google/uuid"
"golang.org/x/xerrors"
"google.golang.org/protobuf/types/known/emptypb"
"cdr.dev/slog"
agentproto "github.com/coder/coder/v2/agent/proto"
"github.com/coder/coder/v2/coderd/audit"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/db2sdk"
"github.com/coder/coder/v2/codersdk/agentsdk"
)
type AuditAPI struct {
AgentFn func(context.Context) (database.WorkspaceAgent, error)
Auditor *atomic.Pointer[audit.Auditor]
Database database.Store
Log slog.Logger
}
func (a *AuditAPI) ReportConnection(ctx context.Context, req *agentproto.ReportConnectionRequest) (*emptypb.Empty, error) {
// We will use connection ID as request ID, typically this is the
// SSH session ID as reported by the agent.
connectionID, err := uuid.FromBytes(req.GetConnection().GetId())
if err != nil {
return nil, xerrors.Errorf("connection id from bytes: %w", err)
}
action, err := db2sdk.AuditActionFromAgentProtoConnectionAction(req.GetConnection().GetAction())
if err != nil {
return nil, err
}
connectionType, err := agentsdk.ConnectionTypeFromProto(req.GetConnection().GetType())
if err != nil {
return nil, err
}
// Fetch contextual data for this audit event.
workspaceAgent, err := a.AgentFn(ctx)
if err != nil {
return nil, xerrors.Errorf("get agent: %w", err)
}
workspace, err := a.Database.GetWorkspaceByAgentID(ctx, workspaceAgent.ID)
if err != nil {
return nil, xerrors.Errorf("get workspace by agent id: %w", err)
}
build, err := a.Database.GetLatestWorkspaceBuildByWorkspaceID(ctx, workspace.ID)
if err != nil {
return nil, xerrors.Errorf("get latest workspace build by workspace id: %w", err)
}
// We pass the below information to the Auditor so that it
// can form a friendly string for the user to view in the UI.
type additionalFields struct {
audit.AdditionalFields
ConnectionType agentsdk.ConnectionType `json:"connection_type"`
Reason string `json:"reason,omitempty"`
}
resourceInfo := additionalFields{
AdditionalFields: audit.AdditionalFields{
WorkspaceID: workspace.ID,
WorkspaceName: workspace.Name,
WorkspaceOwner: workspace.OwnerUsername,
BuildNumber: strconv.FormatInt(int64(build.BuildNumber), 10),
BuildReason: database.BuildReason(string(build.Reason)),
},
ConnectionType: connectionType,
Reason: req.GetConnection().GetReason(),
}
riBytes, err := json.Marshal(resourceInfo)
if err != nil {
a.Log.Error(ctx, "marshal resource info for agent connection failed", slog.Error(err))
riBytes = []byte("{}")
}
audit.BackgroundAudit(ctx, &audit.BackgroundAuditParams[database.WorkspaceAgent]{
Audit: *a.Auditor.Load(),
Log: a.Log,
Time: req.GetConnection().GetTimestamp().AsTime(),
OrganizationID: workspace.OrganizationID,
RequestID: connectionID,
Action: action,
New: workspaceAgent,
Old: workspaceAgent,
IP: req.GetConnection().GetIp(),
Status: int(req.GetConnection().GetStatusCode()),
AdditionalFields: riBytes,
// It's not possible to tell which user connected. Once we have
// the capability, this may be reported by the agent.
UserID: uuid.Nil,
})
return &emptypb.Empty{}, nil
}
+106
View File
@@ -0,0 +1,106 @@
package agentapi
import (
"context"
"database/sql"
"sync/atomic"
"github.com/google/uuid"
"golang.org/x/xerrors"
"google.golang.org/protobuf/types/known/emptypb"
"cdr.dev/slog"
agentproto "github.com/coder/coder/v2/agent/proto"
"github.com/coder/coder/v2/coderd/connectionlog"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/db2sdk"
)
type ConnLogAPI struct {
AgentFn func(context.Context) (database.WorkspaceAgent, error)
ConnectionLogger *atomic.Pointer[connectionlog.ConnectionLogger]
Database database.Store
Log slog.Logger
}
func (a *ConnLogAPI) ReportConnection(ctx context.Context, req *agentproto.ReportConnectionRequest) (*emptypb.Empty, error) {
// We use the connection ID to identify which connection log event to mark
// as closed, when we receive a close action for that ID.
connectionID, err := uuid.FromBytes(req.GetConnection().GetId())
if err != nil {
return nil, xerrors.Errorf("connection id from bytes: %w", err)
}
if connectionID == uuid.Nil {
return nil, xerrors.New("connection ID cannot be nil")
}
action, err := db2sdk.ConnectionLogStatusFromAgentProtoConnectionAction(req.GetConnection().GetAction())
if err != nil {
return nil, err
}
connectionType, err := db2sdk.ConnectionLogConnectionTypeFromAgentProtoConnectionType(req.GetConnection().GetType())
if err != nil {
return nil, err
}
var code sql.NullInt32
if action == database.ConnectionStatusDisconnected {
code = sql.NullInt32{
Int32: req.GetConnection().GetStatusCode(),
Valid: true,
}
}
// Fetch contextual data for this connection log event.
workspaceAgent, err := a.AgentFn(ctx)
if err != nil {
return nil, xerrors.Errorf("get agent: %w", err)
}
workspace, err := a.Database.GetWorkspaceByAgentID(ctx, workspaceAgent.ID)
if err != nil {
return nil, xerrors.Errorf("get workspace by agent id: %w", err)
}
reason := req.GetConnection().GetReason()
connLogger := *a.ConnectionLogger.Load()
err = connLogger.Upsert(ctx, database.UpsertConnectionLogParams{
ID: uuid.New(),
Time: req.GetConnection().GetTimestamp().AsTime(),
OrganizationID: workspace.OrganizationID,
WorkspaceOwnerID: workspace.OwnerID,
WorkspaceID: workspace.ID,
WorkspaceName: workspace.Name,
AgentName: workspaceAgent.Name,
Type: connectionType,
Code: code,
Ip: database.ParseIP(req.GetConnection().GetIp()),
ConnectionID: uuid.NullUUID{
UUID: connectionID,
Valid: true,
},
DisconnectReason: sql.NullString{
String: reason,
Valid: reason != "",
},
// We supply the action:
// - So the DB can handle duplicate connections or disconnections properly.
// - To make it clear whether this is a connection or disconnection
// prior to it's insertion into the DB (logs)
ConnectionStatus: action,
// It's not possible to tell which user connected. Once we have
// the capability, this may be reported by the agent.
UserID: uuid.NullUUID{
Valid: false,
},
// N/A
UserAgent: sql.NullString{},
// N/A
SlugOrPort: sql.NullString{},
})
if err != nil {
return nil, xerrors.Errorf("export connection log: %w", err)
}
return &emptypb.Empty{}, nil
}
@@ -2,7 +2,7 @@ package agentapi_test
import (
"context"
"encoding/json"
"database/sql"
"net"
"sync/atomic"
"testing"
@@ -16,15 +16,14 @@ import (
agentproto "github.com/coder/coder/v2/agent/proto"
"github.com/coder/coder/v2/coderd/agentapi"
"github.com/coder/coder/v2/coderd/audit"
"github.com/coder/coder/v2/coderd/connectionlog"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/db2sdk"
"github.com/coder/coder/v2/coderd/database/dbmock"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/codersdk/agentsdk"
)
func TestAuditReport(t *testing.T) {
func TestConnectionLog(t *testing.T) {
t.Parallel()
var (
@@ -38,10 +37,6 @@ func TestAuditReport(t *testing.T) {
OwnerID: owner.ID,
Name: "cool-workspace",
}
build = database.WorkspaceBuild{
ID: uuid.New(),
WorkspaceID: workspace.ID,
}
agent = database.WorkspaceAgent{
ID: uuid.New(),
}
@@ -62,7 +57,7 @@ func TestAuditReport(t *testing.T) {
id: uuid.New(),
action: agentproto.Connection_CONNECT.Enum(),
typ: agentproto.Connection_SSH.Enum(),
time: time.Now(),
time: dbtime.Now(),
ip: "127.0.0.1",
status: 200,
},
@@ -71,7 +66,7 @@ func TestAuditReport(t *testing.T) {
id: uuid.New(),
action: agentproto.Connection_CONNECT.Enum(),
typ: agentproto.Connection_VSCODE.Enum(),
time: time.Now(),
time: dbtime.Now(),
ip: "8.8.8.8",
},
{
@@ -79,28 +74,28 @@ func TestAuditReport(t *testing.T) {
id: uuid.New(),
action: agentproto.Connection_CONNECT.Enum(),
typ: agentproto.Connection_JETBRAINS.Enum(),
time: time.Now(),
time: dbtime.Now(),
},
{
name: "Reconnecting PTY Connect",
id: uuid.New(),
action: agentproto.Connection_CONNECT.Enum(),
typ: agentproto.Connection_RECONNECTING_PTY.Enum(),
time: time.Now(),
time: dbtime.Now(),
},
{
name: "SSH Disconnect",
id: uuid.New(),
action: agentproto.Connection_DISCONNECT.Enum(),
typ: agentproto.Connection_SSH.Enum(),
time: time.Now(),
time: dbtime.Now(),
},
{
name: "SSH Disconnect",
id: uuid.New(),
action: agentproto.Connection_DISCONNECT.Enum(),
typ: agentproto.Connection_SSH.Enum(),
time: time.Now(),
time: dbtime.Now(),
status: 500,
reason: "because error says so",
},
@@ -110,15 +105,14 @@ func TestAuditReport(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
mAudit := audit.NewMock()
connLogger := connectionlog.NewFake()
mDB := dbmock.NewMockStore(gomock.NewController(t))
mDB.EXPECT().GetWorkspaceByAgentID(gomock.Any(), agent.ID).Return(workspace, nil)
mDB.EXPECT().GetLatestWorkspaceBuildByWorkspaceID(gomock.Any(), workspace.ID).Return(build, nil)
api := &agentapi.AuditAPI{
Auditor: asAtomicPointer[audit.Auditor](mAudit),
Database: mDB,
api := &agentapi.ConnLogAPI{
ConnectionLogger: asAtomicPointer[connectionlog.ConnectionLogger](connLogger),
Database: mDB,
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
@@ -135,41 +129,48 @@ func TestAuditReport(t *testing.T) {
},
})
require.True(t, mAudit.Contains(t, database.AuditLog{
Time: dbtime.Time(tt.time).In(time.UTC),
Action: agentProtoConnectionActionToAudit(t, *tt.action),
OrganizationID: workspace.OrganizationID,
UserID: uuid.Nil,
RequestID: tt.id,
ResourceType: database.ResourceTypeWorkspaceAgent,
ResourceID: agent.ID,
ResourceTarget: agent.Name,
Ip: pqtype.Inet{Valid: true, IPNet: net.IPNet{IP: net.ParseIP(tt.ip), Mask: net.CIDRMask(32, 32)}},
StatusCode: tt.status,
}))
require.True(t, connLogger.Contains(t, database.UpsertConnectionLogParams{
Time: dbtime.Time(tt.time).In(time.UTC),
OrganizationID: workspace.OrganizationID,
WorkspaceOwnerID: workspace.OwnerID,
WorkspaceID: workspace.ID,
WorkspaceName: workspace.Name,
AgentName: agent.Name,
UserID: uuid.NullUUID{
UUID: uuid.Nil,
Valid: false,
},
ConnectionStatus: agentProtoConnectionActionToConnectionLog(t, *tt.action),
// Check some additional fields.
var m map[string]any
err := json.Unmarshal(mAudit.AuditLogs()[0].AdditionalFields, &m)
require.NoError(t, err)
require.Equal(t, string(agentProtoConnectionTypeToSDK(t, *tt.typ)), m["connection_type"].(string))
if tt.reason != "" {
require.Equal(t, tt.reason, m["reason"])
}
Code: sql.NullInt32{
Int32: tt.status,
Valid: *tt.action == agentproto.Connection_DISCONNECT,
},
Ip: pqtype.Inet{Valid: true, IPNet: net.IPNet{IP: net.ParseIP(tt.ip), Mask: net.CIDRMask(32, 32)}},
Type: agentProtoConnectionTypeToConnectionLog(t, *tt.typ),
DisconnectReason: sql.NullString{
String: tt.reason,
Valid: tt.reason != "",
},
ConnectionID: uuid.NullUUID{
UUID: tt.id,
Valid: tt.id != uuid.Nil,
},
}))
})
}
}
func agentProtoConnectionActionToAudit(t *testing.T, action agentproto.Connection_Action) database.AuditAction {
a, err := db2sdk.AuditActionFromAgentProtoConnectionAction(action)
func agentProtoConnectionTypeToConnectionLog(t *testing.T, typ agentproto.Connection_Type) database.ConnectionType {
a, err := db2sdk.ConnectionLogConnectionTypeFromAgentProtoConnectionType(typ)
require.NoError(t, err)
return a
}
func agentProtoConnectionTypeToSDK(t *testing.T, typ agentproto.Connection_Type) agentsdk.ConnectionType {
action, err := agentsdk.ConnectionTypeFromProto(typ)
func agentProtoConnectionActionToConnectionLog(t *testing.T, action agentproto.Connection_Action) database.ConnectionStatus {
a, err := db2sdk.ConnectionLogStatusFromAgentProtoConnectionAction(action)
require.NoError(t, err)
return action
return a
}
func asAtomicPointer[T any](v T) *atomic.Pointer[T] {
+2
View File
@@ -15364,6 +15364,7 @@ const docTemplate = `{
"assign_org_role",
"assign_role",
"audit_log",
"connection_log",
"crypto_key",
"debug_info",
"deployment_config",
@@ -15403,6 +15404,7 @@ const docTemplate = `{
"ResourceAssignOrgRole",
"ResourceAssignRole",
"ResourceAuditLog",
"ResourceConnectionLog",
"ResourceCryptoKey",
"ResourceDebugInfo",
"ResourceDeploymentConfig",
+2
View File
@@ -13936,6 +13936,7 @@
"assign_org_role",
"assign_role",
"audit_log",
"connection_log",
"crypto_key",
"debug_info",
"deployment_config",
@@ -13975,6 +13976,7 @@
"ResourceAssignOrgRole",
"ResourceAssignRole",
"ResourceAuditLog",
"ResourceConnectionLog",
"ResourceCryptoKey",
"ResourceDebugInfo",
"ResourceDeploymentConfig",
+2 -20
View File
@@ -6,13 +6,11 @@ import (
"encoding/json"
"flag"
"fmt"
"net"
"net/http"
"strconv"
"time"
"github.com/google/uuid"
"github.com/sqlc-dev/pqtype"
"go.opentelemetry.io/otel/baggage"
"golang.org/x/xerrors"
@@ -434,7 +432,7 @@ func InitRequest[T Auditable](w http.ResponseWriter, p *RequestParams) (*Request
action = req.Action
}
ip := ParseIP(p.Request.RemoteAddr)
ip := database.ParseIP(p.Request.RemoteAddr)
auditLog := database.AuditLog{
ID: uuid.New(),
Time: dbtime.Now(),
@@ -466,7 +464,7 @@ func InitRequest[T Auditable](w http.ResponseWriter, p *RequestParams) (*Request
// BackgroundAudit creates an audit log for a background event.
// The audit log is committed upon invocation.
func BackgroundAudit[T Auditable](ctx context.Context, p *BackgroundAuditParams[T]) {
ip := ParseIP(p.IP)
ip := database.ParseIP(p.IP)
diff := Diff(p.Audit, p.Old, p.New)
var err error
@@ -581,19 +579,3 @@ func either[T Auditable, R any](old, newVal T, fn func(T) R, auditAction databas
panic("both old and new are nil")
}
}
func ParseIP(ipStr string) pqtype.Inet {
ip := net.ParseIP(ipStr)
ipNet := net.IPNet{}
if ip != nil {
ipNet = net.IPNet{
IP: ip,
Mask: net.CIDRMask(len(ip)*8, len(ip)*8),
}
}
return pqtype.Inet{
IPNet: ipNet,
Valid: ip != nil,
}
}
+9 -1
View File
@@ -59,6 +59,7 @@ import (
"github.com/coder/coder/v2/coderd/appearance"
"github.com/coder/coder/v2/coderd/audit"
"github.com/coder/coder/v2/coderd/awsidentity"
"github.com/coder/coder/v2/coderd/connectionlog"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbrollup"
@@ -154,6 +155,7 @@ type Options struct {
CacheDir string
Auditor audit.Auditor
ConnectionLogger connectionlog.ConnectionLogger
AgentConnectionUpdateFrequency time.Duration
AgentInactiveDisconnectTimeout time.Duration
AWSCertificates awsidentity.Certificates
@@ -400,6 +402,9 @@ func New(options *Options) *API {
if options.Auditor == nil {
options.Auditor = audit.NewNop()
}
if options.ConnectionLogger == nil {
options.ConnectionLogger = connectionlog.NewNop()
}
if options.SSHConfig.HostnamePrefix == "" {
options.SSHConfig.HostnamePrefix = "coder."
}
@@ -568,6 +573,7 @@ func New(options *Options) *API {
},
metricsCache: metricsCache,
Auditor: atomic.Pointer[audit.Auditor]{},
ConnectionLogger: atomic.Pointer[connectionlog.ConnectionLogger]{},
TailnetCoordinator: atomic.Pointer[tailnet.Coordinator]{},
UpdatesProvider: updatesProvider,
TemplateScheduleStore: options.TemplateScheduleStore,
@@ -589,7 +595,7 @@ func New(options *Options) *API {
options.Logger.Named("workspaceapps"),
options.AccessURL,
options.Authorizer,
&api.Auditor,
&api.ConnectionLogger,
options.Database,
options.DeploymentValues,
oauthConfigs,
@@ -691,6 +697,7 @@ func New(options *Options) *API {
}
api.Auditor.Store(&options.Auditor)
api.ConnectionLogger.Store(&options.ConnectionLogger)
api.TailnetCoordinator.Store(&options.TailnetCoordinator)
dialer := &InmemTailnetDialer{
CoordPtr: &api.TailnetCoordinator,
@@ -1613,6 +1620,7 @@ type API struct {
// specific replica.
ID uuid.UUID
Auditor atomic.Pointer[audit.Auditor]
ConnectionLogger atomic.Pointer[connectionlog.ConnectionLogger]
WorkspaceClientCoordinateOverride atomic.Pointer[func(rw http.ResponseWriter) bool]
TailnetCoordinator atomic.Pointer[tailnet.Coordinator]
NetworkTelemetryBatcher *tailnet.NetworkTelemetryBatcher
+1
View File
@@ -451,6 +451,7 @@ func randomRBACType() string {
all := []string{
rbac.ResourceWorkspace.Type,
rbac.ResourceAuditLog.Type,
rbac.ResourceConnectionLog.Type,
rbac.ResourceTemplate.Type,
rbac.ResourceGroup.Type,
rbac.ResourceFile.Type,
+9
View File
@@ -61,6 +61,7 @@ import (
"github.com/coder/coder/v2/coderd/audit"
"github.com/coder/coder/v2/coderd/autobuild"
"github.com/coder/coder/v2/coderd/awsidentity"
"github.com/coder/coder/v2/coderd/connectionlog"
"github.com/coder/coder/v2/coderd/cryptokeys"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/db2sdk"
@@ -125,6 +126,7 @@ type Options struct {
TemplateScheduleStore schedule.TemplateScheduleStore
Coordinator tailnet.Coordinator
CoordinatorResumeTokenProvider tailnet.ResumeTokenProvider
ConnectionLogger connectionlog.ConnectionLogger
HealthcheckFunc func(ctx context.Context, apiKey string) *healthsdk.HealthcheckReport
HealthcheckTimeout time.Duration
@@ -356,6 +358,12 @@ func NewOptions(t testing.TB, options *Options) (func(http.Handler), context.Can
}
auditor.Store(&options.Auditor)
var connectionLogger atomic.Pointer[connectionlog.ConnectionLogger]
if options.ConnectionLogger == nil {
options.ConnectionLogger = connectionlog.NewNop()
}
connectionLogger.Store(&options.ConnectionLogger)
ctx, cancelFunc := context.WithCancel(context.Background())
experiments := coderd.ReadExperiments(*options.Logger, options.DeploymentValues.Experiments)
lifecycleExecutor := autobuild.NewExecutor(
@@ -543,6 +551,7 @@ func NewOptions(t testing.TB, options *Options) (func(http.Handler), context.Can
ExternalAuthConfigs: options.ExternalAuthConfigs,
Auditor: options.Auditor,
ConnectionLogger: options.ConnectionLogger,
AWSCertificates: options.AWSCertificates,
AzureCertificates: options.AzureCertificates,
GithubOAuth2Config: options.GithubOAuth2Config,
+121
View File
@@ -0,0 +1,121 @@
package connectionlog
import (
"context"
"sync"
"testing"
"github.com/google/uuid"
"github.com/coder/coder/v2/coderd/database"
)
type ConnectionLogger interface {
Upsert(ctx context.Context, clog database.UpsertConnectionLogParams) error
}
type nop struct{}
func NewNop() ConnectionLogger {
return nop{}
}
func (nop) Upsert(context.Context, database.UpsertConnectionLogParams) error {
return nil
}
func NewFake() *FakeConnectionLogger {
return &FakeConnectionLogger{}
}
type FakeConnectionLogger struct {
mu sync.Mutex
upsertions []database.UpsertConnectionLogParams
}
func (m *FakeConnectionLogger) Reset() {
m.mu.Lock()
defer m.mu.Unlock()
m.upsertions = make([]database.UpsertConnectionLogParams, 0)
}
func (m *FakeConnectionLogger) ConnectionLogs() []database.UpsertConnectionLogParams {
m.mu.Lock()
defer m.mu.Unlock()
return m.upsertions
}
func (m *FakeConnectionLogger) Upsert(_ context.Context, clog database.UpsertConnectionLogParams) error {
m.mu.Lock()
defer m.mu.Unlock()
m.upsertions = append(m.upsertions, clog)
return nil
}
func (m *FakeConnectionLogger) Contains(t testing.TB, expected database.UpsertConnectionLogParams) bool {
m.mu.Lock()
defer m.mu.Unlock()
for idx, cl := range m.upsertions {
if expected.ID != uuid.Nil && cl.ID != expected.ID {
t.Logf("connection log %d: expected ID %s, got %s", idx+1, expected.ID, cl.ID)
continue
}
if !expected.Time.IsZero() && expected.Time != cl.Time {
t.Logf("connection log %d: expected Time %s, got %s", idx+1, expected.Time, cl.Time)
continue
}
if expected.OrganizationID != uuid.Nil && cl.OrganizationID != expected.OrganizationID {
t.Logf("connection log %d: expected OrganizationID %s, got %s", idx+1, expected.OrganizationID, cl.OrganizationID)
continue
}
if expected.WorkspaceOwnerID != uuid.Nil && cl.WorkspaceOwnerID != expected.WorkspaceOwnerID {
t.Logf("connection log %d: expected WorkspaceOwnerID %s, got %s", idx+1, expected.WorkspaceOwnerID, cl.WorkspaceOwnerID)
continue
}
if expected.WorkspaceID != uuid.Nil && cl.WorkspaceID != expected.WorkspaceID {
t.Logf("connection log %d: expected WorkspaceID %s, got %s", idx+1, expected.WorkspaceID, cl.WorkspaceID)
continue
}
if expected.WorkspaceName != "" && cl.WorkspaceName != expected.WorkspaceName {
t.Logf("connection log %d: expected WorkspaceName %s, got %s", idx+1, expected.WorkspaceName, cl.WorkspaceName)
continue
}
if expected.AgentName != "" && cl.AgentName != expected.AgentName {
t.Logf("connection log %d: expected AgentName %s, got %s", idx+1, expected.AgentName, cl.AgentName)
continue
}
if expected.Type != "" && cl.Type != expected.Type {
t.Logf("connection log %d: expected Type %s, got %s", idx+1, expected.Type, cl.Type)
continue
}
if expected.Code.Valid && cl.Code.Int32 != expected.Code.Int32 {
t.Logf("connection log %d: expected Code %d, got %d", idx+1, expected.Code.Int32, cl.Code.Int32)
continue
}
if expected.Ip.Valid && cl.Ip.IPNet.String() != expected.Ip.IPNet.String() {
t.Logf("connection log %d: expected IP %s, got %s", idx+1, expected.Ip.IPNet, cl.Ip.IPNet)
continue
}
if expected.UserAgent.Valid && cl.UserAgent.String != expected.UserAgent.String {
t.Logf("connection log %d: expected UserAgent %s, got %s", idx+1, expected.UserAgent.String, cl.UserAgent.String)
continue
}
if expected.UserID.Valid && cl.UserID.UUID != expected.UserID.UUID {
t.Logf("connection log %d: expected UserID %s, got %s", idx+1, expected.UserID.UUID, cl.UserID.UUID)
continue
}
if expected.SlugOrPort.Valid && cl.SlugOrPort.String != expected.SlugOrPort.String {
t.Logf("connection log %d: expected SlugOrPort %s, got %s", idx+1, expected.SlugOrPort.String, cl.SlugOrPort.String)
continue
}
if expected.ConnectionID.Valid && cl.ConnectionID.UUID != expected.ConnectionID.UUID {
t.Logf("connection log %d: expected ConnectionID %s, got %s", idx+1, expected.ConnectionID.UUID, cl.ConnectionID.UUID)
continue
}
return true
}
return false
}
+19 -14
View File
@@ -781,26 +781,31 @@ func TemplateRoleActions(role codersdk.TemplateRole) []policy.Action {
return []policy.Action{}
}
func AuditActionFromAgentProtoConnectionAction(action agentproto.Connection_Action) (database.AuditAction, error) {
switch action {
case agentproto.Connection_CONNECT:
return database.AuditActionConnect, nil
case agentproto.Connection_DISCONNECT:
return database.AuditActionDisconnect, nil
func ConnectionLogConnectionTypeFromAgentProtoConnectionType(typ agentproto.Connection_Type) (database.ConnectionType, error) {
switch typ {
case agentproto.Connection_SSH:
return database.ConnectionTypeSsh, nil
case agentproto.Connection_JETBRAINS:
return database.ConnectionTypeJetbrains, nil
case agentproto.Connection_VSCODE:
return database.ConnectionTypeVscode, nil
case agentproto.Connection_RECONNECTING_PTY:
return database.ConnectionTypeReconnectingPty, nil
default:
// Also Connection_ACTION_UNSPECIFIED, no mapping.
return "", xerrors.Errorf("unknown agent connection action %q", action)
// Also Connection_TYPE_UNSPECIFIED, no mapping.
return "", xerrors.Errorf("unknown agent connection type %q", typ)
}
}
func AgentProtoConnectionActionToAuditAction(action database.AuditAction) (agentproto.Connection_Action, error) {
func ConnectionLogStatusFromAgentProtoConnectionAction(action agentproto.Connection_Action) (database.ConnectionStatus, error) {
switch action {
case database.AuditActionConnect:
return agentproto.Connection_CONNECT, nil
case database.AuditActionDisconnect:
return agentproto.Connection_DISCONNECT, nil
case agentproto.Connection_CONNECT:
return database.ConnectionStatusConnected, nil
case agentproto.Connection_DISCONNECT:
return database.ConnectionStatusDisconnected, nil
default:
return agentproto.Connection_ACTION_UNSPECIFIED, xerrors.Errorf("unknown agent connection action %q", action)
// Also Connection_ACTION_UNSPECIFIED, no mapping.
return "", xerrors.Errorf("unknown agent connection action %q", action)
}
}
+48
View File
@@ -306,6 +306,24 @@ var (
Scope: rbac.ScopeAll,
}.WithCachedASTValue()
subjectConnectionLogger = rbac.Subject{
Type: rbac.SubjectTypeConnectionLogger,
FriendlyName: "Connection Logger",
ID: uuid.Nil.String(),
Roles: rbac.Roles([]rbac.Role{
{
Identifier: rbac.RoleIdentifier{Name: "connectionlogger"},
DisplayName: "Connection Logger",
Site: rbac.Permissions(map[string][]policy.Action{
rbac.ResourceConnectionLog.Type: {policy.ActionUpdate, policy.ActionRead},
}),
Org: map[string][]rbac.Permission{},
User: []rbac.Permission{},
},
}),
Scope: rbac.ScopeAll,
}.WithCachedASTValue()
subjectNotifier = rbac.Subject{
Type: rbac.SubjectTypeNotifier,
FriendlyName: "Notifier",
@@ -521,6 +539,10 @@ func AsKeyReader(ctx context.Context) context.Context {
return As(ctx, subjectCryptoKeyReader)
}
func AsConnectionLogger(ctx context.Context) context.Context {
return As(ctx, subjectConnectionLogger)
}
// AsNotifier returns a context with an actor that has permissions required for
// creating/reading/updating/deleting notifications.
func AsNotifier(ctx context.Context) context.Context {
@@ -1856,6 +1878,21 @@ func (q *querier) GetAuthorizationUserRoles(ctx context.Context, userID uuid.UUI
return q.db.GetAuthorizationUserRoles(ctx, userID)
}
func (q *querier) GetConnectionLogsOffset(ctx context.Context, arg database.GetConnectionLogsOffsetParams) ([]database.GetConnectionLogsOffsetRow, error) {
// Just like with the audit logs query, shortcut if the user is an owner.
err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceConnectionLog)
if err == nil {
return q.db.GetConnectionLogsOffset(ctx, arg)
}
prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceConnectionLog.Type)
if err != nil {
return nil, xerrors.Errorf("(dev error) prepare sql filter: %w", err)
}
return q.db.GetAuthorizedConnectionLogsOffset(ctx, arg, prep)
}
func (q *querier) GetCoordinatorResumeTokenSigningKey(ctx context.Context) (string, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
return "", err
@@ -5099,6 +5136,13 @@ func (q *querier) UpsertApplicationName(ctx context.Context, value string) error
return q.db.UpsertApplicationName(ctx, value)
}
func (q *querier) UpsertConnectionLog(ctx context.Context, arg database.UpsertConnectionLogParams) (database.ConnectionLog, error) {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceConnectionLog); err != nil {
return database.ConnectionLog{}, err
}
return q.db.UpsertConnectionLog(ctx, arg)
}
func (q *querier) UpsertCoordinatorResumeTokenSigningKey(ctx context.Context, value string) error {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil {
return err
@@ -5344,3 +5388,7 @@ func (q *querier) GetAuthorizedAuditLogsOffset(ctx context.Context, arg database
func (q *querier) CountAuthorizedAuditLogs(ctx context.Context, arg database.CountAuditLogsParams, _ rbac.PreparedAuthorized) (int64, error) {
return q.CountAuditLogs(ctx, arg)
}
func (q *querier) GetAuthorizedConnectionLogsOffset(ctx context.Context, arg database.GetConnectionLogsOffsetParams, _ rbac.PreparedAuthorized) ([]database.GetConnectionLogsOffsetRow, error) {
return q.GetConnectionLogsOffset(ctx, arg)
}
+69
View File
@@ -339,6 +339,75 @@ func (s *MethodTestSuite) TestAuditLogs() {
}))
}
func (s *MethodTestSuite) TestConnectionLogs() {
createWorkspace := func(t *testing.T, db database.Store) database.WorkspaceTable {
u := dbgen.User(s.T(), db, database.User{})
o := dbgen.Organization(s.T(), db, database.Organization{})
tpl := dbgen.Template(s.T(), db, database.Template{
OrganizationID: o.ID,
CreatedBy: u.ID,
})
return dbgen.Workspace(s.T(), db, database.WorkspaceTable{
ID: uuid.New(),
OwnerID: u.ID,
OrganizationID: o.ID,
AutomaticUpdates: database.AutomaticUpdatesNever,
TemplateID: tpl.ID,
})
}
s.Run("UpsertConnectionLog", s.Subtest(func(db database.Store, check *expects) {
ws := createWorkspace(s.T(), db)
check.Args(database.UpsertConnectionLogParams{
Ip: defaultIPAddress(),
Type: database.ConnectionTypeSsh,
WorkspaceID: ws.ID,
OrganizationID: ws.OrganizationID,
ConnectionStatus: database.ConnectionStatusConnected,
WorkspaceOwnerID: ws.OwnerID,
}).Asserts(rbac.ResourceConnectionLog, policy.ActionUpdate)
}))
s.Run("GetConnectionLogsOffset", s.Subtest(func(db database.Store, check *expects) {
ws := createWorkspace(s.T(), db)
_ = dbgen.ConnectionLog(s.T(), db, database.UpsertConnectionLogParams{
Ip: defaultIPAddress(),
Type: database.ConnectionTypeSsh,
WorkspaceID: ws.ID,
OrganizationID: ws.OrganizationID,
WorkspaceOwnerID: ws.OwnerID,
})
_ = dbgen.ConnectionLog(s.T(), db, database.UpsertConnectionLogParams{
Ip: defaultIPAddress(),
Type: database.ConnectionTypeSsh,
WorkspaceID: ws.ID,
OrganizationID: ws.OrganizationID,
WorkspaceOwnerID: ws.OwnerID,
})
check.Args(database.GetConnectionLogsOffsetParams{
LimitOpt: 10,
}).Asserts(rbac.ResourceConnectionLog, policy.ActionRead).WithNotAuthorized("nil")
}))
s.Run("GetAuthorizedConnectionLogsOffset", s.Subtest(func(db database.Store, check *expects) {
ws := createWorkspace(s.T(), db)
_ = dbgen.ConnectionLog(s.T(), db, database.UpsertConnectionLogParams{
Ip: defaultIPAddress(),
Type: database.ConnectionTypeSsh,
WorkspaceID: ws.ID,
OrganizationID: ws.OrganizationID,
WorkspaceOwnerID: ws.OwnerID,
})
_ = dbgen.ConnectionLog(s.T(), db, database.UpsertConnectionLogParams{
Ip: defaultIPAddress(),
Type: database.ConnectionTypeSsh,
WorkspaceID: ws.ID,
OrganizationID: ws.OrganizationID,
WorkspaceOwnerID: ws.OwnerID,
})
check.Args(database.GetConnectionLogsOffsetParams{
LimitOpt: 10,
}, emptyPreparedAuthorized{}).Asserts(rbac.ResourceConnectionLog, policy.ActionRead)
}))
}
func (s *MethodTestSuite) TestFile() {
s.Run("GetFileByHashAndCreator", s.Subtest(func(db database.Store, check *expects) {
f := dbgen.File(s.T(), db, database.File{})
+47
View File
@@ -73,6 +73,53 @@ func AuditLog(t testing.TB, db database.Store, seed database.AuditLog) database.
return log
}
func ConnectionLog(t testing.TB, db database.Store, seed database.UpsertConnectionLogParams) database.ConnectionLog {
log, err := db.UpsertConnectionLog(genCtx, database.UpsertConnectionLogParams{
ID: takeFirst(seed.ID, uuid.New()),
Time: takeFirst(seed.Time, dbtime.Now()),
OrganizationID: takeFirst(seed.OrganizationID, uuid.New()),
WorkspaceOwnerID: takeFirst(seed.WorkspaceOwnerID, uuid.New()),
WorkspaceID: takeFirst(seed.WorkspaceID, uuid.New()),
WorkspaceName: takeFirst(seed.WorkspaceName, testutil.GetRandomName(t)),
AgentName: takeFirst(seed.AgentName, testutil.GetRandomName(t)),
Type: takeFirst(seed.Type, database.ConnectionTypeSsh),
Code: sql.NullInt32{
Int32: takeFirst(seed.Code.Int32, 0),
Valid: takeFirst(seed.Code.Valid, false),
},
Ip: pqtype.Inet{
IPNet: net.IPNet{
IP: net.IPv4(127, 0, 0, 1),
Mask: net.IPv4Mask(255, 255, 255, 255),
},
Valid: true,
},
UserAgent: sql.NullString{
String: takeFirst(seed.UserAgent.String, ""),
Valid: takeFirst(seed.UserAgent.Valid, false),
},
UserID: uuid.NullUUID{
UUID: takeFirst(seed.UserID.UUID, uuid.Nil),
Valid: takeFirst(seed.UserID.Valid, false),
},
SlugOrPort: sql.NullString{
String: takeFirst(seed.SlugOrPort.String, ""),
Valid: takeFirst(seed.SlugOrPort.Valid, false),
},
ConnectionID: uuid.NullUUID{
UUID: takeFirst(seed.ConnectionID.UUID, uuid.Nil),
Valid: takeFirst(seed.ConnectionID.Valid, false),
},
DisconnectReason: sql.NullString{
String: takeFirst(seed.DisconnectReason.String, ""),
Valid: takeFirst(seed.DisconnectReason.Valid, false),
},
ConnectionStatus: takeFirst(seed.ConnectionStatus, database.ConnectionStatusConnected),
})
require.NoError(t, err, "insert connection log")
return log
}
func Template(t testing.TB, db database.Store, seed database.Template) database.Template {
id := takeFirst(seed.ID, uuid.New())
if seed.GroupACL == nil {
+21
View File
@@ -656,6 +656,13 @@ func (m queryMetricsStore) GetAuthorizationUserRoles(ctx context.Context, userID
return row, err
}
func (m queryMetricsStore) GetConnectionLogsOffset(ctx context.Context, arg database.GetConnectionLogsOffsetParams) ([]database.GetConnectionLogsOffsetRow, error) {
start := time.Now()
r0, r1 := m.s.GetConnectionLogsOffset(ctx, arg)
m.queryLatencies.WithLabelValues("GetConnectionLogsOffset").Observe(time.Since(start).Seconds())
return r0, r1
}
func (m queryMetricsStore) GetCoordinatorResumeTokenSigningKey(ctx context.Context) (string, error) {
start := time.Now()
r0, r1 := m.s.GetCoordinatorResumeTokenSigningKey(ctx)
@@ -3162,6 +3169,13 @@ func (m queryMetricsStore) UpsertApplicationName(ctx context.Context, value stri
return r0
}
func (m queryMetricsStore) UpsertConnectionLog(ctx context.Context, arg database.UpsertConnectionLogParams) (database.ConnectionLog, error) {
start := time.Now()
r0, r1 := m.s.UpsertConnectionLog(ctx, arg)
m.queryLatencies.WithLabelValues("UpsertConnectionLog").Observe(time.Since(start).Seconds())
return r0, r1
}
func (m queryMetricsStore) UpsertCoordinatorResumeTokenSigningKey(ctx context.Context, value string) error {
start := time.Now()
r0 := m.s.UpsertCoordinatorResumeTokenSigningKey(ctx, value)
@@ -3392,3 +3406,10 @@ func (m queryMetricsStore) CountAuthorizedAuditLogs(ctx context.Context, arg dat
m.queryLatencies.WithLabelValues("CountAuthorizedAuditLogs").Observe(time.Since(start).Seconds())
return r0, r1
}
func (m queryMetricsStore) GetAuthorizedConnectionLogsOffset(ctx context.Context, arg database.GetConnectionLogsOffsetParams, prepared rbac.PreparedAuthorized) ([]database.GetConnectionLogsOffsetRow, error) {
start := time.Now()
r0, r1 := m.s.GetAuthorizedConnectionLogsOffset(ctx, arg, prepared)
m.queryLatencies.WithLabelValues("GetAuthorizedConnectionLogsOffset").Observe(time.Since(start).Seconds())
return r0, r1
}
+45
View File
@@ -1248,6 +1248,21 @@ func (mr *MockStoreMockRecorder) GetAuthorizedAuditLogsOffset(ctx, arg, prepared
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthorizedAuditLogsOffset", reflect.TypeOf((*MockStore)(nil).GetAuthorizedAuditLogsOffset), ctx, arg, prepared)
}
// GetAuthorizedConnectionLogsOffset mocks base method.
func (m *MockStore) GetAuthorizedConnectionLogsOffset(ctx context.Context, arg database.GetConnectionLogsOffsetParams, prepared rbac.PreparedAuthorized) ([]database.GetConnectionLogsOffsetRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetAuthorizedConnectionLogsOffset", ctx, arg, prepared)
ret0, _ := ret[0].([]database.GetConnectionLogsOffsetRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetAuthorizedConnectionLogsOffset indicates an expected call of GetAuthorizedConnectionLogsOffset.
func (mr *MockStoreMockRecorder) GetAuthorizedConnectionLogsOffset(ctx, arg, prepared any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthorizedConnectionLogsOffset", reflect.TypeOf((*MockStore)(nil).GetAuthorizedConnectionLogsOffset), ctx, arg, prepared)
}
// GetAuthorizedTemplates mocks base method.
func (m *MockStore) GetAuthorizedTemplates(ctx context.Context, arg database.GetTemplatesWithFilterParams, prepared rbac.PreparedAuthorized) ([]database.Template, error) {
m.ctrl.T.Helper()
@@ -1323,6 +1338,21 @@ func (mr *MockStoreMockRecorder) GetAuthorizedWorkspacesAndAgentsByOwnerID(ctx,
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthorizedWorkspacesAndAgentsByOwnerID", reflect.TypeOf((*MockStore)(nil).GetAuthorizedWorkspacesAndAgentsByOwnerID), ctx, ownerID, prepared)
}
// GetConnectionLogsOffset mocks base method.
func (m *MockStore) GetConnectionLogsOffset(ctx context.Context, arg database.GetConnectionLogsOffsetParams) ([]database.GetConnectionLogsOffsetRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetConnectionLogsOffset", ctx, arg)
ret0, _ := ret[0].([]database.GetConnectionLogsOffsetRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetConnectionLogsOffset indicates an expected call of GetConnectionLogsOffset.
func (mr *MockStoreMockRecorder) GetConnectionLogsOffset(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetConnectionLogsOffset", reflect.TypeOf((*MockStore)(nil).GetConnectionLogsOffset), ctx, arg)
}
// GetCoordinatorResumeTokenSigningKey mocks base method.
func (m *MockStore) GetCoordinatorResumeTokenSigningKey(ctx context.Context) (string, error) {
m.ctrl.T.Helper()
@@ -6698,6 +6728,21 @@ func (mr *MockStoreMockRecorder) UpsertApplicationName(ctx, value any) *gomock.C
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertApplicationName", reflect.TypeOf((*MockStore)(nil).UpsertApplicationName), ctx, value)
}
// UpsertConnectionLog mocks base method.
func (m *MockStore) UpsertConnectionLog(ctx context.Context, arg database.UpsertConnectionLogParams) (database.ConnectionLog, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpsertConnectionLog", ctx, arg)
ret0, _ := ret[0].(database.ConnectionLog)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpsertConnectionLog indicates an expected call of UpsertConnectionLog.
func (mr *MockStoreMockRecorder) UpsertConnectionLog(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertConnectionLog", reflect.TypeOf((*MockStore)(nil).UpsertConnectionLog), ctx, arg)
}
// UpsertCoordinatorResumeTokenSigningKey mocks base method.
func (m *MockStore) UpsertCoordinatorResumeTokenSigningKey(ctx context.Context, value string) error {
m.ctrl.T.Helper()
+73
View File
@@ -38,6 +38,8 @@ CREATE TYPE audit_action AS ENUM (
'close'
);
COMMENT ON TYPE audit_action IS 'NOTE: `connect`, `disconnect`, `open`, and `close` are deprecated and no longer used - these events are now tracked in the connection_logs table.';
CREATE TYPE automatic_updates AS ENUM (
'always',
'never'
@@ -52,6 +54,20 @@ CREATE TYPE build_reason AS ENUM (
'autodelete'
);
CREATE TYPE connection_status AS ENUM (
'connected',
'disconnected'
);
CREATE TYPE connection_type AS ENUM (
'ssh',
'vscode',
'jetbrains',
'reconnecting_pty',
'workspace_app',
'port_forwarding'
);
CREATE TYPE crypto_key_feature AS ENUM (
'workspace_apps_token',
'workspace_apps_api_key',
@@ -823,6 +839,39 @@ CREATE TABLE audit_logs (
resource_icon text NOT NULL
);
CREATE TABLE connection_logs (
id uuid NOT NULL,
connect_time timestamp with time zone NOT NULL,
organization_id uuid NOT NULL,
workspace_owner_id uuid NOT NULL,
workspace_id uuid NOT NULL,
workspace_name text NOT NULL,
agent_name text NOT NULL,
type connection_type NOT NULL,
ip inet NOT NULL,
code integer,
user_agent text,
user_id uuid,
slug_or_port text,
connection_id uuid,
disconnect_time timestamp with time zone,
disconnect_reason text
);
COMMENT ON COLUMN connection_logs.code IS 'Either the HTTP status code of the web request, or the exit code of an SSH connection. For non-web connections, this is Null until we receive a disconnect event for the same connection_id.';
COMMENT ON COLUMN connection_logs.user_agent IS 'Null for SSH events. For web connections, this is the User-Agent header from the request.';
COMMENT ON COLUMN connection_logs.user_id IS 'Null for SSH events. For web connections, this is the ID of the user that made the request.';
COMMENT ON COLUMN connection_logs.slug_or_port IS 'Null for SSH events. For web connections, this is the slug of the app or the port number being forwarded.';
COMMENT ON COLUMN connection_logs.connection_id IS 'The SSH connection ID. Used to correlate connections and disconnections. As it originates from the agent, it is not guaranteed to be unique.';
COMMENT ON COLUMN connection_logs.disconnect_time IS 'The time the connection was closed. Null for web connections. For other connections, this is null until we receive a disconnect event for the same connection_id.';
COMMENT ON COLUMN connection_logs.disconnect_reason IS 'The reason the connection was closed. Null for web connections. For other connections, this is null until we receive a disconnect event for the same connection_id.';
CREATE TABLE crypto_keys (
feature crypto_key_feature NOT NULL,
sequence integer NOT NULL,
@@ -2413,6 +2462,9 @@ ALTER TABLE ONLY api_keys
ALTER TABLE ONLY audit_logs
ADD CONSTRAINT audit_logs_pkey PRIMARY KEY (id);
ALTER TABLE ONLY connection_logs
ADD CONSTRAINT connection_logs_pkey PRIMARY KEY (id);
ALTER TABLE ONLY crypto_keys
ADD CONSTRAINT crypto_keys_pkey PRIMARY KEY (feature, sequence);
@@ -2699,6 +2751,18 @@ CREATE INDEX idx_audit_log_user_id ON audit_logs USING btree (user_id);
CREATE INDEX idx_audit_logs_time_desc ON audit_logs USING btree ("time" DESC);
CREATE INDEX idx_connection_logs_connect_time_desc ON connection_logs USING btree (connect_time DESC);
CREATE UNIQUE INDEX idx_connection_logs_connection_id_workspace_id_agent_name ON connection_logs USING btree (connection_id, workspace_id, agent_name);
COMMENT ON INDEX idx_connection_logs_connection_id_workspace_id_agent_name IS 'Connection ID is NULL for web events, but present for SSH events. Therefore, this index allows multiple web events for the same workspace & agent. For SSH events, the upsertion query handles duplicates on this index by upserting the disconnect_time and disconnect_reason for the same connection_id when the connection is closed.';
CREATE INDEX idx_connection_logs_organization_id ON connection_logs USING btree (organization_id);
CREATE INDEX idx_connection_logs_workspace_id ON connection_logs USING btree (workspace_id);
CREATE INDEX idx_connection_logs_workspace_owner_id ON connection_logs USING btree (workspace_owner_id);
CREATE INDEX idx_custom_roles_id ON custom_roles USING btree (id);
CREATE UNIQUE INDEX idx_custom_roles_name_lower ON custom_roles USING btree (lower(name));
@@ -2906,6 +2970,15 @@ forward without requiring a migration to clean up historical data.';
ALTER TABLE ONLY api_keys
ADD CONSTRAINT api_keys_user_id_uuid_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
ALTER TABLE ONLY connection_logs
ADD CONSTRAINT connection_logs_organization_id_fkey FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE;
ALTER TABLE ONLY connection_logs
ADD CONSTRAINT connection_logs_workspace_id_fkey FOREIGN KEY (workspace_id) REFERENCES workspaces(id) ON DELETE CASCADE;
ALTER TABLE ONLY connection_logs
ADD CONSTRAINT connection_logs_workspace_owner_id_fkey FOREIGN KEY (workspace_owner_id) REFERENCES users(id) ON DELETE CASCADE;
ALTER TABLE ONLY crypto_keys
ADD CONSTRAINT crypto_keys_secret_key_id_fkey FOREIGN KEY (secret_key_id) REFERENCES dbcrypt_keys(active_key_digest);
@@ -7,6 +7,9 @@ type ForeignKeyConstraint string
// ForeignKeyConstraint enums.
const (
ForeignKeyAPIKeysUserIDUUID ForeignKeyConstraint = "api_keys_user_id_uuid_fkey" // ALTER TABLE ONLY api_keys ADD CONSTRAINT api_keys_user_id_uuid_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
ForeignKeyConnectionLogsOrganizationID ForeignKeyConstraint = "connection_logs_organization_id_fkey" // ALTER TABLE ONLY connection_logs ADD CONSTRAINT connection_logs_organization_id_fkey FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE;
ForeignKeyConnectionLogsWorkspaceID ForeignKeyConstraint = "connection_logs_workspace_id_fkey" // ALTER TABLE ONLY connection_logs ADD CONSTRAINT connection_logs_workspace_id_fkey FOREIGN KEY (workspace_id) REFERENCES workspaces(id) ON DELETE CASCADE;
ForeignKeyConnectionLogsWorkspaceOwnerID ForeignKeyConstraint = "connection_logs_workspace_owner_id_fkey" // ALTER TABLE ONLY connection_logs ADD CONSTRAINT connection_logs_workspace_owner_id_fkey FOREIGN KEY (workspace_owner_id) REFERENCES users(id) ON DELETE CASCADE;
ForeignKeyCryptoKeysSecretKeyID ForeignKeyConstraint = "crypto_keys_secret_key_id_fkey" // ALTER TABLE ONLY crypto_keys ADD CONSTRAINT crypto_keys_secret_key_id_fkey FOREIGN KEY (secret_key_id) REFERENCES dbcrypt_keys(active_key_digest);
ForeignKeyFkOauth2ProviderAppTokensUserID ForeignKeyConstraint = "fk_oauth2_provider_app_tokens_user_id" // ALTER TABLE ONLY oauth2_provider_app_tokens ADD CONSTRAINT fk_oauth2_provider_app_tokens_user_id FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
ForeignKeyGitAuthLinksOauthAccessTokenKeyID ForeignKeyConstraint = "git_auth_links_oauth_access_token_key_id_fkey" // ALTER TABLE ONLY external_auth_links ADD CONSTRAINT git_auth_links_oauth_access_token_key_id_fkey FOREIGN KEY (oauth_access_token_key_id) REFERENCES dbcrypt_keys(active_key_digest);
@@ -0,0 +1,11 @@
DROP INDEX IF EXISTS idx_connection_logs_workspace_id;
DROP INDEX IF EXISTS idx_connection_logs_workspace_owner_id;
DROP INDEX IF EXISTS idx_connection_logs_organization_id;
DROP INDEX IF EXISTS idx_connection_logs_connect_time_desc;
DROP INDEX IF EXISTS idx_connection_logs_connection_id_workspace_id_agent_name;
DROP TABLE IF EXISTS connection_logs;
DROP TYPE IF EXISTS connection_type;
DROP TYPE IF EXISTS connection_status;
@@ -0,0 +1,68 @@
CREATE TYPE connection_status AS ENUM (
'connected',
'disconnected'
);
CREATE TYPE connection_type AS ENUM (
-- SSH events
'ssh',
'vscode',
'jetbrains',
'reconnecting_pty',
-- Web events
'workspace_app',
'port_forwarding'
);
CREATE TABLE connection_logs (
id uuid NOT NULL,
connect_time timestamp with time zone NOT NULL,
organization_id uuid NOT NULL REFERENCES organizations (id) ON DELETE CASCADE,
workspace_owner_id uuid NOT NULL REFERENCES users (id) ON DELETE CASCADE,
workspace_id uuid NOT NULL REFERENCES workspaces (id) ON DELETE CASCADE,
workspace_name text NOT NULL,
agent_name text NOT NULL,
type connection_type NOT NULL,
ip inet NOT NULL,
code integer,
-- Only set for web events
user_agent text,
user_id uuid,
slug_or_port text,
-- Null for web events
connection_id uuid,
disconnect_time timestamp with time zone, -- Null until we upsert a disconnect log for the same connection_id.
disconnect_reason text,
PRIMARY KEY (id)
);
COMMENT ON COLUMN connection_logs.code IS 'Either the HTTP status code of the web request, or the exit code of an SSH connection. For non-web connections, this is Null until we receive a disconnect event for the same connection_id.';
COMMENT ON COLUMN connection_logs.user_agent IS 'Null for SSH events. For web connections, this is the User-Agent header from the request.';
COMMENT ON COLUMN connection_logs.user_id IS 'Null for SSH events. For web connections, this is the ID of the user that made the request.';
COMMENT ON COLUMN connection_logs.slug_or_port IS 'Null for SSH events. For web connections, this is the slug of the app or the port number being forwarded.';
COMMENT ON COLUMN connection_logs.connection_id IS 'The SSH connection ID. Used to correlate connections and disconnections. As it originates from the agent, it is not guaranteed to be unique.';
COMMENT ON COLUMN connection_logs.disconnect_time IS 'The time the connection was closed. Null for web connections. For other connections, this is null until we receive a disconnect event for the same connection_id.';
COMMENT ON COLUMN connection_logs.disconnect_reason IS 'The reason the connection was closed. Null for web connections. For other connections, this is null until we receive a disconnect event for the same connection_id.';
COMMENT ON TYPE audit_action IS 'NOTE: `connect`, `disconnect`, `open`, and `close` are deprecated and no longer used - these events are now tracked in the connection_logs table.';
-- To associate connection closure events with the connection start events.
CREATE UNIQUE INDEX idx_connection_logs_connection_id_workspace_id_agent_name
ON connection_logs (connection_id, workspace_id, agent_name);
COMMENT ON INDEX idx_connection_logs_connection_id_workspace_id_agent_name IS 'Connection ID is NULL for web events, but present for SSH events. Therefore, this index allows multiple web events for the same workspace & agent. For SSH events, the upsertion query handles duplicates on this index by upserting the disconnect_time and disconnect_reason for the same connection_id when the connection is closed.';
CREATE INDEX idx_connection_logs_connect_time_desc ON connection_logs USING btree (connect_time DESC);
CREATE INDEX idx_connection_logs_organization_id ON connection_logs USING btree (organization_id);
CREATE INDEX idx_connection_logs_workspace_owner_id ON connection_logs USING btree (workspace_owner_id);
CREATE INDEX idx_connection_logs_workspace_id ON connection_logs USING btree (workspace_id);
@@ -0,0 +1,53 @@
INSERT INTO connection_logs (
id,
connect_time,
organization_id,
workspace_owner_id,
workspace_id,
workspace_name,
agent_name,
type,
code,
ip,
user_agent,
user_id,
slug_or_port,
connection_id,
disconnect_time,
disconnect_reason
) VALUES (
'00000000-0000-0000-0000-000000000001', -- log id
'2023-10-01 12:00:00+00', -- start time
'bb640d07-ca8a-4869-b6bc-ae61ebb2fda1', -- organization id
'a0061a8e-7db7-4585-838c-3116a003dd21', -- workspace owner id
'3a9a1feb-e89d-457c-9d53-ac751b198ebe', -- workspace id
'Test Workspace', -- workspace name
'test-agent', -- agent name
'ssh', -- type
0, -- code
'127.0.0.1', -- ip
NULL, -- user agent
NULL, -- user id
NULL, -- slug or port
'00000000-0000-0000-0000-000000000003', -- connection id
'2023-10-01 12:00:10+00', -- close time
'server shut down' -- reason
),
(
'00000000-0000-0000-0000-000000000002', -- log id
'2023-10-01 12:05:00+00', -- start time
'bb640d07-ca8a-4869-b6bc-ae61ebb2fda1', -- organization id
'a0061a8e-7db7-4585-838c-3116a003dd21', -- workspace owner id
'3a9a1feb-e89d-457c-9d53-ac751b198ebe', -- workspace id
'Test Workspace', -- workspace name
'test-agent', -- agent name
'workspace_app', -- type
200, -- code
'127.0.0.1',
'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/100.0.4896.127 Safari/537.36',
'a0061a8e-7db7-4585-838c-3116a003dd21', -- user id
'code-server', -- slug or port
NULL, -- connection id (request ID)
NULL, -- close time
NULL -- reason
);
+13
View File
@@ -117,6 +117,19 @@ func (w AuditLog) RBACObject() rbac.Object {
return obj
}
func (w GetConnectionLogsOffsetRow) RBACObject() rbac.Object {
return w.ConnectionLog.RBACObject()
}
func (w ConnectionLog) RBACObject() rbac.Object {
obj := rbac.ResourceConnectionLog.WithID(w.ID)
if w.OrganizationID != uuid.Nil {
obj = obj.InOrg(w.OrganizationID)
}
return obj
}
func (s APIKeyScope) ToRBAC() rbac.ScopeName {
switch s {
case APIKeyScopeAll:
+76
View File
@@ -50,6 +50,7 @@ type customQuerier interface {
workspaceQuerier
userQuerier
auditLogQuerier
connectionLogQuerier
}
type templateQuerier interface {
@@ -611,6 +612,81 @@ func (q *sqlQuerier) CountAuthorizedAuditLogs(ctx context.Context, arg CountAudi
return count, nil
}
type connectionLogQuerier interface {
GetAuthorizedConnectionLogsOffset(ctx context.Context, arg GetConnectionLogsOffsetParams, prepared rbac.PreparedAuthorized) ([]GetConnectionLogsOffsetRow, error)
}
func (q *sqlQuerier) GetAuthorizedConnectionLogsOffset(ctx context.Context, arg GetConnectionLogsOffsetParams, prepared rbac.PreparedAuthorized) ([]GetConnectionLogsOffsetRow, error) {
authorizedFilter, err := prepared.CompileToSQL(ctx, regosql.ConvertConfig{
VariableConverter: regosql.ConnectionLogConverter(),
})
if err != nil {
return nil, xerrors.Errorf("compile authorized filter: %w", err)
}
filtered, err := insertAuthorizedFilter(getConnectionLogsOffset, fmt.Sprintf(" AND %s", authorizedFilter))
if err != nil {
return nil, xerrors.Errorf("insert authorized filter: %w", err)
}
query := fmt.Sprintf("-- name: GetAuthorizedConnectionLogsOffset :many\n%s", filtered)
rows, err := q.db.QueryContext(ctx, query,
arg.OffsetOpt,
arg.LimitOpt,
)
if err != nil {
return nil, err
}
defer rows.Close()
var items []GetConnectionLogsOffsetRow
for rows.Next() {
var i GetConnectionLogsOffsetRow
if err := rows.Scan(
&i.ConnectionLog.ID,
&i.ConnectionLog.ConnectTime,
&i.ConnectionLog.OrganizationID,
&i.ConnectionLog.WorkspaceOwnerID,
&i.ConnectionLog.WorkspaceID,
&i.ConnectionLog.WorkspaceName,
&i.ConnectionLog.AgentName,
&i.ConnectionLog.Type,
&i.ConnectionLog.Ip,
&i.ConnectionLog.Code,
&i.ConnectionLog.UserAgent,
&i.ConnectionLog.UserID,
&i.ConnectionLog.SlugOrPort,
&i.ConnectionLog.ConnectionID,
&i.ConnectionLog.DisconnectTime,
&i.ConnectionLog.DisconnectReason,
&i.UserUsername,
&i.UserName,
&i.UserEmail,
&i.UserCreatedAt,
&i.UserUpdatedAt,
&i.UserLastSeenAt,
&i.UserStatus,
&i.UserLoginType,
&i.UserRoles,
&i.UserAvatarUrl,
&i.UserDeleted,
&i.UserQuietHoursSchedule,
&i.WorkspaceOwnerUsername,
&i.OrganizationName,
&i.OrganizationDisplayName,
&i.OrganizationIcon,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
func insertAuthorizedFilter(query string, replaceWith string) (string, error) {
if !strings.Contains(query, authorizedQueryPlaceholder) {
return "", xerrors.Errorf("query does not contain authorized replace string, this is not an authorized query")
+155
View File
@@ -196,6 +196,7 @@ func AllAppSharingLevelValues() []AppSharingLevel {
}
}
// NOTE: `connect`, `disconnect`, `open`, and `close` are deprecated and no longer used - these events are now tracked in the connection_logs table.
type AuditAction string
const (
@@ -415,6 +416,134 @@ func AllBuildReasonValues() []BuildReason {
}
}
type ConnectionStatus string
const (
ConnectionStatusConnected ConnectionStatus = "connected"
ConnectionStatusDisconnected ConnectionStatus = "disconnected"
)
func (e *ConnectionStatus) Scan(src interface{}) error {
switch s := src.(type) {
case []byte:
*e = ConnectionStatus(s)
case string:
*e = ConnectionStatus(s)
default:
return fmt.Errorf("unsupported scan type for ConnectionStatus: %T", src)
}
return nil
}
type NullConnectionStatus struct {
ConnectionStatus ConnectionStatus `json:"connection_status"`
Valid bool `json:"valid"` // Valid is true if ConnectionStatus is not NULL
}
// Scan implements the Scanner interface.
func (ns *NullConnectionStatus) Scan(value interface{}) error {
if value == nil {
ns.ConnectionStatus, ns.Valid = "", false
return nil
}
ns.Valid = true
return ns.ConnectionStatus.Scan(value)
}
// Value implements the driver Valuer interface.
func (ns NullConnectionStatus) Value() (driver.Value, error) {
if !ns.Valid {
return nil, nil
}
return string(ns.ConnectionStatus), nil
}
func (e ConnectionStatus) Valid() bool {
switch e {
case ConnectionStatusConnected,
ConnectionStatusDisconnected:
return true
}
return false
}
func AllConnectionStatusValues() []ConnectionStatus {
return []ConnectionStatus{
ConnectionStatusConnected,
ConnectionStatusDisconnected,
}
}
type ConnectionType string
const (
ConnectionTypeSsh ConnectionType = "ssh"
ConnectionTypeVscode ConnectionType = "vscode"
ConnectionTypeJetbrains ConnectionType = "jetbrains"
ConnectionTypeReconnectingPty ConnectionType = "reconnecting_pty"
ConnectionTypeWorkspaceApp ConnectionType = "workspace_app"
ConnectionTypePortForwarding ConnectionType = "port_forwarding"
)
func (e *ConnectionType) Scan(src interface{}) error {
switch s := src.(type) {
case []byte:
*e = ConnectionType(s)
case string:
*e = ConnectionType(s)
default:
return fmt.Errorf("unsupported scan type for ConnectionType: %T", src)
}
return nil
}
type NullConnectionType struct {
ConnectionType ConnectionType `json:"connection_type"`
Valid bool `json:"valid"` // Valid is true if ConnectionType is not NULL
}
// Scan implements the Scanner interface.
func (ns *NullConnectionType) Scan(value interface{}) error {
if value == nil {
ns.ConnectionType, ns.Valid = "", false
return nil
}
ns.Valid = true
return ns.ConnectionType.Scan(value)
}
// Value implements the driver Valuer interface.
func (ns NullConnectionType) Value() (driver.Value, error) {
if !ns.Valid {
return nil, nil
}
return string(ns.ConnectionType), nil
}
func (e ConnectionType) Valid() bool {
switch e {
case ConnectionTypeSsh,
ConnectionTypeVscode,
ConnectionTypeJetbrains,
ConnectionTypeReconnectingPty,
ConnectionTypeWorkspaceApp,
ConnectionTypePortForwarding:
return true
}
return false
}
func AllConnectionTypeValues() []ConnectionType {
return []ConnectionType{
ConnectionTypeSsh,
ConnectionTypeVscode,
ConnectionTypeJetbrains,
ConnectionTypeReconnectingPty,
ConnectionTypeWorkspaceApp,
ConnectionTypePortForwarding,
}
}
type CryptoKeyFeature string
const (
@@ -2784,6 +2913,32 @@ type AuditLog struct {
ResourceIcon string `db:"resource_icon" json:"resource_icon"`
}
type ConnectionLog struct {
ID uuid.UUID `db:"id" json:"id"`
ConnectTime time.Time `db:"connect_time" json:"connect_time"`
OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"`
WorkspaceOwnerID uuid.UUID `db:"workspace_owner_id" json:"workspace_owner_id"`
WorkspaceID uuid.UUID `db:"workspace_id" json:"workspace_id"`
WorkspaceName string `db:"workspace_name" json:"workspace_name"`
AgentName string `db:"agent_name" json:"agent_name"`
Type ConnectionType `db:"type" json:"type"`
Ip pqtype.Inet `db:"ip" json:"ip"`
// Either the HTTP status code of the web request, or the exit code of an SSH connection. For non-web connections, this is Null until we receive a disconnect event for the same connection_id.
Code sql.NullInt32 `db:"code" json:"code"`
// Null for SSH events. For web connections, this is the User-Agent header from the request.
UserAgent sql.NullString `db:"user_agent" json:"user_agent"`
// Null for SSH events. For web connections, this is the ID of the user that made the request.
UserID uuid.NullUUID `db:"user_id" json:"user_id"`
// Null for SSH events. For web connections, this is the slug of the app or the port number being forwarded.
SlugOrPort sql.NullString `db:"slug_or_port" json:"slug_or_port"`
// The SSH connection ID. Used to correlate connections and disconnections. As it originates from the agent, it is not guaranteed to be unique.
ConnectionID uuid.NullUUID `db:"connection_id" json:"connection_id"`
// The time the connection was closed. Null for web connections. For other connections, this is null until we receive a disconnect event for the same connection_id.
DisconnectTime sql.NullTime `db:"disconnect_time" json:"disconnect_time"`
// The reason the connection was closed. Null for web connections. For other connections, this is null until we receive a disconnect event for the same connection_id.
DisconnectReason sql.NullString `db:"disconnect_reason" json:"disconnect_reason"`
}
type CryptoKey struct {
Feature CryptoKeyFeature `db:"feature" json:"feature"`
Sequence int32 `db:"sequence" json:"sequence"`
+2
View File
@@ -156,6 +156,7 @@ type sqlcQuerier interface {
// This function returns roles for authorization purposes. Implied member roles
// are included.
GetAuthorizationUserRoles(ctx context.Context, userID uuid.UUID) (GetAuthorizationUserRolesRow, error)
GetConnectionLogsOffset(ctx context.Context, arg GetConnectionLogsOffsetParams) ([]GetConnectionLogsOffsetRow, error)
GetCoordinatorResumeTokenSigningKey(ctx context.Context) (string, error)
GetCryptoKeyByFeatureAndSequence(ctx context.Context, arg GetCryptoKeyByFeatureAndSequenceParams) (CryptoKey, error)
GetCryptoKeys(ctx context.Context) ([]CryptoKey, error)
@@ -647,6 +648,7 @@ type sqlcQuerier interface {
UpsertAnnouncementBanners(ctx context.Context, value string) error
UpsertAppSecurityKey(ctx context.Context, value string) error
UpsertApplicationName(ctx context.Context, value string) error
UpsertConnectionLog(ctx context.Context, arg UpsertConnectionLogParams) (ConnectionLog, error)
UpsertCoordinatorResumeTokenSigningKey(ctx context.Context, value string) error
// The default proxy is implied and not actually stored in the database.
// So we need to store it's configuration here for display purposes.
+443
View File
@@ -6,6 +6,7 @@ import (
"encoding/json"
"errors"
"fmt"
"net"
"sort"
"testing"
"time"
@@ -13,6 +14,7 @@ import (
"github.com/google/uuid"
"github.com/lib/pq"
"github.com/prometheus/client_golang/prometheus"
"github.com/sqlc-dev/pqtype"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -2085,6 +2087,447 @@ func auditOnlyIDs[T database.AuditLog | database.GetAuditLogsOffsetRow](logs []T
return ids
}
func TestGetAuthorizedConnectionLogsOffset(t *testing.T) {
t.Parallel()
var allLogs []database.ConnectionLog
db, _ := dbtestutil.NewDB(t)
authz := rbac.NewAuthorizer(prometheus.NewRegistry())
authDb := dbauthz.New(db, authz, slogtest.Make(t, &slogtest.Options{}), coderdtest.AccessControlStorePointer())
orgA := dbfake.Organization(t, db).Do()
orgB := dbfake.Organization(t, db).Do()
user := dbgen.User(t, db, database.User{})
tpl := dbgen.Template(t, db, database.Template{
OrganizationID: orgA.Org.ID,
CreatedBy: user.ID,
})
wsID := uuid.New()
createTemplateVersion(t, db, tpl, tvArgs{
WorkspaceTransition: database.WorkspaceTransitionStart,
Status: database.ProvisionerJobStatusSucceeded,
CreateWorkspace: true,
WorkspaceID: wsID,
})
// This map is a simple way to insert a given number of organizations
// and audit logs for each organization.
// map[orgID][]ConnectionLogID
orgConnectionLogs := map[uuid.UUID][]uuid.UUID{
orgA.Org.ID: {uuid.New(), uuid.New()},
orgB.Org.ID: {uuid.New(), uuid.New()},
}
orgIDs := make([]uuid.UUID, 0, len(orgConnectionLogs))
for orgID := range orgConnectionLogs {
orgIDs = append(orgIDs, orgID)
}
for orgID, ids := range orgConnectionLogs {
for _, id := range ids {
allLogs = append(allLogs, dbgen.ConnectionLog(t, authDb, database.UpsertConnectionLogParams{
WorkspaceID: wsID,
WorkspaceOwnerID: user.ID,
ID: id,
OrganizationID: orgID,
}))
}
}
// Now fetch all the logs
ctx := testutil.Context(t, testutil.WaitLong)
auditorRole, err := rbac.RoleByName(rbac.RoleAuditor())
require.NoError(t, err)
memberRole, err := rbac.RoleByName(rbac.RoleMember())
require.NoError(t, err)
orgAuditorRoles := func(t *testing.T, orgID uuid.UUID) rbac.Role {
t.Helper()
role, err := rbac.RoleByName(rbac.ScopedRoleOrgAuditor(orgID))
require.NoError(t, err)
return role
}
t.Run("NoAccess", func(t *testing.T) {
t.Parallel()
// Given: A user who is a member of 0 organizations
memberCtx := dbauthz.As(ctx, rbac.Subject{
FriendlyName: "member",
ID: uuid.NewString(),
Roles: rbac.Roles{memberRole},
Scope: rbac.ScopeAll,
})
// When: The user queries for connection logs
logs, err := authDb.GetConnectionLogsOffset(memberCtx, database.GetConnectionLogsOffsetParams{})
require.NoError(t, err)
// Then: No logs returned
require.Len(t, logs, 0, "no logs should be returned")
})
t.Run("SiteWideAuditor", func(t *testing.T) {
t.Parallel()
// Given: A site wide auditor
siteAuditorCtx := dbauthz.As(ctx, rbac.Subject{
FriendlyName: "owner",
ID: uuid.NewString(),
Roles: rbac.Roles{auditorRole},
Scope: rbac.ScopeAll,
})
// When: the auditor queries for connection logs
logs, err := authDb.GetConnectionLogsOffset(siteAuditorCtx, database.GetConnectionLogsOffsetParams{})
require.NoError(t, err)
// Then: All logs are returned
require.ElementsMatch(t, connectionOnlyIDs(allLogs), connectionOnlyIDs(logs))
})
t.Run("SingleOrgAuditor", func(t *testing.T) {
t.Parallel()
orgID := orgIDs[0]
// Given: An organization scoped auditor
orgAuditCtx := dbauthz.As(ctx, rbac.Subject{
FriendlyName: "org-auditor",
ID: uuid.NewString(),
Roles: rbac.Roles{orgAuditorRoles(t, orgID)},
Scope: rbac.ScopeAll,
})
// When: The auditor queries for connection logs
logs, err := authDb.GetConnectionLogsOffset(orgAuditCtx, database.GetConnectionLogsOffsetParams{})
require.NoError(t, err)
// Then: Only the logs for the organization are returned
require.ElementsMatch(t, orgConnectionLogs[orgID], connectionOnlyIDs(logs))
})
t.Run("TwoOrgAuditors", func(t *testing.T) {
t.Parallel()
first := orgIDs[0]
second := orgIDs[1]
// Given: A user who is an auditor for two organizations
multiOrgAuditCtx := dbauthz.As(ctx, rbac.Subject{
FriendlyName: "org-auditor",
ID: uuid.NewString(),
Roles: rbac.Roles{orgAuditorRoles(t, first), orgAuditorRoles(t, second)},
Scope: rbac.ScopeAll,
})
// When: The user queries for connection logs
logs, err := authDb.GetConnectionLogsOffset(multiOrgAuditCtx, database.GetConnectionLogsOffsetParams{})
require.NoError(t, err)
// Then: All logs for both organizations are returned
require.ElementsMatch(t, append(orgConnectionLogs[first], orgConnectionLogs[second]...), connectionOnlyIDs(logs))
})
t.Run("ErroneousOrg", func(t *testing.T) {
t.Parallel()
// Given: A user who is an auditor for an organization that has 0 logs
userCtx := dbauthz.As(ctx, rbac.Subject{
FriendlyName: "org-auditor",
ID: uuid.NewString(),
Roles: rbac.Roles{orgAuditorRoles(t, uuid.New())},
Scope: rbac.ScopeAll,
})
// When: The user queries for audit logs
logs, err := authDb.GetConnectionLogsOffset(userCtx, database.GetConnectionLogsOffsetParams{})
require.NoError(t, err)
// Then: No logs are returned
require.Len(t, logs, 0, "no logs should be returned")
})
}
func connectionOnlyIDs[T database.ConnectionLog | database.GetConnectionLogsOffsetRow](logs []T) []uuid.UUID {
ids := make([]uuid.UUID, 0, len(logs))
for _, log := range logs {
switch log := any(log).(type) {
case database.ConnectionLog:
ids = append(ids, log.ID)
case database.GetConnectionLogsOffsetRow:
ids = append(ids, log.ConnectionLog.ID)
default:
panic("unreachable")
}
}
return ids
}
func TestUpsertConnectionLog(t *testing.T) {
t.Parallel()
createWorkspace := func(t *testing.T, db database.Store) database.WorkspaceTable {
u := dbgen.User(t, db, database.User{})
o := dbgen.Organization(t, db, database.Organization{})
tpl := dbgen.Template(t, db, database.Template{
OrganizationID: o.ID,
CreatedBy: u.ID,
})
return dbgen.Workspace(t, db, database.WorkspaceTable{
ID: uuid.New(),
OwnerID: u.ID,
OrganizationID: o.ID,
AutomaticUpdates: database.AutomaticUpdatesNever,
TemplateID: tpl.ID,
})
}
t.Run("ConnectThenDisconnect", func(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
ctx := context.Background()
ws := createWorkspace(t, db)
connectionID := uuid.New()
agentName := "test-agent"
// 1. Insert a 'connect' event.
connectTime := dbtime.Now()
connectParams := database.UpsertConnectionLogParams{
ID: uuid.New(),
Time: connectTime,
OrganizationID: ws.OrganizationID,
WorkspaceOwnerID: ws.OwnerID,
WorkspaceID: ws.ID,
WorkspaceName: ws.Name,
AgentName: agentName,
Type: database.ConnectionTypeSsh,
ConnectionID: uuid.NullUUID{UUID: connectionID, Valid: true},
ConnectionStatus: database.ConnectionStatusConnected,
Ip: pqtype.Inet{
IPNet: net.IPNet{
IP: net.IPv4(127, 0, 0, 1),
Mask: net.IPv4Mask(255, 255, 255, 255),
},
Valid: true,
},
}
log1, err := db.UpsertConnectionLog(ctx, connectParams)
require.NoError(t, err)
require.Equal(t, connectParams.ID, log1.ID)
require.False(t, log1.DisconnectTime.Valid, "CloseTime should not be set on connect")
// Check that one row exists.
rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{LimitOpt: 10})
require.NoError(t, err)
require.Len(t, rows, 1)
// 2. Insert a 'disconnected' event for the same connection.
disconnectTime := connectTime.Add(time.Second)
disconnectParams := database.UpsertConnectionLogParams{
ConnectionID: uuid.NullUUID{UUID: connectionID, Valid: true},
WorkspaceID: ws.ID,
AgentName: agentName,
ConnectionStatus: database.ConnectionStatusDisconnected,
// Updated to:
Time: disconnectTime,
DisconnectReason: sql.NullString{String: "test disconnect", Valid: true},
Code: sql.NullInt32{Int32: 1, Valid: true},
// Ignored
ID: uuid.New(),
OrganizationID: ws.OrganizationID,
WorkspaceOwnerID: ws.OwnerID,
WorkspaceName: ws.Name,
Type: database.ConnectionTypeSsh,
Ip: pqtype.Inet{
IPNet: net.IPNet{
IP: net.IPv4(127, 0, 0, 1),
Mask: net.IPv4Mask(255, 255, 255, 254),
},
Valid: true,
},
}
log2, err := db.UpsertConnectionLog(ctx, disconnectParams)
require.NoError(t, err)
// Updated
require.Equal(t, log1.ID, log2.ID)
require.True(t, log2.DisconnectTime.Valid)
require.True(t, disconnectTime.Equal(log2.DisconnectTime.Time))
require.Equal(t, disconnectParams.DisconnectReason.String, log2.DisconnectReason.String)
rows, err = db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{})
require.NoError(t, err)
require.Len(t, rows, 1)
})
t.Run("ConnectDoesNotUpdate", func(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
ctx := context.Background()
ws := createWorkspace(t, db)
connectionID := uuid.New()
agentName := "test-agent"
// 1. Insert a 'connect' event.
connectTime := dbtime.Now()
connectParams := database.UpsertConnectionLogParams{
ID: uuid.New(),
Time: connectTime,
OrganizationID: ws.OrganizationID,
WorkspaceOwnerID: ws.OwnerID,
WorkspaceID: ws.ID,
WorkspaceName: ws.Name,
AgentName: agentName,
Type: database.ConnectionTypeSsh,
ConnectionID: uuid.NullUUID{UUID: connectionID, Valid: true},
ConnectionStatus: database.ConnectionStatusConnected,
Ip: pqtype.Inet{
IPNet: net.IPNet{
IP: net.IPv4(127, 0, 0, 1),
Mask: net.IPv4Mask(255, 255, 255, 255),
},
Valid: true,
},
}
log, err := db.UpsertConnectionLog(ctx, connectParams)
require.NoError(t, err)
// 2. Insert another 'connect' event for the same connection.
connectTime2 := connectTime.Add(time.Second)
connectParams2 := database.UpsertConnectionLogParams{
ConnectionID: uuid.NullUUID{UUID: connectionID, Valid: true},
WorkspaceID: ws.ID,
AgentName: agentName,
ConnectionStatus: database.ConnectionStatusConnected,
// Ignored
ID: uuid.New(),
Time: connectTime2,
OrganizationID: ws.OrganizationID,
WorkspaceOwnerID: ws.OwnerID,
WorkspaceName: ws.Name,
Type: database.ConnectionTypeSsh,
Code: sql.NullInt32{Int32: 0, Valid: false},
Ip: pqtype.Inet{
IPNet: net.IPNet{
IP: net.IPv4(127, 0, 0, 1),
Mask: net.IPv4Mask(255, 255, 255, 254),
},
Valid: true,
},
}
origLog, err := db.UpsertConnectionLog(ctx, connectParams2)
require.NoError(t, err)
require.Equal(t, log, origLog, "connect update should be a no-op")
// Check that still only one row exists.
rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{})
require.NoError(t, err)
require.Len(t, rows, 1)
require.Equal(t, log, rows[0].ConnectionLog)
})
t.Run("DisconnectThenConnect", func(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
ctx := context.Background()
ws := createWorkspace(t, db)
connectionID := uuid.New()
agentName := "test-agent"
// Insert just a 'disconect' event
disconnectTime := dbtime.Now()
disconnectParams := database.UpsertConnectionLogParams{
ID: uuid.New(),
Time: disconnectTime,
OrganizationID: ws.OrganizationID,
WorkspaceOwnerID: ws.OwnerID,
WorkspaceID: ws.ID,
WorkspaceName: ws.Name,
AgentName: agentName,
Type: database.ConnectionTypeSsh,
ConnectionID: uuid.NullUUID{UUID: connectionID, Valid: true},
ConnectionStatus: database.ConnectionStatusDisconnected,
DisconnectReason: sql.NullString{String: "server shutting down", Valid: true},
Ip: pqtype.Inet{
IPNet: net.IPNet{
IP: net.IPv4(127, 0, 0, 1),
Mask: net.IPv4Mask(255, 255, 255, 255),
},
Valid: true,
},
}
_, err := db.UpsertConnectionLog(ctx, disconnectParams)
require.NoError(t, err)
firstRows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{})
require.NoError(t, err)
require.Len(t, firstRows, 1)
// We expect the connection event to be marked as closed with the start
// and close time being the same.
require.True(t, firstRows[0].ConnectionLog.DisconnectTime.Valid)
require.Equal(t, disconnectTime, firstRows[0].ConnectionLog.DisconnectTime.Time.UTC())
require.Equal(t, firstRows[0].ConnectionLog.ConnectTime.UTC(), firstRows[0].ConnectionLog.DisconnectTime.Time.UTC())
// Now insert a 'connect' event for the same connection.
// This should be a no op
connectTime := disconnectTime.Add(time.Second)
connectParams := database.UpsertConnectionLogParams{
ID: uuid.New(),
Time: connectTime,
OrganizationID: ws.OrganizationID,
WorkspaceOwnerID: ws.OwnerID,
WorkspaceID: ws.ID,
WorkspaceName: ws.Name,
AgentName: agentName,
Type: database.ConnectionTypeSsh,
ConnectionID: uuid.NullUUID{UUID: connectionID, Valid: true},
ConnectionStatus: database.ConnectionStatusConnected,
DisconnectReason: sql.NullString{String: "reconnected", Valid: true},
Code: sql.NullInt32{Int32: 0, Valid: false},
Ip: pqtype.Inet{
IPNet: net.IPNet{
IP: net.IPv4(127, 0, 0, 1),
Mask: net.IPv4Mask(255, 255, 255, 255),
},
Valid: true,
},
}
_, err = db.UpsertConnectionLog(ctx, connectParams)
require.NoError(t, err)
secondRows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{})
require.NoError(t, err)
require.Len(t, secondRows, 1)
require.Equal(t, firstRows, secondRows)
// Upsert a disconnection, which should also be a no op
disconnectParams.DisconnectReason = sql.NullString{
String: "updated close reason",
Valid: true,
}
_, err = db.UpsertConnectionLog(ctx, disconnectParams)
require.NoError(t, err)
thirdRows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{})
require.NoError(t, err)
require.Len(t, secondRows, 1)
// The close reason shouldn't be updated
require.Equal(t, secondRows, thirdRows)
})
}
type tvArgs struct {
Status database.ProvisionerJobStatus
// CreateWorkspace is true if we should create a workspace for the template version
+240
View File
@@ -880,6 +880,246 @@ func (q *sqlQuerier) InsertAuditLog(ctx context.Context, arg InsertAuditLogParam
return i, err
}
const getConnectionLogsOffset = `-- name: GetConnectionLogsOffset :many
SELECT
connection_logs.id, connection_logs.connect_time, connection_logs.organization_id, connection_logs.workspace_owner_id, connection_logs.workspace_id, connection_logs.workspace_name, connection_logs.agent_name, connection_logs.type, connection_logs.ip, connection_logs.code, connection_logs.user_agent, connection_logs.user_id, connection_logs.slug_or_port, connection_logs.connection_id, connection_logs.disconnect_time, connection_logs.disconnect_reason,
-- sqlc.embed(users) would be nice but it does not seem to play well with
-- left joins. This user metadata is necessary for parity with the audit logs
-- API.
users.username AS user_username,
users.name AS user_name,
users.email AS user_email,
users.created_at AS user_created_at,
users.updated_at AS user_updated_at,
users.last_seen_at AS user_last_seen_at,
users.status AS user_status,
users.login_type AS user_login_type,
users.rbac_roles AS user_roles,
users.avatar_url AS user_avatar_url,
users.deleted AS user_deleted,
users.quiet_hours_schedule AS user_quiet_hours_schedule,
workspace_owner.username AS workspace_owner_username,
organizations.name AS organization_name,
organizations.display_name AS organization_display_name,
organizations.icon AS organization_icon
FROM
connection_logs
JOIN users AS workspace_owner ON
connection_logs.workspace_owner_id = workspace_owner.id
LEFT JOIN users ON
connection_logs.user_id = users.id
JOIN organizations ON
connection_logs.organization_id = organizations.id
WHERE TRUE
-- Authorize Filter clause will be injected below in
-- GetAuthorizedConnectionLogsOffset
-- @authorize_filter
ORDER BY
connect_time DESC
LIMIT
-- a limit of 0 means "no limit". The connection log table is unbounded
-- in size, and is expected to be quite large. Implement a default
-- limit of 100 to prevent accidental excessively large queries.
COALESCE(NULLIF($2 :: int, 0), 100)
OFFSET
$1
`
type GetConnectionLogsOffsetParams struct {
OffsetOpt int32 `db:"offset_opt" json:"offset_opt"`
LimitOpt int32 `db:"limit_opt" json:"limit_opt"`
}
type GetConnectionLogsOffsetRow struct {
ConnectionLog ConnectionLog `db:"connection_log" json:"connection_log"`
UserUsername sql.NullString `db:"user_username" json:"user_username"`
UserName sql.NullString `db:"user_name" json:"user_name"`
UserEmail sql.NullString `db:"user_email" json:"user_email"`
UserCreatedAt sql.NullTime `db:"user_created_at" json:"user_created_at"`
UserUpdatedAt sql.NullTime `db:"user_updated_at" json:"user_updated_at"`
UserLastSeenAt sql.NullTime `db:"user_last_seen_at" json:"user_last_seen_at"`
UserStatus NullUserStatus `db:"user_status" json:"user_status"`
UserLoginType NullLoginType `db:"user_login_type" json:"user_login_type"`
UserRoles pq.StringArray `db:"user_roles" json:"user_roles"`
UserAvatarUrl sql.NullString `db:"user_avatar_url" json:"user_avatar_url"`
UserDeleted sql.NullBool `db:"user_deleted" json:"user_deleted"`
UserQuietHoursSchedule sql.NullString `db:"user_quiet_hours_schedule" json:"user_quiet_hours_schedule"`
WorkspaceOwnerUsername string `db:"workspace_owner_username" json:"workspace_owner_username"`
OrganizationName string `db:"organization_name" json:"organization_name"`
OrganizationDisplayName string `db:"organization_display_name" json:"organization_display_name"`
OrganizationIcon string `db:"organization_icon" json:"organization_icon"`
}
func (q *sqlQuerier) GetConnectionLogsOffset(ctx context.Context, arg GetConnectionLogsOffsetParams) ([]GetConnectionLogsOffsetRow, error) {
rows, err := q.db.QueryContext(ctx, getConnectionLogsOffset, arg.OffsetOpt, arg.LimitOpt)
if err != nil {
return nil, err
}
defer rows.Close()
var items []GetConnectionLogsOffsetRow
for rows.Next() {
var i GetConnectionLogsOffsetRow
if err := rows.Scan(
&i.ConnectionLog.ID,
&i.ConnectionLog.ConnectTime,
&i.ConnectionLog.OrganizationID,
&i.ConnectionLog.WorkspaceOwnerID,
&i.ConnectionLog.WorkspaceID,
&i.ConnectionLog.WorkspaceName,
&i.ConnectionLog.AgentName,
&i.ConnectionLog.Type,
&i.ConnectionLog.Ip,
&i.ConnectionLog.Code,
&i.ConnectionLog.UserAgent,
&i.ConnectionLog.UserID,
&i.ConnectionLog.SlugOrPort,
&i.ConnectionLog.ConnectionID,
&i.ConnectionLog.DisconnectTime,
&i.ConnectionLog.DisconnectReason,
&i.UserUsername,
&i.UserName,
&i.UserEmail,
&i.UserCreatedAt,
&i.UserUpdatedAt,
&i.UserLastSeenAt,
&i.UserStatus,
&i.UserLoginType,
&i.UserRoles,
&i.UserAvatarUrl,
&i.UserDeleted,
&i.UserQuietHoursSchedule,
&i.WorkspaceOwnerUsername,
&i.OrganizationName,
&i.OrganizationDisplayName,
&i.OrganizationIcon,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const upsertConnectionLog = `-- name: UpsertConnectionLog :one
INSERT INTO connection_logs (
id,
connect_time,
organization_id,
workspace_owner_id,
workspace_id,
workspace_name,
agent_name,
type,
code,
ip,
user_agent,
user_id,
slug_or_port,
connection_id,
disconnect_reason,
disconnect_time
) VALUES
($1, $15, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14,
-- If we've only received a disconnect event, mark the event as immediately
-- closed.
CASE
WHEN $16::connection_status = 'disconnected'
THEN $15 :: timestamp with time zone
ELSE NULL
END)
ON CONFLICT (connection_id, workspace_id, agent_name)
DO UPDATE SET
-- No-op if the connection is still open.
disconnect_time = CASE
WHEN $16::connection_status = 'disconnected'
-- Can only be set once
AND connection_logs.disconnect_time IS NULL
THEN EXCLUDED.connect_time
ELSE connection_logs.disconnect_time
END,
disconnect_reason = CASE
WHEN $16::connection_status = 'disconnected'
-- Can only be set once
AND connection_logs.disconnect_reason IS NULL
THEN EXCLUDED.disconnect_reason
ELSE connection_logs.disconnect_reason
END,
code = CASE
WHEN $16::connection_status = 'disconnected'
-- Can only be set once
AND connection_logs.code IS NULL
THEN EXCLUDED.code
ELSE connection_logs.code
END
RETURNING id, connect_time, organization_id, workspace_owner_id, workspace_id, workspace_name, agent_name, type, ip, code, user_agent, user_id, slug_or_port, connection_id, disconnect_time, disconnect_reason
`
type UpsertConnectionLogParams struct {
ID uuid.UUID `db:"id" json:"id"`
OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"`
WorkspaceOwnerID uuid.UUID `db:"workspace_owner_id" json:"workspace_owner_id"`
WorkspaceID uuid.UUID `db:"workspace_id" json:"workspace_id"`
WorkspaceName string `db:"workspace_name" json:"workspace_name"`
AgentName string `db:"agent_name" json:"agent_name"`
Type ConnectionType `db:"type" json:"type"`
Code sql.NullInt32 `db:"code" json:"code"`
Ip pqtype.Inet `db:"ip" json:"ip"`
UserAgent sql.NullString `db:"user_agent" json:"user_agent"`
UserID uuid.NullUUID `db:"user_id" json:"user_id"`
SlugOrPort sql.NullString `db:"slug_or_port" json:"slug_or_port"`
ConnectionID uuid.NullUUID `db:"connection_id" json:"connection_id"`
DisconnectReason sql.NullString `db:"disconnect_reason" json:"disconnect_reason"`
Time time.Time `db:"time" json:"time"`
ConnectionStatus ConnectionStatus `db:"connection_status" json:"connection_status"`
}
func (q *sqlQuerier) UpsertConnectionLog(ctx context.Context, arg UpsertConnectionLogParams) (ConnectionLog, error) {
row := q.db.QueryRowContext(ctx, upsertConnectionLog,
arg.ID,
arg.OrganizationID,
arg.WorkspaceOwnerID,
arg.WorkspaceID,
arg.WorkspaceName,
arg.AgentName,
arg.Type,
arg.Code,
arg.Ip,
arg.UserAgent,
arg.UserID,
arg.SlugOrPort,
arg.ConnectionID,
arg.DisconnectReason,
arg.Time,
arg.ConnectionStatus,
)
var i ConnectionLog
err := row.Scan(
&i.ID,
&i.ConnectTime,
&i.OrganizationID,
&i.WorkspaceOwnerID,
&i.WorkspaceID,
&i.WorkspaceName,
&i.AgentName,
&i.Type,
&i.Ip,
&i.Code,
&i.UserAgent,
&i.UserID,
&i.SlugOrPort,
&i.ConnectionID,
&i.DisconnectTime,
&i.DisconnectReason,
)
return i, err
}
const deleteCryptoKey = `-- name: DeleteCryptoKey :one
UPDATE crypto_keys
SET secret = NULL, secret_key_id = NULL
@@ -0,0 +1,97 @@
-- name: GetConnectionLogsOffset :many
SELECT
sqlc.embed(connection_logs),
-- sqlc.embed(users) would be nice but it does not seem to play well with
-- left joins. This user metadata is necessary for parity with the audit logs
-- API.
users.username AS user_username,
users.name AS user_name,
users.email AS user_email,
users.created_at AS user_created_at,
users.updated_at AS user_updated_at,
users.last_seen_at AS user_last_seen_at,
users.status AS user_status,
users.login_type AS user_login_type,
users.rbac_roles AS user_roles,
users.avatar_url AS user_avatar_url,
users.deleted AS user_deleted,
users.quiet_hours_schedule AS user_quiet_hours_schedule,
workspace_owner.username AS workspace_owner_username,
organizations.name AS organization_name,
organizations.display_name AS organization_display_name,
organizations.icon AS organization_icon
FROM
connection_logs
JOIN users AS workspace_owner ON
connection_logs.workspace_owner_id = workspace_owner.id
LEFT JOIN users ON
connection_logs.user_id = users.id
JOIN organizations ON
connection_logs.organization_id = organizations.id
WHERE TRUE
-- Authorize Filter clause will be injected below in
-- GetAuthorizedConnectionLogsOffset
-- @authorize_filter
ORDER BY
connect_time DESC
LIMIT
-- a limit of 0 means "no limit". The connection log table is unbounded
-- in size, and is expected to be quite large. Implement a default
-- limit of 100 to prevent accidental excessively large queries.
COALESCE(NULLIF(@limit_opt :: int, 0), 100)
OFFSET
@offset_opt;
-- name: UpsertConnectionLog :one
INSERT INTO connection_logs (
id,
connect_time,
organization_id,
workspace_owner_id,
workspace_id,
workspace_name,
agent_name,
type,
code,
ip,
user_agent,
user_id,
slug_or_port,
connection_id,
disconnect_reason,
disconnect_time
) VALUES
($1, @time, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14,
-- If we've only received a disconnect event, mark the event as immediately
-- closed.
CASE
WHEN @connection_status::connection_status = 'disconnected'
THEN @time :: timestamp with time zone
ELSE NULL
END)
ON CONFLICT (connection_id, workspace_id, agent_name)
DO UPDATE SET
-- No-op if the connection is still open.
disconnect_time = CASE
WHEN @connection_status::connection_status = 'disconnected'
-- Can only be set once
AND connection_logs.disconnect_time IS NULL
THEN EXCLUDED.connect_time
ELSE connection_logs.disconnect_time
END,
disconnect_reason = CASE
WHEN @connection_status::connection_status = 'disconnected'
-- Can only be set once
AND connection_logs.disconnect_reason IS NULL
THEN EXCLUDED.disconnect_reason
ELSE connection_logs.disconnect_reason
END,
code = CASE
WHEN @connection_status::connection_status = 'disconnected'
-- Can only be set once
AND connection_logs.code IS NULL
THEN EXCLUDED.code
ELSE connection_logs.code
END
RETURNING *;
+18
View File
@@ -4,10 +4,12 @@ import (
"database/sql/driver"
"encoding/json"
"fmt"
"net"
"strings"
"time"
"github.com/google/uuid"
"github.com/sqlc-dev/pqtype"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/rbac/policy"
@@ -237,3 +239,19 @@ func (a *UserLinkClaims) Scan(src interface{}) error {
func (a UserLinkClaims) Value() (driver.Value, error) {
return json.Marshal(a)
}
func ParseIP(ipStr string) pqtype.Inet {
ip := net.ParseIP(ipStr)
ipNet := net.IPNet{}
if ip != nil {
ipNet = net.IPNet{
IP: ip,
Mask: net.CIDRMask(len(ip)*8, len(ip)*8),
}
}
return pqtype.Inet{
IPNet: ipNet,
Valid: ip != nil,
}
}
+2
View File
@@ -9,6 +9,7 @@ const (
UniqueAgentStatsPkey UniqueConstraint = "agent_stats_pkey" // ALTER TABLE ONLY workspace_agent_stats ADD CONSTRAINT agent_stats_pkey PRIMARY KEY (id);
UniqueAPIKeysPkey UniqueConstraint = "api_keys_pkey" // ALTER TABLE ONLY api_keys ADD CONSTRAINT api_keys_pkey PRIMARY KEY (id);
UniqueAuditLogsPkey UniqueConstraint = "audit_logs_pkey" // ALTER TABLE ONLY audit_logs ADD CONSTRAINT audit_logs_pkey PRIMARY KEY (id);
UniqueConnectionLogsPkey UniqueConstraint = "connection_logs_pkey" // ALTER TABLE ONLY connection_logs ADD CONSTRAINT connection_logs_pkey PRIMARY KEY (id);
UniqueCryptoKeysPkey UniqueConstraint = "crypto_keys_pkey" // ALTER TABLE ONLY crypto_keys ADD CONSTRAINT crypto_keys_pkey PRIMARY KEY (feature, sequence);
UniqueCustomRolesUniqueKey UniqueConstraint = "custom_roles_unique_key" // ALTER TABLE ONLY custom_roles ADD CONSTRAINT custom_roles_unique_key UNIQUE (name, organization_id);
UniqueDbcryptKeysActiveKeyDigestKey UniqueConstraint = "dbcrypt_keys_active_key_digest_key" // ALTER TABLE ONLY dbcrypt_keys ADD CONSTRAINT dbcrypt_keys_active_key_digest_key UNIQUE (active_key_digest);
@@ -100,6 +101,7 @@ const (
UniqueWorkspaceResourcesPkey UniqueConstraint = "workspace_resources_pkey" // ALTER TABLE ONLY workspace_resources ADD CONSTRAINT workspace_resources_pkey PRIMARY KEY (id);
UniqueWorkspacesPkey UniqueConstraint = "workspaces_pkey" // ALTER TABLE ONLY workspaces ADD CONSTRAINT workspaces_pkey PRIMARY KEY (id);
UniqueIndexAPIKeyName UniqueConstraint = "idx_api_key_name" // CREATE UNIQUE INDEX idx_api_key_name ON api_keys USING btree (user_id, token_name) WHERE (login_type = 'token'::login_type);
UniqueIndexConnectionLogsConnectionIDWorkspaceIDAgentName UniqueConstraint = "idx_connection_logs_connection_id_workspace_id_agent_name" // CREATE UNIQUE INDEX idx_connection_logs_connection_id_workspace_id_agent_name ON connection_logs USING btree (connection_id, workspace_id, agent_name);
UniqueIndexCustomRolesNameLower UniqueConstraint = "idx_custom_roles_name_lower" // CREATE UNIQUE INDEX idx_custom_roles_name_lower ON custom_roles USING btree (lower(name));
UniqueIndexOrganizationNameLower UniqueConstraint = "idx_organization_name_lower" // CREATE UNIQUE INDEX idx_organization_name_lower ON organizations USING btree (lower(name)) WHERE (deleted = false);
UniqueIndexProvisionerDaemonsOrgNameOwnerKey UniqueConstraint = "idx_provisioner_daemons_org_name_owner_key" // CREATE UNIQUE INDEX idx_provisioner_daemons_org_name_owner_key ON provisioner_daemons USING btree (organization_id, name, lower(COALESCE((tags ->> 'owner'::text), ''::text)));
+1
View File
@@ -65,6 +65,7 @@ const (
SubjectTypeUser SubjectType = "user"
SubjectTypeProvisionerd SubjectType = "provisionerd"
SubjectTypeAutostart SubjectType = "autostart"
SubjectTypeConnectionLogger SubjectType = "connection_logger"
SubjectTypeJobReaper SubjectType = "job_reaper"
SubjectTypeResourceMonitor SubjectType = "resource_monitor"
SubjectTypeCryptoKeyRotator SubjectType = "crypto_key_rotator"
+9
View File
@@ -54,6 +54,14 @@ var (
Type: "audit_log",
}
// ResourceConnectionLog
// Valid Actions
// - "ActionRead" :: read connection logs
// - "ActionUpdate" :: upsert connection log entries
ResourceConnectionLog = Object{
Type: "connection_log",
}
// ResourceCryptoKey
// Valid Actions
// - "ActionCreate" :: create crypto keys
@@ -368,6 +376,7 @@ func AllResources() []Objecter {
ResourceAssignOrgRole,
ResourceAssignRole,
ResourceAuditLog,
ResourceConnectionLog,
ResourceCryptoKey,
ResourceDebugInfo,
ResourceDeploymentConfig,
+6
View File
@@ -138,6 +138,12 @@ var RBACPermissions = map[string]PermissionDefinition{
ActionCreate: actDef("create new audit log entries"),
},
},
"connection_log": {
Actions: map[Action]ActionDefinition{
ActionRead: actDef("read connection logs"),
ActionUpdate: actDef("upsert connection log entries"),
},
},
"deployment_config": {
Actions: map[Action]ActionDefinition{
ActionRead: actDef("read deployment config"),
+14
View File
@@ -50,6 +50,20 @@ func AuditLogConverter() *sqltypes.VariableConverter {
return matcher
}
func ConnectionLogConverter() *sqltypes.VariableConverter {
matcher := sqltypes.NewVariableConverter().RegisterMatcher(
resourceIDMatcher(),
sqltypes.StringVarMatcher("COALESCE(connection_logs.organization_id :: text, '')", []string{"input", "object", "org_owner"}),
// Connection logs have no user owner, only owner by an organization.
sqltypes.AlwaysFalse(userOwnerMatcher()),
)
matcher.RegisterMatcher(
sqltypes.AlwaysFalse(groupACLMatcher(matcher)),
sqltypes.AlwaysFalse(userACLMatcher(matcher)),
)
return matcher
}
func UserConverter() *sqltypes.VariableConverter {
matcher := sqltypes.NewVariableConverter().RegisterMatcher(
resourceIDMatcher(),
+3 -1
View File
@@ -315,6 +315,7 @@ func ReloadBuiltinRoles(opts *RoleOptions) {
Site: Permissions(map[string][]policy.Action{
ResourceAssignOrgRole.Type: {policy.ActionRead},
ResourceAuditLog.Type: {policy.ActionRead},
ResourceConnectionLog.Type: {policy.ActionRead},
// Allow auditors to see the resources that audit logs reflect.
ResourceTemplate.Type: {policy.ActionRead, policy.ActionViewInsights},
ResourceUser.Type: {policy.ActionRead},
@@ -456,7 +457,8 @@ func ReloadBuiltinRoles(opts *RoleOptions) {
Site: []Permission{},
Org: map[string][]Permission{
organizationID.String(): Permissions(map[string][]policy.Action{
ResourceAuditLog.Type: {policy.ActionRead},
ResourceAuditLog.Type: {policy.ActionRead},
ResourceConnectionLog.Type: {policy.ActionRead},
// Allow auditors to see the resources that audit logs reflect.
ResourceTemplate.Type: {policy.ActionRead, policy.ActionViewInsights},
ResourceGroup.Type: {policy.ActionRead},
+9
View File
@@ -849,6 +849,15 @@ func TestRolePermissions(t *testing.T) {
},
},
},
{
Name: "ConnectionLogs",
Actions: []policy.Action{policy.ActionRead, policy.ActionUpdate},
Resource: rbac.ResourceConnectionLog,
AuthorizeMap: map[bool][]hasAuthSubjects{
true: {owner},
false: {setOtherOrg, setOrgNotMe, memberMe, orgMemberMe, templateAdmin, userAdmin},
},
},
}
// We expect every permission to be tested above.
+1 -1
View File
@@ -139,7 +139,7 @@ func (api *API) workspaceAgentRPC(rw http.ResponseWriter, r *http.Request) {
Database: api.Database,
NotificationsEnqueuer: api.NotificationsEnqueuer,
Pubsub: api.Pubsub,
Auditor: &api.Auditor,
ConnectionLogger: &api.ConnectionLogger,
DerpMapFn: api.DERPMap,
TailnetCoordinator: &api.TailnetCoordinator,
AppearanceFetcher: &api.AppearanceFetcher,
+62 -81
View File
@@ -3,7 +3,6 @@ package workspaceapps
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"net/http"
"net/url"
@@ -18,7 +17,7 @@ import (
"golang.org/x/xerrors"
"cdr.dev/slog"
"github.com/coder/coder/v2/coderd/audit"
"github.com/coder/coder/v2/coderd/connectionlog"
"github.com/coder/coder/v2/coderd/cryptokeys"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
@@ -40,7 +39,7 @@ type DBTokenProvider struct {
// DashboardURL is the main dashboard access URL for error pages.
DashboardURL *url.URL
Authorizer rbac.Authorizer
Auditor *atomic.Pointer[audit.Auditor]
ConnectionLogger *atomic.Pointer[connectionlog.ConnectionLogger]
Database database.Store
DeploymentValues *codersdk.DeploymentValues
OAuth2Configs *httpmw.OAuth2Configs
@@ -54,7 +53,7 @@ var _ SignedTokenProvider = &DBTokenProvider{}
func NewDBTokenProvider(log slog.Logger,
accessURL *url.URL,
authz rbac.Authorizer,
auditor *atomic.Pointer[audit.Auditor],
connectionLogger *atomic.Pointer[connectionlog.ConnectionLogger],
db database.Store,
cfg *codersdk.DeploymentValues,
oauth2Cfgs *httpmw.OAuth2Configs,
@@ -73,7 +72,7 @@ func NewDBTokenProvider(log slog.Logger,
Logger: log,
DashboardURL: accessURL,
Authorizer: authz,
Auditor: auditor,
ConnectionLogger: connectionLogger,
Database: db,
DeploymentValues: cfg,
OAuth2Configs: oauth2Cfgs,
@@ -95,7 +94,7 @@ func (p *DBTokenProvider) Issue(ctx context.Context, rw http.ResponseWriter, r *
// // permissions.
dangerousSystemCtx := dbauthz.AsSystemRestricted(ctx)
aReq, commitAudit := p.auditInitRequest(ctx, rw, r)
aReq, commitAudit := p.connLogInitRequest(ctx, rw, r)
defer commitAudit()
appReq := issueReq.AppRequest.Normalize()
@@ -386,20 +385,20 @@ func (p *DBTokenProvider) authorizeRequest(ctx context.Context, roles *rbac.Subj
return false, warnings, nil
}
type auditRequest struct {
type connLogRequest struct {
time time.Time
apiKey *database.APIKey
dbReq *databaseRequest
}
// auditInitRequest creates a new audit session and audit log for the given
// request, if one does not already exist. If an audit session already exists,
// it will be updated with the current timestamp. A session is used to reduce
// the number of audit logs created.
// connLogInitRequest creates a new connection log session and connect log for the
// given request, if one does not already exist. If a connection log session
// already exists, it will be updated with the current timestamp. A session is used to
// reduce the number of connection logs created.
//
// A session is unique to the agent, app, user and users IP. If any of these
// values change, a new session and audit log is created.
func (p *DBTokenProvider) auditInitRequest(ctx context.Context, w http.ResponseWriter, r *http.Request) (aReq *auditRequest, commit func()) {
// values change, a new session and connect log is created.
func (p *DBTokenProvider) connLogInitRequest(ctx context.Context, w http.ResponseWriter, r *http.Request) (aReq *connLogRequest, commit func()) {
// Get the status writer from the request context so we can figure
// out the HTTP status and autocommit the audit log.
sw, ok := w.(*tracing.StatusWriter)
@@ -407,12 +406,12 @@ func (p *DBTokenProvider) auditInitRequest(ctx context.Context, w http.ResponseW
panic("dev error: http.ResponseWriter is not *tracing.StatusWriter")
}
aReq = &auditRequest{
aReq = &connLogRequest{
time: dbtime.Now(),
}
// Set the commit function on the status writer to create an audit
// log, this ensures that the status and response body are available.
// Set the commit function on the status writer to create a connection log
// this ensures that the status and response body are available.
var committed bool
return aReq, func() {
if committed {
@@ -422,7 +421,7 @@ func (p *DBTokenProvider) auditInitRequest(ctx context.Context, w http.ResponseW
if aReq.dbReq == nil {
// App doesn't exist, there's information in the Request
// struct but we need UUIDs for audit logging.
// struct but we need UUIDs for connection logging.
return
}
@@ -434,28 +433,25 @@ func (p *DBTokenProvider) auditInitRequest(ctx context.Context, w http.ResponseW
ip := r.RemoteAddr
// Approximation of the status code.
statusCode := sw.Status
// #nosec G115 - Safe conversion as HTTP status code is expected to be within int32 range (typically 100-599)
var statusCode int32 = int32(sw.Status)
if statusCode == 0 {
statusCode = http.StatusOK
}
type additionalFields struct {
audit.AdditionalFields
SlugOrPort string `json:"slug_or_port,omitempty"`
}
appInfo := additionalFields{
AdditionalFields: audit.AdditionalFields{
WorkspaceOwner: aReq.dbReq.Workspace.OwnerUsername,
WorkspaceName: aReq.dbReq.Workspace.Name,
WorkspaceID: aReq.dbReq.Workspace.ID,
},
}
var (
connType database.ConnectionType
slugOrPort = aReq.dbReq.AppSlugOrPort
)
switch {
case aReq.dbReq.AccessMethod == AccessMethodTerminal:
appInfo.SlugOrPort = "terminal"
connType = database.ConnectionTypeWorkspaceApp
slugOrPort = "terminal"
case aReq.dbReq.App.ID == uuid.Nil:
// If this isn't an app or a terminal, it's a port.
appInfo.SlugOrPort = aReq.dbReq.AppSlugOrPort
connType = database.ConnectionTypePortForwarding
default:
connType = database.ConnectionTypeWorkspaceApp
}
// If we end up logging, ensure relevant fields are set.
@@ -465,7 +461,7 @@ func (p *DBTokenProvider) auditInitRequest(ctx context.Context, w http.ResponseW
slog.F("app_id", aReq.dbReq.App.ID),
slog.F("user_id", userID),
slog.F("user_agent", userAgent),
slog.F("app_slug_or_port", appInfo.SlugOrPort),
slog.F("app_slug_or_port", slugOrPort),
slog.F("status_code", statusCode),
)
@@ -485,9 +481,8 @@ func (p *DBTokenProvider) auditInitRequest(ctx context.Context, w http.ResponseW
UserID: userID, // Can be unset, in which case uuid.Nil is fine.
Ip: ip,
UserAgent: userAgent,
SlugOrPort: appInfo.SlugOrPort,
// #nosec G115 - Safe conversion as HTTP status code is expected to be within int32 range (typically 100-599)
StatusCode: int32(statusCode),
SlugOrPort: slugOrPort,
StatusCode: statusCode,
StartedAt: aReq.time,
UpdatedAt: aReq.time,
})
@@ -500,7 +495,7 @@ func (p *DBTokenProvider) auditInitRequest(ctx context.Context, w http.ResponseW
if err != nil {
logger.Error(ctx, "update workspace app audit session failed", slog.Error(err))
// Avoid spamming the audit log if deduplication failed, this should
// Avoid spamming the connection log if deduplication failed, this should
// only happen if there are problems communicating with the database.
return
}
@@ -511,51 +506,37 @@ func (p *DBTokenProvider) auditInitRequest(ctx context.Context, w http.ResponseW
return
}
// Marshal additional fields only if we're writing an audit log entry.
appInfoBytes, err := json.Marshal(appInfo)
connLogger := *p.ConnectionLogger.Load()
err = connLogger.Upsert(ctx, database.UpsertConnectionLogParams{
ID: uuid.New(),
Time: aReq.time,
OrganizationID: aReq.dbReq.Workspace.OrganizationID,
WorkspaceOwnerID: aReq.dbReq.Workspace.OwnerID,
WorkspaceID: aReq.dbReq.Workspace.ID,
WorkspaceName: aReq.dbReq.Workspace.Name,
AgentName: aReq.dbReq.Agent.Name,
Type: connType,
Code: sql.NullInt32{
Int32: statusCode,
Valid: true,
},
Ip: database.ParseIP(ip),
UserAgent: sql.NullString{Valid: userAgent != "", String: userAgent},
UserID: uuid.NullUUID{
UUID: userID,
Valid: userID != uuid.Nil,
},
SlugOrPort: sql.NullString{Valid: slugOrPort != "", String: slugOrPort},
ConnectionStatus: database.ConnectionStatusConnected,
// N/A
ConnectionID: uuid.NullUUID{},
DisconnectReason: sql.NullString{},
})
if err != nil {
logger.Error(ctx, "marshal additional fields failed", slog.Error(err))
}
// We use the background audit function instead of init request
// here because we don't know the resource type ahead of time.
// This also allows us to log unauthenticated access.
auditor := *p.Auditor.Load()
requestID := httpmw.RequestID(r)
switch {
case aReq.dbReq.App.ID != uuid.Nil:
audit.BackgroundAudit(ctx, &audit.BackgroundAuditParams[database.WorkspaceApp]{
Audit: auditor,
Log: logger,
Action: database.AuditActionOpen,
OrganizationID: aReq.dbReq.Workspace.OrganizationID,
UserID: userID,
RequestID: requestID,
Time: aReq.time,
Status: statusCode,
IP: ip,
UserAgent: userAgent,
New: aReq.dbReq.App,
AdditionalFields: appInfoBytes,
})
default:
// Web terminal, port app, etc.
audit.BackgroundAudit(ctx, &audit.BackgroundAuditParams[database.WorkspaceAgent]{
Audit: auditor,
Log: logger,
Action: database.AuditActionOpen,
OrganizationID: aReq.dbReq.Workspace.OrganizationID,
UserID: userID,
RequestID: requestID,
Time: aReq.time,
Status: statusCode,
IP: ip,
UserAgent: userAgent,
New: aReq.dbReq.Agent,
AdditionalFields: appInfoBytes,
})
logger.Error(ctx, "upsert connection log failed", slog.Error(err))
return
}
}
}
+119 -151
View File
@@ -3,7 +3,6 @@ package workspaceapps_test
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"io"
"net"
@@ -22,10 +21,9 @@ import (
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/agent/agenttest"
"github.com/coder/coder/v2/coderd/audit"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/connectionlog"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/httpmw"
"github.com/coder/coder/v2/coderd/jwtutils"
"github.com/coder/coder/v2/coderd/tracing"
@@ -83,12 +81,12 @@ func Test_ResolveRequest(t *testing.T) {
deploymentValues.Dangerous.AllowPathAppSharing = true
deploymentValues.Dangerous.AllowPathAppSiteOwnerAccess = true
auditor := audit.NewMock()
connLogger := connectionlog.NewFake()
t.Cleanup(func() {
if t.Failed() {
return
}
assert.Len(t, auditor.AuditLogs(), 0, "one or more test cases produced unexpected audit logs, did you replace the auditor or forget to call ResetLogs?")
assert.Len(t, connLogger.ConnectionLogs(), 0, "one or more test cases produced unexpected connection logs, did you replace the auditor or forget to call ResetLogs?")
})
client, closer, api := coderdtest.NewWithAPI(t, &coderdtest.Options{
AppHostname: "*.test.coder.com",
@@ -105,7 +103,7 @@ func Test_ResolveRequest(t *testing.T) {
"CF-Connecting-IP",
},
},
Auditor: auditor,
ConnectionLogger: connLogger,
})
t.Cleanup(func() {
_ = closer.Close()
@@ -231,23 +229,8 @@ func Test_ResolveRequest(t *testing.T) {
}
require.NotEqual(t, uuid.Nil, agentID)
//nolint:gocritic // This is a test, allow dbauthz.AsSystemRestricted.
agent, err := api.Database.GetWorkspaceAgentByID(dbauthz.AsSystemRestricted(ctx), agentID)
require.NoError(t, err)
//nolint:gocritic // This is a test, allow dbauthz.AsSystemRestricted.
apps, err := api.Database.GetWorkspaceAppsByAgentID(dbauthz.AsSystemRestricted(ctx), agentID)
require.NoError(t, err)
appsBySlug := make(map[string]database.WorkspaceApp, len(apps))
for _, app := range apps {
appsBySlug[app.Slug] = app
}
// Reset audit logs so cleanup check can pass.
auditor.ResetLogs()
assertAuditAgent := auditAsserter[database.WorkspaceAgent](workspace)
assertAuditApp := auditAsserter[database.WorkspaceApp](workspace)
connLogger.Reset()
t.Run("OK", func(t *testing.T) {
t.Parallel()
@@ -285,9 +268,9 @@ func Test_ResolveRequest(t *testing.T) {
AppSlugOrPort: app,
}).Normalize()
auditor := audit.NewMock()
connLogger := connectionlog.NewFake()
auditableIP := testutil.RandomIPv6(t)
auditableUA := "Tidua"
auditableUA := "Noitcennoc"
t.Log("app", app)
rw := httptest.NewRecorder()
@@ -297,7 +280,7 @@ func Test_ResolveRequest(t *testing.T) {
r.Header.Set("User-Agent", auditableUA)
// Try resolving the request without a token.
token, ok := workspaceappsResolveRequest(t, auditor, rw, r, workspaceapps.ResolveRequestOptions{
token, ok := workspaceappsResolveRequest(t, connLogger, rw, r, workspaceapps.ResolveRequestOptions{
Logger: api.Logger,
SignedTokenProvider: api.WorkspaceAppsProvider,
DashboardURL: api.AccessURL,
@@ -333,8 +316,8 @@ func Test_ResolveRequest(t *testing.T) {
require.Equal(t, codersdk.SignedAppTokenCookie, cookie.Name)
require.Equal(t, req.BasePath, cookie.Path)
assertAuditApp(t, rw, r, auditor, appsBySlug[app], me.ID, nil)
require.Len(t, auditor.AuditLogs(), 1, "audit log count")
assertConnLogContains(t, rw, r, connLogger, workspace, agentName, app, database.ConnectionTypeWorkspaceApp, me.ID)
require.Len(t, connLogger.ConnectionLogs(), 1)
var parsedToken workspaceapps.SignedToken
err := jwtutils.Verify(ctx, api.AppSigningKeyCache, cookie.Value, &parsedToken)
@@ -350,7 +333,7 @@ func Test_ResolveRequest(t *testing.T) {
r.AddCookie(cookie)
r.RemoteAddr = auditableIP
secondToken, ok := workspaceappsResolveRequest(t, auditor, rw, r, workspaceapps.ResolveRequestOptions{
secondToken, ok := workspaceappsResolveRequest(t, connLogger, rw, r, workspaceapps.ResolveRequestOptions{
Logger: api.Logger,
SignedTokenProvider: api.WorkspaceAppsProvider,
DashboardURL: api.AccessURL,
@@ -363,7 +346,7 @@ func Test_ResolveRequest(t *testing.T) {
require.WithinDuration(t, token.Expiry.Time(), secondToken.Expiry.Time(), 2*time.Second)
secondToken.Expiry = token.Expiry
require.Equal(t, token, secondToken)
require.Len(t, auditor.AuditLogs(), 1, "no new audit log, FromRequest returned the same token and is not audited")
require.Len(t, connLogger.ConnectionLogs(), 1, "no new connection log, FromRequest returned the same token and is not logged")
}
})
}
@@ -382,7 +365,7 @@ func Test_ResolveRequest(t *testing.T) {
AppSlugOrPort: app,
}).Normalize()
auditor := audit.NewMock()
connLogger := connectionlog.NewFake()
auditableIP := testutil.RandomIPv6(t)
t.Log("app", app)
@@ -391,7 +374,7 @@ func Test_ResolveRequest(t *testing.T) {
r.Header.Set(codersdk.SessionTokenHeader, secondUserClient.SessionToken())
r.RemoteAddr = auditableIP
token, ok := workspaceappsResolveRequest(t, auditor, rw, r, workspaceapps.ResolveRequestOptions{
token, ok := workspaceappsResolveRequest(t, connLogger, rw, r, workspaceapps.ResolveRequestOptions{
Logger: api.Logger,
SignedTokenProvider: api.WorkspaceAppsProvider,
DashboardURL: api.AccessURL,
@@ -406,14 +389,15 @@ func Test_ResolveRequest(t *testing.T) {
require.Nil(t, token)
require.NotZero(t, w.StatusCode)
require.Equal(t, http.StatusNotFound, w.StatusCode)
require.Len(t, connLogger.ConnectionLogs(), 1)
return
}
require.True(t, ok)
require.NotNil(t, token)
require.Zero(t, w.StatusCode)
assertAuditApp(t, rw, r, auditor, appsBySlug[app], secondUser.ID, nil)
require.Len(t, auditor.AuditLogs(), 1, "single audit log")
assertConnLogContains(t, rw, r, connLogger, workspace, agentName, app, database.ConnectionTypeWorkspaceApp, secondUser.ID)
require.Len(t, connLogger.ConnectionLogs(), 1)
}
})
@@ -430,14 +414,14 @@ func Test_ResolveRequest(t *testing.T) {
AppSlugOrPort: app,
}).Normalize()
auditor := audit.NewMock()
connLogger := connectionlog.NewFake()
auditableIP := testutil.RandomIPv6(t)
t.Log("app", app)
rw := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/app", nil)
r.RemoteAddr = auditableIP
token, ok := workspaceappsResolveRequest(t, auditor, rw, r, workspaceapps.ResolveRequestOptions{
token, ok := workspaceappsResolveRequest(t, connLogger, rw, r, workspaceapps.ResolveRequestOptions{
Logger: api.Logger,
SignedTokenProvider: api.WorkspaceAppsProvider,
DashboardURL: api.AccessURL,
@@ -452,8 +436,8 @@ func Test_ResolveRequest(t *testing.T) {
require.NotZero(t, rw.Code)
require.NotEqual(t, http.StatusOK, rw.Code)
assertAuditApp(t, rw, r, auditor, appsBySlug[app], uuid.Nil, nil)
require.Len(t, auditor.AuditLogs(), 1, "audit log for unauthenticated requests")
assertConnLogContains(t, rw, r, connLogger, workspace, agentName, app, database.ConnectionTypeWorkspaceApp, uuid.Nil)
require.Len(t, connLogger.ConnectionLogs(), 1)
} else {
if !assert.True(t, ok) {
dump, err := httputil.DumpResponse(w, true)
@@ -466,8 +450,8 @@ func Test_ResolveRequest(t *testing.T) {
t.Fatalf("expected 200 (or unset) response code, got %d", rw.Code)
}
assertAuditApp(t, rw, r, auditor, appsBySlug[app], uuid.Nil, nil)
require.Len(t, auditor.AuditLogs(), 1, "single audit log")
assertConnLogContains(t, rw, r, connLogger, workspace, agentName, app, database.ConnectionTypeWorkspaceApp, uuid.Nil)
require.Len(t, connLogger.ConnectionLogs(), 1)
}
_ = w.Body.Close()
}
@@ -479,12 +463,12 @@ func Test_ResolveRequest(t *testing.T) {
req := (workspaceapps.Request{
AccessMethod: "invalid",
}).Normalize()
auditor := audit.NewMock()
connLogger := connectionlog.NewFake()
auditableIP := testutil.RandomIPv6(t)
rw := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/app", nil)
r.RemoteAddr = auditableIP
token, ok := workspaceappsResolveRequest(t, auditor, rw, r, workspaceapps.ResolveRequestOptions{
token, ok := workspaceappsResolveRequest(t, connLogger, rw, r, workspaceapps.ResolveRequestOptions{
Logger: api.Logger,
SignedTokenProvider: api.WorkspaceAppsProvider,
DashboardURL: api.AccessURL,
@@ -494,7 +478,7 @@ func Test_ResolveRequest(t *testing.T) {
})
require.False(t, ok)
require.Nil(t, token)
require.Len(t, auditor.AuditLogs(), 0, "no audit logs for invalid requests")
require.Len(t, connLogger.ConnectionLogs(), 0)
})
t.Run("SplitWorkspaceAndAgent", func(t *testing.T) {
@@ -562,7 +546,7 @@ func Test_ResolveRequest(t *testing.T) {
AppSlugOrPort: appNamePublic,
}).Normalize()
auditor := audit.NewMock()
connLogger := connectionlog.NewFake()
auditableIP := testutil.RandomIPv6(t)
rw := httptest.NewRecorder()
@@ -570,7 +554,7 @@ func Test_ResolveRequest(t *testing.T) {
r.Header.Set(codersdk.SessionTokenHeader, client.SessionToken())
r.RemoteAddr = auditableIP
token, ok := workspaceappsResolveRequest(t, auditor, rw, r, workspaceapps.ResolveRequestOptions{
token, ok := workspaceappsResolveRequest(t, connLogger, rw, r, workspaceapps.ResolveRequestOptions{
Logger: api.Logger,
SignedTokenProvider: api.WorkspaceAppsProvider,
DashboardURL: api.AccessURL,
@@ -591,11 +575,11 @@ func Test_ResolveRequest(t *testing.T) {
require.Equal(t, token.AgentNameOrID, c.agent)
require.Equal(t, token.WorkspaceID, workspace.ID)
require.Equal(t, token.AgentID, agentID)
assertAuditApp(t, rw, r, auditor, appsBySlug[token.AppSlugOrPort], me.ID, nil)
require.Len(t, auditor.AuditLogs(), 1, "single audit log")
assertConnLogContains(t, rw, r, connLogger, workspace, agentName, token.AppSlugOrPort, database.ConnectionTypeWorkspaceApp, me.ID)
require.Len(t, connLogger.ConnectionLogs(), 1)
} else {
require.Nil(t, token)
require.Len(t, auditor.AuditLogs(), 0, "no audit logs")
require.Len(t, connLogger.ConnectionLogs(), 0)
}
_ = w.Body.Close()
})
@@ -637,7 +621,7 @@ func Test_ResolveRequest(t *testing.T) {
AppSlugOrPort: appNameOwner,
}).Normalize()
auditor := audit.NewMock()
connLogger := connectionlog.NewFake()
auditableIP := testutil.RandomIPv6(t)
rw := httptest.NewRecorder()
@@ -651,7 +635,7 @@ func Test_ResolveRequest(t *testing.T) {
// Even though the token is invalid, we should still perform request
// resolution without failure since we'll just ignore the bad token.
token, ok := workspaceappsResolveRequest(t, auditor, rw, r, workspaceapps.ResolveRequestOptions{
token, ok := workspaceappsResolveRequest(t, connLogger, rw, r, workspaceapps.ResolveRequestOptions{
Logger: api.Logger,
SignedTokenProvider: api.WorkspaceAppsProvider,
DashboardURL: api.AccessURL,
@@ -676,8 +660,8 @@ func Test_ResolveRequest(t *testing.T) {
require.NoError(t, err)
require.Equal(t, appNameOwner, parsedToken.AppSlugOrPort)
assertAuditApp(t, rw, r, auditor, appsBySlug[appNameOwner], me.ID, nil)
require.Len(t, auditor.AuditLogs(), 1, "single audit log")
assertConnLogContains(t, rw, r, connLogger, workspace, agentName, appNameOwner, database.ConnectionTypeWorkspaceApp, me.ID)
require.Len(t, connLogger.ConnectionLogs(), 1)
})
t.Run("PortPathBlocked", func(t *testing.T) {
@@ -692,7 +676,7 @@ func Test_ResolveRequest(t *testing.T) {
AppSlugOrPort: "8080",
}).Normalize()
auditor := audit.NewMock()
connLogger := connectionlog.NewFake()
auditableIP := testutil.RandomIPv6(t)
rw := httptest.NewRecorder()
@@ -700,7 +684,7 @@ func Test_ResolveRequest(t *testing.T) {
r.Header.Set(codersdk.SessionTokenHeader, client.SessionToken())
r.RemoteAddr = auditableIP
token, ok := workspaceappsResolveRequest(t, auditor, rw, r, workspaceapps.ResolveRequestOptions{
token, ok := workspaceappsResolveRequest(t, connLogger, rw, r, workspaceapps.ResolveRequestOptions{
Logger: api.Logger,
SignedTokenProvider: api.WorkspaceAppsProvider,
DashboardURL: api.AccessURL,
@@ -715,7 +699,7 @@ func Test_ResolveRequest(t *testing.T) {
_ = w.Body.Close()
// TODO(mafredri): Verify this is the correct status code.
require.Equal(t, http.StatusInternalServerError, w.StatusCode)
require.Len(t, auditor.AuditLogs(), 0, "no audit logs for port path blocked requests")
require.Len(t, connLogger.ConnectionLogs(), 0, "no connection logs for port path blocked requests")
})
t.Run("PortSubdomain", func(t *testing.T) {
@@ -730,7 +714,7 @@ func Test_ResolveRequest(t *testing.T) {
AppSlugOrPort: "9090",
}).Normalize()
auditor := audit.NewMock()
connLogger := connectionlog.NewFake()
auditableIP := testutil.RandomIPv6(t)
rw := httptest.NewRecorder()
@@ -738,7 +722,7 @@ func Test_ResolveRequest(t *testing.T) {
r.Header.Set(codersdk.SessionTokenHeader, client.SessionToken())
r.RemoteAddr = auditableIP
token, ok := workspaceappsResolveRequest(t, auditor, rw, r, workspaceapps.ResolveRequestOptions{
token, ok := workspaceappsResolveRequest(t, connLogger, rw, r, workspaceapps.ResolveRequestOptions{
Logger: api.Logger,
SignedTokenProvider: api.WorkspaceAppsProvider,
DashboardURL: api.AccessURL,
@@ -749,11 +733,8 @@ func Test_ResolveRequest(t *testing.T) {
require.True(t, ok)
require.Equal(t, req.AppSlugOrPort, token.AppSlugOrPort)
require.Equal(t, "http://127.0.0.1:9090", token.AppURL)
assertAuditAgent(t, rw, r, auditor, agent, me.ID, map[string]any{
"slug_or_port": "9090",
})
require.Len(t, auditor.AuditLogs(), 1, "single audit log")
assertConnLogContains(t, rw, r, connLogger, workspace, agentName, "9090", database.ConnectionTypePortForwarding, me.ID)
require.Len(t, connLogger.ConnectionLogs(), 1)
})
t.Run("PortSubdomainHTTPSS", func(t *testing.T) {
@@ -768,7 +749,7 @@ func Test_ResolveRequest(t *testing.T) {
AppSlugOrPort: "9090ss",
}).Normalize()
auditor := audit.NewMock()
connLogger := connectionlog.NewFake()
auditableIP := testutil.RandomIPv6(t)
rw := httptest.NewRecorder()
@@ -776,7 +757,7 @@ func Test_ResolveRequest(t *testing.T) {
r.Header.Set(codersdk.SessionTokenHeader, client.SessionToken())
r.RemoteAddr = auditableIP
_, ok := workspaceappsResolveRequest(t, auditor, rw, r, workspaceapps.ResolveRequestOptions{
_, ok := workspaceappsResolveRequest(t, connLogger, rw, r, workspaceapps.ResolveRequestOptions{
Logger: api.Logger,
SignedTokenProvider: api.WorkspaceAppsProvider,
DashboardURL: api.AccessURL,
@@ -792,7 +773,7 @@ func Test_ResolveRequest(t *testing.T) {
require.NoError(t, err)
require.Contains(t, string(b), "404 - Application Not Found")
require.Equal(t, http.StatusNotFound, w.StatusCode)
require.Len(t, auditor.AuditLogs(), 0, "no audit logs for invalid requests")
require.Len(t, connLogger.ConnectionLogs(), 0)
})
t.Run("SubdomainEndsInS", func(t *testing.T) {
@@ -807,7 +788,7 @@ func Test_ResolveRequest(t *testing.T) {
AppSlugOrPort: appNameEndsInS,
}).Normalize()
auditor := audit.NewMock()
connLogger := connectionlog.NewFake()
auditableIP := testutil.RandomIPv6(t)
rw := httptest.NewRecorder()
@@ -815,7 +796,7 @@ func Test_ResolveRequest(t *testing.T) {
r.Header.Set(codersdk.SessionTokenHeader, client.SessionToken())
r.RemoteAddr = auditableIP
token, ok := workspaceappsResolveRequest(t, auditor, rw, r, workspaceapps.ResolveRequestOptions{
token, ok := workspaceappsResolveRequest(t, connLogger, rw, r, workspaceapps.ResolveRequestOptions{
Logger: api.Logger,
SignedTokenProvider: api.WorkspaceAppsProvider,
DashboardURL: api.AccessURL,
@@ -825,8 +806,8 @@ func Test_ResolveRequest(t *testing.T) {
})
require.True(t, ok)
require.Equal(t, req.AppSlugOrPort, token.AppSlugOrPort)
assertAuditApp(t, rw, r, auditor, appsBySlug[appNameEndsInS], me.ID, nil)
require.Len(t, auditor.AuditLogs(), 1, "single audit log")
assertConnLogContains(t, rw, r, connLogger, workspace, agentName, appNameEndsInS, database.ConnectionTypeWorkspaceApp, me.ID)
require.Len(t, connLogger.ConnectionLogs(), 1)
})
t.Run("Terminal", func(t *testing.T) {
@@ -838,7 +819,7 @@ func Test_ResolveRequest(t *testing.T) {
AgentNameOrID: agentID.String(),
}).Normalize()
auditor := audit.NewMock()
connLogger := connectionlog.NewFake()
auditableIP := testutil.RandomIPv6(t)
rw := httptest.NewRecorder()
@@ -846,7 +827,7 @@ func Test_ResolveRequest(t *testing.T) {
r.Header.Set(codersdk.SessionTokenHeader, client.SessionToken())
r.RemoteAddr = auditableIP
token, ok := workspaceappsResolveRequest(t, auditor, rw, r, workspaceapps.ResolveRequestOptions{
token, ok := workspaceappsResolveRequest(t, connLogger, rw, r, workspaceapps.ResolveRequestOptions{
Logger: api.Logger,
SignedTokenProvider: api.WorkspaceAppsProvider,
DashboardURL: api.AccessURL,
@@ -862,10 +843,8 @@ func Test_ResolveRequest(t *testing.T) {
require.Equal(t, req.AgentNameOrID, token.Request.AgentNameOrID)
require.Empty(t, token.AppSlugOrPort)
require.Empty(t, token.AppURL)
assertAuditAgent(t, rw, r, auditor, agent, me.ID, map[string]any{
"slug_or_port": "terminal",
})
require.Len(t, auditor.AuditLogs(), 1, "single audit log")
assertConnLogContains(t, rw, r, connLogger, workspace, agentName, "terminal", database.ConnectionTypeWorkspaceApp, me.ID)
require.Len(t, connLogger.ConnectionLogs(), 1)
})
t.Run("InsufficientPermissions", func(t *testing.T) {
@@ -880,7 +859,7 @@ func Test_ResolveRequest(t *testing.T) {
AppSlugOrPort: appNameOwner,
}).Normalize()
auditor := audit.NewMock()
connLogger := connectionlog.NewFake()
auditableIP := testutil.RandomIPv6(t)
rw := httptest.NewRecorder()
@@ -888,7 +867,7 @@ func Test_ResolveRequest(t *testing.T) {
r.Header.Set(codersdk.SessionTokenHeader, secondUserClient.SessionToken())
r.RemoteAddr = auditableIP
token, ok := workspaceappsResolveRequest(t, auditor, rw, r, workspaceapps.ResolveRequestOptions{
token, ok := workspaceappsResolveRequest(t, connLogger, rw, r, workspaceapps.ResolveRequestOptions{
Logger: api.Logger,
SignedTokenProvider: api.WorkspaceAppsProvider,
DashboardURL: api.AccessURL,
@@ -898,8 +877,8 @@ func Test_ResolveRequest(t *testing.T) {
})
require.False(t, ok)
require.Nil(t, token)
assertAuditApp(t, rw, r, auditor, appsBySlug[appNameOwner], secondUser.ID, nil)
require.Len(t, auditor.AuditLogs(), 1, "single audit log")
assertConnLogContains(t, rw, r, connLogger, workspace, agentName, appNameOwner, database.ConnectionTypeWorkspaceApp, secondUser.ID)
require.Len(t, connLogger.ConnectionLogs(), 1)
})
t.Run("UserNotFound", func(t *testing.T) {
@@ -913,7 +892,7 @@ func Test_ResolveRequest(t *testing.T) {
AppSlugOrPort: appNameOwner,
}).Normalize()
auditor := audit.NewMock()
connLogger := connectionlog.NewFake()
auditableIP := testutil.RandomIPv6(t)
rw := httptest.NewRecorder()
@@ -921,7 +900,7 @@ func Test_ResolveRequest(t *testing.T) {
r.Header.Set(codersdk.SessionTokenHeader, client.SessionToken())
r.RemoteAddr = auditableIP
token, ok := workspaceappsResolveRequest(t, auditor, rw, r, workspaceapps.ResolveRequestOptions{
token, ok := workspaceappsResolveRequest(t, connLogger, rw, r, workspaceapps.ResolveRequestOptions{
Logger: api.Logger,
SignedTokenProvider: api.WorkspaceAppsProvider,
DashboardURL: api.AccessURL,
@@ -931,7 +910,7 @@ func Test_ResolveRequest(t *testing.T) {
})
require.False(t, ok)
require.Nil(t, token)
require.Len(t, auditor.AuditLogs(), 0, "no audit logs for user not found")
require.Len(t, connLogger.ConnectionLogs(), 0)
})
t.Run("RedirectSubdomainAuth", func(t *testing.T) {
@@ -946,7 +925,7 @@ func Test_ResolveRequest(t *testing.T) {
AppSlugOrPort: appNameOwner,
}).Normalize()
auditor := audit.NewMock()
connLogger := connectionlog.NewFake()
auditableIP := testutil.RandomIPv6(t)
rw := httptest.NewRecorder()
@@ -955,7 +934,7 @@ func Test_ResolveRequest(t *testing.T) {
r.Host = "app.com"
r.RemoteAddr = auditableIP
token, ok := workspaceappsResolveRequest(t, auditor, rw, r, workspaceapps.ResolveRequestOptions{
token, ok := workspaceappsResolveRequest(t, connLogger, rw, r, workspaceapps.ResolveRequestOptions{
Logger: api.Logger,
SignedTokenProvider: api.WorkspaceAppsProvider,
DashboardURL: api.AccessURL,
@@ -972,8 +951,8 @@ func Test_ResolveRequest(t *testing.T) {
require.Equal(t, http.StatusSeeOther, w.StatusCode)
// Note that we don't capture the owner UUID here because the apiKey
// check/authorization exits early.
assertAuditApp(t, rw, r, auditor, appsBySlug[appNameOwner], uuid.Nil, nil)
require.Len(t, auditor.AuditLogs(), 1, "autit log entry for redirect")
assertConnLogContains(t, rw, r, connLogger, workspace, agentName, appNameOwner, database.ConnectionTypeWorkspaceApp, uuid.Nil)
require.Len(t, connLogger.ConnectionLogs(), 1)
loc, err := w.Location()
require.NoError(t, err)
@@ -1012,7 +991,7 @@ func Test_ResolveRequest(t *testing.T) {
AppSlugOrPort: appNameAgentUnhealthy,
}).Normalize()
auditor := audit.NewMock()
connLogger := connectionlog.NewFake()
auditableIP := testutil.RandomIPv6(t)
rw := httptest.NewRecorder()
@@ -1020,7 +999,7 @@ func Test_ResolveRequest(t *testing.T) {
r.Header.Set(codersdk.SessionTokenHeader, client.SessionToken())
r.RemoteAddr = auditableIP
token, ok := workspaceappsResolveRequest(t, auditor, rw, r, workspaceapps.ResolveRequestOptions{
token, ok := workspaceappsResolveRequest(t, connLogger, rw, r, workspaceapps.ResolveRequestOptions{
Logger: api.Logger,
SignedTokenProvider: api.WorkspaceAppsProvider,
DashboardURL: api.AccessURL,
@@ -1034,8 +1013,8 @@ func Test_ResolveRequest(t *testing.T) {
w := rw.Result()
defer w.Body.Close()
require.Equal(t, http.StatusBadGateway, w.StatusCode)
assertAuditApp(t, rw, r, auditor, appsBySlug[appNameAgentUnhealthy], me.ID, nil)
require.Len(t, auditor.AuditLogs(), 1, "single audit log")
assertConnLogContains(t, rw, r, connLogger, workspace, agentNameUnhealthy, appNameAgentUnhealthy, database.ConnectionTypeWorkspaceApp, me.ID)
require.Len(t, connLogger.ConnectionLogs(), 1)
body, err := io.ReadAll(w.Body)
require.NoError(t, err)
@@ -1075,7 +1054,7 @@ func Test_ResolveRequest(t *testing.T) {
AppSlugOrPort: appNameInitializing,
}).Normalize()
auditor := audit.NewMock()
connLogger := connectionlog.NewFake()
auditableIP := testutil.RandomIPv6(t)
rw := httptest.NewRecorder()
@@ -1083,7 +1062,7 @@ func Test_ResolveRequest(t *testing.T) {
r.Header.Set(codersdk.SessionTokenHeader, client.SessionToken())
r.RemoteAddr = auditableIP
token, ok := workspaceappsResolveRequest(t, auditor, rw, r, workspaceapps.ResolveRequestOptions{
token, ok := workspaceappsResolveRequest(t, connLogger, rw, r, workspaceapps.ResolveRequestOptions{
Logger: api.Logger,
SignedTokenProvider: api.WorkspaceAppsProvider,
DashboardURL: api.AccessURL,
@@ -1093,8 +1072,8 @@ func Test_ResolveRequest(t *testing.T) {
})
require.True(t, ok, "ResolveRequest failed, should pass even though app is initializing")
require.NotNil(t, token)
assertAuditApp(t, rw, r, auditor, appsBySlug[token.AppSlugOrPort], me.ID, nil)
require.Len(t, auditor.AuditLogs(), 1, "single audit log")
assertConnLogContains(t, rw, r, connLogger, workspace, agentName, token.AppSlugOrPort, database.ConnectionTypeWorkspaceApp, me.ID)
require.Len(t, connLogger.ConnectionLogs(), 1)
})
// Unhealthy apps are now permitted to connect anyways. This wasn't always
@@ -1133,7 +1112,7 @@ func Test_ResolveRequest(t *testing.T) {
AppSlugOrPort: appNameUnhealthy,
}).Normalize()
auditor := audit.NewMock()
connLogger := connectionlog.NewFake()
auditableIP := testutil.RandomIPv6(t)
rw := httptest.NewRecorder()
@@ -1141,7 +1120,7 @@ func Test_ResolveRequest(t *testing.T) {
r.Header.Set(codersdk.SessionTokenHeader, client.SessionToken())
r.RemoteAddr = auditableIP
token, ok := workspaceappsResolveRequest(t, auditor, rw, r, workspaceapps.ResolveRequestOptions{
token, ok := workspaceappsResolveRequest(t, connLogger, rw, r, workspaceapps.ResolveRequestOptions{
Logger: api.Logger,
SignedTokenProvider: api.WorkspaceAppsProvider,
DashboardURL: api.AccessURL,
@@ -1151,11 +1130,11 @@ func Test_ResolveRequest(t *testing.T) {
})
require.True(t, ok, "ResolveRequest failed, should pass even though app is unhealthy")
require.NotNil(t, token)
assertAuditApp(t, rw, r, auditor, appsBySlug[token.AppSlugOrPort], me.ID, nil)
require.Len(t, auditor.AuditLogs(), 1, "single audit log")
assertConnLogContains(t, rw, r, connLogger, workspace, agentName, token.AppSlugOrPort, database.ConnectionTypeWorkspaceApp, me.ID)
require.Len(t, connLogger.ConnectionLogs(), 1)
})
t.Run("AuditLogging", func(t *testing.T) {
t.Run("ConnectionLogging", func(t *testing.T) {
t.Parallel()
for _, app := range allApps {
@@ -1168,18 +1147,18 @@ func Test_ResolveRequest(t *testing.T) {
AppSlugOrPort: app,
}).Normalize()
auditor := audit.NewMock()
connLogger := connectionlog.NewFake()
auditableIP := testutil.RandomIPv6(t)
t.Log("app", app)
// First request, new audit log.
// First request, new connection log.
rw := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/app", nil)
r.Header.Set(codersdk.SessionTokenHeader, client.SessionToken())
r.RemoteAddr = auditableIP
_, ok := workspaceappsResolveRequest(t, auditor, rw, r, workspaceapps.ResolveRequestOptions{
_, ok := workspaceappsResolveRequest(t, connLogger, rw, r, workspaceapps.ResolveRequestOptions{
Logger: api.Logger,
SignedTokenProvider: api.WorkspaceAppsProvider,
DashboardURL: api.AccessURL,
@@ -1188,8 +1167,8 @@ func Test_ResolveRequest(t *testing.T) {
AppRequest: req,
})
require.True(t, ok)
assertAuditApp(t, rw, r, auditor, appsBySlug[app], me.ID, nil)
require.Len(t, auditor.AuditLogs(), 1, "single audit log")
assertConnLogContains(t, rw, r, connLogger, workspace, agentName, app, database.ConnectionTypeWorkspaceApp, me.ID)
require.Len(t, connLogger.ConnectionLogs(), 1)
// Second request, no audit log because the session is active.
rw = httptest.NewRecorder()
@@ -1197,7 +1176,7 @@ func Test_ResolveRequest(t *testing.T) {
r.Header.Set(codersdk.SessionTokenHeader, client.SessionToken())
r.RemoteAddr = auditableIP
_, ok = workspaceappsResolveRequest(t, auditor, rw, r, workspaceapps.ResolveRequestOptions{
_, ok = workspaceappsResolveRequest(t, connLogger, rw, r, workspaceapps.ResolveRequestOptions{
Logger: api.Logger,
SignedTokenProvider: api.WorkspaceAppsProvider,
DashboardURL: api.AccessURL,
@@ -1206,7 +1185,7 @@ func Test_ResolveRequest(t *testing.T) {
AppRequest: req,
})
require.True(t, ok)
require.Len(t, auditor.AuditLogs(), 1, "single audit log, previous session active")
require.Len(t, connLogger.ConnectionLogs(), 1, "single connection log, previous session active")
// Third request, session timed out, new audit log.
rw = httptest.NewRecorder()
@@ -1214,7 +1193,7 @@ func Test_ResolveRequest(t *testing.T) {
r.Header.Set(codersdk.SessionTokenHeader, client.SessionToken())
r.RemoteAddr = auditableIP
sessionTimeoutTokenProvider := signedTokenProviderWithAuditor(t, api.WorkspaceAppsProvider, auditor, 0)
sessionTimeoutTokenProvider := signedTokenProviderWithConnLogger(t, api.WorkspaceAppsProvider, connLogger, 0)
_, ok = workspaceappsResolveRequest(t, nil, rw, r, workspaceapps.ResolveRequestOptions{
Logger: api.Logger,
SignedTokenProvider: sessionTimeoutTokenProvider,
@@ -1224,8 +1203,8 @@ func Test_ResolveRequest(t *testing.T) {
AppRequest: req,
})
require.True(t, ok)
assertAuditApp(t, rw, r, auditor, appsBySlug[app], me.ID, nil)
require.Len(t, auditor.AuditLogs(), 2, "two audit logs, session timed out")
assertConnLogContains(t, rw, r, connLogger, workspace, agentName, app, database.ConnectionTypeWorkspaceApp, me.ID)
require.Len(t, connLogger.ConnectionLogs(), 2, "two connection logs, session timed out")
// Fourth request, new IP produces new audit log.
auditableIP = testutil.RandomIPv6(t)
@@ -1234,7 +1213,7 @@ func Test_ResolveRequest(t *testing.T) {
r.Header.Set(codersdk.SessionTokenHeader, client.SessionToken())
r.RemoteAddr = auditableIP
_, ok = workspaceappsResolveRequest(t, auditor, rw, r, workspaceapps.ResolveRequestOptions{
_, ok = workspaceappsResolveRequest(t, connLogger, rw, r, workspaceapps.ResolveRequestOptions{
Logger: api.Logger,
SignedTokenProvider: api.WorkspaceAppsProvider,
DashboardURL: api.AccessURL,
@@ -1243,16 +1222,16 @@ func Test_ResolveRequest(t *testing.T) {
AppRequest: req,
})
require.True(t, ok)
assertAuditApp(t, rw, r, auditor, appsBySlug[app], me.ID, nil)
require.Len(t, auditor.AuditLogs(), 3, "three audit logs, new IP")
assertConnLogContains(t, rw, r, connLogger, workspace, agentName, app, database.ConnectionTypeWorkspaceApp, me.ID)
require.Len(t, connLogger.ConnectionLogs(), 3, "three connection logs, new IP")
}
})
}
func workspaceappsResolveRequest(t testing.TB, auditor audit.Auditor, w http.ResponseWriter, r *http.Request, opts workspaceapps.ResolveRequestOptions) (token *workspaceapps.SignedToken, ok bool) {
func workspaceappsResolveRequest(t testing.TB, connLogger connectionlog.ConnectionLogger, w http.ResponseWriter, r *http.Request, opts workspaceapps.ResolveRequestOptions) (token *workspaceapps.SignedToken, ok bool) {
t.Helper()
if opts.SignedTokenProvider != nil && auditor != nil {
opts.SignedTokenProvider = signedTokenProviderWithAuditor(t, opts.SignedTokenProvider, auditor, time.Hour)
if opts.SignedTokenProvider != nil && connLogger != nil {
opts.SignedTokenProvider = signedTokenProviderWithConnLogger(t, opts.SignedTokenProvider, connLogger, time.Hour)
}
tracing.StatusWriterMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -1264,52 +1243,41 @@ func workspaceappsResolveRequest(t testing.TB, auditor audit.Auditor, w http.Res
return token, ok
}
func signedTokenProviderWithAuditor(t testing.TB, provider workspaceapps.SignedTokenProvider, auditor audit.Auditor, sessionTimeout time.Duration) workspaceapps.SignedTokenProvider {
func signedTokenProviderWithConnLogger(t testing.TB, provider workspaceapps.SignedTokenProvider, connLogger connectionlog.ConnectionLogger, sessionTimeout time.Duration) workspaceapps.SignedTokenProvider {
t.Helper()
p, ok := provider.(*workspaceapps.DBTokenProvider)
require.True(t, ok, "provider is not a DBTokenProvider")
shallowCopy := *p
shallowCopy.Auditor = &atomic.Pointer[audit.Auditor]{}
shallowCopy.Auditor.Store(&auditor)
shallowCopy.ConnectionLogger = &atomic.Pointer[connectionlog.ConnectionLogger]{}
shallowCopy.ConnectionLogger.Store(&connLogger)
shallowCopy.WorkspaceAppAuditSessionTimeout = sessionTimeout
return &shallowCopy
}
func auditAsserter[T audit.Auditable](workspace codersdk.Workspace) func(t testing.TB, rr *httptest.ResponseRecorder, r *http.Request, auditor *audit.MockAuditor, auditable T, userID uuid.UUID, additionalFields map[string]any) {
return func(t testing.TB, rr *httptest.ResponseRecorder, r *http.Request, auditor *audit.MockAuditor, auditable T, userID uuid.UUID, additionalFields map[string]any) {
t.Helper()
func assertConnLogContains(t *testing.T, rr *httptest.ResponseRecorder, r *http.Request, connLogger *connectionlog.FakeConnectionLogger, workspace codersdk.Workspace, agentName string, slugOrPort string, typ database.ConnectionType, userID uuid.UUID) {
t.Helper()
resp := rr.Result()
defer resp.Body.Close()
resp := rr.Result()
defer resp.Body.Close()
require.True(t, auditor.Contains(t, database.AuditLog{
OrganizationID: workspace.OrganizationID,
Action: database.AuditActionOpen,
ResourceType: audit.ResourceType(auditable),
ResourceID: audit.ResourceID(auditable),
ResourceTarget: audit.ResourceTarget(auditable),
UserID: userID,
Ip: audit.ParseIP(r.RemoteAddr),
UserAgent: sql.NullString{Valid: r.UserAgent() != "", String: r.UserAgent()},
StatusCode: int32(resp.StatusCode), //nolint:gosec
}), "audit log")
// Verify additional fields, assume the last log entry.
alog := auditor.AuditLogs()[len(auditor.AuditLogs())-1]
// Contains does not verify uuid.Nil.
if userID == uuid.Nil {
require.Equal(t, uuid.Nil, alog.UserID, "unauthenticated user")
}
add := make(map[string]any)
if len(alog.AdditionalFields) > 0 {
err := json.Unmarshal([]byte(alog.AdditionalFields), &add)
require.NoError(t, err, "audit log unmarhsal additional fields")
}
for k, v := range additionalFields {
require.Equal(t, v, add[k], "audit log additional field %s: additional fields: %v", k, add)
}
}
require.True(t, connLogger.Contains(t, database.UpsertConnectionLogParams{
OrganizationID: workspace.OrganizationID,
WorkspaceOwnerID: workspace.OwnerID,
WorkspaceID: workspace.ID,
WorkspaceName: workspace.Name,
AgentName: agentName,
Type: typ,
Ip: database.ParseIP(r.RemoteAddr),
UserAgent: sql.NullString{Valid: r.UserAgent() != "", String: r.UserAgent()},
Code: sql.NullInt32{
Int32: int32(resp.StatusCode), // nolint:gosec
Valid: true,
},
UserID: uuid.NullUUID{
UUID: userID,
Valid: true,
},
SlugOrPort: sql.NullString{Valid: slugOrPort != "", String: slugOrPort},
}))
}