️ immediate return if single upstream

This commit is contained in:
bouroo 2024-07-22 23:39:47 +07:00 committed by Kawin Viriyaprasopsook
parent 0dddd1f9e3
commit bec363abab
3 changed files with 38 additions and 29 deletions

View File

@ -102,39 +102,48 @@ func (m *RouteManager) getSessionHandler(r *http.Request, upstreams []*Upstream)
/* Functions related to random upstream picking */ /* Functions related to random upstream picking */
// Get a random upstream by the weights defined in Upstream struct, return the upstream, index value and any error // Get a random upstream by the weights defined in Upstream struct, return the upstream, index value and any error
func getRandomUpstreamByWeight(upstreams []*Upstream) (*Upstream, int, error) { func getRandomUpstreamByWeight(upstreams []*Upstream) (*Upstream, int, error) {
var ret *Upstream totalUpstreams := len(upstreams)
sum := 0 if totalUpstreams == 1 {
for _, c := range upstreams { return upstreams[0], 0, nil
sum += c.Weight
} }
r, err := intRange(0, sum) if totalUpstreams == 0 {
if err != nil { return nil, -1, errors.New("no upstream servers available")
return ret, -1, err
}
counter := 0
for _, c := range upstreams {
r -= c.Weight
if r < 0 {
return c, counter, nil
}
counter++
} }
if ret == nil { totalWeight := 0
//All fallback fallbackUpstreams := make([]*Upstream, 0) // List of upstreams with weight 0
//use the first one that is with weight = 0 fallbackUpstreamsOriginalID := make([]int, 0)
fallbackUpstreams := []*Upstream{}
fallbackUpstreamsOriginalID := []int{} // Calculate total weight and gather fallbacks
for ix, upstream := range upstreams { for ix, upstream := range upstreams {
if upstream.Weight == 0 { totalWeight += upstream.Weight
fallbackUpstreams = append(fallbackUpstreams, upstream) if upstream.Weight == 0 {
fallbackUpstreamsOriginalID = append(fallbackUpstreamsOriginalID, ix) fallbackUpstreams = append(fallbackUpstreams, upstream)
} fallbackUpstreamsOriginalID = append(fallbackUpstreamsOriginalID, ix)
}
}
if totalWeight == 0 {
// If total weight is 0, fallback to a random upstream with weight 0
if len(fallbackUpstreams) == 0 {
return nil, -1, errors.New("no valid upstream servers available")
} }
upstreamID := rand.Intn(len(fallbackUpstreams)) upstreamID := rand.Intn(len(fallbackUpstreams))
return fallbackUpstreams[upstreamID], fallbackUpstreamsOriginalID[upstreamID], nil return fallbackUpstreams[upstreamID], fallbackUpstreamsOriginalID[upstreamID], nil
} }
return ret, -1, errors.New("failed to pick an upstream origin server")
// Generate a random number in the range of total weight
r := rand.Intn(totalWeight)
// Select upstream based on random number
for i, upstream := range upstreams {
r -= upstream.Weight
if r < 0 {
return upstream, i, nil
}
}
return nil, -1, errors.New("failed to pick an upstream origin server")
} }
// IntRange returns a random integer in the range from min to max. // IntRange returns a random integer in the range from min to max.

View File

@ -199,4 +199,4 @@ func ValidateListeningAddress(address string) bool {
} }
return true return true
} }

View File

@ -34,7 +34,7 @@ func ReverseProxyUpstreamList(w http.ResponseWriter, r *http.Request) {
activeUpstreams := targetEndpoint.ActiveOrigins activeUpstreams := targetEndpoint.ActiveOrigins
inactiveUpstreams := targetEndpoint.InactiveOrigins inactiveUpstreams := targetEndpoint.InactiveOrigins
slices.SortFunc(activeUpstreams, func(i, j *loadbalance.Upstream) int { slices.SortStableFunc(activeUpstreams, func(i, j *loadbalance.Upstream) int {
if i.Weight != j.Weight { if i.Weight != j.Weight {
// sort by weight DESC // sort by weight DESC
return cmp.Compare(j.Weight, i.Weight) return cmp.Compare(j.Weight, i.Weight)
@ -43,7 +43,7 @@ func ReverseProxyUpstreamList(w http.ResponseWriter, r *http.Request) {
return cmp.Compare(i.OriginIpOrDomain, j.OriginIpOrDomain) return cmp.Compare(i.OriginIpOrDomain, j.OriginIpOrDomain)
}) })
slices.SortFunc(inactiveUpstreams, func(i, j *loadbalance.Upstream) int { slices.SortStableFunc(inactiveUpstreams, func(i, j *loadbalance.Upstream) int {
if i.Weight != j.Weight { if i.Weight != j.Weight {
// sort by weight DESC // sort by weight DESC
return cmp.Compare(j.Weight, i.Weight) return cmp.Compare(j.Weight, i.Weight)