Skip to main content

seekr_code/embedder/
onnx.rs

1//! ONNX Runtime embedding backend.
2//!
3//! Loads a local ONNX model (all-MiniLM-L6-v2 quantized) and provides
4//! embedding inference. Supports automatic model download and caching.
5
6use std::path::{Path, PathBuf};
7use std::sync::Mutex;
8
9use ort::session::Session;
10use ort::session::builder::GraphOptimizationLevel;
11use ort::value::TensorRef;
12
13use crate::embedder::traits::Embedder;
14use crate::error::EmbedderError;
15
16/// HuggingFace model URL for all-MiniLM-L6-v2 ONNX.
17const MODEL_URL: &str = "https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/onnx/model_quantized.onnx";
18
19/// Expected model filename.
20const MODEL_FILENAME: &str = "all-MiniLM-L6-v2-quantized.onnx";
21
22/// Tokenizer URL.
23const TOKENIZER_URL: &str =
24    "https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/tokenizer.json";
25
26/// Tokenizer filename.
27const TOKENIZER_FILENAME: &str = "tokenizer.json";
28
29/// Embedding dimension for all-MiniLM-L6-v2.
30const EMBEDDING_DIM: usize = 384;
31
32/// Maximum sequence length.
33const MAX_SEQ_LENGTH: usize = 256;
34
35/// ONNX-based embedding backend using all-MiniLM-L6-v2.
36pub struct OnnxEmbedder {
37    /// Session wrapped in Mutex because `session.run()` requires `&mut self`.
38    session: Mutex<Session>,
39    /// HuggingFace WordPiece tokenizer loaded from tokenizer.json.
40    tokenizer: tokenizers::Tokenizer,
41    model_dir: PathBuf,
42}
43
44impl OnnxEmbedder {
45    /// Create a new OnnxEmbedder.
46    ///
47    /// If the model is not found in `model_dir`, it will be downloaded
48    /// automatically from HuggingFace.
49    pub fn new(model_dir: &Path) -> Result<Self, EmbedderError> {
50        std::fs::create_dir_all(model_dir).map_err(EmbedderError::Io)?;
51
52        let model_path = model_dir.join(MODEL_FILENAME);
53
54        // Download model if not present
55        if !model_path.exists() {
56            tracing::info!("Downloading ONNX model to {}...", model_path.display());
57            download_file(MODEL_URL, &model_path)?;
58            tracing::info!("Model downloaded successfully.");
59        }
60
61        // Download tokenizer if not present
62        let tokenizer_path = model_dir.join(TOKENIZER_FILENAME);
63        if !tokenizer_path.exists() {
64            tracing::info!("Downloading tokenizer...");
65            download_file(TOKENIZER_URL, &tokenizer_path)?;
66            tracing::info!("Tokenizer downloaded successfully.");
67        }
68
69        // Create ONNX Runtime session
70        let session = Session::builder()
71            .map_err(|e| EmbedderError::OnnxError(e.to_string()))?
72            .with_optimization_level(GraphOptimizationLevel::Level3)
73            .unwrap_or_else(|e| e.recover())
74            .with_intra_threads(4)
75            .unwrap_or_else(|e| e.recover())
76            .commit_from_file(&model_path)
77            .map_err(|e| EmbedderError::OnnxError(format!("Failed to load model: {}", e)))?;
78
79        // Load HuggingFace tokenizer from downloaded tokenizer.json
80        let tokenizer = tokenizers::Tokenizer::from_file(&tokenizer_path)
81            .map_err(|e| EmbedderError::OnnxError(format!("Failed to load tokenizer: {}", e)))?;
82
83        Ok(Self {
84            session: Mutex::new(session),
85            tokenizer,
86            model_dir: model_dir.to_path_buf(),
87        })
88    }
89
90    /// Get the model directory path.
91    pub fn model_dir(&self) -> &Path {
92        &self.model_dir
93    }
94
95    /// Tokenize text using the HuggingFace WordPiece tokenizer.
96    ///
97    /// Uses the real tokenizer.json from all-MiniLM-L6-v2 for proper
98    /// WordPiece tokenization, producing correct token IDs that match
99    /// the model's vocabulary.
100    fn tokenize(&self, text: &str) -> (Vec<i64>, Vec<i64>) {
101        // Encode with special tokens ([CLS] and [SEP] are added automatically)
102        let encoding = self.tokenizer.encode(text, true).unwrap_or_else(|_| {
103            // Fallback: return empty encoding on error
104            self.tokenizer.encode("", true).unwrap()
105        });
106
107        let mut input_ids: Vec<i64> = encoding.get_ids().iter().map(|&id| id as i64).collect();
108        let mut attention_mask: Vec<i64> = encoding
109            .get_attention_mask()
110            .iter()
111            .map(|&m| m as i64)
112            .collect();
113
114        // Truncate to MAX_SEQ_LENGTH if needed
115        if input_ids.len() > MAX_SEQ_LENGTH {
116            input_ids.truncate(MAX_SEQ_LENGTH);
117            attention_mask.truncate(MAX_SEQ_LENGTH);
118            // Ensure the last token is [SEP] (id = 102 for BERT-based models)
119            if let Some(last) = input_ids.last_mut() {
120                *last = 102;
121            }
122        }
123
124        // Pad to fixed length for batching
125        while input_ids.len() < MAX_SEQ_LENGTH {
126            input_ids.push(0);
127            attention_mask.push(0);
128        }
129
130        (input_ids, attention_mask)
131    }
132
133    /// Run inference on tokenized input.
134    fn run_inference(
135        &self,
136        input_ids: &[i64],
137        attention_mask: &[i64],
138    ) -> Result<Vec<f32>, EmbedderError> {
139        let seq_len = input_ids.len();
140
141        let input_ids_array = ndarray::Array2::from_shape_vec((1, seq_len), input_ids.to_vec())
142            .map_err(|e| EmbedderError::OnnxError(format!("Shape error: {}", e)))?;
143        let attention_mask_array =
144            ndarray::Array2::from_shape_vec((1, seq_len), attention_mask.to_vec())
145                .map_err(|e| EmbedderError::OnnxError(format!("Shape error: {}", e)))?;
146        let token_type_ids_array = ndarray::Array2::<i64>::zeros((1, seq_len));
147
148        // Create TensorRef inputs
149        let input_ids_tensor = TensorRef::from_array_view(&input_ids_array)
150            .map_err(|e| EmbedderError::OnnxError(format!("Tensor creation error: {}", e)))?;
151        let attention_mask_tensor = TensorRef::from_array_view(&attention_mask_array)
152            .map_err(|e| EmbedderError::OnnxError(format!("Tensor creation error: {}", e)))?;
153        let token_type_ids_tensor = TensorRef::from_array_view(&token_type_ids_array)
154            .map_err(|e| EmbedderError::OnnxError(format!("Tensor creation error: {}", e)))?;
155
156        let mut session = self
157            .session
158            .lock()
159            .map_err(|e| EmbedderError::OnnxError(format!("Session lock poisoned: {}", e)))?;
160
161        let outputs = session
162            .run(ort::inputs![
163                "input_ids" => input_ids_tensor,
164                "attention_mask" => attention_mask_tensor,
165                "token_type_ids" => token_type_ids_tensor
166            ])
167            .map_err(|e| EmbedderError::OnnxError(format!("Inference error: {}", e)))?;
168
169        // Extract output tensor — try common output names, then fall back to first output
170        let output = if outputs.contains_key("last_hidden_state") {
171            &outputs["last_hidden_state"]
172        } else if outputs.contains_key("token_embeddings") {
173            &outputs["token_embeddings"]
174        } else {
175            &outputs[0]
176        };
177
178        let tensor = output
179            .try_extract_array::<f32>()
180            .map_err(|e| EmbedderError::OnnxError(format!("Extract error: {}", e)))?;
181
182        // Mean pooling: average over the sequence dimension (dim 1)
183        // tensor shape: [1, seq_len, hidden_size]
184        let shape = tensor.shape();
185        if shape.len() != 3 {
186            return Err(EmbedderError::OnnxError(format!(
187                "Unexpected output shape: {:?}",
188                shape
189            )));
190        }
191
192        let hidden_size = shape[2];
193        let mut pooled = vec![0.0f32; hidden_size];
194        let active_tokens: f32 = attention_mask.iter().map(|&m| m as f32).sum();
195
196        if active_tokens > 0.0 {
197            for seq_idx in 0..shape[1] {
198                let mask = attention_mask.get(seq_idx).copied().unwrap_or(0) as f32;
199                if mask > 0.0 {
200                    for dim in 0..hidden_size {
201                        pooled[dim] += tensor[[0, seq_idx, dim]];
202                    }
203                }
204            }
205            for val in &mut pooled {
206                *val /= active_tokens;
207            }
208        }
209
210        // L2 normalize
211        let norm: f32 = pooled.iter().map(|x| x * x).sum::<f32>().sqrt();
212        if norm > 0.0 {
213            for x in &mut pooled {
214                *x /= norm;
215            }
216        }
217
218        Ok(pooled)
219    }
220}
221
222impl Embedder for OnnxEmbedder {
223    fn embed(&self, text: &str) -> Result<Vec<f32>, EmbedderError> {
224        let (input_ids, attention_mask) = self.tokenize(text);
225        self.run_inference(&input_ids, &attention_mask)
226    }
227
228    fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, EmbedderError> {
229        // For now, process sequentially. A proper implementation would
230        // batch inputs together for a single session.run() call.
231        texts.iter().map(|text| self.embed(text)).collect()
232    }
233
234    fn dimension(&self) -> usize {
235        EMBEDDING_DIM
236    }
237}
238
239/// Download a file from a URL to a local path.
240fn download_file(url: &str, dest: &Path) -> Result<(), EmbedderError> {
241    let response = reqwest::blocking::get(url)
242        .map_err(|e| EmbedderError::DownloadFailed(format!("HTTP request failed: {}", e)))?;
243
244    if !response.status().is_success() {
245        return Err(EmbedderError::DownloadFailed(format!(
246            "HTTP {} for {}",
247            response.status(),
248            url
249        )));
250    }
251
252    let bytes = response
253        .bytes()
254        .map_err(|e| EmbedderError::DownloadFailed(format!("Failed to read response: {}", e)))?;
255
256    // Verify download isn't empty
257    if bytes.is_empty() {
258        return Err(EmbedderError::DownloadFailed(
259            "Downloaded file is empty".to_string(),
260        ));
261    }
262
263    // Write to temporary file then rename (atomic-ish)
264    let tmp_path = dest.with_extension("tmp");
265    std::fs::write(&tmp_path, &bytes).map_err(EmbedderError::Io)?;
266    std::fs::rename(&tmp_path, dest).map_err(EmbedderError::Io)?;
267
268    tracing::info!("Downloaded {} bytes to {}", bytes.len(), dest.display());
269
270    Ok(())
271}
272
273#[cfg(test)]
274mod tests {
275    use super::*;
276
277    #[test]
278    fn test_tokenize_output_length() {
279        // Verify that tokenize always produces MAX_SEQ_LENGTH outputs
280        // This test requires a tokenizer.json file, so we create a minimal one
281        // In integration tests, the real tokenizer would be used
282        let model_dir = std::env::temp_dir().join("seekr_test_tokenizer");
283        if let Ok(embedder) = OnnxEmbedder::new(&model_dir) {
284            let (ids, mask) = embedder.tokenize("hello world");
285            assert_eq!(ids.len(), MAX_SEQ_LENGTH);
286            assert_eq!(mask.len(), MAX_SEQ_LENGTH);
287
288            // First token should be [CLS] = 101
289            assert_eq!(ids[0], 101);
290
291            // There should be active tokens (attention_mask = 1)
292            let active: i64 = mask.iter().sum();
293            assert!(active > 0, "Should have at least some active tokens");
294        }
295    }
296
297    #[test]
298    fn test_embedding_dimension() {
299        assert_eq!(EMBEDDING_DIM, 384);
300    }
301
302    #[test]
303    fn test_max_seq_length() {
304        assert_eq!(MAX_SEQ_LENGTH, 256);
305    }
306}