added more billing components.

This commit is contained in:
Blake Ridgway
2026-04-16 21:30:11 -05:00
parent a621b1deb9
commit ffc2bde162
7 changed files with 157 additions and 43 deletions

41
internal/web/csrf.go Normal file
View File

@@ -0,0 +1,41 @@
package web
import (
"crypto/rand"
"encoding/hex"
"net/http"
)
const csrfCookieName = "_csrf"
// ensureCSRFToken returns the current CSRF token from the cookie, generating
// and setting a new one if the cookie is absent.
func ensureCSRFToken(w http.ResponseWriter, r *http.Request, secure bool) string {
if cookie, err := r.Cookie(csrfCookieName); err == nil && cookie.Value != "" {
return cookie.Value
}
raw := make([]byte, 32)
_, _ = rand.Read(raw)
token := hex.EncodeToString(raw)
http.SetCookie(w, &http.Cookie{
Name: csrfCookieName,
Value: token,
Path: "/",
MaxAge: 86400,
HttpOnly: true,
Secure: secure,
SameSite: http.SameSiteStrictMode,
})
return token
}
// validateCSRF returns true when the form's csrf_token field matches the
// _csrf cookie. Call after r.ParseForm().
func validateCSRF(r *http.Request) bool {
cookie, err := r.Cookie(csrfCookieName)
if err != nil || cookie.Value == "" {
return false
}
formToken := r.FormValue("csrf_token")
return formToken != "" && formToken == cookie.Value
}

View File

