1use super::Embedder;
4use crate::{Chunk, Error, Result};
5
6#[cfg(feature = "nemotron")]
11#[derive(Debug, Clone)]
12pub struct NemotronConfig {
13 pub model_path: std::path::PathBuf,
15 pub use_gpu: bool,
17 pub batch_size: usize,
19 pub query_prefix: String,
21 pub passage_prefix: String,
23 pub max_length: usize,
25 pub normalize: bool,
27}
28
29#[cfg(feature = "nemotron")]
30impl Default for NemotronConfig {
31 fn default() -> Self {
32 Self {
33 model_path: std::path::PathBuf::new(),
34 use_gpu: true,
35 batch_size: 8,
36 query_prefix: "Instruct: Given a query, retrieve relevant documents\nQuery: "
38 .to_string(),
39 passage_prefix: String::new(),
40 max_length: 8192,
41 normalize: true,
42 }
43 }
44}
45
46#[cfg(feature = "nemotron")]
47impl NemotronConfig {
48 #[must_use]
50 pub fn new(model_path: impl AsRef<std::path::Path>) -> Self {
51 Self { model_path: model_path.as_ref().to_path_buf(), ..Default::default() }
52 }
53
54 #[must_use]
56 pub fn with_model_path(mut self, path: impl AsRef<std::path::Path>) -> Self {
57 self.model_path = path.as_ref().to_path_buf();
58 self
59 }
60
61 #[must_use]
63 pub fn with_gpu(mut self, use_gpu: bool) -> Self {
64 self.use_gpu = use_gpu;
65 self
66 }
67
68 #[must_use]
70 pub fn with_batch_size(mut self, batch_size: usize) -> Self {
71 self.batch_size = batch_size;
72 self
73 }
74
75 #[must_use]
77 pub fn with_query_prefix(mut self, prefix: impl Into<String>) -> Self {
78 self.query_prefix = prefix.into();
79 self
80 }
81
82 #[must_use]
84 pub fn with_passage_prefix(mut self, prefix: impl Into<String>) -> Self {
85 self.passage_prefix = prefix.into();
86 self
87 }
88
89 #[must_use]
91 pub fn with_max_length(mut self, max_length: usize) -> Self {
92 self.max_length = max_length;
93 self
94 }
95
96 #[must_use]
98 pub fn with_normalize(mut self, normalize: bool) -> Self {
99 self.normalize = normalize;
100 self
101 }
102}
103
104#[cfg(feature = "nemotron")]
124pub struct NemotronEmbedder {
125 transformer: realizar::gguf::GGUFTransformer,
127 model: realizar::gguf::GGUFModel,
129 config: NemotronConfig,
131 dimension: usize,
133}
134
135#[cfg(feature = "nemotron")]
136impl std::fmt::Debug for NemotronEmbedder {
137 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
138 f.debug_struct("NemotronEmbedder")
139 .field("dimension", &self.dimension)
140 .field("config", &self.config)
141 .finish_non_exhaustive()
142 }
143}
144
145#[cfg(feature = "nemotron")]
146impl NemotronEmbedder {
147 pub fn new(config: NemotronConfig) -> Result<Self> {
156 if !config.model_path.exists() {
157 return Err(Error::InvalidConfig(format!(
158 "Model file not found: {}",
159 config.model_path.display()
160 )));
161 }
162
163 let file_data = std::fs::read(&config.model_path).map_err(|e| {
165 Error::InvalidConfig(format!(
166 "Failed to read model file {}: {e}",
167 config.model_path.display()
168 ))
169 })?;
170
171 let model = realizar::gguf::GGUFModel::from_bytes(&file_data)
173 .map_err(|e| Error::InvalidConfig(format!("Failed to parse GGUF model: {e}")))?;
174
175 let transformer = realizar::gguf::GGUFTransformer::from_gguf(&model, &file_data)
177 .map_err(|e| Error::InvalidConfig(format!("Failed to create transformer: {e}")))?;
178
179 let dimension = transformer.config.hidden_dim;
181
182 Ok(Self { transformer, model, config, dimension })
183 }
184
185 #[must_use]
187 pub fn config(&self) -> &NemotronConfig {
188 &self.config
189 }
190
191 fn embed_with_prefix(&self, text: &str, prefix: &str) -> Result<Vec<f32>> {
193 let prefixed = if prefix.is_empty() { text.to_string() } else { format!("{prefix}{text}") };
194
195 let tokens = self
197 .model
198 .encode(&prefixed)
199 .ok_or_else(|| Error::Embedding("Failed to tokenize text".to_string()))?;
200
201 let tokens: Vec<u32> = if tokens.len() > self.config.max_length {
203 tokens[..self.config.max_length].to_vec()
204 } else {
205 tokens
206 };
207
208 let seq_len = tokens.len();
209 if seq_len == 0 {
210 return Err(Error::Embedding("Empty token sequence".to_string()));
211 }
212
213 let embedding = self.extract_embedding_from_model(&tokens)?;
217
218 Ok(embedding)
219 }
220
221 fn extract_embedding_from_model(&self, tokens: &[u32]) -> Result<Vec<f32>> {
223 let hidden_dim = self.dimension;
225
226 let mut hidden: Vec<f32> = tokens
228 .iter()
229 .flat_map(|&token_id| {
230 let start = (token_id as usize) * hidden_dim;
231 let end = start + hidden_dim;
232 self.transformer.token_embedding[start..end].to_vec()
233 })
234 .collect();
235
236 for layer in &self.transformer.layers {
238 hidden = self.process_layer(layer, &hidden, tokens.len())?;
239 }
240
241 let seq_len = tokens.len();
243 let last_token_start = (seq_len - 1) * hidden_dim;
244 let mut embedding = hidden[last_token_start..last_token_start + hidden_dim].to_vec();
245
246 Self::rms_normalize(&mut embedding, &self.transformer.output_norm_weight);
248
249 if self.config.normalize {
251 Self::l2_normalize(&mut embedding);
252 }
253
254 Ok(embedding)
255 }
256
257 fn process_layer(
263 &self,
264 layer: &realizar::gguf::GGUFTransformerLayer,
265 hidden: &[f32],
266 seq_len: usize,
267 ) -> Result<Vec<f32>> {
268 let hidden_dim = self.dimension;
269 let output = hidden.to_vec();
270
271 for pos in 0..seq_len {
274 let start = pos * hidden_dim;
275 let end = start + hidden_dim;
276
277 if end > output.len() {
279 return Err(Error::Embedding(format!(
280 "Layer processing out of bounds: pos={pos}, dim={hidden_dim}"
281 )));
282 }
283
284 let mut normed = output[start..end].to_vec();
286 Self::rms_normalize(&mut normed, &layer.attn_norm_weight);
287
288 }
295
296 Ok(output)
297 }
298
299 fn rms_normalize(vector: &mut [f32], weight: &[f32]) {
301 let eps = 1e-6;
302 let ss: f32 = vector.iter().map(|x| x * x).sum::<f32>() / vector.len().max(1) as f32;
303 let scale = 1.0 / (ss + eps).sqrt();
304
305 for (v, w) in vector.iter_mut().zip(weight.iter()) {
306 *v = *v * scale * w;
307 }
308 }
309
310 fn l2_normalize(vector: &mut [f32]) {
312 let norm: f32 = vector.iter().map(|x| x * x).sum::<f32>().sqrt();
313 if norm > 0.0 {
314 for x in vector.iter_mut() {
315 *x /= norm;
316 }
317 }
318 }
319}
320
321#[cfg(feature = "nemotron")]
322impl Embedder for NemotronEmbedder {
323 fn embed(&self, text: &str) -> Result<Vec<f32>> {
324 self.embed_document(text)
325 }
326
327 fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
328 texts.iter().map(|t| self.embed(t)).collect()
330 }
331
332 fn dimension(&self) -> usize {
333 self.dimension
334 }
335
336 fn model_id(&self) -> &str {
337 "nvidia/NV-Embed-v2"
338 }
339
340 fn embed_query(&self, query: &str) -> Result<Vec<f32>> {
341 if query.is_empty() {
342 return Err(Error::Query("empty query".to_string()));
343 }
344 self.embed_with_prefix(query, &self.config.query_prefix)
345 }
346
347 fn embed_document(&self, document: &str) -> Result<Vec<f32>> {
348 if document.is_empty() {
349 return Err(Error::EmptyDocument("empty document for embedding".to_string()));
350 }
351 self.embed_with_prefix(document, &self.config.passage_prefix)
352 }
353
354 fn embed_chunks(&self, chunks: &mut [Chunk]) -> Result<()> {
355 for chunk in chunks.iter_mut() {
356 let embedding = self.embed_document(&chunk.content)?;
357 chunk.set_embedding(embedding);
358 }
359 Ok(())
360 }
361}
362
363#[cfg(test)]
364#[cfg(feature = "nemotron")]
365mod tests {
366 use super::*;
367
368 #[test]
369 fn test_nemotron_config_default() {
370 let config = NemotronConfig::default();
371 assert!(config.use_gpu);
372 assert_eq!(config.batch_size, 8);
373 assert_eq!(config.max_length, 8192);
374 assert!(config.normalize);
375 assert!(config.query_prefix.contains("Instruct"));
376 assert!(config.passage_prefix.is_empty());
377 }
378
379 #[test]
380 fn test_nemotron_config_new() {
381 let config = NemotronConfig::new("/tmp/model.gguf");
382 assert_eq!(config.model_path, std::path::PathBuf::from("/tmp/model.gguf"));
383 assert!(config.use_gpu);
384 }
385
386 #[test]
387 fn test_nemotron_config_builder() {
388 let config = NemotronConfig::default()
389 .with_model_path("/tmp/model.gguf")
390 .with_gpu(false)
391 .with_batch_size(16)
392 .with_max_length(4096)
393 .with_normalize(false)
394 .with_query_prefix("Query: ")
395 .with_passage_prefix("Passage: ");
396
397 assert_eq!(config.model_path, std::path::PathBuf::from("/tmp/model.gguf"));
398 assert!(!config.use_gpu);
399 assert_eq!(config.batch_size, 16);
400 assert_eq!(config.max_length, 4096);
401 assert!(!config.normalize);
402 assert_eq!(config.query_prefix, "Query: ");
403 assert_eq!(config.passage_prefix, "Passage: ");
404 }
405
406 #[test]
407 fn test_nemotron_embedder_missing_model() {
408 let config = NemotronConfig::new("/nonexistent/model.gguf");
409 let result = NemotronEmbedder::new(config);
410 assert!(result.is_err());
411 let err = result.unwrap_err();
412 assert!(err.to_string().contains("not found"));
413 }
414
415 #[test]
416 fn test_nemotron_embedder_invalid_gguf() {
417 let temp_dir = std::env::temp_dir();
419 let temp_file = temp_dir.join("invalid_model.gguf");
420 std::fs::write(&temp_file, b"not a valid gguf file").unwrap();
421
422 let config = NemotronConfig::new(&temp_file);
423 let result = NemotronEmbedder::new(config);
424
425 let _ = std::fs::remove_file(&temp_file);
427
428 assert!(result.is_err());
430 let err = result.unwrap_err();
431 assert!(
432 err.to_string().contains("parse") || err.to_string().contains("GGUF"),
433 "Expected parse error, got: {}",
434 err
435 );
436 }
437
438 #[test]
439 fn test_nemotron_l2_normalize() {
440 let mut vector = vec![3.0, 4.0];
441 NemotronEmbedder::l2_normalize(&mut vector);
442 let norm: f32 = vector.iter().map(|x| x * x).sum::<f32>().sqrt();
443 assert!((norm - 1.0).abs() < 1e-5);
444 assert!((vector[0] - 0.6).abs() < 1e-5);
445 assert!((vector[1] - 0.8).abs() < 1e-5);
446 }
447
448 #[test]
449 fn test_nemotron_l2_normalize_zero() {
450 let mut vector = vec![0.0, 0.0, 0.0];
451 NemotronEmbedder::l2_normalize(&mut vector);
452 assert_eq!(vector, vec![0.0, 0.0, 0.0]);
453 }
454
455 #[test]
456 fn test_nemotron_rms_normalize() {
457 let mut vector = vec![1.0, 2.0, 3.0, 4.0];
458 let weight = vec![1.0, 1.0, 1.0, 1.0];
459 NemotronEmbedder::rms_normalize(&mut vector, &weight);
460 let rms = (30.0f32 / 4.0).sqrt();
463 let expected_scale = 1.0 / (rms * rms + 1e-6).sqrt();
464 assert!((vector[0] - 1.0 * expected_scale).abs() < 0.1);
465 }
466
467 #[test]
468 fn test_nemotron_config_debug() {
469 let config = NemotronConfig::new("/tmp/test.gguf");
470 let debug_str = format!("{config:?}");
471 assert!(debug_str.contains("NemotronConfig"));
472 assert!(debug_str.contains("model_path"));
473 }
474
475 #[test]
476 fn test_nemotron_config_clone() {
477 let config = NemotronConfig::new("/tmp/test.gguf").with_batch_size(32);
478 let cloned = config.clone();
479 assert_eq!(cloned.batch_size, 32);
480 assert_eq!(cloned.model_path, config.model_path);
481 }
482
483 #[test]
484 fn test_nemotron_rms_normalize_with_weights() {
485 let mut vector = vec![2.0, 2.0];
486 let weight = vec![0.5, 2.0];
487 NemotronEmbedder::rms_normalize(&mut vector, &weight);
488 assert!((vector[0] - 0.5).abs() < 0.01);
493 assert!((vector[1] - 2.0).abs() < 0.01);
494 }
495}