Removed alpha prototype source

This commit is contained in:
Toby Chui
2023-05-04 20:42:35 +08:00
parent 2c586aee32
commit a1d779a0ce
275 changed files with 0 additions and 150920 deletions

View File

@@ -1,76 +0,0 @@
package aroz
import (
"encoding/json"
"flag"
"fmt"
"net/http"
"net/url"
"os"
)
//To be used with arozos system
type ArozHandler struct {
Port string
restfulEndpoint string
}
//Information required for registering this subservice to arozos
type ServiceInfo struct {
Name string //Name of this module. e.g. "Audio"
Desc string //Description for this module
Group string //Group of the module, e.g. "system" / "media" etc
IconPath string //Module icon image path e.g. "Audio/img/function_icon.png"
Version string //Version of the module. Format: [0-9]*.[0-9][0-9].[0-9]
StartDir string //Default starting dir, e.g. "Audio/index.html"
SupportFW bool //Support floatWindow. If yes, floatWindow dir will be loaded
LaunchFWDir string //This link will be launched instead of 'StartDir' if fw mode
SupportEmb bool //Support embedded mode
LaunchEmb string //This link will be launched instead of StartDir / Fw if a file is opened with this module
InitFWSize []int //Floatwindow init size. [0] => Width, [1] => Height
InitEmbSize []int //Embedded mode init size. [0] => Width, [1] => Height
SupportedExt []string //Supported File Extensions. e.g. ".mp3", ".flac", ".wav"
}
//This function will request the required flag from the startup paramters and parse it to the need of the arozos.
func HandleFlagParse(info ServiceInfo) *ArozHandler {
var infoRequestMode = flag.Bool("info", false, "Show information about this program in JSON")
var port = flag.String("port", ":8000", "Management web interface listening port")
var restful = flag.String("rpt", "", "Reserved")
//Parse the flags
flag.Parse()
if *infoRequestMode {
//Information request mode
jsonString, _ := json.MarshalIndent(info, "", " ")
fmt.Println(string(jsonString))
os.Exit(0)
}
return &ArozHandler{
Port: *port,
restfulEndpoint: *restful,
}
}
//Get the username and resources access token from the request, return username, token
func (a *ArozHandler) GetUserInfoFromRequest(w http.ResponseWriter, r *http.Request) (string, string) {
username := r.Header.Get("aouser")
token := r.Header.Get("aotoken")
return username, token
}
func (a *ArozHandler) IsUsingExternalPermissionManager() bool {
return !(a.restfulEndpoint == "")
}
//Request gateway interface for advance permission sandbox control
func (a *ArozHandler) RequestGatewayInterface(token string, script string) (*http.Response, error) {
resp, err := http.PostForm(a.restfulEndpoint,
url.Values{"token": {token}, "script": {script}})
if err != nil {
// handle error
return nil, err
}
return resp, nil
}

Binary file not shown.

View File

@@ -1,478 +0,0 @@
package auth
/*
author: tobychui
*/
import (
"crypto/rand"
"crypto/sha512"
"errors"
"net/http"
"net/mail"
"strings"
"encoding/hex"
"log"
"github.com/gorilla/sessions"
db "imuslab.com/zoraxy/mod/database"
"imuslab.com/zoraxy/mod/utils"
)
type AuthAgent struct {
//Session related
SessionName string
SessionStore *sessions.CookieStore
Database *db.Database
LoginRedirectionHandler func(http.ResponseWriter, *http.Request)
}
type AuthEndpoints struct {
Login string
Logout string
Register string
CheckLoggedIn string
Autologin string
}
//Constructor
func NewAuthenticationAgent(sessionName string, key []byte, sysdb *db.Database, allowReg bool, loginRedirectionHandler func(http.ResponseWriter, *http.Request)) *AuthAgent {
store := sessions.NewCookieStore(key)
err := sysdb.NewTable("auth")
if err != nil {
log.Println("Failed to create auth database. Terminating.")
panic(err)
}
//Create a new AuthAgent object
newAuthAgent := AuthAgent{
SessionName: sessionName,
SessionStore: store,
Database: sysdb,
LoginRedirectionHandler: loginRedirectionHandler,
}
//Return the authAgent
return &newAuthAgent
}
func GetSessionKey(sysdb *db.Database) (string, error) {
sysdb.NewTable("auth")
sessionKey := ""
if !sysdb.KeyExists("auth", "sessionkey") {
key := make([]byte, 32)
rand.Read(key)
sessionKey = string(key)
sysdb.Write("auth", "sessionkey", sessionKey)
log.Println("[Auth] New authentication session key generated")
} else {
log.Println("[Auth] Authentication session key loaded from database")
err := sysdb.Read("auth", "sessionkey", &sessionKey)
if err != nil {
return "", errors.New("database read error. Is the database file corrupted?")
}
}
return sessionKey, nil
}
//This function will handle an http request and redirect to the given login address if not logged in
func (a *AuthAgent) HandleCheckAuth(w http.ResponseWriter, r *http.Request, handler func(http.ResponseWriter, *http.Request)) {
if a.CheckAuth(r) {
//User already logged in
handler(w, r)
} else {
//User not logged in
a.LoginRedirectionHandler(w, r)
}
}
//Handle login request, require POST username and password
func (a *AuthAgent) HandleLogin(w http.ResponseWriter, r *http.Request) {
//Get username from request using POST mode
username, err := utils.PostPara(r, "username")
if err != nil {
//Username not defined
log.Println("[Auth] " + r.RemoteAddr + " trying to login with username: " + username)
utils.SendErrorResponse(w, "Username not defined or empty.")
return
}
//Get password from request using POST mode
password, err := utils.PostPara(r, "password")
if err != nil {
//Password not defined
utils.SendErrorResponse(w, "Password not defined or empty.")
return
}
//Get rememberme settings
rememberme := false
rmbme, _ := utils.PostPara(r, "rmbme")
if rmbme == "true" {
rememberme = true
}
//Check the database and see if this user is in the database
passwordCorrect, rejectionReason := a.ValidateUsernameAndPasswordWithReason(username, password)
//The database contain this user information. Check its password if it is correct
if passwordCorrect {
//Password correct
// Set user as authenticated
a.LoginUserByRequest(w, r, username, rememberme)
//Print the login message to console
log.Println(username + " logged in.")
utils.SendOK(w)
} else {
//Password incorrect
log.Println(username + " login request rejected: " + rejectionReason)
utils.SendErrorResponse(w, rejectionReason)
return
}
}
func (a *AuthAgent) ValidateUsernameAndPassword(username string, password string) bool {
succ, _ := a.ValidateUsernameAndPasswordWithReason(username, password)
return succ
}
//validate the username and password, return reasons if the auth failed
func (a *AuthAgent) ValidateUsernameAndPasswordWithReason(username string, password string) (bool, string) {
hashedPassword := Hash(password)
var passwordInDB string
err := a.Database.Read("auth", "passhash/"+username, &passwordInDB)
if err != nil {
//User not found or db exception
log.Println("[Auth] " + username + " login with incorrect password")
return false, "Invalid username or password"
}
if passwordInDB == hashedPassword {
return true, ""
} else {
return false, "Invalid username or password"
}
}
//Login the user by creating a valid session for this user
func (a *AuthAgent) LoginUserByRequest(w http.ResponseWriter, r *http.Request, username string, rememberme bool) {
session, _ := a.SessionStore.Get(r, a.SessionName)
session.Values["authenticated"] = true
session.Values["username"] = username
session.Values["rememberMe"] = rememberme
//Check if remember me is clicked. If yes, set the maxage to 1 week.
if rememberme {
session.Options = &sessions.Options{
MaxAge: 3600 * 24 * 7, //One week
Path: "/",
}
} else {
session.Options = &sessions.Options{
MaxAge: 3600 * 1, //One hour
Path: "/",
}
}
session.Save(r, w)
}
//Handle logout, reply OK after logged out. WILL NOT DO REDIRECTION
func (a *AuthAgent) HandleLogout(w http.ResponseWriter, r *http.Request) {
username, err := a.GetUserName(w, r)
if username != "" {
log.Println(username + " logged out.")
}
// Revoke users authentication
err = a.Logout(w, r)
if err != nil {
utils.SendErrorResponse(w, "Logout failed")
return
}
w.Write([]byte("OK"))
}
func (a *AuthAgent) Logout(w http.ResponseWriter, r *http.Request) error {
session, err := a.SessionStore.Get(r, a.SessionName)
if err != nil {
return err
}
session.Values["authenticated"] = false
session.Values["username"] = nil
session.Save(r, w)
return nil
}
//Get the current session username from request
func (a *AuthAgent) GetUserName(w http.ResponseWriter, r *http.Request) (string, error) {
if a.CheckAuth(r) {
//This user has logged in.
session, _ := a.SessionStore.Get(r, a.SessionName)
return session.Values["username"].(string), nil
} else {
//This user has not logged in.
return "", errors.New("user not logged in")
}
}
//Get the current session user email from request
func (a *AuthAgent) GetUserEmail(w http.ResponseWriter, r *http.Request) (string, error) {
if a.CheckAuth(r) {
//This user has logged in.
session, _ := a.SessionStore.Get(r, a.SessionName)
username := session.Values["username"].(string)
userEmail := ""
err := a.Database.Read("auth", "email/"+username, &userEmail)
if err != nil {
return "", err
}
return userEmail, nil
} else {
//This user has not logged in.
return "", errors.New("user not logged in")
}
}
//Check if the user has logged in, return true / false in JSON
func (a *AuthAgent) CheckLogin(w http.ResponseWriter, r *http.Request) {
if a.CheckAuth(r) {
utils.SendJSONResponse(w, "true")
} else {
utils.SendJSONResponse(w, "false")
}
}
//Handle new user register. Require POST username, password, group.
func (a *AuthAgent) HandleRegister(w http.ResponseWriter, r *http.Request, callback func(string, string)) {
//Get username from request
newusername, err := utils.PostPara(r, "username")
if err != nil {
utils.SendErrorResponse(w, "Missing 'username' paramter")
return
}
//Get password from request
password, err := utils.PostPara(r, "password")
if err != nil {
utils.SendErrorResponse(w, "Missing 'password' paramter")
return
}
//Get email from request
email, err := utils.PostPara(r, "email")
if err != nil {
utils.SendErrorResponse(w, "Missing 'email' paramter")
return
}
_, err = mail.ParseAddress(email)
if err != nil {
utils.SendErrorResponse(w, "Invalid or malformed email")
return
}
//Ok to proceed create this user
err = a.CreateUserAccount(newusername, password, email)
if err != nil {
utils.SendErrorResponse(w, err.Error())
return
}
//Do callback if exists
if callback != nil {
callback(newusername, email)
}
//Return to the client with OK
utils.SendOK(w)
log.Println("[Auth] New user " + newusername + " added to system.")
}
//Handle new user register without confirmation email. Require POST username, password, group.
func (a *AuthAgent) HandleRegisterWithoutEmail(w http.ResponseWriter, r *http.Request, callback func(string, string)) {
//Get username from request
newusername, err := utils.PostPara(r, "username")
if err != nil {
utils.SendErrorResponse(w, "Missing 'username' paramter")
return
}
//Get password from request
password, err := utils.PostPara(r, "password")
if err != nil {
utils.SendErrorResponse(w, "Missing 'password' paramter")
return
}
//Ok to proceed create this user
err = a.CreateUserAccount(newusername, password, "")
if err != nil {
utils.SendErrorResponse(w, err.Error())
return
}
//Do callback if exists
if callback != nil {
callback(newusername, "")
}
//Return to the client with OK
utils.SendOK(w)
log.Println("[Auth] Admin account created: " + newusername)
}
//Check authentication from request header's session value
func (a *AuthAgent) CheckAuth(r *http.Request) bool {
session, err := a.SessionStore.Get(r, a.SessionName)
if err != nil {
return false
}
// Check if user is authenticated
if auth, ok := session.Values["authenticated"].(bool); !ok || !auth {
return false
}
return true
}
//Handle de-register of users. Require POST username.
//THIS FUNCTION WILL NOT CHECK FOR PERMISSION. PLEASE USE WITH PERMISSION HANDLER
func (a *AuthAgent) HandleUnregister(w http.ResponseWriter, r *http.Request) {
//Check if the user is logged in
if !a.CheckAuth(r) {
//This user has not logged in
utils.SendErrorResponse(w, "Login required to remove user from the system.")
return
}
//Get username from request
username, err := utils.PostPara(r, "username")
if err != nil {
utils.SendErrorResponse(w, "Missing 'username' paramter")
return
}
err = a.UnregisterUser(username)
if err != nil {
utils.SendErrorResponse(w, err.Error())
return
}
//Return to the client with OK
utils.SendOK(w)
log.Println("[Auth] User " + username + " has been removed from the system.")
}
func (a *AuthAgent) UnregisterUser(username string) error {
//Check if the user exists in the system database.
if !a.Database.KeyExists("auth", "passhash/"+username) {
//This user do not exists.
return errors.New("this user does not exists")
}
//OK! Remove the user from the database
a.Database.Delete("auth", "passhash/"+username)
a.Database.Delete("auth", "email/"+username)
return nil
}
//Get the number of users in the system
func (a *AuthAgent) GetUserCounts() int {
entries, _ := a.Database.ListTable("auth")
usercount := 0
for _, keypairs := range entries {
if strings.Contains(string(keypairs[0]), "passhash/") {
//This is a user registry
usercount++
}
}
if usercount == 0 {
log.Println("There are no user in the database.")
}
return usercount
}
//List all username within the system
func (a *AuthAgent) ListUsers() []string {
entries, _ := a.Database.ListTable("auth")
results := []string{}
for _, keypairs := range entries {
if strings.Contains(string(keypairs[0]), "passhash/") {
username := strings.Split(string(keypairs[0]), "/")[1]
results = append(results, username)
}
}
return results
}
//Check if the given username exists
func (a *AuthAgent) UserExists(username string) bool {
userpasswordhash := ""
err := a.Database.Read("auth", "passhash/"+username, &userpasswordhash)
if err != nil || userpasswordhash == "" {
return false
}
return true
}
//Update the session expire time given the request header.
func (a *AuthAgent) UpdateSessionExpireTime(w http.ResponseWriter, r *http.Request) bool {
session, _ := a.SessionStore.Get(r, a.SessionName)
if session.Values["authenticated"].(bool) {
//User authenticated. Extend its expire time
rememberme := session.Values["rememberMe"].(bool)
//Extend the session expire time
if rememberme {
session.Options = &sessions.Options{
MaxAge: 3600 * 24 * 7, //One week
Path: "/",
}
} else {
session.Options = &sessions.Options{
MaxAge: 3600 * 1, //One hour
Path: "/",
}
}
session.Save(r, w)
return true
} else {
return false
}
}
//Create user account
func (a *AuthAgent) CreateUserAccount(newusername string, password string, email string) error {
//Check user already exists
if a.UserExists(newusername) {
return errors.New("user with same name already exists")
}
key := newusername
hashedPassword := Hash(password)
err := a.Database.Write("auth", "passhash/"+key, hashedPassword)
if err != nil {
return err
}
if email != "" {
err = a.Database.Write("auth", "email/"+key, email)
if err != nil {
return err
}
}
return nil
}
//Hash the given raw string into sha512 hash
func Hash(raw string) string {
h := sha512.New()
h.Write([]byte(raw))
return hex.EncodeToString(h.Sum(nil))
}

