diff --git a/cli/exp_scaletest.go b/cli/exp_scaletest.go index 2515012b28..559ffbebd1 100644 --- a/cli/exp_scaletest.go +++ b/cli/exp_scaletest.go @@ -29,7 +29,6 @@ import ( "github.com/coder/coder/v2/cli/cliui" "github.com/coder/coder/v2/coderd/httpapi" - notificationsLib "github.com/coder/coder/v2/coderd/notifications" "github.com/coder/coder/v2/coderd/tracing" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk/workspacesdk" @@ -40,7 +39,6 @@ import ( "github.com/coder/coder/v2/scaletest/dashboard" "github.com/coder/coder/v2/scaletest/harness" "github.com/coder/coder/v2/scaletest/loadtestutil" - "github.com/coder/coder/v2/scaletest/notifications" "github.com/coder/coder/v2/scaletest/reconnectingpty" "github.com/coder/coder/v2/scaletest/workspacebuild" "github.com/coder/coder/v2/scaletest/workspacetraffic" @@ -1922,259 +1920,6 @@ func (r *RootCmd) scaletestAutostart() *serpent.Command { return cmd } -func (r *RootCmd) scaletestNotifications() *serpent.Command { - var ( - userCount int64 - ownerUserPercentage float64 - notificationTimeout time.Duration - dialTimeout time.Duration - noCleanup bool - - tracingFlags = &scaletestTracingFlags{} - - // This test requires unlimited concurrency. - timeoutStrategy = &timeoutFlags{} - cleanupStrategy = newScaletestCleanupStrategy() - output = &scaletestOutputFlags{} - prometheusFlags = &scaletestPrometheusFlags{} - ) - - cmd := &serpent.Command{ - Use: "notifications", - Short: "Simulate notification delivery by creating many users listening to notifications.", - Handler: func(inv *serpent.Invocation) error { - ctx := inv.Context() - client, err := r.InitClient(inv) - if err != nil { - return err - } - - notifyCtx, stop := signal.NotifyContext(ctx, StopSignals...) - defer stop() - ctx = notifyCtx - - me, err := requireAdmin(ctx, client) - if err != nil { - return err - } - - client.HTTPClient = &http.Client{ - Transport: &codersdk.HeaderTransport{ - Transport: http.DefaultTransport, - Header: map[string][]string{ - codersdk.BypassRatelimitHeader: {"true"}, - }, - }, - } - - if userCount <= 0 { - return xerrors.Errorf("--user-count must be greater than 0") - } - - if ownerUserPercentage < 0 || ownerUserPercentage > 100 { - return xerrors.Errorf("--owner-user-percentage must be between 0 and 100") - } - - ownerUserCount := int64(float64(userCount) * ownerUserPercentage / 100) - if ownerUserCount == 0 && ownerUserPercentage > 0 { - ownerUserCount = 1 - } - regularUserCount := userCount - ownerUserCount - - _, _ = fmt.Fprintf(inv.Stderr, "Distribution plan:\n") - _, _ = fmt.Fprintf(inv.Stderr, " Total users: %d\n", userCount) - _, _ = fmt.Fprintf(inv.Stderr, " Owner users: %d (%.1f%%)\n", ownerUserCount, ownerUserPercentage) - _, _ = fmt.Fprintf(inv.Stderr, " Regular users: %d (%.1f%%)\n", regularUserCount, 100.0-ownerUserPercentage) - - outputs, err := output.parse() - if err != nil { - return xerrors.Errorf("could not parse --output flags") - } - - tracerProvider, closeTracing, tracingEnabled, err := tracingFlags.provider(ctx) - if err != nil { - return xerrors.Errorf("create tracer provider: %w", err) - } - tracer := tracerProvider.Tracer(scaletestTracerName) - - reg := prometheus.NewRegistry() - metrics := notifications.NewMetrics(reg) - - logger := inv.Logger - prometheusSrvClose := ServeHandler(ctx, logger, promhttp.HandlerFor(reg, promhttp.HandlerOpts{}), prometheusFlags.Address, "prometheus") - defer prometheusSrvClose() - - defer func() { - _, _ = fmt.Fprintln(inv.Stderr, "\nUploading traces...") - if err := closeTracing(ctx); err != nil { - _, _ = fmt.Fprintf(inv.Stderr, "\nError uploading traces: %+v\n", err) - } - // Wait for prometheus metrics to be scraped - _, _ = fmt.Fprintf(inv.Stderr, "Waiting %s for prometheus metrics to be scraped\n", prometheusFlags.Wait) - <-time.After(prometheusFlags.Wait) - }() - - _, _ = fmt.Fprintln(inv.Stderr, "Creating users...") - - dialBarrier := &sync.WaitGroup{} - ownerWatchBarrier := &sync.WaitGroup{} - dialBarrier.Add(int(userCount)) - ownerWatchBarrier.Add(int(ownerUserCount)) - - expectedNotifications := map[uuid.UUID]chan time.Time{ - notificationsLib.TemplateUserAccountCreated: make(chan time.Time, 1), - notificationsLib.TemplateUserAccountDeleted: make(chan time.Time, 1), - } - - configs := make([]notifications.Config, 0, userCount) - for range ownerUserCount { - config := notifications.Config{ - User: createusers.Config{ - OrganizationID: me.OrganizationIDs[0], - }, - Roles: []string{codersdk.RoleOwner}, - NotificationTimeout: notificationTimeout, - DialTimeout: dialTimeout, - DialBarrier: dialBarrier, - ReceivingWatchBarrier: ownerWatchBarrier, - ExpectedNotifications: expectedNotifications, - Metrics: metrics, - } - if err := config.Validate(); err != nil { - return xerrors.Errorf("validate config: %w", err) - } - configs = append(configs, config) - } - for range regularUserCount { - config := notifications.Config{ - User: createusers.Config{ - OrganizationID: me.OrganizationIDs[0], - }, - Roles: []string{}, - NotificationTimeout: notificationTimeout, - DialTimeout: dialTimeout, - DialBarrier: dialBarrier, - ReceivingWatchBarrier: ownerWatchBarrier, - Metrics: metrics, - } - if err := config.Validate(); err != nil { - return xerrors.Errorf("validate config: %w", err) - } - configs = append(configs, config) - } - - go triggerUserNotifications( - ctx, - logger, - client, - me.OrganizationIDs[0], - dialBarrier, - dialTimeout, - expectedNotifications, - ) - - th := harness.NewTestHarness(timeoutStrategy.wrapStrategy(harness.ConcurrentExecutionStrategy{}), cleanupStrategy.toStrategy()) - - for i, config := range configs { - id := strconv.Itoa(i) - name := fmt.Sprintf("notifications-%s", id) - var runner harness.Runnable = notifications.NewRunner(client, config) - if tracingEnabled { - runner = &runnableTraceWrapper{ - tracer: tracer, - spanName: name, - runner: runner, - } - } - - th.AddRun(name, id, runner) - } - - _, _ = fmt.Fprintln(inv.Stderr, "Running notification delivery scaletest...") - testCtx, testCancel := timeoutStrategy.toContext(ctx) - defer testCancel() - err = th.Run(testCtx) - if err != nil { - return xerrors.Errorf("run test harness (harness failure, not a test failure): %w", err) - } - - // If the command was interrupted, skip stats. - if notifyCtx.Err() != nil { - return notifyCtx.Err() - } - - res := th.Results() - for _, o := range outputs { - err = o.write(res, inv.Stdout) - if err != nil { - return xerrors.Errorf("write output %q to %q: %w", o.format, o.path, err) - } - } - - if !noCleanup { - _, _ = fmt.Fprintln(inv.Stderr, "\nCleaning up...") - cleanupCtx, cleanupCancel := cleanupStrategy.toContext(ctx) - defer cleanupCancel() - err = th.Cleanup(cleanupCtx) - if err != nil { - return xerrors.Errorf("cleanup tests: %w", err) - } - } - - if res.TotalFail > 0 { - return xerrors.New("load test failed, see above for more details") - } - - return nil - }, - } - - cmd.Options = serpent.OptionSet{ - { - Flag: "user-count", - FlagShorthand: "c", - Env: "CODER_SCALETEST_NOTIFICATION_USER_COUNT", - Description: "Required: Total number of users to create.", - Value: serpent.Int64Of(&userCount), - Required: true, - }, - { - Flag: "owner-user-percentage", - Env: "CODER_SCALETEST_NOTIFICATION_OWNER_USER_PERCENTAGE", - Default: "20.0", - Description: "Percentage of users to assign Owner role to (0-100).", - Value: serpent.Float64Of(&ownerUserPercentage), - }, - { - Flag: "notification-timeout", - Env: "CODER_SCALETEST_NOTIFICATION_TIMEOUT", - Default: "5m", - Description: "How long to wait for notifications after triggering.", - Value: serpent.DurationOf(¬ificationTimeout), - }, - { - Flag: "dial-timeout", - Env: "CODER_SCALETEST_DIAL_TIMEOUT", - Default: "2m", - Description: "Timeout for dialing the notification websocket endpoint.", - Value: serpent.DurationOf(&dialTimeout), - }, - { - Flag: "no-cleanup", - Env: "CODER_SCALETEST_NO_CLEANUP", - Description: "Do not clean up resources after the test completes.", - Value: serpent.BoolOf(&noCleanup), - }, - } - - tracingFlags.attach(&cmd.Options) - timeoutStrategy.attach(&cmd.Options) - cleanupStrategy.attach(&cmd.Options) - output.attach(&cmd.Options) - prometheusFlags.attach(&cmd.Options) - return cmd -} - type runnableTraceWrapper struct { tracer trace.Tracer spanName string @@ -2184,8 +1929,9 @@ type runnableTraceWrapper struct { } var ( - _ harness.Runnable = &runnableTraceWrapper{} - _ harness.Cleanable = &runnableTraceWrapper{} + _ harness.Runnable = &runnableTraceWrapper{} + _ harness.Cleanable = &runnableTraceWrapper{} + _ harness.Collectable = &runnableTraceWrapper{} ) func (r *runnableTraceWrapper) Run(ctx context.Context, id string, logs io.Writer) error { @@ -2227,6 +1973,14 @@ func (r *runnableTraceWrapper) Cleanup(ctx context.Context, id string, logs io.W return c.Cleanup(ctx, id, logs) } +func (r *runnableTraceWrapper) GetMetrics() map[string]any { + c, ok := r.runner.(harness.Collectable) + if !ok { + return nil + } + return c.GetMetrics() +} + func getScaletestWorkspaces(ctx context.Context, client *codersdk.Client, owner, template string) ([]codersdk.Workspace, int, error) { var ( pageNumber = 0 @@ -2375,73 +2129,6 @@ func parseTargetRange(name, targets string) (start, end int, err error) { return start, end, nil } -// triggerUserNotifications waits for all test users to connect, -// then creates and deletes a test user to trigger notification events for testing. -func triggerUserNotifications( - ctx context.Context, - logger slog.Logger, - client *codersdk.Client, - orgID uuid.UUID, - dialBarrier *sync.WaitGroup, - dialTimeout time.Duration, - expectedNotifications map[uuid.UUID]chan time.Time, -) { - logger.Info(ctx, "waiting for all users to connect") - - // Wait for all users to connect - waitCtx, cancel := context.WithTimeout(ctx, dialTimeout+30*time.Second) - defer cancel() - - done := make(chan struct{}) - go func() { - dialBarrier.Wait() - close(done) - }() - - select { - case <-done: - logger.Info(ctx, "all users connected") - case <-waitCtx.Done(): - if waitCtx.Err() == context.DeadlineExceeded { - logger.Error(ctx, "timeout waiting for users to connect") - } else { - logger.Info(ctx, "context canceled while waiting for users") - } - return - } - - const ( - triggerUsername = "scaletest-trigger-user" - triggerEmail = "scaletest-trigger@example.com" - ) - - logger.Info(ctx, "creating test user to test notifications", - slog.F("username", triggerUsername), - slog.F("email", triggerEmail), - slog.F("org_id", orgID)) - - testUser, err := client.CreateUserWithOrgs(ctx, codersdk.CreateUserRequestWithOrgs{ - OrganizationIDs: []uuid.UUID{orgID}, - Username: triggerUsername, - Email: triggerEmail, - Password: "test-password-123", - }) - if err != nil { - logger.Error(ctx, "create test user", slog.Error(err)) - return - } - expectedNotifications[notificationsLib.TemplateUserAccountCreated] <- time.Now() - - err = client.DeleteUser(ctx, testUser.ID) - if err != nil { - logger.Error(ctx, "delete test user", slog.Error(err)) - return - } - expectedNotifications[notificationsLib.TemplateUserAccountDeleted] <- time.Now() - close(expectedNotifications[notificationsLib.TemplateUserAccountCreated]) - close(expectedNotifications[notificationsLib.TemplateUserAccountDeleted]) -} - func createWorkspaceAppConfig(client *codersdk.Client, appHost, app string, workspace codersdk.Workspace, agent codersdk.WorkspaceAgent) (workspacetraffic.AppConfig, error) { if app == "" { return workspacetraffic.AppConfig{}, nil diff --git a/cli/exp_scaletest_notifications.go b/cli/exp_scaletest_notifications.go new file mode 100644 index 0000000000..1ea4785893 --- /dev/null +++ b/cli/exp_scaletest_notifications.go @@ -0,0 +1,447 @@ +//go:build !slim + +package cli + +import ( + "context" + "fmt" + "net/http" + "os/signal" + "strconv" + "strings" + "sync" + "time" + + "github.com/google/uuid" + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promhttp" + "golang.org/x/xerrors" + + "cdr.dev/slog" + + notificationsLib "github.com/coder/coder/v2/coderd/notifications" + "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/scaletest/createusers" + "github.com/coder/coder/v2/scaletest/harness" + "github.com/coder/coder/v2/scaletest/notifications" + "github.com/coder/serpent" +) + +func (r *RootCmd) scaletestNotifications() *serpent.Command { + var ( + userCount int64 + ownerUserPercentage float64 + notificationTimeout time.Duration + dialTimeout time.Duration + noCleanup bool + smtpAPIURL string + + tracingFlags = &scaletestTracingFlags{} + + // This test requires unlimited concurrency. + timeoutStrategy = &timeoutFlags{} + cleanupStrategy = newScaletestCleanupStrategy() + output = &scaletestOutputFlags{} + prometheusFlags = &scaletestPrometheusFlags{} + ) + + cmd := &serpent.Command{ + Use: "notifications", + Short: "Simulate notification delivery by creating many users listening to notifications.", + Handler: func(inv *serpent.Invocation) error { + ctx := inv.Context() + client, err := r.InitClient(inv) + if err != nil { + return err + } + + notifyCtx, stop := signal.NotifyContext(ctx, StopSignals...) + defer stop() + ctx = notifyCtx + + me, err := requireAdmin(ctx, client) + if err != nil { + return err + } + + client.HTTPClient = &http.Client{ + Transport: &codersdk.HeaderTransport{ + Transport: http.DefaultTransport, + Header: map[string][]string{ + codersdk.BypassRatelimitHeader: {"true"}, + }, + }, + } + + if userCount <= 0 { + return xerrors.Errorf("--user-count must be greater than 0") + } + + if ownerUserPercentage < 0 || ownerUserPercentage > 100 { + return xerrors.Errorf("--owner-user-percentage must be between 0 and 100") + } + + if smtpAPIURL != "" && !strings.HasPrefix(smtpAPIURL, "http://") && !strings.HasPrefix(smtpAPIURL, "https://") { + return xerrors.Errorf("--smtp-api-url must start with http:// or https://") + } + + ownerUserCount := int64(float64(userCount) * ownerUserPercentage / 100) + if ownerUserCount == 0 && ownerUserPercentage > 0 { + ownerUserCount = 1 + } + regularUserCount := userCount - ownerUserCount + + _, _ = fmt.Fprintf(inv.Stderr, "Distribution plan:\n") + _, _ = fmt.Fprintf(inv.Stderr, " Total users: %d\n", userCount) + _, _ = fmt.Fprintf(inv.Stderr, " Owner users: %d (%.1f%%)\n", ownerUserCount, ownerUserPercentage) + _, _ = fmt.Fprintf(inv.Stderr, " Regular users: %d (%.1f%%)\n", regularUserCount, 100.0-ownerUserPercentage) + + outputs, err := output.parse() + if err != nil { + return xerrors.Errorf("could not parse --output flags") + } + + tracerProvider, closeTracing, tracingEnabled, err := tracingFlags.provider(ctx) + if err != nil { + return xerrors.Errorf("create tracer provider: %w", err) + } + tracer := tracerProvider.Tracer(scaletestTracerName) + + reg := prometheus.NewRegistry() + metrics := notifications.NewMetrics(reg) + + logger := inv.Logger + prometheusSrvClose := ServeHandler(ctx, logger, promhttp.HandlerFor(reg, promhttp.HandlerOpts{}), prometheusFlags.Address, "prometheus") + defer prometheusSrvClose() + + defer func() { + _, _ = fmt.Fprintln(inv.Stderr, "\nUploading traces...") + if err := closeTracing(ctx); err != nil { + _, _ = fmt.Fprintf(inv.Stderr, "\nError uploading traces: %+v\n", err) + } + // Wait for prometheus metrics to be scraped + _, _ = fmt.Fprintf(inv.Stderr, "Waiting %s for prometheus metrics to be scraped\n", prometheusFlags.Wait) + <-time.After(prometheusFlags.Wait) + }() + + _, _ = fmt.Fprintln(inv.Stderr, "Creating users...") + + dialBarrier := &sync.WaitGroup{} + ownerWatchBarrier := &sync.WaitGroup{} + dialBarrier.Add(int(userCount)) + ownerWatchBarrier.Add(int(ownerUserCount)) + + expectedNotificationIDs := map[uuid.UUID]struct{}{ + notificationsLib.TemplateUserAccountCreated: {}, + notificationsLib.TemplateUserAccountDeleted: {}, + } + + triggerTimes := make(map[uuid.UUID]chan time.Time, len(expectedNotificationIDs)) + for id := range expectedNotificationIDs { + triggerTimes[id] = make(chan time.Time, 1) + } + + configs := make([]notifications.Config, 0, userCount) + for range ownerUserCount { + config := notifications.Config{ + User: createusers.Config{ + OrganizationID: me.OrganizationIDs[0], + }, + Roles: []string{codersdk.RoleOwner}, + NotificationTimeout: notificationTimeout, + DialTimeout: dialTimeout, + DialBarrier: dialBarrier, + ReceivingWatchBarrier: ownerWatchBarrier, + ExpectedNotificationsIDs: expectedNotificationIDs, + Metrics: metrics, + SMTPApiURL: smtpAPIURL, + } + if err := config.Validate(); err != nil { + return xerrors.Errorf("validate config: %w", err) + } + configs = append(configs, config) + } + for range regularUserCount { + config := notifications.Config{ + User: createusers.Config{ + OrganizationID: me.OrganizationIDs[0], + }, + Roles: []string{}, + NotificationTimeout: notificationTimeout, + DialTimeout: dialTimeout, + DialBarrier: dialBarrier, + ReceivingWatchBarrier: ownerWatchBarrier, + Metrics: metrics, + SMTPApiURL: smtpAPIURL, + } + if err := config.Validate(); err != nil { + return xerrors.Errorf("validate config: %w", err) + } + configs = append(configs, config) + } + + go triggerUserNotifications( + ctx, + logger, + client, + me.OrganizationIDs[0], + dialBarrier, + dialTimeout, + triggerTimes, + ) + + th := harness.NewTestHarness(timeoutStrategy.wrapStrategy(harness.ConcurrentExecutionStrategy{}), cleanupStrategy.toStrategy()) + + for i, config := range configs { + id := strconv.Itoa(i) + name := fmt.Sprintf("notifications-%s", id) + var runner harness.Runnable = notifications.NewRunner(client, config) + if tracingEnabled { + runner = &runnableTraceWrapper{ + tracer: tracer, + spanName: name, + runner: runner, + } + } + + th.AddRun(name, id, runner) + } + + _, _ = fmt.Fprintln(inv.Stderr, "Running notification delivery scaletest...") + testCtx, testCancel := timeoutStrategy.toContext(ctx) + defer testCancel() + err = th.Run(testCtx) + if err != nil { + return xerrors.Errorf("run test harness (harness failure, not a test failure): %w", err) + } + + // If the command was interrupted, skip stats. + if notifyCtx.Err() != nil { + return notifyCtx.Err() + } + + res := th.Results() + + if err := computeNotificationLatencies(ctx, logger, triggerTimes, res, metrics); err != nil { + return xerrors.Errorf("compute notification latencies: %w", err) + } + + for _, o := range outputs { + err = o.write(res, inv.Stdout) + if err != nil { + return xerrors.Errorf("write output %q to %q: %w", o.format, o.path, err) + } + } + + if !noCleanup { + _, _ = fmt.Fprintln(inv.Stderr, "\nCleaning up...") + cleanupCtx, cleanupCancel := cleanupStrategy.toContext(ctx) + defer cleanupCancel() + err = th.Cleanup(cleanupCtx) + if err != nil { + return xerrors.Errorf("cleanup tests: %w", err) + } + } + + if res.TotalFail > 0 { + return xerrors.New("load test failed, see above for more details") + } + + return nil + }, + } + + cmd.Options = serpent.OptionSet{ + { + Flag: "user-count", + FlagShorthand: "c", + Env: "CODER_SCALETEST_NOTIFICATION_USER_COUNT", + Description: "Required: Total number of users to create.", + Value: serpent.Int64Of(&userCount), + Required: true, + }, + { + Flag: "owner-user-percentage", + Env: "CODER_SCALETEST_NOTIFICATION_OWNER_USER_PERCENTAGE", + Default: "20.0", + Description: "Percentage of users to assign Owner role to (0-100).", + Value: serpent.Float64Of(&ownerUserPercentage), + }, + { + Flag: "notification-timeout", + Env: "CODER_SCALETEST_NOTIFICATION_TIMEOUT", + Default: "5m", + Description: "How long to wait for notifications after triggering.", + Value: serpent.DurationOf(¬ificationTimeout), + }, + { + Flag: "dial-timeout", + Env: "CODER_SCALETEST_DIAL_TIMEOUT", + Default: "2m", + Description: "Timeout for dialing the notification websocket endpoint.", + Value: serpent.DurationOf(&dialTimeout), + }, + { + Flag: "no-cleanup", + Env: "CODER_SCALETEST_NO_CLEANUP", + Description: "Do not clean up resources after the test completes.", + Value: serpent.BoolOf(&noCleanup), + }, + { + Flag: "smtp-api-url", + Env: "CODER_SCALETEST_SMTP_API_URL", + Description: "SMTP mock HTTP API address.", + Value: serpent.StringOf(&smtpAPIURL), + }, + } + + tracingFlags.attach(&cmd.Options) + timeoutStrategy.attach(&cmd.Options) + cleanupStrategy.attach(&cmd.Options) + output.attach(&cmd.Options) + prometheusFlags.attach(&cmd.Options) + return cmd +} + +func computeNotificationLatencies( + ctx context.Context, + logger slog.Logger, + expectedNotifications map[uuid.UUID]chan time.Time, + results harness.Results, + metrics *notifications.Metrics, +) error { + triggerTimes := make(map[uuid.UUID]time.Time) + for notificationID, triggerTimeChan := range expectedNotifications { + select { + case triggerTime := <-triggerTimeChan: + triggerTimes[notificationID] = triggerTime + logger.Info(ctx, "received trigger time", + slog.F("notification_id", notificationID), + slog.F("trigger_time", triggerTime)) + default: + logger.Warn(ctx, "no trigger time received for notification", + slog.F("notification_id", notificationID)) + } + } + + if len(triggerTimes) == 0 { + logger.Warn(ctx, "no trigger times available, skipping latency computation") + return nil + } + + var totalLatencies int + for runID, runResult := range results.Runs { + if runResult.Error != nil { + logger.Debug(ctx, "skipping failed run for latency computation", + slog.F("run_id", runID)) + continue + } + + if runResult.Metrics == nil { + continue + } + + // Process websocket notifications. + if wsReceiptTimes, ok := runResult.Metrics[notifications.WebsocketNotificationReceiptTimeMetric].(map[uuid.UUID]time.Time); ok { + for notificationID, receiptTime := range wsReceiptTimes { + if triggerTime, ok := triggerTimes[notificationID]; ok { + latency := receiptTime.Sub(triggerTime) + metrics.RecordLatency(latency, notificationID.String(), notifications.NotificationTypeWebsocket) + totalLatencies++ + logger.Debug(ctx, "computed websocket latency", + slog.F("run_id", runID), + slog.F("notification_id", notificationID), + slog.F("latency", latency)) + } + } + } + + // Process SMTP notifications + if smtpReceiptTimes, ok := runResult.Metrics[notifications.SMTPNotificationReceiptTimeMetric].(map[uuid.UUID]time.Time); ok { + for notificationID, receiptTime := range smtpReceiptTimes { + if triggerTime, ok := triggerTimes[notificationID]; ok { + latency := receiptTime.Sub(triggerTime) + metrics.RecordLatency(latency, notificationID.String(), notifications.NotificationTypeSMTP) + totalLatencies++ + logger.Debug(ctx, "computed SMTP latency", + slog.F("run_id", runID), + slog.F("notification_id", notificationID), + slog.F("latency", latency)) + } + } + } + } + + logger.Info(ctx, "finished computing notification latencies", + slog.F("total_runs", results.TotalRuns), + slog.F("total_latencies_computed", totalLatencies)) + + return nil +} + +// triggerUserNotifications waits for all test users to connect, +// then creates and deletes a test user to trigger notification events for testing. +func triggerUserNotifications( + ctx context.Context, + logger slog.Logger, + client *codersdk.Client, + orgID uuid.UUID, + dialBarrier *sync.WaitGroup, + dialTimeout time.Duration, + expectedNotifications map[uuid.UUID]chan time.Time, +) { + logger.Info(ctx, "waiting for all users to connect") + + // Wait for all users to connect + waitCtx, cancel := context.WithTimeout(ctx, dialTimeout+30*time.Second) + defer cancel() + + done := make(chan struct{}) + go func() { + dialBarrier.Wait() + close(done) + }() + + select { + case <-done: + logger.Info(ctx, "all users connected") + case <-waitCtx.Done(): + if waitCtx.Err() == context.DeadlineExceeded { + logger.Error(ctx, "timeout waiting for users to connect") + } else { + logger.Info(ctx, "context canceled while waiting for users") + } + return + } + + const ( + triggerUsername = "scaletest-trigger-user" + triggerEmail = "scaletest-trigger@example.com" + ) + + logger.Info(ctx, "creating test user to test notifications", + slog.F("username", triggerUsername), + slog.F("email", triggerEmail), + slog.F("org_id", orgID)) + + testUser, err := client.CreateUserWithOrgs(ctx, codersdk.CreateUserRequestWithOrgs{ + OrganizationIDs: []uuid.UUID{orgID}, + Username: triggerUsername, + Email: triggerEmail, + Password: "test-password-123", + }) + if err != nil { + logger.Error(ctx, "create test user", slog.Error(err)) + return + } + expectedNotifications[notificationsLib.TemplateUserAccountCreated] <- time.Now() + + err = client.DeleteUser(ctx, testUser.ID) + if err != nil { + logger.Error(ctx, "delete test user", slog.Error(err)) + return + } + expectedNotifications[notificationsLib.TemplateUserAccountDeleted] <- time.Now() + close(expectedNotifications[notificationsLib.TemplateUserAccountCreated]) + close(expectedNotifications[notificationsLib.TemplateUserAccountDeleted]) +} diff --git a/scaletest/notifications/config.go b/scaletest/notifications/config.go index ac1c6da49a..ac8daeb9ef 100644 --- a/scaletest/notifications/config.go +++ b/scaletest/notifications/config.go @@ -24,9 +24,8 @@ type Config struct { // DialTimeout is how long to wait for websocket connection. DialTimeout time.Duration `json:"dial_timeout"` - // ExpectedNotifications maps notification template IDs to channels - // that receive the trigger time for each notification. - ExpectedNotifications map[uuid.UUID]chan time.Time `json:"-"` + // ExpectedNotificationsIDs is the list of notification template IDs to expect. + ExpectedNotificationsIDs map[uuid.UUID]struct{} `json:"-"` Metrics *Metrics `json:"-"` @@ -35,6 +34,9 @@ type Config struct { // ReceivingWatchBarrier is the barrier for receiving users. Regular users wait on this to disconnect after receiving users complete. ReceivingWatchBarrier *sync.WaitGroup `json:"-"` + + // SMTPApiUrl is the URL of the SMTP mock HTTP API + SMTPApiURL string `json:"smtp_api_url"` } func (c Config) Validate() error { diff --git a/scaletest/notifications/metrics.go b/scaletest/notifications/metrics.go index c9e7374250..0bf3ebad74 100644 --- a/scaletest/notifications/metrics.go +++ b/scaletest/notifications/metrics.go @@ -6,10 +6,16 @@ import ( "github.com/prometheus/client_golang/prometheus" ) +type NotificationType string + +const ( + NotificationTypeWebsocket NotificationType = "websocket" + NotificationTypeSMTP NotificationType = "smtp" +) + type Metrics struct { notificationLatency *prometheus.HistogramVec notificationErrors *prometheus.CounterVec - missedNotifications *prometheus.CounterVec } func NewMetrics(reg prometheus.Registerer) *Metrics { @@ -22,37 +28,26 @@ func NewMetrics(reg prometheus.Registerer) *Metrics { Subsystem: "scaletest", Name: "notification_delivery_latency_seconds", Help: "Time between notification-creating action and receipt of notification by client", - }, []string{"username", "notification_type"}) + }, []string{"notification_id", "notification_type"}) errors := prometheus.NewCounterVec(prometheus.CounterOpts{ Namespace: "coderd", Subsystem: "scaletest", Name: "notification_delivery_errors_total", Help: "Total number of notification delivery errors", - }, []string{"username", "action"}) - missed := prometheus.NewCounterVec(prometheus.CounterOpts{ - Namespace: "coderd", - Subsystem: "scaletest", - Name: "notification_delivery_missed_total", - Help: "Total number of missed notifications", - }, []string{"username"}) + }, []string{"action"}) - reg.MustRegister(latency, errors, missed) + reg.MustRegister(latency, errors) return &Metrics{ notificationLatency: latency, notificationErrors: errors, - missedNotifications: missed, } } -func (m *Metrics) RecordLatency(latency time.Duration, username, notificationType string) { - m.notificationLatency.WithLabelValues(username, notificationType).Observe(latency.Seconds()) +func (m *Metrics) RecordLatency(latency time.Duration, notificationID string, notificationType NotificationType) { + m.notificationLatency.WithLabelValues(notificationID, string(notificationType)).Observe(latency.Seconds()) } -func (m *Metrics) AddError(username, action string) { - m.notificationErrors.WithLabelValues(username, action).Inc() -} - -func (m *Metrics) RecordMissed(username string) { - m.missedNotifications.WithLabelValues(username).Inc() +func (m *Metrics) AddError(action string) { + m.notificationErrors.WithLabelValues(action).Inc() } diff --git a/scaletest/notifications/run.go b/scaletest/notifications/run.go index d3d68e78ac..abe8445746 100644 --- a/scaletest/notifications/run.go +++ b/scaletest/notifications/run.go @@ -3,12 +3,16 @@ package notifications import ( "context" "encoding/json" + "errors" "fmt" "io" + "maps" "net/http" + "sync" "time" "github.com/google/uuid" + "golang.org/x/sync/errgroup" "golang.org/x/xerrors" "cdr.dev/slog" @@ -19,6 +23,8 @@ import ( "github.com/coder/coder/v2/scaletest/createusers" "github.com/coder/coder/v2/scaletest/harness" "github.com/coder/coder/v2/scaletest/loadtestutil" + "github.com/coder/coder/v2/scaletest/smtpmock" + "github.com/coder/quartz" "github.com/coder/websocket" ) @@ -28,18 +34,32 @@ type Runner struct { createUserRunner *createusers.Runner - // notificationLatencies stores the latency for each notification type - notificationLatencies map[uuid.UUID]time.Duration + // websocketReceiptTimes stores the receipt time for websocket notifications + websocketReceiptTimes map[uuid.UUID]time.Time + websocketReceiptTimesMu sync.RWMutex + + // smtpReceiptTimes stores the receipt time for SMTP notifications + smtpReceiptTimes map[uuid.UUID]time.Time + smtpReceiptTimesMu sync.RWMutex + + clock quartz.Clock } func NewRunner(client *codersdk.Client, cfg Config) *Runner { return &Runner{ client: client, cfg: cfg, - notificationLatencies: make(map[uuid.UUID]time.Duration), + websocketReceiptTimes: make(map[uuid.UUID]time.Time), + smtpReceiptTimes: make(map[uuid.UUID]time.Time), + clock: quartz.NewReal(), } } +func (r *Runner) WithClock(clock quartz.Clock) *Runner { + r.clock = clock + return r +} + var ( _ harness.Runnable = &Runner{} _ harness.Cleanable = &Runner{} @@ -59,7 +79,7 @@ func (r *Runner) Run(ctx context.Context, id string, logs io.Writer) error { reachedReceivingWatchBarrier := false defer func() { - if len(r.cfg.ExpectedNotifications) > 0 && !reachedReceivingWatchBarrier { + if len(r.cfg.ExpectedNotificationsIDs) > 0 && !reachedReceivingWatchBarrier { r.cfg.ReceivingWatchBarrier.Done() } }() @@ -72,7 +92,7 @@ func (r *Runner) Run(ctx context.Context, id string, logs io.Writer) error { r.createUserRunner = createusers.NewRunner(r.client, r.cfg.User) newUserAndToken, err := r.createUserRunner.RunReturningUser(ctx, id, logs) if err != nil { - r.cfg.Metrics.AddError("", "create_user") + r.cfg.Metrics.AddError("create_user") return xerrors.Errorf("create user: %w", err) } newUser := newUserAndToken.User @@ -90,7 +110,7 @@ func (r *Runner) Run(ctx context.Context, id string, logs io.Writer) error { Roles: r.cfg.Roles, }) if err != nil { - r.cfg.Metrics.AddError(newUser.Username, "assign_roles") + r.cfg.Metrics.AddError("assign_roles") return xerrors.Errorf("assign roles: %w", err) } } @@ -101,7 +121,7 @@ func (r *Runner) Run(ctx context.Context, id string, logs io.Writer) error { defer cancel() logger.Info(ctx, "connecting to notification websocket") - conn, err := r.dialNotificationWebsocket(dialCtx, newUserClient, newUser, logger) + conn, err := r.dialNotificationWebsocket(dialCtx, newUserClient, logger) if err != nil { return xerrors.Errorf("dial notification websocket: %w", err) } @@ -112,7 +132,7 @@ func (r *Runner) Run(ctx context.Context, id string, logs io.Writer) error { r.cfg.DialBarrier.Done() r.cfg.DialBarrier.Wait() - if len(r.cfg.ExpectedNotifications) == 0 { + if len(r.cfg.ExpectedNotificationsIDs) == 0 { logger.Info(ctx, "maintaining websocket connection, waiting for receiving users to complete") // Wait for receiving users to complete @@ -136,7 +156,20 @@ func (r *Runner) Run(ctx context.Context, id string, logs io.Writer) error { watchCtx, cancel := context.WithTimeout(ctx, r.cfg.NotificationTimeout) defer cancel() - if err := r.watchNotifications(watchCtx, conn, newUser, logger, r.cfg.ExpectedNotifications); err != nil { + eg, egCtx := errgroup.WithContext(watchCtx) + + eg.Go(func() error { + return r.watchNotifications(egCtx, conn, newUser, logger, r.cfg.ExpectedNotificationsIDs) + }) + + if r.cfg.SMTPApiURL != "" { + logger.Info(ctx, "running SMTP notification watcher") + eg.Go(func() error { + return r.watchNotificationsSMTP(egCtx, newUser, logger, r.cfg.ExpectedNotificationsIDs) + }) + } + + if err := eg.Wait(); err != nil { return xerrors.Errorf("notification watch failed: %w", err) } @@ -157,19 +190,31 @@ func (r *Runner) Cleanup(ctx context.Context, id string, logs io.Writer) error { return nil } -const NotificationDeliveryLatencyMetric = "notification_delivery_latency_seconds" +const ( + WebsocketNotificationReceiptTimeMetric = "notification_websocket_receipt_time" + SMTPNotificationReceiptTimeMetric = "notification_smtp_receipt_time" +) func (r *Runner) GetMetrics() map[string]any { + r.websocketReceiptTimesMu.RLock() + websocketReceiptTimes := maps.Clone(r.websocketReceiptTimes) + r.websocketReceiptTimesMu.RUnlock() + + r.smtpReceiptTimesMu.RLock() + smtpReceiptTimes := maps.Clone(r.smtpReceiptTimes) + r.smtpReceiptTimesMu.RUnlock() + return map[string]any{ - NotificationDeliveryLatencyMetric: r.notificationLatencies, + WebsocketNotificationReceiptTimeMetric: websocketReceiptTimes, + SMTPNotificationReceiptTimeMetric: smtpReceiptTimes, } } -func (r *Runner) dialNotificationWebsocket(ctx context.Context, client *codersdk.Client, user codersdk.User, logger slog.Logger) (*websocket.Conn, error) { +func (r *Runner) dialNotificationWebsocket(ctx context.Context, client *codersdk.Client, logger slog.Logger) (*websocket.Conn, error) { u, err := client.URL.Parse("/api/v2/notifications/inbox/watch") if err != nil { logger.Error(ctx, "parse notification URL", slog.Error(err)) - r.cfg.Metrics.AddError(user.Username, "parse_url") + r.cfg.Metrics.AddError("parse_url") return nil, xerrors.Errorf("parse notification URL: %w", err) } @@ -186,7 +231,7 @@ func (r *Runner) dialNotificationWebsocket(ctx context.Context, client *codersdk } } logger.Error(ctx, "dial notification websocket", slog.Error(err)) - r.cfg.Metrics.AddError(user.Username, "dial") + r.cfg.Metrics.AddError("dial") return nil, xerrors.Errorf("dial notification websocket: %w", err) } @@ -195,7 +240,7 @@ func (r *Runner) dialNotificationWebsocket(ctx context.Context, client *codersdk // watchNotifications reads notifications from the websocket and returns error or nil // once all expected notifications are received. -func (r *Runner) watchNotifications(ctx context.Context, conn *websocket.Conn, user codersdk.User, logger slog.Logger, expectedNotifications map[uuid.UUID]chan time.Time) error { +func (r *Runner) watchNotifications(ctx context.Context, conn *websocket.Conn, user codersdk.User, logger slog.Logger, expectedNotifications map[uuid.UUID]struct{}) error { logger.Info(ctx, "waiting for notifications", slog.F("username", user.Username), slog.F("expected_count", len(expectedNotifications))) @@ -217,28 +262,23 @@ func (r *Runner) watchNotifications(ctx context.Context, conn *websocket.Conn, u notif, err := readNotification(ctx, conn) if err != nil { logger.Error(ctx, "read notification", slog.Error(err)) - r.cfg.Metrics.AddError(user.Username, "read_notification") + r.cfg.Metrics.AddError("read_notification_websocket") return xerrors.Errorf("read notification: %w", err) } templateID := notif.Notification.TemplateID - if triggerTimeChan, exists := expectedNotifications[templateID]; exists { - if _, exists := receivedNotifications[templateID]; !exists { + if _, exists := expectedNotifications[templateID]; exists { + if _, received := receivedNotifications[templateID]; !received { receiptTime := time.Now() - select { - case triggerTime := <-triggerTimeChan: - latency := receiptTime.Sub(triggerTime) - r.notificationLatencies[templateID] = latency - r.cfg.Metrics.RecordLatency(latency, user.Username, templateID.String()) - receivedNotifications[templateID] = struct{}{} + r.websocketReceiptTimesMu.Lock() + r.websocketReceiptTimes[templateID] = receiptTime + r.websocketReceiptTimesMu.Unlock() + receivedNotifications[templateID] = struct{}{} - logger.Info(ctx, "received expected notification", - slog.F("template_id", templateID), - slog.F("title", notif.Notification.Title), - slog.F("latency", latency)) - case <-ctx.Done(): - return xerrors.Errorf("context canceled while waiting for trigger time: %w", ctx.Err()) - } + logger.Info(ctx, "received expected notification", + slog.F("template_id", templateID), + slog.F("title", notif.Notification.Title), + slog.F("receipt_time", receiptTime)) } } else { logger.Debug(ctx, "received notification not being tested", @@ -248,6 +288,97 @@ func (r *Runner) watchNotifications(ctx context.Context, conn *websocket.Conn, u } } +// watchNotificationsSMTP polls the SMTP HTTP API for notifications and returns error or nil +// once all expected notifications are received. +func (r *Runner) watchNotificationsSMTP(ctx context.Context, user codersdk.User, logger slog.Logger, expectedNotifications map[uuid.UUID]struct{}) error { + logger.Info(ctx, "polling SMTP API for notifications", + slog.F("email", user.Email), + slog.F("expected_count", len(expectedNotifications)), + ) + receivedNotifications := make(map[uuid.UUID]struct{}) + + apiURL := fmt.Sprintf("%s/messages?email=%s", r.cfg.SMTPApiURL, user.Email) + httpClient := &http.Client{ + Timeout: 10 * time.Second, + } + + const smtpPollInterval = 2 * time.Second + done := xerrors.New("done") + + tkr := r.clock.TickerFunc(ctx, smtpPollInterval, func() error { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, apiURL, nil) + if err != nil { + logger.Error(ctx, "create SMTP API request", slog.Error(err)) + r.cfg.Metrics.AddError("smtp_create_request") + return xerrors.Errorf("create SMTP API request: %w", err) + } + + resp, err := httpClient.Do(req) + if err != nil { + logger.Error(ctx, "poll smtp api for notifications", slog.Error(err)) + r.cfg.Metrics.AddError("smtp_poll") + return xerrors.Errorf("poll smtp api: %w", err) + } + + if resp.StatusCode != http.StatusOK { + _ = resp.Body.Close() + logger.Error(ctx, "smtp api returned non-200 status", slog.F("status", resp.StatusCode)) + r.cfg.Metrics.AddError("smtp_bad_status") + return xerrors.Errorf("smtp api returned status %d", resp.StatusCode) + } + + var summaries []smtpmock.EmailSummary + if err := json.NewDecoder(resp.Body).Decode(&summaries); err != nil { + _ = resp.Body.Close() + logger.Error(ctx, "decode smtp api response", slog.Error(err)) + r.cfg.Metrics.AddError("smtp_decode") + return xerrors.Errorf("decode smtp api response: %w", err) + } + _ = resp.Body.Close() + + // Process each email summary + for _, summary := range summaries { + notificationID := summary.NotificationTemplateID + if notificationID == uuid.Nil { + continue + } + + if _, exists := expectedNotifications[notificationID]; exists { + if _, received := receivedNotifications[notificationID]; !received { + receiptTime := summary.Date + if receiptTime.IsZero() { + receiptTime = time.Now() + } + + r.smtpReceiptTimesMu.Lock() + r.smtpReceiptTimes[notificationID] = receiptTime + r.smtpReceiptTimesMu.Unlock() + receivedNotifications[notificationID] = struct{}{} + + logger.Info(ctx, "received expected notification via SMTP", + slog.F("notification_id", notificationID), + slog.F("subject", summary.Subject), + slog.F("receipt_time", receiptTime)) + } + } + } + + if len(receivedNotifications) == len(expectedNotifications) { + logger.Info(ctx, "received all expected notifications via SMTP") + return done + } + + return nil + }, "smtp") + + err := tkr.Wait() + if errors.Is(err, done) { + return nil + } + + return err +} + func readNotification(ctx context.Context, conn *websocket.Conn) (codersdk.GetInboxNotificationResponse, error) { _, message, err := conn.Read(ctx) if err != nil { diff --git a/scaletest/notifications/run_test.go b/scaletest/notifications/run_test.go index e94e6d82ea..1e198e9edd 100644 --- a/scaletest/notifications/run_test.go +++ b/scaletest/notifications/run_test.go @@ -1,7 +1,11 @@ package notifications_test import ( + "context" + "encoding/json" "io" + "net/http" + "net/http/httptest" "strconv" "sync" "testing" @@ -9,22 +13,19 @@ import ( "github.com/google/uuid" "github.com/prometheus/client_golang/prometheus" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/sync/errgroup" - "golang.org/x/xerrors" - - "github.com/coder/serpent" "github.com/coder/coder/v2/coderd/coderdtest" "github.com/coder/coder/v2/coderd/database" - "github.com/coder/coder/v2/coderd/database/dbauthz" "github.com/coder/coder/v2/coderd/database/dbtestutil" notificationsLib "github.com/coder/coder/v2/coderd/notifications" "github.com/coder/coder/v2/coderd/notifications/dispatch" + "github.com/coder/coder/v2/coderd/notifications/types" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/scaletest/createusers" "github.com/coder/coder/v2/scaletest/notifications" + "github.com/coder/coder/v2/scaletest/smtpmock" "github.com/coder/coder/v2/testutil" "github.com/coder/quartz" ) @@ -36,39 +37,11 @@ func TestRun(t *testing.T) { logger := testutil.Logger(t) db, ps := dbtestutil.NewDB(t) - // Setup notifications manager with inbox handler - cfg := defaultNotificationsConfig(database.NotificationMethodSmtp) - mgr, err := notificationsLib.NewManager( - cfg, - db, - ps, - defaultHelpers(), - notificationsLib.NewMetrics(prometheus.NewRegistry()), - logger.Named("manager"), - ) - require.NoError(t, err) - - mgr.WithHandlers(map[database.NotificationMethod]notificationsLib.Handler{ - database.NotificationMethodInbox: dispatch.NewInboxHandler(logger.Named("inbox"), db, ps), - }) - t.Cleanup(func() { - assert.NoError(t, mgr.Stop(dbauthz.AsNotifier(ctx))) - }) - mgr.Run(dbauthz.AsNotifier(ctx)) - - enqueuer, err := notificationsLib.NewStoreEnqueuer( - cfg, - db, - defaultHelpers(), - logger.Named("enqueuer"), - quartz.NewReal(), - ) - require.NoError(t, err) + inboxHandler := dispatch.NewInboxHandler(logger.Named("inbox"), db, ps) client := coderdtest.New(t, &coderdtest.Options{ - Database: db, - Pubsub: ps, - NotificationsEnqueuer: enqueuer, + Database: db, + Pubsub: ps, }) firstUser := coderdtest.CreateFirstUser(t, client) @@ -82,9 +55,9 @@ func TestRun(t *testing.T) { eg, runCtx := errgroup.WithContext(ctx) - expectedNotifications := map[uuid.UUID]chan time.Time{ - notificationsLib.TemplateUserAccountCreated: make(chan time.Time, 1), - notificationsLib.TemplateUserAccountDeleted: make(chan time.Time, 1), + expectedNotificationsIDs := map[uuid.UUID]struct{}{ + notificationsLib.TemplateUserAccountCreated: {}, + notificationsLib.TemplateUserAccountDeleted: {}, } // Start receiving runners who will receive notifications @@ -93,14 +66,15 @@ func TestRun(t *testing.T) { runnerCfg := notifications.Config{ User: createusers.Config{ OrganizationID: firstUser.OrganizationID, + Username: "receiving-user-" + strconv.Itoa(i), }, - Roles: []string{codersdk.RoleOwner}, - NotificationTimeout: testutil.WaitLong, - DialTimeout: testutil.WaitLong, - Metrics: metrics, - DialBarrier: dialBarrier, - ReceivingWatchBarrier: receivingWatchBarrier, - ExpectedNotifications: expectedNotifications, + Roles: []string{codersdk.RoleOwner}, + NotificationTimeout: testutil.WaitLong, + DialTimeout: testutil.WaitLong, + Metrics: metrics, + DialBarrier: dialBarrier, + ReceivingWatchBarrier: receivingWatchBarrier, + ExpectedNotificationsIDs: expectedNotificationsIDs, } err := runnerCfg.Validate() require.NoError(t, err) @@ -141,31 +115,17 @@ func TestRun(t *testing.T) { // Wait for all runners to connect dialBarrier.Wait() - createTime := time.Now() - newUser, err := client.CreateUserWithOrgs(runCtx, codersdk.CreateUserRequestWithOrgs{ - OrganizationIDs: []uuid.UUID{firstUser.OrganizationID}, - Email: "test-user@coder.com", - Username: "test-user", - Password: "SomeSecurePassword!", - }) - if err != nil { - return xerrors.Errorf("create test user: %w", err) + for i := 0; i < numReceivingUsers; i++ { + err := sendInboxNotification(runCtx, t, db, inboxHandler, "receiving-user-"+strconv.Itoa(i), notificationsLib.TemplateUserAccountCreated) + require.NoError(t, err) + err = sendInboxNotification(runCtx, t, db, inboxHandler, "receiving-user-"+strconv.Itoa(i), notificationsLib.TemplateUserAccountDeleted) + require.NoError(t, err) } - expectedNotifications[notificationsLib.TemplateUserAccountCreated] <- createTime - - deleteTime := time.Now() - if err := client.DeleteUser(runCtx, newUser.ID); err != nil { - return xerrors.Errorf("delete test user: %w", err) - } - expectedNotifications[notificationsLib.TemplateUserAccountDeleted] <- deleteTime - - close(expectedNotifications[notificationsLib.TemplateUserAccountCreated]) - close(expectedNotifications[notificationsLib.TemplateUserAccountDeleted]) return nil }) - err = eg.Wait() + err := eg.Wait() require.NoError(t, err, "runner execution should complete successfully") cleanupEg, cleanupCtx := errgroup.WithContext(ctx) @@ -188,34 +148,196 @@ func TestRun(t *testing.T) { require.Equal(t, firstUser.UserID, users.Users[0].ID) for _, runner := range receivingRunners { - runnerMetrics := runner.GetMetrics()[notifications.NotificationDeliveryLatencyMetric].(map[uuid.UUID]time.Duration) - require.Contains(t, runnerMetrics, notificationsLib.TemplateUserAccountCreated) - require.Contains(t, runnerMetrics, notificationsLib.TemplateUserAccountDeleted) + metrics := runner.GetMetrics() + websocketReceiptTimes := metrics[notifications.WebsocketNotificationReceiptTimeMetric].(map[uuid.UUID]time.Time) + + require.Contains(t, websocketReceiptTimes, notificationsLib.TemplateUserAccountCreated) + require.Contains(t, websocketReceiptTimes, notificationsLib.TemplateUserAccountDeleted) } } -func defaultNotificationsConfig(method database.NotificationMethod) codersdk.NotificationsConfig { - return codersdk.NotificationsConfig{ - Method: serpent.String(method), - MaxSendAttempts: 5, - FetchInterval: serpent.Duration(time.Millisecond * 100), - StoreSyncInterval: serpent.Duration(time.Millisecond * 200), - LeasePeriod: serpent.Duration(time.Second * 10), - DispatchTimeout: serpent.Duration(time.Second * 5), - RetryInterval: serpent.Duration(time.Millisecond * 50), - LeaseCount: 10, - StoreSyncBufferSize: 50, - Inbox: codersdk.NotificationsInboxConfig{ - Enabled: serpent.Bool(true), - }, +func TestRunWithSMTP(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + logger := testutil.Logger(t) + db, ps := dbtestutil.NewDB(t) + + inboxHandler := dispatch.NewInboxHandler(logger.Named("inbox"), db, ps) + + client := coderdtest.New(t, &coderdtest.Options{ + Database: db, + Pubsub: ps, + }) + firstUser := coderdtest.CreateFirstUser(t, client) + + smtpAPIMux := http.NewServeMux() + smtpAPIMux.HandleFunc("/messages", func(w http.ResponseWriter, r *http.Request) { + summaries := []smtpmock.EmailSummary{ + { + Subject: "TemplateUserAccountCreated", + Date: time.Now(), + NotificationTemplateID: notificationsLib.TemplateUserAccountCreated, + }, + { + Subject: "TemplateUserAccountDeleted", + Date: time.Now(), + NotificationTemplateID: notificationsLib.TemplateUserAccountDeleted, + }, + } + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(summaries) + }) + + smtpAPIServer := httptest.NewServer(smtpAPIMux) + defer smtpAPIServer.Close() + + const numReceivingUsers = 2 + const numRegularUsers = 2 + dialBarrier := new(sync.WaitGroup) + receivingWatchBarrier := new(sync.WaitGroup) + dialBarrier.Add(numReceivingUsers + numRegularUsers) + receivingWatchBarrier.Add(numReceivingUsers) + metrics := notifications.NewMetrics(prometheus.NewRegistry()) + + eg, runCtx := errgroup.WithContext(ctx) + + expectedNotificationsIDs := map[uuid.UUID]struct{}{ + notificationsLib.TemplateUserAccountCreated: {}, + notificationsLib.TemplateUserAccountDeleted: {}, + } + + mClock := quartz.NewMock(t) + smtpTrap := mClock.Trap().TickerFunc("smtp") + defer smtpTrap.Close() + + // Start receiving runners who will receive notifications + receivingRunners := make([]*notifications.Runner, 0, numReceivingUsers) + for i := range numReceivingUsers { + runnerCfg := notifications.Config{ + User: createusers.Config{ + OrganizationID: firstUser.OrganizationID, + Username: "receiving-user-" + strconv.Itoa(i), + }, + Roles: []string{codersdk.RoleOwner}, + NotificationTimeout: testutil.WaitLong, + DialTimeout: testutil.WaitLong, + Metrics: metrics, + DialBarrier: dialBarrier, + ReceivingWatchBarrier: receivingWatchBarrier, + ExpectedNotificationsIDs: expectedNotificationsIDs, + SMTPApiURL: smtpAPIServer.URL, + } + err := runnerCfg.Validate() + require.NoError(t, err) + + runner := notifications.NewRunner(client, runnerCfg).WithClock(mClock) + receivingRunners = append(receivingRunners, runner) + eg.Go(func() error { + return runner.Run(runCtx, "receiving-"+strconv.Itoa(i), io.Discard) + }) + } + + // Start regular user runners who will maintain websocket connections + regularRunners := make([]*notifications.Runner, 0, numRegularUsers) + for i := range numRegularUsers { + runnerCfg := notifications.Config{ + User: createusers.Config{ + OrganizationID: firstUser.OrganizationID, + }, + Roles: []string{}, + NotificationTimeout: testutil.WaitLong, + DialTimeout: testutil.WaitLong, + Metrics: metrics, + DialBarrier: dialBarrier, + ReceivingWatchBarrier: receivingWatchBarrier, + } + err := runnerCfg.Validate() + require.NoError(t, err) + + runner := notifications.NewRunner(client, runnerCfg) + regularRunners = append(regularRunners, runner) + eg.Go(func() error { + return runner.Run(runCtx, "regular-"+strconv.Itoa(i), io.Discard) + }) + } + + // Trigger notifications by creating and deleting a user + eg.Go(func() error { + // Wait for all runners to connect + dialBarrier.Wait() + + for i := 0; i < numReceivingUsers; i++ { + smtpTrap.MustWait(runCtx).MustRelease(runCtx) + } + + for i := 0; i < numReceivingUsers; i++ { + err := sendInboxNotification(runCtx, t, db, inboxHandler, "receiving-user-"+strconv.Itoa(i), notificationsLib.TemplateUserAccountCreated) + require.NoError(t, err) + err = sendInboxNotification(runCtx, t, db, inboxHandler, "receiving-user-"+strconv.Itoa(i), notificationsLib.TemplateUserAccountDeleted) + require.NoError(t, err) + } + + _, w := mClock.AdvanceNext() + w.MustWait(runCtx) + + return nil + }) + + err := eg.Wait() + require.NoError(t, err, "runner execution with SMTP should complete successfully") + + cleanupEg, cleanupCtx := errgroup.WithContext(ctx) + for i, runner := range receivingRunners { + cleanupEg.Go(func() error { + return runner.Cleanup(cleanupCtx, "receiving-"+strconv.Itoa(i), io.Discard) + }) + } + for i, runner := range regularRunners { + cleanupEg.Go(func() error { + return runner.Cleanup(cleanupCtx, "regular-"+strconv.Itoa(i), io.Discard) + }) + } + err = cleanupEg.Wait() + require.NoError(t, err) + + users, err := client.Users(ctx, codersdk.UsersRequest{}) + require.NoError(t, err) + require.Len(t, users.Users, 1) + require.Equal(t, firstUser.UserID, users.Users[0].ID) + + // Verify that notifications were received via both websocket and SMTP + for _, runner := range receivingRunners { + metrics := runner.GetMetrics() + websocketReceiptTimes := metrics[notifications.WebsocketNotificationReceiptTimeMetric].(map[uuid.UUID]time.Time) + smtpReceiptTimes := metrics[notifications.SMTPNotificationReceiptTimeMetric].(map[uuid.UUID]time.Time) + + require.Contains(t, websocketReceiptTimes, notificationsLib.TemplateUserAccountCreated) + require.Contains(t, websocketReceiptTimes, notificationsLib.TemplateUserAccountDeleted) + require.Contains(t, smtpReceiptTimes, notificationsLib.TemplateUserAccountCreated) + require.Contains(t, smtpReceiptTimes, notificationsLib.TemplateUserAccountDeleted) } } -func defaultHelpers() map[string]any { - return map[string]any{ - "base_url": func() string { return "http://test.com" }, - "current_year": func() string { return "2024" }, - "logo_url": func() string { return "https://coder.com/coder-logo-horizontal.png" }, - "app_name": func() string { return "Coder" }, +func sendInboxNotification(ctx context.Context, t *testing.T, db database.Store, inboxHandler *dispatch.InboxHandler, username string, templateID uuid.UUID) error { + user, err := db.GetUserByEmailOrUsername(ctx, database.GetUserByEmailOrUsernameParams{ + Username: username, + }) + require.NoError(t, err) + + dispatchFunc, err := inboxHandler.Dispatcher(types.MessagePayload{ + UserID: user.ID.String(), + NotificationTemplateID: templateID.String(), + }, "", "", nil) + if err != nil { + return err } + + _, err = dispatchFunc(ctx, uuid.New()) + if err != nil { + return err + } + + return nil }