diff --git a/src/go.mod b/src/go.mod index 3b0d5e4..804e62a 100644 --- a/src/go.mod +++ b/src/go.mod @@ -18,33 +18,16 @@ require ( github.com/microcosm-cc/bluemonday v1.0.26 github.com/monperrus/crawler-user-agents v1.1.0 github.com/shirou/gopsutil/v4 v4.25.1 + github.com/stretchr/testify v1.10.0 github.com/syndtr/goleveldb v1.0.0 golang.org/x/net v0.33.0 + golang.org/x/oauth2 v0.24.0 golang.org/x/text v0.21.0 ) require ( cloud.google.com/go/auth v0.13.0 // indirect cloud.google.com/go/auth/oauth2adapt v0.2.6 // indirect - github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resourcegraph/armresourcegraph v0.9.0 // indirect - github.com/benbjohnson/clock v1.3.0 // indirect - github.com/ebitengine/purego v0.8.2 // indirect - github.com/go-ole/go-ole v1.2.6 // indirect - github.com/golang-jwt/jwt/v5 v5.2.1 // indirect - github.com/golang/snappy v0.0.1 // indirect - github.com/huaweicloud/huaweicloud-sdk-go-v3 v0.1.128 // indirect - github.com/opentracing/opentracing-go v1.2.1-0.20220228012449-10b1cf09e00b // indirect - github.com/peterhellberg/link v1.2.0 // indirect - github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect - github.com/shopspring/decimal v1.3.1 // indirect - github.com/tjfoc/gmsm v1.4.1 // indirect - github.com/vultr/govultr/v3 v3.9.1 // indirect - github.com/yusufpapurcu/wmi v1.2.4 // indirect - go.mongodb.org/mongo-driver v1.12.0 // indirect - golang.org/x/sys v0.28.0 // indirect -) - -require ( cloud.google.com/go/compute/metadata v0.6.0 // indirect github.com/AdamSLevy/jsonrpc2/v14 v14.1.0 // indirect github.com/Azure/azure-sdk-for-go v68.0.0+incompatible // indirect @@ -53,6 +36,7 @@ require ( github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0 // indirect github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/dns/armdns v1.2.0 // indirect github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/privatedns/armprivatedns v1.3.0 // indirect + github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resourcegraph/armresourcegraph v0.9.0 // indirect github.com/Azure/go-autorest v14.2.0+incompatible // indirect github.com/Azure/go-autorest/autorest v0.11.29 // indirect github.com/Azure/go-autorest/autorest/adal v0.9.22 // indirect @@ -82,6 +66,7 @@ require ( github.com/aws/aws-sdk-go-v2/service/sts v1.33.3 // indirect github.com/aws/smithy-go v1.22.1 // indirect github.com/aymerick/douceur v0.2.0 // indirect + github.com/benbjohnson/clock v1.3.0 // indirect github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc // indirect github.com/cenkalti/backoff v2.2.1+incompatible // indirect github.com/cenkalti/backoff/v4 v4.3.0 // indirect @@ -94,6 +79,7 @@ require ( github.com/dnsimple/dnsimple-go v1.7.0 // indirect github.com/docker/go-connections v0.5.0 // indirect github.com/docker/go-units v0.5.0 // indirect + github.com/ebitengine/purego v0.8.2 // indirect github.com/fatih/structs v1.1.0 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/fsnotify/fsnotify v1.8.0 // indirect @@ -102,11 +88,14 @@ require ( github.com/go-jose/go-jose/v4 v4.0.4 // indirect github.com/go-logr/logr v1.4.2 // indirect github.com/go-logr/stdr v1.2.2 // indirect + github.com/go-ole/go-ole v1.2.6 // indirect github.com/go-resty/resty/v2 v2.16.2 // indirect github.com/go-viper/mapstructure/v2 v2.2.1 // indirect github.com/goccy/go-json v0.10.4 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/golang-jwt/jwt/v4 v4.5.1 // indirect + github.com/golang-jwt/jwt/v5 v5.2.1 // indirect + github.com/golang/snappy v0.0.1 // indirect github.com/google/go-querystring v1.1.0 // indirect github.com/google/s2a-go v0.1.8 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.4 // indirect @@ -119,6 +108,7 @@ require ( github.com/hashicorp/go-cleanhttp v0.5.2 // indirect github.com/hashicorp/go-multierror v1.1.1 // indirect github.com/hashicorp/go-retryablehttp v0.7.7 // indirect + github.com/huaweicloud/huaweicloud-sdk-go-v3 v0.1.128 // indirect github.com/iij/doapi v0.0.0-20190504054126-0bbf12d6d7df // indirect github.com/infobloxopen/infoblox-go-client v1.1.1 // indirect github.com/jmespath/go-jmespath v0.4.0 // indirect @@ -155,29 +145,36 @@ require ( github.com/nzdjb/go-metaname v1.0.0 // indirect github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/image-spec v1.1.0 // indirect + github.com/opentracing/opentracing-go v1.2.1-0.20220228012449-10b1cf09e00b // indirect github.com/ovh/go-ovh v1.6.0 // indirect + github.com/peterhellberg/link v1.2.0 // indirect github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect + github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect github.com/pquerna/otp v1.4.0 // indirect github.com/sacloud/api-client-go v0.2.10 // indirect github.com/sacloud/go-http v0.1.8 // indirect github.com/sacloud/iaas-api-go v1.14.0 // indirect github.com/sacloud/packages-go v0.0.10 // indirect github.com/scaleway/scaleway-sdk-go v1.0.0-beta.30 // indirect + github.com/shopspring/decimal v1.3.1 // indirect github.com/sirupsen/logrus v1.9.3 // indirect github.com/smartystreets/go-aws-auth v0.0.0-20180515143844-0c1422d1fdb9 // indirect github.com/softlayer/softlayer-go v1.1.7 // indirect github.com/softlayer/xmlrpc v0.0.0-20200409220501-5f089df7cb7e // indirect github.com/spf13/cast v1.6.0 // indirect - github.com/stretchr/testify v1.10.0 // indirect github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common v1.0.1065 // indirect github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/dnspod v1.0.1065 // indirect + github.com/tjfoc/gmsm v1.4.1 // indirect github.com/transip/gotransip/v6 v6.26.0 // indirect github.com/ultradns/ultradns-go-sdk v1.8.0-20241010134910-243eeec // indirect github.com/vinyldns/go-vinyldns v0.9.16 // indirect + github.com/vultr/govultr/v3 v3.9.1 // indirect github.com/yandex-cloud/go-genproto v0.0.0-20241220122821-aeb3b05efd1c // indirect github.com/yandex-cloud/go-sdk v0.0.0-20241220131134-2393e243c134 // indirect + github.com/yusufpapurcu/wmi v1.2.4 // indirect + go.mongodb.org/mongo-driver v1.12.0 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0 // indirect go.opentelemetry.io/otel v1.29.0 // indirect go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.27.0 // indirect @@ -187,8 +184,8 @@ require ( go.uber.org/ratelimit v0.3.0 // indirect golang.org/x/crypto v0.31.0 // indirect golang.org/x/mod v0.22.0 // indirect - golang.org/x/oauth2 v0.24.0 // indirect golang.org/x/sync v0.10.0 // indirect + golang.org/x/sys v0.28.0 // indirect golang.org/x/time v0.8.0 // indirect golang.org/x/tools v0.28.0 // indirect google.golang.org/api v0.214.0 // indirect diff --git a/src/mod/auth/sso/forward/const.go b/src/mod/auth/sso/forward/const.go index 164ef79..a4e935d 100644 --- a/src/mod/auth/sso/forward/const.go +++ b/src/mod/auth/sso/forward/const.go @@ -11,6 +11,7 @@ const ( DatabaseKeyResponseHeaders = "responseHeaders" DatabaseKeyResponseClientHeaders = "responseClientHeaders" DatabaseKeyRequestHeaders = "requestHeaders" + DatabaseKeyRequestIncludedCookies = "requestIncludedCookies" DatabaseKeyRequestExcludedCookies = "requestExcludedCookies" HeaderXForwardedProto = "X-Forwarded-Proto" diff --git a/src/mod/auth/sso/forward/forward.go b/src/mod/auth/sso/forward/forward.go index 16a84c9..f2b94cf 100644 --- a/src/mod/auth/sso/forward/forward.go +++ b/src/mod/auth/sso/forward/forward.go @@ -3,7 +3,6 @@ package forward import ( "encoding/json" "io" - "net" "net/http" "strings" @@ -28,6 +27,10 @@ type AuthRouterOptions struct { // headers are copied. RequestHeaders []string + // RequestIncludedCookies is a list of cookie keys that if defined will be the only cookies sent in the request to + // the authorization server. + RequestIncludedCookies []string + // RequestExcludedCookies is a list of cookie keys that should be removed from every request sent to the upstream. RequestExcludedCookies []string @@ -47,16 +50,18 @@ func NewAuthRouter(options *AuthRouterOptions) *AuthRouter { //Read settings from database if available. options.Database.Read(DatabaseTable, DatabaseKeyAddress, &options.Address) - responseHeaders, responseClientHeaders, requestHeaders, requestExcludedCookies := "", "", "", "" + responseHeaders, responseClientHeaders, requestHeaders, requestIncludedCookies, requestExcludedCookies := "", "", "", "", "" options.Database.Read(DatabaseTable, DatabaseKeyResponseHeaders, &responseHeaders) options.Database.Read(DatabaseTable, DatabaseKeyResponseClientHeaders, &responseClientHeaders) options.Database.Read(DatabaseTable, DatabaseKeyRequestHeaders, &requestHeaders) + options.Database.Read(DatabaseTable, DatabaseKeyRequestIncludedCookies, &requestIncludedCookies) options.Database.Read(DatabaseTable, DatabaseKeyRequestExcludedCookies, &requestExcludedCookies) options.ResponseHeaders = strings.Split(responseHeaders, ",") options.ResponseClientHeaders = strings.Split(responseClientHeaders, ",") options.RequestHeaders = strings.Split(requestHeaders, ",") + options.RequestIncludedCookies = strings.Split(requestIncludedCookies, ",") options.RequestExcludedCookies = strings.Split(requestExcludedCookies, ",") return &AuthRouter{ @@ -87,6 +92,7 @@ func (ar *AuthRouter) handleOptionsGET(w http.ResponseWriter, r *http.Request) { DatabaseKeyResponseHeaders: ar.options.ResponseHeaders, DatabaseKeyResponseClientHeaders: ar.options.ResponseClientHeaders, DatabaseKeyRequestHeaders: ar.options.RequestHeaders, + DatabaseKeyRequestIncludedCookies: ar.options.RequestIncludedCookies, DatabaseKeyRequestExcludedCookies: ar.options.RequestExcludedCookies, }) @@ -108,6 +114,7 @@ func (ar *AuthRouter) handleOptionsPOST(w http.ResponseWriter, r *http.Request) responseHeaders, _ := utils.PostPara(r, DatabaseKeyResponseHeaders) responseClientHeaders, _ := utils.PostPara(r, DatabaseKeyResponseClientHeaders) requestHeaders, _ := utils.PostPara(r, DatabaseKeyRequestHeaders) + requestIncludedCookies, _ := utils.PostPara(r, DatabaseKeyRequestIncludedCookies) requestExcludedCookies, _ := utils.PostPara(r, DatabaseKeyRequestExcludedCookies) // Write changes to runtime @@ -115,6 +122,7 @@ func (ar *AuthRouter) handleOptionsPOST(w http.ResponseWriter, r *http.Request) ar.options.ResponseHeaders = strings.Split(responseHeaders, ",") ar.options.ResponseClientHeaders = strings.Split(responseClientHeaders, ",") ar.options.RequestHeaders = strings.Split(requestHeaders, ",") + ar.options.RequestIncludedCookies = strings.Split(requestIncludedCookies, ",") ar.options.RequestExcludedCookies = strings.Split(requestExcludedCookies, ",") // Write changes to database @@ -122,6 +130,7 @@ func (ar *AuthRouter) handleOptionsPOST(w http.ResponseWriter, r *http.Request) ar.options.Database.Write(DatabaseTable, DatabaseKeyResponseHeaders, responseHeaders) ar.options.Database.Write(DatabaseTable, DatabaseKeyResponseClientHeaders, responseClientHeaders) ar.options.Database.Write(DatabaseTable, DatabaseKeyRequestHeaders, requestHeaders) + ar.options.Database.Write(DatabaseTable, DatabaseKeyRequestIncludedCookies, requestIncludedCookies) ar.options.Database.Write(DatabaseTable, DatabaseKeyRequestExcludedCookies, requestExcludedCookies) utils.SendOK(w) @@ -144,6 +153,9 @@ func (ar *AuthRouter) HandleAuthProviderRouting(w http.ResponseWriter, r *http.R } // Make a request to Authz Server to verify the request + // TODO: Add opt-in support for copying the request body to the forward auth request. Currently it's just an + // empty body which is usually fine in most instances. It's likely best to see if anyone wants this feature + // as I'm unaware of any specific forward auth implementation that needs it. req, err := http.NewRequest(http.MethodGet, ar.options.Address, nil) if err != nil { http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) @@ -153,10 +165,11 @@ func (ar *AuthRouter) HandleAuthProviderRouting(w http.ResponseWriter, r *http.R return ErrInternalServerError } - // TODO: Add opt-in support for copying the request body to the forward auth request. headerCopyIncluded(r.Header, req.Header, ar.options.RequestHeaders, true) + headerCookieRedact(r, ar.options.RequestIncludedCookies, false) - // TODO: Add support for upstream headers. + // TODO: Add support for headers from upstream proxies. This will likely involve implementing some form of + // proxy specific trust system within Zoraxy. rSetForwardedHeaders(r, req) // Make the Authz Request. @@ -186,10 +199,7 @@ func (ar *AuthRouter) HandleAuthProviderRouting(w http.ResponseWriter, r *http.R headerCopyIncluded(respForwarded.Header, w.Header(), ar.options.ResponseClientHeaders, false) } - if len(ar.options.RequestExcludedCookies) != 0 { - // If the user has specified a list of cookies to be removed from the request, deterministically remove them. - headerCookieRedact(r, ar.options.RequestExcludedCookies) - } + headerCookieRedact(r, ar.options.RequestExcludedCookies, true) if len(ar.options.ResponseHeaders) != 0 { // Copy specific user-specified headers from the response of the forward auth request to the request sent to the @@ -197,10 +207,11 @@ func (ar *AuthRouter) HandleAuthProviderRouting(w http.ResponseWriter, r *http.R headerCopyIncluded(respForwarded.Header, w.Header(), ar.options.ResponseHeaders, false) } + // Return the request to the proxy for forwarding to the backend. return nil } - // Copy the response. + // Copy the unsuccessful response. headerCopyExcluded(respForwarded.Header, w.Header(), nil) w.WriteHeader(respForwarded.StatusCode) @@ -214,121 +225,3 @@ func (ar *AuthRouter) HandleAuthProviderRouting(w http.ResponseWriter, r *http.R return ErrUnauthorized } - -func scheme(r *http.Request) string { - if r.TLS != nil { - return "https" - } - - return "http" -} - -func headerCookieRedact(r *http.Request, excluded []string) { - original := r.Cookies() - - if len(original) == 0 { - return - } - - var cookies []string - - for _, cookie := range original { - if stringInSlice(cookie.Name, excluded) { - continue - } - - cookies = append(cookies, cookie.String()) - } - - r.Header.Set(HeaderCookie, strings.Join(cookies, "; ")) -} - -func headerCopyExcluded(original, destination http.Header, excludedHeaders []string) { - for key, values := range original { - // We should never copy the headers in the below list. - if stringInSliceFold(key, doNotCopyHeaders) { - continue - } - - if stringInSliceFold(key, excludedHeaders) { - continue - } - - destination[key] = append(destination[key], values...) - } -} - -func headerCopyIncluded(original, destination http.Header, includedHeaders []string, allIfEmpty bool) { - if allIfEmpty && len(includedHeaders) == 0 { - headerCopyAll(original, destination) - } else { - headerCopyIncludedExact(original, destination, includedHeaders) - } -} - -func headerCopyAll(original, destination http.Header) { - for key, values := range original { - // We should never copy the headers in the below list, even if they're in the list provided by a user. - if stringInSliceFold(key, doNotCopyHeaders) { - continue - } - - destination[key] = append(destination[key], values...) - } -} - -func headerCopyIncludedExact(original, destination http.Header, keys []string) { - for _, key := range keys { - // We should never copy the headers in the below list, even if they're in the list provided by a user. - if stringInSliceFold(key, doNotCopyHeaders) { - continue - } - - if values, ok := original[key]; ok { - destination[key] = append(destination[key], values...) - } - } -} - -func stringInSlice(needle string, haystack []string) bool { - if len(haystack) == 0 { - return false - } - - for _, v := range haystack { - if needle == v { - return true - } - } - - return false -} - -func stringInSliceFold(needle string, haystack []string) bool { - if len(haystack) == 0 { - return false - } - - for _, v := range haystack { - if strings.EqualFold(needle, v) { - return true - } - } - - return false -} - -func rSetForwardedHeaders(r, req *http.Request) { - if r.RemoteAddr != "" { - before, _, _ := strings.Cut(r.RemoteAddr, ":") - - if ip := net.ParseIP(before); ip != nil { - req.Header.Set(HeaderXForwardedFor, ip.String()) - } - } - - req.Header.Set(HeaderXForwardedMethod, r.Method) - req.Header.Set(HeaderXForwardedProto, scheme(r)) - req.Header.Set(HeaderXForwardedHost, r.Host) - req.Header.Set(HeaderXForwardedURI, r.URL.Path) -} diff --git a/src/mod/auth/sso/forward/util.go b/src/mod/auth/sso/forward/util.go new file mode 100644 index 0000000..7ba9b3a --- /dev/null +++ b/src/mod/auth/sso/forward/util.go @@ -0,0 +1,137 @@ +package forward + +import ( + "net" + "net/http" + "strings" +) + +func scheme(r *http.Request) string { + if r.TLS != nil { + return "https" + } + + return "http" +} + +func headerCookieRedact(r *http.Request, names []string, exclude bool) { + if len(names) == 0 { + return + } + + original := r.Cookies() + + if len(original) == 0 { + return + } + + var cookies []string + + for _, cookie := range original { + if exclude && stringInSlice(cookie.Name, names) { + continue + } else if !exclude && !stringInSlice(cookie.Name, names) { + continue + } + + cookies = append(cookies, cookie.String()) + } + + value := strings.Join(cookies, "; ") + + r.Header.Set(HeaderCookie, value) + + return +} + +func headerCopyExcluded(original, destination http.Header, excludedHeaders []string) { + for key, values := range original { + // We should never copy the headers in the below list. + if stringInSliceFold(key, doNotCopyHeaders) { + continue + } + + if stringInSliceFold(key, excludedHeaders) { + continue + } + + destination[key] = append(destination[key], values...) + } +} + +func headerCopyIncluded(original, destination http.Header, includedHeaders []string, allIfEmpty bool) { + if allIfEmpty && len(includedHeaders) == 0 { + headerCopyAll(original, destination) + } else { + headerCopyIncludedExact(original, destination, includedHeaders) + } +} + +func headerCopyAll(original, destination http.Header) { + for key, values := range original { + // We should never copy the headers in the below list, even if they're in the list provided by a user. + if stringInSliceFold(key, doNotCopyHeaders) { + continue + } + + destination[key] = append(destination[key], values...) + } +} + +func headerCopyIncludedExact(original, destination http.Header, keys []string) { + for key, values := range original { + // We should never copy the headers in the below list, even if they're in the list provided by a user. + if stringInSliceFold(key, doNotCopyHeaders) { + continue + } + + if !stringInSliceFold(key, keys) { + continue + } + + destination[key] = append(destination[key], values...) + } +} + +func stringInSlice(needle string, haystack []string) bool { + if len(haystack) == 0 { + return false + } + + for _, v := range haystack { + if needle == v { + return true + } + } + + return false +} + +func stringInSliceFold(needle string, haystack []string) bool { + if len(haystack) == 0 { + return false + } + + for _, v := range haystack { + if strings.EqualFold(needle, v) { + return true + } + } + + return false +} + +func rSetForwardedHeaders(r, req *http.Request) { + if r.RemoteAddr != "" { + before, _, _ := strings.Cut(r.RemoteAddr, ":") + + if ip := net.ParseIP(before); ip != nil { + req.Header.Set(HeaderXForwardedFor, ip.String()) + } + } + + req.Header.Set(HeaderXForwardedMethod, r.Method) + req.Header.Set(HeaderXForwardedProto, scheme(r)) + req.Header.Set(HeaderXForwardedHost, r.Host) + req.Header.Set(HeaderXForwardedURI, r.URL.Path) +} diff --git a/src/mod/auth/sso/forward/util_test.go b/src/mod/auth/sso/forward/util_test.go new file mode 100644 index 0000000..abd0241 --- /dev/null +++ b/src/mod/auth/sso/forward/util_test.go @@ -0,0 +1,217 @@ +package forward + +import ( + "crypto/tls" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestScheme(t *testing.T) { + testCases := []struct { + name string + have *http.Request + expected string + }{ + { + "ShouldHandleDefault", + &http.Request{}, + "http", + }, + { + "ShouldHandleExplicit", + &http.Request{ + TLS: nil, + }, + "http", + }, + { + "ShouldHandleHTTPS", + &http.Request{ + TLS: &tls.ConnectionState{}, + }, + "https", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.expected, scheme(tc.have)) + }) + } +} + +func TestHeaderCookieRedact(t *testing.T) { + testCases := []struct { + name string + have string + names []string + expectedInclude string + expectedExclude string + }{ + { + "ShouldHandleIncludeEmptyWithoutSettings", + "", + nil, + "", + "", + }, + { + "ShouldHandleIncludeEmptyWithSettings", + "", + []string{"include"}, + "", + "", + }, + { + "ShouldHandleValueWithoutSettings", + "include=value; exclude=value", + nil, + "include=value; exclude=value", + "include=value; exclude=value", + }, + { + "ShouldHandleValueWithSettings", + "include=value; exclude=value", + []string{"include"}, + "include=value", + "exclude=value", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var include, exclude *http.Request + + include, exclude = &http.Request{Header: http.Header{}}, &http.Request{Header: http.Header{}} + + if tc.have != "" { + include.Header.Set(HeaderCookie, tc.have) + exclude.Header.Set(HeaderCookie, tc.have) + } + + headerCookieRedact(include, tc.names, false) + + assert.Equal(t, tc.expectedInclude, include.Header.Get(HeaderCookie)) + + headerCookieRedact(exclude, tc.names, true) + + assert.Equal(t, tc.expectedExclude, exclude.Header.Get(HeaderCookie)) + }) + } +} + +func TestHeaderCopyExcluded(t *testing.T) { + testCases := []struct { + name string + original http.Header + excluded []string + expected http.Header + }{ + { + "ShouldHandleNoSettingsNoHeaders", + http.Header{}, + nil, + http.Header{}, + }, + { + "ShouldHandleNoSettingsWithHeaders", + http.Header{ + "Example": []string{"value", "other"}, + "Exclude": []string{"value", "other"}, + HeaderUpgrade: []string{"do", "not", "copy"}, + }, + nil, + http.Header{ + "Example": []string{"value", "other"}, + "Exclude": []string{"value", "other"}, + }, + }, + { + "ShouldHandleSettingsWithHeaders", + http.Header{ + "Example": []string{"value", "other"}, + "Exclude": []string{"value", "other"}, + HeaderUpgrade: []string{"do", "not", "copy"}, + }, + []string{"exclude"}, + http.Header{ + "Example": []string{"value", "other"}, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + headers := http.Header{} + + headerCopyExcluded(tc.original, headers, tc.excluded) + + assert.Equal(t, tc.expected, headers) + }) + } +} + +func TestHeaderCopyIncluded(t *testing.T) { + testCases := []struct { + name string + original http.Header + included []string + expected http.Header + expectedAll http.Header + }{ + { + "ShouldHandleNoSettingsNoHeaders", + http.Header{}, + nil, + http.Header{}, + http.Header{}, + }, + { + "ShouldHandleNoSettingsWithHeaders", + http.Header{ + "Example": []string{"value", "other"}, + "Include": []string{"value", "other"}, + HeaderUpgrade: []string{"do", "not", "copy"}, + }, + nil, + http.Header{}, + http.Header{ + "Example": []string{"value", "other"}, + "Include": []string{"value", "other"}, + }, + }, + { + "ShouldHandleSettingsWithHeaders", + http.Header{ + "Example": []string{"value", "other"}, + "Include": []string{"value", "other"}, + HeaderUpgrade: []string{"do", "not", "copy"}, + }, + []string{"include"}, + http.Header{ + "Include": []string{"value", "other"}, + }, + http.Header{ + "Include": []string{"value", "other"}, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + headers := http.Header{} + + headerCopyIncluded(tc.original, headers, tc.included, false) + + assert.Equal(t, tc.expected, headers) + + headers = http.Header{} + + headerCopyIncluded(tc.original, headers, tc.included, true) + + assert.Equal(t, tc.expectedAll, headers) + }) + } +} diff --git a/src/web/components/sso.html b/src/web/components/sso.html index 605ccdd..d0cadd6 100644 --- a/src/web/components/sso.html +++ b/src/web/components/sso.html @@ -25,6 +25,7 @@