Skip to main content

rustic_ai/providers/
anthropic.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4use base64::{Engine as _, engine::general_purpose};
5use reqwest::{Client, Url};
6use serde::Deserialize;
7use serde_json::{Map, Value, json};
8use uuid::Uuid;
9
10use crate::messages::{
11    ModelMessage, ModelRequestPart, ModelResponse, ModelResponsePart, TextPart, ToolCallPart,
12    UserContent,
13};
14use crate::model::{Model, ModelError, ModelRequestParameters, ModelSettings, OutputMode};
15use crate::providers::{Provider, ProviderError};
16use crate::usage::RequestUsage;
17
18fn map_reqwest_error(label: &str, error: reqwest::Error) -> ModelError {
19    if error.is_timeout() {
20        return ModelError::Timeout;
21    }
22    if error.is_connect() {
23        return ModelError::Transport(format!("{label} connect error: {error}"));
24    }
25    ModelError::Transport(format!("{label} request failed: {error}"))
26}
27
28fn truncate_error_body(body: &str) -> String {
29    const LIMIT: usize = 512;
30    if body.len() <= LIMIT {
31        body.to_string()
32    } else {
33        format!("{}... ({} bytes)", &body[..LIMIT], body.len())
34    }
35}
36
37fn normalize_tool_call_id(id: Option<String>) -> String {
38    match id {
39        Some(value) if !value.trim().is_empty() => value,
40        _ => format!("call_{}", Uuid::new_v4().simple()),
41    }
42}
43
44fn normalize_tool_call_id_str(id: &str) -> String {
45    if id.trim().is_empty() {
46        format!("call_{}", Uuid::new_v4().simple())
47    } else {
48        id.to_string()
49    }
50}
51
52fn tool_return_content(value: &Value) -> String {
53    match value {
54        Value::String(value) => value.clone(),
55        _ => serde_json::to_string(value).unwrap_or_else(|_| value.to_string()),
56    }
57}
58
59fn is_text_like_media_type(media_type: &str) -> bool {
60    media_type.starts_with("text/")
61        || matches!(
62            media_type,
63            "application/json"
64                | "application/xml"
65                | "application/xhtml+xml"
66                | "application/javascript"
67                | "application/x-www-form-urlencoded"
68        )
69}
70
71fn is_pdf_url(url: &str) -> bool {
72    url.split('?')
73        .next()
74        .is_some_and(|path| path.to_lowercase().ends_with(".pdf"))
75}
76
77#[derive(Clone, Debug)]
78pub struct AnthropicProvider {
79    api_key: String,
80    base_url: Url,
81}
82
83impl AnthropicProvider {
84    pub fn new(
85        api_key: impl Into<String>,
86        base_url: impl AsRef<str>,
87    ) -> Result<Self, ProviderError> {
88        let url = Url::parse(base_url.as_ref())
89            .map_err(|_| ProviderError::InvalidModel(base_url.as_ref().to_string()))?;
90        Ok(Self {
91            api_key: api_key.into(),
92            base_url: url,
93        })
94    }
95
96    pub fn from_env() -> Result<Self, ProviderError> {
97        let api_key = std::env::var("ANTHROPIC_API_KEY")
98            .map_err(|_| ProviderError::MissingApiKey("anthropic".to_string()))?;
99        Self::new(api_key, "https://api.anthropic.com")
100    }
101}
102
103#[cfg(test)]
104mod tests {
105    use super::*;
106    use crate::messages::{
107        BinaryContent, DocumentUrl, ModelMessage, ModelRequest, ModelRequestPart, ModelResponse,
108        ModelResponsePart, ToolCallPart, ToolReturnPart,
109    };
110    use base64::engine::general_purpose::STANDARD;
111    use serde_json::{Value, json};
112    use std::path::PathBuf;
113
114    fn fixture_bytes(name: &str) -> Vec<u8> {
115        let path = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
116            .join("tests")
117            .join("fixtures")
118            .join(name);
119        std::fs::read(path).expect("fixture read")
120    }
121
122    #[test]
123    fn convert_user_content_handles_documents_and_images() {
124        let image_bytes = fixture_bytes("fixture.jpg");
125        let pdf_bytes = fixture_bytes("fixture.pdf");
126
127        let content = vec![
128            UserContent::Binary(BinaryContent {
129                data: image_bytes.clone(),
130                media_type: "image/jpeg".to_string(),
131            }),
132            UserContent::Binary(BinaryContent {
133                data: pdf_bytes.clone(),
134                media_type: "application/pdf".to_string(),
135            }),
136            UserContent::Document(DocumentUrl {
137                url: "https://example.com/fixture.pdf".to_string(),
138                media_type: None,
139            }),
140        ];
141
142        let parts = convert_user_content(&content);
143        assert_eq!(parts.len(), 3);
144
145        let image = &parts[0];
146        assert_eq!(image.get("type"), Some(&Value::String("image".to_string())));
147        let image_source = image.get("source").expect("image source");
148        assert_eq!(
149            image_source.get("type"),
150            Some(&Value::String("base64".to_string()))
151        );
152        assert_eq!(
153            image_source.get("media_type"),
154            Some(&Value::String("image/jpeg".to_string()))
155        );
156        assert_eq!(
157            image_source.get("data"),
158            Some(&Value::String(STANDARD.encode(&image_bytes)))
159        );
160
161        let pdf = &parts[1];
162        assert_eq!(
163            pdf.get("type"),
164            Some(&Value::String("document".to_string()))
165        );
166        let pdf_source = pdf.get("source").expect("pdf source");
167        assert_eq!(
168            pdf_source.get("type"),
169            Some(&Value::String("base64".to_string()))
170        );
171        assert_eq!(
172            pdf_source.get("media_type"),
173            Some(&Value::String("application/pdf".to_string()))
174        );
175        assert_eq!(
176            pdf_source.get("data"),
177            Some(&Value::String(STANDARD.encode(&pdf_bytes)))
178        );
179
180        let doc = &parts[2];
181        assert_eq!(
182            doc.get("type"),
183            Some(&Value::String("document".to_string()))
184        );
185        let doc_source = doc.get("source").expect("doc source");
186        assert_eq!(
187            doc_source.get("type"),
188            Some(&Value::String("url".to_string()))
189        );
190        assert_eq!(
191            doc_source.get("url"),
192            Some(&Value::String(
193                "https://example.com/fixture.pdf".to_string()
194            ))
195        );
196    }
197
198    #[test]
199    fn split_system_replays_tool_calls() {
200        let messages = vec![
201            ModelMessage::Response(ModelResponse {
202                parts: vec![ModelResponsePart::ToolCall(ToolCallPart {
203                    id: "call-1".to_string(),
204                    name: "get_data".to_string(),
205                    arguments: json!({"a": 1}),
206                })],
207                usage: None,
208                model_name: None,
209                finish_reason: None,
210            }),
211            ModelMessage::Request(ModelRequest {
212                parts: vec![ModelRequestPart::ToolReturn(ToolReturnPart {
213                    tool_name: "get_data".to_string(),
214                    tool_call_id: "call-1".to_string(),
215                    content: json!({"ok": true}),
216                })],
217                instructions: None,
218            }),
219        ];
220
221        let (_system, out) = AnthropicModel::split_system(&messages);
222        assert_eq!(out.len(), 2);
223
224        let assistant = out[0].as_object().expect("assistant message");
225        assert_eq!(
226            assistant.get("role"),
227            Some(&Value::String("assistant".to_string()))
228        );
229        let assistant_content = assistant
230            .get("content")
231            .and_then(|value| value.as_array())
232            .expect("assistant content");
233        let tool_use = assistant_content
234            .iter()
235            .find(|part| part.get("type") == Some(&Value::String("tool_use".to_string())))
236            .expect("tool_use part");
237        assert_eq!(
238            tool_use.get("id"),
239            Some(&Value::String("call-1".to_string()))
240        );
241        assert_eq!(
242            tool_use.get("name"),
243            Some(&Value::String("get_data".to_string()))
244        );
245        assert_eq!(tool_use.get("input"), Some(&json!({"a": 1})));
246
247        let user = out[1].as_object().expect("tool result message");
248        assert_eq!(user.get("role"), Some(&Value::String("user".to_string())));
249        let user_content = user
250            .get("content")
251            .and_then(|value| value.as_array())
252            .expect("user content");
253        let tool_result = user_content
254            .iter()
255            .find(|part| part.get("type") == Some(&Value::String("tool_result".to_string())))
256            .expect("tool_result part");
257        assert_eq!(
258            tool_result.get("tool_use_id"),
259            Some(&Value::String("call-1".to_string()))
260        );
261        assert_eq!(
262            tool_result.get("content"),
263            Some(&Value::String("{\"ok\":true}".to_string()))
264        );
265    }
266
267    #[test]
268    fn helper_functions_cover_ids_and_media() {
269        let id = normalize_tool_call_id(Some("".to_string()));
270        assert!(id.starts_with("call_"));
271        let id = normalize_tool_call_id_str("");
272        assert!(id.starts_with("call_"));
273
274        assert!(is_text_like_media_type("text/plain"));
275        assert!(is_text_like_media_type("application/json"));
276        assert!(!is_text_like_media_type("image/png"));
277
278        assert_eq!(tool_return_content(&json!("ok")), "ok");
279        assert_eq!(tool_return_content(&json!({"a": 1})), "{\"a\":1}");
280
281        assert!(is_pdf_url("https://example.com/doc.pdf"));
282        assert!(is_pdf_url("https://example.com/doc.pdf?x=1"));
283        assert!(!is_pdf_url("https://example.com/doc.txt"));
284    }
285
286    #[test]
287    fn truncate_error_body_limits_length() {
288        let truncated = truncate_error_body(&"a".repeat(600));
289        assert!(truncated.contains("bytes"));
290    }
291}
292
293impl Provider for AnthropicProvider {
294    fn name(&self) -> &str {
295        "anthropic"
296    }
297
298    fn model(&self, model: &str, settings: Option<ModelSettings>) -> Arc<dyn Model> {
299        Arc::new(AnthropicModel::new(
300            model,
301            self.api_key.clone(),
302            self.base_url.clone(),
303            settings,
304        ))
305    }
306}
307
308#[derive(Clone, Debug)]
309pub struct AnthropicModel {
310    model: String,
311    api_key: String,
312    base_url: Url,
313    client: Client,
314    default_settings: Option<ModelSettings>,
315}
316
317impl AnthropicModel {
318    pub fn new(
319        model: impl Into<String>,
320        api_key: String,
321        base_url: Url,
322        settings: Option<ModelSettings>,
323    ) -> Self {
324        Self {
325            model: model.into(),
326            api_key,
327            base_url,
328            client: Client::new(),
329            default_settings: settings,
330        }
331    }
332
333    fn endpoint(&self) -> Result<Url, ModelError> {
334        self.base_url
335            .join("v1/messages")
336            .map_err(|e| ModelError::Provider(format!("invalid base url: {e}")))
337    }
338
339    fn split_system(messages: &[ModelMessage]) -> (Option<String>, Vec<Value>) {
340        let mut system_parts = Vec::new();
341        let mut out = Vec::new();
342
343        for message in messages {
344            match message {
345                ModelMessage::Request(req) => {
346                    if let Some(instructions) = req
347                        .instructions
348                        .as_ref()
349                        .filter(|value| !value.trim().is_empty())
350                    {
351                        system_parts.push(instructions.to_string());
352                    }
353                    for part in &req.parts {
354                        match part {
355                            ModelRequestPart::SystemPrompt(prompt) => {
356                                system_parts.push(prompt.content.clone());
357                            }
358                            ModelRequestPart::UserPrompt(prompt) => {
359                                out.push(json!({
360                                    "role": "user",
361                                    "content": convert_user_content(&prompt.content)
362                                }));
363                            }
364                            ModelRequestPart::ToolReturn(tool_return) => {
365                                let content = tool_return_content(&tool_return.content);
366                                out.push(json!({
367                                    "role": "user",
368                                    "content": [{
369                                        "type": "tool_result",
370                                        "tool_use_id": normalize_tool_call_id_str(&tool_return.tool_call_id),
371                                        "content": content,
372                                        "is_error": false,
373                                    }]
374                                }));
375                            }
376                            ModelRequestPart::RetryPrompt(retry) => {
377                                if retry.tool_name.is_some() {
378                                    out.push(json!({
379                                        "role": "user",
380                                        "content": [{
381                                            "type": "tool_result",
382                                            "tool_use_id": normalize_tool_call_id(retry.tool_call_id.clone()),
383                                            "content": retry.content,
384                                            "is_error": true,
385                                        }]
386                                    }));
387                                } else {
388                                    out.push(json!({
389                                        "role": "user",
390                                        "content": [{"type": "text", "text": retry.content}]
391                                    }));
392                                }
393                            }
394                        }
395                    }
396                }
397                ModelMessage::Response(res) => {
398                    let mut content = Vec::new();
399                    if let Some(text) = res.text() {
400                        content.push(json!({"type": "text", "text": text}));
401                    }
402                    for call in res.tool_calls() {
403                        content.push(json!({
404                            "type": "tool_use",
405                            "id": normalize_tool_call_id_str(&call.id),
406                            "name": call.name,
407                            "input": call.arguments,
408                        }));
409                    }
410
411                    if !content.is_empty() {
412                        out.push(json!({
413                            "role": "assistant",
414                            "content": content
415                        }));
416                    }
417                }
418            }
419        }
420
421        let system = if system_parts.is_empty() {
422            None
423        } else {
424            Some(system_parts.join("\n\n"))
425        };
426
427        (system, out)
428    }
429}
430
431fn convert_user_content(content: &[UserContent]) -> Vec<Value> {
432    let mut parts = Vec::new();
433    for item in content {
434        match item {
435            UserContent::Text(text) => parts.push(json!({"type": "text", "text": text})),
436            UserContent::Image(image) => parts.push(json!({
437                "type": "image",
438                "source": {"type": "url", "url": image.url}
439            })),
440            UserContent::Binary(binary) => {
441                if binary.media_type.starts_with("image/") {
442                    let encoded = general_purpose::STANDARD.encode(&binary.data);
443                    parts.push(json!({
444                        "type": "image",
445                        "source": {
446                            "type": "base64",
447                            "media_type": binary.media_type,
448                            "data": encoded
449                        }
450                    }));
451                } else if is_text_like_media_type(&binary.media_type) {
452                    match std::str::from_utf8(&binary.data) {
453                        Ok(text) => parts.push(json!({"type": "text", "text": text})),
454                        Err(_) => parts.push(json!({
455                            "type": "text",
456                            "text": format!("[binary content: {} bytes]", binary.data.len())
457                        })),
458                    }
459                } else if binary.media_type == "application/pdf" {
460                    let encoded = general_purpose::STANDARD.encode(&binary.data);
461                    parts.push(json!({
462                        "type": "document",
463                        "source": {
464                            "type": "base64",
465                            "media_type": binary.media_type,
466                            "data": encoded
467                        }
468                    }));
469                } else {
470                    parts.push(json!({
471                        "type": "text",
472                        "text": format!("[binary content: {} bytes]", binary.data.len())
473                    }));
474                }
475            }
476            UserContent::Audio(audio) => parts.push(json!({
477                "type": "text",
478                "text": format!("[audio: {}]", audio.url)
479            })),
480            UserContent::Video(video) => parts.push(json!({
481                "type": "text",
482                "text": format!("[video: {}]", video.url)
483            })),
484            UserContent::Document(doc) => {
485                let media_type = doc.media_type.as_deref();
486                if media_type == Some("application/pdf")
487                    || (media_type.is_none() && is_pdf_url(&doc.url))
488                {
489                    parts.push(json!({
490                        "type": "document",
491                        "source": {"type": "url", "url": doc.url}
492                    }))
493                } else {
494                    parts.push(json!({
495                        "type": "text",
496                        "text": format!("[document: {}]", doc.url)
497                    }))
498                }
499            }
500        }
501    }
502    parts
503}
504
505#[async_trait]
506impl Model for AnthropicModel {
507    fn name(&self) -> &str {
508        &self.model
509    }
510
511    async fn request(
512        &self,
513        messages: &[ModelMessage],
514        settings: Option<&ModelSettings>,
515        params: &ModelRequestParameters,
516    ) -> Result<ModelResponse, ModelError> {
517        tracing::debug!(
518            model = %self.model,
519            tool_count = params.function_tools.len(),
520            output_schema = params.output_schema.is_some(),
521            "Anthropic request"
522        );
523        let (system, messages) = Self::split_system(messages);
524        let mut body = Map::new();
525        body.insert("model".to_string(), Value::String(self.model.clone()));
526        body.insert("messages".to_string(), Value::Array(messages));
527        if let Some(system) = system {
528            body.insert("system".to_string(), Value::String(system));
529        }
530
531        if !params.function_tools.is_empty() {
532            let tools = params
533                .function_tools
534                .iter()
535                .map(|tool| {
536                    json!({
537                        "name": tool.name,
538                        "description": tool.description,
539                        "input_schema": tool.parameters_json_schema,
540                    })
541                })
542                .collect();
543            body.insert("tools".to_string(), Value::Array(tools));
544            let mut tool_choice = json!({"type": "auto"});
545            if params.function_tools.iter().any(|tool| tool.sequential)
546                && let Value::Object(map) = &mut tool_choice
547            {
548                map.insert("disable_parallel_tool_use".to_string(), Value::Bool(true));
549            }
550            body.insert("tool_choice".to_string(), tool_choice);
551        }
552
553        if params.output_mode == OutputMode::JsonSchema
554            && let Some(schema) = params.output_schema.clone()
555        {
556            body.insert(
557                "output_format".to_string(),
558                json!({
559                    "type": "json_schema",
560                    "schema": schema
561                }),
562            );
563        }
564
565        if let Some(settings) = &self.default_settings {
566            for (key, value) in settings {
567                body.insert(key.clone(), value.clone());
568            }
569        }
570
571        if let Some(settings) = settings {
572            for (key, value) in settings {
573                body.insert(key.clone(), value.clone());
574            }
575        }
576
577        if !body.contains_key("max_tokens") {
578            body.insert("max_tokens".to_string(), Value::Number(1024.into()));
579        }
580
581        let mut request = self
582            .client
583            .post(self.endpoint()?)
584            .header("x-api-key", &self.api_key)
585            .header("anthropic-version", "2023-06-01");
586
587        if params.output_mode == OutputMode::JsonSchema && params.output_schema.is_some() {
588            request = request.header("anthropic-beta", "structured-outputs-2025-11-13");
589        }
590
591        let response = request
592            .json(&Value::Object(body))
593            .send()
594            .await
595            .map_err(|e| map_reqwest_error("Anthropic", e))?;
596
597        let status = response.status();
598        if !status.is_success() {
599            let body = response.text().await.unwrap_or_default();
600            tracing::error!(
601                status = status.as_u16(),
602                model = %self.model,
603                body = %truncate_error_body(&body),
604                "Anthropic request failed"
605            );
606            return Err(ModelError::HttpStatus {
607                status: status.as_u16(),
608            });
609        }
610
611        let body: AnthropicResponse = response.json().await.map_err(|e| {
612            tracing::error!(
613                error = %e,
614                model = %self.model,
615                "Anthropic response parse failed"
616            );
617            ModelError::Provider(format!("Anthropic response parse failed: {e}"))
618        })?;
619
620        let mut parts = Vec::new();
621        for content in body.content {
622            match content.kind.as_str() {
623                "text" => {
624                    if let Some(text) = content.text {
625                        parts.push(ModelResponsePart::Text(TextPart { content: text }));
626                    }
627                }
628                "tool_use" => {
629                    parts.push(ModelResponsePart::ToolCall(ToolCallPart {
630                        id: normalize_tool_call_id(content.id),
631                        name: content.name.unwrap_or_else(|| "tool".to_string()),
632                        arguments: content.input.unwrap_or_else(|| Value::Object(Map::new())),
633                    }));
634                }
635                _ => {}
636            }
637        }
638
639        let usage = body.usage.map(|usage| RequestUsage {
640            input_tokens: usage.input_tokens.unwrap_or(0),
641            output_tokens: usage.output_tokens.unwrap_or(0),
642            ..Default::default()
643        });
644
645        Ok(ModelResponse {
646            parts,
647            usage,
648            model_name: Some(self.model.clone()),
649            finish_reason: body.stop_reason,
650        })
651    }
652}
653
654#[derive(Debug, Deserialize)]
655struct AnthropicResponse {
656    content: Vec<AnthropicContent>,
657    stop_reason: Option<String>,
658    usage: Option<AnthropicUsage>,
659}
660
661#[derive(Debug, Deserialize)]
662struct AnthropicContent {
663    #[serde(rename = "type")]
664    kind: String,
665    text: Option<String>,
666    id: Option<String>,
667    name: Option<String>,
668    input: Option<Value>,
669}
670
671#[derive(Debug, Deserialize)]
672struct AnthropicUsage {
673    input_tokens: Option<u64>,
674    output_tokens: Option<u64>,
675}