stygian_graph/adapters/ai/
claude.rs1use std::time::Duration;
27
28use async_trait::async_trait;
29use futures::stream::{self, BoxStream};
30use reqwest::Client;
31use serde_json::{Value, json};
32
33use crate::domain::error::{ProviderError, Result, StygianError};
34use crate::ports::{AIProvider, ProviderCapabilities};
35
36const DEFAULT_MODEL: &str = "claude-sonnet-4-5";
38
39const API_URL: &str = "https://api.anthropic.com/v1/messages";
41
42const ANTHROPIC_VERSION: &str = "2023-06-01";
44
45#[derive(Debug, Clone)]
47pub struct ClaudeConfig {
48 pub api_key: String,
50 pub model: String,
52 pub max_tokens: u32,
54 pub timeout: Duration,
56}
57
58impl ClaudeConfig {
59 #[must_use]
61 pub fn new(api_key: String) -> Self {
62 Self {
63 api_key,
64 model: DEFAULT_MODEL.to_string(),
65 max_tokens: 4096,
66 timeout: Duration::from_mins(2),
67 }
68 }
69
70 #[must_use]
72 pub fn with_model(mut self, model: impl Into<String>) -> Self {
73 self.model = model.into();
74 self
75 }
76
77 #[must_use]
79 pub const fn with_max_tokens(mut self, max_tokens: u32) -> Self {
80 self.max_tokens = max_tokens;
81 self
82 }
83}
84
85pub struct ClaudeProvider {
90 config: ClaudeConfig,
91 client: Client,
92}
93
94impl ClaudeProvider {
95 #[must_use]
105 pub fn new(api_key: String) -> Self {
106 let config = ClaudeConfig::new(api_key);
107 Self::with_config(config)
108 }
109
110 #[must_use]
128 pub fn with_config(config: ClaudeConfig) -> Self {
129 #[allow(clippy::expect_used)]
131 let client = Client::builder()
132 .timeout(config.timeout)
133 .build()
134 .expect("Failed to build HTTP client");
135 Self { config, client }
136 }
137
138 fn build_extract_body(&self, content: &str, schema: &Value) -> Value {
143 let system = "You are a precise data extraction assistant. \
144 Extract the requested information from the provided content and \
145 return it using the extract_data tool. \
146 Always extract exactly what the schema requests — nothing more, nothing less.";
147
148 let tool = json!({
149 "name": "extract_data",
150 "description": "Extract structured data from the provided content according to the schema.",
151 "input_schema": schema
152 });
153
154 json!({
155 "model": self.config.model,
156 "max_tokens": self.config.max_tokens,
157 "system": system,
158 "tools": [tool],
159 "tool_choice": {"type": "tool", "name": "extract_data"},
160 "messages": [
161 {
162 "role": "user",
163 "content": format!("Extract data from the following content:\n\n{content}")
164 }
165 ]
166 })
167 }
168
169 #[allow(dead_code, clippy::indexing_slicing)]
171 fn build_stream_body(&self, content: &str, schema: &Value) -> Value {
172 let mut body = self.build_extract_body(content, schema);
173 body["stream"] = json!(true);
174 body
175 }
176
177 fn parse_extract_response(response: &Value) -> Result<Value> {
179 let content = response
181 .get("content")
182 .and_then(Value::as_array)
183 .ok_or_else(|| {
184 StygianError::Provider(ProviderError::ApiError(
185 "No content in Claude response".to_string(),
186 ))
187 })?;
188
189 for block in content {
190 if block.get("type").and_then(Value::as_str) == Some("tool_use")
191 && let Some(input) = block.get("input")
192 {
193 return Ok(input.clone());
194 }
195 }
196
197 Err(StygianError::Provider(ProviderError::ApiError(
198 "Claude response contained no tool_use block".to_string(),
199 )))
200 }
201
202 fn map_http_error(status: u16, body: &str) -> StygianError {
204 match status {
205 401 => StygianError::Provider(ProviderError::InvalidCredentials),
206 429 => StygianError::Provider(ProviderError::ApiError(format!(
207 "Rate limited by Anthropic API: {body}"
208 ))),
209 400 => {
210 if body.contains("token") {
211 StygianError::Provider(ProviderError::TokenLimitExceeded(body.to_string()))
212 } else if body.contains("policy") {
213 StygianError::Provider(ProviderError::ContentPolicyViolation(body.to_string()))
214 } else {
215 StygianError::Provider(ProviderError::ApiError(body.to_string()))
216 }
217 }
218 _ => StygianError::Provider(ProviderError::ApiError(format!("HTTP {status}: {body}"))),
219 }
220 }
221}
222
223#[async_trait]
224impl AIProvider for ClaudeProvider {
225 async fn extract(&self, content: String, schema: Value) -> Result<Value> {
245 let body = self.build_extract_body(&content, &schema);
246
247 let response = self
248 .client
249 .post(API_URL)
250 .header("x-api-key", &self.config.api_key)
251 .header("anthropic-version", ANTHROPIC_VERSION)
252 .header("content-type", "application/json")
253 .json(&body)
254 .send()
255 .await
256 .map_err(|e| {
257 StygianError::Provider(ProviderError::ApiError(format!(
258 "Request to Anthropic API failed: {e}"
259 )))
260 })?;
261
262 let status = response.status().as_u16();
263 let text = response.text().await.map_err(|e| {
264 StygianError::Provider(ProviderError::ApiError(format!(
265 "Failed to read Anthropic response body: {e}"
266 )))
267 })?;
268
269 if status != 200 {
270 return Err(Self::map_http_error(status, &text));
271 }
272
273 let json_value: Value = serde_json::from_str(&text).map_err(|e| {
274 StygianError::Provider(ProviderError::ApiError(format!(
275 "Failed to parse Anthropic response JSON: {e}"
276 )))
277 })?;
278
279 Self::parse_extract_response(&json_value)
280 }
281
282 async fn stream_extract(
302 &self,
303 content: String,
304 schema: Value,
305 ) -> Result<BoxStream<'static, Result<Value>>> {
306 let result = self.extract(content, schema).await;
311 let stream = stream::once(async move { result });
312 Ok(Box::pin(stream))
313 }
314
315 fn capabilities(&self) -> ProviderCapabilities {
316 ProviderCapabilities {
317 streaming: true,
318 vision: true,
319 tool_use: true,
320 json_mode: true,
321 }
322 }
323
324 fn name(&self) -> &'static str {
325 "claude"
326 }
327}
328
329#[cfg(test)]
330mod tests {
331 use super::*;
332 use serde_json::json;
333
334 #[test]
335 fn test_provider_name() {
336 let p = ClaudeProvider::new("key".to_string());
337 assert_eq!(p.name(), "claude");
338 }
339
340 #[test]
341 fn test_capabilities() {
342 let p = ClaudeProvider::new("key".to_string());
343 let caps = p.capabilities();
344 assert!(caps.streaming);
345 assert!(caps.vision);
346 assert!(caps.tool_use);
347 assert!(caps.json_mode);
348 }
349
350 #[test]
351 fn test_build_extract_body_contains_tool() -> std::result::Result<(), Box<dyn std::error::Error>>
352 {
353 let p = ClaudeProvider::new("key".to_string());
354 let schema = json!({"type": "object"});
355 let body = p.build_extract_body("some content", &schema);
356
357 assert_eq!(
358 body.get("model").and_then(Value::as_str),
359 Some(DEFAULT_MODEL)
360 );
361 let tools = body
362 .get("tools")
363 .and_then(Value::as_array)
364 .ok_or("no tools field")?;
365 assert_eq!(tools.len(), 1);
366 assert_eq!(
367 tools
368 .first()
369 .and_then(|t| t.get("name"))
370 .and_then(Value::as_str),
371 Some("extract_data")
372 );
373 assert_eq!(
374 body.get("tool_choice")
375 .and_then(|tc| tc.get("name"))
376 .and_then(Value::as_str),
377 Some("extract_data")
378 );
379 Ok(())
380 }
381
382 #[test]
383 fn test_parse_extract_response_success() -> Result<()> {
384 let response = json!({
385 "content": [
386 {"type": "tool_use", "name": "extract_data", "input": {"title": "Hello"}}
387 ]
388 });
389 let result = ClaudeProvider::parse_extract_response(&response)?;
390 assert_eq!(result.get("title").and_then(Value::as_str), Some("Hello"));
391 Ok(())
392 }
393
394 #[test]
395 fn test_parse_extract_response_no_tool_use() {
396 let response = json!({
397 "content": [{"type": "text", "text": "some text"}]
398 });
399 let err_result = ClaudeProvider::parse_extract_response(&response);
400 assert!(err_result.is_err(), "expected Err but got Ok");
401 if let Err(e) = err_result {
402 assert!(e.to_string().contains("tool_use"));
403 }
404 }
405
406 #[test]
407 fn test_parse_extract_response_no_content() {
408 let response = json!({"stop_reason": "end_turn"});
409 let err_result = ClaudeProvider::parse_extract_response(&response);
410 assert!(err_result.is_err(), "expected Err but got Ok");
411 if let Err(e) = err_result {
412 assert!(e.to_string().contains("content") || e.to_string().contains("API error"));
413 }
414 }
415
416 #[test]
417 fn test_map_http_error_401() {
418 let e = ClaudeProvider::map_http_error(401, "unauthorized");
419 assert!(matches!(
420 e,
421 StygianError::Provider(ProviderError::InvalidCredentials)
422 ));
423 }
424
425 #[test]
426 fn test_map_http_error_429() {
427 let e = ClaudeProvider::map_http_error(429, "rate limited");
428 assert!(e.to_string().contains("Rate limited"));
429 }
430
431 #[test]
432 fn test_config_builder() {
433 let config = ClaudeConfig::new("key".to_string())
434 .with_model("claude-3-5-sonnet-20241022")
435 .with_max_tokens(2048);
436 assert_eq!(config.model, "claude-3-5-sonnet-20241022");
437 assert_eq!(config.max_tokens, 2048);
438 }
439
440 #[tokio::test]
441 async fn test_stream_extract_returns_stream() {
442 use futures::StreamExt;
443 let p = ClaudeProvider::new("invalid-key".to_string());
445 let schema = json!({"type": "object"});
446 let result = p.stream_extract("content".to_string(), schema).await;
447 assert!(result.is_ok(), "stream_extract should return Ok(stream)");
449 if let Ok(mut s) = result {
450 let item = s.next().await;
452 assert!(item.is_some());
453 }
455 }
456}