diff --git a/src/mod/streamproxy/handler.go b/src/mod/streamproxy/handler.go index bc78148..1d2aaf9 100644 --- a/src/mod/streamproxy/handler.go +++ b/src/mod/streamproxy/handler.go @@ -47,15 +47,18 @@ func (m *Manager) HandleAddProxyConfig(w http.ResponseWriter, r *http.Request) { useTCP, _ := utils.PostBool(r, "useTCP") useUDP, _ := utils.PostBool(r, "useUDP") + // useProxyProtocol, _ := utils.PostBool(r, "useProxyProtocol") + useProxyProtocol := true //Create the target config newConfigUUID := m.NewConfig(&ProxyRelayOptions{ - Name: name, - ListeningAddr: strings.TrimSpace(listenAddr), - ProxyAddr: strings.TrimSpace(proxyAddr), - Timeout: timeout, - UseTCP: useTCP, - UseUDP: useUDP, + Name: name, + ListeningAddr: strings.TrimSpace(listenAddr), + ProxyAddr: strings.TrimSpace(proxyAddr), + Timeout: timeout, + UseTCP: useTCP, + UseUDP: useUDP, + UseProxyProtocol: useProxyProtocol, }) js, _ := json.Marshal(newConfigUUID) diff --git a/src/mod/streamproxy/streamproxy.go b/src/mod/streamproxy/streamproxy.go index 36155a3..e3f8057 100644 --- a/src/mod/streamproxy/streamproxy.go +++ b/src/mod/streamproxy/streamproxy.go @@ -24,12 +24,13 @@ import ( */ type ProxyRelayOptions struct { - Name string - ListeningAddr string - ProxyAddr string - Timeout int - UseTCP bool - UseUDP bool + Name string + ListeningAddr string + ProxyAddr string + Timeout int + UseTCP bool + UseUDP bool + UseProxyProtocol bool } type ProxyRelayConfig struct { @@ -41,6 +42,7 @@ type ProxyRelayConfig struct { ProxyTargetAddr string //Proxy target address UseTCP bool //Enable TCP proxy UseUDP bool //Enable UDP proxy + UseProxyProtocol bool //Enable Proxy Protocol Timeout int //Timeout for connection in sec tcpStopChan chan bool //Stop channel for TCP listener udpStopChan chan bool //Stop channel for UDP listener @@ -157,6 +159,7 @@ func (m *Manager) NewConfig(config *ProxyRelayOptions) string { ProxyTargetAddr: config.ProxyAddr, UseTCP: config.UseTCP, UseUDP: config.UseUDP, + UseProxyProtocol: config.UseProxyProtocol, Timeout: config.Timeout, tcpStopChan: nil, udpStopChan: nil, diff --git a/src/mod/streamproxy/tcpprox.go b/src/mod/streamproxy/tcpprox.go index 6fcaed0..439a8ee 100644 --- a/src/mod/streamproxy/tcpprox.go +++ b/src/mod/streamproxy/tcpprox.go @@ -2,6 +2,7 @@ package streamproxy import ( "errors" + "fmt" "io" "log" "net" @@ -43,6 +44,23 @@ func connCopy(conn1 net.Conn, conn2 net.Conn, wg *sync.WaitGroup, accumulator *a wg.Done() } +func writeProxyProtocolHeaderV1(dst net.Conn, src net.Conn) error { + clientAddr, ok1 := src.RemoteAddr().(*net.TCPAddr) + proxyAddr, ok2 := src.LocalAddr().(*net.TCPAddr) + if !ok1 || !ok2 { + return errors.New("invalid TCP address for proxy protocol") + } + + header := fmt.Sprintf("PROXY TCP4 %s %s %d %d\r\n", + clientAddr.IP.String(), + proxyAddr.IP.String(), + clientAddr.Port, + proxyAddr.Port) + + _, err := dst.Write([]byte(header)) + return err +} + func forward(conn1 net.Conn, conn2 net.Conn, aTob *atomic.Int64, bToa *atomic.Int64) { log.Printf("[+] start transmit. [%s],[%s] <-> [%s],[%s] \n", conn1.LocalAddr().String(), conn1.RemoteAddr().String(), conn2.LocalAddr().String(), conn2.RemoteAddr().String()) var wg sync.WaitGroup @@ -127,7 +145,7 @@ func (c *ProxyRelayConfig) Port2host(allowPort string, targetAddress string, sto //Connection error. Retry continue } - + log.Println("[+]", "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA") go func(targetAddress string) { log.Println("[+]", "start connect host:["+targetAddress+"]") target, err := net.Dial("tcp", targetAddress) @@ -140,6 +158,20 @@ func (c *ProxyRelayConfig) Port2host(allowPort string, targetAddress string, sto return } log.Println("[→]", "connect target address ["+targetAddress+"] success.") + + if c.UseProxyProtocol { + log.Println("[+]", "write proxy protocol header to target address ["+targetAddress+"]") + err = writeProxyProtocolHeaderV1(target, conn) + if err != nil { + log.Println("[x]", "Write proxy protocol header faild: ", err) + target.Close() + conn.Close() + log.Println("[←]", "close the connect at local:["+conn.LocalAddr().String()+"] and remote:["+conn.RemoteAddr().String()+"]") + time.Sleep(time.Duration(c.Timeout) * time.Second) + return + } + } + forward(target, conn, &c.aTobAccumulatedByteTransfer, &c.bToaAccumulatedByteTransfer) }(targetAddress) }