syncable_cli/bedrock/
client.rs

1use super::image::ImageGenerationModel;
2use super::{completion::CompletionModel, embedding::EmbeddingModel};
3use aws_config::{BehaviorVersion, Region};
4use rig::client::Nothing;
5use rig::prelude::*;
6use std::sync::Arc;
7use tokio::sync::OnceCell;
8
9pub const DEFAULT_AWS_REGION: &str = "us-east-1";
10
11#[derive(Clone)]
12pub struct ClientBuilder<'a> {
13    region: &'a str,
14}
15
16impl<'a> ClientBuilder<'a> {
17    #[deprecated(
18        since = "0.2.6",
19        note = "Use `Client::from_env` or `Client::with_profile_name(\"aws_profile\")` instead"
20    )]
21    pub fn new() -> Self {
22        Self {
23            region: DEFAULT_AWS_REGION,
24        }
25    }
26
27    /// Make sure to verify model and region [compatibility]
28    ///
29    /// [compatibility]: https://docs.aws.amazon.com/bedrock/latest/userguide/models-regions.html
30    pub fn region(mut self, region: &'a str) -> Self {
31        self.region = region;
32        self
33    }
34
35    /// Make sure you have permissions to access [Amazon Bedrock foundation model]
36    ///
37    /// [ Amazon Bedrock foundation model]: <https://docs.aws.amazon.com/bedrock/latest/userguide/model-access-modify.html>
38    pub async fn build(self) -> Client {
39        let sdk_config = aws_config::defaults(BehaviorVersion::latest())
40            .region(Region::new(String::from(self.region)))
41            .load()
42            .await;
43        let client = aws_sdk_bedrockruntime::Client::new(&sdk_config);
44        Client {
45            profile_name: None,
46            aws_client: Arc::new(OnceCell::from(client)),
47        }
48    }
49}
50
51impl Default for ClientBuilder<'_> {
52    fn default() -> Self {
53        #[allow(deprecated)]
54        Self::new()
55    }
56}
57
58#[derive(Clone, Debug)]
59pub struct Client {
60    profile_name: Option<String>,
61    pub(crate) aws_client: Arc<OnceCell<aws_sdk_bedrockruntime::Client>>,
62}
63
64impl From<aws_sdk_bedrockruntime::Client> for Client {
65    fn from(aws_client: aws_sdk_bedrockruntime::Client) -> Self {
66        Client {
67            profile_name: None,
68            aws_client: Arc::new(OnceCell::from(aws_client)),
69        }
70    }
71}
72
73impl Client {
74    fn new() -> Self {
75        Self {
76            profile_name: None,
77            aws_client: Arc::new(OnceCell::new()),
78        }
79    }
80
81    /// Create an AWS Bedrock client using AWS profile name
82    pub fn with_profile_name(profile_name: &str) -> Self {
83        Self {
84            profile_name: Some(profile_name.into()),
85            aws_client: Arc::new(OnceCell::new()),
86        }
87    }
88
89    pub async fn get_inner(&self) -> &aws_sdk_bedrockruntime::Client {
90        self.aws_client
91            .get_or_init(|| async {
92                let config = if let Some(profile_name) = &self.profile_name {
93                    aws_config::defaults(BehaviorVersion::latest())
94                        .profile_name(profile_name)
95                        .load()
96                        .await
97                } else {
98                    aws_config::load_from_env().await
99                };
100                aws_sdk_bedrockruntime::Client::new(&config)
101            })
102            .await
103    }
104}
105
106impl ProviderClient for Client {
107    type Input = Nothing;
108
109    fn from_env() -> Self
110    where
111        Self: Sized,
112    {
113        Client::new()
114    }
115
116    fn from_val(_: Nothing) -> Self
117    where
118        Self: Sized,
119    {
120        panic!(
121            "Please use `Client::from_env` or `Client::with_profile_name(\"aws_profile\")` instead"
122        );
123    }
124}
125
126impl CompletionClient for Client {
127    type CompletionModel = CompletionModel;
128
129    fn completion_model(&self, model: impl Into<String>) -> Self::CompletionModel {
130        CompletionModel::new(self.clone(), model)
131    }
132}
133
134impl EmbeddingsClient for Client {
135    type EmbeddingModel = EmbeddingModel;
136
137    fn embedding_model(&self, model: impl Into<String>) -> Self::EmbeddingModel {
138        EmbeddingModel::new(self.clone(), model, None)
139    }
140
141    fn embedding_model_with_ndims(
142        &self,
143        model: impl Into<String>,
144        ndims: usize,
145    ) -> Self::EmbeddingModel {
146        EmbeddingModel::new(self.clone(), model, Some(ndims))
147    }
148}
149
150impl ImageGenerationClient for Client {
151    type ImageGenerationModel = ImageGenerationModel;
152
153    fn image_generation_model(&self, model: impl Into<String>) -> Self::ImageGenerationModel {
154        ImageGenerationModel::new(self.clone(), model)
155    }
156}
157
158impl VerifyClient for Client {
159    async fn verify(&self) -> Result<(), VerifyError> {
160        // No API endpoint to verify the API key
161        Ok(())
162    }
163}