Updates 2.6.4

+ Added force TLS v1.2 above toggle
+ Added trace route
+ Added ICMP ping
+ Added special routing rules module for up-coming acme integration
+ Fixed IPv6 check bug in black/whitelist
+ Optimized UI for TCP Proxy
+
This commit is contained in:
Toby Chui
2023-06-16 00:48:39 +08:00
parent a73a7944ec
commit 48dc85ea3e
27 changed files with 1425 additions and 174 deletions

View File

@@ -115,7 +115,6 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
h.proxyRequest(w, r, targetProxyEndpoint)
} else if !strings.HasSuffix(proxyingPath, "/") {
potentialProxtEndpoint := h.Parent.getTargetProxyEndpointFromRequestURI(proxyingPath + "/")
if potentialProxtEndpoint != nil {
//Missing tailing slash. Redirect to target proxy endpoint
http.Redirect(w, r, r.RequestURI+"/", http.StatusTemporaryRedirect)

View File

@@ -45,6 +45,13 @@ func (router *Router) UpdateTLSSetting(tlsEnabled bool) {
router.Restart()
}
// Update TLS Version in runtime. Will restart proxy server if running.
// Set this to true to force TLS 1.2 or above
func (router *Router) UpdateTLSVersion(requireLatest bool) {
router.Option.ForceTLSLatest = requireLatest
router.Restart()
}
// Update https redirect, which will require updates
func (router *Router) UpdateHttpToHttpsRedirectSetting(useRedirect bool) {
router.Option.ForceHttpsRedirect = useRedirect
@@ -62,8 +69,13 @@ func (router *Router) StartProxyService() error {
return errors.New("Reverse proxy router root not set")
}
minVersion := tls.VersionTLS10
if router.Option.ForceTLSLatest {
minVersion = tls.VersionTLS12
}
config := &tls.Config{
GetCertificate: router.Option.TlsManager.GetCert,
MinVersion: uint16(minVersion),
}
if router.Option.UseTls {
@@ -171,18 +183,22 @@ func (router *Router) StopProxyService() error {
}
// Restart the current router if it is running.
// Startup the server if it is not running initially
func (router *Router) Restart() error {
//Stop the router if it is already running
var err error = nil
if router.Running {
err := router.StopProxyService()
if err != nil {
return err
}
// Start the server
err = router.StartProxyService()
if err != nil {
return err
}
}
//Start the server
err := router.StartProxyService()
return err
}

View File

@@ -15,12 +15,12 @@ import (
type RoutingRule struct {
ID string
MatchRule func(r *http.Request) bool
RoutingHandler http.Handler
RoutingHandler func(http.ResponseWriter, *http.Request)
Enabled bool
}
//Router functions
//Check if a routing rule exists given its id
// Router functions
// Check if a routing rule exists given its id
func (router *Router) GetRoutingRuleById(rrid string) (*RoutingRule, error) {
for _, rr := range router.routingRules {
if rr.ID == rrid {
@@ -31,19 +31,19 @@ func (router *Router) GetRoutingRuleById(rrid string) (*RoutingRule, error) {
return nil, errors.New("routing rule with given id not found")
}
//Add a routing rule to the router
// Add a routing rule to the router
func (router *Router) AddRoutingRules(rr *RoutingRule) error {
_, err := router.GetRoutingRuleById(rr.ID)
if err != nil {
if err == nil {
//routing rule with given id already exists
return err
return errors.New("routing rule with same id already exists")
}
router.routingRules = append(router.routingRules, rr)
return nil
}
//Remove a routing rule from the router
// Remove a routing rule from the router
func (router *Router) RemoveRoutingRule(rrid string) {
newRoutingRules := []*RoutingRule{}
for _, rr := range router.routingRules {
@@ -55,13 +55,13 @@ func (router *Router) RemoveRoutingRule(rrid string) {
router.routingRules = newRoutingRules
}
//Get all routing rules
// Get all routing rules
func (router *Router) GetAllRoutingRules() []*RoutingRule {
return router.routingRules
}
//Get the matching routing rule that describe this request.
//Return nil if no routing rule is match
// Get the matching routing rule that describe this request.
// Return nil if no routing rule is match
func (router *Router) GetMatchingRoutingRule(r *http.Request) *RoutingRule {
for _, thisRr := range router.routingRules {
if thisRr.IsMatch(r) {
@@ -71,8 +71,8 @@ func (router *Router) GetMatchingRoutingRule(r *http.Request) *RoutingRule {
return nil
}
//Routing Rule functions
//Check if a request object match the
// Routing Rule functions
// Check if a request object match the
func (e *RoutingRule) IsMatch(r *http.Request) bool {
if !e.Enabled {
return false
@@ -81,5 +81,5 @@ func (e *RoutingRule) IsMatch(r *http.Request) bool {
}
func (e *RoutingRule) Route(w http.ResponseWriter, r *http.Request) {
e.RoutingHandler.ServeHTTP(w, r)
e.RoutingHandler(w, r)
}

View File

@@ -22,13 +22,14 @@ type ProxyHandler struct {
}
type RouterOption struct {
HostUUID string
Port int
UseTls bool
ForceHttpsRedirect bool
HostUUID string //The UUID of Zoraxy, use for heading mod
Port int //Incoming port
UseTls bool //Use TLS to serve incoming requsts
ForceTLSLatest bool //Force TLS1.2 or above
ForceHttpsRedirect bool //Force redirection of http to https endpoint
TlsManager *tlscert.Manager
RedirectRuleTable *redirection.RuleTable
GeodbStore *geodb.Store
GeodbStore *geodb.Store //GeoIP blacklist and whitelist
StatisticCollector *statistic.Collector
}

16
src/mod/expose/expose.go Normal file
View File

@@ -0,0 +1,16 @@
package expose
/*
Service Expose Proxy
A tunnel for getting your local server online in one line
(No, this is not ngrok)
*/
type Router struct {
}
//Create a new service expose router
func NewServiceExposeRouter() {
}

111
src/mod/expose/security.go Normal file
View File

@@ -0,0 +1,111 @@
package expose
import (
"crypto/rand"
"crypto/rsa"
"crypto/sha512"
"crypto/x509"
"encoding/pem"
"errors"
"log"
)
// GenerateKeyPair generates a new key pair
func GenerateKeyPair(bits int) (*rsa.PrivateKey, *rsa.PublicKey, error) {
privkey, err := rsa.GenerateKey(rand.Reader, bits)
if err != nil {
return nil, nil, err
}
return privkey, &privkey.PublicKey, nil
}
// PrivateKeyToBytes private key to bytes
func PrivateKeyToBytes(priv *rsa.PrivateKey) []byte {
privBytes := pem.EncodeToMemory(
&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(priv),
},
)
return privBytes
}
// PublicKeyToBytes public key to bytes
func PublicKeyToBytes(pub *rsa.PublicKey) ([]byte, error) {
pubASN1, err := x509.MarshalPKIXPublicKey(pub)
if err != nil {
return []byte(""), err
}
pubBytes := pem.EncodeToMemory(&pem.Block{
Type: "RSA PUBLIC KEY",
Bytes: pubASN1,
})
return pubBytes, nil
}
// BytesToPrivateKey bytes to private key
func BytesToPrivateKey(priv []byte) (*rsa.PrivateKey, error) {
block, _ := pem.Decode(priv)
enc := x509.IsEncryptedPEMBlock(block)
b := block.Bytes
var err error
if enc {
log.Println("is encrypted pem block")
b, err = x509.DecryptPEMBlock(block, nil)
if err != nil {
return nil, err
}
}
key, err := x509.ParsePKCS1PrivateKey(b)
if err != nil {
return nil, err
}
return key, nil
}
// BytesToPublicKey bytes to public key
func BytesToPublicKey(pub []byte) (*rsa.PublicKey, error) {
block, _ := pem.Decode(pub)
enc := x509.IsEncryptedPEMBlock(block)
b := block.Bytes
var err error
if enc {
log.Println("is encrypted pem block")
b, err = x509.DecryptPEMBlock(block, nil)
if err != nil {
return nil, err
}
}
ifc, err := x509.ParsePKIXPublicKey(b)
if err != nil {
return nil, err
}
key, ok := ifc.(*rsa.PublicKey)
if !ok {
return nil, errors.New("key not valid")
}
return key, nil
}
// EncryptWithPublicKey encrypts data with public key
func EncryptWithPublicKey(msg []byte, pub *rsa.PublicKey) ([]byte, error) {
hash := sha512.New()
ciphertext, err := rsa.EncryptOAEP(hash, rand.Reader, pub, msg, nil)
if err != nil {
return []byte(""), err
}
return ciphertext, nil
}
// DecryptWithPrivateKey decrypts data with private key
func DecryptWithPrivateKey(ciphertext []byte, priv *rsa.PrivateKey) ([]byte, error) {
hash := sha512.New()
plaintext, err := rsa.DecryptOAEP(hash, rand.Reader, priv, ciphertext, nil)
if err != nil {
return []byte(""), err
}
return plaintext, nil
}

View File

@@ -0,0 +1,69 @@
package netutils
import (
"encoding/json"
"fmt"
"net/http"
"strconv"
"imuslab.com/zoraxy/mod/utils"
)
/*
This script handles basic network utilities like
- traceroute
- ping
*/
func HandleTraceRoute(w http.ResponseWriter, r *http.Request) {
targetIpOrDomain, err := utils.GetPara(r, "target")
if err != nil {
utils.SendErrorResponse(w, "invalid target (domain or ip) address given")
return
}
maxhopsString, err := utils.GetPara(r, "maxhops")
if err != nil {
maxhopsString = "64"
}
maxHops, err := strconv.Atoi(maxhopsString)
if err != nil {
maxHops = 64
}
results, err := TraceRoute(targetIpOrDomain, maxHops)
if err != nil {
utils.SendErrorResponse(w, err.Error())
return
}
js, _ := json.Marshal(results)
utils.SendJSONResponse(w, string(js))
}
func TraceRoute(targetIpOrDomain string, maxHops int) ([]string, error) {
return traceroute(targetIpOrDomain, maxHops)
}
func HandlePing(w http.ResponseWriter, r *http.Request) {
targetIpOrDomain, err := utils.GetPara(r, "target")
if err != nil {
utils.SendErrorResponse(w, "invalid target (domain or ip) address given")
return
}
results := []string{}
for i := 0; i < 4; i++ {
realIP, pingTime, ttl, err := PingIP(targetIpOrDomain)
if err != nil {
results = append(results, "Reply from "+realIP+": "+err.Error())
} else {
results = append(results, fmt.Sprintf("Reply from %s: Time=%dms TTL=%d", realIP, pingTime.Milliseconds(), ttl))
}
}
js, _ := json.Marshal(results)
utils.SendJSONResponse(w, string(js))
}

View File

@@ -0,0 +1,28 @@
package netutils_test
import (
"testing"
"imuslab.com/zoraxy/mod/netutils"
)
func TestHandleTraceRoute(t *testing.T) {
results, err := netutils.TraceRoute("imuslab.com", 64)
if err != nil {
t.Fatal(err)
}
t.Log(results)
}
func TestHandlePing(t *testing.T) {
ipOrDomain := "example.com"
realIP, pingTime, ttl, err := netutils.PingIP(ipOrDomain)
if err != nil {
t.Fatal("Error:", err)
return
}
t.Log(realIP, pingTime, ttl)
}

View File

@@ -0,0 +1,48 @@
package netutils
import (
"fmt"
"net"
"time"
)
func PingIP(ipOrDomain string) (string, time.Duration, int, error) {
ipAddr, err := net.ResolveIPAddr("ip", ipOrDomain)
if err != nil {
return "", 0, 0, fmt.Errorf("failed to resolve IP address: %v", err)
}
ip := ipAddr.IP.String()
start := time.Now()
conn, err := net.Dial("ip:icmp", ip)
if err != nil {
return ip, 0, 0, fmt.Errorf("failed to establish ICMP connection: %v", err)
}
defer conn.Close()
icmpMsg := []byte{8, 0, 0, 0, 0, 1, 0, 0}
_, err = conn.Write(icmpMsg)
if err != nil {
return ip, 0, 0, fmt.Errorf("failed to send ICMP message: %v", err)
}
reply := make([]byte, 1500)
err = conn.SetReadDeadline(time.Now().Add(3 * time.Second))
if err != nil {
return ip, 0, 0, fmt.Errorf("failed to set read deadline: %v", err)
}
_, err = conn.Read(reply)
if err != nil {
return ip, 0, 0, fmt.Errorf("failed to read ICMP reply: %v", err)
}
elapsed := time.Since(start)
pingTime := elapsed.Round(time.Millisecond)
ttl := int(reply[8])
return ip, pingTime, ttl, nil
}

View File

@@ -0,0 +1,212 @@
package netutils
import (
"fmt"
"net"
"os"
"time"
"golang.org/x/net/icmp"
"golang.org/x/net/ipv4"
)
const (
protocolICMP = 1
)
// liveTraceRoute return realtime tracing information to live response handler
func liveTraceRoute(dst string, maxHops int, liveRespHandler func(string)) error {
timeout := time.Second * 3
// resolve the host name to an IP address
ipAddr, err := net.ResolveIPAddr("ip4", dst)
if err != nil {
return fmt.Errorf("failed to resolve IP address for %s: %v", dst, err)
}
// create a socket to listen for incoming ICMP packets
conn, err := icmp.ListenPacket("ip4:icmp", "0.0.0.0")
if err != nil {
return fmt.Errorf("failed to create ICMP listener: %v", err)
}
defer conn.Close()
id := os.Getpid() & 0xffff
seq := 0
loop_ttl:
for ttl := 1; ttl <= maxHops; ttl++ {
// set the TTL on the socket
if err := conn.IPv4PacketConn().SetTTL(ttl); err != nil {
return fmt.Errorf("failed to set TTL: %v", err)
}
seq++
// create an ICMP message
msg := icmp.Message{
Type: ipv4.ICMPTypeEcho,
Code: 0,
Body: &icmp.Echo{
ID: id,
Seq: seq,
Data: []byte("zoraxy_trace"),
},
}
// serialize the ICMP message
msgBytes, err := msg.Marshal(nil)
if err != nil {
return fmt.Errorf("failed to serialize ICMP message: %v", err)
}
// send the ICMP message
start := time.Now()
if _, err := conn.WriteTo(msgBytes, ipAddr); err != nil {
//log.Printf("%d: %v", ttl, err)
liveRespHandler(fmt.Sprintf("%d: %v", ttl, err))
continue loop_ttl
}
// listen for the reply
replyBytes := make([]byte, 1500)
if err := conn.SetReadDeadline(time.Now().Add(timeout)); err != nil {
return fmt.Errorf("failed to set read deadline: %v", err)
}
for i := 0; i < 3; i++ {
n, peer, err := conn.ReadFrom(replyBytes)
if err != nil {
if opErr, ok := err.(*net.OpError); ok && opErr.Timeout() {
//fmt.Printf("%d: *\n", ttl)
liveRespHandler(fmt.Sprintf("%d: *\n", ttl))
continue loop_ttl
} else {
liveRespHandler(fmt.Sprintf("%d: Failed to parse ICMP message: %v", ttl, err))
}
continue
}
// parse the ICMP message
replyMsg, err := icmp.ParseMessage(protocolICMP, replyBytes[:n])
if err != nil {
liveRespHandler(fmt.Sprintf("%d: Failed to parse ICMP message: %v", ttl, err))
continue
}
// check if the reply is an echo reply
if replyMsg.Type == ipv4.ICMPTypeEchoReply {
echoReply, ok := msg.Body.(*icmp.Echo)
if !ok || echoReply.ID != id || echoReply.Seq != seq {
continue
}
liveRespHandler(fmt.Sprintf("%d: %v %v\n", ttl, peer, time.Since(start)))
break loop_ttl
}
if replyMsg.Type == ipv4.ICMPTypeTimeExceeded {
echoReply, ok := msg.Body.(*icmp.Echo)
if !ok || echoReply.ID != id || echoReply.Seq != seq {
continue
}
var raddr = peer.String()
names, _ := net.LookupAddr(raddr)
if len(names) > 0 {
raddr = names[0] + " (" + raddr + ")"
} else {
raddr = raddr + " (" + raddr + ")"
}
liveRespHandler(fmt.Sprintf("%d: %v %v\n", ttl, raddr, time.Since(start)))
continue loop_ttl
}
}
}
return nil
}
// Standard traceroute, return results after complete
func traceroute(dst string, maxHops int) ([]string, error) {
results := []string{}
timeout := time.Second * 3
// resolve the host name to an IP address
ipAddr, err := net.ResolveIPAddr("ip4", dst)
if err != nil {
return results, fmt.Errorf("failed to resolve IP address for %s: %v", dst, err)
}
// create a socket to listen for incoming ICMP packets
conn, err := icmp.ListenPacket("ip4:icmp", "0.0.0.0")
if err != nil {
return results, fmt.Errorf("failed to create ICMP listener: %v", err)
}
defer conn.Close()
id := os.Getpid() & 0xffff
seq := 0
loop_ttl:
for ttl := 1; ttl <= maxHops; ttl++ {
// set the TTL on the socket
if err := conn.IPv4PacketConn().SetTTL(ttl); err != nil {
return results, fmt.Errorf("failed to set TTL: %v", err)
}
seq++
// create an ICMP message
msg := icmp.Message{
Type: ipv4.ICMPTypeEcho,
Code: 0,
Body: &icmp.Echo{
ID: id,
Seq: seq,
Data: []byte("zoraxy_trace"),
},
}
// serialize the ICMP message
msgBytes, err := msg.Marshal(nil)
if err != nil {
return results, fmt.Errorf("failed to serialize ICMP message: %v", err)
}
// send the ICMP message
start := time.Now()
if _, err := conn.WriteTo(msgBytes, ipAddr); err != nil {
//log.Printf("%d: %v", ttl, err)
results = append(results, fmt.Sprintf("%d: %v", ttl, err))
continue loop_ttl
}
// listen for the reply
replyBytes := make([]byte, 1500)
if err := conn.SetReadDeadline(time.Now().Add(timeout)); err != nil {
return results, fmt.Errorf("failed to set read deadline: %v", err)
}
for i := 0; i < 3; i++ {
n, peer, err := conn.ReadFrom(replyBytes)
if err != nil {
if opErr, ok := err.(*net.OpError); ok && opErr.Timeout() {
//fmt.Printf("%d: *\n", ttl)
results = append(results, fmt.Sprintf("%d: *", ttl))
continue loop_ttl
} else {
results = append(results, fmt.Sprintf("%d: Failed to parse ICMP message: %v", ttl, err))
}
continue
}
// parse the ICMP message
replyMsg, err := icmp.ParseMessage(protocolICMP, replyBytes[:n])
if err != nil {
results = append(results, fmt.Sprintf("%d: Failed to parse ICMP message: %v", ttl, err))
continue
}
// check if the reply is an echo reply
if replyMsg.Type == ipv4.ICMPTypeEchoReply {
echoReply, ok := msg.Body.(*icmp.Echo)
if !ok || echoReply.ID != id || echoReply.Seq != seq {
continue
}
results = append(results, fmt.Sprintf("%d: %v %v", ttl, peer, time.Since(start)))
break loop_ttl
}
if replyMsg.Type == ipv4.ICMPTypeTimeExceeded {
echoReply, ok := msg.Body.(*icmp.Echo)
if !ok || echoReply.ID != id || echoReply.Seq != seq {
continue
}
var raddr = peer.String()
names, _ := net.LookupAddr(raddr)
if len(names) > 0 {
raddr = names[0] + " (" + raddr + ")"
} else {
raddr = raddr + " (" + raddr + ")"
}
results = append(results, fmt.Sprintf("%d: %v %v", ttl, raddr, time.Since(start)))
continue loop_ttl
}
}
}
return results, nil
}

100
src/mod/pathrule/handler.go Normal file
View File

@@ -0,0 +1,100 @@
package pathrule
import (
"encoding/json"
"net/http"
"strconv"
uuid "github.com/satori/go.uuid"
"imuslab.com/zoraxy/mod/utils"
)
/*
handler.go
This script handles pathblock api
*/
func (h *Handler) HandleListBlockingPath(w http.ResponseWriter, r *http.Request) {
js, _ := json.Marshal(h.BlockingPaths)
utils.SendJSONResponse(w, string(js))
}
func (h *Handler) HandleAddBlockingPath(w http.ResponseWriter, r *http.Request) {
matchingPath, err := utils.PostPara(r, "matchingPath")
if err != nil {
utils.SendErrorResponse(w, "invalid matching path given")
return
}
exactMatch, err := utils.PostPara(r, "exactMatch")
if err != nil {
utils.SendErrorResponse(w, "invalid exact match value given")
return
}
statusCodeString, err := utils.PostPara(r, "statusCode")
if err != nil {
utils.SendErrorResponse(w, "invalid status code given")
return
}
statusCode, err := strconv.Atoi(statusCodeString)
if err != nil {
utils.SendErrorResponse(w, "invalid status code given")
return
}
enabled, err := utils.PostPara(r, "enabled")
if err != nil {
utils.SendErrorResponse(w, "invalid enabled value given")
return
}
caseSensitive, err := utils.PostPara(r, "caseSensitive")
if err != nil {
utils.SendErrorResponse(w, "invalid case sensitive value given")
return
}
targetBlockingPath := BlockingPath{
UUID: uuid.NewV4().String(),
MatchingPath: matchingPath,
ExactMatch: exactMatch == "true",
StatusCode: statusCode,
CustomHeaders: http.Header{},
CustomHTML: []byte(""),
Enabled: enabled == "true",
CaseSenitive: caseSensitive == "true",
}
err = h.AddBlockingPath(&targetBlockingPath)
if err != nil {
utils.SendErrorResponse(w, err.Error())
return
}
utils.SendOK(w)
}
func (h *Handler) HandleRemoveBlockingPath(w http.ResponseWriter, r *http.Request) {
blockerUUID, err := utils.PostPara(r, "uuid")
if err != nil {
utils.SendErrorResponse(w, "invalid uuid given")
return
}
targetRule := h.GetPathBlockerFromUUID(blockerUUID)
if targetRule == nil {
//Not found
utils.SendErrorResponse(w, "target path blocker not found")
return
}
err = h.RemoveBlockingPathByUUID(blockerUUID)
if err != nil {
utils.SendErrorResponse(w, err.Error())
return
}
utils.SendOK(w)
}

View File

@@ -0,0 +1,175 @@
package pathrule
import (
"encoding/json"
"errors"
"net/http"
"os"
"path/filepath"
"strings"
"imuslab.com/zoraxy/mod/utils"
)
/*
Pathblock.go
This script block off some of the specific pathname in access
For example, this module can help you block request for a particular
apache directory or functional endpoints like /.well-known/ when you
are not using it
*/
type Options struct {
ConfigFolder string //The folder to store the path blocking config files
}
type BlockingPath struct {
UUID string
MatchingPath string
ExactMatch bool
StatusCode int
CustomHeaders http.Header
CustomHTML []byte
Enabled bool
CaseSenitive bool
}
type Handler struct {
Options *Options
BlockingPaths []*BlockingPath
}
// Create a new path blocker handler
func NewPathBlocker(options *Options) *Handler {
//Create folder if not exists
if !utils.FileExists(options.ConfigFolder) {
os.Mkdir(options.ConfigFolder, 0775)
}
//Load the configs from file
//TODO
return &Handler{
Options: options,
BlockingPaths: []*BlockingPath{},
}
}
func (h *Handler) ListBlockingPath() []*BlockingPath {
return h.BlockingPaths
}
// Get the blocker from matching path (path match, ignore tailing slash)
func (h *Handler) GetPathBlockerFromMatchingPath(matchingPath string) *BlockingPath {
for _, blocker := range h.BlockingPaths {
if blocker.MatchingPath == matchingPath {
return blocker
} else if strings.TrimSuffix(blocker.MatchingPath, "/") == strings.TrimSuffix(matchingPath, "/") {
return blocker
}
}
return nil
}
func (h *Handler) GetPathBlockerFromUUID(UUID string) *BlockingPath {
for _, blocker := range h.BlockingPaths {
if blocker.UUID == UUID {
return blocker
}
}
return nil
}
func (h *Handler) AddBlockingPath(pathBlocker *BlockingPath) error {
//Check if the blocker exists
blockerPath := pathBlocker.MatchingPath
targetBlocker := h.GetPathBlockerFromMatchingPath(blockerPath)
if targetBlocker != nil {
//Blocker with the same matching path already exists
return errors.New("path blocker with the same path already exists")
}
h.BlockingPaths = append(h.BlockingPaths, pathBlocker)
//Write the new config to file
return h.SaveBlockerToFile(pathBlocker)
}
func (h *Handler) RemoveBlockingPathByUUID(uuid string) error {
newBlockingList := []*BlockingPath{}
for _, thisBlocker := range h.BlockingPaths {
if thisBlocker.UUID != uuid {
newBlockingList = append(newBlockingList, thisBlocker)
}
}
if len(h.BlockingPaths) == len(newBlockingList) {
//Nothing is removed
return errors.New("given matching path blocker not exists")
}
h.BlockingPaths = newBlockingList
return h.RemoveBlockerFromFile(uuid)
}
func (h *Handler) SaveBlockerToFile(pathBlocker *BlockingPath) error {
saveFilename := filepath.Join(h.Options.ConfigFolder, pathBlocker.UUID)
js, _ := json.MarshalIndent(pathBlocker, "", " ")
return os.WriteFile(saveFilename, js, 0775)
}
func (h *Handler) RemoveBlockerFromFile(uuid string) error {
expectedConfigFile := filepath.Join(h.Options.ConfigFolder, uuid)
if !utils.FileExists(expectedConfigFile) {
return errors.New("config file not found on disk")
}
return os.Remove(expectedConfigFile)
}
// Get all the matching blockers for the given URL path
// return all the path blockers and the max length matching rule
func (h *Handler) GetMatchingBlockers(urlPath string) ([]*BlockingPath, *BlockingPath) {
urlPath = strings.TrimSuffix(urlPath, "/")
matchingBlockers := []*BlockingPath{}
var longestMatchingPrefix *BlockingPath = nil
for _, thisBlocker := range h.BlockingPaths {
if thisBlocker.Enabled == false {
//This blocker is not enabled. Ignore this
continue
}
incomingURLPath := urlPath
matchingPath := strings.TrimSuffix(thisBlocker.MatchingPath, "/")
if !thisBlocker.CaseSenitive {
//This is not case sensitive
incomingURLPath = strings.ToLower(incomingURLPath)
matchingPath = strings.ToLower(matchingPath)
}
if matchingPath == incomingURLPath {
//This blocker have exact url path match
matchingBlockers = append(matchingBlockers, thisBlocker)
if longestMatchingPrefix == nil || len(thisBlocker.MatchingPath) > len(longestMatchingPrefix.MatchingPath) {
longestMatchingPrefix = thisBlocker
}
continue
}
if !thisBlocker.ExactMatch && strings.HasPrefix(incomingURLPath, matchingPath) {
//This blocker have prefix url match
matchingBlockers = append(matchingBlockers, thisBlocker)
if longestMatchingPrefix == nil || len(thisBlocker.MatchingPath) > len(longestMatchingPrefix.MatchingPath) {
longestMatchingPrefix = thisBlocker
}
continue
}
}
return matchingBlockers, longestMatchingPrefix
}