added more billing components.
This commit is contained in:
41
internal/web/csrf.go
Normal file
41
internal/web/csrf.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package web
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
const csrfCookieName = "_csrf"
|
||||
|
||||
// ensureCSRFToken returns the current CSRF token from the cookie, generating
|
||||
// and setting a new one if the cookie is absent.
|
||||
func ensureCSRFToken(w http.ResponseWriter, r *http.Request, secure bool) string {
|
||||
if cookie, err := r.Cookie(csrfCookieName); err == nil && cookie.Value != "" {
|
||||
return cookie.Value
|
||||
}
|
||||
raw := make([]byte, 32)
|
||||
_, _ = rand.Read(raw)
|
||||
token := hex.EncodeToString(raw)
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: csrfCookieName,
|
||||
Value: token,
|
||||
Path: "/",
|
||||
MaxAge: 86400,
|
||||
HttpOnly: true,
|
||||
Secure: secure,
|
||||
SameSite: http.SameSiteStrictMode,
|
||||
})
|
||||
return token
|
||||
}
|
||||
|
||||
// validateCSRF returns true when the form's csrf_token field matches the
|
||||
// _csrf cookie. Call after r.ParseForm().
|
||||
func validateCSRF(r *http.Request) bool {
|
||||
cookie, err := r.Cookie(csrfCookieName)
|
||||
if err != nil || cookie.Value == "" {
|
||||
return false
|
||||
}
|
||||
formToken := r.FormValue("csrf_token")
|
||||
return formToken != "" && formToken == cookie.Value
|
||||
}
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"database/sql"
|
||||
"embed"
|
||||
"errors"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"io"
|
||||
"log/slog"
|
||||
@@ -126,6 +127,7 @@ type dashboardData struct {
|
||||
Subscription *subscriptionRow
|
||||
Invoices []invoiceRow
|
||||
Flash string
|
||||
CSRFToken string
|
||||
}
|
||||
|
||||
// ---- DB helpers ----
|
||||
@@ -181,37 +183,16 @@ func loadRecentInvoices(db *sql.DB, customerID int64) ([]invoiceRow, error) {
|
||||
return result, rows.Err()
|
||||
}
|
||||
|
||||
func validEmail(s string) bool {
|
||||
at := strings.Index(s, "@")
|
||||
return at > 0 && at < len(s)-1 && strings.Contains(s[at+1:], ".")
|
||||
}
|
||||
|
||||
func formatCurrency(dollars, cents int64, currency string) string {
|
||||
if currency == "USD" || currency == "" {
|
||||
return "$" + itoa(dollars) + "." + pad2(cents)
|
||||
return fmt.Sprintf("$%d.%02d", dollars, cents)
|
||||
}
|
||||
return itoa(dollars) + "." + pad2(cents) + " " + currency
|
||||
}
|
||||
|
||||
func itoa(n int64) string {
|
||||
if n == 0 {
|
||||
return "0"
|
||||
}
|
||||
s := ""
|
||||
neg := n < 0
|
||||
if neg {
|
||||
n = -n
|
||||
}
|
||||
for n > 0 {
|
||||
s = string(rune('0'+n%10)) + s
|
||||
n /= 10
|
||||
}
|
||||
if neg {
|
||||
s = "-" + s
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func pad2(n int64) string {
|
||||
if n < 10 {
|
||||
return "0" + itoa(n)
|
||||
}
|
||||
return itoa(n)
|
||||
return fmt.Sprintf("%d.%02d %s", dollars, cents, currency)
|
||||
}
|
||||
|
||||
// ---- session cookie helpers ----
|
||||
@@ -265,8 +246,11 @@ func (h *Handler) IndexHandler(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func (h *Handler) LoginGET(w http.ResponseWriter, r *http.Request) {
|
||||
token := ensureCSRFToken(w, r, sessionSecure())
|
||||
h.ts.render(w, "login.html", map[string]any{
|
||||
"Error": r.URL.Query().Get("error"),
|
||||
"Error": r.URL.Query().Get("error"),
|
||||
"reset": r.URL.Query().Get("reset"),
|
||||
"CSRFToken": token,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -275,6 +259,10 @@ func (h *Handler) LoginPOST(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "bad request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if !validateCSRF(r) {
|
||||
http.Error(w, "invalid request", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
email := strings.TrimSpace(strings.ToLower(r.FormValue("email")))
|
||||
password := r.FormValue("password")
|
||||
@@ -313,8 +301,10 @@ func (h *Handler) LoginPOST(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func (h *Handler) RegisterGET(w http.ResponseWriter, r *http.Request) {
|
||||
token := ensureCSRFToken(w, r, sessionSecure())
|
||||
h.ts.render(w, "register.html", map[string]any{
|
||||
"Error": r.URL.Query().Get("error"),
|
||||
"Error": r.URL.Query().Get("error"),
|
||||
"CSRFToken": token,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -323,6 +313,10 @@ func (h *Handler) RegisterPOST(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "bad request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if !validateCSRF(r) {
|
||||
http.Error(w, "invalid request", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
firstName := strings.TrimSpace(r.FormValue("first_name"))
|
||||
lastName := strings.TrimSpace(r.FormValue("last_name"))
|
||||
@@ -334,6 +328,10 @@ func (h *Handler) RegisterPOST(w http.ResponseWriter, r *http.Request) {
|
||||
http.Redirect(w, r, "/register?error=missing_fields", http.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
if !validEmail(email) {
|
||||
http.Redirect(w, r, "/register?error=invalid_email", http.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
if password != confirm {
|
||||
http.Redirect(w, r, "/register?error=password_mismatch", http.StatusSeeOther)
|
||||
return
|
||||
@@ -416,22 +414,41 @@ func (h *Handler) DashboardGET(w http.ResponseWriter, r *http.Request) {
|
||||
case "cancelled":
|
||||
flash = "Checkout was cancelled. No charge was made."
|
||||
}
|
||||
if r.URL.Query().Get("cancelled") == "1" {
|
||||
switch r.URL.Query().Get("cancelled") {
|
||||
case "1":
|
||||
flash = "Your subscription has been cancelled and will not renew. You retain access until the end of the current billing period."
|
||||
}
|
||||
if r.URL.Query().Get("error") == "cancel_failed" {
|
||||
switch r.URL.Query().Get("error") {
|
||||
case "cancel_failed":
|
||||
flash = "Could not cancel subscription. Please contact support."
|
||||
case "already_cancelling":
|
||||
flash = "Your subscription is already scheduled for cancellation."
|
||||
case "no_subscription":
|
||||
flash = "No active subscription found."
|
||||
case "already_subscribed":
|
||||
flash = "You already have an active subscription."
|
||||
}
|
||||
|
||||
csrfToken := ensureCSRFToken(w, r, sessionSecure())
|
||||
|
||||
h.ts.render(w, "dashboard.html", dashboardData{
|
||||
Customer: c,
|
||||
Subscription: sub,
|
||||
Invoices: invoices,
|
||||
Flash: flash,
|
||||
CSRFToken: csrfToken,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) LogoutPOST(w http.ResponseWriter, r *http.Request) {
|
||||
if err := r.ParseForm(); err != nil {
|
||||
http.Error(w, "bad request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if !validateCSRF(r) {
|
||||
http.Error(w, "invalid request", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
cookie, err := r.Cookie("session")
|
||||
if err == nil {
|
||||
_ = auth.DeleteSession(h.DB, cookie.Value)
|
||||
@@ -441,9 +458,11 @@ func (h *Handler) LogoutPOST(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func (h *Handler) ResetGET(w http.ResponseWriter, r *http.Request) {
|
||||
token := ensureCSRFToken(w, r, sessionSecure())
|
||||
h.ts.render(w, "reset-request.html", map[string]any{
|
||||
"Sent": r.URL.Query().Get("sent"),
|
||||
"Error": r.URL.Query().Get("error"),
|
||||
"Sent": r.URL.Query().Get("sent"),
|
||||
"Error": r.URL.Query().Get("error"),
|
||||
"CSRFToken": token,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -452,6 +471,10 @@ func (h *Handler) ResetPOST(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "bad request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if !validateCSRF(r) {
|
||||
http.Error(w, "invalid request", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
email := strings.TrimSpace(strings.ToLower(r.FormValue("email")))
|
||||
if email == "" {
|
||||
@@ -480,10 +503,12 @@ func (h *Handler) ResetPOST(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func (h *Handler) ResetConfirmGET(w http.ResponseWriter, r *http.Request) {
|
||||
token := r.PathValue("token")
|
||||
pathToken := r.PathValue("token")
|
||||
csrfToken := ensureCSRFToken(w, r, sessionSecure())
|
||||
h.ts.render(w, "reset-confirm.html", map[string]any{
|
||||
"Token": token,
|
||||
"Error": r.URL.Query().Get("error"),
|
||||
"Token": pathToken,
|
||||
"Error": r.URL.Query().Get("error"),
|
||||
"CSRFToken": csrfToken,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -494,6 +519,10 @@ func (h *Handler) ResetConfirmPOST(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "bad request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if !validateCSRF(r) {
|
||||
http.Error(w, "invalid request", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
password := r.FormValue("password")
|
||||
confirm := r.FormValue("confirm_password")
|
||||
@@ -552,6 +581,30 @@ func (h *Handler) CheckoutGET(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Reject price IDs not in the configured set.
|
||||
validPrice := false
|
||||
for _, id := range h.Stripe.PriceIDs {
|
||||
if id == priceID {
|
||||
validPrice = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !validPrice {
|
||||
http.Error(w, "invalid plan", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Block customers who already have an active or cancelling subscription.
|
||||
var existingCount int
|
||||
_ = h.DB.QueryRow(
|
||||
`SELECT COUNT(*) FROM subscriptions WHERE customer_id = ? AND status IN ('active', 'cancelling')`,
|
||||
customerID,
|
||||
).Scan(&existingCount)
|
||||
if existingCount > 0 {
|
||||
http.Redirect(w, r, "/dashboard?error=already_subscribed", http.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
|
||||
var stripeCustomerID string
|
||||
err := h.DB.QueryRow(
|
||||
`SELECT stripe_customer_id FROM customers WHERE id = ?`, customerID,
|
||||
@@ -620,20 +673,33 @@ func (h *Handler) WebhookPOST(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func (h *Handler) CancelPOST(w http.ResponseWriter, r *http.Request) {
|
||||
if err := r.ParseForm(); err != nil {
|
||||
http.Error(w, "bad request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if !validateCSRF(r) {
|
||||
http.Error(w, "invalid request", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
customerID := auth.CustomerIDFromContext(r.Context())
|
||||
|
||||
var stripeSubID string
|
||||
var stripeSubID, subStatus string
|
||||
err := h.DB.QueryRow(
|
||||
`SELECT stripe_subscription_id FROM subscriptions
|
||||
WHERE customer_id = ? AND status = 'active'
|
||||
`SELECT stripe_subscription_id, status FROM subscriptions
|
||||
WHERE customer_id = ? AND status IN ('active', 'cancelling')
|
||||
ORDER BY created_at DESC LIMIT 1`,
|
||||
customerID,
|
||||
).Scan(&stripeSubID)
|
||||
).Scan(&stripeSubID, &subStatus)
|
||||
if err != nil {
|
||||
slog.Error("cancel: find subscription", "err", err)
|
||||
http.Redirect(w, r, "/dashboard?error=no_subscription", http.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
if subStatus == "cancelling" {
|
||||
http.Redirect(w, r, "/dashboard?error=already_cancelling", http.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
|
||||
if err := payments.CancelSubscription(stripeSubID); err != nil {
|
||||
slog.Error("cancel: stripe cancel", "err", err)
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
{{define "nav-actions"}}
|
||||
<form method="POST" action="/logout" style="display:inline;">
|
||||
<input type="hidden" name="csrf_token" value="{{.CSRFToken}}">
|
||||
<button class="btn btn--ghost btn--sm" type="submit">logout</button>
|
||||
</form>
|
||||
{{end}}
|
||||
@@ -60,6 +61,7 @@
|
||||
<div class="dash-actions">
|
||||
<form method="POST" action="/cancel"
|
||||
onsubmit="return confirm('Are you sure you want to cancel your subscription? This cannot be undone.')">
|
||||
<input type="hidden" name="csrf_token" value="{{.CSRFToken}}">
|
||||
<button class="btn btn--danger btn--sm" type="submit">cancel subscription</button>
|
||||
</form>
|
||||
</div>
|
||||
@@ -109,8 +111,8 @@
|
||||
<tbody>
|
||||
{{range .Invoices}}
|
||||
<tr class="inv-table__row">
|
||||
<td class="inv-table__td inv-table__td--muted">{{.CreatedAt | slice 0 10}}</td>
|
||||
<td class="inv-table__td inv-table__td--muted">{{.PeriodStart | slice 0 10}} – {{.PeriodEnd | slice 0 10}}</td>
|
||||
<td class="inv-table__td inv-table__td--muted">{{.CreatedAt | fmtDate}}</td>
|
||||
<td class="inv-table__td inv-table__td--muted">{{.PeriodStart | fmtDate}} – {{.PeriodEnd | fmtDate}}</td>
|
||||
<td class="inv-table__td inv-table__td--right">{{.AmountDisplay}}</td>
|
||||
<td class="inv-table__td"><span class="status-badge status-badge--{{.Status}}">{{.Status}}</span></td>
|
||||
<td class="inv-table__td">
|
||||
|
||||
@@ -25,6 +25,7 @@
|
||||
{{end}}
|
||||
|
||||
<form method="POST" action="/login" class="form">
|
||||
<input type="hidden" name="csrf_token" value="{{.CSRFToken}}">
|
||||
<div class="form__group">
|
||||
<label class="form__label" for="email">Email</label>
|
||||
<input class="form__input" type="email" id="email" name="email"
|
||||
|
||||
@@ -17,12 +17,14 @@
|
||||
{{else if eq .Error "password_mismatch"}}Passwords do not match.
|
||||
{{else if eq .Error "password_too_short"}}Password must be at least 8 characters.
|
||||
{{else if eq .Error "email_taken"}}An account with that email already exists.
|
||||
{{else if eq .Error "invalid_email"}}Please enter a valid email address.
|
||||
{{else if eq .Error "server_error"}}A server error occurred. Please try again.
|
||||
{{else}}An error occurred. Please try again.{{end}}
|
||||
</div>
|
||||
{{end}}
|
||||
|
||||
<form method="POST" action="/register" class="form">
|
||||
<input type="hidden" name="csrf_token" value="{{.CSRFToken}}">
|
||||
<div class="form__row">
|
||||
<div class="form__group">
|
||||
<label class="form__label" for="first_name">First name</label>
|
||||
|
||||
@@ -23,6 +23,7 @@
|
||||
{{end}}
|
||||
|
||||
<form method="POST" action="/reset/{{.Token}}" class="form">
|
||||
<input type="hidden" name="csrf_token" value="{{.CSRFToken}}">
|
||||
<div class="form__group">
|
||||
<label class="form__label" for="password">New password <span class="form__hint">(min. 8 characters)</span></label>
|
||||
<input class="form__input" type="password" id="password" name="password"
|
||||
|
||||
@@ -29,6 +29,7 @@
|
||||
</p>
|
||||
|
||||
<form method="POST" action="/reset" class="form">
|
||||
<input type="hidden" name="csrf_token" value="{{.CSRFToken}}">
|
||||
<div class="form__group">
|
||||
<label class="form__label" for="email">Email</label>
|
||||
<input class="form__input" type="email" id="email" name="email"
|
||||
|
||||
Reference in New Issue
Block a user