diff --git a/src/mod/auth/sso/authelia/authelia.go b/src/mod/auth/sso/authelia/authelia.go index 075e97f..d86f374 100644 --- a/src/mod/auth/sso/authelia/authelia.go +++ b/src/mod/auth/sso/authelia/authelia.go @@ -3,9 +3,10 @@ package authelia import ( "encoding/json" "errors" - "fmt" + "net" "net/http" "net/url" + "strings" "imuslab.com/zoraxy/mod/database" "imuslab.com/zoraxy/mod/info/logger" @@ -93,25 +94,20 @@ func (ar *AutheliaRouter) HandleAutheliaAuth(w http.ResponseWriter, r *http.Requ protocol = "https" } - autheliaBaseURL := protocol + "://" + ar.options.AutheliaURL - //Remove tailing slash if any - if autheliaBaseURL[len(autheliaBaseURL)-1] == '/' { - autheliaBaseURL = autheliaBaseURL[:len(autheliaBaseURL)-1] + autheliaURL := &url.URL{ + Scheme: protocol, + Host: ar.options.AutheliaURL, } //Make a request to Authelia to verify the request - req, err := http.NewRequest("POST", autheliaBaseURL+"/api/verify", nil) + req, err := http.NewRequest("POST", autheliaURL.JoinPath("api", "verify").String(), nil) if err != nil { ar.options.Logger.PrintAndLog("Authelia", "Unable to create request", err) w.WriteHeader(401) return errors.New("unauthorized") } - scheme := "http" - if r.TLS != nil { - scheme = "https" - } - req.Header.Add("X-Original-URL", fmt.Sprintf("%s://%s", scheme, r.Host)) + originalURL := rOriginalHeaders(r, req) // Copy cookies from the incoming request for _, cookie := range r.Cookies() { @@ -127,10 +123,42 @@ func (ar *AutheliaRouter) HandleAutheliaAuth(w http.ResponseWriter, r *http.Requ } if resp.StatusCode != 200 { - redirectURL := autheliaBaseURL + "/?rd=" + url.QueryEscape(scheme+"://"+r.Host+r.URL.String()) + "&rm=" + r.Method - http.Redirect(w, r, redirectURL, http.StatusSeeOther) + redirectURL := autheliaURL.JoinPath() + + query := redirectURL.Query() + + query.Set("rd", originalURL.String()) + query.Set("rm", r.Method) + + http.Redirect(w, r, redirectURL.String(), http.StatusSeeOther) return errors.New("unauthorized") } return nil } + +func rOriginalHeaders(r, req *http.Request) *url.URL { + if r.RemoteAddr != "" { + before, _, _ := strings.Cut(r.RemoteAddr, ":") + + if ip := net.ParseIP(before); ip != nil { + req.Header.Set("X-Forwarded-For", ip.String()) + } + } + + originalURL := &url.URL{ + Scheme: "http", + Host: r.Host, + Path: r.URL.Path, + RawPath: r.URL.RawPath, + } + + if r.TLS != nil { + originalURL.Scheme = "https" + } + + req.Header.Add("X-Forwarded-Method", r.Method) + req.Header.Add("X-Original-URL", originalURL.String()) + + return originalURL +}