View File

@@ -1,53 +0,0 @@
package auth
import (
"errors"
"log"
"net/http"
)
type RouterOption struct {
AuthAgent *AuthAgent
RequireAuth bool //This router require authentication
DeniedHandler func(http.ResponseWriter, *http.Request) //Things to do when request is rejected
}
type RouterDef struct {
option RouterOption
endpoints map[string]func(http.ResponseWriter, *http.Request)
}
func NewManagedHTTPRouter(option RouterOption) *RouterDef {
return &RouterDef{
option: option,
endpoints: map[string]func(http.ResponseWriter, *http.Request){},
}
}
func (router *RouterDef) HandleFunc(endpoint string, handler func(http.ResponseWriter, *http.Request)) error {
//Check if the endpoint already registered
if _, exist := router.endpoints[endpoint]; exist {
log.Println("WARNING! Duplicated registering of web endpoint: " + endpoint)
return errors.New("endpoint register duplicated")
}
authAgent := router.option.AuthAgent
//OK. Register handler
http.HandleFunc(endpoint, func(w http.ResponseWriter, r *http.Request) {
//Check authentication of the user
if router.option.RequireAuth {
authAgent.HandleCheckAuth(w, r, func(w http.ResponseWriter, r *http.Request) {
handler(w, r)
})
} else {
handler(w, r)
}
})
router.endpoints[endpoint] = handler
return nil
}

View File

@@ -1,120 +0,0 @@
package database
/*
ArOZ Online Database Access Module
author: tobychui
This is an improved Object oriented base solution to the original
aroz online database script.
*/
import (
"sync"
)
type Database struct {
Db interface{} //This will be nil on openwrt and *bolt.DB in the rest of the systems
Tables sync.Map
ReadOnly bool
}
func NewDatabase(dbfile string, readOnlyMode bool) (*Database, error) {
return newDatabase(dbfile, readOnlyMode)
}
/*
Create / Drop a table
Usage:
err := sysdb.NewTable("MyTable")
err := sysdb.DropTable("MyTable")
*/
func (d *Database) UpdateReadWriteMode(readOnly bool) {
d.ReadOnly = readOnly
}
//Dump the whole db into a log file
func (d *Database) Dump(filename string) ([]string, error) {
return d.dump(filename)
}
//Create a new table
func (d *Database) NewTable(tableName string) error {
return d.newTable(tableName)
}
//Check is table exists
func (d *Database) TableExists(tableName string) bool {
return d.tableExists(tableName)
}
//Drop the given table
func (d *Database) DropTable(tableName string) error {
return d.dropTable(tableName)
}
/*
Write to database with given tablename and key. Example Usage:
type demo struct{
content string
}
thisDemo := demo{
content: "Hello World",
}
err := sysdb.Write("MyTable", "username/message",thisDemo);
*/
func (d *Database) Write(tableName string, key string, value interface{}) error {
return d.write(tableName, key, value)
}
/*
Read from database and assign the content to a given datatype. Example Usage:
type demo struct{
content string
}
thisDemo := new(demo)
err := sysdb.Read("MyTable", "username/message",&thisDemo);
*/
func (d *Database) Read(tableName string, key string, assignee interface{}) error {
return d.read(tableName, key, assignee)
}
func (d *Database) KeyExists(tableName string, key string) bool {
return d.keyExists(tableName, key)
}
/*
Delete a value from the database table given tablename and key
err := sysdb.Delete("MyTable", "username/message");
*/
func (d *Database) Delete(tableName string, key string) error {
return d.delete(tableName, key)
}
/*
//List table example usage
//Assume the value is stored as a struct named "groupstruct"
entries, err := sysdb.ListTable("test")
if err != nil {
panic(err)
}
for _, keypairs := range entries{
log.Println(string(keypairs[0]))
group := new(groupstruct)
json.Unmarshal(keypairs[1], &group)
log.Println(group);
}
*/
func (d *Database) ListTable(tableName string) ([][][]byte, error) {
return d.listTable(tableName)
}
func (d *Database) Close() {
d.close()
}

View File

@@ -1,186 +0,0 @@
//go:build !mipsle && !riscv64
// +build !mipsle,!riscv64
package database
import (
"encoding/json"
"errors"
"log"
"sync"
"github.com/boltdb/bolt"
)
func newDatabase(dbfile string, readOnlyMode bool) (*Database, error) {
db, err := bolt.Open(dbfile, 0600, nil)
if err != nil {
return nil, err
}
tableMap := sync.Map{}
//Build the table list from database
err = db.View(func(tx *bolt.Tx) error {
return tx.ForEach(func(name []byte, _ *bolt.Bucket) error {
tableMap.Store(string(name), "")
return nil
})
})
return &Database{
Db: db,
Tables: tableMap,
ReadOnly: readOnlyMode,
}, err
}
//Dump the whole db into a log file
func (d *Database) dump(filename string) ([]string, error) {
results := []string{}
d.Tables.Range(func(tableName, v interface{}) bool {
entries, err := d.ListTable(tableName.(string))
if err != nil {
log.Println("Reading table " + tableName.(string) + " failed: " + err.Error())
return false
}
for _, keypairs := range entries {
results = append(results, string(keypairs[0])+":"+string(keypairs[1])+"\n")
}
return true
})
return results, nil
}
//Create a new table
func (d *Database) newTable(tableName string) error {
if d.ReadOnly == true {
return errors.New("Operation rejected in ReadOnly mode")
}
err := d.Db.(*bolt.DB).Update(func(tx *bolt.Tx) error {
_, err := tx.CreateBucketIfNotExists([]byte(tableName))
if err != nil {
return err
}
return nil
})
d.Tables.Store(tableName, "")
return err
}
//Check is table exists
func (d *Database) tableExists(tableName string) bool {
if _, ok := d.Tables.Load(tableName); ok {
return true
}
return false
}
//Drop the given table
func (d *Database) dropTable(tableName string) error {
if d.ReadOnly == true {
return errors.New("Operation rejected in ReadOnly mode")
}
err := d.Db.(*bolt.DB).Update(func(tx *bolt.Tx) error {
err := tx.DeleteBucket([]byte(tableName))
if err != nil {
return err
}
return nil
})
return err
}
//Write to table
func (d *Database) write(tableName string, key string, value interface{}) error {
if d.ReadOnly {
return errors.New("Operation rejected in ReadOnly mode")
}
jsonString, err := json.Marshal(value)
if err != nil {
return err
}
err = d.Db.(*bolt.DB).Update(func(tx *bolt.Tx) error {
_, err := tx.CreateBucketIfNotExists([]byte(tableName))
if err != nil {
return err
}
b := tx.Bucket([]byte(tableName))
err = b.Put([]byte(key), jsonString)
return err
})
return err
}
func (d *Database) read(tableName string, key string, assignee interface{}) error {
err := d.Db.(*bolt.DB).View(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte(tableName))
v := b.Get([]byte(key))
json.Unmarshal(v, &assignee)
return nil
})
return err
}
func (d *Database) keyExists(tableName string, key string) bool {
resultIsNil := false
if !d.TableExists(tableName) {
//Table not exists. Do not proceed accessing key
log.Println("[DB] ERROR: Requesting key from table that didn't exist!!!")
return false
}
err := d.Db.(*bolt.DB).View(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte(tableName))
v := b.Get([]byte(key))
if v == nil {
resultIsNil = true
}
return nil
})
if err != nil {
return false
} else {
if resultIsNil {
return false
} else {
return true
}
}
}
func (d *Database) delete(tableName string, key string) error {
if d.ReadOnly {
return errors.New("Operation rejected in ReadOnly mode")
}
err := d.Db.(*bolt.DB).Update(func(tx *bolt.Tx) error {
tx.Bucket([]byte(tableName)).Delete([]byte(key))
return nil
})
return err
}
func (d *Database) listTable(tableName string) ([][][]byte, error) {
var results [][][]byte
err := d.Db.(*bolt.DB).View(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte(tableName))
c := b.Cursor()
for k, v := c.First(); k != nil; k, v = c.Next() {
results = append(results, [][]byte{k, v})
}
return nil
})
return results, err
}
func (d *Database) close() {
d.Db.(*bolt.DB).Close()
}

View File

@@ -1,208 +0,0 @@
//go:build mipsle || riscv64
// +build mipsle riscv64
package database
import (
"encoding/json"
"errors"
"log"
"os"
"path/filepath"
"strings"
"sync"
)
func newDatabase(dbfile string, readOnlyMode bool) (*Database, error) {
dbRootPath := filepath.ToSlash(filepath.Clean(dbfile))
dbRootPath = "fsdb/" + dbRootPath
err := os.MkdirAll(dbRootPath, 0755)
if err != nil {
return nil, err
}
tableMap := sync.Map{}
//build the table list from file system
files, err := filepath.Glob(filepath.Join(dbRootPath, "/*"))
if err != nil {
return nil, err
}
for _, file := range files {
if isDirectory(file) {
tableMap.Store(filepath.Base(file), "")
}
}
log.Println("Filesystem Emulated Key-value Database Service Started: " + dbRootPath)
return &Database{
Db: dbRootPath,
Tables: tableMap,
ReadOnly: readOnlyMode,
}, nil
}
func (d *Database) dump(filename string) ([]string, error) {
//Get all file objects from root
rootfiles, err := filepath.Glob(filepath.Join(d.Db.(string), "/*"))
if err != nil {
return []string{}, err
}
//Filter out the folders
rootFolders := []string{}
for _, file := range rootfiles {
if !isDirectory(file) {
rootFolders = append(rootFolders, filepath.Base(file))
}
}
return rootFolders, nil
}
func (d *Database) newTable(tableName string) error {
if d.ReadOnly {
return errors.New("Operation rejected in ReadOnly mode")
}
tablePath := filepath.Join(d.Db.(string), filepath.Base(tableName))
if !fileExists(tablePath) {
return os.MkdirAll(tablePath, 0755)
}
return nil
}
func (d *Database) tableExists(tableName string) bool {
tablePath := filepath.Join(d.Db.(string), filepath.Base(tableName))
if _, err := os.Stat(tablePath); errors.Is(err, os.ErrNotExist) {
return false
}
if !isDirectory(tablePath) {
return false
}
return true
}
func (d *Database) dropTable(tableName string) error {
if d.ReadOnly {
return errors.New("Operation rejected in ReadOnly mode")
}
tablePath := filepath.Join(d.Db.(string), filepath.Base(tableName))
if d.tableExists(tableName) {
return os.RemoveAll(tablePath)
} else {
return errors.New("table not exists")
}
}
func (d *Database) write(tableName string, key string, value interface{}) error {
if d.ReadOnly {
return errors.New("Operation rejected in ReadOnly mode")
}
tablePath := filepath.Join(d.Db.(string), filepath.Base(tableName))
js, err := json.Marshal(value)
if err != nil {
return err
}
key = strings.ReplaceAll(key, "/", "-SLASH_SIGN-")
return os.WriteFile(filepath.Join(tablePath, key+".entry"), js, 0755)
}
func (d *Database) read(tableName string, key string, assignee interface{}) error {
if !d.keyExists(tableName, key) {
return errors.New("key not exists")
}
key = strings.ReplaceAll(key, "/", "-SLASH_SIGN-")
tablePath := filepath.Join(d.Db.(string), filepath.Base(tableName))
entryPath := filepath.Join(tablePath, key+".entry")
content, err := os.ReadFile(entryPath)
if err != nil {
return err
}
err = json.Unmarshal(content, &assignee)
return err
}
func (d *Database) keyExists(tableName string, key string) bool {
key = strings.ReplaceAll(key, "/", "-SLASH_SIGN-")
tablePath := filepath.Join(d.Db.(string), filepath.Base(tableName))
entryPath := filepath.Join(tablePath, key+".entry")
return fileExists(entryPath)
}
func (d *Database) delete(tableName string, key string) error {
if d.ReadOnly {
return errors.New("Operation rejected in ReadOnly mode")
}
if !d.keyExists(tableName, key) {
return errors.New("key not exists")
}
key = strings.ReplaceAll(key, "/", "-SLASH_SIGN-")
tablePath := filepath.Join(d.Db.(string), filepath.Base(tableName))
entryPath := filepath.Join(tablePath, key+".entry")
return os.Remove(entryPath)
}
func (d *Database) listTable(tableName string) ([][][]byte, error) {
if !d.tableExists(tableName) {
return [][][]byte{}, errors.New("table not exists")
}
tablePath := filepath.Join(d.Db.(string), filepath.Base(tableName))
entries, err := filepath.Glob(filepath.Join(tablePath, "/*.entry"))
if err != nil {
return [][][]byte{}, err
}
var results [][][]byte = [][][]byte{}
for _, entry := range entries {
if !isDirectory(entry) {
//Read it
key := filepath.Base(entry)
key = strings.TrimSuffix(key, filepath.Ext(key))
key = strings.ReplaceAll(key, "-SLASH_SIGN-", "/")
bkey := []byte(key)
bval := []byte("")
c, err := os.ReadFile(entry)
if err != nil {
break
}
bval = c
results = append(results, [][]byte{bkey, bval})
}
}
return results, nil
}
func (d *Database) close() {
//Nothing to close as it is file system
}
func isDirectory(path string) bool {
fileInfo, err := os.Stat(path)
if err != nil {
return false
}
return fileInfo.IsDir()
}
func fileExists(name string) bool {
_, err := os.Stat(name)
if err == nil {
return true
}
if errors.Is(err, os.ErrNotExist) {
return false
}
return false
}

View File

@@ -1,82 +0,0 @@
package dynamicproxy
import (
"net/http"
"os"
"strings"
"imuslab.com/zoraxy/mod/geodb"
)
/*
Server.go
Main server for dynamic proxy core
*/
func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
//Check if this ip is in blacklist
clientIpAddr := geodb.GetRequesterIP(r)
if h.Parent.Option.GeodbStore.IsBlacklisted(clientIpAddr) {
w.Header().Set("Content-Type", "text/html; charset=utf-8")
w.WriteHeader(http.StatusForbidden)
template, err := os.ReadFile("./web/forbidden.html")
if err != nil {
w.Write([]byte("403 - Forbidden"))
} else {
w.Write(template)
}
h.logRequest(r, false, 403, "blacklist")
return
}
//Check if this is a redirection url
if h.Parent.Option.RedirectRuleTable.IsRedirectable(r) {
statusCode := h.Parent.Option.RedirectRuleTable.HandleRedirect(w, r)
h.logRequest(r, statusCode != 500, statusCode, "redirect")
return
}
//Check if there are external routing rule matches.
//If yes, route them via external rr
matchedRoutingRule := h.Parent.GetMatchingRoutingRule(r)
if matchedRoutingRule != nil {
//Matching routing rule found. Let the sub-router handle it
matchedRoutingRule.Route(w, r)
return
}
//Extract request host to see if it is virtual directory or subdomain
domainOnly := r.Host
if strings.Contains(r.Host, ":") {
hostPath := strings.Split(r.Host, ":")
domainOnly = hostPath[0]
}
if strings.Contains(r.Host, ".") {
//This might be a subdomain. See if there are any subdomain proxy router for this
//Remove the port if any
sep := h.Parent.getSubdomainProxyEndpointFromHostname(domainOnly)
if sep != nil {
h.subdomainRequest(w, r, sep)
return
}
}
//Clean up the request URI
proxyingPath := strings.TrimSpace(r.RequestURI)
targetProxyEndpoint := h.Parent.getTargetProxyEndpointFromRequestURI(proxyingPath)
if targetProxyEndpoint != nil {
h.proxyRequest(w, r, targetProxyEndpoint)
} else if !strings.HasSuffix(proxyingPath, "/") {
potentialProxtEndpoint := h.Parent.getTargetProxyEndpointFromRequestURI(proxyingPath + "/")
if potentialProxtEndpoint != nil {
h.proxyRequest(w, r, potentialProxtEndpoint)
} else {
h.proxyRequest(w, r, h.Parent.Root)
}
} else {
h.proxyRequest(w, r, h.Parent.Root)
}
}

