1use async_trait::async_trait;
4use reqwest::Client;
5use serde::{Deserialize, Serialize};
6
7use super::{Model, ModelConfig, StreamEventStream};
8use crate::types::{
9 content::{ContentBlock, Message, Role, SystemContentBlock},
10 errors::StrandsError,
11 streaming::{
12 ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockDeltaToolUse, ContentBlockStart,
13 ContentBlockStartEvent, ContentBlockStartToolUse, ContentBlockStopEvent, MessageStartEvent,
14 MessageStopEvent, MetadataEvent, Metrics, StopReason, StreamEvent, Usage,
15 },
16 tools::{ToolChoice, ToolSpec},
17};
18
19const DEFAULT_MODEL_ID: &str = "claude-sonnet-4-20250514";
20const ANTHROPIC_API_URL: &str = "https://api.anthropic.com/v1/messages";
21const ANTHROPIC_VERSION: &str = "2023-06-01";
22
23#[derive(Clone)]
25pub struct AnthropicModel {
26 config: ModelConfig,
27 api_key: String,
28 max_tokens: u32,
29 client: Client,
30}
31
32impl std::fmt::Debug for AnthropicModel {
33 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34 f.debug_struct("AnthropicModel")
35 .field("config", &self.config)
36 .field("max_tokens", &self.max_tokens)
37 .finish()
38 }
39}
40
41#[derive(Debug, Serialize)]
42struct AnthropicRequest {
43 model: String,
44 messages: Vec<AnthropicMessage>,
45 max_tokens: u32,
46 stream: bool,
47 #[serde(skip_serializing_if = "Option::is_none")]
48 system: Option<String>,
49 #[serde(skip_serializing_if = "Option::is_none")]
50 temperature: Option<f32>,
51 #[serde(skip_serializing_if = "Option::is_none")]
52 top_p: Option<f32>,
53 #[serde(skip_serializing_if = "Vec::is_empty")]
54 tools: Vec<AnthropicTool>,
55 #[serde(skip_serializing_if = "Option::is_none")]
56 tool_choice: Option<serde_json::Value>,
57}
58
59#[derive(Debug, Serialize)]
60struct AnthropicMessage {
61 role: String,
62 content: Vec<AnthropicContent>,
63}
64
65#[derive(Debug, Serialize)]
66#[serde(untagged)]
67enum AnthropicContent {
68 Text { #[serde(rename = "type")] content_type: String, text: String },
69 ToolUse { #[serde(rename = "type")] content_type: String, id: String, name: String, input: serde_json::Value },
70 ToolResult { #[serde(rename = "type")] content_type: String, tool_use_id: String, content: Vec<AnthropicToolResultContent>, is_error: bool },
71}
72
73#[derive(Debug, Serialize)]
74struct AnthropicToolResultContent {
75 #[serde(rename = "type")]
76 content_type: String,
77 text: String,
78}
79
80#[derive(Debug, Serialize)]
81struct AnthropicTool {
82 name: String,
83 description: String,
84 input_schema: serde_json::Value,
85}
86
87#[derive(Debug, Deserialize)]
88struct AnthropicStreamEvent {
89 #[serde(rename = "type")]
90 event_type: String,
91 #[serde(default)]
92 index: Option<usize>,
93 #[serde(default)]
94 content_block: Option<AnthropicContentBlock>,
95 #[serde(default)]
96 delta: Option<AnthropicDelta>,
97 #[serde(default)]
98 message: Option<AnthropicMessageInfo>,
99 #[serde(default)]
100 usage: Option<AnthropicUsage>,
101}
102
103#[derive(Debug, Deserialize)]
104struct AnthropicContentBlock {
105 #[serde(rename = "type")]
106 block_type: String,
107 #[serde(default)]
108 id: Option<String>,
109 #[serde(default)]
110 name: Option<String>,
111}
112
113#[derive(Debug, Deserialize)]
114struct AnthropicDelta {
115 #[serde(rename = "type")]
116 delta_type: String,
117 #[serde(default)]
118 text: Option<String>,
119 #[serde(default)]
120 partial_json: Option<String>,
121}
122
123#[derive(Debug, Deserialize)]
124struct AnthropicMessageInfo {
125 #[serde(default)]
126 stop_reason: Option<String>,
127}
128
129#[derive(Debug, Deserialize)]
130struct AnthropicUsage {
131 input_tokens: u32,
132 output_tokens: u32,
133}
134
135impl AnthropicModel {
136 pub fn new(api_key: impl Into<String>, max_tokens: u32) -> Self {
137 Self {
138 config: ModelConfig::new(DEFAULT_MODEL_ID),
139 api_key: api_key.into(),
140 max_tokens,
141 client: Client::new(),
142 }
143 }
144
145 pub fn with_model(mut self, model_id: impl Into<String>) -> Self {
146 self.config.model_id = model_id.into();
147 self
148 }
149
150 pub fn with_config(mut self, config: ModelConfig) -> Self {
151 self.config = config;
152 self
153 }
154
155 fn format_messages(&self, messages: &[Message]) -> Vec<AnthropicMessage> {
156 messages
157 .iter()
158 .map(|msg| {
159 let role = match msg.role {
160 Role::User => "user",
161 Role::Assistant => "assistant",
162 };
163
164 let content: Vec<AnthropicContent> = msg
165 .content
166 .iter()
167 .filter_map(|block| self.format_content_block(block))
168 .collect();
169
170 AnthropicMessage {
171 role: role.to_string(),
172 content,
173 }
174 })
175 .collect()
176 }
177
178 fn format_content_block(&self, block: &ContentBlock) -> Option<AnthropicContent> {
179 if let Some(ref text) = block.text {
180 return Some(AnthropicContent::Text {
181 content_type: "text".to_string(),
182 text: text.clone(),
183 });
184 }
185
186 if let Some(ref tu) = block.tool_use {
187 return Some(AnthropicContent::ToolUse {
188 content_type: "tool_use".to_string(),
189 id: tu.tool_use_id.clone(),
190 name: tu.name.clone(),
191 input: tu.input.clone(),
192 });
193 }
194
195 if let Some(ref tr) = block.tool_result {
196 let content: Vec<AnthropicToolResultContent> = tr
197 .content
198 .iter()
199 .filter_map(|c| {
200 c.text.as_ref().map(|t| AnthropicToolResultContent {
201 content_type: "text".to_string(),
202 text: t.clone(),
203 })
204 })
205 .collect();
206
207 let is_error = tr.status == crate::types::tools::ToolResultStatus::Error;
208
209 return Some(AnthropicContent::ToolResult {
210 content_type: "tool_result".to_string(),
211 tool_use_id: tr.tool_use_id.clone(),
212 content,
213 is_error,
214 });
215 }
216
217 None
218 }
219
220 fn format_tools(&self, tool_specs: &[ToolSpec]) -> Vec<AnthropicTool> {
221 tool_specs
222 .iter()
223 .map(|spec| AnthropicTool {
224 name: spec.name.clone(),
225 description: spec.description.clone(),
226 input_schema: spec.input_schema.json.clone(),
227 })
228 .collect()
229 }
230
231 fn format_tool_choice(&self, tool_choice: Option<ToolChoice>) -> Option<serde_json::Value> {
232 tool_choice.map(|tc| match tc {
233 ToolChoice::Auto(_) => serde_json::json!({ "type": "auto" }),
234 ToolChoice::Any(_) => serde_json::json!({ "type": "any" }),
235 ToolChoice::Tool(t) => serde_json::json!({ "type": "tool", "name": t.name }),
236 })
237 }
238
239 fn map_stop_reason(reason: &str) -> StopReason {
240 match reason {
241 "tool_use" => StopReason::ToolUse,
242 "max_tokens" => StopReason::MaxTokens,
243 "end_turn" | "stop_sequence" => StopReason::EndTurn,
244 _ => StopReason::EndTurn,
245 }
246 }
247}
248
249#[async_trait]
250impl Model for AnthropicModel {
251 fn config(&self) -> &ModelConfig {
252 &self.config
253 }
254
255 fn update_config(&mut self, config: ModelConfig) {
256 self.config = config;
257 }
258
259 fn stream<'a>(
260 &'a self,
261 messages: &'a [Message],
262 tool_specs: Option<&'a [ToolSpec]>,
263 system_prompt: Option<&'a str>,
264 tool_choice: Option<ToolChoice>,
265 _system_prompt_content: Option<&'a [SystemContentBlock]>,
266 ) -> StreamEventStream<'a> {
267 let api_key = self.api_key.clone();
268 let client = self.client.clone();
269
270 let request = AnthropicRequest {
271 model: self.config.model_id.clone(),
272 messages: self.format_messages(messages),
273 max_tokens: self.max_tokens,
274 stream: true,
275 system: system_prompt.map(String::from),
276 temperature: self.config.temperature,
277 top_p: self.config.top_p,
278 tools: tool_specs.map(|s| self.format_tools(s)).unwrap_or_default(),
279 tool_choice: self.format_tool_choice(tool_choice),
280 };
281
282 Box::pin(async_stream::stream! {
283 let response = match client
284 .post(ANTHROPIC_API_URL)
285 .header("x-api-key", &api_key)
286 .header("anthropic-version", ANTHROPIC_VERSION)
287 .header("Content-Type", "application/json")
288 .json(&request)
289 .send()
290 .await
291 {
292 Ok(resp) => resp,
293 Err(e) => {
294 yield Err(StrandsError::NetworkError(e.to_string()));
295 return;
296 }
297 };
298
299 if !response.status().is_success() {
300 let status = response.status();
301 let body = response.text().await.unwrap_or_default();
302 if status.as_u16() == 429 {
303 yield Err(StrandsError::ModelThrottled { message: body });
304 } else if body.contains("prompt is too long") || body.contains("context") {
305 yield Err(StrandsError::ContextWindowOverflow { message: body });
306 } else {
307 yield Err(StrandsError::model_error(format!("HTTP {status}: {body}")));
308 }
309 return;
310 }
311
312 use futures::StreamExt;
313 let mut byte_stream = response.bytes_stream();
314 let mut buffer = String::new();
315 let mut final_usage: Option<AnthropicUsage> = None;
316 let mut stop_reason_str: Option<String> = None;
317
318 while let Some(chunk_result) = byte_stream.next().await {
319 let chunk = match chunk_result {
320 Ok(bytes) => String::from_utf8_lossy(&bytes).to_string(),
321 Err(e) => {
322 yield Err(StrandsError::NetworkError(e.to_string()));
323 return;
324 }
325 };
326
327 buffer.push_str(&chunk);
328
329 let lines: Vec<String> = buffer.lines().map(String::from).collect();
330 buffer.clear();
331
332 for line in &lines {
333 let line = line.trim();
334 if line.is_empty() {
335 continue;
336 }
337
338 if let Some(json_str) = line.strip_prefix("data: ") {
339 if let Ok(event) = serde_json::from_str::<AnthropicStreamEvent>(json_str) {
340 match event.event_type.as_str() {
341 "message_start" => {
342 yield Ok(StreamEvent {
343 message_start: Some(MessageStartEvent { role: Role::Assistant }),
344 ..Default::default()
345 });
346 }
347
348 "content_block_start" => {
349 let index = event.index.unwrap_or(0) as u32;
350 let start = event.content_block.as_ref().and_then(|cb| {
351 if cb.block_type == "tool_use" {
352 Some(ContentBlockStart {
353 tool_use: Some(ContentBlockStartToolUse {
354 name: cb.name.clone().unwrap_or_default(),
355 tool_use_id: cb.id.clone().unwrap_or_default(),
356 }),
357 })
358 } else {
359 None
360 }
361 });
362
363 yield Ok(StreamEvent {
364 content_block_start: Some(ContentBlockStartEvent {
365 content_block_index: Some(index),
366 start,
367 }),
368 ..Default::default()
369 });
370 }
371
372 "content_block_delta" => {
373 let index = event.index.unwrap_or(0) as u32;
374 if let Some(ref delta) = event.delta {
375 let block_delta = match delta.delta_type.as_str() {
376 "text_delta" => ContentBlockDelta {
377 text: delta.text.clone(),
378 ..Default::default()
379 },
380 "input_json_delta" => ContentBlockDelta {
381 tool_use: Some(ContentBlockDeltaToolUse {
382 input: delta.partial_json.clone().unwrap_or_default(),
383 }),
384 ..Default::default()
385 },
386 _ => ContentBlockDelta::default(),
387 };
388
389 yield Ok(StreamEvent {
390 content_block_delta: Some(ContentBlockDeltaEvent {
391 content_block_index: Some(index),
392 delta: Some(block_delta),
393 }),
394 ..Default::default()
395 });
396 }
397 }
398
399 "content_block_stop" => {
400 let index = event.index.unwrap_or(0) as u32;
401 yield Ok(StreamEvent {
402 content_block_stop: Some(ContentBlockStopEvent {
403 content_block_index: Some(index),
404 }),
405 ..Default::default()
406 });
407 }
408
409 "message_delta" => {
410 if let Some(ref usage) = event.usage {
411 final_usage = Some(AnthropicUsage {
412 input_tokens: usage.input_tokens,
413 output_tokens: usage.output_tokens,
414 });
415 }
416 if let Some(ref delta) = event.delta {
417 if let Some(ref text) = delta.text {
418 stop_reason_str = Some(text.clone());
419 }
420 }
421 }
422
423 "message_stop" => {
424 let reason = event.message
425 .as_ref()
426 .and_then(|m| m.stop_reason.as_ref())
427 .map(|s| Self::map_stop_reason(s))
428 .or_else(|| stop_reason_str.as_ref().map(|s| Self::map_stop_reason(s)))
429 .unwrap_or(StopReason::EndTurn);
430
431 yield Ok(StreamEvent {
432 message_stop: Some(MessageStopEvent {
433 stop_reason: Some(reason),
434 additional_model_response_fields: None,
435 }),
436 ..Default::default()
437 });
438 }
439
440 _ => {}
441 }
442 }
443 }
444 }
445 }
446
447 if let Some(usage) = final_usage {
448 yield Ok(StreamEvent {
449 metadata: Some(MetadataEvent {
450 usage: Some(Usage {
451 input_tokens: usage.input_tokens,
452 output_tokens: usage.output_tokens,
453 total_tokens: usage.input_tokens + usage.output_tokens,
454 cache_read_input_tokens: 0,
455 cache_write_input_tokens: 0,
456 }),
457 metrics: Some(Metrics {
458 latency_ms: 0,
459 time_to_first_byte_ms: 0,
460 }),
461 trace: None,
462 }),
463 ..Default::default()
464 });
465 }
466 })
467 }
468}
469
470#[cfg(test)]
471mod tests {
472 use super::*;
473
474 #[test]
475 fn test_anthropic_model_creation() {
476 let model = AnthropicModel::new("test-key", 4096).with_model("claude-3-opus-20240229");
477 assert_eq!(model.config().model_id, "claude-3-opus-20240229");
478 assert_eq!(model.max_tokens, 4096);
479 }
480}
481