1use std::collections::HashMap;
7use std::time::Duration;
8
9use reqwest::header::{HeaderMap, HeaderValue};
10use serde::{Deserialize, Serialize};
11use serde_json::Value;
12
13use crate::api::types::PolicyJobStatus;
14use crate::api::SynthClient;
15use crate::auth;
16use crate::errors::CoreError;
17
18use super::events::{ParsedEvent, TerminalStatus};
19use super::progress::ProgressTracker;
20use crate::sse::stream_sse_events;
21use futures_util::StreamExt;
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct PromptLearningResult {
26 pub job_id: String,
28 pub status: PolicyJobStatus,
30 #[serde(default, alias = "best_score")]
32 pub best_reward: Option<f64>,
33 #[serde(default)]
35 pub best_candidate: Option<Value>,
36 #[serde(default)]
38 pub lever_summary: Option<Value>,
39 #[serde(default)]
41 pub sensor_frames: Vec<Value>,
42 #[serde(default)]
44 pub lever_versions: HashMap<String, i64>,
45 #[serde(default)]
47 pub best_lever_version: Option<i64>,
48 #[serde(default, alias = "baseline_score")]
50 pub baseline_reward: Option<f64>,
51 #[serde(default)]
53 pub candidates_evaluated: i32,
54 #[serde(default)]
56 pub generations_completed: i32,
57 #[serde(default)]
59 pub error: Option<String>,
60 #[serde(default)]
62 pub raw: Value,
63}
64
65impl PromptLearningResult {
66 pub fn succeeded(&self) -> bool {
68 self.status == PolicyJobStatus::Succeeded
69 }
70
71 pub fn failed(&self) -> bool {
73 self.status == PolicyJobStatus::Failed
74 }
75
76 pub fn is_terminal(&self) -> bool {
78 self.status.is_terminal()
79 }
80
81 pub fn get_system_prompt(&self) -> Option<String> {
83 self.best_candidate.as_ref().and_then(|p| {
84 p.get("system_prompt")
86 .and_then(|v| v.as_str())
87 .or_else(|| p.get("instruction").and_then(|v| v.as_str()))
88 .or_else(|| {
89 p.get("stages")
90 .and_then(|s| s.get("main"))
91 .and_then(|m| m.get("instruction"))
92 .and_then(|v| v.as_str())
93 })
94 .map(|s| s.to_string())
95 })
96 }
97}
98
99#[derive(Debug, Clone, Serialize, Deserialize)]
101pub struct RankedPrompt {
102 pub rank: i32,
104 pub candidate_id: String,
106 #[serde(default)]
108 pub train_accuracy: Option<f64>,
109 #[serde(default)]
111 pub val_accuracy: Option<f64>,
112 #[serde(default)]
114 pub prompt: Option<Value>,
115}
116
117#[derive(Debug, Clone, Serialize, Deserialize)]
119pub struct PromptResults {
120 #[serde(default)]
122 pub best_candidate: Option<String>,
123 #[serde(default, alias = "best_score")]
125 pub best_reward: Option<f64>,
126 #[serde(default)]
128 pub top_prompts: Vec<RankedPrompt>,
129 #[serde(default)]
131 pub lever_summary: Option<Value>,
132 #[serde(default)]
134 pub sensor_frames: Vec<Value>,
135 #[serde(default)]
137 pub lever_versions: HashMap<String, i64>,
138 #[serde(default)]
140 pub best_lever_version: Option<i64>,
141}
142
143pub struct PromptLearningJob {
145 client: SynthClient,
147 job_id: Option<String>,
149 config: Value,
151 container_worker_token: Option<String>,
153 tracker: ProgressTracker,
155}
156
157impl PromptLearningJob {
158 pub fn from_dict(
183 config: Value,
184 api_key: Option<&str>,
185 base_url: Option<&str>,
186 container_worker_token: Option<String>,
187 ) -> Result<Self, CoreError> {
188 let api_key = match api_key {
189 Some(k) => k.to_string(),
190 None => auth::get_api_key(None)
191 .ok_or_else(|| CoreError::Authentication("SYNTH_API_KEY not found".to_string()))?,
192 };
193
194 let client = SynthClient::new(&api_key, base_url)?;
195
196 Ok(Self {
197 client,
198 job_id: None,
199 config,
200 container_worker_token,
201 tracker: ProgressTracker::new(),
202 })
203 }
204
205 pub fn from_job_id(
213 job_id: &str,
214 api_key: Option<&str>,
215 base_url: Option<&str>,
216 ) -> Result<Self, CoreError> {
217 let api_key = match api_key {
218 Some(k) => k.to_string(),
219 None => auth::get_api_key(None)
220 .ok_or_else(|| CoreError::Authentication("SYNTH_API_KEY not found".to_string()))?,
221 };
222
223 let client = SynthClient::new(&api_key, base_url)?;
224
225 Ok(Self {
226 client,
227 job_id: Some(job_id.to_string()),
228 config: Value::Null,
229 container_worker_token: None,
230 tracker: ProgressTracker::new(),
231 })
232 }
233
234 pub fn job_id(&self) -> Option<&str> {
236 self.job_id.as_deref()
237 }
238
239 pub fn tracker(&self) -> &ProgressTracker {
241 &self.tracker
242 }
243
244 pub async fn submit(&mut self) -> Result<String, CoreError> {
248 if self.job_id.is_some() {
249 return Err(CoreError::Validation("job already submitted".to_string()));
250 }
251
252 if self.config.is_null() {
253 return Err(CoreError::Validation(
254 "no configuration provided".to_string(),
255 ));
256 }
257
258 let job_id = self
260 .client
261 .jobs()
262 .submit_raw_with_worker_token(self.config.clone(), self.container_worker_token.clone())
263 .await?;
264 self.job_id = Some(job_id.clone());
265
266 Ok(job_id)
267 }
268
269 pub async fn get_status(&self) -> Result<PromptLearningResult, CoreError> {
271 let job_id = self
272 .job_id
273 .as_ref()
274 .ok_or_else(|| CoreError::Validation("job not submitted yet".to_string()))?;
275
276 let result = self.client.jobs().get_status(job_id).await?;
277
278 Ok(PromptLearningResult {
279 job_id: result.job_id,
280 status: result.status,
281 best_reward: result.best_reward,
282 best_candidate: result.best_candidate,
283 lever_summary: result.lever_summary,
284 sensor_frames: result.sensor_frames,
285 lever_versions: result.lever_versions,
286 best_lever_version: result.best_lever_version,
287 baseline_reward: None,
288 candidates_evaluated: result.candidates_evaluated.unwrap_or(0),
289 generations_completed: result.generations_completed.unwrap_or(0),
290 error: result.error,
291 raw: Value::Null,
292 })
293 }
294
295 pub async fn poll_until_complete(
302 &self,
303 timeout_secs: f64,
304 interval_secs: f64,
305 ) -> Result<PromptLearningResult, CoreError> {
306 let job_id = self
307 .job_id
308 .as_ref()
309 .ok_or_else(|| CoreError::Validation("job not submitted yet".to_string()))?;
310
311 let result = self
312 .client
313 .jobs()
314 .poll_until_complete(job_id, timeout_secs, interval_secs)
315 .await?;
316
317 Ok(PromptLearningResult {
318 job_id: result.job_id,
319 status: result.status,
320 best_reward: result.best_reward,
321 best_candidate: result.best_candidate,
322 lever_summary: result.lever_summary,
323 sensor_frames: result.sensor_frames,
324 lever_versions: result.lever_versions,
325 best_lever_version: result.best_lever_version,
326 baseline_reward: None,
327 candidates_evaluated: result.candidates_evaluated.unwrap_or(0),
328 generations_completed: result.generations_completed.unwrap_or(0),
329 error: result.error,
330 raw: Value::Null,
331 })
332 }
333
334 pub async fn stream_until_complete<F>(
341 &mut self,
342 timeout_secs: f64,
343 mut on_event: Option<F>,
344 ) -> Result<PromptLearningResult, CoreError>
345 where
346 F: FnMut(&ParsedEvent),
347 {
348 use std::cell::Cell;
349
350 let job_id = self
351 .job_id
352 .as_ref()
353 .ok_or_else(|| CoreError::Validation("job not submitted yet".to_string()))?;
354
355 eprintln!(
356 "[PL] stream_until_complete: job={} timeout={:.0}s",
357 job_id, timeout_secs
358 );
359
360 let timeout = Duration::from_secs_f64(timeout_secs);
361 let base_url = self.client.base_url().trim_end_matches('/').to_string();
362 let events_url = format!(
363 "{}/api/prompt-learning/online/jobs/{}/events/stream",
364 base_url, job_id
365 );
366 let api_key = self.client.http().api_key().to_string();
367 let mut headers = HeaderMap::new();
368 headers.insert("Accept", HeaderValue::from_static("text/event-stream"));
369 headers.insert(
370 "Authorization",
371 HeaderValue::from_str(&format!("Bearer {}", api_key))
372 .map_err(|_| CoreError::Validation("invalid api key".to_string()))?,
373 );
374 headers.insert(
375 "X-API-Key",
376 HeaderValue::from_str(&api_key)
377 .map_err(|_| CoreError::Validation("invalid api key".to_string()))?,
378 );
379
380 let terminal_reached = Cell::new(false);
382 let event_count = Cell::new(0u64);
383 let terminal_status = Cell::new(None);
384
385 {
386 let tracker = &mut self.tracker;
387
388 let mut stream =
389 stream_sse_events(&events_url, "GET", headers, None, Some(timeout)).await?;
390
391 while let Some(item) = stream.next().await {
392 let event = item?;
393 if event.data.trim() == "[DONE]" {
394 break;
395 }
396
397 let payload: Value = serde_json::from_str(&event.data).unwrap_or(Value::Null);
398 let parsed = super::events::EventParser::parse(&payload);
399 let count = event_count.get() + 1;
400 event_count.set(count);
401
402 tracker.update(&parsed);
404
405 if count % 5 == 0 || parsed.category.is_terminal() {
407 eprintln!(
408 "[PL] Event #{}: type={} category={:?} | tracker: best={:.3} baseline={:?} candidates={} gens={}",
409 count,
410 parsed.event_type,
411 parsed.category,
412 tracker.best_reward(),
413 tracker.baseline_reward(),
414 tracker.progress.candidates_evaluated,
415 tracker.progress.generations_completed,
416 );
417 }
418
419 if let Some(ref mut cb) = on_event {
421 cb(&parsed);
422 }
423
424 if parsed.category.is_terminal() {
426 eprintln!(
427 "[PL] Terminal event received: {} (category={:?})",
428 parsed.event_type, parsed.category
429 );
430 terminal_status.set(super::events::EventParser::terminal_status(
431 &parsed.event_type,
432 ));
433 terminal_reached.set(true);
434 break;
435 }
436 }
437 }
438
439 eprintln!(
440 "[PL] stream_until_complete: streaming finished, processed {} events",
441 event_count.get()
442 );
443
444 if !terminal_reached.get() {
445 return Err(CoreError::Timeout(
446 "stream ended without terminal event".to_string(),
447 ));
448 }
449
450 eprintln!("[PL] Fetching final job status...");
452 let status_result = match self.get_status().await {
453 Ok(result) => Some(result),
454 Err(err) => {
455 eprintln!("[PL] Warning: failed to fetch final job status: {}", err);
456 None
457 }
458 };
459
460 let mut final_status = status_result
461 .as_ref()
462 .map(|result| result.status)
463 .unwrap_or(crate::api::types::PolicyJobStatus::Succeeded);
464 if !final_status.is_terminal() {
465 if let Some(status) = terminal_status.get() {
466 final_status = match status {
467 TerminalStatus::Succeeded => crate::api::types::PolicyJobStatus::Succeeded,
468 TerminalStatus::Failed => crate::api::types::PolicyJobStatus::Failed,
469 TerminalStatus::Cancelled => crate::api::types::PolicyJobStatus::Cancelled,
470 TerminalStatus::Paused => crate::api::types::PolicyJobStatus::Paused,
471 };
472 eprintln!(
473 "[PL] Final status override from terminal event: {:?}",
474 final_status
475 );
476 }
477 }
478 eprintln!(
479 "[PL] Final status: status={:?} best_reward={:?} error={:?}",
480 final_status,
481 status_result.as_ref().and_then(|result| result.best_reward),
482 status_result
483 .as_ref()
484 .and_then(|result| result.error.clone())
485 );
486
487 let result = PromptLearningResult {
489 job_id: status_result
490 .as_ref()
491 .map(|result| result.job_id.clone())
492 .unwrap_or_else(|| job_id.to_string()),
493 status: final_status,
494 best_reward: status_result
495 .as_ref()
496 .and_then(|result| result.best_reward)
497 .or(Some(self.tracker.best_reward())),
498 best_candidate: status_result
499 .as_ref()
500 .and_then(|result| result.best_candidate.clone()),
501 lever_summary: status_result
502 .as_ref()
503 .and_then(|result| result.lever_summary.clone()),
504 sensor_frames: status_result
505 .as_ref()
506 .map(|result| result.sensor_frames.clone())
507 .unwrap_or_default(),
508 lever_versions: status_result
509 .as_ref()
510 .map(|result| result.lever_versions.clone())
511 .unwrap_or_default(),
512 best_lever_version: status_result
513 .as_ref()
514 .and_then(|result| result.best_lever_version),
515 baseline_reward: self.tracker.baseline_reward(),
516 candidates_evaluated: self.tracker.progress.candidates_evaluated,
517 generations_completed: self.tracker.progress.generations_completed,
518 error: status_result.and_then(|result| result.error),
519 raw: Value::Null,
520 };
521
522 eprintln!(
523 "[PL] RESULT: status={:?} best={:?} baseline={:?} candidates={} gens={}",
524 result.status,
525 result.best_reward,
526 result.baseline_reward,
527 result.candidates_evaluated,
528 result.generations_completed
529 );
530
531 Ok(result)
532 }
533
534 pub async fn cancel(&self, reason: Option<&str>) -> Result<(), CoreError> {
540 let job_id = self
541 .job_id
542 .as_ref()
543 .ok_or_else(|| CoreError::Validation("job not submitted yet".to_string()))?;
544
545 self.client.jobs().cancel(job_id, reason).await
546 }
547
548 pub async fn pause(&self, reason: Option<&str>) -> Result<(), CoreError> {
554 let job_id = self
555 .job_id
556 .as_ref()
557 .ok_or_else(|| CoreError::Validation("job not submitted yet".to_string()))?;
558
559 self.client.jobs().pause(job_id, reason).await
560 }
561
562 pub async fn resume(&self, reason: Option<&str>) -> Result<(), CoreError> {
568 let job_id = self
569 .job_id
570 .as_ref()
571 .ok_or_else(|| CoreError::Validation("job not submitted yet".to_string()))?;
572
573 self.client.jobs().resume(job_id, reason).await
574 }
575
576 pub async fn get_results(&self) -> Result<PromptResults, CoreError> {
580 let status = self.get_status().await?;
582
583 let best_candidate = status.get_system_prompt();
584 let best_reward = status.best_reward.or(Some(self.tracker.best_reward()));
585
586 let mut top_prompts: Vec<RankedPrompt> = self
588 .tracker
589 .candidates
590 .iter()
591 .filter(|c| c.accepted || c.is_pareto)
592 .map(|c| RankedPrompt {
593 rank: 0,
594 candidate_id: c.candidate_id.clone(),
595 train_accuracy: c.reward,
596 val_accuracy: c.val_reward,
597 prompt: None,
598 })
599 .collect();
600
601 top_prompts.sort_by(|a, b| {
603 let a_score = a.val_accuracy.or(a.train_accuracy).unwrap_or(0.0);
604 let b_score = b.val_accuracy.or(b.train_accuracy).unwrap_or(0.0);
605 b_score
606 .partial_cmp(&a_score)
607 .unwrap_or(std::cmp::Ordering::Equal)
608 });
609
610 for (i, prompt) in top_prompts.iter_mut().enumerate() {
612 prompt.rank = (i + 1) as i32;
613 }
614
615 Ok(PromptResults {
616 best_candidate,
617 best_reward,
618 top_prompts,
619 lever_summary: status.lever_summary,
620 sensor_frames: status.sensor_frames,
621 lever_versions: status.lever_versions,
622 best_lever_version: status.best_lever_version,
623 })
624 }
625}
626
627#[cfg(test)]
628mod tests {
629 use super::*;
630 use serde_json::json;
631
632 #[test]
633 fn test_result_status() {
634 let result = PromptLearningResult {
635 job_id: "test".to_string(),
636 status: PolicyJobStatus::Succeeded,
637 best_reward: Some(0.85),
638 best_candidate: None,
639 lever_summary: None,
640 sensor_frames: Vec::new(),
641 lever_versions: HashMap::new(),
642 best_lever_version: None,
643 baseline_reward: None,
644 candidates_evaluated: 10,
645 generations_completed: 3,
646 error: None,
647 raw: Value::Null,
648 };
649
650 assert!(result.succeeded());
651 assert!(!result.failed());
652 assert!(result.is_terminal());
653 }
654
655 #[test]
656 fn test_result_get_system_prompt() {
657 let result = PromptLearningResult {
658 job_id: "test".to_string(),
659 status: PolicyJobStatus::Succeeded,
660 best_reward: Some(0.85),
661 best_candidate: Some(json!({
662 "system_prompt": "You are a helpful assistant."
663 })),
664 lever_summary: None,
665 sensor_frames: Vec::new(),
666 lever_versions: HashMap::new(),
667 best_lever_version: None,
668 baseline_reward: None,
669 candidates_evaluated: 10,
670 generations_completed: 3,
671 error: None,
672 raw: Value::Null,
673 };
674
675 assert_eq!(
676 result.get_system_prompt(),
677 Some("You are a helpful assistant.".to_string())
678 );
679 }
680
681 #[test]
682 fn test_ranked_prompt_sorting() {
683 let mut prompts = vec![
684 RankedPrompt {
685 rank: 0,
686 candidate_id: "a".to_string(),
687 train_accuracy: Some(0.7),
688 val_accuracy: None,
689 prompt: None,
690 },
691 RankedPrompt {
692 rank: 0,
693 candidate_id: "b".to_string(),
694 train_accuracy: Some(0.9),
695 val_accuracy: None,
696 prompt: None,
697 },
698 RankedPrompt {
699 rank: 0,
700 candidate_id: "c".to_string(),
701 train_accuracy: Some(0.8),
702 val_accuracy: Some(0.85),
703 prompt: None,
704 },
705 ];
706
707 prompts.sort_by(|a, b| {
709 let a_score = a.val_accuracy.or(a.train_accuracy).unwrap_or(0.0);
710 let b_score = b.val_accuracy.or(b.train_accuracy).unwrap_or(0.0);
711 b_score
712 .partial_cmp(&a_score)
713 .unwrap_or(std::cmp::Ordering::Equal)
714 });
715
716 assert_eq!(prompts[0].candidate_id, "b"); assert_eq!(prompts[1].candidate_id, "c"); assert_eq!(prompts[2].candidate_id, "a"); }
720}