swiftide_integrations/anthropic/
chat_completion.rs

1use futures_util::{StreamExt as _, TryStreamExt as _, stream};
2use std::sync::{Arc, Mutex};
3
4use anyhow::{Context as _, Result};
5use async_anthropic::types::{
6    CreateMessagesRequestBuilder, Message, MessageBuilder, MessageContent, MessageContentList,
7    MessageRole, MessagesStreamEvent, ToolChoice, ToolResultBuilder, ToolUseBuilder,
8};
9use async_trait::async_trait;
10use serde_json::{Value, json};
11use swiftide_core::{
12    ChatCompletion, ChatCompletionStream,
13    chat_completion::{
14        ChatCompletionRequest, ChatCompletionResponse, ChatMessage, ToolCall, ToolSpec, Usage,
15        UsageBuilder, errors::LanguageModelError,
16    },
17};
18
19use super::Anthropic;
20
21#[cfg(feature = "metrics")]
22use swiftide_core::metrics::emit_usage;
23
24#[async_trait]
25impl ChatCompletion for Anthropic {
26    #[tracing::instrument(skip_all, err)]
27    async fn complete(
28        &self,
29        request: &ChatCompletionRequest,
30    ) -> Result<ChatCompletionResponse, LanguageModelError> {
31        let model = &self.default_options.prompt_model;
32        let request = self
33            .build_request(request)
34            .and_then(|b| b.build().map_err(LanguageModelError::permanent))?;
35
36        tracing::debug!(
37            model = &model,
38            messages = serde_json::to_string_pretty(&request).expect("Infallible"),
39            "[ChatCompletion] Request to anthropic"
40        );
41
42        let response = self
43            .client
44            .messages()
45            .create(request)
46            .await
47            .map_err(LanguageModelError::permanent)?;
48
49        tracing::debug!(
50            response = serde_json::to_string_pretty(&response).expect("Infallible"),
51            "[ChatCompletion] Response from anthropic"
52        );
53
54        let maybe_tool_calls = response
55            .messages()
56            .iter()
57            .flat_map(Message::tool_uses)
58            .map(|atool| {
59                ToolCall::builder()
60                    .id(atool.id)
61                    .name(atool.name)
62                    .args(atool.input.to_string())
63                    .build()
64                    .expect("infallible")
65            })
66            .collect::<Vec<_>>();
67        let maybe_tool_calls = if maybe_tool_calls.is_empty() {
68            None
69        } else {
70            Some(maybe_tool_calls)
71        };
72
73        let mut builder = ChatCompletionResponse::builder()
74            .maybe_message(response.messages().iter().find_map(Message::text))
75            .maybe_tool_calls(maybe_tool_calls)
76            .to_owned();
77
78        if let Some(usage) = &response.usage {
79            let input_tokens = usage.input_tokens.unwrap_or_default();
80            let output_tokens = usage.output_tokens.unwrap_or_default();
81            let total_tokens = input_tokens + output_tokens;
82
83            #[cfg(feature = "metrics")]
84            emit_usage(
85                model,
86                input_tokens.into(),
87                output_tokens.into(),
88                total_tokens.into(),
89                self.metric_metadata.as_ref(),
90            );
91
92            let usage = Usage {
93                prompt_tokens: input_tokens,
94                completion_tokens: output_tokens,
95                total_tokens,
96            };
97            if let Some(callback) = &self.on_usage {
98                callback(&usage).await?;
99            }
100
101            let usage = UsageBuilder::default()
102                .prompt_tokens(input_tokens)
103                .completion_tokens(output_tokens)
104                .total_tokens(total_tokens)
105                .build()
106                .map_err(LanguageModelError::permanent)?;
107
108            builder.usage(usage);
109        }
110        builder.build().map_err(LanguageModelError::from)
111    }
112
113    #[tracing::instrument(skip_all)]
114    async fn complete_stream(&self, request: &ChatCompletionRequest) -> ChatCompletionStream {
115        let model = &self.default_options.prompt_model;
116        let request = match self
117            .build_request(request)
118            .and_then(|b| b.build().map_err(LanguageModelError::permanent))
119        {
120            Ok(request) => request,
121            Err(e) => {
122                return e.into();
123            }
124        };
125
126        tracing::debug!(
127            model = &model,
128            messages = serde_json::to_string_pretty(&request).expect("Infallible"),
129            "[ChatCompletion] Request to anthropic"
130        );
131
132        let response = self.client.messages().create_stream(request).await;
133
134        let accumulating_response = Arc::new(Mutex::new(ChatCompletionResponse::default()));
135        let final_response = Arc::clone(&accumulating_response);
136        #[cfg(feature = "metrics")]
137        let model = model.clone();
138        #[cfg(feature = "metrics")]
139        let metric_metadata = self.metric_metadata.clone();
140
141        let maybe_usage_callback = self.on_usage.clone();
142
143        response
144            .map_ok(move |chunk| {
145                let accumulating_response = Arc::clone(&accumulating_response);
146
147                let mut lock = accumulating_response.lock().unwrap();
148
149                append_delta_from_chunk(&chunk, &mut lock);
150                lock.clone()
151            })
152            .map_err(LanguageModelError::permanent)
153            .chain(
154                stream::iter(vec![final_response]).map(move |final_response| {
155                    if let Some(usage) = final_response.lock().unwrap().usage.as_ref() {
156                        if let Some(callback) = maybe_usage_callback.as_ref() {
157                            let usage = usage.clone();
158                            let callback = callback.clone();
159
160                            tokio::spawn(async move {
161                                if let Err(e) = callback(&usage).await {
162                                    tracing::error!("Error in on_usage callback: {}", e);
163                                }
164                            });
165                        }
166
167                        #[cfg(feature = "metrics")]
168                        emit_usage(
169                            &model,
170                            usage.prompt_tokens.into(),
171                            usage.completion_tokens.into(),
172                            usage.total_tokens.into(),
173                            metric_metadata.as_ref(),
174                        );
175                    }
176
177                    Ok(final_response.lock().unwrap().clone())
178                }),
179            )
180            .boxed()
181    }
182}
183
184#[allow(clippy::collapsible_match)]
185fn append_delta_from_chunk(chunk: &MessagesStreamEvent, lock: &mut ChatCompletionResponse) {
186    match chunk {
187        MessagesStreamEvent::ContentBlockStart {
188            index,
189            content_block,
190        } => match content_block {
191            MessageContent::ToolUse(tool_use) => {
192                lock.append_tool_call_delta(*index, Some(&tool_use.id), Some(&tool_use.name), None);
193            }
194            MessageContent::Text(text) => {
195                lock.append_message_delta(Some(&text.text));
196            }
197            MessageContent::ToolResult(_tool_result) => (),
198        },
199        MessagesStreamEvent::ContentBlockDelta { index, delta } => match delta {
200            async_anthropic::types::ContentBlockDelta::TextDelta { text } => {
201                lock.append_message_delta(Some(text));
202            }
203            async_anthropic::types::ContentBlockDelta::InputJsonDelta { partial_json } => {
204                lock.append_tool_call_delta(*index, None, None, Some(partial_json));
205            }
206        },
207        #[allow(clippy::cast_possible_truncation)]
208        MessagesStreamEvent::MessageDelta { usage, .. } => {
209            if let Some(usage) = usage {
210                let input_tokens = usage.input_tokens.unwrap_or_default();
211                let output_tokens = usage.output_tokens.unwrap_or_default();
212                let total_tokens = input_tokens + output_tokens;
213                lock.append_usage_delta(input_tokens, output_tokens, total_tokens);
214            }
215        }
216
217        MessagesStreamEvent::MessageStart { message, usage } => {
218            if let Some(usage) = usage {
219                let input_tokens = usage.input_tokens.unwrap_or_default();
220                let output_tokens = usage.output_tokens.unwrap_or_default();
221                let total_tokens = input_tokens + output_tokens;
222                lock.append_usage_delta(input_tokens, output_tokens, total_tokens);
223            }
224            if let Some(message_usage) = &message.usage {
225                let input_tokens = message_usage.input_tokens.unwrap_or_default();
226                let output_tokens = message_usage.output_tokens.unwrap_or_default();
227                let total_tokens = input_tokens + output_tokens;
228                lock.append_usage_delta(input_tokens, output_tokens, total_tokens);
229            }
230        }
231        _ => {}
232    }
233}
234
235impl Anthropic {
236    fn build_request(
237        &self,
238        request: &ChatCompletionRequest,
239    ) -> Result<async_anthropic::types::CreateMessagesRequestBuilder, LanguageModelError> {
240        let model = &self.default_options.prompt_model;
241        let mut messages = request.messages().to_vec();
242
243        let maybe_system = messages
244            .iter()
245            .position(ChatMessage::is_system)
246            .map(|idx| messages.remove(idx));
247
248        let messages = messages
249            .iter()
250            .map(message_to_antropic)
251            .collect::<Result<Vec<_>>>()?;
252
253        let mut anthropic_request = CreateMessagesRequestBuilder::default()
254            .model(model)
255            .messages(messages)
256            .to_owned();
257
258        if let Some(ChatMessage::System(system)) = maybe_system {
259            anthropic_request.system(system);
260        }
261
262        if !request.tools_spec.is_empty() {
263            anthropic_request
264                .tools(
265                    request
266                        .tools_spec()
267                        .iter()
268                        .map(tools_to_anthropic)
269                        .collect::<Result<Vec<_>>>()?,
270                )
271                .tool_choice(ToolChoice::Auto);
272        }
273
274        Ok(anthropic_request)
275    }
276}
277
278#[allow(clippy::items_after_statements)]
279fn message_to_antropic(message: &ChatMessage) -> Result<Message> {
280    let mut builder = MessageBuilder::default().role(MessageRole::User).to_owned();
281
282    use ChatMessage::{Assistant, Summary, System, ToolOutput, User};
283
284    match message {
285        ToolOutput(tool_call, tool_output) => builder.content(
286            ToolResultBuilder::default()
287                .tool_use_id(tool_call.id())
288                .content(tool_output.content().unwrap_or("Success"))
289                .build()?,
290        ),
291        Summary(msg) | System(msg) | User(msg) => builder.content(msg),
292        Assistant(msg, tool_calls) => {
293            builder.role(MessageRole::Assistant);
294
295            let mut content_list: Vec<MessageContent> = Vec::new();
296
297            if let Some(msg) = msg {
298                content_list.push(msg.into());
299            }
300
301            if let Some(tool_calls) = tool_calls {
302                for tool_call in tool_calls {
303                    let tool_call = ToolUseBuilder::default()
304                        .id(tool_call.id())
305                        .name(tool_call.name())
306                        .input(tool_call.args().and_then(|v| v.parse::<Value>().ok()))
307                        .build()?;
308
309                    content_list.push(tool_call.into());
310                }
311            }
312
313            let content_list = MessageContentList(content_list);
314
315            builder.content(content_list)
316        }
317    };
318
319    builder.build().context("Failed to build message")
320}
321
322fn tools_to_anthropic(
323    spec: &ToolSpec,
324) -> Result<serde_json::value::Map<String, serde_json::Value>> {
325    let mut map = json!({
326        "name": &spec.name,
327        "description": &spec.description,
328    })
329    .as_object_mut()
330    .context("Failed to build tool")?
331    .to_owned();
332
333    let schema = match &spec.parameters_schema {
334        Some(schema) => serde_json::to_value(schema)?,
335        None => json!({
336            "type": "object",
337            "properties": {},
338        }),
339    };
340
341    map.insert("input_schema".to_string(), schema);
342
343    Ok(map)
344}
345
346#[cfg(test)]
347mod tests {
348
349    use super::*;
350    use schemars::{JsonSchema, schema_for};
351    use swiftide_core::{
352        AgentContext, Tool,
353        chat_completion::{ChatCompletionRequest, ChatMessage},
354    };
355    use wiremock::{
356        Mock, MockServer, ResponseTemplate,
357        matchers::{body_partial_json, method, path},
358    };
359
360    #[derive(Clone)]
361    struct FakeTool();
362
363    #[derive(JsonSchema, serde::Serialize, serde::Deserialize)]
364    struct LocationArgs {
365        location: String,
366    }
367
368    #[async_trait]
369    impl Tool for FakeTool {
370        async fn invoke(
371            &self,
372            _agent_context: &dyn AgentContext,
373            _tool_call: &ToolCall,
374        ) -> std::result::Result<
375            swiftide_core::chat_completion::ToolOutput,
376            swiftide_core::chat_completion::errors::ToolError,
377        > {
378            todo!()
379        }
380
381        fn name(&self) -> std::borrow::Cow<'_, str> {
382            "get_weather".into()
383        }
384
385        fn tool_spec(&self) -> ToolSpec {
386            ToolSpec::builder()
387                .description("Gets the weather")
388                .name("get_weather")
389                .parameters_schema(schema_for!(LocationArgs))
390                .build()
391                .unwrap()
392        }
393    }
394
395    #[test_log::test(tokio::test)]
396    async fn test_complete_without_tools() {
397        // Start a wiremock server
398        let mock_server = MockServer::start().await;
399
400        // Create a mock response
401        let mock_response = ResponseTemplate::new(200).set_body_json(serde_json::json!({
402            "content": [{"type": "text", "text": "mocked response"}]
403        }));
404
405        // Mock the expected endpoint
406        Mock::given(method("POST"))
407            .and(path("/v1/messages")) // Adjust path to match expected endpoint
408            .respond_with(mock_response)
409            .mount(&mock_server)
410            .await;
411
412        let client = async_anthropic::Client::builder()
413            .base_url(mock_server.uri())
414            .build()
415            .unwrap();
416
417        // Build an Anthropic client with the mock server's URL
418        let mut client_builder = Anthropic::builder();
419        client_builder.client(client);
420        let client = client_builder.build().unwrap();
421
422        // Prepare a sample request
423        let request = ChatCompletionRequest::builder()
424            .messages(vec![ChatMessage::User("hello".into())])
425            .build()
426            .unwrap();
427
428        // Call the complete method
429        let result = client.complete(&request).await.unwrap();
430
431        // Assert the result
432        assert_eq!(result.message, Some("mocked response".into()));
433        assert!(result.tool_calls.is_none());
434    }
435
436    #[test_log::test(tokio::test)]
437    async fn test_complete_with_tools() {
438        // Start a wiremock server
439        let mock_server = MockServer::start().await;
440
441        // Create a mock response
442        let mock_response = ResponseTemplate::new(200).set_body_json(serde_json::json!({
443            "id": "msg_016zKNb88WhhgBQXhSaQf1rs",
444            "content": [
445            {
446                "type": "text",
447                "text": "I'll check the current weather in San Francisco, CA for you."
448            },
449            {
450                "type": "tool_use",
451                "id": "toolu_01E1yxpxXU4hBgCMLzPL1FuR",
452                "input": {
453                "location": "San Francisco, CA"
454                },
455                "name": "get_weather"
456            }
457            ],
458            "model": "claude-3-5-sonnet-20241022",
459            "stop_reason": "tool_use",
460            "stop_sequence": null,
461            "usage": {
462            "input_tokens": 403,
463            "output_tokens": 71
464            }
465        }));
466
467        // Mock the expected endpoint
468        Mock::given(method("POST"))
469            .and(path("/v1/messages")) // Adjust path to match expected endpoint
470            .respond_with(mock_response)
471            .mount(&mock_server)
472            .await;
473
474        let client = async_anthropic::Client::builder()
475            .base_url(mock_server.uri())
476            .build()
477            .unwrap();
478
479        // Build an Anthropic client with the mock server's URL
480        let mut client_builder = Anthropic::builder();
481        client_builder.client(client);
482        let client = client_builder.build().unwrap();
483
484        // Prepare a sample request
485        let request = ChatCompletionRequest::builder()
486            .messages(vec![ChatMessage::User("hello".into())])
487            .tool_specs([FakeTool().tool_spec()])
488            .build()
489            .unwrap();
490
491        // Call the complete method
492        let result = client.complete(&request).await.unwrap();
493
494        // Assert the result
495        assert_eq!(
496            result.message,
497            Some("I'll check the current weather in San Francisco, CA for you.".into())
498        );
499        assert!(result.tool_calls.is_some());
500
501        let Some(tool_call) = result.tool_calls.and_then(|f| f.first().cloned()) else {
502            panic!("No tool call found")
503        };
504        assert_eq!(tool_call.name(), "get_weather");
505        assert_eq!(
506            tool_call.args(),
507            Some(
508                json!({"location": "San Francisco, CA"})
509                    .to_string()
510                    .as_str()
511            )
512        );
513    }
514
515    #[test_log::test(tokio::test)]
516    async fn test_complete_with_system_prompt() {
517        // Start a wiremock server
518        let mock_server = MockServer::start().await;
519
520        // Create a mock response
521        let mock_response = ResponseTemplate::new(200).set_body_json(serde_json::json!({
522            "content": [{"type": "text", "text": "Response with system prompt"}],
523            "usage": {
524                "input_tokens": 19,
525                "output_tokens": 10,
526            }
527        }));
528
529        // Mock the expected endpoint
530        Mock::given(method("POST"))
531            .and(path("/v1/messages")) // Adjust path to match expected endpoint
532            .and(body_partial_json(json!({
533                "system": "System message",
534                "messages":[{"role":"user","content":[{"type":"text","text":"Hello"}]}]
535            })))
536            .respond_with(mock_response)
537            .mount(&mock_server)
538            .await;
539
540        let client = async_anthropic::Client::builder()
541            .base_url(mock_server.uri())
542            .build()
543            .unwrap();
544
545        // Build an Anthropic client with the mock server's URL
546        let mut client_builder = Anthropic::builder();
547        client_builder.client(client);
548        let client = client_builder.build().unwrap();
549
550        // Prepare a sample request with a system message
551        let request = ChatCompletionRequest::builder()
552            .messages(vec![
553                ChatMessage::System("System message".into()),
554                ChatMessage::User("Hello".into()),
555            ])
556            .build()
557            .unwrap();
558
559        // Call the complete method
560        let response = client.complete(&request).await.unwrap();
561
562        // Assert the result
563        assert_eq!(response.message, Some("Response with system prompt".into()));
564
565        let usage = response.usage.unwrap();
566        assert_eq!(usage.prompt_tokens, 19);
567        assert_eq!(usage.completion_tokens, 10);
568        assert_eq!(usage.total_tokens, 29);
569    }
570
571    #[test]
572    fn test_tools_to_anthropic() {
573        let tool_spec = ToolSpec::builder()
574            .description("Gets the weather")
575            .name("get_weather")
576            .parameters_schema(schema_for!(LocationArgs))
577            .build()
578            .unwrap();
579
580        let result = tools_to_anthropic(&tool_spec).unwrap();
581        let expected_schema = serde_json::to_value(schema_for!(LocationArgs)).unwrap();
582        let expected = json!({
583            "name": "get_weather",
584            "description": "Gets the weather",
585            "input_schema": expected_schema,
586        });
587
588        assert_eq!(serde_json::Value::Object(result), expected);
589    }
590}