Skip to main content

sapphire_retrieve/
embed.rs

1//! Text embedding providers.
2//!
3//! Converts text into float vectors used for semantic similarity search.
4//!
5//! The supported providers are:
6//!
7//! - **`"openai"`** — OpenAI-compatible `/v1/embeddings` endpoint.
8//! - **`"ollama"`** — Ollama `/api/embed` endpoint.
9//! - **`"fastembed"`** — Local ONNX inference via the `fastembed` crate.
10//!   No server required; model weights are downloaded from Hugging Face
11//!   on first use and cached under `~/.cache/sapphire-retrieve/fastembed/`.
12
13use crate::error::{Error, Result};
14
15// ── configuration ─────────────────────────────────────────────────────────────
16
17/// Runtime embedding provider configuration passed to [`build_embedder`].
18///
19/// This is the minimal, non-serializable config used to construct an
20/// [`Embedder`] at runtime.  For the user-facing, serde-annotated config
21/// see [`crate::config::EmbeddingConfig`].
22#[derive(Debug, Clone)]
23pub struct EmbedderConfig {
24    /// Embedding provider: `"openai"`, `"ollama"`, or `"fastembed"`.
25    pub provider: String,
26    /// Model name or identifier (provider-specific).
27    pub model: String,
28    /// Environment variable holding the API key (default: `"OPENAI_API_KEY"`).
29    /// Only used by the `"openai"` provider.
30    pub api_key_env: Option<String>,
31    /// Base URL override for the embedding endpoint.
32    /// For `"openai"`: defaults to `https://api.openai.com`.
33    /// For `"ollama"`: defaults to `http://localhost:11434`.
34    pub base_url: Option<String>,
35    /// Directory where downloaded model weights are cached.
36    /// Only used by the `"fastembed"` provider.
37    /// Falls back to the OS temporary directory when `None`.
38    pub cache_dir: Option<std::path::PathBuf>,
39}
40
41// ── Embedder trait ────────────────────────────────────────────────────────────
42
43/// Abstraction over a text embedding provider.
44///
45/// Implementations hold any long-lived state needed for efficient repeated
46/// inference (e.g. the loaded ONNX model for `fastembed`).  REST-backed
47/// providers (OpenAI, Ollama) are stateless and simply store their config.
48pub trait Embedder: Send + Sync {
49    /// Generate embeddings for a batch of texts.
50    ///
51    /// Returns one `Vec<f32>` per input text, in the same order.
52    /// Returns an empty `Vec` when `texts` is empty.
53    fn embed_texts(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>>;
54}
55
56/// Build an [`Embedder`] from a config.
57///
58/// For `"fastembed"` this loads the ONNX model from disk (or downloads it on
59/// first use), which can take several seconds.  For REST providers the
60/// returned value is lightweight.
61pub fn build_embedder(config: &EmbedderConfig) -> Result<Box<dyn Embedder + Send + Sync>> {
62    match config.provider.as_str() {
63        "openai" | "ollama" => Ok(Box::new(RestEmbedder {
64            config: config.clone(),
65        })),
66        #[cfg(feature = "fastembed-embed")]
67        "fastembed" => Ok(Box::new(FastEmbedEmbedder::new(config)?)),
68        other => Err(Error::Embed(format!(
69            "unknown embedding provider `{other}`; supported values: openai, ollama{}",
70            if cfg!(feature = "fastembed-embed") {
71                ", fastembed"
72            } else {
73                ""
74            }
75        ))),
76    }
77}
78
79// ── REST embedder (OpenAI / Ollama) ───────────────────────────────────────────
80
81struct RestEmbedder {
82    config: EmbedderConfig,
83}
84
85impl Embedder for RestEmbedder {
86    fn embed_texts(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
87        if texts.is_empty() {
88            return Ok(Vec::new());
89        }
90        match self.config.provider.as_str() {
91            "openai" => embed_openai(&self.config, texts),
92            "ollama" => embed_ollama(&self.config, texts),
93            other => Err(Error::Embed(format!("unknown REST provider `{other}`"))),
94        }
95    }
96}
97
98// ── fastembed embedder ────────────────────────────────────────────────────────
99
100#[cfg(feature = "fastembed-embed")]
101struct FastEmbedEmbedder {
102    model: std::sync::Mutex<fastembed::TextEmbedding>,
103}
104
105#[cfg(feature = "fastembed-embed")]
106impl FastEmbedEmbedder {
107    fn new(config: &EmbedderConfig) -> Result<Self> {
108        use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
109
110        let model_variant = match config.model.as_str() {
111            "AllMiniLML6V2" => EmbeddingModel::AllMiniLML6V2,
112            "BGESmallENV15" => EmbeddingModel::BGESmallENV15,
113            "BGEBaseENV15" => EmbeddingModel::BGEBaseENV15,
114            "BGELargeENV15" => EmbeddingModel::BGELargeENV15,
115            "NomicEmbedTextV1" => EmbeddingModel::NomicEmbedTextV1,
116            "NomicEmbedTextV15" => EmbeddingModel::NomicEmbedTextV15,
117            "MultilingualE5Small" => EmbeddingModel::MultilingualE5Small,
118            "MultilingualE5Base" => EmbeddingModel::MultilingualE5Base,
119            "MultilingualE5Large" => EmbeddingModel::MultilingualE5Large,
120            other => {
121                return Err(Error::Embed(format!(
122                    "unknown fastembed model `{other}`; \
123                     supported: AllMiniLML6V2, BGESmallENV15, BGEBaseENV15, BGELargeENV15, \
124                     NomicEmbedTextV1, NomicEmbedTextV15, \
125                     MultilingualE5Small, MultilingualE5Base, MultilingualE5Large"
126                )));
127            }
128        };
129
130        let cache_dir = config
131            .cache_dir
132            .clone()
133            .unwrap_or_else(|| std::env::temp_dir().join("fastembed"));
134        let model = TextEmbedding::try_new(
135            InitOptions::new(model_variant)
136                .with_cache_dir(cache_dir)
137                .with_show_download_progress(true),
138        )
139        .map_err(|e| Error::Embed(format!("failed to load fastembed model: {e}")))?;
140
141        Ok(Self {
142            model: std::sync::Mutex::new(model),
143        })
144    }
145}
146
147#[cfg(feature = "fastembed-embed")]
148impl Embedder for FastEmbedEmbedder {
149    fn embed_texts(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
150        if texts.is_empty() {
151            return Ok(Vec::new());
152        }
153        let texts_owned: Vec<String> = texts.iter().map(|s| s.to_string()).collect();
154        self.model
155            .lock()
156            .unwrap()
157            .embed(texts_owned, None)
158            .map_err(|e| Error::Embed(format!("fastembed embedding failed: {e}")))
159    }
160}
161
162// ── OpenAI-compatible ─────────────────────────────────────────────────────────
163
164fn embed_openai(config: &EmbedderConfig, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
165    let api_key_env = config.api_key_env.as_deref().unwrap_or("OPENAI_API_KEY");
166    let api_key = std::env::var(api_key_env)
167        .map_err(|_| Error::Embed(format!("environment variable `{api_key_env}` is not set")))?;
168
169    let base_url = config
170        .base_url
171        .as_deref()
172        .unwrap_or("https://api.openai.com");
173    let url = format!("{base_url}/v1/embeddings");
174
175    let body = serde_json::json!({
176        "model": config.model,
177        "input": texts,
178    });
179
180    let response: serde_json::Value = ureq::post(&url)
181        .header("Authorization", &format!("Bearer {api_key}"))
182        .header("Content-Type", "application/json")
183        .send_json(body)
184        .map_err(|e| Error::Embed(e.to_string()))?
185        .into_body()
186        .read_json()
187        .map_err(|e| Error::Embed(e.to_string()))?;
188
189    parse_openai_response(&response, texts.len())
190}
191
192fn parse_openai_response(response: &serde_json::Value, expected: usize) -> Result<Vec<Vec<f32>>> {
193    let data = response["data"]
194        .as_array()
195        .ok_or_else(|| Error::Embed("unexpected OpenAI response: missing `data` array".into()))?;
196
197    let mut results = vec![Vec::new(); expected];
198    for item in data {
199        let index = item["index"]
200            .as_u64()
201            .ok_or_else(|| Error::Embed("missing `index` in embedding object".into()))?
202            as usize;
203        let vec = parse_float_array(&item["embedding"])?;
204        if index < results.len() {
205            results[index] = vec;
206        }
207    }
208    Ok(results)
209}
210
211// ── Ollama ────────────────────────────────────────────────────────────────────
212
213fn embed_ollama(config: &EmbedderConfig, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
214    let base_url = config
215        .base_url
216        .as_deref()
217        .unwrap_or("http://localhost:11434");
218    let url = format!("{base_url}/api/embed");
219
220    let body = serde_json::json!({
221        "model": config.model,
222        "input": texts,
223    });
224
225    let response: serde_json::Value = ureq::post(&url)
226        .header("Content-Type", "application/json")
227        .send_json(body)
228        .map_err(|e| Error::Embed(e.to_string()))?
229        .into_body()
230        .read_json()
231        .map_err(|e| Error::Embed(e.to_string()))?;
232
233    response["embeddings"]
234        .as_array()
235        .ok_or_else(|| {
236            Error::Embed("unexpected Ollama response: missing `embeddings` array".into())
237        })?
238        .iter()
239        .map(parse_float_array)
240        .collect()
241}
242
243// ── helpers ───────────────────────────────────────────────────────────────────
244
245fn parse_float_array(value: &serde_json::Value) -> Result<Vec<f32>> {
246    value
247        .as_array()
248        .ok_or_else(|| Error::Embed("embedding value is not a JSON array".into()))?
249        .iter()
250        .map(|v| {
251            v.as_f64()
252                .map(|f| f as f32)
253                .ok_or_else(|| Error::Embed("non-numeric value in embedding vector".into()))
254        })
255        .collect()
256}