feat(scaletest): extend notifications runner with smtp support (#20222)

This PR extends the scaletest notification runner with SMTP support.

If the `--smtp-api-url` flag is provided, the runner will also watch for SMTP notifications using the specified URL.

#### Changes
- Added a new watcher to retrieve emails sent to the runner user  
- Tracked WebSocket and SMTP latencies separately  
- Updated metrics to include `notification_id` and `notification_type` labels  

#### CLI Flags
- `--smtp-api-url`: Address of the SMTP mock HTTP API used to retrieve email notifications  

#### Metrics
- `notification_delivery_latency_seconds` now includes:
  - `notification_id`
  - `notification_type` (`websocket` or `smtp`)
This commit is contained in:
Kacper Sawicki
2025-10-22 12:09:35 +02:00
committed by GitHub
parent 7bbeef4999
commit 1230cacf78
6 changed files with 850 additions and 466 deletions
+11 -324
View File
@@ -29,7 +29,6 @@ import (
"github.com/coder/coder/v2/cli/cliui"
"github.com/coder/coder/v2/coderd/httpapi"
notificationsLib "github.com/coder/coder/v2/coderd/notifications"
"github.com/coder/coder/v2/coderd/tracing"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
@@ -40,7 +39,6 @@ import (
"github.com/coder/coder/v2/scaletest/dashboard"
"github.com/coder/coder/v2/scaletest/harness"
"github.com/coder/coder/v2/scaletest/loadtestutil"
"github.com/coder/coder/v2/scaletest/notifications"
"github.com/coder/coder/v2/scaletest/reconnectingpty"
"github.com/coder/coder/v2/scaletest/workspacebuild"
"github.com/coder/coder/v2/scaletest/workspacetraffic"
@@ -1922,259 +1920,6 @@ func (r *RootCmd) scaletestAutostart() *serpent.Command {
return cmd
}
func (r *RootCmd) scaletestNotifications() *serpent.Command {
var (
userCount int64
ownerUserPercentage float64
notificationTimeout time.Duration
dialTimeout time.Duration
noCleanup bool
tracingFlags = &scaletestTracingFlags{}
// This test requires unlimited concurrency.
timeoutStrategy = &timeoutFlags{}
cleanupStrategy = newScaletestCleanupStrategy()
output = &scaletestOutputFlags{}
prometheusFlags = &scaletestPrometheusFlags{}
)
cmd := &serpent.Command{
Use: "notifications",
Short: "Simulate notification delivery by creating many users listening to notifications.",
Handler: func(inv *serpent.Invocation) error {
ctx := inv.Context()
client, err := r.InitClient(inv)
if err != nil {
return err
}
notifyCtx, stop := signal.NotifyContext(ctx, StopSignals...)
defer stop()
ctx = notifyCtx
me, err := requireAdmin(ctx, client)
if err != nil {
return err
}
client.HTTPClient = &http.Client{
Transport: &codersdk.HeaderTransport{
Transport: http.DefaultTransport,
Header: map[string][]string{
codersdk.BypassRatelimitHeader: {"true"},
},
},
}
if userCount <= 0 {
return xerrors.Errorf("--user-count must be greater than 0")
}
if ownerUserPercentage < 0 || ownerUserPercentage > 100 {
return xerrors.Errorf("--owner-user-percentage must be between 0 and 100")
}
ownerUserCount := int64(float64(userCount) * ownerUserPercentage / 100)
if ownerUserCount == 0 && ownerUserPercentage > 0 {
ownerUserCount = 1
}
regularUserCount := userCount - ownerUserCount
_, _ = fmt.Fprintf(inv.Stderr, "Distribution plan:\n")
_, _ = fmt.Fprintf(inv.Stderr, " Total users: %d\n", userCount)
_, _ = fmt.Fprintf(inv.Stderr, " Owner users: %d (%.1f%%)\n", ownerUserCount, ownerUserPercentage)
_, _ = fmt.Fprintf(inv.Stderr, " Regular users: %d (%.1f%%)\n", regularUserCount, 100.0-ownerUserPercentage)
outputs, err := output.parse()
if err != nil {
return xerrors.Errorf("could not parse --output flags")
}
tracerProvider, closeTracing, tracingEnabled, err := tracingFlags.provider(ctx)
if err != nil {
return xerrors.Errorf("create tracer provider: %w", err)
}
tracer := tracerProvider.Tracer(scaletestTracerName)
reg := prometheus.NewRegistry()
metrics := notifications.NewMetrics(reg)
logger := inv.Logger
prometheusSrvClose := ServeHandler(ctx, logger, promhttp.HandlerFor(reg, promhttp.HandlerOpts{}), prometheusFlags.Address, "prometheus")
defer prometheusSrvClose()
defer func() {
_, _ = fmt.Fprintln(inv.Stderr, "\nUploading traces...")
if err := closeTracing(ctx); err != nil {
_, _ = fmt.Fprintf(inv.Stderr, "\nError uploading traces: %+v\n", err)
}
// Wait for prometheus metrics to be scraped
_, _ = fmt.Fprintf(inv.Stderr, "Waiting %s for prometheus metrics to be scraped\n", prometheusFlags.Wait)
<-time.After(prometheusFlags.Wait)
}()
_, _ = fmt.Fprintln(inv.Stderr, "Creating users...")
dialBarrier := &sync.WaitGroup{}
ownerWatchBarrier := &sync.WaitGroup{}
dialBarrier.Add(int(userCount))
ownerWatchBarrier.Add(int(ownerUserCount))
expectedNotifications := map[uuid.UUID]chan time.Time{
notificationsLib.TemplateUserAccountCreated: make(chan time.Time, 1),
notificationsLib.TemplateUserAccountDeleted: make(chan time.Time, 1),
}
configs := make([]notifications.Config, 0, userCount)
for range ownerUserCount {
config := notifications.Config{
User: createusers.Config{
OrganizationID: me.OrganizationIDs[0],
},
Roles: []string{codersdk.RoleOwner},
NotificationTimeout: notificationTimeout,
DialTimeout: dialTimeout,
DialBarrier: dialBarrier,
ReceivingWatchBarrier: ownerWatchBarrier,
ExpectedNotifications: expectedNotifications,
Metrics: metrics,
}
if err := config.Validate(); err != nil {
return xerrors.Errorf("validate config: %w", err)
}
configs = append(configs, config)
}
for range regularUserCount {
config := notifications.Config{
User: createusers.Config{
OrganizationID: me.OrganizationIDs[0],
},
Roles: []string{},
NotificationTimeout: notificationTimeout,
DialTimeout: dialTimeout,
DialBarrier: dialBarrier,
ReceivingWatchBarrier: ownerWatchBarrier,
Metrics: metrics,
}
if err := config.Validate(); err != nil {
return xerrors.Errorf("validate config: %w", err)
}
configs = append(configs, config)
}
go triggerUserNotifications(
ctx,
logger,
client,
me.OrganizationIDs[0],
dialBarrier,
dialTimeout,
expectedNotifications,
)
th := harness.NewTestHarness(timeoutStrategy.wrapStrategy(harness.ConcurrentExecutionStrategy{}), cleanupStrategy.toStrategy())
for i, config := range configs {
id := strconv.Itoa(i)
name := fmt.Sprintf("notifications-%s", id)
var runner harness.Runnable = notifications.NewRunner(client, config)
if tracingEnabled {
runner = &runnableTraceWrapper{
tracer: tracer,
spanName: name,
runner: runner,
}
}
th.AddRun(name, id, runner)
}
_, _ = fmt.Fprintln(inv.Stderr, "Running notification delivery scaletest...")
testCtx, testCancel := timeoutStrategy.toContext(ctx)
defer testCancel()
err = th.Run(testCtx)
if err != nil {
return xerrors.Errorf("run test harness (harness failure, not a test failure): %w", err)
}
// If the command was interrupted, skip stats.
if notifyCtx.Err() != nil {
return notifyCtx.Err()
}
res := th.Results()
for _, o := range outputs {
err = o.write(res, inv.Stdout)
if err != nil {
return xerrors.Errorf("write output %q to %q: %w", o.format, o.path, err)
}
}
if !noCleanup {
_, _ = fmt.Fprintln(inv.Stderr, "\nCleaning up...")
cleanupCtx, cleanupCancel := cleanupStrategy.toContext(ctx)
defer cleanupCancel()
err = th.Cleanup(cleanupCtx)
if err != nil {
return xerrors.Errorf("cleanup tests: %w", err)
}
}
if res.TotalFail > 0 {
return xerrors.New("load test failed, see above for more details")
}
return nil
},
}
cmd.Options = serpent.OptionSet{
{
Flag: "user-count",
FlagShorthand: "c",
Env: "CODER_SCALETEST_NOTIFICATION_USER_COUNT",
Description: "Required: Total number of users to create.",
Value: serpent.Int64Of(&userCount),
Required: true,
},
{
Flag: "owner-user-percentage",
Env: "CODER_SCALETEST_NOTIFICATION_OWNER_USER_PERCENTAGE",
Default: "20.0",
Description: "Percentage of users to assign Owner role to (0-100).",
Value: serpent.Float64Of(&ownerUserPercentage),
},
{
Flag: "notification-timeout",
Env: "CODER_SCALETEST_NOTIFICATION_TIMEOUT",
Default: "5m",
Description: "How long to wait for notifications after triggering.",
Value: serpent.DurationOf(&notificationTimeout),
},
{
Flag: "dial-timeout",
Env: "CODER_SCALETEST_DIAL_TIMEOUT",
Default: "2m",
Description: "Timeout for dialing the notification websocket endpoint.",
Value: serpent.DurationOf(&dialTimeout),
},
{
Flag: "no-cleanup",
Env: "CODER_SCALETEST_NO_CLEANUP",
Description: "Do not clean up resources after the test completes.",
Value: serpent.BoolOf(&noCleanup),
},
}
tracingFlags.attach(&cmd.Options)
timeoutStrategy.attach(&cmd.Options)
cleanupStrategy.attach(&cmd.Options)
output.attach(&cmd.Options)
prometheusFlags.attach(&cmd.Options)
return cmd
}
type runnableTraceWrapper struct {
tracer trace.Tracer
spanName string
@@ -2184,8 +1929,9 @@ type runnableTraceWrapper struct {
}
var (
_ harness.Runnable = &runnableTraceWrapper{}
_ harness.Cleanable = &runnableTraceWrapper{}
_ harness.Runnable = &runnableTraceWrapper{}
_ harness.Cleanable = &runnableTraceWrapper{}
_ harness.Collectable = &runnableTraceWrapper{}
)
func (r *runnableTraceWrapper) Run(ctx context.Context, id string, logs io.Writer) error {
@@ -2227,6 +1973,14 @@ func (r *runnableTraceWrapper) Cleanup(ctx context.Context, id string, logs io.W
return c.Cleanup(ctx, id, logs)
}
func (r *runnableTraceWrapper) GetMetrics() map[string]any {
c, ok := r.runner.(harness.Collectable)
if !ok {
return nil
}
return c.GetMetrics()
}
func getScaletestWorkspaces(ctx context.Context, client *codersdk.Client, owner, template string) ([]codersdk.Workspace, int, error) {
var (
pageNumber = 0
@@ -2375,73 +2129,6 @@ func parseTargetRange(name, targets string) (start, end int, err error) {
return start, end, nil
}
// triggerUserNotifications waits for all test users to connect,
// then creates and deletes a test user to trigger notification events for testing.
func triggerUserNotifications(
ctx context.Context,
logger slog.Logger,
client *codersdk.Client,
orgID uuid.UUID,
dialBarrier *sync.WaitGroup,
dialTimeout time.Duration,
expectedNotifications map[uuid.UUID]chan time.Time,
) {
logger.Info(ctx, "waiting for all users to connect")
// Wait for all users to connect
waitCtx, cancel := context.WithTimeout(ctx, dialTimeout+30*time.Second)
defer cancel()
done := make(chan struct{})
go func() {
dialBarrier.Wait()
close(done)
}()
select {
case <-done:
logger.Info(ctx, "all users connected")
case <-waitCtx.Done():
if waitCtx.Err() == context.DeadlineExceeded {
logger.Error(ctx, "timeout waiting for users to connect")
} else {
logger.Info(ctx, "context canceled while waiting for users")
}
return
}
const (
triggerUsername = "scaletest-trigger-user"
triggerEmail = "scaletest-trigger@example.com"
)
logger.Info(ctx, "creating test user to test notifications",
slog.F("username", triggerUsername),
slog.F("email", triggerEmail),
slog.F("org_id", orgID))
testUser, err := client.CreateUserWithOrgs(ctx, codersdk.CreateUserRequestWithOrgs{
OrganizationIDs: []uuid.UUID{orgID},
Username: triggerUsername,
Email: triggerEmail,
Password: "test-password-123",
})
if err != nil {
logger.Error(ctx, "create test user", slog.Error(err))
return
}
expectedNotifications[notificationsLib.TemplateUserAccountCreated] <- time.Now()
err = client.DeleteUser(ctx, testUser.ID)
if err != nil {
logger.Error(ctx, "delete test user", slog.Error(err))
return
}
expectedNotifications[notificationsLib.TemplateUserAccountDeleted] <- time.Now()
close(expectedNotifications[notificationsLib.TemplateUserAccountCreated])
close(expectedNotifications[notificationsLib.TemplateUserAccountDeleted])
}
func createWorkspaceAppConfig(client *codersdk.Client, appHost, app string, workspace codersdk.Workspace, agent codersdk.WorkspaceAgent) (workspacetraffic.AppConfig, error) {
if app == "" {
return workspacetraffic.AppConfig{}, nil
+447
View File
@@ -0,0 +1,447 @@
//go:build !slim
package cli
import (
"context"
"fmt"
"net/http"
"os/signal"
"strconv"
"strings"
"sync"
"time"
"github.com/google/uuid"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
"golang.org/x/xerrors"
"cdr.dev/slog"
notificationsLib "github.com/coder/coder/v2/coderd/notifications"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/scaletest/createusers"
"github.com/coder/coder/v2/scaletest/harness"
"github.com/coder/coder/v2/scaletest/notifications"
"github.com/coder/serpent"
)
func (r *RootCmd) scaletestNotifications() *serpent.Command {
var (
userCount int64
ownerUserPercentage float64
notificationTimeout time.Duration
dialTimeout time.Duration
noCleanup bool
smtpAPIURL string
tracingFlags = &scaletestTracingFlags{}
// This test requires unlimited concurrency.
timeoutStrategy = &timeoutFlags{}
cleanupStrategy = newScaletestCleanupStrategy()
output = &scaletestOutputFlags{}
prometheusFlags = &scaletestPrometheusFlags{}
)
cmd := &serpent.Command{
Use: "notifications",
Short: "Simulate notification delivery by creating many users listening to notifications.",
Handler: func(inv *serpent.Invocation) error {
ctx := inv.Context()
client, err := r.InitClient(inv)
if err != nil {
return err
}
notifyCtx, stop := signal.NotifyContext(ctx, StopSignals...)
defer stop()
ctx = notifyCtx
me, err := requireAdmin(ctx, client)
if err != nil {
return err
}
client.HTTPClient = &http.Client{
Transport: &codersdk.HeaderTransport{
Transport: http.DefaultTransport,
Header: map[string][]string{
codersdk.BypassRatelimitHeader: {"true"},
},
},
}
if userCount <= 0 {
return xerrors.Errorf("--user-count must be greater than 0")
}
if ownerUserPercentage < 0 || ownerUserPercentage > 100 {
return xerrors.Errorf("--owner-user-percentage must be between 0 and 100")
}
if smtpAPIURL != "" && !strings.HasPrefix(smtpAPIURL, "http://") && !strings.HasPrefix(smtpAPIURL, "https://") {
return xerrors.Errorf("--smtp-api-url must start with http:// or https://")
}
ownerUserCount := int64(float64(userCount) * ownerUserPercentage / 100)
if ownerUserCount == 0 && ownerUserPercentage > 0 {
ownerUserCount = 1
}
regularUserCount := userCount - ownerUserCount
_, _ = fmt.Fprintf(inv.Stderr, "Distribution plan:\n")
_, _ = fmt.Fprintf(inv.Stderr, " Total users: %d\n", userCount)
_, _ = fmt.Fprintf(inv.Stderr, " Owner users: %d (%.1f%%)\n", ownerUserCount, ownerUserPercentage)
_, _ = fmt.Fprintf(inv.Stderr, " Regular users: %d (%.1f%%)\n", regularUserCount, 100.0-ownerUserPercentage)
outputs, err := output.parse()
if err != nil {
return xerrors.Errorf("could not parse --output flags")
}
tracerProvider, closeTracing, tracingEnabled, err := tracingFlags.provider(ctx)
if err != nil {
return xerrors.Errorf("create tracer provider: %w", err)
}
tracer := tracerProvider.Tracer(scaletestTracerName)
reg := prometheus.NewRegistry()
metrics := notifications.NewMetrics(reg)
logger := inv.Logger
prometheusSrvClose := ServeHandler(ctx, logger, promhttp.HandlerFor(reg, promhttp.HandlerOpts{}), prometheusFlags.Address, "prometheus")
defer prometheusSrvClose()
defer func() {
_, _ = fmt.Fprintln(inv.Stderr, "\nUploading traces...")
if err := closeTracing(ctx); err != nil {
_, _ = fmt.Fprintf(inv.Stderr, "\nError uploading traces: %+v\n", err)
}
// Wait for prometheus metrics to be scraped
_, _ = fmt.Fprintf(inv.Stderr, "Waiting %s for prometheus metrics to be scraped\n", prometheusFlags.Wait)
<-time.After(prometheusFlags.Wait)
}()
_, _ = fmt.Fprintln(inv.Stderr, "Creating users...")
dialBarrier := &sync.WaitGroup{}
ownerWatchBarrier := &sync.WaitGroup{}
dialBarrier.Add(int(userCount))
ownerWatchBarrier.Add(int(ownerUserCount))
expectedNotificationIDs := map[uuid.UUID]struct{}{
notificationsLib.TemplateUserAccountCreated: {},
notificationsLib.TemplateUserAccountDeleted: {},
}
triggerTimes := make(map[uuid.UUID]chan time.Time, len(expectedNotificationIDs))
for id := range expectedNotificationIDs {
triggerTimes[id] = make(chan time.Time, 1)
}
configs := make([]notifications.Config, 0, userCount)
for range ownerUserCount {
config := notifications.Config{
User: createusers.Config{
OrganizationID: me.OrganizationIDs[0],
},
Roles: []string{codersdk.RoleOwner},
NotificationTimeout: notificationTimeout,
DialTimeout: dialTimeout,
DialBarrier: dialBarrier,
ReceivingWatchBarrier: ownerWatchBarrier,
ExpectedNotificationsIDs: expectedNotificationIDs,
Metrics: metrics,
SMTPApiURL: smtpAPIURL,
}
if err := config.Validate(); err != nil {
return xerrors.Errorf("validate config: %w", err)
}
configs = append(configs, config)
}
for range regularUserCount {
config := notifications.Config{
User: createusers.Config{
OrganizationID: me.OrganizationIDs[0],
},
Roles: []string{},
NotificationTimeout: notificationTimeout,
DialTimeout: dialTimeout,
DialBarrier: dialBarrier,
ReceivingWatchBarrier: ownerWatchBarrier,
Metrics: metrics,
SMTPApiURL: smtpAPIURL,
}
if err := config.Validate(); err != nil {
return xerrors.Errorf("validate config: %w", err)
}
configs = append(configs, config)
}
go triggerUserNotifications(
ctx,
logger,
client,
me.OrganizationIDs[0],
dialBarrier,
dialTimeout,
triggerTimes,
)
th := harness.NewTestHarness(timeoutStrategy.wrapStrategy(harness.ConcurrentExecutionStrategy{}), cleanupStrategy.toStrategy())
for i, config := range configs {
id := strconv.Itoa(i)
name := fmt.Sprintf("notifications-%s", id)
var runner harness.Runnable = notifications.NewRunner(client, config)
if tracingEnabled {
runner = &runnableTraceWrapper{
tracer: tracer,
spanName: name,
runner: runner,
}
}
th.AddRun(name, id, runner)
}
_, _ = fmt.Fprintln(inv.Stderr, "Running notification delivery scaletest...")
testCtx, testCancel := timeoutStrategy.toContext(ctx)
defer testCancel()
err = th.Run(testCtx)
if err != nil {
return xerrors.Errorf("run test harness (harness failure, not a test failure): %w", err)
}
// If the command was interrupted, skip stats.
if notifyCtx.Err() != nil {
return notifyCtx.Err()
}
res := th.Results()
if err := computeNotificationLatencies(ctx, logger, triggerTimes, res, metrics); err != nil {
return xerrors.Errorf("compute notification latencies: %w", err)
}
for _, o := range outputs {
err = o.write(res, inv.Stdout)
if err != nil {
return xerrors.Errorf("write output %q to %q: %w", o.format, o.path, err)
}
}
if !noCleanup {
_, _ = fmt.Fprintln(inv.Stderr, "\nCleaning up...")
cleanupCtx, cleanupCancel := cleanupStrategy.toContext(ctx)
defer cleanupCancel()
err = th.Cleanup(cleanupCtx)
if err != nil {
return xerrors.Errorf("cleanup tests: %w", err)
}
}
if res.TotalFail > 0 {
return xerrors.New("load test failed, see above for more details")
}
return nil
},
}
cmd.Options = serpent.OptionSet{
{
Flag: "user-count",
FlagShorthand: "c",
Env: "CODER_SCALETEST_NOTIFICATION_USER_COUNT",
Description: "Required: Total number of users to create.",
Value: serpent.Int64Of(&userCount),
Required: true,
},
{
Flag: "owner-user-percentage",
Env: "CODER_SCALETEST_NOTIFICATION_OWNER_USER_PERCENTAGE",
Default: "20.0",
Description: "Percentage of users to assign Owner role to (0-100).",
Value: serpent.Float64Of(&ownerUserPercentage),
},
{
Flag: "notification-timeout",
Env: "CODER_SCALETEST_NOTIFICATION_TIMEOUT",
Default: "5m",
Description: "How long to wait for notifications after triggering.",
Value: serpent.DurationOf(&notificationTimeout),
},
{
Flag: "dial-timeout",
Env: "CODER_SCALETEST_DIAL_TIMEOUT",
Default: "2m",
Description: "Timeout for dialing the notification websocket endpoint.",
Value: serpent.DurationOf(&dialTimeout),
},
{
Flag: "no-cleanup",
Env: "CODER_SCALETEST_NO_CLEANUP",
Description: "Do not clean up resources after the test completes.",
Value: serpent.BoolOf(&noCleanup),
},
{
Flag: "smtp-api-url",
Env: "CODER_SCALETEST_SMTP_API_URL",
Description: "SMTP mock HTTP API address.",
Value: serpent.StringOf(&smtpAPIURL),
},
}
tracingFlags.attach(&cmd.Options)
timeoutStrategy.attach(&cmd.Options)
cleanupStrategy.attach(&cmd.Options)
output.attach(&cmd.Options)
prometheusFlags.attach(&cmd.Options)
return cmd
}
func computeNotificationLatencies(
ctx context.Context,
logger slog.Logger,
expectedNotifications map[uuid.UUID]chan time.Time,
results harness.Results,
metrics *notifications.Metrics,
) error {
triggerTimes := make(map[uuid.UUID]time.Time)
for notificationID, triggerTimeChan := range expectedNotifications {
select {
case triggerTime := <-triggerTimeChan:
triggerTimes[notificationID] = triggerTime
logger.Info(ctx, "received trigger time",
slog.F("notification_id", notificationID),
slog.F("trigger_time", triggerTime))
default:
logger.Warn(ctx, "no trigger time received for notification",
slog.F("notification_id", notificationID))
}
}
if len(triggerTimes) == 0 {
logger.Warn(ctx, "no trigger times available, skipping latency computation")
return nil
}
var totalLatencies int
for runID, runResult := range results.Runs {
if runResult.Error != nil {
logger.Debug(ctx, "skipping failed run for latency computation",
slog.F("run_id", runID))
continue
}
if runResult.Metrics == nil {
continue
}
// Process websocket notifications.
if wsReceiptTimes, ok := runResult.Metrics[notifications.WebsocketNotificationReceiptTimeMetric].(map[uuid.UUID]time.Time); ok {
for notificationID, receiptTime := range wsReceiptTimes {
if triggerTime, ok := triggerTimes[notificationID]; ok {
latency := receiptTime.Sub(triggerTime)
metrics.RecordLatency(latency, notificationID.String(), notifications.NotificationTypeWebsocket)
totalLatencies++
logger.Debug(ctx, "computed websocket latency",
slog.F("run_id", runID),
slog.F("notification_id", notificationID),
slog.F("latency", latency))
}
}
}
// Process SMTP notifications
if smtpReceiptTimes, ok := runResult.Metrics[notifications.SMTPNotificationReceiptTimeMetric].(map[uuid.UUID]time.Time); ok {
for notificationID, receiptTime := range smtpReceiptTimes {
if triggerTime, ok := triggerTimes[notificationID]; ok {
latency := receiptTime.Sub(triggerTime)
metrics.RecordLatency(latency, notificationID.String(), notifications.NotificationTypeSMTP)
totalLatencies++
logger.Debug(ctx, "computed SMTP latency",
slog.F("run_id", runID),
slog.F("notification_id", notificationID),
slog.F("latency", latency))
}
}
}
}
logger.Info(ctx, "finished computing notification latencies",
slog.F("total_runs", results.TotalRuns),
slog.F("total_latencies_computed", totalLatencies))
return nil
}
// triggerUserNotifications waits for all test users to connect,
// then creates and deletes a test user to trigger notification events for testing.
func triggerUserNotifications(
ctx context.Context,
logger slog.Logger,
client *codersdk.Client,
orgID uuid.UUID,
dialBarrier *sync.WaitGroup,
dialTimeout time.Duration,
expectedNotifications map[uuid.UUID]chan time.Time,
) {
logger.Info(ctx, "waiting for all users to connect")
// Wait for all users to connect
waitCtx, cancel := context.WithTimeout(ctx, dialTimeout+30*time.Second)
defer cancel()
done := make(chan struct{})
go func() {
dialBarrier.Wait()
close(done)
}()
select {
case <-done:
logger.Info(ctx, "all users connected")
case <-waitCtx.Done():
if waitCtx.Err() == context.DeadlineExceeded {
logger.Error(ctx, "timeout waiting for users to connect")
} else {
logger.Info(ctx, "context canceled while waiting for users")
}
return
}
const (
triggerUsername = "scaletest-trigger-user"
triggerEmail = "scaletest-trigger@example.com"
)
logger.Info(ctx, "creating test user to test notifications",
slog.F("username", triggerUsername),
slog.F("email", triggerEmail),
slog.F("org_id", orgID))
testUser, err := client.CreateUserWithOrgs(ctx, codersdk.CreateUserRequestWithOrgs{
OrganizationIDs: []uuid.UUID{orgID},
Username: triggerUsername,
Email: triggerEmail,
Password: "test-password-123",
})
if err != nil {
logger.Error(ctx, "create test user", slog.Error(err))
return
}
expectedNotifications[notificationsLib.TemplateUserAccountCreated] <- time.Now()
err = client.DeleteUser(ctx, testUser.ID)
if err != nil {
logger.Error(ctx, "delete test user", slog.Error(err))
return
}
expectedNotifications[notificationsLib.TemplateUserAccountDeleted] <- time.Now()
close(expectedNotifications[notificationsLib.TemplateUserAccountCreated])
close(expectedNotifications[notificationsLib.TemplateUserAccountDeleted])
}
+5 -3
View File
@@ -24,9 +24,8 @@ type Config struct {
// DialTimeout is how long to wait for websocket connection.
DialTimeout time.Duration `json:"dial_timeout"`
// ExpectedNotifications maps notification template IDs to channels
// that receive the trigger time for each notification.
ExpectedNotifications map[uuid.UUID]chan time.Time `json:"-"`
// ExpectedNotificationsIDs is the list of notification template IDs to expect.
ExpectedNotificationsIDs map[uuid.UUID]struct{} `json:"-"`
Metrics *Metrics `json:"-"`
@@ -35,6 +34,9 @@ type Config struct {
// ReceivingWatchBarrier is the barrier for receiving users. Regular users wait on this to disconnect after receiving users complete.
ReceivingWatchBarrier *sync.WaitGroup `json:"-"`
// SMTPApiUrl is the URL of the SMTP mock HTTP API
SMTPApiURL string `json:"smtp_api_url"`
}
func (c Config) Validate() error {
+14 -19
View File
@@ -6,10 +6,16 @@ import (
"github.com/prometheus/client_golang/prometheus"
)
type NotificationType string
const (
NotificationTypeWebsocket NotificationType = "websocket"
NotificationTypeSMTP NotificationType = "smtp"
)
type Metrics struct {
notificationLatency *prometheus.HistogramVec
notificationErrors *prometheus.CounterVec
missedNotifications *prometheus.CounterVec
}
func NewMetrics(reg prometheus.Registerer) *Metrics {
@@ -22,37 +28,26 @@ func NewMetrics(reg prometheus.Registerer) *Metrics {
Subsystem: "scaletest",
Name: "notification_delivery_latency_seconds",
Help: "Time between notification-creating action and receipt of notification by client",
}, []string{"username", "notification_type"})
}, []string{"notification_id", "notification_type"})
errors := prometheus.NewCounterVec(prometheus.CounterOpts{
Namespace: "coderd",
Subsystem: "scaletest",
Name: "notification_delivery_errors_total",
Help: "Total number of notification delivery errors",
}, []string{"username", "action"})
missed := prometheus.NewCounterVec(prometheus.CounterOpts{
Namespace: "coderd",
Subsystem: "scaletest",
Name: "notification_delivery_missed_total",
Help: "Total number of missed notifications",
}, []string{"username"})
}, []string{"action"})
reg.MustRegister(latency, errors, missed)
reg.MustRegister(latency, errors)
return &Metrics{
notificationLatency: latency,
notificationErrors: errors,
missedNotifications: missed,
}
}
func (m *Metrics) RecordLatency(latency time.Duration, username, notificationType string) {
m.notificationLatency.WithLabelValues(username, notificationType).Observe(latency.Seconds())
func (m *Metrics) RecordLatency(latency time.Duration, notificationID string, notificationType NotificationType) {
m.notificationLatency.WithLabelValues(notificationID, string(notificationType)).Observe(latency.Seconds())
}
func (m *Metrics) AddError(username, action string) {
m.notificationErrors.WithLabelValues(username, action).Inc()
}
func (m *Metrics) RecordMissed(username string) {
m.missedNotifications.WithLabelValues(username).Inc()
func (m *Metrics) AddError(action string) {
m.notificationErrors.WithLabelValues(action).Inc()
}
+162 -31
View File
@@ -3,12 +3,16 @@ package notifications
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"maps"
"net/http"
"sync"
"time"
"github.com/google/uuid"
"golang.org/x/sync/errgroup"
"golang.org/x/xerrors"
"cdr.dev/slog"
@@ -19,6 +23,8 @@ import (
"github.com/coder/coder/v2/scaletest/createusers"
"github.com/coder/coder/v2/scaletest/harness"
"github.com/coder/coder/v2/scaletest/loadtestutil"
"github.com/coder/coder/v2/scaletest/smtpmock"
"github.com/coder/quartz"
"github.com/coder/websocket"
)
@@ -28,18 +34,32 @@ type Runner struct {
createUserRunner *createusers.Runner
// notificationLatencies stores the latency for each notification type
notificationLatencies map[uuid.UUID]time.Duration
// websocketReceiptTimes stores the receipt time for websocket notifications
websocketReceiptTimes map[uuid.UUID]time.Time
websocketReceiptTimesMu sync.RWMutex
// smtpReceiptTimes stores the receipt time for SMTP notifications
smtpReceiptTimes map[uuid.UUID]time.Time
smtpReceiptTimesMu sync.RWMutex
clock quartz.Clock
}
func NewRunner(client *codersdk.Client, cfg Config) *Runner {
return &Runner{
client: client,
cfg: cfg,
notificationLatencies: make(map[uuid.UUID]time.Duration),
websocketReceiptTimes: make(map[uuid.UUID]time.Time),
smtpReceiptTimes: make(map[uuid.UUID]time.Time),
clock: quartz.NewReal(),
}
}
func (r *Runner) WithClock(clock quartz.Clock) *Runner {
r.clock = clock
return r
}
var (
_ harness.Runnable = &Runner{}
_ harness.Cleanable = &Runner{}
@@ -59,7 +79,7 @@ func (r *Runner) Run(ctx context.Context, id string, logs io.Writer) error {
reachedReceivingWatchBarrier := false
defer func() {
if len(r.cfg.ExpectedNotifications) > 0 && !reachedReceivingWatchBarrier {
if len(r.cfg.ExpectedNotificationsIDs) > 0 && !reachedReceivingWatchBarrier {
r.cfg.ReceivingWatchBarrier.Done()
}
}()
@@ -72,7 +92,7 @@ func (r *Runner) Run(ctx context.Context, id string, logs io.Writer) error {
r.createUserRunner = createusers.NewRunner(r.client, r.cfg.User)
newUserAndToken, err := r.createUserRunner.RunReturningUser(ctx, id, logs)
if err != nil {
r.cfg.Metrics.AddError("", "create_user")
r.cfg.Metrics.AddError("create_user")
return xerrors.Errorf("create user: %w", err)
}
newUser := newUserAndToken.User
@@ -90,7 +110,7 @@ func (r *Runner) Run(ctx context.Context, id string, logs io.Writer) error {
Roles: r.cfg.Roles,
})
if err != nil {
r.cfg.Metrics.AddError(newUser.Username, "assign_roles")
r.cfg.Metrics.AddError("assign_roles")
return xerrors.Errorf("assign roles: %w", err)
}
}
@@ -101,7 +121,7 @@ func (r *Runner) Run(ctx context.Context, id string, logs io.Writer) error {
defer cancel()
logger.Info(ctx, "connecting to notification websocket")
conn, err := r.dialNotificationWebsocket(dialCtx, newUserClient, newUser, logger)
conn, err := r.dialNotificationWebsocket(dialCtx, newUserClient, logger)
if err != nil {
return xerrors.Errorf("dial notification websocket: %w", err)
}
@@ -112,7 +132,7 @@ func (r *Runner) Run(ctx context.Context, id string, logs io.Writer) error {
r.cfg.DialBarrier.Done()
r.cfg.DialBarrier.Wait()
if len(r.cfg.ExpectedNotifications) == 0 {
if len(r.cfg.ExpectedNotificationsIDs) == 0 {
logger.Info(ctx, "maintaining websocket connection, waiting for receiving users to complete")
// Wait for receiving users to complete
@@ -136,7 +156,20 @@ func (r *Runner) Run(ctx context.Context, id string, logs io.Writer) error {
watchCtx, cancel := context.WithTimeout(ctx, r.cfg.NotificationTimeout)
defer cancel()
if err := r.watchNotifications(watchCtx, conn, newUser, logger, r.cfg.ExpectedNotifications); err != nil {
eg, egCtx := errgroup.WithContext(watchCtx)
eg.Go(func() error {
return r.watchNotifications(egCtx, conn, newUser, logger, r.cfg.ExpectedNotificationsIDs)
})
if r.cfg.SMTPApiURL != "" {
logger.Info(ctx, "running SMTP notification watcher")
eg.Go(func() error {
return r.watchNotificationsSMTP(egCtx, newUser, logger, r.cfg.ExpectedNotificationsIDs)
})
}
if err := eg.Wait(); err != nil {
return xerrors.Errorf("notification watch failed: %w", err)
}
@@ -157,19 +190,31 @@ func (r *Runner) Cleanup(ctx context.Context, id string, logs io.Writer) error {
return nil
}
const NotificationDeliveryLatencyMetric = "notification_delivery_latency_seconds"
const (
WebsocketNotificationReceiptTimeMetric = "notification_websocket_receipt_time"
SMTPNotificationReceiptTimeMetric = "notification_smtp_receipt_time"
)
func (r *Runner) GetMetrics() map[string]any {
r.websocketReceiptTimesMu.RLock()
websocketReceiptTimes := maps.Clone(r.websocketReceiptTimes)
r.websocketReceiptTimesMu.RUnlock()
r.smtpReceiptTimesMu.RLock()
smtpReceiptTimes := maps.Clone(r.smtpReceiptTimes)
r.smtpReceiptTimesMu.RUnlock()
return map[string]any{
NotificationDeliveryLatencyMetric: r.notificationLatencies,
WebsocketNotificationReceiptTimeMetric: websocketReceiptTimes,
SMTPNotificationReceiptTimeMetric: smtpReceiptTimes,
}
}
func (r *Runner) dialNotificationWebsocket(ctx context.Context, client *codersdk.Client, user codersdk.User, logger slog.Logger) (*websocket.Conn, error) {
func (r *Runner) dialNotificationWebsocket(ctx context.Context, client *codersdk.Client, logger slog.Logger) (*websocket.Conn, error) {
u, err := client.URL.Parse("/api/v2/notifications/inbox/watch")
if err != nil {
logger.Error(ctx, "parse notification URL", slog.Error(err))
r.cfg.Metrics.AddError(user.Username, "parse_url")
r.cfg.Metrics.AddError("parse_url")
return nil, xerrors.Errorf("parse notification URL: %w", err)
}
@@ -186,7 +231,7 @@ func (r *Runner) dialNotificationWebsocket(ctx context.Context, client *codersdk
}
}
logger.Error(ctx, "dial notification websocket", slog.Error(err))
r.cfg.Metrics.AddError(user.Username, "dial")
r.cfg.Metrics.AddError("dial")
return nil, xerrors.Errorf("dial notification websocket: %w", err)
}
@@ -195,7 +240,7 @@ func (r *Runner) dialNotificationWebsocket(ctx context.Context, client *codersdk
// watchNotifications reads notifications from the websocket and returns error or nil
// once all expected notifications are received.
func (r *Runner) watchNotifications(ctx context.Context, conn *websocket.Conn, user codersdk.User, logger slog.Logger, expectedNotifications map[uuid.UUID]chan time.Time) error {
func (r *Runner) watchNotifications(ctx context.Context, conn *websocket.Conn, user codersdk.User, logger slog.Logger, expectedNotifications map[uuid.UUID]struct{}) error {
logger.Info(ctx, "waiting for notifications",
slog.F("username", user.Username),
slog.F("expected_count", len(expectedNotifications)))
@@ -217,28 +262,23 @@ func (r *Runner) watchNotifications(ctx context.Context, conn *websocket.Conn, u
notif, err := readNotification(ctx, conn)
if err != nil {
logger.Error(ctx, "read notification", slog.Error(err))
r.cfg.Metrics.AddError(user.Username, "read_notification")
r.cfg.Metrics.AddError("read_notification_websocket")
return xerrors.Errorf("read notification: %w", err)
}
templateID := notif.Notification.TemplateID
if triggerTimeChan, exists := expectedNotifications[templateID]; exists {
if _, exists := receivedNotifications[templateID]; !exists {
if _, exists := expectedNotifications[templateID]; exists {
if _, received := receivedNotifications[templateID]; !received {
receiptTime := time.Now()
select {
case triggerTime := <-triggerTimeChan:
latency := receiptTime.Sub(triggerTime)
r.notificationLatencies[templateID] = latency
r.cfg.Metrics.RecordLatency(latency, user.Username, templateID.String())
receivedNotifications[templateID] = struct{}{}
r.websocketReceiptTimesMu.Lock()
r.websocketReceiptTimes[templateID] = receiptTime
r.websocketReceiptTimesMu.Unlock()
receivedNotifications[templateID] = struct{}{}
logger.Info(ctx, "received expected notification",
slog.F("template_id", templateID),
slog.F("title", notif.Notification.Title),
slog.F("latency", latency))
case <-ctx.Done():
return xerrors.Errorf("context canceled while waiting for trigger time: %w", ctx.Err())
}
logger.Info(ctx, "received expected notification",
slog.F("template_id", templateID),
slog.F("title", notif.Notification.Title),
slog.F("receipt_time", receiptTime))
}
} else {
logger.Debug(ctx, "received notification not being tested",
@@ -248,6 +288,97 @@ func (r *Runner) watchNotifications(ctx context.Context, conn *websocket.Conn, u
}
}
// watchNotificationsSMTP polls the SMTP HTTP API for notifications and returns error or nil
// once all expected notifications are received.
func (r *Runner) watchNotificationsSMTP(ctx context.Context, user codersdk.User, logger slog.Logger, expectedNotifications map[uuid.UUID]struct{}) error {
logger.Info(ctx, "polling SMTP API for notifications",
slog.F("email", user.Email),
slog.F("expected_count", len(expectedNotifications)),
)
receivedNotifications := make(map[uuid.UUID]struct{})
apiURL := fmt.Sprintf("%s/messages?email=%s", r.cfg.SMTPApiURL, user.Email)
httpClient := &http.Client{
Timeout: 10 * time.Second,
}
const smtpPollInterval = 2 * time.Second
done := xerrors.New("done")
tkr := r.clock.TickerFunc(ctx, smtpPollInterval, func() error {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, apiURL, nil)
if err != nil {
logger.Error(ctx, "create SMTP API request", slog.Error(err))
r.cfg.Metrics.AddError("smtp_create_request")
return xerrors.Errorf("create SMTP API request: %w", err)
}
resp, err := httpClient.Do(req)
if err != nil {
logger.Error(ctx, "poll smtp api for notifications", slog.Error(err))
r.cfg.Metrics.AddError("smtp_poll")
return xerrors.Errorf("poll smtp api: %w", err)
}
if resp.StatusCode != http.StatusOK {
_ = resp.Body.Close()
logger.Error(ctx, "smtp api returned non-200 status", slog.F("status", resp.StatusCode))
r.cfg.Metrics.AddError("smtp_bad_status")
return xerrors.Errorf("smtp api returned status %d", resp.StatusCode)
}
var summaries []smtpmock.EmailSummary
if err := json.NewDecoder(resp.Body).Decode(&summaries); err != nil {
_ = resp.Body.Close()
logger.Error(ctx, "decode smtp api response", slog.Error(err))
r.cfg.Metrics.AddError("smtp_decode")
return xerrors.Errorf("decode smtp api response: %w", err)
}
_ = resp.Body.Close()
// Process each email summary
for _, summary := range summaries {
notificationID := summary.NotificationTemplateID
if notificationID == uuid.Nil {
continue
}
if _, exists := expectedNotifications[notificationID]; exists {
if _, received := receivedNotifications[notificationID]; !received {
receiptTime := summary.Date
if receiptTime.IsZero() {
receiptTime = time.Now()
}
r.smtpReceiptTimesMu.Lock()
r.smtpReceiptTimes[notificationID] = receiptTime
r.smtpReceiptTimesMu.Unlock()
receivedNotifications[notificationID] = struct{}{}
logger.Info(ctx, "received expected notification via SMTP",
slog.F("notification_id", notificationID),
slog.F("subject", summary.Subject),
slog.F("receipt_time", receiptTime))
}
}
}
if len(receivedNotifications) == len(expectedNotifications) {
logger.Info(ctx, "received all expected notifications via SMTP")
return done
}
return nil
}, "smtp")
err := tkr.Wait()
if errors.Is(err, done) {
return nil
}
return err
}
func readNotification(ctx context.Context, conn *websocket.Conn) (codersdk.GetInboxNotificationResponse, error) {
_, message, err := conn.Read(ctx)
if err != nil {
+211 -89
View File
@@ -1,7 +1,11 @@
package notifications_test
import (
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strconv"
"sync"
"testing"
@@ -9,22 +13,19 @@ import (
"github.com/google/uuid"
"github.com/prometheus/client_golang/prometheus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/sync/errgroup"
"golang.org/x/xerrors"
"github.com/coder/serpent"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
notificationsLib "github.com/coder/coder/v2/coderd/notifications"
"github.com/coder/coder/v2/coderd/notifications/dispatch"
"github.com/coder/coder/v2/coderd/notifications/types"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/scaletest/createusers"
"github.com/coder/coder/v2/scaletest/notifications"
"github.com/coder/coder/v2/scaletest/smtpmock"
"github.com/coder/coder/v2/testutil"
"github.com/coder/quartz"
)
@@ -36,39 +37,11 @@ func TestRun(t *testing.T) {
logger := testutil.Logger(t)
db, ps := dbtestutil.NewDB(t)
// Setup notifications manager with inbox handler
cfg := defaultNotificationsConfig(database.NotificationMethodSmtp)
mgr, err := notificationsLib.NewManager(
cfg,
db,
ps,
defaultHelpers(),
notificationsLib.NewMetrics(prometheus.NewRegistry()),
logger.Named("manager"),
)
require.NoError(t, err)
mgr.WithHandlers(map[database.NotificationMethod]notificationsLib.Handler{
database.NotificationMethodInbox: dispatch.NewInboxHandler(logger.Named("inbox"), db, ps),
})
t.Cleanup(func() {
assert.NoError(t, mgr.Stop(dbauthz.AsNotifier(ctx)))
})
mgr.Run(dbauthz.AsNotifier(ctx))
enqueuer, err := notificationsLib.NewStoreEnqueuer(
cfg,
db,
defaultHelpers(),
logger.Named("enqueuer"),
quartz.NewReal(),
)
require.NoError(t, err)
inboxHandler := dispatch.NewInboxHandler(logger.Named("inbox"), db, ps)
client := coderdtest.New(t, &coderdtest.Options{
Database: db,
Pubsub: ps,
NotificationsEnqueuer: enqueuer,
Database: db,
Pubsub: ps,
})
firstUser := coderdtest.CreateFirstUser(t, client)
@@ -82,9 +55,9 @@ func TestRun(t *testing.T) {
eg, runCtx := errgroup.WithContext(ctx)
expectedNotifications := map[uuid.UUID]chan time.Time{
notificationsLib.TemplateUserAccountCreated: make(chan time.Time, 1),
notificationsLib.TemplateUserAccountDeleted: make(chan time.Time, 1),
expectedNotificationsIDs := map[uuid.UUID]struct{}{
notificationsLib.TemplateUserAccountCreated: {},
notificationsLib.TemplateUserAccountDeleted: {},
}
// Start receiving runners who will receive notifications
@@ -93,14 +66,15 @@ func TestRun(t *testing.T) {
runnerCfg := notifications.Config{
User: createusers.Config{
OrganizationID: firstUser.OrganizationID,
Username: "receiving-user-" + strconv.Itoa(i),
},
Roles: []string{codersdk.RoleOwner},
NotificationTimeout: testutil.WaitLong,
DialTimeout: testutil.WaitLong,
Metrics: metrics,
DialBarrier: dialBarrier,
ReceivingWatchBarrier: receivingWatchBarrier,
ExpectedNotifications: expectedNotifications,
Roles: []string{codersdk.RoleOwner},
NotificationTimeout: testutil.WaitLong,
DialTimeout: testutil.WaitLong,
Metrics: metrics,
DialBarrier: dialBarrier,
ReceivingWatchBarrier: receivingWatchBarrier,
ExpectedNotificationsIDs: expectedNotificationsIDs,
}
err := runnerCfg.Validate()
require.NoError(t, err)
@@ -141,31 +115,17 @@ func TestRun(t *testing.T) {
// Wait for all runners to connect
dialBarrier.Wait()
createTime := time.Now()
newUser, err := client.CreateUserWithOrgs(runCtx, codersdk.CreateUserRequestWithOrgs{
OrganizationIDs: []uuid.UUID{firstUser.OrganizationID},
Email: "test-user@coder.com",
Username: "test-user",
Password: "SomeSecurePassword!",
})
if err != nil {
return xerrors.Errorf("create test user: %w", err)
for i := 0; i < numReceivingUsers; i++ {
err := sendInboxNotification(runCtx, t, db, inboxHandler, "receiving-user-"+strconv.Itoa(i), notificationsLib.TemplateUserAccountCreated)
require.NoError(t, err)
err = sendInboxNotification(runCtx, t, db, inboxHandler, "receiving-user-"+strconv.Itoa(i), notificationsLib.TemplateUserAccountDeleted)
require.NoError(t, err)
}
expectedNotifications[notificationsLib.TemplateUserAccountCreated] <- createTime
deleteTime := time.Now()
if err := client.DeleteUser(runCtx, newUser.ID); err != nil {
return xerrors.Errorf("delete test user: %w", err)
}
expectedNotifications[notificationsLib.TemplateUserAccountDeleted] <- deleteTime
close(expectedNotifications[notificationsLib.TemplateUserAccountCreated])
close(expectedNotifications[notificationsLib.TemplateUserAccountDeleted])
return nil
})
err = eg.Wait()
err := eg.Wait()
require.NoError(t, err, "runner execution should complete successfully")
cleanupEg, cleanupCtx := errgroup.WithContext(ctx)
@@ -188,34 +148,196 @@ func TestRun(t *testing.T) {
require.Equal(t, firstUser.UserID, users.Users[0].ID)
for _, runner := range receivingRunners {
runnerMetrics := runner.GetMetrics()[notifications.NotificationDeliveryLatencyMetric].(map[uuid.UUID]time.Duration)
require.Contains(t, runnerMetrics, notificationsLib.TemplateUserAccountCreated)
require.Contains(t, runnerMetrics, notificationsLib.TemplateUserAccountDeleted)
metrics := runner.GetMetrics()
websocketReceiptTimes := metrics[notifications.WebsocketNotificationReceiptTimeMetric].(map[uuid.UUID]time.Time)
require.Contains(t, websocketReceiptTimes, notificationsLib.TemplateUserAccountCreated)
require.Contains(t, websocketReceiptTimes, notificationsLib.TemplateUserAccountDeleted)
}
}
func defaultNotificationsConfig(method database.NotificationMethod) codersdk.NotificationsConfig {
return codersdk.NotificationsConfig{
Method: serpent.String(method),
MaxSendAttempts: 5,
FetchInterval: serpent.Duration(time.Millisecond * 100),
StoreSyncInterval: serpent.Duration(time.Millisecond * 200),
LeasePeriod: serpent.Duration(time.Second * 10),
DispatchTimeout: serpent.Duration(time.Second * 5),
RetryInterval: serpent.Duration(time.Millisecond * 50),
LeaseCount: 10,
StoreSyncBufferSize: 50,
Inbox: codersdk.NotificationsInboxConfig{
Enabled: serpent.Bool(true),
},
func TestRunWithSMTP(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
logger := testutil.Logger(t)
db, ps := dbtestutil.NewDB(t)
inboxHandler := dispatch.NewInboxHandler(logger.Named("inbox"), db, ps)
client := coderdtest.New(t, &coderdtest.Options{
Database: db,
Pubsub: ps,
})
firstUser := coderdtest.CreateFirstUser(t, client)
smtpAPIMux := http.NewServeMux()
smtpAPIMux.HandleFunc("/messages", func(w http.ResponseWriter, r *http.Request) {
summaries := []smtpmock.EmailSummary{
{
Subject: "TemplateUserAccountCreated",
Date: time.Now(),
NotificationTemplateID: notificationsLib.TemplateUserAccountCreated,
},
{
Subject: "TemplateUserAccountDeleted",
Date: time.Now(),
NotificationTemplateID: notificationsLib.TemplateUserAccountDeleted,
},
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(summaries)
})
smtpAPIServer := httptest.NewServer(smtpAPIMux)
defer smtpAPIServer.Close()
const numReceivingUsers = 2
const numRegularUsers = 2
dialBarrier := new(sync.WaitGroup)
receivingWatchBarrier := new(sync.WaitGroup)
dialBarrier.Add(numReceivingUsers + numRegularUsers)
receivingWatchBarrier.Add(numReceivingUsers)
metrics := notifications.NewMetrics(prometheus.NewRegistry())
eg, runCtx := errgroup.WithContext(ctx)
expectedNotificationsIDs := map[uuid.UUID]struct{}{
notificationsLib.TemplateUserAccountCreated: {},
notificationsLib.TemplateUserAccountDeleted: {},
}
mClock := quartz.NewMock(t)
smtpTrap := mClock.Trap().TickerFunc("smtp")
defer smtpTrap.Close()
// Start receiving runners who will receive notifications
receivingRunners := make([]*notifications.Runner, 0, numReceivingUsers)
for i := range numReceivingUsers {
runnerCfg := notifications.Config{
User: createusers.Config{
OrganizationID: firstUser.OrganizationID,
Username: "receiving-user-" + strconv.Itoa(i),
},
Roles: []string{codersdk.RoleOwner},
NotificationTimeout: testutil.WaitLong,
DialTimeout: testutil.WaitLong,
Metrics: metrics,
DialBarrier: dialBarrier,
ReceivingWatchBarrier: receivingWatchBarrier,
ExpectedNotificationsIDs: expectedNotificationsIDs,
SMTPApiURL: smtpAPIServer.URL,
}
err := runnerCfg.Validate()
require.NoError(t, err)
runner := notifications.NewRunner(client, runnerCfg).WithClock(mClock)
receivingRunners = append(receivingRunners, runner)
eg.Go(func() error {
return runner.Run(runCtx, "receiving-"+strconv.Itoa(i), io.Discard)
})
}
// Start regular user runners who will maintain websocket connections
regularRunners := make([]*notifications.Runner, 0, numRegularUsers)
for i := range numRegularUsers {
runnerCfg := notifications.Config{
User: createusers.Config{
OrganizationID: firstUser.OrganizationID,
},
Roles: []string{},
NotificationTimeout: testutil.WaitLong,
DialTimeout: testutil.WaitLong,
Metrics: metrics,
DialBarrier: dialBarrier,
ReceivingWatchBarrier: receivingWatchBarrier,
}
err := runnerCfg.Validate()
require.NoError(t, err)
runner := notifications.NewRunner(client, runnerCfg)
regularRunners = append(regularRunners, runner)
eg.Go(func() error {
return runner.Run(runCtx, "regular-"+strconv.Itoa(i), io.Discard)
})
}
// Trigger notifications by creating and deleting a user
eg.Go(func() error {
// Wait for all runners to connect
dialBarrier.Wait()
for i := 0; i < numReceivingUsers; i++ {
smtpTrap.MustWait(runCtx).MustRelease(runCtx)
}
for i := 0; i < numReceivingUsers; i++ {
err := sendInboxNotification(runCtx, t, db, inboxHandler, "receiving-user-"+strconv.Itoa(i), notificationsLib.TemplateUserAccountCreated)
require.NoError(t, err)
err = sendInboxNotification(runCtx, t, db, inboxHandler, "receiving-user-"+strconv.Itoa(i), notificationsLib.TemplateUserAccountDeleted)
require.NoError(t, err)
}
_, w := mClock.AdvanceNext()
w.MustWait(runCtx)
return nil
})
err := eg.Wait()
require.NoError(t, err, "runner execution with SMTP should complete successfully")
cleanupEg, cleanupCtx := errgroup.WithContext(ctx)
for i, runner := range receivingRunners {
cleanupEg.Go(func() error {
return runner.Cleanup(cleanupCtx, "receiving-"+strconv.Itoa(i), io.Discard)
})
}
for i, runner := range regularRunners {
cleanupEg.Go(func() error {
return runner.Cleanup(cleanupCtx, "regular-"+strconv.Itoa(i), io.Discard)
})
}
err = cleanupEg.Wait()
require.NoError(t, err)
users, err := client.Users(ctx, codersdk.UsersRequest{})
require.NoError(t, err)
require.Len(t, users.Users, 1)
require.Equal(t, firstUser.UserID, users.Users[0].ID)
// Verify that notifications were received via both websocket and SMTP
for _, runner := range receivingRunners {
metrics := runner.GetMetrics()
websocketReceiptTimes := metrics[notifications.WebsocketNotificationReceiptTimeMetric].(map[uuid.UUID]time.Time)
smtpReceiptTimes := metrics[notifications.SMTPNotificationReceiptTimeMetric].(map[uuid.UUID]time.Time)
require.Contains(t, websocketReceiptTimes, notificationsLib.TemplateUserAccountCreated)
require.Contains(t, websocketReceiptTimes, notificationsLib.TemplateUserAccountDeleted)
require.Contains(t, smtpReceiptTimes, notificationsLib.TemplateUserAccountCreated)
require.Contains(t, smtpReceiptTimes, notificationsLib.TemplateUserAccountDeleted)
}
}
func defaultHelpers() map[string]any {
return map[string]any{
"base_url": func() string { return "http://test.com" },
"current_year": func() string { return "2024" },
"logo_url": func() string { return "https://coder.com/coder-logo-horizontal.png" },
"app_name": func() string { return "Coder" },
func sendInboxNotification(ctx context.Context, t *testing.T, db database.Store, inboxHandler *dispatch.InboxHandler, username string, templateID uuid.UUID) error {
user, err := db.GetUserByEmailOrUsername(ctx, database.GetUserByEmailOrUsernameParams{
Username: username,
})
require.NoError(t, err)
dispatchFunc, err := inboxHandler.Dispatcher(types.MessagePayload{
UserID: user.ID.String(),
NotificationTemplateID: templateID.String(),
}, "", "", nil)
if err != nil {
return err
}
_, err = dispatchFunc(ctx, uuid.New())
if err != nil {
return err
}
return nil
}