swiftide_integrations/anthropic/
simple_prompt.rs

1use anyhow::Context as _;
2use async_anthropic::{errors::AnthropicError, types::CreateMessagesRequestBuilder};
3use async_trait::async_trait;
4use swiftide_core::{
5    chat_completion::{Usage, errors::LanguageModelError},
6    indexing::SimplePrompt,
7};
8
9#[cfg(feature = "metrics")]
10use swiftide_core::metrics::emit_usage;
11
12use super::Anthropic;
13
14#[async_trait]
15impl SimplePrompt for Anthropic {
16    #[tracing::instrument(skip_all, err)]
17    async fn prompt(
18        &self,
19        prompt: swiftide_core::prompt::Prompt,
20    ) -> Result<String, LanguageModelError> {
21        let model = &self.default_options.prompt_model;
22
23        let request = CreateMessagesRequestBuilder::default()
24            .model(model)
25            .messages(vec![prompt.render()?.into()])
26            .build()
27            .map_err(LanguageModelError::permanent)?;
28
29        tracing::debug!(
30            model = &model,
31            messages =
32                serde_json::to_string_pretty(&request).map_err(LanguageModelError::permanent)?,
33            "[SimplePrompt] Request to anthropic"
34        );
35
36        let response = self.client.messages().create(request).await.map_err(|e| {
37            match &e {
38                AnthropicError::NetworkError(_) => LanguageModelError::TransientError(e.into()),
39                // TODO: The Rust Anthropic client is not documented well, we should figure out
40                // which of these errors are client errors and which are server errors.
41                // And which would be the ContextLengthExceeded error
42                // For now, we'll just map all of them to client errors so we get feedback.
43                _ => LanguageModelError::PermanentError(e.into()),
44            }
45        })?;
46
47        tracing::debug!(
48            response =
49                serde_json::to_string_pretty(&response).map_err(LanguageModelError::permanent)?,
50            "[SimplePrompt] Response from anthropic"
51        );
52
53        if let Some(usage) = response.usage.as_ref() {
54            if let Some(callback) = &self.on_usage {
55                let usage = Usage {
56                    prompt_tokens: usage.input_tokens.unwrap_or_default(),
57                    completion_tokens: usage.output_tokens.unwrap_or_default(),
58                    total_tokens: (usage.input_tokens.unwrap_or_default()
59                        + usage.output_tokens.unwrap_or_default()),
60                };
61                callback(&usage).await?;
62            }
63
64            #[cfg(feature = "metrics")]
65            {
66                emit_usage(
67                    model,
68                    usage.input_tokens.unwrap_or_default().into(),
69                    usage.output_tokens.unwrap_or_default().into(),
70                    (usage.input_tokens.unwrap_or_default()
71                        + usage.output_tokens.unwrap_or_default())
72                    .into(),
73                    self.metric_metadata.as_ref(),
74                );
75            }
76        }
77
78        let message = response
79            .messages()
80            .into_iter()
81            .next()
82            .context("No messages in response")
83            .map_err(LanguageModelError::permanent)?;
84
85        message
86            .text()
87            .context("No text in response")
88            .map_err(LanguageModelError::permanent)
89    }
90}
91
92#[cfg(test)]
93mod tests {
94    use wiremock::{
95        Mock, MockServer, ResponseTemplate,
96        matchers::{method, path},
97    };
98
99    use super::*;
100
101    #[tokio::test]
102    async fn test_simple_prompt_with_mock() {
103        // Start a WireMock server
104        let mock_server = MockServer::start().await;
105
106        // Create a mock response
107        let mock_response = ResponseTemplate::new(200).set_body_json(serde_json::json!({
108            "content": [{"type": "text", "text": "mocked response"}]
109        }));
110
111        // Mock the expected endpoint
112        Mock::given(method("POST"))
113            .and(path("/v1/messages")) // Adjust path to match expected endpoint
114            .respond_with(mock_response)
115            .mount(&mock_server)
116            .await;
117
118        let client = async_anthropic::Client::builder()
119            .base_url(mock_server.uri())
120            .build()
121            .unwrap();
122
123        // Build an Anthropic client with the mock server's URL
124        let mut client_builder = Anthropic::builder();
125        client_builder.client(client);
126        let client = client_builder.build().unwrap();
127
128        // Call the prompt method
129        let result = client.prompt("hello".into()).await.unwrap();
130
131        // Assert the result
132        assert_eq!(result, "mocked response");
133    }
134}