Added load balance origin picker

+ Added load balance picker
+ Added fallback mode for upstream
+ Added stick session
This commit is contained in:
Toby Chui 2024-07-12 20:14:31 +08:00
parent 2aa35cbe6d
commit aca6e44b35
12 changed files with 266 additions and 109 deletions

View File

@ -53,13 +53,13 @@ var enableHighSpeedGeoIPLookup = flag.Bool("fastgeoip", false, "Enable high spee
var staticWebServerRoot = flag.String("webroot", "./www", "Static web server root folder. Only allow chnage in start paramters")
var allowWebFileManager = flag.Bool("webfm", true, "Enable web file manager for static web server root folder")
var logOutputToFile = flag.Bool("log", true, "Log terminal output to file")
var updateMode = flag.Int("update", 0, "Version number (usually the version before you update Zoraxy) to start accumulation update. To update v3.0.7 to latest, use -update=307")
var enableAutoUpdate = flag.Bool("cfgupgrade", true, "Enable auto config upgrade if breaking change is detected")
var (
name = "Zoraxy"
version = "3.0.8"
nodeUUID = "generic"
development = true //Set this to false to use embedded web fs
nodeUUID = "generic" //System uuid, in uuidv4 format
development = true //Set this to false to use embedded web fs
bootTime = time.Now().Unix()
/*
@ -123,8 +123,9 @@ func ShutdownSeq() {
// Stop the mdns service
mdnsTickerStop <- true
}
mdnsScanner.Close()
fmt.Println("- Shutting down load balancer")
loadBalancer.Close()
fmt.Println("- Closing Certificates Auto Renewer")
acmeAutoRenewer.Close()
//Remove the tmp folder
@ -147,17 +148,16 @@ func main() {
os.Exit(0)
}
if *updateMode > 306 {
fmt.Println("Entering Update Mode")
update.RunConfigUpdate(*updateMode, update.GetVersionIntFromVersionNumber(version))
os.Exit(0)
}
if !utils.ValidateListeningAddress(*webUIPort) {
fmt.Println("Malformed -port (listening address) paramter. Do you mean -port=:" + *webUIPort + "?")
os.Exit(0)
}
if *enableAutoUpdate {
log.Println("[INFO] Checking required config update")
update.RunConfigUpdate(0, update.GetVersionIntFromVersionNumber(version))
}
SetupCloseHandler()
//Read or create the system uuid

View File

@ -64,6 +64,7 @@ type ResponseRewriteRuleSet struct {
PathPrefix string //Vdir prefix for root, / will be rewrite to this
UpstreamHeaders [][]string
DownstreamHeaders [][]string
NoRemoveHopByHop bool //Do not remove hop-by-hop headers, dangerous
Version string //Version number of Zoraxy, use for X-Proxy-By
}

View File

@ -151,18 +151,19 @@ func (router *Router) StartProxyService() error {
}
}
selectedUpstream, err := router.loadBalancer.GetRequestUpstreamTarget(r, sep.ActiveOrigins)
selectedUpstream, err := router.loadBalancer.GetRequestUpstreamTarget(w, r, sep.ActiveOrigins, sep.UseStickySession)
if err != nil {
http.ServeFile(w, r, "./web/hosterror.html")
log.Println(err.Error())
router.logRequest(r, false, 404, "vdir-http", r.Host)
}
selectedUpstream.ServeHTTP(w, r, &dpcore.ResponseRewriteRuleSet{
ProxyDomain: selectedUpstream.OriginIpOrDomain,
OriginalHost: originalHostHeader,
UseTLS: selectedUpstream.RequireTLS,
PathPrefix: "",
Version: sep.parent.Option.HostVersion,
ProxyDomain: selectedUpstream.OriginIpOrDomain,
OriginalHost: originalHostHeader,
UseTLS: selectedUpstream.RequireTLS,
NoRemoveHopByHop: sep.DisableHopByHopHeaderRemoval,
PathPrefix: "",
Version: sep.parent.Option.HostVersion,
})
return
}

View File

@ -3,8 +3,9 @@ package loadbalance
import (
"strings"
"sync"
"sync/atomic"
"github.com/google/uuid"
"github.com/gorilla/sessions"
"imuslab.com/zoraxy/mod/dynamicproxy/dpcore"
"imuslab.com/zoraxy/mod/geodb"
"imuslab.com/zoraxy/mod/info/logger"
@ -17,12 +18,14 @@ import (
*/
type Options struct {
SystemUUID string //Use for the session store
UseActiveHealthCheck bool //Use active health check, default to false
Geodb *geodb.Store //GeoIP resolver for checking incoming request origin country
Logger *logger.Logger
}
type RouteManager struct {
SessionStore *sessions.CookieStore
LoadBalanceMap sync.Map //Sync map to store the last load balance state of a given node
OnlineStatusMap sync.Map //Sync map to store the online status of a given ip address or domain name
onlineStatusTickerStop chan bool //Stopping channel for the online status pinger
@ -39,20 +42,26 @@ type Upstream struct {
//Load balancing configs
Weight int //Random weight for round robin, 0 for fallback only
MaxConn int //Maxmium connection to this server, 0 for unlimited
MaxConn int //TODO: Maxmium connection to this server, 0 for unlimited
currentConnectionCounts atomic.Uint64 //Counter for number of client currently connected
proxy *dpcore.ReverseProxy
//currentConnectionCounts atomic.Uint64 //Counter for number of client currently connected
proxy *dpcore.ReverseProxy
}
// Create a new load balancer
func NewLoadBalancer(options *Options) *RouteManager {
onlineStatusCheckerStopChan := make(chan bool)
if options.SystemUUID == "" {
//System UUID not passed in. Use random key
options.SystemUUID = uuid.New().String()
}
//Generate a session store for stickySession
store := sessions.NewCookieStore([]byte(options.SystemUUID))
return &RouteManager{
SessionStore: store,
LoadBalanceMap: sync.Map{},
OnlineStatusMap: sync.Map{},
onlineStatusTickerStop: onlineStatusCheckerStopChan,
onlineStatusTickerStop: nil,
Options: *options,
}
}

View File

@ -16,10 +16,6 @@ func (m *RouteManager) IsTargetOnline(matchingDomainOrIp string) bool {
return ok && isOnline
}
func (m *RouteManager) SetTargetOffline() {
}
// Ping a target to see if it is online
func PingTarget(targetMatchingDomainOrIp string, requireTLS bool) bool {
client := &http.Client{
@ -41,30 +37,3 @@ func PingTarget(targetMatchingDomainOrIp string, requireTLS bool) bool {
return resp.StatusCode >= 200 && resp.StatusCode <= 600
}
// StartHeartbeats start pinging each server every minutes to make sure all targets are online
// Active mode only
/*
func (m *RouteManager) StartHeartbeats(pingTargets []*FallbackProxyTarget) {
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
fmt.Println("Heartbeat started")
go func() {
for {
select {
case <-m.onlineStatusTickerStop:
ticker.Stop()
return
case <-ticker.C:
for _, target := range pingTargets {
go func(target *FallbackProxyTarget) {
isOnline := PingTarget(target.MatchingDomainOrIp, target.RequireTLS)
m.LoadBalanceMap.Store(target.MatchingDomainOrIp, isOnline)
}(target)
}
}
}
}()
}
*/

View File

@ -3,6 +3,8 @@ package loadbalance
import (
"errors"
"fmt"
"log"
"math/rand"
"net/http"
)
@ -15,12 +17,138 @@ import (
// GetRequestUpstreamTarget return the upstream target where this
// request should be routed
func (m *RouteManager) GetRequestUpstreamTarget(r *http.Request, origins []*Upstream) (*Upstream, error) {
func (m *RouteManager) GetRequestUpstreamTarget(w http.ResponseWriter, r *http.Request, origins []*Upstream, useStickySession bool) (*Upstream, error) {
if len(origins) == 0 {
return nil, errors.New("no upstream is defined for this host")
}
var targetOrigin = origins[0]
if useStickySession {
//Use stick session, check which origins this request previously used
targetOriginId, err := m.getSessionHandler(r, origins)
if err != nil {
//No valid session found. Assign a new upstream
targetOrigin, index, err := getRandomUpstreamByWeight(origins)
if err != nil {
fmt.Println("Oops. Unable to get random upstream")
targetOrigin = origins[0]
index = 0
}
m.setSessionHandler(w, r, targetOrigin.OriginIpOrDomain, index)
return targetOrigin, nil
}
//TODO: Add upstream picking algorithm here
fmt.Println("DEBUG: Picking origin " + origins[0].OriginIpOrDomain)
return origins[0], nil
//Valid session found. Resume the previous session
return origins[targetOriginId], nil
} else {
//Do not use stick session. Get a random one
var err error
targetOrigin, _, err = getRandomUpstreamByWeight(origins)
if err != nil {
log.Println(err)
targetOrigin = origins[0]
}
}
//fmt.Println("DEBUG: Picking origin " + targetOrigin.OriginIpOrDomain)
return targetOrigin, nil
}
/* Features related to session access */
//Set a new origin for this connection by session
func (m *RouteManager) setSessionHandler(w http.ResponseWriter, r *http.Request, originIpOrDomain string, index int) error {
session, err := m.SessionStore.Get(r, "STICKYSESSION")
if err != nil {
return err
}
session.Values["zr_sid_origin"] = originIpOrDomain
session.Values["zr_sid_index"] = index
session.Options.MaxAge = 86400 //1 day
session.Options.Path = "/"
err = session.Save(r, w)
if err != nil {
return err
}
return nil
}
// Get the previous connected origin from session
func (m *RouteManager) getSessionHandler(r *http.Request, upstreams []*Upstream) (int, error) {
// Get existing session
session, err := m.SessionStore.Get(r, "STICKYSESSION")
if err != nil {
return -1, err
}
// Retrieve session values for origin
originDomainRaw := session.Values["zr_sid_origin"]
originIDRaw := session.Values["zr_sid_index"]
if originDomainRaw == nil || originIDRaw == nil {
return -1, errors.New("no session has been set")
}
originDomain := originDomainRaw.(string)
originID := originIDRaw.(int)
//Check if it has been modified
if len(upstreams) < originID || upstreams[originID].OriginIpOrDomain != originDomain {
//Mismatch or upstreams has been updated
return -1, errors.New("upstreams has been changed")
}
return originID, nil
}
/* Functions related to random upstream picking */
// Get a random upstream by the weights defined in Upstream struct, return the upstream, index value and any error
func getRandomUpstreamByWeight(upstreams []*Upstream) (*Upstream, int, error) {
var ret *Upstream
sum := 0
for _, c := range upstreams {
sum += c.Weight
}
r, err := intRange(0, sum)
if err != nil {
return ret, -1, err
}
counter := 0
for _, c := range upstreams {
r -= c.Weight
if r < 0 {
return c, counter, nil
}
counter++
}
if ret == nil {
//All fallback
//use the first one that is with weight = 0
fallbackUpstreams := []*Upstream{}
fallbackUpstreamsOriginalID := []int{}
for ix, upstream := range upstreams {
if upstream.Weight == 0 {
fallbackUpstreams = append(fallbackUpstreams, upstream)
fallbackUpstreamsOriginalID = append(fallbackUpstreamsOriginalID, ix)
}
}
upstreamID := rand.Intn(len(fallbackUpstreams))
return fallbackUpstreams[upstreamID], fallbackUpstreamsOriginalID[upstreamID], nil
}
return ret, -1, errors.New("failed to pick an upstream origin server")
}
// IntRange returns a random integer in the range from min to max.
func intRange(min, max int) (int, error) {
var result int
switch {
case min > max:
// Fail with error
return result, errors.New("min is greater than max")
case max == min:
result = max
case max > min:
b := rand.Intn(max-min) + min
result = min + int(b)
}
return result, nil
}

View File

@ -112,7 +112,7 @@ func (router *Router) rewriteURL(rooturl string, requestURL string) string {
func (h *ProxyHandler) hostRequest(w http.ResponseWriter, r *http.Request, target *ProxyEndpoint) {
r.Header.Set("X-Forwarded-Host", r.Host)
r.Header.Set("X-Forwarded-Server", "zoraxy-"+h.Parent.Option.HostUUID)
selectedUpstream, err := h.Parent.loadBalancer.GetRequestUpstreamTarget(r, target.ActiveOrigins)
selectedUpstream, err := h.Parent.loadBalancer.GetRequestUpstreamTarget(w, r, target.ActiveOrigins, target.UseStickySession)
if err != nil {
http.ServeFile(w, r, "./web/rperror.html")
log.Println(err.Error())
@ -164,6 +164,7 @@ func (h *ProxyHandler) hostRequest(w http.ResponseWriter, r *http.Request, targe
PathPrefix: "",
UpstreamHeaders: upstreamHeaders,
DownstreamHeaders: downstreamHeaders,
NoRemoveHopByHop: target.DisableHopByHopHeaderRemoval,
Version: target.parent.Option.HostVersion,
})

View File

@ -133,6 +133,7 @@ type ProxyEndpoint struct {
HSTSMaxAge int64 //HSTS max age, set to 0 for disable HSTS headers
EnablePermissionPolicyHeader bool //Enable injection of permission policy header
PermissionPolicy *permissionpolicy.PermissionsPolicy //Permission policy header
DisableHopByHopHeaderRemoval bool //TODO: Do not remove hop-by-hop headers
//Authentication
RequireBasicAuth bool //Set to true to request basic auth before proxy

View File

@ -9,21 +9,52 @@ package update
import (
"fmt"
"os"
"strconv"
"strings"
v308 "imuslab.com/zoraxy/mod/update/v308"
"imuslab.com/zoraxy/mod/utils"
)
// Run config update. Version numbers are int. For example
// to update 3.0.7 to 3.0.8, use RunConfigUpdate(307, 308)
// This function support cross versions updates (e.g. 307 -> 310)
func RunConfigUpdate(fromVersion int, toVersion int) {
versionFile := "./conf/version"
if fromVersion == 0 {
//Run auto previous version detection
fromVersion = 307
if utils.FileExists(versionFile) {
//Read the version file
previousVersionText, err := os.ReadFile(versionFile)
if err != nil {
panic("Unable to read version file at " + versionFile)
}
//Convert the version to int
versionInt, err := strconv.Atoi(strings.TrimSpace(string(previousVersionText)))
if err != nil {
panic("Unable to read version file at " + versionFile)
}
fromVersion = versionInt
}
if fromVersion == toVersion {
//No need to update
return
}
}
//Do iterate update
for i := fromVersion; i < toVersion; i++ {
oldVersion := fromVersion
newVersion := fromVersion + 1
fmt.Println("Updating from v", oldVersion, " to v", newVersion)
runUpdateRoutineWithVersion(oldVersion, newVersion)
//Write the updated version to file
os.WriteFile(versionFile, []byte(strconv.Itoa(newVersion)), 0775)
}
fmt.Println("Update completed")
}
@ -36,6 +67,7 @@ func GetVersionIntFromVersionNumber(version string) int {
func runUpdateRoutineWithVersion(fromVersion int, toVersion int) {
if fromVersion == 307 && toVersion == 308 {
//Updating from v3.0.7 to v3.0.8
err := v308.UpdateFrom307To308()
if err != nil {
panic(err)

View File

@ -207,12 +207,7 @@ func ReverseProxyHandleAddEndpoint(w http.ResponseWriter, r *http.Request) {
useBypassGlobalTLS := bypassGlobalTLS == "true"
//Enable TLS validation?
stv, _ := utils.PostPara(r, "tlsval")
if stv == "" {
stv = "false"
}
skipTlsValidation := (stv == "true")
skipTlsValidation, _ := utils.PostBool(r, "tlsval")
//Get access rule ID
accessRuleID, _ := utils.PostPara(r, "access")
@ -225,12 +220,10 @@ func ReverseProxyHandleAddEndpoint(w http.ResponseWriter, r *http.Request) {
}
// Require basic auth?
rba, _ := utils.PostPara(r, "bauth")
if rba == "" {
rba = "false"
}
requireBasicAuth, _ := utils.PostBool(r, "bauth")
requireBasicAuth := (rba == "true")
//Use sticky session?
useStickySession, _ := utils.PostBool(r, "stickysess")
// Require Rate Limiting?
requireRateLimit := false
@ -328,7 +321,7 @@ func ReverseProxyHandleAddEndpoint(w http.ResponseWriter, r *http.Request) {
},
},
InactiveOrigins: []*loadbalance.Upstream{},
UseStickySession: false, //TODO: Move options to webform
UseStickySession: useStickySession,
//TLS
BypassGlobalTLS: useBypassGlobalTLS,

View File

@ -104,8 +104,9 @@ func startupSequence() {
//Create a load balancer
loadBalancer = loadbalance.NewLoadBalancer(&loadbalance.Options{
Geodb: geodbStore,
Logger: SystemWideLogger,
SystemUUID: nodeUUID,
Geodb: geodbStore,
Logger: SystemWideLogger,
})
//Create the access controller

View File

@ -11,6 +11,17 @@
border-radius: 0.6em;
padding: 1em;
}
.descheader{
display:none !important;
}
@media (min-width: 1367px) {
.descheader{
display:auto !important;
}
}
</style>
<div class="standardContainer">
<div class="ui stackable grid">
@ -47,16 +58,14 @@
</div>
<div class="content">
<div class="field">
<label>Access Rule</label>
<div class="ui selection dropdown">
<input type="hidden" id="newProxyRuleAccessFilter" value="default">
<i class="dropdown icon"></i>
<div class="default text">Default</div>
<div class="menu" id="newProxyRuleAccessList">
<div class="item" data-value="default"><i class="ui yellow star icon"></i> Default</div>
</div>
<div class="ui checkbox">
<input type="checkbox" id="useStickySessionLB">
<label>Sticky Session<br><small>Enable stick session on upstream load balancing</small></label>
</div>
<small>Allow regional access control using blacklist or whitelist. Use "default" for "allow all".</small>
</div>
<div class="ui horizontal divider">
<i class="ui green lock icon"></i>
Security
</div>
<div class="field">
<div class="ui checkbox">
@ -76,27 +85,21 @@
<label>Allow plain HTTP access<br><small>Allow this subdomain to be connected without TLS (Require HTTP server enabled on port 80)</small></label>
</div>
</div>
<div class="field">
<div class="ui checkbox">
<input type="checkbox" id="useStickySessionLB">
<label>Sticky Session<br><small>Enable stick session on upstream load balancing</small></label>
</div>
<div class="ui horizontal divider">
<i class="ui red ban icon"></i>
Access Control
</div>
<div class="field">
<div class="ui checkbox">
<input type="checkbox" id="requireRateLimit">
<label>Require Rate Limit<br><small>This proxy endpoint will be rate limited.</small></label>
</div>
</div>
<div class="field">
<label>Rate Limit</label>
<div class="ui fluid right labeled input">
<input type="number" id="proxyRateLimit" placeholder="100" min="1" max="1000" value="100">
<div class="ui basic label">
req / sec / IP
<label>Access Rule</label>
<div class="ui selection dropdown">
<input type="hidden" id="newProxyRuleAccessFilter" value="default">
<i class="dropdown icon"></i>
<div class="default text">Default</div>
<div class="menu" id="newProxyRuleAccessList">
<div class="item" data-value="default"><i class="ui yellow star icon"></i> Default</div>
</div>
</div>
<small>Return a 429 error code if request rate exceed the rate limit.</small>
<small>Allow regional access control using blacklist or whitelist. Use "default" for "allow all".</small>
</div>
<div class="field">
<div class="ui checkbox">
@ -131,6 +134,22 @@
</div>
</div>
</div>
<div class="field">
<div class="ui checkbox">
<input type="checkbox" id="requireRateLimit">
<label>Require Rate Limit<br><small>This proxy endpoint will be rate limited.</small></label>
</div>
</div>
<div class="field">
<label>Rate Limit</label>
<div class="ui fluid right labeled input">
<input type="number" id="proxyRateLimit" placeholder="100" min="1" max="1000" value="100">
<div class="ui basic label">
req / sec / IP
</div>
</div>
<small>Return a 429 error code if request rate exceed the rate limit.</small>
</div>
</div>
</div>
</div>
@ -166,17 +185,18 @@
//New Proxy Endpoint
function newProxyEndpoint(){
var rootname = $("#rootname").val();
var proxyDomain = $("#proxyDomain").val();
var useTLS = $("#reqTls")[0].checked;
var skipTLSValidation = $("#skipTLSValidation")[0].checked;
var bypassGlobalTLS = $("#bypassGlobalTLS")[0].checked;
var requireBasicAuth = $("#requireBasicAuth")[0].checked;
var proxyRateLimit = $("#proxyRateLimit").val();
var requireRateLimit = $("#requireRateLimit")[0].checked;
var skipWebSocketOriginCheck = $("#skipWebsocketOriginCheck")[0].checked;
var accessRuleToUse = $("#newProxyRuleAccessFilter").val();
let rootname = $("#rootname").val();
let proxyDomain = $("#proxyDomain").val();
let useTLS = $("#reqTls")[0].checked;
let skipTLSValidation = $("#skipTLSValidation")[0].checked;
let bypassGlobalTLS = $("#bypassGlobalTLS")[0].checked;
let requireBasicAuth = $("#requireBasicAuth")[0].checked;
let proxyRateLimit = $("#proxyRateLimit").val();
let requireRateLimit = $("#requireRateLimit")[0].checked;
let skipWebSocketOriginCheck = $("#skipWebsocketOriginCheck")[0].checked;
let accessRuleToUse = $("#newProxyRuleAccessFilter").val();
let useStickySessionLB = $("#useStickySessionLB")[0].checked;
if (rootname.trim() == ""){
$("#rootname").parent().addClass("error");
return
@ -207,6 +227,7 @@
ratenum: proxyRateLimit,
cred: JSON.stringify(credentials),
access: accessRuleToUse,
stickysess: useStickySessionLB,
},
success: function(data){
if (data.error != undefined){