diff --git a/.github/workflows/precommit.yml b/.github/workflows/precommit.yml index 1f0e667..7500c9c 100644 --- a/.github/workflows/precommit.yml +++ b/.github/workflows/precommit.yml @@ -13,7 +13,7 @@ jobs: - uses: actions/setup-python@v2 - uses: actions/setup-go@v2 with: - go-version: '1.17' + go-version: '1.24' - name: Install goimports run: go install golang.org/x/tools/cmd/goimports@latest - - uses: pre-commit/action@v2.0.0 + - uses: pre-commit/action@v3.0.1 diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 394387d..57343ef 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -18,7 +18,7 @@ jobs: name: Set up Go uses: actions/setup-go@v2 with: - go-version: 1.17 + go-version: 1.24 - name: Login to Public ECR env: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f44bf91..0dc6791 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/dnephin/pre-commit-golang - rev: master + rev: v0.5.1 hooks: - id: go-fmt - id: go-vet diff --git a/cmd/credential_process.go b/cmd/credential_process.go index 0bc417b..96ff346 100644 --- a/cmd/credential_process.go +++ b/cmd/credential_process.go @@ -156,6 +156,6 @@ func printCredentialProcess(credentials *aws.Credentials) error { logging.LogError(err, "Error parsing credential response") return err } - fmt.Printf(string(b)) + fmt.Print(string(b)) return nil } diff --git a/go.mod b/go.mod index 50a4c80..e9e6a70 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/netflix/weep -go 1.17 +go 1.24 require ( github.com/aws/aws-sdk-go v1.40.39 diff --git a/pkg/config/config.go b/pkg/config/config.go index f0699b6..f30416a 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -73,6 +73,7 @@ func getDefaultLogFile() string { // - /etc/weep/weep.yaml // - ~/.weep/weep.yaml // - ./weep.yaml +// // If a config file is specified via CLI arg, it will be read exclusively and not merged with other // configuration. func InitConfig(filename string) error { diff --git a/pkg/creds/consoleme.go b/pkg/creds/consoleme.go index d5f7ee4..be706f3 100644 --- a/pkg/creds/consoleme.go +++ b/pkg/creds/consoleme.go @@ -273,7 +273,7 @@ func parseWebError(rawErrorResponse []byte) error { if err := json.Unmarshal(rawErrorResponse, &errorResponse); err != nil { return errors.Wrap(err, "failed to unmarshal JSON") } - return fmt.Errorf(strings.Join(errorResponse.Errors, "\n")) + return errors.New(strings.Join(errorResponse.Errors, "\n")) } func parseError(statusCode int, rawErrorResponse []byte) error { diff --git a/pkg/creds/refreshable.go b/pkg/creds/refreshable.go index 363f498..3b47fbb 100644 --- a/pkg/creds/refreshable.go +++ b/pkg/creds/refreshable.go @@ -17,7 +17,7 @@ package creds import ( - "fmt" + stdErrors "errors" "strings" "time" @@ -107,7 +107,7 @@ func (rp *RefreshableProvider) refresh() error { // The http.Client, with the best of intentions, will hold the connection open, // meaning that an auto-updated cert won't be used by the client. rp.client.CloseIdleConnections() - return fmt.Errorf(viper.GetString("mtls_settings.old_cert_message")) + return stdErrors.New(viper.GetString("mtls_settings.old_cert_message")) } else { return err } diff --git a/pkg/errors/errors.go b/pkg/errors/errors.go index ba23d64..5318485 100644 --- a/pkg/errors/errors.go +++ b/pkg/errors/errors.go @@ -23,6 +23,7 @@ func (e Error) Error() string { return string(e) } const ( NoCredentialsFoundInCache = Error("no credentials found in cache") NoTokenFoundInCache = Error("no token found in cache") + InvalidTokenFoundInCache = Error("token's cached attributes are invalid") NoDefaultRoleSet = Error("no default role set") BrowserOpenError = Error("could not launch browser, open link manually") CredentialRetrievalError = Error("failed to retrieve credentials from broker") diff --git a/pkg/session/cache.go b/pkg/session/cache.go index 340bec1..5dc9279 100644 --- a/pkg/session/cache.go +++ b/pkg/session/cache.go @@ -14,8 +14,7 @@ import ( const letters = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz-" type tokenCache struct { - sync.RWMutex - TokenMap + sync.Map } type tokenAttributes struct { @@ -40,9 +39,7 @@ func randomString(n int) string { } func createCache() *tokenCache { - c := &tokenCache{ - TokenMap: make(map[string]*tokenAttributes), - } + c := &tokenCache{} go c.startWatcher() return c } @@ -59,12 +56,18 @@ func (c *tokenCache) startWatcher() { } func (c *tokenCache) clean() { - for token, attr := range c.TokenMap { - if attr.Expiration.Before(time.Now()) { - logging.Log.Debugf("deleting token with expiration %v", attr.Expiration) - c.delete(token) + c.Range(func(key, value interface{}) bool { + if attr, ok := value.(*tokenAttributes); ok { + if attr.Expiration.Before(time.Now()) { + logging.Log.Debugf("deleting token with expiration %v", attr.Expiration) + c.Delete(key) + } + } else { + logging.Log.Debugf("deleting token with invalid attributes %v", value) + c.Delete(key) } - } + return true + }) } func (c *tokenCache) generateToken(role string, ttlSeconds int) string { @@ -74,27 +77,19 @@ func (c *tokenCache) generateToken(role string, ttlSeconds int) string { } func (c *tokenCache) checkToken(token string) (bool, int) { - attr, err := sessions.Get(token) + attr, err := c.Get(token) if err != nil { - logging.Log.Warning("invalid session token") + logging.Log.Warningf("invalid session token: %v", err) return false, 0 } if attr.Expiration.Before(time.Now()) { logging.Log.Warning("session token is expired") return false, 0 } - remainingTtl := time.Now().Sub(attr.Expiration) + remainingTtl := attr.Expiration.Sub(time.Now()) return true, int(remainingTtl.Seconds()) } -func (c *tokenCache) delete(token string) { - c.Lock() - defer c.Unlock() - if _, ok := c.TokenMap[token]; ok { - delete(c.TokenMap, token) - } -} - func (c *tokenCache) Set(token, role string, ttl int) { expiration := time.Now().Add(time.Duration(ttl) * time.Second) attr := tokenAttributes{ @@ -102,17 +97,19 @@ func (c *tokenCache) Set(token, role string, ttl int) { Expiration: expiration, Role: role, } - c.Lock() - defer c.Unlock() - c.TokenMap[token] = &attr + c.Store(token, &attr) } func (c *tokenCache) Get(token string) (*tokenAttributes, error) { - c.RLock() - defer c.RUnlock() - attr, ok := c.TokenMap[token] + var value interface{} + var ok bool + var attr *tokenAttributes + value, ok = c.Load(token) if !ok { return nil, errors.NoTokenFoundInCache } - return attr, nil + if attr, ok = value.(*tokenAttributes); ok { + return attr, nil + } + return nil, errors.InvalidTokenFoundInCache } diff --git a/pkg/session/cache_test.go b/pkg/session/cache_test.go new file mode 100644 index 0000000..850ba97 --- /dev/null +++ b/pkg/session/cache_test.go @@ -0,0 +1,388 @@ +package session + +import ( + "testing" + "time" +) + +func TestRandomString(t *testing.T) { + t.Run("generates string of correct length", func(t *testing.T) { + length := 64 + result := randomString(length) + if len(result) != length { + t.Errorf("expected length %d, got %d", length, len(result)) + } + }) + + t.Run("generates different strings on multiple calls", func(t *testing.T) { + result1 := randomString(64) + result2 := randomString(64) + if result1 == result2 { + t.Error("expected different strings but got the same") + } + }) + + t.Run("only contains valid characters", func(t *testing.T) { + result := randomString(100) + for _, char := range result { + found := false + for _, validChar := range letters { + if char == validChar { + found = true + break + } + } + if !found { + t.Errorf("invalid character %c in random string", char) + } + } + }) + + t.Run("handles zero length", func(t *testing.T) { + result := randomString(0) + if len(result) != 0 { + t.Errorf("expected empty string, got length %d", len(result)) + } + }) +} + +func TestCreateCache(t *testing.T) { + t.Run("creates a new cache", func(t *testing.T) { + cache := createCache() + if cache == nil { + t.Error("expected non-nil cache") + } + }) +} + +func TestTokenCache_Set(t *testing.T) { + t.Run("sets a token with attributes", func(t *testing.T) { + cache := &tokenCache{} + token := "test-token" + role := "test-role" + ttl := 3600 + + cache.Set(token, role, ttl) + + attr, err := cache.Get(token) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if attr.Role != role { + t.Errorf("expected role %s, got %s", role, attr.Role) + } + if attr.InitialTtl != ttl { + t.Errorf("expected ttl %d, got %d", ttl, attr.InitialTtl) + } + if !attr.Expiration.After(time.Now()) { + t.Error("expected expiration to be in the future") + } + }) + + t.Run("sets expiration time correctly", func(t *testing.T) { + cache := &tokenCache{} + token := "test-token" + role := "test-role" + ttl := 10 + + before := time.Now() + cache.Set(token, role, ttl) + after := time.Now() + + attr, err := cache.Get(token) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + expectedMin := before.Add(time.Duration(ttl) * time.Second) + expectedMax := after.Add(time.Duration(ttl) * time.Second) + + if attr.Expiration.Before(expectedMin) { + t.Errorf("expiration %v is before expected minimum %v", attr.Expiration, expectedMin) + } + if attr.Expiration.After(expectedMax) { + t.Errorf("expiration %v is after expected maximum %v", attr.Expiration, expectedMax) + } + }) +} + +func TestTokenCache_Get(t *testing.T) { + t.Run("retrieves existing token", func(t *testing.T) { + cache := &tokenCache{} + token := "test-token" + role := "test-role" + ttl := 3600 + + cache.Set(token, role, ttl) + + attr, err := cache.Get(token) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if attr == nil { + t.Error("expected non-nil attributes") + } + if attr.Role != role { + t.Errorf("expected role %s, got %s", role, attr.Role) + } + if attr.InitialTtl != ttl { + t.Errorf("expected ttl %d, got %d", ttl, attr.InitialTtl) + } + }) + + t.Run("returns error for non-existent token", func(t *testing.T) { + cache := &tokenCache{} + + attr, err := cache.Get("non-existent-token") + if err == nil { + t.Error("expected error for non-existent token") + } + if attr != nil { + t.Error("expected nil attributes for non-existent token") + } + }) + + t.Run("handles concurrent access", func(t *testing.T) { + cache := &tokenCache{} + token := "test-token" + role := "test-role" + ttl := 3600 + + cache.Set(token, role, ttl) + + done := make(chan bool) + errors := make(chan error, 10) + + for i := 0; i < 10; i++ { + go func() { + attr, err := cache.Get(token) + if err != nil { + errors <- err + } else if attr.Role != role { + errors <- err + } + done <- true + }() + } + + for i := 0; i < 10; i++ { + <-done + } + + close(errors) + for err := range errors { + if err != nil { + t.Errorf("concurrent access error: %v", err) + } + } + }) +} + +func TestTokenCache_GenerateToken(t *testing.T) { + t.Run("generates token with correct length", func(t *testing.T) { + cache := &tokenCache{} + role := "test-role" + ttl := 3600 + + token := cache.generateToken(role, ttl) + if len(token) != 64 { + t.Errorf("expected token length 64, got %d", len(token)) + } + }) + + t.Run("stores token in cache", func(t *testing.T) { + cache := &tokenCache{} + role := "test-role" + ttl := 3600 + + token := cache.generateToken(role, ttl) + + attr, err := cache.Get(token) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if attr.Role != role { + t.Errorf("expected role %s, got %s", role, attr.Role) + } + if attr.InitialTtl != ttl { + t.Errorf("expected ttl %d, got %d", ttl, attr.InitialTtl) + } + }) + + t.Run("generates unique tokens", func(t *testing.T) { + cache := &tokenCache{} + role := "test-role" + ttl := 3600 + + token1 := cache.generateToken(role, ttl) + token2 := cache.generateToken(role, ttl) + + if token1 == token2 { + t.Error("expected unique tokens but got duplicates") + } + }) +} + +func TestTokenCache_CheckToken(t *testing.T) { + t.Run("validates existing valid token", func(t *testing.T) { + cache := &tokenCache{} + token := "test-token" + role := "test-role" + ttl := 3600 + + cache.Set(token, role, ttl) + + valid, remainingTtl := cache.checkToken(token) + if !valid { + t.Error("expected token to be valid") + } + if remainingTtl < 3595 || remainingTtl > 3600 { + t.Errorf("expected remaining TTL to be slightly less than 3600, got %d", remainingTtl) + } + }) + + t.Run("rejects non-existent token", func(t *testing.T) { + cache := &tokenCache{} + + valid, remainingTtl := cache.checkToken("non-existent-token") + if valid { + t.Error("expected token to be invalid") + } + if remainingTtl != 0 { + t.Errorf("expected remainingTtl 0, got %d", remainingTtl) + } + }) + + t.Run("rejects expired token", func(t *testing.T) { + cache := &tokenCache{} + token := "test-token" + role := "test-role" + ttl := -1 // Expired immediately + + cache.Set(token, role, ttl) + time.Sleep(10 * time.Millisecond) // Ensure it's expired + + valid, remainingTtl := cache.checkToken(token) + if valid { + t.Error("expected token to be invalid") + } + if remainingTtl != 0 { + t.Errorf("expected remainingTtl 0, got %d", remainingTtl) + } + }) +} + +func TestTokenCache_Clean(t *testing.T) { + t.Run("removes expired tokens", func(t *testing.T) { + cache := &tokenCache{} + expiredToken := "expired-token" + validToken := "valid-token" + + cache.Set(expiredToken, "role1", -1) // Already expired + cache.Set(validToken, "role2", 3600) // Valid for an hour + + time.Sleep(10 * time.Millisecond) // Ensure expiration check works + + cache.clean() + + // Expired token should be removed + _, err := cache.Get(expiredToken) + if err == nil { + t.Error("expected expired token to be removed") + } + + // Valid token should still exist + attr, err := cache.Get(validToken) + if err != nil { + t.Errorf("unexpected error for valid token: %v", err) + } + if attr == nil { + t.Error("expected valid token to still exist") + } + }) + + t.Run("removes tokens with invalid attributes", func(t *testing.T) { + cache := &tokenCache{} + token := "test-token" + + // Store invalid data + cache.Store(token, "not-a-pointer") + + cache.clean() + + // Token should be removed + _, err := cache.Get(token) + if err == nil { + t.Error("expected token with invalid attributes to be removed") + } + }) + + t.Run("handles empty cache", func(t *testing.T) { + cache := &tokenCache{} + + // Should not panic + defer func() { + if r := recover(); r != nil { + t.Errorf("clean panicked on empty cache: %v", r) + } + }() + + cache.clean() + }) +} + +func TestTokenCache_Concurrency(t *testing.T) { + t.Run("handles concurrent Set and Get operations", func(t *testing.T) { + cache := &tokenCache{} + done := make(chan bool) + + // Concurrent writes + for i := 0; i < 10; i++ { + go func(id int) { + cache.Set("token-"+string(rune(id)), "role", 3600) + done <- true + }(i) + } + + // Wait for writes + for i := 0; i < 10; i++ { + <-done + } + + // Concurrent reads + for i := 0; i < 10; i++ { + go func(id int) { + cache.Get("token-" + string(rune(id))) + done <- true + }(i) + } + + // Wait for reads + for i := 0; i < 10; i++ { + <-done + } + }) + + t.Run("handles concurrent clean operations", func(t *testing.T) { + cache := &tokenCache{} + done := make(chan bool) + + // Add some tokens + for i := 0; i < 10; i++ { + cache.Set("token-"+string(rune(i)), "role", 3600) + } + + // Run clean concurrently + for i := 0; i < 5; i++ { + go func() { + cache.clean() + done <- true + }() + } + + // Wait for all cleans + for i := 0; i < 5; i++ { + <-done + } + }) +}