syncable_cli/bedrock/
embedding.rs1use aws_smithy_types::Blob;
2use rig::embeddings::{self, Embedding, EmbeddingError};
3use serde::{Deserialize, Serialize};
4
5use super::{client::Client, types::errors::AwsSdkInvokeModelError};
6
7#[derive(Serialize)]
8#[serde(rename_all = "camelCase")]
9pub struct EmbeddingRequest {
10 pub input_text: String,
11 pub dimensions: usize,
12 pub normalize: bool,
13}
14
15#[derive(Deserialize, Debug)]
16#[serde(rename_all = "camelCase")]
17pub struct EmbeddingResponse {
18 pub embedding: Vec<f64>,
19 pub input_text_token_count: usize,
20}
21
22pub const AMAZON_TITAN_EMBED_TEXT_V1: &str = "amazon.titan-embed-text-v1";
24pub const AMAZON_TITAN_EMBED_TEXT_V2_0: &str = "amazon.titan-embed-text-v2:0";
26pub const AMAZON_TITAN_EMBED_IMAGE_V1: &str = "amazon.titan-embed-image-v1";
28pub const COHERE_EMBED_ENGLISH_V3: &str = "cohere.embed-english-v3";
30pub const COHERE_EMBED_MULTILINGUAL_V3: &str = "cohere.embed-multilingual-v3";
32
33#[derive(Clone)]
34pub struct EmbeddingModel {
35 client: Client,
36 model: String,
37 ndims: Option<usize>,
38}
39
40impl EmbeddingModel {
41 pub fn new(client: Client, model: impl Into<String>, ndims: Option<usize>) -> Self {
42 Self {
43 client,
44 model: model.into(),
45 ndims,
46 }
47 }
48
49 pub async fn document_to_embeddings(
50 &self,
51 request: EmbeddingRequest,
52 ) -> Result<EmbeddingResponse, EmbeddingError> {
53 let input_document = serde_json::to_string(&request).map_err(EmbeddingError::JsonError)?;
54
55 let model_response = self
56 .client
57 .get_inner()
58 .await
59 .invoke_model()
60 .model_id(self.model.as_str())
61 .content_type("application/json")
62 .accept("application/json")
63 .body(Blob::new(input_document))
64 .send()
65 .await;
66
67 let response = model_response
68 .map_err(|sdk_error| AwsSdkInvokeModelError(sdk_error).into())
69 .map_err(|e: EmbeddingError| e)?;
70
71 let response_str = String::from_utf8(response.body.into_inner())
72 .map_err(|e| EmbeddingError::ResponseError(e.to_string()))?;
73
74 let result: EmbeddingResponse =
75 serde_json::from_str(&response_str).map_err(EmbeddingError::JsonError)?;
76
77 Ok(result)
78 }
79}
80
81impl embeddings::EmbeddingModel for EmbeddingModel {
82 const MAX_DOCUMENTS: usize = 1024;
83
84 type Client = Client;
85
86 fn make(client: &Self::Client, model: impl Into<String>, dims: Option<usize>) -> Self {
87 Self::new(client.clone(), model, dims)
88 }
89
90 fn ndims(&self) -> usize {
91 self.ndims.unwrap_or_default()
92 }
93
94 async fn embed_texts(
95 &self,
96 documents: impl IntoIterator<Item = String> + Send,
97 ) -> Result<Vec<Embedding>, EmbeddingError> {
98 let documents: Vec<_> = documents.into_iter().collect();
99
100 let mut results = Vec::new();
101 let mut errors = Vec::new();
102
103 let mut iterator = documents.into_iter();
104 while let Some(embedding) = iterator.next().map(|doc| async move {
105 let request = EmbeddingRequest {
106 input_text: doc.to_owned(),
107 dimensions: self.ndims(),
108 normalize: true,
109 };
110 self.document_to_embeddings(request)
111 .await
112 .map(|embeddings| Embedding {
113 document: doc.to_owned(),
114 vec: embeddings.embedding,
115 })
116 }) {
117 match embedding.await {
118 Ok(embedding) => results.push(embedding),
119 Err(err) => errors.push(err),
120 }
121 }
122
123 match errors.as_slice() {
124 [] => Ok(results),
125 [err, ..] => Err(EmbeddingError::ResponseError(err.to_string())),
126 }
127 }
128}