@@ -4,6 +4,7 @@ import (
"database/sql" "database/sql"
"embed" "embed"
"errors" "errors"
"fmt"
"html/template" "html/template"
"io" "io"
"log/slog" "log/slog"
@@ -126,6 +127,7 @@ type dashboardData struct {
Subscription *subscriptionRow Subscription *subscriptionRow
Invoices []invoiceRow Invoices []invoiceRow
Flash string Flash string
CSRFToken string
} }
// ---- DB helpers ---- // ---- DB helpers ----
@@ -181,37 +183,16 @@ func loadRecentInvoices(db *sql.DB, customerID int64) ([]invoiceRow, error) {
return result, rows.Err() return result, rows.Err()
} }
func validEmail(s string) bool {
at := strings.Index(s, "@")
return at > 0 && at < len(s)-1 && strings.Contains(s[at+1:], ".")
}
func formatCurrency(dollars, cents int64, currency string) string { func formatCurrency(dollars, cents int64, currency string) string {
if currency == "USD" || currency == "" { if currency == "USD" || currency == "" {
return "$" + itoa(dollars) + "." + pad2(cents) return fmt.Sprintf("$%d.%02d", dollars, cents)
} }
return itoa(dollars) + "." + pad2(cents) + " " + currency return fmt.Sprintf("%d.%02d %s", dollars, cents, currency)
}
func itoa(n int64) string {
if n == 0 {
return "0"
}
s := ""
neg := n < 0
if neg {
n = -n
}
for n > 0 {
s = string(rune('0'+n%10)) + s
n /= 10
}
if neg {
s = "-" + s
}
return s
}
func pad2(n int64) string {
if n < 10 {
return "0" + itoa(n)
}
return itoa(n)
} }
// ---- session cookie helpers ---- // ---- session cookie helpers ----
@@ -265,8 +246,11 @@ func (h *Handler) IndexHandler(w http.ResponseWriter, r *http.Request) {
} }
func (h *Handler) LoginGET(w http.ResponseWriter, r *http.Request) { func (h *Handler) LoginGET(w http.ResponseWriter, r *http.Request) {
token := ensureCSRFToken(w, r, sessionSecure())
h.ts.render(w, "login.html", map[string]any{ h.ts.render(w, "login.html", map[string]any{
"Error": r.URL.Query().Get("error"), "Error": r.URL.Query().Get("error"),
"reset": r.URL.Query().Get("reset"),
"CSRFToken": token,
}) })
} }
@@ -275,6 +259,10 @@ func (h *Handler) LoginPOST(w http.ResponseWriter, r *http.Request) {
http.Error(w, "bad request", http.StatusBadRequest) http.Error(w, "bad request", http.StatusBadRequest)
return return
} }
if !validateCSRF(r) {
http.Error(w, "invalid request", http.StatusForbidden)
return
}
email := strings.TrimSpace(strings.ToLower(r.FormValue("email"))) email := strings.TrimSpace(strings.ToLower(r.FormValue("email")))
password := r.FormValue("password") password := r.FormValue("password")
@@ -313,8 +301,10 @@ func (h *Handler) LoginPOST(w http.ResponseWriter, r *http.Request) {
} }
func (h *Handler) RegisterGET(w http.ResponseWriter, r *http.Request) { func (h *Handler) RegisterGET(w http.ResponseWriter, r *http.Request) {
token := ensureCSRFToken(w, r, sessionSecure())
h.ts.render(w, "register.html", map[string]any{ h.ts.render(w, "register.html", map[string]any{
"Error": r.URL.Query().Get("error"), "Error": r.URL.Query().Get("error"),
"CSRFToken": token,
}) })
} }
@@ -323,6 +313,10 @@ func (h *Handler) RegisterPOST(w http.ResponseWriter, r *http.Request) {
http.Error(w, "bad request", http.StatusBadRequest) http.Error(w, "bad request", http.StatusBadRequest)
return return
} }
if !validateCSRF(r) {
http.Error(w, "invalid request", http.StatusForbidden)
return
}
firstName := strings.TrimSpace(r.FormValue("first_name")) firstName := strings.TrimSpace(r.FormValue("first_name"))
lastName := strings.TrimSpace(r.FormValue("last_name")) lastName := strings.TrimSpace(r.FormValue("last_name"))
@@ -334,6 +328,10 @@ func (h *Handler) RegisterPOST(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, "/register?error=missing_fields", http.StatusSeeOther) http.Redirect(w, r, "/register?error=missing_fields", http.StatusSeeOther)
return return
} }
if !validEmail(email) {
http.Redirect(w, r, "/register?error=invalid_email", http.StatusSeeOther)
return
}
if password != confirm { if password != confirm {
http.Redirect(w, r, "/register?error=password_mismatch", http.StatusSeeOther) http.Redirect(w, r, "/register?error=password_mismatch", http.StatusSeeOther)
return return
@@ -416,22 +414,41 @@ func (h *Handler) DashboardGET(w http.ResponseWriter, r *http.Request) {
case "cancelled": case "cancelled":
flash = "Checkout was cancelled. No charge was made." flash = "Checkout was cancelled. No charge was made."
} }
if r.URL.Query().Get("cancelled") == "1" { switch r.URL.Query().Get("cancelled") {
case "1":
flash = "Your subscription has been cancelled and will not renew. You retain access until the end of the current billing period." flash = "Your subscription has been cancelled and will not renew. You retain access until the end of the current billing period."
} }
if r.URL.Query().Get("error") == "cancel_failed" { switch r.URL.Query().Get("error") {
case "cancel_failed":
flash = "Could not cancel subscription. Please contact support." flash = "Could not cancel subscription. Please contact support."
case "already_cancelling":
flash = "Your subscription is already scheduled for cancellation."
case "no_subscription":
flash = "No active subscription found."
case "already_subscribed":
flash = "You already have an active subscription."
} }
csrfToken := ensureCSRFToken(w, r, sessionSecure())
h.ts.render(w, "dashboard.html", dashboardData{ h.ts.render(w, "dashboard.html", dashboardData{
Customer: c, Customer: c,
Subscription: sub, Subscription: sub,
Invoices: invoices, Invoices: invoices,
Flash: flash, Flash: flash,
CSRFToken: csrfToken,
}) })
} }
func (h *Handler) LogoutPOST(w http.ResponseWriter, r *http.Request) { func (h *Handler) LogoutPOST(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
http.Error(w, "bad request", http.StatusBadRequest)
return
}
if !validateCSRF(r) {
http.Error(w, "invalid request", http.StatusForbidden)
return
}
cookie, err := r.Cookie("session") cookie, err := r.Cookie("session")
if err == nil { if err == nil {
_ = auth.DeleteSession(h.DB, cookie.Value) _ = auth.DeleteSession(h.DB, cookie.Value)
@@ -441,9 +458,11 @@ func (h *Handler) LogoutPOST(w http.ResponseWriter, r *http.Request) {
} }
func (h *Handler) ResetGET(w http.ResponseWriter, r *http.Request) { func (h *Handler) ResetGET(w http.ResponseWriter, r *http.Request) {
token := ensureCSRFToken(w, r, sessionSecure())
h.ts.render(w, "reset-request.html", map[string]any{ h.ts.render(w, "reset-request.html", map[string]any{
"Sent": r.URL.Query().Get("sent"), "Sent": r.URL.Query().Get("sent"),
"Error": r.URL.Query().Get("error"), "Error": r.URL.Query().Get("error"),
"CSRFToken": token,
}) })
} }
@@ -452,6 +471,10 @@ func (h *Handler) ResetPOST(w http.ResponseWriter, r *http.Request) {
http.Error(w, "bad request", http.StatusBadRequest) http.Error(w, "bad request", http.StatusBadRequest)
return return
} }
if !validateCSRF(r) {
http.Error(w, "invalid request", http.StatusForbidden)
return
}
email := strings.TrimSpace(strings.ToLower(r.FormValue("email"))) email := strings.TrimSpace(strings.ToLower(r.FormValue("email")))
if email == "" { if email == "" {
@@ -480,10 +503,12 @@ func (h *Handler) ResetPOST(w http.ResponseWriter, r *http.Request) {
} }
func (h *Handler) ResetConfirmGET(w http.ResponseWriter, r *http.Request) { func (h *Handler) ResetConfirmGET(w http.ResponseWriter, r *http.Request) {
token := r.PathValue("token") pathToken := r.PathValue("token")
csrfToken := ensureCSRFToken(w, r, sessionSecure())
h.ts.render(w, "reset-confirm.html", map[string]any{ h.ts.render(w, "reset-confirm.html", map[string]any{
"Token": token, "Token": pathToken,
"Error": r.URL.Query().Get("error"), "Error": r.URL.Query().Get("error"),
"CSRFToken": csrfToken,
}) })
} }
@@ -494,6 +519,10 @@ func (h *Handler) ResetConfirmPOST(w http.ResponseWriter, r *http.Request) {
http.Error(w, "bad request", http.StatusBadRequest) http.Error(w, "bad request", http.StatusBadRequest)
return return
} }
if !validateCSRF(r) {
http.Error(w, "invalid request", http.StatusForbidden)
return
}
password := r.FormValue("password") password := r.FormValue("password")
confirm := r.FormValue("confirm_password") confirm := r.FormValue("confirm_password")
@@ -552,6 +581,30 @@ func (h *Handler) CheckoutGET(w http.ResponseWriter, r *http.Request) {
return return
} }
// Reject price IDs not in the configured set.
validPrice := false
for _, id := range h.Stripe.PriceIDs {
if id == priceID {
validPrice = true
break
}
}
if !validPrice {
http.Error(w, "invalid plan", http.StatusBadRequest)
return
}
// Block customers who already have an active or cancelling subscription.
var existingCount int
_ = h.DB.QueryRow(
`SELECT COUNT(*) FROM subscriptions WHERE customer_id = ? AND status IN ('active', 'cancelling')`,
customerID,
).Scan(&existingCount)
if existingCount > 0 {
http.Redirect(w, r, "/dashboard?error=already_subscribed", http.StatusSeeOther)
return
}
var stripeCustomerID string var stripeCustomerID string
err := h.DB.QueryRow( err := h.DB.QueryRow(
`SELECT stripe_customer_id FROM customers WHERE id = ?`, customerID, `SELECT stripe_customer_id FROM customers WHERE id = ?`, customerID,
@@ -620,20 +673,33 @@ func (h *Handler) WebhookPOST(w http.ResponseWriter, r *http.Request) {
} }
func (h *Handler) CancelPOST(w http.ResponseWriter, r *http.Request) { func (h *Handler) CancelPOST(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
http.Error(w, "bad request", http.StatusBadRequest)
return
}
if !validateCSRF(r) {
http.Error(w, "invalid request", http.StatusForbidden)
return
}
customerID := auth.CustomerIDFromContext(r.Context()) customerID := auth.CustomerIDFromContext(r.Context())
var stripeSubID string var stripeSubID, subStatus string
err := h.DB.QueryRow( err := h.DB.QueryRow(
`SELECT stripe_subscription_id FROM subscriptions `SELECT stripe_subscription_id, status FROM subscriptions
WHERE customer_id = ? AND status = 'active' WHERE customer_id = ? AND status IN ('active', 'cancelling')
ORDER BY created_at DESC LIMIT 1`, ORDER BY created_at DESC LIMIT 1`,
customerID, customerID,
).Scan(&stripeSubID) ).Scan(&stripeSubID, &subStatus)
if err != nil { if err != nil {
slog.Error("cancel: find subscription", "err", err) slog.Error("cancel: find subscription", "err", err)
http.Redirect(w, r, "/dashboard?error=no_subscription", http.StatusSeeOther) http.Redirect(w, r, "/dashboard?error=no_subscription", http.StatusSeeOther)
return return
} }
if subStatus == "cancelling" {
http.Redirect(w, r, "/dashboard?error=already_cancelling", http.StatusSeeOther)
return
}
if err := payments.CancelSubscription(stripeSubID); err != nil { if err := payments.CancelSubscription(stripeSubID); err != nil {
slog.Error("cancel: stripe cancel", "err", err) slog.Error("cancel: stripe cancel", "err", err)

View File

@@ -4,6 +4,7 @@
{{define "nav-actions"}} {{define "nav-actions"}}
<form method="POST" action="/logout" style="display:inline;"> <form method="POST" action="/logout" style="display:inline;">
<input type="hidden" name="csrf_token" value="{{.CSRFToken}}">
<button class="btn btn--ghost btn--sm" type="submit">logout</button> <button class="btn btn--ghost btn--sm" type="submit">logout</button>
</form> </form>
{{end}} {{end}}
@@ -60,6 +61,7 @@
<div class="dash-actions"> <div class="dash-actions">
<form method="POST" action="/cancel" <form method="POST" action="/cancel"
onsubmit="return confirm('Are you sure you want to cancel your subscription? This cannot be undone.')"> onsubmit="return confirm('Are you sure you want to cancel your subscription? This cannot be undone.')">
<input type="hidden" name="csrf_token" value="{{.CSRFToken}}">
<button class="btn btn--danger btn--sm" type="submit">cancel subscription</button> <button class="btn btn--danger btn--sm" type="submit">cancel subscription</button>
</form> </form>
</div> </div>
@@ -109,8 +111,8 @@
<tbody> <tbody>
{{range .Invoices}} {{range .Invoices}}
<tr class="inv-table__row"> <tr class="inv-table__row">
<td class="inv-table__td inv-table__td--muted">{{.CreatedAt | slice 0 10}}</td> <td class="inv-table__td inv-table__td--muted">{{.CreatedAt | fmtDate}}</td>
<td class="inv-table__td inv-table__td--muted">{{.PeriodStart | slice 0 10}} {{.PeriodEnd | slice 0 10}}</td> <td class="inv-table__td inv-table__td--muted">{{.PeriodStart | fmtDate}} {{.PeriodEnd | fmtDate}}</td>
<td class="inv-table__td inv-table__td--right">{{.AmountDisplay}}</td> <td class="inv-table__td inv-table__td--right">{{.AmountDisplay}}</td>
<td class="inv-table__td"><span class="status-badge status-badge--{{.Status}}">{{.Status}}</span></td> <td class="inv-table__td"><span class="status-badge status-badge--{{.Status}}">{{.Status}}</span></td>
<td class="inv-table__td"> <td class="inv-table__td">

View File

@@ -25,6 +25,7 @@
{{end}} {{end}}
<form method="POST" action="/login" class="form"> <form method="POST" action="/login" class="form">
<input type="hidden" name="csrf_token" value="{{.CSRFToken}}">
<div class="form__group"> <div class="form__group">
<label class="form__label" for="email">Email</label> <label class="form__label" for="email">Email</label>
<input class="form__input" type="email" id="email" name="email" <input class="form__input" type="email" id="email" name="email"

View File

@@ -17,12 +17,14 @@
{{else if eq .Error "password_mismatch"}}Passwords do not match. {{else if eq .Error "password_mismatch"}}Passwords do not match.
{{else if eq .Error "password_too_short"}}Password must be at least 8 characters. {{else if eq .Error "password_too_short"}}Password must be at least 8 characters.
{{else if eq .Error "email_taken"}}An account with that email already exists. {{else if eq .Error "email_taken"}}An account with that email already exists.
{{else if eq .Error "invalid_email"}}Please enter a valid email address.
{{else if eq .Error "server_error"}}A server error occurred. Please try again. {{else if eq .Error "server_error"}}A server error occurred. Please try again.
{{else}}An error occurred. Please try again.{{end}} {{else}}An error occurred. Please try again.{{end}}
</div> </div>
{{end}} {{end}}
<form method="POST" action="/register" class="form"> <form method="POST" action="/register" class="form">
<input type="hidden" name="csrf_token" value="{{.CSRFToken}}">
<div class="form__row"> <div class="form__row">
<div class="form__group"> <div class="form__group">
<label class="form__label" for="first_name">First name</label> <label class="form__label" for="first_name">First name</label>

View File

@@ -23,6 +23,7 @@
{{end}} {{end}}
<form method="POST" action="/reset/{{.Token}}" class="form"> <form method="POST" action="/reset/{{.Token}}" class="form">
<input type="hidden" name="csrf_token" value="{{.CSRFToken}}">
<div class="form__group"> <div class="form__group">
<label class="form__label" for="password">New password <span class="form__hint">(min. 8 characters)</span></label> <label class="form__label" for="password">New password <span class="form__hint">(min. 8 characters)</span></label>
<input class="form__input" type="password" id="password" name="password" <input class="form__input" type="password" id="password" name="password"

View File

@@ -29,6 +29,7 @@
</p> </p>
<form method="POST" action="/reset" class="form"> <form method="POST" action="/reset" class="form">
<input type="hidden" name="csrf_token" value="{{.CSRFToken}}">
<div class="form__group"> <div class="form__group">
<label class="form__label" for="email">Email</label> <label class="form__label" for="email">Email</label>
<input class="form__input" type="email" id="email" name="email" <input class="form__input" type="email" id="email" name="email"