diff --git a/src/mod/auth/plugin_middleware.go b/src/mod/auth/plugin_middleware.go index 507d42e..568c801 100644 --- a/src/mod/auth/plugin_middleware.go +++ b/src/mod/auth/plugin_middleware.go @@ -3,60 +3,92 @@ package auth import ( + "errors" + "fmt" "net/http" "strings" ) +const ( + PLUGIN_API_PREFIX = "/plugin" +) + +type PluginMiddlewareOptions struct { + DeniedHandler http.HandlerFunc //Thing(s) to do when request is rejected + ApiKeyManager *APIKeyManager + TargetMux *http.ServeMux +} + // PluginAuthMiddleware provides authentication middleware for plugin API requests type PluginAuthMiddleware struct { - apiKeyManager *APIKeyManager + option PluginMiddlewareOptions + endpoints map[string]http.HandlerFunc } // NewPluginAuthMiddleware creates a new plugin authentication middleware -func NewPluginAuthMiddleware(apiKeyManager *APIKeyManager) *PluginAuthMiddleware { +func NewPluginAuthMiddleware(option PluginMiddlewareOptions) *PluginAuthMiddleware { return &PluginAuthMiddleware{ - apiKeyManager: apiKeyManager, + option: option, + endpoints: make(map[string]http.HandlerFunc), } } -// WrapHandler wraps an HTTP handler with plugin authentication middleware -func (m *PluginAuthMiddleware) WrapHandler(endpoint string, handler http.HandlerFunc) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - // First, remove any existing plugin authentication headers - r.Header.Del("X-Zoraxy-Plugin-ID") - r.Header.Del("X-Zoraxy-Plugin-Auth") - - // Check for API key in the Authorization header - authHeader := r.Header.Get("Authorization") - if authHeader == "" { - // No authorization header, proceed with normal authentication - handler(w, r) - return - } - - // Check if it's a plugin API key (Bearer token) - if !strings.HasPrefix(authHeader, "Bearer ") { - // Not a Bearer token, proceed with normal authentication - handler(w, r) - return - } - - // Extract the API key - apiKey := strings.TrimPrefix(authHeader, "Bearer ") - - // Validate the API key for this endpoint - pluginAPIKey, err := m.apiKeyManager.ValidateAPIKeyForEndpoint(endpoint, r.Method, apiKey) - if err != nil { - // Invalid API key or endpoint not permitted - http.Error(w, "Unauthorized: Invalid API key or endpoint not permitted", http.StatusUnauthorized) - return - } - - // Add plugin information to the request context - r.Header.Set("X-Zoraxy-Plugin-ID", pluginAPIKey.PluginID) - r.Header.Set("X-Zoraxy-Plugin-Auth", "true") - - // Call the original handler - handler(w, r) +func (m *PluginAuthMiddleware) HandleAuthCheck(w http.ResponseWriter, r *http.Request, handler http.HandlerFunc) { + // Check for API key in the Authorization header + authHeader := r.Header.Get("Authorization") + if authHeader == "" { + // No authorization header + m.option.DeniedHandler(w, r) + return } + + // Check if it's a plugin API key (Bearer token) + if !strings.HasPrefix(authHeader, "Bearer ") { + // Not a Bearer token + m.option.DeniedHandler(w, r) + return + } + + // Extract the API key + apiKey := strings.TrimPrefix(authHeader, "Bearer ") + + // Validate the API key for this endpoint + _, err := m.option.ApiKeyManager.ValidateAPIKeyForEndpoint(r.URL.Path, r.Method, apiKey) + if err != nil { + // Invalid API key or endpoint not permitted + m.option.DeniedHandler(w, r) + return + } + + // Call the original handler + handler(w, r) +} + +// wraps an HTTP handler with plugin authentication middleware +func (m *PluginAuthMiddleware) HandleFunc(endpoint string, handler http.HandlerFunc) error { + // ensure the endpoint is prefixed with PLUGIN_API_PREFIX + if !strings.HasPrefix(endpoint, PLUGIN_API_PREFIX) { + endpoint = PLUGIN_API_PREFIX + endpoint + } + + // Check if the endpoint already registered + if _, exist := m.endpoints[endpoint]; exist { + fmt.Println("WARNING! Duplicated registering of plugin api endpoint: " + endpoint) + return errors.New("endpoint register duplicated") + } + + m.endpoints[endpoint] = handler + + wrappedHandler := func(w http.ResponseWriter, r *http.Request) { + m.HandleAuthCheck(w, r, handler) + } + + // Ok. Register handler + if m.option.TargetMux == nil { + http.HandleFunc(endpoint, wrappedHandler) + } else { + m.option.TargetMux.HandleFunc(endpoint, wrappedHandler) + } + + return nil } diff --git a/src/start.go b/src/start.go index 1cc55bf..b280871 100644 --- a/src/start.go +++ b/src/start.go @@ -91,7 +91,6 @@ func startupSequence() { os.MkdirAll(CONF_HTTP_PROXY, 0775) //Create an auth agent - pluginApiKeyManager = auth.NewAPIKeyManager() sessionKey, err := auth.GetSessionKey(sysdb, SystemWideLogger) if err != nil { log.Fatal(err) @@ -99,7 +98,10 @@ func startupSequence() { authAgent = auth.NewAuthenticationAgent(SYSTEM_NAME, []byte(sessionKey), sysdb, true, SystemWideLogger, func(w http.ResponseWriter, r *http.Request) { //Not logged in. Redirecting to login page http.Redirect(w, r, "/login.html", http.StatusTemporaryRedirect) - }, pluginApiKeyManager) + }) + + // Create an API key manager for plugin authentication + pluginApiKeyManager = auth.NewAPIKeyManager() //Create a TLS certificate manager tlsCertManager, err = tlscert.NewManager(CONF_CERT_STORE, SystemWideLogger)