swiftide_integrations/aws_bedrock/
mod.rs

1//! An integration with the AWS Bedrock service.
2//!
3//! Supports various model families for prompting.
4use std::sync::Arc;
5
6use anyhow::Result;
7use async_trait::async_trait;
8use aws_sdk_bedrockruntime::{Client, error::SdkError, primitives::Blob};
9use derive_builder::Builder;
10use serde::Serialize;
11use tokio::runtime::Handle;
12
13#[cfg(test)]
14use mockall::{automock, predicate::*};
15
16mod models;
17mod simple_prompt;
18
19/// An integration with the AWS Bedrock service.
20///
21/// Can be used as `SimplePrompt`.
22///
23/// To use Bedrock, you need to have a model id and access to the service.
24/// By default, the aws sdk will be configured from the environment.
25/// If you have the aws cli properly configured with a region set, it should work out of the box.
26///
27/// Otherwise, you can use the builder for customization.
28///
29/// See the aws cli documentation for more information on how to get access to the service.
30#[derive(Debug, Builder)]
31#[builder(setter(strip_option))]
32pub struct AwsBedrock {
33    #[builder(setter(into))]
34    /// The model id or arn of the model to use
35    model_id: String,
36
37    #[builder(default = self.default_client(), setter(custom))]
38    /// The bedrock runtime client
39    client: Arc<dyn BedrockPrompt>,
40    #[builder(default)]
41    /// The model configuration to use
42    model_config: ModelConfig,
43    /// The model family to use. In bedrock, families share their api.
44    model_family: ModelFamily,
45}
46
47#[cfg_attr(test, automock)]
48#[async_trait]
49trait BedrockPrompt: std::fmt::Debug + Send + Sync {
50    async fn prompt_u8(&self, model_id: &str, blob: Blob) -> Result<Vec<u8>>;
51}
52
53#[async_trait]
54impl BedrockPrompt for Client {
55    async fn prompt_u8(&self, model_id: &str, blob: Blob) -> Result<Vec<u8>> {
56        let response = self
57            .invoke_model()
58            .body(blob)
59            .model_id(model_id)
60            .send()
61            .await
62            .map_err(SdkError::into_service_error)?;
63
64        Ok(response.body.into_inner())
65    }
66}
67
68impl Clone for AwsBedrock {
69    fn clone(&self) -> Self {
70        Self {
71            model_id: self.model_id.clone(),
72            client: self.client.clone(),
73            model_config: self.model_config.clone(),
74            model_family: self.model_family.clone(),
75        }
76    }
77}
78
79impl AwsBedrock {
80    pub fn builder() -> AwsBedrockBuilder {
81        AwsBedrockBuilder::default()
82    }
83
84    /// Build a new `AwsBedrock` instance with the Titan model family
85    pub fn build_titan_family(model_id: impl Into<String>) -> AwsBedrockBuilder {
86        Self::builder().titan().model_id(model_id).to_owned()
87    }
88
89    /// Build a new `AwsBedrock` instance with the Anthropic model family
90    pub fn build_anthropic_family(model_id: impl Into<String>) -> AwsBedrockBuilder {
91        Self::builder().anthropic().model_id(model_id).to_owned()
92    }
93}
94impl AwsBedrockBuilder {
95    /// Set the model family to Anthropic
96    pub fn anthropic(&mut self) -> &mut Self {
97        self.model_family = Some(ModelFamily::Anthropic);
98        self
99    }
100
101    /// Set the model family to Titan
102    pub fn titan(&mut self) -> &mut Self {
103        self.model_family = Some(ModelFamily::Titan);
104        self
105    }
106
107    #[allow(clippy::unused_self)]
108    fn default_config(&self) -> aws_config::SdkConfig {
109        tokio::task::block_in_place(|| {
110            Handle::current().block_on(async { aws_config::from_env().load().await })
111        })
112    }
113    fn default_client(&self) -> Arc<Client> {
114        Arc::new(Client::new(&self.default_config()))
115    }
116
117    /// Set the aws bedrock runtime client
118    pub fn client(&mut self, client: Client) -> &mut Self {
119        self.client = Some(Arc::new(client));
120        self
121    }
122
123    #[cfg(test)]
124    #[allow(private_bounds)]
125    pub fn test_client(&mut self, client: impl BedrockPrompt + 'static) -> &mut Self {
126        self.client = Some(Arc::new(client));
127        self
128    }
129}
130
131use self::models::ModelFamily;
132
133#[derive(Serialize, Debug, Clone)]
134#[serde(rename_all = "camelCase")]
135pub struct ModelConfig {
136    temperature: f32,
137    top_p: f32,
138    max_token_count: i32,
139    stop_sequences: Vec<String>,
140}
141
142impl Default for ModelConfig {
143    fn default() -> Self {
144        Self {
145            temperature: 0.5,
146            top_p: 0.9,
147            max_token_count: 8192,
148            stop_sequences: vec![],
149        }
150    }
151}