Added load balance origin picker

+ Added load balance picker
+ Added fallback mode for upstream
+ Added stick session
This commit is contained in:
Toby Chui
2024-07-12 20:14:31 +08:00
parent 2aa35cbe6d
commit aca6e44b35
12 changed files with 266 additions and 109 deletions

View File

@@ -64,6 +64,7 @@ type ResponseRewriteRuleSet struct {
PathPrefix string //Vdir prefix for root, / will be rewrite to this
UpstreamHeaders [][]string
DownstreamHeaders [][]string
NoRemoveHopByHop bool //Do not remove hop-by-hop headers, dangerous
Version string //Version number of Zoraxy, use for X-Proxy-By
}

View File

@@ -151,18 +151,19 @@ func (router *Router) StartProxyService() error {
}
}
selectedUpstream, err := router.loadBalancer.GetRequestUpstreamTarget(r, sep.ActiveOrigins)
selectedUpstream, err := router.loadBalancer.GetRequestUpstreamTarget(w, r, sep.ActiveOrigins, sep.UseStickySession)
if err != nil {
http.ServeFile(w, r, "./web/hosterror.html")
log.Println(err.Error())
router.logRequest(r, false, 404, "vdir-http", r.Host)
}
selectedUpstream.ServeHTTP(w, r, &dpcore.ResponseRewriteRuleSet{
ProxyDomain: selectedUpstream.OriginIpOrDomain,
OriginalHost: originalHostHeader,
UseTLS: selectedUpstream.RequireTLS,
PathPrefix: "",
Version: sep.parent.Option.HostVersion,
ProxyDomain: selectedUpstream.OriginIpOrDomain,
OriginalHost: originalHostHeader,
UseTLS: selectedUpstream.RequireTLS,
NoRemoveHopByHop: sep.DisableHopByHopHeaderRemoval,
PathPrefix: "",
Version: sep.parent.Option.HostVersion,
})
return
}

View File

