syncable_cli/bedrock/
embedding.rs

1use aws_smithy_types::Blob;
2use rig::embeddings::{self, Embedding, EmbeddingError};
3use serde::{Deserialize, Serialize};
4
5use super::{client::Client, types::errors::AwsSdkInvokeModelError};
6
7#[derive(Serialize)]
8#[serde(rename_all = "camelCase")]
9pub struct EmbeddingRequest {
10    pub input_text: String,
11    pub dimensions: usize,
12    pub normalize: bool,
13}
14
15#[derive(Deserialize, Debug)]
16#[serde(rename_all = "camelCase")]
17pub struct EmbeddingResponse {
18    pub embedding: Vec<f64>,
19    pub input_text_token_count: usize,
20}
21
22/// `amazon.titan-embed-text-v1`
23pub const AMAZON_TITAN_EMBED_TEXT_V1: &str = "amazon.titan-embed-text-v1";
24/// `amazon.titan-embed-text-v2:0`
25pub const AMAZON_TITAN_EMBED_TEXT_V2_0: &str = "amazon.titan-embed-text-v2:0";
26/// `amazon.titan-embed-image-v1`
27pub const AMAZON_TITAN_EMBED_IMAGE_V1: &str = "amazon.titan-embed-image-v1";
28/// `cohere.embed-english-v3`
29pub const COHERE_EMBED_ENGLISH_V3: &str = "cohere.embed-english-v3";
30/// `cohere.embed-multilingual-v3`
31pub const COHERE_EMBED_MULTILINGUAL_V3: &str = "cohere.embed-multilingual-v3";
32
33#[derive(Clone)]
34pub struct EmbeddingModel {
35    client: Client,
36    model: String,
37    ndims: Option<usize>,
38}
39
40impl EmbeddingModel {
41    pub fn new(client: Client, model: impl Into<String>, ndims: Option<usize>) -> Self {
42        Self {
43            client,
44            model: model.into(),
45            ndims,
46        }
47    }
48
49    pub async fn document_to_embeddings(
50        &self,
51        request: EmbeddingRequest,
52    ) -> Result<EmbeddingResponse, EmbeddingError> {
53        let input_document = serde_json::to_string(&request).map_err(EmbeddingError::JsonError)?;
54
55        let model_response = self
56            .client
57            .get_inner()
58            .await
59            .invoke_model()
60            .model_id(self.model.as_str())
61            .content_type("application/json")
62            .accept("application/json")
63            .body(Blob::new(input_document))
64            .send()
65            .await;
66
67        let response = model_response
68            .map_err(|sdk_error| AwsSdkInvokeModelError(sdk_error).into())
69            .map_err(|e: EmbeddingError| e)?;
70
71        let response_str = String::from_utf8(response.body.into_inner())
72            .map_err(|e| EmbeddingError::ResponseError(e.to_string()))?;
73
74        let result: EmbeddingResponse =
75            serde_json::from_str(&response_str).map_err(EmbeddingError::JsonError)?;
76
77        Ok(result)
78    }
79}
80
81impl embeddings::EmbeddingModel for EmbeddingModel {
82    const MAX_DOCUMENTS: usize = 1024;
83
84    type Client = Client;
85
86    fn make(client: &Self::Client, model: impl Into<String>, dims: Option<usize>) -> Self {
87        Self::new(client.clone(), model, dims)
88    }
89
90    fn ndims(&self) -> usize {
91        self.ndims.unwrap_or_default()
92    }
93
94    async fn embed_texts(
95        &self,
96        documents: impl IntoIterator<Item = String> + Send,
97    ) -> Result<Vec<Embedding>, EmbeddingError> {
98        let documents: Vec<_> = documents.into_iter().collect();
99
100        let mut results = Vec::new();
101        let mut errors = Vec::new();
102
103        let mut iterator = documents.into_iter();
104        while let Some(embedding) = iterator.next().map(|doc| async move {
105            let request = EmbeddingRequest {
106                input_text: doc.to_owned(),
107                dimensions: self.ndims(),
108                normalize: true,
109            };
110            self.document_to_embeddings(request)
111                .await
112                .map(|embeddings| Embedding {
113                    document: doc.to_owned(),
114                    vec: embeddings.embedding,
115                })
116        }) {
117            match embedding.await {
118                Ok(embedding) => results.push(embedding),
119                Err(err) => errors.push(err),
120            }
121        }
122
123        match errors.as_slice() {
124            [] => Ok(results),
125            [err, ..] => Err(EmbeddingError::ResponseError(err.to_string())),
126        }
127    }
128}