Skip to main content

synth_ai_core/orchestration/
graph_evolve.rs

1//! Graph Evolve (GraphGen) helpers.
2//!
3//! These helpers normalize datasets/configs and build payloads so
4//! both Rust and Python SDKs share the same core logic.
5
6use serde_json::{json, Map, Value};
7
8use crate::data::enums::GraphType;
9use crate::errors::CoreError;
10
11const DEFAULT_POPULATION_SIZE: i64 = 4;
12const DEFAULT_NUM_PARENTS: i64 = 2;
13const DEFAULT_ROLLOUT_MAX_CONCURRENT: i64 = 25;
14const DEFAULT_ROLLOUT_TIMEOUT_SECONDS: f64 = 60.0;
15
16fn parse_graph_type(value: Option<&str>) -> Result<GraphType, CoreError> {
17    let Some(raw) = value else {
18        return Ok(GraphType::Policy);
19    };
20    match raw.trim().to_lowercase().as_str() {
21        "policy" => Ok(GraphType::Policy),
22        "verifier" => Ok(GraphType::Verifier),
23        "rlm" => Ok(GraphType::Rlm),
24        other => Err(CoreError::Validation(format!(
25            "invalid graph_type '{}'; expected 'policy', 'verifier', or 'rlm'",
26            other
27        ))),
28    }
29}
30
31fn as_i64(value: Option<&Value>) -> Option<i64> {
32    value.and_then(|v| v.as_i64())
33}
34
35fn ensure_task_list(dataset: &Map<String, Value>) -> Result<(), CoreError> {
36    match dataset.get("tasks") {
37        Some(Value::Array(tasks)) if !tasks.is_empty() => Ok(()),
38        Some(Value::Array(_)) => Err(CoreError::Validation(
39            "dataset must contain at least one task".to_string(),
40        )),
41        _ => Err(CoreError::Validation(
42            "dataset.tasks must be a non-empty list".to_string(),
43        )),
44    }
45}
46
47/// Parse and validate a Graph Evolve dataset JSON object.
48pub fn parse_graph_evolve_dataset(dataset: &Value) -> Result<Value, CoreError> {
49    let dataset_map = dataset
50        .as_object()
51        .ok_or_else(|| CoreError::Validation("dataset must be an object".to_string()))?;
52    ensure_task_list(dataset_map)?;
53    Ok(Value::Object(dataset_map.clone()))
54}
55
56/// Load and validate a Graph Evolve dataset from a JSON file.
57pub fn load_graph_evolve_dataset(path: &str) -> Result<Value, CoreError> {
58    let contents = std::fs::read_to_string(path).map_err(|e| {
59        CoreError::InvalidInput(format!("failed to read dataset file '{}': {}", path, e))
60    })?;
61    let value: Value = serde_json::from_str(&contents).map_err(|e| {
62        CoreError::Validation(format!("failed to parse dataset JSON '{}': {}", path, e))
63    })?;
64    parse_graph_evolve_dataset(&value)
65}
66
67/// Ensure policy model list is non-empty.
68pub fn normalize_graph_evolve_policy_models(models: Vec<String>) -> Result<Vec<String>, CoreError> {
69    let filtered: Vec<String> = models
70        .into_iter()
71        .map(|m| m.trim().to_string())
72        .filter(|m| !m.is_empty())
73        .collect();
74    if filtered.is_empty() {
75        return Err(CoreError::Validation(
76            "policy_models must contain at least one model".to_string(),
77        ));
78    }
79    Ok(filtered)
80}
81
82/// Build a Graph Evolve config dict with defaults.
83#[allow(clippy::too_many_arguments)]
84pub fn build_graph_evolve_config(
85    policy_models: Vec<String>,
86    rollout_budget: i64,
87    proposer_effort: &str,
88    verifier_model: Option<String>,
89    verifier_provider: Option<String>,
90    population_size: i64,
91    num_generations: Option<i64>,
92    problem_spec: Option<String>,
93    target_llm_calls: Option<i64>,
94    graph_type: Option<String>,
95    initial_graph_id: Option<String>,
96) -> Result<Value, CoreError> {
97    let policy_models = normalize_graph_evolve_policy_models(policy_models)?;
98
99    if rollout_budget < 10 || rollout_budget > 10000 {
100        return Err(CoreError::Validation(format!(
101            "rollout_budget must be between 10 and 10000, got {}",
102            rollout_budget
103        )));
104    }
105
106    let effort = proposer_effort.trim().to_lowercase();
107    if effort != "low" && effort != "medium" && effort != "high" {
108        return Err(CoreError::Validation(
109            "proposer_effort must be one of: low, medium, high".to_string(),
110        ));
111    }
112
113    if population_size < 2 || population_size > 20 {
114        return Err(CoreError::Validation(format!(
115            "population_size must be between 2 and 20, got {}",
116            population_size
117        )));
118    }
119
120    if let Some(value) = num_generations {
121        if value < 1 || value > 50 {
122            return Err(CoreError::Validation(format!(
123                "num_generations must be between 1 and 50, got {}",
124                value
125            )));
126        }
127    }
128
129    if let Some(value) = target_llm_calls {
130        if value < 1 || value > 10 {
131            return Err(CoreError::Validation(format!(
132                "target_llm_calls must be between 1 and 10, got {}",
133                value
134            )));
135        }
136    }
137
138    let initial_graph_id = initial_graph_id.ok_or_else(|| {
139        CoreError::Validation(
140            "initial_graph_id is required for Graph Evolve (de-novo graph generation is disabled)"
141                .to_string(),
142        )
143    })?;
144
145    let graph_type = parse_graph_type(graph_type.as_deref())?;
146
147    let mut map = Map::new();
148    map.insert("graph_type".to_string(), json!(graph_type));
149    map.insert("policy_models".to_string(), json!(policy_models));
150    map.insert("rollout_budget".to_string(), json!(rollout_budget));
151    map.insert(
152        "rollout_max_concurrent".to_string(),
153        json!(DEFAULT_ROLLOUT_MAX_CONCURRENT),
154    );
155    map.insert(
156        "rollout_timeout_seconds".to_string(),
157        json!(DEFAULT_ROLLOUT_TIMEOUT_SECONDS),
158    );
159    map.insert("proposer_effort".to_string(), json!(effort));
160    map.insert("population_size".to_string(), json!(population_size));
161    map.insert("num_parents".to_string(), json!(DEFAULT_NUM_PARENTS));
162    map.insert("initial_graph_id".to_string(), json!(initial_graph_id));
163
164    if let Some(value) = verifier_model {
165        map.insert("verifier_model".to_string(), json!(value));
166    }
167    if let Some(value) = verifier_provider {
168        map.insert("verifier_provider".to_string(), json!(value));
169    }
170    if let Some(value) = num_generations {
171        map.insert("num_generations".to_string(), json!(value));
172    }
173    if let Some(value) = problem_spec {
174        map.insert("problem_spec".to_string(), json!(value));
175    }
176    if let Some(value) = target_llm_calls {
177        map.insert("target_llm_calls".to_string(), json!(value));
178    }
179
180    Ok(Value::Object(map))
181}
182
183/// Build a Graph Evolve payload.
184pub fn build_graph_evolve_payload(
185    dataset: &Value,
186    config: &Value,
187    metadata: Option<&Value>,
188    auto_start: bool,
189) -> Result<Value, CoreError> {
190    let mut dataset_map = dataset
191        .as_object()
192        .ok_or_else(|| CoreError::Validation("dataset must be an object".to_string()))?
193        .clone();
194    ensure_task_list(&dataset_map)?;
195
196    let config_map = config
197        .as_object()
198        .ok_or_else(|| CoreError::Validation("config must be an object".to_string()))?;
199
200    if !dataset_map.contains_key("initial_prompt") {
201        let fallback = config_map
202            .get("problem_spec")
203            .and_then(|v| v.as_str())
204            .unwrap_or("Optimizing prompt graph...");
205        dataset_map.insert(
206            "initial_prompt".to_string(),
207            Value::String(fallback.to_string()),
208        );
209    }
210
211    let mut metadata_map = match metadata {
212        Some(Value::Object(map)) => map.clone(),
213        _ => Map::new(),
214    };
215
216    if let Some(value) = as_i64(config_map.get("num_generations")) {
217        metadata_map.insert("num_generations".to_string(), json!(value));
218    }
219    if let Some(value) = as_i64(config_map.get("population_size")) {
220        if value != DEFAULT_POPULATION_SIZE {
221            metadata_map.insert("population_size".to_string(), json!(value));
222        }
223    }
224    if let Some(value) = as_i64(config_map.get("num_parents")) {
225        if value != DEFAULT_NUM_PARENTS {
226            metadata_map.insert("num_parents".to_string(), json!(value));
227        }
228    }
229    if let Some(Value::Array(seeds)) = config_map.get("evaluation_seeds") {
230        metadata_map.insert("evaluation_seeds".to_string(), Value::Array(seeds.clone()));
231    }
232
233    let eval_sample_size = metadata_map.remove("eval_sample_size");
234    let feedback_sample_size = metadata_map.remove("feedback_sample_size");
235
236    let policy_models = config_map
237        .get("policy_models")
238        .ok_or_else(|| CoreError::Validation("policy_models missing from config".to_string()))?
239        .clone();
240    let rollout_budget = config_map
241        .get("rollout_budget")
242        .ok_or_else(|| CoreError::Validation("rollout_budget missing from config".to_string()))?
243        .clone();
244    let proposer_effort = config_map
245        .get("proposer_effort")
246        .ok_or_else(|| CoreError::Validation("proposer_effort missing from config".to_string()))?
247        .clone();
248
249    let mut payload = Map::new();
250    payload.insert("dataset".to_string(), Value::Object(dataset_map));
251    payload.insert("initial_prompt".to_string(), Value::Null);
252    payload.insert("policy_models".to_string(), policy_models);
253    payload.insert("rollout_budget".to_string(), rollout_budget);
254    payload.insert("proposer_effort".to_string(), proposer_effort);
255
256    if let Some(value) = config_map.get("policy_provider") {
257        if !value.is_null() {
258            payload.insert("policy_provider".to_string(), value.clone());
259        }
260    }
261    if let Some(value) = config_map.get("verifier_model") {
262        if !value.is_null() {
263            payload.insert("judge_model".to_string(), value.clone());
264        }
265    }
266    if let Some(value) = config_map.get("verifier_provider") {
267        if !value.is_null() {
268            payload.insert("judge_provider".to_string(), value.clone());
269        }
270    }
271    if let Some(value) = config_map.get("problem_spec") {
272        if !value.is_null() {
273            payload.insert("problem_spec".to_string(), value.clone());
274        }
275    }
276    if let Some(value) = config_map.get("target_llm_calls") {
277        if !value.is_null() {
278            payload.insert("target_llm_calls".to_string(), value.clone());
279        }
280    }
281    if let Some(value) = config_map.get("initial_graph_id") {
282        if !value.is_null() {
283            payload.insert("initial_graph_id".to_string(), value.clone());
284        } else {
285            return Err(CoreError::Validation(
286                "initial_graph_id missing from config".to_string(),
287            ));
288        }
289    } else {
290        return Err(CoreError::Validation(
291            "initial_graph_id missing from config".to_string(),
292        ));
293    }
294
295    if let Some(value) = eval_sample_size {
296        payload.insert("eval_sample_size".to_string(), value);
297    }
298    if let Some(value) = feedback_sample_size {
299        payload.insert("feedback_sample_size".to_string(), value);
300    }
301
302    payload.insert("metadata".to_string(), Value::Object(metadata_map));
303    payload.insert("auto_start".to_string(), Value::Bool(auto_start));
304
305    Ok(Value::Object(payload))
306}
307
308/// Resolve prompt/graph snapshot IDs for graph-evolve requests.
309pub fn resolve_graph_evolve_snapshot_id(
310    prompt_snapshot_id: Option<&str>,
311    graph_snapshot_id: Option<&str>,
312) -> Result<Option<String>, CoreError> {
313    if prompt_snapshot_id.is_some() && graph_snapshot_id.is_some() {
314        return Err(CoreError::Validation(
315            "Provide only one of prompt_snapshot_id or graph_snapshot_id.".to_string(),
316        ));
317    }
318    Ok(graph_snapshot_id
319        .map(|s| s.to_string())
320        .or_else(|| prompt_snapshot_id.map(|s| s.to_string())))
321}
322
323/// Build payload for graph record download.
324pub fn build_graph_evolve_graph_record_payload(
325    job_id: &str,
326    prompt_snapshot_id: Option<&str>,
327    graph_snapshot_id: Option<&str>,
328) -> Result<Value, CoreError> {
329    let snapshot_id = resolve_graph_evolve_snapshot_id(prompt_snapshot_id, graph_snapshot_id)?;
330    let mut map = Map::new();
331    map.insert("job_id".to_string(), Value::String(job_id.to_string()));
332    if let Some(snapshot_id) = snapshot_id {
333        map.insert("prompt_snapshot_id".to_string(), Value::String(snapshot_id));
334    }
335    Ok(Value::Object(map))
336}
337
338/// Build payload for graph inference.
339pub fn build_graph_evolve_inference_payload(
340    job_id: &str,
341    input: &Value,
342    model: Option<&str>,
343    prompt_snapshot_id: Option<&str>,
344    graph_snapshot_id: Option<&str>,
345) -> Result<Value, CoreError> {
346    let snapshot_id = resolve_graph_evolve_snapshot_id(prompt_snapshot_id, graph_snapshot_id)?;
347    let mut map = Map::new();
348    map.insert("job_id".to_string(), Value::String(job_id.to_string()));
349    map.insert("input".to_string(), input.clone());
350    if let Some(model) = model {
351        map.insert("model".to_string(), Value::String(model.to_string()));
352    }
353    if let Some(snapshot_id) = snapshot_id {
354        map.insert("prompt_snapshot_id".to_string(), Value::String(snapshot_id));
355    }
356    Ok(Value::Object(map))
357}
358
359/// Build a placeholder dataset for resumed jobs.
360pub fn build_graph_evolve_placeholder_dataset() -> Value {
361    json!({
362        "metadata": {"name": "(resumed job)"},
363        "tasks": [{"id": "placeholder", "input": {}}]
364    })
365}
366
367// =============================================================================
368// Graph Evolve job orchestration
369// =============================================================================
370
371use crate::api::SynthClient;
372
373/// High-level Graph Evolve job orchestration.
374pub struct GraphEvolveJob {
375    client: SynthClient,
376    job_id: Option<String>,
377    legacy_graphgen_job_id: Option<String>,
378    payload: Option<Value>,
379}
380
381impl GraphEvolveJob {
382    /// Create a job from a payload.
383    pub fn from_payload(
384        payload: Value,
385        api_key: Option<&str>,
386        base_url: Option<&str>,
387    ) -> Result<Self, CoreError> {
388        let api_key = match api_key {
389            Some(k) => k.to_string(),
390            None => crate::auth::get_api_key(None)
391                .ok_or_else(|| CoreError::Authentication("SYNTH_API_KEY not found".to_string()))?,
392        };
393
394        let client = SynthClient::new(&api_key, base_url)?;
395
396        Ok(Self {
397            client,
398            job_id: None,
399            legacy_graphgen_job_id: None,
400            payload: Some(payload),
401        })
402    }
403
404    /// Reconnect to an existing job by ID.
405    pub fn from_job_id(
406        job_id: &str,
407        api_key: Option<&str>,
408        base_url: Option<&str>,
409    ) -> Result<Self, CoreError> {
410        let api_key = match api_key {
411            Some(k) => k.to_string(),
412            None => crate::auth::get_api_key(None)
413                .ok_or_else(|| CoreError::Authentication("SYNTH_API_KEY not found".to_string()))?,
414        };
415
416        let client = SynthClient::new(&api_key, base_url)?;
417
418        let mut job = Self {
419            client,
420            job_id: Some(job_id.to_string()),
421            legacy_graphgen_job_id: None,
422            payload: None,
423        };
424
425        if job_id.starts_with("graphgen_") {
426            job.legacy_graphgen_job_id = Some(job_id.to_string());
427        }
428
429        Ok(job)
430    }
431
432    /// Get the job ID (primary or legacy).
433    pub fn job_id(&self) -> Option<&str> {
434        if let Some(id) = self.job_id.as_deref() {
435            return Some(id);
436        }
437        self.legacy_graphgen_job_id.as_deref()
438    }
439
440    /// Get the legacy GraphGen job ID, if known.
441    pub fn legacy_job_id(&self) -> Option<&str> {
442        self.legacy_graphgen_job_id.as_deref()
443    }
444
445    fn require_job_id(&self) -> Result<&str, CoreError> {
446        self.job_id()
447            .ok_or_else(|| CoreError::Validation("job not submitted yet".to_string()))
448    }
449
450    /// Submit the job and return the backend response.
451    pub async fn submit(&mut self) -> Result<Value, CoreError> {
452        if self.job_id.is_some() || self.legacy_graphgen_job_id.is_some() {
453            return Err(CoreError::Validation("job already submitted".to_string()));
454        }
455        let payload = self
456            .payload
457            .as_ref()
458            .ok_or_else(|| CoreError::Validation("payload missing".to_string()))?;
459        let response = self
460            .client
461            .graph_evolve()
462            .submit_job(payload.clone())
463            .await?;
464
465        if let Some(id) = response.get("graph_evolve_job_id").and_then(|v| v.as_str()) {
466            self.job_id = Some(id.to_string());
467        }
468        if let Some(id) = response.get("graphgen_job_id").and_then(|v| v.as_str()) {
469            self.legacy_graphgen_job_id = Some(id.to_string());
470            if self.job_id.is_none() {
471                self.job_id = Some(id.to_string());
472            }
473        }
474
475        Ok(response)
476    }
477
478    /// Get current job status.
479    pub async fn get_status(&self) -> Result<Value, CoreError> {
480        let job_id = self.require_job_id()?;
481        self.client.graph_evolve().get_status(job_id).await
482    }
483
484    /// Start a queued job.
485    pub async fn start(&self) -> Result<Value, CoreError> {
486        let job_id = self.require_job_id()?;
487        self.client.graph_evolve().start_job(job_id).await
488    }
489
490    /// Fetch events for the job.
491    pub async fn get_events(&self, since_seq: i64, limit: i64) -> Result<Value, CoreError> {
492        let job_id = self.require_job_id()?;
493        self.client
494            .graph_evolve()
495            .get_events(job_id, since_seq, limit)
496            .await
497    }
498
499    /// Fetch metrics for the job.
500    pub async fn get_metrics(&self, query_string: &str) -> Result<Value, CoreError> {
501        let job_id = self.require_job_id()?;
502        self.client
503            .graph_evolve()
504            .get_metrics(job_id, query_string)
505            .await
506    }
507
508    /// Download prompt (JSON response).
509    pub async fn download_prompt(&self) -> Result<Value, CoreError> {
510        let job_id = self.require_job_id()?;
511        self.client.graph_evolve().download_prompt(job_id).await
512    }
513
514    /// Download redacted graph export.
515    pub async fn download_graph_txt(&self) -> Result<String, CoreError> {
516        let job_id = self.require_job_id()?;
517        self.client.graph_evolve().download_graph_txt(job_id).await
518    }
519
520    /// Run inference using the optimized graph.
521    pub async fn run_inference(&self, payload: Value) -> Result<Value, CoreError> {
522        self.client.graph_evolve().run_inference(payload).await
523    }
524
525    /// Fetch a graph record snapshot.
526    pub async fn get_graph_record(&self, payload: Value) -> Result<Value, CoreError> {
527        self.client.graph_evolve().get_graph_record(payload).await
528    }
529
530    /// Cancel the job.
531    pub async fn cancel(&self, payload: Value) -> Result<Value, CoreError> {
532        let job_id = self.require_job_id()?;
533        self.client.graph_evolve().cancel_job(job_id, payload).await
534    }
535
536    /// Query workflow state.
537    pub async fn query_workflow_state(&self) -> Result<Value, CoreError> {
538        let job_id = self.require_job_id()?;
539        self.client
540            .graph_evolve()
541            .query_workflow_state(job_id)
542            .await
543    }
544}