Files
uptimemonitor/handler/middleware.go
2025-08-01 18:01:55 +02:00

119 lines
2.8 KiB
Go

package handler
import (
"context"
"log/slog"
"net/http"
"strings"
"uptimemonitor"
)
type contextKey string
const (
sessionContextKey contextKey = "session"
userContextKey contextKey = "user"
)
func (m *Handler) Installed(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
count, err := m.Store.CountUsers(r.Context())
if err != nil {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
if count == 0 && r.Method == http.MethodGet {
http.Redirect(w, r, "/setup", http.StatusSeeOther)
return
}
if count == 0 {
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
return
}
next.ServeHTTP(w, r)
})
}
func (m *Handler) UserFromCookie(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
c, err := r.Cookie("session")
if err != nil || c.Value == "" {
next.ServeHTTP(w, r)
return
}
session, err := m.Store.GetSessionByUuid(r.Context(), c.Value)
if err != nil {
next.ServeHTTP(w, r)
return
}
ctx := context.WithValue(r.Context(), sessionContextKey, session)
ctx = context.WithValue(ctx, userContextKey, session.User)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
func (m *Handler) Authenticated(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
value := r.Context().Value(sessionContextKey)
if value == nil {
http.Redirect(w, r, "/login", http.StatusSeeOther)
return
}
_, ok := value.(uptimemonitor.Session)
if !ok {
http.Redirect(w, r, "/login", http.StatusSeeOther)
return
}
next.ServeHTTP(w, r)
})
}
func (m *Handler) Guest(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
value := r.Context().Value(sessionContextKey)
if value != nil {
http.Redirect(w, r, "/", http.StatusSeeOther)
return
}
next.ServeHTTP(w, r)
})
}
func (m *Handler) Recoverer(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer func() {
if err := recover(); err != nil {
slog.Error("panic", "error", err)
w.WriteHeader(http.StatusInternalServerError)
}
}()
next.ServeHTTP(w, r)
})
}
func (m *Handler) NoCache(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodGet && strings.HasPrefix(r.URL.Path, "/static") {
w.Header().Set("Cache-Control", "public, max-age=31536000")
next.ServeHTTP(w, r)
return
}
w.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate")
w.Header().Set("Pragma", "no-cache")
w.Header().Set("Expires", "0")
next.ServeHTTP(w, r)
})
}