Files
rideaware-api/internal/ai/service.go
2026-05-17 20:39:47 -05:00

479 lines
13 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package ai
import (
"bytes"
"encoding/json"
"fmt"
"io"
"log"
"math"
"net/http"
"os"
"time"
"rideaware/internal/event"
"rideaware/internal/nutrition"
"rideaware/internal/stats"
"rideaware/internal/user"
"rideaware/internal/workout"
"rideaware/pkg/database"
)
type Service struct {
userRepo *user.Repository
statsRepo *stats.Repository
workoutRepo *workout.Repository
eventRepo *event.Repository
nutritionSvc *nutrition.Service
aiRepo *Repository
}
func NewService() *Service {
return &Service{
userRepo: user.NewRepository(),
statsRepo: stats.NewRepository(),
workoutRepo: workout.NewRepository(),
eventRepo: event.NewRepository(),
nutritionSvc: nutrition.NewService(),
aiRepo: NewRepository(),
}
}
// GenerateWorkouts orchestrates the AI generation process
func (s *Service) GenerateWorkouts(userID uint, req GenerateRequest) (*GenerateResponse, error) {
// 1. Gather user context
userCtx, err := s.buildUserContext(userID, req)
if err != nil {
return nil, fmt.Errorf("failed to build user context: %w", err)
}
// 2. Call DeepSeek API with retry logic
var aiResponse string
maxRetries := 2
for i := 0; i <= maxRetries; i++ {
if i > 0 {
log.Printf("[AI] Retry attempt %d/%d", i, maxRetries)
time.Sleep(time.Duration(i) * 2 * time.Second) // Exponential backoff
}
aiResponse, err = s.callDeepSeekAPI(userCtx, req)
if err == nil {
break
}
log.Printf("[AI] API call attempt %d failed: %v", i+1, err)
}
if err != nil {
return nil, fmt.Errorf("AI API call failed after %d attempts: %w", maxRetries+1, err)
}
// 3. Parse and validate response
genResponse, err := s.parseAIResponse(aiResponse)
if err != nil {
return nil, fmt.Errorf("failed to parse AI response: %w", err)
}
// 4. Fix duration mismatches (AI sometimes miscalculates)
s.fixWorkoutDurations(genResponse.Workouts)
// 5. Filter out incomplete workouts (defensive against malformed AI responses)
validWorkouts := make([]AIWorkout, 0, len(genResponse.Workouts))
for i, w := range genResponse.Workouts {
if w.Duration == 0 || len(w.Segments) == 0 || w.Title == "" {
log.Printf("[AI] Skipping incomplete workout %d (duration=%d, segments=%d, title=%q)",
i, w.Duration, len(w.Segments), w.Title)
continue
}
validWorkouts = append(validWorkouts, w)
}
genResponse.Workouts = validWorkouts
if len(validWorkouts) == 0 {
return nil, fmt.Errorf("no valid workouts generated by AI")
}
// 6. Validate workout structures
if err := validateWorkouts(genResponse.Workouts); err != nil {
return nil, fmt.Errorf("workout validation failed: %w", err)
}
// 7. Calculate total TSS
totalTSS := 0.0
for _, w := range genResponse.Workouts {
totalTSS += w.EstimatedTSS
}
genResponse.TotalTSS = totalTSS
// 8. Store recommendation in database
recID, err := s.aiRepo.SaveRecommendation(userID, userCtx, req, aiResponse, genResponse)
if err != nil {
return nil, fmt.Errorf("failed to save recommendation: %w", err)
}
genResponse.RecommendationID = recID
return genResponse, nil
}
// buildUserContext gathers all relevant user data
func (s *Service) buildUserContext(userID uint, req GenerateRequest) (UserContext, error) {
var ctx UserContext
// Get user profile
userObj, err := s.userRepo.GetUserByID(userID)
if err != nil {
return ctx, err
}
// Set default values if profile doesn't exist
ctx.FTP = 200 // default FTP
ctx.MaxHR = 180
ctx.RestingHR = 60
ctx.Weight = 70.0
ctx.WeeklyHours = req.WeeklyHours
// Override with actual profile data if available
if userObj.Profile != nil {
if userObj.Profile.FTP > 0 {
ctx.FTP = userObj.Profile.FTP
}
if userObj.Profile.MaxHR > 0 {
ctx.MaxHR = userObj.Profile.MaxHR
}
if userObj.Profile.RestingHR > 0 {
ctx.RestingHR = userObj.Profile.RestingHR
}
if userObj.Profile.Weight > 0 {
ctx.Weight = userObj.Profile.Weight
}
}
// Get recent workouts (last 30 days)
endDate := time.Now()
startDate := endDate.AddDate(0, 0, -30)
workouts, err := s.workoutRepo.GetWorkoutsByDateRange(userID, startDate, endDate)
if err == nil {
ctx.RecentWorkouts = s.summarizeWorkouts(workouts, ctx.FTP)
} else {
ctx.RecentWorkouts = []RecentWorkoutSummary{}
}
// Get training load (CTL/ATL/TSB)
ctx.TrainingLoad = s.calculateTrainingLoad(userID, ctx.FTP)
// Get upcoming events
upcomingEvents, err := s.eventRepo.GetUpcomingEvents(userID, 5)
if err == nil && len(upcomingEvents) > 0 {
now := time.Now()
now = time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
for _, ev := range upcomingEvents {
evDate := time.Date(ev.EventDate.Year(), ev.EventDate.Month(), ev.EventDate.Day(), 0, 0, 0, 0, ev.EventDate.Location())
daysAway := int(evDate.Sub(now).Hours() / 24)
summary := UpcomingEventSummary{
Name: ev.Name,
Date: ev.EventDate.Format("2006-01-02"),
EventType: ev.EventType,
Distance: ev.Distance,
Priority: ev.Priority,
DaysAway: daysAway,
}
ctx.UpcomingEvents = append(ctx.UpcomingEvents, summary)
// If this is the target event, set it
if req.TargetEventID != nil && ev.ID == *req.TargetEventID {
target := summary
ctx.TargetEvent = &target
}
}
}
// Get nutrition context
nutritionTargets, err := s.nutritionSvc.GetTargets(userID)
if err == nil && nutritionTargets != nil && nutritionTargets.IsConfigured {
ctx.Nutrition = &NutritionContext{
Goal: nutritionTargets.NutritionGoal,
DailyCalories: nutritionTargets.DailyCalories,
ProteinG: nutritionTargets.Protein,
CarbsG: nutritionTargets.Carbs,
FatG: nutritionTargets.Fat,
DietaryPref: nutritionTargets.DietaryPref,
}
}
return ctx, nil
}
// summarizeWorkouts converts workouts to summaries
func (s *Service) summarizeWorkouts(workouts []workout.Workout, ftp int) []RecentWorkoutSummary {
summaries := make([]RecentWorkoutSummary, 0, len(workouts))
// Limit to last 10 workouts to keep prompt size manageable
count := 0
for _, w := range workouts {
if w.Status != "completed" {
continue
}
if count >= 10 {
break
}
count++
tss := 0.0
if ftp > 0 && w.AvgPower > 0 && w.Duration > 0 {
// TSS = (duration_seconds × (avgPower/FTP)²) / 36
intensityFactor := float64(w.AvgPower) / float64(ftp)
tss = (float64(w.Duration) * intensityFactor * intensityFactor) / 36.0
}
summaries = append(summaries, RecentWorkoutSummary{
Date: w.ScheduledDate.Format("2006-01-02"),
Type: w.Type,
Duration: w.Duration,
AvgPower: w.AvgPower,
TSS: tss,
})
}
return summaries
}
// calculateTrainingLoad computes CTL/ATL/TSB
func (s *Service) calculateTrainingLoad(userID uint, ftp int) TrainingLoadSummary {
summary := TrainingLoadSummary{CTL: 0, ATL: 0, TSB: 0}
if ftp == 0 {
return summary
}
// Get daily TSS for last 42 days (for CTL)
dailyTSS, err := s.statsRepo.GetDailyTSS(userID, ftp, 42)
if err != nil || len(dailyTSS) == 0 {
return summary
}
// Calculate CTL (42-day exponential moving average)
ctlTC := 42.0 // time constant
ctl := 0.0
for _, day := range dailyTSS {
ctl = ctl + (day.TSS-ctl)*(1.0/ctlTC)
}
// Calculate ATL (7-day exponential moving average, using last 7 days)
atlTC := 7.0
atl := 0.0
start := len(dailyTSS) - 7
if start < 0 {
start = 0
}
for i := start; i < len(dailyTSS); i++ {
atl = atl + (dailyTSS[i].TSS-atl)*(1.0/atlTC)
}
summary.CTL = math.Round(ctl*10) / 10
summary.ATL = math.Round(atl*10) / 10
summary.TSB = math.Round((ctl-atl)*10) / 10
return summary
}
// callDeepSeekAPI makes the HTTP request to DeepSeek
func (s *Service) callDeepSeekAPI(userCtx UserContext, req GenerateRequest) (string, error) {
apiKey := os.Getenv("DEEPSEEK_API_KEY")
if apiKey == "" {
return "", fmt.Errorf("DEEPSEEK_API_KEY not configured")
}
// Build prompts
systemPrompt := BuildSystemPrompt()
userPrompt := BuildUserPrompt(userCtx, req)
// Create request
deepseekReq := DeepSeekRequest{
Model: "deepseek-chat",
Messages: []DeepSeekMessage{
{Role: "system", Content: systemPrompt},
{Role: "user", Content: userPrompt},
},
Temperature: 0.7,
MaxTokens: 8000, // Ensure enough tokens for complete response
}
reqBody, _ := json.Marshal(deepseekReq)
log.Printf("[AI] Calling DeepSeek API for user context with FTP=%d", userCtx.FTP)
// Make HTTP request
httpReq, err := http.NewRequest("POST", "https://api.deepseek.com/v1/chat/completions", bytes.NewBuffer(reqBody))
if err != nil {
return "", err
}
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Authorization", "Bearer "+apiKey)
// Increase timeout to 180 seconds (3 minutes) for AI generation
client := &http.Client{Timeout: 180 * time.Second}
log.Printf("[AI] Sending request to DeepSeek API...")
resp, err := client.Do(httpReq)
if err != nil {
return "", err
}
defer resp.Body.Close()
// Read response
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", err
}
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("DeepSeek API error (status %d): %s", resp.StatusCode, string(body))
}
// Parse response
var deepseekResp DeepSeekResponse
if err := json.Unmarshal(body, &deepseekResp); err != nil {
return "", err
}
if deepseekResp.Error != nil {
return "", fmt.Errorf("DeepSeek API error: %s", deepseekResp.Error.Message)
}
if len(deepseekResp.Choices) == 0 {
return "", fmt.Errorf("no response from DeepSeek API")
}
content := deepseekResp.Choices[0].Message.Content
log.Printf("[AI] DeepSeek API call successful, response length: %d bytes", len(content))
// Log preview of response
if len(content) > 200 {
log.Printf("[AI] Response preview: %s...", content[:200])
} else {
log.Printf("[AI] Full response: %s", content)
}
return content, nil
}
// parseAIResponse extracts structured workout data from AI response
func (s *Service) parseAIResponse(aiResponse string) (*GenerateResponse, error) {
var response GenerateResponse
// Try to extract JSON from response (AI might wrap it in markdown)
cleaned := extractJSON(aiResponse)
log.Printf("[AI] Raw response length: %d bytes", len(aiResponse))
log.Printf("[AI] Cleaned response length: %d bytes", len(cleaned))
// Log first 500 chars for debugging
if len(cleaned) > 500 {
log.Printf("[AI] Response preview: %s...", cleaned[:500])
} else {
log.Printf("[AI] Full cleaned response: %s", cleaned)
}
if err := json.Unmarshal([]byte(cleaned), &response); err != nil {
log.Printf("[AI] Failed to parse JSON: %v", err)
log.Printf("[AI] Full raw response: %s", aiResponse)
return nil, fmt.Errorf("failed to parse JSON: %w (check logs for full response)", err)
}
return &response, nil
}
// ScheduleWorkouts converts AI recommendations to actual workouts
func (s *Service) ScheduleWorkouts(userID uint, recommendationID uint, workoutIndices []int) ([]*workout.Workout, error) {
// Get the recommendation
rec, err := s.aiRepo.GetRecommendation(recommendationID, userID)
if err != nil {
return nil, err
}
// Prevent duplicate scheduling
if rec.Status == "scheduled" {
return nil, fmt.Errorf("this training plan has already been scheduled")
}
// Parse generated workouts
var aiWorkouts []AIWorkout
if err := json.Unmarshal(rec.GeneratedWorkouts.Data, &aiWorkouts); err != nil {
return nil, err
}
// Schedule selected workouts
var scheduled []*workout.Workout
for _, idx := range workoutIndices {
if idx < 0 || idx >= len(aiWorkouts) {
continue
}
aiWorkout := aiWorkouts[idx]
scheduledDate, _ := time.Parse("2006-01-02", aiWorkout.ScheduledDate)
// Convert to workout model
workoutData := workout.WorkoutDataJSON{
Name: aiWorkout.Title,
TotalDuration: aiWorkout.Duration,
Segments: aiWorkout.Segments,
}
newWorkout := &workout.Workout{
UserID: userID,
Title: aiWorkout.Title,
Description: aiWorkout.Description,
Type: aiWorkout.Type,
Status: "planned",
ScheduledDate: scheduledDate,
Duration: aiWorkout.Duration,
WorkoutData: workoutData,
Notes: aiWorkout.Notes,
}
if err := s.workoutRepo.CreateWorkout(newWorkout); err != nil {
log.Printf("[AI] Failed to create workout: %v", err)
continue
}
scheduled = append(scheduled, newWorkout)
}
// Update recommendation status
if len(scheduled) > 0 {
s.aiRepo.UpdateRecommendationStatus(recommendationID, "scheduled")
}
return scheduled, nil
}
// GetUserRecommendations fetches recommendation history
func (s *Service) GetUserRecommendations(userID uint, limit int) ([]AIRecommendation, error) {
return s.aiRepo.GetUserRecommendations(userID, limit)
}
// fixWorkoutDurations corrects duration mismatches by recalculating from segments
func (s *Service) fixWorkoutDurations(workouts []AIWorkout) {
for i := range workouts {
totalDuration := 0
for _, seg := range workouts[i].Segments {
totalDuration += seg.Duration
}
// If there's a mismatch, fix the workout duration
if workouts[i].Duration != totalDuration {
log.Printf("[AI] Fixing workout %d duration: %d -> %d (sum of segments)",
i, workouts[i].Duration, totalDuration)
workouts[i].Duration = totalDuration
}
}
}
// GetUserFTP queries FTP from user_profiles table directly
func (s *Service) GetUserFTP(userID uint) (int, error) {
var ftp int
err := database.DB.Table("user_profiles").
Select("ftp").
Where("user_id = ?", userID).
Scan(&ftp).Error
return ftp, err
}