migrate from sqlite3 file to psql

This commit is contained in:
Blake Ridgway
2026-04-17 19:38:35 -05:00
parent ba1770b493
commit b6ef3a73f2
5 changed files with 81 additions and 157 deletions

View File

@@ -1,9 +1,9 @@
# Required: your bot token from https://discord.com/developers/applications # Required: your bot token from https://discord.com/developers/applications
DISCORD_TOKEN=your-bot-token-here DISCORD_TOKEN=your-bot-token-here
# Required: PostgreSQL connection string
DATABASE_URL=postgres://user:password@localhost:5432/cyclingbot?sslmode=disable
# Optional: restrict slash command registration to a specific server (instant propagation) # Optional: restrict slash command registration to a specific server (instant propagation)
# Leave blank for global commands (~1 hour to propagate) # Leave blank for global commands (~1 hour to propagate)
GUILD_ID= GUILD_ID=
# Optional: path to the SQLite database file (default: cycling_bot.db)
DB_PATH=cycling_bot.db

165
db/db.go
View File

@@ -6,7 +6,7 @@ import (
"fmt" "fmt"
"time" "time"
_ "modernc.org/sqlite" _ "github.com/lib/pq"
) )
type DB struct { type DB struct {
@@ -36,11 +36,14 @@ type ChallengeArchive struct {
EndDate time.Time EndDate time.Time
} }
func Open(path string) (*DB, error) { func Open(connStr string) (*DB, error) {
conn, err := sql.Open("sqlite", path) conn, err := sql.Open("postgres", connStr)
if err != nil { if err != nil {
return nil, fmt.Errorf("open db: %w", err) return nil, fmt.Errorf("open db: %w", err)
} }
if err := conn.Ping(); err != nil {
return nil, fmt.Errorf("ping db: %w", err)
}
d := &DB{conn: conn} d := &DB{conn: conn}
if err := d.migrate(); err != nil { if err := d.migrate(); err != nil {
return nil, fmt.Errorf("migrate: %w", err) return nil, fmt.Errorf("migrate: %w", err)
@@ -55,14 +58,14 @@ func (d *DB) Close() error {
func (d *DB) migrate() error { func (d *DB) migrate() error {
_, err := d.conn.Exec(` _, err := d.conn.Exec(`
CREATE TABLE IF NOT EXISTS distance_logs ( CREATE TABLE IF NOT EXISTS distance_logs (
id INTEGER PRIMARY KEY AUTOINCREMENT, id SERIAL PRIMARY KEY,
guild_id TEXT NOT NULL DEFAULT '', guild_id TEXT NOT NULL DEFAULT '',
user_id TEXT NOT NULL, user_id TEXT NOT NULL,
username TEXT NOT NULL, username TEXT NOT NULL,
km REAL NOT NULL, km DOUBLE PRECISION NOT NULL,
message_id TEXT NOT NULL UNIQUE, message_id TEXT NOT NULL UNIQUE,
channel_id TEXT NOT NULL, channel_id TEXT NOT NULL,
logged_at DATETIME DEFAULT CURRENT_TIMESTAMP logged_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP
); );
CREATE TABLE IF NOT EXISTS settings ( CREATE TABLE IF NOT EXISTS settings (
@@ -73,13 +76,13 @@ func (d *DB) migrate() error {
); );
CREATE TABLE IF NOT EXISTS challenge_archive ( CREATE TABLE IF NOT EXISTS challenge_archive (
id INTEGER PRIMARY KEY AUTOINCREMENT, id SERIAL PRIMARY KEY,
guild_id TEXT NOT NULL, guild_id TEXT NOT NULL,
name TEXT NOT NULL DEFAULT '', name TEXT NOT NULL DEFAULT '',
total_km REAL NOT NULL, total_km DOUBLE PRECISION NOT NULL,
riders INTEGER NOT NULL, riders INTEGER NOT NULL,
start_date DATETIME NOT NULL, start_date TIMESTAMPTZ NOT NULL,
end_date DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP end_date TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
); );
CREATE TABLE IF NOT EXISTS user_preferences ( CREATE TABLE IF NOT EXISTS user_preferences (
@@ -90,20 +93,19 @@ func (d *DB) migrate() error {
); );
CREATE TABLE IF NOT EXISTS kudos ( CREATE TABLE IF NOT EXISTS kudos (
id INTEGER PRIMARY KEY AUTOINCREMENT, id SERIAL PRIMARY KEY,
guild_id TEXT NOT NULL, guild_id TEXT NOT NULL,
from_user_id TEXT NOT NULL, from_user_id TEXT NOT NULL,
from_username TEXT NOT NULL, from_username TEXT NOT NULL,
to_user_id TEXT NOT NULL, to_user_id TEXT NOT NULL,
to_username TEXT NOT NULL, to_username TEXT NOT NULL,
given_at DATETIME DEFAULT CURRENT_TIMESTAMP given_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP
); );
`) `)
if err != nil { if err != nil {
return err return err
} }
// Non-fatal migration for existing databases _, _ = d.conn.Exec(`ALTER TABLE distance_logs ADD COLUMN IF NOT EXISTS guild_id TEXT NOT NULL DEFAULT ''`)
_, _ = d.conn.Exec(`ALTER TABLE distance_logs ADD COLUMN guild_id TEXT NOT NULL DEFAULT ''`)
return nil return nil
} }
@@ -112,7 +114,7 @@ func (d *DB) migrate() error {
func (d *DB) GetSetting(ctx context.Context, guildID, key string) (string, bool, error) { func (d *DB) GetSetting(ctx context.Context, guildID, key string) (string, bool, error) {
var value string var value string
err := d.conn.QueryRowContext(ctx, err := d.conn.QueryRowContext(ctx,
`SELECT value FROM settings WHERE guild_id = ? AND key = ?`, guildID, key, `SELECT value FROM settings WHERE guild_id = $1 AND key = $2`, guildID, key,
).Scan(&value) ).Scan(&value)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return "", false, nil return "", false, nil
@@ -122,25 +124,24 @@ func (d *DB) GetSetting(ctx context.Context, guildID, key string) (string, bool,
func (d *DB) SetSetting(ctx context.Context, guildID, key, value string) error { func (d *DB) SetSetting(ctx context.Context, guildID, key, value string) error {
_, err := d.conn.ExecContext(ctx, ` _, err := d.conn.ExecContext(ctx, `
INSERT INTO settings (guild_id, key, value) VALUES (?, ?, ?) INSERT INTO settings (guild_id, key, value) VALUES ($1, $2, $3)
ON CONFLICT(guild_id, key) DO UPDATE SET value = excluded.value ON CONFLICT(guild_id, key) DO UPDATE SET value = EXCLUDED.value
`, guildID, key, value) `, guildID, key, value)
return err return err
} }
func (d *DB) DeleteSetting(ctx context.Context, guildID, key string) error { func (d *DB) DeleteSetting(ctx context.Context, guildID, key string) error {
_, err := d.conn.ExecContext(ctx, _, err := d.conn.ExecContext(ctx,
`DELETE FROM settings WHERE guild_id = ? AND key = ?`, guildID, key) `DELETE FROM settings WHERE guild_id = $1 AND key = $2`, guildID, key)
return err return err
} }
// ── Core Logging ────────────────────────────────────────────────────────────── // ── Core Logging ──────────────────────────────────────────────────────────────
// AddLog records a distance entry. Returns false if the message was already processed.
func (d *DB) AddLog(ctx context.Context, guildID, userID, username, messageID, channelID string, km float64) (bool, error) { func (d *DB) AddLog(ctx context.Context, guildID, userID, username, messageID, channelID string, km float64) (bool, error) {
res, err := d.conn.ExecContext(ctx, ` res, err := d.conn.ExecContext(ctx, `
INSERT INTO distance_logs (guild_id, user_id, username, km, message_id, channel_id) INSERT INTO distance_logs (guild_id, user_id, username, km, message_id, channel_id)
VALUES (?, ?, ?, ?, ?, ?) VALUES ($1, $2, $3, $4, $5, $6)
ON CONFLICT(message_id) DO NOTHING ON CONFLICT(message_id) DO NOTHING
`, guildID, userID, username, km, messageID, channelID) `, guildID, userID, username, km, messageID, channelID)
if err != nil { if err != nil {
@@ -150,43 +151,40 @@ func (d *DB) AddLog(ctx context.Context, guildID, userID, username, messageID, c
return rows > 0, nil return rows > 0, nil
} }
// RemoveLog deletes a log by message ID. Returns the removed KM (0 if not found).
func (d *DB) RemoveLog(ctx context.Context, messageID string) (float64, error) { func (d *DB) RemoveLog(ctx context.Context, messageID string) (float64, error) {
var km float64 var km float64
err := d.conn.QueryRowContext(ctx, err := d.conn.QueryRowContext(ctx,
`SELECT km FROM distance_logs WHERE message_id = ?`, messageID).Scan(&km) `SELECT km FROM distance_logs WHERE message_id = $1`, messageID).Scan(&km)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return 0, nil return 0, nil
} }
if err != nil { if err != nil {
return 0, err return 0, err
} }
_, err = d.conn.ExecContext(ctx, `DELETE FROM distance_logs WHERE message_id = ?`, messageID) _, err = d.conn.ExecContext(ctx, `DELETE FROM distance_logs WHERE message_id = $1`, messageID)
return km, err return km, err
} }
// AdjustKM manually adds or subtracts KM for a user.
func (d *DB) AdjustKM(ctx context.Context, guildID, userID, username string, km float64) error { func (d *DB) AdjustKM(ctx context.Context, guildID, userID, username string, km float64) error {
_, err := d.conn.ExecContext(ctx, ` _, err := d.conn.ExecContext(ctx, `
INSERT INTO distance_logs (guild_id, user_id, username, km, message_id, channel_id) INSERT INTO distance_logs (guild_id, user_id, username, km, message_id, channel_id)
VALUES (?, ?, ?, ?, 'manual-' || hex(randomblob(8)), 'admin') VALUES ($1, $2, $3, $4, 'manual-' || gen_random_uuid()::text, 'admin')
`, guildID, userID, username, km) `, guildID, userID, username, km)
return err return err
} }
// ── Stats Queries ───────────────────────────────────────────────────────────── // ── Stats Queries ─────────────────────────────────────────────────────────────
// GetLeaderboard returns top N users by total KM within the optional since time.
func (d *DB) GetLeaderboard(ctx context.Context, guildID string, since time.Time, limit int) ([]*UserStats, error) { func (d *DB) GetLeaderboard(ctx context.Context, guildID string, since time.Time, limit int) ([]*UserStats, error) {
q := `SELECT user_id, username, SUM(km), COUNT(*), MAX(logged_at) q := `SELECT user_id, username, SUM(km), COUNT(*), MAX(logged_at)
FROM distance_logs WHERE guild_id = ?` FROM distance_logs WHERE guild_id = $1`
args := []interface{}{guildID} args := []interface{}{guildID}
if !since.IsZero() { if !since.IsZero() {
q += ` AND logged_at >= ?`
args = append(args, since) args = append(args, since)
q += fmt.Sprintf(` AND logged_at >= $%d`, len(args))
} }
q += ` GROUP BY user_id ORDER BY SUM(km) DESC LIMIT ?`
args = append(args, limit) args = append(args, limit)
q += fmt.Sprintf(` GROUP BY user_id, username ORDER BY SUM(km) DESC LIMIT $%d`, len(args))
rows, err := d.conn.QueryContext(ctx, q, args...) rows, err := d.conn.QueryContext(ctx, q, args...)
if err != nil { if err != nil {
@@ -197,35 +195,35 @@ func (d *DB) GetLeaderboard(ctx context.Context, guildID string, since time.Time
var results []*UserStats var results []*UserStats
for rows.Next() { for rows.Next() {
s := &UserStats{} s := &UserStats{}
if err := rows.Scan(&s.UserID, &s.Username, &s.TotalKM, &s.LogCount, &s.LastUpdated); err != nil { var lastUpdated time.Time
if err := rows.Scan(&s.UserID, &s.Username, &s.TotalKM, &s.LogCount, &lastUpdated); err != nil {
return nil, err return nil, err
} }
s.LastUpdated = lastUpdated.Format(time.RFC3339)
results = append(results, s) results = append(results, s)
} }
return results, rows.Err() return results, rows.Err()
} }
// GetTotalKM returns combined KM for a guild within the optional since time.
func (d *DB) GetTotalKM(ctx context.Context, guildID string, since time.Time) (float64, error) { func (d *DB) GetTotalKM(ctx context.Context, guildID string, since time.Time) (float64, error) {
q := `SELECT COALESCE(SUM(km), 0) FROM distance_logs WHERE guild_id = ?` q := `SELECT COALESCE(SUM(km), 0) FROM distance_logs WHERE guild_id = $1`
args := []interface{}{guildID} args := []interface{}{guildID}
if !since.IsZero() { if !since.IsZero() {
q += ` AND logged_at >= ?`
args = append(args, since) args = append(args, since)
q += fmt.Sprintf(` AND logged_at >= $%d`, len(args))
} }
var total float64 var total float64
err := d.conn.QueryRowContext(ctx, q, args...).Scan(&total) err := d.conn.QueryRowContext(ctx, q, args...).Scan(&total)
return total, err return total, err
} }
// GetUserStats returns cumulative stats for a single user within the optional since time.
func (d *DB) GetUserStats(ctx context.Context, guildID, userID string, since time.Time) (*UserStats, error) { func (d *DB) GetUserStats(ctx context.Context, guildID, userID string, since time.Time) (*UserStats, error) {
q := `SELECT user_id, username, COALESCE(SUM(km), 0), COUNT(*), COALESCE(MAX(logged_at), '') q := `SELECT user_id, username, COALESCE(SUM(km), 0), COUNT(*), COALESCE(MAX(logged_at)::text, '')
FROM distance_logs WHERE guild_id = ? AND user_id = ?` FROM distance_logs WHERE guild_id = $1 AND user_id = $2`
args := []interface{}{guildID, userID} args := []interface{}{guildID, userID}
if !since.IsZero() { if !since.IsZero() {
q += ` AND logged_at >= ?`
args = append(args, since) args = append(args, since)
q += fmt.Sprintf(` AND logged_at >= $%d`, len(args))
} }
q += ` GROUP BY user_id, username` q += ` GROUP BY user_id, username`
@@ -238,15 +236,14 @@ func (d *DB) GetUserStats(ctx context.Context, guildID, userID string, since tim
return s, err return s, err
} }
// GetStatsInRange returns user stats for a specific date range (for weekly/monthly reports).
func (d *DB) GetStatsInRange(ctx context.Context, guildID string, from, to time.Time, limit int) ([]*UserStats, error) { func (d *DB) GetStatsInRange(ctx context.Context, guildID string, from, to time.Time, limit int) ([]*UserStats, error) {
rows, err := d.conn.QueryContext(ctx, ` rows, err := d.conn.QueryContext(ctx, `
SELECT user_id, username, SUM(km), COUNT(*), MAX(logged_at) SELECT user_id, username, SUM(km), COUNT(*), MAX(logged_at)
FROM distance_logs FROM distance_logs
WHERE guild_id = ? AND logged_at >= ? AND logged_at <= ? WHERE guild_id = $1 AND logged_at >= $2 AND logged_at <= $3
GROUP BY user_id GROUP BY user_id, username
ORDER BY SUM(km) DESC ORDER BY SUM(km) DESC
LIMIT ? LIMIT $4
`, guildID, from, to, limit) `, guildID, from, to, limit)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -256,20 +253,21 @@ func (d *DB) GetStatsInRange(ctx context.Context, guildID string, from, to time.
var results []*UserStats var results []*UserStats
for rows.Next() { for rows.Next() {
s := &UserStats{} s := &UserStats{}
if err := rows.Scan(&s.UserID, &s.Username, &s.TotalKM, &s.LogCount, &s.LastUpdated); err != nil { var lastUpdated time.Time
if err := rows.Scan(&s.UserID, &s.Username, &s.TotalKM, &s.LogCount, &lastUpdated); err != nil {
return nil, err return nil, err
} }
s.LastUpdated = lastUpdated.Format(time.RFC3339)
results = append(results, s) results = append(results, s)
} }
return results, rows.Err() return results, rows.Err()
} }
// GetYearTotals returns total KM per calendar year for a guild, newest first.
func (d *DB) GetYearTotals(ctx context.Context, guildID string) ([]YearTotal, error) { func (d *DB) GetYearTotals(ctx context.Context, guildID string) ([]YearTotal, error) {
rows, err := d.conn.QueryContext(ctx, ` rows, err := d.conn.QueryContext(ctx, `
SELECT CAST(strftime('%Y', logged_at) AS INTEGER) as yr, SUM(km) SELECT EXTRACT(YEAR FROM logged_at)::INTEGER as yr, SUM(km)
FROM distance_logs FROM distance_logs
WHERE guild_id = ? WHERE guild_id = $1
GROUP BY yr GROUP BY yr
ORDER BY yr DESC ORDER BY yr DESC
`, guildID) `, guildID)
@@ -294,16 +292,15 @@ type YearTotal struct {
TotalKM float64 TotalKM float64
} }
// GetYearlyLeaderboard returns the top N users for a given calendar year.
func (d *DB) GetYearlyLeaderboard(ctx context.Context, guildID string, year, limit int) ([]*UserStats, error) { func (d *DB) GetYearlyLeaderboard(ctx context.Context, guildID string, year, limit int) ([]*UserStats, error) {
rows, err := d.conn.QueryContext(ctx, ` rows, err := d.conn.QueryContext(ctx, `
SELECT user_id, username, SUM(km), COUNT(*), MAX(logged_at) SELECT user_id, username, SUM(km), COUNT(*), MAX(logged_at)
FROM distance_logs FROM distance_logs
WHERE guild_id = ? AND strftime('%Y', logged_at) = ? WHERE guild_id = $1 AND EXTRACT(YEAR FROM logged_at) = $2
GROUP BY user_id GROUP BY user_id, username
ORDER BY SUM(km) DESC ORDER BY SUM(km) DESC
LIMIT ? LIMIT $3
`, guildID, fmt.Sprintf("%d", year), limit) `, guildID, year, limit)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -312,40 +309,43 @@ func (d *DB) GetYearlyLeaderboard(ctx context.Context, guildID string, year, lim
var results []*UserStats var results []*UserStats
for rows.Next() { for rows.Next() {
s := &UserStats{} s := &UserStats{}
if err := rows.Scan(&s.UserID, &s.Username, &s.TotalKM, &s.LogCount, &s.LastUpdated); err != nil { var lastUpdated time.Time
if err := rows.Scan(&s.UserID, &s.Username, &s.TotalKM, &s.LogCount, &lastUpdated); err != nil {
return nil, err return nil, err
} }
s.LastUpdated = lastUpdated.Format(time.RFC3339)
results = append(results, s) results = append(results, s)
} }
return results, rows.Err() return results, rows.Err()
} }
// GetUserYearlyStats returns a user's stats for a given calendar year.
func (d *DB) GetUserYearlyStats(ctx context.Context, guildID, userID string, year int) (*UserStats, error) { func (d *DB) GetUserYearlyStats(ctx context.Context, guildID, userID string, year int) (*UserStats, error) {
s := &UserStats{UserID: userID} s := &UserStats{UserID: userID}
var lastUpdated sql.NullTime
err := d.conn.QueryRowContext(ctx, ` err := d.conn.QueryRowContext(ctx, `
SELECT user_id, username, COALESCE(SUM(km), 0), COUNT(*), COALESCE(MAX(logged_at), '') SELECT user_id, username, COALESCE(SUM(km), 0), COUNT(*), MAX(logged_at)
FROM distance_logs FROM distance_logs
WHERE guild_id = ? AND user_id = ? AND strftime('%Y', logged_at) = ? WHERE guild_id = $1 AND user_id = $2 AND EXTRACT(YEAR FROM logged_at) = $3
GROUP BY user_id, username GROUP BY user_id, username
`, guildID, userID, fmt.Sprintf("%d", year)).Scan( `, guildID, userID, year).Scan(&s.UserID, &s.Username, &s.TotalKM, &s.LogCount, &lastUpdated)
&s.UserID, &s.Username, &s.TotalKM, &s.LogCount, &s.LastUpdated)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return s, nil return s, nil
} }
if lastUpdated.Valid {
s.LastUpdated = lastUpdated.Time.Format(time.RFC3339)
}
return s, err return s, err
} }
// ── Personal Stats ──────────────────────────────────────────────────────────── // ── Personal Stats ────────────────────────────────────────────────────────────
// GetUserHistory returns the last N ride logs for a user.
func (d *DB) GetUserHistory(ctx context.Context, guildID, userID string, limit int) ([]*RideLog, error) { func (d *DB) GetUserHistory(ctx context.Context, guildID, userID string, limit int) ([]*RideLog, error) {
rows, err := d.conn.QueryContext(ctx, ` rows, err := d.conn.QueryContext(ctx, `
SELECT id, km, message_id, logged_at SELECT id, km, message_id, logged_at
FROM distance_logs FROM distance_logs
WHERE guild_id = ? AND user_id = ? WHERE guild_id = $1 AND user_id = $2
ORDER BY logged_at DESC ORDER BY logged_at DESC
LIMIT ? LIMIT $3
`, guildID, userID, limit) `, guildID, userID, limit)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -355,31 +355,27 @@ func (d *DB) GetUserHistory(ctx context.Context, guildID, userID string, limit i
var logs []*RideLog var logs []*RideLog
for rows.Next() { for rows.Next() {
l := &RideLog{} l := &RideLog{}
var loggedAt string if err := rows.Scan(&l.ID, &l.KM, &l.MessageID, &l.LoggedAt); err != nil {
if err := rows.Scan(&l.ID, &l.KM, &l.MessageID, &loggedAt); err != nil {
return nil, err return nil, err
} }
l.LoggedAt, _ = time.Parse("2006-01-02 15:04:05", loggedAt)
logs = append(logs, l) logs = append(logs, l)
} }
return logs, rows.Err() return logs, rows.Err()
} }
// GetUserPB returns a user's personal best single-ride distance.
func (d *DB) GetUserPB(ctx context.Context, guildID, userID string) (float64, error) { func (d *DB) GetUserPB(ctx context.Context, guildID, userID string) (float64, error) {
var pb float64 var pb float64
err := d.conn.QueryRowContext(ctx, err := d.conn.QueryRowContext(ctx,
`SELECT COALESCE(MAX(km), 0) FROM distance_logs WHERE guild_id = ? AND user_id = ?`, `SELECT COALESCE(MAX(km), 0) FROM distance_logs WHERE guild_id = $1 AND user_id = $2`,
guildID, userID).Scan(&pb) guildID, userID).Scan(&pb)
return pb, err return pb, err
} }
// GetUserStreak returns the number of consecutive days a user has logged a ride.
func (d *DB) GetUserStreak(ctx context.Context, userID string) (int, error) { func (d *DB) GetUserStreak(ctx context.Context, userID string) (int, error) {
rows, err := d.conn.QueryContext(ctx, ` rows, err := d.conn.QueryContext(ctx, `
SELECT DISTINCT DATE(logged_at) as ride_date SELECT DISTINCT logged_at::date as ride_date
FROM distance_logs FROM distance_logs
WHERE user_id = ? WHERE user_id = $1
ORDER BY ride_date DESC ORDER BY ride_date DESC
`, userID) `, userID)
if err != nil { if err != nil {
@@ -389,14 +385,10 @@ func (d *DB) GetUserStreak(ctx context.Context, userID string) (int, error) {
var dates []time.Time var dates []time.Time
for rows.Next() { for rows.Next() {
var s string var t time.Time
if err := rows.Scan(&s); err != nil { if err := rows.Scan(&t); err != nil {
return 0, err return 0, err
} }
t, err := time.Parse("2006-01-02", s)
if err != nil {
continue
}
dates = append(dates, t) dates = append(dates, t)
} }
if len(dates) == 0 { if len(dates) == 0 {
@@ -404,7 +396,6 @@ func (d *DB) GetUserStreak(ctx context.Context, userID string) (int, error) {
} }
today := time.Now().UTC().Truncate(24 * time.Hour) today := time.Now().UTC().Truncate(24 * time.Hour)
// Streak must include today or yesterday
if dates[0].Before(today.Add(-24 * time.Hour)) { if dates[0].Before(today.Add(-24 * time.Hour)) {
return 0, nil return 0, nil
} }
@@ -425,7 +416,7 @@ func (d *DB) GetUserStreak(ctx context.Context, userID string) (int, error) {
func (d *DB) GetUserPreference(ctx context.Context, userID, guildID string) (string, error) { func (d *DB) GetUserPreference(ctx context.Context, userID, guildID string) (string, error) {
var unit string var unit string
err := d.conn.QueryRowContext(ctx, err := d.conn.QueryRowContext(ctx,
`SELECT unit FROM user_preferences WHERE user_id = ? AND guild_id = ?`, `SELECT unit FROM user_preferences WHERE user_id = $1 AND guild_id = $2`,
userID, guildID).Scan(&unit) userID, guildID).Scan(&unit)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return "km", nil return "km", nil
@@ -435,8 +426,8 @@ func (d *DB) GetUserPreference(ctx context.Context, userID, guildID string) (str
func (d *DB) SetUserPreference(ctx context.Context, userID, guildID, unit string) error { func (d *DB) SetUserPreference(ctx context.Context, userID, guildID, unit string) error {
_, err := d.conn.ExecContext(ctx, ` _, err := d.conn.ExecContext(ctx, `
INSERT INTO user_preferences (user_id, guild_id, unit) VALUES (?, ?, ?) INSERT INTO user_preferences (user_id, guild_id, unit) VALUES ($1, $2, $3)
ON CONFLICT(user_id, guild_id) DO UPDATE SET unit = excluded.unit ON CONFLICT(user_id, guild_id) DO UPDATE SET unit = EXCLUDED.unit
`, userID, guildID, unit) `, userID, guildID, unit)
return err return err
} }
@@ -446,7 +437,7 @@ func (d *DB) SetUserPreference(ctx context.Context, userID, guildID, unit string
func (d *DB) GiveKudos(ctx context.Context, guildID, fromUserID, fromUsername, toUserID, toUsername string) error { func (d *DB) GiveKudos(ctx context.Context, guildID, fromUserID, fromUsername, toUserID, toUsername string) error {
_, err := d.conn.ExecContext(ctx, ` _, err := d.conn.ExecContext(ctx, `
INSERT INTO kudos (guild_id, from_user_id, from_username, to_user_id, to_username) INSERT INTO kudos (guild_id, from_user_id, from_username, to_user_id, to_username)
VALUES (?, ?, ?, ?, ?) VALUES ($1, $2, $3, $4, $5)
`, guildID, fromUserID, fromUsername, toUserID, toUsername) `, guildID, fromUserID, fromUsername, toUserID, toUsername)
return err return err
} }
@@ -454,14 +445,13 @@ func (d *DB) GiveKudos(ctx context.Context, guildID, fromUserID, fromUsername, t
func (d *DB) GetKudosReceived(ctx context.Context, guildID, toUserID string) (int, error) { func (d *DB) GetKudosReceived(ctx context.Context, guildID, toUserID string) (int, error) {
var count int var count int
err := d.conn.QueryRowContext(ctx, err := d.conn.QueryRowContext(ctx,
`SELECT COUNT(*) FROM kudos WHERE guild_id = ? AND to_user_id = ?`, `SELECT COUNT(*) FROM kudos WHERE guild_id = $1 AND to_user_id = $2`,
guildID, toUserID).Scan(&count) guildID, toUserID).Scan(&count)
return count, err return count, err
} }
// ── Challenge Management ────────────────────────────────────────────────────── // ── Challenge Management ──────────────────────────────────────────────────────
// GetChallengeStart returns the start time of the current challenge period.
func (d *DB) GetChallengeStart(ctx context.Context, guildID string) (time.Time, bool, error) { func (d *DB) GetChallengeStart(ctx context.Context, guildID string) (time.Time, bool, error) {
val, ok, err := d.GetSetting(ctx, guildID, "challenge_start") val, ok, err := d.GetSetting(ctx, guildID, "challenge_start")
if !ok || err != nil { if !ok || err != nil {
@@ -471,11 +461,9 @@ func (d *DB) GetChallengeStart(ctx context.Context, guildID string) (time.Time,
return t, err == nil, err return t, err == nil, err
} }
// ResetChallenge archives current stats and starts a new challenge period.
func (d *DB) ResetChallenge(ctx context.Context, guildID, name string) error { func (d *DB) ResetChallenge(ctx context.Context, guildID, name string) error {
challengeStart, hasPrev, _ := d.GetChallengeStart(ctx, guildID) challengeStart, hasPrev, _ := d.GetChallengeStart(ctx, guildID)
// Calculate current totals before reset
total, err := d.GetTotalKM(ctx, guildID, challengeStart) total, err := d.GetTotalKM(ctx, guildID, challengeStart)
if err != nil { if err != nil {
return err return err
@@ -486,7 +474,6 @@ func (d *DB) ResetChallenge(ctx context.Context, guildID, name string) error {
return err return err
} }
// Archive if there was a previous period with data
if hasPrev && total > 0 { if hasPrev && total > 0 {
archiveName := name archiveName := name
if archiveName == "" { if archiveName == "" {
@@ -494,23 +481,21 @@ func (d *DB) ResetChallenge(ctx context.Context, guildID, name string) error {
} }
_, err = d.conn.ExecContext(ctx, ` _, err = d.conn.ExecContext(ctx, `
INSERT INTO challenge_archive (guild_id, name, total_km, riders, start_date) INSERT INTO challenge_archive (guild_id, name, total_km, riders, start_date)
VALUES (?, ?, ?, ?, ?) VALUES ($1, $2, $3, $4, $5)
`, guildID, archiveName, total, len(entries), challengeStart) `, guildID, archiveName, total, len(entries), challengeStart)
if err != nil { if err != nil {
return err return err
} }
} }
// Set new challenge start
return d.SetSetting(ctx, guildID, "challenge_start", time.Now().UTC().Format(time.RFC3339)) return d.SetSetting(ctx, guildID, "challenge_start", time.Now().UTC().Format(time.RFC3339))
} }
// GetChallengeArchive returns past challenge records for a guild.
func (d *DB) GetChallengeArchive(ctx context.Context, guildID string) ([]*ChallengeArchive, error) { func (d *DB) GetChallengeArchive(ctx context.Context, guildID string) ([]*ChallengeArchive, error) {
rows, err := d.conn.QueryContext(ctx, ` rows, err := d.conn.QueryContext(ctx, `
SELECT name, total_km, riders, start_date, end_date SELECT name, total_km, riders, start_date, end_date
FROM challenge_archive FROM challenge_archive
WHERE guild_id = ? WHERE guild_id = $1
ORDER BY end_date DESC ORDER BY end_date DESC
`, guildID) `, guildID)
if err != nil { if err != nil {
@@ -521,12 +506,9 @@ func (d *DB) GetChallengeArchive(ctx context.Context, guildID string) ([]*Challe
var results []*ChallengeArchive var results []*ChallengeArchive
for rows.Next() { for rows.Next() {
a := &ChallengeArchive{} a := &ChallengeArchive{}
var start, end string if err := rows.Scan(&a.Name, &a.TotalKM, &a.Riders, &a.StartDate, &a.EndDate); err != nil {
if err := rows.Scan(&a.Name, &a.TotalKM, &a.Riders, &start, &end); err != nil {
return nil, err return nil, err
} }
a.StartDate, _ = time.Parse("2006-01-02 15:04:05", start)
a.EndDate, _ = time.Parse("2006-01-02 15:04:05", end)
results = append(results, a) results = append(results, a)
} }
return results, rows.Err() return results, rows.Err()
@@ -534,7 +516,6 @@ func (d *DB) GetChallengeArchive(ctx context.Context, guildID string) ([]*Challe
// ── Admin ───────────────────────────────────────────────────────────────────── // ── Admin ─────────────────────────────────────────────────────────────────────
// GetUserLogs returns all ride logs for a user (for audit).
func (d *DB) GetUserLogs(ctx context.Context, guildID, userID string, limit int) ([]*RideLog, error) { func (d *DB) GetUserLogs(ctx context.Context, guildID, userID string, limit int) ([]*RideLog, error) {
return d.GetUserHistory(ctx, guildID, userID, limit) return d.GetUserHistory(ctx, guildID, userID, limit)
} }

14
go.mod
View File

@@ -5,23 +5,11 @@ go 1.22
require ( require (
github.com/bwmarrin/discordgo v0.28.1 github.com/bwmarrin/discordgo v0.28.1
github.com/joho/godotenv v1.5.1 github.com/joho/godotenv v1.5.1
modernc.org/sqlite v1.29.10 github.com/lib/pq v1.12.3
) )
require ( require (
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/gorilla/websocket v1.4.2 // indirect github.com/gorilla/websocket v1.4.2 // indirect
github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/ncruces/go-strftime v0.1.9 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b // indirect golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b // indirect
golang.org/x/sys v0.19.0 // indirect golang.org/x/sys v0.19.0 // indirect
modernc.org/gc/v3 v3.0.0-20240107210532-573471604cb6 // indirect
modernc.org/libc v1.49.3 // indirect
modernc.org/mathutil v1.6.0 // indirect
modernc.org/memory v1.8.0 // indirect
modernc.org/strutil v1.2.0 // indirect
modernc.org/token v1.1.0 // indirect
) )

49
go.sum
View File

@@ -1,62 +1,17 @@
github.com/bwmarrin/discordgo v0.28.1 h1:gXsuo2GBO7NbR6uqmrrBDplPUx2T3nzu775q/Rd1aG4= github.com/bwmarrin/discordgo v0.28.1 h1:gXsuo2GBO7NbR6uqmrrBDplPUx2T3nzu775q/Rd1aG4=
github.com/bwmarrin/discordgo v0.28.1/go.mod h1:NJZpH+1AfhIcyQsPeuBKsUtYrRnjkyu0kIVMCHkZtRY= github.com/bwmarrin/discordgo v0.28.1/go.mod h1:NJZpH+1AfhIcyQsPeuBKsUtYrRnjkyu0kIVMCHkZtRY=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd h1:gbpYu9NMq8jhDVbvlGkMFWCjLFlqqEZjEmObmhUy6Vo=
github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd/go.mod h1:kf6iHlnVGwgKolg33glAes7Yg/8iWP8ukqeldJSO7jw=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc= github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc=
github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/lib/pq v1.12.3 h1:tTWxr2YLKwIvK90ZXEw8GP7UFHtcbTtty8zsI+YjrfQ=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/lib/pq v1.12.3/go.mod h1:/p+8NSbOcwzAEI7wiMXFlgydTwcgTr3OSKMsD2BitpA=
github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4=
github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b h1:7mWr3k41Qtv8XlltBkDkl8LoP3mpSgBW8BUoxtEdbXg= golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b h1:7mWr3k41Qtv8XlltBkDkl8LoP3mpSgBW8BUoxtEdbXg=
golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
golang.org/x/mod v0.16.0 h1:QX4fJ0Rr5cPQCF7O9lh9Se4pmwfwskqZfq5moyldzic=
golang.org/x/mod v0.16.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.19.0 h1:q5f1RH2jigJ1MoAWp2KTp3gm5zAGFUTarQZ5U386+4o= golang.org/x/sys v0.19.0 h1:q5f1RH2jigJ1MoAWp2KTp3gm5zAGFUTarQZ5U386+4o=
golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.19.0 h1:tfGCXNR1OsFG+sVdLAitlpjAvD/I6dHDKnYrpEZUHkw=
golang.org/x/tools v0.19.0/go.mod h1:qoJWxmGSIBmAeriMx19ogtrEPrGtDbPK634QFIcLAhc=
modernc.org/cc/v4 v4.20.0 h1:45Or8mQfbUqJOG9WaxvlFYOAQO0lQ5RvqBcFCXngjxk=
modernc.org/cc/v4 v4.20.0/go.mod h1:HM7VJTZbUCR3rV8EYBi9wxnJ0ZBRiGE5OeGXNA0IsLQ=
modernc.org/ccgo/v4 v4.16.0 h1:ofwORa6vx2FMm0916/CkZjpFPSR70VwTjUCe2Eg5BnA=
modernc.org/ccgo/v4 v4.16.0/go.mod h1:dkNyWIjFrVIZ68DTo36vHK+6/ShBn4ysU61So6PIqCI=
modernc.org/fileutil v1.3.0 h1:gQ5SIzK3H9kdfai/5x41oQiKValumqNTDXMvKo62HvE=
modernc.org/fileutil v1.3.0/go.mod h1:XatxS8fZi3pS8/hKG2GH/ArUogfxjpEKs3Ku3aK4JyQ=
modernc.org/gc/v2 v2.4.1 h1:9cNzOqPyMJBvrUipmynX0ZohMhcxPtMccYgGOJdOiBw=
modernc.org/gc/v2 v2.4.1/go.mod h1:wzN5dK1AzVGoH6XOzc3YZ+ey/jPgYHLuVckd62P0GYU=
modernc.org/gc/v3 v3.0.0-20240107210532-573471604cb6 h1:5D53IMaUuA5InSeMu9eJtlQXS2NxAhyWQvkKEgXZhHI=
modernc.org/gc/v3 v3.0.0-20240107210532-573471604cb6/go.mod h1:Qz0X07sNOR1jWYCrJMEnbW/X55x206Q7Vt4mz6/wHp4=
modernc.org/libc v1.49.3 h1:j2MRCRdwJI2ls/sGbeSk0t2bypOG/uvPZUsGQFDulqg=
modernc.org/libc v1.49.3/go.mod h1:yMZuGkn7pXbKfoT/M35gFJOAEdSKdxL0q64sF7KqCDo=
modernc.org/mathutil v1.6.0 h1:fRe9+AmYlaej+64JsEEhoWuAYBkOtQiMEU7n/XgfYi4=
modernc.org/mathutil v1.6.0/go.mod h1:Ui5Q9q1TR2gFm0AQRqQUaBWFLAhQpCwNcuhBOSedWPo=
modernc.org/memory v1.8.0 h1:IqGTL6eFMaDZZhEWwcREgeMXYwmW83LYW8cROZYkg+E=
modernc.org/memory v1.8.0/go.mod h1:XPZ936zp5OMKGWPqbD3JShgd/ZoQ7899TUuQqxY+peU=
modernc.org/opt v0.1.3 h1:3XOZf2yznlhC+ibLltsDGzABUGVx8J6pnFMS3E4dcq4=
modernc.org/opt v0.1.3/go.mod h1:WdSiB5evDcignE70guQKxYUl14mgWtbClRi5wmkkTX0=
modernc.org/sortutil v1.2.0 h1:jQiD3PfS2REGJNzNCMMaLSp/wdMNieTbKX920Cqdgqc=
modernc.org/sortutil v1.2.0/go.mod h1:TKU2s7kJMf1AE84OoiGppNHJwvB753OYfNl2WRb++Ss=
modernc.org/sqlite v1.29.10 h1:3u93dz83myFnMilBGCOLbr+HjklS6+5rJLx4q86RDAg=
modernc.org/sqlite v1.29.10/go.mod h1:ItX2a1OVGgNsFh6Dv60JQvGfJfTPHPVpV6DF59akYOA=
modernc.org/strutil v1.2.0 h1:agBi9dp1I+eOnxXeiZawM8F4LawKv4NzGWSaLfyeNZA=
modernc.org/strutil v1.2.0/go.mod h1:/mdcBmfOibveCTBxUl5B5l6W+TTH1FXPLHZE6bTosX0=
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM=

View File

@@ -17,10 +17,10 @@ func main() {
_ = godotenv.Load() _ = godotenv.Load()
token := mustEnv("DISCORD_TOKEN") token := mustEnv("DISCORD_TOKEN")
dbPath := getEnv("DB_PATH", "cycling_bot.db") dbURL := mustEnv("DATABASE_URL")
guildID := getEnv("GUILD_ID", "") // empty = global commands (takes ~1h to propagate) guildID := getEnv("GUILD_ID", "") // empty = global commands (takes ~1h to propagate)
database, err := db.Open(dbPath) database, err := db.Open(dbURL)
if err != nil { if err != nil {
log.Fatalf("open database: %v", err) log.Fatalf("open database: %v", err)
} }