diff --git a/access_key.go b/access_key.go index 33cf9bd..99a23ec 100644 --- a/access_key.go +++ b/access_key.go @@ -4,11 +4,13 @@ import ( "context" "errors" "fmt" + "net/http" "strings" "crypto/rand" "encoding/binary" + "github.com/go-chi/transport" "github.com/goware/base64" "github.com/jxskiss/base62" ) @@ -57,6 +59,18 @@ func GetAccessKeyPrefix(accessKey string) string { return strings.Join(parts[:len(parts)-1], Separator) } +func ForwardAccessKeyTransport(next http.RoundTripper) http.RoundTripper { + return transport.RoundTripFunc(func(req *http.Request) (resp *http.Response, err error) { + r := transport.CloneRequest(req) + + if accessKey, ok := GetAccessKey(req.Context()); ok { + r.Header.Set(HeaderAccessKey, accessKey) + } + + return next.RoundTrip(r) + }) +} + type Encoding interface { Version() byte Encode(ctx context.Context, projectID uint64) string diff --git a/access_key_test.go b/access_key_test.go index 6f68b58..4f98249 100644 --- a/access_key_test.go +++ b/access_key_test.go @@ -2,6 +2,8 @@ package authcontrol_test import ( "context" + "net/http" + "net/http/httptest" "testing" "github.com/0xsequence/authcontrol" @@ -57,3 +59,34 @@ func TestDecode(t *testing.T) { accessKey := authcontrol.GenerateAccessKey(ctx, 237) t.Log("=> k", accessKey, "| prefix =>", authcontrol.GetAccessKeyPrefix(accessKey)) } + +func TestForwardAccessKeyTransport(t *testing.T) { + // Create a test server that captures the request headers + var capturedHeaders http.Header + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedHeaders = r.Header.Clone() + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + // Create context with access key + accessKey := "test-access-key-123" + ctx := authcontrol.WithAccessKey(context.Background(), accessKey) + + // Create HTTP client with ForwardAccessKeyTransport + client := &http.Client{ + Transport: authcontrol.ForwardAccessKeyTransport(http.DefaultTransport), + } + + // Create request with the context + req, err := http.NewRequestWithContext(ctx, "GET", server.URL, nil) + require.NoError(t, err) + + // Make the request + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + // Verify the access key header was set + require.Equal(t, accessKey, capturedHeaders.Get(authcontrol.HeaderAccessKey)) +}