diff --git a/src/mod/dynamicproxy/modh2c/modh2c.go b/src/mod/dynamicproxy/modh2c/modh2c.go index 8bf005b..6023daa 100644 --- a/src/mod/dynamicproxy/modh2c/modh2c.go +++ b/src/mod/dynamicproxy/modh2c/modh2c.go @@ -28,11 +28,14 @@ func (h2c *H2CRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) defer cancel() - req, err := http.NewRequestWithContext(ctx, req.Method, req.RequestURI, nil) + req, err := http.NewRequestWithContext(ctx, req.Method, req.URL.String(), req.Body) if err != nil { return nil, err } + // Copy headers + req.Header = req.Header.Clone() + tr := &http2.Transport{ AllowHTTP: true, DialTLSContext: func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) { @@ -43,3 +46,20 @@ func (h2c *H2CRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) return tr.RoundTrip(req) } + +func (h2c *H2CRoundTripper) CheckServerSupportsH2C(serverURL string) bool { + req, err := http.NewRequest("GET", serverURL, nil) + if err != nil { + return false + } + + tr := &http2.Transport{ + AllowHTTP: true, + DialTLSContext: func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) { + var d net.Dialer + return d.DialContext(ctx, network, addr) + }, + } + _, err = tr.RoundTrip(req) + return err == nil +} diff --git a/src/mod/dynamicproxy/modh2c/modh2c_test.go b/src/mod/dynamicproxy/modh2c/modh2c_test.go new file mode 100644 index 0000000..b9a5615 --- /dev/null +++ b/src/mod/dynamicproxy/modh2c/modh2c_test.go @@ -0,0 +1,83 @@ +package modh2c + +import ( + "net/http" + "net/http/httptest" + "testing" + + "golang.org/x/net/http2" + "golang.org/x/net/http2/h2c" +) + +func TestH2CRoundTripper_RoundTrip(t *testing.T) { + // Create a test server that supports h2c + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + if r.Proto != "HTTP/2.0" { + t.Errorf("Expected HTTP/2.0, got %s", r.Proto) + } + w.WriteHeader(http.StatusOK) + w.Write([]byte("Hello, h2c!")) + }) + + server := httptest.NewServer(h2c.NewHandler(mux, &http2.Server{})) + defer server.Close() + + // Create the round tripper + rt := NewH2CRoundTripper() + + // Create a request + req, err := http.NewRequest("GET", server.URL, nil) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + + // Perform the round trip + resp, err := rt.RoundTrip(req) + if err != nil { + t.Fatalf("RoundTrip failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("Expected status 200, got %d", resp.StatusCode) + } + + // Check the response body + body := make([]byte, 1024) + n, err := resp.Body.Read(body) + if err != nil && err.Error() != "EOF" { + t.Fatalf("Failed to read body: %v", err) + } + if string(body[:n]) != "Hello, h2c!" { + t.Errorf("Expected 'Hello, h2c!', got '%s'", string(body[:n])) + } +} + +func TestH2CRoundTripper_CheckServerSupportsH2C(t *testing.T) { + // Test with h2c server + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + server := httptest.NewServer(h2c.NewHandler(mux, &http2.Server{})) + defer server.Close() + + rt := NewH2CRoundTripper() + supports := rt.CheckServerSupportsH2C(server.URL) + if !supports { + t.Error("Expected server to support h2c") + } + + // Test with non-h2c server (regular HTTP/1.1) + server2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer server2.Close() + + supports2 := rt.CheckServerSupportsH2C(server2.URL) + if supports2 { + t.Error("Expected server to not support h2c") + } +}