swiftide_integrations/aws_bedrock/
simple_prompt.rs

1use anyhow::Result;
2use async_trait::async_trait;
3use aws_sdk_bedrockruntime::primitives::Blob;
4use swiftide_core::{
5    chat_completion::errors::LanguageModelError, indexing::SimplePrompt, prompt::Prompt,
6};
7
8use super::AwsBedrock;
9
10#[async_trait]
11impl SimplePrompt for AwsBedrock {
12    #[tracing::instrument(skip_all, err)]
13    async fn prompt(&self, prompt: Prompt) -> Result<String, LanguageModelError> {
14        let blob = self
15            .model_family
16            .build_request_to_bytes(prompt.render()?, &self.model_config)
17            .map(Blob::new)?;
18
19        let response_bytes = self.client.prompt_u8(&self.model_id, blob).await?;
20
21        tracing::debug!(
22            "Received response: {:?}",
23            std::str::from_utf8(&response_bytes).map_err(LanguageModelError::permanent)
24        );
25
26        self.model_family
27            .output_message_from_bytes(&response_bytes)
28            .map_err(std::convert::Into::into)
29    }
30}
31
32#[cfg(test)]
33mod test {
34    use crate::aws_bedrock::MockBedrockPrompt;
35    use crate::aws_bedrock::models::*;
36
37    use super::*;
38    use anyhow::Context as _;
39    use test_log;
40
41    #[test_log::test(tokio::test)]
42    async fn test_prompt_with_titan() {
43        let mut bedrock_mock = MockBedrockPrompt::new();
44
45        bedrock_mock.expect_prompt_u8().once().returning(|_, _| {
46            serde_json::to_vec(&TitanResponse {
47                input_text_token_count: 1,
48                results: vec![TitanTextResult {
49                    output_text: "Hello, world!".to_string(),
50                    token_count: 1,
51                    completion_reason: "STOP".to_string(),
52                }],
53            })
54            .context("Failed to serialize response")
55        });
56
57        let bedrock = AwsBedrock::build_titan_family("my_model")
58            .test_client(bedrock_mock)
59            .build()
60            .unwrap();
61
62        let response = bedrock.prompt("Hello".into()).await.unwrap();
63
64        assert_eq!(response, "Hello, world!");
65    }
66
67    #[test_log::test(tokio::test)]
68    async fn test_prompt_with_anthropic() {
69        let mut bedrock_mock = MockBedrockPrompt::new();
70        bedrock_mock.expect_prompt_u8().once().returning(|_, _| {
71            serde_json::to_vec(&AnthropicResponse {
72                content: vec![AnthropicMessageContent {
73                    _type: "text".to_string(),
74                    text: "Hello, world!".to_string(),
75                }],
76                id: "id".to_string(),
77                model: "model".to_string(),
78                _type: "text".to_string(),
79                role: "user".to_string(),
80                stop_reason: Some("max_tokens".to_string()),
81                stop_sequence: None,
82                usage: AnthropicUsage {
83                    input_tokens: 10,
84                    output_tokens: 10,
85                },
86            })
87            .context("Failed to serialize response")
88        });
89        let bedrock = AwsBedrock::build_anthropic_family("my_model")
90            .test_client(bedrock_mock)
91            .build()
92            .unwrap();
93        let response = bedrock.prompt("Hello".into()).await.unwrap();
94        assert_eq!(response, "Hello, world!");
95    }
96}