View File

@@ -1,23 +0,0 @@
package domainsniff
import (
"net"
"time"
)
//Check if the domain is reachable and return err if not reachable
func DomainReachableWithError(domain string) error {
timeout := 1 * time.Second
conn, err := net.DialTimeout("tcp", domain, timeout)
if err != nil {
return err
}
conn.Close()
return nil
}
//Check if domain reachable
func DomainReachable(domain string) bool {
return DomainReachableWithError(domain) == nil
}

View File

@@ -1,21 +0,0 @@
MIT License
Copyright (c) 2018-present tobychui
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -1,414 +0,0 @@
package dpcore
import (
"errors"
"io"
"log"
"net"
"net/http"
"net/url"
"path/filepath"
"strings"
"sync"
"time"
)
var onExitFlushLoop func()
const (
defaultTimeout = time.Minute * 5
)
// ReverseProxy is an HTTP Handler that takes an incoming request and
// sends it to another server, proxying the response back to the
// client, support http, also support https tunnel using http.hijacker
type ReverseProxy struct {
// Set the timeout of the proxy server, default is 5 minutes
Timeout time.Duration
// Director must be a function which modifies
// the request into a new request to be sent
// using Transport. Its response is then copied
// back to the original client unmodified.
// Director must not access the provided Request
// after returning.
Director func(*http.Request)
// The transport used to perform proxy requests.
// default is http.DefaultTransport.
Transport http.RoundTripper
// FlushInterval specifies the flush interval
// to flush to the client while copying the
// response body. If zero, no periodic flushing is done.
FlushInterval time.Duration
// ErrorLog specifies an optional logger for errors
// that occur when attempting to proxy the request.
// If nil, logging goes to os.Stderr via the log package's
// standard logger.
ErrorLog *log.Logger
// ModifyResponse is an optional function that
// modifies the Response from the backend.
// If it returns an error, the proxy returns a StatusBadGateway error.
ModifyResponse func(*http.Response) error
//Prepender is an optional prepend text for URL rewrite
//
Prepender string
Verbal bool
}
type requestCanceler interface {
CancelRequest(req *http.Request)
}
func NewDynamicProxyCore(target *url.URL, prepender string) *ReverseProxy {
targetQuery := target.RawQuery
director := func(req *http.Request) {
req.URL.Scheme = target.Scheme
req.URL.Host = target.Host
req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path)
// If Host is empty, the Request.Write method uses
// the value of URL.Host.
// force use URL.Host
req.Host = req.URL.Host
if targetQuery == "" || req.URL.RawQuery == "" {
req.URL.RawQuery = targetQuery + req.URL.RawQuery
} else {
req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery
}
if _, ok := req.Header["User-Agent"]; !ok {
req.Header.Set("User-Agent", "")
}
}
return &ReverseProxy{
Director: director,
Prepender: prepender,
Verbal: false,
}
}
func singleJoiningSlash(a, b string) string {
aslash := strings.HasSuffix(a, "/")
bslash := strings.HasPrefix(b, "/")
switch {
case aslash && bslash:
return a + b[1:]
case !aslash && !bslash:
return a + "/" + b
}
return a + b
}
func copyHeader(dst, src http.Header) {
for k, vv := range src {
for _, v := range vv {
dst.Add(k, v)
}
}
}
// Hop-by-hop headers. These are removed when sent to the backend.
// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html
var hopHeaders = []string{
//"Connection",
"Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google
"Keep-Alive",
"Proxy-Authenticate",
"Proxy-Authorization",
"Te", // canonicalized version of "TE"
"Trailer", // not Trailers per URL above; http://www.rfc-editor.org/errata_search.php?eid=4522
"Transfer-Encoding",
//"Upgrade",
}
func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) {
if p.FlushInterval != 0 {
if wf, ok := dst.(writeFlusher); ok {
mlw := &maxLatencyWriter{
dst: wf,
latency: p.FlushInterval,
done: make(chan bool),
}
go mlw.flushLoop()
defer mlw.stop()
dst = mlw
}
}
io.Copy(dst, src)
}
type writeFlusher interface {
io.Writer
http.Flusher
}
type maxLatencyWriter struct {
dst writeFlusher
latency time.Duration
mu sync.Mutex
done chan bool
}
func (m *maxLatencyWriter) Write(b []byte) (int, error) {
m.mu.Lock()
defer m.mu.Unlock()
return m.dst.Write(b)
}
func (m *maxLatencyWriter) flushLoop() {
t := time.NewTicker(m.latency)
defer t.Stop()
for {
select {
case <-m.done:
if onExitFlushLoop != nil {
onExitFlushLoop()
}
return
case <-t.C:
m.mu.Lock()
m.dst.Flush()
m.mu.Unlock()
}
}
}
func (m *maxLatencyWriter) stop() {
m.done <- true
}
func (p *ReverseProxy) logf(format string, args ...interface{}) {
if p.ErrorLog != nil {
p.ErrorLog.Printf(format, args...)
} else {
log.Printf(format, args...)
}
}
func removeHeaders(header http.Header) {
// Remove hop-by-hop headers listed in the "Connection" header.
if c := header.Get("Connection"); c != "" {
for _, f := range strings.Split(c, ",") {
if f = strings.TrimSpace(f); f != "" {
header.Del(f)
}
}
}
// Remove hop-by-hop headers
for _, h := range hopHeaders {
if header.Get(h) != "" {
header.Del(h)
}
}
if header.Get("A-Upgrade") != "" {
header.Set("Upgrade", header.Get("A-Upgrade"))
header.Del("A-Upgrade")
}
}
func addXForwardedForHeader(req *http.Request) {
if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
// If we aren't the first proxy retain prior
// X-Forwarded-For information as a comma+space
// separated list and fold multiple headers into one.
if prior, ok := req.Header["X-Forwarded-For"]; ok {
clientIP = strings.Join(prior, ", ") + ", " + clientIP
}
req.Header.Set("X-Forwarded-For", clientIP)
}
}
func (p *ReverseProxy) ProxyHTTP(rw http.ResponseWriter, req *http.Request) error {
transport := p.Transport
if transport == nil {
transport = http.DefaultTransport
}
outreq := new(http.Request)
// Shallow copies of maps, like header
*outreq = *req
if cn, ok := rw.(http.CloseNotifier); ok {
if requestCanceler, ok := transport.(requestCanceler); ok {
// After the Handler has returned, there is no guarantee
// that the channel receives a value, so to make sure
reqDone := make(chan struct{})
defer close(reqDone)
clientGone := cn.CloseNotify()
go func() {
select {
case <-clientGone:
requestCanceler.CancelRequest(outreq)
case <-reqDone:
}
}()
}
}
p.Director(outreq)
outreq.Close = false
// We may modify the header (shallow copied above), so we only copy it.
outreq.Header = make(http.Header)
copyHeader(outreq.Header, req.Header)
// Remove hop-by-hop headers listed in the "Connection" header, Remove hop-by-hop headers.
removeHeaders(outreq.Header)
// Add X-Forwarded-For Header.
addXForwardedForHeader(outreq)
res, err := transport.RoundTrip(outreq)
if err != nil {
if p.Verbal {
p.logf("http: proxy error: %v", err)
}
//rw.WriteHeader(http.StatusBadGateway)
return err
}
// Remove hop-by-hop headers listed in the "Connection" header of the response, Remove hop-by-hop headers.
removeHeaders(res.Header)
if p.ModifyResponse != nil {
if err := p.ModifyResponse(res); err != nil {
if p.Verbal {
p.logf("http: proxy error: %v", err)
}
//rw.WriteHeader(http.StatusBadGateway)
return err
}
}
//Custom header rewriter functions
if res.Header.Get("Location") != "" {
//Custom redirection to this rproxy relative path
res.Header.Set("Location", filepath.ToSlash(filepath.Join(p.Prepender, res.Header.Get("Location"))))
}
// Copy header from response to client.
copyHeader(rw.Header(), res.Header)
// The "Trailer" header isn't included in the Transport's response, Build it up from Trailer.
if len(res.Trailer) > 0 {
trailerKeys := make([]string, 0, len(res.Trailer))
for k := range res.Trailer {
trailerKeys = append(trailerKeys, k)
}
rw.Header().Add("Trailer", strings.Join(trailerKeys, ", "))
}
rw.WriteHeader(res.StatusCode)
if len(res.Trailer) > 0 {
// Force chunking if we saw a response trailer.
// This prevents net/http from calculating the length for short
// bodies and adding a Content-Length.
if fl, ok := rw.(http.Flusher); ok {
fl.Flush()
}
}
p.copyResponse(rw, res.Body)
// close now, instead of defer, to populate res.Trailer
res.Body.Close()
copyHeader(rw.Header(), res.Trailer)
return nil
}
func (p *ReverseProxy) ProxyHTTPS(rw http.ResponseWriter, req *http.Request) error {
hij, ok := rw.(http.Hijacker)
if !ok {
p.logf("http server does not support hijacker")
return errors.New("http server does not support hijacker")
}
clientConn, _, err := hij.Hijack()
if err != nil {
if p.Verbal {
p.logf("http: proxy error: %v", err)
}
return err
}
proxyConn, err := net.Dial("tcp", req.URL.Host)
if err != nil {
if p.Verbal {
p.logf("http: proxy error: %v", err)
}
return err
}
// The returned net.Conn may have read or write deadlines
// already set, depending on the configuration of the
// Server, to set or clear those deadlines as needed
// we set timeout to 5 minutes
deadline := time.Now()
if p.Timeout == 0 {
deadline = deadline.Add(time.Minute * 5)
} else {
deadline = deadline.Add(p.Timeout)
}
err = clientConn.SetDeadline(deadline)
if err != nil {
if p.Verbal {
p.logf("http: proxy error: %v", err)
}
return err
}
err = proxyConn.SetDeadline(deadline)
if err != nil {
if p.Verbal {
p.logf("http: proxy error: %v", err)
}
return err
}
_, err = clientConn.Write([]byte("HTTP/1.0 200 OK\r\n\r\n"))
if err != nil {
if p.Verbal {
p.logf("http: proxy error: %v", err)
}
return err
}
go func() {
io.Copy(clientConn, proxyConn)
clientConn.Close()
proxyConn.Close()
}()
io.Copy(proxyConn, clientConn)
proxyConn.Close()
clientConn.Close()
return nil
}
func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) error {
if req.Method == "CONNECT" {
err := p.ProxyHTTPS(rw, req)
return err
} else {
err := p.ProxyHTTP(rw, req)
return err
}
}

View File

