mirror of
https://github.com/tobychui/zoraxy.git
synced 2025-08-11 07:37:51 +02:00
Removed alpha prototype source
This commit is contained in:
@@ -1,76 +0,0 @@
|
||||
package aroz
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
)
|
||||
|
||||
//To be used with arozos system
|
||||
type ArozHandler struct {
|
||||
Port string
|
||||
restfulEndpoint string
|
||||
}
|
||||
|
||||
//Information required for registering this subservice to arozos
|
||||
type ServiceInfo struct {
|
||||
Name string //Name of this module. e.g. "Audio"
|
||||
Desc string //Description for this module
|
||||
Group string //Group of the module, e.g. "system" / "media" etc
|
||||
IconPath string //Module icon image path e.g. "Audio/img/function_icon.png"
|
||||
Version string //Version of the module. Format: [0-9]*.[0-9][0-9].[0-9]
|
||||
StartDir string //Default starting dir, e.g. "Audio/index.html"
|
||||
SupportFW bool //Support floatWindow. If yes, floatWindow dir will be loaded
|
||||
LaunchFWDir string //This link will be launched instead of 'StartDir' if fw mode
|
||||
SupportEmb bool //Support embedded mode
|
||||
LaunchEmb string //This link will be launched instead of StartDir / Fw if a file is opened with this module
|
||||
InitFWSize []int //Floatwindow init size. [0] => Width, [1] => Height
|
||||
InitEmbSize []int //Embedded mode init size. [0] => Width, [1] => Height
|
||||
SupportedExt []string //Supported File Extensions. e.g. ".mp3", ".flac", ".wav"
|
||||
}
|
||||
|
||||
//This function will request the required flag from the startup paramters and parse it to the need of the arozos.
|
||||
func HandleFlagParse(info ServiceInfo) *ArozHandler {
|
||||
var infoRequestMode = flag.Bool("info", false, "Show information about this program in JSON")
|
||||
var port = flag.String("port", ":8000", "Management web interface listening port")
|
||||
var restful = flag.String("rpt", "", "Reserved")
|
||||
//Parse the flags
|
||||
flag.Parse()
|
||||
if *infoRequestMode {
|
||||
//Information request mode
|
||||
jsonString, _ := json.MarshalIndent(info, "", " ")
|
||||
fmt.Println(string(jsonString))
|
||||
os.Exit(0)
|
||||
}
|
||||
return &ArozHandler{
|
||||
Port: *port,
|
||||
restfulEndpoint: *restful,
|
||||
}
|
||||
}
|
||||
|
||||
//Get the username and resources access token from the request, return username, token
|
||||
func (a *ArozHandler) GetUserInfoFromRequest(w http.ResponseWriter, r *http.Request) (string, string) {
|
||||
username := r.Header.Get("aouser")
|
||||
token := r.Header.Get("aotoken")
|
||||
|
||||
return username, token
|
||||
}
|
||||
|
||||
func (a *ArozHandler) IsUsingExternalPermissionManager() bool {
|
||||
return !(a.restfulEndpoint == "")
|
||||
}
|
||||
|
||||
//Request gateway interface for advance permission sandbox control
|
||||
func (a *ArozHandler) RequestGatewayInterface(token string, script string) (*http.Response, error) {
|
||||
resp, err := http.PostForm(a.restfulEndpoint,
|
||||
url.Values{"token": {token}, "script": {script}})
|
||||
if err != nil {
|
||||
// handle error
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
Binary file not shown.
@@ -1,478 +0,0 @@
|
||||
package auth
|
||||
|
||||
/*
|
||||
|
||||
author: tobychui
|
||||
*/
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha512"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/mail"
|
||||
"strings"
|
||||
|
||||
"encoding/hex"
|
||||
"log"
|
||||
|
||||
"github.com/gorilla/sessions"
|
||||
db "imuslab.com/zoraxy/mod/database"
|
||||
"imuslab.com/zoraxy/mod/utils"
|
||||
)
|
||||
|
||||
type AuthAgent struct {
|
||||
//Session related
|
||||
SessionName string
|
||||
SessionStore *sessions.CookieStore
|
||||
Database *db.Database
|
||||
LoginRedirectionHandler func(http.ResponseWriter, *http.Request)
|
||||
}
|
||||
|
||||
type AuthEndpoints struct {
|
||||
Login string
|
||||
Logout string
|
||||
Register string
|
||||
CheckLoggedIn string
|
||||
Autologin string
|
||||
}
|
||||
|
||||
//Constructor
|
||||
func NewAuthenticationAgent(sessionName string, key []byte, sysdb *db.Database, allowReg bool, loginRedirectionHandler func(http.ResponseWriter, *http.Request)) *AuthAgent {
|
||||
store := sessions.NewCookieStore(key)
|
||||
err := sysdb.NewTable("auth")
|
||||
if err != nil {
|
||||
log.Println("Failed to create auth database. Terminating.")
|
||||
panic(err)
|
||||
}
|
||||
|
||||
//Create a new AuthAgent object
|
||||
newAuthAgent := AuthAgent{
|
||||
SessionName: sessionName,
|
||||
SessionStore: store,
|
||||
Database: sysdb,
|
||||
LoginRedirectionHandler: loginRedirectionHandler,
|
||||
}
|
||||
|
||||
//Return the authAgent
|
||||
return &newAuthAgent
|
||||
}
|
||||
|
||||
func GetSessionKey(sysdb *db.Database) (string, error) {
|
||||
sysdb.NewTable("auth")
|
||||
sessionKey := ""
|
||||
if !sysdb.KeyExists("auth", "sessionkey") {
|
||||
key := make([]byte, 32)
|
||||
rand.Read(key)
|
||||
sessionKey = string(key)
|
||||
sysdb.Write("auth", "sessionkey", sessionKey)
|
||||
log.Println("[Auth] New authentication session key generated")
|
||||
} else {
|
||||
log.Println("[Auth] Authentication session key loaded from database")
|
||||
err := sysdb.Read("auth", "sessionkey", &sessionKey)
|
||||
if err != nil {
|
||||
return "", errors.New("database read error. Is the database file corrupted?")
|
||||
}
|
||||
}
|
||||
return sessionKey, nil
|
||||
}
|
||||
|
||||
//This function will handle an http request and redirect to the given login address if not logged in
|
||||
func (a *AuthAgent) HandleCheckAuth(w http.ResponseWriter, r *http.Request, handler func(http.ResponseWriter, *http.Request)) {
|
||||
if a.CheckAuth(r) {
|
||||
//User already logged in
|
||||
handler(w, r)
|
||||
} else {
|
||||
//User not logged in
|
||||
a.LoginRedirectionHandler(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
//Handle login request, require POST username and password
|
||||
func (a *AuthAgent) HandleLogin(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
//Get username from request using POST mode
|
||||
username, err := utils.PostPara(r, "username")
|
||||
if err != nil {
|
||||
//Username not defined
|
||||
log.Println("[Auth] " + r.RemoteAddr + " trying to login with username: " + username)
|
||||
utils.SendErrorResponse(w, "Username not defined or empty.")
|
||||
return
|
||||
}
|
||||
|
||||
//Get password from request using POST mode
|
||||
password, err := utils.PostPara(r, "password")
|
||||
if err != nil {
|
||||
//Password not defined
|
||||
utils.SendErrorResponse(w, "Password not defined or empty.")
|
||||
return
|
||||
}
|
||||
|
||||
//Get rememberme settings
|
||||
rememberme := false
|
||||
rmbme, _ := utils.PostPara(r, "rmbme")
|
||||
if rmbme == "true" {
|
||||
rememberme = true
|
||||
}
|
||||
|
||||
//Check the database and see if this user is in the database
|
||||
passwordCorrect, rejectionReason := a.ValidateUsernameAndPasswordWithReason(username, password)
|
||||
//The database contain this user information. Check its password if it is correct
|
||||
if passwordCorrect {
|
||||
//Password correct
|
||||
// Set user as authenticated
|
||||
a.LoginUserByRequest(w, r, username, rememberme)
|
||||
|
||||
//Print the login message to console
|
||||
log.Println(username + " logged in.")
|
||||
utils.SendOK(w)
|
||||
} else {
|
||||
//Password incorrect
|
||||
log.Println(username + " login request rejected: " + rejectionReason)
|
||||
|
||||
utils.SendErrorResponse(w, rejectionReason)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (a *AuthAgent) ValidateUsernameAndPassword(username string, password string) bool {
|
||||
succ, _ := a.ValidateUsernameAndPasswordWithReason(username, password)
|
||||
return succ
|
||||
}
|
||||
|
||||
//validate the username and password, return reasons if the auth failed
|
||||
func (a *AuthAgent) ValidateUsernameAndPasswordWithReason(username string, password string) (bool, string) {
|
||||
hashedPassword := Hash(password)
|
||||
var passwordInDB string
|
||||
err := a.Database.Read("auth", "passhash/"+username, &passwordInDB)
|
||||
if err != nil {
|
||||
//User not found or db exception
|
||||
log.Println("[Auth] " + username + " login with incorrect password")
|
||||
return false, "Invalid username or password"
|
||||
}
|
||||
|
||||
if passwordInDB == hashedPassword {
|
||||
return true, ""
|
||||
} else {
|
||||
return false, "Invalid username or password"
|
||||
}
|
||||
}
|
||||
|
||||
//Login the user by creating a valid session for this user
|
||||
func (a *AuthAgent) LoginUserByRequest(w http.ResponseWriter, r *http.Request, username string, rememberme bool) {
|
||||
session, _ := a.SessionStore.Get(r, a.SessionName)
|
||||
|
||||
session.Values["authenticated"] = true
|
||||
session.Values["username"] = username
|
||||
session.Values["rememberMe"] = rememberme
|
||||
|
||||
//Check if remember me is clicked. If yes, set the maxage to 1 week.
|
||||
if rememberme {
|
||||
session.Options = &sessions.Options{
|
||||
MaxAge: 3600 * 24 * 7, //One week
|
||||
Path: "/",
|
||||
}
|
||||
} else {
|
||||
session.Options = &sessions.Options{
|
||||
MaxAge: 3600 * 1, //One hour
|
||||
Path: "/",
|
||||
}
|
||||
}
|
||||
session.Save(r, w)
|
||||
}
|
||||
|
||||
//Handle logout, reply OK after logged out. WILL NOT DO REDIRECTION
|
||||
func (a *AuthAgent) HandleLogout(w http.ResponseWriter, r *http.Request) {
|
||||
username, err := a.GetUserName(w, r)
|
||||
if username != "" {
|
||||
log.Println(username + " logged out.")
|
||||
}
|
||||
// Revoke users authentication
|
||||
err = a.Logout(w, r)
|
||||
if err != nil {
|
||||
utils.SendErrorResponse(w, "Logout failed")
|
||||
return
|
||||
}
|
||||
|
||||
w.Write([]byte("OK"))
|
||||
}
|
||||
|
||||
func (a *AuthAgent) Logout(w http.ResponseWriter, r *http.Request) error {
|
||||
session, err := a.SessionStore.Get(r, a.SessionName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
session.Values["authenticated"] = false
|
||||
session.Values["username"] = nil
|
||||
session.Save(r, w)
|
||||
return nil
|
||||
}
|
||||
|
||||
//Get the current session username from request
|
||||
func (a *AuthAgent) GetUserName(w http.ResponseWriter, r *http.Request) (string, error) {
|
||||
if a.CheckAuth(r) {
|
||||
//This user has logged in.
|
||||
session, _ := a.SessionStore.Get(r, a.SessionName)
|
||||
return session.Values["username"].(string), nil
|
||||
} else {
|
||||
//This user has not logged in.
|
||||
return "", errors.New("user not logged in")
|
||||
}
|
||||
}
|
||||
|
||||
//Get the current session user email from request
|
||||
func (a *AuthAgent) GetUserEmail(w http.ResponseWriter, r *http.Request) (string, error) {
|
||||
if a.CheckAuth(r) {
|
||||
//This user has logged in.
|
||||
session, _ := a.SessionStore.Get(r, a.SessionName)
|
||||
username := session.Values["username"].(string)
|
||||
userEmail := ""
|
||||
err := a.Database.Read("auth", "email/"+username, &userEmail)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return userEmail, nil
|
||||
} else {
|
||||
//This user has not logged in.
|
||||
return "", errors.New("user not logged in")
|
||||
}
|
||||
}
|
||||
|
||||
//Check if the user has logged in, return true / false in JSON
|
||||
func (a *AuthAgent) CheckLogin(w http.ResponseWriter, r *http.Request) {
|
||||
if a.CheckAuth(r) {
|
||||
utils.SendJSONResponse(w, "true")
|
||||
} else {
|
||||
utils.SendJSONResponse(w, "false")
|
||||
}
|
||||
}
|
||||
|
||||
//Handle new user register. Require POST username, password, group.
|
||||
func (a *AuthAgent) HandleRegister(w http.ResponseWriter, r *http.Request, callback func(string, string)) {
|
||||
//Get username from request
|
||||
newusername, err := utils.PostPara(r, "username")
|
||||
if err != nil {
|
||||
utils.SendErrorResponse(w, "Missing 'username' paramter")
|
||||
return
|
||||
}
|
||||
|
||||
//Get password from request
|
||||
password, err := utils.PostPara(r, "password")
|
||||
if err != nil {
|
||||
utils.SendErrorResponse(w, "Missing 'password' paramter")
|
||||
return
|
||||
}
|
||||
|
||||
//Get email from request
|
||||
email, err := utils.PostPara(r, "email")
|
||||
if err != nil {
|
||||
utils.SendErrorResponse(w, "Missing 'email' paramter")
|
||||
return
|
||||
}
|
||||
|
||||
_, err = mail.ParseAddress(email)
|
||||
if err != nil {
|
||||
utils.SendErrorResponse(w, "Invalid or malformed email")
|
||||
return
|
||||
}
|
||||
|
||||
//Ok to proceed create this user
|
||||
err = a.CreateUserAccount(newusername, password, email)
|
||||
if err != nil {
|
||||
utils.SendErrorResponse(w, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
//Do callback if exists
|
||||
if callback != nil {
|
||||
callback(newusername, email)
|
||||
}
|
||||
|
||||
//Return to the client with OK
|
||||
utils.SendOK(w)
|
||||
log.Println("[Auth] New user " + newusername + " added to system.")
|
||||
}
|
||||
|
||||
//Handle new user register without confirmation email. Require POST username, password, group.
|
||||
func (a *AuthAgent) HandleRegisterWithoutEmail(w http.ResponseWriter, r *http.Request, callback func(string, string)) {
|
||||
//Get username from request
|
||||
newusername, err := utils.PostPara(r, "username")
|
||||
if err != nil {
|
||||
utils.SendErrorResponse(w, "Missing 'username' paramter")
|
||||
return
|
||||
}
|
||||
|
||||
//Get password from request
|
||||
password, err := utils.PostPara(r, "password")
|
||||
if err != nil {
|
||||
utils.SendErrorResponse(w, "Missing 'password' paramter")
|
||||
return
|
||||
}
|
||||
|
||||
//Ok to proceed create this user
|
||||
err = a.CreateUserAccount(newusername, password, "")
|
||||
if err != nil {
|
||||
utils.SendErrorResponse(w, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
//Do callback if exists
|
||||
if callback != nil {
|
||||
callback(newusername, "")
|
||||
}
|
||||
|
||||
//Return to the client with OK
|
||||
utils.SendOK(w)
|
||||
log.Println("[Auth] Admin account created: " + newusername)
|
||||
}
|
||||
|
||||
//Check authentication from request header's session value
|
||||
func (a *AuthAgent) CheckAuth(r *http.Request) bool {
|
||||
session, err := a.SessionStore.Get(r, a.SessionName)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
// Check if user is authenticated
|
||||
if auth, ok := session.Values["authenticated"].(bool); !ok || !auth {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
//Handle de-register of users. Require POST username.
|
||||
//THIS FUNCTION WILL NOT CHECK FOR PERMISSION. PLEASE USE WITH PERMISSION HANDLER
|
||||
func (a *AuthAgent) HandleUnregister(w http.ResponseWriter, r *http.Request) {
|
||||
//Check if the user is logged in
|
||||
if !a.CheckAuth(r) {
|
||||
//This user has not logged in
|
||||
utils.SendErrorResponse(w, "Login required to remove user from the system.")
|
||||
return
|
||||
}
|
||||
|
||||
//Get username from request
|
||||
username, err := utils.PostPara(r, "username")
|
||||
if err != nil {
|
||||
utils.SendErrorResponse(w, "Missing 'username' paramter")
|
||||
return
|
||||
}
|
||||
|
||||
err = a.UnregisterUser(username)
|
||||
if err != nil {
|
||||
utils.SendErrorResponse(w, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
//Return to the client with OK
|
||||
utils.SendOK(w)
|
||||
log.Println("[Auth] User " + username + " has been removed from the system.")
|
||||
}
|
||||
|
||||
func (a *AuthAgent) UnregisterUser(username string) error {
|
||||
//Check if the user exists in the system database.
|
||||
if !a.Database.KeyExists("auth", "passhash/"+username) {
|
||||
//This user do not exists.
|
||||
return errors.New("this user does not exists")
|
||||
}
|
||||
|
||||
//OK! Remove the user from the database
|
||||
a.Database.Delete("auth", "passhash/"+username)
|
||||
a.Database.Delete("auth", "email/"+username)
|
||||
return nil
|
||||
}
|
||||
|
||||
//Get the number of users in the system
|
||||
func (a *AuthAgent) GetUserCounts() int {
|
||||
entries, _ := a.Database.ListTable("auth")
|
||||
usercount := 0
|
||||
for _, keypairs := range entries {
|
||||
if strings.Contains(string(keypairs[0]), "passhash/") {
|
||||
//This is a user registry
|
||||
usercount++
|
||||
}
|
||||
}
|
||||
|
||||
if usercount == 0 {
|
||||
log.Println("There are no user in the database.")
|
||||
}
|
||||
return usercount
|
||||
}
|
||||
|
||||
//List all username within the system
|
||||
func (a *AuthAgent) ListUsers() []string {
|
||||
entries, _ := a.Database.ListTable("auth")
|
||||
results := []string{}
|
||||
for _, keypairs := range entries {
|
||||
if strings.Contains(string(keypairs[0]), "passhash/") {
|
||||
username := strings.Split(string(keypairs[0]), "/")[1]
|
||||
results = append(results, username)
|
||||
}
|
||||
}
|
||||
return results
|
||||
}
|
||||
|
||||
//Check if the given username exists
|
||||
func (a *AuthAgent) UserExists(username string) bool {
|
||||
userpasswordhash := ""
|
||||
err := a.Database.Read("auth", "passhash/"+username, &userpasswordhash)
|
||||
if err != nil || userpasswordhash == "" {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
//Update the session expire time given the request header.
|
||||
func (a *AuthAgent) UpdateSessionExpireTime(w http.ResponseWriter, r *http.Request) bool {
|
||||
session, _ := a.SessionStore.Get(r, a.SessionName)
|
||||
if session.Values["authenticated"].(bool) {
|
||||
//User authenticated. Extend its expire time
|
||||
rememberme := session.Values["rememberMe"].(bool)
|
||||
//Extend the session expire time
|
||||
if rememberme {
|
||||
session.Options = &sessions.Options{
|
||||
MaxAge: 3600 * 24 * 7, //One week
|
||||
Path: "/",
|
||||
}
|
||||
} else {
|
||||
session.Options = &sessions.Options{
|
||||
MaxAge: 3600 * 1, //One hour
|
||||
Path: "/",
|
||||
}
|
||||
}
|
||||
session.Save(r, w)
|
||||
return true
|
||||
} else {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
//Create user account
|
||||
func (a *AuthAgent) CreateUserAccount(newusername string, password string, email string) error {
|
||||
//Check user already exists
|
||||
if a.UserExists(newusername) {
|
||||
return errors.New("user with same name already exists")
|
||||
}
|
||||
|
||||
key := newusername
|
||||
hashedPassword := Hash(password)
|
||||
err := a.Database.Write("auth", "passhash/"+key, hashedPassword)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if email != "" {
|
||||
err = a.Database.Write("auth", "email/"+key, email)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
//Hash the given raw string into sha512 hash
|
||||
func Hash(raw string) string {
|
||||
h := sha512.New()
|
||||
h.Write([]byte(raw))
|
||||
return hex.EncodeToString(h.Sum(nil))
|
||||
}
|
@@ -1,53 +0,0 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type RouterOption struct {
|
||||
AuthAgent *AuthAgent
|
||||
RequireAuth bool //This router require authentication
|
||||
DeniedHandler func(http.ResponseWriter, *http.Request) //Things to do when request is rejected
|
||||
|
||||
}
|
||||
|
||||
type RouterDef struct {
|
||||
option RouterOption
|
||||
endpoints map[string]func(http.ResponseWriter, *http.Request)
|
||||
}
|
||||
|
||||
func NewManagedHTTPRouter(option RouterOption) *RouterDef {
|
||||
return &RouterDef{
|
||||
option: option,
|
||||
endpoints: map[string]func(http.ResponseWriter, *http.Request){},
|
||||
}
|
||||
}
|
||||
|
||||
func (router *RouterDef) HandleFunc(endpoint string, handler func(http.ResponseWriter, *http.Request)) error {
|
||||
//Check if the endpoint already registered
|
||||
if _, exist := router.endpoints[endpoint]; exist {
|
||||
log.Println("WARNING! Duplicated registering of web endpoint: " + endpoint)
|
||||
return errors.New("endpoint register duplicated")
|
||||
}
|
||||
|
||||
authAgent := router.option.AuthAgent
|
||||
|
||||
//OK. Register handler
|
||||
http.HandleFunc(endpoint, func(w http.ResponseWriter, r *http.Request) {
|
||||
//Check authentication of the user
|
||||
if router.option.RequireAuth {
|
||||
authAgent.HandleCheckAuth(w, r, func(w http.ResponseWriter, r *http.Request) {
|
||||
handler(w, r)
|
||||
})
|
||||
} else {
|
||||
handler(w, r)
|
||||
}
|
||||
|
||||
})
|
||||
|
||||
router.endpoints[endpoint] = handler
|
||||
|
||||
return nil
|
||||
}
|
@@ -1,120 +0,0 @@
|
||||
package database
|
||||
|
||||
/*
|
||||
ArOZ Online Database Access Module
|
||||
author: tobychui
|
||||
|
||||
This is an improved Object oriented base solution to the original
|
||||
aroz online database script.
|
||||
*/
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
type Database struct {
|
||||
Db interface{} //This will be nil on openwrt and *bolt.DB in the rest of the systems
|
||||
Tables sync.Map
|
||||
ReadOnly bool
|
||||
}
|
||||
|
||||
func NewDatabase(dbfile string, readOnlyMode bool) (*Database, error) {
|
||||
return newDatabase(dbfile, readOnlyMode)
|
||||
}
|
||||
|
||||
/*
|
||||
Create / Drop a table
|
||||
Usage:
|
||||
err := sysdb.NewTable("MyTable")
|
||||
err := sysdb.DropTable("MyTable")
|
||||
*/
|
||||
|
||||
func (d *Database) UpdateReadWriteMode(readOnly bool) {
|
||||
d.ReadOnly = readOnly
|
||||
}
|
||||
|
||||
//Dump the whole db into a log file
|
||||
func (d *Database) Dump(filename string) ([]string, error) {
|
||||
return d.dump(filename)
|
||||
}
|
||||
|
||||
//Create a new table
|
||||
func (d *Database) NewTable(tableName string) error {
|
||||
return d.newTable(tableName)
|
||||
}
|
||||
|
||||
//Check is table exists
|
||||
func (d *Database) TableExists(tableName string) bool {
|
||||
return d.tableExists(tableName)
|
||||
}
|
||||
|
||||
//Drop the given table
|
||||
func (d *Database) DropTable(tableName string) error {
|
||||
return d.dropTable(tableName)
|
||||
}
|
||||
|
||||
/*
|
||||
Write to database with given tablename and key. Example Usage:
|
||||
type demo struct{
|
||||
content string
|
||||
}
|
||||
thisDemo := demo{
|
||||
content: "Hello World",
|
||||
}
|
||||
err := sysdb.Write("MyTable", "username/message",thisDemo);
|
||||
*/
|
||||
func (d *Database) Write(tableName string, key string, value interface{}) error {
|
||||
return d.write(tableName, key, value)
|
||||
}
|
||||
|
||||
/*
|
||||
Read from database and assign the content to a given datatype. Example Usage:
|
||||
|
||||
type demo struct{
|
||||
content string
|
||||
}
|
||||
thisDemo := new(demo)
|
||||
err := sysdb.Read("MyTable", "username/message",&thisDemo);
|
||||
*/
|
||||
|
||||
func (d *Database) Read(tableName string, key string, assignee interface{}) error {
|
||||
return d.read(tableName, key, assignee)
|
||||
}
|
||||
|
||||
func (d *Database) KeyExists(tableName string, key string) bool {
|
||||
return d.keyExists(tableName, key)
|
||||
}
|
||||
|
||||
/*
|
||||
Delete a value from the database table given tablename and key
|
||||
|
||||
err := sysdb.Delete("MyTable", "username/message");
|
||||
*/
|
||||
func (d *Database) Delete(tableName string, key string) error {
|
||||
return d.delete(tableName, key)
|
||||
}
|
||||
|
||||
/*
|
||||
//List table example usage
|
||||
//Assume the value is stored as a struct named "groupstruct"
|
||||
|
||||
entries, err := sysdb.ListTable("test")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
for _, keypairs := range entries{
|
||||
log.Println(string(keypairs[0]))
|
||||
group := new(groupstruct)
|
||||
json.Unmarshal(keypairs[1], &group)
|
||||
log.Println(group);
|
||||
}
|
||||
|
||||
*/
|
||||
|
||||
func (d *Database) ListTable(tableName string) ([][][]byte, error) {
|
||||
return d.listTable(tableName)
|
||||
}
|
||||
|
||||
func (d *Database) Close() {
|
||||
d.close()
|
||||
}
|
@@ -1,186 +0,0 @@
|
||||
//go:build !mipsle && !riscv64
|
||||
// +build !mipsle,!riscv64
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log"
|
||||
"sync"
|
||||
|
||||
"github.com/boltdb/bolt"
|
||||
)
|
||||
|
||||
func newDatabase(dbfile string, readOnlyMode bool) (*Database, error) {
|
||||
db, err := bolt.Open(dbfile, 0600, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tableMap := sync.Map{}
|
||||
//Build the table list from database
|
||||
err = db.View(func(tx *bolt.Tx) error {
|
||||
return tx.ForEach(func(name []byte, _ *bolt.Bucket) error {
|
||||
tableMap.Store(string(name), "")
|
||||
return nil
|
||||
})
|
||||
})
|
||||
|
||||
return &Database{
|
||||
Db: db,
|
||||
Tables: tableMap,
|
||||
ReadOnly: readOnlyMode,
|
||||
}, err
|
||||
}
|
||||
|
||||
//Dump the whole db into a log file
|
||||
func (d *Database) dump(filename string) ([]string, error) {
|
||||
results := []string{}
|
||||
|
||||
d.Tables.Range(func(tableName, v interface{}) bool {
|
||||
entries, err := d.ListTable(tableName.(string))
|
||||
if err != nil {
|
||||
log.Println("Reading table " + tableName.(string) + " failed: " + err.Error())
|
||||
return false
|
||||
}
|
||||
for _, keypairs := range entries {
|
||||
results = append(results, string(keypairs[0])+":"+string(keypairs[1])+"\n")
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
//Create a new table
|
||||
func (d *Database) newTable(tableName string) error {
|
||||
if d.ReadOnly == true {
|
||||
return errors.New("Operation rejected in ReadOnly mode")
|
||||
}
|
||||
|
||||
err := d.Db.(*bolt.DB).Update(func(tx *bolt.Tx) error {
|
||||
_, err := tx.CreateBucketIfNotExists([]byte(tableName))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
d.Tables.Store(tableName, "")
|
||||
return err
|
||||
}
|
||||
|
||||
//Check is table exists
|
||||
func (d *Database) tableExists(tableName string) bool {
|
||||
if _, ok := d.Tables.Load(tableName); ok {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
//Drop the given table
|
||||
func (d *Database) dropTable(tableName string) error {
|
||||
if d.ReadOnly == true {
|
||||
return errors.New("Operation rejected in ReadOnly mode")
|
||||
}
|
||||
|
||||
err := d.Db.(*bolt.DB).Update(func(tx *bolt.Tx) error {
|
||||
err := tx.DeleteBucket([]byte(tableName))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
//Write to table
|
||||
func (d *Database) write(tableName string, key string, value interface{}) error {
|
||||
if d.ReadOnly {
|
||||
return errors.New("Operation rejected in ReadOnly mode")
|
||||
}
|
||||
|
||||
jsonString, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = d.Db.(*bolt.DB).Update(func(tx *bolt.Tx) error {
|
||||
_, err := tx.CreateBucketIfNotExists([]byte(tableName))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
b := tx.Bucket([]byte(tableName))
|
||||
err = b.Put([]byte(key), jsonString)
|
||||
return err
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (d *Database) read(tableName string, key string, assignee interface{}) error {
|
||||
err := d.Db.(*bolt.DB).View(func(tx *bolt.Tx) error {
|
||||
b := tx.Bucket([]byte(tableName))
|
||||
v := b.Get([]byte(key))
|
||||
json.Unmarshal(v, &assignee)
|
||||
return nil
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (d *Database) keyExists(tableName string, key string) bool {
|
||||
resultIsNil := false
|
||||
if !d.TableExists(tableName) {
|
||||
//Table not exists. Do not proceed accessing key
|
||||
log.Println("[DB] ERROR: Requesting key from table that didn't exist!!!")
|
||||
return false
|
||||
}
|
||||
err := d.Db.(*bolt.DB).View(func(tx *bolt.Tx) error {
|
||||
b := tx.Bucket([]byte(tableName))
|
||||
v := b.Get([]byte(key))
|
||||
if v == nil {
|
||||
resultIsNil = true
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return false
|
||||
} else {
|
||||
if resultIsNil {
|
||||
return false
|
||||
} else {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Database) delete(tableName string, key string) error {
|
||||
if d.ReadOnly {
|
||||
return errors.New("Operation rejected in ReadOnly mode")
|
||||
}
|
||||
|
||||
err := d.Db.(*bolt.DB).Update(func(tx *bolt.Tx) error {
|
||||
tx.Bucket([]byte(tableName)).Delete([]byte(key))
|
||||
return nil
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (d *Database) listTable(tableName string) ([][][]byte, error) {
|
||||
var results [][][]byte
|
||||
err := d.Db.(*bolt.DB).View(func(tx *bolt.Tx) error {
|
||||
b := tx.Bucket([]byte(tableName))
|
||||
c := b.Cursor()
|
||||
|
||||
for k, v := c.First(); k != nil; k, v = c.Next() {
|
||||
results = append(results, [][]byte{k, v})
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return results, err
|
||||
}
|
||||
|
||||
func (d *Database) close() {
|
||||
d.Db.(*bolt.DB).Close()
|
||||
}
|
@@ -1,208 +0,0 @@
|
||||
//go:build mipsle || riscv64
|
||||
// +build mipsle riscv64
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
func newDatabase(dbfile string, readOnlyMode bool) (*Database, error) {
|
||||
dbRootPath := filepath.ToSlash(filepath.Clean(dbfile))
|
||||
dbRootPath = "fsdb/" + dbRootPath
|
||||
err := os.MkdirAll(dbRootPath, 0755)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tableMap := sync.Map{}
|
||||
//build the table list from file system
|
||||
files, err := filepath.Glob(filepath.Join(dbRootPath, "/*"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, file := range files {
|
||||
if isDirectory(file) {
|
||||
tableMap.Store(filepath.Base(file), "")
|
||||
}
|
||||
}
|
||||
|
||||
log.Println("Filesystem Emulated Key-value Database Service Started: " + dbRootPath)
|
||||
return &Database{
|
||||
Db: dbRootPath,
|
||||
Tables: tableMap,
|
||||
ReadOnly: readOnlyMode,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (d *Database) dump(filename string) ([]string, error) {
|
||||
//Get all file objects from root
|
||||
rootfiles, err := filepath.Glob(filepath.Join(d.Db.(string), "/*"))
|
||||
if err != nil {
|
||||
return []string{}, err
|
||||
}
|
||||
|
||||
//Filter out the folders
|
||||
rootFolders := []string{}
|
||||
for _, file := range rootfiles {
|
||||
if !isDirectory(file) {
|
||||
rootFolders = append(rootFolders, filepath.Base(file))
|
||||
}
|
||||
}
|
||||
|
||||
return rootFolders, nil
|
||||
}
|
||||
|
||||
func (d *Database) newTable(tableName string) error {
|
||||
if d.ReadOnly {
|
||||
return errors.New("Operation rejected in ReadOnly mode")
|
||||
}
|
||||
tablePath := filepath.Join(d.Db.(string), filepath.Base(tableName))
|
||||
if !fileExists(tablePath) {
|
||||
return os.MkdirAll(tablePath, 0755)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Database) tableExists(tableName string) bool {
|
||||
tablePath := filepath.Join(d.Db.(string), filepath.Base(tableName))
|
||||
if _, err := os.Stat(tablePath); errors.Is(err, os.ErrNotExist) {
|
||||
return false
|
||||
}
|
||||
|
||||
if !isDirectory(tablePath) {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (d *Database) dropTable(tableName string) error {
|
||||
if d.ReadOnly {
|
||||
return errors.New("Operation rejected in ReadOnly mode")
|
||||
}
|
||||
tablePath := filepath.Join(d.Db.(string), filepath.Base(tableName))
|
||||
if d.tableExists(tableName) {
|
||||
return os.RemoveAll(tablePath)
|
||||
} else {
|
||||
return errors.New("table not exists")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (d *Database) write(tableName string, key string, value interface{}) error {
|
||||
if d.ReadOnly {
|
||||
return errors.New("Operation rejected in ReadOnly mode")
|
||||
}
|
||||
tablePath := filepath.Join(d.Db.(string), filepath.Base(tableName))
|
||||
js, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
key = strings.ReplaceAll(key, "/", "-SLASH_SIGN-")
|
||||
|
||||
return os.WriteFile(filepath.Join(tablePath, key+".entry"), js, 0755)
|
||||
}
|
||||
|
||||
func (d *Database) read(tableName string, key string, assignee interface{}) error {
|
||||
if !d.keyExists(tableName, key) {
|
||||
return errors.New("key not exists")
|
||||
}
|
||||
|
||||
key = strings.ReplaceAll(key, "/", "-SLASH_SIGN-")
|
||||
|
||||
tablePath := filepath.Join(d.Db.(string), filepath.Base(tableName))
|
||||
entryPath := filepath.Join(tablePath, key+".entry")
|
||||
content, err := os.ReadFile(entryPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = json.Unmarshal(content, &assignee)
|
||||
return err
|
||||
}
|
||||
|
||||
func (d *Database) keyExists(tableName string, key string) bool {
|
||||
key = strings.ReplaceAll(key, "/", "-SLASH_SIGN-")
|
||||
tablePath := filepath.Join(d.Db.(string), filepath.Base(tableName))
|
||||
entryPath := filepath.Join(tablePath, key+".entry")
|
||||
return fileExists(entryPath)
|
||||
}
|
||||
|
||||
func (d *Database) delete(tableName string, key string) error {
|
||||
if d.ReadOnly {
|
||||
return errors.New("Operation rejected in ReadOnly mode")
|
||||
}
|
||||
if !d.keyExists(tableName, key) {
|
||||
return errors.New("key not exists")
|
||||
}
|
||||
key = strings.ReplaceAll(key, "/", "-SLASH_SIGN-")
|
||||
tablePath := filepath.Join(d.Db.(string), filepath.Base(tableName))
|
||||
entryPath := filepath.Join(tablePath, key+".entry")
|
||||
|
||||
return os.Remove(entryPath)
|
||||
}
|
||||
|
||||
func (d *Database) listTable(tableName string) ([][][]byte, error) {
|
||||
if !d.tableExists(tableName) {
|
||||
return [][][]byte{}, errors.New("table not exists")
|
||||
}
|
||||
tablePath := filepath.Join(d.Db.(string), filepath.Base(tableName))
|
||||
entries, err := filepath.Glob(filepath.Join(tablePath, "/*.entry"))
|
||||
if err != nil {
|
||||
return [][][]byte{}, err
|
||||
}
|
||||
|
||||
var results [][][]byte = [][][]byte{}
|
||||
for _, entry := range entries {
|
||||
if !isDirectory(entry) {
|
||||
//Read it
|
||||
key := filepath.Base(entry)
|
||||
key = strings.TrimSuffix(key, filepath.Ext(key))
|
||||
key = strings.ReplaceAll(key, "-SLASH_SIGN-", "/")
|
||||
|
||||
bkey := []byte(key)
|
||||
bval := []byte("")
|
||||
c, err := os.ReadFile(entry)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
bval = c
|
||||
results = append(results, [][]byte{bkey, bval})
|
||||
}
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (d *Database) close() {
|
||||
//Nothing to close as it is file system
|
||||
}
|
||||
|
||||
func isDirectory(path string) bool {
|
||||
fileInfo, err := os.Stat(path)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return fileInfo.IsDir()
|
||||
}
|
||||
|
||||
func fileExists(name string) bool {
|
||||
_, err := os.Stat(name)
|
||||
if err == nil {
|
||||
return true
|
||||
}
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
return false
|
||||
}
|
||||
return false
|
||||
}
|
@@ -1,82 +0,0 @@
|
||||
package dynamicproxy
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"imuslab.com/zoraxy/mod/geodb"
|
||||
)
|
||||
|
||||
/*
|
||||
Server.go
|
||||
|
||||
Main server for dynamic proxy core
|
||||
*/
|
||||
|
||||
func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
//Check if this ip is in blacklist
|
||||
clientIpAddr := geodb.GetRequesterIP(r)
|
||||
if h.Parent.Option.GeodbStore.IsBlacklisted(clientIpAddr) {
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
template, err := os.ReadFile("./web/forbidden.html")
|
||||
if err != nil {
|
||||
w.Write([]byte("403 - Forbidden"))
|
||||
} else {
|
||||
w.Write(template)
|
||||
}
|
||||
h.logRequest(r, false, 403, "blacklist")
|
||||
return
|
||||
}
|
||||
|
||||
//Check if this is a redirection url
|
||||
if h.Parent.Option.RedirectRuleTable.IsRedirectable(r) {
|
||||
statusCode := h.Parent.Option.RedirectRuleTable.HandleRedirect(w, r)
|
||||
h.logRequest(r, statusCode != 500, statusCode, "redirect")
|
||||
return
|
||||
}
|
||||
|
||||
//Check if there are external routing rule matches.
|
||||
//If yes, route them via external rr
|
||||
matchedRoutingRule := h.Parent.GetMatchingRoutingRule(r)
|
||||
if matchedRoutingRule != nil {
|
||||
//Matching routing rule found. Let the sub-router handle it
|
||||
matchedRoutingRule.Route(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
//Extract request host to see if it is virtual directory or subdomain
|
||||
domainOnly := r.Host
|
||||
if strings.Contains(r.Host, ":") {
|
||||
hostPath := strings.Split(r.Host, ":")
|
||||
domainOnly = hostPath[0]
|
||||
}
|
||||
|
||||
if strings.Contains(r.Host, ".") {
|
||||
//This might be a subdomain. See if there are any subdomain proxy router for this
|
||||
//Remove the port if any
|
||||
|
||||
sep := h.Parent.getSubdomainProxyEndpointFromHostname(domainOnly)
|
||||
if sep != nil {
|
||||
h.subdomainRequest(w, r, sep)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
//Clean up the request URI
|
||||
proxyingPath := strings.TrimSpace(r.RequestURI)
|
||||
targetProxyEndpoint := h.Parent.getTargetProxyEndpointFromRequestURI(proxyingPath)
|
||||
if targetProxyEndpoint != nil {
|
||||
h.proxyRequest(w, r, targetProxyEndpoint)
|
||||
} else if !strings.HasSuffix(proxyingPath, "/") {
|
||||
potentialProxtEndpoint := h.Parent.getTargetProxyEndpointFromRequestURI(proxyingPath + "/")
|
||||
if potentialProxtEndpoint != nil {
|
||||
h.proxyRequest(w, r, potentialProxtEndpoint)
|
||||
} else {
|
||||
h.proxyRequest(w, r, h.Parent.Root)
|
||||
}
|
||||
} else {
|
||||
h.proxyRequest(w, r, h.Parent.Root)
|
||||
}
|
||||
}
|
@@ -1,23 +0,0 @@
|
||||
package domainsniff
|
||||
|
||||
import (
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
//Check if the domain is reachable and return err if not reachable
|
||||
func DomainReachableWithError(domain string) error {
|
||||
timeout := 1 * time.Second
|
||||
conn, err := net.DialTimeout("tcp", domain, timeout)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
conn.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
//Check if domain reachable
|
||||
func DomainReachable(domain string) bool {
|
||||
return DomainReachableWithError(domain) == nil
|
||||
}
|
@@ -1,21 +0,0 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2018-present tobychui
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
@@ -1,414 +0,0 @@
|
||||
package dpcore
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var onExitFlushLoop func()
|
||||
|
||||
const (
|
||||
defaultTimeout = time.Minute * 5
|
||||
)
|
||||
|
||||
// ReverseProxy is an HTTP Handler that takes an incoming request and
|
||||
// sends it to another server, proxying the response back to the
|
||||
// client, support http, also support https tunnel using http.hijacker
|
||||
type ReverseProxy struct {
|
||||
// Set the timeout of the proxy server, default is 5 minutes
|
||||
Timeout time.Duration
|
||||
|
||||
// Director must be a function which modifies
|
||||
// the request into a new request to be sent
|
||||
// using Transport. Its response is then copied
|
||||
// back to the original client unmodified.
|
||||
// Director must not access the provided Request
|
||||
// after returning.
|
||||
Director func(*http.Request)
|
||||
|
||||
// The transport used to perform proxy requests.
|
||||
// default is http.DefaultTransport.
|
||||
Transport http.RoundTripper
|
||||
|
||||
// FlushInterval specifies the flush interval
|
||||
// to flush to the client while copying the
|
||||
// response body. If zero, no periodic flushing is done.
|
||||
FlushInterval time.Duration
|
||||
|
||||
// ErrorLog specifies an optional logger for errors
|
||||
// that occur when attempting to proxy the request.
|
||||
// If nil, logging goes to os.Stderr via the log package's
|
||||
// standard logger.
|
||||
ErrorLog *log.Logger
|
||||
|
||||
// ModifyResponse is an optional function that
|
||||
// modifies the Response from the backend.
|
||||
// If it returns an error, the proxy returns a StatusBadGateway error.
|
||||
ModifyResponse func(*http.Response) error
|
||||
|
||||
//Prepender is an optional prepend text for URL rewrite
|
||||
//
|
||||
Prepender string
|
||||
|
||||
Verbal bool
|
||||
}
|
||||
|
||||
type requestCanceler interface {
|
||||
CancelRequest(req *http.Request)
|
||||
}
|
||||
|
||||
func NewDynamicProxyCore(target *url.URL, prepender string) *ReverseProxy {
|
||||
targetQuery := target.RawQuery
|
||||
director := func(req *http.Request) {
|
||||
req.URL.Scheme = target.Scheme
|
||||
req.URL.Host = target.Host
|
||||
req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path)
|
||||
|
||||
// If Host is empty, the Request.Write method uses
|
||||
// the value of URL.Host.
|
||||
// force use URL.Host
|
||||
req.Host = req.URL.Host
|
||||
if targetQuery == "" || req.URL.RawQuery == "" {
|
||||
req.URL.RawQuery = targetQuery + req.URL.RawQuery
|
||||
} else {
|
||||
req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery
|
||||
}
|
||||
|
||||
if _, ok := req.Header["User-Agent"]; !ok {
|
||||
req.Header.Set("User-Agent", "")
|
||||
}
|
||||
}
|
||||
|
||||
return &ReverseProxy{
|
||||
Director: director,
|
||||
Prepender: prepender,
|
||||
Verbal: false,
|
||||
}
|
||||
}
|
||||
|
||||
func singleJoiningSlash(a, b string) string {
|
||||
aslash := strings.HasSuffix(a, "/")
|
||||
bslash := strings.HasPrefix(b, "/")
|
||||
switch {
|
||||
case aslash && bslash:
|
||||
return a + b[1:]
|
||||
case !aslash && !bslash:
|
||||
return a + "/" + b
|
||||
}
|
||||
return a + b
|
||||
}
|
||||
|
||||
func copyHeader(dst, src http.Header) {
|
||||
for k, vv := range src {
|
||||
for _, v := range vv {
|
||||
dst.Add(k, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Hop-by-hop headers. These are removed when sent to the backend.
|
||||
// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html
|
||||
var hopHeaders = []string{
|
||||
//"Connection",
|
||||
"Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google
|
||||
"Keep-Alive",
|
||||
"Proxy-Authenticate",
|
||||
"Proxy-Authorization",
|
||||
"Te", // canonicalized version of "TE"
|
||||
"Trailer", // not Trailers per URL above; http://www.rfc-editor.org/errata_search.php?eid=4522
|
||||
"Transfer-Encoding",
|
||||
//"Upgrade",
|
||||
}
|
||||
|
||||
func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) {
|
||||
if p.FlushInterval != 0 {
|
||||
if wf, ok := dst.(writeFlusher); ok {
|
||||
mlw := &maxLatencyWriter{
|
||||
dst: wf,
|
||||
latency: p.FlushInterval,
|
||||
done: make(chan bool),
|
||||
}
|
||||
|
||||
go mlw.flushLoop()
|
||||
defer mlw.stop()
|
||||
dst = mlw
|
||||
}
|
||||
}
|
||||
|
||||
io.Copy(dst, src)
|
||||
}
|
||||
|
||||
type writeFlusher interface {
|
||||
io.Writer
|
||||
http.Flusher
|
||||
}
|
||||
|
||||
type maxLatencyWriter struct {
|
||||
dst writeFlusher
|
||||
latency time.Duration
|
||||
mu sync.Mutex
|
||||
done chan bool
|
||||
}
|
||||
|
||||
func (m *maxLatencyWriter) Write(b []byte) (int, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return m.dst.Write(b)
|
||||
}
|
||||
|
||||
func (m *maxLatencyWriter) flushLoop() {
|
||||
t := time.NewTicker(m.latency)
|
||||
defer t.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-m.done:
|
||||
if onExitFlushLoop != nil {
|
||||
onExitFlushLoop()
|
||||
}
|
||||
return
|
||||
case <-t.C:
|
||||
m.mu.Lock()
|
||||
m.dst.Flush()
|
||||
m.mu.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *maxLatencyWriter) stop() {
|
||||
m.done <- true
|
||||
}
|
||||
|
||||
func (p *ReverseProxy) logf(format string, args ...interface{}) {
|
||||
if p.ErrorLog != nil {
|
||||
p.ErrorLog.Printf(format, args...)
|
||||
} else {
|
||||
log.Printf(format, args...)
|
||||
}
|
||||
}
|
||||
|
||||
func removeHeaders(header http.Header) {
|
||||
// Remove hop-by-hop headers listed in the "Connection" header.
|
||||
if c := header.Get("Connection"); c != "" {
|
||||
for _, f := range strings.Split(c, ",") {
|
||||
if f = strings.TrimSpace(f); f != "" {
|
||||
header.Del(f)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Remove hop-by-hop headers
|
||||
for _, h := range hopHeaders {
|
||||
if header.Get(h) != "" {
|
||||
header.Del(h)
|
||||
}
|
||||
}
|
||||
|
||||
if header.Get("A-Upgrade") != "" {
|
||||
header.Set("Upgrade", header.Get("A-Upgrade"))
|
||||
header.Del("A-Upgrade")
|
||||
}
|
||||
}
|
||||
|
||||
func addXForwardedForHeader(req *http.Request) {
|
||||
if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
|
||||
// If we aren't the first proxy retain prior
|
||||
// X-Forwarded-For information as a comma+space
|
||||
// separated list and fold multiple headers into one.
|
||||
if prior, ok := req.Header["X-Forwarded-For"]; ok {
|
||||
clientIP = strings.Join(prior, ", ") + ", " + clientIP
|
||||
}
|
||||
req.Header.Set("X-Forwarded-For", clientIP)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ReverseProxy) ProxyHTTP(rw http.ResponseWriter, req *http.Request) error {
|
||||
transport := p.Transport
|
||||
if transport == nil {
|
||||
transport = http.DefaultTransport
|
||||
}
|
||||
|
||||
outreq := new(http.Request)
|
||||
// Shallow copies of maps, like header
|
||||
*outreq = *req
|
||||
|
||||
if cn, ok := rw.(http.CloseNotifier); ok {
|
||||
if requestCanceler, ok := transport.(requestCanceler); ok {
|
||||
// After the Handler has returned, there is no guarantee
|
||||
// that the channel receives a value, so to make sure
|
||||
reqDone := make(chan struct{})
|
||||
defer close(reqDone)
|
||||
clientGone := cn.CloseNotify()
|
||||
|
||||
go func() {
|
||||
select {
|
||||
case <-clientGone:
|
||||
requestCanceler.CancelRequest(outreq)
|
||||
case <-reqDone:
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
p.Director(outreq)
|
||||
outreq.Close = false
|
||||
|
||||
// We may modify the header (shallow copied above), so we only copy it.
|
||||
outreq.Header = make(http.Header)
|
||||
copyHeader(outreq.Header, req.Header)
|
||||
|
||||
// Remove hop-by-hop headers listed in the "Connection" header, Remove hop-by-hop headers.
|
||||
removeHeaders(outreq.Header)
|
||||
|
||||
// Add X-Forwarded-For Header.
|
||||
addXForwardedForHeader(outreq)
|
||||
|
||||
res, err := transport.RoundTrip(outreq)
|
||||
if err != nil {
|
||||
if p.Verbal {
|
||||
p.logf("http: proxy error: %v", err)
|
||||
}
|
||||
|
||||
//rw.WriteHeader(http.StatusBadGateway)
|
||||
return err
|
||||
}
|
||||
|
||||
// Remove hop-by-hop headers listed in the "Connection" header of the response, Remove hop-by-hop headers.
|
||||
removeHeaders(res.Header)
|
||||
|
||||
if p.ModifyResponse != nil {
|
||||
if err := p.ModifyResponse(res); err != nil {
|
||||
if p.Verbal {
|
||||
p.logf("http: proxy error: %v", err)
|
||||
}
|
||||
|
||||
//rw.WriteHeader(http.StatusBadGateway)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
//Custom header rewriter functions
|
||||
if res.Header.Get("Location") != "" {
|
||||
//Custom redirection to this rproxy relative path
|
||||
res.Header.Set("Location", filepath.ToSlash(filepath.Join(p.Prepender, res.Header.Get("Location"))))
|
||||
}
|
||||
// Copy header from response to client.
|
||||
copyHeader(rw.Header(), res.Header)
|
||||
|
||||
// The "Trailer" header isn't included in the Transport's response, Build it up from Trailer.
|
||||
if len(res.Trailer) > 0 {
|
||||
trailerKeys := make([]string, 0, len(res.Trailer))
|
||||
for k := range res.Trailer {
|
||||
trailerKeys = append(trailerKeys, k)
|
||||
}
|
||||
rw.Header().Add("Trailer", strings.Join(trailerKeys, ", "))
|
||||
}
|
||||
|
||||
rw.WriteHeader(res.StatusCode)
|
||||
if len(res.Trailer) > 0 {
|
||||
// Force chunking if we saw a response trailer.
|
||||
// This prevents net/http from calculating the length for short
|
||||
// bodies and adding a Content-Length.
|
||||
if fl, ok := rw.(http.Flusher); ok {
|
||||
fl.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
p.copyResponse(rw, res.Body)
|
||||
// close now, instead of defer, to populate res.Trailer
|
||||
res.Body.Close()
|
||||
copyHeader(rw.Header(), res.Trailer)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *ReverseProxy) ProxyHTTPS(rw http.ResponseWriter, req *http.Request) error {
|
||||
hij, ok := rw.(http.Hijacker)
|
||||
if !ok {
|
||||
p.logf("http server does not support hijacker")
|
||||
return errors.New("http server does not support hijacker")
|
||||
}
|
||||
|
||||
clientConn, _, err := hij.Hijack()
|
||||
if err != nil {
|
||||
if p.Verbal {
|
||||
p.logf("http: proxy error: %v", err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
proxyConn, err := net.Dial("tcp", req.URL.Host)
|
||||
if err != nil {
|
||||
if p.Verbal {
|
||||
p.logf("http: proxy error: %v", err)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// The returned net.Conn may have read or write deadlines
|
||||
// already set, depending on the configuration of the
|
||||
// Server, to set or clear those deadlines as needed
|
||||
// we set timeout to 5 minutes
|
||||
deadline := time.Now()
|
||||
if p.Timeout == 0 {
|
||||
deadline = deadline.Add(time.Minute * 5)
|
||||
} else {
|
||||
deadline = deadline.Add(p.Timeout)
|
||||
}
|
||||
|
||||
err = clientConn.SetDeadline(deadline)
|
||||
if err != nil {
|
||||
if p.Verbal {
|
||||
p.logf("http: proxy error: %v", err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
err = proxyConn.SetDeadline(deadline)
|
||||
if err != nil {
|
||||
if p.Verbal {
|
||||
p.logf("http: proxy error: %v", err)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = clientConn.Write([]byte("HTTP/1.0 200 OK\r\n\r\n"))
|
||||
if err != nil {
|
||||
if p.Verbal {
|
||||
p.logf("http: proxy error: %v", err)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
go func() {
|
||||
io.Copy(clientConn, proxyConn)
|
||||
clientConn.Close()
|
||||
proxyConn.Close()
|
||||
}()
|
||||
|
||||
io.Copy(proxyConn, clientConn)
|
||||
proxyConn.Close()
|
||||
clientConn.Close()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) error {
|
||||
if req.Method == "CONNECT" {
|
||||
err := p.ProxyHTTPS(rw, req)
|
||||
return err
|
||||
} else {
|
||||
err := p.ProxyHTTP(rw, req)
|
||||
return err
|
||||
}
|
||||
}
|
@@ -1,347 +0,0 @@
|
||||
package dynamicproxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"imuslab.com/zoraxy/mod/dynamicproxy/dpcore"
|
||||
"imuslab.com/zoraxy/mod/dynamicproxy/redirection"
|
||||
"imuslab.com/zoraxy/mod/geodb"
|
||||
"imuslab.com/zoraxy/mod/reverseproxy"
|
||||
"imuslab.com/zoraxy/mod/statistic"
|
||||
"imuslab.com/zoraxy/mod/tlscert"
|
||||
)
|
||||
|
||||
/*
|
||||
Zoraxy Dynamic Proxy
|
||||
*/
|
||||
type RouterOption struct {
|
||||
Port int
|
||||
UseTls bool
|
||||
ForceHttpsRedirect bool
|
||||
TlsManager *tlscert.Manager
|
||||
RedirectRuleTable *redirection.RuleTable
|
||||
GeodbStore *geodb.Store
|
||||
StatisticCollector *statistic.Collector
|
||||
}
|
||||
|
||||
type Router struct {
|
||||
Option *RouterOption
|
||||
ProxyEndpoints *sync.Map
|
||||
SubdomainEndpoint *sync.Map
|
||||
Running bool
|
||||
Root *ProxyEndpoint
|
||||
mux http.Handler
|
||||
server *http.Server
|
||||
tlsListener net.Listener
|
||||
routingRules []*RoutingRule
|
||||
}
|
||||
|
||||
type ProxyEndpoint struct {
|
||||
Root string
|
||||
Domain string
|
||||
RequireTLS bool
|
||||
Proxy *dpcore.ReverseProxy `json:"-"`
|
||||
}
|
||||
|
||||
type SubdomainEndpoint struct {
|
||||
MatchingDomain string
|
||||
Domain string
|
||||
RequireTLS bool
|
||||
Proxy *reverseproxy.ReverseProxy `json:"-"`
|
||||
}
|
||||
|
||||
type ProxyHandler struct {
|
||||
Parent *Router
|
||||
}
|
||||
|
||||
func NewDynamicProxy(option RouterOption) (*Router, error) {
|
||||
proxyMap := sync.Map{}
|
||||
domainMap := sync.Map{}
|
||||
thisRouter := Router{
|
||||
Option: &option,
|
||||
ProxyEndpoints: &proxyMap,
|
||||
SubdomainEndpoint: &domainMap,
|
||||
Running: false,
|
||||
server: nil,
|
||||
routingRules: []*RoutingRule{},
|
||||
}
|
||||
|
||||
thisRouter.mux = &ProxyHandler{
|
||||
Parent: &thisRouter,
|
||||
}
|
||||
|
||||
return &thisRouter, nil
|
||||
}
|
||||
|
||||
// Update TLS setting in runtime. Will restart the proxy server
|
||||
// if it is already running in the background
|
||||
func (router *Router) UpdateTLSSetting(tlsEnabled bool) {
|
||||
router.Option.UseTls = tlsEnabled
|
||||
router.Restart()
|
||||
}
|
||||
|
||||
// Update https redirect, which will require updates
|
||||
func (router *Router) UpdateHttpToHttpsRedirectSetting(useRedirect bool) {
|
||||
router.Option.ForceHttpsRedirect = useRedirect
|
||||
router.Restart()
|
||||
}
|
||||
|
||||
// Start the dynamic routing
|
||||
func (router *Router) StartProxyService() error {
|
||||
//Create a new server object
|
||||
if router.server != nil {
|
||||
return errors.New("Reverse proxy server already running")
|
||||
}
|
||||
|
||||
if router.Root == nil {
|
||||
return errors.New("Reverse proxy router root not set")
|
||||
}
|
||||
|
||||
config := &tls.Config{
|
||||
GetCertificate: router.Option.TlsManager.GetCert,
|
||||
}
|
||||
|
||||
if router.Option.UseTls {
|
||||
//Serve with TLS mode
|
||||
ln, err := tls.Listen("tcp", ":"+strconv.Itoa(router.Option.Port), config)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return err
|
||||
}
|
||||
router.tlsListener = ln
|
||||
router.server = &http.Server{Addr: ":" + strconv.Itoa(router.Option.Port), Handler: router.mux}
|
||||
router.Running = true
|
||||
|
||||
if router.Option.Port == 443 && router.Option.ForceHttpsRedirect {
|
||||
//Add a 80 to 443 redirector
|
||||
httpServer := &http.Server{
|
||||
Addr: ":80",
|
||||
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Redirect(w, r, "https://"+r.Host+r.RequestURI, http.StatusTemporaryRedirect)
|
||||
}),
|
||||
ReadTimeout: 3 * time.Second,
|
||||
WriteTimeout: 3 * time.Second,
|
||||
IdleTimeout: 120 * time.Second,
|
||||
}
|
||||
|
||||
log.Println("Starting HTTP-to-HTTPS redirector (port 80)")
|
||||
go func() {
|
||||
//Start another router to check if the router.server is killed. If yes, kill this server as well
|
||||
go func() {
|
||||
for router.server != nil {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
httpServer.Shutdown(ctx)
|
||||
log.Println(":80 to :433 redirection listener stopped")
|
||||
}()
|
||||
if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
log.Fatalf("Could not start server: %v\n", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
log.Println("Reverse proxy service started in the background (TLS mode)")
|
||||
go func() {
|
||||
if err := router.server.Serve(ln); err != nil && err != http.ErrServerClosed {
|
||||
log.Fatalf("Could not start server: %v\n", err)
|
||||
}
|
||||
}()
|
||||
} else {
|
||||
//Serve with non TLS mode
|
||||
router.tlsListener = nil
|
||||
router.server = &http.Server{Addr: ":" + strconv.Itoa(router.Option.Port), Handler: router.mux}
|
||||
router.Running = true
|
||||
log.Println("Reverse proxy service started in the background (Plain HTTP mode)")
|
||||
go func() {
|
||||
router.server.ListenAndServe()
|
||||
//log.Println("[DynamicProxy] " + err.Error())
|
||||
}()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (router *Router) StopProxyService() error {
|
||||
if router.server == nil {
|
||||
return errors.New("Reverse proxy server already stopped")
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
err := router.server.Shutdown(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if router.tlsListener != nil {
|
||||
router.tlsListener.Close()
|
||||
}
|
||||
|
||||
//Discard the server object
|
||||
router.tlsListener = nil
|
||||
router.server = nil
|
||||
router.Running = false
|
||||
return nil
|
||||
}
|
||||
|
||||
// Restart the current router if it is running.
|
||||
// Startup the server if it is not running initially
|
||||
func (router *Router) Restart() error {
|
||||
//Stop the router if it is already running
|
||||
if router.Running {
|
||||
err := router.StopProxyService()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
//Start the server
|
||||
err := router.StartProxyService()
|
||||
return err
|
||||
}
|
||||
|
||||
/*
|
||||
Check if a given request is accessed via a proxied subdomain
|
||||
*/
|
||||
|
||||
func (router *Router) IsProxiedSubdomain(r *http.Request) bool {
|
||||
hostname := r.Header.Get("X-Forwarded-Host")
|
||||
if hostname == "" {
|
||||
hostname = r.Host
|
||||
}
|
||||
hostname = strings.Split(hostname, ":")[0]
|
||||
subdEndpoint := router.getSubdomainProxyEndpointFromHostname(hostname)
|
||||
return subdEndpoint != nil
|
||||
}
|
||||
|
||||
/*
|
||||
Add an URL into a custom proxy services
|
||||
*/
|
||||
func (router *Router) AddVirtualDirectoryProxyService(rootname string, domain string, requireTLS bool) error {
|
||||
if domain[len(domain)-1:] == "/" {
|
||||
domain = domain[:len(domain)-1]
|
||||
}
|
||||
|
||||
if rootname[len(rootname)-1:] == "/" {
|
||||
rootname = rootname[:len(rootname)-1]
|
||||
}
|
||||
|
||||
webProxyEndpoint := domain
|
||||
if requireTLS {
|
||||
webProxyEndpoint = "https://" + webProxyEndpoint
|
||||
} else {
|
||||
webProxyEndpoint = "http://" + webProxyEndpoint
|
||||
}
|
||||
//Create a new proxy agent for this root
|
||||
path, err := url.Parse(webProxyEndpoint)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
proxy := dpcore.NewDynamicProxyCore(path, rootname)
|
||||
|
||||
endpointObject := ProxyEndpoint{
|
||||
Root: rootname,
|
||||
Domain: domain,
|
||||
RequireTLS: requireTLS,
|
||||
Proxy: proxy,
|
||||
}
|
||||
|
||||
router.ProxyEndpoints.Store(rootname, &endpointObject)
|
||||
|
||||
log.Println("Adding Proxy Rule: ", rootname+" to "+domain)
|
||||
return nil
|
||||
}
|
||||
|
||||
/*
|
||||
Remove routing from RP
|
||||
*/
|
||||
func (router *Router) RemoveProxy(ptype string, key string) error {
|
||||
//fmt.Println(ptype, key)
|
||||
if ptype == "vdir" {
|
||||
router.ProxyEndpoints.Delete(key)
|
||||
return nil
|
||||
} else if ptype == "subd" {
|
||||
router.SubdomainEndpoint.Delete(key)
|
||||
return nil
|
||||
}
|
||||
return errors.New("invalid ptype")
|
||||
}
|
||||
|
||||
/*
|
||||
Add an default router for the proxy server
|
||||
*/
|
||||
func (router *Router) SetRootProxy(proxyLocation string, requireTLS bool) error {
|
||||
if proxyLocation[len(proxyLocation)-1:] == "/" {
|
||||
proxyLocation = proxyLocation[:len(proxyLocation)-1]
|
||||
}
|
||||
|
||||
webProxyEndpoint := proxyLocation
|
||||
if requireTLS {
|
||||
webProxyEndpoint = "https://" + webProxyEndpoint
|
||||
} else {
|
||||
webProxyEndpoint = "http://" + webProxyEndpoint
|
||||
}
|
||||
//Create a new proxy agent for this root
|
||||
path, err := url.Parse(webProxyEndpoint)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
proxy := dpcore.NewDynamicProxyCore(path, "")
|
||||
|
||||
rootEndpoint := ProxyEndpoint{
|
||||
Root: "/",
|
||||
Domain: proxyLocation,
|
||||
RequireTLS: requireTLS,
|
||||
Proxy: proxy,
|
||||
}
|
||||
|
||||
router.Root = &rootEndpoint
|
||||
return nil
|
||||
}
|
||||
|
||||
//Helpers to export the syncmap for easier processing
|
||||
func (r *Router) GetSDProxyEndpointsAsMap() map[string]*SubdomainEndpoint {
|
||||
m := make(map[string]*SubdomainEndpoint)
|
||||
r.SubdomainEndpoint.Range(func(key, value interface{}) bool {
|
||||
k, ok := key.(string)
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
v, ok := value.(*SubdomainEndpoint)
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
m[k] = v
|
||||
return true
|
||||
})
|
||||
return m
|
||||
}
|
||||
|
||||
func (r *Router) GetVDProxyEndpointsAsMap() map[string]*ProxyEndpoint {
|
||||
m := make(map[string]*ProxyEndpoint)
|
||||
r.ProxyEndpoints.Range(func(key, value interface{}) bool {
|
||||
k, ok := key.(string)
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
v, ok := value.(*ProxyEndpoint)
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
m[k] = v
|
||||
return true
|
||||
})
|
||||
return m
|
||||
}
|
@@ -1,149 +0,0 @@
|
||||
package dynamicproxy
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"imuslab.com/zoraxy/mod/geodb"
|
||||
"imuslab.com/zoraxy/mod/statistic"
|
||||
"imuslab.com/zoraxy/mod/websocketproxy"
|
||||
)
|
||||
|
||||
func (router *Router) getTargetProxyEndpointFromRequestURI(requestURI string) *ProxyEndpoint {
|
||||
var targetProxyEndpoint *ProxyEndpoint = nil
|
||||
router.ProxyEndpoints.Range(func(key, value interface{}) bool {
|
||||
rootname := key.(string)
|
||||
if strings.HasPrefix(requestURI, rootname) {
|
||||
thisProxyEndpoint := value.(*ProxyEndpoint)
|
||||
targetProxyEndpoint = thisProxyEndpoint
|
||||
}
|
||||
/*
|
||||
if len(requestURI) >= len(rootname) && requestURI[:len(rootname)] == rootname {
|
||||
thisProxyEndpoint := value.(*ProxyEndpoint)
|
||||
targetProxyEndpoint = thisProxyEndpoint
|
||||
}
|
||||
*/
|
||||
return true
|
||||
})
|
||||
|
||||
return targetProxyEndpoint
|
||||
}
|
||||
|
||||
func (router *Router) getSubdomainProxyEndpointFromHostname(hostname string) *SubdomainEndpoint {
|
||||
var targetSubdomainEndpoint *SubdomainEndpoint = nil
|
||||
ep, ok := router.SubdomainEndpoint.Load(hostname)
|
||||
if ok {
|
||||
targetSubdomainEndpoint = ep.(*SubdomainEndpoint)
|
||||
}
|
||||
|
||||
return targetSubdomainEndpoint
|
||||
}
|
||||
|
||||
func (router *Router) rewriteURL(rooturl string, requestURL string) string {
|
||||
if len(requestURL) > len(rooturl) {
|
||||
return requestURL[len(rooturl):]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) subdomainRequest(w http.ResponseWriter, r *http.Request, target *SubdomainEndpoint) {
|
||||
r.Header.Set("X-Forwarded-Host", r.Host)
|
||||
requestURL := r.URL.String()
|
||||
if r.Header["Upgrade"] != nil && strings.ToLower(r.Header["Upgrade"][0]) == "websocket" {
|
||||
//Handle WebSocket request. Forward the custom Upgrade header and rewrite origin
|
||||
r.Header.Set("A-Upgrade", "websocket")
|
||||
wsRedirectionEndpoint := target.Domain
|
||||
if wsRedirectionEndpoint[len(wsRedirectionEndpoint)-1:] != "/" {
|
||||
//Append / to the end of the redirection endpoint if not exists
|
||||
wsRedirectionEndpoint = wsRedirectionEndpoint + "/"
|
||||
}
|
||||
if len(requestURL) > 0 && requestURL[:1] == "/" {
|
||||
//Remove starting / from request URL if exists
|
||||
requestURL = requestURL[1:]
|
||||
}
|
||||
u, _ := url.Parse("ws://" + wsRedirectionEndpoint + requestURL)
|
||||
if target.RequireTLS {
|
||||
u, _ = url.Parse("wss://" + wsRedirectionEndpoint + requestURL)
|
||||
}
|
||||
h.logRequest(r, true, 101, "subdomain-websocket")
|
||||
wspHandler := websocketproxy.NewProxy(u)
|
||||
wspHandler.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
r.Host = r.URL.Host
|
||||
err := target.Proxy.ServeHTTP(w, r)
|
||||
var dnsError *net.DNSError
|
||||
if err != nil {
|
||||
if errors.As(err, &dnsError) {
|
||||
http.ServeFile(w, r, "./web/hosterror.html")
|
||||
log.Println(err.Error())
|
||||
h.logRequest(r, false, 404, "subdomain-http")
|
||||
} else {
|
||||
http.ServeFile(w, r, "./web/rperror.html")
|
||||
log.Println(err.Error())
|
||||
h.logRequest(r, false, 521, "subdomain-http")
|
||||
}
|
||||
}
|
||||
|
||||
h.logRequest(r, true, 200, "subdomain-http")
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) proxyRequest(w http.ResponseWriter, r *http.Request, target *ProxyEndpoint) {
|
||||
rewriteURL := h.Parent.rewriteURL(target.Root, r.RequestURI)
|
||||
r.URL, _ = url.Parse(rewriteURL)
|
||||
r.Header.Set("X-Forwarded-Host", r.Host)
|
||||
if r.Header["Upgrade"] != nil && strings.ToLower(r.Header["Upgrade"][0]) == "websocket" {
|
||||
//Handle WebSocket request. Forward the custom Upgrade header and rewrite origin
|
||||
r.Header.Set("A-Upgrade", "websocket")
|
||||
wsRedirectionEndpoint := target.Domain
|
||||
if wsRedirectionEndpoint[len(wsRedirectionEndpoint)-1:] != "/" {
|
||||
wsRedirectionEndpoint = wsRedirectionEndpoint + "/"
|
||||
}
|
||||
u, _ := url.Parse("ws://" + wsRedirectionEndpoint + r.URL.String())
|
||||
if target.RequireTLS {
|
||||
u, _ = url.Parse("wss://" + wsRedirectionEndpoint + r.URL.String())
|
||||
}
|
||||
h.logRequest(r, true, 101, "vdir-websocket")
|
||||
wspHandler := websocketproxy.NewProxy(u)
|
||||
wspHandler.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
r.Host = r.URL.Host
|
||||
err := target.Proxy.ServeHTTP(w, r)
|
||||
var dnsError *net.DNSError
|
||||
if err != nil {
|
||||
if errors.As(err, &dnsError) {
|
||||
http.ServeFile(w, r, "./web/hosterror.html")
|
||||
log.Println(err.Error())
|
||||
h.logRequest(r, false, 404, "vdir-http")
|
||||
} else {
|
||||
http.ServeFile(w, r, "./web/rperror.html")
|
||||
log.Println(err.Error())
|
||||
h.logRequest(r, false, 521, "vdir-http")
|
||||
}
|
||||
}
|
||||
h.logRequest(r, true, 200, "vdir-http")
|
||||
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) logRequest(r *http.Request, succ bool, statusCode int, forwardType string) {
|
||||
if h.Parent.Option.StatisticCollector != nil {
|
||||
go func() {
|
||||
requestInfo := statistic.RequestInfo{
|
||||
IpAddr: geodb.GetRequesterIP(r),
|
||||
RequestOriginalCountryISOCode: h.Parent.Option.GeodbStore.GetRequesterCountryISOCode(r),
|
||||
Succ: succ,
|
||||
StatusCode: statusCode,
|
||||
ForwardType: forwardType,
|
||||
}
|
||||
h.Parent.Option.StatisticCollector.RecordRequest(requestInfo)
|
||||
}()
|
||||
|
||||
}
|
||||
}
|
@@ -1,53 +0,0 @@
|
||||
package redirection
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
/*
|
||||
handler.go
|
||||
|
||||
This script store the handlers use for handling
|
||||
redirection request
|
||||
*/
|
||||
|
||||
//Check if a request URL is a redirectable URI
|
||||
func (t *RuleTable) IsRedirectable(r *http.Request) bool {
|
||||
requestPath := r.Host + r.URL.Path
|
||||
rr := t.MatchRedirectRule(requestPath)
|
||||
return rr != nil
|
||||
}
|
||||
|
||||
//Handle the redirect request, return after calling this function to prevent
|
||||
//multiple write to the response writer
|
||||
//Return the status code of the redirection handling
|
||||
func (t *RuleTable) HandleRedirect(w http.ResponseWriter, r *http.Request) int {
|
||||
requestPath := r.Host + r.URL.Path
|
||||
rr := t.MatchRedirectRule(requestPath)
|
||||
if rr != nil {
|
||||
redirectTarget := rr.TargetURL
|
||||
//Always pad a / at the back of the target URL
|
||||
if redirectTarget[len(redirectTarget)-1:] != "/" {
|
||||
redirectTarget += "/"
|
||||
}
|
||||
if rr.ForwardChildpath {
|
||||
//Remove the first / in the path
|
||||
redirectTarget += r.URL.Path[1:] + "?" + r.URL.RawQuery
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(redirectTarget, "http://") && !strings.HasPrefix(redirectTarget, "https://") {
|
||||
redirectTarget = "http://" + redirectTarget
|
||||
}
|
||||
|
||||
http.Redirect(w, r, redirectTarget, rr.StatusCode)
|
||||
return rr.StatusCode
|
||||
} else {
|
||||
//Invalid usage
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
w.Write([]byte("500 - Internal Server Error"))
|
||||
log.Println("Target request URL do not have matching redirect rule. Check with IsRedirectable before calling HandleRedirect!")
|
||||
return 500
|
||||
}
|
||||
}
|
@@ -1,162 +0,0 @@
|
||||
package redirection
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"log"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"imuslab.com/zoraxy/mod/utils"
|
||||
)
|
||||
|
||||
type RuleTable struct {
|
||||
configPath string //The location where the redirection rules is stored
|
||||
rules sync.Map //Store the redirection rules for this reverse proxy instance
|
||||
}
|
||||
|
||||
type RedirectRules struct {
|
||||
RedirectURL string //The matching URL to redirect
|
||||
TargetURL string //The destination redirection url
|
||||
ForwardChildpath bool //Also redirect the pathname
|
||||
StatusCode int //Status Code for redirection
|
||||
}
|
||||
|
||||
func NewRuleTable(configPath string) (*RuleTable, error) {
|
||||
thisRuleTable := RuleTable{
|
||||
rules: sync.Map{},
|
||||
configPath: configPath,
|
||||
}
|
||||
//Load all the rules from the config path
|
||||
if !utils.FileExists(configPath) {
|
||||
os.MkdirAll(configPath, 0775)
|
||||
}
|
||||
|
||||
// Load all the *.json from the configPath
|
||||
files, err := filepath.Glob(filepath.Join(configPath, "*.json"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse the json content into RedirectRules
|
||||
var rules []*RedirectRules
|
||||
for _, file := range files {
|
||||
b, err := os.ReadFile(file)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
thisRule := RedirectRules{}
|
||||
|
||||
err = json.Unmarshal(b, &thisRule)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
rules = append(rules, &thisRule)
|
||||
}
|
||||
|
||||
//Map the rules into the sync map
|
||||
for _, rule := range rules {
|
||||
log.Println("Redirection rule added: " + rule.RedirectURL + " -> " + rule.TargetURL)
|
||||
thisRuleTable.rules.Store(rule.RedirectURL, rule)
|
||||
}
|
||||
|
||||
return &thisRuleTable, nil
|
||||
}
|
||||
|
||||
func (t *RuleTable) AddRedirectRule(redirectURL string, destURL string, forwardPathname bool, statusCode int) error {
|
||||
// Create a new RedirectRules object with the given parameters
|
||||
newRule := &RedirectRules{
|
||||
RedirectURL: redirectURL,
|
||||
TargetURL: destURL,
|
||||
ForwardChildpath: forwardPathname,
|
||||
StatusCode: statusCode,
|
||||
}
|
||||
|
||||
// Convert the redirectURL to a valid filename by replacing "/" with "-" and "." with "_"
|
||||
filename := strings.ReplaceAll(strings.ReplaceAll(redirectURL, "/", "-"), ".", "_") + ".json"
|
||||
|
||||
// Create the full file path by joining the t.configPath with the filename
|
||||
filepath := path.Join(t.configPath, filename)
|
||||
|
||||
// Create a new file for writing the JSON data
|
||||
file, err := os.Create(filepath)
|
||||
if err != nil {
|
||||
log.Printf("Error creating file %s: %s", filepath, err)
|
||||
return err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// Encode the RedirectRules object to JSON and write it to the file
|
||||
err = json.NewEncoder(file).Encode(newRule)
|
||||
if err != nil {
|
||||
log.Printf("Error encoding JSON to file %s: %s", filepath, err)
|
||||
return err
|
||||
}
|
||||
|
||||
// Store the RedirectRules object in the sync.Map
|
||||
t.rules.Store(redirectURL, newRule)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *RuleTable) DeleteRedirectRule(redirectURL string) error {
|
||||
// Convert the redirectURL to a valid filename by replacing "/" with "-" and "." with "_"
|
||||
filename := strings.ReplaceAll(strings.ReplaceAll(redirectURL, "/", "-"), ".", "_") + ".json"
|
||||
|
||||
// Create the full file path by joining the t.configPath with the filename
|
||||
filepath := path.Join(t.configPath, filename)
|
||||
|
||||
// Check if the file exists
|
||||
if _, err := os.Stat(filepath); os.IsNotExist(err) {
|
||||
return nil // File doesn't exist, nothing to delete
|
||||
}
|
||||
|
||||
// Delete the file
|
||||
if err := os.Remove(filepath); err != nil {
|
||||
log.Printf("Error deleting file %s: %s", filepath, err)
|
||||
return err
|
||||
}
|
||||
|
||||
// Delete the key-value pair from the sync.Map
|
||||
t.rules.Delete(redirectURL)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get a list of all the redirection rules
|
||||
func (t *RuleTable) GetAllRedirectRules() []*RedirectRules {
|
||||
rules := []*RedirectRules{}
|
||||
t.rules.Range(func(key, value interface{}) bool {
|
||||
r, ok := value.(*RedirectRules)
|
||||
if ok {
|
||||
rules = append(rules, r)
|
||||
}
|
||||
return true
|
||||
})
|
||||
return rules
|
||||
}
|
||||
|
||||
// Check if a given request URL matched any of the redirection rule
|
||||
func (t *RuleTable) MatchRedirectRule(requestedURL string) *RedirectRules {
|
||||
// Iterate through all the keys in the rules map
|
||||
var targetRedirectionRule *RedirectRules = nil
|
||||
var maxMatch int = 0
|
||||
|
||||
t.rules.Range(func(key interface{}, value interface{}) bool {
|
||||
// Check if the requested URL starts with the key as a prefix
|
||||
if strings.HasPrefix(requestedURL, key.(string)) {
|
||||
// This request URL matched the domain
|
||||
if len(key.(string)) > maxMatch {
|
||||
maxMatch = len(key.(string))
|
||||
targetRedirectionRule = value.(*RedirectRules)
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
return targetRedirectionRule
|
||||
}
|
@@ -1,85 +0,0 @@
|
||||
package dynamicproxy
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
/*
|
||||
Special.go
|
||||
|
||||
This script handle special routing rules
|
||||
by external modules
|
||||
*/
|
||||
|
||||
type RoutingRule struct {
|
||||
ID string
|
||||
MatchRule func(r *http.Request) bool
|
||||
RoutingHandler http.Handler
|
||||
Enabled bool
|
||||
}
|
||||
|
||||
//Router functions
|
||||
//Check if a routing rule exists given its id
|
||||
func (router *Router) GetRoutingRuleById(rrid string) (*RoutingRule, error) {
|
||||
for _, rr := range router.routingRules {
|
||||
if rr.ID == rrid {
|
||||
return rr, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, errors.New("routing rule with given id not found")
|
||||
}
|
||||
|
||||
//Add a routing rule to the router
|
||||
func (router *Router) AddRoutingRules(rr *RoutingRule) error {
|
||||
_, err := router.GetRoutingRuleById(rr.ID)
|
||||
if err != nil {
|
||||
//routing rule with given id already exists
|
||||
return err
|
||||
}
|
||||
|
||||
router.routingRules = append(router.routingRules, rr)
|
||||
return nil
|
||||
}
|
||||
|
||||
//Remove a routing rule from the router
|
||||
func (router *Router) RemoveRoutingRule(rrid string) {
|
||||
newRoutingRules := []*RoutingRule{}
|
||||
for _, rr := range router.routingRules {
|
||||
if rr.ID != rrid {
|
||||
newRoutingRules = append(newRoutingRules, rr)
|
||||
}
|
||||
}
|
||||
|
||||
router.routingRules = newRoutingRules
|
||||
}
|
||||
|
||||
//Get all routing rules
|
||||
func (router *Router) GetAllRoutingRules() []*RoutingRule {
|
||||
return router.routingRules
|
||||
}
|
||||
|
||||
//Get the matching routing rule that describe this request.
|
||||
//Return nil if no routing rule is match
|
||||
func (router *Router) GetMatchingRoutingRule(r *http.Request) *RoutingRule {
|
||||
for _, thisRr := range router.routingRules {
|
||||
if thisRr.IsMatch(r) {
|
||||
return thisRr
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
//Routing Rule functions
|
||||
//Check if a request object match the
|
||||
func (e *RoutingRule) IsMatch(r *http.Request) bool {
|
||||
if !e.Enabled {
|
||||
return false
|
||||
}
|
||||
return e.MatchRule(r)
|
||||
}
|
||||
|
||||
func (e *RoutingRule) Route(w http.ResponseWriter, r *http.Request) {
|
||||
e.RoutingHandler.ServeHTTP(w, r)
|
||||
}
|
@@ -1,44 +0,0 @@
|
||||
package dynamicproxy
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/url"
|
||||
|
||||
"imuslab.com/zoraxy/mod/reverseproxy"
|
||||
)
|
||||
|
||||
/*
|
||||
Add an URL intoa custom subdomain service
|
||||
|
||||
*/
|
||||
|
||||
func (router *Router) AddSubdomainRoutingService(hostnameWithSubdomain string, domain string, requireTLS bool) error {
|
||||
if domain[len(domain)-1:] == "/" {
|
||||
domain = domain[:len(domain)-1]
|
||||
}
|
||||
|
||||
webProxyEndpoint := domain
|
||||
if requireTLS {
|
||||
webProxyEndpoint = "https://" + webProxyEndpoint
|
||||
} else {
|
||||
webProxyEndpoint = "http://" + webProxyEndpoint
|
||||
}
|
||||
|
||||
//Create a new proxy agent for this root
|
||||
path, err := url.Parse(webProxyEndpoint)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
proxy := reverseproxy.NewReverseProxy(path)
|
||||
|
||||
router.SubdomainEndpoint.Store(hostnameWithSubdomain, &SubdomainEndpoint{
|
||||
MatchingDomain: hostnameWithSubdomain,
|
||||
Domain: domain,
|
||||
RequireTLS: requireTLS,
|
||||
Proxy: proxy,
|
||||
})
|
||||
|
||||
log.Println("Adding Subdomain Rule: ", hostnameWithSubdomain+" to "+domain)
|
||||
return nil
|
||||
}
|
@@ -1,244 +0,0 @@
|
||||
package geodb
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/oschwald/geoip2-golang"
|
||||
"imuslab.com/zoraxy/mod/database"
|
||||
)
|
||||
|
||||
type Store struct {
|
||||
Enabled bool
|
||||
geodb *geoip2.Reader
|
||||
sysdb *database.Database
|
||||
}
|
||||
|
||||
type CountryInfo struct {
|
||||
CountryIsoCode string
|
||||
ContinetCode string
|
||||
}
|
||||
|
||||
func NewGeoDb(sysdb *database.Database, dbfile string) (*Store, error) {
|
||||
db, err := geoip2.Open(dbfile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = sysdb.NewTable("blacklist-cn")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = sysdb.NewTable("blacklist-ip")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = sysdb.NewTable("blacklist")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
blacklistEnabled := false
|
||||
sysdb.Read("blacklist", "enabled", &blacklistEnabled)
|
||||
|
||||
return &Store{
|
||||
Enabled: blacklistEnabled,
|
||||
geodb: db,
|
||||
sysdb: sysdb,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Store) ToggleBlacklist(enabled bool) {
|
||||
s.sysdb.Write("blacklist", "enabled", enabled)
|
||||
s.Enabled = enabled
|
||||
}
|
||||
|
||||
func (s *Store) ResolveCountryCodeFromIP(ipstring string) (*CountryInfo, error) {
|
||||
// If you are using strings that may be invalid, check that ip is not nil
|
||||
ip := net.ParseIP(ipstring)
|
||||
record, err := s.geodb.City(ip)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &CountryInfo{
|
||||
record.Country.IsoCode,
|
||||
record.Continent.Code,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Store) Close() {
|
||||
s.geodb.Close()
|
||||
}
|
||||
|
||||
func (s *Store) AddCountryCodeToBlackList(countryCode string) {
|
||||
countryCode = strings.ToLower(countryCode)
|
||||
s.sysdb.Write("blacklist-cn", countryCode, true)
|
||||
}
|
||||
|
||||
func (s *Store) RemoveCountryCodeFromBlackList(countryCode string) {
|
||||
countryCode = strings.ToLower(countryCode)
|
||||
s.sysdb.Delete("blacklist-cn", countryCode)
|
||||
}
|
||||
|
||||
func (s *Store) IsCountryCodeBlacklisted(countryCode string) bool {
|
||||
countryCode = strings.ToLower(countryCode)
|
||||
var isBlacklisted bool = false
|
||||
s.sysdb.Read("blacklist-cn", countryCode, &isBlacklisted)
|
||||
return isBlacklisted
|
||||
}
|
||||
|
||||
func (s *Store) GetAllBlacklistedCountryCode() []string {
|
||||
bannedCountryCodes := []string{}
|
||||
entries, err := s.sysdb.ListTable("blacklist-cn")
|
||||
if err != nil {
|
||||
return bannedCountryCodes
|
||||
}
|
||||
for _, keypairs := range entries {
|
||||
ip := string(keypairs[0])
|
||||
bannedCountryCodes = append(bannedCountryCodes, ip)
|
||||
}
|
||||
|
||||
return bannedCountryCodes
|
||||
}
|
||||
|
||||
func (s *Store) AddIPToBlackList(ipAddr string) {
|
||||
s.sysdb.Write("blacklist-ip", ipAddr, true)
|
||||
}
|
||||
|
||||
func (s *Store) RemoveIPFromBlackList(ipAddr string) {
|
||||
s.sysdb.Delete("blacklist-ip", ipAddr)
|
||||
}
|
||||
|
||||
func (s *Store) IsIPBlacklisted(ipAddr string) bool {
|
||||
var isBlacklisted bool = false
|
||||
s.sysdb.Read("blacklist-ip", ipAddr, &isBlacklisted)
|
||||
if isBlacklisted {
|
||||
return true
|
||||
}
|
||||
|
||||
//Check for IP wildcard and CIRD rules
|
||||
AllBlacklistedIps := s.GetAllBlacklistedIp()
|
||||
for _, blacklistRule := range AllBlacklistedIps {
|
||||
wildcardMatch := MatchIpWildcard(ipAddr, blacklistRule)
|
||||
if wildcardMatch {
|
||||
return true
|
||||
}
|
||||
|
||||
cidrMatch := MatchIpCIDR(ipAddr, blacklistRule)
|
||||
if cidrMatch {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *Store) GetAllBlacklistedIp() []string {
|
||||
bannedIps := []string{}
|
||||
entries, err := s.sysdb.ListTable("blacklist-ip")
|
||||
if err != nil {
|
||||
return bannedIps
|
||||
}
|
||||
|
||||
for _, keypairs := range entries {
|
||||
ip := string(keypairs[0])
|
||||
bannedIps = append(bannedIps, ip)
|
||||
}
|
||||
|
||||
return bannedIps
|
||||
}
|
||||
|
||||
//Check if a IP address is blacklisted, in either country or IP blacklist
|
||||
func (s *Store) IsBlacklisted(ipAddr string) bool {
|
||||
if !s.Enabled {
|
||||
//Blacklist not enabled. Always return false
|
||||
return false
|
||||
}
|
||||
|
||||
if ipAddr == "" {
|
||||
//Unable to get the target IP address
|
||||
return false
|
||||
}
|
||||
|
||||
countryCode, err := s.ResolveCountryCodeFromIP(ipAddr)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if s.IsCountryCodeBlacklisted(countryCode.CountryIsoCode) {
|
||||
return true
|
||||
}
|
||||
|
||||
if s.IsIPBlacklisted(ipAddr) {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *Store) GetRequesterCountryISOCode(r *http.Request) string {
|
||||
ipAddr := GetRequesterIP(r)
|
||||
if ipAddr == "" {
|
||||
return ""
|
||||
}
|
||||
countryCode, err := s.ResolveCountryCodeFromIP(ipAddr)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return countryCode.CountryIsoCode
|
||||
}
|
||||
|
||||
//Utilities function
|
||||
func GetRequesterIP(r *http.Request) string {
|
||||
ip := r.Header.Get("X-Forwarded-For")
|
||||
if ip == "" {
|
||||
ip = r.Header.Get("X-Real-IP")
|
||||
if ip == "" {
|
||||
ip = strings.Split(r.RemoteAddr, ":")[0]
|
||||
}
|
||||
}
|
||||
return ip
|
||||
}
|
||||
|
||||
//Match the IP address with a wildcard string
|
||||
func MatchIpWildcard(ipAddress, wildcard string) bool {
|
||||
// Split IP address and wildcard into octets
|
||||
ipOctets := strings.Split(ipAddress, ".")
|
||||
wildcardOctets := strings.Split(wildcard, ".")
|
||||
|
||||
// Check that both have 4 octets
|
||||
if len(ipOctets) != 4 || len(wildcardOctets) != 4 {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check each octet to see if it matches the wildcard or is an exact match
|
||||
for i := 0; i < 4; i++ {
|
||||
if wildcardOctets[i] == "*" {
|
||||
continue
|
||||
}
|
||||
if ipOctets[i] != wildcardOctets[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
//Match ip address with CIDR
|
||||
func MatchIpCIDR(ip string, cidr string) bool {
|
||||
// parse the CIDR string
|
||||
_, cidrnet, err := net.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// parse the IP address
|
||||
ipAddr := net.ParseIP(ip)
|
||||
|
||||
// check if the IP address is within the CIDR range
|
||||
return cidrnet.Contains(ipAddr)
|
||||
}
|
@@ -1,21 +0,0 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2018-present tobychui
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
@@ -1,68 +0,0 @@
|
||||
# Introduction
|
||||
A minimalist proxy library for go, inspired by `net/http/httputil` and add support for HTTPS using HTTP Tunnel
|
||||
|
||||
Support cancels an in-flight request by closing it's connection
|
||||
|
||||
# Installation
|
||||
```sh
|
||||
go get github.com/cssivision/reverseproxy
|
||||
```
|
||||
|
||||
# Usage
|
||||
|
||||
## A simple proxy
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"github.com/cssivision/reverseproxy"
|
||||
)
|
||||
|
||||
func main() {
|
||||
http.ListenAndServe(":8080", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
path, err := url.Parse("https://github.com")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
return
|
||||
}
|
||||
proxy := reverseproxy.NewReverseProxy(path)
|
||||
proxy.ServeHTTP(w, r)
|
||||
}))
|
||||
}
|
||||
```
|
||||
|
||||
## Use as a proxy server
|
||||
|
||||
To use proxy server, you should set browser to use the proxy server as an HTTP proxy.
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"github.com/cssivision/reverseproxy"
|
||||
)
|
||||
|
||||
func main() {
|
||||
http.ListenAndServe(":8080", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
path, err := url.Parse("http://" + r.Host)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
return
|
||||
}
|
||||
|
||||
proxy := reverseproxy.NewReverseProxy(path)
|
||||
proxy.ServeHTTP(w, r)
|
||||
|
||||
// Specific for HTTP and HTTPS
|
||||
// if r.Method == "CONNECT" {
|
||||
// proxy.ProxyHTTPS(w, r)
|
||||
// } else {
|
||||
// proxy.ProxyHTTP(w, r)
|
||||
// }
|
||||
}))
|
||||
}
|
||||
```
|
@@ -1,405 +0,0 @@
|
||||
package reverseproxy
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var onExitFlushLoop func()
|
||||
|
||||
const (
|
||||
defaultTimeout = time.Minute * 5
|
||||
)
|
||||
|
||||
// ReverseProxy is an HTTP Handler that takes an incoming request and
|
||||
// sends it to another server, proxying the response back to the
|
||||
// client, support http, also support https tunnel using http.hijacker
|
||||
type ReverseProxy struct {
|
||||
// Set the timeout of the proxy server, default is 5 minutes
|
||||
Timeout time.Duration
|
||||
|
||||
// Director must be a function which modifies
|
||||
// the request into a new request to be sent
|
||||
// using Transport. Its response is then copied
|
||||
// back to the original client unmodified.
|
||||
// Director must not access the provided Request
|
||||
// after returning.
|
||||
Director func(*http.Request)
|
||||
|
||||
// The transport used to perform proxy requests.
|
||||
// default is http.DefaultTransport.
|
||||
Transport http.RoundTripper
|
||||
|
||||
// FlushInterval specifies the flush interval
|
||||
// to flush to the client while copying the
|
||||
// response body. If zero, no periodic flushing is done.
|
||||
FlushInterval time.Duration
|
||||
|
||||
// ErrorLog specifies an optional logger for errors
|
||||
// that occur when attempting to proxy the request.
|
||||
// If nil, logging goes to os.Stderr via the log package's
|
||||
// standard logger.
|
||||
ErrorLog *log.Logger
|
||||
|
||||
// ModifyResponse is an optional function that
|
||||
// modifies the Response from the backend.
|
||||
// If it returns an error, the proxy returns a StatusBadGateway error.
|
||||
ModifyResponse func(*http.Response) error
|
||||
|
||||
Verbal bool
|
||||
}
|
||||
|
||||
type requestCanceler interface {
|
||||
CancelRequest(req *http.Request)
|
||||
}
|
||||
|
||||
// NewReverseProxy returns a new ReverseProxy that routes
|
||||
// URLs to the scheme, host, and base path provided in target. If the
|
||||
// target's path is "/base" and the incoming request was for "/dir",
|
||||
// the target request will be for /base/dir. if the target's query is a=10
|
||||
// and the incoming request's query is b=100, the target's request's query
|
||||
// will be a=10&b=100.
|
||||
// NewReverseProxy does not rewrite the Host header.
|
||||
// To rewrite Host headers, use ReverseProxy directly with a custom
|
||||
// Director policy.
|
||||
func NewReverseProxy(target *url.URL) *ReverseProxy {
|
||||
targetQuery := target.RawQuery
|
||||
director := func(req *http.Request) {
|
||||
req.URL.Scheme = target.Scheme
|
||||
req.URL.Host = target.Host
|
||||
req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path)
|
||||
|
||||
// If Host is empty, the Request.Write method uses
|
||||
// the value of URL.Host.
|
||||
// force use URL.Host
|
||||
req.Host = req.URL.Host
|
||||
if targetQuery == "" || req.URL.RawQuery == "" {
|
||||
req.URL.RawQuery = targetQuery + req.URL.RawQuery
|
||||
} else {
|
||||
req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery
|
||||
}
|
||||
|
||||
if _, ok := req.Header["User-Agent"]; !ok {
|
||||
req.Header.Set("User-Agent", "")
|
||||
}
|
||||
}
|
||||
|
||||
return &ReverseProxy{Director: director, Verbal: false}
|
||||
}
|
||||
|
||||
func singleJoiningSlash(a, b string) string {
|
||||
aslash := strings.HasSuffix(a, "/")
|
||||
bslash := strings.HasPrefix(b, "/")
|
||||
switch {
|
||||
case aslash && bslash:
|
||||
return a + b[1:]
|
||||
case !aslash && !bslash:
|
||||
return a + "/" + b
|
||||
}
|
||||
return a + b
|
||||
}
|
||||
|
||||
func copyHeader(dst, src http.Header) {
|
||||
for k, vv := range src {
|
||||
for _, v := range vv {
|
||||
dst.Add(k, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Hop-by-hop headers. These are removed when sent to the backend.
|
||||
// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html
|
||||
var hopHeaders = []string{
|
||||
//"Connection",
|
||||
"Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google
|
||||
"Keep-Alive",
|
||||
"Proxy-Authenticate",
|
||||
"Proxy-Authorization",
|
||||
"Te", // canonicalized version of "TE"
|
||||
"Trailer", // not Trailers per URL above; http://www.rfc-editor.org/errata_search.php?eid=4522
|
||||
"Transfer-Encoding",
|
||||
//"Upgrade",
|
||||
}
|
||||
|
||||
func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) {
|
||||
if p.FlushInterval != 0 {
|
||||
if wf, ok := dst.(writeFlusher); ok {
|
||||
mlw := &maxLatencyWriter{
|
||||
dst: wf,
|
||||
latency: p.FlushInterval,
|
||||
done: make(chan bool),
|
||||
}
|
||||
|
||||
go mlw.flushLoop()
|
||||
defer mlw.stop()
|
||||
dst = mlw
|
||||
}
|
||||
}
|
||||
|
||||
io.Copy(dst, src)
|
||||
}
|
||||
|
||||
type writeFlusher interface {
|
||||
io.Writer
|
||||
http.Flusher
|
||||
}
|
||||
|
||||
type maxLatencyWriter struct {
|
||||
dst writeFlusher
|
||||
latency time.Duration
|
||||
mu sync.Mutex
|
||||
done chan bool
|
||||
}
|
||||
|
||||
func (m *maxLatencyWriter) Write(b []byte) (int, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return m.dst.Write(b)
|
||||
}
|
||||
|
||||
func (m *maxLatencyWriter) flushLoop() {
|
||||
t := time.NewTicker(m.latency)
|
||||
defer t.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-m.done:
|
||||
if onExitFlushLoop != nil {
|
||||
onExitFlushLoop()
|
||||
}
|
||||
return
|
||||
case <-t.C:
|
||||
m.mu.Lock()
|
||||
m.dst.Flush()
|
||||
m.mu.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *maxLatencyWriter) stop() {
|
||||
m.done <- true
|
||||
}
|
||||
|
||||
func (p *ReverseProxy) logf(format string, args ...interface{}) {
|
||||
if p.ErrorLog != nil {
|
||||
p.ErrorLog.Printf(format, args...)
|
||||
} else {
|
||||
log.Printf(format, args...)
|
||||
}
|
||||
}
|
||||
|
||||
func removeHeaders(header http.Header) {
|
||||
// Remove hop-by-hop headers listed in the "Connection" header.
|
||||
if c := header.Get("Connection"); c != "" {
|
||||
for _, f := range strings.Split(c, ",") {
|
||||
if f = strings.TrimSpace(f); f != "" {
|
||||
header.Del(f)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Remove hop-by-hop headers
|
||||
for _, h := range hopHeaders {
|
||||
if header.Get(h) != "" {
|
||||
header.Del(h)
|
||||
}
|
||||
}
|
||||
|
||||
if header.Get("A-Upgrade") != "" {
|
||||
header.Set("Upgrade", header.Get("A-Upgrade"))
|
||||
header.Del("A-Upgrade")
|
||||
}
|
||||
}
|
||||
|
||||
func addXForwardedForHeader(req *http.Request) {
|
||||
if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
|
||||
// If we aren't the first proxy retain prior
|
||||
// X-Forwarded-For information as a comma+space
|
||||
// separated list and fold multiple headers into one.
|
||||
if prior, ok := req.Header["X-Forwarded-For"]; ok {
|
||||
clientIP = strings.Join(prior, ", ") + ", " + clientIP
|
||||
}
|
||||
req.Header.Set("X-Forwarded-For", clientIP)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ReverseProxy) ProxyHTTP(rw http.ResponseWriter, req *http.Request) error {
|
||||
transport := p.Transport
|
||||
if transport == nil {
|
||||
transport = http.DefaultTransport
|
||||
}
|
||||
|
||||
outreq := new(http.Request)
|
||||
// Shallow copies of maps, like header
|
||||
*outreq = *req
|
||||
|
||||
if cn, ok := rw.(http.CloseNotifier); ok {
|
||||
if requestCanceler, ok := transport.(requestCanceler); ok {
|
||||
// After the Handler has returned, there is no guarantee
|
||||
// that the channel receives a value, so to make sure
|
||||
reqDone := make(chan struct{})
|
||||
defer close(reqDone)
|
||||
clientGone := cn.CloseNotify()
|
||||
|
||||
go func() {
|
||||
select {
|
||||
case <-clientGone:
|
||||
requestCanceler.CancelRequest(outreq)
|
||||
case <-reqDone:
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
p.Director(outreq)
|
||||
outreq.Close = false
|
||||
|
||||
// We may modify the header (shallow copied above), so we only copy it.
|
||||
outreq.Header = make(http.Header)
|
||||
copyHeader(outreq.Header, req.Header)
|
||||
|
||||
// Remove hop-by-hop headers listed in the "Connection" header, Remove hop-by-hop headers.
|
||||
removeHeaders(outreq.Header)
|
||||
|
||||
// Add X-Forwarded-For Header.
|
||||
addXForwardedForHeader(outreq)
|
||||
|
||||
res, err := transport.RoundTrip(outreq)
|
||||
if err != nil {
|
||||
if p.Verbal {
|
||||
p.logf("http: proxy error: %v", err)
|
||||
}
|
||||
|
||||
//rw.WriteHeader(http.StatusBadGateway)
|
||||
return err
|
||||
}
|
||||
|
||||
// Remove hop-by-hop headers listed in the "Connection" header of the response, Remove hop-by-hop headers.
|
||||
removeHeaders(res.Header)
|
||||
|
||||
if p.ModifyResponse != nil {
|
||||
if err := p.ModifyResponse(res); err != nil {
|
||||
if p.Verbal {
|
||||
p.logf("http: proxy error: %v", err)
|
||||
}
|
||||
//rw.WriteHeader(http.StatusBadGateway)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Copy header from response to client.
|
||||
copyHeader(rw.Header(), res.Header)
|
||||
|
||||
// The "Trailer" header isn't included in the Transport's response, Build it up from Trailer.
|
||||
if len(res.Trailer) > 0 {
|
||||
trailerKeys := make([]string, 0, len(res.Trailer))
|
||||
for k := range res.Trailer {
|
||||
trailerKeys = append(trailerKeys, k)
|
||||
}
|
||||
rw.Header().Add("Trailer", strings.Join(trailerKeys, ", "))
|
||||
}
|
||||
|
||||
rw.WriteHeader(res.StatusCode)
|
||||
if len(res.Trailer) > 0 {
|
||||
// Force chunking if we saw a response trailer.
|
||||
// This prevents net/http from calculating the length for short
|
||||
// bodies and adding a Content-Length.
|
||||
if fl, ok := rw.(http.Flusher); ok {
|
||||
fl.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
p.copyResponse(rw, res.Body)
|
||||
// close now, instead of defer, to populate res.Trailer
|
||||
res.Body.Close()
|
||||
copyHeader(rw.Header(), res.Trailer)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *ReverseProxy) ProxyHTTPS(rw http.ResponseWriter, req *http.Request) error {
|
||||
hij, ok := rw.(http.Hijacker)
|
||||
if !ok {
|
||||
p.logf("http server does not support hijacker")
|
||||
return errors.New("http server does not support hijacker")
|
||||
}
|
||||
|
||||
clientConn, _, err := hij.Hijack()
|
||||
if err != nil {
|
||||
if p.Verbal {
|
||||
p.logf("http: proxy error: %v", err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
proxyConn, err := net.Dial("tcp", req.URL.Host)
|
||||
if err != nil {
|
||||
if p.Verbal {
|
||||
p.logf("http: proxy error: %v", err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// The returned net.Conn may have read or write deadlines
|
||||
// already set, depending on the configuration of the
|
||||
// Server, to set or clear those deadlines as needed
|
||||
// we set timeout to 5 minutes
|
||||
deadline := time.Now()
|
||||
if p.Timeout == 0 {
|
||||
deadline = deadline.Add(time.Minute * 5)
|
||||
} else {
|
||||
deadline = deadline.Add(p.Timeout)
|
||||
}
|
||||
|
||||
err = clientConn.SetDeadline(deadline)
|
||||
if err != nil {
|
||||
if p.Verbal {
|
||||
p.logf("http: proxy error: %v", err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
err = proxyConn.SetDeadline(deadline)
|
||||
if err != nil {
|
||||
if p.Verbal {
|
||||
p.logf("http: proxy error: %v", err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = clientConn.Write([]byte("HTTP/1.0 200 OK\r\n\r\n"))
|
||||
if err != nil {
|
||||
if p.Verbal {
|
||||
p.logf("http: proxy error: %v", err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
go func() {
|
||||
io.Copy(clientConn, proxyConn)
|
||||
clientConn.Close()
|
||||
proxyConn.Close()
|
||||
}()
|
||||
|
||||
io.Copy(proxyConn, clientConn)
|
||||
proxyConn.Close()
|
||||
clientConn.Close()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) error {
|
||||
if req.Method == "CONNECT" {
|
||||
err := p.ProxyHTTPS(rw, req)
|
||||
return err
|
||||
} else {
|
||||
err := p.ProxyHTTP(rw, req)
|
||||
return err
|
||||
}
|
||||
}
|
@@ -1,67 +0,0 @@
|
||||
package statistic
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"imuslab.com/zoraxy/mod/utils"
|
||||
)
|
||||
|
||||
/*
|
||||
Handler.go
|
||||
|
||||
This script handles incoming request for loading the statistic of the day
|
||||
|
||||
*/
|
||||
|
||||
func (c *Collector) HandleTodayStatLoad(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
fast, err := utils.GetPara(r, "fast")
|
||||
if err != nil {
|
||||
fast = "false"
|
||||
}
|
||||
d := c.DailySummary
|
||||
if fast == "true" {
|
||||
//Only return the counter
|
||||
exported := DailySummaryExport{
|
||||
TotalRequest: d.TotalRequest,
|
||||
ErrorRequest: d.ErrorRequest,
|
||||
ValidRequest: d.ValidRequest,
|
||||
}
|
||||
js, _ := json.Marshal(exported)
|
||||
utils.SendJSONResponse(w, string(js))
|
||||
} else {
|
||||
//Return everything
|
||||
exported := DailySummaryExport{
|
||||
TotalRequest: d.TotalRequest,
|
||||
ErrorRequest: d.ErrorRequest,
|
||||
ValidRequest: d.ValidRequest,
|
||||
ForwardTypes: make(map[string]int),
|
||||
RequestOrigin: make(map[string]int),
|
||||
RequestClientIp: make(map[string]int),
|
||||
}
|
||||
|
||||
// Export ForwardTypes sync.Map
|
||||
d.ForwardTypes.Range(func(key, value interface{}) bool {
|
||||
exported.ForwardTypes[key.(string)] = value.(int)
|
||||
return true
|
||||
})
|
||||
|
||||
// Export RequestOrigin sync.Map
|
||||
d.RequestOrigin.Range(func(key, value interface{}) bool {
|
||||
exported.RequestOrigin[key.(string)] = value.(int)
|
||||
return true
|
||||
})
|
||||
|
||||
// Export RequestClientIp sync.Map
|
||||
d.RequestClientIp.Range(func(key, value interface{}) bool {
|
||||
exported.RequestClientIp[key.(string)] = value.(int)
|
||||
return true
|
||||
})
|
||||
|
||||
js, _ := json.Marshal(exported)
|
||||
|
||||
utils.SendJSONResponse(w, string(js))
|
||||
}
|
||||
|
||||
}
|
@@ -1,186 +0,0 @@
|
||||
package statistic
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"imuslab.com/zoraxy/mod/database"
|
||||
)
|
||||
|
||||
/*
|
||||
Statistic Package
|
||||
|
||||
This packet is designed to collection information
|
||||
and store them for future analysis
|
||||
*/
|
||||
|
||||
//Faststat, a interval summary for all collected data and avoid
|
||||
//looping through every data everytime a overview is needed
|
||||
type DailySummary struct {
|
||||
TotalRequest int64 //Total request of the day
|
||||
ErrorRequest int64 //Invalid request of the day, including error or not found
|
||||
ValidRequest int64 //Valid request of the day
|
||||
//Type counters
|
||||
ForwardTypes *sync.Map //Map that hold the forward types
|
||||
RequestOrigin *sync.Map //Map that hold [country ISO code]: visitor counter
|
||||
RequestClientIp *sync.Map //Map that hold all unique request IPs
|
||||
}
|
||||
|
||||
type RequestInfo struct {
|
||||
IpAddr string
|
||||
RequestOriginalCountryISOCode string
|
||||
Succ bool
|
||||
StatusCode int
|
||||
ForwardType string
|
||||
}
|
||||
|
||||
type CollectorOption struct {
|
||||
Database *database.Database
|
||||
}
|
||||
|
||||
type Collector struct {
|
||||
rtdataStopChan chan bool
|
||||
DailySummary *DailySummary
|
||||
Option *CollectorOption
|
||||
}
|
||||
|
||||
func NewStatisticCollector(option CollectorOption) (*Collector, error) {
|
||||
option.Database.NewTable("stats")
|
||||
|
||||
//Create the collector object
|
||||
thisCollector := Collector{
|
||||
DailySummary: newDailySummary(),
|
||||
Option: &option,
|
||||
}
|
||||
|
||||
//Load the stat if exists for today
|
||||
//This will exists if the program was forcefully restarted
|
||||
year, month, day := time.Now().Date()
|
||||
summary := thisCollector.LoadSummaryOfDay(year, month, day)
|
||||
if summary != nil {
|
||||
thisCollector.DailySummary = summary
|
||||
}
|
||||
|
||||
//Schedule the realtime statistic clearing at midnight everyday
|
||||
rtstatStopChan := thisCollector.ScheduleResetRealtimeStats()
|
||||
thisCollector.rtdataStopChan = rtstatStopChan
|
||||
|
||||
return &thisCollector, nil
|
||||
}
|
||||
|
||||
//Write the current in-memory summary to database file
|
||||
func (c *Collector) SaveSummaryOfDay() {
|
||||
//When it is called in 0:00am, make sure it is stored as yesterday key
|
||||
t := time.Now().Add(-30 * time.Second)
|
||||
summaryKey := t.Format("02_01_2006")
|
||||
saveData := DailySummaryToExport(*c.DailySummary)
|
||||
c.Option.Database.Write("stats", summaryKey, saveData)
|
||||
}
|
||||
|
||||
//Load the summary of a day given
|
||||
func (c *Collector) LoadSummaryOfDay(year int, month time.Month, day int) *DailySummary {
|
||||
date := time.Date(year, time.Month(month), day, 0, 0, 0, 0, time.Local)
|
||||
summaryKey := date.Format("02_01_2006")
|
||||
targetSummaryExport := DailySummaryExport{}
|
||||
c.Option.Database.Read("stats", summaryKey, &targetSummaryExport)
|
||||
targetSummary := DailySummaryExportToSummary(targetSummaryExport)
|
||||
return &targetSummary
|
||||
}
|
||||
|
||||
//This function gives the current slot in the 288- 5 minutes interval of the day
|
||||
func (c *Collector) GetCurrentRealtimeStatIntervalId() int {
|
||||
now := time.Now()
|
||||
startOfDay := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, time.Local).Unix()
|
||||
secondsSinceStartOfDay := now.Unix() - startOfDay
|
||||
interval := secondsSinceStartOfDay / (5 * 60)
|
||||
return int(interval)
|
||||
}
|
||||
|
||||
func (c *Collector) Close() {
|
||||
//Stop the ticker
|
||||
c.rtdataStopChan <- true
|
||||
|
||||
//Write the buffered data into database
|
||||
c.SaveSummaryOfDay()
|
||||
|
||||
}
|
||||
|
||||
//Main function to record all the inbound traffics
|
||||
//Note that this function run in go routine and might have concurrent R/W issue
|
||||
//Please make sure there is no racing paramters in this function
|
||||
func (c *Collector) RecordRequest(ri RequestInfo) {
|
||||
go func() {
|
||||
c.DailySummary.TotalRequest++
|
||||
if ri.Succ {
|
||||
c.DailySummary.ValidRequest++
|
||||
} else {
|
||||
c.DailySummary.ErrorRequest++
|
||||
}
|
||||
|
||||
//Store the request info into correct types of maps
|
||||
ft, ok := c.DailySummary.ForwardTypes.Load(ri.ForwardType)
|
||||
if !ok {
|
||||
c.DailySummary.ForwardTypes.Store(ri.ForwardType, 1)
|
||||
} else {
|
||||
c.DailySummary.ForwardTypes.Store(ri.ForwardType, ft.(int)+1)
|
||||
}
|
||||
|
||||
originISO := strings.ToLower(ri.RequestOriginalCountryISOCode)
|
||||
fo, ok := c.DailySummary.RequestOrigin.Load(originISO)
|
||||
if !ok {
|
||||
c.DailySummary.RequestOrigin.Store(originISO, 1)
|
||||
} else {
|
||||
c.DailySummary.RequestOrigin.Store(originISO, fo.(int)+1)
|
||||
}
|
||||
|
||||
fi, ok := c.DailySummary.RequestClientIp.Load(ri.IpAddr)
|
||||
if !ok {
|
||||
c.DailySummary.RequestClientIp.Store(ri.IpAddr, 1)
|
||||
} else {
|
||||
c.DailySummary.RequestClientIp.Store(ri.IpAddr, fi.(int)+1)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
//nightly task
|
||||
func (c *Collector) ScheduleResetRealtimeStats() chan bool {
|
||||
doneCh := make(chan bool)
|
||||
|
||||
go func() {
|
||||
defer close(doneCh)
|
||||
|
||||
for {
|
||||
// calculate duration until next midnight
|
||||
now := time.Now()
|
||||
|
||||
// Get midnight of the next day in the local time zone
|
||||
midnight := time.Date(now.Year(), now.Month(), now.Day()+1, 0, 0, 0, 0, now.Location())
|
||||
|
||||
// Calculate the duration until midnight
|
||||
duration := midnight.Sub(now)
|
||||
select {
|
||||
case <-time.After(duration):
|
||||
// store daily summary to database and reset summary
|
||||
c.SaveSummaryOfDay()
|
||||
c.DailySummary = newDailySummary()
|
||||
case <-doneCh:
|
||||
// stop the routine
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return doneCh
|
||||
}
|
||||
|
||||
func newDailySummary() *DailySummary {
|
||||
return &DailySummary{
|
||||
TotalRequest: 0,
|
||||
ErrorRequest: 0,
|
||||
ValidRequest: 0,
|
||||
ForwardTypes: &sync.Map{},
|
||||
RequestOrigin: &sync.Map{},
|
||||
RequestClientIp: &sync.Map{},
|
||||
}
|
||||
}
|
@@ -1,66 +0,0 @@
|
||||
package statistic
|
||||
|
||||
import "sync"
|
||||
|
||||
type DailySummaryExport struct {
|
||||
TotalRequest int64 //Total request of the day
|
||||
ErrorRequest int64 //Invalid request of the day, including error or not found
|
||||
ValidRequest int64 //Valid request of the day
|
||||
|
||||
ForwardTypes map[string]int
|
||||
RequestOrigin map[string]int
|
||||
RequestClientIp map[string]int
|
||||
}
|
||||
|
||||
func DailySummaryToExport(summary DailySummary) DailySummaryExport {
|
||||
export := DailySummaryExport{
|
||||
TotalRequest: summary.TotalRequest,
|
||||
ErrorRequest: summary.ErrorRequest,
|
||||
ValidRequest: summary.ValidRequest,
|
||||
ForwardTypes: make(map[string]int),
|
||||
RequestOrigin: make(map[string]int),
|
||||
RequestClientIp: make(map[string]int),
|
||||
}
|
||||
|
||||
summary.ForwardTypes.Range(func(key, value interface{}) bool {
|
||||
export.ForwardTypes[key.(string)] = value.(int)
|
||||
return true
|
||||
})
|
||||
|
||||
summary.RequestOrigin.Range(func(key, value interface{}) bool {
|
||||
export.RequestOrigin[key.(string)] = value.(int)
|
||||
return true
|
||||
})
|
||||
|
||||
summary.RequestClientIp.Range(func(key, value interface{}) bool {
|
||||
export.RequestClientIp[key.(string)] = value.(int)
|
||||
return true
|
||||
})
|
||||
|
||||
return export
|
||||
}
|
||||
|
||||
func DailySummaryExportToSummary(export DailySummaryExport) DailySummary {
|
||||
summary := DailySummary{
|
||||
TotalRequest: export.TotalRequest,
|
||||
ErrorRequest: export.ErrorRequest,
|
||||
ValidRequest: export.ValidRequest,
|
||||
ForwardTypes: &sync.Map{},
|
||||
RequestOrigin: &sync.Map{},
|
||||
RequestClientIp: &sync.Map{},
|
||||
}
|
||||
|
||||
for k, v := range export.ForwardTypes {
|
||||
summary.ForwardTypes.Store(k, v)
|
||||
}
|
||||
|
||||
for k, v := range export.RequestOrigin {
|
||||
summary.RequestOrigin.Store(k, v)
|
||||
}
|
||||
|
||||
for k, v := range export.RequestClientIp {
|
||||
summary.RequestClientIp.Store(k, v)
|
||||
}
|
||||
|
||||
return summary
|
||||
}
|
@@ -1,60 +0,0 @@
|
||||
package tlscert
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
//This remove the certificates in the list where either the
|
||||
//public key or the private key is missing
|
||||
func getCertPairs(certFiles []string) []string {
|
||||
crtMap := make(map[string]bool)
|
||||
keyMap := make(map[string]bool)
|
||||
|
||||
for _, filename := range certFiles {
|
||||
if filepath.Ext(filename) == ".crt" {
|
||||
crtMap[strings.TrimSuffix(filename, ".crt")] = true
|
||||
} else if filepath.Ext(filename) == ".key" {
|
||||
keyMap[strings.TrimSuffix(filename, ".key")] = true
|
||||
}
|
||||
}
|
||||
|
||||
var result []string
|
||||
for domain := range crtMap {
|
||||
if keyMap[domain] {
|
||||
result = append(result, domain)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
//Get the cloest subdomain certificate from a list of domains
|
||||
func matchClosestDomainCertificate(subdomain string, domains []string) string {
|
||||
var matchingDomain string = ""
|
||||
maxLength := 0
|
||||
|
||||
for _, domain := range domains {
|
||||
if strings.HasSuffix(subdomain, "."+domain) && len(domain) > maxLength {
|
||||
matchingDomain = domain
|
||||
maxLength = len(domain)
|
||||
}
|
||||
}
|
||||
|
||||
return matchingDomain
|
||||
}
|
||||
|
||||
//Check if a requesting domain is a subdomain of a given domain
|
||||
func isSubdomain(subdomain, domain string) bool {
|
||||
subdomainParts := strings.Split(subdomain, ".")
|
||||
domainParts := strings.Split(domain, ".")
|
||||
if len(subdomainParts) < len(domainParts) {
|
||||
return false
|
||||
}
|
||||
for i := range domainParts {
|
||||
if subdomainParts[len(subdomainParts)-1-i] != domainParts[len(domainParts)-1-i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
@@ -1,172 +0,0 @@
|
||||
package tlscert
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"imuslab.com/zoraxy/mod/utils"
|
||||
)
|
||||
|
||||
type Manager struct {
|
||||
CertStore string
|
||||
verbal bool
|
||||
}
|
||||
|
||||
func NewManager(certStore string) (*Manager, error) {
|
||||
if !utils.FileExists(certStore) {
|
||||
os.MkdirAll(certStore, 0775)
|
||||
}
|
||||
|
||||
thisManager := Manager{
|
||||
CertStore: certStore,
|
||||
verbal: true,
|
||||
}
|
||||
|
||||
return &thisManager, nil
|
||||
}
|
||||
|
||||
func (m *Manager) ListCertDomains() ([]string, error) {
|
||||
filenames, err := m.ListCerts()
|
||||
if err != nil {
|
||||
return []string{}, err
|
||||
}
|
||||
|
||||
//Remove certificates where there are missing public key or private key
|
||||
filenames = getCertPairs(filenames)
|
||||
|
||||
return filenames, nil
|
||||
}
|
||||
|
||||
func (m *Manager) ListCerts() ([]string, error) {
|
||||
certs, err := ioutil.ReadDir(m.CertStore)
|
||||
if err != nil {
|
||||
return []string{}, err
|
||||
}
|
||||
|
||||
filenames := make([]string, 0, len(certs))
|
||||
for _, cert := range certs {
|
||||
if !cert.IsDir() {
|
||||
filenames = append(filenames, cert.Name())
|
||||
}
|
||||
}
|
||||
|
||||
return filenames, nil
|
||||
}
|
||||
|
||||
func (m *Manager) GetCert(helloInfo *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
//Check if the domain corrisponding cert exists
|
||||
pubKey := "./system/localhost.crt"
|
||||
priKey := "./system/localhost.key"
|
||||
|
||||
if utils.FileExists(filepath.Join(m.CertStore, helloInfo.ServerName+".crt")) && utils.FileExists(filepath.Join(m.CertStore, helloInfo.ServerName+".key")) {
|
||||
pubKey = filepath.Join(m.CertStore, helloInfo.ServerName+".crt")
|
||||
priKey = filepath.Join(m.CertStore, helloInfo.ServerName+".key")
|
||||
|
||||
} else {
|
||||
domainCerts, _ := m.ListCertDomains()
|
||||
cloestDomainCert := matchClosestDomainCertificate(helloInfo.ServerName, domainCerts)
|
||||
if cloestDomainCert != "" {
|
||||
//There is a matching parent domain for this subdomain. Use this instead.
|
||||
pubKey = filepath.Join(m.CertStore, cloestDomainCert+".crt")
|
||||
priKey = filepath.Join(m.CertStore, cloestDomainCert+".key")
|
||||
} else if m.DefaultCertExists() {
|
||||
//Use default.crt and default.key
|
||||
pubKey = filepath.Join(m.CertStore, "default.crt")
|
||||
priKey = filepath.Join(m.CertStore, "default.key")
|
||||
if m.verbal {
|
||||
log.Println("No matching certificate found. Serving with default")
|
||||
}
|
||||
} else {
|
||||
if m.verbal {
|
||||
log.Println("Matching certificate not found. Serving with default. Requesting server name: ", helloInfo.ServerName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//Load the cert and serve it
|
||||
cer, err := tls.LoadX509KeyPair(pubKey, priKey)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return &cer, nil
|
||||
}
|
||||
|
||||
//Check if both the default cert public key and private key exists
|
||||
func (m *Manager) DefaultCertExists() bool {
|
||||
return utils.FileExists(filepath.Join(m.CertStore, "default.crt")) && utils.FileExists(filepath.Join(m.CertStore, "default.key"))
|
||||
}
|
||||
|
||||
//Check if the default cert exists returning seperate results for pubkey and prikey
|
||||
func (m *Manager) DefaultCertExistsSep() (bool, bool) {
|
||||
return utils.FileExists(filepath.Join(m.CertStore, "default.crt")), utils.FileExists(filepath.Join(m.CertStore, "default.key"))
|
||||
}
|
||||
|
||||
//Delete the cert if exists
|
||||
func (m *Manager) RemoveCert(domain string) error {
|
||||
pubKey := filepath.Join(m.CertStore, domain+".crt")
|
||||
priKey := filepath.Join(m.CertStore, domain+".key")
|
||||
if utils.FileExists(pubKey) {
|
||||
err := os.Remove(pubKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if utils.FileExists(priKey) {
|
||||
err := os.Remove(priKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
//Check if the given file is a valid TLS file
|
||||
func IsValidTLSFile(file io.Reader) bool {
|
||||
// Read the contents of the uploaded file
|
||||
contents, err := io.ReadAll(file)
|
||||
if err != nil {
|
||||
// Handle the error
|
||||
return false
|
||||
}
|
||||
|
||||
// Parse the contents of the file as a PEM-encoded certificate or key
|
||||
block, _ := pem.Decode(contents)
|
||||
if block == nil {
|
||||
// The file is not a valid PEM-encoded certificate or key
|
||||
return false
|
||||
}
|
||||
|
||||
// Parse the certificate or key
|
||||
if strings.Contains(block.Type, "CERTIFICATE") {
|
||||
// The file contains a certificate
|
||||
cert, err := x509.ParseCertificate(block.Bytes)
|
||||
if err != nil {
|
||||
// Handle the error
|
||||
return false
|
||||
}
|
||||
// Check if the certificate is a valid TLS/SSL certificate
|
||||
return cert.IsCA == false && cert.KeyUsage&x509.KeyUsageDigitalSignature != 0 && cert.KeyUsage&x509.KeyUsageKeyEncipherment != 0
|
||||
} else if strings.Contains(block.Type, "PRIVATE KEY") {
|
||||
// The file contains a private key
|
||||
_, err := x509.ParsePKCS1PrivateKey(block.Bytes)
|
||||
if err != nil {
|
||||
// Handle the error
|
||||
return false
|
||||
}
|
||||
return true
|
||||
} else {
|
||||
return false
|
||||
}
|
||||
|
||||
}
|
@@ -1,127 +0,0 @@
|
||||
package upnp
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"gitlab.com/NebulousLabs/go-upnp"
|
||||
)
|
||||
|
||||
/*
|
||||
uPNP Module
|
||||
|
||||
This module handles uPNP Connections to the gateway router and create a port forward entry
|
||||
for the host system at the given port (set with -port paramter)
|
||||
*/
|
||||
|
||||
type UPnPClient struct {
|
||||
Connection *upnp.IGD //UPnP conenction object
|
||||
ExternalIP string //Storage of external IP address
|
||||
RequiredPorts []int //All the required ports will be recored
|
||||
PolicyNames sync.Map //Name for the required port nubmer
|
||||
}
|
||||
|
||||
func NewUPNPClient() (*UPnPClient, error) {
|
||||
//Create uPNP forwarding in the NAT router
|
||||
log.Println("Discovering UPnP router in Local Area Network...")
|
||||
d, err := upnp.Discover()
|
||||
if err != nil {
|
||||
return &UPnPClient{}, err
|
||||
}
|
||||
|
||||
// discover external IP
|
||||
ip, err := d.ExternalIP()
|
||||
if err != nil {
|
||||
return &UPnPClient{}, err
|
||||
}
|
||||
|
||||
//Create the final obejcts
|
||||
newUPnPObject := &UPnPClient{
|
||||
Connection: d,
|
||||
ExternalIP: ip,
|
||||
RequiredPorts: []int{},
|
||||
}
|
||||
|
||||
return newUPnPObject, nil
|
||||
}
|
||||
|
||||
func (u *UPnPClient) ForwardPort(portNumber int, ruleName string) error {
|
||||
log.Println("UPnP forwarding new port: ", portNumber, "for "+ruleName+" service")
|
||||
|
||||
//Check if port already forwarded
|
||||
_, ok := u.PolicyNames.Load(portNumber)
|
||||
if ok {
|
||||
//Port already forward. Ignore this request
|
||||
return errors.New("Port already forwarded")
|
||||
}
|
||||
|
||||
// forward a port
|
||||
err := u.Connection.Forward(uint16(portNumber), ruleName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
u.RequiredPorts = append(u.RequiredPorts, portNumber)
|
||||
u.PolicyNames.Store(portNumber, ruleName)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *UPnPClient) ClosePort(portNumber int) error {
|
||||
//Check if port is opened
|
||||
portOpened := false
|
||||
newRequiredPort := []int{}
|
||||
for _, thisPort := range u.RequiredPorts {
|
||||
if thisPort != portNumber {
|
||||
newRequiredPort = append(newRequiredPort, thisPort)
|
||||
} else {
|
||||
portOpened = true
|
||||
}
|
||||
}
|
||||
|
||||
if portOpened {
|
||||
//Update the port list
|
||||
u.RequiredPorts = newRequiredPort
|
||||
|
||||
// Close the port
|
||||
log.Println("Closing UPnP Port Forward: ", portNumber)
|
||||
err := u.Connection.Clear(uint16(portNumber))
|
||||
|
||||
//Delete the name registry
|
||||
u.PolicyNames.Delete(portNumber)
|
||||
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Renew forward rules, prevent router lease time from flushing the Upnp config
|
||||
func (u *UPnPClient) RenewForwardRules() {
|
||||
portsToRenew := u.RequiredPorts
|
||||
for _, thisPort := range portsToRenew {
|
||||
ruleName, ok := u.PolicyNames.Load(thisPort)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
u.ClosePort(thisPort)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
u.ForwardPort(thisPort, ruleName.(string))
|
||||
}
|
||||
log.Println("UPnP Port Forward rule renew completed")
|
||||
}
|
||||
|
||||
func (u *UPnPClient) Close() {
|
||||
//Shutdown the default UPnP Object
|
||||
if u != nil {
|
||||
for _, portNumber := range u.RequiredPorts {
|
||||
err := u.Connection.Clear(uint16(portNumber))
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@@ -1,227 +0,0 @@
|
||||
package uptime
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"imuslab.com/zoraxy/mod/utils"
|
||||
)
|
||||
|
||||
type Record struct {
|
||||
Timestamp int64
|
||||
ID string
|
||||
Name string
|
||||
URL string
|
||||
Protocol string
|
||||
Online bool
|
||||
StatusCode int
|
||||
Latency int64
|
||||
}
|
||||
|
||||
type Target struct {
|
||||
ID string
|
||||
Name string
|
||||
URL string
|
||||
Protocol string
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Targets []*Target
|
||||
Interval int
|
||||
MaxRecordsStore int
|
||||
}
|
||||
|
||||
type Monitor struct {
|
||||
Config *Config
|
||||
OnlineStatusLog map[string][]*Record
|
||||
}
|
||||
|
||||
// Default configs
|
||||
var exampleTarget = Target{
|
||||
ID: "example",
|
||||
Name: "Example",
|
||||
URL: "example.com",
|
||||
Protocol: "https",
|
||||
}
|
||||
|
||||
//Create a new uptime monitor
|
||||
func NewUptimeMonitor(config *Config) (*Monitor, error) {
|
||||
//Create new monitor object
|
||||
thisMonitor := Monitor{
|
||||
Config: config,
|
||||
OnlineStatusLog: map[string][]*Record{},
|
||||
}
|
||||
//Start the endpoint listener
|
||||
ticker := time.NewTicker(time.Duration(config.Interval) * time.Second)
|
||||
done := make(chan bool)
|
||||
|
||||
//Start the uptime check once first before entering loop
|
||||
thisMonitor.ExecuteUptimeCheck()
|
||||
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
case t := <-ticker.C:
|
||||
log.Println("Uptime updated - ", t.Unix())
|
||||
thisMonitor.ExecuteUptimeCheck()
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return &thisMonitor, nil
|
||||
}
|
||||
|
||||
func (m *Monitor) ExecuteUptimeCheck() {
|
||||
for _, target := range m.Config.Targets {
|
||||
//For each target to check online, do the following
|
||||
var thisRecord Record
|
||||
if target.Protocol == "http" || target.Protocol == "https" {
|
||||
online, laterncy, statusCode := getWebsiteStatusWithLatency(target.URL)
|
||||
thisRecord = Record{
|
||||
Timestamp: time.Now().Unix(),
|
||||
ID: target.ID,
|
||||
Name: target.Name,
|
||||
URL: target.URL,
|
||||
Protocol: target.Protocol,
|
||||
Online: online,
|
||||
StatusCode: statusCode,
|
||||
Latency: laterncy,
|
||||
}
|
||||
|
||||
//fmt.Println(thisRecord)
|
||||
|
||||
} else {
|
||||
log.Println("Unknown protocol: " + target.Protocol + ". Skipping")
|
||||
continue
|
||||
}
|
||||
|
||||
thisRecords, ok := m.OnlineStatusLog[target.ID]
|
||||
if !ok {
|
||||
//First record. Create the array
|
||||
m.OnlineStatusLog[target.ID] = []*Record{&thisRecord}
|
||||
} else {
|
||||
//Append to the previous record
|
||||
thisRecords = append(thisRecords, &thisRecord)
|
||||
|
||||
//Check if the record is longer than the logged record. If yes, clear out the old records
|
||||
if len(thisRecords) > m.Config.MaxRecordsStore {
|
||||
thisRecords = thisRecords[1:]
|
||||
}
|
||||
|
||||
m.OnlineStatusLog[target.ID] = thisRecords
|
||||
}
|
||||
}
|
||||
|
||||
//TODO: Write results to db
|
||||
}
|
||||
|
||||
func (m *Monitor) AddTargetToMonitor(target *Target) {
|
||||
// Add target to Config
|
||||
m.Config.Targets = append(m.Config.Targets, target)
|
||||
|
||||
// Add target to OnlineStatusLog
|
||||
m.OnlineStatusLog[target.ID] = []*Record{}
|
||||
}
|
||||
|
||||
func (m *Monitor) RemoveTargetFromMonitor(targetId string) {
|
||||
// Remove target from Config
|
||||
for i, target := range m.Config.Targets {
|
||||
if target.ID == targetId {
|
||||
m.Config.Targets = append(m.Config.Targets[:i], m.Config.Targets[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Remove target from OnlineStatusLog
|
||||
delete(m.OnlineStatusLog, targetId)
|
||||
}
|
||||
|
||||
//Scan the config target. If a target exists in m.OnlineStatusLog no longer
|
||||
//exists in m.Monitor.Config.Targets, it remove it from the log as well.
|
||||
func (m *Monitor) CleanRecords() {
|
||||
// Create a set of IDs for all targets in the config
|
||||
targetIDs := make(map[string]bool)
|
||||
for _, target := range m.Config.Targets {
|
||||
targetIDs[target.ID] = true
|
||||
}
|
||||
|
||||
// Iterate over all log entries and remove any that have a target ID that
|
||||
// is not in the set of current target IDs
|
||||
newStatusLog := m.OnlineStatusLog
|
||||
for id, _ := range m.OnlineStatusLog {
|
||||
_, idExistsInTargets := targetIDs[id]
|
||||
if !idExistsInTargets {
|
||||
delete(newStatusLog, id)
|
||||
}
|
||||
}
|
||||
|
||||
m.OnlineStatusLog = newStatusLog
|
||||
}
|
||||
|
||||
/*
|
||||
Web Interface Handler
|
||||
*/
|
||||
|
||||
func (m *Monitor) HandleUptimeLogRead(w http.ResponseWriter, r *http.Request) {
|
||||
id, _ := utils.GetPara(r, "id")
|
||||
if id == "" {
|
||||
js, _ := json.Marshal(m.OnlineStatusLog)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write(js)
|
||||
} else {
|
||||
//Check if that id exists
|
||||
log, ok := m.OnlineStatusLog[id]
|
||||
if !ok {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
js, _ := json.MarshalIndent(log, "", " ")
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write(js)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/*
|
||||
Utilities
|
||||
*/
|
||||
|
||||
// Get website stauts with latency given URL, return is conn succ and its latency and status code
|
||||
func getWebsiteStatusWithLatency(url string) (bool, int64, int) {
|
||||
start := time.Now().UnixNano() / int64(time.Millisecond)
|
||||
statusCode, err := getWebsiteStatus(url)
|
||||
end := time.Now().UnixNano() / int64(time.Millisecond)
|
||||
if err != nil {
|
||||
log.Println(err.Error())
|
||||
return false, 0, 0
|
||||
} else {
|
||||
diff := end - start
|
||||
succ := false
|
||||
if statusCode >= 200 && statusCode < 300 {
|
||||
//OK
|
||||
succ = true
|
||||
} else if statusCode >= 300 && statusCode < 400 {
|
||||
//Redirection code
|
||||
succ = true
|
||||
} else {
|
||||
succ = false
|
||||
}
|
||||
|
||||
return succ, diff, statusCode
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func getWebsiteStatus(url string) (int, error) {
|
||||
resp, err := http.Get(url)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
status_code := resp.StatusCode
|
||||
return status_code, nil
|
||||
}
|
@@ -1,16 +0,0 @@
|
||||
package utils
|
||||
|
||||
import "strconv"
|
||||
|
||||
func StringToInt64(number string) (int64, error) {
|
||||
i, err := strconv.ParseInt(number, 10, 64)
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
return i, nil
|
||||
}
|
||||
|
||||
func Int64ToString(number int64) string {
|
||||
convedNumber := strconv.FormatInt(number, 10)
|
||||
return convedNumber
|
||||
}
|
@@ -1,19 +0,0 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
/*
|
||||
Web Template Generator
|
||||
|
||||
This is the main system core module that perform function similar to what PHP did.
|
||||
To replace part of the content of any file, use {{paramter}} to replace it.
|
||||
|
||||
|
||||
*/
|
||||
|
||||
func SendHTMLResponse(w http.ResponseWriter, msg string) {
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
w.Write([]byte(msg))
|
||||
}
|
@@ -1,175 +0,0 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
/*
|
||||
Common
|
||||
|
||||
Some commonly used functions in ArozOS
|
||||
|
||||
*/
|
||||
|
||||
// Response related
|
||||
func SendTextResponse(w http.ResponseWriter, msg string) {
|
||||
w.Write([]byte(msg))
|
||||
}
|
||||
|
||||
// Send JSON response, with an extra json header
|
||||
func SendJSONResponse(w http.ResponseWriter, json string) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write([]byte(json))
|
||||
}
|
||||
|
||||
func SendErrorResponse(w http.ResponseWriter, errMsg string) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write([]byte("{\"error\":\"" + errMsg + "\"}"))
|
||||
}
|
||||
|
||||
func SendOK(w http.ResponseWriter) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write([]byte("\"OK\""))
|
||||
}
|
||||
|
||||
/*
|
||||
The paramter move function (mv)
|
||||
|
||||
You can find similar things in the PHP version of ArOZ Online Beta. You need to pass in
|
||||
r (HTTP Request Object)
|
||||
getParamter (string, aka $_GET['This string])
|
||||
|
||||
Will return
|
||||
Paramter string (if any)
|
||||
Error (if error)
|
||||
|
||||
*/
|
||||
/*
|
||||
func Mv(r *http.Request, getParamter string, postMode bool) (string, error) {
|
||||
if postMode == false {
|
||||
//Access the paramter via GET
|
||||
keys, ok := r.URL.Query()[getParamter]
|
||||
|
||||
if !ok || len(keys[0]) < 1 {
|
||||
//log.Println("Url Param " + getParamter +" is missing")
|
||||
return "", errors.New("GET paramter " + getParamter + " not found or it is empty")
|
||||
}
|
||||
|
||||
// Query()["key"] will return an array of items,
|
||||
// we only want the single item.
|
||||
key := keys[0]
|
||||
return string(key), nil
|
||||
} else {
|
||||
//Access the parameter via POST
|
||||
r.ParseForm()
|
||||
x := r.Form.Get(getParamter)
|
||||
if len(x) == 0 || x == "" {
|
||||
return "", errors.New("POST paramter " + getParamter + " not found or it is empty")
|
||||
}
|
||||
return string(x), nil
|
||||
}
|
||||
|
||||
}
|
||||
*/
|
||||
|
||||
// Get GET parameter
|
||||
func GetPara(r *http.Request, key string) (string, error) {
|
||||
keys, ok := r.URL.Query()[key]
|
||||
if !ok || len(keys[0]) < 1 {
|
||||
return "", errors.New("invalid " + key + " given")
|
||||
} else {
|
||||
return keys[0], nil
|
||||
}
|
||||
}
|
||||
|
||||
// Get POST paramter
|
||||
func PostPara(r *http.Request, key string) (string, error) {
|
||||
r.ParseForm()
|
||||
x := r.Form.Get(key)
|
||||
if x == "" {
|
||||
return "", errors.New("invalid " + key + " given")
|
||||
} else {
|
||||
return x, nil
|
||||
}
|
||||
}
|
||||
|
||||
func FileExists(filename string) bool {
|
||||
_, err := os.Stat(filename)
|
||||
if os.IsNotExist(err) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func IsDir(path string) bool {
|
||||
if FileExists(path) == false {
|
||||
return false
|
||||
}
|
||||
fi, err := os.Stat(path)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
return false
|
||||
}
|
||||
switch mode := fi.Mode(); {
|
||||
case mode.IsDir():
|
||||
return true
|
||||
case mode.IsRegular():
|
||||
return false
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func TimeToString(targetTime time.Time) string {
|
||||
return targetTime.Format("2006-01-02 15:04:05")
|
||||
}
|
||||
|
||||
func LoadImageAsBase64(filepath string) (string, error) {
|
||||
if !FileExists(filepath) {
|
||||
return "", errors.New("File not exists")
|
||||
}
|
||||
f, _ := os.Open(filepath)
|
||||
reader := bufio.NewReader(f)
|
||||
content, _ := io.ReadAll(reader)
|
||||
encoded := base64.StdEncoding.EncodeToString(content)
|
||||
return string(encoded), nil
|
||||
}
|
||||
|
||||
// Use for redirections
|
||||
func ConstructRelativePathFromRequestURL(requestURI string, redirectionLocation string) string {
|
||||
if strings.Count(requestURI, "/") == 1 {
|
||||
//Already root level
|
||||
return redirectionLocation
|
||||
}
|
||||
for i := 0; i < strings.Count(requestURI, "/")-1; i++ {
|
||||
redirectionLocation = "../" + redirectionLocation
|
||||
}
|
||||
|
||||
return redirectionLocation
|
||||
}
|
||||
|
||||
// Check if given string in a given slice
|
||||
func StringInArray(arr []string, str string) bool {
|
||||
for _, a := range arr {
|
||||
if a == str {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func StringInArrayIgnoreCase(arr []string, str string) bool {
|
||||
smallArray := []string{}
|
||||
for _, item := range arr {
|
||||
smallArray = append(smallArray, strings.ToLower(item))
|
||||
}
|
||||
|
||||
return StringInArray(smallArray, strings.ToLower(str))
|
||||
}
|
@@ -1,20 +0,0 @@
|
||||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2014 Koding, Inc.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy of
|
||||
this software and associated documentation files (the "Software"), to deal in
|
||||
the Software without restriction, including without limitation the rights to
|
||||
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
||||
the Software, and to permit persons to whom the Software is furnished to do so,
|
||||
subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
||||
FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
||||
COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
||||
IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
||||
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
@@ -1,54 +0,0 @@
|
||||
# WebsocketProxy [](https://godoc.org/github.com/koding/websocketproxy) [](https://travis-ci.org/koding/websocketproxy)
|
||||
|
||||
WebsocketProxy is an http.Handler interface build on top of
|
||||
[gorilla/websocket](https://github.com/gorilla/websocket) that you can plug
|
||||
into your existing Go webserver to provide WebSocket reverse proxy.
|
||||
|
||||
## Install
|
||||
|
||||
```bash
|
||||
go get github.com/koding/websocketproxy
|
||||
```
|
||||
|
||||
## Example
|
||||
|
||||
Below is a simple server that proxies to the given backend URL
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/koding/websocketproxy"
|
||||
)
|
||||
|
||||
var (
|
||||
flagBackend = flag.String("backend", "", "Backend URL for proxying")
|
||||
)
|
||||
|
||||
func main() {
|
||||
u, err := url.Parse(*flagBackend)
|
||||
if err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
|
||||
err = http.ListenAndServe(":80", websocketproxy.NewProxy(u))
|
||||
if err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Save it as `proxy.go` and run as:
|
||||
|
||||
```bash
|
||||
go run proxy.go -backend ws://example.com:3000
|
||||
```
|
||||
|
||||
Now all incoming WebSocket requests coming to this server will be proxied to
|
||||
`ws://example.com:3000`
|
||||
|
||||
|
@@ -1,239 +0,0 @@
|
||||
// Package websocketproxy is a reverse proxy for WebSocket connections.
|
||||
package websocketproxy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
var (
|
||||
// DefaultUpgrader specifies the parameters for upgrading an HTTP
|
||||
// connection to a WebSocket connection.
|
||||
DefaultUpgrader = &websocket.Upgrader{
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 1024,
|
||||
}
|
||||
|
||||
// DefaultDialer is a dialer with all fields set to the default zero values.
|
||||
DefaultDialer = websocket.DefaultDialer
|
||||
)
|
||||
|
||||
// WebsocketProxy is an HTTP Handler that takes an incoming WebSocket
|
||||
// connection and proxies it to another server.
|
||||
type WebsocketProxy struct {
|
||||
// Director, if non-nil, is a function that may copy additional request
|
||||
// headers from the incoming WebSocket connection into the output headers
|
||||
// which will be forwarded to another server.
|
||||
Director func(incoming *http.Request, out http.Header)
|
||||
|
||||
// Backend returns the backend URL which the proxy uses to reverse proxy
|
||||
// the incoming WebSocket connection. Request is the initial incoming and
|
||||
// unmodified request.
|
||||
Backend func(*http.Request) *url.URL
|
||||
|
||||
// Upgrader specifies the parameters for upgrading a incoming HTTP
|
||||
// connection to a WebSocket connection. If nil, DefaultUpgrader is used.
|
||||
Upgrader *websocket.Upgrader
|
||||
|
||||
// Dialer contains options for connecting to the backend WebSocket server.
|
||||
// If nil, DefaultDialer is used.
|
||||
Dialer *websocket.Dialer
|
||||
|
||||
Verbal bool
|
||||
}
|
||||
|
||||
// ProxyHandler returns a new http.Handler interface that reverse proxies the
|
||||
// request to the given target.
|
||||
func ProxyHandler(target *url.URL) http.Handler { return NewProxy(target) }
|
||||
|
||||
// NewProxy returns a new Websocket reverse proxy that rewrites the
|
||||
// URL's to the scheme, host and base path provider in target.
|
||||
func NewProxy(target *url.URL) *WebsocketProxy {
|
||||
backend := func(r *http.Request) *url.URL {
|
||||
// Shallow copy
|
||||
u := *target
|
||||
u.Fragment = r.URL.Fragment
|
||||
u.Path = r.URL.Path
|
||||
u.RawQuery = r.URL.RawQuery
|
||||
return &u
|
||||
}
|
||||
return &WebsocketProxy{Backend: backend, Verbal: false}
|
||||
}
|
||||
|
||||
// ServeHTTP implements the http.Handler that proxies WebSocket connections.
|
||||
func (w *WebsocketProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
if w.Backend == nil {
|
||||
log.Println("websocketproxy: backend function is not defined")
|
||||
http.Error(rw, "internal server error (code: 1)", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
backendURL := w.Backend(req)
|
||||
if backendURL == nil {
|
||||
log.Println("websocketproxy: backend URL is nil")
|
||||
http.Error(rw, "internal server error (code: 2)", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
dialer := w.Dialer
|
||||
if w.Dialer == nil {
|
||||
dialer = DefaultDialer
|
||||
}
|
||||
|
||||
// Pass headers from the incoming request to the dialer to forward them to
|
||||
// the final destinations.
|
||||
requestHeader := http.Header{}
|
||||
if origin := req.Header.Get("Origin"); origin != "" {
|
||||
requestHeader.Add("Origin", origin)
|
||||
}
|
||||
for _, prot := range req.Header[http.CanonicalHeaderKey("Sec-WebSocket-Protocol")] {
|
||||
requestHeader.Add("Sec-WebSocket-Protocol", prot)
|
||||
}
|
||||
for _, cookie := range req.Header[http.CanonicalHeaderKey("Cookie")] {
|
||||
requestHeader.Add("Cookie", cookie)
|
||||
}
|
||||
if req.Host != "" {
|
||||
requestHeader.Set("Host", req.Host)
|
||||
}
|
||||
|
||||
// Pass X-Forwarded-For headers too, code below is a part of
|
||||
// httputil.ReverseProxy. See http://en.wikipedia.org/wiki/X-Forwarded-For
|
||||
// for more information
|
||||
// TODO: use RFC7239 http://tools.ietf.org/html/rfc7239
|
||||
if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
|
||||
// If we aren't the first proxy retain prior
|
||||
// X-Forwarded-For information as a comma+space
|
||||
// separated list and fold multiple headers into one.
|
||||
if prior, ok := req.Header["X-Forwarded-For"]; ok {
|
||||
clientIP = strings.Join(prior, ", ") + ", " + clientIP
|
||||
}
|
||||
requestHeader.Set("X-Forwarded-For", clientIP)
|
||||
}
|
||||
|
||||
// Set the originating protocol of the incoming HTTP request. The SSL might
|
||||
// be terminated on our site and because we doing proxy adding this would
|
||||
// be helpful for applications on the backend.
|
||||
requestHeader.Set("X-Forwarded-Proto", "http")
|
||||
if req.TLS != nil {
|
||||
requestHeader.Set("X-Forwarded-Proto", "https")
|
||||
}
|
||||
|
||||
// Enable the director to copy any additional headers it desires for
|
||||
// forwarding to the remote server.
|
||||
if w.Director != nil {
|
||||
w.Director(req, requestHeader)
|
||||
}
|
||||
|
||||
// Connect to the backend URL, also pass the headers we get from the requst
|
||||
// together with the Forwarded headers we prepared above.
|
||||
// TODO: support multiplexing on the same backend connection instead of
|
||||
// opening a new TCP connection time for each request. This should be
|
||||
// optional:
|
||||
// http://tools.ietf.org/html/draft-ietf-hybi-websocket-multiplexing-01
|
||||
connBackend, resp, err := dialer.Dial(backendURL.String(), requestHeader)
|
||||
if err != nil {
|
||||
log.Printf("websocketproxy: couldn't dial to remote backend url %s", err)
|
||||
if resp != nil {
|
||||
// If the WebSocket handshake fails, ErrBadHandshake is returned
|
||||
// along with a non-nil *http.Response so that callers can handle
|
||||
// redirects, authentication, etcetera.
|
||||
if err := copyResponse(rw, resp); err != nil {
|
||||
log.Printf("websocketproxy: couldn't write response after failed remote backend handshake: %s", err)
|
||||
}
|
||||
} else {
|
||||
http.Error(rw, http.StatusText(http.StatusServiceUnavailable), http.StatusServiceUnavailable)
|
||||
}
|
||||
return
|
||||
}
|
||||
defer connBackend.Close()
|
||||
|
||||
upgrader := w.Upgrader
|
||||
if w.Upgrader == nil {
|
||||
upgrader = DefaultUpgrader
|
||||
}
|
||||
|
||||
// Only pass those headers to the upgrader.
|
||||
upgradeHeader := http.Header{}
|
||||
if hdr := resp.Header.Get("Sec-Websocket-Protocol"); hdr != "" {
|
||||
upgradeHeader.Set("Sec-Websocket-Protocol", hdr)
|
||||
}
|
||||
if hdr := resp.Header.Get("Set-Cookie"); hdr != "" {
|
||||
upgradeHeader.Set("Set-Cookie", hdr)
|
||||
}
|
||||
|
||||
// Now upgrade the existing incoming request to a WebSocket connection.
|
||||
// Also pass the header that we gathered from the Dial handshake.
|
||||
connPub, err := upgrader.Upgrade(rw, req, upgradeHeader)
|
||||
if err != nil {
|
||||
log.Printf("websocketproxy: couldn't upgrade %s", err)
|
||||
return
|
||||
}
|
||||
defer connPub.Close()
|
||||
|
||||
errClient := make(chan error, 1)
|
||||
errBackend := make(chan error, 1)
|
||||
replicateWebsocketConn := func(dst, src *websocket.Conn, errc chan error) {
|
||||
for {
|
||||
msgType, msg, err := src.ReadMessage()
|
||||
if err != nil {
|
||||
m := websocket.FormatCloseMessage(websocket.CloseNormalClosure, fmt.Sprintf("%v", err))
|
||||
if e, ok := err.(*websocket.CloseError); ok {
|
||||
if e.Code != websocket.CloseNoStatusReceived {
|
||||
m = websocket.FormatCloseMessage(e.Code, e.Text)
|
||||
}
|
||||
}
|
||||
errc <- err
|
||||
dst.WriteMessage(websocket.CloseMessage, m)
|
||||
break
|
||||
}
|
||||
err = dst.WriteMessage(msgType, msg)
|
||||
if err != nil {
|
||||
errc <- err
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
go replicateWebsocketConn(connPub, connBackend, errClient)
|
||||
go replicateWebsocketConn(connBackend, connPub, errBackend)
|
||||
|
||||
var message string
|
||||
select {
|
||||
case err = <-errClient:
|
||||
message = "websocketproxy: Error when copying from backend to client: %v"
|
||||
case err = <-errBackend:
|
||||
message = "websocketproxy: Error when copying from client to backend: %v"
|
||||
|
||||
}
|
||||
if e, ok := err.(*websocket.CloseError); !ok || e.Code == websocket.CloseAbnormalClosure {
|
||||
if w.Verbal {
|
||||
//Only print message on verbal mode
|
||||
log.Printf(message, err)
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
func copyHeader(dst, src http.Header) {
|
||||
for k, vv := range src {
|
||||
for _, v := range vv {
|
||||
dst.Add(k, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func copyResponse(rw http.ResponseWriter, resp *http.Response) error {
|
||||
copyHeader(rw.Header(), resp.Header)
|
||||
rw.WriteHeader(resp.StatusCode)
|
||||
defer resp.Body.Close()
|
||||
|
||||
_, err := io.Copy(rw, resp.Body)
|
||||
return err
|
||||
}
|
@@ -1,130 +0,0 @@
|
||||
package websocketproxy
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
var (
|
||||
serverURL = "ws://127.0.0.1:7777"
|
||||
backendURL = "ws://127.0.0.1:8888"
|
||||
)
|
||||
|
||||
func TestProxy(t *testing.T) {
|
||||
// websocket proxy
|
||||
supportedSubProtocols := []string{"test-protocol"}
|
||||
upgrader := &websocket.Upgrader{
|
||||
ReadBufferSize: 4096,
|
||||
WriteBufferSize: 4096,
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
return true
|
||||
},
|
||||
Subprotocols: supportedSubProtocols,
|
||||
}
|
||||
|
||||
u, _ := url.Parse(backendURL)
|
||||
proxy := NewProxy(u)
|
||||
proxy.Upgrader = upgrader
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle("/proxy", proxy)
|
||||
go func() {
|
||||
if err := http.ListenAndServe(":7777", mux); err != nil {
|
||||
t.Fatal("ListenAndServe: ", err)
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(time.Millisecond * 100)
|
||||
|
||||
// backend echo server
|
||||
go func() {
|
||||
mux2 := http.NewServeMux()
|
||||
mux2.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
// Don't upgrade if original host header isn't preserved
|
||||
if r.Host != "127.0.0.1:7777" {
|
||||
log.Printf("Host header set incorrectly. Expecting 127.0.0.1:7777 got %s", r.Host)
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return
|
||||
}
|
||||
|
||||
messageType, p, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err = conn.WriteMessage(messageType, p); err != nil {
|
||||
return
|
||||
}
|
||||
})
|
||||
|
||||
err := http.ListenAndServe(":8888", mux2)
|
||||
if err != nil {
|
||||
t.Fatal("ListenAndServe: ", err)
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(time.Millisecond * 100)
|
||||
|
||||
// let's us define two subprotocols, only one is supported by the server
|
||||
clientSubProtocols := []string{"test-protocol", "test-notsupported"}
|
||||
h := http.Header{}
|
||||
for _, subprot := range clientSubProtocols {
|
||||
h.Add("Sec-WebSocket-Protocol", subprot)
|
||||
}
|
||||
|
||||
// frontend server, dial now our proxy, which will reverse proxy our
|
||||
// message to the backend websocket server.
|
||||
conn, resp, err := websocket.DefaultDialer.Dial(serverURL+"/proxy", h)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// check if the server really accepted only the first one
|
||||
in := func(desired string) bool {
|
||||
for _, prot := range resp.Header[http.CanonicalHeaderKey("Sec-WebSocket-Protocol")] {
|
||||
if desired == prot {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
if !in("test-protocol") {
|
||||
t.Error("test-protocol should be available")
|
||||
}
|
||||
|
||||
if in("test-notsupported") {
|
||||
t.Error("test-notsupported should be not recevied from the server.")
|
||||
}
|
||||
|
||||
// now write a message and send it to the backend server (which goes trough
|
||||
// proxy..)
|
||||
msg := "hello kite"
|
||||
err = conn.WriteMessage(websocket.TextMessage, []byte(msg))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
messageType, p, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if messageType != websocket.TextMessage {
|
||||
t.Error("incoming message type is not Text")
|
||||
}
|
||||
|
||||
if msg != string(p) {
|
||||
t.Errorf("expecting: %s, got: %s", msg, string(p))
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user