Added Proxy Protocol V1 function.

- Added useProxyProtocol in ProxyRelayConfig
- Added writeProxyProtocolHeaderV1 function
This commit is contained in:
Jemmy
2025-07-02 16:50:14 +08:00
parent 8030f3d62a
commit b59ac47c8c
3 changed files with 51 additions and 13 deletions

View File

@@ -2,6 +2,7 @@ package streamproxy
import (
"errors"
"fmt"
"io"
"log"
"net"
@@ -43,6 +44,23 @@ func connCopy(conn1 net.Conn, conn2 net.Conn, wg *sync.WaitGroup, accumulator *a
wg.Done()
}
func writeProxyProtocolHeaderV1(dst net.Conn, src net.Conn) error {
clientAddr, ok1 := src.RemoteAddr().(*net.TCPAddr)
proxyAddr, ok2 := src.LocalAddr().(*net.TCPAddr)
if !ok1 || !ok2 {
return errors.New("invalid TCP address for proxy protocol")
}
header := fmt.Sprintf("PROXY TCP4 %s %s %d %d\r\n",
clientAddr.IP.String(),
proxyAddr.IP.String(),
clientAddr.Port,
proxyAddr.Port)
_, err := dst.Write([]byte(header))
return err
}
func forward(conn1 net.Conn, conn2 net.Conn, aTob *atomic.Int64, bToa *atomic.Int64) {
log.Printf("[+] start transmit. [%s],[%s] <-> [%s],[%s] \n", conn1.LocalAddr().String(), conn1.RemoteAddr().String(), conn2.LocalAddr().String(), conn2.RemoteAddr().String())
var wg sync.WaitGroup
@@ -127,7 +145,7 @@ func (c *ProxyRelayConfig) Port2host(allowPort string, targetAddress string, sto
//Connection error. Retry
continue
}
log.Println("[+]", "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")
go func(targetAddress string) {
log.Println("[+]", "start connect host:["+targetAddress+"]")
target, err := net.Dial("tcp", targetAddress)
@@ -140,6 +158,20 @@ func (c *ProxyRelayConfig) Port2host(allowPort string, targetAddress string, sto
return
}
log.Println("[→]", "connect target address ["+targetAddress+"] success.")
if c.UseProxyProtocol {
log.Println("[+]", "write proxy protocol header to target address ["+targetAddress+"]")
err = writeProxyProtocolHeaderV1(target, conn)
if err != nil {
log.Println("[x]", "Write proxy protocol header faild: ", err)
target.Close()
conn.Close()
log.Println("[←]", "close the connect at local:["+conn.LocalAddr().String()+"] and remote:["+conn.RemoteAddr().String()+"]")
time.Sleep(time.Duration(c.Timeout) * time.Second)
return
}
}
forward(target, conn, &c.aTobAccumulatedByteTransfer, &c.bToaAccumulatedByteTransfer)
}(targetAddress)
}