swiftide_integrations/aws_bedrock/
mod.rs1use std::sync::Arc;
5
6use anyhow::Result;
7use async_trait::async_trait;
8use aws_sdk_bedrockruntime::{Client, error::SdkError, primitives::Blob};
9use derive_builder::Builder;
10use serde::Serialize;
11use tokio::runtime::Handle;
12
13#[cfg(test)]
14use mockall::{automock, predicate::*};
15
16mod models;
17mod simple_prompt;
18
19#[derive(Debug, Builder)]
31#[builder(setter(strip_option))]
32pub struct AwsBedrock {
33 #[builder(setter(into))]
34 model_id: String,
36
37 #[builder(default = self.default_client(), setter(custom))]
38 client: Arc<dyn BedrockPrompt>,
40 #[builder(default)]
41 model_config: ModelConfig,
43 model_family: ModelFamily,
45}
46
47#[cfg_attr(test, automock)]
48#[async_trait]
49trait BedrockPrompt: std::fmt::Debug + Send + Sync {
50 async fn prompt_u8(&self, model_id: &str, blob: Blob) -> Result<Vec<u8>>;
51}
52
53#[async_trait]
54impl BedrockPrompt for Client {
55 async fn prompt_u8(&self, model_id: &str, blob: Blob) -> Result<Vec<u8>> {
56 let response = self
57 .invoke_model()
58 .body(blob)
59 .model_id(model_id)
60 .send()
61 .await
62 .map_err(SdkError::into_service_error)?;
63
64 Ok(response.body.into_inner())
65 }
66}
67
68impl Clone for AwsBedrock {
69 fn clone(&self) -> Self {
70 Self {
71 model_id: self.model_id.clone(),
72 client: self.client.clone(),
73 model_config: self.model_config.clone(),
74 model_family: self.model_family.clone(),
75 }
76 }
77}
78
79impl AwsBedrock {
80 pub fn builder() -> AwsBedrockBuilder {
81 AwsBedrockBuilder::default()
82 }
83
84 pub fn build_titan_family(model_id: impl Into<String>) -> AwsBedrockBuilder {
86 Self::builder().titan().model_id(model_id).to_owned()
87 }
88
89 pub fn build_anthropic_family(model_id: impl Into<String>) -> AwsBedrockBuilder {
91 Self::builder().anthropic().model_id(model_id).to_owned()
92 }
93}
94impl AwsBedrockBuilder {
95 pub fn anthropic(&mut self) -> &mut Self {
97 self.model_family = Some(ModelFamily::Anthropic);
98 self
99 }
100
101 pub fn titan(&mut self) -> &mut Self {
103 self.model_family = Some(ModelFamily::Titan);
104 self
105 }
106
107 #[allow(clippy::unused_self)]
108 fn default_config(&self) -> aws_config::SdkConfig {
109 tokio::task::block_in_place(|| {
110 Handle::current().block_on(async { aws_config::from_env().load().await })
111 })
112 }
113 fn default_client(&self) -> Arc<Client> {
114 Arc::new(Client::new(&self.default_config()))
115 }
116
117 pub fn client(&mut self, client: Client) -> &mut Self {
119 self.client = Some(Arc::new(client));
120 self
121 }
122
123 #[cfg(test)]
124 #[allow(private_bounds)]
125 pub fn test_client(&mut self, client: impl BedrockPrompt + 'static) -> &mut Self {
126 self.client = Some(Arc::new(client));
127 self
128 }
129}
130
131use self::models::ModelFamily;
132
133#[derive(Serialize, Debug, Clone)]
134#[serde(rename_all = "camelCase")]
135pub struct ModelConfig {
136 temperature: f32,
137 top_p: f32,
138 max_token_count: i32,
139 stop_sequences: Vec<String>,
140}
141
142impl Default for ModelConfig {
143 fn default() -> Self {
144 Self {
145 temperature: 0.5,
146 top_p: 0.9,
147 max_token_count: 8192,
148 stop_sequences: vec![],
149 }
150 }
151}