Skip to main content

rustic_ai/providers/
anthropic.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4use base64::{Engine as _, engine::general_purpose};
5use reqwest::{Client, Url};
6use serde::Deserialize;
7use serde_json::{Map, Value, json};
8use uuid::Uuid;
9
10use crate::messages::{
11    ModelMessage, ModelRequestPart, ModelResponse, ModelResponsePart, TextPart, ToolCallPart,
12    UserContent,
13};
14use crate::model::{Model, ModelError, ModelRequestParameters, ModelSettings, OutputMode};
15use crate::providers::{Provider, ProviderError};
16use crate::usage::RequestUsage;
17
18fn map_reqwest_error(label: &str, error: reqwest::Error) -> ModelError {
19    if error.is_timeout() {
20        return ModelError::Timeout;
21    }
22    if error.is_connect() {
23        return ModelError::Transport(format!("{label} connect error: {error}"));
24    }
25    ModelError::Transport(format!("{label} request failed: {error}"))
26}
27
28fn truncate_error_body(body: &str) -> String {
29    const LIMIT: usize = 512;
30    if body.len() <= LIMIT {
31        body.to_string()
32    } else {
33        format!("{}... ({} bytes)", &body[..LIMIT], body.len())
34    }
35}
36
37fn normalize_tool_call_id(id: Option<String>) -> String {
38    match id {
39        Some(value) if !value.trim().is_empty() => value,
40        _ => format!("call_{}", Uuid::new_v4().simple()),
41    }
42}
43
44fn normalize_tool_call_id_str(id: &str) -> String {
45    if id.trim().is_empty() {
46        format!("call_{}", Uuid::new_v4().simple())
47    } else {
48        id.to_string()
49    }
50}
51
52fn tool_return_content(value: &Value) -> String {
53    match value {
54        Value::String(value) => value.clone(),
55        _ => serde_json::to_string(value).unwrap_or_else(|_| value.to_string()),
56    }
57}
58
59fn is_text_like_media_type(media_type: &str) -> bool {
60    media_type.starts_with("text/")
61        || matches!(
62            media_type,
63            "application/json"
64                | "application/xml"
65                | "application/xhtml+xml"
66                | "application/javascript"
67                | "application/x-www-form-urlencoded"
68        )
69}
70
71fn is_pdf_url(url: &str) -> bool {
72    url.split('?')
73        .next()
74        .is_some_and(|path| path.to_lowercase().ends_with(".pdf"))
75}
76
77#[derive(Clone, Debug)]
78pub struct AnthropicProvider {
79    api_key: String,
80    base_url: Url,
81}
82
83impl AnthropicProvider {
84    pub fn new(
85        api_key: impl Into<String>,
86        base_url: impl AsRef<str>,
87    ) -> Result<Self, ProviderError> {
88        let url = Url::parse(base_url.as_ref())
89            .map_err(|_| ProviderError::InvalidModel(base_url.as_ref().to_string()))?;
90        Ok(Self {
91            api_key: api_key.into(),
92            base_url: url,
93        })
94    }
95
96    pub fn from_env() -> Result<Self, ProviderError> {
97        let api_key = std::env::var("ANTHROPIC_API_KEY")
98            .map_err(|_| ProviderError::MissingApiKey("anthropic".to_string()))?;
99        Self::new(api_key, "https://api.anthropic.com")
100    }
101}
102
103#[cfg(test)]
104mod tests {
105    use super::*;
106    use crate::messages::{
107        BinaryContent, DocumentUrl, ModelMessage, ModelRequest, ModelRequestPart, ModelResponse,
108        ModelResponsePart, ToolCallPart, ToolReturnPart,
109    };
110    use base64::engine::general_purpose::STANDARD;
111    use serde_json::{Value, json};
112    use std::path::PathBuf;
113
114    fn fixture_bytes(name: &str) -> Vec<u8> {
115        let path = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
116            .join("tests")
117            .join("fixtures")
118            .join(name);
119        std::fs::read(path).expect("fixture read")
120    }
121
122    #[test]
123    fn convert_user_content_handles_documents_and_images() {
124        let image_bytes = fixture_bytes("fixture.jpg");
125        let pdf_bytes = fixture_bytes("fixture.pdf");
126
127        let content = vec![
128            UserContent::Binary(BinaryContent {
129                data: image_bytes.clone(),
130                media_type: "image/jpeg".to_string(),
131            }),
132            UserContent::Binary(BinaryContent {
133                data: pdf_bytes.clone(),
134                media_type: "application/pdf".to_string(),
135            }),
136            UserContent::Document(DocumentUrl {
137                url: "https://example.com/fixture.pdf".to_string(),
138                media_type: None,
139            }),
140        ];
141
142        let parts = convert_user_content(&content);
143        assert_eq!(parts.len(), 3);
144
145        let image = &parts[0];
146        assert_eq!(image.get("type"), Some(&Value::String("image".to_string())));
147        let image_source = image.get("source").expect("image source");
148        assert_eq!(
149            image_source.get("type"),
150            Some(&Value::String("base64".to_string()))
151        );
152        assert_eq!(
153            image_source.get("media_type"),
154            Some(&Value::String("image/jpeg".to_string()))
155        );
156        assert_eq!(
157            image_source.get("data"),
158            Some(&Value::String(STANDARD.encode(&image_bytes)))
159        );
160
161        let pdf = &parts[1];
162        assert_eq!(
163            pdf.get("type"),
164            Some(&Value::String("document".to_string()))
165        );
166        let pdf_source = pdf.get("source").expect("pdf source");
167        assert_eq!(
168            pdf_source.get("type"),
169            Some(&Value::String("base64".to_string()))
170        );
171        assert_eq!(
172            pdf_source.get("media_type"),
173            Some(&Value::String("application/pdf".to_string()))
174        );
175        assert_eq!(
176            pdf_source.get("data"),
177            Some(&Value::String(STANDARD.encode(&pdf_bytes)))
178        );
179
180        let doc = &parts[2];
181        assert_eq!(
182            doc.get("type"),
183            Some(&Value::String("document".to_string()))
184        );
185        let doc_source = doc.get("source").expect("doc source");
186        assert_eq!(
187            doc_source.get("type"),
188            Some(&Value::String("url".to_string()))
189        );
190        assert_eq!(
191            doc_source.get("url"),
192            Some(&Value::String(
193                "https://example.com/fixture.pdf".to_string()
194            ))
195        );
196    }
197
198    #[test]
199    fn split_system_replays_tool_calls() {
200        let messages = vec![
201            ModelMessage::Response(ModelResponse {
202                parts: vec![ModelResponsePart::ToolCall(ToolCallPart {
203                    id: "call-1".to_string(),
204                    name: "get_data".to_string(),
205                    arguments: json!({"a": 1}),
206                })],
207                usage: None,
208                model_name: None,
209                finish_reason: None,
210            }),
211            ModelMessage::Request(ModelRequest {
212                parts: vec![ModelRequestPart::ToolReturn(ToolReturnPart {
213                    tool_name: "get_data".to_string(),
214                    tool_call_id: "call-1".to_string(),
215                    content: json!({"ok": true}),
216                })],
217                instructions: None,
218            }),
219        ];
220
221        let (_system, out) = AnthropicModel::split_system(&messages);
222        assert_eq!(out.len(), 2);
223
224        let assistant = out[0].as_object().expect("assistant message");
225        assert_eq!(
226            assistant.get("role"),
227            Some(&Value::String("assistant".to_string()))
228        );
229        let assistant_content = assistant
230            .get("content")
231            .and_then(|value| value.as_array())
232            .expect("assistant content");
233        let tool_use = assistant_content
234            .iter()
235            .find(|part| part.get("type") == Some(&Value::String("tool_use".to_string())))
236            .expect("tool_use part");
237        assert_eq!(
238            tool_use.get("id"),
239            Some(&Value::String("call-1".to_string()))
240        );
241        assert_eq!(
242            tool_use.get("name"),
243            Some(&Value::String("get_data".to_string()))
244        );
245        assert_eq!(tool_use.get("input"), Some(&json!({"a": 1})));
246
247        let user = out[1].as_object().expect("tool result message");
248        assert_eq!(user.get("role"), Some(&Value::String("user".to_string())));
249        let user_content = user
250            .get("content")
251            .and_then(|value| value.as_array())
252            .expect("user content");
253        let tool_result = user_content
254            .iter()
255            .find(|part| part.get("type") == Some(&Value::String("tool_result".to_string())))
256            .expect("tool_result part");
257        assert_eq!(
258            tool_result.get("tool_use_id"),
259            Some(&Value::String("call-1".to_string()))
260        );
261        assert_eq!(
262            tool_result.get("content"),
263            Some(&Value::String("{\"ok\":true}".to_string()))
264        );
265    }
266}
267
268impl Provider for AnthropicProvider {
269    fn name(&self) -> &str {
270        "anthropic"
271    }
272
273    fn model(&self, model: &str, settings: Option<ModelSettings>) -> Arc<dyn Model> {
274        Arc::new(AnthropicModel::new(
275            model,
276            self.api_key.clone(),
277            self.base_url.clone(),
278            settings,
279        ))
280    }
281}
282
283#[derive(Clone, Debug)]
284pub struct AnthropicModel {
285    model: String,
286    api_key: String,
287    base_url: Url,
288    client: Client,
289    default_settings: Option<ModelSettings>,
290}
291
292impl AnthropicModel {
293    pub fn new(
294        model: impl Into<String>,
295        api_key: String,
296        base_url: Url,
297        settings: Option<ModelSettings>,
298    ) -> Self {
299        Self {
300            model: model.into(),
301            api_key,
302            base_url,
303            client: Client::new(),
304            default_settings: settings,
305        }
306    }
307
308    fn endpoint(&self) -> Result<Url, ModelError> {
309        self.base_url
310            .join("v1/messages")
311            .map_err(|e| ModelError::Provider(format!("invalid base url: {e}")))
312    }
313
314    fn split_system(messages: &[ModelMessage]) -> (Option<String>, Vec<Value>) {
315        let mut system_parts = Vec::new();
316        let mut out = Vec::new();
317
318        for message in messages {
319            match message {
320                ModelMessage::Request(req) => {
321                    if let Some(instructions) = req
322                        .instructions
323                        .as_ref()
324                        .filter(|value| !value.trim().is_empty())
325                    {
326                        system_parts.push(instructions.to_string());
327                    }
328                    for part in &req.parts {
329                        match part {
330                            ModelRequestPart::SystemPrompt(prompt) => {
331                                system_parts.push(prompt.content.clone());
332                            }
333                            ModelRequestPart::UserPrompt(prompt) => {
334                                out.push(json!({
335                                    "role": "user",
336                                    "content": convert_user_content(&prompt.content)
337                                }));
338                            }
339                            ModelRequestPart::ToolReturn(tool_return) => {
340                                let content = tool_return_content(&tool_return.content);
341                                out.push(json!({
342                                    "role": "user",
343                                    "content": [{
344                                        "type": "tool_result",
345                                        "tool_use_id": normalize_tool_call_id_str(&tool_return.tool_call_id),
346                                        "content": content,
347                                        "is_error": false,
348                                    }]
349                                }));
350                            }
351                            ModelRequestPart::RetryPrompt(retry) => {
352                                if retry.tool_name.is_some() {
353                                    out.push(json!({
354                                        "role": "user",
355                                        "content": [{
356                                            "type": "tool_result",
357                                            "tool_use_id": normalize_tool_call_id(retry.tool_call_id.clone()),
358                                            "content": retry.content,
359                                            "is_error": true,
360                                        }]
361                                    }));
362                                } else {
363                                    out.push(json!({
364                                        "role": "user",
365                                        "content": [{"type": "text", "text": retry.content}]
366                                    }));
367                                }
368                            }
369                        }
370                    }
371                }
372                ModelMessage::Response(res) => {
373                    let mut content = Vec::new();
374                    if let Some(text) = res.text() {
375                        content.push(json!({"type": "text", "text": text}));
376                    }
377                    for call in res.tool_calls() {
378                        content.push(json!({
379                            "type": "tool_use",
380                            "id": normalize_tool_call_id_str(&call.id),
381                            "name": call.name,
382                            "input": call.arguments,
383                        }));
384                    }
385
386                    if !content.is_empty() {
387                        out.push(json!({
388                            "role": "assistant",
389                            "content": content
390                        }));
391                    }
392                }
393            }
394        }
395
396        let system = if system_parts.is_empty() {
397            None
398        } else {
399            Some(system_parts.join("\n\n"))
400        };
401
402        (system, out)
403    }
404}
405
406fn convert_user_content(content: &[UserContent]) -> Vec<Value> {
407    let mut parts = Vec::new();
408    for item in content {
409        match item {
410            UserContent::Text(text) => parts.push(json!({"type": "text", "text": text})),
411            UserContent::Image(image) => parts.push(json!({
412                "type": "image",
413                "source": {"type": "url", "url": image.url}
414            })),
415            UserContent::Binary(binary) => {
416                if binary.media_type.starts_with("image/") {
417                    let encoded = general_purpose::STANDARD.encode(&binary.data);
418                    parts.push(json!({
419                        "type": "image",
420                        "source": {
421                            "type": "base64",
422                            "media_type": binary.media_type,
423                            "data": encoded
424                        }
425                    }));
426                } else if is_text_like_media_type(&binary.media_type) {
427                    match std::str::from_utf8(&binary.data) {
428                        Ok(text) => parts.push(json!({"type": "text", "text": text})),
429                        Err(_) => parts.push(json!({
430                            "type": "text",
431                            "text": format!("[binary content: {} bytes]", binary.data.len())
432                        })),
433                    }
434                } else if binary.media_type == "application/pdf" {
435                    let encoded = general_purpose::STANDARD.encode(&binary.data);
436                    parts.push(json!({
437                        "type": "document",
438                        "source": {
439                            "type": "base64",
440                            "media_type": binary.media_type,
441                            "data": encoded
442                        }
443                    }));
444                } else {
445                    parts.push(json!({
446                        "type": "text",
447                        "text": format!("[binary content: {} bytes]", binary.data.len())
448                    }));
449                }
450            }
451            UserContent::Audio(audio) => parts.push(json!({
452                "type": "text",
453                "text": format!("[audio: {}]", audio.url)
454            })),
455            UserContent::Video(video) => parts.push(json!({
456                "type": "text",
457                "text": format!("[video: {}]", video.url)
458            })),
459            UserContent::Document(doc) => {
460                let media_type = doc.media_type.as_deref();
461                if media_type == Some("application/pdf")
462                    || (media_type.is_none() && is_pdf_url(&doc.url))
463                {
464                    parts.push(json!({
465                        "type": "document",
466                        "source": {"type": "url", "url": doc.url}
467                    }))
468                } else {
469                    parts.push(json!({
470                        "type": "text",
471                        "text": format!("[document: {}]", doc.url)
472                    }))
473                }
474            }
475        }
476    }
477    parts
478}
479
480#[async_trait]
481impl Model for AnthropicModel {
482    fn name(&self) -> &str {
483        &self.model
484    }
485
486    async fn request(
487        &self,
488        messages: &[ModelMessage],
489        settings: Option<&ModelSettings>,
490        params: &ModelRequestParameters,
491    ) -> Result<ModelResponse, ModelError> {
492        tracing::debug!(
493            model = %self.model,
494            tool_count = params.function_tools.len(),
495            output_schema = params.output_schema.is_some(),
496            "Anthropic request"
497        );
498        let (system, messages) = Self::split_system(messages);
499        let mut body = Map::new();
500        body.insert("model".to_string(), Value::String(self.model.clone()));
501        body.insert("messages".to_string(), Value::Array(messages));
502        if let Some(system) = system {
503            body.insert("system".to_string(), Value::String(system));
504        }
505
506        if !params.function_tools.is_empty() {
507            let tools = params
508                .function_tools
509                .iter()
510                .map(|tool| {
511                    json!({
512                        "name": tool.name,
513                        "description": tool.description,
514                        "input_schema": tool.parameters_json_schema,
515                    })
516                })
517                .collect();
518            body.insert("tools".to_string(), Value::Array(tools));
519            let mut tool_choice = json!({"type": "auto"});
520            if params.function_tools.iter().any(|tool| tool.sequential)
521                && let Value::Object(map) = &mut tool_choice
522            {
523                map.insert("disable_parallel_tool_use".to_string(), Value::Bool(true));
524            }
525            body.insert("tool_choice".to_string(), tool_choice);
526        }
527
528        if params.output_mode == OutputMode::JsonSchema
529            && let Some(schema) = params.output_schema.clone()
530        {
531            body.insert(
532                "output_format".to_string(),
533                json!({
534                    "type": "json_schema",
535                    "schema": schema
536                }),
537            );
538        }
539
540        if let Some(settings) = &self.default_settings {
541            for (key, value) in settings {
542                body.insert(key.clone(), value.clone());
543            }
544        }
545
546        if let Some(settings) = settings {
547            for (key, value) in settings {
548                body.insert(key.clone(), value.clone());
549            }
550        }
551
552        if !body.contains_key("max_tokens") {
553            body.insert("max_tokens".to_string(), Value::Number(1024.into()));
554        }
555
556        let mut request = self
557            .client
558            .post(self.endpoint()?)
559            .header("x-api-key", &self.api_key)
560            .header("anthropic-version", "2023-06-01");
561
562        if params.output_mode == OutputMode::JsonSchema && params.output_schema.is_some() {
563            request = request.header("anthropic-beta", "structured-outputs-2025-11-13");
564        }
565
566        let response = request
567            .json(&Value::Object(body))
568            .send()
569            .await
570            .map_err(|e| map_reqwest_error("Anthropic", e))?;
571
572        let status = response.status();
573        if !status.is_success() {
574            let body = response.text().await.unwrap_or_default();
575            tracing::error!(
576                status = status.as_u16(),
577                model = %self.model,
578                body = %truncate_error_body(&body),
579                "Anthropic request failed"
580            );
581            return Err(ModelError::HttpStatus {
582                status: status.as_u16(),
583            });
584        }
585
586        let body: AnthropicResponse = response.json().await.map_err(|e| {
587            tracing::error!(
588                error = %e,
589                model = %self.model,
590                "Anthropic response parse failed"
591            );
592            ModelError::Provider(format!("Anthropic response parse failed: {e}"))
593        })?;
594
595        let mut parts = Vec::new();
596        for content in body.content {
597            match content.kind.as_str() {
598                "text" => {
599                    if let Some(text) = content.text {
600                        parts.push(ModelResponsePart::Text(TextPart { content: text }));
601                    }
602                }
603                "tool_use" => {
604                    parts.push(ModelResponsePart::ToolCall(ToolCallPart {
605                        id: normalize_tool_call_id(content.id),
606                        name: content.name.unwrap_or_else(|| "tool".to_string()),
607                        arguments: content.input.unwrap_or_else(|| Value::Object(Map::new())),
608                    }));
609                }
610                _ => {}
611            }
612        }
613
614        let usage = body.usage.map(|usage| RequestUsage {
615            input_tokens: usage.input_tokens.unwrap_or(0),
616            output_tokens: usage.output_tokens.unwrap_or(0),
617            ..Default::default()
618        });
619
620        Ok(ModelResponse {
621            parts,
622            usage,
623            model_name: Some(self.model.clone()),
624            finish_reason: body.stop_reason,
625        })
626    }
627}
628
629#[derive(Debug, Deserialize)]
630struct AnthropicResponse {
631    content: Vec<AnthropicContent>,
632    stop_reason: Option<String>,
633    usage: Option<AnthropicUsage>,
634}
635
636#[derive(Debug, Deserialize)]
637struct AnthropicContent {
638    #[serde(rename = "type")]
639    kind: String,
640    text: Option<String>,
641    id: Option<String>,
642    name: Option<String>,
643    input: Option<Value>,
644}
645
646#[derive(Debug, Deserialize)]
647struct AnthropicUsage {
648    input_tokens: Option<u64>,
649    output_tokens: Option<u64>,
650}