diff --git a/src/mod/dynamicproxy/loadbalance/originPicker.go b/src/mod/dynamicproxy/loadbalance/originPicker.go index 51bcc2b..ad77472 100644 --- a/src/mod/dynamicproxy/loadbalance/originPicker.go +++ b/src/mod/dynamicproxy/loadbalance/originPicker.go @@ -102,39 +102,62 @@ func (m *RouteManager) getSessionHandler(r *http.Request, upstreams []*Upstream) /* Functions related to random upstream picking */ // Get a random upstream by the weights defined in Upstream struct, return the upstream, index value and any error func getRandomUpstreamByWeight(upstreams []*Upstream) (*Upstream, int, error) { - var ret *Upstream - sum := 0 - for _, c := range upstreams { - sum += c.Weight - } - r, err := intRange(0, sum) - if err != nil { - return ret, -1, err - } - counter := 0 - for _, c := range upstreams { - r -= c.Weight - if r < 0 { - return c, counter, nil - } - counter++ + // If there is only one upstream, return it + if len(upstreams) == 1 { + return upstreams[0], 0, nil } - if ret == nil { - //All fallback - //use the first one that is with weight = 0 - fallbackUpstreams := []*Upstream{} - fallbackUpstreamsOriginalID := []int{} - for ix, upstream := range upstreams { - if upstream.Weight == 0 { - fallbackUpstreams = append(fallbackUpstreams, upstream) - fallbackUpstreamsOriginalID = append(fallbackUpstreamsOriginalID, ix) - } - } - upstreamID := rand.Intn(len(fallbackUpstreams)) - return fallbackUpstreams[upstreamID], fallbackUpstreamsOriginalID[upstreamID], nil + // Preserve the index with upstreams + type upstreamWithIndex struct { + Upstream *Upstream + Index int } - return ret, -1, errors.New("failed to pick an upstream origin server") + + // Calculate total weight for upstreams with weight > 0 + totalWeight := 0 + fallbackUpstreams := make([]upstreamWithIndex, 0, len(upstreams)) + + for index, upstream := range upstreams { + if upstream.Weight > 0 { + totalWeight += upstream.Weight + } else { + // Collect fallback upstreams + fallbackUpstreams = append(fallbackUpstreams, upstreamWithIndex{upstream, index}) + } + } + + // If there are no upstreams with weight > 0, return a fallback upstream if available + if totalWeight == 0 { + if len(fallbackUpstreams) > 0 { + // Randomly select one of the fallback upstreams + randIndex := rand.Intn(len(fallbackUpstreams)) + return fallbackUpstreams[randIndex].Upstream, fallbackUpstreams[randIndex].Index, nil + } + // No upstreams available at all + return nil, -1, errors.New("no valid upstream servers available") + } + + // Random weight between 0 and total weight + randomWeight := rand.Intn(totalWeight) + + // Select an upstream based on the random weight + for index, upstream := range upstreams { + if upstream.Weight > 0 { // Only consider upstreams with weight > 0 + if randomWeight < upstream.Weight { + // Return the selected upstream and its index + return upstream, index, nil + } + randomWeight -= upstream.Weight + } + } + + // If we reach here, it means we should return a fallback upstream if available + if len(fallbackUpstreams) > 0 { + randIndex := rand.Intn(len(fallbackUpstreams)) + return fallbackUpstreams[randIndex].Upstream, fallbackUpstreams[randIndex].Index, nil + } + + return nil, -1, errors.New("failed to pick an upstream origin server") } // IntRange returns a random integer in the range from min to max. diff --git a/src/mod/utils/utils.go b/src/mod/utils/utils.go index a61d5ed..2fe1ffd 100644 --- a/src/mod/utils/utils.go +++ b/src/mod/utils/utils.go @@ -41,12 +41,12 @@ func SendOK(w http.ResponseWriter) { // Get GET parameter func GetPara(r *http.Request, key string) (string, error) { - keys, ok := r.URL.Query()[key] - if !ok || len(keys[0]) < 1 { + // Get first value from the URL query + value := r.URL.Query().Get(key) + if len(value) == 0 { return "", errors.New("invalid " + key + " given") - } else { - return keys[0], nil } + return value, nil } // Get GET paramter as boolean, accept 1 or true @@ -56,26 +56,29 @@ func GetBool(r *http.Request, key string) (bool, error) { return false, err } - x = strings.TrimSpace(x) - - if x == "1" || strings.ToLower(x) == "true" || strings.ToLower(x) == "on" { + // Convert to lowercase and trim spaces just once to compare + switch strings.ToLower(strings.TrimSpace(x)) { + case "1", "true", "on": return true, nil - } else if x == "0" || strings.ToLower(x) == "false" || strings.ToLower(x) == "off" { + case "0", "false", "off": return false, nil } return false, errors.New("invalid boolean given") } -// Get POST paramter +// Get POST parameter func PostPara(r *http.Request, key string) (string, error) { - r.ParseForm() - x := r.Form.Get(key) - if x == "" { - return "", errors.New("invalid " + key + " given") - } else { - return x, nil + // Try to parse the form + if err := r.ParseForm(); err != nil { + return "", err } + // Get first value from the form + x := r.Form.Get(key) + if len(x) == 0 { + return "", errors.New("invalid " + key + " given") + } + return x, nil } // Get POST paramter as boolean, accept 1 or true @@ -85,11 +88,11 @@ func PostBool(r *http.Request, key string) (bool, error) { return false, err } - x = strings.TrimSpace(x) - - if x == "1" || strings.ToLower(x) == "true" || strings.ToLower(x) == "on" { + // Convert to lowercase and trim spaces just once to compare + switch strings.ToLower(strings.TrimSpace(x)) { + case "1", "true", "on": return true, nil - } else if x == "0" || strings.ToLower(x) == "false" || strings.ToLower(x) == "off" { + case "0", "false", "off": return false, nil } @@ -114,14 +117,19 @@ func PostInt(r *http.Request, key string) (int, error) { func FileExists(filename string) bool { _, err := os.Stat(filename) - if os.IsNotExist(err) { + if err == nil { + // File exists + return true + } else if errors.Is(err, os.ErrNotExist) { + // File does not exist return false } - return true + // Some other error + return false } func IsDir(path string) bool { - if FileExists(path) == false { + if !FileExists(path) { return false } fi, err := os.Stat(path) @@ -191,4 +199,4 @@ func ValidateListeningAddress(address string) bool { } return true -} +} \ No newline at end of file