sapphire_retrieve/
embed.rs1use crate::error::{Error, Result};
14
15#[derive(Debug, Clone)]
23pub struct EmbedderConfig {
24 pub provider: String,
26 pub model: String,
28 pub api_key_env: Option<String>,
31 pub base_url: Option<String>,
35 pub cache_dir: Option<std::path::PathBuf>,
39}
40
41pub trait Embedder: Send + Sync {
49 fn embed_texts(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>>;
54}
55
56pub fn build_embedder(config: &EmbedderConfig) -> Result<Box<dyn Embedder + Send + Sync>> {
62 match config.provider.as_str() {
63 "openai" | "ollama" => Ok(Box::new(RestEmbedder {
64 config: config.clone(),
65 })),
66 #[cfg(feature = "fastembed-embed")]
67 "fastembed" => Ok(Box::new(FastEmbedEmbedder::new(config)?)),
68 other => Err(Error::Embed(format!(
69 "unknown embedding provider `{other}`; supported values: openai, ollama{}",
70 if cfg!(feature = "fastembed-embed") {
71 ", fastembed"
72 } else {
73 ""
74 }
75 ))),
76 }
77}
78
79struct RestEmbedder {
82 config: EmbedderConfig,
83}
84
85impl Embedder for RestEmbedder {
86 fn embed_texts(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
87 if texts.is_empty() {
88 return Ok(Vec::new());
89 }
90 match self.config.provider.as_str() {
91 "openai" => embed_openai(&self.config, texts),
92 "ollama" => embed_ollama(&self.config, texts),
93 other => Err(Error::Embed(format!("unknown REST provider `{other}`"))),
94 }
95 }
96}
97
98#[cfg(feature = "fastembed-embed")]
101struct FastEmbedEmbedder {
102 model: std::sync::Mutex<fastembed::TextEmbedding>,
103}
104
105#[cfg(feature = "fastembed-embed")]
106impl FastEmbedEmbedder {
107 fn new(config: &EmbedderConfig) -> Result<Self> {
108 use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
109
110 let model_variant = match config.model.as_str() {
111 "AllMiniLML6V2" => EmbeddingModel::AllMiniLML6V2,
112 "BGESmallENV15" => EmbeddingModel::BGESmallENV15,
113 "BGEBaseENV15" => EmbeddingModel::BGEBaseENV15,
114 "BGELargeENV15" => EmbeddingModel::BGELargeENV15,
115 "NomicEmbedTextV1" => EmbeddingModel::NomicEmbedTextV1,
116 "NomicEmbedTextV15" => EmbeddingModel::NomicEmbedTextV15,
117 "MultilingualE5Small" => EmbeddingModel::MultilingualE5Small,
118 "MultilingualE5Base" => EmbeddingModel::MultilingualE5Base,
119 "MultilingualE5Large" => EmbeddingModel::MultilingualE5Large,
120 other => {
121 return Err(Error::Embed(format!(
122 "unknown fastembed model `{other}`; \
123 supported: AllMiniLML6V2, BGESmallENV15, BGEBaseENV15, BGELargeENV15, \
124 NomicEmbedTextV1, NomicEmbedTextV15, \
125 MultilingualE5Small, MultilingualE5Base, MultilingualE5Large"
126 )));
127 }
128 };
129
130 let cache_dir = config
131 .cache_dir
132 .clone()
133 .unwrap_or_else(|| std::env::temp_dir().join("fastembed"));
134 let model = TextEmbedding::try_new(
135 InitOptions::new(model_variant)
136 .with_cache_dir(cache_dir)
137 .with_show_download_progress(true),
138 )
139 .map_err(|e| Error::Embed(format!("failed to load fastembed model: {e}")))?;
140
141 Ok(Self {
142 model: std::sync::Mutex::new(model),
143 })
144 }
145}
146
147#[cfg(feature = "fastembed-embed")]
148impl Embedder for FastEmbedEmbedder {
149 fn embed_texts(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
150 if texts.is_empty() {
151 return Ok(Vec::new());
152 }
153 let texts_owned: Vec<String> = texts.iter().map(|s| s.to_string()).collect();
154 self.model
155 .lock()
156 .unwrap()
157 .embed(texts_owned, None)
158 .map_err(|e| Error::Embed(format!("fastembed embedding failed: {e}")))
159 }
160}
161
162fn embed_openai(config: &EmbedderConfig, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
165 let api_key_env = config.api_key_env.as_deref().unwrap_or("OPENAI_API_KEY");
166 let api_key = std::env::var(api_key_env)
167 .map_err(|_| Error::Embed(format!("environment variable `{api_key_env}` is not set")))?;
168
169 let base_url = config
170 .base_url
171 .as_deref()
172 .unwrap_or("https://api.openai.com");
173 let url = format!("{base_url}/v1/embeddings");
174
175 let body = serde_json::json!({
176 "model": config.model,
177 "input": texts,
178 });
179
180 let response: serde_json::Value = ureq::post(&url)
181 .header("Authorization", &format!("Bearer {api_key}"))
182 .header("Content-Type", "application/json")
183 .send_json(body)
184 .map_err(|e| Error::Embed(e.to_string()))?
185 .into_body()
186 .read_json()
187 .map_err(|e| Error::Embed(e.to_string()))?;
188
189 parse_openai_response(&response, texts.len())
190}
191
192fn parse_openai_response(response: &serde_json::Value, expected: usize) -> Result<Vec<Vec<f32>>> {
193 let data = response["data"]
194 .as_array()
195 .ok_or_else(|| Error::Embed("unexpected OpenAI response: missing `data` array".into()))?;
196
197 let mut results = vec![Vec::new(); expected];
198 for item in data {
199 let index = item["index"]
200 .as_u64()
201 .ok_or_else(|| Error::Embed("missing `index` in embedding object".into()))?
202 as usize;
203 let vec = parse_float_array(&item["embedding"])?;
204 if index < results.len() {
205 results[index] = vec;
206 }
207 }
208 Ok(results)
209}
210
211fn embed_ollama(config: &EmbedderConfig, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
214 let base_url = config
215 .base_url
216 .as_deref()
217 .unwrap_or("http://localhost:11434");
218 let url = format!("{base_url}/api/embed");
219
220 let body = serde_json::json!({
221 "model": config.model,
222 "input": texts,
223 });
224
225 let response: serde_json::Value = ureq::post(&url)
226 .header("Content-Type", "application/json")
227 .send_json(body)
228 .map_err(|e| Error::Embed(e.to_string()))?
229 .into_body()
230 .read_json()
231 .map_err(|e| Error::Embed(e.to_string()))?;
232
233 response["embeddings"]
234 .as_array()
235 .ok_or_else(|| {
236 Error::Embed("unexpected Ollama response: missing `embeddings` array".into())
237 })?
238 .iter()
239 .map(parse_float_array)
240 .collect()
241}
242
243fn parse_float_array(value: &serde_json::Value) -> Result<Vec<f32>> {
246 value
247 .as_array()
248 .ok_or_else(|| Error::Embed("embedding value is not a JSON array".into()))?
249 .iter()
250 .map(|v| {
251 v.as_f64()
252 .map(|f| f as f32)
253 .ok_or_else(|| Error::Embed("non-numeric value in embedding vector".into()))
254 })
255 .collect()
256}