Skip to main content

seekr_code/embedder/
onnx.rs

1//! ONNX Runtime embedding backend.
2//!
3//! Loads a local ONNX model (all-MiniLM-L6-v2 quantized) and provides
4//! embedding inference. Supports automatic model download and caching.
5
6use std::path::{Path, PathBuf};
7use std::sync::Mutex;
8
9use ort::session::builder::GraphOptimizationLevel;
10use ort::session::Session;
11use ort::value::TensorRef;
12
13use crate::embedder::traits::Embedder;
14use crate::error::EmbedderError;
15
16/// HuggingFace model URL for all-MiniLM-L6-v2 ONNX.
17const MODEL_URL: &str = "https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/onnx/model_quantized.onnx";
18
19/// Expected model filename.
20const MODEL_FILENAME: &str = "all-MiniLM-L6-v2-quantized.onnx";
21
22/// Tokenizer URL.
23const TOKENIZER_URL: &str = "https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/tokenizer.json";
24
25/// Tokenizer filename.
26const TOKENIZER_FILENAME: &str = "tokenizer.json";
27
28/// Embedding dimension for all-MiniLM-L6-v2.
29const EMBEDDING_DIM: usize = 384;
30
31/// Maximum sequence length.
32const MAX_SEQ_LENGTH: usize = 256;
33
34/// ONNX-based embedding backend using all-MiniLM-L6-v2.
35pub struct OnnxEmbedder {
36    /// Session wrapped in Mutex because `session.run()` requires `&mut self`.
37    session: Mutex<Session>,
38    model_dir: PathBuf,
39}
40
41impl OnnxEmbedder {
42    /// Create a new OnnxEmbedder.
43    ///
44    /// If the model is not found in `model_dir`, it will be downloaded
45    /// automatically from HuggingFace.
46    pub fn new(model_dir: &Path) -> Result<Self, EmbedderError> {
47        std::fs::create_dir_all(model_dir).map_err(EmbedderError::Io)?;
48
49        let model_path = model_dir.join(MODEL_FILENAME);
50
51        // Download model if not present
52        if !model_path.exists() {
53            tracing::info!("Downloading ONNX model to {}...", model_path.display());
54            download_file(MODEL_URL, &model_path)?;
55            tracing::info!("Model downloaded successfully.");
56        }
57
58        // Download tokenizer if not present
59        let tokenizer_path = model_dir.join(TOKENIZER_FILENAME);
60        if !tokenizer_path.exists() {
61            tracing::info!("Downloading tokenizer...");
62            download_file(TOKENIZER_URL, &tokenizer_path)?;
63            tracing::info!("Tokenizer downloaded successfully.");
64        }
65
66        // Create ONNX Runtime session
67        let session = Session::builder()
68            .map_err(|e| EmbedderError::OnnxError(e.to_string()))?
69            .with_optimization_level(GraphOptimizationLevel::Level3)
70            .unwrap_or_else(|e| e.recover())
71            .with_intra_threads(4)
72            .unwrap_or_else(|e| e.recover())
73            .commit_from_file(&model_path)
74            .map_err(|e| EmbedderError::OnnxError(format!("Failed to load model: {}", e)))?;
75
76        Ok(Self {
77            session: Mutex::new(session),
78            model_dir: model_dir.to_path_buf(),
79        })
80    }
81
82    /// Get the model directory path.
83    pub fn model_dir(&self) -> &Path {
84        &self.model_dir
85    }
86
87    /// Simple tokenization for the embedding model.
88    ///
89    /// This is a simplified tokenizer that creates word-piece-like tokens.
90    /// For production use, a proper HuggingFace tokenizer should be used.
91    fn tokenize(&self, text: &str) -> (Vec<i64>, Vec<i64>) {
92        // Simple whitespace + punctuation tokenization
93        let words: Vec<&str> = text.split_whitespace().collect();
94
95        // CLS token = 101, SEP token = 102
96        let mut input_ids = vec![101i64]; // [CLS]
97        let mut attention_mask = vec![1i64];
98
99        for word in words {
100            if input_ids.len() >= MAX_SEQ_LENGTH - 1 {
101                break;
102            }
103            // Simple hash-based token ID (simplified tokenization)
104            let token_id = simple_hash(word) % 30000 + 1000;
105            input_ids.push(token_id as i64);
106            attention_mask.push(1);
107        }
108
109        input_ids.push(102); // [SEP]
110        attention_mask.push(1);
111
112        // Pad to fixed length for batching
113        while input_ids.len() < MAX_SEQ_LENGTH {
114            input_ids.push(0);
115            attention_mask.push(0);
116        }
117
118        (input_ids, attention_mask)
119    }
120
121    /// Run inference on tokenized input.
122    fn run_inference(
123        &self,
124        input_ids: &[i64],
125        attention_mask: &[i64],
126    ) -> Result<Vec<f32>, EmbedderError> {
127        let seq_len = input_ids.len();
128
129        let input_ids_array =
130            ndarray::Array2::from_shape_vec((1, seq_len), input_ids.to_vec())
131                .map_err(|e| EmbedderError::OnnxError(format!("Shape error: {}", e)))?;
132        let attention_mask_array =
133            ndarray::Array2::from_shape_vec((1, seq_len), attention_mask.to_vec())
134                .map_err(|e| EmbedderError::OnnxError(format!("Shape error: {}", e)))?;
135        let token_type_ids_array = ndarray::Array2::<i64>::zeros((1, seq_len));
136
137        // Create TensorRef inputs
138        let input_ids_tensor = TensorRef::from_array_view(&input_ids_array)
139            .map_err(|e| EmbedderError::OnnxError(format!("Tensor creation error: {}", e)))?;
140        let attention_mask_tensor = TensorRef::from_array_view(&attention_mask_array)
141            .map_err(|e| EmbedderError::OnnxError(format!("Tensor creation error: {}", e)))?;
142        let token_type_ids_tensor = TensorRef::from_array_view(&token_type_ids_array)
143            .map_err(|e| EmbedderError::OnnxError(format!("Tensor creation error: {}", e)))?;
144
145        let mut session = self.session.lock().map_err(|e| {
146            EmbedderError::OnnxError(format!("Session lock poisoned: {}", e))
147        })?;
148
149        let outputs = session
150            .run(ort::inputs![
151                "input_ids" => input_ids_tensor,
152                "attention_mask" => attention_mask_tensor,
153                "token_type_ids" => token_type_ids_tensor
154            ])
155            .map_err(|e| EmbedderError::OnnxError(format!("Inference error: {}", e)))?;
156
157        // Extract output tensor — try common output names, then fall back to first output
158        let output = if outputs.contains_key("last_hidden_state") {
159            &outputs["last_hidden_state"]
160        } else if outputs.contains_key("token_embeddings") {
161            &outputs["token_embeddings"]
162        } else {
163            &outputs[0]
164        };
165
166        let tensor = output
167            .try_extract_array::<f32>()
168            .map_err(|e| EmbedderError::OnnxError(format!("Extract error: {}", e)))?;
169
170        // Mean pooling: average over the sequence dimension (dim 1)
171        // tensor shape: [1, seq_len, hidden_size]
172        let shape = tensor.shape();
173        if shape.len() != 3 {
174            return Err(EmbedderError::OnnxError(format!(
175                "Unexpected output shape: {:?}",
176                shape
177            )));
178        }
179
180        let hidden_size = shape[2];
181        let mut pooled = vec![0.0f32; hidden_size];
182        let active_tokens: f32 = attention_mask.iter().map(|&m| m as f32).sum();
183
184        if active_tokens > 0.0 {
185            for seq_idx in 0..shape[1] {
186                let mask = attention_mask.get(seq_idx).copied().unwrap_or(0) as f32;
187                if mask > 0.0 {
188                    for dim in 0..hidden_size {
189                        pooled[dim] += tensor[[0, seq_idx, dim]];
190                    }
191                }
192            }
193            for val in &mut pooled {
194                *val /= active_tokens;
195            }
196        }
197
198        // L2 normalize
199        let norm: f32 = pooled.iter().map(|x| x * x).sum::<f32>().sqrt();
200        if norm > 0.0 {
201            for x in &mut pooled {
202                *x /= norm;
203            }
204        }
205
206        Ok(pooled)
207    }
208}
209
210impl Embedder for OnnxEmbedder {
211    fn embed(&self, text: &str) -> Result<Vec<f32>, EmbedderError> {
212        let (input_ids, attention_mask) = self.tokenize(text);
213        self.run_inference(&input_ids, &attention_mask)
214    }
215
216    fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, EmbedderError> {
217        // For now, process sequentially. A proper implementation would
218        // batch inputs together for a single session.run() call.
219        texts.iter().map(|text| self.embed(text)).collect()
220    }
221
222    fn dimension(&self) -> usize {
223        EMBEDDING_DIM
224    }
225}
226
227/// Simple hash function for token ID generation.
228fn simple_hash(s: &str) -> u64 {
229    let mut hash: u64 = 5381;
230    for byte in s.bytes() {
231        hash = hash.wrapping_mul(33).wrapping_add(byte as u64);
232    }
233    hash
234}
235
236/// Download a file from a URL to a local path.
237fn download_file(url: &str, dest: &Path) -> Result<(), EmbedderError> {
238    let response = reqwest::blocking::get(url)
239        .map_err(|e| EmbedderError::DownloadFailed(format!("HTTP request failed: {}", e)))?;
240
241    if !response.status().is_success() {
242        return Err(EmbedderError::DownloadFailed(format!(
243            "HTTP {} for {}",
244            response.status(),
245            url
246        )));
247    }
248
249    let bytes = response
250        .bytes()
251        .map_err(|e| EmbedderError::DownloadFailed(format!("Failed to read response: {}", e)))?;
252
253    // Verify download isn't empty
254    if bytes.is_empty() {
255        return Err(EmbedderError::DownloadFailed(
256            "Downloaded file is empty".to_string(),
257        ));
258    }
259
260    // Write to temporary file then rename (atomic-ish)
261    let tmp_path = dest.with_extension("tmp");
262    std::fs::write(&tmp_path, &bytes).map_err(EmbedderError::Io)?;
263    std::fs::rename(&tmp_path, dest).map_err(EmbedderError::Io)?;
264
265    tracing::info!(
266        "Downloaded {} bytes to {}",
267        bytes.len(),
268        dest.display()
269    );
270
271    Ok(())
272}
273
274#[cfg(test)]
275mod tests {
276    use super::*;
277
278    #[test]
279    fn test_simple_hash() {
280        let h1 = simple_hash("hello");
281        let h2 = simple_hash("world");
282        assert_ne!(h1, h2);
283
284        // Same input should give same hash
285        assert_eq!(simple_hash("test"), simple_hash("test"));
286    }
287
288    #[test]
289    fn test_tokenize_basic() {
290        // We just test the hash function used for tokenization
291        let token = simple_hash("authentication") % 30000 + 1000;
292        assert!(token >= 1000);
293        assert!(token < 31000);
294    }
295}