This commit is contained in:
Toby Chui
2025-10-16 21:50:04 +08:00
29 changed files with 652 additions and 259 deletions

View File

@@ -27,6 +27,7 @@ import (
"github.com/go-acme/lego/v4/registration"
"imuslab.com/zoraxy/mod/database"
"imuslab.com/zoraxy/mod/info/logger"
"imuslab.com/zoraxy/mod/netutils"
"imuslab.com/zoraxy/mod/utils"
)
@@ -432,18 +433,18 @@ func (a *ACMEHandler) HandleGetExpiredDomains(w http.ResponseWriter, r *http.Req
// to renew the certificate, and sends a JSON response indicating the result of the renewal process.
func (a *ACMEHandler) HandleRenewCertificate(w http.ResponseWriter, r *http.Request) {
domainPara, err := utils.PostPara(r, "domains")
//Clean each domain
cleanedDomains := []string{}
if (domainPara != "") {
if domainPara != "" {
for _, d := range strings.Split(domainPara, ",") {
// Apply normalization on each domain
nd, err := NormalizeDomain(d)
nd, err := netutils.NormalizeDomain(d)
if err != nil {
utils.SendErrorResponse(w, jsonEscape(err.Error()))
return
}
cleanedDomains = append(cleanedDomains, nd)
}
cleanedDomains = append(cleanedDomains, nd)
}
}
@@ -507,7 +508,6 @@ func (a *ACMEHandler) HandleRenewCertificate(w http.ResponseWriter, r *http.Requ
dns = true
}
// Default propagation timeout is 300 seconds
propagationTimeout := 300
if dns {
@@ -549,7 +549,6 @@ func (a *ACMEHandler) HandleRenewCertificate(w http.ResponseWriter, r *http.Requ
a.Logf("Could not extract SANs from PEM, using domainPara only", err)
}
// Extract DNS servers from the request
var dnsServers []string
dnsServersPara, err := utils.PostPara(r, "dnsServers")

View File

@@ -7,8 +7,6 @@ import (
"fmt"
"os"
"time"
"strings"
"unicode"
)
// Get the issuer name from pem file
@@ -42,8 +40,6 @@ func ExtractDomains(certBytes []byte) ([]string, error) {
return []string{}, errors.New("decode cert bytes failed")
}
func ExtractIssuerName(certBytes []byte) (string, error) {
// Parse the PEM block
block, _ := pem.Decode(certBytes)
@@ -73,9 +69,9 @@ func ExtractDomainsFromPEM(pemFilePath string) ([]string, error) {
certBytes, err := os.ReadFile(pemFilePath)
if err != nil {
return nil, err
return nil, err
}
domains,err := ExtractDomains(certBytes)
domains, err := ExtractDomains(certBytes)
if err != nil {
return nil, err
}
@@ -116,48 +112,3 @@ func CertExpireSoon(certBytes []byte, numberOfDays int) bool {
}
return false
}
// NormalizeDomain cleans and validates a domain string.
// - Trims spaces around the domain
// - Converts to lowercase
// - Removes trailing dot (FQDN canonicalization)
// - Checks that the domain conforms to standard rules:
// * Each label ≤ 63 characters
// * Only letters, digits, and hyphens
// * Labels do not start or end with a hyphen
// * Full domain ≤ 253 characters
// Returns an empty string if the domain is invalid.
func NormalizeDomain(d string) (string, error) {
d = strings.TrimSpace(d)
d = strings.ToLower(d)
d = strings.TrimSuffix(d, ".")
if len(d) == 0 {
return "", errors.New("domain is empty")
}
if len(d) > 253 {
return "", errors.New("domain exceeds 253 characters")
}
labels := strings.Split(d, ".")
for _, label := range labels {
if len(label) == 0 {
return "", errors.New("Domain '" + d + "' not valid: Empty label")
}
if len(label) > 63 {
return "", errors.New("Domain not valid: label exceeds 63 characters")
}
for i, r := range label {
if !(unicode.IsLetter(r) || unicode.IsDigit(r) || r == '-') {
return "", errors.New("Domain '" + d + "' not valid: Invalid character '" + string(r) + "' in label")
}
if (i == 0 || i == len(label)-1) && r == '-' {
return "", errors.New("Domain '" + d + "' not valid: label '"+ label +"' starts or ends with hyphen")
}
}
}
return d, nil
}

View File

@@ -92,7 +92,6 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
//Plugin routing
if h.Parent.Option.PluginManager != nil && h.Parent.Option.PluginManager.HandleRoute(w, r, sep.Tags) {
//Request handled by subroute
return

View File

@@ -438,7 +438,15 @@ func (p *ReverseProxy) ProxyHTTPS(rw http.ResponseWriter, req *http.Request) (in
if !strings.Contains(host, ":") {
host += ":443"
}
serverName := req.URL.Hostname()
serverName := ""
//if p.Transport != nil {
// if tr, ok := p.Transport.(*http.Transport); ok && tr.TLSClientConfig != nil && tr.TLSClientConfig.ServerName != "" {
// serverName = tr.TLSClientConfig.ServerName
// }
//}
if serverName == "" {
serverName = req.URL.Hostname()
}
// Connect with SNI offload
tlsConfig := &tls.Config{

View File

@@ -48,8 +48,8 @@ func (router *Router) UpdateTLSSetting(tlsEnabled bool) {
// Update TLS Version in runtime. Will restart proxy server if running.
// Set this to true to force TLS 1.2 or above
func (router *Router) UpdateTLSVersion(requireLatest bool) {
router.Option.ForceTLSLatest = requireLatest
func (router *Router) SetTlsMinVersion(minTlsVersion uint16) {
router.Option.MinTLSVersion = minTlsVersion
router.Restart()
}
@@ -77,9 +77,9 @@ func (router *Router) StartProxyService() error {
return errors.New("reverse proxy router root not set")
}
minVersion := tls.VersionTLS10
if router.Option.ForceTLSLatest {
minVersion = tls.VersionTLS12
minVersion := tls.VersionTLS12 //Default to TLS 1.2
if router.Option.MinTLSVersion != 0 {
minVersion = int(router.Option.MinTLSVersion)
}
config := &tls.Config{

View File

@@ -272,6 +272,11 @@ func (ep *ProxyEndpoint) Remove() error {
return nil
}
// Check if the proxy endpoint is enabled
func (ep *ProxyEndpoint) IsEnabled() bool {
return !ep.Disabled
}
// Write changes to runtime without respawning the proxy handler
// use prepare -> remove -> add if you change anything in the endpoint
// that effects the proxy routing src / dest

View File

@@ -12,6 +12,7 @@ import (
"strings"
"imuslab.com/zoraxy/mod/dynamicproxy/dpcore"
"imuslab.com/zoraxy/mod/dynamicproxy/loadbalance"
"imuslab.com/zoraxy/mod/dynamicproxy/rewrite"
"imuslab.com/zoraxy/mod/netutils"
"imuslab.com/zoraxy/mod/statistic"
@@ -95,27 +96,47 @@ func (router *Router) GetProxyEndpointFromHostname(hostname string) *ProxyEndpoi
return targetSubdomainEndpoint
}
// Clearn URL Path (without the http:// part) replaces // in a URL to /
func (router *Router) clearnURL(targetUrlOPath string) string {
return strings.ReplaceAll(targetUrlOPath, "//", "/")
}
// Rewrite URL rewrite the prefix part of a virtual directory URL with /
func (router *Router) rewriteURL(rooturl string, requestURL string) string {
rewrittenURL := requestURL
rewrittenURL = strings.TrimPrefix(rewrittenURL, strings.TrimSuffix(rooturl, "/"))
if strings.Contains(rewrittenURL, "//") {
rewrittenURL = router.clearnURL(rewrittenURL)
rewrittenURL = strings.ReplaceAll(rewrittenURL, "//", "/")
}
return rewrittenURL
}
// upstreamHostSwap check if this loopback to one of the proxy rule in the system. If yes, do a shortcut target swap
// this prevents unnecessary external DNS lookup and connection, return true if swapped and request is already handled
// by the loopback handler. Only continue if return is false
func (h *ProxyHandler) upstreamHostSwap(w http.ResponseWriter, r *http.Request, selectedUpstream *loadbalance.Upstream) bool {
upstreamHostname := selectedUpstream.OriginIpOrDomain
if strings.Contains(upstreamHostname, ":") {
upstreamHostname = strings.Split(upstreamHostname, ":")[0]
}
loopbackProxyEndpoint := h.Parent.GetProxyEndpointFromHostname(upstreamHostname)
if loopbackProxyEndpoint != nil {
//This is a loopback request. Swap the target to the loopback target
//h.Parent.Option.Logger.PrintAndLog("proxy", "Detected a loopback request to self. Swap the target to "+loopbackProxyEndpoint.RootOrMatchingDomain, nil)
if loopbackProxyEndpoint.IsEnabled() {
h.hostRequest(w, r, loopbackProxyEndpoint)
} else {
//Endpoint disabled, return 503
http.ServeFile(w, r, "./web/rperror.html")
h.Parent.logRequest(r, false, 521, "host-http", r.Host, upstreamHostname)
}
return true
}
return false
}
// Handle host request
func (h *ProxyHandler) hostRequest(w http.ResponseWriter, r *http.Request, target *ProxyEndpoint) {
r.Header.Set("X-Forwarded-Host", r.Host)
r.Header.Set("X-Forwarded-Server", "zoraxy-"+h.Parent.Option.HostUUID)
reqHostname := r.Host
/* Load balancing */
selectedUpstream, err := h.Parent.loadBalancer.GetRequestUpstreamTarget(w, r, target.ActiveOrigins, target.UseStickySession)
if err != nil {
@@ -125,6 +146,12 @@ func (h *ProxyHandler) hostRequest(w http.ResponseWriter, r *http.Request, targe
return
}
/* Upstream Host Swap (use to detect loopback to self) */
if h.upstreamHostSwap(w, r, selectedUpstream) {
//Request handled by the loopback handler
return
}
/* WebSocket automatic proxy */
requestURL := r.URL.String()
if r.Header["Upgrade"] != nil && strings.ToLower(r.Header["Upgrade"][0]) == "websocket" {

View File

@@ -49,7 +49,7 @@ type RouterOption struct {
HostVersion string //The version of Zoraxy, use for heading mod
Port int //Incoming port
UseTls bool //Use TLS to serve incoming requsts
ForceTLSLatest bool //Force TLS1.2 or above
MinTLSVersion uint16 //Minimum TLS version
NoCache bool //Force set Cache-Control: no-store
ListenOnPort80 bool //Enable port 80 http listener
ForceHttpsRedirect bool //Force redirection of http to https endpoint

View File

@@ -13,6 +13,25 @@ import (
CIDR and IPv4 / v6 validations
*/
// Get the requester IP without trusting any proxy headers
func GetRequesterIPUntrusted(r *http.Request) string {
// If the request is from an untrusted IP, we should not trust the X-Real-IP and X-Forwarded-For headers
ip := r.RemoteAddr
// Trim away the port number
reqHost, _, err := net.SplitHostPort(ip)
if err == nil {
ip = reqHost
}
// Check if the IP is a valid IPv4 or IPv6 address
parsedIP := net.ParseIP(ip)
if parsedIP == nil {
return ""
}
return ip
}
// Get the requester IP, trust the X-Real-IP and X-Forwarded-For headers
func GetRequesterIP(r *http.Request) string {
ip := r.Header.Get("X-Real-Ip")
if ip == "" {

View File

@@ -2,10 +2,13 @@ package netutils
import (
"encoding/json"
"errors"
"fmt"
"net"
"net/http"
"strconv"
"strings"
"unicode"
"github.com/likexian/whois"
"imuslab.com/zoraxy/mod/utils"
@@ -167,3 +170,53 @@ func CheckIfPortOccupied(portNumber int) bool {
listener.Close()
return false
}
// NormalizeDomain cleans and validates a domain string.
// - Trims spaces around the domain
// - Converts to lowercase
// - Removes trailing dot (FQDN canonicalization)
// - Checks that the domain conforms to standard rules:
// - Each label ≤ 63 characters
// - Only letters, digits, and hyphens
// - Labels do not start or end with a hyphen
// - Full domain ≤ 253 characters
//
// Returns an empty string if the domain is invalid.
func NormalizeDomain(d string) (string, error) {
d = strings.TrimSpace(d)
d = strings.ToLower(d)
d = strings.TrimSuffix(d, ".")
if len(d) == 0 {
return "", errors.New("domain is empty")
}
if len(d) > 253 {
return "", errors.New("domain exceeds 253 characters")
}
labels := strings.Split(d, ".")
for index, label := range labels {
if index == 0 {
if len(label) == 1 && label == "*" {
continue
}
}
if len(label) == 0 {
return "", errors.New("Domain '" + d + "' not valid: Empty label")
}
if len(label) > 63 {
return "", errors.New("Domain not valid: label exceeds 63 characters")
}
for i, r := range label {
if !unicode.IsLetter(r) && !unicode.IsDigit(r) && r != '-' {
return "", errors.New("Domain '" + d + "' not valid: Invalid character '" + string(r) + "' in label")
}
if (i == 0 || i == len(label)-1) && r == '-' {
return "", errors.New("Domain '" + d + "' not valid: label '" + label + "' starts or ends with hyphen")
}
}
}
return d, nil
}

View File

@@ -47,19 +47,19 @@ func (m *Manager) HandleAddProxyConfig(w http.ResponseWriter, r *http.Request) {
useTCP, _ := utils.PostBool(r, "useTCP")
useUDP, _ := utils.PostBool(r, "useUDP")
useProxyProtocol, _ := utils.PostBool(r, "useProxyProtocol")
ProxyProtocolVersion, _ := utils.PostInt(r, "proxyProtocolVersion")
enableLogging, _ := utils.PostBool(r, "enableLogging")
//Create the target config
newConfigUUID := m.NewConfig(&ProxyRelayOptions{
Name: name,
ListeningAddr: strings.TrimSpace(listenAddr),
ProxyAddr: strings.TrimSpace(proxyAddr),
Timeout: timeout,
UseTCP: useTCP,
UseUDP: useUDP,
UseProxyProtocol: useProxyProtocol,
EnableLogging: enableLogging,
Name: name,
ListeningAddr: strings.TrimSpace(listenAddr),
ProxyAddr: strings.TrimSpace(proxyAddr),
Timeout: timeout,
UseTCP: useTCP,
UseUDP: useUDP,
ProxyProtocolVersion: convertIntToProxyProtocolVersion(ProxyProtocolVersion),
EnableLogging: enableLogging,
})
js, _ := json.Marshal(newConfigUUID)
@@ -79,7 +79,7 @@ func (m *Manager) HandleEditProxyConfigs(w http.ResponseWriter, r *http.Request)
proxyAddr, _ := utils.PostPara(r, "proxyAddr")
useTCP, _ := utils.PostBool(r, "useTCP")
useUDP, _ := utils.PostBool(r, "useUDP")
useProxyProtocol, _ := utils.PostBool(r, "useProxyProtocol")
proxyProtocolVersion, _ := utils.PostInt(r, "proxyProtocolVersion")
enableLogging, _ := utils.PostBool(r, "enableLogging")
newTimeoutStr, _ := utils.PostPara(r, "timeout")
@@ -94,15 +94,15 @@ func (m *Manager) HandleEditProxyConfigs(w http.ResponseWriter, r *http.Request)
// Create a new ProxyRuleUpdateConfig with the extracted parameters
newConfig := &ProxyRuleUpdateConfig{
InstanceUUID: configUUID,
NewName: newName,
NewListeningAddr: listenAddr,
NewProxyAddr: proxyAddr,
UseTCP: useTCP,
UseUDP: useUDP,
UseProxyProtocol: useProxyProtocol,
EnableLogging: enableLogging,
NewTimeout: newTimeout,
InstanceUUID: configUUID,
NewName: newName,
NewListeningAddr: listenAddr,
NewProxyAddr: proxyAddr,
UseTCP: useTCP,
UseUDP: useUDP,
ProxyProtocolVersion: proxyProtocolVersion,
EnableLogging: enableLogging,
NewTimeout: newTimeout,
}
// Call the EditConfig method to modify the configuration

View File

@@ -15,50 +15,59 @@ import (
)
/*
TCP Proxy
Stream Proxy
Forward port from one port to another
Also accept active connection and passive
connection
*/
// ProxyProtocolVersion enum type
type ProxyProtocolVersion int
const (
ProxyProtocolDisabled ProxyProtocolVersion = 0
ProxyProtocolV1 ProxyProtocolVersion = 1
ProxyProtocolV2 ProxyProtocolVersion = 2
)
type ProxyRelayOptions struct {
Name string
ListeningAddr string
ProxyAddr string
Timeout int
UseTCP bool
UseUDP bool
UseProxyProtocol bool
EnableLogging bool
Name string
ListeningAddr string
ProxyAddr string
Timeout int
UseTCP bool
UseUDP bool
ProxyProtocolVersion ProxyProtocolVersion
EnableLogging bool
}
// ProxyRuleUpdateConfig is used to update the proxy rule config
type ProxyRuleUpdateConfig struct {
InstanceUUID string //The target instance UUID to update
NewName string //New name for the instance, leave empty for no change
NewListeningAddr string //New listening address, leave empty for no change
NewProxyAddr string //New proxy target address, leave empty for no change
UseTCP bool //Enable TCP proxy, default to false
UseUDP bool //Enable UDP proxy, default to false
UseProxyProtocol bool //Enable Proxy Protocol, default to false
EnableLogging bool //Enable Logging TCP/UDP Message, default to true
NewTimeout int //New timeout for the connection, leave -1 for no change
InstanceUUID string //The target instance UUID to update
NewName string //New name for the instance, leave empty for no change
NewListeningAddr string //New listening address, leave empty for no change
NewProxyAddr string //New proxy target address, leave empty for no change
UseTCP bool //Enable TCP proxy, default to false
UseUDP bool //Enable UDP proxy, default to false
ProxyProtocolVersion int //Enable Proxy Protocol v1/v2, default to disabled
EnableLogging bool //Enable Logging TCP/UDP Message, default to true
NewTimeout int //New timeout for the connection, leave -1 for no change
}
type ProxyRelayInstance struct {
/* Runtime Config */
UUID string //A UUIDv4 representing this config
Name string //Name of the config
Running bool //Status, read only
AutoStart bool //If the service suppose to started automatically
ListeningAddress string //Listening Address, usually 127.0.0.1:port
ProxyTargetAddr string //Proxy target address
UseTCP bool //Enable TCP proxy
UseUDP bool //Enable UDP proxy
UseProxyProtocol bool //Enable Proxy Protocol
EnableLogging bool //Enable logging for ProxyInstance
Timeout int //Timeout for connection in sec
UUID string //A UUIDv4 representing this config
Name string //Name of the config
Running bool //Status, read only
AutoStart bool //If the service suppose to started automatically
ListeningAddress string //Listening Address, usually 127.0.0.1:port
ProxyTargetAddr string //Proxy target address
UseTCP bool //Enable TCP proxy
UseUDP bool //Enable UDP proxy
ProxyProtocolVersion ProxyProtocolVersion //Proxy Protocol v1/v2
EnableLogging bool //Enable logging for ProxyInstance
Timeout int //Timeout for connection in sec
/* Internal */
tcpStopChan chan bool //Stop channel for TCP listener
@@ -178,7 +187,7 @@ func (m *Manager) NewConfig(config *ProxyRelayOptions) string {
ProxyTargetAddr: config.ProxyAddr,
UseTCP: config.UseTCP,
UseUDP: config.UseUDP,
UseProxyProtocol: config.UseProxyProtocol,
ProxyProtocolVersion: config.ProxyProtocolVersion,
EnableLogging: config.EnableLogging,
Timeout: config.Timeout,
tcpStopChan: nil,
@@ -203,6 +212,30 @@ func (m *Manager) GetConfigByUUID(configUUID string) (*ProxyRelayInstance, error
return nil, errors.New("config not found")
}
// ConvertIntToProxyProtocolVersion converts an int to ProxyProtocolVersion type
func convertIntToProxyProtocolVersion(v int) ProxyProtocolVersion {
switch v {
case 1:
return ProxyProtocolV1
case 2:
return ProxyProtocolV2
default:
return ProxyProtocolDisabled
}
}
// convertProxyProtocolVersionToInt converts ProxyProtocolVersion type back to int
func convertProxyProtocolVersionToInt(v ProxyProtocolVersion) int {
switch v {
case ProxyProtocolV1:
return 1
case ProxyProtocolV2:
return 2
default:
return 0
}
}
// Edit the config based on config UUID, leave empty for unchange fields
func (m *Manager) EditConfig(newConfig *ProxyRuleUpdateConfig) error {
// Find the config with the specified UUID
@@ -224,7 +257,7 @@ func (m *Manager) EditConfig(newConfig *ProxyRuleUpdateConfig) error {
foundConfig.UseTCP = newConfig.UseTCP
foundConfig.UseUDP = newConfig.UseUDP
foundConfig.UseProxyProtocol = newConfig.UseProxyProtocol
foundConfig.ProxyProtocolVersion = convertIntToProxyProtocolVersion(newConfig.ProxyProtocolVersion)
foundConfig.EnableLogging = newConfig.EnableLogging
if newConfig.NewTimeout != -1 {

View File

@@ -11,6 +11,8 @@ import (
"sync"
"sync/atomic"
"time"
proxyproto "github.com/pires/go-proxyproto"
)
func isValidIP(ip string) bool {
@@ -44,20 +46,22 @@ func (c *ProxyRelayInstance) connCopy(conn1 net.Conn, conn2 net.Conn, wg *sync.W
wg.Done()
}
func writeProxyProtocolHeaderV1(dst net.Conn, src net.Conn) error {
func WriteProxyProtocolHeader(dst net.Conn, src net.Conn, version ProxyProtocolVersion) error {
clientAddr, ok1 := src.RemoteAddr().(*net.TCPAddr)
proxyAddr, ok2 := src.LocalAddr().(*net.TCPAddr)
if !ok1 || !ok2 {
return errors.New("invalid TCP address for proxy protocol")
}
header := fmt.Sprintf("PROXY TCP4 %s %s %d %d\r\n",
clientAddr.IP.String(),
proxyAddr.IP.String(),
clientAddr.Port,
proxyAddr.Port)
header := proxyproto.Header{
Version: byte(convertProxyProtocolVersionToInt(version)),
Command: proxyproto.PROXY,
TransportProtocol: proxyproto.TCPv4,
SourceAddr: clientAddr,
DestinationAddr: proxyAddr,
}
_, err := dst.Write([]byte(header))
_, err := header.WriteTo(dst)
return err
}
@@ -161,9 +165,9 @@ func (c *ProxyRelayInstance) Port2host(allowPort string, targetAddress string, s
}
c.LogMsg("[→] connect target address ["+targetAddress+"] success.", nil)
if c.UseProxyProtocol {
if c.ProxyProtocolVersion != ProxyProtocolDisabled {
c.LogMsg("[+] write proxy protocol header to target address ["+targetAddress+"]", nil)
err = writeProxyProtocolHeaderV1(target, conn)
err = WriteProxyProtocolHeader(target, conn, c.ProxyProtocolVersion)
if err != nil {
c.LogMsg("[x] Write proxy protocol header failed: "+err.Error(), nil)
target.Close()

View File

@@ -1,11 +1,14 @@
package streamproxy
import (
"bytes"
"errors"
"log"
"net"
"strings"
"time"
proxyproto "github.com/pires/go-proxyproto"
)
/*
@@ -82,6 +85,24 @@ func (c *ProxyRelayInstance) CloseAllUDPConnections() {
})
}
// Write Proxy Protocol v2 header to UDP connection
func WriteProxyProtocolHeaderUDP(conn *net.UDPConn, srcAddr, dstAddr *net.UDPAddr) error {
header := proxyproto.Header{
Version: byte(convertProxyProtocolVersionToInt(ProxyProtocolV2)),
Command: proxyproto.PROXY,
TransportProtocol: proxyproto.UDPv4,
SourceAddr: srcAddr,
DestinationAddr: dstAddr,
}
var buf bytes.Buffer
_, err := header.WriteTo(&buf)
if err != nil {
return err
}
_, err = conn.Write(buf.Bytes())
return err
}
func (c *ProxyRelayInstance) ForwardUDP(address1, address2 string, stopChan chan bool) error {
//By default the incoming listen Address is int
//We need to add the loopback address into it
@@ -142,6 +163,10 @@ func (c *ProxyRelayInstance) ForwardUDP(address1, address2 string, stopChan chan
// Fire up routine to manage new connection
go c.RunUDPConnectionRelay(conn, lisener)
// Send Proxy Protocol header if enabled
if c.ProxyProtocolVersion == ProxyProtocolV2 {
_ = WriteProxyProtocolHeaderUDP(conn.ServerConn, cliaddr, targetAddr)
}
} else {
c.LogMsg("[UDP] Found connection for client "+saddr, nil)
conn = rawConn.(*udpClientServerConn)

View File

@@ -75,6 +75,50 @@ func (m *Manager) HandleCertDownload(w http.ResponseWriter, r *http.Request) {
}
}
// Set the selected certificate as the default / fallback certificate
func (m *Manager) SetCertAsDefault(w http.ResponseWriter, r *http.Request) {
certname, err := utils.PostPara(r, "certname")
if err != nil {
utils.SendErrorResponse(w, "invalid certname given")
return
}
//Check if the previous default cert exists. If yes, get its hostname from cert contents
defaultPubKey := filepath.Join(m.CertStore, "default.key")
defaultPriKey := filepath.Join(m.CertStore, "default.pem")
if utils.FileExists(defaultPubKey) && utils.FileExists(defaultPriKey) {
//Move the existing default cert to its original name
certBytes, err := os.ReadFile(defaultPriKey)
if err == nil {
block, _ := pem.Decode(certBytes)
if block != nil {
cert, err := x509.ParseCertificate(block.Bytes)
if err == nil {
os.Rename(defaultPubKey, filepath.Join(m.CertStore, domainToFilename(cert.Subject.CommonName, "key")))
os.Rename(defaultPriKey, filepath.Join(m.CertStore, domainToFilename(cert.Subject.CommonName, "pem")))
}
}
}
}
//Check if the cert exists
certname = filepath.Base(certname) //prevent path escape
pubKey := filepath.Join(filepath.Join(m.CertStore), certname+".key")
priKey := filepath.Join(filepath.Join(m.CertStore), certname+".pem")
if utils.FileExists(pubKey) && utils.FileExists(priKey) {
os.Rename(pubKey, filepath.Join(m.CertStore, "default.key"))
os.Rename(priKey, filepath.Join(m.CertStore, "default.pem"))
utils.SendOK(w)
//Update cert list
m.UpdateLoadedCertList()
} else {
utils.SendErrorResponse(w, "invalid key-pairs: private key or public key not found in key store")
return
}
}
// Handle upload of the certificate
func (m *Manager) HandleCertUpload(w http.ResponseWriter, r *http.Request) {
// check if request method is POST
@@ -124,6 +168,13 @@ func (m *Manager) HandleCertUpload(w http.ResponseWriter, r *http.Request) {
defer file.Close()
// create file in upload directory
// Read file contents for validation
fileBytes, err := io.ReadAll(file)
if err != nil {
http.Error(w, "Failed to read file", http.StatusBadRequest)
return
}
os.MkdirAll(m.CertStore, 0775)
f, err := os.Create(filepath.Join(m.CertStore, overWriteFilename))
if err != nil {
@@ -138,6 +189,11 @@ func (m *Manager) HandleCertUpload(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Failed to save file", http.StatusInternalServerError)
return
}
_, err = f.Write(fileBytes)
if err != nil {
http.Error(w, "Failed to save file", http.StatusInternalServerError)
return
}
//Update cert list
m.UpdateLoadedCertList()
@@ -215,11 +271,13 @@ func (m *Manager) HandleListCertificate(w http.ResponseWriter, r *http.Request)
showDate, _ := utils.GetBool(r, "date")
if showDate {
type CertInfo struct {
Domain string
Domain string // Domain name of the certificate
Filename string // Filename that stores the certificate
LastModifiedDate string
ExpireDate string
RemainingDays int
UseDNS bool
UseDNS bool // Whether this cert is obtained via DNS challenge
IsFallback bool // Whether this cert is the fallback/default cert
}
results := []*CertInfo{}
@@ -248,7 +306,7 @@ func (m *Manager) HandleListCertificate(w http.ResponseWriter, r *http.Request)
if err == nil {
certExpireTime = cert.NotAfter.Format("2006-01-02 15:04:05")
duration := cert.NotAfter.Sub(time.Now())
duration := time.Until(cert.NotAfter)
// Convert the duration to days
expiredIn = int(duration.Hours() / 24)
@@ -262,12 +320,23 @@ func (m *Manager) HandleListCertificate(w http.ResponseWriter, r *http.Request)
useDNSValidation = certInfo.UseDNS
}
certDomain := ""
block, _ := pem.Decode(certBtyes)
if block != nil {
cert, err := x509.ParseCertificate(block.Bytes)
if err == nil {
certDomain = cert.Subject.CommonName
}
}
thisCertInfo := CertInfo{
Domain: filename,
Domain: certDomain,
Filename: filename,
LastModifiedDate: modifiedTime,
ExpireDate: certExpireTime,
RemainingDays: expiredIn,
UseDNS: useDNSValidation,
IsFallback: (filename == "default"), //TODO: figure out a better implementation
}
results = append(results, &thisCertInfo)
@@ -350,3 +419,25 @@ func (m *Manager) HandleSelfSignCertGenerate(w http.ResponseWriter, r *http.Requ
}
utils.SendOK(w)
}
// Extract the common name from a PEM encoded certificate
func (m *Manager) HandleGetCertCommonName(w http.ResponseWriter, r *http.Request) {
certContents, err := utils.PostPara(r, "cert")
if err != nil {
utils.SendErrorResponse(w, "Certificate content not provided")
return
}
block, _ := pem.Decode([]byte(certContents))
if block == nil {
utils.SendErrorResponse(w, "Failed to decode PEM block")
return
}
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
utils.SendErrorResponse(w, "Failed to parse certificate: "+err.Error())
return
}
js, _ := json.Marshal(cert.Subject.CommonName)
utils.SendJSONResponse(w, string(js))
}

View File

@@ -29,21 +29,6 @@ func getCertPairs(certFiles []string) []string {
return result
}
// Get the cloest subdomain certificate from a list of domains
func matchClosestDomainCertificate(subdomain string, domains []string) string {
var matchingDomain string = ""
maxLength := 0
for _, domain := range domains {
if strings.HasSuffix(subdomain, "."+domain) && len(domain) > maxLength {
matchingDomain = domain
maxLength = len(domain)
}
}
return matchingDomain
}
// Convert a domain name to a filename format
func domainToFilename(domain string, ext string) string {
// Replace wildcard '*' with '_'
@@ -52,6 +37,10 @@ func domainToFilename(domain string, ext string) string {
domain = "_" + strings.TrimPrefix(domain, "*")
}
if strings.HasPrefix(".", ext) {
ext = strings.TrimPrefix(ext, ".")
}
// Add .pem extension
ext = strings.TrimPrefix(ext, ".") // Ensure ext does not start with a dot
return domain + "." + ext

View File

@@ -211,7 +211,6 @@ func getWebsiteStatus(url string) (int, error) {
}
resp, err := client.Do(req)
//resp, err := client.Get(url)
if err != nil {
//Try replace the http with https and vise versa
rewriteURL := ""

View File

@@ -199,4 +199,4 @@ func ValidateListeningAddress(address string) bool {
}
return true
}
}