synth_ai_core/orchestration/
prompt_learning.rs1use std::time::Duration;
7
8use serde::{Deserialize, Serialize};
9use serde_json::Value;
10
11use crate::api::types::PolicyJobStatus;
12use crate::api::SynthClient;
13use crate::auth;
14use crate::errors::CoreError;
15
16use super::events::ParsedEvent;
17use super::progress::ProgressTracker;
18use super::streaming::EventStream;
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct PromptLearningResult {
23 pub job_id: String,
25 pub status: PolicyJobStatus,
27 #[serde(default)]
29 pub best_score: Option<f64>,
30 #[serde(default)]
32 pub best_prompt: Option<Value>,
33 #[serde(default)]
35 pub baseline_score: Option<f64>,
36 #[serde(default)]
38 pub candidates_evaluated: i32,
39 #[serde(default)]
41 pub generations_completed: i32,
42 #[serde(default)]
44 pub error: Option<String>,
45 #[serde(default)]
47 pub raw: Value,
48}
49
50impl PromptLearningResult {
51 pub fn succeeded(&self) -> bool {
53 self.status == PolicyJobStatus::Succeeded
54 }
55
56 pub fn failed(&self) -> bool {
58 self.status == PolicyJobStatus::Failed
59 }
60
61 pub fn is_terminal(&self) -> bool {
63 self.status.is_terminal()
64 }
65
66 pub fn get_system_prompt(&self) -> Option<String> {
68 self.best_prompt.as_ref().and_then(|p| {
69 p.get("system_prompt")
71 .and_then(|v| v.as_str())
72 .or_else(|| p.get("instruction").and_then(|v| v.as_str()))
73 .or_else(|| {
74 p.get("stages")
75 .and_then(|s| s.get("main"))
76 .and_then(|m| m.get("instruction"))
77 .and_then(|v| v.as_str())
78 })
79 .map(|s| s.to_string())
80 })
81 }
82}
83
84#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct RankedPrompt {
87 pub rank: i32,
89 pub candidate_id: String,
91 #[serde(default)]
93 pub train_accuracy: Option<f64>,
94 #[serde(default)]
96 pub val_accuracy: Option<f64>,
97 #[serde(default)]
99 pub prompt: Option<Value>,
100}
101
102#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct PromptResults {
105 #[serde(default)]
107 pub best_prompt: Option<String>,
108 #[serde(default)]
110 pub best_score: Option<f64>,
111 #[serde(default)]
113 pub top_prompts: Vec<RankedPrompt>,
114}
115
116pub struct PromptLearningJob {
118 client: SynthClient,
120 job_id: Option<String>,
122 config: Value,
124 tracker: ProgressTracker,
126}
127
128impl PromptLearningJob {
129 pub fn from_dict(
153 config: Value,
154 api_key: Option<&str>,
155 base_url: Option<&str>,
156 ) -> Result<Self, CoreError> {
157 let api_key = match api_key {
158 Some(k) => k.to_string(),
159 None => auth::get_api_key(None).ok_or_else(|| {
160 CoreError::Authentication("SYNTH_API_KEY not found".to_string())
161 })?,
162 };
163
164 let client = SynthClient::new(&api_key, base_url)?;
165
166 Ok(Self {
167 client,
168 job_id: None,
169 config,
170 tracker: ProgressTracker::new(),
171 })
172 }
173
174 pub fn from_job_id(
182 job_id: &str,
183 api_key: Option<&str>,
184 base_url: Option<&str>,
185 ) -> Result<Self, CoreError> {
186 let api_key = match api_key {
187 Some(k) => k.to_string(),
188 None => auth::get_api_key(None).ok_or_else(|| {
189 CoreError::Authentication("SYNTH_API_KEY not found".to_string())
190 })?,
191 };
192
193 let client = SynthClient::new(&api_key, base_url)?;
194
195 Ok(Self {
196 client,
197 job_id: Some(job_id.to_string()),
198 config: Value::Null,
199 tracker: ProgressTracker::new(),
200 })
201 }
202
203 pub fn job_id(&self) -> Option<&str> {
205 self.job_id.as_deref()
206 }
207
208 pub fn tracker(&self) -> &ProgressTracker {
210 &self.tracker
211 }
212
213 pub async fn submit(&mut self) -> Result<String, CoreError> {
217 if self.job_id.is_some() {
218 return Err(CoreError::Validation(
219 "job already submitted".to_string(),
220 ));
221 }
222
223 if self.config.is_null() {
224 return Err(CoreError::Validation(
225 "no configuration provided".to_string(),
226 ));
227 }
228
229 let job_id = self.client.jobs().submit_raw(self.config.clone()).await?;
231 self.job_id = Some(job_id.clone());
232
233 Ok(job_id)
234 }
235
236 pub async fn get_status(&self) -> Result<PromptLearningResult, CoreError> {
238 let job_id = self.job_id.as_ref().ok_or_else(|| {
239 CoreError::Validation("job not submitted yet".to_string())
240 })?;
241
242 let result = self.client.jobs().get_status(job_id).await?;
243
244 Ok(PromptLearningResult {
245 job_id: result.job_id,
246 status: result.status,
247 best_score: result.best_score,
248 best_prompt: result.best_prompt,
249 baseline_score: None,
250 candidates_evaluated: result.candidates_evaluated.unwrap_or(0),
251 generations_completed: result.generations_completed.unwrap_or(0),
252 error: result.error,
253 raw: Value::Null,
254 })
255 }
256
257 pub async fn poll_until_complete(
264 &self,
265 timeout_secs: f64,
266 interval_secs: f64,
267 ) -> Result<PromptLearningResult, CoreError> {
268 let job_id = self.job_id.as_ref().ok_or_else(|| {
269 CoreError::Validation("job not submitted yet".to_string())
270 })?;
271
272 let result = self
273 .client
274 .jobs()
275 .poll_until_complete(job_id, timeout_secs, interval_secs)
276 .await?;
277
278 Ok(PromptLearningResult {
279 job_id: result.job_id,
280 status: result.status,
281 best_score: result.best_score,
282 best_prompt: result.best_prompt,
283 baseline_score: None,
284 candidates_evaluated: result.candidates_evaluated.unwrap_or(0),
285 generations_completed: result.generations_completed.unwrap_or(0),
286 error: result.error,
287 raw: Value::Null,
288 })
289 }
290
291 pub async fn stream_until_complete<F>(
298 &mut self,
299 timeout_secs: f64,
300 mut on_event: Option<F>,
301 ) -> Result<PromptLearningResult, CoreError>
302 where
303 F: FnMut(&ParsedEvent),
304 {
305 use std::cell::Cell;
306
307 let job_id = self.job_id.as_ref().ok_or_else(|| {
308 CoreError::Validation("job not submitted yet".to_string())
309 })?;
310
311 let mut stream = EventStream::new(
312 self.client.http().clone(),
313 self.client.base_url(),
314 job_id,
315 );
316
317 let timeout = Duration::from_secs_f64(timeout_secs);
318 let poll_interval = Duration::from_secs(5);
319
320 let terminal_reached = Cell::new(false);
322
323 {
324 let tracker = &mut self.tracker;
325
326 stream
327 .stream_until(
328 |event| {
329 tracker.update(event);
331
332 if let Some(ref mut cb) = on_event {
334 cb(event);
335 }
336
337 if event.category.is_terminal() {
339 terminal_reached.set(true);
340 }
341 },
342 timeout,
343 poll_interval,
344 || terminal_reached.get(),
345 )
346 .await?;
347 }
348
349 let status_result = self.get_status().await?;
351
352 Ok(PromptLearningResult {
354 job_id: status_result.job_id,
355 status: status_result.status,
356 best_score: status_result.best_score.or(Some(self.tracker.best_score())),
357 best_prompt: status_result.best_prompt,
358 baseline_score: self.tracker.baseline_score(),
359 candidates_evaluated: self.tracker.progress.candidates_evaluated,
360 generations_completed: self.tracker.progress.generations_completed,
361 error: status_result.error,
362 raw: Value::Null,
363 })
364 }
365
366 pub async fn cancel(&self, reason: Option<&str>) -> Result<(), CoreError> {
372 let job_id = self.job_id.as_ref().ok_or_else(|| {
373 CoreError::Validation("job not submitted yet".to_string())
374 })?;
375
376 self.client.jobs().cancel(job_id, reason).await
377 }
378
379 pub async fn get_results(&self) -> Result<PromptResults, CoreError> {
383 let job_id = self.job_id.as_ref().ok_or_else(|| {
384 CoreError::Validation("job not submitted yet".to_string())
385 })?;
386
387 let status = self.get_status().await?;
389
390 let best_prompt = status.get_system_prompt();
391 let best_score = status.best_score.or(Some(self.tracker.best_score()));
392
393 let mut top_prompts: Vec<RankedPrompt> = self
395 .tracker
396 .candidates
397 .iter()
398 .filter(|c| c.accepted || c.is_pareto)
399 .map(|c| RankedPrompt {
400 rank: 0,
401 candidate_id: c.candidate_id.clone(),
402 train_accuracy: c.accuracy,
403 val_accuracy: c.val_accuracy,
404 prompt: None,
405 })
406 .collect();
407
408 top_prompts.sort_by(|a, b| {
410 let a_score = a.val_accuracy.or(a.train_accuracy).unwrap_or(0.0);
411 let b_score = b.val_accuracy.or(b.train_accuracy).unwrap_or(0.0);
412 b_score.partial_cmp(&a_score).unwrap_or(std::cmp::Ordering::Equal)
413 });
414
415 for (i, prompt) in top_prompts.iter_mut().enumerate() {
417 prompt.rank = (i + 1) as i32;
418 }
419
420 Ok(PromptResults {
421 best_prompt,
422 best_score,
423 top_prompts,
424 })
425 }
426}
427
428#[cfg(test)]
429mod tests {
430 use super::*;
431 use serde_json::json;
432
433 #[test]
434 fn test_result_status() {
435 let result = PromptLearningResult {
436 job_id: "test".to_string(),
437 status: PolicyJobStatus::Succeeded,
438 best_score: Some(0.85),
439 best_prompt: None,
440 baseline_score: None,
441 candidates_evaluated: 10,
442 generations_completed: 3,
443 error: None,
444 raw: Value::Null,
445 };
446
447 assert!(result.succeeded());
448 assert!(!result.failed());
449 assert!(result.is_terminal());
450 }
451
452 #[test]
453 fn test_result_get_system_prompt() {
454 let result = PromptLearningResult {
455 job_id: "test".to_string(),
456 status: PolicyJobStatus::Succeeded,
457 best_score: Some(0.85),
458 best_prompt: Some(json!({
459 "system_prompt": "You are a helpful assistant."
460 })),
461 baseline_score: None,
462 candidates_evaluated: 10,
463 generations_completed: 3,
464 error: None,
465 raw: Value::Null,
466 };
467
468 assert_eq!(
469 result.get_system_prompt(),
470 Some("You are a helpful assistant.".to_string())
471 );
472 }
473
474 #[test]
475 fn test_ranked_prompt_sorting() {
476 let mut prompts = vec![
477 RankedPrompt {
478 rank: 0,
479 candidate_id: "a".to_string(),
480 train_accuracy: Some(0.7),
481 val_accuracy: None,
482 prompt: None,
483 },
484 RankedPrompt {
485 rank: 0,
486 candidate_id: "b".to_string(),
487 train_accuracy: Some(0.9),
488 val_accuracy: None,
489 prompt: None,
490 },
491 RankedPrompt {
492 rank: 0,
493 candidate_id: "c".to_string(),
494 train_accuracy: Some(0.8),
495 val_accuracy: Some(0.85),
496 prompt: None,
497 },
498 ];
499
500 prompts.sort_by(|a, b| {
502 let a_score = a.val_accuracy.or(a.train_accuracy).unwrap_or(0.0);
503 let b_score = b.val_accuracy.or(b.train_accuracy).unwrap_or(0.0);
504 b_score.partial_cmp(&a_score).unwrap_or(std::cmp::Ordering::Equal)
505 });
506
507 assert_eq!(prompts[0].candidate_id, "b"); assert_eq!(prompts[1].candidate_id, "c"); assert_eq!(prompts[2].candidate_id, "a"); }
511}