Skip to main content

sparrow/provider/
anthropic.rs

1use async_trait::async_trait;
2use futures::stream::{self, StreamExt};
3use reqwest::Client;
4use serde_json::json;
5use std::collections::HashMap;
6
7use super::{Brain, BrainEvent, BrainRequest, BrainStream, ContentBlock, LatencyClass, ModelCaps};
8
9pub struct AnthropicAdapter {
10    model: String,
11    api_key: String,
12    base_url: String,
13    client: Client,
14    caps: ModelCaps,
15}
16
17impl AnthropicAdapter {
18    pub fn new(model: &str, api_key: impl Into<String>, base_url: Option<&str>) -> Self {
19        let model = model.to_string();
20        let caps = Self::model_caps(&model);
21        Self {
22            model,
23            api_key: api_key.into(),
24            base_url: base_url.unwrap_or("https://api.anthropic.com").to_string(),
25            client: Client::new(),
26            caps,
27        }
28    }
29
30    pub fn with_caps(mut self, caps: ModelCaps) -> Self {
31        self.caps = caps;
32        self
33    }
34
35    fn model_caps(model: &str) -> ModelCaps {
36        if model.contains("opus") {
37            ModelCaps {
38                context_window: 200_000,
39                max_output: 32_000,
40                tools: true,
41                vision: true,
42                cost_input_per_mtok: 15.0,
43                cost_output_per_mtok: 75.0,
44                latency: LatencyClass::Slow,
45            }
46        } else if model.contains("sonnet") {
47            ModelCaps {
48                context_window: 200_000,
49                max_output: 16_000,
50                tools: true,
51                vision: true,
52                cost_input_per_mtok: 3.0,
53                cost_output_per_mtok: 15.0,
54                latency: LatencyClass::Medium,
55            }
56        } else {
57            // haiku
58            ModelCaps {
59                context_window: 200_000,
60                max_output: 8_000,
61                tools: true,
62                vision: true,
63                cost_input_per_mtok: 0.8,
64                cost_output_per_mtok: 4.0,
65                latency: LatencyClass::Fast,
66            }
67        }
68    }
69}
70
71fn cache_control_value(req: &BrainRequest) -> serde_json::Value {
72    json!({
73        "type": "ephemeral",
74        "ttl": req.cache.ttl.anthropic_ttl(),
75    })
76}
77
78fn text_block(text: &str, cache_control: Option<serde_json::Value>) -> serde_json::Value {
79    let mut block = json!({"type": "text", "text": text});
80    if let Some(cache_control) = cache_control {
81        block["cache_control"] = cache_control;
82    }
83    block
84}
85
86fn build_messages_body(model: &str, req: &BrainRequest) -> serde_json::Value {
87    let system: Option<String> = req.system.clone();
88    let mut messages = Vec::new();
89
90    // Build Anthropic-formatted messages from our Msg vec
91    for msg in &req.messages {
92        let mut content: Vec<serde_json::Value> = Vec::new();
93
94        for block in &msg.content {
95            match block {
96                ContentBlock::Text { text } => {
97                    content.push(text_block(text, None));
98                }
99                ContentBlock::Image { source } => match source {
100                    super::ImageSource::Base64 { media_type, data } => {
101                        content.push(json!({
102                            "type": "image",
103                            "source": {
104                                "type": "base64",
105                                "media_type": media_type,
106                                "data": data,
107                            }
108                        }));
109                    }
110                    super::ImageSource::Url { url } => {
111                        content.push(json!({
112                            "type": "image",
113                            "source": {
114                                "type": "url",
115                                "url": url,
116                            }
117                        }));
118                    }
119                },
120                ContentBlock::ToolResult {
121                    tool_use_id,
122                    content: tool_content,
123                    is_error,
124                } => {
125                    let inner: Vec<serde_json::Value> = tool_content
126                        .iter()
127                        .map(|b| match b {
128                            ContentBlock::Text { text } => text_block(text, None),
129                            _ => json!({"type": "text", "text": format!("{:?}", b)}),
130                        })
131                        .collect();
132                    let mut val = json!({
133                        "type": "tool_result",
134                        "tool_use_id": tool_use_id,
135                        "content": inner,
136                    });
137                    if let Some(true) = is_error {
138                        val["is_error"] = json!(true);
139                    }
140                    content.push(val);
141                }
142                ContentBlock::ToolUse { .. } => {}
143                // Anthropic doesn't use the openai-style `reasoning_content`
144                // field; thinking content is handled via the separate
145                // `thinking` API parameter. Drop reasoning blocks here so
146                // they don't leak as text — they're transcript-only.
147                ContentBlock::Reasoning { .. } => {}
148            }
149        }
150
151        messages.push(json!({
152            "role": msg.role,
153            "content": content,
154        }));
155    }
156
157    // Build tools
158    let tools: Vec<serde_json::Value> = if req.tools.is_empty() {
159        vec![]
160    } else {
161        req.tools
162            .iter()
163            .map(|t| {
164                json!({
165                    "name": t.name,
166                    "description": t.description,
167                    "input_schema": t.input_schema,
168                })
169            })
170            .collect()
171    };
172
173    let mut body = json!({
174        "model": model,
175        "max_tokens": req.max_tokens,
176        "temperature": req.temperature,
177        "messages": messages,
178        "stream": true,
179    });
180
181    if let Some(sys) = &system {
182        body["system"] = if req.cache.enabled {
183            json!([text_block(sys, Some(cache_control_value(req)))])
184        } else {
185            json!(sys)
186        };
187    }
188    if !tools.is_empty() {
189        body["tools"] = json!(tools);
190    }
191    if !req.stop.is_empty() {
192        body["stop_sequences"] = json!(req.stop);
193    }
194
195    body
196}
197
198#[async_trait]
199impl Brain for AnthropicAdapter {
200    fn id(&self) -> &str {
201        &self.model
202    }
203
204    fn caps(&self) -> ModelCaps {
205        self.caps.clone()
206    }
207
208    async fn complete(&self, req: BrainRequest) -> anyhow::Result<BrainStream> {
209        let body = build_messages_body(&self.model, &req);
210
211        let response = self
212            .client
213            .post(format!("{}/v1/messages", self.base_url))
214            .header("x-api-key", &self.api_key)
215            .header("anthropic-version", "2023-06-01")
216            .json(&body)
217            .send()
218            .await?;
219
220        if !response.status().is_success() {
221            let status = response.status().as_u16();
222            let body = response.text().await.unwrap_or_default();
223            return Err(anyhow::anyhow!("Anthropic API error {}: {}", status, body));
224        }
225
226        let stream = response.bytes_stream();
227        let model = self.model.clone();
228
229        // Tool-id map per content-block index + line buffer that survives
230        // chunk boundaries (see provider/sse_buffer.rs).
231        struct AnthropicSse {
232            tools: HashMap<u64, String>,
233            lines: super::sse_buffer::LineBuffer,
234        }
235        let event_stream = stream
236            .scan(
237                AnthropicSse {
238                    tools: HashMap::new(),
239                    lines: super::sse_buffer::LineBuffer::new(),
240                },
241                move |state, chunk| {
242                    let _model = model.clone();
243                    let events = match chunk {
244                        Ok(bytes) => {
245                            let lines = state.lines.push(&bytes);
246                            let tool_ids = &mut state.tools;
247                            let mut events = Vec::new();
248                            for line in lines {
249                                let line = line.trim();
250                                if line.is_empty() || !line.starts_with("data: ") {
251                                    continue;
252                                }
253                                let data = &line[6..]; // Strip "data: "
254                                let event: serde_json::Value = match serde_json::from_str(data) {
255                                    Ok(v) => v,
256                                    Err(_) => continue,
257                                };
258
259                                let event_type = event["type"].as_str().unwrap_or("");
260                                match event_type {
261                                    "content_block_start" => {
262                                        let index = event["index"].as_u64().unwrap_or(0);
263                                        let content_type =
264                                            event["content_block"]["type"].as_str().unwrap_or("");
265                                        if content_type == "tool_use" {
266                                            let id = event["content_block"]["id"]
267                                                .as_str()
268                                                .unwrap_or("")
269                                                .to_string();
270                                            let name = event["content_block"]["name"]
271                                                .as_str()
272                                                .unwrap_or("")
273                                                .to_string();
274                                            if !id.is_empty() {
275                                                tool_ids.insert(index, id.clone());
276                                            }
277                                            events.push(BrainEvent::ToolUseStart { id, name });
278                                        }
279                                    }
280                                    "content_block_delta" => {
281                                        let delta_type =
282                                            event["delta"]["type"].as_str().unwrap_or("");
283                                        if delta_type == "text_delta" {
284                                            let text = event["delta"]["text"]
285                                                .as_str()
286                                                .unwrap_or("")
287                                                .to_string();
288                                            events.push(BrainEvent::TextDelta(text));
289                                        } else if delta_type == "input_json_delta" {
290                                            let partial = event["delta"]["partial_json"]
291                                                .as_str()
292                                                .unwrap_or("")
293                                                .to_string();
294                                            let index = event["index"].as_u64().unwrap_or(0);
295                                            let id = tool_ids
296                                                .get(&index)
297                                                .cloned()
298                                                .unwrap_or_else(|| index.to_string());
299                                            events.push(BrainEvent::ToolUseDelta {
300                                                id,
301                                                json: partial,
302                                            });
303                                        }
304                                    }
305                                    "content_block_stop" => {
306                                        let index = event["index"].as_u64().unwrap_or(0);
307                                        let id = tool_ids
308                                            .remove(&index)
309                                            .unwrap_or_else(|| index.to_string());
310                                        events.push(BrainEvent::ToolUseEnd { id });
311                                    }
312                                    "message_delta" => {
313                                        if let Some(usage) = event["usage"].as_object() {
314                                            events.push(BrainEvent::Usage(
315                                                crate::event::TokenUsage {
316                                                    input: usage["input_tokens"]
317                                                        .as_u64()
318                                                        .unwrap_or(0),
319                                                    output: usage["output_tokens"]
320                                                        .as_u64()
321                                                        .unwrap_or(0),
322                                                },
323                                            ));
324                                        }
325                                        let stop_reason = event["delta"]["stop_reason"]
326                                            .as_str()
327                                            .unwrap_or("end_turn");
328                                        let reason = match stop_reason {
329                                            "end_turn" => crate::event::StopReason::EndTurn,
330                                            "max_tokens" => crate::event::StopReason::MaxTokens,
331                                            "tool_use" => crate::event::StopReason::ToolUse,
332                                            s => crate::event::StopReason::StopSequence(
333                                                s.to_string(),
334                                            ),
335                                        };
336                                        events.push(BrainEvent::Done(reason));
337                                    }
338                                    "message_stop" => {}
339                                    _ => {}
340                                }
341                            }
342                            events
343                        }
344                        Err(e) => {
345                            vec![BrainEvent::Error(format!("stream error: {}", e))]
346                        }
347                    };
348                    futures::future::ready(Some(stream::iter(events)))
349                },
350            )
351            .flatten();
352
353        Ok(Box::pin(event_stream))
354    }
355}
356
357#[cfg(test)]
358mod tests {
359    use super::*;
360    use crate::provider::{Msg, PromptCacheConfig, PromptCacheTtl};
361
362    #[test]
363    fn anthropic_system_prompt_gets_cache_control() {
364        let req = BrainRequest {
365            system: Some("stable sparrow system".into()),
366            messages: vec![Msg {
367                role: "user".into(),
368                content: vec![ContentBlock::Text {
369                    text: "dynamic task".into(),
370                }],
371            }],
372            cache: PromptCacheConfig {
373                enabled: true,
374                ttl: PromptCacheTtl::OneHour,
375                key: Some("repo-key".into()),
376            },
377            ..BrainRequest::default()
378        };
379
380        let body = build_messages_body("claude-test", &req);
381        assert_eq!(
382            body["system"][0]["cache_control"],
383            json!({"type":"ephemeral","ttl":"1h"})
384        );
385        assert!(body["messages"][0]["content"][0]["cache_control"].is_null());
386    }
387}