first commit
This commit is contained in:
150
internal/web/middleware.go
Normal file
150
internal/web/middleware.go
Normal file
@@ -0,0 +1,150 @@
|
||||
package web
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
// securityHeaders are applied to every response.
|
||||
var securityHeaders = map[string]string{
|
||||
"Referrer-Policy": "strict-origin-when-cross-origin",
|
||||
"X-Content-Type-Options": "nosniff",
|
||||
"X-Frame-Options": "SAMEORIGIN",
|
||||
"Strict-Transport-Security": "max-age=31536000; includeSubDomains",
|
||||
"Permissions-Policy": "camera=(), microphone=(), geolocation=(), interest-cohort=()",
|
||||
"Content-Security-Policy": "default-src 'self'; " +
|
||||
"style-src 'self' 'unsafe-inline'; " +
|
||||
"font-src 'self'; " +
|
||||
"script-src 'self'; " +
|
||||
"img-src 'self' data:; " +
|
||||
"connect-src 'self'; " +
|
||||
"frame-ancestors 'none'; " +
|
||||
"base-uri 'self'; " +
|
||||
"form-action 'self'",
|
||||
}
|
||||
|
||||
func applySecurityHeaders(w http.ResponseWriter) {
|
||||
for k, v := range securityHeaders {
|
||||
if w.Header().Get(k) == "" {
|
||||
w.Header().Set(k, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---- per-IP rate limiter ----
|
||||
|
||||
type ipLimiter struct {
|
||||
limiter *rate.Limiter
|
||||
lastSeen time.Time
|
||||
}
|
||||
|
||||
// RateLimiter tracks per-IP request rates.
|
||||
type RateLimiter struct {
|
||||
mu sync.Mutex
|
||||
limiters map[string]*ipLimiter
|
||||
r rate.Limit
|
||||
burst int
|
||||
}
|
||||
|
||||
// NewRateLimiter creates a RateLimiter with the given sustained rate and burst.
|
||||
func NewRateLimiter(r rate.Limit, burst int) *RateLimiter {
|
||||
rl := &RateLimiter{
|
||||
limiters: make(map[string]*ipLimiter),
|
||||
r: r,
|
||||
burst: burst,
|
||||
}
|
||||
go rl.cleanup()
|
||||
return rl
|
||||
}
|
||||
|
||||
func (rl *RateLimiter) get(ip string) *rate.Limiter {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
entry, ok := rl.limiters[ip]
|
||||
if !ok {
|
||||
entry = &ipLimiter{limiter: rate.NewLimiter(rl.r, rl.burst)}
|
||||
rl.limiters[ip] = entry
|
||||
}
|
||||
entry.lastSeen = time.Now()
|
||||
return entry.limiter
|
||||
}
|
||||
|
||||
func (rl *RateLimiter) cleanup() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
rl.mu.Lock()
|
||||
for ip, e := range rl.limiters {
|
||||
if time.Since(e.lastSeen) > 10*time.Minute {
|
||||
delete(rl.limiters, ip)
|
||||
}
|
||||
}
|
||||
rl.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// ---- status recorder for logging ----
|
||||
|
||||
type statusRecorder struct {
|
||||
http.ResponseWriter
|
||||
status int
|
||||
wroteHeader bool
|
||||
}
|
||||
|
||||
func (sr *statusRecorder) WriteHeader(code int) {
|
||||
if !sr.wroteHeader {
|
||||
sr.status = code
|
||||
sr.wroteHeader = true
|
||||
}
|
||||
sr.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
|
||||
func (sr *statusRecorder) Write(b []byte) (int, error) {
|
||||
if !sr.wroteHeader {
|
||||
sr.status = http.StatusOK
|
||||
sr.wroteHeader = true
|
||||
}
|
||||
return sr.ResponseWriter.Write(b)
|
||||
}
|
||||
|
||||
// BuildMiddleware wraps mux with: logging → timeout → rate-limit → security headers.
|
||||
func BuildMiddleware(mux http.Handler, rl *RateLimiter, timeout time.Duration) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
start := time.Now()
|
||||
|
||||
ctx, cancel := context.WithTimeout(r.Context(), timeout)
|
||||
defer cancel()
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
ip, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
ip = r.RemoteAddr
|
||||
}
|
||||
|
||||
if !rl.get(ip).Allow() {
|
||||
http.Error(w, "429 Too Many Requests", http.StatusTooManyRequests)
|
||||
slog.Info("rate limited", "ip", ip, "path", r.URL.Path)
|
||||
return
|
||||
}
|
||||
|
||||
applySecurityHeaders(w)
|
||||
|
||||
sr := &statusRecorder{ResponseWriter: w}
|
||||
mux.ServeHTTP(sr, r)
|
||||
|
||||
slog.Info("request",
|
||||
"method", r.Method,
|
||||
"path", r.URL.Path,
|
||||
"status", sr.status,
|
||||
"ip", ip,
|
||||
"duration", time.Since(start).String(),
|
||||
)
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user