swiftide_integrations/anthropic/
chat_completion.rs1use futures_util::{StreamExt as _, TryStreamExt as _, stream};
2use std::sync::{Arc, Mutex};
3
4use anyhow::{Context as _, Result};
5use async_anthropic::types::{
6 CreateMessagesRequestBuilder, Message, MessageBuilder, MessageContent, MessageContentList,
7 MessageRole, MessagesStreamEvent, ToolChoice, ToolResultBuilder, ToolUseBuilder,
8};
9use async_trait::async_trait;
10use serde_json::{Value, json};
11use swiftide_core::{
12 ChatCompletion, ChatCompletionStream,
13 chat_completion::{
14 ChatCompletionRequest, ChatCompletionResponse, ChatMessage, ToolCall, ToolSpec, Usage,
15 UsageBuilder, errors::LanguageModelError,
16 },
17};
18
19use super::Anthropic;
20
21#[cfg(feature = "metrics")]
22use swiftide_core::metrics::emit_usage;
23
24#[async_trait]
25impl ChatCompletion for Anthropic {
26 #[tracing::instrument(skip_all, err)]
27 async fn complete(
28 &self,
29 request: &ChatCompletionRequest,
30 ) -> Result<ChatCompletionResponse, LanguageModelError> {
31 let model = &self.default_options.prompt_model;
32 let request = self
33 .build_request(request)
34 .and_then(|b| b.build().map_err(LanguageModelError::permanent))?;
35
36 tracing::debug!(
37 model = &model,
38 messages = serde_json::to_string_pretty(&request).expect("Infallible"),
39 "[ChatCompletion] Request to anthropic"
40 );
41
42 let response = self
43 .client
44 .messages()
45 .create(request)
46 .await
47 .map_err(LanguageModelError::permanent)?;
48
49 tracing::debug!(
50 response = serde_json::to_string_pretty(&response).expect("Infallible"),
51 "[ChatCompletion] Response from anthropic"
52 );
53
54 let maybe_tool_calls = response
55 .messages()
56 .iter()
57 .flat_map(Message::tool_uses)
58 .map(|atool| {
59 ToolCall::builder()
60 .id(atool.id)
61 .name(atool.name)
62 .args(atool.input.to_string())
63 .build()
64 .expect("infallible")
65 })
66 .collect::<Vec<_>>();
67 let maybe_tool_calls = if maybe_tool_calls.is_empty() {
68 None
69 } else {
70 Some(maybe_tool_calls)
71 };
72
73 let mut builder = ChatCompletionResponse::builder()
74 .maybe_message(response.messages().iter().find_map(Message::text))
75 .maybe_tool_calls(maybe_tool_calls)
76 .to_owned();
77
78 if let Some(usage) = &response.usage {
79 let input_tokens = usage.input_tokens.unwrap_or_default();
80 let output_tokens = usage.output_tokens.unwrap_or_default();
81 let total_tokens = input_tokens + output_tokens;
82
83 #[cfg(feature = "metrics")]
84 emit_usage(
85 model,
86 input_tokens.into(),
87 output_tokens.into(),
88 total_tokens.into(),
89 self.metric_metadata.as_ref(),
90 );
91
92 let usage = Usage {
93 prompt_tokens: input_tokens,
94 completion_tokens: output_tokens,
95 total_tokens,
96 };
97 if let Some(callback) = &self.on_usage {
98 callback(&usage).await?;
99 }
100
101 let usage = UsageBuilder::default()
102 .prompt_tokens(input_tokens)
103 .completion_tokens(output_tokens)
104 .total_tokens(total_tokens)
105 .build()
106 .map_err(LanguageModelError::permanent)?;
107
108 builder.usage(usage);
109 }
110 builder.build().map_err(LanguageModelError::from)
111 }
112
113 #[tracing::instrument(skip_all)]
114 async fn complete_stream(&self, request: &ChatCompletionRequest) -> ChatCompletionStream {
115 let model = &self.default_options.prompt_model;
116 let request = match self
117 .build_request(request)
118 .and_then(|b| b.build().map_err(LanguageModelError::permanent))
119 {
120 Ok(request) => request,
121 Err(e) => {
122 return e.into();
123 }
124 };
125
126 tracing::debug!(
127 model = &model,
128 messages = serde_json::to_string_pretty(&request).expect("Infallible"),
129 "[ChatCompletion] Request to anthropic"
130 );
131
132 let response = self.client.messages().create_stream(request).await;
133
134 let accumulating_response = Arc::new(Mutex::new(ChatCompletionResponse::default()));
135 let final_response = Arc::clone(&accumulating_response);
136 #[cfg(feature = "metrics")]
137 let model = model.clone();
138 #[cfg(feature = "metrics")]
139 let metric_metadata = self.metric_metadata.clone();
140
141 let maybe_usage_callback = self.on_usage.clone();
142
143 response
144 .map_ok(move |chunk| {
145 let accumulating_response = Arc::clone(&accumulating_response);
146
147 let mut lock = accumulating_response.lock().unwrap();
148
149 append_delta_from_chunk(&chunk, &mut lock);
150 lock.clone()
151 })
152 .map_err(LanguageModelError::permanent)
153 .chain(
154 stream::iter(vec![final_response]).map(move |final_response| {
155 if let Some(usage) = final_response.lock().unwrap().usage.as_ref() {
156 if let Some(callback) = maybe_usage_callback.as_ref() {
157 let usage = usage.clone();
158 let callback = callback.clone();
159
160 tokio::spawn(async move {
161 if let Err(e) = callback(&usage).await {
162 tracing::error!("Error in on_usage callback: {}", e);
163 }
164 });
165 }
166
167 #[cfg(feature = "metrics")]
168 emit_usage(
169 &model,
170 usage.prompt_tokens.into(),
171 usage.completion_tokens.into(),
172 usage.total_tokens.into(),
173 metric_metadata.as_ref(),
174 );
175 }
176
177 Ok(final_response.lock().unwrap().clone())
178 }),
179 )
180 .boxed()
181 }
182}
183
184#[allow(clippy::collapsible_match)]
185fn append_delta_from_chunk(chunk: &MessagesStreamEvent, lock: &mut ChatCompletionResponse) {
186 match chunk {
187 MessagesStreamEvent::ContentBlockStart {
188 index,
189 content_block,
190 } => match content_block {
191 MessageContent::ToolUse(tool_use) => {
192 lock.append_tool_call_delta(*index, Some(&tool_use.id), Some(&tool_use.name), None);
193 }
194 MessageContent::Text(text) => {
195 lock.append_message_delta(Some(&text.text));
196 }
197 MessageContent::ToolResult(_tool_result) => (),
198 },
199 MessagesStreamEvent::ContentBlockDelta { index, delta } => match delta {
200 async_anthropic::types::ContentBlockDelta::TextDelta { text } => {
201 lock.append_message_delta(Some(text));
202 }
203 async_anthropic::types::ContentBlockDelta::InputJsonDelta { partial_json } => {
204 lock.append_tool_call_delta(*index, None, None, Some(partial_json));
205 }
206 },
207 #[allow(clippy::cast_possible_truncation)]
208 MessagesStreamEvent::MessageDelta { usage, .. } => {
209 if let Some(usage) = usage {
210 let input_tokens = usage.input_tokens.unwrap_or_default();
211 let output_tokens = usage.output_tokens.unwrap_or_default();
212 let total_tokens = input_tokens + output_tokens;
213 lock.append_usage_delta(input_tokens, output_tokens, total_tokens);
214 }
215 }
216
217 MessagesStreamEvent::MessageStart { message, usage } => {
218 if let Some(usage) = usage {
219 let input_tokens = usage.input_tokens.unwrap_or_default();
220 let output_tokens = usage.output_tokens.unwrap_or_default();
221 let total_tokens = input_tokens + output_tokens;
222 lock.append_usage_delta(input_tokens, output_tokens, total_tokens);
223 }
224 if let Some(message_usage) = &message.usage {
225 let input_tokens = message_usage.input_tokens.unwrap_or_default();
226 let output_tokens = message_usage.output_tokens.unwrap_or_default();
227 let total_tokens = input_tokens + output_tokens;
228 lock.append_usage_delta(input_tokens, output_tokens, total_tokens);
229 }
230 }
231 _ => {}
232 }
233}
234
235impl Anthropic {
236 fn build_request(
237 &self,
238 request: &ChatCompletionRequest,
239 ) -> Result<async_anthropic::types::CreateMessagesRequestBuilder, LanguageModelError> {
240 let model = &self.default_options.prompt_model;
241 let mut messages = request.messages().to_vec();
242
243 let maybe_system = messages
244 .iter()
245 .position(ChatMessage::is_system)
246 .map(|idx| messages.remove(idx));
247
248 let messages = messages
249 .iter()
250 .map(message_to_antropic)
251 .collect::<Result<Vec<_>>>()?;
252
253 let mut anthropic_request = CreateMessagesRequestBuilder::default()
254 .model(model)
255 .messages(messages)
256 .to_owned();
257
258 if let Some(ChatMessage::System(system)) = maybe_system {
259 anthropic_request.system(system);
260 }
261
262 if !request.tools_spec.is_empty() {
263 anthropic_request
264 .tools(
265 request
266 .tools_spec()
267 .iter()
268 .map(tools_to_anthropic)
269 .collect::<Result<Vec<_>>>()?,
270 )
271 .tool_choice(ToolChoice::Auto);
272 }
273
274 Ok(anthropic_request)
275 }
276}
277
278#[allow(clippy::items_after_statements)]
279fn message_to_antropic(message: &ChatMessage) -> Result<Message> {
280 let mut builder = MessageBuilder::default().role(MessageRole::User).to_owned();
281
282 use ChatMessage::{Assistant, Summary, System, ToolOutput, User};
283
284 match message {
285 ToolOutput(tool_call, tool_output) => builder.content(
286 ToolResultBuilder::default()
287 .tool_use_id(tool_call.id())
288 .content(tool_output.content().unwrap_or("Success"))
289 .build()?,
290 ),
291 Summary(msg) | System(msg) | User(msg) => builder.content(msg),
292 Assistant(msg, tool_calls) => {
293 builder.role(MessageRole::Assistant);
294
295 let mut content_list: Vec<MessageContent> = Vec::new();
296
297 if let Some(msg) = msg {
298 content_list.push(msg.into());
299 }
300
301 if let Some(tool_calls) = tool_calls {
302 for tool_call in tool_calls {
303 let tool_call = ToolUseBuilder::default()
304 .id(tool_call.id())
305 .name(tool_call.name())
306 .input(tool_call.args().and_then(|v| v.parse::<Value>().ok()))
307 .build()?;
308
309 content_list.push(tool_call.into());
310 }
311 }
312
313 let content_list = MessageContentList(content_list);
314
315 builder.content(content_list)
316 }
317 };
318
319 builder.build().context("Failed to build message")
320}
321
322fn tools_to_anthropic(
323 spec: &ToolSpec,
324) -> Result<serde_json::value::Map<String, serde_json::Value>> {
325 let mut map = json!({
326 "name": &spec.name,
327 "description": &spec.description,
328 })
329 .as_object_mut()
330 .context("Failed to build tool")?
331 .to_owned();
332
333 let schema = match &spec.parameters_schema {
334 Some(schema) => serde_json::to_value(schema)?,
335 None => json!({
336 "type": "object",
337 "properties": {},
338 }),
339 };
340
341 map.insert("input_schema".to_string(), schema);
342
343 Ok(map)
344}
345
346#[cfg(test)]
347mod tests {
348
349 use super::*;
350 use schemars::{JsonSchema, schema_for};
351 use swiftide_core::{
352 AgentContext, Tool,
353 chat_completion::{ChatCompletionRequest, ChatMessage},
354 };
355 use wiremock::{
356 Mock, MockServer, ResponseTemplate,
357 matchers::{body_partial_json, method, path},
358 };
359
360 #[derive(Clone)]
361 struct FakeTool();
362
363 #[derive(JsonSchema, serde::Serialize, serde::Deserialize)]
364 struct LocationArgs {
365 location: String,
366 }
367
368 #[async_trait]
369 impl Tool for FakeTool {
370 async fn invoke(
371 &self,
372 _agent_context: &dyn AgentContext,
373 _tool_call: &ToolCall,
374 ) -> std::result::Result<
375 swiftide_core::chat_completion::ToolOutput,
376 swiftide_core::chat_completion::errors::ToolError,
377 > {
378 todo!()
379 }
380
381 fn name(&self) -> std::borrow::Cow<'_, str> {
382 "get_weather".into()
383 }
384
385 fn tool_spec(&self) -> ToolSpec {
386 ToolSpec::builder()
387 .description("Gets the weather")
388 .name("get_weather")
389 .parameters_schema(schema_for!(LocationArgs))
390 .build()
391 .unwrap()
392 }
393 }
394
395 #[test_log::test(tokio::test)]
396 async fn test_complete_without_tools() {
397 let mock_server = MockServer::start().await;
399
400 let mock_response = ResponseTemplate::new(200).set_body_json(serde_json::json!({
402 "content": [{"type": "text", "text": "mocked response"}]
403 }));
404
405 Mock::given(method("POST"))
407 .and(path("/v1/messages")) .respond_with(mock_response)
409 .mount(&mock_server)
410 .await;
411
412 let client = async_anthropic::Client::builder()
413 .base_url(mock_server.uri())
414 .build()
415 .unwrap();
416
417 let mut client_builder = Anthropic::builder();
419 client_builder.client(client);
420 let client = client_builder.build().unwrap();
421
422 let request = ChatCompletionRequest::builder()
424 .messages(vec![ChatMessage::User("hello".into())])
425 .build()
426 .unwrap();
427
428 let result = client.complete(&request).await.unwrap();
430
431 assert_eq!(result.message, Some("mocked response".into()));
433 assert!(result.tool_calls.is_none());
434 }
435
436 #[test_log::test(tokio::test)]
437 async fn test_complete_with_tools() {
438 let mock_server = MockServer::start().await;
440
441 let mock_response = ResponseTemplate::new(200).set_body_json(serde_json::json!({
443 "id": "msg_016zKNb88WhhgBQXhSaQf1rs",
444 "content": [
445 {
446 "type": "text",
447 "text": "I'll check the current weather in San Francisco, CA for you."
448 },
449 {
450 "type": "tool_use",
451 "id": "toolu_01E1yxpxXU4hBgCMLzPL1FuR",
452 "input": {
453 "location": "San Francisco, CA"
454 },
455 "name": "get_weather"
456 }
457 ],
458 "model": "claude-3-5-sonnet-20241022",
459 "stop_reason": "tool_use",
460 "stop_sequence": null,
461 "usage": {
462 "input_tokens": 403,
463 "output_tokens": 71
464 }
465 }));
466
467 Mock::given(method("POST"))
469 .and(path("/v1/messages")) .respond_with(mock_response)
471 .mount(&mock_server)
472 .await;
473
474 let client = async_anthropic::Client::builder()
475 .base_url(mock_server.uri())
476 .build()
477 .unwrap();
478
479 let mut client_builder = Anthropic::builder();
481 client_builder.client(client);
482 let client = client_builder.build().unwrap();
483
484 let request = ChatCompletionRequest::builder()
486 .messages(vec![ChatMessage::User("hello".into())])
487 .tool_specs([FakeTool().tool_spec()])
488 .build()
489 .unwrap();
490
491 let result = client.complete(&request).await.unwrap();
493
494 assert_eq!(
496 result.message,
497 Some("I'll check the current weather in San Francisco, CA for you.".into())
498 );
499 assert!(result.tool_calls.is_some());
500
501 let Some(tool_call) = result.tool_calls.and_then(|f| f.first().cloned()) else {
502 panic!("No tool call found")
503 };
504 assert_eq!(tool_call.name(), "get_weather");
505 assert_eq!(
506 tool_call.args(),
507 Some(
508 json!({"location": "San Francisco, CA"})
509 .to_string()
510 .as_str()
511 )
512 );
513 }
514
515 #[test_log::test(tokio::test)]
516 async fn test_complete_with_system_prompt() {
517 let mock_server = MockServer::start().await;
519
520 let mock_response = ResponseTemplate::new(200).set_body_json(serde_json::json!({
522 "content": [{"type": "text", "text": "Response with system prompt"}],
523 "usage": {
524 "input_tokens": 19,
525 "output_tokens": 10,
526 }
527 }));
528
529 Mock::given(method("POST"))
531 .and(path("/v1/messages")) .and(body_partial_json(json!({
533 "system": "System message",
534 "messages":[{"role":"user","content":[{"type":"text","text":"Hello"}]}]
535 })))
536 .respond_with(mock_response)
537 .mount(&mock_server)
538 .await;
539
540 let client = async_anthropic::Client::builder()
541 .base_url(mock_server.uri())
542 .build()
543 .unwrap();
544
545 let mut client_builder = Anthropic::builder();
547 client_builder.client(client);
548 let client = client_builder.build().unwrap();
549
550 let request = ChatCompletionRequest::builder()
552 .messages(vec![
553 ChatMessage::System("System message".into()),
554 ChatMessage::User("Hello".into()),
555 ])
556 .build()
557 .unwrap();
558
559 let response = client.complete(&request).await.unwrap();
561
562 assert_eq!(response.message, Some("Response with system prompt".into()));
564
565 let usage = response.usage.unwrap();
566 assert_eq!(usage.prompt_tokens, 19);
567 assert_eq!(usage.completion_tokens, 10);
568 assert_eq!(usage.total_tokens, 29);
569 }
570
571 #[test]
572 fn test_tools_to_anthropic() {
573 let tool_spec = ToolSpec::builder()
574 .description("Gets the weather")
575 .name("get_weather")
576 .parameters_schema(schema_for!(LocationArgs))
577 .build()
578 .unwrap();
579
580 let result = tools_to_anthropic(&tool_spec).unwrap();
581 let expected_schema = serde_json::to_value(schema_for!(LocationArgs)).unwrap();
582 let expected = json!({
583 "name": "get_weather",
584 "description": "Gets the weather",
585 "input_schema": expected_schema,
586 });
587
588 assert_eq!(serde_json::Value::Object(result), expected);
589 }
590}