Skip to main content

synth_ai_core/orchestration/
prompt_learning.rs

1//! Prompt learning job orchestration.
2//!
3//! This module provides the high-level `PromptLearningJob` class for
4//! submitting and tracking GEPA/MIPRO optimization jobs.
5
6use std::time::Duration;
7
8use reqwest::header::{HeaderMap, HeaderValue};
9use serde::{Deserialize, Serialize};
10use serde_json::Value;
11
12use crate::api::types::PolicyJobStatus;
13use crate::api::SynthClient;
14use crate::auth;
15use crate::errors::CoreError;
16
17use super::events::{ParsedEvent, TerminalStatus};
18use super::progress::ProgressTracker;
19use crate::sse::stream_sse_events;
20use futures_util::StreamExt;
21
22/// Result from a prompt learning job.
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct PromptLearningResult {
25    /// Job ID
26    pub job_id: String,
27    /// Current status
28    pub status: PolicyJobStatus,
29    /// Best reward achieved
30    #[serde(default, alias = "best_score")]
31    pub best_reward: Option<f64>,
32    /// Best prompt configuration
33    #[serde(default)]
34    pub best_prompt: Option<Value>,
35    /// Baseline reward
36    #[serde(default, alias = "baseline_score")]
37    pub baseline_reward: Option<f64>,
38    /// Number of candidates evaluated
39    #[serde(default)]
40    pub candidates_evaluated: i32,
41    /// Number of generations completed
42    #[serde(default)]
43    pub generations_completed: i32,
44    /// Error message if failed
45    #[serde(default)]
46    pub error: Option<String>,
47    /// Raw response data
48    #[serde(default)]
49    pub raw: Value,
50}
51
52impl PromptLearningResult {
53    /// Check if the job succeeded.
54    pub fn succeeded(&self) -> bool {
55        self.status == PolicyJobStatus::Succeeded
56    }
57
58    /// Check if the job failed.
59    pub fn failed(&self) -> bool {
60        self.status == PolicyJobStatus::Failed
61    }
62
63    /// Check if the job is in a terminal state.
64    pub fn is_terminal(&self) -> bool {
65        self.status.is_terminal()
66    }
67
68    /// Get the system prompt from best_prompt if available.
69    pub fn get_system_prompt(&self) -> Option<String> {
70        self.best_prompt.as_ref().and_then(|p| {
71            // Try various paths where system prompt might be
72            p.get("system_prompt")
73                .and_then(|v| v.as_str())
74                .or_else(|| p.get("instruction").and_then(|v| v.as_str()))
75                .or_else(|| {
76                    p.get("stages")
77                        .and_then(|s| s.get("main"))
78                        .and_then(|m| m.get("instruction"))
79                        .and_then(|v| v.as_str())
80                })
81                .map(|s| s.to_string())
82        })
83    }
84}
85
86/// Ranked prompt from results.
87#[derive(Debug, Clone, Serialize, Deserialize)]
88pub struct RankedPrompt {
89    /// Rank (1 = best)
90    pub rank: i32,
91    /// Candidate ID
92    pub candidate_id: String,
93    /// Training accuracy
94    #[serde(default)]
95    pub train_accuracy: Option<f64>,
96    /// Validation accuracy
97    #[serde(default)]
98    pub val_accuracy: Option<f64>,
99    /// Prompt text or configuration
100    #[serde(default)]
101    pub prompt: Option<Value>,
102}
103
104/// Extracted prompt results.
105#[derive(Debug, Clone, Serialize, Deserialize)]
106pub struct PromptResults {
107    /// Best prompt text
108    #[serde(default)]
109    pub best_prompt: Option<String>,
110    /// Best reward
111    #[serde(default, alias = "best_score")]
112    pub best_reward: Option<f64>,
113    /// Top prompts ranked by score
114    #[serde(default)]
115    pub top_prompts: Vec<RankedPrompt>,
116}
117
118/// High-level prompt learning job orchestration.
119pub struct PromptLearningJob {
120    /// Synth API client
121    client: SynthClient,
122    /// Job ID (set after submit)
123    job_id: Option<String>,
124    /// Job configuration
125    config: Value,
126    /// Optional SynthTunnel worker token
127    task_app_worker_token: Option<String>,
128    /// Progress tracker
129    tracker: ProgressTracker,
130}
131
132impl PromptLearningJob {
133    /// Create a job from a configuration dict.
134    ///
135    /// # Arguments
136    ///
137    /// * `config` - Job configuration (algorithm, task_app_url, policy, etc.)
138    /// * `api_key` - Optional API key (uses env if not provided)
139    /// * `base_url` - Optional base URL (uses default if not provided)
140    ///
141    /// # Example
142    ///
143    /// ```ignore
144    /// let job = PromptLearningJob::from_dict(
145    ///     serde_json::json!({
146    ///         "algorithm": "gepa",
147    ///         "task_app_url": "http://localhost:8000",
148    ///         "env_name": "default",
149    ///         "policy": { "model": "gpt-4o-mini", "provider": "openai" },
150    ///         "gepa": { "rollout_budget": 100 }
151    ///     }),
152    ///     None,
153    ///     None,
154    ///     None,
155    /// )?;
156    /// ```
157    pub fn from_dict(
158        config: Value,
159        api_key: Option<&str>,
160        base_url: Option<&str>,
161        task_app_worker_token: Option<String>,
162    ) -> Result<Self, CoreError> {
163        let api_key = match api_key {
164            Some(k) => k.to_string(),
165            None => auth::get_api_key(None)
166                .ok_or_else(|| CoreError::Authentication("SYNTH_API_KEY not found".to_string()))?,
167        };
168
169        let client = SynthClient::new(&api_key, base_url)?;
170
171        Ok(Self {
172            client,
173            job_id: None,
174            config,
175            task_app_worker_token,
176            tracker: ProgressTracker::new(),
177        })
178    }
179
180    /// Reconnect to an existing job by ID.
181    ///
182    /// # Arguments
183    ///
184    /// * `job_id` - Existing job ID
185    /// * `api_key` - Optional API key
186    /// * `base_url` - Optional base URL
187    pub fn from_job_id(
188        job_id: &str,
189        api_key: Option<&str>,
190        base_url: Option<&str>,
191    ) -> Result<Self, CoreError> {
192        let api_key = match api_key {
193            Some(k) => k.to_string(),
194            None => auth::get_api_key(None)
195                .ok_or_else(|| CoreError::Authentication("SYNTH_API_KEY not found".to_string()))?,
196        };
197
198        let client = SynthClient::new(&api_key, base_url)?;
199
200        Ok(Self {
201            client,
202            job_id: Some(job_id.to_string()),
203            config: Value::Null,
204            task_app_worker_token: None,
205            tracker: ProgressTracker::new(),
206        })
207    }
208
209    /// Get the job ID (if submitted).
210    pub fn job_id(&self) -> Option<&str> {
211        self.job_id.as_deref()
212    }
213
214    /// Get the progress tracker.
215    pub fn tracker(&self) -> &ProgressTracker {
216        &self.tracker
217    }
218
219    /// Submit the job to the backend.
220    ///
221    /// Returns the job ID on success.
222    pub async fn submit(&mut self) -> Result<String, CoreError> {
223        if self.job_id.is_some() {
224            return Err(CoreError::Validation("job already submitted".to_string()));
225        }
226
227        if self.config.is_null() {
228            return Err(CoreError::Validation(
229                "no configuration provided".to_string(),
230            ));
231        }
232
233        // Submit via jobs API
234        let job_id = self
235            .client
236            .jobs()
237            .submit_raw_with_worker_token(self.config.clone(), self.task_app_worker_token.clone())
238            .await?;
239        self.job_id = Some(job_id.clone());
240
241        Ok(job_id)
242    }
243
244    /// Get the current job status.
245    pub async fn get_status(&self) -> Result<PromptLearningResult, CoreError> {
246        let job_id = self
247            .job_id
248            .as_ref()
249            .ok_or_else(|| CoreError::Validation("job not submitted yet".to_string()))?;
250
251        let result = self.client.jobs().get_status(job_id).await?;
252
253        Ok(PromptLearningResult {
254            job_id: result.job_id,
255            status: result.status,
256            best_reward: result.best_reward,
257            best_prompt: result.best_prompt,
258            baseline_reward: None,
259            candidates_evaluated: result.candidates_evaluated.unwrap_or(0),
260            generations_completed: result.generations_completed.unwrap_or(0),
261            error: result.error,
262            raw: Value::Null,
263        })
264    }
265
266    /// Poll until the job reaches a terminal state.
267    ///
268    /// # Arguments
269    ///
270    /// * `timeout_secs` - Maximum time to wait
271    /// * `interval_secs` - Polling interval
272    pub async fn poll_until_complete(
273        &self,
274        timeout_secs: f64,
275        interval_secs: f64,
276    ) -> Result<PromptLearningResult, CoreError> {
277        let job_id = self
278            .job_id
279            .as_ref()
280            .ok_or_else(|| CoreError::Validation("job not submitted yet".to_string()))?;
281
282        let result = self
283            .client
284            .jobs()
285            .poll_until_complete(job_id, timeout_secs, interval_secs)
286            .await?;
287
288        Ok(PromptLearningResult {
289            job_id: result.job_id,
290            status: result.status,
291            best_reward: result.best_reward,
292            best_prompt: result.best_prompt,
293            baseline_reward: None,
294            candidates_evaluated: result.candidates_evaluated.unwrap_or(0),
295            generations_completed: result.generations_completed.unwrap_or(0),
296            error: result.error,
297            raw: Value::Null,
298        })
299    }
300
301    /// Stream events until completion with callback.
302    ///
303    /// # Arguments
304    ///
305    /// * `timeout_secs` - Maximum time to wait
306    /// * `on_event` - Optional callback for each event
307    pub async fn stream_until_complete<F>(
308        &mut self,
309        timeout_secs: f64,
310        mut on_event: Option<F>,
311    ) -> Result<PromptLearningResult, CoreError>
312    where
313        F: FnMut(&ParsedEvent),
314    {
315        use std::cell::Cell;
316
317        let job_id = self
318            .job_id
319            .as_ref()
320            .ok_or_else(|| CoreError::Validation("job not submitted yet".to_string()))?;
321
322        eprintln!(
323            "[PL] stream_until_complete: job={} timeout={:.0}s",
324            job_id, timeout_secs
325        );
326
327        let timeout = Duration::from_secs_f64(timeout_secs);
328        let base_url = self.client.base_url().trim_end_matches('/').to_string();
329        let events_url = format!(
330            "{}/api/prompt-learning/online/jobs/{}/events/stream",
331            base_url, job_id
332        );
333        let api_key = self.client.http().api_key().to_string();
334        let mut headers = HeaderMap::new();
335        headers.insert("Accept", HeaderValue::from_static("text/event-stream"));
336        headers.insert(
337            "Authorization",
338            HeaderValue::from_str(&format!("Bearer {}", api_key))
339                .map_err(|_| CoreError::Validation("invalid api key".to_string()))?,
340        );
341        headers.insert(
342            "X-API-Key",
343            HeaderValue::from_str(&api_key)
344                .map_err(|_| CoreError::Validation("invalid api key".to_string()))?,
345        );
346
347        // Use Cell for interior mutability to satisfy borrow checker
348        let terminal_reached = Cell::new(false);
349        let event_count = Cell::new(0u64);
350        let terminal_status = Cell::new(None);
351
352        {
353            let tracker = &mut self.tracker;
354
355            let mut stream =
356                stream_sse_events(&events_url, "GET", headers, None, Some(timeout)).await?;
357
358            while let Some(item) = stream.next().await {
359                let event = item?;
360                if event.data.trim() == "[DONE]" {
361                    break;
362                }
363
364                let payload: Value = serde_json::from_str(&event.data).unwrap_or(Value::Null);
365                let parsed = super::events::EventParser::parse(&payload);
366                let count = event_count.get() + 1;
367                event_count.set(count);
368
369                // Update tracker
370                tracker.update(&parsed);
371
372                // Log progress periodically
373                if count % 5 == 0 || parsed.category.is_terminal() {
374                    eprintln!(
375                        "[PL] Event #{}: type={} category={:?} | tracker: best={:.3} baseline={:?} candidates={} gens={}",
376                        count,
377                        parsed.event_type,
378                        parsed.category,
379                        tracker.best_reward(),
380                        tracker.baseline_reward(),
381                        tracker.progress.candidates_evaluated,
382                        tracker.progress.generations_completed,
383                    );
384                }
385
386                // Call user callback
387                if let Some(ref mut cb) = on_event {
388                    cb(&parsed);
389                }
390
391                // Check for terminal
392                if parsed.category.is_terminal() {
393                    eprintln!(
394                        "[PL] Terminal event received: {} (category={:?})",
395                        parsed.event_type, parsed.category
396                    );
397                    terminal_status.set(super::events::EventParser::terminal_status(
398                        &parsed.event_type,
399                    ));
400                    terminal_reached.set(true);
401                    break;
402                }
403            }
404        }
405
406        eprintln!(
407            "[PL] stream_until_complete: streaming finished, processed {} events",
408            event_count.get()
409        );
410
411        if !terminal_reached.get() {
412            return Err(CoreError::Timeout(
413                "stream ended without terminal event".to_string(),
414            ));
415        }
416
417        // Get final status (tracker borrow is dropped now)
418        eprintln!("[PL] Fetching final job status...");
419        let status_result = match self.get_status().await {
420            Ok(result) => Some(result),
421            Err(err) => {
422                eprintln!("[PL] Warning: failed to fetch final job status: {}", err);
423                None
424            }
425        };
426
427        let mut final_status = status_result
428            .as_ref()
429            .map(|result| result.status)
430            .unwrap_or(crate::api::types::PolicyJobStatus::Succeeded);
431        if !final_status.is_terminal() {
432            if let Some(status) = terminal_status.get() {
433                final_status = match status {
434                    TerminalStatus::Succeeded => crate::api::types::PolicyJobStatus::Succeeded,
435                    TerminalStatus::Failed => crate::api::types::PolicyJobStatus::Failed,
436                    TerminalStatus::Cancelled => crate::api::types::PolicyJobStatus::Cancelled,
437                    TerminalStatus::Paused => crate::api::types::PolicyJobStatus::Paused,
438                };
439                eprintln!(
440                    "[PL] Final status override from terminal event: {:?}",
441                    final_status
442                );
443            }
444        }
445        eprintln!(
446            "[PL] Final status: status={:?} best_reward={:?} error={:?}",
447            final_status,
448            status_result.as_ref().and_then(|result| result.best_reward),
449            status_result
450                .as_ref()
451                .and_then(|result| result.error.clone())
452        );
453
454        // Merge tracker data with status (fall back to tracker if status fetch failed)
455        let result = PromptLearningResult {
456            job_id: status_result
457                .as_ref()
458                .map(|result| result.job_id.clone())
459                .unwrap_or_else(|| job_id.to_string()),
460            status: final_status,
461            best_reward: status_result
462                .as_ref()
463                .and_then(|result| result.best_reward)
464                .or(Some(self.tracker.best_reward())),
465            best_prompt: status_result
466                .as_ref()
467                .and_then(|result| result.best_prompt.clone()),
468            baseline_reward: self.tracker.baseline_reward(),
469            candidates_evaluated: self.tracker.progress.candidates_evaluated,
470            generations_completed: self.tracker.progress.generations_completed,
471            error: status_result.and_then(|result| result.error),
472            raw: Value::Null,
473        };
474
475        eprintln!(
476            "[PL] RESULT: status={:?} best={:?} baseline={:?} candidates={} gens={}",
477            result.status,
478            result.best_reward,
479            result.baseline_reward,
480            result.candidates_evaluated,
481            result.generations_completed
482        );
483
484        Ok(result)
485    }
486
487    /// Cancel a running job.
488    ///
489    /// # Arguments
490    ///
491    /// * `reason` - Optional cancellation reason
492    pub async fn cancel(&self, reason: Option<&str>) -> Result<(), CoreError> {
493        let job_id = self
494            .job_id
495            .as_ref()
496            .ok_or_else(|| CoreError::Validation("job not submitted yet".to_string()))?;
497
498        self.client.jobs().cancel(job_id, reason).await
499    }
500
501    /// Pause a running job.
502    ///
503    /// # Arguments
504    ///
505    /// * `reason` - Optional pause reason
506    pub async fn pause(&self, reason: Option<&str>) -> Result<(), CoreError> {
507        let job_id = self
508            .job_id
509            .as_ref()
510            .ok_or_else(|| CoreError::Validation("job not submitted yet".to_string()))?;
511
512        self.client.jobs().pause(job_id, reason).await
513    }
514
515    /// Resume a paused job.
516    ///
517    /// # Arguments
518    ///
519    /// * `reason` - Optional resume reason
520    pub async fn resume(&self, reason: Option<&str>) -> Result<(), CoreError> {
521        let job_id = self
522            .job_id
523            .as_ref()
524            .ok_or_else(|| CoreError::Validation("job not submitted yet".to_string()))?;
525
526        self.client.jobs().resume(job_id, reason).await
527    }
528
529    /// Get detailed results including prompt extraction.
530    ///
531    /// This fetches events to extract the best prompts.
532    pub async fn get_results(&self) -> Result<PromptResults, CoreError> {
533        // Get final status for best_prompt
534        let status = self.get_status().await?;
535
536        let best_prompt = status.get_system_prompt();
537        let best_reward = status.best_reward.or(Some(self.tracker.best_reward()));
538
539        // Build ranked prompts from tracker candidates
540        let mut top_prompts: Vec<RankedPrompt> = self
541            .tracker
542            .candidates
543            .iter()
544            .filter(|c| c.accepted || c.is_pareto)
545            .map(|c| RankedPrompt {
546                rank: 0,
547                candidate_id: c.candidate_id.clone(),
548                train_accuracy: c.reward,
549                val_accuracy: c.val_reward,
550                prompt: None,
551            })
552            .collect();
553
554        // Sort by accuracy descending
555        top_prompts.sort_by(|a, b| {
556            let a_score = a.val_accuracy.or(a.train_accuracy).unwrap_or(0.0);
557            let b_score = b.val_accuracy.or(b.train_accuracy).unwrap_or(0.0);
558            b_score
559                .partial_cmp(&a_score)
560                .unwrap_or(std::cmp::Ordering::Equal)
561        });
562
563        // Assign ranks
564        for (i, prompt) in top_prompts.iter_mut().enumerate() {
565            prompt.rank = (i + 1) as i32;
566        }
567
568        Ok(PromptResults {
569            best_prompt,
570            best_reward,
571            top_prompts,
572        })
573    }
574}
575
576#[cfg(test)]
577mod tests {
578    use super::*;
579    use serde_json::json;
580
581    #[test]
582    fn test_result_status() {
583        let result = PromptLearningResult {
584            job_id: "test".to_string(),
585            status: PolicyJobStatus::Succeeded,
586            best_reward: Some(0.85),
587            best_prompt: None,
588            baseline_reward: None,
589            candidates_evaluated: 10,
590            generations_completed: 3,
591            error: None,
592            raw: Value::Null,
593        };
594
595        assert!(result.succeeded());
596        assert!(!result.failed());
597        assert!(result.is_terminal());
598    }
599
600    #[test]
601    fn test_result_get_system_prompt() {
602        let result = PromptLearningResult {
603            job_id: "test".to_string(),
604            status: PolicyJobStatus::Succeeded,
605            best_reward: Some(0.85),
606            best_prompt: Some(json!({
607                "system_prompt": "You are a helpful assistant."
608            })),
609            baseline_reward: None,
610            candidates_evaluated: 10,
611            generations_completed: 3,
612            error: None,
613            raw: Value::Null,
614        };
615
616        assert_eq!(
617            result.get_system_prompt(),
618            Some("You are a helpful assistant.".to_string())
619        );
620    }
621
622    #[test]
623    fn test_ranked_prompt_sorting() {
624        let mut prompts = vec![
625            RankedPrompt {
626                rank: 0,
627                candidate_id: "a".to_string(),
628                train_accuracy: Some(0.7),
629                val_accuracy: None,
630                prompt: None,
631            },
632            RankedPrompt {
633                rank: 0,
634                candidate_id: "b".to_string(),
635                train_accuracy: Some(0.9),
636                val_accuracy: None,
637                prompt: None,
638            },
639            RankedPrompt {
640                rank: 0,
641                candidate_id: "c".to_string(),
642                train_accuracy: Some(0.8),
643                val_accuracy: Some(0.85),
644                prompt: None,
645            },
646        ];
647
648        // Sort by accuracy descending (val_accuracy takes precedence)
649        prompts.sort_by(|a, b| {
650            let a_score = a.val_accuracy.or(a.train_accuracy).unwrap_or(0.0);
651            let b_score = b.val_accuracy.or(b.train_accuracy).unwrap_or(0.0);
652            b_score
653                .partial_cmp(&a_score)
654                .unwrap_or(std::cmp::Ordering::Equal)
655        });
656
657        assert_eq!(prompts[0].candidate_id, "b"); // 0.9
658        assert_eq!(prompts[1].candidate_id, "c"); // 0.85 (val)
659        assert_eq!(prompts[2].candidate_id, "a"); // 0.7
660    }
661}