feat(sso): forward auth improvements

This adds a couple of key improvements to the Forward Auth SSO implementation. Primarily it adds an included cookies setting which allows filtering cookies to the authorization server. Secondly it fixes a bug where the headerCopyIncluded function was case-sensitive. Documentation in the code and on the web UI is clearer to resolve some common questions and issues. Lastly it moves a lot of funcs to the util.go file and adds fairly comprehensive tests.
This commit is contained in:
James Elliott
2025-06-15 11:54:16 +10:00
parent 31ba4f20ae
commit 26d03f9ad4
6 changed files with 421 additions and 156 deletions

View File

@@ -11,6 +11,7 @@ const (
DatabaseKeyResponseHeaders = "responseHeaders"
DatabaseKeyResponseClientHeaders = "responseClientHeaders"
DatabaseKeyRequestHeaders = "requestHeaders"
DatabaseKeyRequestIncludedCookies = "requestIncludedCookies"
DatabaseKeyRequestExcludedCookies = "requestExcludedCookies"
HeaderXForwardedProto = "X-Forwarded-Proto"

View File

@@ -3,7 +3,6 @@ package forward
import (
"encoding/json"
"io"
"net"
"net/http"
"strings"
@@ -28,6 +27,10 @@ type AuthRouterOptions struct {
// headers are copied.
RequestHeaders []string
// RequestIncludedCookies is a list of cookie keys that if defined will be the only cookies sent in the request to
// the authorization server.
RequestIncludedCookies []string
// RequestExcludedCookies is a list of cookie keys that should be removed from every request sent to the upstream.
RequestExcludedCookies []string
@@ -47,16 +50,18 @@ func NewAuthRouter(options *AuthRouterOptions) *AuthRouter {
//Read settings from database if available.
options.Database.Read(DatabaseTable, DatabaseKeyAddress, &options.Address)
responseHeaders, responseClientHeaders, requestHeaders, requestExcludedCookies := "", "", "", ""
responseHeaders, responseClientHeaders, requestHeaders, requestIncludedCookies, requestExcludedCookies := "", "", "", "", ""
options.Database.Read(DatabaseTable, DatabaseKeyResponseHeaders, &responseHeaders)
options.Database.Read(DatabaseTable, DatabaseKeyResponseClientHeaders, &responseClientHeaders)
options.Database.Read(DatabaseTable, DatabaseKeyRequestHeaders, &requestHeaders)
options.Database.Read(DatabaseTable, DatabaseKeyRequestIncludedCookies, &requestIncludedCookies)
options.Database.Read(DatabaseTable, DatabaseKeyRequestExcludedCookies, &requestExcludedCookies)
options.ResponseHeaders = strings.Split(responseHeaders, ",")
options.ResponseClientHeaders = strings.Split(responseClientHeaders, ",")
options.RequestHeaders = strings.Split(requestHeaders, ",")
options.RequestIncludedCookies = strings.Split(requestIncludedCookies, ",")
options.RequestExcludedCookies = strings.Split(requestExcludedCookies, ",")
return &AuthRouter{
@@ -87,6 +92,7 @@ func (ar *AuthRouter) handleOptionsGET(w http.ResponseWriter, r *http.Request) {
DatabaseKeyResponseHeaders: ar.options.ResponseHeaders,
DatabaseKeyResponseClientHeaders: ar.options.ResponseClientHeaders,
DatabaseKeyRequestHeaders: ar.options.RequestHeaders,
DatabaseKeyRequestIncludedCookies: ar.options.RequestIncludedCookies,
DatabaseKeyRequestExcludedCookies: ar.options.RequestExcludedCookies,
})
@@ -108,6 +114,7 @@ func (ar *AuthRouter) handleOptionsPOST(w http.ResponseWriter, r *http.Request)
responseHeaders, _ := utils.PostPara(r, DatabaseKeyResponseHeaders)
responseClientHeaders, _ := utils.PostPara(r, DatabaseKeyResponseClientHeaders)
requestHeaders, _ := utils.PostPara(r, DatabaseKeyRequestHeaders)
requestIncludedCookies, _ := utils.PostPara(r, DatabaseKeyRequestIncludedCookies)
requestExcludedCookies, _ := utils.PostPara(r, DatabaseKeyRequestExcludedCookies)
// Write changes to runtime
@@ -115,6 +122,7 @@ func (ar *AuthRouter) handleOptionsPOST(w http.ResponseWriter, r *http.Request)
ar.options.ResponseHeaders = strings.Split(responseHeaders, ",")
ar.options.ResponseClientHeaders = strings.Split(responseClientHeaders, ",")
ar.options.RequestHeaders = strings.Split(requestHeaders, ",")
ar.options.RequestIncludedCookies = strings.Split(requestIncludedCookies, ",")
ar.options.RequestExcludedCookies = strings.Split(requestExcludedCookies, ",")
// Write changes to database
@@ -122,6 +130,7 @@ func (ar *AuthRouter) handleOptionsPOST(w http.ResponseWriter, r *http.Request)
ar.options.Database.Write(DatabaseTable, DatabaseKeyResponseHeaders, responseHeaders)
ar.options.Database.Write(DatabaseTable, DatabaseKeyResponseClientHeaders, responseClientHeaders)
ar.options.Database.Write(DatabaseTable, DatabaseKeyRequestHeaders, requestHeaders)
ar.options.Database.Write(DatabaseTable, DatabaseKeyRequestIncludedCookies, requestIncludedCookies)
ar.options.Database.Write(DatabaseTable, DatabaseKeyRequestExcludedCookies, requestExcludedCookies)
utils.SendOK(w)
@@ -144,6 +153,9 @@ func (ar *AuthRouter) HandleAuthProviderRouting(w http.ResponseWriter, r *http.R
}
// Make a request to Authz Server to verify the request
// TODO: Add opt-in support for copying the request body to the forward auth request. Currently it's just an
// empty body which is usually fine in most instances. It's likely best to see if anyone wants this feature
// as I'm unaware of any specific forward auth implementation that needs it.
req, err := http.NewRequest(http.MethodGet, ar.options.Address, nil)
if err != nil {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
@@ -153,10 +165,11 @@ func (ar *AuthRouter) HandleAuthProviderRouting(w http.ResponseWriter, r *http.R
return ErrInternalServerError
}
// TODO: Add opt-in support for copying the request body to the forward auth request.
headerCopyIncluded(r.Header, req.Header, ar.options.RequestHeaders, true)
headerCookieRedact(r, ar.options.RequestIncludedCookies, false)
// TODO: Add support for upstream headers.
// TODO: Add support for headers from upstream proxies. This will likely involve implementing some form of
// proxy specific trust system within Zoraxy.
rSetForwardedHeaders(r, req)
// Make the Authz Request.
@@ -186,10 +199,7 @@ func (ar *AuthRouter) HandleAuthProviderRouting(w http.ResponseWriter, r *http.R
headerCopyIncluded(respForwarded.Header, w.Header(), ar.options.ResponseClientHeaders, false)
}
if len(ar.options.RequestExcludedCookies) != 0 {
// If the user has specified a list of cookies to be removed from the request, deterministically remove them.
headerCookieRedact(r, ar.options.RequestExcludedCookies)
}
headerCookieRedact(r, ar.options.RequestExcludedCookies, true)
if len(ar.options.ResponseHeaders) != 0 {
// Copy specific user-specified headers from the response of the forward auth request to the request sent to the
@@ -197,10 +207,11 @@ func (ar *AuthRouter) HandleAuthProviderRouting(w http.ResponseWriter, r *http.R
headerCopyIncluded(respForwarded.Header, w.Header(), ar.options.ResponseHeaders, false)
}
// Return the request to the proxy for forwarding to the backend.
return nil
}
// Copy the response.
// Copy the unsuccessful response.
headerCopyExcluded(respForwarded.Header, w.Header(), nil)
w.WriteHeader(respForwarded.StatusCode)
@@ -214,121 +225,3 @@ func (ar *AuthRouter) HandleAuthProviderRouting(w http.ResponseWriter, r *http.R
return ErrUnauthorized
}
func scheme(r *http.Request) string {
if r.TLS != nil {
return "https"
}
return "http"
}
func headerCookieRedact(r *http.Request, excluded []string) {
original := r.Cookies()
if len(original) == 0 {
return
}
var cookies []string
for _, cookie := range original {
if stringInSlice(cookie.Name, excluded) {
continue
}
cookies = append(cookies, cookie.String())
}
r.Header.Set(HeaderCookie, strings.Join(cookies, "; "))
}
func headerCopyExcluded(original, destination http.Header, excludedHeaders []string) {
for key, values := range original {
// We should never copy the headers in the below list.
if stringInSliceFold(key, doNotCopyHeaders) {
continue
}
if stringInSliceFold(key, excludedHeaders) {
continue
}
destination[key] = append(destination[key], values...)
}
}
func headerCopyIncluded(original, destination http.Header, includedHeaders []string, allIfEmpty bool) {
if allIfEmpty && len(includedHeaders) == 0 {
headerCopyAll(original, destination)
} else {
headerCopyIncludedExact(original, destination, includedHeaders)
}
}
func headerCopyAll(original, destination http.Header) {
for key, values := range original {
// We should never copy the headers in the below list, even if they're in the list provided by a user.
if stringInSliceFold(key, doNotCopyHeaders) {
continue
}
destination[key] = append(destination[key], values...)
}
}
func headerCopyIncludedExact(original, destination http.Header, keys []string) {
for _, key := range keys {
// We should never copy the headers in the below list, even if they're in the list provided by a user.
if stringInSliceFold(key, doNotCopyHeaders) {
continue
}
if values, ok := original[key]; ok {
destination[key] = append(destination[key], values...)
}
}
}
func stringInSlice(needle string, haystack []string) bool {
if len(haystack) == 0 {
return false
}
for _, v := range haystack {
if needle == v {
return true
}
}
return false
}
func stringInSliceFold(needle string, haystack []string) bool {
if len(haystack) == 0 {
return false
}
for _, v := range haystack {
if strings.EqualFold(needle, v) {
return true
}
}
return false
}
func rSetForwardedHeaders(r, req *http.Request) {
if r.RemoteAddr != "" {
before, _, _ := strings.Cut(r.RemoteAddr, ":")
if ip := net.ParseIP(before); ip != nil {
req.Header.Set(HeaderXForwardedFor, ip.String())
}
}
req.Header.Set(HeaderXForwardedMethod, r.Method)
req.Header.Set(HeaderXForwardedProto, scheme(r))
req.Header.Set(HeaderXForwardedHost, r.Host)
req.Header.Set(HeaderXForwardedURI, r.URL.Path)
}

View File

@@ -0,0 +1,137 @@
package forward
import (
"net"
"net/http"
"strings"
)
func scheme(r *http.Request) string {
if r.TLS != nil {
return "https"
}
return "http"
}
func headerCookieRedact(r *http.Request, names []string, exclude bool) {
if len(names) == 0 {
return
}
original := r.Cookies()
if len(original) == 0 {
return
}
var cookies []string
for _, cookie := range original {
if exclude && stringInSlice(cookie.Name, names) {
continue
} else if !exclude && !stringInSlice(cookie.Name, names) {
continue
}
cookies = append(cookies, cookie.String())
}
value := strings.Join(cookies, "; ")
r.Header.Set(HeaderCookie, value)
return
}
func headerCopyExcluded(original, destination http.Header, excludedHeaders []string) {
for key, values := range original {
// We should never copy the headers in the below list.
if stringInSliceFold(key, doNotCopyHeaders) {
continue
}
if stringInSliceFold(key, excludedHeaders) {
continue
}
destination[key] = append(destination[key], values...)
}
}
func headerCopyIncluded(original, destination http.Header, includedHeaders []string, allIfEmpty bool) {
if allIfEmpty && len(includedHeaders) == 0 {
headerCopyAll(original, destination)
} else {
headerCopyIncludedExact(original, destination, includedHeaders)
}
}
func headerCopyAll(original, destination http.Header) {
for key, values := range original {
// We should never copy the headers in the below list, even if they're in the list provided by a user.
if stringInSliceFold(key, doNotCopyHeaders) {
continue
}
destination[key] = append(destination[key], values...)
}
}
func headerCopyIncludedExact(original, destination http.Header, keys []string) {
for key, values := range original {
// We should never copy the headers in the below list, even if they're in the list provided by a user.
if stringInSliceFold(key, doNotCopyHeaders) {
continue
}
if !stringInSliceFold(key, keys) {
continue
}
destination[key] = append(destination[key], values...)
}
}
func stringInSlice(needle string, haystack []string) bool {
if len(haystack) == 0 {
return false
}
for _, v := range haystack {
if needle == v {
return true
}
}
return false
}
func stringInSliceFold(needle string, haystack []string) bool {
if len(haystack) == 0 {
return false
}
for _, v := range haystack {
if strings.EqualFold(needle, v) {
return true
}
}
return false
}
func rSetForwardedHeaders(r, req *http.Request) {
if r.RemoteAddr != "" {
before, _, _ := strings.Cut(r.RemoteAddr, ":")
if ip := net.ParseIP(before); ip != nil {
req.Header.Set(HeaderXForwardedFor, ip.String())
}
}
req.Header.Set(HeaderXForwardedMethod, r.Method)
req.Header.Set(HeaderXForwardedProto, scheme(r))
req.Header.Set(HeaderXForwardedHost, r.Host)
req.Header.Set(HeaderXForwardedURI, r.URL.Path)
}

View File

@@ -0,0 +1,217 @@
package forward
import (
"crypto/tls"
"net/http"
"testing"
"github.com/stretchr/testify/assert"
)
func TestScheme(t *testing.T) {
testCases := []struct {
name string
have *http.Request
expected string
}{
{
"ShouldHandleDefault",
&http.Request{},
"http",
},
{
"ShouldHandleExplicit",
&http.Request{
TLS: nil,
},
"http",
},
{
"ShouldHandleHTTPS",
&http.Request{
TLS: &tls.ConnectionState{},
},
"https",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
assert.Equal(t, tc.expected, scheme(tc.have))
})
}
}
func TestHeaderCookieRedact(t *testing.T) {
testCases := []struct {
name string
have string
names []string
expectedInclude string
expectedExclude string
}{
{
"ShouldHandleIncludeEmptyWithoutSettings",
"",
nil,
"",
"",
},
{
"ShouldHandleIncludeEmptyWithSettings",
"",
[]string{"include"},
"",
"",
},
{
"ShouldHandleValueWithoutSettings",
"include=value; exclude=value",
nil,
"include=value; exclude=value",
"include=value; exclude=value",
},
{
"ShouldHandleValueWithSettings",
"include=value; exclude=value",
[]string{"include"},
"include=value",
"exclude=value",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var include, exclude *http.Request
include, exclude = &http.Request{Header: http.Header{}}, &http.Request{Header: http.Header{}}
if tc.have != "" {
include.Header.Set(HeaderCookie, tc.have)
exclude.Header.Set(HeaderCookie, tc.have)
}
headerCookieRedact(include, tc.names, false)
assert.Equal(t, tc.expectedInclude, include.Header.Get(HeaderCookie))
headerCookieRedact(exclude, tc.names, true)
assert.Equal(t, tc.expectedExclude, exclude.Header.Get(HeaderCookie))
})
}
}
func TestHeaderCopyExcluded(t *testing.T) {
testCases := []struct {
name string
original http.Header
excluded []string
expected http.Header
}{
{
"ShouldHandleNoSettingsNoHeaders",
http.Header{},
nil,
http.Header{},
},
{
"ShouldHandleNoSettingsWithHeaders",
http.Header{
"Example": []string{"value", "other"},
"Exclude": []string{"value", "other"},
HeaderUpgrade: []string{"do", "not", "copy"},
},
nil,
http.Header{
"Example": []string{"value", "other"},
"Exclude": []string{"value", "other"},
},
},
{
"ShouldHandleSettingsWithHeaders",
http.Header{
"Example": []string{"value", "other"},
"Exclude": []string{"value", "other"},
HeaderUpgrade: []string{"do", "not", "copy"},
},
[]string{"exclude"},
http.Header{
"Example": []string{"value", "other"},
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
headers := http.Header{}
headerCopyExcluded(tc.original, headers, tc.excluded)
assert.Equal(t, tc.expected, headers)
})
}
}
func TestHeaderCopyIncluded(t *testing.T) {
testCases := []struct {
name string
original http.Header
included []string
expected http.Header
expectedAll http.Header
}{
{
"ShouldHandleNoSettingsNoHeaders",
http.Header{},
nil,
http.Header{},
http.Header{},
},
{
"ShouldHandleNoSettingsWithHeaders",
http.Header{
"Example": []string{"value", "other"},
"Include": []string{"value", "other"},
HeaderUpgrade: []string{"do", "not", "copy"},
},
nil,
http.Header{},
http.Header{
"Example": []string{"value", "other"},
"Include": []string{"value", "other"},
},
},
{
"ShouldHandleSettingsWithHeaders",
http.Header{
"Example": []string{"value", "other"},
"Include": []string{"value", "other"},
HeaderUpgrade: []string{"do", "not", "copy"},
},
[]string{"include"},
http.Header{
"Include": []string{"value", "other"},
},
http.Header{
"Include": []string{"value", "other"},
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
headers := http.Header{}
headerCopyIncluded(tc.original, headers, tc.included, false)
assert.Equal(t, tc.expected, headers)
headers = http.Header{}
headerCopyIncluded(tc.original, headers, tc.included, true)
assert.Equal(t, tc.expectedAll, headers)
})
}
}