@@ -1,347 +0,0 @@
package dynamicproxy
import (
"context"
"crypto/tls"
"errors"
"log"
"net"
"net/http"
"net/url"
"strconv"
"strings"
"sync"
"time"
"imuslab.com/zoraxy/mod/dynamicproxy/dpcore"
"imuslab.com/zoraxy/mod/dynamicproxy/redirection"
"imuslab.com/zoraxy/mod/geodb"
"imuslab.com/zoraxy/mod/reverseproxy"
"imuslab.com/zoraxy/mod/statistic"
"imuslab.com/zoraxy/mod/tlscert"
)
/*
Zoraxy Dynamic Proxy
*/
type RouterOption struct {
Port int
UseTls bool
ForceHttpsRedirect bool
TlsManager *tlscert.Manager
RedirectRuleTable *redirection.RuleTable
GeodbStore *geodb.Store
StatisticCollector *statistic.Collector
}
type Router struct {
Option *RouterOption
ProxyEndpoints *sync.Map
SubdomainEndpoint *sync.Map
Running bool
Root *ProxyEndpoint
mux http.Handler
server *http.Server
tlsListener net.Listener
routingRules []*RoutingRule
}
type ProxyEndpoint struct {
Root string
Domain string
RequireTLS bool
Proxy *dpcore.ReverseProxy `json:"-"`
}
type SubdomainEndpoint struct {
MatchingDomain string
Domain string
RequireTLS bool
Proxy *reverseproxy.ReverseProxy `json:"-"`
}
type ProxyHandler struct {
Parent *Router
}
func NewDynamicProxy(option RouterOption) (*Router, error) {
proxyMap := sync.Map{}
domainMap := sync.Map{}
thisRouter := Router{
Option: &option,
ProxyEndpoints: &proxyMap,
SubdomainEndpoint: &domainMap,
Running: false,
server: nil,
routingRules: []*RoutingRule{},
}
thisRouter.mux = &ProxyHandler{
Parent: &thisRouter,
}
return &thisRouter, nil
}
// Update TLS setting in runtime. Will restart the proxy server
// if it is already running in the background
func (router *Router) UpdateTLSSetting(tlsEnabled bool) {
router.Option.UseTls = tlsEnabled
router.Restart()
}
// Update https redirect, which will require updates
func (router *Router) UpdateHttpToHttpsRedirectSetting(useRedirect bool) {
router.Option.ForceHttpsRedirect = useRedirect
router.Restart()
}
// Start the dynamic routing
func (router *Router) StartProxyService() error {
//Create a new server object
if router.server != nil {
return errors.New("Reverse proxy server already running")
}
if router.Root == nil {
return errors.New("Reverse proxy router root not set")
}
config := &tls.Config{
GetCertificate: router.Option.TlsManager.GetCert,
}
if router.Option.UseTls {
//Serve with TLS mode
ln, err := tls.Listen("tcp", ":"+strconv.Itoa(router.Option.Port), config)
if err != nil {
log.Println(err)
return err
}
router.tlsListener = ln
router.server = &http.Server{Addr: ":" + strconv.Itoa(router.Option.Port), Handler: router.mux}
router.Running = true
if router.Option.Port == 443 && router.Option.ForceHttpsRedirect {
//Add a 80 to 443 redirector
httpServer := &http.Server{
Addr: ":80",
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, "https://"+r.Host+r.RequestURI, http.StatusTemporaryRedirect)
}),
ReadTimeout: 3 * time.Second,
WriteTimeout: 3 * time.Second,
IdleTimeout: 120 * time.Second,
}
log.Println("Starting HTTP-to-HTTPS redirector (port 80)")
go func() {
//Start another router to check if the router.server is killed. If yes, kill this server as well
go func() {
for router.server != nil {
time.Sleep(100 * time.Millisecond)
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
httpServer.Shutdown(ctx)
log.Println(":80 to :433 redirection listener stopped")
}()
if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
log.Fatalf("Could not start server: %v\n", err)
}
}()
}
log.Println("Reverse proxy service started in the background (TLS mode)")
go func() {
if err := router.server.Serve(ln); err != nil && err != http.ErrServerClosed {
log.Fatalf("Could not start server: %v\n", err)
}
}()
} else {
//Serve with non TLS mode
router.tlsListener = nil
router.server = &http.Server{Addr: ":" + strconv.Itoa(router.Option.Port), Handler: router.mux}
router.Running = true
log.Println("Reverse proxy service started in the background (Plain HTTP mode)")
go func() {
router.server.ListenAndServe()
//log.Println("[DynamicProxy] " + err.Error())
}()
}
return nil
}
func (router *Router) StopProxyService() error {
if router.server == nil {
return errors.New("Reverse proxy server already stopped")
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
err := router.server.Shutdown(ctx)
if err != nil {
return err
}
if router.tlsListener != nil {
router.tlsListener.Close()
}
//Discard the server object
router.tlsListener = nil
router.server = nil
router.Running = false
return nil
}
// Restart the current router if it is running.
// Startup the server if it is not running initially
func (router *Router) Restart() error {
//Stop the router if it is already running
if router.Running {
err := router.StopProxyService()
if err != nil {
return err
}
}
//Start the server
err := router.StartProxyService()
return err
}
/*
Check if a given request is accessed via a proxied subdomain
*/
func (router *Router) IsProxiedSubdomain(r *http.Request) bool {
hostname := r.Header.Get("X-Forwarded-Host")
if hostname == "" {
hostname = r.Host
}
hostname = strings.Split(hostname, ":")[0]
subdEndpoint := router.getSubdomainProxyEndpointFromHostname(hostname)
return subdEndpoint != nil
}
/*
Add an URL into a custom proxy services
*/
func (router *Router) AddVirtualDirectoryProxyService(rootname string, domain string, requireTLS bool) error {
if domain[len(domain)-1:] == "/" {
domain = domain[:len(domain)-1]
}
if rootname[len(rootname)-1:] == "/" {
rootname = rootname[:len(rootname)-1]
}
webProxyEndpoint := domain
if requireTLS {
webProxyEndpoint = "https://" + webProxyEndpoint
} else {
webProxyEndpoint = "http://" + webProxyEndpoint
}
//Create a new proxy agent for this root
path, err := url.Parse(webProxyEndpoint)
if err != nil {
return err
}
proxy := dpcore.NewDynamicProxyCore(path, rootname)
endpointObject := ProxyEndpoint{
Root: rootname,
Domain: domain,
RequireTLS: requireTLS,
Proxy: proxy,
}
router.ProxyEndpoints.Store(rootname, &endpointObject)
log.Println("Adding Proxy Rule: ", rootname+" to "+domain)
return nil
}
/*
Remove routing from RP
*/
func (router *Router) RemoveProxy(ptype string, key string) error {
//fmt.Println(ptype, key)
if ptype == "vdir" {
router.ProxyEndpoints.Delete(key)
return nil
} else if ptype == "subd" {
router.SubdomainEndpoint.Delete(key)
return nil
}
return errors.New("invalid ptype")
}
/*
Add an default router for the proxy server
*/
func (router *Router) SetRootProxy(proxyLocation string, requireTLS bool) error {
if proxyLocation[len(proxyLocation)-1:] == "/" {
proxyLocation = proxyLocation[:len(proxyLocation)-1]
}
webProxyEndpoint := proxyLocation
if requireTLS {
webProxyEndpoint = "https://" + webProxyEndpoint
} else {
webProxyEndpoint = "http://" + webProxyEndpoint
}
//Create a new proxy agent for this root
path, err := url.Parse(webProxyEndpoint)
if err != nil {
return err
}
proxy := dpcore.NewDynamicProxyCore(path, "")
rootEndpoint := ProxyEndpoint{
Root: "/",
Domain: proxyLocation,
RequireTLS: requireTLS,
Proxy: proxy,
}
router.Root = &rootEndpoint
return nil
}
//Helpers to export the syncmap for easier processing
func (r *Router) GetSDProxyEndpointsAsMap() map[string]*SubdomainEndpoint {
m := make(map[string]*SubdomainEndpoint)
r.SubdomainEndpoint.Range(func(key, value interface{}) bool {
k, ok := key.(string)
if !ok {
return true
}
v, ok := value.(*SubdomainEndpoint)
if !ok {
return true
}
m[k] = v
return true
})
return m
}
func (r *Router) GetVDProxyEndpointsAsMap() map[string]*ProxyEndpoint {
m := make(map[string]*ProxyEndpoint)
r.ProxyEndpoints.Range(func(key, value interface{}) bool {
k, ok := key.(string)
if !ok {
return true
}
v, ok := value.(*ProxyEndpoint)
if !ok {
return true
}
m[k] = v
return true
})
return m
}

View File

@@ -1,149 +0,0 @@
package dynamicproxy
import (
"errors"
"log"
"net"
"net/http"
"net/url"
"strings"
"imuslab.com/zoraxy/mod/geodb"
"imuslab.com/zoraxy/mod/statistic"
"imuslab.com/zoraxy/mod/websocketproxy"
)
func (router *Router) getTargetProxyEndpointFromRequestURI(requestURI string) *ProxyEndpoint {
var targetProxyEndpoint *ProxyEndpoint = nil
router.ProxyEndpoints.Range(func(key, value interface{}) bool {
rootname := key.(string)
if strings.HasPrefix(requestURI, rootname) {
thisProxyEndpoint := value.(*ProxyEndpoint)
targetProxyEndpoint = thisProxyEndpoint
}
/*
if len(requestURI) >= len(rootname) && requestURI[:len(rootname)] == rootname {
thisProxyEndpoint := value.(*ProxyEndpoint)
targetProxyEndpoint = thisProxyEndpoint
}
*/
return true
})
return targetProxyEndpoint
}
func (router *Router) getSubdomainProxyEndpointFromHostname(hostname string) *SubdomainEndpoint {
var targetSubdomainEndpoint *SubdomainEndpoint = nil
ep, ok := router.SubdomainEndpoint.Load(hostname)
if ok {
targetSubdomainEndpoint = ep.(*SubdomainEndpoint)
}
return targetSubdomainEndpoint
}
func (router *Router) rewriteURL(rooturl string, requestURL string) string {
if len(requestURL) > len(rooturl) {
return requestURL[len(rooturl):]
}
return ""
}
func (h *ProxyHandler) subdomainRequest(w http.ResponseWriter, r *http.Request, target *SubdomainEndpoint) {
r.Header.Set("X-Forwarded-Host", r.Host)
requestURL := r.URL.String()
if r.Header["Upgrade"] != nil && strings.ToLower(r.Header["Upgrade"][0]) == "websocket" {
//Handle WebSocket request. Forward the custom Upgrade header and rewrite origin
r.Header.Set("A-Upgrade", "websocket")
wsRedirectionEndpoint := target.Domain
if wsRedirectionEndpoint[len(wsRedirectionEndpoint)-1:] != "/" {
//Append / to the end of the redirection endpoint if not exists
wsRedirectionEndpoint = wsRedirectionEndpoint + "/"
}
if len(requestURL) > 0 && requestURL[:1] == "/" {
//Remove starting / from request URL if exists
requestURL = requestURL[1:]
}
u, _ := url.Parse("ws://" + wsRedirectionEndpoint + requestURL)
if target.RequireTLS {
u, _ = url.Parse("wss://" + wsRedirectionEndpoint + requestURL)
}
h.logRequest(r, true, 101, "subdomain-websocket")
wspHandler := websocketproxy.NewProxy(u)
wspHandler.ServeHTTP(w, r)
return
}
r.Host = r.URL.Host
err := target.Proxy.ServeHTTP(w, r)
var dnsError *net.DNSError
if err != nil {
if errors.As(err, &dnsError) {
http.ServeFile(w, r, "./web/hosterror.html")
log.Println(err.Error())
h.logRequest(r, false, 404, "subdomain-http")
} else {
http.ServeFile(w, r, "./web/rperror.html")
log.Println(err.Error())
h.logRequest(r, false, 521, "subdomain-http")
}
}
h.logRequest(r, true, 200, "subdomain-http")
}
func (h *ProxyHandler) proxyRequest(w http.ResponseWriter, r *http.Request, target *ProxyEndpoint) {
rewriteURL := h.Parent.rewriteURL(target.Root, r.RequestURI)
r.URL, _ = url.Parse(rewriteURL)
r.Header.Set("X-Forwarded-Host", r.Host)
if r.Header["Upgrade"] != nil && strings.ToLower(r.Header["Upgrade"][0]) == "websocket" {
//Handle WebSocket request. Forward the custom Upgrade header and rewrite origin
r.Header.Set("A-Upgrade", "websocket")
wsRedirectionEndpoint := target.Domain
if wsRedirectionEndpoint[len(wsRedirectionEndpoint)-1:] != "/" {
wsRedirectionEndpoint = wsRedirectionEndpoint + "/"
}
u, _ := url.Parse("ws://" + wsRedirectionEndpoint + r.URL.String())
if target.RequireTLS {
u, _ = url.Parse("wss://" + wsRedirectionEndpoint + r.URL.String())
}
h.logRequest(r, true, 101, "vdir-websocket")
wspHandler := websocketproxy.NewProxy(u)
wspHandler.ServeHTTP(w, r)
return
}
r.Host = r.URL.Host
err := target.Proxy.ServeHTTP(w, r)
var dnsError *net.DNSError
if err != nil {
if errors.As(err, &dnsError) {
http.ServeFile(w, r, "./web/hosterror.html")
log.Println(err.Error())
h.logRequest(r, false, 404, "vdir-http")
} else {
http.ServeFile(w, r, "./web/rperror.html")
log.Println(err.Error())
h.logRequest(r, false, 521, "vdir-http")
}
}
h.logRequest(r, true, 200, "vdir-http")
}
func (h *ProxyHandler) logRequest(r *http.Request, succ bool, statusCode int, forwardType string) {
if h.Parent.Option.StatisticCollector != nil {
go func() {
requestInfo := statistic.RequestInfo{
IpAddr: geodb.GetRequesterIP(r),
RequestOriginalCountryISOCode: h.Parent.Option.GeodbStore.GetRequesterCountryISOCode(r),
Succ: succ,
StatusCode: statusCode,
ForwardType: forwardType,
}
h.Parent.Option.StatisticCollector.RecordRequest(requestInfo)
}()
}
}

View File

@@ -1,53 +0,0 @@
package redirection
import (
"log"
"net/http"
"strings"
)
/*
handler.go
This script store the handlers use for handling
redirection request
*/
//Check if a request URL is a redirectable URI
func (t *RuleTable) IsRedirectable(r *http.Request) bool {
requestPath := r.Host + r.URL.Path
rr := t.MatchRedirectRule(requestPath)
return rr != nil
}
//Handle the redirect request, return after calling this function to prevent
//multiple write to the response writer
//Return the status code of the redirection handling
func (t *RuleTable) HandleRedirect(w http.ResponseWriter, r *http.Request) int {
requestPath := r.Host + r.URL.Path
rr := t.MatchRedirectRule(requestPath)
if rr != nil {
redirectTarget := rr.TargetURL
//Always pad a / at the back of the target URL
if redirectTarget[len(redirectTarget)-1:] != "/" {
redirectTarget += "/"
}
if rr.ForwardChildpath {
//Remove the first / in the path
redirectTarget += r.URL.Path[1:] + "?" + r.URL.RawQuery
}
if !strings.HasPrefix(redirectTarget, "http://") && !strings.HasPrefix(redirectTarget, "https://") {
redirectTarget = "http://" + redirectTarget
}
http.Redirect(w, r, redirectTarget, rr.StatusCode)
return rr.StatusCode
} else {
//Invalid usage
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte("500 - Internal Server Error"))
log.Println("Target request URL do not have matching redirect rule. Check with IsRedirectable before calling HandleRedirect!")
return 500
}
}

View File

@@ -1,162 +0,0 @@
package redirection
import (
"encoding/json"
"log"
"os"
"path"
"path/filepath"
"strings"
"sync"
"imuslab.com/zoraxy/mod/utils"
)
type RuleTable struct {
configPath string //The location where the redirection rules is stored
rules sync.Map //Store the redirection rules for this reverse proxy instance
}
type RedirectRules struct {
RedirectURL string //The matching URL to redirect
TargetURL string //The destination redirection url
ForwardChildpath bool //Also redirect the pathname
StatusCode int //Status Code for redirection
}
func NewRuleTable(configPath string) (*RuleTable, error) {
thisRuleTable := RuleTable{
rules: sync.Map{},
configPath: configPath,
}
//Load all the rules from the config path
if !utils.FileExists(configPath) {
os.MkdirAll(configPath, 0775)
}
// Load all the *.json from the configPath
files, err := filepath.Glob(filepath.Join(configPath, "*.json"))
if err != nil {
return nil, err
}
// Parse the json content into RedirectRules
var rules []*RedirectRules
for _, file := range files {
b, err := os.ReadFile(file)
if err != nil {
continue
}
thisRule := RedirectRules{}
err = json.Unmarshal(b, &thisRule)
if err != nil {
continue
}
rules = append(rules, &thisRule)
}
//Map the rules into the sync map
for _, rule := range rules {
log.Println("Redirection rule added: " + rule.RedirectURL + " -> " + rule.TargetURL)
thisRuleTable.rules.Store(rule.RedirectURL, rule)
}
return &thisRuleTable, nil
}
func (t *RuleTable) AddRedirectRule(redirectURL string, destURL string, forwardPathname bool, statusCode int) error {
// Create a new RedirectRules object with the given parameters
newRule := &RedirectRules{
RedirectURL: redirectURL,
TargetURL: destURL,
ForwardChildpath: forwardPathname,
StatusCode: statusCode,
}
// Convert the redirectURL to a valid filename by replacing "/" with "-" and "." with "_"
filename := strings.ReplaceAll(strings.ReplaceAll(redirectURL, "/", "-"), ".", "_") + ".json"
// Create the full file path by joining the t.configPath with the filename
filepath := path.Join(t.configPath, filename)
// Create a new file for writing the JSON data
file, err := os.Create(filepath)
if err != nil {
log.Printf("Error creating file %s: %s", filepath, err)
return err
}
defer file.Close()
// Encode the RedirectRules object to JSON and write it to the file
err = json.NewEncoder(file).Encode(newRule)
if err != nil {
log.Printf("Error encoding JSON to file %s: %s", filepath, err)
return err
}
// Store the RedirectRules object in the sync.Map
t.rules.Store(redirectURL, newRule)
return nil
}
func (t *RuleTable) DeleteRedirectRule(redirectURL string) error {
// Convert the redirectURL to a valid filename by replacing "/" with "-" and "." with "_"
filename := strings.ReplaceAll(strings.ReplaceAll(redirectURL, "/", "-"), ".", "_") + ".json"
// Create the full file path by joining the t.configPath with the filename
filepath := path.Join(t.configPath, filename)
// Check if the file exists
if _, err := os.Stat(filepath); os.IsNotExist(err) {
return nil // File doesn't exist, nothing to delete
}
// Delete the file
if err := os.Remove(filepath); err != nil {
log.Printf("Error deleting file %s: %s", filepath, err)
return err
}
// Delete the key-value pair from the sync.Map
t.rules.Delete(redirectURL)
return nil
}
// Get a list of all the redirection rules
func (t *RuleTable) GetAllRedirectRules() []*RedirectRules {
rules := []*RedirectRules{}
t.rules.Range(func(key, value interface{}) bool {
r, ok := value.(*RedirectRules)
if ok {
rules = append(rules, r)
}
return true
})
return rules
}
// Check if a given request URL matched any of the redirection rule
func (t *RuleTable) MatchRedirectRule(requestedURL string) *RedirectRules {
// Iterate through all the keys in the rules map
var targetRedirectionRule *RedirectRules = nil
var maxMatch int = 0
t.rules.Range(func(key interface{}, value interface{}) bool {
// Check if the requested URL starts with the key as a prefix
if strings.HasPrefix(requestedURL, key.(string)) {
// This request URL matched the domain
if len(key.(string)) > maxMatch {
maxMatch = len(key.(string))
targetRedirectionRule = value.(*RedirectRules)
}
}
return true
})
return targetRedirectionRule
}

