extract security checks from sendmail into a specific handler

use `http.Handler` wrappers to split the sendmail handler into 2
distinct handlers: one for validating the query (currently, checking the allowed
domains) and the other one for the sendmail functionality.
This commit is contained in:
Chris Camel 2020-06-01 19:48:13 +02:00
parent 5f9a9ea9a7
commit 40f4fd2d30
No known key found for this signature in database
GPG Key ID: 125EFEF60AEF6949
2 changed files with 39 additions and 33 deletions

10
main.go
View File

@ -30,7 +30,7 @@ type Route struct {
Name string Name string
Method string Method string
Pattern string Pattern string
HandlerFunc http.HandlerFunc Handler http.Handler
} }
type Routes []Route type Routes []Route
@ -40,13 +40,13 @@ var routes = Routes{
"SendMail", "SendMail",
"POST", "POST",
"/sendmail", "/sendmail",
SendMail, MuxSecAllowedDomainsHandler(http.HandlerFunc(SendMail)),
}, },
Route{ Route{
"Healthz", "Healthz",
"GET", "GET",
"/healthz", "/healthz",
Healthz, http.HandlerFunc(Healthz),
}, },
} }
@ -67,9 +67,7 @@ func MuxLoggerHandler(inner http.Handler, name string) http.Handler {
func NewRouter() *mux.Router { func NewRouter() *mux.Router {
router := mux.NewRouter().StrictSlash(true) router := mux.NewRouter().StrictSlash(true)
for _, route := range routes { for _, route := range routes {
var handler http.Handler handler := MuxLoggerHandler(route.Handler, route.Name)
handler = route.HandlerFunc
handler = MuxLoggerHandler(handler, route.Name)
router. router.
Methods(route.Method). Methods(route.Method).

View File

@ -148,11 +148,12 @@ func (m *SendMailRequest) ParseTemplate(templateFileName string, data interface{
return nil return nil
} }
// SendMail handles HTTP request to send email // MuxSecAllowedDomainsHandler is a security middleware which controls allowed domains.
func SendMail(httpResp http.ResponseWriter, httpReq *http.Request) { func MuxSecAllowedDomainsHandler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
allowedDomains := strings.Split(viper.GetString("ALLOWED_ORIGINS"), ",") allowedDomains := strings.Split(viper.GetString("ALLOWED_ORIGINS"), ",")
allowedOrigins := make(map[string]bool) allowedOrigins := make(map[string]bool)
for _, domain := range allowedDomains { for _, domain := range allowedDomains {
domainTrimmed := strings.TrimSpace(domain) domainTrimmed := strings.TrimSpace(domain)
allowedOrigins[fmt.Sprintf("http://%s", domainTrimmed)] = true allowedOrigins[fmt.Sprintf("http://%s", domainTrimmed)] = true
@ -160,20 +161,27 @@ func SendMail(httpResp http.ResponseWriter, httpReq *http.Request) {
allowedOrigins[fmt.Sprintf("http://www.%s", domainTrimmed)] = true allowedOrigins[fmt.Sprintf("http://www.%s", domainTrimmed)] = true
allowedOrigins[fmt.Sprintf("https://www.%s", domainTrimmed)] = true allowedOrigins[fmt.Sprintf("https://www.%s", domainTrimmed)] = true
} }
if len(httpReq.Header["Origin"]) == 0 || len(httpReq.Header["Referer"]) == 0 {
rawHeader, _ := json.Marshal(httpReq.Header) if len(r.Header["Origin"]) == 0 || len(r.Header["Referer"]) == 0 {
rawHeader, _ := json.Marshal(r.Header)
log.Infoln("request with unexpected headers", string(rawHeader)) log.Infoln("request with unexpected headers", string(rawHeader))
httpResp.WriteHeader(http.StatusForbidden) w.WriteHeader(http.StatusForbidden)
return return
} }
reqOrigin := httpReq.Header["Origin"][0] reqOrigin := r.Header["Origin"][0]
if _, domainFound := allowedOrigins[reqOrigin]; !domainFound { if _, domainFound := allowedOrigins[reqOrigin]; !domainFound {
log.Errorln("not allowed origin", reqOrigin) log.Errorln("not allowed origin", reqOrigin)
httpResp.WriteHeader(http.StatusForbidden) w.WriteHeader(http.StatusForbidden)
return return
} }
next.ServeHTTP(w, r)
})
}
// SendMail handles HTTP request to send email
func SendMail(httpResp http.ResponseWriter, httpReq *http.Request) {
httpReq.ParseForm() httpReq.ParseForm()
contactRequest := ContactRequest{ contactRequest := ContactRequest{