1use std::sync::Arc;
2
3use async_trait::async_trait;
4use serde_json::{json, Value};
5use synaptic_core::{
6 AIMessageChunk, ChatModel, ChatRequest, ChatResponse, ChatStream, Message, SynapticError,
7 TokenUsage, ToolCall, ToolChoice, ToolDefinition,
8};
9
10use synaptic_models::{ProviderBackend, ProviderRequest, ProviderResponse};
11
12#[derive(Debug, Clone)]
13pub struct AnthropicConfig {
14 pub api_key: String,
15 pub model: String,
16 pub base_url: String,
17 pub max_tokens: u32,
18 pub top_p: Option<f64>,
19 pub stop: Option<Vec<String>>,
20}
21
22impl AnthropicConfig {
23 pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Self {
24 Self {
25 api_key: api_key.into(),
26 model: model.into(),
27 base_url: "https://api.anthropic.com".to_string(),
28 max_tokens: 1024,
29 top_p: None,
30 stop: None,
31 }
32 }
33
34 pub fn with_base_url(mut self, url: impl Into<String>) -> Self {
35 self.base_url = url.into();
36 self
37 }
38
39 pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
40 self.max_tokens = max_tokens;
41 self
42 }
43
44 pub fn with_top_p(mut self, top_p: f64) -> Self {
45 self.top_p = Some(top_p);
46 self
47 }
48
49 pub fn with_stop(mut self, stop: Vec<String>) -> Self {
50 self.stop = Some(stop);
51 self
52 }
53}
54
55pub struct AnthropicChatModel {
56 config: AnthropicConfig,
57 backend: Arc<dyn ProviderBackend>,
58}
59
60impl AnthropicChatModel {
61 pub fn new(config: AnthropicConfig, backend: Arc<dyn ProviderBackend>) -> Self {
62 Self { config, backend }
63 }
64
65 fn build_request(&self, request: &ChatRequest, stream: bool) -> ProviderRequest {
66 let mut system_text: Option<String> = None;
67 let mut messages: Vec<Value> = Vec::new();
68
69 for msg in &request.messages {
70 match msg {
71 Message::System { content, .. } => {
72 system_text = Some(content.clone());
73 }
74 Message::Human { content, .. } => {
75 messages.push(json!({
76 "role": "user",
77 "content": content,
78 }));
79 }
80 Message::AI {
81 content,
82 tool_calls,
83 ..
84 } => {
85 let mut content_blocks: Vec<Value> = Vec::new();
86 if !content.is_empty() {
87 content_blocks.push(json!({
88 "type": "text",
89 "text": content,
90 }));
91 }
92 for tc in tool_calls {
93 content_blocks.push(json!({
94 "type": "tool_use",
95 "id": tc.id,
96 "name": tc.name,
97 "input": tc.arguments,
98 }));
99 }
100 messages.push(json!({
101 "role": "assistant",
102 "content": content_blocks,
103 }));
104 }
105 Message::Tool {
106 content,
107 tool_call_id,
108 ..
109 } => {
110 messages.push(json!({
111 "role": "user",
112 "content": [{
113 "type": "tool_result",
114 "tool_use_id": tool_call_id,
115 "content": content,
116 }],
117 }));
118 }
119 Message::Chat {
120 custom_role,
121 content,
122 ..
123 } => {
124 messages.push(json!({
125 "role": custom_role,
126 "content": content,
127 }));
128 }
129 Message::Remove { .. } => { }
130 }
131 }
132
133 let mut body = json!({
134 "model": self.config.model,
135 "max_tokens": self.config.max_tokens,
136 "messages": messages,
137 "stream": stream,
138 });
139
140 if let Some(system) = system_text {
141 body["system"] = json!(system);
142 }
143
144 if let Some(top_p) = self.config.top_p {
145 body["top_p"] = json!(top_p);
146 }
147 if let Some(ref stop) = self.config.stop {
148 body["stop_sequences"] = json!(stop);
149 }
150
151 if !request.tools.is_empty() {
152 body["tools"] = json!(request
153 .tools
154 .iter()
155 .map(tool_def_to_anthropic)
156 .collect::<Vec<_>>());
157 }
158 if let Some(ref choice) = request.tool_choice {
159 body["tool_choice"] = match choice {
160 ToolChoice::Auto => json!({"type": "auto"}),
161 ToolChoice::Required => json!({"type": "any"}),
162 ToolChoice::None => json!({"type": "none"}),
163 ToolChoice::Specific(name) => json!({"type": "tool", "name": name}),
164 };
165 }
166
167 ProviderRequest {
168 url: format!("{}/v1/messages", self.config.base_url),
169 headers: vec![
170 ("x-api-key".to_string(), self.config.api_key.clone()),
171 ("anthropic-version".to_string(), "2023-06-01".to_string()),
172 ("Content-Type".to_string(), "application/json".to_string()),
173 ],
174 body,
175 }
176 }
177}
178
179fn tool_def_to_anthropic(def: &ToolDefinition) -> Value {
180 json!({
181 "name": def.name,
182 "description": def.description,
183 "input_schema": def.parameters,
184 })
185}
186
187fn parse_response(resp: &ProviderResponse) -> Result<ChatResponse, SynapticError> {
188 check_error_status(resp)?;
189
190 let content_blocks = resp.body["content"].as_array().cloned().unwrap_or_default();
191
192 let mut text = String::new();
193 let mut tool_calls = Vec::new();
194
195 for block in &content_blocks {
196 match block["type"].as_str() {
197 Some("text") => {
198 if let Some(t) = block["text"].as_str() {
199 text.push_str(t);
200 }
201 }
202 Some("tool_use") => {
203 if let (Some(id), Some(name)) = (block["id"].as_str(), block["name"].as_str()) {
204 tool_calls.push(ToolCall {
205 id: id.to_string(),
206 name: name.to_string(),
207 arguments: block["input"].clone(),
208 });
209 }
210 }
211 _ => {}
212 }
213 }
214
215 let usage = parse_usage(&resp.body["usage"]);
216
217 let message = if tool_calls.is_empty() {
218 Message::ai(text)
219 } else {
220 Message::ai_with_tool_calls(text, tool_calls)
221 };
222
223 Ok(ChatResponse { message, usage })
224}
225
226fn check_error_status(resp: &ProviderResponse) -> Result<(), SynapticError> {
227 if resp.status == 429 {
228 let msg = resp.body["error"]["message"]
229 .as_str()
230 .unwrap_or("rate limited")
231 .to_string();
232 return Err(SynapticError::RateLimit(msg));
233 }
234 if resp.status >= 400 {
235 let msg = resp.body["error"]["message"]
236 .as_str()
237 .unwrap_or("unknown API error")
238 .to_string();
239 return Err(SynapticError::Model(format!(
240 "Anthropic API error ({}): {}",
241 resp.status, msg
242 )));
243 }
244 Ok(())
245}
246
247fn parse_usage(usage: &Value) -> Option<TokenUsage> {
248 if usage.is_null() {
249 return None;
250 }
251 Some(TokenUsage {
252 input_tokens: usage["input_tokens"].as_u64().unwrap_or(0) as u32,
253 output_tokens: usage["output_tokens"].as_u64().unwrap_or(0) as u32,
254 total_tokens: (usage["input_tokens"].as_u64().unwrap_or(0)
255 + usage["output_tokens"].as_u64().unwrap_or(0)) as u32,
256 input_details: None,
257 output_details: None,
258 })
259}
260
261fn parse_stream_event(event_type: &str, data: &str) -> Option<AIMessageChunk> {
262 let v: Value = serde_json::from_str(data).ok()?;
263
264 match event_type {
265 "content_block_delta" => {
266 let delta = &v["delta"];
267 match delta["type"].as_str() {
268 Some("text_delta") => Some(AIMessageChunk {
269 content: delta["text"].as_str().unwrap_or("").to_string(),
270 ..Default::default()
271 }),
272 Some("input_json_delta") => {
273 None
276 }
277 _ => None,
278 }
279 }
280 "content_block_start" => {
281 let block = &v["content_block"];
282 if block["type"].as_str() == Some("tool_use") {
283 let id = block["id"].as_str().unwrap_or("").to_string();
284 let name = block["name"].as_str().unwrap_or("").to_string();
285 Some(AIMessageChunk {
286 tool_calls: vec![ToolCall {
287 id,
288 name,
289 arguments: block["input"].clone(),
290 }],
291 ..Default::default()
292 })
293 } else {
294 None
295 }
296 }
297 "message_delta" => {
298 let usage = parse_usage(&v["usage"]);
299 if usage.is_some() {
300 Some(AIMessageChunk {
301 usage,
302 ..Default::default()
303 })
304 } else {
305 None
306 }
307 }
308 _ => None,
309 }
310}
311
312#[async_trait]
313impl ChatModel for AnthropicChatModel {
314 async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, SynapticError> {
315 let provider_req = self.build_request(&request, false);
316 let resp = self.backend.send(provider_req).await?;
317 parse_response(&resp)
318 }
319
320 fn stream_chat(&self, request: ChatRequest) -> ChatStream<'_> {
321 Box::pin(async_stream::stream! {
322 let provider_req = self.build_request(&request, true);
323 let byte_stream = self.backend.send_stream(provider_req).await;
324
325 let byte_stream = match byte_stream {
326 Ok(s) => s,
327 Err(e) => {
328 yield Err(e);
329 return;
330 }
331 };
332
333 use eventsource_stream::Eventsource;
334 use futures::StreamExt;
335
336 let mut event_stream = byte_stream
337 .map(|result| result.map_err(|e| std::io::Error::other(e.to_string())))
338 .eventsource();
339
340 while let Some(event) = event_stream.next().await {
341 match event {
342 Ok(ev) => {
343 if ev.event == "message_stop" {
344 break;
345 }
346 if let Some(chunk) = parse_stream_event(&ev.event, &ev.data) {
347 yield Ok(chunk);
348 }
349 }
350 Err(e) => {
351 yield Err(SynapticError::Model(format!("SSE parse error: {e}")));
352 break;
353 }
354 }
355 }
356 })
357 }
358}