1use serde_json::Map;
6use serde_json::{json, Value};
7
8use crate::http::HttpError;
9use crate::CoreError;
10
11use super::client::SynthClient;
12use super::types::{
13 GraphCompletionRequest, GraphCompletionResponse, RlmOptions, VerifierOptions, VerifierResponse,
14};
15
16const GRAPHS_ENDPOINT: &str = "/api/graphs/completions";
18
19pub const DEFAULT_VERIFIER: &str = "zero_shot_verifier_rubric_single";
21
22pub const RLM_VERIFIER_V1: &str = "zero_shot_verifier_rubric_rlm";
24
25pub const RLM_VERIFIER_V2: &str = "zero_shot_verifier_rubric_rlm_v2";
27
28pub struct GraphsClient<'a> {
32 client: &'a SynthClient,
33}
34
35impl<'a> GraphsClient<'a> {
36 pub(crate) fn new(client: &'a SynthClient) -> Self {
38 Self { client }
39 }
40
41 pub async fn complete(
62 &self,
63 request: GraphCompletionRequest,
64 ) -> Result<GraphCompletionResponse, CoreError> {
65 let body = serde_json::to_value(&request)
66 .map_err(|e| CoreError::Validation(format!("failed to serialize request: {}", e)))?;
67
68 self.client
69 .http
70 .post_json(GRAPHS_ENDPOINT, &body)
71 .await
72 .map_err(map_http_error)
73 }
74
75 pub async fn list_graphs(
77 &self,
78 kind: Option<&str>,
79 limit: Option<i32>,
80 ) -> Result<Value, CoreError> {
81 let mut params = Vec::new();
82 let limit_str;
83 if let Some(limit_val) = limit {
84 limit_str = limit_val.to_string();
85 params.push(("limit", limit_str.as_str()));
86 }
87
88 let kind_val;
89 if let Some(kind_val_raw) = kind {
90 kind_val = kind_val_raw.to_string();
91 params.push(("kind", kind_val.as_str()));
92 }
93
94 let params_ref: Option<&[(&str, &str)]> = if params.is_empty() {
95 None
96 } else {
97 Some(¶ms)
98 };
99
100 self.client
101 .http
102 .get_json("/graph-evolve/graphs", params_ref)
103 .await
104 .map_err(map_http_error)
105 }
106
107 pub async fn complete_raw(&self, request: Value) -> Result<Value, CoreError> {
109 self.client
110 .http
111 .post_json(GRAPHS_ENDPOINT, &request)
112 .await
113 .map_err(map_http_error)
114 }
115
116 pub async fn verify(
150 &self,
151 trace: Value,
152 rubric: Value,
153 options: Option<VerifierOptions>,
154 ) -> Result<VerifierResponse, CoreError> {
155 let options = options.unwrap_or_default();
156 let verifier_id = options.verifier_id.as_deref().unwrap_or(DEFAULT_VERIFIER);
157
158 let mut input = json!({
159 "trace": trace,
160 "rubric": rubric,
161 });
162
163 if let Some(model) = &options.model {
164 input["model"] = json!(model);
165 }
166
167 let request = GraphCompletionRequest {
168 job_id: verifier_id.to_string(),
169 input,
170 model: options.model.clone(),
171 prompt_snapshot_id: None,
172 stream: Some(false),
173 };
174
175 let body = serde_json::to_value(&request)
176 .map_err(|e| CoreError::Validation(format!("failed to serialize request: {}", e)))?;
177
178 let response: GraphCompletionResponse = self
179 .client
180 .http
181 .post_json(GRAPHS_ENDPOINT, &body)
182 .await
183 .map_err(map_http_error)?;
184
185 serde_json::from_value(response.output.clone()).map_err(|e| {
187 CoreError::Validation(format!(
188 "failed to parse verifier response: {} (output: {:?})",
189 e, response.output
190 ))
191 })
192 }
193
194 pub async fn rlm_inference(
209 &self,
210 query: &str,
211 context: Value,
212 options: Option<RlmOptions>,
213 ) -> Result<Value, CoreError> {
214 let options = options.unwrap_or_default();
215 let rlm_id = options.rlm_id.as_deref().unwrap_or(RLM_VERIFIER_V1);
216
217 let mut input = json!({
218 "query": query,
219 "context": context,
220 });
221
222 if let Some(max_tokens) = options.max_context_tokens {
223 input["max_context_tokens"] = json!(max_tokens);
224 }
225
226 let request = GraphCompletionRequest {
227 job_id: rlm_id.to_string(),
228 input,
229 model: options.model,
230 prompt_snapshot_id: None,
231 stream: Some(false),
232 };
233
234 let body = serde_json::to_value(&request)
235 .map_err(|e| CoreError::Validation(format!("failed to serialize request: {}", e)))?;
236
237 let response: GraphCompletionResponse = self
238 .client
239 .http
240 .post_json(GRAPHS_ENDPOINT, &body)
241 .await
242 .map_err(map_http_error)?;
243
244 Ok(response.output)
245 }
246
247 pub async fn policy_inference(
262 &self,
263 job_id: &str,
264 input: Value,
265 model: Option<&str>,
266 ) -> Result<Value, CoreError> {
267 let request = GraphCompletionRequest {
268 job_id: job_id.to_string(),
269 input,
270 model: model.map(|s| s.to_string()),
271 prompt_snapshot_id: None,
272 stream: Some(false),
273 };
274
275 let body = serde_json::to_value(&request)
276 .map_err(|e| CoreError::Validation(format!("failed to serialize request: {}", e)))?;
277
278 let response: GraphCompletionResponse = self
279 .client
280 .http
281 .post_json(GRAPHS_ENDPOINT, &body)
282 .await
283 .map_err(map_http_error)?;
284
285 Ok(response.output)
286 }
287}
288
289pub fn build_verifier_request(
296 trace_content: Option<Value>,
297 trace_ref: Option<String>,
298 rubric: Value,
299 system_prompt: Option<String>,
300 user_prompt: Option<String>,
301 options: Option<Value>,
302 model: Option<String>,
303 verifier_shape: Option<String>,
304 rlm_impl: Option<String>,
305) -> Result<GraphCompletionRequest, CoreError> {
306 let has_ref = trace_ref.as_ref().map(|s| !s.is_empty()).unwrap_or(false);
307 let has_content = trace_content.is_some();
308
309 if !has_ref && !has_content {
310 return Err(CoreError::Validation(
311 "trace_content or trace_ref is required".to_string(),
312 ));
313 }
314
315 let shape = match verifier_shape.as_deref() {
316 Some("single") => "single",
317 Some("rlm") => "rlm",
318 Some(other) => {
319 return Err(CoreError::Validation(format!(
320 "unsupported verifier_shape: {}",
321 other
322 )))
323 }
324 None => {
325 if has_ref {
326 "rlm"
327 } else {
328 let tokens = estimate_trace_tokens(trace_content.as_ref().unwrap())?;
329 if tokens < 50_000 {
330 "single"
331 } else {
332 "rlm"
333 }
334 }
335 }
336 };
337
338 let verifier_id = if shape == "single" {
339 if rlm_impl
340 .as_deref()
341 .map(|value| !value.is_empty())
342 .unwrap_or(false)
343 {
344 return Err(CoreError::Validation(
345 "rlm_impl is only valid when verifier_shape is 'rlm'".to_string(),
346 ));
347 }
348 DEFAULT_VERIFIER
349 } else if matches!(rlm_impl.as_deref(), Some("v2")) {
350 RLM_VERIFIER_V2
351 } else {
352 RLM_VERIFIER_V1
353 };
354
355 let mut input = Map::new();
356 input.insert("rubric".to_string(), rubric);
357 input.insert(
358 "options".to_string(),
359 options.unwrap_or_else(|| Value::Object(Map::new())),
360 );
361
362 if let Some(trace_ref) = trace_ref {
363 if !trace_ref.is_empty() {
364 input.insert("trace_ref".to_string(), Value::String(trace_ref));
365 }
366 }
367
368 if !input.contains_key("trace_ref") {
369 if let Some(trace_content) = trace_content {
370 input.insert("trace_content".to_string(), trace_content);
371 }
372 }
373
374 if let Some(system_prompt) = system_prompt {
375 input.insert("system_prompt".to_string(), Value::String(system_prompt));
376 }
377 if let Some(user_prompt) = user_prompt {
378 input.insert("user_prompt".to_string(), Value::String(user_prompt));
379 }
380
381 Ok(GraphCompletionRequest {
382 job_id: verifier_id.to_string(),
383 input: Value::Object(input),
384 model,
385 prompt_snapshot_id: None,
386 stream: Some(false),
387 })
388}
389
390fn estimate_trace_tokens(trace: &Value) -> Result<usize, CoreError> {
391 let payload = serde_json::to_string(trace)
392 .map_err(|e| CoreError::Validation(format!("failed to serialize trace: {}", e)))?;
393 Ok(payload.len() / 4)
394}
395
396pub fn resolve_graph_job_id(
400 job_id: Option<String>,
401 graph: Option<Value>,
402) -> Result<String, CoreError> {
403 if let Some(job_id) = job_id {
404 if !job_id.trim().is_empty() {
405 return Ok(job_id);
406 }
407 }
408
409 let graph = graph
410 .ok_or_else(|| CoreError::Validation("graph_completions_missing_job_id".to_string()))?;
411
412 let graph_obj = graph
413 .as_object()
414 .ok_or_else(|| CoreError::Validation("graph target must be an object".to_string()))?;
415
416 if let Some(Value::String(job_id)) = graph_obj.get("job_id") {
417 if !job_id.trim().is_empty() {
418 return Ok(job_id.clone());
419 }
420 }
421
422 let kind = graph_obj.get("kind").and_then(|v| v.as_str()).unwrap_or("");
423 if kind == "zero_shot" {
424 if let Some(shape) = graph_obj
425 .get("verifier_shape")
426 .and_then(|v| v.as_str())
427 .or_else(|| graph_obj.get("graph_name").and_then(|v| v.as_str()))
428 {
429 return Ok(shape.to_string());
430 }
431 return Err(CoreError::Validation(
432 "graph_completions_missing_verifier_shape".to_string(),
433 ));
434 }
435
436 if kind == "graphgen" {
437 if let Some(graphgen_job_id) = graph_obj.get("graphgen_job_id").and_then(|v| v.as_str()) {
438 return Ok(graphgen_job_id.to_string());
439 }
440 return Err(CoreError::Validation(
441 "graph_completions_missing_graphgen_job_id".to_string(),
442 ));
443 }
444
445 if let Some(graph_name) = graph_obj.get("graph_name").and_then(|v| v.as_str()) {
446 return Ok(graph_name.to_string());
447 }
448
449 Err(CoreError::Validation(
450 "graph_completions_missing_graph_target".to_string(),
451 ))
452}
453
454fn map_http_error(e: HttpError) -> CoreError {
456 match e {
457 HttpError::Response(detail) => {
458 if detail.status == 401 || detail.status == 403 {
459 CoreError::Authentication(format!("authentication failed: {}", detail))
460 } else if detail.status == 429 {
461 CoreError::UsageLimit(crate::UsageLimitInfo::from_http_429("graphs", &detail))
462 } else {
463 CoreError::HttpResponse(crate::HttpErrorInfo {
464 status: detail.status,
465 url: detail.url,
466 message: detail.message,
467 body_snippet: detail.body_snippet,
468 })
469 }
470 }
471 HttpError::Request(e) => CoreError::Http(e),
472 _ => CoreError::Internal(format!("{}", e)),
473 }
474}
475
476#[cfg(test)]
477mod tests {
478 use super::*;
479
480 #[test]
481 fn test_graphs_endpoint() {
482 assert_eq!(GRAPHS_ENDPOINT, "/api/graphs/completions");
483 }
484
485 #[test]
486 fn test_verifier_constants() {
487 assert_eq!(DEFAULT_VERIFIER, "zero_shot_verifier_rubric_single");
488 assert_eq!(RLM_VERIFIER_V1, "zero_shot_verifier_rubric_rlm");
489 assert_eq!(RLM_VERIFIER_V2, "zero_shot_verifier_rubric_rlm_v2");
490 }
491}