From 44f7a3ca90e856fa800a79814acfaaacbcda3754 Mon Sep 17 00:00:00 2001 From: Omar Jarjur Date: Wed, 3 Dec 2025 13:18:39 -0800 Subject: [PATCH 1/5] [agent] Add a test for the websocket-shim protocol handlers --- agent/websockets/websockets_test.go | 155 ++++++++++++++++++++++++++++ 1 file changed, 155 insertions(+) diff --git a/agent/websockets/websockets_test.go b/agent/websockets/websockets_test.go index b2310c2..d6bca59 100644 --- a/agent/websockets/websockets_test.go +++ b/agent/websockets/websockets_test.go @@ -23,6 +23,8 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "github.com/gorilla/websocket" + + "github.com/google/inverting-proxy/agent/metrics" ) func TestInjectWebsocketMessage(t *testing.T) { @@ -184,3 +186,156 @@ func TestInjectWebsocketMessage(t *testing.T) { }) } } + +func TestShimHandlers(t *testing.T) { + testShimPath := "/websocket-shim" + h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !websocket.IsWebSocketUpgrade(r) { + http.Error(w, "only websocket connections are supported", http.StatusBadRequest) + } + upgrader := websocket.Upgrader{} + conn, err := upgrader.Upgrade(w, r, http.Header{}) + if err != nil { + t.Fatalf("Failure upgrading the websocket connection: %+v", err) + } + defer conn.Close() + for { + messageType, msg, err := conn.ReadMessage() + var closeError *websocket.CloseError + if errors.As(err, &closeError) && closeError.Code == websocket.CloseNormalClosure { + return + } + if err != nil { + t.Logf("Error returned from ReadMessage: %+v", err) + return + } + if err := conn.WriteMessage(messageType, msg); err != nil { + t.Logf("Error returned from WriteMessage: %+v", err) + return + } + } + }) + server := httptest.NewServer(h) + defer server.Close() + serverURL, err := url.Parse(server.URL) + if err != nil { + t.Fatal(err) + } + openWrapper := func(h http.Handler, metricHandler *metrics.MetricHandler) http.Handler { + return h + } + p, err := Proxy(t.Context(), h, serverURL.Host, testShimPath, false, false, openWrapper, nil) + if err != nil { + t.Fatalf("Failure creating the websocket shim proxy: %+v", err) + } + proxyServer := httptest.NewServer(p) + defer proxyServer.Close() + + openConnection := func() (string, error) { + openResp, err := proxyServer.Client().Post(proxyServer.URL+path.Join(testShimPath, "open"), "", strings.NewReader(server.URL)) + if err != nil { + return "", fmt.Errorf("failure opening the shimmed websocket connection: %+v", err) + } + defer openResp.Body.Close() + respBytes, err := io.ReadAll(openResp.Body) + if err != nil { + return "", fmt.Errorf("failure reading the open response: %+v", err) + } + var openedSession sessionMessage + if err := json.Unmarshal(respBytes, &openedSession); err != nil { + return "", fmt.Errorf("failure parsing the open response: %+v", err) + } + return openedSession.ID, nil + } + + closeConnection := func(sessionID string) error { + closeMessage := &sessionMessage{ + ID: sessionID, + } + closeBytes, err := json.Marshal(closeMessage) + if err != nil { + return fmt.Errorf("failure marshalling the close session message: %+v", err) + } + closeResp, err := proxyServer.Client().Post(proxyServer.URL+path.Join(testShimPath, "close"), "", bytes.NewReader(closeBytes)) + if err != nil { + return fmt.Errorf("failure closing the shimmed websocket connection: %+v", err) + } + defer closeResp.Body.Close() + if got, want := closeResp.StatusCode, http.StatusOK; got != want { + return fmt.Errorf("unexpected close response status code: got %v, want %v", got, want) + } + return nil + } + + sendMessage := func(sessionID, message string) error { + dataMessages := []*sessionMessage{ + &sessionMessage{ + ID: sessionID, + Message: message, + }, + } + dataBytes, err := json.Marshal(dataMessages) + if err != nil { + return fmt.Errorf("failure marshalling the data session messages: %+v", err) + } + dataResp, err := proxyServer.Client().Post(proxyServer.URL+path.Join(testShimPath, "data"), "", bytes.NewReader(dataBytes)) + if err != nil { + return fmt.Errorf("failure writing to the shimmed websocket connection: %+v", err) + } + defer dataResp.Body.Close() + if got, want := dataResp.StatusCode, http.StatusOK; got != want { + return fmt.Errorf("unexpected data response status code: got %v, want %v", got, want) + } + return nil + } + + readMessage := func(sessionID string) (string, error) { + pollMessage := &sessionMessage{ + ID: sessionID, + } + pollBytes, err := json.Marshal(pollMessage) + if err != nil { + return "", fmt.Errorf("failure marshalling the poll session messages: %+v", err) + } + pollResp, err := proxyServer.Client().Post(proxyServer.URL+path.Join(testShimPath, "poll"), "", bytes.NewReader(pollBytes)) + if err != nil { + return "", fmt.Errorf("failure reading from the shimmed websocket connection: %+v", err) + } + defer pollResp.Body.Close() + if got, want := pollResp.StatusCode, http.StatusOK; got != want { + return "", fmt.Errorf("unexpected poll response status code: got %v, want %v", got, want) + } + respBytes, err := io.ReadAll(pollResp.Body) + if err != nil { + return "", fmt.Errorf("failure reading the poll response: %+v", err) + } + var readMessages []string + if err := json.Unmarshal(respBytes, &readMessages); err != nil { + return "", fmt.Errorf("failure parsing the poll response: %+v", err) + } + if got, want := len(readMessages), 1; got != want { + return "", fmt.Errorf("unexpected number of websocket messages read; got %d, want %d", got, want) + } + return readMessages[0], nil + } + + for i := 0; i < 100; i++ { + sessionID, err := openConnection() + if err != nil { + t.Errorf("Failure opening the connection: %v", err) + } + for j := 0; j < 100; j++ { + msg := fmt.Sprintf("connection #%d, message #%d", i, j) + if err := sendMessage(sessionID, msg); err != nil { + t.Errorf("Failure sending message %q: %v", msg, err) + } else if readMsg, err := readMessage(sessionID); err != nil { + t.Errorf("Failure reading back the message %q: %v", msg, err) + } else if got, want := readMsg, msg; got != want { + t.Errorf("Unexpected message read back; got %q, want %q", got, want) + } + } + if err := closeConnection(sessionID); err != nil { + t.Errorf("Failure closing the connection: %v", err) + } + } +} From 85730e3498e9f157c1161632e7db862f150dc586 Mon Sep 17 00:00:00 2001 From: Omar Jarjur Date: Wed, 3 Dec 2025 13:22:17 -0800 Subject: [PATCH 2/5] [agent] Fix the missing imports from websockets_test.go --- agent/websockets/websockets_test.go | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/agent/websockets/websockets_test.go b/agent/websockets/websockets_test.go index d6bca59..6c4e637 100644 --- a/agent/websockets/websockets_test.go +++ b/agent/websockets/websockets_test.go @@ -17,7 +17,16 @@ limitations under the License. package websockets import ( + "bytes" "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/http/httptest" + "net/url" + "path" + "strings" "testing" "github.com/google/go-cmp/cmp" From 8a3e62f9a3043349f23bc8a5a5aaeddf2ea29d5d Mon Sep 17 00:00:00 2001 From: Omar Jarjur Date: Wed, 3 Dec 2025 13:26:34 -0800 Subject: [PATCH 3/5] [agent] Fix the websockets_test.go test to not require the latest version of Go --- agent/websockets/websockets_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/agent/websockets/websockets_test.go b/agent/websockets/websockets_test.go index 6c4e637..613be2f 100644 --- a/agent/websockets/websockets_test.go +++ b/agent/websockets/websockets_test.go @@ -233,7 +233,7 @@ func TestShimHandlers(t *testing.T) { openWrapper := func(h http.Handler, metricHandler *metrics.MetricHandler) http.Handler { return h } - p, err := Proxy(t.Context(), h, serverURL.Host, testShimPath, false, false, openWrapper, nil) + p, err := Proxy(context.Background(), h, serverURL.Host, testShimPath, false, false, openWrapper, nil) if err != nil { t.Fatalf("Failure creating the websocket shim proxy: %+v", err) } From adfc3e76d0ec0dfb32a52abf4dbddc24f5f5d5d5 Mon Sep 17 00:00:00 2001 From: Omar Jarjur Date: Wed, 3 Dec 2025 13:28:31 -0800 Subject: [PATCH 4/5] [agent] Fix missing import in websockets_test.go --- agent/websockets/websockets_test.go | 1 + 1 file changed, 1 insertion(+) diff --git a/agent/websockets/websockets_test.go b/agent/websockets/websockets_test.go index 613be2f..1eefa8f 100644 --- a/agent/websockets/websockets_test.go +++ b/agent/websockets/websockets_test.go @@ -18,6 +18,7 @@ package websockets import ( "bytes" + "context" "encoding/json" "errors" "fmt" From 5ed933ab681b518d535e122bfd42e6552e7114ac Mon Sep 17 00:00:00 2001 From: Omar Jarjur Date: Fri, 19 Dec 2025 17:36:39 -0800 Subject: [PATCH 5/5] [agent] Ensure the websockets test does not attempt to log after the test completes --- agent/websockets/websockets_test.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/agent/websockets/websockets_test.go b/agent/websockets/websockets_test.go index 1eefa8f..0d674a3 100644 --- a/agent/websockets/websockets_test.go +++ b/agent/websockets/websockets_test.go @@ -28,6 +28,7 @@ import ( "net/url" "path" "strings" + "sync" "testing" "github.com/google/go-cmp/cmp" @@ -199,7 +200,11 @@ func TestInjectWebsocketMessage(t *testing.T) { func TestShimHandlers(t *testing.T) { testShimPath := "/websocket-shim" + var wg sync.WaitGroup + defer wg.Wait() h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + wg.Add(1) + defer wg.Done() if !websocket.IsWebSocketUpgrade(r) { http.Error(w, "only websocket connections are supported", http.StatusBadRequest) }