View File

@@ -1,85 +0,0 @@
package dynamicproxy
import (
"errors"
"net/http"
)
/*
Special.go
This script handle special routing rules
by external modules
*/
type RoutingRule struct {
ID string
MatchRule func(r *http.Request) bool
RoutingHandler http.Handler
Enabled bool
}
//Router functions
//Check if a routing rule exists given its id
func (router *Router) GetRoutingRuleById(rrid string) (*RoutingRule, error) {
for _, rr := range router.routingRules {
if rr.ID == rrid {
return rr, nil
}
}
return nil, errors.New("routing rule with given id not found")
}
//Add a routing rule to the router
func (router *Router) AddRoutingRules(rr *RoutingRule) error {
_, err := router.GetRoutingRuleById(rr.ID)
if err != nil {
//routing rule with given id already exists
return err
}
router.routingRules = append(router.routingRules, rr)
return nil
}
//Remove a routing rule from the router
func (router *Router) RemoveRoutingRule(rrid string) {
newRoutingRules := []*RoutingRule{}
for _, rr := range router.routingRules {
if rr.ID != rrid {
newRoutingRules = append(newRoutingRules, rr)
}
}
router.routingRules = newRoutingRules
}
//Get all routing rules
func (router *Router) GetAllRoutingRules() []*RoutingRule {
return router.routingRules
}
//Get the matching routing rule that describe this request.
//Return nil if no routing rule is match
func (router *Router) GetMatchingRoutingRule(r *http.Request) *RoutingRule {
for _, thisRr := range router.routingRules {
if thisRr.IsMatch(r) {
return thisRr
}
}
return nil
}
//Routing Rule functions
//Check if a request object match the
func (e *RoutingRule) IsMatch(r *http.Request) bool {
if !e.Enabled {
return false
}
return e.MatchRule(r)
}
func (e *RoutingRule) Route(w http.ResponseWriter, r *http.Request) {
e.RoutingHandler.ServeHTTP(w, r)
}

View File

@@ -1,44 +0,0 @@
package dynamicproxy
import (
"log"
"net/url"
"imuslab.com/zoraxy/mod/reverseproxy"
)
/*
Add an URL intoa custom subdomain service
*/
func (router *Router) AddSubdomainRoutingService(hostnameWithSubdomain string, domain string, requireTLS bool) error {
if domain[len(domain)-1:] == "/" {
domain = domain[:len(domain)-1]
}
webProxyEndpoint := domain
if requireTLS {
webProxyEndpoint = "https://" + webProxyEndpoint
} else {
webProxyEndpoint = "http://" + webProxyEndpoint
}
//Create a new proxy agent for this root
path, err := url.Parse(webProxyEndpoint)
if err != nil {
return err
}
proxy := reverseproxy.NewReverseProxy(path)
router.SubdomainEndpoint.Store(hostnameWithSubdomain, &SubdomainEndpoint{
MatchingDomain: hostnameWithSubdomain,
Domain: domain,
RequireTLS: requireTLS,
Proxy: proxy,
})
log.Println("Adding Subdomain Rule: ", hostnameWithSubdomain+" to "+domain)
return nil
}

View File

@@ -1,244 +0,0 @@
package geodb
import (
"net"
"net/http"
"strings"
"github.com/oschwald/geoip2-golang"
"imuslab.com/zoraxy/mod/database"
)
type Store struct {
Enabled bool
geodb *geoip2.Reader
sysdb *database.Database
}
type CountryInfo struct {
CountryIsoCode string
ContinetCode string
}
func NewGeoDb(sysdb *database.Database, dbfile string) (*Store, error) {
db, err := geoip2.Open(dbfile)
if err != nil {
return nil, err
}
err = sysdb.NewTable("blacklist-cn")
if err != nil {
return nil, err
}
err = sysdb.NewTable("blacklist-ip")
if err != nil {
return nil, err
}
err = sysdb.NewTable("blacklist")
if err != nil {
return nil, err
}
blacklistEnabled := false
sysdb.Read("blacklist", "enabled", &blacklistEnabled)
return &Store{
Enabled: blacklistEnabled,
geodb: db,
sysdb: sysdb,
}, nil
}
func (s *Store) ToggleBlacklist(enabled bool) {
s.sysdb.Write("blacklist", "enabled", enabled)
s.Enabled = enabled
}
func (s *Store) ResolveCountryCodeFromIP(ipstring string) (*CountryInfo, error) {
// If you are using strings that may be invalid, check that ip is not nil
ip := net.ParseIP(ipstring)
record, err := s.geodb.City(ip)
if err != nil {
return nil, err
}
return &CountryInfo{
record.Country.IsoCode,
record.Continent.Code,
}, nil
}
func (s *Store) Close() {
s.geodb.Close()
}
func (s *Store) AddCountryCodeToBlackList(countryCode string) {
countryCode = strings.ToLower(countryCode)
s.sysdb.Write("blacklist-cn", countryCode, true)
}
func (s *Store) RemoveCountryCodeFromBlackList(countryCode string) {
countryCode = strings.ToLower(countryCode)
s.sysdb.Delete("blacklist-cn", countryCode)
}
func (s *Store) IsCountryCodeBlacklisted(countryCode string) bool {
countryCode = strings.ToLower(countryCode)
var isBlacklisted bool = false
s.sysdb.Read("blacklist-cn", countryCode, &isBlacklisted)
return isBlacklisted
}
func (s *Store) GetAllBlacklistedCountryCode() []string {
bannedCountryCodes := []string{}
entries, err := s.sysdb.ListTable("blacklist-cn")
if err != nil {
return bannedCountryCodes
}
for _, keypairs := range entries {
ip := string(keypairs[0])
bannedCountryCodes = append(bannedCountryCodes, ip)
}
return bannedCountryCodes
}
func (s *Store) AddIPToBlackList(ipAddr string) {
s.sysdb.Write("blacklist-ip", ipAddr, true)
}
func (s *Store) RemoveIPFromBlackList(ipAddr string) {
s.sysdb.Delete("blacklist-ip", ipAddr)
}
func (s *Store) IsIPBlacklisted(ipAddr string) bool {
var isBlacklisted bool = false
s.sysdb.Read("blacklist-ip", ipAddr, &isBlacklisted)
if isBlacklisted {
return true
}
//Check for IP wildcard and CIRD rules
AllBlacklistedIps := s.GetAllBlacklistedIp()
for _, blacklistRule := range AllBlacklistedIps {
wildcardMatch := MatchIpWildcard(ipAddr, blacklistRule)
if wildcardMatch {
return true
}
cidrMatch := MatchIpCIDR(ipAddr, blacklistRule)
if cidrMatch {
return true
}
}
return false
}
func (s *Store) GetAllBlacklistedIp() []string {
bannedIps := []string{}
entries, err := s.sysdb.ListTable("blacklist-ip")
if err != nil {
return bannedIps
}
for _, keypairs := range entries {
ip := string(keypairs[0])
bannedIps = append(bannedIps, ip)
}
return bannedIps
}
//Check if a IP address is blacklisted, in either country or IP blacklist
func (s *Store) IsBlacklisted(ipAddr string) bool {
if !s.Enabled {
//Blacklist not enabled. Always return false
return false
}
if ipAddr == "" {
//Unable to get the target IP address
return false
}
countryCode, err := s.ResolveCountryCodeFromIP(ipAddr)
if err != nil {
return false
}
if s.IsCountryCodeBlacklisted(countryCode.CountryIsoCode) {
return true
}
if s.IsIPBlacklisted(ipAddr) {
return true
}
return false
}
func (s *Store) GetRequesterCountryISOCode(r *http.Request) string {
ipAddr := GetRequesterIP(r)
if ipAddr == "" {
return ""
}
countryCode, err := s.ResolveCountryCodeFromIP(ipAddr)
if err != nil {
return ""
}
return countryCode.CountryIsoCode
}
//Utilities function
func GetRequesterIP(r *http.Request) string {
ip := r.Header.Get("X-Forwarded-For")
if ip == "" {
ip = r.Header.Get("X-Real-IP")
if ip == "" {
ip = strings.Split(r.RemoteAddr, ":")[0]
}
}
return ip
}
//Match the IP address with a wildcard string
func MatchIpWildcard(ipAddress, wildcard string) bool {
// Split IP address and wildcard into octets
ipOctets := strings.Split(ipAddress, ".")
wildcardOctets := strings.Split(wildcard, ".")
// Check that both have 4 octets
if len(ipOctets) != 4 || len(wildcardOctets) != 4 {
return false
}
// Check each octet to see if it matches the wildcard or is an exact match
for i := 0; i < 4; i++ {
if wildcardOctets[i] == "*" {
continue
}
if ipOctets[i] != wildcardOctets[i] {
return false
}
}
return true
}
//Match ip address with CIDR
func MatchIpCIDR(ip string, cidr string) bool {
// parse the CIDR string
_, cidrnet, err := net.ParseCIDR(cidr)
if err != nil {
return false
}
// parse the IP address
ipAddr := net.ParseIP(ip)
// check if the IP address is within the CIDR range
return cidrnet.Contains(ipAddr)
}

View File

@@ -1,21 +0,0 @@
MIT License
Copyright (c) 2018-present tobychui
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -1,68 +0,0 @@
# Introduction
A minimalist proxy library for go, inspired by `net/http/httputil` and add support for HTTPS using HTTP Tunnel
Support cancels an in-flight request by closing it's connection
# Installation
```sh
go get github.com/cssivision/reverseproxy
```
# Usage
## A simple proxy
```go
package main
import (
"net/http"
"net/url"
"github.com/cssivision/reverseproxy"
)
func main() {
http.ListenAndServe(":8080", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
path, err := url.Parse("https://github.com")
if err != nil {
panic(err)
return
}
proxy := reverseproxy.NewReverseProxy(path)
proxy.ServeHTTP(w, r)
}))
}
```
## Use as a proxy server
To use proxy server, you should set browser to use the proxy server as an HTTP proxy.
```go
package main
import (
"net/http"
"net/url"
"github.com/cssivision/reverseproxy"
)
func main() {
http.ListenAndServe(":8080", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
path, err := url.Parse("http://" + r.Host)
if err != nil {
panic(err)
return
}
proxy := reverseproxy.NewReverseProxy(path)
proxy.ServeHTTP(w, r)
// Specific for HTTP and HTTPS
// if r.Method == "CONNECT" {
// proxy.ProxyHTTPS(w, r)
// } else {
// proxy.ProxyHTTP(w, r)
// }
}))
}
```

View File

