synth_ai_core/api/
graphs.rs1use serde_json::{json, Value};
6
7use crate::http::HttpError;
8use crate::CoreError;
9
10use super::client::SynthClient;
11use super::types::{
12 GraphCompletionRequest, GraphCompletionResponse, RlmOptions, VerifierOptions, VerifierResponse,
13};
14
15const GRAPHS_ENDPOINT: &str = "/api/graphs/completions";
17
18pub const DEFAULT_VERIFIER: &str = "zero_shot_verifier_rubric_single";
20
21pub const RLM_VERIFIER_V1: &str = "zero_shot_verifier_rubric_rlm";
23
24pub const RLM_VERIFIER_V2: &str = "zero_shot_verifier_rubric_rlm_v2";
26
27pub struct GraphsClient<'a> {
31 client: &'a SynthClient,
32}
33
34impl<'a> GraphsClient<'a> {
35 pub(crate) fn new(client: &'a SynthClient) -> Self {
37 Self { client }
38 }
39
40 pub async fn complete(
61 &self,
62 request: GraphCompletionRequest,
63 ) -> Result<GraphCompletionResponse, CoreError> {
64 let body = serde_json::to_value(&request)
65 .map_err(|e| CoreError::Validation(format!("failed to serialize request: {}", e)))?;
66
67 self.client
68 .http
69 .post_json(GRAPHS_ENDPOINT, &body)
70 .await
71 .map_err(map_http_error)
72 }
73
74 pub async fn complete_raw(&self, request: Value) -> Result<Value, CoreError> {
76 self.client
77 .http
78 .post_json(GRAPHS_ENDPOINT, &request)
79 .await
80 .map_err(map_http_error)
81 }
82
83 pub async fn verify(
117 &self,
118 trace: Value,
119 rubric: Value,
120 options: Option<VerifierOptions>,
121 ) -> Result<VerifierResponse, CoreError> {
122 let options = options.unwrap_or_default();
123 let verifier_id = options
124 .verifier_id
125 .as_deref()
126 .unwrap_or(DEFAULT_VERIFIER);
127
128 let mut input = json!({
129 "trace": trace,
130 "rubric": rubric,
131 });
132
133 if let Some(model) = &options.model {
134 input["model"] = json!(model);
135 }
136
137 let request = GraphCompletionRequest {
138 job_id: verifier_id.to_string(),
139 input,
140 model: options.model.clone(),
141 stream: Some(false),
142 };
143
144 let body = serde_json::to_value(&request)
145 .map_err(|e| CoreError::Validation(format!("failed to serialize request: {}", e)))?;
146
147 let response: GraphCompletionResponse = self
148 .client
149 .http
150 .post_json(GRAPHS_ENDPOINT, &body)
151 .await
152 .map_err(map_http_error)?;
153
154 serde_json::from_value(response.output.clone()).map_err(|e| {
156 CoreError::Validation(format!(
157 "failed to parse verifier response: {} (output: {:?})",
158 e, response.output
159 ))
160 })
161 }
162
163 pub async fn rlm_inference(
178 &self,
179 query: &str,
180 context: Value,
181 options: Option<RlmOptions>,
182 ) -> Result<Value, CoreError> {
183 let options = options.unwrap_or_default();
184 let rlm_id = options.rlm_id.as_deref().unwrap_or(RLM_VERIFIER_V1);
185
186 let mut input = json!({
187 "query": query,
188 "context": context,
189 });
190
191 if let Some(max_tokens) = options.max_context_tokens {
192 input["max_context_tokens"] = json!(max_tokens);
193 }
194
195 let request = GraphCompletionRequest {
196 job_id: rlm_id.to_string(),
197 input,
198 model: options.model,
199 stream: Some(false),
200 };
201
202 let body = serde_json::to_value(&request)
203 .map_err(|e| CoreError::Validation(format!("failed to serialize request: {}", e)))?;
204
205 let response: GraphCompletionResponse = self
206 .client
207 .http
208 .post_json(GRAPHS_ENDPOINT, &body)
209 .await
210 .map_err(map_http_error)?;
211
212 Ok(response.output)
213 }
214
215 pub async fn policy_inference(
230 &self,
231 job_id: &str,
232 input: Value,
233 model: Option<&str>,
234 ) -> Result<Value, CoreError> {
235 let request = GraphCompletionRequest {
236 job_id: job_id.to_string(),
237 input,
238 model: model.map(|s| s.to_string()),
239 stream: Some(false),
240 };
241
242 let body = serde_json::to_value(&request)
243 .map_err(|e| CoreError::Validation(format!("failed to serialize request: {}", e)))?;
244
245 let response: GraphCompletionResponse = self
246 .client
247 .http
248 .post_json(GRAPHS_ENDPOINT, &body)
249 .await
250 .map_err(map_http_error)?;
251
252 Ok(response.output)
253 }
254}
255
256fn map_http_error(e: HttpError) -> CoreError {
258 match e {
259 HttpError::Response(detail) => {
260 if detail.status == 401 || detail.status == 403 {
261 CoreError::Authentication(format!("authentication failed: {}", detail))
262 } else if detail.status == 429 {
263 CoreError::UsageLimit(crate::UsageLimitInfo {
264 limit_type: "rate_limit".to_string(),
265 api: "graphs".to_string(),
266 current: 0.0,
267 limit: 0.0,
268 tier: "unknown".to_string(),
269 retry_after_seconds: None,
270 upgrade_url: "https://usesynth.ai/pricing".to_string(),
271 })
272 } else {
273 CoreError::HttpResponse(crate::HttpErrorInfo {
274 status: detail.status,
275 url: detail.url,
276 message: detail.message,
277 body_snippet: detail.body_snippet,
278 })
279 }
280 }
281 HttpError::Request(e) => CoreError::Http(e),
282 _ => CoreError::Internal(format!("{}", e)),
283 }
284}
285
286#[cfg(test)]
287mod tests {
288 use super::*;
289
290 #[test]
291 fn test_graphs_endpoint() {
292 assert_eq!(GRAPHS_ENDPOINT, "/api/graphs/completions");
293 }
294
295 #[test]
296 fn test_verifier_constants() {
297 assert_eq!(DEFAULT_VERIFIER, "zero_shot_verifier_rubric_single");
298 assert_eq!(RLM_VERIFIER_V1, "zero_shot_verifier_rubric_rlm");
299 assert_eq!(RLM_VERIFIER_V2, "zero_shot_verifier_rubric_rlm_v2");
300 }
301}