swiftide_integrations/anthropic/
simple_prompt.rs1use anyhow::Context as _;
2use async_anthropic::{errors::AnthropicError, types::CreateMessagesRequestBuilder};
3use async_trait::async_trait;
4use swiftide_core::{
5 chat_completion::{Usage, errors::LanguageModelError},
6 indexing::SimplePrompt,
7};
8
9#[cfg(feature = "metrics")]
10use swiftide_core::metrics::emit_usage;
11
12use super::Anthropic;
13
14#[async_trait]
15impl SimplePrompt for Anthropic {
16 #[tracing::instrument(skip_all, err)]
17 async fn prompt(
18 &self,
19 prompt: swiftide_core::prompt::Prompt,
20 ) -> Result<String, LanguageModelError> {
21 let model = &self.default_options.prompt_model;
22
23 let request = CreateMessagesRequestBuilder::default()
24 .model(model)
25 .messages(vec![prompt.render()?.into()])
26 .build()
27 .map_err(LanguageModelError::permanent)?;
28
29 tracing::debug!(
30 model = &model,
31 messages =
32 serde_json::to_string_pretty(&request).map_err(LanguageModelError::permanent)?,
33 "[SimplePrompt] Request to anthropic"
34 );
35
36 let response = self.client.messages().create(request).await.map_err(|e| {
37 match &e {
38 AnthropicError::NetworkError(_) => LanguageModelError::TransientError(e.into()),
39 _ => LanguageModelError::PermanentError(e.into()),
44 }
45 })?;
46
47 tracing::debug!(
48 response =
49 serde_json::to_string_pretty(&response).map_err(LanguageModelError::permanent)?,
50 "[SimplePrompt] Response from anthropic"
51 );
52
53 if let Some(usage) = response.usage.as_ref() {
54 if let Some(callback) = &self.on_usage {
55 let usage = Usage {
56 prompt_tokens: usage.input_tokens.unwrap_or_default(),
57 completion_tokens: usage.output_tokens.unwrap_or_default(),
58 total_tokens: (usage.input_tokens.unwrap_or_default()
59 + usage.output_tokens.unwrap_or_default()),
60 };
61 callback(&usage).await?;
62 }
63
64 #[cfg(feature = "metrics")]
65 {
66 emit_usage(
67 model,
68 usage.input_tokens.unwrap_or_default().into(),
69 usage.output_tokens.unwrap_or_default().into(),
70 (usage.input_tokens.unwrap_or_default()
71 + usage.output_tokens.unwrap_or_default())
72 .into(),
73 self.metric_metadata.as_ref(),
74 );
75 }
76 }
77
78 let message = response
79 .messages()
80 .into_iter()
81 .next()
82 .context("No messages in response")
83 .map_err(LanguageModelError::permanent)?;
84
85 message
86 .text()
87 .context("No text in response")
88 .map_err(LanguageModelError::permanent)
89 }
90}
91
92#[cfg(test)]
93mod tests {
94 use wiremock::{
95 Mock, MockServer, ResponseTemplate,
96 matchers::{method, path},
97 };
98
99 use super::*;
100
101 #[tokio::test]
102 async fn test_simple_prompt_with_mock() {
103 let mock_server = MockServer::start().await;
105
106 let mock_response = ResponseTemplate::new(200).set_body_json(serde_json::json!({
108 "content": [{"type": "text", "text": "mocked response"}]
109 }));
110
111 Mock::given(method("POST"))
113 .and(path("/v1/messages")) .respond_with(mock_response)
115 .mount(&mock_server)
116 .await;
117
118 let client = async_anthropic::Client::builder()
119 .base_url(mock_server.uri())
120 .build()
121 .unwrap();
122
123 let mut client_builder = Anthropic::builder();
125 client_builder.client(client);
126 let client = client_builder.build().unwrap();
127
128 let result = client.prompt("hello".into()).await.unwrap();
130
131 assert_eq!(result, "mocked response");
133 }
134}