@@ -3,8 +3,9 @@ package loadbalance
import (
"strings"
"sync"
"sync/atomic"
"github.com/google/uuid"
"github.com/gorilla/sessions"
"imuslab.com/zoraxy/mod/dynamicproxy/dpcore"
"imuslab.com/zoraxy/mod/geodb"
"imuslab.com/zoraxy/mod/info/logger"
@@ -17,12 +18,14 @@ import (
*/
type Options struct {
SystemUUID string //Use for the session store
UseActiveHealthCheck bool //Use active health check, default to false
Geodb *geodb.Store //GeoIP resolver for checking incoming request origin country
Logger *logger.Logger
}
type RouteManager struct {
SessionStore *sessions.CookieStore
LoadBalanceMap sync.Map //Sync map to store the last load balance state of a given node
OnlineStatusMap sync.Map //Sync map to store the online status of a given ip address or domain name
onlineStatusTickerStop chan bool //Stopping channel for the online status pinger
@@ -39,20 +42,26 @@ type Upstream struct {
//Load balancing configs
Weight int //Random weight for round robin, 0 for fallback only
MaxConn int //Maxmium connection to this server, 0 for unlimited
MaxConn int //TODO: Maxmium connection to this server, 0 for unlimited
currentConnectionCounts atomic.Uint64 //Counter for number of client currently connected
proxy *dpcore.ReverseProxy
//currentConnectionCounts atomic.Uint64 //Counter for number of client currently connected
proxy *dpcore.ReverseProxy
}
// Create a new load balancer
func NewLoadBalancer(options *Options) *RouteManager {
onlineStatusCheckerStopChan := make(chan bool)
if options.SystemUUID == "" {
//System UUID not passed in. Use random key
options.SystemUUID = uuid.New().String()
}
//Generate a session store for stickySession
store := sessions.NewCookieStore([]byte(options.SystemUUID))
return &RouteManager{
SessionStore: store,
LoadBalanceMap: sync.Map{},
OnlineStatusMap: sync.Map{},
onlineStatusTickerStop: onlineStatusCheckerStopChan,
onlineStatusTickerStop: nil,
Options: *options,
}
}

View File

@@ -16,10 +16,6 @@ func (m *RouteManager) IsTargetOnline(matchingDomainOrIp string) bool {
return ok && isOnline
}
func (m *RouteManager) SetTargetOffline() {
}
// Ping a target to see if it is online
func PingTarget(targetMatchingDomainOrIp string, requireTLS bool) bool {
client := &http.Client{
@@ -41,30 +37,3 @@ func PingTarget(targetMatchingDomainOrIp string, requireTLS bool) bool {
return resp.StatusCode >= 200 && resp.StatusCode <= 600
}
// StartHeartbeats start pinging each server every minutes to make sure all targets are online
// Active mode only
/*
func (m *RouteManager) StartHeartbeats(pingTargets []*FallbackProxyTarget) {
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
fmt.Println("Heartbeat started")
go func() {
for {
select {
case <-m.onlineStatusTickerStop:
ticker.Stop()
return
case <-ticker.C:
for _, target := range pingTargets {
go func(target *FallbackProxyTarget) {
isOnline := PingTarget(target.MatchingDomainOrIp, target.RequireTLS)
m.LoadBalanceMap.Store(target.MatchingDomainOrIp, isOnline)
}(target)
}
}
}
}()
}
*/

View File

@@ -3,6 +3,8 @@ package loadbalance
import (
"errors"
"fmt"
"log"
"math/rand"
"net/http"
)
@@ -15,12 +17,138 @@ import (
// GetRequestUpstreamTarget return the upstream target where this
// request should be routed
func (m *RouteManager) GetRequestUpstreamTarget(r *http.Request, origins []*Upstream) (*Upstream, error) {
func (m *RouteManager) GetRequestUpstreamTarget(w http.ResponseWriter, r *http.Request, origins []*Upstream, useStickySession bool) (*Upstream, error) {
if len(origins) == 0 {
return nil, errors.New("no upstream is defined for this host")
}
var targetOrigin = origins[0]
if useStickySession {
//Use stick session, check which origins this request previously used
targetOriginId, err := m.getSessionHandler(r, origins)
if err != nil {
//No valid session found. Assign a new upstream
targetOrigin, index, err := getRandomUpstreamByWeight(origins)
if err != nil {
fmt.Println("Oops. Unable to get random upstream")
targetOrigin = origins[0]
index = 0
}
m.setSessionHandler(w, r, targetOrigin.OriginIpOrDomain, index)
return targetOrigin, nil
}
//TODO: Add upstream picking algorithm here
fmt.Println("DEBUG: Picking origin " + origins[0].OriginIpOrDomain)
return origins[0], nil
//Valid session found. Resume the previous session
return origins[targetOriginId], nil
} else {
//Do not use stick session. Get a random one
var err error
targetOrigin, _, err = getRandomUpstreamByWeight(origins)
if err != nil {
log.Println(err)
targetOrigin = origins[0]
}
}
//fmt.Println("DEBUG: Picking origin " + targetOrigin.OriginIpOrDomain)
return targetOrigin, nil
}
/* Features related to session access */
//Set a new origin for this connection by session
func (m *RouteManager) setSessionHandler(w http.ResponseWriter, r *http.Request, originIpOrDomain string, index int) error {
session, err := m.SessionStore.Get(r, "STICKYSESSION")
if err != nil {
return err
}
session.Values["zr_sid_origin"] = originIpOrDomain
session.Values["zr_sid_index"] = index
session.Options.MaxAge = 86400 //1 day
session.Options.Path = "/"
err = session.Save(r, w)
if err != nil {
return err
}
return nil
}
// Get the previous connected origin from session
func (m *RouteManager) getSessionHandler(r *http.Request, upstreams []*Upstream) (int, error) {
// Get existing session
session, err := m.SessionStore.Get(r, "STICKYSESSION")
if err != nil {
return -1, err
}
// Retrieve session values for origin
originDomainRaw := session.Values["zr_sid_origin"]
originIDRaw := session.Values["zr_sid_index"]
if originDomainRaw == nil || originIDRaw == nil {
return -1, errors.New("no session has been set")
}
originDomain := originDomainRaw.(string)
originID := originIDRaw.(int)
//Check if it has been modified
if len(upstreams) < originID || upstreams[originID].OriginIpOrDomain != originDomain {
//Mismatch or upstreams has been updated
return -1, errors.New("upstreams has been changed")
}
return originID, nil
}
/* Functions related to random upstream picking */
// Get a random upstream by the weights defined in Upstream struct, return the upstream, index value and any error
func getRandomUpstreamByWeight(upstreams []*Upstream) (*Upstream, int, error) {
var ret *Upstream
sum := 0
for _, c := range upstreams {
sum += c.Weight
}
r, err := intRange(0, sum)
if err != nil {
return ret, -1, err
}
counter := 0
for _, c := range upstreams {
r -= c.Weight
if r < 0 {
return c, counter, nil
}
counter++
}
if ret == nil {
//All fallback
//use the first one that is with weight = 0
fallbackUpstreams := []*Upstream{}
fallbackUpstreamsOriginalID := []int{}
for ix, upstream := range upstreams {
if upstream.Weight == 0 {
fallbackUpstreams = append(fallbackUpstreams, upstream)
fallbackUpstreamsOriginalID = append(fallbackUpstreamsOriginalID, ix)
}
}
upstreamID := rand.Intn(len(fallbackUpstreams))
return fallbackUpstreams[upstreamID], fallbackUpstreamsOriginalID[upstreamID], nil
}
return ret, -1, errors.New("failed to pick an upstream origin server")
}
// IntRange returns a random integer in the range from min to max.
func intRange(min, max int) (int, error) {
var result int
switch {
case min > max:
// Fail with error
return result, errors.New("min is greater than max")
case max == min:
result = max
case max > min:
b := rand.Intn(max-min) + min
result = min + int(b)
}
return result, nil
}

View File

@@ -112,7 +112,7 @@ func (router *Router) rewriteURL(rooturl string, requestURL string) string {
func (h *ProxyHandler) hostRequest(w http.ResponseWriter, r *http.Request, target *ProxyEndpoint) {
r.Header.Set("X-Forwarded-Host", r.Host)
r.Header.Set("X-Forwarded-Server", "zoraxy-"+h.Parent.Option.HostUUID)
selectedUpstream, err := h.Parent.loadBalancer.GetRequestUpstreamTarget(r, target.ActiveOrigins)
selectedUpstream, err := h.Parent.loadBalancer.GetRequestUpstreamTarget(w, r, target.ActiveOrigins, target.UseStickySession)
if err != nil {
http.ServeFile(w, r, "./web/rperror.html")
log.Println(err.Error())
@@ -164,6 +164,7 @@ func (h *ProxyHandler) hostRequest(w http.ResponseWriter, r *http.Request, targe
PathPrefix: "",
UpstreamHeaders: upstreamHeaders,
DownstreamHeaders: downstreamHeaders,
NoRemoveHopByHop: target.DisableHopByHopHeaderRemoval,
Version: target.parent.Option.HostVersion,
})

View File

@@ -133,6 +133,7 @@ type ProxyEndpoint struct {
HSTSMaxAge int64 //HSTS max age, set to 0 for disable HSTS headers
EnablePermissionPolicyHeader bool //Enable injection of permission policy header
PermissionPolicy *permissionpolicy.PermissionsPolicy //Permission policy header
DisableHopByHopHeaderRemoval bool //TODO: Do not remove hop-by-hop headers
//Authentication
RequireBasicAuth bool //Set to true to request basic auth before proxy