1use async_trait::async_trait;
2use futures::stream::{self, StreamExt};
3use reqwest::Client;
4use serde_json::json;
5use std::collections::HashMap;
6
7use super::{Brain, BrainEvent, BrainRequest, BrainStream, ContentBlock, LatencyClass, ModelCaps};
8
9pub struct AnthropicAdapter {
10 model: String,
11 api_key: String,
12 base_url: String,
13 client: Client,
14 caps: ModelCaps,
15}
16
17impl AnthropicAdapter {
18 pub fn new(model: &str, api_key: impl Into<String>, base_url: Option<&str>) -> Self {
19 let model = model.to_string();
20 let caps = Self::model_caps(&model);
21 Self {
22 model,
23 api_key: api_key.into(),
24 base_url: base_url.unwrap_or("https://api.anthropic.com").to_string(),
25 client: Client::new(),
26 caps,
27 }
28 }
29
30 pub fn with_caps(mut self, caps: ModelCaps) -> Self {
31 self.caps = caps;
32 self
33 }
34
35 fn model_caps(model: &str) -> ModelCaps {
36 if model.contains("opus") {
37 ModelCaps {
38 context_window: 200_000,
39 max_output: 32_000,
40 tools: true,
41 vision: true,
42 cost_input_per_mtok: 15.0,
43 cost_output_per_mtok: 75.0,
44 latency: LatencyClass::Slow,
45 }
46 } else if model.contains("sonnet") {
47 ModelCaps {
48 context_window: 200_000,
49 max_output: 16_000,
50 tools: true,
51 vision: true,
52 cost_input_per_mtok: 3.0,
53 cost_output_per_mtok: 15.0,
54 latency: LatencyClass::Medium,
55 }
56 } else {
57 ModelCaps {
59 context_window: 200_000,
60 max_output: 8_000,
61 tools: true,
62 vision: true,
63 cost_input_per_mtok: 0.8,
64 cost_output_per_mtok: 4.0,
65 latency: LatencyClass::Fast,
66 }
67 }
68 }
69}
70
71fn cache_control_value(req: &BrainRequest) -> serde_json::Value {
72 json!({
73 "type": "ephemeral",
74 "ttl": req.cache.ttl.anthropic_ttl(),
75 })
76}
77
78fn text_block(text: &str, cache_control: Option<serde_json::Value>) -> serde_json::Value {
79 let mut block = json!({"type": "text", "text": text});
80 if let Some(cache_control) = cache_control {
81 block["cache_control"] = cache_control;
82 }
83 block
84}
85
86fn build_messages_body(model: &str, req: &BrainRequest) -> serde_json::Value {
87 let system: Option<String> = req.system.clone();
88 let mut messages = Vec::new();
89
90 for msg in &req.messages {
92 let mut content: Vec<serde_json::Value> = Vec::new();
93
94 for block in &msg.content {
95 match block {
96 ContentBlock::Text { text } => {
97 content.push(text_block(text, None));
98 }
99 ContentBlock::Image { source } => match source {
100 super::ImageSource::Base64 { media_type, data } => {
101 content.push(json!({
102 "type": "image",
103 "source": {
104 "type": "base64",
105 "media_type": media_type,
106 "data": data,
107 }
108 }));
109 }
110 super::ImageSource::Url { url } => {
111 content.push(json!({
112 "type": "image",
113 "source": {
114 "type": "url",
115 "url": url,
116 }
117 }));
118 }
119 },
120 ContentBlock::ToolResult {
121 tool_use_id,
122 content: tool_content,
123 is_error,
124 } => {
125 let inner: Vec<serde_json::Value> = tool_content
126 .iter()
127 .map(|b| match b {
128 ContentBlock::Text { text } => text_block(text, None),
129 _ => json!({"type": "text", "text": format!("{:?}", b)}),
130 })
131 .collect();
132 let mut val = json!({
133 "type": "tool_result",
134 "tool_use_id": tool_use_id,
135 "content": inner,
136 });
137 if let Some(true) = is_error {
138 val["is_error"] = json!(true);
139 }
140 content.push(val);
141 }
142 ContentBlock::ToolUse { .. } => {}
143 ContentBlock::Reasoning { .. } => {}
148 }
149 }
150
151 messages.push(json!({
152 "role": msg.role,
153 "content": content,
154 }));
155 }
156
157 let tools: Vec<serde_json::Value> = if req.tools.is_empty() {
159 vec![]
160 } else {
161 req.tools
162 .iter()
163 .map(|t| {
164 json!({
165 "name": t.name,
166 "description": t.description,
167 "input_schema": t.input_schema,
168 })
169 })
170 .collect()
171 };
172
173 let mut body = json!({
174 "model": model,
175 "max_tokens": req.max_tokens,
176 "temperature": req.temperature,
177 "messages": messages,
178 "stream": true,
179 });
180
181 if let Some(sys) = &system {
182 body["system"] = if req.cache.enabled {
183 json!([text_block(sys, Some(cache_control_value(req)))])
184 } else {
185 json!(sys)
186 };
187 }
188 if !tools.is_empty() {
189 body["tools"] = json!(tools);
190 }
191 if !req.stop.is_empty() {
192 body["stop_sequences"] = json!(req.stop);
193 }
194
195 body
196}
197
198#[async_trait]
199impl Brain for AnthropicAdapter {
200 fn id(&self) -> &str {
201 &self.model
202 }
203
204 fn caps(&self) -> ModelCaps {
205 self.caps.clone()
206 }
207
208 async fn complete(&self, req: BrainRequest) -> anyhow::Result<BrainStream> {
209 let body = build_messages_body(&self.model, &req);
210
211 let response = self
212 .client
213 .post(format!("{}/v1/messages", self.base_url))
214 .header("x-api-key", &self.api_key)
215 .header("anthropic-version", "2023-06-01")
216 .json(&body)
217 .send()
218 .await?;
219
220 if !response.status().is_success() {
221 let status = response.status().as_u16();
222 let body = response.text().await.unwrap_or_default();
223 return Err(anyhow::anyhow!("Anthropic API error {}: {}", status, body));
224 }
225
226 let stream = response.bytes_stream();
227 let model = self.model.clone();
228
229 struct AnthropicSse {
232 tools: HashMap<u64, String>,
233 lines: super::sse_buffer::LineBuffer,
234 }
235 let event_stream = stream
236 .scan(
237 AnthropicSse {
238 tools: HashMap::new(),
239 lines: super::sse_buffer::LineBuffer::new(),
240 },
241 move |state, chunk| {
242 let _model = model.clone();
243 let events = match chunk {
244 Ok(bytes) => {
245 let lines = state.lines.push(&bytes);
246 let tool_ids = &mut state.tools;
247 let mut events = Vec::new();
248 for line in lines {
249 let line = line.trim();
250 if line.is_empty() || !line.starts_with("data: ") {
251 continue;
252 }
253 let data = &line[6..]; let event: serde_json::Value = match serde_json::from_str(data) {
255 Ok(v) => v,
256 Err(_) => continue,
257 };
258
259 let event_type = event["type"].as_str().unwrap_or("");
260 match event_type {
261 "content_block_start" => {
262 let index = event["index"].as_u64().unwrap_or(0);
263 let content_type =
264 event["content_block"]["type"].as_str().unwrap_or("");
265 if content_type == "tool_use" {
266 let id = event["content_block"]["id"]
267 .as_str()
268 .unwrap_or("")
269 .to_string();
270 let name = event["content_block"]["name"]
271 .as_str()
272 .unwrap_or("")
273 .to_string();
274 if !id.is_empty() {
275 tool_ids.insert(index, id.clone());
276 }
277 events.push(BrainEvent::ToolUseStart { id, name });
278 }
279 }
280 "content_block_delta" => {
281 let delta_type =
282 event["delta"]["type"].as_str().unwrap_or("");
283 if delta_type == "text_delta" {
284 let text = event["delta"]["text"]
285 .as_str()
286 .unwrap_or("")
287 .to_string();
288 events.push(BrainEvent::TextDelta(text));
289 } else if delta_type == "input_json_delta" {
290 let partial = event["delta"]["partial_json"]
291 .as_str()
292 .unwrap_or("")
293 .to_string();
294 let index = event["index"].as_u64().unwrap_or(0);
295 let id = tool_ids
296 .get(&index)
297 .cloned()
298 .unwrap_or_else(|| index.to_string());
299 events.push(BrainEvent::ToolUseDelta {
300 id,
301 json: partial,
302 });
303 }
304 }
305 "content_block_stop" => {
306 let index = event["index"].as_u64().unwrap_or(0);
307 let id = tool_ids
308 .remove(&index)
309 .unwrap_or_else(|| index.to_string());
310 events.push(BrainEvent::ToolUseEnd { id });
311 }
312 "message_delta" => {
313 if let Some(usage) = event["usage"].as_object() {
314 events.push(BrainEvent::Usage(
315 crate::event::TokenUsage {
316 input: usage["input_tokens"]
317 .as_u64()
318 .unwrap_or(0),
319 output: usage["output_tokens"]
320 .as_u64()
321 .unwrap_or(0),
322 },
323 ));
324 }
325 let stop_reason = event["delta"]["stop_reason"]
326 .as_str()
327 .unwrap_or("end_turn");
328 let reason = match stop_reason {
329 "end_turn" => crate::event::StopReason::EndTurn,
330 "max_tokens" => crate::event::StopReason::MaxTokens,
331 "tool_use" => crate::event::StopReason::ToolUse,
332 s => crate::event::StopReason::StopSequence(
333 s.to_string(),
334 ),
335 };
336 events.push(BrainEvent::Done(reason));
337 }
338 "message_stop" => {}
339 _ => {}
340 }
341 }
342 events
343 }
344 Err(e) => {
345 vec![BrainEvent::Error(format!("stream error: {}", e))]
346 }
347 };
348 futures::future::ready(Some(stream::iter(events)))
349 },
350 )
351 .flatten();
352
353 Ok(Box::pin(event_stream))
354 }
355}
356
357#[cfg(test)]
358mod tests {
359 use super::*;
360 use crate::provider::{Msg, PromptCacheConfig, PromptCacheTtl};
361
362 #[test]
363 fn anthropic_system_prompt_gets_cache_control() {
364 let req = BrainRequest {
365 system: Some("stable sparrow system".into()),
366 messages: vec![Msg {
367 role: "user".into(),
368 content: vec![ContentBlock::Text {
369 text: "dynamic task".into(),
370 }],
371 }],
372 cache: PromptCacheConfig {
373 enabled: true,
374 ttl: PromptCacheTtl::OneHour,
375 key: Some("repo-key".into()),
376 },
377 ..BrainRequest::default()
378 };
379
380 let body = build_messages_body("claude-test", &req);
381 assert_eq!(
382 body["system"][0]["cache_control"],
383 json!({"type":"ephemeral","ttl":"1h"})
384 );
385 assert!(body["messages"][0]["content"][0]["cache_control"].is_null());
386 }
387}