extract security checks from sendmail into a specific handler
use `http.Handler` wrappers to split the sendmail handler into 2 distinct handlers: one for validating the query (currently, checking the allowed domains) and the other one for the sendmail functionality.
This commit is contained in:
parent
5f9a9ea9a7
commit
40f4fd2d30
16
main.go
16
main.go
|
@ -27,10 +27,10 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type Route struct {
|
type Route struct {
|
||||||
Name string
|
Name string
|
||||||
Method string
|
Method string
|
||||||
Pattern string
|
Pattern string
|
||||||
HandlerFunc http.HandlerFunc
|
Handler http.Handler
|
||||||
}
|
}
|
||||||
|
|
||||||
type Routes []Route
|
type Routes []Route
|
||||||
|
@ -40,13 +40,13 @@ var routes = Routes{
|
||||||
"SendMail",
|
"SendMail",
|
||||||
"POST",
|
"POST",
|
||||||
"/sendmail",
|
"/sendmail",
|
||||||
SendMail,
|
MuxSecAllowedDomainsHandler(http.HandlerFunc(SendMail)),
|
||||||
},
|
},
|
||||||
Route{
|
Route{
|
||||||
"Healthz",
|
"Healthz",
|
||||||
"GET",
|
"GET",
|
||||||
"/healthz",
|
"/healthz",
|
||||||
Healthz,
|
http.HandlerFunc(Healthz),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -67,9 +67,7 @@ func MuxLoggerHandler(inner http.Handler, name string) http.Handler {
|
||||||
func NewRouter() *mux.Router {
|
func NewRouter() *mux.Router {
|
||||||
router := mux.NewRouter().StrictSlash(true)
|
router := mux.NewRouter().StrictSlash(true)
|
||||||
for _, route := range routes {
|
for _, route := range routes {
|
||||||
var handler http.Handler
|
handler := MuxLoggerHandler(route.Handler, route.Name)
|
||||||
handler = route.HandlerFunc
|
|
||||||
handler = MuxLoggerHandler(handler, route.Name)
|
|
||||||
|
|
||||||
router.
|
router.
|
||||||
Methods(route.Method).
|
Methods(route.Method).
|
||||||
|
|
56
sendmail.go
56
sendmail.go
|
@ -148,32 +148,40 @@ func (m *SendMailRequest) ParseTemplate(templateFileName string, data interface{
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MuxSecAllowedDomainsHandler is a security middleware which controls allowed domains.
|
||||||
|
func MuxSecAllowedDomainsHandler(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
allowedDomains := strings.Split(viper.GetString("ALLOWED_ORIGINS"), ",")
|
||||||
|
allowedOrigins := make(map[string]bool)
|
||||||
|
|
||||||
|
for _, domain := range allowedDomains {
|
||||||
|
domainTrimmed := strings.TrimSpace(domain)
|
||||||
|
allowedOrigins[fmt.Sprintf("http://%s", domainTrimmed)] = true
|
||||||
|
allowedOrigins[fmt.Sprintf("https://%s", domainTrimmed)] = true
|
||||||
|
allowedOrigins[fmt.Sprintf("http://www.%s", domainTrimmed)] = true
|
||||||
|
allowedOrigins[fmt.Sprintf("https://www.%s", domainTrimmed)] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(r.Header["Origin"]) == 0 || len(r.Header["Referer"]) == 0 {
|
||||||
|
rawHeader, _ := json.Marshal(r.Header)
|
||||||
|
log.Infoln("request with unexpected headers", string(rawHeader))
|
||||||
|
w.WriteHeader(http.StatusForbidden)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
reqOrigin := r.Header["Origin"][0]
|
||||||
|
if _, domainFound := allowedOrigins[reqOrigin]; !domainFound {
|
||||||
|
log.Errorln("not allowed origin", reqOrigin)
|
||||||
|
w.WriteHeader(http.StatusForbidden)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// SendMail handles HTTP request to send email
|
// SendMail handles HTTP request to send email
|
||||||
func SendMail(httpResp http.ResponseWriter, httpReq *http.Request) {
|
func SendMail(httpResp http.ResponseWriter, httpReq *http.Request) {
|
||||||
|
|
||||||
allowedDomains := strings.Split(viper.GetString("ALLOWED_ORIGINS"), ",")
|
|
||||||
allowedOrigins := make(map[string]bool)
|
|
||||||
for _, domain := range allowedDomains {
|
|
||||||
domainTrimmed := strings.TrimSpace(domain)
|
|
||||||
allowedOrigins[fmt.Sprintf("http://%s", domainTrimmed)] = true
|
|
||||||
allowedOrigins[fmt.Sprintf("https://%s", domainTrimmed)] = true
|
|
||||||
allowedOrigins[fmt.Sprintf("http://www.%s", domainTrimmed)] = true
|
|
||||||
allowedOrigins[fmt.Sprintf("https://www.%s", domainTrimmed)] = true
|
|
||||||
}
|
|
||||||
if len(httpReq.Header["Origin"]) == 0 || len(httpReq.Header["Referer"]) == 0 {
|
|
||||||
rawHeader, _ := json.Marshal(httpReq.Header)
|
|
||||||
log.Infoln("request with unexpected headers", string(rawHeader))
|
|
||||||
httpResp.WriteHeader(http.StatusForbidden)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
reqOrigin := httpReq.Header["Origin"][0]
|
|
||||||
if _, domainFound := allowedOrigins[reqOrigin]; !domainFound {
|
|
||||||
log.Errorln("not allowed origin", reqOrigin)
|
|
||||||
httpResp.WriteHeader(http.StatusForbidden)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
httpReq.ParseForm()
|
httpReq.ParseForm()
|
||||||
|
|
||||||
contactRequest := ContactRequest{
|
contactRequest := ContactRequest{
|
||||||
|
|
Loading…
Reference in New Issue