Merge pull request #258 from bouroo/perf/upstreams-sortfunc

weighted random upstream
This commit is contained in:
Toby Chui 2024-08-19 15:39:22 +08:00 committed by GitHub
commit b558bcbfcf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 84 additions and 53 deletions

View File

@ -102,39 +102,62 @@ func (m *RouteManager) getSessionHandler(r *http.Request, upstreams []*Upstream)
/* Functions related to random upstream picking */ /* Functions related to random upstream picking */
// Get a random upstream by the weights defined in Upstream struct, return the upstream, index value and any error // Get a random upstream by the weights defined in Upstream struct, return the upstream, index value and any error
func getRandomUpstreamByWeight(upstreams []*Upstream) (*Upstream, int, error) { func getRandomUpstreamByWeight(upstreams []*Upstream) (*Upstream, int, error) {
var ret *Upstream // If there is only one upstream, return it
sum := 0 if len(upstreams) == 1 {
for _, c := range upstreams { return upstreams[0], 0, nil
sum += c.Weight
}
r, err := intRange(0, sum)
if err != nil {
return ret, -1, err
}
counter := 0
for _, c := range upstreams {
r -= c.Weight
if r < 0 {
return c, counter, nil
}
counter++
} }
if ret == nil { // Preserve the index with upstreams
//All fallback type upstreamWithIndex struct {
//use the first one that is with weight = 0 Upstream *Upstream
fallbackUpstreams := []*Upstream{} Index int
fallbackUpstreamsOriginalID := []int{}
for ix, upstream := range upstreams {
if upstream.Weight == 0 {
fallbackUpstreams = append(fallbackUpstreams, upstream)
fallbackUpstreamsOriginalID = append(fallbackUpstreamsOriginalID, ix)
}
}
upstreamID := rand.Intn(len(fallbackUpstreams))
return fallbackUpstreams[upstreamID], fallbackUpstreamsOriginalID[upstreamID], nil
} }
return ret, -1, errors.New("failed to pick an upstream origin server")
// Calculate total weight for upstreams with weight > 0
totalWeight := 0
fallbackUpstreams := make([]upstreamWithIndex, 0, len(upstreams))
for index, upstream := range upstreams {
if upstream.Weight > 0 {
totalWeight += upstream.Weight
} else {
// Collect fallback upstreams
fallbackUpstreams = append(fallbackUpstreams, upstreamWithIndex{upstream, index})
}
}
// If there are no upstreams with weight > 0, return a fallback upstream if available
if totalWeight == 0 {
if len(fallbackUpstreams) > 0 {
// Randomly select one of the fallback upstreams
randIndex := rand.Intn(len(fallbackUpstreams))
return fallbackUpstreams[randIndex].Upstream, fallbackUpstreams[randIndex].Index, nil
}
// No upstreams available at all
return nil, -1, errors.New("no valid upstream servers available")
}
// Random weight between 0 and total weight
randomWeight := rand.Intn(totalWeight)
// Select an upstream based on the random weight
for index, upstream := range upstreams {
if upstream.Weight > 0 { // Only consider upstreams with weight > 0
if randomWeight < upstream.Weight {
// Return the selected upstream and its index
return upstream, index, nil
}
randomWeight -= upstream.Weight
}
}
// If we reach here, it means we should return a fallback upstream if available
if len(fallbackUpstreams) > 0 {
randIndex := rand.Intn(len(fallbackUpstreams))
return fallbackUpstreams[randIndex].Upstream, fallbackUpstreams[randIndex].Index, nil
}
return nil, -1, errors.New("failed to pick an upstream origin server")
} }
// IntRange returns a random integer in the range from min to max. // IntRange returns a random integer in the range from min to max.

View File

@ -41,12 +41,12 @@ func SendOK(w http.ResponseWriter) {
// Get GET parameter // Get GET parameter
func GetPara(r *http.Request, key string) (string, error) { func GetPara(r *http.Request, key string) (string, error) {
keys, ok := r.URL.Query()[key] // Get first value from the URL query
if !ok || len(keys[0]) < 1 { value := r.URL.Query().Get(key)
if len(value) == 0 {
return "", errors.New("invalid " + key + " given") return "", errors.New("invalid " + key + " given")
} else {
return keys[0], nil
} }
return value, nil
} }
// Get GET paramter as boolean, accept 1 or true // Get GET paramter as boolean, accept 1 or true
@ -56,26 +56,29 @@ func GetBool(r *http.Request, key string) (bool, error) {
return false, err return false, err
} }
x = strings.TrimSpace(x) // Convert to lowercase and trim spaces just once to compare
switch strings.ToLower(strings.TrimSpace(x)) {
if x == "1" || strings.ToLower(x) == "true" || strings.ToLower(x) == "on" { case "1", "true", "on":
return true, nil return true, nil
} else if x == "0" || strings.ToLower(x) == "false" || strings.ToLower(x) == "off" { case "0", "false", "off":
return false, nil return false, nil
} }
return false, errors.New("invalid boolean given") return false, errors.New("invalid boolean given")
} }
// Get POST paramter // Get POST parameter
func PostPara(r *http.Request, key string) (string, error) { func PostPara(r *http.Request, key string) (string, error) {
r.ParseForm() // Try to parse the form
x := r.Form.Get(key) if err := r.ParseForm(); err != nil {
if x == "" { return "", err
return "", errors.New("invalid " + key + " given")
} else {
return x, nil
} }
// Get first value from the form
x := r.Form.Get(key)
if len(x) == 0 {
return "", errors.New("invalid " + key + " given")
}
return x, nil
} }
// Get POST paramter as boolean, accept 1 or true // Get POST paramter as boolean, accept 1 or true
@ -85,11 +88,11 @@ func PostBool(r *http.Request, key string) (bool, error) {
return false, err return false, err
} }
x = strings.TrimSpace(x) // Convert to lowercase and trim spaces just once to compare
switch strings.ToLower(strings.TrimSpace(x)) {
if x == "1" || strings.ToLower(x) == "true" || strings.ToLower(x) == "on" { case "1", "true", "on":
return true, nil return true, nil
} else if x == "0" || strings.ToLower(x) == "false" || strings.ToLower(x) == "off" { case "0", "false", "off":
return false, nil return false, nil
} }
@ -114,14 +117,19 @@ func PostInt(r *http.Request, key string) (int, error) {
func FileExists(filename string) bool { func FileExists(filename string) bool {
_, err := os.Stat(filename) _, err := os.Stat(filename)
if os.IsNotExist(err) { if err == nil {
// File exists
return true
} else if errors.Is(err, os.ErrNotExist) {
// File does not exist
return false return false
} }
return true // Some other error
return false
} }
func IsDir(path string) bool { func IsDir(path string) bool {
if FileExists(path) == false { if !FileExists(path) {
return false return false
} }
fi, err := os.Stat(path) fi, err := os.Stat(path)
@ -191,4 +199,4 @@ func ValidateListeningAddress(address string) bool {
} }
return true return true
} }