mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
feat(coderd): filter expired API tokens server-side (#22263)
## Summary
Moves expired token filtering from client-side to server-side by adding
an `include_expired` parameter to the `GetAPIKeysByLoginType` and
`GetAPIKeysByUserID` database queries. This is more efficient for large
deployments with many expired/short-lived tokens.
## Changes
- Add `include_expired` parameter to SQL queries using `OR`
short-circuit
- Add `include_expired` query parameter to `GET
/users/{user}/keys/tokens`
- Add `IncludeExpired` field to `codersdk.TokensFilter`
- Remove client-side filtering from CLI `tokens list` command
- Add `TestTokensFilterExpired` test
Fixes coder/internal#1357
This commit is contained in:
+2
-15
@@ -241,26 +241,13 @@ func (r *RootCmd) listTokens() *serpent.Command {
|
||||
}
|
||||
|
||||
tokens, err := client.Tokens(inv.Context(), codersdk.Me, codersdk.TokensFilter{
|
||||
IncludeAll: all,
|
||||
IncludeAll: all,
|
||||
IncludeExpired: includeExpired,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("list tokens: %w", err)
|
||||
}
|
||||
|
||||
// Filter out expired tokens unless --include-expired is set
|
||||
// TODO(Cian): This _could_ get too big for client-side filtering.
|
||||
// If it causes issues, we can filter server-side.
|
||||
if !includeExpired {
|
||||
now := time.Now()
|
||||
filtered := make([]codersdk.APIKeyWithOwner, 0, len(tokens))
|
||||
for _, token := range tokens {
|
||||
if token.ExpiresAt.After(now) {
|
||||
filtered = append(filtered, token)
|
||||
}
|
||||
}
|
||||
tokens = filtered
|
||||
}
|
||||
|
||||
displayTokens = make([]tokenListRow, len(tokens))
|
||||
|
||||
for i, token := range tokens {
|
||||
|
||||
Generated
+6
@@ -8238,6 +8238,12 @@ const docTemplate = `{
|
||||
"name": "user",
|
||||
"in": "path",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"type": "boolean",
|
||||
"description": "Include expired tokens in the list",
|
||||
"name": "include_expired",
|
||||
"in": "query"
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
|
||||
Generated
+6
@@ -7285,6 +7285,12 @@
|
||||
"name": "user",
|
||||
"in": "path",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"type": "boolean",
|
||||
"description": "Include expired tokens in the list",
|
||||
"name": "include_expired",
|
||||
"in": "query"
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
|
||||
+14
-8
@@ -307,20 +307,26 @@ func (api *API) apiKeyByName(rw http.ResponseWriter, r *http.Request) {
|
||||
// @Tags Users
|
||||
// @Param user path string true "User ID, name, or me"
|
||||
// @Success 200 {array} codersdk.APIKey
|
||||
// @Param include_expired query bool false "Include expired tokens in the list"
|
||||
// @Router /users/{user}/keys/tokens [get]
|
||||
func (api *API) tokens(rw http.ResponseWriter, r *http.Request) {
|
||||
var (
|
||||
ctx = r.Context()
|
||||
user = httpmw.UserParam(r)
|
||||
keys []database.APIKey
|
||||
err error
|
||||
queryStr = r.URL.Query().Get("include_all")
|
||||
includeAll, _ = strconv.ParseBool(queryStr)
|
||||
ctx = r.Context()
|
||||
user = httpmw.UserParam(r)
|
||||
keys []database.APIKey
|
||||
err error
|
||||
queryStr = r.URL.Query().Get("include_all")
|
||||
includeAll, _ = strconv.ParseBool(queryStr)
|
||||
expiredStr = r.URL.Query().Get("include_expired")
|
||||
includeExpired, _ = strconv.ParseBool(expiredStr)
|
||||
)
|
||||
|
||||
if includeAll {
|
||||
// get tokens for all users
|
||||
keys, err = api.Database.GetAPIKeysByLoginType(ctx, database.LoginTypeToken)
|
||||
keys, err = api.Database.GetAPIKeysByLoginType(ctx, database.GetAPIKeysByLoginTypeParams{
|
||||
LoginType: database.LoginTypeToken,
|
||||
IncludeExpired: includeExpired,
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching API keys.",
|
||||
@@ -330,7 +336,7 @@ func (api *API) tokens(rw http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
} else {
|
||||
// get user's tokens only
|
||||
keys, err = api.Database.GetAPIKeysByUserID(ctx, database.GetAPIKeysByUserIDParams{LoginType: database.LoginTypeToken, UserID: user.ID})
|
||||
keys, err = api.Database.GetAPIKeysByUserID(ctx, database.GetAPIKeysByUserIDParams{LoginType: database.LoginTypeToken, UserID: user.ID, IncludeExpired: includeExpired})
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching API keys.",
|
||||
|
||||
@@ -69,6 +69,44 @@ func TestTokenCRUD(t *testing.T) {
|
||||
require.Equal(t, database.AuditActionDelete, auditor.AuditLogs()[numLogs-1].Action)
|
||||
}
|
||||
|
||||
func TestTokensFilterExpired(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
adminClient := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, adminClient)
|
||||
|
||||
// Create a token.
|
||||
res, err := adminClient.CreateToken(ctx, codersdk.Me, codersdk.CreateTokenRequest{
|
||||
Lifetime: time.Hour * 24 * 7,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
keyID := strings.Split(res.Key, "-")[0]
|
||||
|
||||
// List tokens without including expired - should see the token.
|
||||
keys, err := adminClient.Tokens(ctx, codersdk.Me, codersdk.TokensFilter{})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, keys, 1)
|
||||
|
||||
// Expire the token.
|
||||
err = adminClient.ExpireAPIKey(ctx, codersdk.Me, keyID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// List tokens without including expired - should NOT see expired token.
|
||||
keys, err = adminClient.Tokens(ctx, codersdk.Me, codersdk.TokensFilter{})
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, keys)
|
||||
|
||||
// List tokens WITH including expired - should see expired token.
|
||||
keys, err = adminClient.Tokens(ctx, codersdk.Me, codersdk.TokensFilter{
|
||||
IncludeExpired: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, keys, 1)
|
||||
require.Equal(t, keyID, keys[0].ID)
|
||||
}
|
||||
|
||||
func TestTokenScoped(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -2194,12 +2194,12 @@ func (q *querier) GetAPIKeyByName(ctx context.Context, arg database.GetAPIKeyByN
|
||||
return fetch(q.log, q.auth, q.db.GetAPIKeyByName)(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetAPIKeysByLoginType(ctx context.Context, loginType database.LoginType) ([]database.APIKey, error) {
|
||||
func (q *querier) GetAPIKeysByLoginType(ctx context.Context, loginType database.GetAPIKeysByLoginTypeParams) ([]database.APIKey, error) {
|
||||
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetAPIKeysByLoginType)(ctx, loginType)
|
||||
}
|
||||
|
||||
func (q *querier) GetAPIKeysByUserID(ctx context.Context, params database.GetAPIKeysByUserIDParams) ([]database.APIKey, error) {
|
||||
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetAPIKeysByUserID)(ctx, database.GetAPIKeysByUserIDParams{LoginType: params.LoginType, UserID: params.UserID})
|
||||
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetAPIKeysByUserID)(ctx, params)
|
||||
}
|
||||
|
||||
func (q *querier) GetAPIKeysLastUsedAfter(ctx context.Context, lastUsed time.Time) ([]database.APIKey, error) {
|
||||
|
||||
@@ -237,8 +237,8 @@ func (s *MethodTestSuite) TestAPIKey() {
|
||||
s.Run("GetAPIKeysByLoginType", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
a := testutil.Fake(s.T(), faker, database.APIKey{LoginType: database.LoginTypePassword})
|
||||
b := testutil.Fake(s.T(), faker, database.APIKey{LoginType: database.LoginTypePassword})
|
||||
dbm.EXPECT().GetAPIKeysByLoginType(gomock.Any(), database.LoginTypePassword).Return([]database.APIKey{a, b}, nil).AnyTimes()
|
||||
check.Args(database.LoginTypePassword).Asserts(a, policy.ActionRead, b, policy.ActionRead).Returns(slice.New(a, b))
|
||||
dbm.EXPECT().GetAPIKeysByLoginType(gomock.Any(), database.GetAPIKeysByLoginTypeParams{LoginType: database.LoginTypePassword}).Return([]database.APIKey{a, b}, nil).AnyTimes()
|
||||
check.Args(database.GetAPIKeysByLoginTypeParams{LoginType: database.LoginTypePassword}).Asserts(a, policy.ActionRead, b, policy.ActionRead).Returns(slice.New(a, b))
|
||||
}))
|
||||
s.Run("GetAPIKeysByUserID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
u1 := testutil.Fake(s.T(), faker, database.User{})
|
||||
|
||||
@@ -774,7 +774,7 @@ func (m queryMetricsStore) GetAPIKeyByName(ctx context.Context, arg database.Get
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetAPIKeysByLoginType(ctx context.Context, loginType database.LoginType) ([]database.APIKey, error) {
|
||||
func (m queryMetricsStore) GetAPIKeysByLoginType(ctx context.Context, loginType database.GetAPIKeysByLoginTypeParams) ([]database.APIKey, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetAPIKeysByLoginType(ctx, loginType)
|
||||
m.queryLatencies.WithLabelValues("GetAPIKeysByLoginType").Observe(time.Since(start).Seconds())
|
||||
|
||||
@@ -1305,18 +1305,18 @@ func (mr *MockStoreMockRecorder) GetAPIKeyByName(ctx, arg any) *gomock.Call {
|
||||
}
|
||||
|
||||
// GetAPIKeysByLoginType mocks base method.
|
||||
func (m *MockStore) GetAPIKeysByLoginType(ctx context.Context, loginType database.LoginType) ([]database.APIKey, error) {
|
||||
func (m *MockStore) GetAPIKeysByLoginType(ctx context.Context, arg database.GetAPIKeysByLoginTypeParams) ([]database.APIKey, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetAPIKeysByLoginType", ctx, loginType)
|
||||
ret := m.ctrl.Call(m, "GetAPIKeysByLoginType", ctx, arg)
|
||||
ret0, _ := ret[0].([]database.APIKey)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetAPIKeysByLoginType indicates an expected call of GetAPIKeysByLoginType.
|
||||
func (mr *MockStoreMockRecorder) GetAPIKeysByLoginType(ctx, loginType any) *gomock.Call {
|
||||
func (mr *MockStoreMockRecorder) GetAPIKeysByLoginType(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAPIKeysByLoginType", reflect.TypeOf((*MockStore)(nil).GetAPIKeysByLoginType), ctx, loginType)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAPIKeysByLoginType", reflect.TypeOf((*MockStore)(nil).GetAPIKeysByLoginType), ctx, arg)
|
||||
}
|
||||
|
||||
// GetAPIKeysByUserID mocks base method.
|
||||
|
||||
@@ -169,7 +169,7 @@ type sqlcQuerier interface {
|
||||
GetAPIKeyByID(ctx context.Context, id string) (APIKey, error)
|
||||
// there is no unique constraint on empty token names
|
||||
GetAPIKeyByName(ctx context.Context, arg GetAPIKeyByNameParams) (APIKey, error)
|
||||
GetAPIKeysByLoginType(ctx context.Context, loginType LoginType) ([]APIKey, error)
|
||||
GetAPIKeysByLoginType(ctx context.Context, arg GetAPIKeysByLoginTypeParams) ([]APIKey, error)
|
||||
GetAPIKeysByUserID(ctx context.Context, arg GetAPIKeysByUserIDParams) ([]APIKey, error)
|
||||
GetAPIKeysLastUsedAfter(ctx context.Context, lastUsed time.Time) ([]APIKey, error)
|
||||
GetActivePresetPrebuildSchedules(ctx context.Context) ([]TemplateVersionPresetPrebuildSchedule, error)
|
||||
|
||||
@@ -8195,8 +8195,9 @@ func TestDeleteExpiredAPIKeys(t *testing.T) {
|
||||
|
||||
// All keys are present before deletion
|
||||
keys, err := db.GetAPIKeysByUserID(ctx, database.GetAPIKeysByUserIDParams{
|
||||
LoginType: user.LoginType,
|
||||
UserID: user.ID,
|
||||
LoginType: user.LoginType,
|
||||
UserID: user.ID,
|
||||
IncludeExpired: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, keys, len(expiredTimes)+len(unexpiredTimes))
|
||||
@@ -8212,8 +8213,9 @@ func TestDeleteExpiredAPIKeys(t *testing.T) {
|
||||
|
||||
// Ensure it was deleted
|
||||
remaining, err := db.GetAPIKeysByUserID(ctx, database.GetAPIKeysByUserIDParams{
|
||||
LoginType: user.LoginType,
|
||||
UserID: user.ID,
|
||||
LoginType: user.LoginType,
|
||||
UserID: user.ID,
|
||||
IncludeExpired: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, remaining, len(expiredTimes)+len(unexpiredTimes)-1)
|
||||
@@ -8228,8 +8230,9 @@ func TestDeleteExpiredAPIKeys(t *testing.T) {
|
||||
|
||||
// Ensure only unexpired keys remain
|
||||
remaining, err = db.GetAPIKeysByUserID(ctx, database.GetAPIKeysByUserIDParams{
|
||||
LoginType: user.LoginType,
|
||||
UserID: user.ID,
|
||||
LoginType: user.LoginType,
|
||||
UserID: user.ID,
|
||||
IncludeExpired: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, remaining, len(unexpiredTimes))
|
||||
|
||||
@@ -1270,10 +1270,16 @@ func (q *sqlQuerier) GetAPIKeyByName(ctx context.Context, arg GetAPIKeyByNamePar
|
||||
|
||||
const getAPIKeysByLoginType = `-- name: GetAPIKeysByLoginType :many
|
||||
SELECT id, hashed_secret, user_id, last_used, expires_at, created_at, updated_at, login_type, lifetime_seconds, ip_address, token_name, scopes, allow_list FROM api_keys WHERE login_type = $1
|
||||
AND ($2::bool OR expires_at > now())
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) GetAPIKeysByLoginType(ctx context.Context, loginType LoginType) ([]APIKey, error) {
|
||||
rows, err := q.db.QueryContext(ctx, getAPIKeysByLoginType, loginType)
|
||||
type GetAPIKeysByLoginTypeParams struct {
|
||||
LoginType LoginType `db:"login_type" json:"login_type"`
|
||||
IncludeExpired bool `db:"include_expired" json:"include_expired"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) GetAPIKeysByLoginType(ctx context.Context, arg GetAPIKeysByLoginTypeParams) ([]APIKey, error) {
|
||||
rows, err := q.db.QueryContext(ctx, getAPIKeysByLoginType, arg.LoginType, arg.IncludeExpired)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1311,15 +1317,17 @@ func (q *sqlQuerier) GetAPIKeysByLoginType(ctx context.Context, loginType LoginT
|
||||
|
||||
const getAPIKeysByUserID = `-- name: GetAPIKeysByUserID :many
|
||||
SELECT id, hashed_secret, user_id, last_used, expires_at, created_at, updated_at, login_type, lifetime_seconds, ip_address, token_name, scopes, allow_list FROM api_keys WHERE login_type = $1 AND user_id = $2
|
||||
AND ($3::bool OR expires_at > now())
|
||||
`
|
||||
|
||||
type GetAPIKeysByUserIDParams struct {
|
||||
LoginType LoginType `db:"login_type" json:"login_type"`
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
LoginType LoginType `db:"login_type" json:"login_type"`
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
IncludeExpired bool `db:"include_expired" json:"include_expired"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) GetAPIKeysByUserID(ctx context.Context, arg GetAPIKeysByUserIDParams) ([]APIKey, error) {
|
||||
rows, err := q.db.QueryContext(ctx, getAPIKeysByUserID, arg.LoginType, arg.UserID)
|
||||
rows, err := q.db.QueryContext(ctx, getAPIKeysByUserID, arg.LoginType, arg.UserID, arg.IncludeExpired)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -25,10 +25,12 @@ LIMIT
|
||||
SELECT * FROM api_keys WHERE last_used > $1;
|
||||
|
||||
-- name: GetAPIKeysByLoginType :many
|
||||
SELECT * FROM api_keys WHERE login_type = $1;
|
||||
SELECT * FROM api_keys WHERE login_type = $1
|
||||
AND (@include_expired::bool OR expires_at > now());
|
||||
|
||||
-- name: GetAPIKeysByUserID :many
|
||||
SELECT * FROM api_keys WHERE login_type = $1 AND user_id = $2;
|
||||
SELECT * FROM api_keys WHERE login_type = $1 AND user_id = $2
|
||||
AND (@include_expired::bool OR expires_at > now());
|
||||
|
||||
-- name: InsertAPIKey :one
|
||||
INSERT INTO
|
||||
|
||||
+3
-1
@@ -94,7 +94,8 @@ func (c *Client) CreateAPIKey(ctx context.Context, user string) (GenerateAPIKeyR
|
||||
}
|
||||
|
||||
type TokensFilter struct {
|
||||
IncludeAll bool `json:"include_all"`
|
||||
IncludeAll bool `json:"include_all"`
|
||||
IncludeExpired bool `json:"include_expired"`
|
||||
}
|
||||
|
||||
type APIKeyWithOwner struct {
|
||||
@@ -112,6 +113,7 @@ func (f TokensFilter) asRequestOption() RequestOption {
|
||||
return func(r *http.Request) {
|
||||
q := r.URL.Query()
|
||||
q.Set("include_all", fmt.Sprintf("%t", f.IncludeAll))
|
||||
q.Set("include_expired", fmt.Sprintf("%t", f.IncludeExpired))
|
||||
r.URL.RawQuery = q.Encode()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -105,8 +105,13 @@ expires the token, (soft-delete):
|
||||
coder tokens remove <name|id>
|
||||
```
|
||||
|
||||
Expired tokens can no longer be used for authentication but remain visible in
|
||||
token listings.
|
||||
Expired tokens can no longer be used for authentication and are hidden from
|
||||
token listings by default. To include expired tokens, use the
|
||||
`--include-expired` flag:
|
||||
|
||||
```console
|
||||
coder tokens list --include-expired
|
||||
```
|
||||
|
||||
To hard-delete a token, use the `--delete` flag:
|
||||
|
||||
|
||||
Generated
+4
-3
@@ -746,9 +746,10 @@ curl -X GET http://coder-server:8080/api/v2/users/{user}/keys/tokens \
|
||||
|
||||
### Parameters
|
||||
|
||||
| Name | In | Type | Required | Description |
|
||||
|--------|------|--------|----------|----------------------|
|
||||
| `user` | path | string | true | User ID, name, or me |
|
||||
| Name | In | Type | Required | Description |
|
||||
|-------------------|-------|---------|----------|------------------------------------|
|
||||
| `user` | path | string | true | User ID, name, or me |
|
||||
| `include_expired` | query | boolean | false | Include expired tokens in the list |
|
||||
|
||||
### Example responses
|
||||
|
||||
|
||||
Generated
+1
@@ -5558,6 +5558,7 @@ export interface TokenConfig {
|
||||
// From codersdk/apikey.go
|
||||
export interface TokensFilter {
|
||||
readonly include_all: boolean;
|
||||
readonly include_expired: boolean;
|
||||
}
|
||||
|
||||
// From codersdk/deployment.go
|
||||
|
||||
@@ -27,6 +27,7 @@ const TokensPage: FC = () => {
|
||||
// we currently do not show all tokens in the UI, even if
|
||||
// the user has read all permissions
|
||||
include_all: false,
|
||||
include_expired: false,
|
||||
});
|
||||
|
||||
return (
|
||||
|
||||
@@ -8,11 +8,14 @@ import {
|
||||
} from "react-query";
|
||||
|
||||
// Load all tokens
|
||||
export const useTokensData = ({ include_all }: TokensFilter) => {
|
||||
const queryKey = ["tokens", include_all];
|
||||
export const useTokensData = ({
|
||||
include_all,
|
||||
include_expired,
|
||||
}: TokensFilter) => {
|
||||
const queryKey = ["tokens", include_all, include_expired];
|
||||
const result = useQuery({
|
||||
queryKey,
|
||||
queryFn: () => API.getTokens({ include_all }),
|
||||
queryFn: () => API.getTokens({ include_all, include_expired }),
|
||||
});
|
||||
|
||||
return {
|
||||
|
||||
Reference in New Issue
Block a user