Skip to main content

synth_ai_core/api/
graphs.rs

1//! Graphs API client.
2//!
3//! This module provides methods for graph completions and verifier inference.
4
5use serde_json::{json, Value};
6
7use crate::http::HttpError;
8use crate::CoreError;
9
10use super::client::SynthClient;
11use super::types::{
12    GraphCompletionRequest, GraphCompletionResponse, RlmOptions, VerifierOptions, VerifierResponse,
13};
14
15/// API endpoint for graph completions.
16const GRAPHS_ENDPOINT: &str = "/api/graphs/completions";
17
18/// Default verifier graph ID.
19pub const DEFAULT_VERIFIER: &str = "zero_shot_verifier_rubric_single";
20
21/// RLM v1 verifier for large contexts.
22pub const RLM_VERIFIER_V1: &str = "zero_shot_verifier_rubric_rlm";
23
24/// RLM v2 verifier (multi-agent).
25pub const RLM_VERIFIER_V2: &str = "zero_shot_verifier_rubric_rlm_v2";
26
27/// Graphs API client.
28///
29/// Use this for graph completions and verifier inference.
30pub struct GraphsClient<'a> {
31    client: &'a SynthClient,
32}
33
34impl<'a> GraphsClient<'a> {
35    /// Create a new Graphs client.
36    pub(crate) fn new(client: &'a SynthClient) -> Self {
37        Self { client }
38    }
39
40    /// Execute a graph completion.
41    ///
42    /// # Arguments
43    ///
44    /// * `request` - The graph completion request
45    ///
46    /// # Returns
47    ///
48    /// The graph output.
49    ///
50    /// # Example
51    ///
52    /// ```ignore
53    /// let response = client.graphs().complete(GraphCompletionRequest {
54    ///     job_id: "my-graph".into(),
55    ///     input: json!({"prompt": "Hello"}),
56    ///     model: None,
57    ///     stream: None,
58    /// }).await?;
59    /// ```
60    pub async fn complete(
61        &self,
62        request: GraphCompletionRequest,
63    ) -> Result<GraphCompletionResponse, CoreError> {
64        let body = serde_json::to_value(&request)
65            .map_err(|e| CoreError::Validation(format!("failed to serialize request: {}", e)))?;
66
67        self.client
68            .http
69            .post_json(GRAPHS_ENDPOINT, &body)
70            .await
71            .map_err(map_http_error)
72    }
73
74    /// Execute a raw graph completion from a JSON value.
75    pub async fn complete_raw(&self, request: Value) -> Result<Value, CoreError> {
76        self.client
77            .http
78            .post_json(GRAPHS_ENDPOINT, &request)
79            .await
80            .map_err(map_http_error)
81    }
82
83    /// Run verifier inference on a trace.
84    ///
85    /// This evaluates a trace against a rubric using the verifier graph.
86    ///
87    /// # Arguments
88    ///
89    /// * `trace` - The trace to verify (JSON object with events)
90    /// * `rubric` - The rubric to evaluate against
91    /// * `options` - Optional verifier configuration
92    ///
93    /// # Returns
94    ///
95    /// The verification result with scores and reviews.
96    ///
97    /// # Example
98    ///
99    /// ```ignore
100    /// let result = client.graphs().verify(
101    ///     json!({
102    ///         "events": [
103    ///             {"type": "user_message", "content": "Hello"},
104    ///             {"type": "assistant_message", "content": "Hi there!"}
105    ///         ]
106    ///     }),
107    ///     json!({
108    ///         "objectives": [
109    ///             {"name": "helpfulness", "description": "Be helpful"}
110    ///         ]
111    ///     }),
112    ///     None,
113    /// ).await?;
114    /// println!("Objectives: {:?}", result.objectives);
115    /// ```
116    pub async fn verify(
117        &self,
118        trace: Value,
119        rubric: Value,
120        options: Option<VerifierOptions>,
121    ) -> Result<VerifierResponse, CoreError> {
122        let options = options.unwrap_or_default();
123        let verifier_id = options
124            .verifier_id
125            .as_deref()
126            .unwrap_or(DEFAULT_VERIFIER);
127
128        let mut input = json!({
129            "trace": trace,
130            "rubric": rubric,
131        });
132
133        if let Some(model) = &options.model {
134            input["model"] = json!(model);
135        }
136
137        let request = GraphCompletionRequest {
138            job_id: verifier_id.to_string(),
139            input,
140            model: options.model.clone(),
141            stream: Some(false),
142        };
143
144        let body = serde_json::to_value(&request)
145            .map_err(|e| CoreError::Validation(format!("failed to serialize request: {}", e)))?;
146
147        let response: GraphCompletionResponse = self
148            .client
149            .http
150            .post_json(GRAPHS_ENDPOINT, &body)
151            .await
152            .map_err(map_http_error)?;
153
154        // Parse the output as VerifierResponse
155        serde_json::from_value(response.output.clone()).map_err(|e| {
156            CoreError::Validation(format!(
157                "failed to parse verifier response: {} (output: {:?})",
158                e, response.output
159            ))
160        })
161    }
162
163    /// Run RLM (Retrieval-augmented LM) inference.
164    ///
165    /// This is useful for large context scenarios where the full trace
166    /// doesn't fit in a single context window.
167    ///
168    /// # Arguments
169    ///
170    /// * `query` - The query/question to answer
171    /// * `context` - The context to search through
172    /// * `options` - Optional RLM configuration
173    ///
174    /// # Returns
175    ///
176    /// The RLM output as a JSON value.
177    pub async fn rlm_inference(
178        &self,
179        query: &str,
180        context: Value,
181        options: Option<RlmOptions>,
182    ) -> Result<Value, CoreError> {
183        let options = options.unwrap_or_default();
184        let rlm_id = options.rlm_id.as_deref().unwrap_or(RLM_VERIFIER_V1);
185
186        let mut input = json!({
187            "query": query,
188            "context": context,
189        });
190
191        if let Some(max_tokens) = options.max_context_tokens {
192            input["max_context_tokens"] = json!(max_tokens);
193        }
194
195        let request = GraphCompletionRequest {
196            job_id: rlm_id.to_string(),
197            input,
198            model: options.model,
199            stream: Some(false),
200        };
201
202        let body = serde_json::to_value(&request)
203            .map_err(|e| CoreError::Validation(format!("failed to serialize request: {}", e)))?;
204
205        let response: GraphCompletionResponse = self
206            .client
207            .http
208            .post_json(GRAPHS_ENDPOINT, &body)
209            .await
210            .map_err(map_http_error)?;
211
212        Ok(response.output)
213    }
214
215    /// Execute a policy/prompt from a job.
216    ///
217    /// This runs inference using a trained policy from a completed
218    /// optimization job.
219    ///
220    /// # Arguments
221    ///
222    /// * `job_id` - The optimization job ID
223    /// * `input` - The input to the policy
224    /// * `model` - Optional model override
225    ///
226    /// # Returns
227    ///
228    /// The policy output.
229    pub async fn policy_inference(
230        &self,
231        job_id: &str,
232        input: Value,
233        model: Option<&str>,
234    ) -> Result<Value, CoreError> {
235        let request = GraphCompletionRequest {
236            job_id: job_id.to_string(),
237            input,
238            model: model.map(|s| s.to_string()),
239            stream: Some(false),
240        };
241
242        let body = serde_json::to_value(&request)
243            .map_err(|e| CoreError::Validation(format!("failed to serialize request: {}", e)))?;
244
245        let response: GraphCompletionResponse = self
246            .client
247            .http
248            .post_json(GRAPHS_ENDPOINT, &body)
249            .await
250            .map_err(map_http_error)?;
251
252        Ok(response.output)
253    }
254}
255
256/// Map HTTP errors to CoreError.
257fn map_http_error(e: HttpError) -> CoreError {
258    match e {
259        HttpError::Response(detail) => {
260            if detail.status == 401 || detail.status == 403 {
261                CoreError::Authentication(format!("authentication failed: {}", detail))
262            } else if detail.status == 429 {
263                CoreError::UsageLimit(crate::UsageLimitInfo {
264                    limit_type: "rate_limit".to_string(),
265                    api: "graphs".to_string(),
266                    current: 0.0,
267                    limit: 0.0,
268                    tier: "unknown".to_string(),
269                    retry_after_seconds: None,
270                    upgrade_url: "https://usesynth.ai/pricing".to_string(),
271                })
272            } else {
273                CoreError::HttpResponse(crate::HttpErrorInfo {
274                    status: detail.status,
275                    url: detail.url,
276                    message: detail.message,
277                    body_snippet: detail.body_snippet,
278                })
279            }
280        }
281        HttpError::Request(e) => CoreError::Http(e),
282        _ => CoreError::Internal(format!("{}", e)),
283    }
284}
285
286#[cfg(test)]
287mod tests {
288    use super::*;
289
290    #[test]
291    fn test_graphs_endpoint() {
292        assert_eq!(GRAPHS_ENDPOINT, "/api/graphs/completions");
293    }
294
295    #[test]
296    fn test_verifier_constants() {
297        assert_eq!(DEFAULT_VERIFIER, "zero_shot_verifier_rubric_single");
298        assert_eq!(RLM_VERIFIER_V1, "zero_shot_verifier_rubric_rlm");
299        assert_eq!(RLM_VERIFIER_V2, "zero_shot_verifier_rubric_rlm_v2");
300    }
301}