Optimized rate limiter implementation

- Moved rate limiter scope into proxy router
- Give IpTable a better name following clean code guideline
- Optimized client IP retrieval method
- Added stop channel for request counter ticker
- Fixed #199
- Optimized UI for rate limit
This commit is contained in:
Toby Chui
2024-06-14 23:42:52 +08:00
parent 85f9b297c4
commit 10048150bb
11 changed files with 160 additions and 81 deletions

View File

@@ -72,7 +72,7 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}
// Rate Limit Check
// Rate Limit
if sep.RequireRateLimit {
err := h.handleRateLimitRouting(w, r, sep)
if err != nil {

View File

@@ -91,7 +91,6 @@ func addXForwardedForHeader(req *http.Request) {
req.Header.Set("X-Real-Ip", strings.TrimSpace(ips[0]))
}
}
}
}

View File

@@ -23,12 +23,12 @@ import (
func NewDynamicProxy(option RouterOption) (*Router, error) {
proxyMap := sync.Map{}
thisRouter := Router{
Option: &option,
ProxyEndpoints: &proxyMap,
Running: false,
server: nil,
routingRules: []*RoutingRule{},
tldMap: map[string]int{},
Option: &option,
ProxyEndpoints: &proxyMap,
Running: false,
server: nil,
routingRules: []*RoutingRule{},
rateLimitCounter: RequestCountPerIpTable{},
}
thisRouter.mux = &ProxyHandler{
@@ -85,6 +85,12 @@ func (router *Router) StartProxyService() error {
MinVersion: uint16(minVersion),
}
//Start rate limitor
err := router.startRateLimterCounterResetTicker()
if err != nil {
return err
}
if router.Option.UseTls {
router.server = &http.Server{
Addr: ":" + strconv.Itoa(router.Option.Port),
@@ -129,12 +135,12 @@ func (router *Router) StartProxyService() error {
}
}
// Rate Limit Check
// if sep.RequireBasicAuth {
if err := handleRateLimit(w, r, sep); err != nil {
return
// Rate Limit
if sep.RequireRateLimit {
if err := router.handleRateLimit(w, r, sep); err != nil {
return
}
}
// }
//Validate basic auth
if sep.RequireBasicAuth {
@@ -239,10 +245,23 @@ func (router *Router) StopProxyService() error {
return err
}
//Stop TLS listener
if router.tlsListener != nil {
router.tlsListener.Close()
}
//Stop rate limiter
if router.rateLimterStop != nil {
go func() {
// As the rate timer loop has a 1 sec ticker
// stop the rate limiter in go routine can prevent
// front end from freezing for 1 sec
router.rateLimterStop <- true
}()
}
//Stop TLS redirection (from port 80)
if router.tlsRedirectStop != nil {
router.tlsRedirectStop <- true
}

View File

@@ -2,27 +2,27 @@ package dynamicproxy
import (
"errors"
"log"
"net"
"net/http"
"strings"
"sync"
"sync/atomic"
"time"
)
// IpTable is a rate limiter implementation using sync.Map with atomic int64
type IpTable struct {
type RequestCountPerIpTable struct {
table sync.Map
}
// Increment the count of requests for a given IP
func (t *IpTable) Increment(ip string) {
func (t *RequestCountPerIpTable) Increment(ip string) {
v, _ := t.table.LoadOrStore(ip, new(int64))
atomic.AddInt64(v.(*int64), 1)
}
// Check if the IP is in the table and if it is, check if the count is less than the limit
func (t *IpTable) Exceeded(ip string, limit int64) bool {
func (t *RequestCountPerIpTable) Exceeded(ip string, limit int64) bool {
v, ok := t.table.Load(ip)
if !ok {
return false
@@ -32,7 +32,7 @@ func (t *IpTable) Exceeded(ip string, limit int64) bool {
}
// Get the count of requests for a given IP
func (t *IpTable) GetCount(ip string) int64 {
func (t *RequestCountPerIpTable) GetCount(ip string) int64 {
v, ok := t.table.Load(ip)
if !ok {
return 0
@@ -41,34 +41,50 @@ func (t *IpTable) GetCount(ip string) int64 {
}
// Clear the IP table
func (t *IpTable) Clear() {
func (t *RequestCountPerIpTable) Clear() {
t.table.Range(func(key, value interface{}) bool {
t.table.Delete(key)
return true
})
}
var ipTable = IpTable{}
func (h *ProxyHandler) handleRateLimitRouting(w http.ResponseWriter, r *http.Request, pe *ProxyEndpoint) error {
err := handleRateLimit(w, r, pe)
err := h.Parent.handleRateLimit(w, r, pe)
if err != nil {
h.logRequest(r, false, 429, "ratelimit", pe.Domain)
}
return err
}
func handleRateLimit(w http.ResponseWriter, r *http.Request, pe *ProxyEndpoint) error {
ip, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
w.WriteHeader(500)
log.Println("Error resolving remote address", r.RemoteAddr, err)
return errors.New("internal server error")
func (router *Router) handleRateLimit(w http.ResponseWriter, r *http.Request, pe *ProxyEndpoint) error {
//Get the real client-ip from request header
clientIP := r.RemoteAddr
if r.Header.Get("X-Real-Ip") == "" {
CF_Connecting_IP := r.Header.Get("CF-Connecting-IP")
Fastly_Client_IP := r.Header.Get("Fastly-Client-IP")
if CF_Connecting_IP != "" {
//Use CF Connecting IP
clientIP = CF_Connecting_IP
} else if Fastly_Client_IP != "" {
//Use Fastly Client IP
clientIP = Fastly_Client_IP
} else {
ips := strings.Split(clientIP, ",")
if len(ips) > 0 {
clientIP = strings.TrimSpace(ips[0])
}
}
}
ipTable.Increment(ip)
ip, _, err := net.SplitHostPort(clientIP)
if err != nil {
//Default allow passthrough on error
return nil
}
if ipTable.Exceeded(ip, int64(pe.RateLimit)) {
router.rateLimitCounter.Increment(ip)
if router.rateLimitCounter.Exceeded(ip, int64(pe.RateLimit)) {
w.WriteHeader(429)
return errors.New("rate limit exceeded")
}
@@ -78,9 +94,26 @@ func handleRateLimit(w http.ResponseWriter, r *http.Request, pe *ProxyEndpoint)
return nil
}
func InitRateLimit() {
for {
ipTable.Clear()
time.Sleep(time.Second)
// Start the ticker routine for reseting the rate limit counter every seconds
func (r *Router) startRateLimterCounterResetTicker() error {
if r.rateLimterStop != nil {
return errors.New("another rate limiter ticker already running")
}
tickerStopChan := make(chan bool)
r.rateLimterStop = tickerStopChan
counterResetTicker := time.NewTicker(1 * time.Second)
go func() {
for {
select {
case <-tickerStopChan:
r.rateLimterStop = nil
return
case <-counterResetTicker.C:
r.rateLimitCounter.Clear()
}
}
}()
return nil
}

View File

@@ -51,8 +51,9 @@ type Router struct {
tlsListener net.Listener
routingRules []*RoutingRule
tlsRedirectStop chan bool //Stop channel for tls redirection server
tldMap map[string]int //Top level domain map, see tld.json
tlsRedirectStop chan bool //Stop channel for tls redirection server
rateLimterStop chan bool //Stop channel for rate limiter
rateLimitCounter RequestCountPerIpTable //Request counter for rate limter
}
// Auth credential for basic auth on certain endpoints