1use crate::embedding::Embedding;
4use crate::error::{EmbeddingError, EmbeddingResult};
5use async_trait::async_trait;
6use serde::{Deserialize, Serialize};
7
8#[derive(Debug, Clone)]
10pub struct EmbeddingOutput {
11 pub embeddings: Vec<Embedding>,
13 pub total_tokens: Option<u64>,
15 pub model: String,
17}
18
19impl EmbeddingOutput {
20 pub fn new(embeddings: Vec<Embedding>, model: impl Into<String>) -> Self {
22 Self {
23 embeddings,
24 total_tokens: None,
25 model: model.into(),
26 }
27 }
28
29 pub fn with_tokens(mut self, tokens: u64) -> Self {
31 self.total_tokens = Some(tokens);
32 self
33 }
34
35 pub fn embedding(&self) -> Option<&Embedding> {
37 self.embeddings.first()
38 }
39
40 pub fn dimensions(&self) -> Option<usize> {
42 self.embeddings.first().map(|e| e.dimensions())
43 }
44
45 pub fn is_empty(&self) -> bool {
47 self.embeddings.is_empty()
48 }
49
50 pub fn len(&self) -> usize {
52 self.embeddings.len()
53 }
54}
55
56#[derive(Debug, Clone)]
58pub enum EmbedInput {
59 Query(String),
61 Documents(Vec<String>),
63}
64
65impl EmbedInput {
66 pub fn len(&self) -> usize {
68 match self {
69 Self::Query(_) => 1,
70 Self::Documents(docs) => docs.len(),
71 }
72 }
73
74 pub fn is_empty(&self) -> bool {
76 match self {
77 Self::Query(q) => q.is_empty(),
78 Self::Documents(docs) => docs.is_empty(),
79 }
80 }
81
82 pub fn into_texts(self) -> Vec<String> {
84 match self {
85 Self::Query(q) => vec![q],
86 Self::Documents(docs) => docs,
87 }
88 }
89
90 pub fn texts(&self) -> Vec<&str> {
92 match self {
93 Self::Query(q) => vec![q.as_str()],
94 Self::Documents(docs) => docs.iter().map(|s| s.as_str()).collect(),
95 }
96 }
97}
98
99impl From<&str> for EmbedInput {
100 fn from(s: &str) -> Self {
101 Self::Query(s.to_string())
102 }
103}
104
105impl From<String> for EmbedInput {
106 fn from(s: String) -> Self {
107 Self::Query(s)
108 }
109}
110
111impl From<Vec<String>> for EmbedInput {
112 fn from(docs: Vec<String>) -> Self {
113 Self::Documents(docs)
114 }
115}
116
117impl From<Vec<&str>> for EmbedInput {
118 fn from(docs: Vec<&str>) -> Self {
119 Self::Documents(docs.into_iter().map(String::from).collect())
120 }
121}
122
123#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
125#[serde(rename_all = "lowercase")]
126pub enum EncodingFormat {
127 #[default]
129 Float,
130 Base64,
132}
133
134#[derive(Debug, Clone, Default)]
136pub struct EmbeddingSettings {
137 pub dimensions: Option<usize>,
139 pub encoding_format: Option<EncodingFormat>,
141 pub user: Option<String>,
143 pub input_type: Option<InputType>,
145 pub truncation: Option<TruncationMode>,
147}
148
149impl EmbeddingSettings {
150 pub fn new() -> Self {
152 Self::default()
153 }
154
155 pub fn dimensions(mut self, dims: usize) -> Self {
157 self.dimensions = Some(dims);
158 self
159 }
160
161 pub fn encoding_format(mut self, format: EncodingFormat) -> Self {
163 self.encoding_format = Some(format);
164 self
165 }
166
167 pub fn user(mut self, user: impl Into<String>) -> Self {
169 self.user = Some(user.into());
170 self
171 }
172
173 pub fn input_type(mut self, input_type: InputType) -> Self {
175 self.input_type = Some(input_type);
176 self
177 }
178
179 pub fn truncation(mut self, mode: TruncationMode) -> Self {
181 self.truncation = Some(mode);
182 self
183 }
184}
185
186#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
188#[serde(rename_all = "snake_case")]
189pub enum InputType {
190 SearchQuery,
192 SearchDocument,
194 Classification,
196 Clustering,
198}
199
200#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
202pub enum TruncationMode {
203 #[default]
205 End,
206 Start,
208 None,
210}
211
212#[async_trait]
214pub trait EmbeddingModel: Send + Sync {
215 fn name(&self) -> &str;
217
218 fn dimensions(&self) -> usize;
220
221 fn max_tokens(&self) -> usize {
223 8192 }
225
226 async fn embed(
228 &self,
229 input: EmbedInput,
230 settings: &EmbeddingSettings,
231 ) -> EmbeddingResult<EmbeddingOutput>;
232
233 async fn embed_query(&self, query: &str) -> EmbeddingResult<EmbeddingOutput> {
235 self.embed(
236 EmbedInput::Query(query.to_string()),
237 &EmbeddingSettings::default().input_type(InputType::SearchQuery),
238 )
239 .await
240 }
241
242 async fn embed_documents(&self, docs: Vec<String>) -> EmbeddingResult<EmbeddingOutput> {
244 self.embed(
245 EmbedInput::Documents(docs),
246 &EmbeddingSettings::default().input_type(InputType::SearchDocument),
247 )
248 .await
249 }
250
251 async fn count_tokens(&self, _text: &str) -> EmbeddingResult<u64> {
253 Err(EmbeddingError::NotSupported("Token counting".into()))
254 }
255}
256
257pub type BoxedEmbeddingModel = Box<dyn EmbeddingModel>;
259
260#[cfg(test)]
261mod tests {
262 use super::*;
263
264 #[test]
265 fn test_embed_input_from_str() {
266 let input: EmbedInput = "hello".into();
267 assert!(matches!(input, EmbedInput::Query(_)));
268 assert_eq!(input.len(), 1);
269 }
270
271 #[test]
272 fn test_embed_input_from_vec() {
273 let input: EmbedInput = vec!["a", "b", "c"].into();
274 assert!(matches!(input, EmbedInput::Documents(_)));
275 assert_eq!(input.len(), 3);
276 }
277
278 #[test]
279 fn test_embedding_output() {
280 let embeddings = vec![
281 Embedding::new(vec![1.0, 2.0, 3.0]),
282 Embedding::new(vec![4.0, 5.0, 6.0]),
283 ];
284 let output = EmbeddingOutput::new(embeddings, "test-model").with_tokens(100);
285
286 assert_eq!(output.len(), 2);
287 assert_eq!(output.dimensions(), Some(3));
288 assert_eq!(output.total_tokens, Some(100));
289 }
290
291 #[test]
292 fn test_embedding_settings() {
293 let settings = EmbeddingSettings::new()
294 .dimensions(1536)
295 .input_type(InputType::SearchQuery)
296 .user("user-123");
297
298 assert_eq!(settings.dimensions, Some(1536));
299 assert_eq!(settings.input_type, Some(InputType::SearchQuery));
300 }
301}