Skip to main content

synaptic_anthropic/
chat_model.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4use serde_json::{json, Value};
5use synaptic_core::{
6    AIMessageChunk, ChatModel, ChatRequest, ChatResponse, ChatStream, Message, SynapticError,
7    TokenUsage, ToolCall, ToolChoice, ToolDefinition,
8};
9
10use synaptic_models::{ProviderBackend, ProviderRequest, ProviderResponse};
11
12#[derive(Debug, Clone)]
13pub struct AnthropicConfig {
14    pub api_key: String,
15    pub model: String,
16    pub base_url: String,
17    pub max_tokens: u32,
18    pub top_p: Option<f64>,
19    pub stop: Option<Vec<String>>,
20}
21
22impl AnthropicConfig {
23    pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Self {
24        Self {
25            api_key: api_key.into(),
26            model: model.into(),
27            base_url: "https://api.anthropic.com".to_string(),
28            max_tokens: 1024,
29            top_p: None,
30            stop: None,
31        }
32    }
33
34    pub fn with_base_url(mut self, url: impl Into<String>) -> Self {
35        self.base_url = url.into();
36        self
37    }
38
39    pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
40        self.max_tokens = max_tokens;
41        self
42    }
43
44    pub fn with_top_p(mut self, top_p: f64) -> Self {
45        self.top_p = Some(top_p);
46        self
47    }
48
49    pub fn with_stop(mut self, stop: Vec<String>) -> Self {
50        self.stop = Some(stop);
51        self
52    }
53}
54
55pub struct AnthropicChatModel {
56    config: AnthropicConfig,
57    backend: Arc<dyn ProviderBackend>,
58}
59
60impl AnthropicChatModel {
61    pub fn new(config: AnthropicConfig, backend: Arc<dyn ProviderBackend>) -> Self {
62        Self { config, backend }
63    }
64
65    fn build_request(&self, request: &ChatRequest, stream: bool) -> ProviderRequest {
66        let mut system_text: Option<String> = None;
67        let mut messages: Vec<Value> = Vec::new();
68
69        for msg in &request.messages {
70            match msg {
71                Message::System { content, .. } => {
72                    system_text = Some(content.clone());
73                }
74                Message::Human { content, .. } => {
75                    messages.push(json!({
76                        "role": "user",
77                        "content": content,
78                    }));
79                }
80                Message::AI {
81                    content,
82                    tool_calls,
83                    ..
84                } => {
85                    let mut content_blocks: Vec<Value> = Vec::new();
86                    if !content.is_empty() {
87                        content_blocks.push(json!({
88                            "type": "text",
89                            "text": content,
90                        }));
91                    }
92                    for tc in tool_calls {
93                        content_blocks.push(json!({
94                            "type": "tool_use",
95                            "id": tc.id,
96                            "name": tc.name,
97                            "input": tc.arguments,
98                        }));
99                    }
100                    messages.push(json!({
101                        "role": "assistant",
102                        "content": content_blocks,
103                    }));
104                }
105                Message::Tool {
106                    content,
107                    tool_call_id,
108                    ..
109                } => {
110                    messages.push(json!({
111                        "role": "user",
112                        "content": [{
113                            "type": "tool_result",
114                            "tool_use_id": tool_call_id,
115                            "content": content,
116                        }],
117                    }));
118                }
119                Message::Chat {
120                    custom_role,
121                    content,
122                    ..
123                } => {
124                    messages.push(json!({
125                        "role": custom_role,
126                        "content": content,
127                    }));
128                }
129                Message::Remove { .. } => { /* skip Remove messages */ }
130            }
131        }
132
133        let mut body = json!({
134            "model": self.config.model,
135            "max_tokens": self.config.max_tokens,
136            "messages": messages,
137            "stream": stream,
138        });
139
140        if let Some(system) = system_text {
141            body["system"] = json!(system);
142        }
143
144        if let Some(top_p) = self.config.top_p {
145            body["top_p"] = json!(top_p);
146        }
147        if let Some(ref stop) = self.config.stop {
148            body["stop_sequences"] = json!(stop);
149        }
150
151        if !request.tools.is_empty() {
152            body["tools"] = json!(request
153                .tools
154                .iter()
155                .map(tool_def_to_anthropic)
156                .collect::<Vec<_>>());
157        }
158        if let Some(ref choice) = request.tool_choice {
159            body["tool_choice"] = match choice {
160                ToolChoice::Auto => json!({"type": "auto"}),
161                ToolChoice::Required => json!({"type": "any"}),
162                ToolChoice::None => json!({"type": "none"}),
163                ToolChoice::Specific(name) => json!({"type": "tool", "name": name}),
164            };
165        }
166
167        ProviderRequest {
168            url: format!("{}/v1/messages", self.config.base_url),
169            headers: vec![
170                ("x-api-key".to_string(), self.config.api_key.clone()),
171                ("anthropic-version".to_string(), "2023-06-01".to_string()),
172                ("Content-Type".to_string(), "application/json".to_string()),
173            ],
174            body,
175        }
176    }
177}
178
179fn tool_def_to_anthropic(def: &ToolDefinition) -> Value {
180    json!({
181        "name": def.name,
182        "description": def.description,
183        "input_schema": def.parameters,
184    })
185}
186
187fn parse_response(resp: &ProviderResponse) -> Result<ChatResponse, SynapticError> {
188    check_error_status(resp)?;
189
190    let content_blocks = resp.body["content"].as_array().cloned().unwrap_or_default();
191
192    let mut text = String::new();
193    let mut tool_calls = Vec::new();
194
195    for block in &content_blocks {
196        match block["type"].as_str() {
197            Some("text") => {
198                if let Some(t) = block["text"].as_str() {
199                    text.push_str(t);
200                }
201            }
202            Some("tool_use") => {
203                if let (Some(id), Some(name)) = (block["id"].as_str(), block["name"].as_str()) {
204                    tool_calls.push(ToolCall {
205                        id: id.to_string(),
206                        name: name.to_string(),
207                        arguments: block["input"].clone(),
208                    });
209                }
210            }
211            _ => {}
212        }
213    }
214
215    let usage = parse_usage(&resp.body["usage"]);
216
217    let message = if tool_calls.is_empty() {
218        Message::ai(text)
219    } else {
220        Message::ai_with_tool_calls(text, tool_calls)
221    };
222
223    Ok(ChatResponse { message, usage })
224}
225
226fn check_error_status(resp: &ProviderResponse) -> Result<(), SynapticError> {
227    if resp.status == 429 {
228        let msg = resp.body["error"]["message"]
229            .as_str()
230            .unwrap_or("rate limited")
231            .to_string();
232        return Err(SynapticError::RateLimit(msg));
233    }
234    if resp.status >= 400 {
235        let msg = resp.body["error"]["message"]
236            .as_str()
237            .unwrap_or("unknown API error")
238            .to_string();
239        return Err(SynapticError::Model(format!(
240            "Anthropic API error ({}): {}",
241            resp.status, msg
242        )));
243    }
244    Ok(())
245}
246
247fn parse_usage(usage: &Value) -> Option<TokenUsage> {
248    if usage.is_null() {
249        return None;
250    }
251    Some(TokenUsage {
252        input_tokens: usage["input_tokens"].as_u64().unwrap_or(0) as u32,
253        output_tokens: usage["output_tokens"].as_u64().unwrap_or(0) as u32,
254        total_tokens: (usage["input_tokens"].as_u64().unwrap_or(0)
255            + usage["output_tokens"].as_u64().unwrap_or(0)) as u32,
256        input_details: None,
257        output_details: None,
258    })
259}
260
261fn parse_stream_event(event_type: &str, data: &str) -> Option<AIMessageChunk> {
262    let v: Value = serde_json::from_str(data).ok()?;
263
264    match event_type {
265        "content_block_delta" => {
266            let delta = &v["delta"];
267            match delta["type"].as_str() {
268                Some("text_delta") => Some(AIMessageChunk {
269                    content: delta["text"].as_str().unwrap_or("").to_string(),
270                    ..Default::default()
271                }),
272                Some("input_json_delta") => {
273                    // Tool input streaming — we accumulate partial JSON
274                    // For simplicity, we emit tool_calls once in content_block_start
275                    None
276                }
277                _ => None,
278            }
279        }
280        "content_block_start" => {
281            let block = &v["content_block"];
282            if block["type"].as_str() == Some("tool_use") {
283                let id = block["id"].as_str().unwrap_or("").to_string();
284                let name = block["name"].as_str().unwrap_or("").to_string();
285                Some(AIMessageChunk {
286                    tool_calls: vec![ToolCall {
287                        id,
288                        name,
289                        arguments: block["input"].clone(),
290                    }],
291                    ..Default::default()
292                })
293            } else {
294                None
295            }
296        }
297        "message_delta" => {
298            let usage = parse_usage(&v["usage"]);
299            if usage.is_some() {
300                Some(AIMessageChunk {
301                    usage,
302                    ..Default::default()
303                })
304            } else {
305                None
306            }
307        }
308        _ => None,
309    }
310}
311
312#[async_trait]
313impl ChatModel for AnthropicChatModel {
314    async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, SynapticError> {
315        let provider_req = self.build_request(&request, false);
316        let resp = self.backend.send(provider_req).await?;
317        parse_response(&resp)
318    }
319
320    fn stream_chat(&self, request: ChatRequest) -> ChatStream<'_> {
321        Box::pin(async_stream::stream! {
322            let provider_req = self.build_request(&request, true);
323            let byte_stream = self.backend.send_stream(provider_req).await;
324
325            let byte_stream = match byte_stream {
326                Ok(s) => s,
327                Err(e) => {
328                    yield Err(e);
329                    return;
330                }
331            };
332
333            use eventsource_stream::Eventsource;
334            use futures::StreamExt;
335
336            let mut event_stream = byte_stream
337                .map(|result| result.map_err(|e| std::io::Error::other(e.to_string())))
338                .eventsource();
339
340            while let Some(event) = event_stream.next().await {
341                match event {
342                    Ok(ev) => {
343                        if ev.event == "message_stop" {
344                            break;
345                        }
346                        if let Some(chunk) = parse_stream_event(&ev.event, &ev.data) {
347                            yield Ok(chunk);
348                        }
349                    }
350                    Err(e) => {
351                        yield Err(SynapticError::Model(format!("SSE parse error: {e}")));
352                        break;
353                    }
354                }
355            }
356        })
357    }
358}