strands_agents/models/
anthropic.rs

1//! Anthropic Claude model provider.
2
3use async_trait::async_trait;
4use reqwest::Client;
5use serde::{Deserialize, Serialize};
6
7use super::{Model, ModelConfig, StreamEventStream};
8use crate::types::{
9    content::{ContentBlock, Message, Role, SystemContentBlock},
10    errors::StrandsError,
11    streaming::{
12        ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockDeltaToolUse, ContentBlockStart,
13        ContentBlockStartEvent, ContentBlockStartToolUse, ContentBlockStopEvent, MessageStartEvent,
14        MessageStopEvent, MetadataEvent, Metrics, StopReason, StreamEvent, Usage,
15    },
16    tools::{ToolChoice, ToolSpec},
17};
18
19const DEFAULT_MODEL_ID: &str = "claude-sonnet-4-20250514";
20const ANTHROPIC_API_URL: &str = "https://api.anthropic.com/v1/messages";
21const ANTHROPIC_VERSION: &str = "2023-06-01";
22
23/// Anthropic Claude model provider.
24#[derive(Clone)]
25pub struct AnthropicModel {
26    config: ModelConfig,
27    api_key: String,
28    max_tokens: u32,
29    client: Client,
30}
31
32impl std::fmt::Debug for AnthropicModel {
33    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34        f.debug_struct("AnthropicModel")
35            .field("config", &self.config)
36            .field("max_tokens", &self.max_tokens)
37            .finish()
38    }
39}
40
41#[derive(Debug, Serialize)]
42struct AnthropicRequest {
43    model: String,
44    messages: Vec<AnthropicMessage>,
45    max_tokens: u32,
46    stream: bool,
47    #[serde(skip_serializing_if = "Option::is_none")]
48    system: Option<String>,
49    #[serde(skip_serializing_if = "Option::is_none")]
50    temperature: Option<f32>,
51    #[serde(skip_serializing_if = "Option::is_none")]
52    top_p: Option<f32>,
53    #[serde(skip_serializing_if = "Vec::is_empty")]
54    tools: Vec<AnthropicTool>,
55    #[serde(skip_serializing_if = "Option::is_none")]
56    tool_choice: Option<serde_json::Value>,
57}
58
59#[derive(Debug, Serialize)]
60struct AnthropicMessage {
61    role: String,
62    content: Vec<AnthropicContent>,
63}
64
65#[derive(Debug, Serialize)]
66#[serde(untagged)]
67enum AnthropicContent {
68    Text { #[serde(rename = "type")] content_type: String, text: String },
69    ToolUse { #[serde(rename = "type")] content_type: String, id: String, name: String, input: serde_json::Value },
70    ToolResult { #[serde(rename = "type")] content_type: String, tool_use_id: String, content: Vec<AnthropicToolResultContent>, is_error: bool },
71}
72
73#[derive(Debug, Serialize)]
74struct AnthropicToolResultContent {
75    #[serde(rename = "type")]
76    content_type: String,
77    text: String,
78}
79
80#[derive(Debug, Serialize)]
81struct AnthropicTool {
82    name: String,
83    description: String,
84    input_schema: serde_json::Value,
85}
86
87#[derive(Debug, Deserialize)]
88struct AnthropicStreamEvent {
89    #[serde(rename = "type")]
90    event_type: String,
91    #[serde(default)]
92    index: Option<usize>,
93    #[serde(default)]
94    content_block: Option<AnthropicContentBlock>,
95    #[serde(default)]
96    delta: Option<AnthropicDelta>,
97    #[serde(default)]
98    message: Option<AnthropicMessageInfo>,
99    #[serde(default)]
100    usage: Option<AnthropicUsage>,
101}
102
103#[derive(Debug, Deserialize)]
104struct AnthropicContentBlock {
105    #[serde(rename = "type")]
106    block_type: String,
107    #[serde(default)]
108    id: Option<String>,
109    #[serde(default)]
110    name: Option<String>,
111}
112
113#[derive(Debug, Deserialize)]
114struct AnthropicDelta {
115    #[serde(rename = "type")]
116    delta_type: String,
117    #[serde(default)]
118    text: Option<String>,
119    #[serde(default)]
120    partial_json: Option<String>,
121}
122
123#[derive(Debug, Deserialize)]
124struct AnthropicMessageInfo {
125    #[serde(default)]
126    stop_reason: Option<String>,
127}
128
129#[derive(Debug, Deserialize)]
130struct AnthropicUsage {
131    input_tokens: u32,
132    output_tokens: u32,
133}
134
135impl AnthropicModel {
136    pub fn new(api_key: impl Into<String>, max_tokens: u32) -> Self {
137        Self {
138            config: ModelConfig::new(DEFAULT_MODEL_ID),
139            api_key: api_key.into(),
140            max_tokens,
141            client: Client::new(),
142        }
143    }
144
145    pub fn with_model(mut self, model_id: impl Into<String>) -> Self {
146        self.config.model_id = model_id.into();
147        self
148    }
149
150    pub fn with_config(mut self, config: ModelConfig) -> Self {
151        self.config = config;
152        self
153    }
154
155    fn format_messages(&self, messages: &[Message]) -> Vec<AnthropicMessage> {
156        messages
157            .iter()
158            .map(|msg| {
159                let role = match msg.role {
160                    Role::User => "user",
161                    Role::Assistant => "assistant",
162                };
163
164                let content: Vec<AnthropicContent> = msg
165                    .content
166                    .iter()
167                    .filter_map(|block| self.format_content_block(block))
168                    .collect();
169
170                AnthropicMessage {
171                    role: role.to_string(),
172                    content,
173                }
174            })
175            .collect()
176    }
177
178    fn format_content_block(&self, block: &ContentBlock) -> Option<AnthropicContent> {
179        if let Some(ref text) = block.text {
180            return Some(AnthropicContent::Text {
181                content_type: "text".to_string(),
182                text: text.clone(),
183            });
184        }
185
186        if let Some(ref tu) = block.tool_use {
187            return Some(AnthropicContent::ToolUse {
188                content_type: "tool_use".to_string(),
189                id: tu.tool_use_id.clone(),
190                name: tu.name.clone(),
191                input: tu.input.clone(),
192            });
193        }
194
195        if let Some(ref tr) = block.tool_result {
196            let content: Vec<AnthropicToolResultContent> = tr
197                .content
198                .iter()
199                .filter_map(|c| {
200                    c.text.as_ref().map(|t| AnthropicToolResultContent {
201                        content_type: "text".to_string(),
202                        text: t.clone(),
203                    })
204                })
205                .collect();
206
207            let is_error = tr.status == crate::types::tools::ToolResultStatus::Error;
208
209            return Some(AnthropicContent::ToolResult {
210                content_type: "tool_result".to_string(),
211                tool_use_id: tr.tool_use_id.clone(),
212                content,
213                is_error,
214            });
215        }
216
217        None
218    }
219
220    fn format_tools(&self, tool_specs: &[ToolSpec]) -> Vec<AnthropicTool> {
221        tool_specs
222            .iter()
223            .map(|spec| AnthropicTool {
224                name: spec.name.clone(),
225                description: spec.description.clone(),
226                input_schema: spec.input_schema.json.clone(),
227            })
228            .collect()
229    }
230
231    fn format_tool_choice(&self, tool_choice: Option<ToolChoice>) -> Option<serde_json::Value> {
232        tool_choice.map(|tc| match tc {
233            ToolChoice::Auto(_) => serde_json::json!({ "type": "auto" }),
234            ToolChoice::Any(_) => serde_json::json!({ "type": "any" }),
235            ToolChoice::Tool(t) => serde_json::json!({ "type": "tool", "name": t.name }),
236        })
237    }
238
239    fn map_stop_reason(reason: &str) -> StopReason {
240        match reason {
241            "tool_use" => StopReason::ToolUse,
242            "max_tokens" => StopReason::MaxTokens,
243            "end_turn" | "stop_sequence" => StopReason::EndTurn,
244            _ => StopReason::EndTurn,
245        }
246    }
247}
248
249#[async_trait]
250impl Model for AnthropicModel {
251    fn config(&self) -> &ModelConfig {
252        &self.config
253    }
254
255    fn update_config(&mut self, config: ModelConfig) {
256        self.config = config;
257    }
258
259    fn stream<'a>(
260        &'a self,
261        messages: &'a [Message],
262        tool_specs: Option<&'a [ToolSpec]>,
263        system_prompt: Option<&'a str>,
264        tool_choice: Option<ToolChoice>,
265        _system_prompt_content: Option<&'a [SystemContentBlock]>,
266    ) -> StreamEventStream<'a> {
267        let api_key = self.api_key.clone();
268        let client = self.client.clone();
269
270        let request = AnthropicRequest {
271            model: self.config.model_id.clone(),
272            messages: self.format_messages(messages),
273            max_tokens: self.max_tokens,
274            stream: true,
275            system: system_prompt.map(String::from),
276            temperature: self.config.temperature,
277            top_p: self.config.top_p,
278            tools: tool_specs.map(|s| self.format_tools(s)).unwrap_or_default(),
279            tool_choice: self.format_tool_choice(tool_choice),
280        };
281
282        Box::pin(async_stream::stream! {
283            let response = match client
284                .post(ANTHROPIC_API_URL)
285                .header("x-api-key", &api_key)
286                .header("anthropic-version", ANTHROPIC_VERSION)
287                .header("Content-Type", "application/json")
288                .json(&request)
289                .send()
290                .await
291            {
292                Ok(resp) => resp,
293                Err(e) => {
294                    yield Err(StrandsError::NetworkError(e.to_string()));
295                    return;
296                }
297            };
298
299            if !response.status().is_success() {
300                let status = response.status();
301                let body = response.text().await.unwrap_or_default();
302                if status.as_u16() == 429 {
303                    yield Err(StrandsError::ModelThrottled { message: body });
304                } else if body.contains("prompt is too long") || body.contains("context") {
305                    yield Err(StrandsError::ContextWindowOverflow { message: body });
306                } else {
307                    yield Err(StrandsError::model_error(format!("HTTP {status}: {body}")));
308                }
309                return;
310            }
311
312            use futures::StreamExt;
313            let mut byte_stream = response.bytes_stream();
314            let mut buffer = String::new();
315            let mut final_usage: Option<AnthropicUsage> = None;
316            let mut stop_reason_str: Option<String> = None;
317
318            while let Some(chunk_result) = byte_stream.next().await {
319                let chunk = match chunk_result {
320                    Ok(bytes) => String::from_utf8_lossy(&bytes).to_string(),
321                    Err(e) => {
322                        yield Err(StrandsError::NetworkError(e.to_string()));
323                        return;
324                    }
325                };
326
327                buffer.push_str(&chunk);
328
329                let lines: Vec<String> = buffer.lines().map(String::from).collect();
330                buffer.clear();
331
332                for line in &lines {
333                    let line = line.trim();
334                    if line.is_empty() {
335                        continue;
336                    }
337
338                    if let Some(json_str) = line.strip_prefix("data: ") {
339                        if let Ok(event) = serde_json::from_str::<AnthropicStreamEvent>(json_str) {
340                            match event.event_type.as_str() {
341                                "message_start" => {
342                                    yield Ok(StreamEvent {
343                                        message_start: Some(MessageStartEvent { role: Role::Assistant }),
344                                        ..Default::default()
345                                    });
346                                }
347
348                                "content_block_start" => {
349                                    let index = event.index.unwrap_or(0) as u32;
350                                    let start = event.content_block.as_ref().and_then(|cb| {
351                                        if cb.block_type == "tool_use" {
352                                            Some(ContentBlockStart {
353                                                tool_use: Some(ContentBlockStartToolUse {
354                                                    name: cb.name.clone().unwrap_or_default(),
355                                                    tool_use_id: cb.id.clone().unwrap_or_default(),
356                                                }),
357                                            })
358                                        } else {
359                                            None
360                                        }
361                                    });
362
363                                    yield Ok(StreamEvent {
364                                        content_block_start: Some(ContentBlockStartEvent {
365                                            content_block_index: Some(index),
366                                            start,
367                                        }),
368                                        ..Default::default()
369                                    });
370                                }
371
372                                "content_block_delta" => {
373                                    let index = event.index.unwrap_or(0) as u32;
374                                    if let Some(ref delta) = event.delta {
375                                        let block_delta = match delta.delta_type.as_str() {
376                                            "text_delta" => ContentBlockDelta {
377                                                text: delta.text.clone(),
378                                                ..Default::default()
379                                            },
380                                            "input_json_delta" => ContentBlockDelta {
381                                                tool_use: Some(ContentBlockDeltaToolUse {
382                                                    input: delta.partial_json.clone().unwrap_or_default(),
383                                                }),
384                                                ..Default::default()
385                                            },
386                                            _ => ContentBlockDelta::default(),
387                                        };
388
389                                        yield Ok(StreamEvent {
390                                            content_block_delta: Some(ContentBlockDeltaEvent {
391                                                content_block_index: Some(index),
392                                                delta: Some(block_delta),
393                                            }),
394                                            ..Default::default()
395                                        });
396                                    }
397                                }
398
399                                "content_block_stop" => {
400                                    let index = event.index.unwrap_or(0) as u32;
401                                    yield Ok(StreamEvent {
402                                        content_block_stop: Some(ContentBlockStopEvent {
403                                            content_block_index: Some(index),
404                                        }),
405                                        ..Default::default()
406                                    });
407                                }
408
409                                "message_delta" => {
410                                    if let Some(ref usage) = event.usage {
411                                        final_usage = Some(AnthropicUsage {
412                                            input_tokens: usage.input_tokens,
413                                            output_tokens: usage.output_tokens,
414                                        });
415                                    }
416                                    if let Some(ref delta) = event.delta {
417                                        if let Some(ref text) = delta.text {
418                                            stop_reason_str = Some(text.clone());
419                                        }
420                                    }
421                                }
422
423                                "message_stop" => {
424                                    let reason = event.message
425                                        .as_ref()
426                                        .and_then(|m| m.stop_reason.as_ref())
427                                        .map(|s| Self::map_stop_reason(s))
428                                        .or_else(|| stop_reason_str.as_ref().map(|s| Self::map_stop_reason(s)))
429                                        .unwrap_or(StopReason::EndTurn);
430
431                                    yield Ok(StreamEvent {
432                                        message_stop: Some(MessageStopEvent {
433                                            stop_reason: Some(reason),
434                                            additional_model_response_fields: None,
435                                        }),
436                                        ..Default::default()
437                                    });
438                                }
439
440                                _ => {}
441                            }
442                        }
443                    }
444                }
445            }
446
447            if let Some(usage) = final_usage {
448                yield Ok(StreamEvent {
449                    metadata: Some(MetadataEvent {
450                        usage: Some(Usage {
451                            input_tokens: usage.input_tokens,
452                            output_tokens: usage.output_tokens,
453                            total_tokens: usage.input_tokens + usage.output_tokens,
454                            cache_read_input_tokens: 0,
455                            cache_write_input_tokens: 0,
456                        }),
457                        metrics: Some(Metrics {
458                            latency_ms: 0,
459                            time_to_first_byte_ms: 0,
460                        }),
461                        trace: None,
462                    }),
463                    ..Default::default()
464                });
465            }
466        })
467    }
468}
469
470#[cfg(test)]
471mod tests {
472    use super::*;
473
474    #[test]
475    fn test_anthropic_model_creation() {
476        let model = AnthropicModel::new("test-key", 4096).with_model("claude-3-opus-20240229");
477        assert_eq!(model.config().model_id, "claude-3-opus-20240229");
478        assert_eq!(model.max_tokens, 4096);
479    }
480}
481