Skip to main content

synth_ai_core/api/
graphs.rs

1//! Graphs API client.
2//!
3//! This module provides methods for graph completions and verifier inference.
4
5use serde_json::Map;
6use serde_json::{json, Value};
7
8use crate::http::HttpError;
9use crate::CoreError;
10
11use super::client::SynthClient;
12use super::types::{
13    GraphCompletionRequest, GraphCompletionResponse, RlmOptions, VerifierOptions, VerifierResponse,
14};
15
16/// API endpoint for graph completions.
17const GRAPHS_ENDPOINT: &str = "/api/graphs/completions";
18
19/// Default verifier graph ID.
20pub const DEFAULT_VERIFIER: &str = "zero_shot_verifier_rubric_single";
21
22/// RLM v1 verifier for large contexts.
23pub const RLM_VERIFIER_V1: &str = "zero_shot_verifier_rubric_rlm";
24
25/// RLM v2 verifier (multi-agent).
26pub const RLM_VERIFIER_V2: &str = "zero_shot_verifier_rubric_rlm_v2";
27
28/// Graphs API client.
29///
30/// Use this for graph completions and verifier inference.
31pub struct GraphsClient<'a> {
32    client: &'a SynthClient,
33}
34
35impl<'a> GraphsClient<'a> {
36    /// Create a new Graphs client.
37    pub(crate) fn new(client: &'a SynthClient) -> Self {
38        Self { client }
39    }
40
41    /// Execute a graph completion.
42    ///
43    /// # Arguments
44    ///
45    /// * `request` - The graph completion request
46    ///
47    /// # Returns
48    ///
49    /// The graph output.
50    ///
51    /// # Example
52    ///
53    /// ```ignore
54    /// let response = client.graphs().complete(GraphCompletionRequest {
55    ///     job_id: "my-graph".into(),
56    ///     input: json!({"prompt": "Hello"}),
57    ///     model: None,
58    ///     stream: None,
59    /// }).await?;
60    /// ```
61    pub async fn complete(
62        &self,
63        request: GraphCompletionRequest,
64    ) -> Result<GraphCompletionResponse, CoreError> {
65        let body = serde_json::to_value(&request)
66            .map_err(|e| CoreError::Validation(format!("failed to serialize request: {}", e)))?;
67
68        self.client
69            .http
70            .post_json(GRAPHS_ENDPOINT, &body)
71            .await
72            .map_err(map_http_error)
73    }
74
75    /// List graphs registered to the org.
76    pub async fn list_graphs(
77        &self,
78        kind: Option<&str>,
79        limit: Option<i32>,
80    ) -> Result<Value, CoreError> {
81        let mut params = Vec::new();
82        let limit_str;
83        if let Some(limit_val) = limit {
84            limit_str = limit_val.to_string();
85            params.push(("limit", limit_str.as_str()));
86        }
87
88        let kind_val;
89        if let Some(kind_val_raw) = kind {
90            kind_val = kind_val_raw.to_string();
91            params.push(("kind", kind_val.as_str()));
92        }
93
94        let params_ref: Option<&[(&str, &str)]> = if params.is_empty() {
95            None
96        } else {
97            Some(&params)
98        };
99
100        self.client
101            .http
102            .get_json("/graph-evolve/graphs", params_ref)
103            .await
104            .map_err(map_http_error)
105    }
106
107    /// Execute a raw graph completion from a JSON value.
108    pub async fn complete_raw(&self, request: Value) -> Result<Value, CoreError> {
109        self.client
110            .http
111            .post_json(GRAPHS_ENDPOINT, &request)
112            .await
113            .map_err(map_http_error)
114    }
115
116    /// Run verifier inference on a trace.
117    ///
118    /// This evaluates a trace against a rubric using the verifier graph.
119    ///
120    /// # Arguments
121    ///
122    /// * `trace` - The trace to verify (JSON object with events)
123    /// * `rubric` - The rubric to evaluate against
124    /// * `options` - Optional verifier configuration
125    ///
126    /// # Returns
127    ///
128    /// The verification result with scores and reviews.
129    ///
130    /// # Example
131    ///
132    /// ```ignore
133    /// let result = client.graphs().verify(
134    ///     json!({
135    ///         "events": [
136    ///             {"type": "user_message", "content": "Hello"},
137    ///             {"type": "assistant_message", "content": "Hi there!"}
138    ///         ]
139    ///     }),
140    ///     json!({
141    ///         "objectives": [
142    ///             {"name": "helpfulness", "description": "Be helpful"}
143    ///         ]
144    ///     }),
145    ///     None,
146    /// ).await?;
147    /// println!("Objectives: {:?}", result.objectives);
148    /// ```
149    pub async fn verify(
150        &self,
151        trace: Value,
152        rubric: Value,
153        options: Option<VerifierOptions>,
154    ) -> Result<VerifierResponse, CoreError> {
155        let options = options.unwrap_or_default();
156        let verifier_id = options.verifier_id.as_deref().unwrap_or(DEFAULT_VERIFIER);
157
158        let mut input = json!({
159            "trace": trace,
160            "rubric": rubric,
161        });
162
163        if let Some(model) = &options.model {
164            input["model"] = json!(model);
165        }
166
167        let request = GraphCompletionRequest {
168            job_id: verifier_id.to_string(),
169            input,
170            model: options.model.clone(),
171            prompt_snapshot_id: None,
172            stream: Some(false),
173        };
174
175        let body = serde_json::to_value(&request)
176            .map_err(|e| CoreError::Validation(format!("failed to serialize request: {}", e)))?;
177
178        let response: GraphCompletionResponse = self
179            .client
180            .http
181            .post_json(GRAPHS_ENDPOINT, &body)
182            .await
183            .map_err(map_http_error)?;
184
185        // Parse the output as VerifierResponse
186        serde_json::from_value(response.output.clone()).map_err(|e| {
187            CoreError::Validation(format!(
188                "failed to parse verifier response: {} (output: {:?})",
189                e, response.output
190            ))
191        })
192    }
193
194    /// Run RLM (Retrieval-augmented LM) inference.
195    ///
196    /// This is useful for large context scenarios where the full trace
197    /// doesn't fit in a single context window.
198    ///
199    /// # Arguments
200    ///
201    /// * `query` - The query/question to answer
202    /// * `context` - The context to search through
203    /// * `options` - Optional RLM configuration
204    ///
205    /// # Returns
206    ///
207    /// The RLM output as a JSON value.
208    pub async fn rlm_inference(
209        &self,
210        query: &str,
211        context: Value,
212        options: Option<RlmOptions>,
213    ) -> Result<Value, CoreError> {
214        let options = options.unwrap_or_default();
215        let rlm_id = options.rlm_id.as_deref().unwrap_or(RLM_VERIFIER_V1);
216
217        let mut input = json!({
218            "query": query,
219            "context": context,
220        });
221
222        if let Some(max_tokens) = options.max_context_tokens {
223            input["max_context_tokens"] = json!(max_tokens);
224        }
225
226        let request = GraphCompletionRequest {
227            job_id: rlm_id.to_string(),
228            input,
229            model: options.model,
230            prompt_snapshot_id: None,
231            stream: Some(false),
232        };
233
234        let body = serde_json::to_value(&request)
235            .map_err(|e| CoreError::Validation(format!("failed to serialize request: {}", e)))?;
236
237        let response: GraphCompletionResponse = self
238            .client
239            .http
240            .post_json(GRAPHS_ENDPOINT, &body)
241            .await
242            .map_err(map_http_error)?;
243
244        Ok(response.output)
245    }
246
247    /// Execute a policy/prompt from a job.
248    ///
249    /// This runs inference using a trained policy from a completed
250    /// optimization job.
251    ///
252    /// # Arguments
253    ///
254    /// * `job_id` - The optimization job ID
255    /// * `input` - The input to the policy
256    /// * `model` - Optional model override
257    ///
258    /// # Returns
259    ///
260    /// The policy output.
261    pub async fn policy_inference(
262        &self,
263        job_id: &str,
264        input: Value,
265        model: Option<&str>,
266    ) -> Result<Value, CoreError> {
267        let request = GraphCompletionRequest {
268            job_id: job_id.to_string(),
269            input,
270            model: model.map(|s| s.to_string()),
271            prompt_snapshot_id: None,
272            stream: Some(false),
273        };
274
275        let body = serde_json::to_value(&request)
276            .map_err(|e| CoreError::Validation(format!("failed to serialize request: {}", e)))?;
277
278        let response: GraphCompletionResponse = self
279            .client
280            .http
281            .post_json(GRAPHS_ENDPOINT, &body)
282            .await
283            .map_err(map_http_error)?;
284
285        Ok(response.output)
286    }
287}
288
289/// Build a verifier graph completion request from SDK inputs.
290///
291/// This matches the Python SDK behavior for `verify_with_rubric`:
292/// - Uses trace_ref if provided, otherwise trace_content
293/// - Defaults to RLM for trace_ref, otherwise chooses by estimated size
294/// - Supports optional system/user prompts and options payload
295pub fn build_verifier_request(
296    trace_content: Option<Value>,
297    trace_ref: Option<String>,
298    rubric: Value,
299    system_prompt: Option<String>,
300    user_prompt: Option<String>,
301    options: Option<Value>,
302    model: Option<String>,
303    verifier_shape: Option<String>,
304    rlm_impl: Option<String>,
305) -> Result<GraphCompletionRequest, CoreError> {
306    let has_ref = trace_ref.as_ref().map(|s| !s.is_empty()).unwrap_or(false);
307    let has_content = trace_content.is_some();
308
309    if !has_ref && !has_content {
310        return Err(CoreError::Validation(
311            "trace_content or trace_ref is required".to_string(),
312        ));
313    }
314
315    let shape = match verifier_shape.as_deref() {
316        Some("single") => "single",
317        Some("rlm") => "rlm",
318        Some(other) => {
319            return Err(CoreError::Validation(format!(
320                "unsupported verifier_shape: {}",
321                other
322            )))
323        }
324        None => {
325            if has_ref {
326                "rlm"
327            } else {
328                let tokens = estimate_trace_tokens(trace_content.as_ref().unwrap())?;
329                if tokens < 50_000 {
330                    "single"
331                } else {
332                    "rlm"
333                }
334            }
335        }
336    };
337
338    let verifier_id = if shape == "single" {
339        if rlm_impl
340            .as_deref()
341            .map(|value| !value.is_empty())
342            .unwrap_or(false)
343        {
344            return Err(CoreError::Validation(
345                "rlm_impl is only valid when verifier_shape is 'rlm'".to_string(),
346            ));
347        }
348        DEFAULT_VERIFIER
349    } else if matches!(rlm_impl.as_deref(), Some("v2")) {
350        RLM_VERIFIER_V2
351    } else {
352        RLM_VERIFIER_V1
353    };
354
355    let mut input = Map::new();
356    input.insert("rubric".to_string(), rubric);
357    input.insert(
358        "options".to_string(),
359        options.unwrap_or_else(|| Value::Object(Map::new())),
360    );
361
362    if let Some(trace_ref) = trace_ref {
363        if !trace_ref.is_empty() {
364            input.insert("trace_ref".to_string(), Value::String(trace_ref));
365        }
366    }
367
368    if !input.contains_key("trace_ref") {
369        if let Some(trace_content) = trace_content {
370            input.insert("trace_content".to_string(), trace_content);
371        }
372    }
373
374    if let Some(system_prompt) = system_prompt {
375        input.insert("system_prompt".to_string(), Value::String(system_prompt));
376    }
377    if let Some(user_prompt) = user_prompt {
378        input.insert("user_prompt".to_string(), Value::String(user_prompt));
379    }
380
381    Ok(GraphCompletionRequest {
382        job_id: verifier_id.to_string(),
383        input: Value::Object(input),
384        model,
385        prompt_snapshot_id: None,
386        stream: Some(false),
387    })
388}
389
390fn estimate_trace_tokens(trace: &Value) -> Result<usize, CoreError> {
391    let payload = serde_json::to_string(trace)
392        .map_err(|e| CoreError::Validation(format!("failed to serialize trace: {}", e)))?;
393    Ok(payload.len() / 4)
394}
395
396/// Resolve a graph job ID from explicit job_id or graph target spec.
397///
398/// Mirrors the Python SDK logic for graph target resolution.
399pub fn resolve_graph_job_id(
400    job_id: Option<String>,
401    graph: Option<Value>,
402) -> Result<String, CoreError> {
403    if let Some(job_id) = job_id {
404        if !job_id.trim().is_empty() {
405            return Ok(job_id);
406        }
407    }
408
409    let graph = graph
410        .ok_or_else(|| CoreError::Validation("graph_completions_missing_job_id".to_string()))?;
411
412    let graph_obj = graph
413        .as_object()
414        .ok_or_else(|| CoreError::Validation("graph target must be an object".to_string()))?;
415
416    if let Some(Value::String(job_id)) = graph_obj.get("job_id") {
417        if !job_id.trim().is_empty() {
418            return Ok(job_id.clone());
419        }
420    }
421
422    let kind = graph_obj.get("kind").and_then(|v| v.as_str()).unwrap_or("");
423    if kind == "zero_shot" {
424        if let Some(shape) = graph_obj
425            .get("verifier_shape")
426            .and_then(|v| v.as_str())
427            .or_else(|| graph_obj.get("graph_name").and_then(|v| v.as_str()))
428        {
429            return Ok(shape.to_string());
430        }
431        return Err(CoreError::Validation(
432            "graph_completions_missing_verifier_shape".to_string(),
433        ));
434    }
435
436    if kind == "graphgen" {
437        if let Some(graphgen_job_id) = graph_obj.get("graphgen_job_id").and_then(|v| v.as_str()) {
438            return Ok(graphgen_job_id.to_string());
439        }
440        return Err(CoreError::Validation(
441            "graph_completions_missing_graphgen_job_id".to_string(),
442        ));
443    }
444
445    if let Some(graph_name) = graph_obj.get("graph_name").and_then(|v| v.as_str()) {
446        return Ok(graph_name.to_string());
447    }
448
449    Err(CoreError::Validation(
450        "graph_completions_missing_graph_target".to_string(),
451    ))
452}
453
454/// Map HTTP errors to CoreError.
455fn map_http_error(e: HttpError) -> CoreError {
456    match e {
457        HttpError::Response(detail) => {
458            if detail.status == 401 || detail.status == 403 {
459                CoreError::Authentication(format!("authentication failed: {}", detail))
460            } else if detail.status == 429 {
461                CoreError::UsageLimit(crate::UsageLimitInfo::from_http_429("graphs", &detail))
462            } else {
463                CoreError::HttpResponse(crate::HttpErrorInfo {
464                    status: detail.status,
465                    url: detail.url,
466                    message: detail.message,
467                    body_snippet: detail.body_snippet,
468                })
469            }
470        }
471        HttpError::Request(e) => CoreError::Http(e),
472        _ => CoreError::Internal(format!("{}", e)),
473    }
474}
475
476#[cfg(test)]
477mod tests {
478    use super::*;
479
480    #[test]
481    fn test_graphs_endpoint() {
482        assert_eq!(GRAPHS_ENDPOINT, "/api/graphs/completions");
483    }
484
485    #[test]
486    fn test_verifier_constants() {
487        assert_eq!(DEFAULT_VERIFIER, "zero_shot_verifier_rubric_single");
488        assert_eq!(RLM_VERIFIER_V1, "zero_shot_verifier_rubric_rlm");
489        assert_eq!(RLM_VERIFIER_V2, "zero_shot_verifier_rubric_rlm_v2");
490    }
491}