diff --git a/src/mod/dynamicproxy/loadbalance/originPicker.go b/src/mod/dynamicproxy/loadbalance/originPicker.go index 8971184..da4423e 100644 --- a/src/mod/dynamicproxy/loadbalance/originPicker.go +++ b/src/mod/dynamicproxy/loadbalance/originPicker.go @@ -107,15 +107,22 @@ func getRandomUpstreamByWeight(upstreams []*Upstream) (*Upstream, int, error) { return upstreams[0], 0, nil } + // Preserve the index with upstreams + type upstreamWithIndex struct { + Upstream *Upstream + Index int + } + // Calculate total weight for upstreams with weight > 0 totalWeight := 0 - fallbackUpstreams := make([]*Upstream, 0) + fallbackUpstreams := make([]upstreamWithIndex, 0, len(upstreams)) - for _, upstream := range upstreams { + for index, upstream := range upstreams { if upstream.Weight > 0 { totalWeight += upstream.Weight } else { - fallbackUpstreams = append(fallbackUpstreams, upstream) // Collect fallback upstreams + // Collect fallback upstreams + fallbackUpstreams = append(fallbackUpstreams, upstreamWithIndex{upstream, index}) } } @@ -124,7 +131,7 @@ func getRandomUpstreamByWeight(upstreams []*Upstream) (*Upstream, int, error) { if len(fallbackUpstreams) > 0 { // Randomly select one of the fallback upstreams index := rand.Intn(len(fallbackUpstreams)) - return fallbackUpstreams[index], index, nil + return fallbackUpstreams[index].Upstream, fallbackUpstreams[index].Index, nil } // No upstreams available at all return nil, -1, errors.New("no valid upstream servers available") @@ -134,10 +141,11 @@ func getRandomUpstreamByWeight(upstreams []*Upstream) (*Upstream, int, error) { randomWeight := rand.Intn(totalWeight) // Select an upstream based on the random weight - for i, upstream := range upstreams { + for index, upstream := range upstreams { if upstream.Weight > 0 { // Only consider upstreams with weight > 0 if randomWeight < upstream.Weight { - return upstream, i, nil // Return the selected upstream and its index + // Return the selected upstream and its index + return upstream, index, nil } randomWeight -= upstream.Weight } @@ -146,7 +154,7 @@ func getRandomUpstreamByWeight(upstreams []*Upstream) (*Upstream, int, error) { // If we reach here, it means we should return a fallback upstream if available if len(fallbackUpstreams) > 0 { index := rand.Intn(len(fallbackUpstreams)) - return fallbackUpstreams[index], index, nil + return fallbackUpstreams[index].Upstream, fallbackUpstreams[index].Index, nil } return nil, -1, errors.New("failed to pick an upstream origin server")