Skip to main content

seekr_code/embedder/
onnx.rs

1//! ONNX Runtime embedding backend.
2//!
3//! Loads a local ONNX model (all-MiniLM-L6-v2 quantized) and provides
4//! embedding inference. Supports automatic model download and caching.
5
6use std::path::{Path, PathBuf};
7use std::sync::Mutex;
8
9use ort::session::builder::GraphOptimizationLevel;
10use ort::session::Session;
11use ort::value::TensorRef;
12
13use crate::embedder::traits::Embedder;
14use crate::error::EmbedderError;
15
16/// HuggingFace model URL for all-MiniLM-L6-v2 ONNX.
17const MODEL_URL: &str = "https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/onnx/model_quantized.onnx";
18
19/// Expected model filename.
20const MODEL_FILENAME: &str = "all-MiniLM-L6-v2-quantized.onnx";
21
22/// Tokenizer URL.
23const TOKENIZER_URL: &str = "https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/tokenizer.json";
24
25/// Tokenizer filename.
26const TOKENIZER_FILENAME: &str = "tokenizer.json";
27
28/// Embedding dimension for all-MiniLM-L6-v2.
29const EMBEDDING_DIM: usize = 384;
30
31/// Maximum sequence length.
32const MAX_SEQ_LENGTH: usize = 256;
33
34/// ONNX-based embedding backend using all-MiniLM-L6-v2.
35pub struct OnnxEmbedder {
36    /// Session wrapped in Mutex because `session.run()` requires `&mut self`.
37    session: Mutex<Session>,
38    /// HuggingFace WordPiece tokenizer loaded from tokenizer.json.
39    tokenizer: tokenizers::Tokenizer,
40    model_dir: PathBuf,
41}
42
43impl OnnxEmbedder {
44    /// Create a new OnnxEmbedder.
45    ///
46    /// If the model is not found in `model_dir`, it will be downloaded
47    /// automatically from HuggingFace.
48    pub fn new(model_dir: &Path) -> Result<Self, EmbedderError> {
49        std::fs::create_dir_all(model_dir).map_err(EmbedderError::Io)?;
50
51        let model_path = model_dir.join(MODEL_FILENAME);
52
53        // Download model if not present
54        if !model_path.exists() {
55            tracing::info!("Downloading ONNX model to {}...", model_path.display());
56            download_file(MODEL_URL, &model_path)?;
57            tracing::info!("Model downloaded successfully.");
58        }
59
60        // Download tokenizer if not present
61        let tokenizer_path = model_dir.join(TOKENIZER_FILENAME);
62        if !tokenizer_path.exists() {
63            tracing::info!("Downloading tokenizer...");
64            download_file(TOKENIZER_URL, &tokenizer_path)?;
65            tracing::info!("Tokenizer downloaded successfully.");
66        }
67
68        // Create ONNX Runtime session
69        let session = Session::builder()
70            .map_err(|e| EmbedderError::OnnxError(e.to_string()))?
71            .with_optimization_level(GraphOptimizationLevel::Level3)
72            .unwrap_or_else(|e| e.recover())
73            .with_intra_threads(4)
74            .unwrap_or_else(|e| e.recover())
75            .commit_from_file(&model_path)
76            .map_err(|e| EmbedderError::OnnxError(format!("Failed to load model: {}", e)))?;
77
78        // Load HuggingFace tokenizer from downloaded tokenizer.json
79        let tokenizer = tokenizers::Tokenizer::from_file(&tokenizer_path)
80            .map_err(|e| EmbedderError::OnnxError(format!("Failed to load tokenizer: {}", e)))?;
81
82        Ok(Self {
83            session: Mutex::new(session),
84            tokenizer,
85            model_dir: model_dir.to_path_buf(),
86        })
87    }
88
89    /// Get the model directory path.
90    pub fn model_dir(&self) -> &Path {
91        &self.model_dir
92    }
93
94    /// Tokenize text using the HuggingFace WordPiece tokenizer.
95    ///
96    /// Uses the real tokenizer.json from all-MiniLM-L6-v2 for proper
97    /// WordPiece tokenization, producing correct token IDs that match
98    /// the model's vocabulary.
99    fn tokenize(&self, text: &str) -> (Vec<i64>, Vec<i64>) {
100        // Encode with special tokens ([CLS] and [SEP] are added automatically)
101        let encoding = self
102            .tokenizer
103            .encode(text, true)
104            .unwrap_or_else(|_| {
105                // Fallback: return empty encoding on error
106                self.tokenizer.encode("", true).unwrap()
107            });
108
109        let mut input_ids: Vec<i64> = encoding.get_ids().iter().map(|&id| id as i64).collect();
110        let mut attention_mask: Vec<i64> =
111            encoding.get_attention_mask().iter().map(|&m| m as i64).collect();
112
113        // Truncate to MAX_SEQ_LENGTH if needed
114        if input_ids.len() > MAX_SEQ_LENGTH {
115            input_ids.truncate(MAX_SEQ_LENGTH);
116            attention_mask.truncate(MAX_SEQ_LENGTH);
117            // Ensure the last token is [SEP] (id = 102 for BERT-based models)
118            if let Some(last) = input_ids.last_mut() {
119                *last = 102;
120            }
121        }
122
123        // Pad to fixed length for batching
124        while input_ids.len() < MAX_SEQ_LENGTH {
125            input_ids.push(0);
126            attention_mask.push(0);
127        }
128
129        (input_ids, attention_mask)
130    }
131
132    /// Run inference on tokenized input.
133    fn run_inference(
134        &self,
135        input_ids: &[i64],
136        attention_mask: &[i64],
137    ) -> Result<Vec<f32>, EmbedderError> {
138        let seq_len = input_ids.len();
139
140        let input_ids_array =
141            ndarray::Array2::from_shape_vec((1, seq_len), input_ids.to_vec())
142                .map_err(|e| EmbedderError::OnnxError(format!("Shape error: {}", e)))?;
143        let attention_mask_array =
144            ndarray::Array2::from_shape_vec((1, seq_len), attention_mask.to_vec())
145                .map_err(|e| EmbedderError::OnnxError(format!("Shape error: {}", e)))?;
146        let token_type_ids_array = ndarray::Array2::<i64>::zeros((1, seq_len));
147
148        // Create TensorRef inputs
149        let input_ids_tensor = TensorRef::from_array_view(&input_ids_array)
150            .map_err(|e| EmbedderError::OnnxError(format!("Tensor creation error: {}", e)))?;
151        let attention_mask_tensor = TensorRef::from_array_view(&attention_mask_array)
152            .map_err(|e| EmbedderError::OnnxError(format!("Tensor creation error: {}", e)))?;
153        let token_type_ids_tensor = TensorRef::from_array_view(&token_type_ids_array)
154            .map_err(|e| EmbedderError::OnnxError(format!("Tensor creation error: {}", e)))?;
155
156        let mut session = self.session.lock().map_err(|e| {
157            EmbedderError::OnnxError(format!("Session lock poisoned: {}", e))
158        })?;
159
160        let outputs = session
161            .run(ort::inputs![
162                "input_ids" => input_ids_tensor,
163                "attention_mask" => attention_mask_tensor,
164                "token_type_ids" => token_type_ids_tensor
165            ])
166            .map_err(|e| EmbedderError::OnnxError(format!("Inference error: {}", e)))?;
167
168        // Extract output tensor — try common output names, then fall back to first output
169        let output = if outputs.contains_key("last_hidden_state") {
170            &outputs["last_hidden_state"]
171        } else if outputs.contains_key("token_embeddings") {
172            &outputs["token_embeddings"]
173        } else {
174            &outputs[0]
175        };
176
177        let tensor = output
178            .try_extract_array::<f32>()
179            .map_err(|e| EmbedderError::OnnxError(format!("Extract error: {}", e)))?;
180
181        // Mean pooling: average over the sequence dimension (dim 1)
182        // tensor shape: [1, seq_len, hidden_size]
183        let shape = tensor.shape();
184        if shape.len() != 3 {
185            return Err(EmbedderError::OnnxError(format!(
186                "Unexpected output shape: {:?}",
187                shape
188            )));
189        }
190
191        let hidden_size = shape[2];
192        let mut pooled = vec![0.0f32; hidden_size];
193        let active_tokens: f32 = attention_mask.iter().map(|&m| m as f32).sum();
194
195        if active_tokens > 0.0 {
196            for seq_idx in 0..shape[1] {
197                let mask = attention_mask.get(seq_idx).copied().unwrap_or(0) as f32;
198                if mask > 0.0 {
199                    for dim in 0..hidden_size {
200                        pooled[dim] += tensor[[0, seq_idx, dim]];
201                    }
202                }
203            }
204            for val in &mut pooled {
205                *val /= active_tokens;
206            }
207        }
208
209        // L2 normalize
210        let norm: f32 = pooled.iter().map(|x| x * x).sum::<f32>().sqrt();
211        if norm > 0.0 {
212            for x in &mut pooled {
213                *x /= norm;
214            }
215        }
216
217        Ok(pooled)
218    }
219}
220
221impl Embedder for OnnxEmbedder {
222    fn embed(&self, text: &str) -> Result<Vec<f32>, EmbedderError> {
223        let (input_ids, attention_mask) = self.tokenize(text);
224        self.run_inference(&input_ids, &attention_mask)
225    }
226
227    fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, EmbedderError> {
228        // For now, process sequentially. A proper implementation would
229        // batch inputs together for a single session.run() call.
230        texts.iter().map(|text| self.embed(text)).collect()
231    }
232
233    fn dimension(&self) -> usize {
234        EMBEDDING_DIM
235    }
236}
237
238/// Download a file from a URL to a local path.
239fn download_file(url: &str, dest: &Path) -> Result<(), EmbedderError> {
240    let response = reqwest::blocking::get(url)
241        .map_err(|e| EmbedderError::DownloadFailed(format!("HTTP request failed: {}", e)))?;
242
243    if !response.status().is_success() {
244        return Err(EmbedderError::DownloadFailed(format!(
245            "HTTP {} for {}",
246            response.status(),
247            url
248        )));
249    }
250
251    let bytes = response
252        .bytes()
253        .map_err(|e| EmbedderError::DownloadFailed(format!("Failed to read response: {}", e)))?;
254
255    // Verify download isn't empty
256    if bytes.is_empty() {
257        return Err(EmbedderError::DownloadFailed(
258            "Downloaded file is empty".to_string(),
259        ));
260    }
261
262    // Write to temporary file then rename (atomic-ish)
263    let tmp_path = dest.with_extension("tmp");
264    std::fs::write(&tmp_path, &bytes).map_err(EmbedderError::Io)?;
265    std::fs::rename(&tmp_path, dest).map_err(EmbedderError::Io)?;
266
267    tracing::info!(
268        "Downloaded {} bytes to {}",
269        bytes.len(),
270        dest.display()
271    );
272
273    Ok(())
274}
275
276#[cfg(test)]
277mod tests {
278    use super::*;
279
280    #[test]
281    fn test_tokenize_output_length() {
282        // Verify that tokenize always produces MAX_SEQ_LENGTH outputs
283        // This test requires a tokenizer.json file, so we create a minimal one
284        // In integration tests, the real tokenizer would be used
285        let model_dir = std::env::temp_dir().join("seekr_test_tokenizer");
286        if let Ok(embedder) = OnnxEmbedder::new(&model_dir) {
287            let (ids, mask) = embedder.tokenize("hello world");
288            assert_eq!(ids.len(), MAX_SEQ_LENGTH);
289            assert_eq!(mask.len(), MAX_SEQ_LENGTH);
290
291            // First token should be [CLS] = 101
292            assert_eq!(ids[0], 101);
293
294            // There should be active tokens (attention_mask = 1)
295            let active: i64 = mask.iter().sum();
296            assert!(active > 0, "Should have at least some active tokens");
297        }
298    }
299
300    #[test]
301    fn test_embedding_dimension() {
302        assert_eq!(EMBEDDING_DIM, 384);
303    }
304
305    #[test]
306    fn test_max_seq_length() {
307        assert_eq!(MAX_SEQ_LENGTH, 256);
308    }
309}