@@ -1,405 +0,0 @@
package reverseproxy
import (
"errors"
"io"
"log"
"net"
"net/http"
"net/url"
"strings"
"sync"
"time"
)
var onExitFlushLoop func()
const (
defaultTimeout = time.Minute * 5
)
// ReverseProxy is an HTTP Handler that takes an incoming request and
// sends it to another server, proxying the response back to the
// client, support http, also support https tunnel using http.hijacker
type ReverseProxy struct {
// Set the timeout of the proxy server, default is 5 minutes
Timeout time.Duration
// Director must be a function which modifies
// the request into a new request to be sent
// using Transport. Its response is then copied
// back to the original client unmodified.
// Director must not access the provided Request
// after returning.
Director func(*http.Request)
// The transport used to perform proxy requests.
// default is http.DefaultTransport.
Transport http.RoundTripper
// FlushInterval specifies the flush interval
// to flush to the client while copying the
// response body. If zero, no periodic flushing is done.
FlushInterval time.Duration
// ErrorLog specifies an optional logger for errors
// that occur when attempting to proxy the request.
// If nil, logging goes to os.Stderr via the log package's
// standard logger.
ErrorLog *log.Logger
// ModifyResponse is an optional function that
// modifies the Response from the backend.
// If it returns an error, the proxy returns a StatusBadGateway error.
ModifyResponse func(*http.Response) error
Verbal bool
}
type requestCanceler interface {
CancelRequest(req *http.Request)
}
// NewReverseProxy returns a new ReverseProxy that routes
// URLs to the scheme, host, and base path provided in target. If the
// target's path is "/base" and the incoming request was for "/dir",
// the target request will be for /base/dir. if the target's query is a=10
// and the incoming request's query is b=100, the target's request's query
// will be a=10&b=100.
// NewReverseProxy does not rewrite the Host header.
// To rewrite Host headers, use ReverseProxy directly with a custom
// Director policy.
func NewReverseProxy(target *url.URL) *ReverseProxy {
targetQuery := target.RawQuery
director := func(req *http.Request) {
req.URL.Scheme = target.Scheme
req.URL.Host = target.Host
req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path)
// If Host is empty, the Request.Write method uses
// the value of URL.Host.
// force use URL.Host
req.Host = req.URL.Host
if targetQuery == "" || req.URL.RawQuery == "" {
req.URL.RawQuery = targetQuery + req.URL.RawQuery
} else {
req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery
}
if _, ok := req.Header["User-Agent"]; !ok {
req.Header.Set("User-Agent", "")
}
}
return &ReverseProxy{Director: director, Verbal: false}
}
func singleJoiningSlash(a, b string) string {
aslash := strings.HasSuffix(a, "/")
bslash := strings.HasPrefix(b, "/")
switch {
case aslash && bslash:
return a + b[1:]
case !aslash && !bslash:
return a + "/" + b
}
return a + b
}
func copyHeader(dst, src http.Header) {
for k, vv := range src {
for _, v := range vv {
dst.Add(k, v)
}
}
}
// Hop-by-hop headers. These are removed when sent to the backend.
// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html
var hopHeaders = []string{
//"Connection",
"Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google
"Keep-Alive",
"Proxy-Authenticate",
"Proxy-Authorization",
"Te", // canonicalized version of "TE"
"Trailer", // not Trailers per URL above; http://www.rfc-editor.org/errata_search.php?eid=4522
"Transfer-Encoding",
//"Upgrade",
}
func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) {
if p.FlushInterval != 0 {
if wf, ok := dst.(writeFlusher); ok {
mlw := &maxLatencyWriter{
dst: wf,
latency: p.FlushInterval,
done: make(chan bool),
}
go mlw.flushLoop()
defer mlw.stop()
dst = mlw
}
}
io.Copy(dst, src)
}
type writeFlusher interface {
io.Writer
http.Flusher
}
type maxLatencyWriter struct {
dst writeFlusher
latency time.Duration
mu sync.Mutex
done chan bool
}
func (m *maxLatencyWriter) Write(b []byte) (int, error) {
m.mu.Lock()
defer m.mu.Unlock()
return m.dst.Write(b)
}
func (m *maxLatencyWriter) flushLoop() {
t := time.NewTicker(m.latency)
defer t.Stop()
for {
select {
case <-m.done:
if onExitFlushLoop != nil {
onExitFlushLoop()
}
return
case <-t.C:
m.mu.Lock()
m.dst.Flush()
m.mu.Unlock()
}
}
}
func (m *maxLatencyWriter) stop() {
m.done <- true
}
func (p *ReverseProxy) logf(format string, args ...interface{}) {
if p.ErrorLog != nil {
p.ErrorLog.Printf(format, args...)
} else {
log.Printf(format, args...)
}
}
func removeHeaders(header http.Header) {
// Remove hop-by-hop headers listed in the "Connection" header.
if c := header.Get("Connection"); c != "" {
for _, f := range strings.Split(c, ",") {
if f = strings.TrimSpace(f); f != "" {
header.Del(f)
}
}
}
// Remove hop-by-hop headers
for _, h := range hopHeaders {
if header.Get(h) != "" {
header.Del(h)
}
}
if header.Get("A-Upgrade") != "" {
header.Set("Upgrade", header.Get("A-Upgrade"))
header.Del("A-Upgrade")
}
}
func addXForwardedForHeader(req *http.Request) {
if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
// If we aren't the first proxy retain prior
// X-Forwarded-For information as a comma+space
// separated list and fold multiple headers into one.
if prior, ok := req.Header["X-Forwarded-For"]; ok {
clientIP = strings.Join(prior, ", ") + ", " + clientIP
}
req.Header.Set("X-Forwarded-For", clientIP)
}
}
func (p *ReverseProxy) ProxyHTTP(rw http.ResponseWriter, req *http.Request) error {
transport := p.Transport
if transport == nil {
transport = http.DefaultTransport
}
outreq := new(http.Request)
// Shallow copies of maps, like header
*outreq = *req
if cn, ok := rw.(http.CloseNotifier); ok {
if requestCanceler, ok := transport.(requestCanceler); ok {
// After the Handler has returned, there is no guarantee
// that the channel receives a value, so to make sure
reqDone := make(chan struct{})
defer close(reqDone)
clientGone := cn.CloseNotify()
go func() {
select {
case <-clientGone:
requestCanceler.CancelRequest(outreq)
case <-reqDone:
}
}()
}
}
p.Director(outreq)
outreq.Close = false
// We may modify the header (shallow copied above), so we only copy it.
outreq.Header = make(http.Header)
copyHeader(outreq.Header, req.Header)
// Remove hop-by-hop headers listed in the "Connection" header, Remove hop-by-hop headers.
removeHeaders(outreq.Header)
// Add X-Forwarded-For Header.
addXForwardedForHeader(outreq)
res, err := transport.RoundTrip(outreq)
if err != nil {
if p.Verbal {
p.logf("http: proxy error: %v", err)
}
//rw.WriteHeader(http.StatusBadGateway)
return err
}
// Remove hop-by-hop headers listed in the "Connection" header of the response, Remove hop-by-hop headers.
removeHeaders(res.Header)
if p.ModifyResponse != nil {
if err := p.ModifyResponse(res); err != nil {
if p.Verbal {
p.logf("http: proxy error: %v", err)
}
//rw.WriteHeader(http.StatusBadGateway)
return err
}
}
// Copy header from response to client.
copyHeader(rw.Header(), res.Header)
// The "Trailer" header isn't included in the Transport's response, Build it up from Trailer.
if len(res.Trailer) > 0 {
trailerKeys := make([]string, 0, len(res.Trailer))
for k := range res.Trailer {
trailerKeys = append(trailerKeys, k)
}
rw.Header().Add("Trailer", strings.Join(trailerKeys, ", "))
}
rw.WriteHeader(res.StatusCode)
if len(res.Trailer) > 0 {
// Force chunking if we saw a response trailer.
// This prevents net/http from calculating the length for short
// bodies and adding a Content-Length.
if fl, ok := rw.(http.Flusher); ok {
fl.Flush()
}
}
p.copyResponse(rw, res.Body)
// close now, instead of defer, to populate res.Trailer
res.Body.Close()
copyHeader(rw.Header(), res.Trailer)
return nil
}
func (p *ReverseProxy) ProxyHTTPS(rw http.ResponseWriter, req *http.Request) error {
hij, ok := rw.(http.Hijacker)
if !ok {
p.logf("http server does not support hijacker")
return errors.New("http server does not support hijacker")
}
clientConn, _, err := hij.Hijack()
if err != nil {
if p.Verbal {
p.logf("http: proxy error: %v", err)
}
return err
}
proxyConn, err := net.Dial("tcp", req.URL.Host)
if err != nil {
if p.Verbal {
p.logf("http: proxy error: %v", err)
}
return err
}
// The returned net.Conn may have read or write deadlines
// already set, depending on the configuration of the
// Server, to set or clear those deadlines as needed
// we set timeout to 5 minutes
deadline := time.Now()
if p.Timeout == 0 {
deadline = deadline.Add(time.Minute * 5)
} else {
deadline = deadline.Add(p.Timeout)
}
err = clientConn.SetDeadline(deadline)
if err != nil {
if p.Verbal {
p.logf("http: proxy error: %v", err)
}
return err
}
err = proxyConn.SetDeadline(deadline)
if err != nil {
if p.Verbal {
p.logf("http: proxy error: %v", err)
}
return err
}
_, err = clientConn.Write([]byte("HTTP/1.0 200 OK\r\n\r\n"))
if err != nil {
if p.Verbal {
p.logf("http: proxy error: %v", err)
}
return err
}
go func() {
io.Copy(clientConn, proxyConn)
clientConn.Close()
proxyConn.Close()
}()
io.Copy(proxyConn, clientConn)
proxyConn.Close()
clientConn.Close()
return nil
}
func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) error {
if req.Method == "CONNECT" {
err := p.ProxyHTTPS(rw, req)
return err
} else {
err := p.ProxyHTTP(rw, req)
return err
}
}

View File

@@ -1,67 +0,0 @@
package statistic
import (
"encoding/json"
"net/http"
"imuslab.com/zoraxy/mod/utils"
)
/*
Handler.go
This script handles incoming request for loading the statistic of the day
*/
func (c *Collector) HandleTodayStatLoad(w http.ResponseWriter, r *http.Request) {
fast, err := utils.GetPara(r, "fast")
if err != nil {
fast = "false"
}
d := c.DailySummary
if fast == "true" {
//Only return the counter
exported := DailySummaryExport{
TotalRequest: d.TotalRequest,
ErrorRequest: d.ErrorRequest,
ValidRequest: d.ValidRequest,
}
js, _ := json.Marshal(exported)
utils.SendJSONResponse(w, string(js))
} else {
//Return everything
exported := DailySummaryExport{
TotalRequest: d.TotalRequest,
ErrorRequest: d.ErrorRequest,
ValidRequest: d.ValidRequest,
ForwardTypes: make(map[string]int),
RequestOrigin: make(map[string]int),
RequestClientIp: make(map[string]int),
}
// Export ForwardTypes sync.Map
d.ForwardTypes.Range(func(key, value interface{}) bool {
exported.ForwardTypes[key.(string)] = value.(int)
return true
})
// Export RequestOrigin sync.Map
d.RequestOrigin.Range(func(key, value interface{}) bool {
exported.RequestOrigin[key.(string)] = value.(int)
return true
})
// Export RequestClientIp sync.Map
d.RequestClientIp.Range(func(key, value interface{}) bool {
exported.RequestClientIp[key.(string)] = value.(int)
return true
})
js, _ := json.Marshal(exported)
utils.SendJSONResponse(w, string(js))
}
}

View File

@@ -1,186 +0,0 @@
package statistic
import (
"strings"
"sync"
"time"
"imuslab.com/zoraxy/mod/database"
)
/*
Statistic Package
This packet is designed to collection information
and store them for future analysis
*/
//Faststat, a interval summary for all collected data and avoid
//looping through every data everytime a overview is needed
type DailySummary struct {
TotalRequest int64 //Total request of the day
ErrorRequest int64 //Invalid request of the day, including error or not found
ValidRequest int64 //Valid request of the day
//Type counters
ForwardTypes *sync.Map //Map that hold the forward types
RequestOrigin *sync.Map //Map that hold [country ISO code]: visitor counter
RequestClientIp *sync.Map //Map that hold all unique request IPs
}
type RequestInfo struct {
IpAddr string
RequestOriginalCountryISOCode string
Succ bool
StatusCode int
ForwardType string
}
type CollectorOption struct {
Database *database.Database
}
type Collector struct {
rtdataStopChan chan bool
DailySummary *DailySummary
Option *CollectorOption
}
func NewStatisticCollector(option CollectorOption) (*Collector, error) {
option.Database.NewTable("stats")
//Create the collector object
thisCollector := Collector{
DailySummary: newDailySummary(),
Option: &option,
}
//Load the stat if exists for today
//This will exists if the program was forcefully restarted
year, month, day := time.Now().Date()
summary := thisCollector.LoadSummaryOfDay(year, month, day)
if summary != nil {
thisCollector.DailySummary = summary
}
//Schedule the realtime statistic clearing at midnight everyday
rtstatStopChan := thisCollector.ScheduleResetRealtimeStats()
thisCollector.rtdataStopChan = rtstatStopChan
return &thisCollector, nil
}
//Write the current in-memory summary to database file
func (c *Collector) SaveSummaryOfDay() {
//When it is called in 0:00am, make sure it is stored as yesterday key
t := time.Now().Add(-30 * time.Second)
summaryKey := t.Format("02_01_2006")
saveData := DailySummaryToExport(*c.DailySummary)
c.Option.Database.Write("stats", summaryKey, saveData)
}
//Load the summary of a day given
func (c *Collector) LoadSummaryOfDay(year int, month time.Month, day int) *DailySummary {
date := time.Date(year, time.Month(month), day, 0, 0, 0, 0, time.Local)
summaryKey := date.Format("02_01_2006")
targetSummaryExport := DailySummaryExport{}
c.Option.Database.Read("stats", summaryKey, &targetSummaryExport)
targetSummary := DailySummaryExportToSummary(targetSummaryExport)
return &targetSummary
}
//This function gives the current slot in the 288- 5 minutes interval of the day
func (c *Collector) GetCurrentRealtimeStatIntervalId() int {
now := time.Now()
startOfDay := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, time.Local).Unix()
secondsSinceStartOfDay := now.Unix() - startOfDay
interval := secondsSinceStartOfDay / (5 * 60)
return int(interval)
}
func (c *Collector) Close() {
//Stop the ticker
c.rtdataStopChan <- true
//Write the buffered data into database
c.SaveSummaryOfDay()
}
//Main function to record all the inbound traffics
//Note that this function run in go routine and might have concurrent R/W issue
//Please make sure there is no racing paramters in this function
func (c *Collector) RecordRequest(ri RequestInfo) {
go func() {
c.DailySummary.TotalRequest++
if ri.Succ {
c.DailySummary.ValidRequest++
} else {
c.DailySummary.ErrorRequest++
}
//Store the request info into correct types of maps
ft, ok := c.DailySummary.ForwardTypes.Load(ri.ForwardType)
if !ok {
c.DailySummary.ForwardTypes.Store(ri.ForwardType, 1)
} else {
c.DailySummary.ForwardTypes.Store(ri.ForwardType, ft.(int)+1)
}
originISO := strings.ToLower(ri.RequestOriginalCountryISOCode)
fo, ok := c.DailySummary.RequestOrigin.Load(originISO)
if !ok {
c.DailySummary.RequestOrigin.Store(originISO, 1)
} else {
c.DailySummary.RequestOrigin.Store(originISO, fo.(int)+1)
}
fi, ok := c.DailySummary.RequestClientIp.Load(ri.IpAddr)
if !ok {
c.DailySummary.RequestClientIp.Store(ri.IpAddr, 1)
} else {
c.DailySummary.RequestClientIp.Store(ri.IpAddr, fi.(int)+1)
}
}()
}
//nightly task
func (c *Collector) ScheduleResetRealtimeStats() chan bool {
doneCh := make(chan bool)
go func() {
defer close(doneCh)
for {
// calculate duration until next midnight
now := time.Now()
// Get midnight of the next day in the local time zone
midnight := time.Date(now.Year(), now.Month(), now.Day()+1, 0, 0, 0, 0, now.Location())
// Calculate the duration until midnight
duration := midnight.Sub(now)
select {
case <-time.After(duration):
// store daily summary to database and reset summary
c.SaveSummaryOfDay()
c.DailySummary = newDailySummary()
case <-doneCh:
// stop the routine
return
}
}
}()
return doneCh
}
func newDailySummary() *DailySummary {
return &DailySummary{
TotalRequest: 0,
ErrorRequest: 0,
ValidRequest: 0,
ForwardTypes: &sync.Map{},
RequestOrigin: &sync.Map{},
RequestClientIp: &sync.Map{},
}
}

View File

@@ -1,66 +0,0 @@
package statistic
import "sync"
type DailySummaryExport struct {
TotalRequest int64 //Total request of the day
ErrorRequest int64 //Invalid request of the day, including error or not found
ValidRequest int64 //Valid request of the day
ForwardTypes map[string]int
RequestOrigin map[string]int
RequestClientIp map[string]int
}
func DailySummaryToExport(summary DailySummary) DailySummaryExport {
export := DailySummaryExport{
TotalRequest: summary.TotalRequest,
ErrorRequest: summary.ErrorRequest,
ValidRequest: summary.ValidRequest,
ForwardTypes: make(map[string]int),
RequestOrigin: make(map[string]int),
RequestClientIp: make(map[string]int),
}
summary.ForwardTypes.Range(func(key, value interface{}) bool {
export.ForwardTypes[key.(string)] = value.(int)
return true
})
summary.RequestOrigin.Range(func(key, value interface{}) bool {
export.RequestOrigin[key.(string)] = value.(int)
return true
})
summary.RequestClientIp.Range(func(key, value interface{}) bool {
export.RequestClientIp[key.(string)] = value.(int)
return true
})
return export
}
func DailySummaryExportToSummary(export DailySummaryExport) DailySummary {
summary := DailySummary{
TotalRequest: export.TotalRequest,
ErrorRequest: export.ErrorRequest,
ValidRequest: export.ValidRequest,
ForwardTypes: &sync.Map{},
RequestOrigin: &sync.Map{},
RequestClientIp: &sync.Map{},
}
for k, v := range export.ForwardTypes {
summary.ForwardTypes.Store(k, v)
}
for k, v := range export.RequestOrigin {
summary.RequestOrigin.Store(k, v)
}
for k, v := range export.RequestClientIp {
summary.RequestClientIp.Store(k, v)
}
return summary
}

View File

