From a44056cff51e1383f21381644da38ad747207bbe Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Mon, 24 Jan 2022 11:07:42 -0600 Subject: [PATCH] feat: Add project API endpoints (#51) * feat: Add project models * Add project query functions * Add organization parameter query * Add project URL parameter parse * Add project create and list endpoints * Add test for organization provided * Remove unimplemented routes * Decrease conn timeout * Add test for UnbiasedModulo32 * Fix expected value * Add single user endpoint * Add query for project versions * Fix linting errors * Add comments * Add test for invalid archive * Check unauthenticated endpoints * Add check if no change happened * Ensure context close ends listener * Fix parallel test run * Test empty * Fix organization param comment --- .golangci.yml | 1 + Makefile | 5 +- coderd/cmd/root_test.go | 1 + coderd/coderd.go | 22 ++ coderd/coderdtest/coderdtest_test.go | 1 + coderd/projects.go | 229 +++++++++++ coderd/projects_test.go | 183 +++++++++ coderd/userpassword/userpassword_test.go | 6 + coderd/users_test.go | 1 + codersdk/projects.go | 86 +++++ codersdk/projects_test.go | 130 +++++++ codersdk/users_test.go | 5 + cryptorand/numbers_test.go | 3 + cryptorand/strings_test.go | 1 + database/databasefake/databasefake.go | 122 ++++++ database/dump.sql | 77 ++++ database/migrate_test.go | 33 +- database/migrations/000002_projects.down.sql | 0 database/migrations/000002_projects.up.sql | 84 ++++ database/models.go | 98 +++++ database/pubsub_test.go | 17 + database/querier.go | 9 + database/query.sql | 146 ++++++- database/query.sql.go | 384 ++++++++++++++++++- database/sqlc.yaml | 1 + go.mod | 1 + go.sum | 2 + httpapi/httpapi_test.go | 10 + httpmw/apikey_test.go | 12 + httpmw/organizationparam.go | 86 +++++ httpmw/organizationparam_test.go | 165 ++++++++ httpmw/projectparam.go | 60 +++ httpmw/projectparam_test.go | 151 ++++++++ httpmw/userparam_test.go | 4 + peer/conn_test.go | 1 + peerbroker/listen_test.go | 3 + provisionersdk/serve_test.go | 3 + 37 files changed, 2121 insertions(+), 22 deletions(-) create mode 100644 coderd/projects.go create mode 100644 coderd/projects_test.go create mode 100644 codersdk/projects.go create mode 100644 codersdk/projects_test.go create mode 100644 database/migrations/000002_projects.down.sql create mode 100644 database/migrations/000002_projects.up.sql create mode 100644 httpmw/organizationparam.go create mode 100644 httpmw/organizationparam_test.go create mode 100644 httpmw/projectparam.go create mode 100644 httpmw/projectparam_test.go diff --git a/.golangci.yml b/.golangci.yml index 859e160899..1231a65ddb 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -234,6 +234,7 @@ linters: - misspell - nilnil - noctx + - paralleltest - revive - rowserrcheck - sqlclosecheck diff --git a/Makefile b/Makefile index c255da17de..572afc6109 100644 --- a/Makefile +++ b/Makefile @@ -11,7 +11,7 @@ database/dump.sql: $(wildcard database/migrations/*.sql) go run database/dump/main.go # Generates Go code for querying the database. -database/generate: database/dump.sql database/query.sql +database/generate: fmt/sql database/dump.sql database/query.sql cd database && sqlc generate && rm db_tmp.go cd database && gofmt -w -r 'Querier -> querier' *.go cd database && gofmt -w -r 'Queries -> sqlQuerier' *.go @@ -27,12 +27,13 @@ else endif .PHONY: fmt/prettier -fmt/sql: +fmt/sql: ./database/query.sql npx sql-formatter \ --language postgresql \ --lines-between-queries 2 \ ./database/query.sql \ --output ./database/query.sql + sed -i 's/@ /@/g' ./database/query.sql fmt: fmt/prettier fmt/sql .PHONY: fmt diff --git a/coderd/cmd/root_test.go b/coderd/cmd/root_test.go index 59996051e5..caf11ecdb9 100644 --- a/coderd/cmd/root_test.go +++ b/coderd/cmd/root_test.go @@ -10,6 +10,7 @@ import ( ) func TestRoot(t *testing.T) { + t.Parallel() ctx, cancelFunc := context.WithCancel(context.Background()) go cancelFunc() err := cmd.Root().ExecuteContext(ctx) diff --git a/coderd/coderd.go b/coderd/coderd.go index 4ad9463c73..b666df9810 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -20,6 +20,9 @@ type Options struct { // New constructs the Coder API into an HTTP handler. func New(options *Options) http.Handler { + projects := &projects{ + Database: options.Database, + } users := &users{ Database: options.Database, } @@ -44,6 +47,25 @@ func New(options *Options) http.Handler { r.Get("/{user}/organizations", users.userOrganizations) }) }) + r.Route("/projects", func(r chi.Router) { + r.Use( + httpmw.ExtractAPIKey(options.Database, nil), + ) + r.Get("/", projects.allProjects) + r.Route("/{organization}", func(r chi.Router) { + r.Use(httpmw.ExtractOrganizationParam(options.Database)) + r.Get("/", projects.allProjectsForOrganization) + r.Post("/", projects.createProject) + r.Route("/{project}", func(r chi.Router) { + r.Use(httpmw.ExtractProjectParameter(options.Database)) + r.Get("/", projects.project) + r.Route("/versions", func(r chi.Router) { + r.Get("/", projects.projectVersions) + r.Post("/", projects.createProjectVersion) + }) + }) + }) + }) }) r.NotFound(site.Handler().ServeHTTP) return r diff --git a/coderd/coderdtest/coderdtest_test.go b/coderd/coderdtest/coderdtest_test.go index 127b479941..e36d1c1408 100644 --- a/coderd/coderdtest/coderdtest_test.go +++ b/coderd/coderdtest/coderdtest_test.go @@ -13,6 +13,7 @@ func TestMain(m *testing.M) { } func TestNew(t *testing.T) { + t.Parallel() server := coderdtest.New(t) _ = server.RandomInitialUser(t) } diff --git a/coderd/projects.go b/coderd/projects.go new file mode 100644 index 0000000000..be326ee15e --- /dev/null +++ b/coderd/projects.go @@ -0,0 +1,229 @@ +package coderd + +import ( + "archive/tar" + "bytes" + "database/sql" + "errors" + "fmt" + "net/http" + "time" + + "github.com/go-chi/render" + "github.com/google/uuid" + + "github.com/moby/moby/pkg/namesgenerator" + + "github.com/coder/coder/database" + "github.com/coder/coder/httpapi" + "github.com/coder/coder/httpmw" +) + +// Project is the JSON representation of a Coder project. +// This type matches the database object for now, but is +// abstracted for ease of change later on. +type Project database.Project + +// ProjectVersion is the JSON representation of a Coder project version. +type ProjectVersion struct { + ID uuid.UUID `json:"id"` + ProjectID uuid.UUID `json:"project_id"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + Name string `json:"name"` + StorageMethod database.ProjectStorageMethod `json:"storage_method"` +} + +// CreateProjectRequest enables callers to create a new Project. +type CreateProjectRequest struct { + Name string `json:"name" validate:"username,required"` + Provisioner database.ProvisionerType `json:"provisioner" validate:"oneof=terraform cdr-basic,required"` +} + +// CreateProjectVersionRequest enables callers to create a new Project Version. +type CreateProjectVersionRequest struct { + Name string `json:"name,omitempty" validate:"username"` + StorageMethod database.ProjectStorageMethod `json:"storage_method" validate:"oneof=inline-archive,required"` + StorageSource []byte `json:"storage_source" validate:"max=1048576,required"` +} + +type projects struct { + Database database.Store +} + +// allProjects lists all projects across organizations for a user. +func (p *projects) allProjects(rw http.ResponseWriter, r *http.Request) { + apiKey := httpmw.APIKey(r) + organizations, err := p.Database.GetOrganizationsByUserID(r.Context(), apiKey.UserID) + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("get organizations: %s", err.Error()), + }) + return + } + organizationIDs := make([]string, 0, len(organizations)) + for _, organization := range organizations { + organizationIDs = append(organizationIDs, organization.ID) + } + projects, err := p.Database.GetProjectsByOrganizationIDs(r.Context(), organizationIDs) + if errors.Is(err, sql.ErrNoRows) { + err = nil + } + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("get projects: %s", err.Error()), + }) + return + } + render.Status(r, http.StatusOK) + render.JSON(rw, r, projects) +} + +// allProjectsForOrganization lists all projects for a specific organization. +func (p *projects) allProjectsForOrganization(rw http.ResponseWriter, r *http.Request) { + organization := httpmw.OrganizationParam(r) + projects, err := p.Database.GetProjectsByOrganizationIDs(r.Context(), []string{organization.ID}) + if errors.Is(err, sql.ErrNoRows) { + err = nil + } + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("get projects: %s", err.Error()), + }) + return + } + render.Status(r, http.StatusOK) + render.JSON(rw, r, projects) +} + +// createProject makes a new project in an organization. +func (p *projects) createProject(rw http.ResponseWriter, r *http.Request) { + var createProject CreateProjectRequest + if !httpapi.Read(rw, r, &createProject) { + return + } + organization := httpmw.OrganizationParam(r) + _, err := p.Database.GetProjectByOrganizationAndName(r.Context(), database.GetProjectByOrganizationAndNameParams{ + OrganizationID: organization.ID, + Name: createProject.Name, + }) + if err == nil { + httpapi.Write(rw, http.StatusConflict, httpapi.Response{ + Message: fmt.Sprintf("project %q already exists", createProject.Name), + Errors: []httpapi.Error{{ + Field: "name", + Code: "exists", + }}, + }) + return + } + if !errors.Is(err, sql.ErrNoRows) { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("get project by name: %s", err.Error()), + }) + return + } + + project, err := p.Database.InsertProject(r.Context(), database.InsertProjectParams{ + ID: uuid.New(), + CreatedAt: database.Now(), + UpdatedAt: database.Now(), + OrganizationID: organization.ID, + Name: createProject.Name, + Provisioner: createProject.Provisioner, + }) + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("insert project: %s", err), + }) + return + } + render.Status(r, http.StatusCreated) + render.JSON(rw, r, project) +} + +// project returns a single project parsed from the URL path. +func (*projects) project(rw http.ResponseWriter, r *http.Request) { + project := httpmw.ProjectParam(r) + + render.Status(r, http.StatusOK) + render.JSON(rw, r, project) +} + +// projectVersions lists versions for a single project. +func (p *projects) projectVersions(rw http.ResponseWriter, r *http.Request) { + project := httpmw.ProjectParam(r) + + history, err := p.Database.GetProjectHistoryByProjectID(r.Context(), project.ID) + if errors.Is(err, sql.ErrNoRows) { + err = nil + } + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("get project history: %s", err), + }) + return + } + versions := make([]ProjectVersion, 0) + for _, version := range history { + versions = append(versions, convertProjectHistory(version)) + } + render.Status(r, http.StatusOK) + render.JSON(rw, r, versions) +} + +func (p *projects) createProjectVersion(rw http.ResponseWriter, r *http.Request) { + var createProjectVersion CreateProjectVersionRequest + if !httpapi.Read(rw, r, &createProjectVersion) { + return + } + + switch createProjectVersion.StorageMethod { + case database.ProjectStorageMethodInlineArchive: + tarReader := tar.NewReader(bytes.NewReader(createProjectVersion.StorageSource)) + _, err := tarReader.Next() + if err != nil { + httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{ + Message: "the archive must be a tar", + }) + return + } + default: + httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{ + Message: fmt.Sprintf("unsupported storage method %s", createProjectVersion.StorageMethod), + }) + return + } + + project := httpmw.ProjectParam(r) + history, err := p.Database.InsertProjectHistory(r.Context(), database.InsertProjectHistoryParams{ + ID: uuid.New(), + ProjectID: project.ID, + CreatedAt: database.Now(), + UpdatedAt: database.Now(), + Name: namesgenerator.GetRandomName(1), + StorageMethod: createProjectVersion.StorageMethod, + StorageSource: createProjectVersion.StorageSource, + }) + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("insert project history: %s", err), + }) + return + } + + // TODO: A job to process the new version should occur here. + + render.Status(r, http.StatusCreated) + render.JSON(rw, r, convertProjectHistory(history)) +} + +func convertProjectHistory(history database.ProjectHistory) ProjectVersion { + return ProjectVersion{ + ID: history.ID, + ProjectID: history.ProjectID, + CreatedAt: history.CreatedAt, + UpdatedAt: history.UpdatedAt, + Name: history.Name, + } +} diff --git a/coderd/projects_test.go b/coderd/projects_test.go new file mode 100644 index 0000000000..fb653e1701 --- /dev/null +++ b/coderd/projects_test.go @@ -0,0 +1,183 @@ +package coderd_test + +import ( + "archive/tar" + "bytes" + "context" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/coder/coder/coderd" + "github.com/coder/coder/coderd/coderdtest" + "github.com/coder/coder/database" +) + +func TestProjects(t *testing.T) { + t.Parallel() + + t.Run("Create", func(t *testing.T) { + t.Parallel() + server := coderdtest.New(t) + user := server.RandomInitialUser(t) + _, err := server.Client.CreateProject(context.Background(), user.Organization, coderd.CreateProjectRequest{ + Name: "someproject", + Provisioner: database.ProvisionerTypeTerraform, + }) + require.NoError(t, err) + }) + + t.Run("AlreadyExists", func(t *testing.T) { + t.Parallel() + server := coderdtest.New(t) + user := server.RandomInitialUser(t) + _, err := server.Client.CreateProject(context.Background(), user.Organization, coderd.CreateProjectRequest{ + Name: "someproject", + Provisioner: database.ProvisionerTypeTerraform, + }) + require.NoError(t, err) + _, err = server.Client.CreateProject(context.Background(), user.Organization, coderd.CreateProjectRequest{ + Name: "someproject", + Provisioner: database.ProvisionerTypeTerraform, + }) + require.Error(t, err) + }) + + t.Run("ListEmpty", func(t *testing.T) { + t.Parallel() + server := coderdtest.New(t) + _ = server.RandomInitialUser(t) + projects, err := server.Client.Projects(context.Background(), "") + require.NoError(t, err) + require.Len(t, projects, 0) + }) + + t.Run("List", func(t *testing.T) { + t.Parallel() + server := coderdtest.New(t) + user := server.RandomInitialUser(t) + _, err := server.Client.CreateProject(context.Background(), user.Organization, coderd.CreateProjectRequest{ + Name: "someproject", + Provisioner: database.ProvisionerTypeTerraform, + }) + require.NoError(t, err) + // Ensure global query works. + projects, err := server.Client.Projects(context.Background(), "") + require.NoError(t, err) + require.Len(t, projects, 1) + + // Ensure specified query works. + projects, err = server.Client.Projects(context.Background(), user.Organization) + require.NoError(t, err) + require.Len(t, projects, 1) + }) + + t.Run("ListEmpty", func(t *testing.T) { + t.Parallel() + server := coderdtest.New(t) + user := server.RandomInitialUser(t) + + projects, err := server.Client.Projects(context.Background(), user.Organization) + require.NoError(t, err) + require.Len(t, projects, 0) + }) + + t.Run("Single", func(t *testing.T) { + t.Parallel() + server := coderdtest.New(t) + user := server.RandomInitialUser(t) + project, err := server.Client.CreateProject(context.Background(), user.Organization, coderd.CreateProjectRequest{ + Name: "someproject", + Provisioner: database.ProvisionerTypeTerraform, + }) + require.NoError(t, err) + _, err = server.Client.Project(context.Background(), user.Organization, project.Name) + require.NoError(t, err) + }) + + t.Run("NoVersions", func(t *testing.T) { + t.Parallel() + server := coderdtest.New(t) + user := server.RandomInitialUser(t) + project, err := server.Client.CreateProject(context.Background(), user.Organization, coderd.CreateProjectRequest{ + Name: "someproject", + Provisioner: database.ProvisionerTypeTerraform, + }) + require.NoError(t, err) + versions, err := server.Client.ProjectVersions(context.Background(), user.Organization, project.Name) + require.NoError(t, err) + require.Len(t, versions, 0) + }) + + t.Run("CreateVersion", func(t *testing.T) { + t.Parallel() + server := coderdtest.New(t) + user := server.RandomInitialUser(t) + project, err := server.Client.CreateProject(context.Background(), user.Organization, coderd.CreateProjectRequest{ + Name: "someproject", + Provisioner: database.ProvisionerTypeTerraform, + }) + require.NoError(t, err) + var buffer bytes.Buffer + writer := tar.NewWriter(&buffer) + err = writer.WriteHeader(&tar.Header{ + Name: "file", + Size: 1 << 10, + }) + require.NoError(t, err) + _, err = writer.Write(make([]byte, 1<<10)) + require.NoError(t, err) + _, err = server.Client.CreateProjectVersion(context.Background(), user.Organization, project.Name, coderd.CreateProjectVersionRequest{ + Name: "moo", + StorageMethod: database.ProjectStorageMethodInlineArchive, + StorageSource: buffer.Bytes(), + }) + require.NoError(t, err) + versions, err := server.Client.ProjectVersions(context.Background(), user.Organization, project.Name) + require.NoError(t, err) + require.Len(t, versions, 1) + }) + + t.Run("CreateVersionArchiveTooBig", func(t *testing.T) { + t.Parallel() + server := coderdtest.New(t) + user := server.RandomInitialUser(t) + project, err := server.Client.CreateProject(context.Background(), user.Organization, coderd.CreateProjectRequest{ + Name: "someproject", + Provisioner: database.ProvisionerTypeTerraform, + }) + require.NoError(t, err) + var buffer bytes.Buffer + writer := tar.NewWriter(&buffer) + err = writer.WriteHeader(&tar.Header{ + Name: "file", + Size: 1 << 21, + }) + require.NoError(t, err) + _, err = writer.Write(make([]byte, 1<<21)) + require.NoError(t, err) + _, err = server.Client.CreateProjectVersion(context.Background(), user.Organization, project.Name, coderd.CreateProjectVersionRequest{ + Name: "moo", + StorageMethod: database.ProjectStorageMethodInlineArchive, + StorageSource: buffer.Bytes(), + }) + require.Error(t, err) + }) + + t.Run("CreateVersionInvalidArchive", func(t *testing.T) { + t.Parallel() + server := coderdtest.New(t) + user := server.RandomInitialUser(t) + project, err := server.Client.CreateProject(context.Background(), user.Organization, coderd.CreateProjectRequest{ + Name: "someproject", + Provisioner: database.ProvisionerTypeTerraform, + }) + require.NoError(t, err) + _, err = server.Client.CreateProjectVersion(context.Background(), user.Organization, project.Name, coderd.CreateProjectVersionRequest{ + Name: "moo", + StorageMethod: database.ProjectStorageMethodInlineArchive, + StorageSource: []byte{}, + }) + require.Error(t, err) + }) +} diff --git a/coderd/userpassword/userpassword_test.go b/coderd/userpassword/userpassword_test.go index a5a879e3df..0546163d24 100644 --- a/coderd/userpassword/userpassword_test.go +++ b/coderd/userpassword/userpassword_test.go @@ -9,7 +9,9 @@ import ( ) func TestUserPassword(t *testing.T) { + t.Parallel() t.Run("Legacy", func(t *testing.T) { + t.Parallel() // Ensures legacy v1 passwords function for v2. // This has is manually generated using a print statement from v1 code. equal, err := userpassword.Compare("$pbkdf2-sha256$65535$z8c1p1C2ru9EImBP1I+ZNA$pNjE3Yk0oG0PmJ0Je+y7ENOVlSkn/b0BEqqdKsq6Y97wQBq0xT+lD5bWJpyIKJqQICuPZcEaGDKrXJn8+SIHRg", "tomato") @@ -18,6 +20,7 @@ func TestUserPassword(t *testing.T) { }) t.Run("Same", func(t *testing.T) { + t.Parallel() hash, err := userpassword.Hash("password") require.NoError(t, err) equal, err := userpassword.Compare(hash, "password") @@ -26,6 +29,7 @@ func TestUserPassword(t *testing.T) { }) t.Run("Different", func(t *testing.T) { + t.Parallel() hash, err := userpassword.Hash("password") require.NoError(t, err) equal, err := userpassword.Compare(hash, "notpassword") @@ -34,12 +38,14 @@ func TestUserPassword(t *testing.T) { }) t.Run("Invalid", func(t *testing.T) { + t.Parallel() equal, err := userpassword.Compare("invalidhash", "password") require.False(t, equal) require.Error(t, err) }) t.Run("InvalidParts", func(t *testing.T) { + t.Parallel() equal, err := userpassword.Compare("abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz", "test") require.False(t, equal) require.Error(t, err) diff --git a/coderd/users_test.go b/coderd/users_test.go index cd6deda103..59602f2592 100644 --- a/coderd/users_test.go +++ b/coderd/users_test.go @@ -35,6 +35,7 @@ func TestUsers(t *testing.T) { }) t.Run("Login", func(t *testing.T) { + t.Parallel() server := coderdtest.New(t) user := server.RandomInitialUser(t) _, err := server.Client.LoginWithPassword(context.Background(), coderd.LoginWithPasswordRequest{ diff --git a/codersdk/projects.go b/codersdk/projects.go new file mode 100644 index 0000000000..cb3806d915 --- /dev/null +++ b/codersdk/projects.go @@ -0,0 +1,86 @@ +package codersdk + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + + "github.com/coder/coder/coderd" +) + +// Projects lists projects inside an organization. +// If organization is an empty string, all projects will be returned +// for the authenticated user. +func (c *Client) Projects(ctx context.Context, organization string) ([]coderd.Project, error) { + route := "/api/v2/projects" + if organization != "" { + route = fmt.Sprintf("/api/v2/projects/%s", organization) + } + res, err := c.request(ctx, http.MethodGet, route, nil) + if err != nil { + return nil, err + } + defer res.Body.Close() + if res.StatusCode != http.StatusOK { + return nil, readBodyAsError(res) + } + var projects []coderd.Project + return projects, json.NewDecoder(res.Body).Decode(&projects) +} + +// Project returns a single project. +func (c *Client) Project(ctx context.Context, organization, project string) (coderd.Project, error) { + res, err := c.request(ctx, http.MethodGet, fmt.Sprintf("/api/v2/projects/%s/%s", organization, project), nil) + if err != nil { + return coderd.Project{}, nil + } + defer res.Body.Close() + if res.StatusCode != http.StatusOK { + return coderd.Project{}, readBodyAsError(res) + } + var resp coderd.Project + return resp, json.NewDecoder(res.Body).Decode(&resp) +} + +// CreateProject creates a new project inside an organization. +func (c *Client) CreateProject(ctx context.Context, organization string, request coderd.CreateProjectRequest) (coderd.Project, error) { + res, err := c.request(ctx, http.MethodPost, fmt.Sprintf("/api/v2/projects/%s", organization), request) + if err != nil { + return coderd.Project{}, err + } + defer res.Body.Close() + if res.StatusCode != http.StatusCreated { + return coderd.Project{}, readBodyAsError(res) + } + var project coderd.Project + return project, json.NewDecoder(res.Body).Decode(&project) +} + +// ProjectVersions lists history for a project. +func (c *Client) ProjectVersions(ctx context.Context, organization, project string) ([]coderd.ProjectVersion, error) { + res, err := c.request(ctx, http.MethodGet, fmt.Sprintf("/api/v2/projects/%s/%s/versions", organization, project), nil) + if err != nil { + return nil, err + } + defer res.Body.Close() + if res.StatusCode != http.StatusOK { + return nil, readBodyAsError(res) + } + var projectVersions []coderd.ProjectVersion + return projectVersions, json.NewDecoder(res.Body).Decode(&projectVersions) +} + +// CreateProjectVersion inserts a new version for the project. +func (c *Client) CreateProjectVersion(ctx context.Context, organization, project string, request coderd.CreateProjectVersionRequest) (coderd.ProjectVersion, error) { + res, err := c.request(ctx, http.MethodPost, fmt.Sprintf("/api/v2/projects/%s/%s/versions", organization, project), request) + if err != nil { + return coderd.ProjectVersion{}, err + } + defer res.Body.Close() + if res.StatusCode != http.StatusCreated { + return coderd.ProjectVersion{}, readBodyAsError(res) + } + var projectVersion coderd.ProjectVersion + return projectVersion, json.NewDecoder(res.Body).Decode(&projectVersion) +} diff --git a/codersdk/projects_test.go b/codersdk/projects_test.go new file mode 100644 index 0000000000..e3914b5f94 --- /dev/null +++ b/codersdk/projects_test.go @@ -0,0 +1,130 @@ +package codersdk_test + +import ( + "archive/tar" + "bytes" + "context" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/coder/coder/coderd" + "github.com/coder/coder/coderd/coderdtest" + "github.com/coder/coder/database" +) + +func TestProjects(t *testing.T) { + t.Parallel() + + t.Run("UnauthenticatedList", func(t *testing.T) { + t.Parallel() + server := coderdtest.New(t) + _, err := server.Client.Projects(context.Background(), "") + require.Error(t, err) + }) + + t.Run("List", func(t *testing.T) { + t.Parallel() + server := coderdtest.New(t) + user := server.RandomInitialUser(t) + _, err := server.Client.Projects(context.Background(), "") + require.NoError(t, err) + _, err = server.Client.Projects(context.Background(), user.Organization) + require.NoError(t, err) + }) + + t.Run("UnauthenticatedCreate", func(t *testing.T) { + t.Parallel() + server := coderdtest.New(t) + _, err := server.Client.CreateProject(context.Background(), "", coderd.CreateProjectRequest{}) + require.Error(t, err) + }) + + t.Run("Create", func(t *testing.T) { + t.Parallel() + server := coderdtest.New(t) + user := server.RandomInitialUser(t) + _, err := server.Client.CreateProject(context.Background(), user.Organization, coderd.CreateProjectRequest{ + Name: "bananas", + Provisioner: database.ProvisionerTypeTerraform, + }) + require.NoError(t, err) + }) + + t.Run("UnauthenticatedSingle", func(t *testing.T) { + t.Parallel() + server := coderdtest.New(t) + _, err := server.Client.Project(context.Background(), "wow", "example") + require.Error(t, err) + }) + + t.Run("Single", func(t *testing.T) { + t.Parallel() + server := coderdtest.New(t) + user := server.RandomInitialUser(t) + _, err := server.Client.CreateProject(context.Background(), user.Organization, coderd.CreateProjectRequest{ + Name: "bananas", + Provisioner: database.ProvisionerTypeTerraform, + }) + require.NoError(t, err) + _, err = server.Client.Project(context.Background(), user.Organization, "bananas") + require.NoError(t, err) + }) + + t.Run("UnauthenticatedVersions", func(t *testing.T) { + t.Parallel() + server := coderdtest.New(t) + _, err := server.Client.ProjectVersions(context.Background(), "org", "project") + require.Error(t, err) + }) + + t.Run("Versions", func(t *testing.T) { + t.Parallel() + server := coderdtest.New(t) + user := server.RandomInitialUser(t) + project, err := server.Client.CreateProject(context.Background(), user.Organization, coderd.CreateProjectRequest{ + Name: "bananas", + Provisioner: database.ProvisionerTypeTerraform, + }) + require.NoError(t, err) + _, err = server.Client.ProjectVersions(context.Background(), user.Organization, project.Name) + require.NoError(t, err) + }) + + t.Run("CreateVersionUnauthenticated", func(t *testing.T) { + t.Parallel() + server := coderdtest.New(t) + _, err := server.Client.CreateProjectVersion(context.Background(), "org", "project", coderd.CreateProjectVersionRequest{ + Name: "hello", + StorageMethod: database.ProjectStorageMethodInlineArchive, + StorageSource: []byte{}, + }) + require.Error(t, err) + }) + + t.Run("CreateVersion", func(t *testing.T) { + t.Parallel() + server := coderdtest.New(t) + user := server.RandomInitialUser(t) + project, err := server.Client.CreateProject(context.Background(), user.Organization, coderd.CreateProjectRequest{ + Name: "bananas", + Provisioner: database.ProvisionerTypeTerraform, + }) + require.NoError(t, err) + var buffer bytes.Buffer + writer := tar.NewWriter(&buffer) + err = writer.WriteHeader(&tar.Header{ + Name: "file", + Size: 1 << 10, + }) + require.NoError(t, err) + _, err = writer.Write(make([]byte, 1<<10)) + require.NoError(t, err) + _, err = server.Client.CreateProjectVersion(context.Background(), user.Organization, project.Name, coderd.CreateProjectVersionRequest{ + Name: "hello", + StorageMethod: database.ProjectStorageMethodInlineArchive, + StorageSource: buffer.Bytes(), + }) + require.NoError(t, err) + }) +} diff --git a/codersdk/users_test.go b/codersdk/users_test.go index 2304bc9398..ee59e97330 100644 --- a/codersdk/users_test.go +++ b/codersdk/users_test.go @@ -11,7 +11,9 @@ import ( ) func TestUsers(t *testing.T) { + t.Parallel() t.Run("CreateInitial", func(t *testing.T) { + t.Parallel() server := coderdtest.New(t) _, err := server.Client.CreateInitialUser(context.Background(), coderd.CreateInitialUserRequest{ Email: "wowie@coder.com", @@ -23,12 +25,14 @@ func TestUsers(t *testing.T) { }) t.Run("NoUser", func(t *testing.T) { + t.Parallel() server := coderdtest.New(t) _, err := server.Client.User(context.Background(), "") require.Error(t, err) }) t.Run("User", func(t *testing.T) { + t.Parallel() server := coderdtest.New(t) _ = server.RandomInitialUser(t) _, err := server.Client.User(context.Background(), "") @@ -36,6 +40,7 @@ func TestUsers(t *testing.T) { }) t.Run("UserOrganizations", func(t *testing.T) { + t.Parallel() server := coderdtest.New(t) _ = server.RandomInitialUser(t) orgs, err := server.Client.UserOrganizations(context.Background(), "") diff --git a/cryptorand/numbers_test.go b/cryptorand/numbers_test.go index b1602df404..105ce080b0 100644 --- a/cryptorand/numbers_test.go +++ b/cryptorand/numbers_test.go @@ -47,6 +47,9 @@ func TestUnbiasedModulo32(t *testing.T) { const mod = 7 dist := [mod]uint32{} + _, err := cryptorand.UnbiasedModulo32(0, mod) + require.NoError(t, err) + for i := 0; i < 1000; i++ { b := [4]byte{} _, _ = rand.Read(b[:]) diff --git a/cryptorand/strings_test.go b/cryptorand/strings_test.go index 3f6025e0f9..50730b8e09 100644 --- a/cryptorand/strings_test.go +++ b/cryptorand/strings_test.go @@ -91,6 +91,7 @@ func TestStringCharset(t *testing.T) { }, } + //nolint:paralleltest for _, test := range tests { test := test t.Run(test.Name, func(t *testing.T) { diff --git a/database/databasefake/databasefake.go b/database/databasefake/databasefake.go index 07c7e11ac7..8511bad058 100644 --- a/database/databasefake/databasefake.go +++ b/database/databasefake/databasefake.go @@ -3,6 +3,9 @@ package databasefake import ( "context" "database/sql" + "strings" + + "github.com/google/uuid" "github.com/coder/coder/database" ) @@ -14,15 +17,25 @@ func New() database.Store { organizations: make([]database.Organization, 0), organizationMembers: make([]database.OrganizationMember, 0), users: make([]database.User, 0), + + project: make([]database.Project, 0), + projectHistory: make([]database.ProjectHistory, 0), + projectParameter: make([]database.ProjectParameter, 0), } } // fakeQuerier replicates database functionality to enable quick testing. type fakeQuerier struct { + // Legacy tables apiKeys []database.APIKey organizations []database.Organization organizationMembers []database.OrganizationMember users []database.User + + // New tables + project []database.Project + projectHistory []database.ProjectHistory + projectParameter []database.ProjectParameter } // InTx doesn't rollback data properly for in-memory yet. @@ -89,6 +102,62 @@ func (q *fakeQuerier) GetOrganizationsByUserID(_ context.Context, userID string) return organizations, nil } +func (q *fakeQuerier) GetProjectByOrganizationAndName(_ context.Context, arg database.GetProjectByOrganizationAndNameParams) (database.Project, error) { + for _, project := range q.project { + if project.OrganizationID != arg.OrganizationID { + continue + } + if !strings.EqualFold(project.Name, arg.Name) { + continue + } + return project, nil + } + return database.Project{}, sql.ErrNoRows +} + +func (q *fakeQuerier) GetProjectHistoryByProjectID(_ context.Context, projectID uuid.UUID) ([]database.ProjectHistory, error) { + history := make([]database.ProjectHistory, 0) + for _, projectHistory := range q.projectHistory { + if projectHistory.ProjectID.String() != projectID.String() { + continue + } + history = append(history, projectHistory) + } + if len(history) == 0 { + return nil, sql.ErrNoRows + } + return history, nil +} + +func (q *fakeQuerier) GetProjectsByOrganizationIDs(_ context.Context, ids []string) ([]database.Project, error) { + projects := make([]database.Project, 0) + for _, project := range q.project { + for _, id := range ids { + if project.OrganizationID == id { + projects = append(projects, project) + break + } + } + } + if len(projects) == 0 { + return nil, sql.ErrNoRows + } + return projects, nil +} + +func (q *fakeQuerier) GetOrganizationMemberByUserID(_ context.Context, arg database.GetOrganizationMemberByUserIDParams) (database.OrganizationMember, error) { + for _, organizationMember := range q.organizationMembers { + if organizationMember.OrganizationID != arg.OrganizationID { + continue + } + if organizationMember.UserID != arg.UserID { + continue + } + return organizationMember, nil + } + return database.OrganizationMember{}, sql.ErrNoRows +} + func (q *fakeQuerier) InsertAPIKey(_ context.Context, arg database.InsertAPIKeyParams) (database.APIKey, error) { //nolint:gosimple key := database.APIKey{ @@ -136,6 +205,59 @@ func (q *fakeQuerier) InsertOrganizationMember(_ context.Context, arg database.I return organizationMember, nil } +func (q *fakeQuerier) InsertProject(_ context.Context, arg database.InsertProjectParams) (database.Project, error) { + project := database.Project{ + ID: arg.ID, + CreatedAt: arg.CreatedAt, + UpdatedAt: arg.UpdatedAt, + OrganizationID: arg.OrganizationID, + Name: arg.Name, + Provisioner: arg.Provisioner, + } + q.project = append(q.project, project) + return project, nil +} + +func (q *fakeQuerier) InsertProjectHistory(_ context.Context, arg database.InsertProjectHistoryParams) (database.ProjectHistory, error) { + //nolint:gosimple + history := database.ProjectHistory{ + ID: arg.ID, + ProjectID: arg.ProjectID, + CreatedAt: arg.CreatedAt, + UpdatedAt: arg.UpdatedAt, + Name: arg.Name, + Description: arg.Description, + StorageMethod: arg.StorageMethod, + StorageSource: arg.StorageSource, + ImportJobID: arg.ImportJobID, + } + q.projectHistory = append(q.projectHistory, history) + return history, nil +} + +func (q *fakeQuerier) InsertProjectParameter(_ context.Context, arg database.InsertProjectParameterParams) (database.ProjectParameter, error) { + //nolint:gosimple + param := database.ProjectParameter{ + ID: arg.ID, + CreatedAt: arg.CreatedAt, + ProjectHistoryID: arg.ProjectHistoryID, + Name: arg.Name, + Description: arg.Description, + DefaultSource: arg.DefaultSource, + AllowOverrideSource: arg.AllowOverrideSource, + DefaultDestination: arg.DefaultDestination, + AllowOverrideDestination: arg.AllowOverrideDestination, + DefaultRefresh: arg.DefaultRefresh, + RedisplayValue: arg.RedisplayValue, + ValidationError: arg.ValidationError, + ValidationCondition: arg.ValidationCondition, + ValidationTypeSystem: arg.ValidationTypeSystem, + ValidationValueType: arg.ValidationValueType, + } + q.projectParameter = append(q.projectParameter, param) + return param, nil +} + func (q *fakeQuerier) InsertUser(_ context.Context, arg database.InsertUserParams) (database.User, error) { user := database.User{ ID: arg.ID, diff --git a/database/dump.sql b/database/dump.sql index e0771d8d39..5e8a1bf7cb 100644 --- a/database/dump.sql +++ b/database/dump.sql @@ -6,6 +6,19 @@ CREATE TYPE login_type AS ENUM ( 'oidc' ); +CREATE TYPE parameter_type_system AS ENUM ( + 'hcl' +); + +CREATE TYPE project_storage_method AS ENUM ( + 'inline-archive' +); + +CREATE TYPE provisioner_type AS ENUM ( + 'terraform', + 'cdr-basic' +); + CREATE TYPE userstatus AS ENUM ( 'active', 'dormant', @@ -57,6 +70,46 @@ CREATE TABLE organizations ( workspace_auto_off boolean DEFAULT false NOT NULL ); +CREATE TABLE project ( + id uuid NOT NULL, + created_at timestamp with time zone NOT NULL, + updated_at timestamp with time zone NOT NULL, + organization_id text NOT NULL, + name character varying(64) NOT NULL, + provisioner provisioner_type NOT NULL, + active_version_id uuid +); + +CREATE TABLE project_history ( + id uuid NOT NULL, + project_id uuid NOT NULL, + created_at timestamp with time zone NOT NULL, + updated_at timestamp with time zone NOT NULL, + name character varying(64) NOT NULL, + description character varying(1048576) NOT NULL, + storage_method project_storage_method NOT NULL, + storage_source bytea NOT NULL, + import_job_id uuid NOT NULL +); + +CREATE TABLE project_parameter ( + id uuid NOT NULL, + created_at timestamp with time zone NOT NULL, + project_history_id uuid NOT NULL, + name character varying(64) NOT NULL, + description character varying(8192) DEFAULT ''::character varying NOT NULL, + default_source text, + allow_override_source boolean NOT NULL, + default_destination text, + allow_override_destination boolean NOT NULL, + default_refresh text NOT NULL, + redisplay_value boolean NOT NULL, + validation_error character varying(256) NOT NULL, + validation_condition character varying(512) NOT NULL, + validation_type_system parameter_type_system NOT NULL, + validation_value_type character varying(64) NOT NULL +); + CREATE TABLE users ( id text NOT NULL, email text NOT NULL, @@ -79,3 +132,27 @@ CREATE TABLE users ( shell text DEFAULT ''::text NOT NULL ); +ALTER TABLE ONLY project_history + ADD CONSTRAINT project_history_id_key UNIQUE (id); + +ALTER TABLE ONLY project_history + ADD CONSTRAINT project_history_project_id_name_key UNIQUE (project_id, name); + +ALTER TABLE ONLY project + ADD CONSTRAINT project_id_key UNIQUE (id); + +ALTER TABLE ONLY project + ADD CONSTRAINT project_organization_id_name_key UNIQUE (organization_id, name); + +ALTER TABLE ONLY project_parameter + ADD CONSTRAINT project_parameter_id_key UNIQUE (id); + +ALTER TABLE ONLY project_parameter + ADD CONSTRAINT project_parameter_project_history_id_name_key UNIQUE (project_history_id, name); + +ALTER TABLE ONLY project_history + ADD CONSTRAINT project_history_project_id_fkey FOREIGN KEY (project_id) REFERENCES project(id); + +ALTER TABLE ONLY project_parameter + ADD CONSTRAINT project_parameter_project_history_id_fkey FOREIGN KEY (project_history_id) REFERENCES project_history(id) ON DELETE CASCADE; + diff --git a/database/migrate_test.go b/database/migrate_test.go index d16671198b..5627e7606c 100644 --- a/database/migrate_test.go +++ b/database/migrate_test.go @@ -20,12 +20,29 @@ func TestMain(m *testing.M) { func TestMigrate(t *testing.T) { t.Parallel() - connection, closeFn, err := postgres.Open() - require.NoError(t, err) - defer closeFn() - db, err := sql.Open("postgres", connection) - require.NoError(t, err) - defer db.Close() - err = database.Migrate(db) - require.NoError(t, err) + t.Run("Once", func(t *testing.T) { + t.Parallel() + connection, closeFn, err := postgres.Open() + require.NoError(t, err) + defer closeFn() + db, err := sql.Open("postgres", connection) + require.NoError(t, err) + defer db.Close() + err = database.Migrate(db) + require.NoError(t, err) + }) + + t.Run("Twice", func(t *testing.T) { + t.Parallel() + connection, closeFn, err := postgres.Open() + require.NoError(t, err) + defer closeFn() + db, err := sql.Open("postgres", connection) + require.NoError(t, err) + defer db.Close() + err = database.Migrate(db) + require.NoError(t, err) + err = database.Migrate(db) + require.NoError(t, err) + }) } diff --git a/database/migrations/000002_projects.down.sql b/database/migrations/000002_projects.down.sql new file mode 100644 index 0000000000..e69de29bb2 diff --git a/database/migrations/000002_projects.up.sql b/database/migrations/000002_projects.up.sql new file mode 100644 index 0000000000..3483dcd9ff --- /dev/null +++ b/database/migrations/000002_projects.up.sql @@ -0,0 +1,84 @@ +CREATE TYPE provisioner_type AS ENUM ('terraform', 'cdr-basic'); + +-- Project defines infrastructure that your software project +-- requires for development. +CREATE TABLE project ( + id uuid NOT NULL UNIQUE, + created_at timestamptz NOT NULL, + updated_at timestamptz NOT NULL, + -- Projects must be scoped to an organization. + organization_id text NOT NULL, + name varchar(64) NOT NULL, + provisioner provisioner_type NOT NULL, + -- Target's a Project Version to use for Workspaces. + -- If a Workspace doesn't match this version, it will be prompted to rebuild. + active_version_id uuid, + -- Disallow projects to have the same name under + -- the same organization. + UNIQUE(organization_id, name) +); + +CREATE TYPE project_storage_method AS ENUM ('inline-archive'); + +-- Project Versions store Project history. When a Project Version is imported, +-- an "import" job is queued to parse parameters. A Project Version +-- can only be used if the import job succeeds. +CREATE TABLE project_history ( + id uuid NOT NULL UNIQUE, + -- This should be indexed. + project_id uuid NOT NULL REFERENCES project (id), + created_at timestamptz NOT NULL, + updated_at timestamptz NOT NULL, + -- Name is generated for ease of differentiation. + -- eg. TheCozyRabbit16 + name varchar(64) NOT NULL, + -- Extracted from a README.md on import. + -- Maximum of 1MB. + description varchar(1048576) NOT NULL, + storage_method project_storage_method NOT NULL, + storage_source bytea NOT NULL, + -- The import job for a Project Version. This is used + -- to detect if an import was successful. + import_job_id uuid NOT NULL, + -- Disallow projects to have the same build name + -- multiple times. + UNIQUE(project_id, name) +); + +-- Types of parameters the automator supports. +CREATE TYPE parameter_type_system AS ENUM ('hcl'); + +-- Stores project version parameters parsed on import. +-- No secrets are stored here. +-- +-- All parameter validation occurs server-side to process +-- complex validations. +-- +-- Parameter types, description, and validation will produce +-- a UI for users to enter values. +-- Needs to be made consistent with the examples below. +CREATE TABLE project_parameter ( + id uuid NOT NULL UNIQUE, + created_at timestamptz NOT NULL, + project_history_id uuid NOT NULL REFERENCES project_history(id) ON DELETE CASCADE, + name varchar(64) NOT NULL, + -- 8KB limit + description varchar(8192) NOT NULL DEFAULT '', + -- eg. data://inlinevalue + default_source text, + -- Allows the user to override the source. + allow_override_source boolean NOT null, + -- eg. env://SOME_VARIABLE, tfvars://example + default_destination text, + -- Allows the user to override the destination. + allow_override_destination boolean NOT null, + default_refresh text NOT NULL, + -- Whether the consumer can view the source and destinations. + redisplay_value boolean NOT null, + -- This error would appear in the UI if the condition is not met. + validation_error varchar(256) NOT NULL, + validation_condition varchar(512) NOT NULL, + validation_type_system parameter_type_system NOT NULL, + validation_value_type varchar(64) NOT NULL, + UNIQUE(project_history_id, name) +); diff --git a/database/models.go b/database/models.go index ac57bc0646..341e7821ec 100644 --- a/database/models.go +++ b/database/models.go @@ -3,9 +3,12 @@ package database import ( + "database/sql" "encoding/json" "fmt" "time" + + "github.com/google/uuid" ) type LoginType string @@ -28,6 +31,61 @@ func (e *LoginType) Scan(src interface{}) error { return nil } +type ParameterTypeSystem string + +const ( + ParameterTypeSystemHCL ParameterTypeSystem = "hcl" +) + +func (e *ParameterTypeSystem) Scan(src interface{}) error { + switch s := src.(type) { + case []byte: + *e = ParameterTypeSystem(s) + case string: + *e = ParameterTypeSystem(s) + default: + return fmt.Errorf("unsupported scan type for ParameterTypeSystem: %T", src) + } + return nil +} + +type ProjectStorageMethod string + +const ( + ProjectStorageMethodInlineArchive ProjectStorageMethod = "inline-archive" +) + +func (e *ProjectStorageMethod) Scan(src interface{}) error { + switch s := src.(type) { + case []byte: + *e = ProjectStorageMethod(s) + case string: + *e = ProjectStorageMethod(s) + default: + return fmt.Errorf("unsupported scan type for ProjectStorageMethod: %T", src) + } + return nil +} + +type ProvisionerType string + +const ( + ProvisionerTypeTerraform ProvisionerType = "terraform" + ProvisionerTypeCdrBasic ProvisionerType = "cdr-basic" +) + +func (e *ProvisionerType) Scan(src interface{}) error { + switch s := src.(type) { + case []byte: + *e = ProvisionerType(s) + case string: + *e = ProvisionerType(s) + default: + return fmt.Errorf("unsupported scan type for ProvisionerType: %T", src) + } + return nil +} + type UserStatus string const ( @@ -93,6 +151,46 @@ type OrganizationMember struct { Roles []string `db:"roles" json:"roles"` } +type Project struct { + ID uuid.UUID `db:"id" json:"id"` + CreatedAt time.Time `db:"created_at" json:"created_at"` + UpdatedAt time.Time `db:"updated_at" json:"updated_at"` + OrganizationID string `db:"organization_id" json:"organization_id"` + Name string `db:"name" json:"name"` + Provisioner ProvisionerType `db:"provisioner" json:"provisioner"` + ActiveVersionID uuid.NullUUID `db:"active_version_id" json:"active_version_id"` +} + +type ProjectHistory struct { + ID uuid.UUID `db:"id" json:"id"` + ProjectID uuid.UUID `db:"project_id" json:"project_id"` + CreatedAt time.Time `db:"created_at" json:"created_at"` + UpdatedAt time.Time `db:"updated_at" json:"updated_at"` + Name string `db:"name" json:"name"` + Description string `db:"description" json:"description"` + StorageMethod ProjectStorageMethod `db:"storage_method" json:"storage_method"` + StorageSource []byte `db:"storage_source" json:"storage_source"` + ImportJobID uuid.UUID `db:"import_job_id" json:"import_job_id"` +} + +type ProjectParameter struct { + ID uuid.UUID `db:"id" json:"id"` + CreatedAt time.Time `db:"created_at" json:"created_at"` + ProjectHistoryID uuid.UUID `db:"project_history_id" json:"project_history_id"` + Name string `db:"name" json:"name"` + Description string `db:"description" json:"description"` + DefaultSource sql.NullString `db:"default_source" json:"default_source"` + AllowOverrideSource bool `db:"allow_override_source" json:"allow_override_source"` + DefaultDestination sql.NullString `db:"default_destination" json:"default_destination"` + AllowOverrideDestination bool `db:"allow_override_destination" json:"allow_override_destination"` + DefaultRefresh string `db:"default_refresh" json:"default_refresh"` + RedisplayValue bool `db:"redisplay_value" json:"redisplay_value"` + ValidationError string `db:"validation_error" json:"validation_error"` + ValidationCondition string `db:"validation_condition" json:"validation_condition"` + ValidationTypeSystem ParameterTypeSystem `db:"validation_type_system" json:"validation_type_system"` + ValidationValueType string `db:"validation_value_type" json:"validation_value_type"` +} + type User struct { ID string `db:"id" json:"id"` Email string `db:"email" json:"email"` diff --git a/database/pubsub_test.go b/database/pubsub_test.go index 0a1eba426f..fb21383d7f 100644 --- a/database/pubsub_test.go +++ b/database/pubsub_test.go @@ -18,6 +18,7 @@ func TestPubsub(t *testing.T) { t.Parallel() t.Run("Postgres", func(t *testing.T) { + t.Parallel() ctx, cancelFunc := context.WithCancel(context.Background()) defer cancelFunc() @@ -45,4 +46,20 @@ func TestPubsub(t *testing.T) { message := <-messageChannel assert.Equal(t, string(message), data) }) + + t.Run("PostgresCloseCancel", func(t *testing.T) { + t.Parallel() + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + connectionURL, close, err := postgres.Open() + require.NoError(t, err) + defer close() + db, err := sql.Open("postgres", connectionURL) + require.NoError(t, err) + defer db.Close() + pubsub, err := database.NewPubsub(ctx, db, connectionURL) + require.NoError(t, err) + defer pubsub.Close() + cancelFunc() + }) } diff --git a/database/querier.go b/database/querier.go index ce7658c10e..255d55face 100644 --- a/database/querier.go +++ b/database/querier.go @@ -4,18 +4,27 @@ package database import ( "context" + + "github.com/google/uuid" ) type querier interface { GetAPIKeyByID(ctx context.Context, id string) (APIKey, error) GetOrganizationByName(ctx context.Context, name string) (Organization, error) + GetOrganizationMemberByUserID(ctx context.Context, arg GetOrganizationMemberByUserIDParams) (OrganizationMember, error) GetOrganizationsByUserID(ctx context.Context, userID string) ([]Organization, error) + GetProjectByOrganizationAndName(ctx context.Context, arg GetProjectByOrganizationAndNameParams) (Project, error) + GetProjectHistoryByProjectID(ctx context.Context, projectID uuid.UUID) ([]ProjectHistory, error) + GetProjectsByOrganizationIDs(ctx context.Context, ids []string) ([]Project, error) GetUserByEmailOrUsername(ctx context.Context, arg GetUserByEmailOrUsernameParams) (User, error) GetUserByID(ctx context.Context, id string) (User, error) GetUserCount(ctx context.Context) (int64, error) InsertAPIKey(ctx context.Context, arg InsertAPIKeyParams) (APIKey, error) InsertOrganization(ctx context.Context, arg InsertOrganizationParams) (Organization, error) InsertOrganizationMember(ctx context.Context, arg InsertOrganizationMemberParams) (OrganizationMember, error) + InsertProject(ctx context.Context, arg InsertProjectParams) (Project, error) + InsertProjectHistory(ctx context.Context, arg InsertProjectHistoryParams) (ProjectHistory, error) + InsertProjectParameter(ctx context.Context, arg InsertProjectParameterParams) (ProjectParameter, error) InsertUser(ctx context.Context, arg InsertUserParams) (User, error) UpdateAPIKeyByID(ctx context.Context, arg UpdateAPIKeyByIDParams) error } diff --git a/database/query.sql b/database/query.sql index 9bc8ad891e..74ab452b61 100644 --- a/database/query.sql +++ b/database/query.sql @@ -42,12 +42,67 @@ FROM users; -- name: GetOrganizationByName :one -SELECT * FROM organizations WHERE name = $1 LIMIT 1; +SELECT + * +FROM + organizations +WHERE + name = $1 +LIMIT + 1; -- name: GetOrganizationsByUserID :many -SELECT * FROM organizations WHERE id = ( - SELECT organization_id FROM organization_members WHERE user_id = $1 -); +SELECT + * +FROM + organizations +WHERE + id = ( + SELECT + organization_id + FROM + organization_members + WHERE + user_id = $1 + ); + +-- name: GetOrganizationMemberByUserID :one +SELECT + * +FROM + organization_members +WHERE + organization_id = $1 + AND user_id = $2 +LIMIT + 1; + +-- name: GetProjectByOrganizationAndName :one +SELECT + * +FROM + project +WHERE + organization_id = $1 + AND name = $2 +LIMIT + 1; + +-- name: GetProjectsByOrganizationIDs :many +SELECT + * +FROM + project +WHERE + organization_id = ANY(@ids :: text [ ]); + +-- name: GetProjectHistoryByProjectID :many +SELECT + * +FROM + project_history +WHERE + project_id = $1; -- name: InsertAPIKey :one INSERT INTO @@ -88,10 +143,89 @@ VALUES ) RETURNING *; -- name: InsertOrganization :one -INSERT INTO organizations (id, name, description, created_at, updated_at) VALUES ($1, $2, $3, $4, $5) RETURNING *; +INSERT INTO + organizations (id, name, description, created_at, updated_at) +VALUES + ($1, $2, $3, $4, $5) RETURNING *; -- name: InsertOrganizationMember :one -INSERT INTO organization_members (organization_id, user_id, created_at, updated_at, roles) VALUES ($1, $2, $3, $4, $5) RETURNING *; +INSERT INTO + organization_members ( + organization_id, + user_id, + created_at, + updated_at, + roles + ) +VALUES + ($1, $2, $3, $4, $5) RETURNING *; + +-- name: InsertProject :one +INSERT INTO + project ( + id, + created_at, + updated_at, + organization_id, + name, + provisioner + ) +VALUES + ($1, $2, $3, $4, $5, $6) RETURNING *; + +-- name: InsertProjectHistory :one +INSERT INTO + project_history ( + id, + project_id, + created_at, + updated_at, + name, + description, + storage_method, + storage_source, + import_job_id + ) +VALUES + ($1, $2, $3, $4, $5, $6, $7, $8, $9) RETURNING *; + +-- name: InsertProjectParameter :one +INSERT INTO + project_parameter ( + id, + created_at, + project_history_id, + name, + description, + default_source, + allow_override_source, + default_destination, + allow_override_destination, + default_refresh, + redisplay_value, + validation_error, + validation_condition, + validation_type_system, + validation_value_type + ) +VALUES + ( + $1, + $2, + $3, + $4, + $5, + $6, + $7, + $8, + $9, + $10, + $11, + $12, + $13, + $14, + $15 + ) RETURNING *; -- name: InsertUser :one INSERT INTO diff --git a/database/query.sql.go b/database/query.sql.go index cc6977aa22..096baad8bb 100644 --- a/database/query.sql.go +++ b/database/query.sql.go @@ -5,8 +5,10 @@ package database import ( "context" + "database/sql" "time" + "github.com/google/uuid" "github.com/lib/pq" ) @@ -45,7 +47,14 @@ func (q *sqlQuerier) GetAPIKeyByID(ctx context.Context, id string) (APIKey, erro } const getOrganizationByName = `-- name: GetOrganizationByName :one -SELECT id, name, description, created_at, updated_at, "default", auto_off_threshold, cpu_provisioning_rate, memory_provisioning_rate, workspace_auto_off FROM organizations WHERE name = $1 LIMIT 1 +SELECT + id, name, description, created_at, updated_at, "default", auto_off_threshold, cpu_provisioning_rate, memory_provisioning_rate, workspace_auto_off +FROM + organizations +WHERE + name = $1 +LIMIT + 1 ` func (q *sqlQuerier) GetOrganizationByName(ctx context.Context, name string) (Organization, error) { @@ -66,10 +75,50 @@ func (q *sqlQuerier) GetOrganizationByName(ctx context.Context, name string) (Or return i, err } +const getOrganizationMemberByUserID = `-- name: GetOrganizationMemberByUserID :one +SELECT + organization_id, user_id, created_at, updated_at, roles +FROM + organization_members +WHERE + organization_id = $1 + AND user_id = $2 +LIMIT + 1 +` + +type GetOrganizationMemberByUserIDParams struct { + OrganizationID string `db:"organization_id" json:"organization_id"` + UserID string `db:"user_id" json:"user_id"` +} + +func (q *sqlQuerier) GetOrganizationMemberByUserID(ctx context.Context, arg GetOrganizationMemberByUserIDParams) (OrganizationMember, error) { + row := q.db.QueryRowContext(ctx, getOrganizationMemberByUserID, arg.OrganizationID, arg.UserID) + var i OrganizationMember + err := row.Scan( + &i.OrganizationID, + &i.UserID, + &i.CreatedAt, + &i.UpdatedAt, + pq.Array(&i.Roles), + ) + return i, err +} + const getOrganizationsByUserID = `-- name: GetOrganizationsByUserID :many -SELECT id, name, description, created_at, updated_at, "default", auto_off_threshold, cpu_provisioning_rate, memory_provisioning_rate, workspace_auto_off FROM organizations WHERE id = ( - SELECT organization_id FROM organization_members WHERE user_id = $1 -) +SELECT + id, name, description, created_at, updated_at, "default", auto_off_threshold, cpu_provisioning_rate, memory_provisioning_rate, workspace_auto_off +FROM + organizations +WHERE + id = ( + SELECT + organization_id + FROM + organization_members + WHERE + user_id = $1 + ) ` func (q *sqlQuerier) GetOrganizationsByUserID(ctx context.Context, userID string) ([]Organization, error) { @@ -106,6 +155,120 @@ func (q *sqlQuerier) GetOrganizationsByUserID(ctx context.Context, userID string return items, nil } +const getProjectByOrganizationAndName = `-- name: GetProjectByOrganizationAndName :one +SELECT + id, created_at, updated_at, organization_id, name, provisioner, active_version_id +FROM + project +WHERE + organization_id = $1 + AND name = $2 +LIMIT + 1 +` + +type GetProjectByOrganizationAndNameParams struct { + OrganizationID string `db:"organization_id" json:"organization_id"` + Name string `db:"name" json:"name"` +} + +func (q *sqlQuerier) GetProjectByOrganizationAndName(ctx context.Context, arg GetProjectByOrganizationAndNameParams) (Project, error) { + row := q.db.QueryRowContext(ctx, getProjectByOrganizationAndName, arg.OrganizationID, arg.Name) + var i Project + err := row.Scan( + &i.ID, + &i.CreatedAt, + &i.UpdatedAt, + &i.OrganizationID, + &i.Name, + &i.Provisioner, + &i.ActiveVersionID, + ) + return i, err +} + +const getProjectHistoryByProjectID = `-- name: GetProjectHistoryByProjectID :many +SELECT + id, project_id, created_at, updated_at, name, description, storage_method, storage_source, import_job_id +FROM + project_history +WHERE + project_id = $1 +` + +func (q *sqlQuerier) GetProjectHistoryByProjectID(ctx context.Context, projectID uuid.UUID) ([]ProjectHistory, error) { + rows, err := q.db.QueryContext(ctx, getProjectHistoryByProjectID, projectID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []ProjectHistory + for rows.Next() { + var i ProjectHistory + if err := rows.Scan( + &i.ID, + &i.ProjectID, + &i.CreatedAt, + &i.UpdatedAt, + &i.Name, + &i.Description, + &i.StorageMethod, + &i.StorageSource, + &i.ImportJobID, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getProjectsByOrganizationIDs = `-- name: GetProjectsByOrganizationIDs :many +SELECT + id, created_at, updated_at, organization_id, name, provisioner, active_version_id +FROM + project +WHERE + organization_id = ANY($1 :: text [ ]) +` + +func (q *sqlQuerier) GetProjectsByOrganizationIDs(ctx context.Context, ids []string) ([]Project, error) { + rows, err := q.db.QueryContext(ctx, getProjectsByOrganizationIDs, pq.Array(ids)) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Project + for rows.Next() { + var i Project + if err := rows.Scan( + &i.ID, + &i.CreatedAt, + &i.UpdatedAt, + &i.OrganizationID, + &i.Name, + &i.Provisioner, + &i.ActiveVersionID, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const getUserByEmailOrUsername = `-- name: GetUserByEmailOrUsername :one SELECT id, email, name, revoked, login_type, hashed_password, created_at, updated_at, temporary_password, avatar_hash, ssh_key_regenerated_at, username, dotfiles_git_uri, roles, status, relatime, gpg_key_regenerated_at, _decomissioned, shell @@ -299,7 +462,10 @@ func (q *sqlQuerier) InsertAPIKey(ctx context.Context, arg InsertAPIKeyParams) ( } const insertOrganization = `-- name: InsertOrganization :one -INSERT INTO organizations (id, name, description, created_at, updated_at) VALUES ($1, $2, $3, $4, $5) RETURNING id, name, description, created_at, updated_at, "default", auto_off_threshold, cpu_provisioning_rate, memory_provisioning_rate, workspace_auto_off +INSERT INTO + organizations (id, name, description, created_at, updated_at) +VALUES + ($1, $2, $3, $4, $5) RETURNING id, name, description, created_at, updated_at, "default", auto_off_threshold, cpu_provisioning_rate, memory_provisioning_rate, workspace_auto_off ` type InsertOrganizationParams struct { @@ -335,7 +501,16 @@ func (q *sqlQuerier) InsertOrganization(ctx context.Context, arg InsertOrganizat } const insertOrganizationMember = `-- name: InsertOrganizationMember :one -INSERT INTO organization_members (organization_id, user_id, created_at, updated_at, roles) VALUES ($1, $2, $3, $4, $5) RETURNING organization_id, user_id, created_at, updated_at, roles +INSERT INTO + organization_members ( + organization_id, + user_id, + created_at, + updated_at, + roles + ) +VALUES + ($1, $2, $3, $4, $5) RETURNING organization_id, user_id, created_at, updated_at, roles ` type InsertOrganizationMemberParams struct { @@ -365,6 +540,203 @@ func (q *sqlQuerier) InsertOrganizationMember(ctx context.Context, arg InsertOrg return i, err } +const insertProject = `-- name: InsertProject :one +INSERT INTO + project ( + id, + created_at, + updated_at, + organization_id, + name, + provisioner + ) +VALUES + ($1, $2, $3, $4, $5, $6) RETURNING id, created_at, updated_at, organization_id, name, provisioner, active_version_id +` + +type InsertProjectParams struct { + ID uuid.UUID `db:"id" json:"id"` + CreatedAt time.Time `db:"created_at" json:"created_at"` + UpdatedAt time.Time `db:"updated_at" json:"updated_at"` + OrganizationID string `db:"organization_id" json:"organization_id"` + Name string `db:"name" json:"name"` + Provisioner ProvisionerType `db:"provisioner" json:"provisioner"` +} + +func (q *sqlQuerier) InsertProject(ctx context.Context, arg InsertProjectParams) (Project, error) { + row := q.db.QueryRowContext(ctx, insertProject, + arg.ID, + arg.CreatedAt, + arg.UpdatedAt, + arg.OrganizationID, + arg.Name, + arg.Provisioner, + ) + var i Project + err := row.Scan( + &i.ID, + &i.CreatedAt, + &i.UpdatedAt, + &i.OrganizationID, + &i.Name, + &i.Provisioner, + &i.ActiveVersionID, + ) + return i, err +} + +const insertProjectHistory = `-- name: InsertProjectHistory :one +INSERT INTO + project_history ( + id, + project_id, + created_at, + updated_at, + name, + description, + storage_method, + storage_source, + import_job_id + ) +VALUES + ($1, $2, $3, $4, $5, $6, $7, $8, $9) RETURNING id, project_id, created_at, updated_at, name, description, storage_method, storage_source, import_job_id +` + +type InsertProjectHistoryParams struct { + ID uuid.UUID `db:"id" json:"id"` + ProjectID uuid.UUID `db:"project_id" json:"project_id"` + CreatedAt time.Time `db:"created_at" json:"created_at"` + UpdatedAt time.Time `db:"updated_at" json:"updated_at"` + Name string `db:"name" json:"name"` + Description string `db:"description" json:"description"` + StorageMethod ProjectStorageMethod `db:"storage_method" json:"storage_method"` + StorageSource []byte `db:"storage_source" json:"storage_source"` + ImportJobID uuid.UUID `db:"import_job_id" json:"import_job_id"` +} + +func (q *sqlQuerier) InsertProjectHistory(ctx context.Context, arg InsertProjectHistoryParams) (ProjectHistory, error) { + row := q.db.QueryRowContext(ctx, insertProjectHistory, + arg.ID, + arg.ProjectID, + arg.CreatedAt, + arg.UpdatedAt, + arg.Name, + arg.Description, + arg.StorageMethod, + arg.StorageSource, + arg.ImportJobID, + ) + var i ProjectHistory + err := row.Scan( + &i.ID, + &i.ProjectID, + &i.CreatedAt, + &i.UpdatedAt, + &i.Name, + &i.Description, + &i.StorageMethod, + &i.StorageSource, + &i.ImportJobID, + ) + return i, err +} + +const insertProjectParameter = `-- name: InsertProjectParameter :one +INSERT INTO + project_parameter ( + id, + created_at, + project_history_id, + name, + description, + default_source, + allow_override_source, + default_destination, + allow_override_destination, + default_refresh, + redisplay_value, + validation_error, + validation_condition, + validation_type_system, + validation_value_type + ) +VALUES + ( + $1, + $2, + $3, + $4, + $5, + $6, + $7, + $8, + $9, + $10, + $11, + $12, + $13, + $14, + $15 + ) RETURNING id, created_at, project_history_id, name, description, default_source, allow_override_source, default_destination, allow_override_destination, default_refresh, redisplay_value, validation_error, validation_condition, validation_type_system, validation_value_type +` + +type InsertProjectParameterParams struct { + ID uuid.UUID `db:"id" json:"id"` + CreatedAt time.Time `db:"created_at" json:"created_at"` + ProjectHistoryID uuid.UUID `db:"project_history_id" json:"project_history_id"` + Name string `db:"name" json:"name"` + Description string `db:"description" json:"description"` + DefaultSource sql.NullString `db:"default_source" json:"default_source"` + AllowOverrideSource bool `db:"allow_override_source" json:"allow_override_source"` + DefaultDestination sql.NullString `db:"default_destination" json:"default_destination"` + AllowOverrideDestination bool `db:"allow_override_destination" json:"allow_override_destination"` + DefaultRefresh string `db:"default_refresh" json:"default_refresh"` + RedisplayValue bool `db:"redisplay_value" json:"redisplay_value"` + ValidationError string `db:"validation_error" json:"validation_error"` + ValidationCondition string `db:"validation_condition" json:"validation_condition"` + ValidationTypeSystem ParameterTypeSystem `db:"validation_type_system" json:"validation_type_system"` + ValidationValueType string `db:"validation_value_type" json:"validation_value_type"` +} + +func (q *sqlQuerier) InsertProjectParameter(ctx context.Context, arg InsertProjectParameterParams) (ProjectParameter, error) { + row := q.db.QueryRowContext(ctx, insertProjectParameter, + arg.ID, + arg.CreatedAt, + arg.ProjectHistoryID, + arg.Name, + arg.Description, + arg.DefaultSource, + arg.AllowOverrideSource, + arg.DefaultDestination, + arg.AllowOverrideDestination, + arg.DefaultRefresh, + arg.RedisplayValue, + arg.ValidationError, + arg.ValidationCondition, + arg.ValidationTypeSystem, + arg.ValidationValueType, + ) + var i ProjectParameter + err := row.Scan( + &i.ID, + &i.CreatedAt, + &i.ProjectHistoryID, + &i.Name, + &i.Description, + &i.DefaultSource, + &i.AllowOverrideSource, + &i.DefaultDestination, + &i.AllowOverrideDestination, + &i.DefaultRefresh, + &i.RedisplayValue, + &i.ValidationError, + &i.ValidationCondition, + &i.ValidationTypeSystem, + &i.ValidationValueType, + ) + return i, err +} + const insertUser = `-- name: InsertUser :one INSERT INTO users ( diff --git a/database/sqlc.yaml b/database/sqlc.yaml index 93279b4229..01ad9dc733 100644 --- a/database/sqlc.yaml +++ b/database/sqlc.yaml @@ -25,4 +25,5 @@ rename: oidc_expiry: OIDCExpiry oidc_id_token: OIDCIDToken oidc_refresh_token: OIDCRefreshToken + parameter_type_system_hcl: ParameterTypeSystemHCL userstatus: UserStatus diff --git a/go.mod b/go.mod index cc36094c30..4a63f8f6c8 100644 --- a/go.mod +++ b/go.mod @@ -18,6 +18,7 @@ require ( github.com/hashicorp/terraform-exec v0.15.0 github.com/justinas/nosurf v1.1.1 github.com/lib/pq v1.10.4 + github.com/moby/moby v20.10.12+incompatible github.com/ory/dockertest/v3 v3.8.1 github.com/pion/datachannel v1.5.2 github.com/pion/logging v0.2.2 diff --git a/go.sum b/go.sum index 514074701a..ce47f85d8a 100644 --- a/go.sum +++ b/go.sum @@ -903,6 +903,8 @@ github.com/mitchellh/osext v0.0.0-20151018003038-5e2d6d41470f/go.mod h1:OkQIRizQ github.com/mitchellh/reflectwalk v1.0.0/go.mod h1:mSTlrgnPZtwu0c4WaC2kGObEpuNDbx0jmZXqmk4esnw= github.com/mitchellh/reflectwalk v1.0.2/go.mod h1:mSTlrgnPZtwu0c4WaC2kGObEpuNDbx0jmZXqmk4esnw= github.com/moby/locker v1.0.1/go.mod h1:S7SDdo5zpBK84bzzVlKr2V0hz+7x9hWbYC/kq7oQppc= +github.com/moby/moby v20.10.12+incompatible h1:MJVrdG0tIQqVJQBTdtooPuZQFIgski5pYTXlcW8ToE0= +github.com/moby/moby v20.10.12+incompatible/go.mod h1:fDXVQ6+S340veQPv35CzDahGBmHsiclFwfEygB/TWMc= github.com/moby/sys/mountinfo v0.4.0/go.mod h1:rEr8tzG/lsIZHBtN/JjGG+LMYx9eXgW2JI+6q0qou+A= github.com/moby/sys/mountinfo v0.4.1/go.mod h1:rEr8tzG/lsIZHBtN/JjGG+LMYx9eXgW2JI+6q0qou+A= github.com/moby/sys/symlink v0.1.0/go.mod h1:GGDODQmbFOjFsXvfLVn3+ZRxkch54RkSiGqsZeMYowQ= diff --git a/httpapi/httpapi_test.go b/httpapi/httpapi_test.go index a459640d39..46cd85e960 100644 --- a/httpapi/httpapi_test.go +++ b/httpapi/httpapi_test.go @@ -13,7 +13,9 @@ import ( ) func TestWrite(t *testing.T) { + t.Parallel() t.Run("NoErrors", func(t *testing.T) { + t.Parallel() rw := httptest.NewRecorder() httpapi.Write(rw, http.StatusOK, httpapi.Response{ Message: "wow", @@ -27,7 +29,9 @@ func TestWrite(t *testing.T) { } func TestRead(t *testing.T) { + t.Parallel() t.Run("EmptyStruct", func(t *testing.T) { + t.Parallel() rw := httptest.NewRecorder() r := httptest.NewRequest("POST", "/", bytes.NewBufferString("{}")) v := struct{}{} @@ -35,6 +39,7 @@ func TestRead(t *testing.T) { }) t.Run("NoBody", func(t *testing.T) { + t.Parallel() rw := httptest.NewRecorder() r := httptest.NewRequest("POST", "/", nil) var v json.RawMessage @@ -42,6 +47,7 @@ func TestRead(t *testing.T) { }) t.Run("Validate", func(t *testing.T) { + t.Parallel() type toValidate struct { Value string `json:"value" validate:"required"` } @@ -54,6 +60,7 @@ func TestRead(t *testing.T) { }) t.Run("ValidateFailure", func(t *testing.T) { + t.Parallel() type toValidate struct { Value string `json:"value" validate:"required"` } @@ -72,6 +79,7 @@ func TestRead(t *testing.T) { } func TestReadUsername(t *testing.T) { + t.Parallel() // Tests whether usernames are valid or not. testCases := []struct { Username string @@ -121,7 +129,9 @@ func TestReadUsername(t *testing.T) { Username string `json:"username" validate:"username"` } for _, testCase := range testCases { + testCase := testCase t.Run(testCase.Username, func(t *testing.T) { + t.Parallel() rw := httptest.NewRecorder() data, err := json.Marshal(toValidate{testCase.Username}) require.NoError(t, err) diff --git a/httpmw/apikey_test.go b/httpmw/apikey_test.go index 6bdebcde6e..4af9fb7173 100644 --- a/httpmw/apikey_test.go +++ b/httpmw/apikey_test.go @@ -35,6 +35,7 @@ func TestAPIKey(t *testing.T) { }) t.Run("NoCookie", func(t *testing.T) { + t.Parallel() var ( db = databasefake.New() r = httptest.NewRequest("GET", "/", nil) @@ -47,6 +48,7 @@ func TestAPIKey(t *testing.T) { }) t.Run("InvalidFormat", func(t *testing.T) { + t.Parallel() var ( db = databasefake.New() r = httptest.NewRequest("GET", "/", nil) @@ -64,6 +66,7 @@ func TestAPIKey(t *testing.T) { }) t.Run("InvalidIDLength", func(t *testing.T) { + t.Parallel() var ( db = databasefake.New() r = httptest.NewRequest("GET", "/", nil) @@ -81,6 +84,7 @@ func TestAPIKey(t *testing.T) { }) t.Run("InvalidSecretLength", func(t *testing.T) { + t.Parallel() var ( db = databasefake.New() r = httptest.NewRequest("GET", "/", nil) @@ -98,6 +102,7 @@ func TestAPIKey(t *testing.T) { }) t.Run("NotFound", func(t *testing.T) { + t.Parallel() var ( db = databasefake.New() id, secret = randomAPIKeyParts() @@ -116,6 +121,7 @@ func TestAPIKey(t *testing.T) { }) t.Run("InvalidSecret", func(t *testing.T) { + t.Parallel() var ( db = databasefake.New() id, secret = randomAPIKeyParts() @@ -141,6 +147,7 @@ func TestAPIKey(t *testing.T) { }) t.Run("Expired", func(t *testing.T) { + t.Parallel() var ( db = databasefake.New() id, secret = randomAPIKeyParts() @@ -165,6 +172,7 @@ func TestAPIKey(t *testing.T) { }) t.Run("Valid", func(t *testing.T) { + t.Parallel() var ( db = databasefake.New() id, secret = randomAPIKeyParts() @@ -203,6 +211,7 @@ func TestAPIKey(t *testing.T) { }) t.Run("ValidUpdateLastUsed", func(t *testing.T) { + t.Parallel() var ( db = databasefake.New() id, secret = randomAPIKeyParts() @@ -235,6 +244,7 @@ func TestAPIKey(t *testing.T) { }) t.Run("ValidUpdateExpiry", func(t *testing.T) { + t.Parallel() var ( db = databasefake.New() id, secret = randomAPIKeyParts() @@ -267,6 +277,7 @@ func TestAPIKey(t *testing.T) { }) t.Run("OIDCNotExpired", func(t *testing.T) { + t.Parallel() var ( db = databasefake.New() id, secret = randomAPIKeyParts() @@ -300,6 +311,7 @@ func TestAPIKey(t *testing.T) { }) t.Run("OIDCRefresh", func(t *testing.T) { + t.Parallel() var ( db = databasefake.New() id, secret = randomAPIKeyParts() diff --git a/httpmw/organizationparam.go b/httpmw/organizationparam.go new file mode 100644 index 0000000000..62f46e60d7 --- /dev/null +++ b/httpmw/organizationparam.go @@ -0,0 +1,86 @@ +package httpmw + +import ( + "context" + "database/sql" + "errors" + "fmt" + "net/http" + + "github.com/go-chi/chi" + + "github.com/coder/coder/database" + "github.com/coder/coder/httpapi" +) + +type organizationParamContextKey struct{} +type organizationMemberParamContextKey struct{} + +// OrganizationParam returns the organization from the ExtractOrganizationParam handler. +func OrganizationParam(r *http.Request) database.Organization { + organization, ok := r.Context().Value(organizationParamContextKey{}).(database.Organization) + if !ok { + panic("developer error: organization param middleware not provided") + } + return organization +} + +// OrganizationMemberParam returns the organization membership that allowed the query +// from the ExtractOrganizationParam handler. +func OrganizationMemberParam(r *http.Request) database.OrganizationMember { + organizationMember, ok := r.Context().Value(organizationMemberParamContextKey{}).(database.OrganizationMember) + if !ok { + panic("developer error: organization param middleware not provided") + } + return organizationMember +} + +// ExtractOrganizationParam grabs an organization and user membership from the "organization" URL parameter. +// This middleware requires the API key middleware higher in the call stack for authentication. +func ExtractOrganizationParam(db database.Store) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + apiKey := APIKey(r) + organizationName := chi.URLParam(r, "organization") + if organizationName == "" { + httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{ + Message: "organization name must be provided", + }) + return + } + organization, err := db.GetOrganizationByName(r.Context(), organizationName) + if errors.Is(err, sql.ErrNoRows) { + httpapi.Write(rw, http.StatusNotFound, httpapi.Response{ + Message: fmt.Sprintf("organization %q does not exist", organizationName), + }) + return + } + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("get organization: %s", err.Error()), + }) + return + } + organizationMember, err := db.GetOrganizationMemberByUserID(r.Context(), database.GetOrganizationMemberByUserIDParams{ + OrganizationID: organization.ID, + UserID: apiKey.UserID, + }) + if errors.Is(err, sql.ErrNoRows) { + httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{ + Message: "not a member of the organization", + }) + return + } + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("get organization member: %s", err.Error()), + }) + return + } + + ctx := context.WithValue(r.Context(), organizationParamContextKey{}, organization) + ctx = context.WithValue(ctx, organizationMemberParamContextKey{}, organizationMember) + next.ServeHTTP(rw, r.WithContext(ctx)) + }) + } +} diff --git a/httpmw/organizationparam_test.go b/httpmw/organizationparam_test.go new file mode 100644 index 0000000000..b1b43356a9 --- /dev/null +++ b/httpmw/organizationparam_test.go @@ -0,0 +1,165 @@ +package httpmw_test + +import ( + "context" + "crypto/sha256" + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/go-chi/chi" + "github.com/google/uuid" + "github.com/stretchr/testify/require" + + "github.com/coder/coder/cryptorand" + "github.com/coder/coder/database" + "github.com/coder/coder/database/databasefake" + "github.com/coder/coder/httpmw" +) + +func TestOrganizationParam(t *testing.T) { + t.Parallel() + + setupAuthentication := func(db database.Store) (*http.Request, database.User) { + var ( + id, secret = randomAPIKeyParts() + r = httptest.NewRequest("GET", "/", nil) + hashed = sha256.Sum256([]byte(secret)) + ) + r.AddCookie(&http.Cookie{ + Name: httpmw.AuthCookie, + Value: fmt.Sprintf("%s-%s", id, secret), + }) + userID, err := cryptorand.String(16) + require.NoError(t, err) + username, err := cryptorand.String(8) + require.NoError(t, err) + user, err := db.InsertUser(r.Context(), database.InsertUserParams{ + ID: userID, + Email: "testaccount@coder.com", + Name: "example", + LoginType: database.LoginTypeBuiltIn, + HashedPassword: hashed[:], + Username: username, + CreatedAt: database.Now(), + UpdatedAt: database.Now(), + }) + require.NoError(t, err) + _, err = db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{ + ID: id, + UserID: user.ID, + HashedSecret: hashed[:], + LastUsed: database.Now(), + ExpiresAt: database.Now().Add(time.Minute), + }) + require.NoError(t, err) + r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, chi.NewRouteContext())) + return r, user + } + + t.Run("None", func(t *testing.T) { + t.Parallel() + var ( + db = databasefake.New() + rw = httptest.NewRecorder() + r, _ = setupAuthentication(db) + rtr = chi.NewRouter() + ) + rtr.Use( + httpmw.ExtractAPIKey(db, nil), + httpmw.ExtractOrganizationParam(db), + ) + rtr.Get("/", nil) + rtr.ServeHTTP(rw, r) + res := rw.Result() + defer res.Body.Close() + require.Equal(t, http.StatusBadRequest, res.StatusCode) + }) + + t.Run("NotFound", func(t *testing.T) { + t.Parallel() + var ( + db = databasefake.New() + rw = httptest.NewRecorder() + r, _ = setupAuthentication(db) + rtr = chi.NewRouter() + ) + chi.RouteContext(r.Context()).URLParams.Add("organization", "nothin") + rtr.Use( + httpmw.ExtractAPIKey(db, nil), + httpmw.ExtractOrganizationParam(db), + ) + rtr.Get("/", nil) + rtr.ServeHTTP(rw, r) + res := rw.Result() + defer res.Body.Close() + require.Equal(t, http.StatusNotFound, res.StatusCode) + }) + + t.Run("NotInOrganization", func(t *testing.T) { + t.Parallel() + var ( + db = databasefake.New() + rw = httptest.NewRecorder() + r, _ = setupAuthentication(db) + rtr = chi.NewRouter() + ) + organization, err := db.InsertOrganization(r.Context(), database.InsertOrganizationParams{ + ID: uuid.NewString(), + Name: "test", + CreatedAt: database.Now(), + UpdatedAt: database.Now(), + }) + require.NoError(t, err) + chi.RouteContext(r.Context()).URLParams.Add("organization", organization.Name) + rtr.Use( + httpmw.ExtractAPIKey(db, nil), + httpmw.ExtractOrganizationParam(db), + ) + rtr.Get("/", nil) + rtr.ServeHTTP(rw, r) + res := rw.Result() + defer res.Body.Close() + require.Equal(t, http.StatusUnauthorized, res.StatusCode) + }) + + t.Run("Success", func(t *testing.T) { + t.Parallel() + var ( + db = databasefake.New() + rw = httptest.NewRecorder() + r, user = setupAuthentication(db) + rtr = chi.NewRouter() + ) + organization, err := db.InsertOrganization(r.Context(), database.InsertOrganizationParams{ + ID: uuid.NewString(), + Name: "test", + CreatedAt: database.Now(), + UpdatedAt: database.Now(), + }) + require.NoError(t, err) + _, err = db.InsertOrganizationMember(r.Context(), database.InsertOrganizationMemberParams{ + OrganizationID: organization.ID, + UserID: user.ID, + CreatedAt: database.Now(), + UpdatedAt: database.Now(), + }) + require.NoError(t, err) + chi.RouteContext(r.Context()).URLParams.Add("organization", organization.Name) + rtr.Use( + httpmw.ExtractAPIKey(db, nil), + httpmw.ExtractOrganizationParam(db), + ) + rtr.Get("/", func(rw http.ResponseWriter, r *http.Request) { + _ = httpmw.OrganizationParam(r) + _ = httpmw.OrganizationMemberParam(r) + rw.WriteHeader(http.StatusOK) + }) + rtr.ServeHTTP(rw, r) + res := rw.Result() + defer res.Body.Close() + require.Equal(t, http.StatusOK, res.StatusCode) + }) +} diff --git a/httpmw/projectparam.go b/httpmw/projectparam.go new file mode 100644 index 0000000000..7daf681304 --- /dev/null +++ b/httpmw/projectparam.go @@ -0,0 +1,60 @@ +package httpmw + +import ( + "context" + "database/sql" + "errors" + "fmt" + "net/http" + + "github.com/go-chi/chi" + + "github.com/coder/coder/database" + "github.com/coder/coder/httpapi" +) + +type projectParamContextKey struct{} + +// ProjectParam returns the project from the ExtractProjectParameter handler. +func ProjectParam(r *http.Request) database.Project { + project, ok := r.Context().Value(projectParamContextKey{}).(database.Project) + if !ok { + panic("developer error: project param middleware not provided") + } + return project +} + +// ExtractProjectParameter grabs a project from the "project" URL parameter. +func ExtractProjectParameter(db database.Store) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + organization := OrganizationParam(r) + projectName := chi.URLParam(r, "project") + if projectName == "" { + httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{ + Message: "project name must be provided", + }) + return + } + project, err := db.GetProjectByOrganizationAndName(r.Context(), database.GetProjectByOrganizationAndNameParams{ + OrganizationID: organization.ID, + Name: projectName, + }) + if errors.Is(err, sql.ErrNoRows) { + httpapi.Write(rw, http.StatusNotFound, httpapi.Response{ + Message: fmt.Sprintf("project %q does not exist", projectName), + }) + return + } + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("get project: %s", err.Error()), + }) + return + } + + ctx := context.WithValue(r.Context(), projectParamContextKey{}, project) + next.ServeHTTP(rw, r.WithContext(ctx)) + }) + } +} diff --git a/httpmw/projectparam_test.go b/httpmw/projectparam_test.go new file mode 100644 index 0000000000..129109df50 --- /dev/null +++ b/httpmw/projectparam_test.go @@ -0,0 +1,151 @@ +package httpmw_test + +import ( + "context" + "crypto/sha256" + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/go-chi/chi" + "github.com/google/uuid" + "github.com/stretchr/testify/require" + + "github.com/coder/coder/cryptorand" + "github.com/coder/coder/database" + "github.com/coder/coder/database/databasefake" + "github.com/coder/coder/httpmw" +) + +func TestProjectParam(t *testing.T) { + t.Parallel() + + setupAuthentication := func(db database.Store) (*http.Request, database.Organization) { + var ( + id, secret = randomAPIKeyParts() + hashed = sha256.Sum256([]byte(secret)) + ) + r := httptest.NewRequest("GET", "/", nil) + r.AddCookie(&http.Cookie{ + Name: httpmw.AuthCookie, + Value: fmt.Sprintf("%s-%s", id, secret), + }) + userID, err := cryptorand.String(16) + require.NoError(t, err) + username, err := cryptorand.String(8) + require.NoError(t, err) + user, err := db.InsertUser(r.Context(), database.InsertUserParams{ + ID: userID, + Email: "testaccount@coder.com", + Name: "example", + LoginType: database.LoginTypeBuiltIn, + HashedPassword: hashed[:], + Username: username, + CreatedAt: database.Now(), + UpdatedAt: database.Now(), + }) + require.NoError(t, err) + _, err = db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{ + ID: id, + UserID: user.ID, + HashedSecret: hashed[:], + LastUsed: database.Now(), + ExpiresAt: database.Now().Add(time.Minute), + }) + require.NoError(t, err) + orgID, err := cryptorand.String(16) + require.NoError(t, err) + organization, err := db.InsertOrganization(r.Context(), database.InsertOrganizationParams{ + ID: orgID, + Name: "banana", + Description: "wowie", + CreatedAt: database.Now(), + UpdatedAt: database.Now(), + }) + require.NoError(t, err) + _, err = db.InsertOrganizationMember(r.Context(), database.InsertOrganizationMemberParams{ + OrganizationID: orgID, + UserID: user.ID, + CreatedAt: database.Now(), + UpdatedAt: database.Now(), + }) + require.NoError(t, err) + + ctx := chi.NewRouteContext() + ctx.URLParams.Add("organization", organization.Name) + r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, ctx)) + return r, organization + } + + t.Run("None", func(t *testing.T) { + t.Parallel() + db := databasefake.New() + rtr := chi.NewRouter() + rtr.Use( + httpmw.ExtractAPIKey(db, nil), + httpmw.ExtractOrganizationParam(db), + httpmw.ExtractProjectParameter(db), + ) + rtr.Get("/", nil) + r, _ := setupAuthentication(db) + rw := httptest.NewRecorder() + rtr.ServeHTTP(rw, r) + + res := rw.Result() + defer res.Body.Close() + require.Equal(t, http.StatusBadRequest, res.StatusCode) + }) + + t.Run("NotFound", func(t *testing.T) { + t.Parallel() + db := databasefake.New() + rtr := chi.NewRouter() + rtr.Use( + httpmw.ExtractAPIKey(db, nil), + httpmw.ExtractOrganizationParam(db), + httpmw.ExtractProjectParameter(db), + ) + rtr.Get("/", nil) + + r, _ := setupAuthentication(db) + chi.RouteContext(r.Context()).URLParams.Add("project", "nothin") + rw := httptest.NewRecorder() + rtr.ServeHTTP(rw, r) + + res := rw.Result() + defer res.Body.Close() + require.Equal(t, http.StatusNotFound, res.StatusCode) + }) + + t.Run("Project", func(t *testing.T) { + t.Parallel() + db := databasefake.New() + rtr := chi.NewRouter() + rtr.Use( + httpmw.ExtractAPIKey(db, nil), + httpmw.ExtractOrganizationParam(db), + httpmw.ExtractProjectParameter(db), + ) + rtr.Get("/", func(rw http.ResponseWriter, r *http.Request) { + _ = httpmw.ProjectParam(r) + rw.WriteHeader(http.StatusOK) + }) + + r, org := setupAuthentication(db) + project, err := db.InsertProject(context.Background(), database.InsertProjectParams{ + ID: uuid.New(), + OrganizationID: org.ID, + Name: "moo", + }) + require.NoError(t, err) + chi.RouteContext(r.Context()).URLParams.Add("project", project.Name) + rw := httptest.NewRecorder() + rtr.ServeHTTP(rw, r) + + res := rw.Result() + defer res.Body.Close() + require.Equal(t, http.StatusOK, res.StatusCode) + }) +} diff --git a/httpmw/userparam_test.go b/httpmw/userparam_test.go index 58204d899f..867833ffbd 100644 --- a/httpmw/userparam_test.go +++ b/httpmw/userparam_test.go @@ -18,6 +18,7 @@ import ( ) func TestUserParam(t *testing.T) { + t.Parallel() setup := func(t *testing.T) (database.Store, *httptest.ResponseRecorder, *http.Request) { var ( db = databasefake.New() @@ -47,6 +48,7 @@ func TestUserParam(t *testing.T) { } t.Run("None", func(t *testing.T) { + t.Parallel() db, rw, r := setup(t) httpmw.ExtractAPIKey(db, nil)(http.HandlerFunc(func(rw http.ResponseWriter, returnedRequest *http.Request) { @@ -62,6 +64,7 @@ func TestUserParam(t *testing.T) { }) t.Run("NotMe", func(t *testing.T) { + t.Parallel() db, rw, r := setup(t) httpmw.ExtractAPIKey(db, nil)(http.HandlerFunc(func(rw http.ResponseWriter, returnedRequest *http.Request) { @@ -80,6 +83,7 @@ func TestUserParam(t *testing.T) { }) t.Run("Me", func(t *testing.T) { + t.Parallel() db, rw, r := setup(t) httpmw.ExtractAPIKey(db, nil)(http.HandlerFunc(func(rw http.ResponseWriter, returnedRequest *http.Request) { diff --git a/peer/conn_test.go b/peer/conn_test.go index 29ec872a9f..0b954579c9 100644 --- a/peer/conn_test.go +++ b/peer/conn_test.go @@ -211,6 +211,7 @@ func TestConn(t *testing.T) { }) t.Run("CloseWithError", func(t *testing.T) { + t.Parallel() conn, err := peer.Client([]webrtc.ICEServer{}, nil) require.NoError(t, err) expectedErr := errors.New("wow") diff --git a/peerbroker/listen_test.go b/peerbroker/listen_test.go index 246d72bf15..c66d8a480a 100644 --- a/peerbroker/listen_test.go +++ b/peerbroker/listen_test.go @@ -14,9 +14,11 @@ import ( ) func TestListen(t *testing.T) { + t.Parallel() // Ensures connections blocked on Accept() are // closed if the listener is. t.Run("NoAcceptClosed", func(t *testing.T) { + t.Parallel() ctx := context.Background() client, server := provisionersdk.TransportPipe() defer client.Close() @@ -37,6 +39,7 @@ func TestListen(t *testing.T) { // Ensures Accept() properly exits when Close() is called. t.Run("AcceptClosed", func(t *testing.T) { + t.Parallel() client, server := provisionersdk.TransportPipe() defer client.Close() defer server.Close() diff --git a/provisionersdk/serve_test.go b/provisionersdk/serve_test.go index ed2b96ae6d..a08180edb0 100644 --- a/provisionersdk/serve_test.go +++ b/provisionersdk/serve_test.go @@ -18,7 +18,9 @@ func TestMain(m *testing.M) { } func TestProvisionerSDK(t *testing.T) { + t.Parallel() t.Run("Serve", func(t *testing.T) { + t.Parallel() client, server := provisionersdk.TransportPipe() defer client.Close() defer server.Close() @@ -37,6 +39,7 @@ func TestProvisionerSDK(t *testing.T) { require.Equal(t, drpcerr.Unimplemented, int(drpcerr.Code(err))) }) t.Run("ServeClosedPipe", func(t *testing.T) { + t.Parallel() client, server := provisionersdk.TransportPipe() _ = client.Close() _ = server.Close()