Skip to main content

synth_ai_core/orchestration/
prompt_learning.rs

1//! Prompt learning job orchestration.
2//!
3//! This module provides the high-level `PromptLearningJob` class for
4//! submitting and tracking GEPA/MIPRO optimization jobs.
5
6use std::time::Duration;
7
8use serde::{Deserialize, Serialize};
9use serde_json::Value;
10
11use crate::api::types::PolicyJobStatus;
12use crate::api::SynthClient;
13use crate::auth;
14use crate::errors::CoreError;
15
16use super::events::ParsedEvent;
17use super::progress::ProgressTracker;
18use super::streaming::EventStream;
19
20/// Result from a prompt learning job.
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct PromptLearningResult {
23    /// Job ID
24    pub job_id: String,
25    /// Current status
26    pub status: PolicyJobStatus,
27    /// Best score achieved
28    #[serde(default)]
29    pub best_score: Option<f64>,
30    /// Best prompt configuration
31    #[serde(default)]
32    pub best_prompt: Option<Value>,
33    /// Baseline score
34    #[serde(default)]
35    pub baseline_score: Option<f64>,
36    /// Number of candidates evaluated
37    #[serde(default)]
38    pub candidates_evaluated: i32,
39    /// Number of generations completed
40    #[serde(default)]
41    pub generations_completed: i32,
42    /// Error message if failed
43    #[serde(default)]
44    pub error: Option<String>,
45    /// Raw response data
46    #[serde(default)]
47    pub raw: Value,
48}
49
50impl PromptLearningResult {
51    /// Check if the job succeeded.
52    pub fn succeeded(&self) -> bool {
53        self.status == PolicyJobStatus::Succeeded
54    }
55
56    /// Check if the job failed.
57    pub fn failed(&self) -> bool {
58        self.status == PolicyJobStatus::Failed
59    }
60
61    /// Check if the job is in a terminal state.
62    pub fn is_terminal(&self) -> bool {
63        self.status.is_terminal()
64    }
65
66    /// Get the system prompt from best_prompt if available.
67    pub fn get_system_prompt(&self) -> Option<String> {
68        self.best_prompt.as_ref().and_then(|p| {
69            // Try various paths where system prompt might be
70            p.get("system_prompt")
71                .and_then(|v| v.as_str())
72                .or_else(|| p.get("instruction").and_then(|v| v.as_str()))
73                .or_else(|| {
74                    p.get("stages")
75                        .and_then(|s| s.get("main"))
76                        .and_then(|m| m.get("instruction"))
77                        .and_then(|v| v.as_str())
78                })
79                .map(|s| s.to_string())
80        })
81    }
82}
83
84/// Ranked prompt from results.
85#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct RankedPrompt {
87    /// Rank (1 = best)
88    pub rank: i32,
89    /// Candidate ID
90    pub candidate_id: String,
91    /// Training accuracy
92    #[serde(default)]
93    pub train_accuracy: Option<f64>,
94    /// Validation accuracy
95    #[serde(default)]
96    pub val_accuracy: Option<f64>,
97    /// Prompt text or configuration
98    #[serde(default)]
99    pub prompt: Option<Value>,
100}
101
102/// Extracted prompt results.
103#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct PromptResults {
105    /// Best prompt text
106    #[serde(default)]
107    pub best_prompt: Option<String>,
108    /// Best score
109    #[serde(default)]
110    pub best_score: Option<f64>,
111    /// Top prompts ranked by score
112    #[serde(default)]
113    pub top_prompts: Vec<RankedPrompt>,
114}
115
116/// High-level prompt learning job orchestration.
117pub struct PromptLearningJob {
118    /// Synth API client
119    client: SynthClient,
120    /// Job ID (set after submit)
121    job_id: Option<String>,
122    /// Job configuration
123    config: Value,
124    /// Progress tracker
125    tracker: ProgressTracker,
126}
127
128impl PromptLearningJob {
129    /// Create a job from a configuration dict.
130    ///
131    /// # Arguments
132    ///
133    /// * `config` - Job configuration (algorithm, task_app_url, policy, etc.)
134    /// * `api_key` - Optional API key (uses env if not provided)
135    /// * `base_url` - Optional base URL (uses default if not provided)
136    ///
137    /// # Example
138    ///
139    /// ```ignore
140    /// let job = PromptLearningJob::from_dict(
141    ///     serde_json::json!({
142    ///         "algorithm": "gepa",
143    ///         "task_app_url": "http://localhost:8000",
144    ///         "env_name": "default",
145    ///         "policy": { "model": "gpt-4o-mini", "provider": "openai" },
146    ///         "gepa": { "rollout_budget": 100 }
147    ///     }),
148    ///     None,
149    ///     None,
150    /// )?;
151    /// ```
152    pub fn from_dict(
153        config: Value,
154        api_key: Option<&str>,
155        base_url: Option<&str>,
156    ) -> Result<Self, CoreError> {
157        let api_key = match api_key {
158            Some(k) => k.to_string(),
159            None => auth::get_api_key(None).ok_or_else(|| {
160                CoreError::Authentication("SYNTH_API_KEY not found".to_string())
161            })?,
162        };
163
164        let client = SynthClient::new(&api_key, base_url)?;
165
166        Ok(Self {
167            client,
168            job_id: None,
169            config,
170            tracker: ProgressTracker::new(),
171        })
172    }
173
174    /// Reconnect to an existing job by ID.
175    ///
176    /// # Arguments
177    ///
178    /// * `job_id` - Existing job ID
179    /// * `api_key` - Optional API key
180    /// * `base_url` - Optional base URL
181    pub fn from_job_id(
182        job_id: &str,
183        api_key: Option<&str>,
184        base_url: Option<&str>,
185    ) -> Result<Self, CoreError> {
186        let api_key = match api_key {
187            Some(k) => k.to_string(),
188            None => auth::get_api_key(None).ok_or_else(|| {
189                CoreError::Authentication("SYNTH_API_KEY not found".to_string())
190            })?,
191        };
192
193        let client = SynthClient::new(&api_key, base_url)?;
194
195        Ok(Self {
196            client,
197            job_id: Some(job_id.to_string()),
198            config: Value::Null,
199            tracker: ProgressTracker::new(),
200        })
201    }
202
203    /// Get the job ID (if submitted).
204    pub fn job_id(&self) -> Option<&str> {
205        self.job_id.as_deref()
206    }
207
208    /// Get the progress tracker.
209    pub fn tracker(&self) -> &ProgressTracker {
210        &self.tracker
211    }
212
213    /// Submit the job to the backend.
214    ///
215    /// Returns the job ID on success.
216    pub async fn submit(&mut self) -> Result<String, CoreError> {
217        if self.job_id.is_some() {
218            return Err(CoreError::Validation(
219                "job already submitted".to_string(),
220            ));
221        }
222
223        if self.config.is_null() {
224            return Err(CoreError::Validation(
225                "no configuration provided".to_string(),
226            ));
227        }
228
229        // Submit via jobs API
230        let job_id = self.client.jobs().submit_raw(self.config.clone()).await?;
231        self.job_id = Some(job_id.clone());
232
233        Ok(job_id)
234    }
235
236    /// Get the current job status.
237    pub async fn get_status(&self) -> Result<PromptLearningResult, CoreError> {
238        let job_id = self.job_id.as_ref().ok_or_else(|| {
239            CoreError::Validation("job not submitted yet".to_string())
240        })?;
241
242        let result = self.client.jobs().get_status(job_id).await?;
243
244        Ok(PromptLearningResult {
245            job_id: result.job_id,
246            status: result.status,
247            best_score: result.best_score,
248            best_prompt: result.best_prompt,
249            baseline_score: None,
250            candidates_evaluated: result.candidates_evaluated.unwrap_or(0),
251            generations_completed: result.generations_completed.unwrap_or(0),
252            error: result.error,
253            raw: Value::Null,
254        })
255    }
256
257    /// Poll until the job reaches a terminal state.
258    ///
259    /// # Arguments
260    ///
261    /// * `timeout_secs` - Maximum time to wait
262    /// * `interval_secs` - Polling interval
263    pub async fn poll_until_complete(
264        &self,
265        timeout_secs: f64,
266        interval_secs: f64,
267    ) -> Result<PromptLearningResult, CoreError> {
268        let job_id = self.job_id.as_ref().ok_or_else(|| {
269            CoreError::Validation("job not submitted yet".to_string())
270        })?;
271
272        let result = self
273            .client
274            .jobs()
275            .poll_until_complete(job_id, timeout_secs, interval_secs)
276            .await?;
277
278        Ok(PromptLearningResult {
279            job_id: result.job_id,
280            status: result.status,
281            best_score: result.best_score,
282            best_prompt: result.best_prompt,
283            baseline_score: None,
284            candidates_evaluated: result.candidates_evaluated.unwrap_or(0),
285            generations_completed: result.generations_completed.unwrap_or(0),
286            error: result.error,
287            raw: Value::Null,
288        })
289    }
290
291    /// Stream events until completion with callback.
292    ///
293    /// # Arguments
294    ///
295    /// * `timeout_secs` - Maximum time to wait
296    /// * `on_event` - Optional callback for each event
297    pub async fn stream_until_complete<F>(
298        &mut self,
299        timeout_secs: f64,
300        mut on_event: Option<F>,
301    ) -> Result<PromptLearningResult, CoreError>
302    where
303        F: FnMut(&ParsedEvent),
304    {
305        use std::cell::Cell;
306
307        let job_id = self.job_id.as_ref().ok_or_else(|| {
308            CoreError::Validation("job not submitted yet".to_string())
309        })?;
310
311        let mut stream = EventStream::new(
312            self.client.http().clone(),
313            self.client.base_url(),
314            job_id,
315        );
316
317        let timeout = Duration::from_secs_f64(timeout_secs);
318        let poll_interval = Duration::from_secs(5);
319
320        // Use Cell for interior mutability to satisfy borrow checker
321        let terminal_reached = Cell::new(false);
322
323        {
324            let tracker = &mut self.tracker;
325
326            stream
327                .stream_until(
328                    |event| {
329                        // Update tracker
330                        tracker.update(event);
331
332                        // Call user callback
333                        if let Some(ref mut cb) = on_event {
334                            cb(event);
335                        }
336
337                        // Check for terminal
338                        if event.category.is_terminal() {
339                            terminal_reached.set(true);
340                        }
341                    },
342                    timeout,
343                    poll_interval,
344                    || terminal_reached.get(),
345                )
346                .await?;
347        }
348
349        // Get final status (tracker borrow is dropped now)
350        let status_result = self.get_status().await?;
351
352        // Merge tracker data with status
353        Ok(PromptLearningResult {
354            job_id: status_result.job_id,
355            status: status_result.status,
356            best_score: status_result.best_score.or(Some(self.tracker.best_score())),
357            best_prompt: status_result.best_prompt,
358            baseline_score: self.tracker.baseline_score(),
359            candidates_evaluated: self.tracker.progress.candidates_evaluated,
360            generations_completed: self.tracker.progress.generations_completed,
361            error: status_result.error,
362            raw: Value::Null,
363        })
364    }
365
366    /// Cancel a running job.
367    ///
368    /// # Arguments
369    ///
370    /// * `reason` - Optional cancellation reason
371    pub async fn cancel(&self, reason: Option<&str>) -> Result<(), CoreError> {
372        let job_id = self.job_id.as_ref().ok_or_else(|| {
373            CoreError::Validation("job not submitted yet".to_string())
374        })?;
375
376        self.client.jobs().cancel(job_id, reason).await
377    }
378
379    /// Get detailed results including prompt extraction.
380    ///
381    /// This fetches events to extract the best prompts.
382    pub async fn get_results(&self) -> Result<PromptResults, CoreError> {
383        let job_id = self.job_id.as_ref().ok_or_else(|| {
384            CoreError::Validation("job not submitted yet".to_string())
385        })?;
386
387        // Get final status for best_prompt
388        let status = self.get_status().await?;
389
390        let best_prompt = status.get_system_prompt();
391        let best_score = status.best_score.or(Some(self.tracker.best_score()));
392
393        // Build ranked prompts from tracker candidates
394        let mut top_prompts: Vec<RankedPrompt> = self
395            .tracker
396            .candidates
397            .iter()
398            .filter(|c| c.accepted || c.is_pareto)
399            .map(|c| RankedPrompt {
400                rank: 0,
401                candidate_id: c.candidate_id.clone(),
402                train_accuracy: c.accuracy,
403                val_accuracy: c.val_accuracy,
404                prompt: None,
405            })
406            .collect();
407
408        // Sort by accuracy descending
409        top_prompts.sort_by(|a, b| {
410            let a_score = a.val_accuracy.or(a.train_accuracy).unwrap_or(0.0);
411            let b_score = b.val_accuracy.or(b.train_accuracy).unwrap_or(0.0);
412            b_score.partial_cmp(&a_score).unwrap_or(std::cmp::Ordering::Equal)
413        });
414
415        // Assign ranks
416        for (i, prompt) in top_prompts.iter_mut().enumerate() {
417            prompt.rank = (i + 1) as i32;
418        }
419
420        Ok(PromptResults {
421            best_prompt,
422            best_score,
423            top_prompts,
424        })
425    }
426}
427
428#[cfg(test)]
429mod tests {
430    use super::*;
431    use serde_json::json;
432
433    #[test]
434    fn test_result_status() {
435        let result = PromptLearningResult {
436            job_id: "test".to_string(),
437            status: PolicyJobStatus::Succeeded,
438            best_score: Some(0.85),
439            best_prompt: None,
440            baseline_score: None,
441            candidates_evaluated: 10,
442            generations_completed: 3,
443            error: None,
444            raw: Value::Null,
445        };
446
447        assert!(result.succeeded());
448        assert!(!result.failed());
449        assert!(result.is_terminal());
450    }
451
452    #[test]
453    fn test_result_get_system_prompt() {
454        let result = PromptLearningResult {
455            job_id: "test".to_string(),
456            status: PolicyJobStatus::Succeeded,
457            best_score: Some(0.85),
458            best_prompt: Some(json!({
459                "system_prompt": "You are a helpful assistant."
460            })),
461            baseline_score: None,
462            candidates_evaluated: 10,
463            generations_completed: 3,
464            error: None,
465            raw: Value::Null,
466        };
467
468        assert_eq!(
469            result.get_system_prompt(),
470            Some("You are a helpful assistant.".to_string())
471        );
472    }
473
474    #[test]
475    fn test_ranked_prompt_sorting() {
476        let mut prompts = vec![
477            RankedPrompt {
478                rank: 0,
479                candidate_id: "a".to_string(),
480                train_accuracy: Some(0.7),
481                val_accuracy: None,
482                prompt: None,
483            },
484            RankedPrompt {
485                rank: 0,
486                candidate_id: "b".to_string(),
487                train_accuracy: Some(0.9),
488                val_accuracy: None,
489                prompt: None,
490            },
491            RankedPrompt {
492                rank: 0,
493                candidate_id: "c".to_string(),
494                train_accuracy: Some(0.8),
495                val_accuracy: Some(0.85),
496                prompt: None,
497            },
498        ];
499
500        // Sort by accuracy descending (val_accuracy takes precedence)
501        prompts.sort_by(|a, b| {
502            let a_score = a.val_accuracy.or(a.train_accuracy).unwrap_or(0.0);
503            let b_score = b.val_accuracy.or(b.train_accuracy).unwrap_or(0.0);
504            b_score.partial_cmp(&a_score).unwrap_or(std::cmp::Ordering::Equal)
505        });
506
507        assert_eq!(prompts[0].candidate_id, "b"); // 0.9
508        assert_eq!(prompts[1].candidate_id, "c"); // 0.85 (val)
509        assert_eq!(prompts[2].candidate_id, "a"); // 0.7
510    }
511}