syncable_cli/bedrock/
client.rs1use super::image::ImageGenerationModel;
2use super::{completion::CompletionModel, embedding::EmbeddingModel};
3use aws_config::{BehaviorVersion, Region};
4use rig::client::Nothing;
5use rig::prelude::*;
6use std::sync::Arc;
7use tokio::sync::OnceCell;
8
9pub const DEFAULT_AWS_REGION: &str = "us-east-1";
10
11#[derive(Clone)]
12pub struct ClientBuilder<'a> {
13 region: &'a str,
14}
15
16impl<'a> ClientBuilder<'a> {
17 #[deprecated(
18 since = "0.2.6",
19 note = "Use `Client::from_env` or `Client::with_profile_name(\"aws_profile\")` instead"
20 )]
21 pub fn new() -> Self {
22 Self {
23 region: DEFAULT_AWS_REGION,
24 }
25 }
26
27 pub fn region(mut self, region: &'a str) -> Self {
31 self.region = region;
32 self
33 }
34
35 pub async fn build(self) -> Client {
39 let sdk_config = aws_config::defaults(BehaviorVersion::latest())
40 .region(Region::new(String::from(self.region)))
41 .load()
42 .await;
43 let client = aws_sdk_bedrockruntime::Client::new(&sdk_config);
44 Client {
45 profile_name: None,
46 aws_client: Arc::new(OnceCell::from(client)),
47 }
48 }
49}
50
51impl Default for ClientBuilder<'_> {
52 fn default() -> Self {
53 #[allow(deprecated)]
54 Self::new()
55 }
56}
57
58#[derive(Clone, Debug)]
59pub struct Client {
60 profile_name: Option<String>,
61 pub(crate) aws_client: Arc<OnceCell<aws_sdk_bedrockruntime::Client>>,
62}
63
64impl From<aws_sdk_bedrockruntime::Client> for Client {
65 fn from(aws_client: aws_sdk_bedrockruntime::Client) -> Self {
66 Client {
67 profile_name: None,
68 aws_client: Arc::new(OnceCell::from(aws_client)),
69 }
70 }
71}
72
73impl Client {
74 fn new() -> Self {
75 Self {
76 profile_name: None,
77 aws_client: Arc::new(OnceCell::new()),
78 }
79 }
80
81 pub fn with_profile_name(profile_name: &str) -> Self {
83 Self {
84 profile_name: Some(profile_name.into()),
85 aws_client: Arc::new(OnceCell::new()),
86 }
87 }
88
89 pub async fn get_inner(&self) -> &aws_sdk_bedrockruntime::Client {
90 self.aws_client
91 .get_or_init(|| async {
92 let config = if let Some(profile_name) = &self.profile_name {
93 aws_config::defaults(BehaviorVersion::latest())
94 .profile_name(profile_name)
95 .load()
96 .await
97 } else {
98 aws_config::load_from_env().await
99 };
100 aws_sdk_bedrockruntime::Client::new(&config)
101 })
102 .await
103 }
104}
105
106impl ProviderClient for Client {
107 type Input = Nothing;
108
109 fn from_env() -> Self
110 where
111 Self: Sized,
112 {
113 Client::new()
114 }
115
116 fn from_val(_: Nothing) -> Self
117 where
118 Self: Sized,
119 {
120 panic!(
121 "Please use `Client::from_env` or `Client::with_profile_name(\"aws_profile\")` instead"
122 );
123 }
124}
125
126impl CompletionClient for Client {
127 type CompletionModel = CompletionModel;
128
129 fn completion_model(&self, model: impl Into<String>) -> Self::CompletionModel {
130 CompletionModel::new(self.clone(), model)
131 }
132}
133
134impl EmbeddingsClient for Client {
135 type EmbeddingModel = EmbeddingModel;
136
137 fn embedding_model(&self, model: impl Into<String>) -> Self::EmbeddingModel {
138 EmbeddingModel::new(self.clone(), model, None)
139 }
140
141 fn embedding_model_with_ndims(
142 &self,
143 model: impl Into<String>,
144 ndims: usize,
145 ) -> Self::EmbeddingModel {
146 EmbeddingModel::new(self.clone(), model, Some(ndims))
147 }
148}
149
150impl ImageGenerationClient for Client {
151 type ImageGenerationModel = ImageGenerationModel;
152
153 fn image_generation_model(&self, model: impl Into<String>) -> Self::ImageGenerationModel {
154 ImageGenerationModel::new(self.clone(), model)
155 }
156}
157
158impl VerifyClient for Client {
159 async fn verify(&self) -> Result<(), VerifyError> {
160 Ok(())
162 }
163}