extract security checks from sendmail into a specific handler

use `http.Handler` wrappers to split the sendmail handler into 2
distinct handlers: one for validating the query (currently, checking the allowed
domains) and the other one for the sendmail functionality.
This commit is contained in:
Chris Camel 2020-06-01 19:48:13 +02:00
parent 5f9a9ea9a7
commit 40f4fd2d30
No known key found for this signature in database
GPG Key ID: 125EFEF60AEF6949
2 changed files with 39 additions and 33 deletions

16
main.go
View File

@ -27,10 +27,10 @@ import (
)
type Route struct {
Name string
Method string
Pattern string
HandlerFunc http.HandlerFunc
Name string
Method string
Pattern string
Handler http.Handler
}
type Routes []Route
@ -40,13 +40,13 @@ var routes = Routes{
"SendMail",
"POST",
"/sendmail",
SendMail,
MuxSecAllowedDomainsHandler(http.HandlerFunc(SendMail)),
},
Route{
"Healthz",
"GET",
"/healthz",
Healthz,
http.HandlerFunc(Healthz),
},
}
@ -67,9 +67,7 @@ func MuxLoggerHandler(inner http.Handler, name string) http.Handler {
func NewRouter() *mux.Router {
router := mux.NewRouter().StrictSlash(true)
for _, route := range routes {
var handler http.Handler
handler = route.HandlerFunc
handler = MuxLoggerHandler(handler, route.Name)
handler := MuxLoggerHandler(route.Handler, route.Name)
router.
Methods(route.Method).

View File

@ -148,32 +148,40 @@ func (m *SendMailRequest) ParseTemplate(templateFileName string, data interface{
return nil
}
// MuxSecAllowedDomainsHandler is a security middleware which controls allowed domains.
func MuxSecAllowedDomainsHandler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
allowedDomains := strings.Split(viper.GetString("ALLOWED_ORIGINS"), ",")
allowedOrigins := make(map[string]bool)
for _, domain := range allowedDomains {
domainTrimmed := strings.TrimSpace(domain)
allowedOrigins[fmt.Sprintf("http://%s", domainTrimmed)] = true
allowedOrigins[fmt.Sprintf("https://%s", domainTrimmed)] = true
allowedOrigins[fmt.Sprintf("http://www.%s", domainTrimmed)] = true
allowedOrigins[fmt.Sprintf("https://www.%s", domainTrimmed)] = true
}
if len(r.Header["Origin"]) == 0 || len(r.Header["Referer"]) == 0 {
rawHeader, _ := json.Marshal(r.Header)
log.Infoln("request with unexpected headers", string(rawHeader))
w.WriteHeader(http.StatusForbidden)
return
}
reqOrigin := r.Header["Origin"][0]
if _, domainFound := allowedOrigins[reqOrigin]; !domainFound {
log.Errorln("not allowed origin", reqOrigin)
w.WriteHeader(http.StatusForbidden)
return
}
next.ServeHTTP(w, r)
})
}
// SendMail handles HTTP request to send email
func SendMail(httpResp http.ResponseWriter, httpReq *http.Request) {
allowedDomains := strings.Split(viper.GetString("ALLOWED_ORIGINS"), ",")
allowedOrigins := make(map[string]bool)
for _, domain := range allowedDomains {
domainTrimmed := strings.TrimSpace(domain)
allowedOrigins[fmt.Sprintf("http://%s", domainTrimmed)] = true
allowedOrigins[fmt.Sprintf("https://%s", domainTrimmed)] = true
allowedOrigins[fmt.Sprintf("http://www.%s", domainTrimmed)] = true
allowedOrigins[fmt.Sprintf("https://www.%s", domainTrimmed)] = true
}
if len(httpReq.Header["Origin"]) == 0 || len(httpReq.Header["Referer"]) == 0 {
rawHeader, _ := json.Marshal(httpReq.Header)
log.Infoln("request with unexpected headers", string(rawHeader))
httpResp.WriteHeader(http.StatusForbidden)
return
}
reqOrigin := httpReq.Header["Origin"][0]
if _, domainFound := allowedOrigins[reqOrigin]; !domainFound {
log.Errorln("not allowed origin", reqOrigin)
httpResp.WriteHeader(http.StatusForbidden)
return
}
httpReq.ParseForm()
contactRequest := ContactRequest{