diff --git a/main.go b/main.go index a08799b..46a4cde 100644 --- a/main.go +++ b/main.go @@ -27,10 +27,10 @@ import ( ) type Route struct { - Name string - Method string - Pattern string - HandlerFunc http.HandlerFunc + Name string + Method string + Pattern string + Handler http.Handler } type Routes []Route @@ -40,13 +40,13 @@ var routes = Routes{ "SendMail", "POST", "/sendmail", - SendMail, + MuxSecAllowedDomainsHandler(http.HandlerFunc(SendMail)), }, Route{ "Healthz", "GET", "/healthz", - Healthz, + http.HandlerFunc(Healthz), }, } @@ -67,9 +67,7 @@ func MuxLoggerHandler(inner http.Handler, name string) http.Handler { func NewRouter() *mux.Router { router := mux.NewRouter().StrictSlash(true) for _, route := range routes { - var handler http.Handler - handler = route.HandlerFunc - handler = MuxLoggerHandler(handler, route.Name) + handler := MuxLoggerHandler(route.Handler, route.Name) router. Methods(route.Method). diff --git a/sendmail.go b/sendmail.go index 6347291..704f4fb 100644 --- a/sendmail.go +++ b/sendmail.go @@ -148,32 +148,40 @@ func (m *SendMailRequest) ParseTemplate(templateFileName string, data interface{ return nil } +// MuxSecAllowedDomainsHandler is a security middleware which controls allowed domains. +func MuxSecAllowedDomainsHandler(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + allowedDomains := strings.Split(viper.GetString("ALLOWED_ORIGINS"), ",") + allowedOrigins := make(map[string]bool) + + for _, domain := range allowedDomains { + domainTrimmed := strings.TrimSpace(domain) + allowedOrigins[fmt.Sprintf("http://%s", domainTrimmed)] = true + allowedOrigins[fmt.Sprintf("https://%s", domainTrimmed)] = true + allowedOrigins[fmt.Sprintf("http://www.%s", domainTrimmed)] = true + allowedOrigins[fmt.Sprintf("https://www.%s", domainTrimmed)] = true + } + + if len(r.Header["Origin"]) == 0 || len(r.Header["Referer"]) == 0 { + rawHeader, _ := json.Marshal(r.Header) + log.Infoln("request with unexpected headers", string(rawHeader)) + w.WriteHeader(http.StatusForbidden) + return + } + + reqOrigin := r.Header["Origin"][0] + if _, domainFound := allowedOrigins[reqOrigin]; !domainFound { + log.Errorln("not allowed origin", reqOrigin) + w.WriteHeader(http.StatusForbidden) + return + } + + next.ServeHTTP(w, r) + }) +} + // SendMail handles HTTP request to send email func SendMail(httpResp http.ResponseWriter, httpReq *http.Request) { - - allowedDomains := strings.Split(viper.GetString("ALLOWED_ORIGINS"), ",") - allowedOrigins := make(map[string]bool) - for _, domain := range allowedDomains { - domainTrimmed := strings.TrimSpace(domain) - allowedOrigins[fmt.Sprintf("http://%s", domainTrimmed)] = true - allowedOrigins[fmt.Sprintf("https://%s", domainTrimmed)] = true - allowedOrigins[fmt.Sprintf("http://www.%s", domainTrimmed)] = true - allowedOrigins[fmt.Sprintf("https://www.%s", domainTrimmed)] = true - } - if len(httpReq.Header["Origin"]) == 0 || len(httpReq.Header["Referer"]) == 0 { - rawHeader, _ := json.Marshal(httpReq.Header) - log.Infoln("request with unexpected headers", string(rawHeader)) - httpResp.WriteHeader(http.StatusForbidden) - return - } - - reqOrigin := httpReq.Header["Origin"][0] - if _, domainFound := allowedOrigins[reqOrigin]; !domainFound { - log.Errorln("not allowed origin", reqOrigin) - httpResp.WriteHeader(http.StatusForbidden) - return - } - httpReq.ParseForm() contactRequest := ContactRequest{