Optimized memory usage and root routing

+ Added unset subdomain custom redirection feature #46
+ Optimized memory usage by space time tradeoff in geoip lookup to fix #52
+ Replaced all stori/go.uuid to google/uuid for security reasons #55
This commit is contained in:
Toby Chui
2023-08-27 10:18:49 +08:00
parent 4f7f60188f
commit 73ab9ca778
22 changed files with 9701 additions and 212 deletions

View File

@@ -20,13 +20,16 @@ type Store struct {
WhitelistEnabled bool
geodb [][]string //Parsed geodb list
geodbIpv6 [][]string //Parsed geodb list for ipv6
geotrie *trie
geotrieIpv6 *trie
geotrie *trie
geotrieIpv6 *trie
//geoipCache sync.Map
sysdb *database.Database
option *StoreOptions
}
sysdb *database.Database
type StoreOptions struct {
AllowSlowIpv4LookUp bool
AllowSloeIpv6Lookup bool
}
type CountryInfo struct {
@@ -34,7 +37,7 @@ type CountryInfo struct {
ContinetCode string
}
func NewGeoDb(sysdb *database.Database) (*Store, error) {
func NewGeoDb(sysdb *database.Database, option *StoreOptions) (*Store, error) {
parsedGeoData, err := parseCSV(geoipv4)
if err != nil {
return nil, err
@@ -79,14 +82,25 @@ func NewGeoDb(sysdb *database.Database) (*Store, error) {
log.Println("Database pointer set to nil: Entering debug mode")
}
var ipv4Trie *trie
if !option.AllowSlowIpv4LookUp {
ipv4Trie = constrctTrieTree(parsedGeoData)
}
var ipv6Trie *trie
if !option.AllowSloeIpv6Lookup {
ipv6Trie = constrctTrieTree(parsedGeoDataIpv6)
}
return &Store{
BlacklistEnabled: blacklistEnabled,
WhitelistEnabled: whitelistEnabled,
geodb: parsedGeoData,
geotrie: constrctTrieTree(parsedGeoData),
geotrie: ipv4Trie,
geodbIpv6: parsedGeoDataIpv6,
geotrieIpv6: constrctTrieTree(parsedGeoDataIpv6),
geotrieIpv6: ipv6Trie,
sysdb: sysdb,
option: option,
}, nil
}
@@ -106,6 +120,7 @@ func (s *Store) ResolveCountryCodeFromIP(ipstring string) (*CountryInfo, error)
CountryIsoCode: cc,
ContinetCode: "",
}, nil
}
func (s *Store) Close() {

View File

@@ -41,7 +41,10 @@ func TestTrieConstruct(t *testing.T) {
func TestResolveCountryCodeFromIP(t *testing.T) {
// Create a new store
store, err := geodb.NewGeoDb(nil)
store, err := geodb.NewGeoDb(nil, &geodb.StoreOptions{
false,
false,
})
if err != nil {
t.Errorf("error creating store: %v", err)
return

View File

@@ -4,7 +4,6 @@ import (
"bytes"
"encoding/csv"
"io"
"net"
"strings"
)
@@ -26,9 +25,17 @@ func (s *Store) search(ip string) string {
//Search in geotrie tree
cc := ""
if IsIPv6(ip) {
cc = s.geotrieIpv6.search(ip)
if s.geotrieIpv6 == nil {
cc = s.slowSearchIpv6(ip)
} else {
cc = s.geotrieIpv6.search(ip)
}
} else {
cc = s.geotrie.search(ip)
if s.geotrie == nil {
cc = s.slowSearchIpv4(ip)
} else {
cc = s.geotrie.search(ip)
}
}
/*
@@ -69,27 +76,3 @@ func parseCSV(content []byte) ([][]string, error) {
}
return records, nil
}
// Check if a ip string is within the range of two others
func isIPInRange(ip, start, end string) bool {
ipAddr := net.ParseIP(ip)
if ipAddr == nil {
return false
}
startAddr := net.ParseIP(start)
if startAddr == nil {
return false
}
endAddr := net.ParseIP(end)
if endAddr == nil {
return false
}
if ipAddr.To4() == nil || startAddr.To4() == nil || endAddr.To4() == nil {
return false
}
return bytes.Compare(ipAddr.To4(), startAddr.To4()) >= 0 && bytes.Compare(ipAddr.To4(), endAddr.To4()) <= 0
}

View File

@@ -0,0 +1,81 @@
package geodb
import (
"errors"
"math/big"
"net"
)
/*
slowSearch.go
This script implement the slow search method for ip to country code
lookup. If you have the memory allocation for near O(1) lookup,
you should not be using slow search mode.
*/
func ipv4ToUInt32(ip net.IP) uint32 {
ip = ip.To4()
return uint32(ip[0])<<24 | uint32(ip[1])<<16 | uint32(ip[2])<<8 | uint32(ip[3])
}
func isIPv4InRange(startIP, endIP, testIP string) (bool, error) {
start := net.ParseIP(startIP)
end := net.ParseIP(endIP)
test := net.ParseIP(testIP)
if start == nil || end == nil || test == nil {
return false, errors.New("invalid IP address format")
}
startUint := ipv4ToUInt32(start)
endUint := ipv4ToUInt32(end)
testUint := ipv4ToUInt32(test)
return testUint >= startUint && testUint <= endUint, nil
}
func isIPv6InRange(startIP, endIP, testIP string) (bool, error) {
start := net.ParseIP(startIP)
end := net.ParseIP(endIP)
test := net.ParseIP(testIP)
if start == nil || end == nil || test == nil {
return false, errors.New("invalid IP address format")
}
startInt := new(big.Int).SetBytes(start.To16())
endInt := new(big.Int).SetBytes(end.To16())
testInt := new(big.Int).SetBytes(test.To16())
return testInt.Cmp(startInt) >= 0 && testInt.Cmp(endInt) <= 0, nil
}
// Slow country code lookup for
func (s *Store) slowSearchIpv4(ipAddr string) string {
for _, ipRange := range s.geodb {
startIp := ipRange[0]
endIp := ipRange[1]
cc := ipRange[2]
inRange, _ := isIPv4InRange(startIp, endIp, ipAddr)
if inRange {
return cc
}
}
return ""
}
func (s *Store) slowSearchIpv6(ipAddr string) string {
for _, ipRange := range s.geodbIpv6 {
startIp := ipRange[0]
endIp := ipRange[1]
cc := ipRange[2]
inRange, _ := isIPv6InRange(startIp, endIp, ipAddr)
if inRange {
return cc
}
}
return ""
}

View File

@@ -1,15 +1,12 @@
package geodb
import (
"fmt"
"math"
"net"
"strconv"
"strings"
)
type trie_Node struct {
childrens [2]*trie_Node
ends bool
cc string
}
@@ -18,7 +15,7 @@ type trie struct {
root *trie_Node
}
func ipToBitString(ip string) string {
func ipToBytes(ip string) []byte {
// Parse the IP address string into a net.IP object
parsedIP := net.ParseIP(ip)
@@ -29,49 +26,7 @@ func ipToBitString(ip string) string {
ipBytes = parsedIP.To16()
}
// Convert each byte in the IP address to its 8-bit binary representation
var result []string
for _, b := range ipBytes {
result = append(result, fmt.Sprintf("%08b", b))
}
// Join the binary representation of each byte with dots to form the final bit string
return strings.Join(result, "")
}
func bitStringToIp(bitString string) string {
// Check if the bit string represents an IPv4 or IPv6 address
isIPv4 := len(bitString) == 32
// Split the bit string into 8-bit segments
segments := make([]string, 0)
if isIPv4 {
for i := 0; i < 4; i++ {
segments = append(segments, bitString[i*8:(i+1)*8])
}
} else {
for i := 0; i < 16; i++ {
segments = append(segments, bitString[i*8:(i+1)*8])
}
}
// Convert each segment to its decimal equivalent
decimalSegments := make([]int, len(segments))
for i, s := range segments {
val, _ := strconv.ParseInt(s, 2, 64)
decimalSegments[i] = int(val)
}
// Construct the IP address string based on the type (IPv4 or IPv6)
if isIPv4 {
return fmt.Sprintf("%d.%d.%d.%d", decimalSegments[0], decimalSegments[1], decimalSegments[2], decimalSegments[3])
} else {
ip := make(net.IP, net.IPv6len)
for i := 0; i < net.IPv6len; i++ {
ip[i] = byte(decimalSegments[i])
}
return ip.String()
}
return ipBytes
}
// inititlaizing a new trie
@@ -83,20 +38,39 @@ func newTrie() *trie {
// Passing words to trie
func (t *trie) insert(ipAddr string, cc string) {
word := ipToBitString(ipAddr)
ipBytes := ipToBytes(ipAddr)
current := t.root
for _, wr := range word {
index := wr - '0'
if current.childrens[index] == nil {
current.childrens[index] = &trie_Node{
childrens: [2]*trie_Node{},
ends: false,
cc: cc,
for _, b := range ipBytes {
//For each byte in the ip address
//each byte is 8 bit
for j := 0; j < 8; j++ {
bitwise := (b&uint8(math.Pow(float64(2), float64(j))) > 0)
bit := 0b0000
if bitwise {
bit = 0b0001
}
if current.childrens[bit] == nil {
current.childrens[bit] = &trie_Node{
childrens: [2]*trie_Node{},
cc: cc,
}
}
current = current.childrens[bit]
}
current = current.childrens[index]
}
current.ends = true
/*
for i := 63; i >= 0; i-- {
bit := (ipInt64 >> uint(i)) & 1
if current.childrens[bit] == nil {
current.childrens[bit] = &trie_Node{
childrens: [2]*trie_Node{},
cc: cc,
}
}
current = current.childrens[bit]
}
*/
}
func isReservedIP(ip string) bool {
@@ -126,16 +100,34 @@ func (t *trie) search(ipAddr string) string {
if isReservedIP(ipAddr) {
return ""
}
word := ipToBitString(ipAddr)
ipBytes := ipToBytes(ipAddr)
current := t.root
for _, wr := range word {
index := wr - '0'
if current.childrens[index] == nil {
return current.cc
for _, b := range ipBytes {
//For each byte in the ip address
//each byte is 8 bit
for j := 0; j < 8; j++ {
bitwise := (b&uint8(math.Pow(float64(2), float64(j))) > 0)
bit := 0b0000
if bitwise {
bit = 0b0001
}
if current.childrens[bit] == nil {
return current.cc
}
current = current.childrens[bit]
}
current = current.childrens[index]
}
if current.ends {
/*
for i := 63; i >= 0; i-- {
bit := (ipInt64 >> uint(i)) & 1
if current.childrens[bit] == nil {
return current.cc
}
current = current.childrens[bit]
}
*/
if len(current.childrens) == 0 {
return current.cc
}