diff --git a/src/main.go b/src/main.go index cff24a3..0cf1dd3 100644 --- a/src/main.go +++ b/src/main.go @@ -53,13 +53,13 @@ var enableHighSpeedGeoIPLookup = flag.Bool("fastgeoip", false, "Enable high spee var staticWebServerRoot = flag.String("webroot", "./www", "Static web server root folder. Only allow chnage in start paramters") var allowWebFileManager = flag.Bool("webfm", true, "Enable web file manager for static web server root folder") var logOutputToFile = flag.Bool("log", true, "Log terminal output to file") -var updateMode = flag.Int("update", 0, "Version number (usually the version before you update Zoraxy) to start accumulation update. To update v3.0.7 to latest, use -update=307") +var enableAutoUpdate = flag.Bool("cfgupgrade", true, "Enable auto config upgrade if breaking change is detected") var ( name = "Zoraxy" version = "3.0.8" - nodeUUID = "generic" - development = true //Set this to false to use embedded web fs + nodeUUID = "generic" //System uuid, in uuidv4 format + development = true //Set this to false to use embedded web fs bootTime = time.Now().Unix() /* @@ -123,8 +123,9 @@ func ShutdownSeq() { // Stop the mdns service mdnsTickerStop <- true } - mdnsScanner.Close() + fmt.Println("- Shutting down load balancer") + loadBalancer.Close() fmt.Println("- Closing Certificates Auto Renewer") acmeAutoRenewer.Close() //Remove the tmp folder @@ -147,17 +148,16 @@ func main() { os.Exit(0) } - if *updateMode > 306 { - fmt.Println("Entering Update Mode") - update.RunConfigUpdate(*updateMode, update.GetVersionIntFromVersionNumber(version)) - os.Exit(0) - } - if !utils.ValidateListeningAddress(*webUIPort) { fmt.Println("Malformed -port (listening address) paramter. Do you mean -port=:" + *webUIPort + "?") os.Exit(0) } + if *enableAutoUpdate { + log.Println("[INFO] Checking required config update") + update.RunConfigUpdate(0, update.GetVersionIntFromVersionNumber(version)) + } + SetupCloseHandler() //Read or create the system uuid diff --git a/src/mod/dynamicproxy/dpcore/dpcore.go b/src/mod/dynamicproxy/dpcore/dpcore.go index bca450e..2c64eaa 100644 --- a/src/mod/dynamicproxy/dpcore/dpcore.go +++ b/src/mod/dynamicproxy/dpcore/dpcore.go @@ -64,6 +64,7 @@ type ResponseRewriteRuleSet struct { PathPrefix string //Vdir prefix for root, / will be rewrite to this UpstreamHeaders [][]string DownstreamHeaders [][]string + NoRemoveHopByHop bool //Do not remove hop-by-hop headers, dangerous Version string //Version number of Zoraxy, use for X-Proxy-By } diff --git a/src/mod/dynamicproxy/dynamicproxy.go b/src/mod/dynamicproxy/dynamicproxy.go index 2eb5fca..40ff4e3 100644 --- a/src/mod/dynamicproxy/dynamicproxy.go +++ b/src/mod/dynamicproxy/dynamicproxy.go @@ -151,18 +151,19 @@ func (router *Router) StartProxyService() error { } } - selectedUpstream, err := router.loadBalancer.GetRequestUpstreamTarget(r, sep.ActiveOrigins) + selectedUpstream, err := router.loadBalancer.GetRequestUpstreamTarget(w, r, sep.ActiveOrigins, sep.UseStickySession) if err != nil { http.ServeFile(w, r, "./web/hosterror.html") log.Println(err.Error()) router.logRequest(r, false, 404, "vdir-http", r.Host) } selectedUpstream.ServeHTTP(w, r, &dpcore.ResponseRewriteRuleSet{ - ProxyDomain: selectedUpstream.OriginIpOrDomain, - OriginalHost: originalHostHeader, - UseTLS: selectedUpstream.RequireTLS, - PathPrefix: "", - Version: sep.parent.Option.HostVersion, + ProxyDomain: selectedUpstream.OriginIpOrDomain, + OriginalHost: originalHostHeader, + UseTLS: selectedUpstream.RequireTLS, + NoRemoveHopByHop: sep.DisableHopByHopHeaderRemoval, + PathPrefix: "", + Version: sep.parent.Option.HostVersion, }) return } diff --git a/src/mod/dynamicproxy/loadbalance/loadbalance.go b/src/mod/dynamicproxy/loadbalance/loadbalance.go index 7673a7a..726696f 100644 --- a/src/mod/dynamicproxy/loadbalance/loadbalance.go +++ b/src/mod/dynamicproxy/loadbalance/loadbalance.go @@ -3,8 +3,9 @@ package loadbalance import ( "strings" "sync" - "sync/atomic" + "github.com/google/uuid" + "github.com/gorilla/sessions" "imuslab.com/zoraxy/mod/dynamicproxy/dpcore" "imuslab.com/zoraxy/mod/geodb" "imuslab.com/zoraxy/mod/info/logger" @@ -17,12 +18,14 @@ import ( */ type Options struct { + SystemUUID string //Use for the session store UseActiveHealthCheck bool //Use active health check, default to false Geodb *geodb.Store //GeoIP resolver for checking incoming request origin country Logger *logger.Logger } type RouteManager struct { + SessionStore *sessions.CookieStore LoadBalanceMap sync.Map //Sync map to store the last load balance state of a given node OnlineStatusMap sync.Map //Sync map to store the online status of a given ip address or domain name onlineStatusTickerStop chan bool //Stopping channel for the online status pinger @@ -39,20 +42,26 @@ type Upstream struct { //Load balancing configs Weight int //Random weight for round robin, 0 for fallback only - MaxConn int //Maxmium connection to this server, 0 for unlimited + MaxConn int //TODO: Maxmium connection to this server, 0 for unlimited - currentConnectionCounts atomic.Uint64 //Counter for number of client currently connected - proxy *dpcore.ReverseProxy + //currentConnectionCounts atomic.Uint64 //Counter for number of client currently connected + proxy *dpcore.ReverseProxy } // Create a new load balancer func NewLoadBalancer(options *Options) *RouteManager { - onlineStatusCheckerStopChan := make(chan bool) + if options.SystemUUID == "" { + //System UUID not passed in. Use random key + options.SystemUUID = uuid.New().String() + } + //Generate a session store for stickySession + store := sessions.NewCookieStore([]byte(options.SystemUUID)) return &RouteManager{ + SessionStore: store, LoadBalanceMap: sync.Map{}, OnlineStatusMap: sync.Map{}, - onlineStatusTickerStop: onlineStatusCheckerStopChan, + onlineStatusTickerStop: nil, Options: *options, } } diff --git a/src/mod/dynamicproxy/loadbalance/onlineStatus.go b/src/mod/dynamicproxy/loadbalance/onlineStatus.go index b63681d..2dcd4e3 100644 --- a/src/mod/dynamicproxy/loadbalance/onlineStatus.go +++ b/src/mod/dynamicproxy/loadbalance/onlineStatus.go @@ -16,10 +16,6 @@ func (m *RouteManager) IsTargetOnline(matchingDomainOrIp string) bool { return ok && isOnline } -func (m *RouteManager) SetTargetOffline() { - -} - // Ping a target to see if it is online func PingTarget(targetMatchingDomainOrIp string, requireTLS bool) bool { client := &http.Client{ @@ -41,30 +37,3 @@ func PingTarget(targetMatchingDomainOrIp string, requireTLS bool) bool { return resp.StatusCode >= 200 && resp.StatusCode <= 600 } - -// StartHeartbeats start pinging each server every minutes to make sure all targets are online -// Active mode only -/* -func (m *RouteManager) StartHeartbeats(pingTargets []*FallbackProxyTarget) { - ticker := time.NewTicker(1 * time.Minute) - defer ticker.Stop() - - fmt.Println("Heartbeat started") - go func() { - for { - select { - case <-m.onlineStatusTickerStop: - ticker.Stop() - return - case <-ticker.C: - for _, target := range pingTargets { - go func(target *FallbackProxyTarget) { - isOnline := PingTarget(target.MatchingDomainOrIp, target.RequireTLS) - m.LoadBalanceMap.Store(target.MatchingDomainOrIp, isOnline) - }(target) - } - } - } - }() -} -*/ diff --git a/src/mod/dynamicproxy/loadbalance/originPicker.go b/src/mod/dynamicproxy/loadbalance/originPicker.go index 63ae5d9..51bcc2b 100644 --- a/src/mod/dynamicproxy/loadbalance/originPicker.go +++ b/src/mod/dynamicproxy/loadbalance/originPicker.go @@ -3,6 +3,8 @@ package loadbalance import ( "errors" "fmt" + "log" + "math/rand" "net/http" ) @@ -15,12 +17,138 @@ import ( // GetRequestUpstreamTarget return the upstream target where this // request should be routed -func (m *RouteManager) GetRequestUpstreamTarget(r *http.Request, origins []*Upstream) (*Upstream, error) { +func (m *RouteManager) GetRequestUpstreamTarget(w http.ResponseWriter, r *http.Request, origins []*Upstream, useStickySession bool) (*Upstream, error) { if len(origins) == 0 { return nil, errors.New("no upstream is defined for this host") } + var targetOrigin = origins[0] + if useStickySession { + //Use stick session, check which origins this request previously used + targetOriginId, err := m.getSessionHandler(r, origins) + if err != nil { + //No valid session found. Assign a new upstream + targetOrigin, index, err := getRandomUpstreamByWeight(origins) + if err != nil { + fmt.Println("Oops. Unable to get random upstream") + targetOrigin = origins[0] + index = 0 + } + m.setSessionHandler(w, r, targetOrigin.OriginIpOrDomain, index) + return targetOrigin, nil + } - //TODO: Add upstream picking algorithm here - fmt.Println("DEBUG: Picking origin " + origins[0].OriginIpOrDomain) - return origins[0], nil + //Valid session found. Resume the previous session + return origins[targetOriginId], nil + } else { + //Do not use stick session. Get a random one + var err error + targetOrigin, _, err = getRandomUpstreamByWeight(origins) + if err != nil { + log.Println(err) + targetOrigin = origins[0] + } + + } + + //fmt.Println("DEBUG: Picking origin " + targetOrigin.OriginIpOrDomain) + return targetOrigin, nil +} + +/* Features related to session access */ +//Set a new origin for this connection by session +func (m *RouteManager) setSessionHandler(w http.ResponseWriter, r *http.Request, originIpOrDomain string, index int) error { + session, err := m.SessionStore.Get(r, "STICKYSESSION") + if err != nil { + return err + } + session.Values["zr_sid_origin"] = originIpOrDomain + session.Values["zr_sid_index"] = index + session.Options.MaxAge = 86400 //1 day + session.Options.Path = "/" + err = session.Save(r, w) + if err != nil { + return err + } + return nil +} + +// Get the previous connected origin from session +func (m *RouteManager) getSessionHandler(r *http.Request, upstreams []*Upstream) (int, error) { + // Get existing session + session, err := m.SessionStore.Get(r, "STICKYSESSION") + if err != nil { + return -1, err + } + + // Retrieve session values for origin + originDomainRaw := session.Values["zr_sid_origin"] + originIDRaw := session.Values["zr_sid_index"] + + if originDomainRaw == nil || originIDRaw == nil { + return -1, errors.New("no session has been set") + } + originDomain := originDomainRaw.(string) + originID := originIDRaw.(int) + + //Check if it has been modified + if len(upstreams) < originID || upstreams[originID].OriginIpOrDomain != originDomain { + //Mismatch or upstreams has been updated + return -1, errors.New("upstreams has been changed") + } + + return originID, nil +} + +/* Functions related to random upstream picking */ +// Get a random upstream by the weights defined in Upstream struct, return the upstream, index value and any error +func getRandomUpstreamByWeight(upstreams []*Upstream) (*Upstream, int, error) { + var ret *Upstream + sum := 0 + for _, c := range upstreams { + sum += c.Weight + } + r, err := intRange(0, sum) + if err != nil { + return ret, -1, err + } + counter := 0 + for _, c := range upstreams { + r -= c.Weight + if r < 0 { + return c, counter, nil + } + counter++ + } + + if ret == nil { + //All fallback + //use the first one that is with weight = 0 + fallbackUpstreams := []*Upstream{} + fallbackUpstreamsOriginalID := []int{} + for ix, upstream := range upstreams { + if upstream.Weight == 0 { + fallbackUpstreams = append(fallbackUpstreams, upstream) + fallbackUpstreamsOriginalID = append(fallbackUpstreamsOriginalID, ix) + } + } + upstreamID := rand.Intn(len(fallbackUpstreams)) + return fallbackUpstreams[upstreamID], fallbackUpstreamsOriginalID[upstreamID], nil + } + return ret, -1, errors.New("failed to pick an upstream origin server") +} + +// IntRange returns a random integer in the range from min to max. +func intRange(min, max int) (int, error) { + var result int + switch { + case min > max: + // Fail with error + return result, errors.New("min is greater than max") + case max == min: + result = max + case max > min: + b := rand.Intn(max-min) + min + result = min + int(b) + } + return result, nil } diff --git a/src/mod/dynamicproxy/proxyRequestHandler.go b/src/mod/dynamicproxy/proxyRequestHandler.go index 3426731..0bc4bb4 100644 --- a/src/mod/dynamicproxy/proxyRequestHandler.go +++ b/src/mod/dynamicproxy/proxyRequestHandler.go @@ -112,7 +112,7 @@ func (router *Router) rewriteURL(rooturl string, requestURL string) string { func (h *ProxyHandler) hostRequest(w http.ResponseWriter, r *http.Request, target *ProxyEndpoint) { r.Header.Set("X-Forwarded-Host", r.Host) r.Header.Set("X-Forwarded-Server", "zoraxy-"+h.Parent.Option.HostUUID) - selectedUpstream, err := h.Parent.loadBalancer.GetRequestUpstreamTarget(r, target.ActiveOrigins) + selectedUpstream, err := h.Parent.loadBalancer.GetRequestUpstreamTarget(w, r, target.ActiveOrigins, target.UseStickySession) if err != nil { http.ServeFile(w, r, "./web/rperror.html") log.Println(err.Error()) @@ -164,6 +164,7 @@ func (h *ProxyHandler) hostRequest(w http.ResponseWriter, r *http.Request, targe PathPrefix: "", UpstreamHeaders: upstreamHeaders, DownstreamHeaders: downstreamHeaders, + NoRemoveHopByHop: target.DisableHopByHopHeaderRemoval, Version: target.parent.Option.HostVersion, }) diff --git a/src/mod/dynamicproxy/typedef.go b/src/mod/dynamicproxy/typedef.go index 1bcfa78..908fc70 100644 --- a/src/mod/dynamicproxy/typedef.go +++ b/src/mod/dynamicproxy/typedef.go @@ -133,6 +133,7 @@ type ProxyEndpoint struct { HSTSMaxAge int64 //HSTS max age, set to 0 for disable HSTS headers EnablePermissionPolicyHeader bool //Enable injection of permission policy header PermissionPolicy *permissionpolicy.PermissionsPolicy //Permission policy header + DisableHopByHopHeaderRemoval bool //TODO: Do not remove hop-by-hop headers //Authentication RequireBasicAuth bool //Set to true to request basic auth before proxy diff --git a/src/mod/update/update.go b/src/mod/update/update.go index 8c671b9..c5a85f7 100644 --- a/src/mod/update/update.go +++ b/src/mod/update/update.go @@ -9,21 +9,52 @@ package update import ( "fmt" + "os" "strconv" "strings" v308 "imuslab.com/zoraxy/mod/update/v308" + "imuslab.com/zoraxy/mod/utils" ) // Run config update. Version numbers are int. For example // to update 3.0.7 to 3.0.8, use RunConfigUpdate(307, 308) // This function support cross versions updates (e.g. 307 -> 310) func RunConfigUpdate(fromVersion int, toVersion int) { + versionFile := "./conf/version" + if fromVersion == 0 { + //Run auto previous version detection + fromVersion = 307 + if utils.FileExists(versionFile) { + //Read the version file + previousVersionText, err := os.ReadFile(versionFile) + if err != nil { + panic("Unable to read version file at " + versionFile) + } + + //Convert the version to int + versionInt, err := strconv.Atoi(strings.TrimSpace(string(previousVersionText))) + if err != nil { + panic("Unable to read version file at " + versionFile) + } + + fromVersion = versionInt + } + + if fromVersion == toVersion { + //No need to update + return + } + } + + //Do iterate update for i := fromVersion; i < toVersion; i++ { oldVersion := fromVersion newVersion := fromVersion + 1 fmt.Println("Updating from v", oldVersion, " to v", newVersion) runUpdateRoutineWithVersion(oldVersion, newVersion) + //Write the updated version to file + os.WriteFile(versionFile, []byte(strconv.Itoa(newVersion)), 0775) } fmt.Println("Update completed") } @@ -36,6 +67,7 @@ func GetVersionIntFromVersionNumber(version string) int { func runUpdateRoutineWithVersion(fromVersion int, toVersion int) { if fromVersion == 307 && toVersion == 308 { + //Updating from v3.0.7 to v3.0.8 err := v308.UpdateFrom307To308() if err != nil { panic(err) diff --git a/src/reverseproxy.go b/src/reverseproxy.go index b058db0..1163942 100644 --- a/src/reverseproxy.go +++ b/src/reverseproxy.go @@ -207,12 +207,7 @@ func ReverseProxyHandleAddEndpoint(w http.ResponseWriter, r *http.Request) { useBypassGlobalTLS := bypassGlobalTLS == "true" //Enable TLS validation? - stv, _ := utils.PostPara(r, "tlsval") - if stv == "" { - stv = "false" - } - - skipTlsValidation := (stv == "true") + skipTlsValidation, _ := utils.PostBool(r, "tlsval") //Get access rule ID accessRuleID, _ := utils.PostPara(r, "access") @@ -225,12 +220,10 @@ func ReverseProxyHandleAddEndpoint(w http.ResponseWriter, r *http.Request) { } // Require basic auth? - rba, _ := utils.PostPara(r, "bauth") - if rba == "" { - rba = "false" - } + requireBasicAuth, _ := utils.PostBool(r, "bauth") - requireBasicAuth := (rba == "true") + //Use sticky session? + useStickySession, _ := utils.PostBool(r, "stickysess") // Require Rate Limiting? requireRateLimit := false @@ -328,7 +321,7 @@ func ReverseProxyHandleAddEndpoint(w http.ResponseWriter, r *http.Request) { }, }, InactiveOrigins: []*loadbalance.Upstream{}, - UseStickySession: false, //TODO: Move options to webform + UseStickySession: useStickySession, //TLS BypassGlobalTLS: useBypassGlobalTLS, diff --git a/src/start.go b/src/start.go index b7895cf..73da080 100644 --- a/src/start.go +++ b/src/start.go @@ -104,8 +104,9 @@ func startupSequence() { //Create a load balancer loadBalancer = loadbalance.NewLoadBalancer(&loadbalance.Options{ - Geodb: geodbStore, - Logger: SystemWideLogger, + SystemUUID: nodeUUID, + Geodb: geodbStore, + Logger: SystemWideLogger, }) //Create the access controller diff --git a/src/web/components/rules.html b/src/web/components/rules.html index 6df7fa6..d175a98 100644 --- a/src/web/components/rules.html +++ b/src/web/components/rules.html @@ -11,6 +11,17 @@ border-radius: 0.6em; padding: 1em; } + + .descheader{ + display:none !important; + } + + @media (min-width: 1367px) { + .descheader{ + display:auto !important; + + } + }
@@ -47,16 +58,14 @@
- - +
+ + Security
@@ -76,27 +85,21 @@
-
-
- - -
+
+ + Access Control
-
- - -
-
-
- -
- -
- req / sec / IP + + - Return a 429 error code if request rate exceed the rate limit. + Allow regional access control using blacklist or whitelist. Use "default" for "allow all".
@@ -131,6 +134,22 @@
+
+
+ + +
+
+
+ +
+ +
+ req / sec / IP +
+
+ Return a 429 error code if request rate exceed the rate limit. +
@@ -166,17 +185,18 @@ //New Proxy Endpoint function newProxyEndpoint(){ - var rootname = $("#rootname").val(); - var proxyDomain = $("#proxyDomain").val(); - var useTLS = $("#reqTls")[0].checked; - var skipTLSValidation = $("#skipTLSValidation")[0].checked; - var bypassGlobalTLS = $("#bypassGlobalTLS")[0].checked; - var requireBasicAuth = $("#requireBasicAuth")[0].checked; - var proxyRateLimit = $("#proxyRateLimit").val(); - var requireRateLimit = $("#requireRateLimit")[0].checked; - var skipWebSocketOriginCheck = $("#skipWebsocketOriginCheck")[0].checked; - var accessRuleToUse = $("#newProxyRuleAccessFilter").val(); - + let rootname = $("#rootname").val(); + let proxyDomain = $("#proxyDomain").val(); + let useTLS = $("#reqTls")[0].checked; + let skipTLSValidation = $("#skipTLSValidation")[0].checked; + let bypassGlobalTLS = $("#bypassGlobalTLS")[0].checked; + let requireBasicAuth = $("#requireBasicAuth")[0].checked; + let proxyRateLimit = $("#proxyRateLimit").val(); + let requireRateLimit = $("#requireRateLimit")[0].checked; + let skipWebSocketOriginCheck = $("#skipWebsocketOriginCheck")[0].checked; + let accessRuleToUse = $("#newProxyRuleAccessFilter").val(); + let useStickySessionLB = $("#useStickySessionLB")[0].checked; + if (rootname.trim() == ""){ $("#rootname").parent().addClass("error"); return @@ -207,6 +227,7 @@ ratenum: proxyRateLimit, cred: JSON.stringify(credentials), access: accessRuleToUse, + stickysess: useStickySessionLB, }, success: function(data){ if (data.error != undefined){