1use std::collections::HashMap;
2use std::fs;
3use std::path::Path;
4
5use serde::{Deserialize, Serialize};
6use serde_json::{json, Map, Value};
7
8use crate::client::{AuthStyle, SynthClient};
9use crate::sse::{stream_sse, SseStream};
10use crate::types::{Result, SynthError};
11
12#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
13#[serde(rename_all = "lowercase")]
14pub enum Algorithm {
15 Gepa,
16 Mipro,
17}
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct PolicyOptimizationJobConfig {
21 pub config: Value,
22}
23
24impl PolicyOptimizationJobConfig {
25 pub fn from_json(config: Value) -> Self {
26 Self { config }
27 }
28
29 pub fn from_toml_str(input: &str) -> Result<Self> {
30 let value: toml::Value =
31 toml::from_str(input).map_err(|err| SynthError::UnexpectedResponse(err.to_string()))?;
32 let config = serde_json::to_value(value)?;
33 Ok(Self { config })
34 }
35
36 pub fn from_toml_file(path: impl AsRef<Path>) -> Result<Self> {
37 let content = fs::read_to_string(path)?;
38 Self::from_toml_str(&content)
39 }
40
41 pub fn to_payload(&self) -> Value {
42 let mut config = self.config.clone();
43 if let Value::Object(ref mut obj) = config {
44 if let Some(policy_opt) = obj.remove("policy_optimization") {
45 obj.insert("prompt_learning".to_string(), policy_opt);
46 }
47 if let Some(Value::Object(pl)) = obj.get_mut("prompt_learning") {
48 if let Some(local_url) = pl.remove("localapi_url") {
49 pl.insert("task_app_url".to_string(), local_url);
50 }
51 }
52 }
53 config
54 }
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize, Default)]
58pub struct PromptLearningResults {
59 pub best_prompt: Option<Value>,
60 pub best_score: Option<f64>,
61 pub top_prompts: Vec<Value>,
62 pub optimized_candidates: Vec<Value>,
63 pub attempted_candidates: Vec<Value>,
64 pub validation_results: Vec<Value>,
65}
66
67#[derive(Clone)]
68pub struct PolicyOptimizationJob {
69 client: SynthClient,
70 job_id: String,
71}
72
73impl PolicyOptimizationJob {
74 pub fn new(client: SynthClient, job_id: impl Into<String>) -> Self {
75 Self {
76 client,
77 job_id: job_id.into(),
78 }
79 }
80
81 pub fn job_id(&self) -> &str {
82 &self.job_id
83 }
84
85 pub async fn submit(client: SynthClient, config: &PolicyOptimizationJobConfig) -> Result<Self> {
86 let payload = config.to_payload();
87 let algorithm = payload
88 .get("prompt_learning")
89 .and_then(|v| v.get("algorithm"))
90 .and_then(|v| v.as_str())
91 .unwrap_or("gepa");
92 let submit_body = json!({
93 "algorithm": algorithm,
94 "config_body": payload,
95 });
96 let resp = client
97 .post_json_fallback(
98 &[
99 "/policy-optimization/online/jobs",
100 "/prompt-learning/online/jobs",
101 ],
102 &submit_body,
103 AuthStyle::Both,
104 )
105 .await?;
106 let job_id = resp
107 .get("job_id")
108 .and_then(|v| v.as_str())
109 .ok_or_else(|| SynthError::UnexpectedResponse("missing job_id".to_string()))?;
110 Ok(Self::new(client, job_id))
111 }
112
113 pub async fn status(&self) -> Result<Value> {
114 let path = format!(
115 "/policy-optimization/online/jobs/{}",
116 self.job_id
117 );
118 let fallback = format!("/prompt-learning/online/jobs/{}", self.job_id);
119 self.client
120 .get_json_fallback(
121 &[path.as_str(), fallback.as_str()],
122 AuthStyle::Both,
123 )
124 .await
125 }
126
127 pub async fn events(&self) -> Result<Vec<Value>> {
128 let path = format!(
129 "/policy-optimization/online/jobs/{}/events",
130 self.job_id
131 );
132 let fallback = format!("/prompt-learning/online/jobs/{}/events", self.job_id);
133 let value = self
134 .client
135 .get_json_fallback(
136 &[path.as_str(), fallback.as_str()],
137 AuthStyle::Both,
138 )
139 .await?;
140 parse_events(value)
141 }
142
143 pub async fn results(&self) -> Result<PromptLearningResults> {
144 let events = self.events().await?;
145 Ok(PromptLearningResults::from_events(&events))
146 }
147
148 pub async fn stream_events(&self) -> Result<SseStream> {
149 let primary = format!(
150 "{}/policy-optimization/online/jobs/{}/events/stream",
151 self.client.api_base(),
152 self.job_id
153 );
154 let fallback = format!(
155 "{}/prompt-learning/online/jobs/{}/events/stream",
156 self.client.api_base(),
157 self.job_id
158 );
159 let headers = self.client.auth_headers(AuthStyle::Both);
160 match stream_sse(self.client.http(), primary, headers.clone()).await {
161 Ok(stream) => Ok(stream),
162 Err(SynthError::Api { status: 404, .. }) => {
163 stream_sse(self.client.http(), fallback, headers).await
164 }
165 Err(err) => Err(err),
166 }
167 }
168}
169
170impl PromptLearningResults {
171 pub fn from_events(events: &[Value]) -> Self {
172 let mut results = PromptLearningResults::default();
173 let mut validation_by_rank: HashMap<i64, f64> = HashMap::new();
174
175 for event in events {
176 let event_type = event.get("type").and_then(|v| v.as_str()).unwrap_or("");
177 let data = event.get("data").and_then(|v| v.as_object());
178 if data.is_none() {
179 continue;
180 }
181 let data = data.unwrap();
182
183 match event_type {
184 "learning.policy.gepa.candidate.new_best" => {
185 results.best_prompt = data.get("best_prompt").cloned();
186 if results.best_score.is_none() {
187 results.best_score = extract_reward_value(data, &["best_score"]);
188 }
189 }
190 "learning.policy.gepa.candidate.evaluated" => {
191 if let Some(rank) = data.get("rank").and_then(|v| v.as_i64()) {
192 let mut prompt_entry = Map::new();
193 prompt_entry.insert("rank".to_string(), json!(rank));
194 prompt_entry.insert(
195 "train_accuracy".to_string(),
196 data.get("train_accuracy").cloned().unwrap_or(Value::Null),
197 );
198 prompt_entry.insert(
199 "val_accuracy".to_string(),
200 data.get("val_accuracy").cloned().unwrap_or(Value::Null),
201 );
202 if let Some(pattern) = data.get("pattern") {
203 prompt_entry.insert("pattern".to_string(), pattern.clone());
204 if let Some(text) = extract_full_text_from_pattern(pattern) {
205 prompt_entry.insert("full_text".to_string(), json!(text));
206 }
207 } else if let Some(template) = data.get("template") {
208 if let Some(pattern) = convert_template_to_pattern(template) {
209 prompt_entry.insert("pattern".to_string(), pattern.clone());
210 if let Some(text) = extract_full_text_from_pattern(&pattern) {
211 prompt_entry.insert("full_text".to_string(), json!(text));
212 }
213 }
214 }
215 results.top_prompts.push(Value::Object(prompt_entry));
216 }
217 }
218 "learning.policy.gepa.job.completed" => {
219 if let Some(cands) = data.get("optimized_candidates").and_then(|v| v.as_array())
220 {
221 results.optimized_candidates = cands.clone();
222 }
223 if let Some(cands) = data.get("attempted_candidates").and_then(|v| v.as_array())
224 {
225 results.attempted_candidates = cands.clone();
226 }
227 if results.best_prompt.is_none() {
228 results.best_prompt = data.get("best_prompt").cloned();
229 }
230 if results.best_score.is_none() {
231 results.best_score = extract_reward_value(data, &["best_score"]);
232 }
233
234 if let Some(validation) = data.get("validation").and_then(|v| v.as_array()) {
235 for val in validation {
236 if let Some(val_obj) = val.as_object() {
237 if let (Some(rank), Some(score)) = (
238 val_obj.get("rank").and_then(|v| v.as_i64()),
239 extract_reward_value(val_obj, &[]),
240 ) {
241 validation_by_rank.insert(rank, score);
242 }
243 }
244 }
245 }
246 }
247 "learning.policy.gepa.validation.completed" => {
248 results.validation_results.push(Value::Object(data.clone()));
249 if let (Some(rank), Some(score)) = (
250 data.get("rank").and_then(|v| v.as_i64()),
251 extract_reward_value(data, &[]),
252 ) {
253 validation_by_rank.insert(rank, score);
254 }
255 }
256 "learning.policy.mipro.job.completed" => {
257 if results.best_score.is_none() {
258 results.best_score = extract_reward_value(
259 data,
260 &["best_score", "best_full_score", "best_minibatch_score"],
261 );
262 }
263 }
264 _ => {}
265 }
266 }
267
268 if results.top_prompts.is_empty() && !results.optimized_candidates.is_empty() {
269 for (idx, cand) in results.optimized_candidates.iter().enumerate() {
270 let cand_obj = match cand.as_object() {
271 Some(obj) => obj,
272 None => continue,
273 };
274 let rank = cand_obj
275 .get("rank")
276 .and_then(|v| v.as_i64())
277 .unwrap_or((idx + 1) as i64);
278 let mut prompt_entry = Map::new();
279 prompt_entry.insert("rank".to_string(), json!(rank));
280
281 let train_accuracy = cand_obj
282 .get("score")
283 .and_then(|v| v.as_object())
284 .and_then(|v| extract_reward_value(v, &[]))
285 .or_else(|| extract_reward_value(cand_obj, &[]));
286 if let Some(score) = train_accuracy {
287 prompt_entry.insert("train_accuracy".to_string(), json!(score));
288 }
289 if let Some(val) = validation_by_rank.get(&rank) {
290 prompt_entry.insert("val_accuracy".to_string(), json!(*val));
291 }
292
293 if let Some(pattern) = cand_obj.get("pattern") {
294 prompt_entry.insert("pattern".to_string(), pattern.clone());
295 if let Some(text) = extract_full_text_from_pattern(pattern) {
296 prompt_entry.insert("full_text".to_string(), json!(text));
297 }
298 } else if let Some(template) = cand_obj.get("template") {
299 if let Some(pattern) = convert_template_to_pattern(template) {
300 if let Some(text) = extract_full_text_from_pattern(&pattern) {
301 prompt_entry.insert("full_text".to_string(), json!(text));
302 }
303 prompt_entry.insert("pattern".to_string(), pattern);
304 }
305 }
306
307 results.top_prompts.push(Value::Object(prompt_entry));
308 }
309 }
310
311 results
312 }
313}
314
315fn parse_events(value: Value) -> Result<Vec<Value>> {
316 if let Value::Array(items) = value {
317 return Ok(items);
318 }
319 if let Value::Object(obj) = value {
320 if let Some(Value::Array(items)) = obj.get("events") {
321 return Ok(items.clone());
322 }
323 }
324 Err(SynthError::UnexpectedResponse(
325 "events response did not contain an events list".to_string(),
326 ))
327}
328
329fn coerce_f64(value: &Value) -> Option<f64> {
330 match value {
331 Value::Number(num) => num.as_f64(),
332 Value::String(s) => s.parse::<f64>().ok(),
333 _ => None,
334 }
335}
336
337fn extract_outcome_reward(payload: &Map<String, Value>) -> Option<f64> {
338 if let Some(Value::Object(obj)) = payload.get("outcome_objectives") {
339 if let Some(val) = obj.get("reward").and_then(coerce_f64) {
340 return Some(val);
341 }
342 }
343 payload.get("outcome_reward").and_then(coerce_f64)
344}
345
346fn extract_reward_value(payload: &Map<String, Value>, fallback_keys: &[&str]) -> Option<f64> {
347 if let Some(val) = extract_outcome_reward(payload) {
348 return Some(val);
349 }
350 for key in fallback_keys {
351 if let Some(val) = payload.get(*key).and_then(coerce_f64) {
352 return Some(val);
353 }
354 }
355 None
356}
357
358fn convert_template_to_pattern(template: &Value) -> Option<Value> {
359 let sections = template
360 .get("sections")
361 .and_then(|v| v.as_array())
362 .filter(|v| !v.is_empty())
363 .or_else(|| template.get("prompt_sections").and_then(|v| v.as_array()))?;
364 let mut messages = Vec::new();
365 for sec in sections {
366 let sec_obj = sec.as_object()?;
367 let content = sec_obj.get("content")?;
368 if content.is_null() {
369 continue;
370 }
371 let role = sec_obj
372 .get("role")
373 .and_then(|v| v.as_str())
374 .or_else(|| sec_obj.get("name").and_then(|v| v.as_str()))
375 .unwrap_or("system");
376 let name = sec_obj.get("name").and_then(|v| v.as_str()).unwrap_or("");
377 messages.push(json!({
378 "role": role,
379 "name": name,
380 "pattern": content,
381 }));
382 }
383 if messages.is_empty() {
384 return None;
385 }
386 Some(json!({ "messages": messages }))
387}
388
389fn extract_full_text_from_pattern(pattern: &Value) -> Option<String> {
390 let messages = pattern.get("messages")?.as_array()?;
391 let mut parts = Vec::new();
392 for msg in messages {
393 let msg_obj = msg.as_object()?;
394 let role = msg_obj.get("role").and_then(|v| v.as_str()).unwrap_or("");
395 let name = msg_obj.get("name").and_then(|v| v.as_str()).unwrap_or("");
396 let content = msg_obj
397 .get("pattern")
398 .or_else(|| msg_obj.get("content"))
399 .and_then(|v| v.as_str())
400 .unwrap_or("");
401 parts.push(format!("[{role} | {name}]\n{content}"));
402 }
403 if parts.is_empty() {
404 None
405 } else {
406 Some(parts.join("\n\n"))
407 }
408}