added more billing components.

This commit is contained in:
Blake Ridgway
2026-04-16 21:30:11 -05:00
parent a621b1deb9
commit ffc2bde162
7 changed files with 157 additions and 43 deletions

41
internal/web/csrf.go Normal file
View File

@@ -0,0 +1,41 @@
package web
import (
"crypto/rand"
"encoding/hex"
"net/http"
)
const csrfCookieName = "_csrf"
// ensureCSRFToken returns the current CSRF token from the cookie, generating
// and setting a new one if the cookie is absent.
func ensureCSRFToken(w http.ResponseWriter, r *http.Request, secure bool) string {
if cookie, err := r.Cookie(csrfCookieName); err == nil && cookie.Value != "" {
return cookie.Value
}
raw := make([]byte, 32)
_, _ = rand.Read(raw)
token := hex.EncodeToString(raw)
http.SetCookie(w, &http.Cookie{
Name: csrfCookieName,
Value: token,
Path: "/",
MaxAge: 86400,
HttpOnly: true,
Secure: secure,
SameSite: http.SameSiteStrictMode,
})
return token
}
// validateCSRF returns true when the form's csrf_token field matches the
// _csrf cookie. Call after r.ParseForm().
func validateCSRF(r *http.Request) bool {
cookie, err := r.Cookie(csrfCookieName)
if err != nil || cookie.Value == "" {
return false
}
formToken := r.FormValue("csrf_token")
return formToken != "" && formToken == cookie.Value
}

View File

