1use std::time::Duration;
7
8use reqwest::header::{HeaderMap, HeaderValue};
9use serde::{Deserialize, Serialize};
10use serde_json::Value;
11
12use crate::api::types::PolicyJobStatus;
13use crate::api::SynthClient;
14use crate::auth;
15use crate::errors::CoreError;
16
17use super::events::{ParsedEvent, TerminalStatus};
18use super::progress::ProgressTracker;
19use crate::sse::stream_sse_events;
20use futures_util::StreamExt;
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct PromptLearningResult {
25 pub job_id: String,
27 pub status: PolicyJobStatus,
29 #[serde(default, alias = "best_score")]
31 pub best_reward: Option<f64>,
32 #[serde(default)]
34 pub best_prompt: Option<Value>,
35 #[serde(default, alias = "baseline_score")]
37 pub baseline_reward: Option<f64>,
38 #[serde(default)]
40 pub candidates_evaluated: i32,
41 #[serde(default)]
43 pub generations_completed: i32,
44 #[serde(default)]
46 pub error: Option<String>,
47 #[serde(default)]
49 pub raw: Value,
50}
51
52impl PromptLearningResult {
53 pub fn succeeded(&self) -> bool {
55 self.status == PolicyJobStatus::Succeeded
56 }
57
58 pub fn failed(&self) -> bool {
60 self.status == PolicyJobStatus::Failed
61 }
62
63 pub fn is_terminal(&self) -> bool {
65 self.status.is_terminal()
66 }
67
68 pub fn get_system_prompt(&self) -> Option<String> {
70 self.best_prompt.as_ref().and_then(|p| {
71 p.get("system_prompt")
73 .and_then(|v| v.as_str())
74 .or_else(|| p.get("instruction").and_then(|v| v.as_str()))
75 .or_else(|| {
76 p.get("stages")
77 .and_then(|s| s.get("main"))
78 .and_then(|m| m.get("instruction"))
79 .and_then(|v| v.as_str())
80 })
81 .map(|s| s.to_string())
82 })
83 }
84}
85
86#[derive(Debug, Clone, Serialize, Deserialize)]
88pub struct RankedPrompt {
89 pub rank: i32,
91 pub candidate_id: String,
93 #[serde(default)]
95 pub train_accuracy: Option<f64>,
96 #[serde(default)]
98 pub val_accuracy: Option<f64>,
99 #[serde(default)]
101 pub prompt: Option<Value>,
102}
103
104#[derive(Debug, Clone, Serialize, Deserialize)]
106pub struct PromptResults {
107 #[serde(default)]
109 pub best_prompt: Option<String>,
110 #[serde(default, alias = "best_score")]
112 pub best_reward: Option<f64>,
113 #[serde(default)]
115 pub top_prompts: Vec<RankedPrompt>,
116}
117
118pub struct PromptLearningJob {
120 client: SynthClient,
122 job_id: Option<String>,
124 config: Value,
126 task_app_worker_token: Option<String>,
128 tracker: ProgressTracker,
130}
131
132impl PromptLearningJob {
133 pub fn from_dict(
158 config: Value,
159 api_key: Option<&str>,
160 base_url: Option<&str>,
161 task_app_worker_token: Option<String>,
162 ) -> Result<Self, CoreError> {
163 let api_key = match api_key {
164 Some(k) => k.to_string(),
165 None => auth::get_api_key(None)
166 .ok_or_else(|| CoreError::Authentication("SYNTH_API_KEY not found".to_string()))?,
167 };
168
169 let client = SynthClient::new(&api_key, base_url)?;
170
171 Ok(Self {
172 client,
173 job_id: None,
174 config,
175 task_app_worker_token,
176 tracker: ProgressTracker::new(),
177 })
178 }
179
180 pub fn from_job_id(
188 job_id: &str,
189 api_key: Option<&str>,
190 base_url: Option<&str>,
191 ) -> Result<Self, CoreError> {
192 let api_key = match api_key {
193 Some(k) => k.to_string(),
194 None => auth::get_api_key(None)
195 .ok_or_else(|| CoreError::Authentication("SYNTH_API_KEY not found".to_string()))?,
196 };
197
198 let client = SynthClient::new(&api_key, base_url)?;
199
200 Ok(Self {
201 client,
202 job_id: Some(job_id.to_string()),
203 config: Value::Null,
204 task_app_worker_token: None,
205 tracker: ProgressTracker::new(),
206 })
207 }
208
209 pub fn job_id(&self) -> Option<&str> {
211 self.job_id.as_deref()
212 }
213
214 pub fn tracker(&self) -> &ProgressTracker {
216 &self.tracker
217 }
218
219 pub async fn submit(&mut self) -> Result<String, CoreError> {
223 if self.job_id.is_some() {
224 return Err(CoreError::Validation("job already submitted".to_string()));
225 }
226
227 if self.config.is_null() {
228 return Err(CoreError::Validation(
229 "no configuration provided".to_string(),
230 ));
231 }
232
233 let job_id = self
235 .client
236 .jobs()
237 .submit_raw_with_worker_token(self.config.clone(), self.task_app_worker_token.clone())
238 .await?;
239 self.job_id = Some(job_id.clone());
240
241 Ok(job_id)
242 }
243
244 pub async fn get_status(&self) -> Result<PromptLearningResult, CoreError> {
246 let job_id = self
247 .job_id
248 .as_ref()
249 .ok_or_else(|| CoreError::Validation("job not submitted yet".to_string()))?;
250
251 let result = self.client.jobs().get_status(job_id).await?;
252
253 Ok(PromptLearningResult {
254 job_id: result.job_id,
255 status: result.status,
256 best_reward: result.best_reward,
257 best_prompt: result.best_prompt,
258 baseline_reward: None,
259 candidates_evaluated: result.candidates_evaluated.unwrap_or(0),
260 generations_completed: result.generations_completed.unwrap_or(0),
261 error: result.error,
262 raw: Value::Null,
263 })
264 }
265
266 pub async fn poll_until_complete(
273 &self,
274 timeout_secs: f64,
275 interval_secs: f64,
276 ) -> Result<PromptLearningResult, CoreError> {
277 let job_id = self
278 .job_id
279 .as_ref()
280 .ok_or_else(|| CoreError::Validation("job not submitted yet".to_string()))?;
281
282 let result = self
283 .client
284 .jobs()
285 .poll_until_complete(job_id, timeout_secs, interval_secs)
286 .await?;
287
288 Ok(PromptLearningResult {
289 job_id: result.job_id,
290 status: result.status,
291 best_reward: result.best_reward,
292 best_prompt: result.best_prompt,
293 baseline_reward: None,
294 candidates_evaluated: result.candidates_evaluated.unwrap_or(0),
295 generations_completed: result.generations_completed.unwrap_or(0),
296 error: result.error,
297 raw: Value::Null,
298 })
299 }
300
301 pub async fn stream_until_complete<F>(
308 &mut self,
309 timeout_secs: f64,
310 mut on_event: Option<F>,
311 ) -> Result<PromptLearningResult, CoreError>
312 where
313 F: FnMut(&ParsedEvent),
314 {
315 use std::cell::Cell;
316
317 let job_id = self
318 .job_id
319 .as_ref()
320 .ok_or_else(|| CoreError::Validation("job not submitted yet".to_string()))?;
321
322 eprintln!(
323 "[PL] stream_until_complete: job={} timeout={:.0}s",
324 job_id, timeout_secs
325 );
326
327 let timeout = Duration::from_secs_f64(timeout_secs);
328 let base_url = self.client.base_url().trim_end_matches('/').to_string();
329 let events_url = format!(
330 "{}/api/prompt-learning/online/jobs/{}/events/stream",
331 base_url, job_id
332 );
333 let api_key = self.client.http().api_key().to_string();
334 let mut headers = HeaderMap::new();
335 headers.insert("Accept", HeaderValue::from_static("text/event-stream"));
336 headers.insert(
337 "Authorization",
338 HeaderValue::from_str(&format!("Bearer {}", api_key))
339 .map_err(|_| CoreError::Validation("invalid api key".to_string()))?,
340 );
341 headers.insert(
342 "X-API-Key",
343 HeaderValue::from_str(&api_key)
344 .map_err(|_| CoreError::Validation("invalid api key".to_string()))?,
345 );
346
347 let terminal_reached = Cell::new(false);
349 let event_count = Cell::new(0u64);
350 let terminal_status = Cell::new(None);
351
352 {
353 let tracker = &mut self.tracker;
354
355 let mut stream =
356 stream_sse_events(&events_url, "GET", headers, None, Some(timeout)).await?;
357
358 while let Some(item) = stream.next().await {
359 let event = item?;
360 if event.data.trim() == "[DONE]" {
361 break;
362 }
363
364 let payload: Value = serde_json::from_str(&event.data).unwrap_or(Value::Null);
365 let parsed = super::events::EventParser::parse(&payload);
366 let count = event_count.get() + 1;
367 event_count.set(count);
368
369 tracker.update(&parsed);
371
372 if count % 5 == 0 || parsed.category.is_terminal() {
374 eprintln!(
375 "[PL] Event #{}: type={} category={:?} | tracker: best={:.3} baseline={:?} candidates={} gens={}",
376 count,
377 parsed.event_type,
378 parsed.category,
379 tracker.best_reward(),
380 tracker.baseline_reward(),
381 tracker.progress.candidates_evaluated,
382 tracker.progress.generations_completed,
383 );
384 }
385
386 if let Some(ref mut cb) = on_event {
388 cb(&parsed);
389 }
390
391 if parsed.category.is_terminal() {
393 eprintln!(
394 "[PL] Terminal event received: {} (category={:?})",
395 parsed.event_type, parsed.category
396 );
397 terminal_status.set(super::events::EventParser::terminal_status(
398 &parsed.event_type,
399 ));
400 terminal_reached.set(true);
401 break;
402 }
403 }
404 }
405
406 eprintln!(
407 "[PL] stream_until_complete: streaming finished, processed {} events",
408 event_count.get()
409 );
410
411 if !terminal_reached.get() {
412 return Err(CoreError::Timeout(
413 "stream ended without terminal event".to_string(),
414 ));
415 }
416
417 eprintln!("[PL] Fetching final job status...");
419 let status_result = match self.get_status().await {
420 Ok(result) => Some(result),
421 Err(err) => {
422 eprintln!("[PL] Warning: failed to fetch final job status: {}", err);
423 None
424 }
425 };
426
427 let mut final_status = status_result
428 .as_ref()
429 .map(|result| result.status)
430 .unwrap_or(crate::api::types::PolicyJobStatus::Succeeded);
431 if !final_status.is_terminal() {
432 if let Some(status) = terminal_status.get() {
433 final_status = match status {
434 TerminalStatus::Succeeded => crate::api::types::PolicyJobStatus::Succeeded,
435 TerminalStatus::Failed => crate::api::types::PolicyJobStatus::Failed,
436 TerminalStatus::Cancelled => crate::api::types::PolicyJobStatus::Cancelled,
437 TerminalStatus::Paused => crate::api::types::PolicyJobStatus::Paused,
438 };
439 eprintln!(
440 "[PL] Final status override from terminal event: {:?}",
441 final_status
442 );
443 }
444 }
445 eprintln!(
446 "[PL] Final status: status={:?} best_reward={:?} error={:?}",
447 final_status,
448 status_result.as_ref().and_then(|result| result.best_reward),
449 status_result
450 .as_ref()
451 .and_then(|result| result.error.clone())
452 );
453
454 let result = PromptLearningResult {
456 job_id: status_result
457 .as_ref()
458 .map(|result| result.job_id.clone())
459 .unwrap_or_else(|| job_id.to_string()),
460 status: final_status,
461 best_reward: status_result
462 .as_ref()
463 .and_then(|result| result.best_reward)
464 .or(Some(self.tracker.best_reward())),
465 best_prompt: status_result
466 .as_ref()
467 .and_then(|result| result.best_prompt.clone()),
468 baseline_reward: self.tracker.baseline_reward(),
469 candidates_evaluated: self.tracker.progress.candidates_evaluated,
470 generations_completed: self.tracker.progress.generations_completed,
471 error: status_result.and_then(|result| result.error),
472 raw: Value::Null,
473 };
474
475 eprintln!(
476 "[PL] RESULT: status={:?} best={:?} baseline={:?} candidates={} gens={}",
477 result.status,
478 result.best_reward,
479 result.baseline_reward,
480 result.candidates_evaluated,
481 result.generations_completed
482 );
483
484 Ok(result)
485 }
486
487 pub async fn cancel(&self, reason: Option<&str>) -> Result<(), CoreError> {
493 let job_id = self
494 .job_id
495 .as_ref()
496 .ok_or_else(|| CoreError::Validation("job not submitted yet".to_string()))?;
497
498 self.client.jobs().cancel(job_id, reason).await
499 }
500
501 pub async fn pause(&self, reason: Option<&str>) -> Result<(), CoreError> {
507 let job_id = self
508 .job_id
509 .as_ref()
510 .ok_or_else(|| CoreError::Validation("job not submitted yet".to_string()))?;
511
512 self.client.jobs().pause(job_id, reason).await
513 }
514
515 pub async fn resume(&self, reason: Option<&str>) -> Result<(), CoreError> {
521 let job_id = self
522 .job_id
523 .as_ref()
524 .ok_or_else(|| CoreError::Validation("job not submitted yet".to_string()))?;
525
526 self.client.jobs().resume(job_id, reason).await
527 }
528
529 pub async fn get_results(&self) -> Result<PromptResults, CoreError> {
533 let status = self.get_status().await?;
535
536 let best_prompt = status.get_system_prompt();
537 let best_reward = status.best_reward.or(Some(self.tracker.best_reward()));
538
539 let mut top_prompts: Vec<RankedPrompt> = self
541 .tracker
542 .candidates
543 .iter()
544 .filter(|c| c.accepted || c.is_pareto)
545 .map(|c| RankedPrompt {
546 rank: 0,
547 candidate_id: c.candidate_id.clone(),
548 train_accuracy: c.reward,
549 val_accuracy: c.val_reward,
550 prompt: None,
551 })
552 .collect();
553
554 top_prompts.sort_by(|a, b| {
556 let a_score = a.val_accuracy.or(a.train_accuracy).unwrap_or(0.0);
557 let b_score = b.val_accuracy.or(b.train_accuracy).unwrap_or(0.0);
558 b_score
559 .partial_cmp(&a_score)
560 .unwrap_or(std::cmp::Ordering::Equal)
561 });
562
563 for (i, prompt) in top_prompts.iter_mut().enumerate() {
565 prompt.rank = (i + 1) as i32;
566 }
567
568 Ok(PromptResults {
569 best_prompt,
570 best_reward,
571 top_prompts,
572 })
573 }
574}
575
576#[cfg(test)]
577mod tests {
578 use super::*;
579 use serde_json::json;
580
581 #[test]
582 fn test_result_status() {
583 let result = PromptLearningResult {
584 job_id: "test".to_string(),
585 status: PolicyJobStatus::Succeeded,
586 best_reward: Some(0.85),
587 best_prompt: None,
588 baseline_reward: None,
589 candidates_evaluated: 10,
590 generations_completed: 3,
591 error: None,
592 raw: Value::Null,
593 };
594
595 assert!(result.succeeded());
596 assert!(!result.failed());
597 assert!(result.is_terminal());
598 }
599
600 #[test]
601 fn test_result_get_system_prompt() {
602 let result = PromptLearningResult {
603 job_id: "test".to_string(),
604 status: PolicyJobStatus::Succeeded,
605 best_reward: Some(0.85),
606 best_prompt: Some(json!({
607 "system_prompt": "You are a helpful assistant."
608 })),
609 baseline_reward: None,
610 candidates_evaluated: 10,
611 generations_completed: 3,
612 error: None,
613 raw: Value::Null,
614 };
615
616 assert_eq!(
617 result.get_system_prompt(),
618 Some("You are a helpful assistant.".to_string())
619 );
620 }
621
622 #[test]
623 fn test_ranked_prompt_sorting() {
624 let mut prompts = vec![
625 RankedPrompt {
626 rank: 0,
627 candidate_id: "a".to_string(),
628 train_accuracy: Some(0.7),
629 val_accuracy: None,
630 prompt: None,
631 },
632 RankedPrompt {
633 rank: 0,
634 candidate_id: "b".to_string(),
635 train_accuracy: Some(0.9),
636 val_accuracy: None,
637 prompt: None,
638 },
639 RankedPrompt {
640 rank: 0,
641 candidate_id: "c".to_string(),
642 train_accuracy: Some(0.8),
643 val_accuracy: Some(0.85),
644 prompt: None,
645 },
646 ];
647
648 prompts.sort_by(|a, b| {
650 let a_score = a.val_accuracy.or(a.train_accuracy).unwrap_or(0.0);
651 let b_score = b.val_accuracy.or(b.train_accuracy).unwrap_or(0.0);
652 b_score
653 .partial_cmp(&a_score)
654 .unwrap_or(std::cmp::Ordering::Equal)
655 });
656
657 assert_eq!(prompts[0].candidate_id, "b"); assert_eq!(prompts[1].candidate_id, "c"); assert_eq!(prompts[2].candidate_id, "a"); }
661}