@@ -1,60 +0,0 @@
package tlscert
import (
"path/filepath"
"strings"
)
//This remove the certificates in the list where either the
//public key or the private key is missing
func getCertPairs(certFiles []string) []string {
crtMap := make(map[string]bool)
keyMap := make(map[string]bool)
for _, filename := range certFiles {
if filepath.Ext(filename) == ".crt" {
crtMap[strings.TrimSuffix(filename, ".crt")] = true
} else if filepath.Ext(filename) == ".key" {
keyMap[strings.TrimSuffix(filename, ".key")] = true
}
}
var result []string
for domain := range crtMap {
if keyMap[domain] {
result = append(result, domain)
}
}
return result
}
//Get the cloest subdomain certificate from a list of domains
func matchClosestDomainCertificate(subdomain string, domains []string) string {
var matchingDomain string = ""
maxLength := 0
for _, domain := range domains {
if strings.HasSuffix(subdomain, "."+domain) && len(domain) > maxLength {
matchingDomain = domain
maxLength = len(domain)
}
}
return matchingDomain
}
//Check if a requesting domain is a subdomain of a given domain
func isSubdomain(subdomain, domain string) bool {
subdomainParts := strings.Split(subdomain, ".")
domainParts := strings.Split(domain, ".")
if len(subdomainParts) < len(domainParts) {
return false
}
for i := range domainParts {
if subdomainParts[len(subdomainParts)-1-i] != domainParts[len(domainParts)-1-i] {
return false
}
}
return true
}

View File

@@ -1,172 +0,0 @@
package tlscert
import (
"crypto/tls"
"crypto/x509"
"encoding/pem"
"io"
"io/ioutil"
"log"
"os"
"path/filepath"
"strings"
"imuslab.com/zoraxy/mod/utils"
)
type Manager struct {
CertStore string
verbal bool
}
func NewManager(certStore string) (*Manager, error) {
if !utils.FileExists(certStore) {
os.MkdirAll(certStore, 0775)
}
thisManager := Manager{
CertStore: certStore,
verbal: true,
}
return &thisManager, nil
}
func (m *Manager) ListCertDomains() ([]string, error) {
filenames, err := m.ListCerts()
if err != nil {
return []string{}, err
}
//Remove certificates where there are missing public key or private key
filenames = getCertPairs(filenames)
return filenames, nil
}
func (m *Manager) ListCerts() ([]string, error) {
certs, err := ioutil.ReadDir(m.CertStore)
if err != nil {
return []string{}, err
}
filenames := make([]string, 0, len(certs))
for _, cert := range certs {
if !cert.IsDir() {
filenames = append(filenames, cert.Name())
}
}
return filenames, nil
}
func (m *Manager) GetCert(helloInfo *tls.ClientHelloInfo) (*tls.Certificate, error) {
//Check if the domain corrisponding cert exists
pubKey := "./system/localhost.crt"
priKey := "./system/localhost.key"
if utils.FileExists(filepath.Join(m.CertStore, helloInfo.ServerName+".crt")) && utils.FileExists(filepath.Join(m.CertStore, helloInfo.ServerName+".key")) {
pubKey = filepath.Join(m.CertStore, helloInfo.ServerName+".crt")
priKey = filepath.Join(m.CertStore, helloInfo.ServerName+".key")
} else {
domainCerts, _ := m.ListCertDomains()
cloestDomainCert := matchClosestDomainCertificate(helloInfo.ServerName, domainCerts)
if cloestDomainCert != "" {
//There is a matching parent domain for this subdomain. Use this instead.
pubKey = filepath.Join(m.CertStore, cloestDomainCert+".crt")
priKey = filepath.Join(m.CertStore, cloestDomainCert+".key")
} else if m.DefaultCertExists() {
//Use default.crt and default.key
pubKey = filepath.Join(m.CertStore, "default.crt")
priKey = filepath.Join(m.CertStore, "default.key")
if m.verbal {
log.Println("No matching certificate found. Serving with default")
}
} else {
if m.verbal {
log.Println("Matching certificate not found. Serving with default. Requesting server name: ", helloInfo.ServerName)
}
}
}
//Load the cert and serve it
cer, err := tls.LoadX509KeyPair(pubKey, priKey)
if err != nil {
log.Println(err)
return nil, nil
}
return &cer, nil
}
//Check if both the default cert public key and private key exists
func (m *Manager) DefaultCertExists() bool {
return utils.FileExists(filepath.Join(m.CertStore, "default.crt")) && utils.FileExists(filepath.Join(m.CertStore, "default.key"))
}
//Check if the default cert exists returning seperate results for pubkey and prikey
func (m *Manager) DefaultCertExistsSep() (bool, bool) {
return utils.FileExists(filepath.Join(m.CertStore, "default.crt")), utils.FileExists(filepath.Join(m.CertStore, "default.key"))
}
//Delete the cert if exists
func (m *Manager) RemoveCert(domain string) error {
pubKey := filepath.Join(m.CertStore, domain+".crt")
priKey := filepath.Join(m.CertStore, domain+".key")
if utils.FileExists(pubKey) {
err := os.Remove(pubKey)
if err != nil {
return err
}
}
if utils.FileExists(priKey) {
err := os.Remove(priKey)
if err != nil {
return err
}
}
return nil
}
//Check if the given file is a valid TLS file
func IsValidTLSFile(file io.Reader) bool {
// Read the contents of the uploaded file
contents, err := io.ReadAll(file)
if err != nil {
// Handle the error
return false
}
// Parse the contents of the file as a PEM-encoded certificate or key
block, _ := pem.Decode(contents)
if block == nil {
// The file is not a valid PEM-encoded certificate or key
return false
}
// Parse the certificate or key
if strings.Contains(block.Type, "CERTIFICATE") {
// The file contains a certificate
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
// Handle the error
return false
}
// Check if the certificate is a valid TLS/SSL certificate
return cert.IsCA == false && cert.KeyUsage&x509.KeyUsageDigitalSignature != 0 && cert.KeyUsage&x509.KeyUsageKeyEncipherment != 0
} else if strings.Contains(block.Type, "PRIVATE KEY") {
// The file contains a private key
_, err := x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
// Handle the error
return false
}
return true
} else {
return false
}
}

View File

@@ -1,127 +0,0 @@
package upnp
import (
"errors"
"log"
"sync"
"time"
"gitlab.com/NebulousLabs/go-upnp"
)
/*
uPNP Module
This module handles uPNP Connections to the gateway router and create a port forward entry
for the host system at the given port (set with -port paramter)
*/
type UPnPClient struct {
Connection *upnp.IGD //UPnP conenction object
ExternalIP string //Storage of external IP address
RequiredPorts []int //All the required ports will be recored
PolicyNames sync.Map //Name for the required port nubmer
}
func NewUPNPClient() (*UPnPClient, error) {
//Create uPNP forwarding in the NAT router
log.Println("Discovering UPnP router in Local Area Network...")
d, err := upnp.Discover()
if err != nil {
return &UPnPClient{}, err
}
// discover external IP
ip, err := d.ExternalIP()
if err != nil {
return &UPnPClient{}, err
}
//Create the final obejcts
newUPnPObject := &UPnPClient{
Connection: d,
ExternalIP: ip,
RequiredPorts: []int{},
}
return newUPnPObject, nil
}
func (u *UPnPClient) ForwardPort(portNumber int, ruleName string) error {
log.Println("UPnP forwarding new port: ", portNumber, "for "+ruleName+" service")
//Check if port already forwarded
_, ok := u.PolicyNames.Load(portNumber)
if ok {
//Port already forward. Ignore this request
return errors.New("Port already forwarded")
}
// forward a port
err := u.Connection.Forward(uint16(portNumber), ruleName)
if err != nil {
return err
}
u.RequiredPorts = append(u.RequiredPorts, portNumber)
u.PolicyNames.Store(portNumber, ruleName)
return nil
}
func (u *UPnPClient) ClosePort(portNumber int) error {
//Check if port is opened
portOpened := false
newRequiredPort := []int{}
for _, thisPort := range u.RequiredPorts {
if thisPort != portNumber {
newRequiredPort = append(newRequiredPort, thisPort)
} else {
portOpened = true
}
}
if portOpened {
//Update the port list
u.RequiredPorts = newRequiredPort
// Close the port
log.Println("Closing UPnP Port Forward: ", portNumber)
err := u.Connection.Clear(uint16(portNumber))
//Delete the name registry
u.PolicyNames.Delete(portNumber)
if err != nil {
log.Println(err)
return err
}
}
return nil
}
// Renew forward rules, prevent router lease time from flushing the Upnp config
func (u *UPnPClient) RenewForwardRules() {
portsToRenew := u.RequiredPorts
for _, thisPort := range portsToRenew {
ruleName, ok := u.PolicyNames.Load(thisPort)
if !ok {
continue
}
u.ClosePort(thisPort)
time.Sleep(100 * time.Millisecond)
u.ForwardPort(thisPort, ruleName.(string))
}
log.Println("UPnP Port Forward rule renew completed")
}
func (u *UPnPClient) Close() {
//Shutdown the default UPnP Object
if u != nil {
for _, portNumber := range u.RequiredPorts {
err := u.Connection.Clear(uint16(portNumber))
if err != nil {
log.Println(err)
}
}
}
}

View File

@@ -1,227 +0,0 @@
package uptime
import (
"encoding/json"
"log"
"net/http"
"time"
"imuslab.com/zoraxy/mod/utils"
)
type Record struct {
Timestamp int64
ID string
Name string
URL string
Protocol string
Online bool
StatusCode int
Latency int64
}
type Target struct {
ID string
Name string
URL string
Protocol string
}
type Config struct {
Targets []*Target
Interval int
MaxRecordsStore int
}
type Monitor struct {
Config *Config
OnlineStatusLog map[string][]*Record
}
// Default configs
var exampleTarget = Target{
ID: "example",
Name: "Example",
URL: "example.com",
Protocol: "https",
}
//Create a new uptime monitor
func NewUptimeMonitor(config *Config) (*Monitor, error) {
//Create new monitor object
thisMonitor := Monitor{
Config: config,
OnlineStatusLog: map[string][]*Record{},
}
//Start the endpoint listener
ticker := time.NewTicker(time.Duration(config.Interval) * time.Second)
done := make(chan bool)
//Start the uptime check once first before entering loop
thisMonitor.ExecuteUptimeCheck()
go func() {
for {
select {
case <-done:
return
case t := <-ticker.C:
log.Println("Uptime updated - ", t.Unix())
thisMonitor.ExecuteUptimeCheck()
}
}
}()
return &thisMonitor, nil
}
func (m *Monitor) ExecuteUptimeCheck() {
for _, target := range m.Config.Targets {
//For each target to check online, do the following
var thisRecord Record
if target.Protocol == "http" || target.Protocol == "https" {
online, laterncy, statusCode := getWebsiteStatusWithLatency(target.URL)
thisRecord = Record{
Timestamp: time.Now().Unix(),
ID: target.ID,
Name: target.Name,
URL: target.URL,
Protocol: target.Protocol,
Online: online,
StatusCode: statusCode,
Latency: laterncy,
}
//fmt.Println(thisRecord)
} else {
log.Println("Unknown protocol: " + target.Protocol + ". Skipping")
continue
}
thisRecords, ok := m.OnlineStatusLog[target.ID]
if !ok {
//First record. Create the array
m.OnlineStatusLog[target.ID] = []*Record{&thisRecord}
} else {
//Append to the previous record
thisRecords = append(thisRecords, &thisRecord)
//Check if the record is longer than the logged record. If yes, clear out the old records
if len(thisRecords) > m.Config.MaxRecordsStore {
thisRecords = thisRecords[1:]
}
m.OnlineStatusLog[target.ID] = thisRecords
}
}
//TODO: Write results to db
}
func (m *Monitor) AddTargetToMonitor(target *Target) {
// Add target to Config
m.Config.Targets = append(m.Config.Targets, target)
// Add target to OnlineStatusLog
m.OnlineStatusLog[target.ID] = []*Record{}
}
func (m *Monitor) RemoveTargetFromMonitor(targetId string) {
// Remove target from Config
for i, target := range m.Config.Targets {
if target.ID == targetId {
m.Config.Targets = append(m.Config.Targets[:i], m.Config.Targets[i+1:]...)
break
}
}
// Remove target from OnlineStatusLog
delete(m.OnlineStatusLog, targetId)
}
//Scan the config target. If a target exists in m.OnlineStatusLog no longer
//exists in m.Monitor.Config.Targets, it remove it from the log as well.
func (m *Monitor) CleanRecords() {
// Create a set of IDs for all targets in the config
targetIDs := make(map[string]bool)
for _, target := range m.Config.Targets {
targetIDs[target.ID] = true
}
// Iterate over all log entries and remove any that have a target ID that
// is not in the set of current target IDs
newStatusLog := m.OnlineStatusLog
for id, _ := range m.OnlineStatusLog {
_, idExistsInTargets := targetIDs[id]
if !idExistsInTargets {
delete(newStatusLog, id)
}
}
m.OnlineStatusLog = newStatusLog
}
/*
Web Interface Handler
*/
func (m *Monitor) HandleUptimeLogRead(w http.ResponseWriter, r *http.Request) {
id, _ := utils.GetPara(r, "id")
if id == "" {
js, _ := json.Marshal(m.OnlineStatusLog)
w.Header().Set("Content-Type", "application/json")
w.Write(js)
} else {
//Check if that id exists
log, ok := m.OnlineStatusLog[id]
if !ok {
http.NotFound(w, r)
return
}
js, _ := json.MarshalIndent(log, "", " ")
w.Header().Set("Content-Type", "application/json")
w.Write(js)
}
}
/*
Utilities
*/
// Get website stauts with latency given URL, return is conn succ and its latency and status code
func getWebsiteStatusWithLatency(url string) (bool, int64, int) {
start := time.Now().UnixNano() / int64(time.Millisecond)
statusCode, err := getWebsiteStatus(url)
end := time.Now().UnixNano() / int64(time.Millisecond)
if err != nil {
log.Println(err.Error())
return false, 0, 0
} else {
diff := end - start
succ := false
if statusCode >= 200 && statusCode < 300 {
//OK
succ = true
} else if statusCode >= 300 && statusCode < 400 {
//Redirection code
succ = true
} else {
succ = false
}
return succ, diff, statusCode
}
}
func getWebsiteStatus(url string) (int, error) {
resp, err := http.Get(url)
if err != nil {
return 0, err
}
status_code := resp.StatusCode
return status_code, nil
}

View File

@@ -1,16 +0,0 @@
package utils
import "strconv"
func StringToInt64(number string) (int64, error) {
i, err := strconv.ParseInt(number, 10, 64)
if err != nil {
return -1, err
}
return i, nil
}
func Int64ToString(number int64) string {
convedNumber := strconv.FormatInt(number, 10)
return convedNumber
}

View File

@@ -1,19 +0,0 @@
package utils
import (
"net/http"
)
/*
Web Template Generator
This is the main system core module that perform function similar to what PHP did.
To replace part of the content of any file, use {{paramter}} to replace it.
*/
func SendHTMLResponse(w http.ResponseWriter, msg string) {
w.Header().Set("Content-Type", "text/html")
w.Write([]byte(msg))
}

View File

