diff --git a/src/mod/dynamicproxy/loadbalance/originPicker.go b/src/mod/dynamicproxy/loadbalance/originPicker.go index 51bcc2b..ad5ddc4 100644 --- a/src/mod/dynamicproxy/loadbalance/originPicker.go +++ b/src/mod/dynamicproxy/loadbalance/originPicker.go @@ -102,39 +102,48 @@ func (m *RouteManager) getSessionHandler(r *http.Request, upstreams []*Upstream) /* Functions related to random upstream picking */ // Get a random upstream by the weights defined in Upstream struct, return the upstream, index value and any error func getRandomUpstreamByWeight(upstreams []*Upstream) (*Upstream, int, error) { - var ret *Upstream - sum := 0 - for _, c := range upstreams { - sum += c.Weight + totalUpstreams := len(upstreams) + if totalUpstreams == 1 { + return upstreams[0], 0, nil } - r, err := intRange(0, sum) - if err != nil { - return ret, -1, err - } - counter := 0 - for _, c := range upstreams { - r -= c.Weight - if r < 0 { - return c, counter, nil - } - counter++ + if totalUpstreams == 0 { + return nil, -1, errors.New("no upstream servers available") } - if ret == nil { - //All fallback - //use the first one that is with weight = 0 - fallbackUpstreams := []*Upstream{} - fallbackUpstreamsOriginalID := []int{} - for ix, upstream := range upstreams { - if upstream.Weight == 0 { - fallbackUpstreams = append(fallbackUpstreams, upstream) - fallbackUpstreamsOriginalID = append(fallbackUpstreamsOriginalID, ix) - } + totalWeight := 0 + fallbackUpstreams := make([]*Upstream, 0) // List of upstreams with weight 0 + fallbackUpstreamsOriginalID := make([]int, 0) + + // Calculate total weight and gather fallbacks + for ix, upstream := range upstreams { + totalWeight += upstream.Weight + if upstream.Weight == 0 { + fallbackUpstreams = append(fallbackUpstreams, upstream) + fallbackUpstreamsOriginalID = append(fallbackUpstreamsOriginalID, ix) + } + } + + if totalWeight == 0 { + // If total weight is 0, fallback to a random upstream with weight 0 + if len(fallbackUpstreams) == 0 { + return nil, -1, errors.New("no valid upstream servers available") } upstreamID := rand.Intn(len(fallbackUpstreams)) return fallbackUpstreams[upstreamID], fallbackUpstreamsOriginalID[upstreamID], nil } - return ret, -1, errors.New("failed to pick an upstream origin server") + + // Generate a random number in the range of total weight + r := rand.Intn(totalWeight) + + // Select upstream based on random number + for i, upstream := range upstreams { + r -= upstream.Weight + if r < 0 { + return upstream, i, nil + } + } + + return nil, -1, errors.New("failed to pick an upstream origin server") } // IntRange returns a random integer in the range from min to max. diff --git a/src/mod/utils/utils.go b/src/mod/utils/utils.go index 21d2e40..2fe1ffd 100644 --- a/src/mod/utils/utils.go +++ b/src/mod/utils/utils.go @@ -199,4 +199,4 @@ func ValidateListeningAddress(address string) bool { } return true -} +} \ No newline at end of file diff --git a/src/upstreams.go b/src/upstreams.go index dd53c83..8241f04 100644 --- a/src/upstreams.go +++ b/src/upstreams.go @@ -34,7 +34,7 @@ func ReverseProxyUpstreamList(w http.ResponseWriter, r *http.Request) { activeUpstreams := targetEndpoint.ActiveOrigins inactiveUpstreams := targetEndpoint.InactiveOrigins - slices.SortFunc(activeUpstreams, func(i, j *loadbalance.Upstream) int { + slices.SortStableFunc(activeUpstreams, func(i, j *loadbalance.Upstream) int { if i.Weight != j.Weight { // sort by weight DESC return cmp.Compare(j.Weight, i.Weight) @@ -43,7 +43,7 @@ func ReverseProxyUpstreamList(w http.ResponseWriter, r *http.Request) { return cmp.Compare(i.OriginIpOrDomain, j.OriginIpOrDomain) }) - slices.SortFunc(inactiveUpstreams, func(i, j *loadbalance.Upstream) int { + slices.SortStableFunc(inactiveUpstreams, func(i, j *loadbalance.Upstream) int { if i.Weight != j.Weight { // sort by weight DESC return cmp.Compare(j.Weight, i.Weight)