extract security checks from sendmail into a specific handler
use `http.Handler` wrappers to split the sendmail handler into 2 distinct handlers: one for validating the query (currently, checking the allowed domains) and the other one for the sendmail functionality.
This commit is contained in:
parent
5f9a9ea9a7
commit
40f4fd2d30
10
main.go
10
main.go
|
@ -30,7 +30,7 @@ type Route struct {
|
|||
Name string
|
||||
Method string
|
||||
Pattern string
|
||||
HandlerFunc http.HandlerFunc
|
||||
Handler http.Handler
|
||||
}
|
||||
|
||||
type Routes []Route
|
||||
|
@ -40,13 +40,13 @@ var routes = Routes{
|
|||
"SendMail",
|
||||
"POST",
|
||||
"/sendmail",
|
||||
SendMail,
|
||||
MuxSecAllowedDomainsHandler(http.HandlerFunc(SendMail)),
|
||||
},
|
||||
Route{
|
||||
"Healthz",
|
||||
"GET",
|
||||
"/healthz",
|
||||
Healthz,
|
||||
http.HandlerFunc(Healthz),
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -67,9 +67,7 @@ func MuxLoggerHandler(inner http.Handler, name string) http.Handler {
|
|||
func NewRouter() *mux.Router {
|
||||
router := mux.NewRouter().StrictSlash(true)
|
||||
for _, route := range routes {
|
||||
var handler http.Handler
|
||||
handler = route.HandlerFunc
|
||||
handler = MuxLoggerHandler(handler, route.Name)
|
||||
handler := MuxLoggerHandler(route.Handler, route.Name)
|
||||
|
||||
router.
|
||||
Methods(route.Method).
|
||||
|
|
24
sendmail.go
24
sendmail.go
|
@ -148,11 +148,12 @@ func (m *SendMailRequest) ParseTemplate(templateFileName string, data interface{
|
|||
return nil
|
||||
}
|
||||
|
||||
// SendMail handles HTTP request to send email
|
||||
func SendMail(httpResp http.ResponseWriter, httpReq *http.Request) {
|
||||
|
||||
// MuxSecAllowedDomainsHandler is a security middleware which controls allowed domains.
|
||||
func MuxSecAllowedDomainsHandler(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
allowedDomains := strings.Split(viper.GetString("ALLOWED_ORIGINS"), ",")
|
||||
allowedOrigins := make(map[string]bool)
|
||||
|
||||
for _, domain := range allowedDomains {
|
||||
domainTrimmed := strings.TrimSpace(domain)
|
||||
allowedOrigins[fmt.Sprintf("http://%s", domainTrimmed)] = true
|
||||
|
@ -160,20 +161,27 @@ func SendMail(httpResp http.ResponseWriter, httpReq *http.Request) {
|
|||
allowedOrigins[fmt.Sprintf("http://www.%s", domainTrimmed)] = true
|
||||
allowedOrigins[fmt.Sprintf("https://www.%s", domainTrimmed)] = true
|
||||
}
|
||||
if len(httpReq.Header["Origin"]) == 0 || len(httpReq.Header["Referer"]) == 0 {
|
||||
rawHeader, _ := json.Marshal(httpReq.Header)
|
||||
|
||||
if len(r.Header["Origin"]) == 0 || len(r.Header["Referer"]) == 0 {
|
||||
rawHeader, _ := json.Marshal(r.Header)
|
||||
log.Infoln("request with unexpected headers", string(rawHeader))
|
||||
httpResp.WriteHeader(http.StatusForbidden)
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
reqOrigin := httpReq.Header["Origin"][0]
|
||||
reqOrigin := r.Header["Origin"][0]
|
||||
if _, domainFound := allowedOrigins[reqOrigin]; !domainFound {
|
||||
log.Errorln("not allowed origin", reqOrigin)
|
||||
httpResp.WriteHeader(http.StatusForbidden)
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// SendMail handles HTTP request to send email
|
||||
func SendMail(httpResp http.ResponseWriter, httpReq *http.Request) {
|
||||
httpReq.ParseForm()
|
||||
|
||||
contactRequest := ContactRequest{
|
||||
|
|
Loading…
Reference in New Issue