@@ -1,175 +0,0 @@
package utils
import (
"bufio"
"encoding/base64"
"errors"
"io"
"log"
"net/http"
"os"
"strings"
"time"
)
/*
Common
Some commonly used functions in ArozOS
*/
// Response related
func SendTextResponse(w http.ResponseWriter, msg string) {
w.Write([]byte(msg))
}
// Send JSON response, with an extra json header
func SendJSONResponse(w http.ResponseWriter, json string) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(json))
}
func SendErrorResponse(w http.ResponseWriter, errMsg string) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte("{\"error\":\"" + errMsg + "\"}"))
}
func SendOK(w http.ResponseWriter) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte("\"OK\""))
}
/*
The paramter move function (mv)
You can find similar things in the PHP version of ArOZ Online Beta. You need to pass in
r (HTTP Request Object)
getParamter (string, aka $_GET['This string])
Will return
Paramter string (if any)
Error (if error)
*/
/*
func Mv(r *http.Request, getParamter string, postMode bool) (string, error) {
if postMode == false {
//Access the paramter via GET
keys, ok := r.URL.Query()[getParamter]
if !ok || len(keys[0]) < 1 {
//log.Println("Url Param " + getParamter +" is missing")
return "", errors.New("GET paramter " + getParamter + " not found or it is empty")
}
// Query()["key"] will return an array of items,
// we only want the single item.
key := keys[0]
return string(key), nil
} else {
//Access the parameter via POST
r.ParseForm()
x := r.Form.Get(getParamter)
if len(x) == 0 || x == "" {
return "", errors.New("POST paramter " + getParamter + " not found or it is empty")
}
return string(x), nil
}
}
*/
// Get GET parameter
func GetPara(r *http.Request, key string) (string, error) {
keys, ok := r.URL.Query()[key]
if !ok || len(keys[0]) < 1 {
return "", errors.New("invalid " + key + " given")
} else {
return keys[0], nil
}
}
// Get POST paramter
func PostPara(r *http.Request, key string) (string, error) {
r.ParseForm()
x := r.Form.Get(key)
if x == "" {
return "", errors.New("invalid " + key + " given")
} else {
return x, nil
}
}
func FileExists(filename string) bool {
_, err := os.Stat(filename)
if os.IsNotExist(err) {
return false
}
return true
}
func IsDir(path string) bool {
if FileExists(path) == false {
return false
}
fi, err := os.Stat(path)
if err != nil {
log.Fatal(err)
return false
}
switch mode := fi.Mode(); {
case mode.IsDir():
return true
case mode.IsRegular():
return false
}
return false
}
func TimeToString(targetTime time.Time) string {
return targetTime.Format("2006-01-02 15:04:05")
}
func LoadImageAsBase64(filepath string) (string, error) {
if !FileExists(filepath) {
return "", errors.New("File not exists")
}
f, _ := os.Open(filepath)
reader := bufio.NewReader(f)
content, _ := io.ReadAll(reader)
encoded := base64.StdEncoding.EncodeToString(content)
return string(encoded), nil
}
// Use for redirections
func ConstructRelativePathFromRequestURL(requestURI string, redirectionLocation string) string {
if strings.Count(requestURI, "/") == 1 {
//Already root level
return redirectionLocation
}
for i := 0; i < strings.Count(requestURI, "/")-1; i++ {
redirectionLocation = "../" + redirectionLocation
}
return redirectionLocation
}
// Check if given string in a given slice
func StringInArray(arr []string, str string) bool {
for _, a := range arr {
if a == str {
return true
}
}
return false
}
func StringInArrayIgnoreCase(arr []string, str string) bool {
smallArray := []string{}
for _, item := range arr {
smallArray = append(smallArray, strings.ToLower(item))
}
return StringInArray(smallArray, strings.ToLower(str))
}

View File

@@ -1,20 +0,0 @@
The MIT License (MIT)
Copyright (c) 2014 Koding, Inc.
Permission is hereby granted, free of charge, to any person obtaining a copy of
this software and associated documentation files (the "Software"), to deal in
the Software without restriction, including without limitation the rights to
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
the Software, and to permit persons to whom the Software is furnished to do so,
subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

View File

@@ -1,54 +0,0 @@
# WebsocketProxy [![GoDoc](https://godoc.org/github.com/koding/websocketproxy?status.svg)](https://godoc.org/github.com/koding/websocketproxy) [![Build Status](https://travis-ci.org/koding/websocketproxy.svg)](https://travis-ci.org/koding/websocketproxy)
WebsocketProxy is an http.Handler interface build on top of
[gorilla/websocket](https://github.com/gorilla/websocket) that you can plug
into your existing Go webserver to provide WebSocket reverse proxy.
## Install
```bash
go get github.com/koding/websocketproxy
```
## Example
Below is a simple server that proxies to the given backend URL
```go
package main
import (
"flag"
"net/http"
"net/url"
"github.com/koding/websocketproxy"
)
var (
flagBackend = flag.String("backend", "", "Backend URL for proxying")
)
func main() {
u, err := url.Parse(*flagBackend)
if err != nil {
log.Fatalln(err)
}
err = http.ListenAndServe(":80", websocketproxy.NewProxy(u))
if err != nil {
log.Fatalln(err)
}
}
```
Save it as `proxy.go` and run as:
```bash
go run proxy.go -backend ws://example.com:3000
```
Now all incoming WebSocket requests coming to this server will be proxied to
`ws://example.com:3000`

View File

@@ -1,239 +0,0 @@
// Package websocketproxy is a reverse proxy for WebSocket connections.
package websocketproxy
import (
"fmt"
"io"
"log"
"net"
"net/http"
"net/url"
"strings"
"github.com/gorilla/websocket"
)
var (
// DefaultUpgrader specifies the parameters for upgrading an HTTP
// connection to a WebSocket connection.
DefaultUpgrader = &websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
}
// DefaultDialer is a dialer with all fields set to the default zero values.
DefaultDialer = websocket.DefaultDialer
)
// WebsocketProxy is an HTTP Handler that takes an incoming WebSocket
// connection and proxies it to another server.
type WebsocketProxy struct {
// Director, if non-nil, is a function that may copy additional request
// headers from the incoming WebSocket connection into the output headers
// which will be forwarded to another server.
Director func(incoming *http.Request, out http.Header)
// Backend returns the backend URL which the proxy uses to reverse proxy
// the incoming WebSocket connection. Request is the initial incoming and
// unmodified request.
Backend func(*http.Request) *url.URL
// Upgrader specifies the parameters for upgrading a incoming HTTP
// connection to a WebSocket connection. If nil, DefaultUpgrader is used.
Upgrader *websocket.Upgrader
// Dialer contains options for connecting to the backend WebSocket server.
// If nil, DefaultDialer is used.
Dialer *websocket.Dialer
Verbal bool
}
// ProxyHandler returns a new http.Handler interface that reverse proxies the
// request to the given target.
func ProxyHandler(target *url.URL) http.Handler { return NewProxy(target) }
// NewProxy returns a new Websocket reverse proxy that rewrites the
// URL's to the scheme, host and base path provider in target.
func NewProxy(target *url.URL) *WebsocketProxy {
backend := func(r *http.Request) *url.URL {
// Shallow copy
u := *target
u.Fragment = r.URL.Fragment
u.Path = r.URL.Path
u.RawQuery = r.URL.RawQuery
return &u
}
return &WebsocketProxy{Backend: backend, Verbal: false}
}
// ServeHTTP implements the http.Handler that proxies WebSocket connections.
func (w *WebsocketProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
if w.Backend == nil {
log.Println("websocketproxy: backend function is not defined")
http.Error(rw, "internal server error (code: 1)", http.StatusInternalServerError)
return
}
backendURL := w.Backend(req)
if backendURL == nil {
log.Println("websocketproxy: backend URL is nil")
http.Error(rw, "internal server error (code: 2)", http.StatusInternalServerError)
return
}
dialer := w.Dialer
if w.Dialer == nil {
dialer = DefaultDialer
}
// Pass headers from the incoming request to the dialer to forward them to
// the final destinations.
requestHeader := http.Header{}
if origin := req.Header.Get("Origin"); origin != "" {
requestHeader.Add("Origin", origin)
}
for _, prot := range req.Header[http.CanonicalHeaderKey("Sec-WebSocket-Protocol")] {
requestHeader.Add("Sec-WebSocket-Protocol", prot)
}
for _, cookie := range req.Header[http.CanonicalHeaderKey("Cookie")] {
requestHeader.Add("Cookie", cookie)
}
if req.Host != "" {
requestHeader.Set("Host", req.Host)
}
// Pass X-Forwarded-For headers too, code below is a part of
// httputil.ReverseProxy. See http://en.wikipedia.org/wiki/X-Forwarded-For
// for more information
// TODO: use RFC7239 http://tools.ietf.org/html/rfc7239
if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
// If we aren't the first proxy retain prior
// X-Forwarded-For information as a comma+space
// separated list and fold multiple headers into one.
if prior, ok := req.Header["X-Forwarded-For"]; ok {
clientIP = strings.Join(prior, ", ") + ", " + clientIP
}
requestHeader.Set("X-Forwarded-For", clientIP)
}
// Set the originating protocol of the incoming HTTP request. The SSL might
// be terminated on our site and because we doing proxy adding this would
// be helpful for applications on the backend.
requestHeader.Set("X-Forwarded-Proto", "http")
if req.TLS != nil {
requestHeader.Set("X-Forwarded-Proto", "https")
}
// Enable the director to copy any additional headers it desires for
// forwarding to the remote server.
if w.Director != nil {
w.Director(req, requestHeader)
}
// Connect to the backend URL, also pass the headers we get from the requst
// together with the Forwarded headers we prepared above.
// TODO: support multiplexing on the same backend connection instead of
// opening a new TCP connection time for each request. This should be
// optional:
// http://tools.ietf.org/html/draft-ietf-hybi-websocket-multiplexing-01
connBackend, resp, err := dialer.Dial(backendURL.String(), requestHeader)
if err != nil {
log.Printf("websocketproxy: couldn't dial to remote backend url %s", err)
if resp != nil {
// If the WebSocket handshake fails, ErrBadHandshake is returned
// along with a non-nil *http.Response so that callers can handle
// redirects, authentication, etcetera.
if err := copyResponse(rw, resp); err != nil {
log.Printf("websocketproxy: couldn't write response after failed remote backend handshake: %s", err)
}
} else {
http.Error(rw, http.StatusText(http.StatusServiceUnavailable), http.StatusServiceUnavailable)
}
return
}
defer connBackend.Close()
upgrader := w.Upgrader
if w.Upgrader == nil {
upgrader = DefaultUpgrader
}
// Only pass those headers to the upgrader.
upgradeHeader := http.Header{}
if hdr := resp.Header.Get("Sec-Websocket-Protocol"); hdr != "" {
upgradeHeader.Set("Sec-Websocket-Protocol", hdr)
}
if hdr := resp.Header.Get("Set-Cookie"); hdr != "" {
upgradeHeader.Set("Set-Cookie", hdr)
}
// Now upgrade the existing incoming request to a WebSocket connection.
// Also pass the header that we gathered from the Dial handshake.
connPub, err := upgrader.Upgrade(rw, req, upgradeHeader)
if err != nil {
log.Printf("websocketproxy: couldn't upgrade %s", err)
return
}
defer connPub.Close()
errClient := make(chan error, 1)
errBackend := make(chan error, 1)
replicateWebsocketConn := func(dst, src *websocket.Conn, errc chan error) {
for {
msgType, msg, err := src.ReadMessage()
if err != nil {
m := websocket.FormatCloseMessage(websocket.CloseNormalClosure, fmt.Sprintf("%v", err))
if e, ok := err.(*websocket.CloseError); ok {
if e.Code != websocket.CloseNoStatusReceived {
m = websocket.FormatCloseMessage(e.Code, e.Text)
}
}
errc <- err
dst.WriteMessage(websocket.CloseMessage, m)
break
}
err = dst.WriteMessage(msgType, msg)
if err != nil {
errc <- err
break
}
}
}
go replicateWebsocketConn(connPub, connBackend, errClient)
go replicateWebsocketConn(connBackend, connPub, errBackend)
var message string
select {
case err = <-errClient:
message = "websocketproxy: Error when copying from backend to client: %v"
case err = <-errBackend:
message = "websocketproxy: Error when copying from client to backend: %v"
}
if e, ok := err.(*websocket.CloseError); !ok || e.Code == websocket.CloseAbnormalClosure {
if w.Verbal {
//Only print message on verbal mode
log.Printf(message, err)
}
}
}
func copyHeader(dst, src http.Header) {
for k, vv := range src {
for _, v := range vv {
dst.Add(k, v)
}
}
}
func copyResponse(rw http.ResponseWriter, resp *http.Response) error {
copyHeader(rw.Header(), resp.Header)
rw.WriteHeader(resp.StatusCode)
defer resp.Body.Close()
_, err := io.Copy(rw, resp.Body)
return err
}

View File

@@ -1,130 +0,0 @@
package websocketproxy
import (
"log"
"net/http"
"net/url"
"testing"
"time"
"github.com/gorilla/websocket"
)
var (
serverURL = "ws://127.0.0.1:7777"
backendURL = "ws://127.0.0.1:8888"
)
func TestProxy(t *testing.T) {
// websocket proxy
supportedSubProtocols := []string{"test-protocol"}
upgrader := &websocket.Upgrader{
ReadBufferSize: 4096,
WriteBufferSize: 4096,
CheckOrigin: func(r *http.Request) bool {
return true
},
Subprotocols: supportedSubProtocols,
}
u, _ := url.Parse(backendURL)
proxy := NewProxy(u)
proxy.Upgrader = upgrader
mux := http.NewServeMux()
mux.Handle("/proxy", proxy)
go func() {
if err := http.ListenAndServe(":7777", mux); err != nil {
t.Fatal("ListenAndServe: ", err)
}
}()
time.Sleep(time.Millisecond * 100)
// backend echo server
go func() {
mux2 := http.NewServeMux()
mux2.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
// Don't upgrade if original host header isn't preserved
if r.Host != "127.0.0.1:7777" {
log.Printf("Host header set incorrectly. Expecting 127.0.0.1:7777 got %s", r.Host)
return
}
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Println(err)
return
}
messageType, p, err := conn.ReadMessage()
if err != nil {
return
}
if err = conn.WriteMessage(messageType, p); err != nil {
return
}
})
err := http.ListenAndServe(":8888", mux2)
if err != nil {
t.Fatal("ListenAndServe: ", err)
}
}()
time.Sleep(time.Millisecond * 100)
// let's us define two subprotocols, only one is supported by the server
clientSubProtocols := []string{"test-protocol", "test-notsupported"}
h := http.Header{}
for _, subprot := range clientSubProtocols {
h.Add("Sec-WebSocket-Protocol", subprot)
}
// frontend server, dial now our proxy, which will reverse proxy our
// message to the backend websocket server.
conn, resp, err := websocket.DefaultDialer.Dial(serverURL+"/proxy", h)
if err != nil {
t.Fatal(err)
}
// check if the server really accepted only the first one
in := func(desired string) bool {
for _, prot := range resp.Header[http.CanonicalHeaderKey("Sec-WebSocket-Protocol")] {
if desired == prot {
return true
}
}
return false
}
if !in("test-protocol") {
t.Error("test-protocol should be available")
}
if in("test-notsupported") {
t.Error("test-notsupported should be not recevied from the server.")
}
// now write a message and send it to the backend server (which goes trough
// proxy..)
msg := "hello kite"
err = conn.WriteMessage(websocket.TextMessage, []byte(msg))
if err != nil {
t.Error(err)
}
messageType, p, err := conn.ReadMessage()
if err != nil {
t.Error(err)
}
if messageType != websocket.TextMessage {
t.Error("incoming message type is not Text")
}
if msg != string(p) {
t.Errorf("expecting: %s, got: %s", msg, string(p))
}
}