seekr_code/embedder/
onnx.rs1use std::path::{Path, PathBuf};
7use std::sync::Mutex;
8
9use ort::session::builder::GraphOptimizationLevel;
10use ort::session::Session;
11use ort::value::TensorRef;
12
13use crate::embedder::traits::Embedder;
14use crate::error::EmbedderError;
15
16const MODEL_URL: &str = "https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/onnx/model_quantized.onnx";
18
19const MODEL_FILENAME: &str = "all-MiniLM-L6-v2-quantized.onnx";
21
22const TOKENIZER_URL: &str = "https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/tokenizer.json";
24
25const TOKENIZER_FILENAME: &str = "tokenizer.json";
27
28const EMBEDDING_DIM: usize = 384;
30
31const MAX_SEQ_LENGTH: usize = 256;
33
34pub struct OnnxEmbedder {
36 session: Mutex<Session>,
38 model_dir: PathBuf,
39}
40
41impl OnnxEmbedder {
42 pub fn new(model_dir: &Path) -> Result<Self, EmbedderError> {
47 std::fs::create_dir_all(model_dir).map_err(EmbedderError::Io)?;
48
49 let model_path = model_dir.join(MODEL_FILENAME);
50
51 if !model_path.exists() {
53 tracing::info!("Downloading ONNX model to {}...", model_path.display());
54 download_file(MODEL_URL, &model_path)?;
55 tracing::info!("Model downloaded successfully.");
56 }
57
58 let tokenizer_path = model_dir.join(TOKENIZER_FILENAME);
60 if !tokenizer_path.exists() {
61 tracing::info!("Downloading tokenizer...");
62 download_file(TOKENIZER_URL, &tokenizer_path)?;
63 tracing::info!("Tokenizer downloaded successfully.");
64 }
65
66 let session = Session::builder()
68 .map_err(|e| EmbedderError::OnnxError(e.to_string()))?
69 .with_optimization_level(GraphOptimizationLevel::Level3)
70 .unwrap_or_else(|e| e.recover())
71 .with_intra_threads(4)
72 .unwrap_or_else(|e| e.recover())
73 .commit_from_file(&model_path)
74 .map_err(|e| EmbedderError::OnnxError(format!("Failed to load model: {}", e)))?;
75
76 Ok(Self {
77 session: Mutex::new(session),
78 model_dir: model_dir.to_path_buf(),
79 })
80 }
81
82 pub fn model_dir(&self) -> &Path {
84 &self.model_dir
85 }
86
87 fn tokenize(&self, text: &str) -> (Vec<i64>, Vec<i64>) {
92 let words: Vec<&str> = text.split_whitespace().collect();
94
95 let mut input_ids = vec![101i64]; let mut attention_mask = vec![1i64];
98
99 for word in words {
100 if input_ids.len() >= MAX_SEQ_LENGTH - 1 {
101 break;
102 }
103 let token_id = simple_hash(word) % 30000 + 1000;
105 input_ids.push(token_id as i64);
106 attention_mask.push(1);
107 }
108
109 input_ids.push(102); attention_mask.push(1);
111
112 while input_ids.len() < MAX_SEQ_LENGTH {
114 input_ids.push(0);
115 attention_mask.push(0);
116 }
117
118 (input_ids, attention_mask)
119 }
120
121 fn run_inference(
123 &self,
124 input_ids: &[i64],
125 attention_mask: &[i64],
126 ) -> Result<Vec<f32>, EmbedderError> {
127 let seq_len = input_ids.len();
128
129 let input_ids_array =
130 ndarray::Array2::from_shape_vec((1, seq_len), input_ids.to_vec())
131 .map_err(|e| EmbedderError::OnnxError(format!("Shape error: {}", e)))?;
132 let attention_mask_array =
133 ndarray::Array2::from_shape_vec((1, seq_len), attention_mask.to_vec())
134 .map_err(|e| EmbedderError::OnnxError(format!("Shape error: {}", e)))?;
135 let token_type_ids_array = ndarray::Array2::<i64>::zeros((1, seq_len));
136
137 let input_ids_tensor = TensorRef::from_array_view(&input_ids_array)
139 .map_err(|e| EmbedderError::OnnxError(format!("Tensor creation error: {}", e)))?;
140 let attention_mask_tensor = TensorRef::from_array_view(&attention_mask_array)
141 .map_err(|e| EmbedderError::OnnxError(format!("Tensor creation error: {}", e)))?;
142 let token_type_ids_tensor = TensorRef::from_array_view(&token_type_ids_array)
143 .map_err(|e| EmbedderError::OnnxError(format!("Tensor creation error: {}", e)))?;
144
145 let mut session = self.session.lock().map_err(|e| {
146 EmbedderError::OnnxError(format!("Session lock poisoned: {}", e))
147 })?;
148
149 let outputs = session
150 .run(ort::inputs![
151 "input_ids" => input_ids_tensor,
152 "attention_mask" => attention_mask_tensor,
153 "token_type_ids" => token_type_ids_tensor
154 ])
155 .map_err(|e| EmbedderError::OnnxError(format!("Inference error: {}", e)))?;
156
157 let output = if outputs.contains_key("last_hidden_state") {
159 &outputs["last_hidden_state"]
160 } else if outputs.contains_key("token_embeddings") {
161 &outputs["token_embeddings"]
162 } else {
163 &outputs[0]
164 };
165
166 let tensor = output
167 .try_extract_array::<f32>()
168 .map_err(|e| EmbedderError::OnnxError(format!("Extract error: {}", e)))?;
169
170 let shape = tensor.shape();
173 if shape.len() != 3 {
174 return Err(EmbedderError::OnnxError(format!(
175 "Unexpected output shape: {:?}",
176 shape
177 )));
178 }
179
180 let hidden_size = shape[2];
181 let mut pooled = vec![0.0f32; hidden_size];
182 let active_tokens: f32 = attention_mask.iter().map(|&m| m as f32).sum();
183
184 if active_tokens > 0.0 {
185 for seq_idx in 0..shape[1] {
186 let mask = attention_mask.get(seq_idx).copied().unwrap_or(0) as f32;
187 if mask > 0.0 {
188 for dim in 0..hidden_size {
189 pooled[dim] += tensor[[0, seq_idx, dim]];
190 }
191 }
192 }
193 for val in &mut pooled {
194 *val /= active_tokens;
195 }
196 }
197
198 let norm: f32 = pooled.iter().map(|x| x * x).sum::<f32>().sqrt();
200 if norm > 0.0 {
201 for x in &mut pooled {
202 *x /= norm;
203 }
204 }
205
206 Ok(pooled)
207 }
208}
209
210impl Embedder for OnnxEmbedder {
211 fn embed(&self, text: &str) -> Result<Vec<f32>, EmbedderError> {
212 let (input_ids, attention_mask) = self.tokenize(text);
213 self.run_inference(&input_ids, &attention_mask)
214 }
215
216 fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, EmbedderError> {
217 texts.iter().map(|text| self.embed(text)).collect()
220 }
221
222 fn dimension(&self) -> usize {
223 EMBEDDING_DIM
224 }
225}
226
227fn simple_hash(s: &str) -> u64 {
229 let mut hash: u64 = 5381;
230 for byte in s.bytes() {
231 hash = hash.wrapping_mul(33).wrapping_add(byte as u64);
232 }
233 hash
234}
235
236fn download_file(url: &str, dest: &Path) -> Result<(), EmbedderError> {
238 let response = reqwest::blocking::get(url)
239 .map_err(|e| EmbedderError::DownloadFailed(format!("HTTP request failed: {}", e)))?;
240
241 if !response.status().is_success() {
242 return Err(EmbedderError::DownloadFailed(format!(
243 "HTTP {} for {}",
244 response.status(),
245 url
246 )));
247 }
248
249 let bytes = response
250 .bytes()
251 .map_err(|e| EmbedderError::DownloadFailed(format!("Failed to read response: {}", e)))?;
252
253 if bytes.is_empty() {
255 return Err(EmbedderError::DownloadFailed(
256 "Downloaded file is empty".to_string(),
257 ));
258 }
259
260 let tmp_path = dest.with_extension("tmp");
262 std::fs::write(&tmp_path, &bytes).map_err(EmbedderError::Io)?;
263 std::fs::rename(&tmp_path, dest).map_err(EmbedderError::Io)?;
264
265 tracing::info!(
266 "Downloaded {} bytes to {}",
267 bytes.len(),
268 dest.display()
269 );
270
271 Ok(())
272}
273
274#[cfg(test)]
275mod tests {
276 use super::*;
277
278 #[test]
279 fn test_simple_hash() {
280 let h1 = simple_hash("hello");
281 let h2 = simple_hash("world");
282 assert_ne!(h1, h2);
283
284 assert_eq!(simple_hash("test"), simple_hash("test"));
286 }
287
288 #[test]
289 fn test_tokenize_basic() {
290 let token = simple_hash("authentication") % 30000 + 1000;
292 assert!(token >= 1000);
293 assert!(token < 31000);
294 }
295}