Skip to main content

talon_core/inference/
embedding.rs

1//! Blocking HTTP client for embedding endpoints.
2
3use std::convert::TryFrom;
4
5use crate::config::{CredentialsConfig, EmbeddingAdapter, EmbeddingConfig, ResolvedAuth};
6use crate::inference::error::InferenceError;
7use crate::inference::http::{AuthenticatedHttp, DEFAULT_INFERENCE_TIMEOUT};
8use crate::inference::types::{
9    EmbedChunkedDataItem, EmbedChunkedRequest, EmbedChunkedResponse, EmbedRequest,
10    OpenAiEmbeddingRequest, OpenAiEmbeddingResponse,
11};
12
13/// Blocking client for configured embedding endpoints.
14#[derive(Debug, Clone)]
15pub struct EmbeddingClient {
16    adapter: EmbeddingAdapter,
17    base_url: String,
18    model: String,
19    document_model: String,
20    http: AuthenticatedHttp,
21}
22
23impl EmbeddingClient {
24    /// Builds a client from config and resolved credentials.
25    ///
26    /// # Errors
27    ///
28    /// Returns [`InferenceError::Build`] when the HTTP client cannot be built.
29    /// Returns [`InferenceError::Config`] when auth resolution fails.
30    pub fn from_config(
31        config: &EmbeddingConfig,
32        credentials: &CredentialsConfig,
33    ) -> Result<Self, InferenceError> {
34        let auth = config
35            .auth
36            .resolve(credentials)
37            .map_err(|err| InferenceError::Config {
38                message: err.to_string(),
39            })?;
40        let http = AuthenticatedHttp::with_timeout(DEFAULT_INFERENCE_TIMEOUT, auth, 3)?;
41        Ok(Self {
42            adapter: config.adapter,
43            base_url: config.base_url.clone(),
44            model: config.model.clone(),
45            document_model: config.document_model().to_owned(),
46            http,
47        })
48    }
49
50    /// Builds a TEI client for tests and wiremock fixtures.
51    ///
52    /// # Errors
53    ///
54    /// Returns [`InferenceError::Build`] when the HTTP client cannot be built.
55    pub fn tei_for_tests(
56        base_url: impl Into<String>,
57        model: impl Into<String>,
58    ) -> Result<Self, InferenceError> {
59        let model = model.into();
60        let http =
61            AuthenticatedHttp::with_timeout(DEFAULT_INFERENCE_TIMEOUT, ResolvedAuth::default(), 3)?;
62        Ok(Self {
63            adapter: EmbeddingAdapter::Tei,
64            base_url: base_url.into(),
65            document_model: model.clone(),
66            model,
67            http,
68        })
69    }
70
71    /// Model slug written to vector metadata for single-chunk notes.
72    #[must_use]
73    pub fn chunk_model(&self) -> &str {
74        &self.model
75    }
76
77    /// Model slug written to vector metadata for multi-chunk notes.
78    #[must_use]
79    pub fn document_model(&self) -> &str {
80        &self.document_model
81    }
82
83    /// Embeds a batch of texts and returns one vector per input.
84    ///
85    /// # Errors
86    ///
87    /// Returns [`InferenceError::Http`] or [`InferenceError::Decode`].
88    pub fn embed(&self, inputs: &[String]) -> Result<Vec<Vec<f32>>, InferenceError> {
89        match self.adapter {
90            EmbeddingAdapter::Tei => self.embed_tei(inputs),
91            EmbeddingAdapter::OpenAi => self.embed_openai(inputs, &self.model),
92        }
93    }
94
95    /// Embeds grouped chunks (one group per note).
96    ///
97    /// # Errors
98    ///
99    /// Returns [`InferenceError::Http`] or [`InferenceError::Decode`].
100    pub fn embed_chunked(
101        &self,
102        input: &[Vec<String>],
103    ) -> Result<EmbedChunkedResponse, InferenceError> {
104        match self.adapter {
105            EmbeddingAdapter::Tei => self.embed_chunked_tei(input),
106            EmbeddingAdapter::OpenAi => self.embed_chunked_openai(input),
107        }
108    }
109
110    fn embed_tei(&self, inputs: &[String]) -> Result<Vec<Vec<f32>>, InferenceError> {
111        let url = format!("{}/embed", self.base_url.trim_end_matches('/'));
112        let body = EmbedRequest {
113            inputs: inputs.to_vec(),
114        };
115        self.http.post_json(&url, &body)
116    }
117
118    fn embed_openai(
119        &self,
120        inputs: &[String],
121        model: &str,
122    ) -> Result<Vec<Vec<f32>>, InferenceError> {
123        let url = format!("{}/embeddings", self.base_url.trim_end_matches('/'));
124        let body = OpenAiEmbeddingRequest {
125            model: model.to_owned(),
126            input: inputs.to_vec(),
127        };
128        let response: OpenAiEmbeddingResponse = self.http.post_json(&url, &body)?;
129        let mut rows = response.data;
130        rows.sort_by_key(|row| row.index);
131        Ok(rows.into_iter().map(|row| row.embedding).collect())
132    }
133
134    fn embed_chunked_tei(
135        &self,
136        input: &[Vec<String>],
137    ) -> Result<EmbedChunkedResponse, InferenceError> {
138        let url = format!("{}/embed-chunked", self.base_url.trim_end_matches('/'));
139        let body = EmbedChunkedRequest {
140            input: input.to_vec(),
141        };
142        if input.len() <= 1 {
143            return self.http.post_json(&url, &body);
144        }
145        self.http
146            .post_json_with_retry(&url, &body)
147            .map_or_else(|_| self.embed_chunked_tei_fallback(&url, input), Ok)
148    }
149
150    fn embed_chunked_tei_fallback(
151        &self,
152        url: &str,
153        input: &[Vec<String>],
154    ) -> Result<EmbedChunkedResponse, InferenceError> {
155        let mut data = Vec::with_capacity(input.len());
156        let mut model: Option<String> = None;
157
158        for (index, group) in input.iter().enumerate() {
159            let body = EmbedChunkedRequest {
160                input: vec![group.clone()],
161            };
162            let mut response: EmbedChunkedResponse = self.http.post_json(url, &body)?;
163            let group_index = u32::try_from(index).map_err(|_| InferenceError::Decode {
164                message: "embed-chunked index overflow".to_owned(),
165            })?;
166            let Some(mut item) = response.data.pop() else {
167                return Err(InferenceError::Decode {
168                    message: "embed-chunked fallback returned no data".to_owned(),
169                });
170            };
171            if !response.data.is_empty() {
172                return Err(InferenceError::Decode {
173                    message: "embed-chunked fallback returned unexpected response shape".to_owned(),
174                });
175            }
176            item.index = group_index;
177            data.push(item);
178            if model.is_none() {
179                model = Some(response.model);
180            }
181        }
182
183        Ok(EmbedChunkedResponse {
184            data,
185            model: model.unwrap_or_default(),
186        })
187    }
188
189    fn embed_chunked_openai(
190        &self,
191        input: &[Vec<String>],
192    ) -> Result<EmbedChunkedResponse, InferenceError> {
193        let mut data = Vec::with_capacity(input.len());
194        for (index, group) in input.iter().enumerate() {
195            let embeddings = self.embed_openai(group, self.document_model())?;
196            data.push(EmbedChunkedDataItem {
197                embeddings,
198                index: u32::try_from(index).map_err(|_| InferenceError::Decode {
199                    message: "embed-chunked index overflow".to_owned(),
200                })?,
201            });
202        }
203        Ok(EmbedChunkedResponse {
204            data,
205            model: self.document_model().to_owned(),
206        })
207    }
208}