swiftide_integrations/aws_bedrock/
simple_prompt.rs1use anyhow::Result;
2use async_trait::async_trait;
3use aws_sdk_bedrockruntime::primitives::Blob;
4use swiftide_core::{
5 chat_completion::errors::LanguageModelError, indexing::SimplePrompt, prompt::Prompt,
6};
7
8use super::AwsBedrock;
9
10#[async_trait]
11impl SimplePrompt for AwsBedrock {
12 #[tracing::instrument(skip_all, err)]
13 async fn prompt(&self, prompt: Prompt) -> Result<String, LanguageModelError> {
14 let blob = self
15 .model_family
16 .build_request_to_bytes(prompt.render()?, &self.model_config)
17 .map(Blob::new)?;
18
19 let response_bytes = self.client.prompt_u8(&self.model_id, blob).await?;
20
21 tracing::debug!(
22 "Received response: {:?}",
23 std::str::from_utf8(&response_bytes).map_err(LanguageModelError::permanent)
24 );
25
26 self.model_family
27 .output_message_from_bytes(&response_bytes)
28 .map_err(std::convert::Into::into)
29 }
30}
31
32#[cfg(test)]
33mod test {
34 use crate::aws_bedrock::MockBedrockPrompt;
35 use crate::aws_bedrock::models::*;
36
37 use super::*;
38 use anyhow::Context as _;
39 use test_log;
40
41 #[test_log::test(tokio::test)]
42 async fn test_prompt_with_titan() {
43 let mut bedrock_mock = MockBedrockPrompt::new();
44
45 bedrock_mock.expect_prompt_u8().once().returning(|_, _| {
46 serde_json::to_vec(&TitanResponse {
47 input_text_token_count: 1,
48 results: vec![TitanTextResult {
49 output_text: "Hello, world!".to_string(),
50 token_count: 1,
51 completion_reason: "STOP".to_string(),
52 }],
53 })
54 .context("Failed to serialize response")
55 });
56
57 let bedrock = AwsBedrock::build_titan_family("my_model")
58 .test_client(bedrock_mock)
59 .build()
60 .unwrap();
61
62 let response = bedrock.prompt("Hello".into()).await.unwrap();
63
64 assert_eq!(response, "Hello, world!");
65 }
66
67 #[test_log::test(tokio::test)]
68 async fn test_prompt_with_anthropic() {
69 let mut bedrock_mock = MockBedrockPrompt::new();
70 bedrock_mock.expect_prompt_u8().once().returning(|_, _| {
71 serde_json::to_vec(&AnthropicResponse {
72 content: vec![AnthropicMessageContent {
73 _type: "text".to_string(),
74 text: "Hello, world!".to_string(),
75 }],
76 id: "id".to_string(),
77 model: "model".to_string(),
78 _type: "text".to_string(),
79 role: "user".to_string(),
80 stop_reason: Some("max_tokens".to_string()),
81 stop_sequence: None,
82 usage: AnthropicUsage {
83 input_tokens: 10,
84 output_tokens: 10,
85 },
86 })
87 .context("Failed to serialize response")
88 });
89 let bedrock = AwsBedrock::build_anthropic_family("my_model")
90 .test_client(bedrock_mock)
91 .build()
92 .unwrap();
93 let response = bedrock.prompt("Hello".into()).await.unwrap();
94 assert_eq!(response, "Hello, world!");
95 }
96}