@@ -4,6 +4,7 @@ import (
"database/sql"
"embed"
"errors"
"fmt"
"html/template"
"io"
"log/slog"
@@ -126,6 +127,7 @@ type dashboardData struct {
Subscription *subscriptionRow
Invoices []invoiceRow
Flash string
CSRFToken string
}
// ---- DB helpers ----
@@ -181,37 +183,16 @@ func loadRecentInvoices(db *sql.DB, customerID int64) ([]invoiceRow, error) {
return result, rows.Err()
}
func validEmail(s string) bool {
at := strings.Index(s, "@")
return at > 0 && at < len(s)-1 && strings.Contains(s[at+1:], ".")
}
func formatCurrency(dollars, cents int64, currency string) string {
if currency == "USD" || currency == "" {
return "$" + itoa(dollars) + "." + pad2(cents)
return fmt.Sprintf("$%d.%02d", dollars, cents)
}
return itoa(dollars) + "." + pad2(cents) + " " + currency
}
func itoa(n int64) string {
if n == 0 {
return "0"
}
s := ""
neg := n < 0
if neg {
n = -n
}
for n > 0 {
s = string(rune('0'+n%10)) + s
n /= 10
}
if neg {
s = "-" + s
}
return s
}
func pad2(n int64) string {
if n < 10 {
return "0" + itoa(n)
}
return itoa(n)
return fmt.Sprintf("%d.%02d %s", dollars, cents, currency)
}
// ---- session cookie helpers ----
@@ -265,8 +246,11 @@ func (h *Handler) IndexHandler(w http.ResponseWriter, r *http.Request) {
}
func (h *Handler) LoginGET(w http.ResponseWriter, r *http.Request) {
token := ensureCSRFToken(w, r, sessionSecure())
h.ts.render(w, "login.html", map[string]any{
"Error": r.URL.Query().Get("error"),
"Error": r.URL.Query().Get("error"),
"reset": r.URL.Query().Get("reset"),
"CSRFToken": token,
})
}
@@ -275,6 +259,10 @@ func (h *Handler) LoginPOST(w http.ResponseWriter, r *http.Request) {
http.Error(w, "bad request", http.StatusBadRequest)
return
}
if !validateCSRF(r) {
http.Error(w, "invalid request", http.StatusForbidden)
return
}
email := strings.TrimSpace(strings.ToLower(r.FormValue("email")))
password := r.FormValue("password")
@@ -313,8 +301,10 @@ func (h *Handler) LoginPOST(w http.ResponseWriter, r *http.Request) {
}
func (h *Handler) RegisterGET(w http.ResponseWriter, r *http.Request) {
token := ensureCSRFToken(w, r, sessionSecure())
h.ts.render(w, "register.html", map[string]any{
"Error": r.URL.Query().Get("error"),
"Error": r.URL.Query().Get("error"),
"CSRFToken": token,
})
}
@@ -323,6 +313,10 @@ func (h *Handler) RegisterPOST(w http.ResponseWriter, r *http.Request) {
http.Error(w, "bad request", http.StatusBadRequest)
return
}
if !validateCSRF(r) {
http.Error(w, "invalid request", http.StatusForbidden)
return
}
firstName := strings.TrimSpace(r.FormValue("first_name"))
lastName := strings.TrimSpace(r.FormValue("last_name"))
@@ -334,6 +328,10 @@ func (h *Handler) RegisterPOST(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, "/register?error=missing_fields", http.StatusSeeOther)
return
}
if !validEmail(email) {
http.Redirect(w, r, "/register?error=invalid_email", http.StatusSeeOther)
return
}
if password != confirm {
http.Redirect(w, r, "/register?error=password_mismatch", http.StatusSeeOther)
return
@@ -416,22 +414,41 @@ func (h *Handler) DashboardGET(w http.ResponseWriter, r *http.Request) {
case "cancelled":
flash = "Checkout was cancelled. No charge was made."
}
if r.URL.Query().Get("cancelled") == "1" {
switch r.URL.Query().Get("cancelled") {
case "1":
flash = "Your subscription has been cancelled and will not renew. You retain access until the end of the current billing period."
}
if r.URL.Query().Get("error") == "cancel_failed" {
switch r.URL.Query().Get("error") {
case "cancel_failed":
flash = "Could not cancel subscription. Please contact support."
case "already_cancelling":
flash = "Your subscription is already scheduled for cancellation."
case "no_subscription":
flash = "No active subscription found."
case "already_subscribed":
flash = "You already have an active subscription."
}
csrfToken := ensureCSRFToken(w, r, sessionSecure())
h.ts.render(w, "dashboard.html", dashboardData{
Customer: c,
Subscription: sub,
Invoices: invoices,
Flash: flash,
CSRFToken: csrfToken,
})
}
func (h *Handler) LogoutPOST(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
http.Error(w, "bad request", http.StatusBadRequest)
return
}
if !validateCSRF(r) {
http.Error(w, "invalid request", http.StatusForbidden)
return
}
cookie, err := r.Cookie("session")
if err == nil {
_ = auth.DeleteSession(h.DB, cookie.Value)
@@ -441,9 +458,11 @@ func (h *Handler) LogoutPOST(w http.ResponseWriter, r *http.Request) {
}
func (h *Handler) ResetGET(w http.ResponseWriter, r *http.Request) {
token := ensureCSRFToken(w, r, sessionSecure())
h.ts.render(w, "reset-request.html", map[string]any{
"Sent": r.URL.Query().Get("sent"),
"Error": r.URL.Query().Get("error"),
"Sent": r.URL.Query().Get("sent"),
"Error": r.URL.Query().Get("error"),
"CSRFToken": token,
})
}
@@ -452,6 +471,10 @@ func (h *Handler) ResetPOST(w http.ResponseWriter, r *http.Request) {
http.Error(w, "bad request", http.StatusBadRequest)
return
}
if !validateCSRF(r) {
http.Error(w, "invalid request", http.StatusForbidden)
return
}
email := strings.TrimSpace(strings.ToLower(r.FormValue("email")))
if email == "" {
@@ -480,10 +503,12 @@ func (h *Handler) ResetPOST(w http.ResponseWriter, r *http.Request) {
}
func (h *Handler) ResetConfirmGET(w http.ResponseWriter, r *http.Request) {
token := r.PathValue("token")
pathToken := r.PathValue("token")
csrfToken := ensureCSRFToken(w, r, sessionSecure())
h.ts.render(w, "reset-confirm.html", map[string]any{
"Token": token,
"Error": r.URL.Query().Get("error"),
"Token": pathToken,
"Error": r.URL.Query().Get("error"),
"CSRFToken": csrfToken,
})
}
@@ -494,6 +519,10 @@ func (h *Handler) ResetConfirmPOST(w http.ResponseWriter, r *http.Request) {
http.Error(w, "bad request", http.StatusBadRequest)
return
}
if !validateCSRF(r) {
http.Error(w, "invalid request", http.StatusForbidden)
return
}
password := r.FormValue("password")
confirm := r.FormValue("confirm_password")
@@ -552,6 +581,30 @@ func (h *Handler) CheckoutGET(w http.ResponseWriter, r *http.Request) {
return
}
// Reject price IDs not in the configured set.
validPrice := false
for _, id := range h.Stripe.PriceIDs {
if id == priceID {
validPrice = true
break
}
}
if !validPrice {
http.Error(w, "invalid plan", http.StatusBadRequest)
return
}
// Block customers who already have an active or cancelling subscription.
var existingCount int
_ = h.DB.QueryRow(
`SELECT COUNT(*) FROM subscriptions WHERE customer_id = ? AND status IN ('active', 'cancelling')`,
customerID,
).Scan(&existingCount)
if existingCount > 0 {
http.Redirect(w, r, "/dashboard?error=already_subscribed", http.StatusSeeOther)
return
}
var stripeCustomerID string
err := h.DB.QueryRow(
`SELECT stripe_customer_id FROM customers WHERE id = ?`, customerID,
@@ -620,20 +673,33 @@ func (h *Handler) WebhookPOST(w http.ResponseWriter, r *http.Request) {
}
func (h *Handler) CancelPOST(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
http.Error(w, "bad request", http.StatusBadRequest)
return
}
if !validateCSRF(r) {
http.Error(w, "invalid request", http.StatusForbidden)
return
}
customerID := auth.CustomerIDFromContext(r.Context())
var stripeSubID string
var stripeSubID, subStatus string
err := h.DB.QueryRow(
`SELECT stripe_subscription_id FROM subscriptions
WHERE customer_id = ? AND status = 'active'
`SELECT stripe_subscription_id, status FROM subscriptions
WHERE customer_id = ? AND status IN ('active', 'cancelling')
ORDER BY created_at DESC LIMIT 1`,
customerID,
).Scan(&stripeSubID)
).Scan(&stripeSubID, &subStatus)
if err != nil {
slog.Error("cancel: find subscription", "err", err)
http.Redirect(w, r, "/dashboard?error=no_subscription", http.StatusSeeOther)
return
}
if subStatus == "cancelling" {
http.Redirect(w, r, "/dashboard?error=already_cancelling", http.StatusSeeOther)
return
}
if err := payments.CancelSubscription(stripeSubID); err != nil {
slog.Error("cancel: stripe cancel", "err", err)

View File

@@ -4,6 +4,7 @@
{{define "nav-actions"}}
<form method="POST" action="/logout" style="display:inline;">
<input type="hidden" name="csrf_token" value="{{.CSRFToken}}">
<button class="btn btn--ghost btn--sm" type="submit">logout</button>
</form>
{{end}}
@@ -60,6 +61,7 @@
<div class="dash-actions">
<form method="POST" action="/cancel"
onsubmit="return confirm('Are you sure you want to cancel your subscription? This cannot be undone.')">
<input type="hidden" name="csrf_token" value="{{.CSRFToken}}">
<button class="btn btn--danger btn--sm" type="submit">cancel subscription</button>
</form>
</div>
@@ -109,8 +111,8 @@
<tbody>
{{range .Invoices}}
<tr class="inv-table__row">
<td class="inv-table__td inv-table__td--muted">{{.CreatedAt | slice 0 10}}</td>
<td class="inv-table__td inv-table__td--muted">{{.PeriodStart | slice 0 10}} {{.PeriodEnd | slice 0 10}}</td>
<td class="inv-table__td inv-table__td--muted">{{.CreatedAt | fmtDate}}</td>
<td class="inv-table__td inv-table__td--muted">{{.PeriodStart | fmtDate}} {{.PeriodEnd | fmtDate}}</td>
<td class="inv-table__td inv-table__td--right">{{.AmountDisplay}}</td>
<td class="inv-table__td"><span class="status-badge status-badge--{{.Status}}">{{.Status}}</span></td>
<td class="inv-table__td">

View File

@@ -25,6 +25,7 @@
{{end}}
<form method="POST" action="/login" class="form">
<input type="hidden" name="csrf_token" value="{{.CSRFToken}}">
<div class="form__group">
<label class="form__label" for="email">Email</label>
<input class="form__input" type="email" id="email" name="email"

View File

@@ -17,12 +17,14 @@
{{else if eq .Error "password_mismatch"}}Passwords do not match.
{{else if eq .Error "password_too_short"}}Password must be at least 8 characters.
{{else if eq .Error "email_taken"}}An account with that email already exists.
{{else if eq .Error "invalid_email"}}Please enter a valid email address.
{{else if eq .Error "server_error"}}A server error occurred. Please try again.
{{else}}An error occurred. Please try again.{{end}}
</div>
{{end}}
<form method="POST" action="/register" class="form">
<input type="hidden" name="csrf_token" value="{{.CSRFToken}}">
<div class="form__row">
<div class="form__group">
<label class="form__label" for="first_name">First name</label>

View File

@@ -23,6 +23,7 @@
{{end}}
<form method="POST" action="/reset/{{.Token}}" class="form">
<input type="hidden" name="csrf_token" value="{{.CSRFToken}}">
<div class="form__group">
<label class="form__label" for="password">New password <span class="form__hint">(min. 8 characters)</span></label>
<input class="form__input" type="password" id="password" name="password"

View File

@@ -29,6 +29,7 @@
</p>
<form method="POST" action="/reset" class="form">
<input type="hidden" name="csrf_token" value="{{.CSRFToken}}">
<div class="form__group">
<label class="form__label" for="email">Email</label>
<input class="form__input" type="email" id="email" name="email"