1use std::convert::TryFrom;
4
5use crate::config::{CredentialsConfig, EmbeddingAdapter, EmbeddingConfig, ResolvedAuth};
6use crate::inference::error::InferenceError;
7use crate::inference::http::{AuthenticatedHttp, DEFAULT_INFERENCE_TIMEOUT};
8use crate::inference::types::{
9 EmbedChunkedDataItem, EmbedChunkedRequest, EmbedChunkedResponse, EmbedRequest,
10 OpenAiEmbeddingRequest, OpenAiEmbeddingResponse,
11};
12
13#[derive(Debug, Clone)]
15pub struct EmbeddingClient {
16 adapter: EmbeddingAdapter,
17 base_url: String,
18 model: String,
19 document_model: String,
20 http: AuthenticatedHttp,
21}
22
23impl EmbeddingClient {
24 pub fn from_config(
31 config: &EmbeddingConfig,
32 credentials: &CredentialsConfig,
33 ) -> Result<Self, InferenceError> {
34 let auth = config
35 .auth
36 .resolve(credentials)
37 .map_err(|err| InferenceError::Config {
38 message: err.to_string(),
39 })?;
40 let http = AuthenticatedHttp::with_timeout(DEFAULT_INFERENCE_TIMEOUT, auth, 3)?;
41 Ok(Self {
42 adapter: config.adapter,
43 base_url: config.base_url.clone(),
44 model: config.model.clone(),
45 document_model: config.document_model().to_owned(),
46 http,
47 })
48 }
49
50 pub fn tei_for_tests(
56 base_url: impl Into<String>,
57 model: impl Into<String>,
58 ) -> Result<Self, InferenceError> {
59 let model = model.into();
60 let http =
61 AuthenticatedHttp::with_timeout(DEFAULT_INFERENCE_TIMEOUT, ResolvedAuth::default(), 3)?;
62 Ok(Self {
63 adapter: EmbeddingAdapter::Tei,
64 base_url: base_url.into(),
65 document_model: model.clone(),
66 model,
67 http,
68 })
69 }
70
71 #[must_use]
73 pub fn chunk_model(&self) -> &str {
74 &self.model
75 }
76
77 #[must_use]
79 pub fn document_model(&self) -> &str {
80 &self.document_model
81 }
82
83 pub fn embed(&self, inputs: &[String]) -> Result<Vec<Vec<f32>>, InferenceError> {
89 match self.adapter {
90 EmbeddingAdapter::Tei => self.embed_tei(inputs),
91 EmbeddingAdapter::OpenAi => self.embed_openai(inputs, &self.model),
92 }
93 }
94
95 pub fn embed_chunked(
101 &self,
102 input: &[Vec<String>],
103 ) -> Result<EmbedChunkedResponse, InferenceError> {
104 match self.adapter {
105 EmbeddingAdapter::Tei => self.embed_chunked_tei(input),
106 EmbeddingAdapter::OpenAi => self.embed_chunked_openai(input),
107 }
108 }
109
110 fn embed_tei(&self, inputs: &[String]) -> Result<Vec<Vec<f32>>, InferenceError> {
111 let url = format!("{}/embed", self.base_url.trim_end_matches('/'));
112 let body = EmbedRequest {
113 inputs: inputs.to_vec(),
114 };
115 self.http.post_json(&url, &body)
116 }
117
118 fn embed_openai(
119 &self,
120 inputs: &[String],
121 model: &str,
122 ) -> Result<Vec<Vec<f32>>, InferenceError> {
123 let url = format!("{}/embeddings", self.base_url.trim_end_matches('/'));
124 let body = OpenAiEmbeddingRequest {
125 model: model.to_owned(),
126 input: inputs.to_vec(),
127 };
128 let response: OpenAiEmbeddingResponse = self.http.post_json(&url, &body)?;
129 let mut rows = response.data;
130 rows.sort_by_key(|row| row.index);
131 Ok(rows.into_iter().map(|row| row.embedding).collect())
132 }
133
134 fn embed_chunked_tei(
135 &self,
136 input: &[Vec<String>],
137 ) -> Result<EmbedChunkedResponse, InferenceError> {
138 let url = format!("{}/embed-chunked", self.base_url.trim_end_matches('/'));
139 let body = EmbedChunkedRequest {
140 input: input.to_vec(),
141 };
142 if input.len() <= 1 {
143 return self.http.post_json(&url, &body);
144 }
145 self.http
146 .post_json_with_retry(&url, &body)
147 .map_or_else(|_| self.embed_chunked_tei_fallback(&url, input), Ok)
148 }
149
150 fn embed_chunked_tei_fallback(
151 &self,
152 url: &str,
153 input: &[Vec<String>],
154 ) -> Result<EmbedChunkedResponse, InferenceError> {
155 let mut data = Vec::with_capacity(input.len());
156 let mut model: Option<String> = None;
157
158 for (index, group) in input.iter().enumerate() {
159 let body = EmbedChunkedRequest {
160 input: vec![group.clone()],
161 };
162 let mut response: EmbedChunkedResponse = self.http.post_json(url, &body)?;
163 let group_index = u32::try_from(index).map_err(|_| InferenceError::Decode {
164 message: "embed-chunked index overflow".to_owned(),
165 })?;
166 let Some(mut item) = response.data.pop() else {
167 return Err(InferenceError::Decode {
168 message: "embed-chunked fallback returned no data".to_owned(),
169 });
170 };
171 if !response.data.is_empty() {
172 return Err(InferenceError::Decode {
173 message: "embed-chunked fallback returned unexpected response shape".to_owned(),
174 });
175 }
176 item.index = group_index;
177 data.push(item);
178 if model.is_none() {
179 model = Some(response.model);
180 }
181 }
182
183 Ok(EmbedChunkedResponse {
184 data,
185 model: model.unwrap_or_default(),
186 })
187 }
188
189 fn embed_chunked_openai(
190 &self,
191 input: &[Vec<String>],
192 ) -> Result<EmbedChunkedResponse, InferenceError> {
193 let mut data = Vec::with_capacity(input.len());
194 for (index, group) in input.iter().enumerate() {
195 let embeddings = self.embed_openai(group, self.document_model())?;
196 data.push(EmbedChunkedDataItem {
197 embeddings,
198 index: u32::try_from(index).map_err(|_| InferenceError::Decode {
199 message: "embed-chunked index overflow".to_owned(),
200 })?,
201 });
202 }
203 Ok(EmbedChunkedResponse {
204 data,
205 model: self.document_model().to_owned(),
206 })
207 }
208}