Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions access_key.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@ import (
"context"
"errors"
"fmt"
"net/http"
"strings"

"crypto/rand"
"encoding/binary"

"github.com/go-chi/transport"
"github.com/goware/base64"
"github.com/jxskiss/base62"
)
Expand Down Expand Up @@ -57,6 +59,18 @@ func GetAccessKeyPrefix(accessKey string) string {
return strings.Join(parts[:len(parts)-1], Separator)
}

func ForwardAccessKeyTransport(next http.RoundTripper) http.RoundTripper {
return transport.RoundTripFunc(func(req *http.Request) (resp *http.Response, err error) {
r := transport.CloneRequest(req)

if accessKey, ok := GetAccessKey(req.Context()); ok {
r.Header.Set(HeaderAccessKey, accessKey)
}

return next.RoundTrip(r)
})
}

type Encoding interface {
Version() byte
Encode(ctx context.Context, projectID uint64) string
Expand Down
33 changes: 33 additions & 0 deletions access_key_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package authcontrol_test

import (
"context"
"net/http"
"net/http/httptest"
"testing"

"github.com/0xsequence/authcontrol"
Expand Down Expand Up @@ -57,3 +59,34 @@ func TestDecode(t *testing.T) {
accessKey := authcontrol.GenerateAccessKey(ctx, 237)
t.Log("=> k", accessKey, "| prefix =>", authcontrol.GetAccessKeyPrefix(accessKey))
}

func TestForwardAccessKeyTransport(t *testing.T) {
// Create a test server that captures the request headers
var capturedHeaders http.Header
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedHeaders = r.Header.Clone()
w.WriteHeader(http.StatusOK)
}))
defer server.Close()

// Create context with access key
accessKey := "test-access-key-123"
ctx := authcontrol.WithAccessKey(context.Background(), accessKey)

// Create HTTP client with ForwardAccessKeyTransport
client := &http.Client{
Transport: authcontrol.ForwardAccessKeyTransport(http.DefaultTransport),
}

// Create request with the context
req, err := http.NewRequestWithContext(ctx, "GET", server.URL, nil)
require.NoError(t, err)

// Make the request
resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()

// Verify the access key header was set
require.Equal(t, accessKey, capturedHeaders.Get(authcontrol.HeaderAccessKey))
}