Skip to main content

synth_ai_core/orchestration/
prompt_learning.rs

1//! Prompt learning job orchestration.
2//!
3//! This module provides the high-level `PromptLearningJob` class for
4//! submitting and tracking GEPA/MIPRO optimization jobs.
5
6use std::collections::HashMap;
7use std::time::Duration;
8
9use reqwest::header::{HeaderMap, HeaderValue};
10use serde::{Deserialize, Serialize};
11use serde_json::Value;
12
13use crate::api::types::PolicyJobStatus;
14use crate::api::SynthClient;
15use crate::auth;
16use crate::errors::CoreError;
17
18use super::events::{ParsedEvent, TerminalStatus};
19use super::progress::ProgressTracker;
20use crate::sse::stream_sse_events;
21use futures_util::StreamExt;
22
23/// Result from a prompt learning job.
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct PromptLearningResult {
26    /// Job ID
27    pub job_id: String,
28    /// Current status
29    pub status: PolicyJobStatus,
30    /// Best reward achieved
31    #[serde(default, alias = "best_score")]
32    pub best_reward: Option<f64>,
33    /// Best candidate configuration
34    #[serde(default)]
35    pub best_candidate: Option<Value>,
36    /// Lever summary emitted by optimizer runtimes.
37    #[serde(default)]
38    pub lever_summary: Option<Value>,
39    /// Sensor frame summaries emitted by optimizer runtimes.
40    #[serde(default)]
41    pub sensor_frames: Vec<Value>,
42    /// Lever versions for the selected/best candidate.
43    #[serde(default)]
44    pub lever_versions: HashMap<String, i64>,
45    /// Highest lever version represented in `lever_versions`.
46    #[serde(default)]
47    pub best_lever_version: Option<i64>,
48    /// Baseline reward
49    #[serde(default, alias = "baseline_score")]
50    pub baseline_reward: Option<f64>,
51    /// Number of candidates evaluated
52    #[serde(default)]
53    pub candidates_evaluated: i32,
54    /// Number of generations completed
55    #[serde(default)]
56    pub generations_completed: i32,
57    /// Error message if failed
58    #[serde(default)]
59    pub error: Option<String>,
60    /// Raw response data
61    #[serde(default)]
62    pub raw: Value,
63}
64
65impl PromptLearningResult {
66    /// Check if the job succeeded.
67    pub fn succeeded(&self) -> bool {
68        self.status == PolicyJobStatus::Succeeded
69    }
70
71    /// Check if the job failed.
72    pub fn failed(&self) -> bool {
73        self.status == PolicyJobStatus::Failed
74    }
75
76    /// Check if the job is in a terminal state.
77    pub fn is_terminal(&self) -> bool {
78        self.status.is_terminal()
79    }
80
81    /// Get the system prompt from best_candidate if available.
82    pub fn get_system_prompt(&self) -> Option<String> {
83        self.best_candidate.as_ref().and_then(|p| {
84            // Try various paths where system prompt might be
85            p.get("system_prompt")
86                .and_then(|v| v.as_str())
87                .or_else(|| p.get("instruction").and_then(|v| v.as_str()))
88                .or_else(|| {
89                    p.get("stages")
90                        .and_then(|s| s.get("main"))
91                        .and_then(|m| m.get("instruction"))
92                        .and_then(|v| v.as_str())
93                })
94                .map(|s| s.to_string())
95        })
96    }
97}
98
99/// Ranked prompt from results.
100#[derive(Debug, Clone, Serialize, Deserialize)]
101pub struct RankedPrompt {
102    /// Rank (1 = best)
103    pub rank: i32,
104    /// Candidate ID
105    pub candidate_id: String,
106    /// Training accuracy
107    #[serde(default)]
108    pub train_accuracy: Option<f64>,
109    /// Validation accuracy
110    #[serde(default)]
111    pub val_accuracy: Option<f64>,
112    /// Prompt text or configuration
113    #[serde(default)]
114    pub prompt: Option<Value>,
115}
116
117/// Extracted prompt results.
118#[derive(Debug, Clone, Serialize, Deserialize)]
119pub struct PromptResults {
120    /// Best candidate text
121    #[serde(default)]
122    pub best_candidate: Option<String>,
123    /// Best reward
124    #[serde(default, alias = "best_score")]
125    pub best_reward: Option<f64>,
126    /// Top prompts ranked by score
127    #[serde(default)]
128    pub top_prompts: Vec<RankedPrompt>,
129    /// Lever summary emitted by optimizer runtimes.
130    #[serde(default)]
131    pub lever_summary: Option<Value>,
132    /// Sensor frame summaries emitted by optimizer runtimes.
133    #[serde(default)]
134    pub sensor_frames: Vec<Value>,
135    /// Lever versions for the selected/best candidate.
136    #[serde(default)]
137    pub lever_versions: HashMap<String, i64>,
138    /// Highest lever version represented in `lever_versions`.
139    #[serde(default)]
140    pub best_lever_version: Option<i64>,
141}
142
143/// High-level prompt learning job orchestration.
144pub struct PromptLearningJob {
145    /// Synth API client
146    client: SynthClient,
147    /// Job ID (set after submit)
148    job_id: Option<String>,
149    /// Job configuration
150    config: Value,
151    /// Optional SynthTunnel worker token
152    container_worker_token: Option<String>,
153    /// Progress tracker
154    tracker: ProgressTracker,
155}
156
157impl PromptLearningJob {
158    /// Create a job from a configuration dict.
159    ///
160    /// # Arguments
161    ///
162    /// * `config` - Job configuration (algorithm, container_url, policy, etc.)
163    /// * `api_key` - Optional API key (uses env if not provided)
164    /// * `base_url` - Optional base URL (uses default if not provided)
165    ///
166    /// # Example
167    ///
168    /// ```ignore
169    /// let job = PromptLearningJob::from_dict(
170    ///     serde_json::json!({
171    ///         "algorithm": "gepa",
172    ///         "container_url": "http://localhost:8000",
173    ///         "env_name": "default",
174    ///         "policy": { "model": "gpt-4o-mini", "provider": "openai" },
175    ///         "gepa": { "rollout_budget": 100 }
176    ///     }),
177    ///     None,
178    ///     None,
179    ///     None,
180    /// )?;
181    /// ```
182    pub fn from_dict(
183        config: Value,
184        api_key: Option<&str>,
185        base_url: Option<&str>,
186        container_worker_token: Option<String>,
187    ) -> Result<Self, CoreError> {
188        let api_key = match api_key {
189            Some(k) => k.to_string(),
190            None => auth::get_api_key(None)
191                .ok_or_else(|| CoreError::Authentication("SYNTH_API_KEY not found".to_string()))?,
192        };
193
194        let client = SynthClient::new(&api_key, base_url)?;
195
196        Ok(Self {
197            client,
198            job_id: None,
199            config,
200            container_worker_token,
201            tracker: ProgressTracker::new(),
202        })
203    }
204
205    /// Reconnect to an existing job by ID.
206    ///
207    /// # Arguments
208    ///
209    /// * `job_id` - Existing job ID
210    /// * `api_key` - Optional API key
211    /// * `base_url` - Optional base URL
212    pub fn from_job_id(
213        job_id: &str,
214        api_key: Option<&str>,
215        base_url: Option<&str>,
216    ) -> Result<Self, CoreError> {
217        let api_key = match api_key {
218            Some(k) => k.to_string(),
219            None => auth::get_api_key(None)
220                .ok_or_else(|| CoreError::Authentication("SYNTH_API_KEY not found".to_string()))?,
221        };
222
223        let client = SynthClient::new(&api_key, base_url)?;
224
225        Ok(Self {
226            client,
227            job_id: Some(job_id.to_string()),
228            config: Value::Null,
229            container_worker_token: None,
230            tracker: ProgressTracker::new(),
231        })
232    }
233
234    /// Get the job ID (if submitted).
235    pub fn job_id(&self) -> Option<&str> {
236        self.job_id.as_deref()
237    }
238
239    /// Get the progress tracker.
240    pub fn tracker(&self) -> &ProgressTracker {
241        &self.tracker
242    }
243
244    /// Submit the job to the backend.
245    ///
246    /// Returns the job ID on success.
247    pub async fn submit(&mut self) -> Result<String, CoreError> {
248        if self.job_id.is_some() {
249            return Err(CoreError::Validation("job already submitted".to_string()));
250        }
251
252        if self.config.is_null() {
253            return Err(CoreError::Validation(
254                "no configuration provided".to_string(),
255            ));
256        }
257
258        // Submit via jobs API
259        let job_id = self
260            .client
261            .jobs()
262            .submit_raw_with_worker_token(self.config.clone(), self.container_worker_token.clone())
263            .await?;
264        self.job_id = Some(job_id.clone());
265
266        Ok(job_id)
267    }
268
269    /// Get the current job status.
270    pub async fn get_status(&self) -> Result<PromptLearningResult, CoreError> {
271        let job_id = self
272            .job_id
273            .as_ref()
274            .ok_or_else(|| CoreError::Validation("job not submitted yet".to_string()))?;
275
276        let result = self.client.jobs().get_status(job_id).await?;
277
278        Ok(PromptLearningResult {
279            job_id: result.job_id,
280            status: result.status,
281            best_reward: result.best_reward,
282            best_candidate: result.best_candidate,
283            lever_summary: result.lever_summary,
284            sensor_frames: result.sensor_frames,
285            lever_versions: result.lever_versions,
286            best_lever_version: result.best_lever_version,
287            baseline_reward: None,
288            candidates_evaluated: result.candidates_evaluated.unwrap_or(0),
289            generations_completed: result.generations_completed.unwrap_or(0),
290            error: result.error,
291            raw: Value::Null,
292        })
293    }
294
295    /// Poll until the job reaches a terminal state.
296    ///
297    /// # Arguments
298    ///
299    /// * `timeout_secs` - Maximum time to wait
300    /// * `interval_secs` - Polling interval
301    pub async fn poll_until_complete(
302        &self,
303        timeout_secs: f64,
304        interval_secs: f64,
305    ) -> Result<PromptLearningResult, CoreError> {
306        let job_id = self
307            .job_id
308            .as_ref()
309            .ok_or_else(|| CoreError::Validation("job not submitted yet".to_string()))?;
310
311        let result = self
312            .client
313            .jobs()
314            .poll_until_complete(job_id, timeout_secs, interval_secs)
315            .await?;
316
317        Ok(PromptLearningResult {
318            job_id: result.job_id,
319            status: result.status,
320            best_reward: result.best_reward,
321            best_candidate: result.best_candidate,
322            lever_summary: result.lever_summary,
323            sensor_frames: result.sensor_frames,
324            lever_versions: result.lever_versions,
325            best_lever_version: result.best_lever_version,
326            baseline_reward: None,
327            candidates_evaluated: result.candidates_evaluated.unwrap_or(0),
328            generations_completed: result.generations_completed.unwrap_or(0),
329            error: result.error,
330            raw: Value::Null,
331        })
332    }
333
334    /// Stream events until completion with callback.
335    ///
336    /// # Arguments
337    ///
338    /// * `timeout_secs` - Maximum time to wait
339    /// * `on_event` - Optional callback for each event
340    pub async fn stream_until_complete<F>(
341        &mut self,
342        timeout_secs: f64,
343        mut on_event: Option<F>,
344    ) -> Result<PromptLearningResult, CoreError>
345    where
346        F: FnMut(&ParsedEvent),
347    {
348        use std::cell::Cell;
349
350        let job_id = self
351            .job_id
352            .as_ref()
353            .ok_or_else(|| CoreError::Validation("job not submitted yet".to_string()))?;
354
355        eprintln!(
356            "[PL] stream_until_complete: job={} timeout={:.0}s",
357            job_id, timeout_secs
358        );
359
360        let timeout = Duration::from_secs_f64(timeout_secs);
361        let base_url = self.client.base_url().trim_end_matches('/').to_string();
362        let events_url = format!(
363            "{}/api/prompt-learning/online/jobs/{}/events/stream",
364            base_url, job_id
365        );
366        let api_key = self.client.http().api_key().to_string();
367        let mut headers = HeaderMap::new();
368        headers.insert("Accept", HeaderValue::from_static("text/event-stream"));
369        headers.insert(
370            "Authorization",
371            HeaderValue::from_str(&format!("Bearer {}", api_key))
372                .map_err(|_| CoreError::Validation("invalid api key".to_string()))?,
373        );
374        headers.insert(
375            "X-API-Key",
376            HeaderValue::from_str(&api_key)
377                .map_err(|_| CoreError::Validation("invalid api key".to_string()))?,
378        );
379
380        // Use Cell for interior mutability to satisfy borrow checker
381        let terminal_reached = Cell::new(false);
382        let event_count = Cell::new(0u64);
383        let terminal_status = Cell::new(None);
384
385        {
386            let tracker = &mut self.tracker;
387
388            let mut stream =
389                stream_sse_events(&events_url, "GET", headers, None, Some(timeout)).await?;
390
391            while let Some(item) = stream.next().await {
392                let event = item?;
393                if event.data.trim() == "[DONE]" {
394                    break;
395                }
396
397                let payload: Value = serde_json::from_str(&event.data).unwrap_or(Value::Null);
398                let parsed = super::events::EventParser::parse(&payload);
399                let count = event_count.get() + 1;
400                event_count.set(count);
401
402                // Update tracker
403                tracker.update(&parsed);
404
405                // Log progress periodically
406                if count % 5 == 0 || parsed.category.is_terminal() {
407                    eprintln!(
408                        "[PL] Event #{}: type={} category={:?} | tracker: best={:.3} baseline={:?} candidates={} gens={}",
409                        count,
410                        parsed.event_type,
411                        parsed.category,
412                        tracker.best_reward(),
413                        tracker.baseline_reward(),
414                        tracker.progress.candidates_evaluated,
415                        tracker.progress.generations_completed,
416                    );
417                }
418
419                // Call user callback
420                if let Some(ref mut cb) = on_event {
421                    cb(&parsed);
422                }
423
424                // Check for terminal
425                if parsed.category.is_terminal() {
426                    eprintln!(
427                        "[PL] Terminal event received: {} (category={:?})",
428                        parsed.event_type, parsed.category
429                    );
430                    terminal_status.set(super::events::EventParser::terminal_status(
431                        &parsed.event_type,
432                    ));
433                    terminal_reached.set(true);
434                    break;
435                }
436            }
437        }
438
439        eprintln!(
440            "[PL] stream_until_complete: streaming finished, processed {} events",
441            event_count.get()
442        );
443
444        if !terminal_reached.get() {
445            return Err(CoreError::Timeout(
446                "stream ended without terminal event".to_string(),
447            ));
448        }
449
450        // Get final status (tracker borrow is dropped now)
451        eprintln!("[PL] Fetching final job status...");
452        let status_result = match self.get_status().await {
453            Ok(result) => Some(result),
454            Err(err) => {
455                eprintln!("[PL] Warning: failed to fetch final job status: {}", err);
456                None
457            }
458        };
459
460        let mut final_status = status_result
461            .as_ref()
462            .map(|result| result.status)
463            .unwrap_or(crate::api::types::PolicyJobStatus::Succeeded);
464        if !final_status.is_terminal() {
465            if let Some(status) = terminal_status.get() {
466                final_status = match status {
467                    TerminalStatus::Succeeded => crate::api::types::PolicyJobStatus::Succeeded,
468                    TerminalStatus::Failed => crate::api::types::PolicyJobStatus::Failed,
469                    TerminalStatus::Cancelled => crate::api::types::PolicyJobStatus::Cancelled,
470                    TerminalStatus::Paused => crate::api::types::PolicyJobStatus::Paused,
471                };
472                eprintln!(
473                    "[PL] Final status override from terminal event: {:?}",
474                    final_status
475                );
476            }
477        }
478        eprintln!(
479            "[PL] Final status: status={:?} best_reward={:?} error={:?}",
480            final_status,
481            status_result.as_ref().and_then(|result| result.best_reward),
482            status_result
483                .as_ref()
484                .and_then(|result| result.error.clone())
485        );
486
487        // Merge tracker data with status (fall back to tracker if status fetch failed)
488        let result = PromptLearningResult {
489            job_id: status_result
490                .as_ref()
491                .map(|result| result.job_id.clone())
492                .unwrap_or_else(|| job_id.to_string()),
493            status: final_status,
494            best_reward: status_result
495                .as_ref()
496                .and_then(|result| result.best_reward)
497                .or(Some(self.tracker.best_reward())),
498            best_candidate: status_result
499                .as_ref()
500                .and_then(|result| result.best_candidate.clone()),
501            lever_summary: status_result
502                .as_ref()
503                .and_then(|result| result.lever_summary.clone()),
504            sensor_frames: status_result
505                .as_ref()
506                .map(|result| result.sensor_frames.clone())
507                .unwrap_or_default(),
508            lever_versions: status_result
509                .as_ref()
510                .map(|result| result.lever_versions.clone())
511                .unwrap_or_default(),
512            best_lever_version: status_result
513                .as_ref()
514                .and_then(|result| result.best_lever_version),
515            baseline_reward: self.tracker.baseline_reward(),
516            candidates_evaluated: self.tracker.progress.candidates_evaluated,
517            generations_completed: self.tracker.progress.generations_completed,
518            error: status_result.and_then(|result| result.error),
519            raw: Value::Null,
520        };
521
522        eprintln!(
523            "[PL] RESULT: status={:?} best={:?} baseline={:?} candidates={} gens={}",
524            result.status,
525            result.best_reward,
526            result.baseline_reward,
527            result.candidates_evaluated,
528            result.generations_completed
529        );
530
531        Ok(result)
532    }
533
534    /// Cancel a running job.
535    ///
536    /// # Arguments
537    ///
538    /// * `reason` - Optional cancellation reason
539    pub async fn cancel(&self, reason: Option<&str>) -> Result<(), CoreError> {
540        let job_id = self
541            .job_id
542            .as_ref()
543            .ok_or_else(|| CoreError::Validation("job not submitted yet".to_string()))?;
544
545        self.client.jobs().cancel(job_id, reason).await
546    }
547
548    /// Pause a running job.
549    ///
550    /// # Arguments
551    ///
552    /// * `reason` - Optional pause reason
553    pub async fn pause(&self, reason: Option<&str>) -> Result<(), CoreError> {
554        let job_id = self
555            .job_id
556            .as_ref()
557            .ok_or_else(|| CoreError::Validation("job not submitted yet".to_string()))?;
558
559        self.client.jobs().pause(job_id, reason).await
560    }
561
562    /// Resume a paused job.
563    ///
564    /// # Arguments
565    ///
566    /// * `reason` - Optional resume reason
567    pub async fn resume(&self, reason: Option<&str>) -> Result<(), CoreError> {
568        let job_id = self
569            .job_id
570            .as_ref()
571            .ok_or_else(|| CoreError::Validation("job not submitted yet".to_string()))?;
572
573        self.client.jobs().resume(job_id, reason).await
574    }
575
576    /// Get detailed results including prompt extraction.
577    ///
578    /// This fetches events to extract the best prompts.
579    pub async fn get_results(&self) -> Result<PromptResults, CoreError> {
580        // Get final status for best_candidate
581        let status = self.get_status().await?;
582
583        let best_candidate = status.get_system_prompt();
584        let best_reward = status.best_reward.or(Some(self.tracker.best_reward()));
585
586        // Build ranked prompts from tracker candidates
587        let mut top_prompts: Vec<RankedPrompt> = self
588            .tracker
589            .candidates
590            .iter()
591            .filter(|c| c.accepted || c.is_pareto)
592            .map(|c| RankedPrompt {
593                rank: 0,
594                candidate_id: c.candidate_id.clone(),
595                train_accuracy: c.reward,
596                val_accuracy: c.val_reward,
597                prompt: None,
598            })
599            .collect();
600
601        // Sort by accuracy descending
602        top_prompts.sort_by(|a, b| {
603            let a_score = a.val_accuracy.or(a.train_accuracy).unwrap_or(0.0);
604            let b_score = b.val_accuracy.or(b.train_accuracy).unwrap_or(0.0);
605            b_score
606                .partial_cmp(&a_score)
607                .unwrap_or(std::cmp::Ordering::Equal)
608        });
609
610        // Assign ranks
611        for (i, prompt) in top_prompts.iter_mut().enumerate() {
612            prompt.rank = (i + 1) as i32;
613        }
614
615        Ok(PromptResults {
616            best_candidate,
617            best_reward,
618            top_prompts,
619            lever_summary: status.lever_summary,
620            sensor_frames: status.sensor_frames,
621            lever_versions: status.lever_versions,
622            best_lever_version: status.best_lever_version,
623        })
624    }
625}
626
627#[cfg(test)]
628mod tests {
629    use super::*;
630    use serde_json::json;
631
632    #[test]
633    fn test_result_status() {
634        let result = PromptLearningResult {
635            job_id: "test".to_string(),
636            status: PolicyJobStatus::Succeeded,
637            best_reward: Some(0.85),
638            best_candidate: None,
639            lever_summary: None,
640            sensor_frames: Vec::new(),
641            lever_versions: HashMap::new(),
642            best_lever_version: None,
643            baseline_reward: None,
644            candidates_evaluated: 10,
645            generations_completed: 3,
646            error: None,
647            raw: Value::Null,
648        };
649
650        assert!(result.succeeded());
651        assert!(!result.failed());
652        assert!(result.is_terminal());
653    }
654
655    #[test]
656    fn test_result_get_system_prompt() {
657        let result = PromptLearningResult {
658            job_id: "test".to_string(),
659            status: PolicyJobStatus::Succeeded,
660            best_reward: Some(0.85),
661            best_candidate: Some(json!({
662                "system_prompt": "You are a helpful assistant."
663            })),
664            lever_summary: None,
665            sensor_frames: Vec::new(),
666            lever_versions: HashMap::new(),
667            best_lever_version: None,
668            baseline_reward: None,
669            candidates_evaluated: 10,
670            generations_completed: 3,
671            error: None,
672            raw: Value::Null,
673        };
674
675        assert_eq!(
676            result.get_system_prompt(),
677            Some("You are a helpful assistant.".to_string())
678        );
679    }
680
681    #[test]
682    fn test_ranked_prompt_sorting() {
683        let mut prompts = vec![
684            RankedPrompt {
685                rank: 0,
686                candidate_id: "a".to_string(),
687                train_accuracy: Some(0.7),
688                val_accuracy: None,
689                prompt: None,
690            },
691            RankedPrompt {
692                rank: 0,
693                candidate_id: "b".to_string(),
694                train_accuracy: Some(0.9),
695                val_accuracy: None,
696                prompt: None,
697            },
698            RankedPrompt {
699                rank: 0,
700                candidate_id: "c".to_string(),
701                train_accuracy: Some(0.8),
702                val_accuracy: Some(0.85),
703                prompt: None,
704            },
705        ];
706
707        // Sort by accuracy descending (val_accuracy takes precedence)
708        prompts.sort_by(|a, b| {
709            let a_score = a.val_accuracy.or(a.train_accuracy).unwrap_or(0.0);
710            let b_score = b.val_accuracy.or(b.train_accuracy).unwrap_or(0.0);
711            b_score
712                .partial_cmp(&a_score)
713                .unwrap_or(std::cmp::Ordering::Equal)
714        });
715
716        assert_eq!(prompts[0].candidate_id, "b"); // 0.9
717        assert_eq!(prompts[1].candidate_id, "c"); // 0.85 (val)
718        assert_eq!(prompts[2].candidate_id, "a"); // 0.7
719    }
720}