479 lines
13 KiB
Go
479 lines
13 KiB
Go
package ai
|
||
|
||
import (
|
||
"bytes"
|
||
"encoding/json"
|
||
"fmt"
|
||
"io"
|
||
"log"
|
||
"math"
|
||
"net/http"
|
||
"os"
|
||
"time"
|
||
|
||
"rideaware/internal/event"
|
||
"rideaware/internal/nutrition"
|
||
"rideaware/internal/stats"
|
||
"rideaware/internal/user"
|
||
"rideaware/internal/workout"
|
||
"rideaware/pkg/database"
|
||
)
|
||
|
||
type Service struct {
|
||
userRepo *user.Repository
|
||
statsRepo *stats.Repository
|
||
workoutRepo *workout.Repository
|
||
eventRepo *event.Repository
|
||
nutritionSvc *nutrition.Service
|
||
aiRepo *Repository
|
||
}
|
||
|
||
func NewService() *Service {
|
||
return &Service{
|
||
userRepo: user.NewRepository(),
|
||
statsRepo: stats.NewRepository(),
|
||
workoutRepo: workout.NewRepository(),
|
||
eventRepo: event.NewRepository(),
|
||
nutritionSvc: nutrition.NewService(),
|
||
aiRepo: NewRepository(),
|
||
}
|
||
}
|
||
|
||
// GenerateWorkouts orchestrates the AI generation process
|
||
func (s *Service) GenerateWorkouts(userID uint, req GenerateRequest) (*GenerateResponse, error) {
|
||
// 1. Gather user context
|
||
userCtx, err := s.buildUserContext(userID, req)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to build user context: %w", err)
|
||
}
|
||
|
||
// 2. Call DeepSeek API with retry logic
|
||
var aiResponse string
|
||
maxRetries := 2
|
||
for i := 0; i <= maxRetries; i++ {
|
||
if i > 0 {
|
||
log.Printf("[AI] Retry attempt %d/%d", i, maxRetries)
|
||
time.Sleep(time.Duration(i) * 2 * time.Second) // Exponential backoff
|
||
}
|
||
aiResponse, err = s.callDeepSeekAPI(userCtx, req)
|
||
if err == nil {
|
||
break
|
||
}
|
||
log.Printf("[AI] API call attempt %d failed: %v", i+1, err)
|
||
}
|
||
if err != nil {
|
||
return nil, fmt.Errorf("AI API call failed after %d attempts: %w", maxRetries+1, err)
|
||
}
|
||
|
||
// 3. Parse and validate response
|
||
genResponse, err := s.parseAIResponse(aiResponse)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to parse AI response: %w", err)
|
||
}
|
||
|
||
// 4. Fix duration mismatches (AI sometimes miscalculates)
|
||
s.fixWorkoutDurations(genResponse.Workouts)
|
||
|
||
// 5. Filter out incomplete workouts (defensive against malformed AI responses)
|
||
validWorkouts := make([]AIWorkout, 0, len(genResponse.Workouts))
|
||
for i, w := range genResponse.Workouts {
|
||
if w.Duration == 0 || len(w.Segments) == 0 || w.Title == "" {
|
||
log.Printf("[AI] Skipping incomplete workout %d (duration=%d, segments=%d, title=%q)",
|
||
i, w.Duration, len(w.Segments), w.Title)
|
||
continue
|
||
}
|
||
validWorkouts = append(validWorkouts, w)
|
||
}
|
||
genResponse.Workouts = validWorkouts
|
||
|
||
if len(validWorkouts) == 0 {
|
||
return nil, fmt.Errorf("no valid workouts generated by AI")
|
||
}
|
||
|
||
// 6. Validate workout structures
|
||
if err := validateWorkouts(genResponse.Workouts); err != nil {
|
||
return nil, fmt.Errorf("workout validation failed: %w", err)
|
||
}
|
||
|
||
// 7. Calculate total TSS
|
||
totalTSS := 0.0
|
||
for _, w := range genResponse.Workouts {
|
||
totalTSS += w.EstimatedTSS
|
||
}
|
||
genResponse.TotalTSS = totalTSS
|
||
|
||
// 8. Store recommendation in database
|
||
recID, err := s.aiRepo.SaveRecommendation(userID, userCtx, req, aiResponse, genResponse)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to save recommendation: %w", err)
|
||
}
|
||
|
||
genResponse.RecommendationID = recID
|
||
return genResponse, nil
|
||
}
|
||
|
||
// buildUserContext gathers all relevant user data
|
||
func (s *Service) buildUserContext(userID uint, req GenerateRequest) (UserContext, error) {
|
||
var ctx UserContext
|
||
|
||
// Get user profile
|
||
userObj, err := s.userRepo.GetUserByID(userID)
|
||
if err != nil {
|
||
return ctx, err
|
||
}
|
||
|
||
// Set default values if profile doesn't exist
|
||
ctx.FTP = 200 // default FTP
|
||
ctx.MaxHR = 180
|
||
ctx.RestingHR = 60
|
||
ctx.Weight = 70.0
|
||
ctx.WeeklyHours = req.WeeklyHours
|
||
|
||
// Override with actual profile data if available
|
||
if userObj.Profile != nil {
|
||
if userObj.Profile.FTP > 0 {
|
||
ctx.FTP = userObj.Profile.FTP
|
||
}
|
||
if userObj.Profile.MaxHR > 0 {
|
||
ctx.MaxHR = userObj.Profile.MaxHR
|
||
}
|
||
if userObj.Profile.RestingHR > 0 {
|
||
ctx.RestingHR = userObj.Profile.RestingHR
|
||
}
|
||
if userObj.Profile.Weight > 0 {
|
||
ctx.Weight = userObj.Profile.Weight
|
||
}
|
||
}
|
||
|
||
// Get recent workouts (last 30 days)
|
||
endDate := time.Now()
|
||
startDate := endDate.AddDate(0, 0, -30)
|
||
workouts, err := s.workoutRepo.GetWorkoutsByDateRange(userID, startDate, endDate)
|
||
if err == nil {
|
||
ctx.RecentWorkouts = s.summarizeWorkouts(workouts, ctx.FTP)
|
||
} else {
|
||
ctx.RecentWorkouts = []RecentWorkoutSummary{}
|
||
}
|
||
|
||
// Get training load (CTL/ATL/TSB)
|
||
ctx.TrainingLoad = s.calculateTrainingLoad(userID, ctx.FTP)
|
||
|
||
// Get upcoming events
|
||
upcomingEvents, err := s.eventRepo.GetUpcomingEvents(userID, 5)
|
||
if err == nil && len(upcomingEvents) > 0 {
|
||
now := time.Now()
|
||
now = time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
|
||
for _, ev := range upcomingEvents {
|
||
evDate := time.Date(ev.EventDate.Year(), ev.EventDate.Month(), ev.EventDate.Day(), 0, 0, 0, 0, ev.EventDate.Location())
|
||
daysAway := int(evDate.Sub(now).Hours() / 24)
|
||
summary := UpcomingEventSummary{
|
||
Name: ev.Name,
|
||
Date: ev.EventDate.Format("2006-01-02"),
|
||
EventType: ev.EventType,
|
||
Distance: ev.Distance,
|
||
Priority: ev.Priority,
|
||
DaysAway: daysAway,
|
||
}
|
||
ctx.UpcomingEvents = append(ctx.UpcomingEvents, summary)
|
||
|
||
// If this is the target event, set it
|
||
if req.TargetEventID != nil && ev.ID == *req.TargetEventID {
|
||
target := summary
|
||
ctx.TargetEvent = &target
|
||
}
|
||
}
|
||
}
|
||
|
||
// Get nutrition context
|
||
nutritionTargets, err := s.nutritionSvc.GetTargets(userID)
|
||
if err == nil && nutritionTargets != nil && nutritionTargets.IsConfigured {
|
||
ctx.Nutrition = &NutritionContext{
|
||
Goal: nutritionTargets.NutritionGoal,
|
||
DailyCalories: nutritionTargets.DailyCalories,
|
||
ProteinG: nutritionTargets.Protein,
|
||
CarbsG: nutritionTargets.Carbs,
|
||
FatG: nutritionTargets.Fat,
|
||
DietaryPref: nutritionTargets.DietaryPref,
|
||
}
|
||
}
|
||
|
||
return ctx, nil
|
||
}
|
||
|
||
// summarizeWorkouts converts workouts to summaries
|
||
func (s *Service) summarizeWorkouts(workouts []workout.Workout, ftp int) []RecentWorkoutSummary {
|
||
summaries := make([]RecentWorkoutSummary, 0, len(workouts))
|
||
|
||
// Limit to last 10 workouts to keep prompt size manageable
|
||
count := 0
|
||
for _, w := range workouts {
|
||
if w.Status != "completed" {
|
||
continue
|
||
}
|
||
if count >= 10 {
|
||
break
|
||
}
|
||
count++
|
||
|
||
tss := 0.0
|
||
if ftp > 0 && w.AvgPower > 0 && w.Duration > 0 {
|
||
// TSS = (duration_seconds × (avgPower/FTP)²) / 36
|
||
intensityFactor := float64(w.AvgPower) / float64(ftp)
|
||
tss = (float64(w.Duration) * intensityFactor * intensityFactor) / 36.0
|
||
}
|
||
|
||
summaries = append(summaries, RecentWorkoutSummary{
|
||
Date: w.ScheduledDate.Format("2006-01-02"),
|
||
Type: w.Type,
|
||
Duration: w.Duration,
|
||
AvgPower: w.AvgPower,
|
||
TSS: tss,
|
||
})
|
||
}
|
||
|
||
return summaries
|
||
}
|
||
|
||
// calculateTrainingLoad computes CTL/ATL/TSB
|
||
func (s *Service) calculateTrainingLoad(userID uint, ftp int) TrainingLoadSummary {
|
||
summary := TrainingLoadSummary{CTL: 0, ATL: 0, TSB: 0}
|
||
|
||
if ftp == 0 {
|
||
return summary
|
||
}
|
||
|
||
// Get daily TSS for last 42 days (for CTL)
|
||
dailyTSS, err := s.statsRepo.GetDailyTSS(userID, ftp, 42)
|
||
if err != nil || len(dailyTSS) == 0 {
|
||
return summary
|
||
}
|
||
|
||
// Calculate CTL (42-day exponential moving average)
|
||
ctlTC := 42.0 // time constant
|
||
ctl := 0.0
|
||
for _, day := range dailyTSS {
|
||
ctl = ctl + (day.TSS-ctl)*(1.0/ctlTC)
|
||
}
|
||
|
||
// Calculate ATL (7-day exponential moving average, using last 7 days)
|
||
atlTC := 7.0
|
||
atl := 0.0
|
||
start := len(dailyTSS) - 7
|
||
if start < 0 {
|
||
start = 0
|
||
}
|
||
for i := start; i < len(dailyTSS); i++ {
|
||
atl = atl + (dailyTSS[i].TSS-atl)*(1.0/atlTC)
|
||
}
|
||
|
||
summary.CTL = math.Round(ctl*10) / 10
|
||
summary.ATL = math.Round(atl*10) / 10
|
||
summary.TSB = math.Round((ctl-atl)*10) / 10
|
||
|
||
return summary
|
||
}
|
||
|
||
// callDeepSeekAPI makes the HTTP request to DeepSeek
|
||
func (s *Service) callDeepSeekAPI(userCtx UserContext, req GenerateRequest) (string, error) {
|
||
apiKey := os.Getenv("DEEPSEEK_API_KEY")
|
||
if apiKey == "" {
|
||
return "", fmt.Errorf("DEEPSEEK_API_KEY not configured")
|
||
}
|
||
|
||
// Build prompts
|
||
systemPrompt := BuildSystemPrompt()
|
||
userPrompt := BuildUserPrompt(userCtx, req)
|
||
|
||
// Create request
|
||
deepseekReq := DeepSeekRequest{
|
||
Model: "deepseek-chat",
|
||
Messages: []DeepSeekMessage{
|
||
{Role: "system", Content: systemPrompt},
|
||
{Role: "user", Content: userPrompt},
|
||
},
|
||
Temperature: 0.7,
|
||
MaxTokens: 8000, // Ensure enough tokens for complete response
|
||
}
|
||
|
||
reqBody, _ := json.Marshal(deepseekReq)
|
||
|
||
log.Printf("[AI] Calling DeepSeek API for user context with FTP=%d", userCtx.FTP)
|
||
|
||
// Make HTTP request
|
||
httpReq, err := http.NewRequest("POST", "https://api.deepseek.com/v1/chat/completions", bytes.NewBuffer(reqBody))
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
|
||
httpReq.Header.Set("Content-Type", "application/json")
|
||
httpReq.Header.Set("Authorization", "Bearer "+apiKey)
|
||
|
||
// Increase timeout to 180 seconds (3 minutes) for AI generation
|
||
client := &http.Client{Timeout: 180 * time.Second}
|
||
|
||
log.Printf("[AI] Sending request to DeepSeek API...")
|
||
resp, err := client.Do(httpReq)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
// Read response
|
||
body, err := io.ReadAll(resp.Body)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
|
||
if resp.StatusCode != http.StatusOK {
|
||
return "", fmt.Errorf("DeepSeek API error (status %d): %s", resp.StatusCode, string(body))
|
||
}
|
||
|
||
// Parse response
|
||
var deepseekResp DeepSeekResponse
|
||
if err := json.Unmarshal(body, &deepseekResp); err != nil {
|
||
return "", err
|
||
}
|
||
|
||
if deepseekResp.Error != nil {
|
||
return "", fmt.Errorf("DeepSeek API error: %s", deepseekResp.Error.Message)
|
||
}
|
||
|
||
if len(deepseekResp.Choices) == 0 {
|
||
return "", fmt.Errorf("no response from DeepSeek API")
|
||
}
|
||
|
||
content := deepseekResp.Choices[0].Message.Content
|
||
log.Printf("[AI] DeepSeek API call successful, response length: %d bytes", len(content))
|
||
|
||
// Log preview of response
|
||
if len(content) > 200 {
|
||
log.Printf("[AI] Response preview: %s...", content[:200])
|
||
} else {
|
||
log.Printf("[AI] Full response: %s", content)
|
||
}
|
||
|
||
return content, nil
|
||
}
|
||
|
||
// parseAIResponse extracts structured workout data from AI response
|
||
func (s *Service) parseAIResponse(aiResponse string) (*GenerateResponse, error) {
|
||
var response GenerateResponse
|
||
|
||
// Try to extract JSON from response (AI might wrap it in markdown)
|
||
cleaned := extractJSON(aiResponse)
|
||
|
||
log.Printf("[AI] Raw response length: %d bytes", len(aiResponse))
|
||
log.Printf("[AI] Cleaned response length: %d bytes", len(cleaned))
|
||
|
||
// Log first 500 chars for debugging
|
||
if len(cleaned) > 500 {
|
||
log.Printf("[AI] Response preview: %s...", cleaned[:500])
|
||
} else {
|
||
log.Printf("[AI] Full cleaned response: %s", cleaned)
|
||
}
|
||
|
||
if err := json.Unmarshal([]byte(cleaned), &response); err != nil {
|
||
log.Printf("[AI] Failed to parse JSON: %v", err)
|
||
log.Printf("[AI] Full raw response: %s", aiResponse)
|
||
return nil, fmt.Errorf("failed to parse JSON: %w (check logs for full response)", err)
|
||
}
|
||
|
||
return &response, nil
|
||
}
|
||
|
||
// ScheduleWorkouts converts AI recommendations to actual workouts
|
||
func (s *Service) ScheduleWorkouts(userID uint, recommendationID uint, workoutIndices []int) ([]*workout.Workout, error) {
|
||
// Get the recommendation
|
||
rec, err := s.aiRepo.GetRecommendation(recommendationID, userID)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// Prevent duplicate scheduling
|
||
if rec.Status == "scheduled" {
|
||
return nil, fmt.Errorf("this training plan has already been scheduled")
|
||
}
|
||
|
||
// Parse generated workouts
|
||
var aiWorkouts []AIWorkout
|
||
if err := json.Unmarshal(rec.GeneratedWorkouts.Data, &aiWorkouts); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// Schedule selected workouts
|
||
var scheduled []*workout.Workout
|
||
for _, idx := range workoutIndices {
|
||
if idx < 0 || idx >= len(aiWorkouts) {
|
||
continue
|
||
}
|
||
|
||
aiWorkout := aiWorkouts[idx]
|
||
scheduledDate, _ := time.Parse("2006-01-02", aiWorkout.ScheduledDate)
|
||
|
||
// Convert to workout model
|
||
workoutData := workout.WorkoutDataJSON{
|
||
Name: aiWorkout.Title,
|
||
TotalDuration: aiWorkout.Duration,
|
||
Segments: aiWorkout.Segments,
|
||
}
|
||
|
||
newWorkout := &workout.Workout{
|
||
UserID: userID,
|
||
Title: aiWorkout.Title,
|
||
Description: aiWorkout.Description,
|
||
Type: aiWorkout.Type,
|
||
Status: "planned",
|
||
ScheduledDate: scheduledDate,
|
||
Duration: aiWorkout.Duration,
|
||
WorkoutData: workoutData,
|
||
Notes: aiWorkout.Notes,
|
||
}
|
||
|
||
if err := s.workoutRepo.CreateWorkout(newWorkout); err != nil {
|
||
log.Printf("[AI] Failed to create workout: %v", err)
|
||
continue
|
||
}
|
||
|
||
scheduled = append(scheduled, newWorkout)
|
||
}
|
||
|
||
// Update recommendation status
|
||
if len(scheduled) > 0 {
|
||
s.aiRepo.UpdateRecommendationStatus(recommendationID, "scheduled")
|
||
}
|
||
|
||
return scheduled, nil
|
||
}
|
||
|
||
// GetUserRecommendations fetches recommendation history
|
||
func (s *Service) GetUserRecommendations(userID uint, limit int) ([]AIRecommendation, error) {
|
||
return s.aiRepo.GetUserRecommendations(userID, limit)
|
||
}
|
||
|
||
// fixWorkoutDurations corrects duration mismatches by recalculating from segments
|
||
func (s *Service) fixWorkoutDurations(workouts []AIWorkout) {
|
||
for i := range workouts {
|
||
totalDuration := 0
|
||
for _, seg := range workouts[i].Segments {
|
||
totalDuration += seg.Duration
|
||
}
|
||
|
||
// If there's a mismatch, fix the workout duration
|
||
if workouts[i].Duration != totalDuration {
|
||
log.Printf("[AI] Fixing workout %d duration: %d -> %d (sum of segments)",
|
||
i, workouts[i].Duration, totalDuration)
|
||
workouts[i].Duration = totalDuration
|
||
}
|
||
}
|
||
}
|
||
|
||
// GetUserFTP queries FTP from user_profiles table directly
|
||
func (s *Service) GetUserFTP(userID uint) (int, error) {
|
||
var ftp int
|
||
err := database.DB.Table("user_profiles").
|
||
Select("ftp").
|
||
Where("user_id = ?", userID).
|
||
Scan(&ftp).Error
|
||
return ftp, err
|
||
}
|