Files
coder/coderd/agentapi/connectionlog.go
T
Callum Styan 8ed1c1d372 perf: reduce calls to GetWorkspaceByAgentID in GetWorkspaceAgentByID (#21046)
This PR piggy backs on the agent API cached workspace added in an earlier PR to provide a fast path for avoiding `GetWorkspaceByAgentID` calls in dbauthz's `GetWorkspaceAgentByID`. This query is not the most expensive, but has a significant call volume at ~16 million calls per week.

Signed-off-by: Callum Styan <callumstyan@gmail.com>
2025-12-10 14:03:24 -08:00

134 lines
4.2 KiB
Go

package agentapi
import (
"context"
"database/sql"
"sync/atomic"
"github.com/google/uuid"
"golang.org/x/xerrors"
"google.golang.org/protobuf/types/known/emptypb"
"cdr.dev/slog"
agentproto "github.com/coder/coder/v2/agent/proto"
"github.com/coder/coder/v2/coderd/connectionlog"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/db2sdk"
"github.com/coder/coder/v2/coderd/database/dbauthz"
)
type ConnLogAPI struct {
AgentFn func(context.Context) (database.WorkspaceAgent, error)
ConnectionLogger *atomic.Pointer[connectionlog.ConnectionLogger]
Workspace *CachedWorkspaceFields
Database database.Store
Log slog.Logger
}
func (a *ConnLogAPI) ReportConnection(ctx context.Context, req *agentproto.ReportConnectionRequest) (*emptypb.Empty, error) {
// We use the connection ID to identify which connection log event to mark
// as closed, when we receive a close action for that ID.
connectionID, err := uuid.FromBytes(req.GetConnection().GetId())
if err != nil {
return nil, xerrors.Errorf("connection id from bytes: %w", err)
}
if connectionID == uuid.Nil {
return nil, xerrors.New("connection ID cannot be nil")
}
action, err := db2sdk.ConnectionLogStatusFromAgentProtoConnectionAction(req.GetConnection().GetAction())
if err != nil {
return nil, err
}
connectionType, err := db2sdk.ConnectionLogConnectionTypeFromAgentProtoConnectionType(req.GetConnection().GetType())
if err != nil {
return nil, err
}
var code sql.NullInt32
if action == database.ConnectionStatusDisconnected {
code = sql.NullInt32{
Int32: req.GetConnection().GetStatusCode(),
Valid: true,
}
}
// Inject RBAC object into context for dbauthz fast path, avoid having to
// call GetWorkspaceByAgentID on every metadata update.
rbacCtx := ctx
var ws database.WorkspaceIdentity
if dbws, ok := a.Workspace.AsWorkspaceIdentity(); ok {
ws = dbws
rbacCtx, err = dbauthz.WithWorkspaceRBAC(ctx, dbws.RBACObject())
if err != nil {
// Don't error level log here, will exit the function. We want to fall back to GetWorkspaceByAgentID.
//nolint:gocritic
a.Log.Debug(ctx, "Cached workspace was present but RBAC object was invalid", slog.F("err", err))
}
}
// Fetch contextual data for this connection log event.
workspaceAgent, err := a.AgentFn(rbacCtx)
if err != nil {
return nil, xerrors.Errorf("get agent: %w", err)
}
if ws.Equal(database.WorkspaceIdentity{}) {
workspace, err := a.Database.GetWorkspaceByAgentID(ctx, workspaceAgent.ID)
if err != nil {
return nil, xerrors.Errorf("get workspace by agent id: %w", err)
}
ws = database.WorkspaceIdentityFromWorkspace(workspace)
}
// Some older clients may incorrectly report "localhost" as the IP address.
// Related to https://github.com/coder/coder/issues/20194
logIPRaw := req.GetConnection().GetIp()
if logIPRaw == "localhost" {
logIPRaw = "127.0.0.1"
}
logIP := database.ParseIP(logIPRaw) // will return null if invalid
reason := req.GetConnection().GetReason()
connLogger := *a.ConnectionLogger.Load()
err = connLogger.Upsert(ctx, database.UpsertConnectionLogParams{
ID: uuid.New(),
Time: req.GetConnection().GetTimestamp().AsTime(),
OrganizationID: ws.OrganizationID,
WorkspaceOwnerID: ws.OwnerID,
WorkspaceID: ws.ID,
WorkspaceName: ws.Name,
AgentName: workspaceAgent.Name,
Type: connectionType,
Code: code,
Ip: logIP,
ConnectionID: uuid.NullUUID{
UUID: connectionID,
Valid: true,
},
DisconnectReason: sql.NullString{
String: reason,
Valid: reason != "",
},
// We supply the action:
// - So the DB can handle duplicate connections or disconnections properly.
// - To make it clear whether this is a connection or disconnection
// prior to it's insertion into the DB (logs)
ConnectionStatus: action,
// It's not possible to tell which user connected. Once we have
// the capability, this may be reported by the agent.
UserID: uuid.NullUUID{
Valid: false,
},
// N/A
UserAgent: sql.NullString{},
// N/A
SlugOrPort: sql.NullString{},
})
if err != nil {
return nil, xerrors.Errorf("export connection log: %w", err)
}
return &emptypb.Empty{}, nil
}