zoraxy/src/mod/acme/acme.go
Toby Chui 01f68c5ef5 Added tour for basic operations
- added static website setup tour
- added subdomain setup tour
2024-08-15 22:35:43 +08:00

505 lines
14 KiB
Go

package acme
import (
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"encoding/json"
"encoding/pem"
"errors"
"fmt"
"net"
"net/http"
"os"
"path/filepath"
"strconv"
"strings"
"time"
"github.com/go-acme/lego/v4/certcrypto"
"github.com/go-acme/lego/v4/certificate"
"github.com/go-acme/lego/v4/challenge/http01"
"github.com/go-acme/lego/v4/lego"
"github.com/go-acme/lego/v4/registration"
"imuslab.com/zoraxy/mod/database"
"imuslab.com/zoraxy/mod/info/logger"
"imuslab.com/zoraxy/mod/utils"
)
type CertificateInfoJSON struct {
AcmeName string `json:"acme_name"`
AcmeUrl string `json:"acme_url"`
SkipTLS bool `json:"skip_tls"`
UseDNS bool `json:"dns"`
}
// ACMEUser represents a user in the ACME system.
type ACMEUser struct {
Email string
Registration *registration.Resource
key crypto.PrivateKey
}
type EABConfig struct {
Kid string `json:"kid"`
HmacKey string `json:"HmacKey"`
}
// GetEmail returns the email of the ACMEUser.
func (u *ACMEUser) GetEmail() string {
return u.Email
}
// GetRegistration returns the registration resource of the ACMEUser.
func (u ACMEUser) GetRegistration() *registration.Resource {
return u.Registration
}
// GetPrivateKey returns the private key of the ACMEUser.
func (u *ACMEUser) GetPrivateKey() crypto.PrivateKey {
return u.key
}
// ACMEHandler handles ACME-related operations.
type ACMEHandler struct {
DefaultAcmeServer string
Port string
Database *database.Database
Logger *logger.Logger
}
// NewACME creates a new ACMEHandler instance.
func NewACME(defaultAcmeServer string, port string, database *database.Database, logger *logger.Logger) *ACMEHandler {
return &ACMEHandler{
DefaultAcmeServer: defaultAcmeServer,
Port: port,
Database: database,
Logger: logger,
}
}
func (a *ACMEHandler) Logf(message string, err error) {
a.Logger.PrintAndLog("ACME", message, err)
}
// ObtainCert obtains a certificate for the specified domains.
func (a *ACMEHandler) ObtainCert(domains []string, certificateName string, email string, caName string, caUrl string, skipTLS bool, useDNS bool) (bool, error) {
a.Logf("Obtaining certificate for: "+strings.Join(domains, ", "), nil)
// generate private key
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
a.Logf("Private key generation failed", err)
return false, err
}
// create a admin user for our new generation
adminUser := ACMEUser{
Email: email,
key: privateKey,
}
// create config
config := lego.NewConfig(&adminUser)
// skip TLS verify if need
// Ref: https://github.com/go-acme/lego/blob/6af2c756ac73a9cb401621afca722d0f4112b1b8/lego/client_config.go#L74
if skipTLS {
a.Logf("Ignoring TLS/SSL Verification Error for ACME Server", nil)
config.HTTPClient.Transport = &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).DialContext,
TLSHandshakeTimeout: 30 * time.Second,
ResponseHeaderTimeout: 30 * time.Second,
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
}
}
//Fallback to Let's Encrypt if it is not set
if caName == "" {
caName = "Let's Encrypt"
}
// setup the custom ACME url endpoint.
if caUrl != "" {
config.CADirURL = caUrl
}
// if not custom ACME url, load it from ca.json
if caName == "custom" {
a.Logf("Using Custom ACME "+caUrl+" for CA Directory URL", nil)
} else {
caLinkOverwrite, err := loadCAApiServerFromName(caName)
if err == nil {
config.CADirURL = caLinkOverwrite
a.Logf("Using "+caLinkOverwrite+" for CA Directory URL", nil)
} else {
// (caName == "" || caUrl == "") will use default acme
config.CADirURL = a.DefaultAcmeServer
a.Logf("Using Default ACME "+a.DefaultAcmeServer+" for CA Directory URL", nil)
}
}
config.Certificate.KeyType = certcrypto.RSA2048
client, err := lego.NewClient(config)
if err != nil {
a.Logf("Failed to spawn new ACME client from current config", err)
return false, err
}
// setup how to receive challenge
if useDNS {
if !a.Database.TableExists("acme") {
a.Database.NewTable("acme")
return false, errors.New("DNS Provider and DNS Credenital configuration required for ACME Provider (Error -1)")
}
if !a.Database.KeyExists("acme", certificateName+"_dns_provider") || !a.Database.KeyExists("acme", certificateName+"_dns_credentials") {
return false, errors.New("DNS Provider and DNS Credenital configuration required for ACME Provider (Error -2)")
}
var dnsCredentials string
err := a.Database.Read("acme", certificateName+"_dns_credentials", &dnsCredentials)
if err != nil {
a.Logf("Read DNS credential failed", err)
return false, err
}
var dnsProvider string
err = a.Database.Read("acme", certificateName+"_dns_provider", &dnsProvider)
if err != nil {
a.Logf("Read DNS Provider failed", err)
return false, err
}
provider, err := GetDnsChallengeProviderByName(dnsProvider, dnsCredentials)
if err != nil {
a.Logf("Unable to resolve DNS challenge provider", err)
return false, err
}
err = client.Challenge.SetDNS01Provider(provider)
if err != nil {
a.Logf("Failed to resolve DNS01 Provider", err)
return false, err
}
} else {
err = client.Challenge.SetHTTP01Provider(http01.NewProviderServer("", a.Port))
if err != nil {
a.Logf("Failed to resolve HTTP01 Provider", err)
return false, err
}
}
// New users will need to register
/*
reg, err := client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true})
if err != nil {
log.Println(err)
return false, err
}
*/
var reg *registration.Resource
// New users will need to register
if client.GetExternalAccountRequired() {
a.Logf("External Account Required for this ACME Provider", nil)
// IF KID and HmacEncoded is overidden
if !a.Database.TableExists("acme") {
a.Database.NewTable("acme")
return false, errors.New("kid and HmacEncoded configuration required for ACME Provider (Error -1)")
}
if !a.Database.KeyExists("acme", config.CADirURL+"_kid") || !a.Database.KeyExists("acme", config.CADirURL+"_hmacEncoded") {
return false, errors.New("kid and HmacEncoded configuration required for ACME Provider (Error -2)")
}
var kid string
var hmacEncoded string
err := a.Database.Read("acme", config.CADirURL+"_kid", &kid)
if err != nil {
a.Logf("Failed to read kid from database", err)
return false, err
}
err = a.Database.Read("acme", config.CADirURL+"_hmacEncoded", &hmacEncoded)
if err != nil {
a.Logf("Failed to read HMAC from database", err)
return false, err
}
a.Logf("EAB Credential retrieved: "+kid+" / "+hmacEncoded, nil)
if kid != "" && hmacEncoded != "" {
reg, err = client.Registration.RegisterWithExternalAccountBinding(registration.RegisterEABOptions{
TermsOfServiceAgreed: true,
Kid: kid,
HmacEncoded: hmacEncoded,
})
}
if err != nil {
a.Logf("Register with external account binder failed", err)
return false, err
}
//return false, errors.New("External Account Required for this ACME Provider.")
} else {
reg, err = client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true})
if err != nil {
a.Logf("Unable to register client", err)
return false, err
}
}
adminUser.Registration = reg
// obtain the certificate
request := certificate.ObtainRequest{
Domains: domains,
Bundle: true,
}
certificates, err := client.Certificate.Obtain(request)
if err != nil {
a.Logf("Obtain certificate failed", err)
return false, err
}
// Each certificate comes back with the cert bytes, the bytes of the client's
// private key, and a certificate URL.
err = os.WriteFile("./conf/certs/"+certificateName+".pem", certificates.Certificate, 0777)
if err != nil {
a.Logf("Failed to write public key to disk", err)
return false, err
}
err = os.WriteFile("./conf/certs/"+certificateName+".key", certificates.PrivateKey, 0777)
if err != nil {
a.Logf("Failed to write private key to disk", err)
return false, err
}
// Save certificate's ACME info for renew usage
certInfo := &CertificateInfoJSON{
AcmeName: caName,
AcmeUrl: caUrl,
SkipTLS: skipTLS,
UseDNS: useDNS,
}
certInfoBytes, err := json.Marshal(certInfo)
if err != nil {
a.Logf("Marshal certificate renew config failed", err)
return false, err
}
err = os.WriteFile("./conf/certs/"+certificateName+".json", certInfoBytes, 0777)
if err != nil {
a.Logf("Failed to write certificate renew config to file", err)
return false, err
}
return true, nil
}
// CheckCertificate returns a list of domains that are in expired certificates.
// It will return all domains that is in expired certificates
// *** if there is a vaild certificate contains the domain and there is a expired certificate contains the same domain
// it will said expired as well!
func (a *ACMEHandler) CheckCertificate() []string {
// read from dir
filenames, err := os.ReadDir("./conf/certs/")
expiredCerts := []string{}
if err != nil {
a.Logf("Failed to load certificate folder", err)
return []string{}
}
for _, filename := range filenames {
certFilepath := filepath.Join("./conf/certs/", filename.Name())
certBytes, err := os.ReadFile(certFilepath)
if err != nil {
// Unable to load this file
continue
} else {
// Cert loaded. Check its expiry time
block, _ := pem.Decode(certBytes)
if block != nil {
cert, err := x509.ParseCertificate(block.Bytes)
if err == nil {
elapsed := time.Since(cert.NotAfter)
if elapsed > 0 {
// if it is expired then add it in
// make sure it's uniqueless
for _, dnsName := range cert.DNSNames {
if !contains(expiredCerts, dnsName) {
expiredCerts = append(expiredCerts, dnsName)
}
}
if !contains(expiredCerts, cert.Subject.CommonName) {
expiredCerts = append(expiredCerts, cert.Subject.CommonName)
}
}
}
}
}
}
return expiredCerts
}
// return the current port number
func (a *ACMEHandler) Getport() string {
return a.Port
}
// contains checks if a string is present in a slice.
func contains(slice []string, str string) bool {
for _, s := range slice {
if s == str {
return true
}
}
return false
}
// HandleGetExpiredDomains handles the HTTP GET request to retrieve the list of expired domains.
// It calls the CheckCertificate method to obtain the expired domains and sends a JSON response
// containing the list of expired domains.
func (a *ACMEHandler) HandleGetExpiredDomains(w http.ResponseWriter, r *http.Request) {
type ExpiredDomains struct {
Domain []string `json:"domain"`
}
info := ExpiredDomains{
Domain: a.CheckCertificate(),
}
js, _ := json.MarshalIndent(info, "", " ")
utils.SendJSONResponse(w, string(js))
}
// HandleRenewCertificate handles the HTTP GET request to renew a certificate for the provided domains.
// It retrieves the domains and filename parameters from the request, calls the ObtainCert method
// to renew the certificate, and sends a JSON response indicating the result of the renewal process.
func (a *ACMEHandler) HandleRenewCertificate(w http.ResponseWriter, r *http.Request) {
domainPara, err := utils.PostPara(r, "domains")
if err != nil {
utils.SendErrorResponse(w, jsonEscape(err.Error()))
return
}
filename, err := utils.PostPara(r, "filename")
if err != nil {
utils.SendErrorResponse(w, jsonEscape(err.Error()))
return
}
//Make sure the wildcard * do not goes into the filename
filename = strings.ReplaceAll(filename, "*", "_")
email, err := utils.PostPara(r, "email")
if err != nil {
utils.SendErrorResponse(w, jsonEscape(err.Error()))
return
}
var caUrl string
ca, err := utils.PostPara(r, "ca")
if err != nil {
a.Logf("CA not set. Using default", nil)
ca, caUrl = "", ""
}
if ca == "custom" {
caUrl, err = utils.PostPara(r, "caURL")
if err != nil {
a.Logf("Custom CA set but no URL provide, Using default", nil)
ca, caUrl = "", ""
}
}
if ca == "" {
//default. Use Let's Encrypt
ca = "Let's Encrypt"
}
var skipTLS bool
if skipTLSString, err := utils.PostPara(r, "skipTLS"); err != nil {
skipTLS = false
} else if skipTLSString != "true" {
skipTLS = false
} else {
skipTLS = true
}
var dns bool
if dnsString, err := utils.PostPara(r, "dns"); err != nil {
dns = false
} else if dnsString != "true" {
dns = false
} else {
dns = true
}
domains := strings.Split(domainPara, ",")
//Clean spaces in front or behind each domain
cleanedDomains := []string{}
for _, domain := range domains {
cleanedDomains = append(cleanedDomains, strings.TrimSpace(domain))
}
result, err := a.ObtainCert(cleanedDomains, filename, email, ca, caUrl, skipTLS, dns)
if err != nil {
utils.SendErrorResponse(w, jsonEscape(err.Error()))
return
}
utils.SendJSONResponse(w, strconv.FormatBool(result))
}
// Escape JSON string
func jsonEscape(i string) string {
b, err := json.Marshal(i)
if err != nil {
//log.Println("Unable to escape json data: " + err.Error())
return i
}
s := string(b)
return s[1 : len(s)-1]
}
// Helper function to check if a port is in use
func IsPortInUse(port int) bool {
address := fmt.Sprintf(":%d", port)
listener, err := net.Listen("tcp", address)
if err != nil {
return true // Port is in use
}
defer listener.Close()
return false // Port is not in use
}
// Load cert information from json file
func LoadCertInfoJSON(filename string) (*CertificateInfoJSON, error) {
certInfoBytes, err := os.ReadFile(filename)
if err != nil {
return nil, err
}
certInfo := &CertificateInfoJSON{}
if err = json.Unmarshal(certInfoBytes, certInfo); err != nil {
return nil, err
}
return certInfo, nil
}