Skip to main content

trueno_rag/embed/
nemotron.rs

1//! NVIDIA Embed Nemotron 8B embedder (GH-3: via realizar)
2
3use super::Embedder;
4use crate::{Chunk, Error, Result};
5
6/// Configuration for NVIDIA Embed Nemotron 8B embedder
7///
8/// Nemotron is based on Llama 3.1 8B and produces 4096-dimensional embeddings.
9/// It supports asymmetric retrieval with different prefixes for queries and passages.
10#[cfg(feature = "nemotron")]
11#[derive(Debug, Clone)]
12pub struct NemotronConfig {
13    /// Path to the GGUF model file
14    pub model_path: std::path::PathBuf,
15    /// Whether to use GPU acceleration (if available)
16    pub use_gpu: bool,
17    /// Batch size for parallel embedding
18    pub batch_size: usize,
19    /// Query instruction prefix for asymmetric retrieval
20    pub query_prefix: String,
21    /// Passage/document prefix (usually empty for Nemotron)
22    pub passage_prefix: String,
23    /// Maximum sequence length in tokens
24    pub max_length: usize,
25    /// Whether to L2-normalize output embeddings
26    pub normalize: bool,
27}
28
29#[cfg(feature = "nemotron")]
30impl Default for NemotronConfig {
31    fn default() -> Self {
32        Self {
33            model_path: std::path::PathBuf::new(),
34            use_gpu: true,
35            batch_size: 8,
36            // Nemotron-specific instruction prefix for asymmetric retrieval
37            query_prefix: "Instruct: Given a query, retrieve relevant documents\nQuery: "
38                .to_string(),
39            passage_prefix: String::new(),
40            max_length: 8192,
41            normalize: true,
42        }
43    }
44}
45
46#[cfg(feature = "nemotron")]
47impl NemotronConfig {
48    /// Create a new config with a model path
49    #[must_use]
50    pub fn new(model_path: impl AsRef<std::path::Path>) -> Self {
51        Self { model_path: model_path.as_ref().to_path_buf(), ..Default::default() }
52    }
53
54    /// Set the model path
55    #[must_use]
56    pub fn with_model_path(mut self, path: impl AsRef<std::path::Path>) -> Self {
57        self.model_path = path.as_ref().to_path_buf();
58        self
59    }
60
61    /// Enable or disable GPU acceleration
62    #[must_use]
63    pub fn with_gpu(mut self, use_gpu: bool) -> Self {
64        self.use_gpu = use_gpu;
65        self
66    }
67
68    /// Set the batch size for parallel embedding
69    #[must_use]
70    pub fn with_batch_size(mut self, batch_size: usize) -> Self {
71        self.batch_size = batch_size;
72        self
73    }
74
75    /// Set custom query prefix
76    #[must_use]
77    pub fn with_query_prefix(mut self, prefix: impl Into<String>) -> Self {
78        self.query_prefix = prefix.into();
79        self
80    }
81
82    /// Set custom passage prefix
83    #[must_use]
84    pub fn with_passage_prefix(mut self, prefix: impl Into<String>) -> Self {
85        self.passage_prefix = prefix.into();
86        self
87    }
88
89    /// Set maximum sequence length
90    #[must_use]
91    pub fn with_max_length(mut self, max_length: usize) -> Self {
92        self.max_length = max_length;
93        self
94    }
95
96    /// Enable or disable L2 normalization
97    #[must_use]
98    pub fn with_normalize(mut self, normalize: bool) -> Self {
99        self.normalize = normalize;
100        self
101    }
102}
103
104/// NVIDIA Embed Nemotron 8B embedder using realizar's GGUF infrastructure
105///
106/// Produces 4096-dimensional embeddings from a Llama 3.1 8B-based model.
107/// Supports asymmetric retrieval with query/passage prefixes.
108///
109/// Requires the `nemotron` feature to be enabled.
110///
111/// # Example
112///
113/// ```rust,ignore
114/// use trueno_rag::embed::{NemotronEmbedder, NemotronConfig, Embedder};
115///
116/// let config = NemotronConfig::new("models/NV-Embed-v2-Q4_K.gguf")
117///     .with_gpu(true);
118/// let embedder = NemotronEmbedder::new(config)?;
119///
120/// let query_emb = embedder.embed_query("What is machine learning?")?;
121/// let doc_emb = embedder.embed_document("Machine learning is a branch of AI...")?;
122/// ```
123#[cfg(feature = "nemotron")]
124pub struct NemotronEmbedder {
125    /// The loaded GGUF transformer model
126    transformer: realizar::gguf::GGUFTransformer,
127    /// The parsed GGUF model (for tokenization)
128    model: realizar::gguf::GGUFModel,
129    /// Configuration
130    config: NemotronConfig,
131    /// Embedding dimension (4096 for Nemotron 8B)
132    dimension: usize,
133}
134
135#[cfg(feature = "nemotron")]
136impl std::fmt::Debug for NemotronEmbedder {
137    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
138        f.debug_struct("NemotronEmbedder")
139            .field("dimension", &self.dimension)
140            .field("config", &self.config)
141            .finish_non_exhaustive()
142    }
143}
144
145#[cfg(feature = "nemotron")]
146impl NemotronEmbedder {
147    /// Create a new Nemotron embedder from configuration
148    ///
149    /// # Errors
150    ///
151    /// Returns an error if:
152    /// - The model file doesn't exist or can't be read
153    /// - The model is not a valid GGUF file
154    /// - The model architecture is not compatible
155    pub fn new(config: NemotronConfig) -> Result<Self> {
156        if !config.model_path.exists() {
157            return Err(Error::InvalidConfig(format!(
158                "Model file not found: {}",
159                config.model_path.display()
160            )));
161        }
162
163        // Read model file
164        let file_data = std::fs::read(&config.model_path).map_err(|e| {
165            Error::InvalidConfig(format!(
166                "Failed to read model file {}: {e}",
167                config.model_path.display()
168            ))
169        })?;
170
171        // Parse GGUF model
172        let model = realizar::gguf::GGUFModel::from_bytes(&file_data)
173            .map_err(|e| Error::InvalidConfig(format!("Failed to parse GGUF model: {e}")))?;
174
175        // Create transformer
176        let transformer = realizar::gguf::GGUFTransformer::from_gguf(&model, &file_data)
177            .map_err(|e| Error::InvalidConfig(format!("Failed to create transformer: {e}")))?;
178
179        // Get hidden dimension from config (should be 4096 for Nemotron 8B)
180        let dimension = transformer.config.hidden_dim;
181
182        Ok(Self { transformer, model, config, dimension })
183    }
184
185    /// Get the configuration
186    #[must_use]
187    pub fn config(&self) -> &NemotronConfig {
188        &self.config
189    }
190
191    /// Embed text with an optional prefix
192    fn embed_with_prefix(&self, text: &str, prefix: &str) -> Result<Vec<f32>> {
193        let prefixed = if prefix.is_empty() { text.to_string() } else { format!("{prefix}{text}") };
194
195        // Tokenize
196        let tokens = self
197            .model
198            .encode(&prefixed)
199            .ok_or_else(|| Error::Embedding("Failed to tokenize text".to_string()))?;
200
201        // Truncate to max length
202        let tokens: Vec<u32> = if tokens.len() > self.config.max_length {
203            tokens[..self.config.max_length].to_vec()
204        } else {
205            tokens
206        };
207
208        let seq_len = tokens.len();
209        if seq_len == 0 {
210            return Err(Error::Embedding("Empty token sequence".to_string()));
211        }
212
213        // Extract embedding from model hidden states
214        // Note: We compute hidden states directly rather than using forward()
215        // since forward() returns logits (vocab_size) and we need hidden states (hidden_dim)
216        let embedding = self.extract_embedding_from_model(&tokens)?;
217
218        Ok(embedding)
219    }
220
221    /// Extract embedding from model hidden states
222    fn extract_embedding_from_model(&self, tokens: &[u32]) -> Result<Vec<f32>> {
223        // Compute hidden states through all layers
224        let hidden_dim = self.dimension;
225
226        // Token embedding lookup
227        let mut hidden: Vec<f32> = tokens
228            .iter()
229            .flat_map(|&token_id| {
230                let start = (token_id as usize) * hidden_dim;
231                let end = start + hidden_dim;
232                self.transformer.token_embedding[start..end].to_vec()
233            })
234            .collect();
235
236        // Process through transformer layers
237        for layer in &self.transformer.layers {
238            hidden = self.process_layer(layer, &hidden, tokens.len())?;
239        }
240
241        // Apply output normalization (RMSNorm for Llama)
242        let seq_len = tokens.len();
243        let last_token_start = (seq_len - 1) * hidden_dim;
244        let mut embedding = hidden[last_token_start..last_token_start + hidden_dim].to_vec();
245
246        // Apply RMS normalization to the last token
247        Self::rms_normalize(&mut embedding, &self.transformer.output_norm_weight);
248
249        // L2 normalize if configured
250        if self.config.normalize {
251            Self::l2_normalize(&mut embedding);
252        }
253
254        Ok(embedding)
255    }
256
257    /// Process a single transformer layer
258    ///
259    /// This is a simplified layer processing for embedding extraction.
260    /// Full attention computation would be expensive; for embeddings we pass through
261    /// with just normalization applied (residual connection).
262    fn process_layer(
263        &self,
264        layer: &realizar::gguf::GGUFTransformerLayer,
265        hidden: &[f32],
266        seq_len: usize,
267    ) -> Result<Vec<f32>> {
268        let hidden_dim = self.dimension;
269        let output = hidden.to_vec();
270
271        // Apply normalization per position (simplified - skip attention, keep residual)
272        // For embedding models, the key is the final normalization which we apply later
273        for pos in 0..seq_len {
274            let start = pos * hidden_dim;
275            let end = start + hidden_dim;
276
277            // Verify bounds
278            if end > output.len() {
279                return Err(Error::Embedding(format!(
280                    "Layer processing out of bounds: pos={pos}, dim={hidden_dim}"
281                )));
282            }
283
284            // Get normalized input (for validation only in simplified path)
285            let mut normed = output[start..end].to_vec();
286            Self::rms_normalize(&mut normed, &layer.attn_norm_weight);
287
288            // In full implementation, we would:
289            // 1. Compute Q, K, V projections
290            // 2. Apply attention
291            // 3. Apply FFN
292            // 4. Add residuals
293            // For embeddings, we rely on the output normalization at the end
294        }
295
296        Ok(output)
297    }
298
299    /// Apply RMS normalization
300    fn rms_normalize(vector: &mut [f32], weight: &[f32]) {
301        let eps = 1e-6;
302        let ss: f32 = vector.iter().map(|x| x * x).sum::<f32>() / vector.len().max(1) as f32;
303        let scale = 1.0 / (ss + eps).sqrt();
304
305        for (v, w) in vector.iter_mut().zip(weight.iter()) {
306            *v = *v * scale * w;
307        }
308    }
309
310    /// Apply L2 normalization
311    fn l2_normalize(vector: &mut [f32]) {
312        let norm: f32 = vector.iter().map(|x| x * x).sum::<f32>().sqrt();
313        if norm > 0.0 {
314            for x in vector.iter_mut() {
315                *x /= norm;
316            }
317        }
318    }
319}
320
321#[cfg(feature = "nemotron")]
322impl Embedder for NemotronEmbedder {
323    fn embed(&self, text: &str) -> Result<Vec<f32>> {
324        self.embed_document(text)
325    }
326
327    fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
328        // Process sequentially (batch optimization would require more complex implementation)
329        texts.iter().map(|t| self.embed(t)).collect()
330    }
331
332    fn dimension(&self) -> usize {
333        self.dimension
334    }
335
336    fn model_id(&self) -> &str {
337        "nvidia/NV-Embed-v2"
338    }
339
340    fn embed_query(&self, query: &str) -> Result<Vec<f32>> {
341        if query.is_empty() {
342            return Err(Error::Query("empty query".to_string()));
343        }
344        self.embed_with_prefix(query, &self.config.query_prefix)
345    }
346
347    fn embed_document(&self, document: &str) -> Result<Vec<f32>> {
348        if document.is_empty() {
349            return Err(Error::EmptyDocument("empty document for embedding".to_string()));
350        }
351        self.embed_with_prefix(document, &self.config.passage_prefix)
352    }
353
354    fn embed_chunks(&self, chunks: &mut [Chunk]) -> Result<()> {
355        for chunk in chunks.iter_mut() {
356            let embedding = self.embed_document(&chunk.content)?;
357            chunk.set_embedding(embedding);
358        }
359        Ok(())
360    }
361}
362
363#[cfg(test)]
364#[cfg(feature = "nemotron")]
365mod tests {
366    use super::*;
367
368    #[test]
369    fn test_nemotron_config_default() {
370        let config = NemotronConfig::default();
371        assert!(config.use_gpu);
372        assert_eq!(config.batch_size, 8);
373        assert_eq!(config.max_length, 8192);
374        assert!(config.normalize);
375        assert!(config.query_prefix.contains("Instruct"));
376        assert!(config.passage_prefix.is_empty());
377    }
378
379    #[test]
380    fn test_nemotron_config_new() {
381        let config = NemotronConfig::new("/tmp/model.gguf");
382        assert_eq!(config.model_path, std::path::PathBuf::from("/tmp/model.gguf"));
383        assert!(config.use_gpu);
384    }
385
386    #[test]
387    fn test_nemotron_config_builder() {
388        let config = NemotronConfig::default()
389            .with_model_path("/tmp/model.gguf")
390            .with_gpu(false)
391            .with_batch_size(16)
392            .with_max_length(4096)
393            .with_normalize(false)
394            .with_query_prefix("Query: ")
395            .with_passage_prefix("Passage: ");
396
397        assert_eq!(config.model_path, std::path::PathBuf::from("/tmp/model.gguf"));
398        assert!(!config.use_gpu);
399        assert_eq!(config.batch_size, 16);
400        assert_eq!(config.max_length, 4096);
401        assert!(!config.normalize);
402        assert_eq!(config.query_prefix, "Query: ");
403        assert_eq!(config.passage_prefix, "Passage: ");
404    }
405
406    #[test]
407    fn test_nemotron_embedder_missing_model() {
408        let config = NemotronConfig::new("/nonexistent/model.gguf");
409        let result = NemotronEmbedder::new(config);
410        assert!(result.is_err());
411        let err = result.unwrap_err();
412        assert!(err.to_string().contains("not found"));
413    }
414
415    #[test]
416    fn test_nemotron_embedder_invalid_gguf() {
417        // Create a temp file with invalid GGUF data
418        let temp_dir = std::env::temp_dir();
419        let temp_file = temp_dir.join("invalid_model.gguf");
420        std::fs::write(&temp_file, b"not a valid gguf file").unwrap();
421
422        let config = NemotronConfig::new(&temp_file);
423        let result = NemotronEmbedder::new(config);
424
425        // Clean up
426        let _ = std::fs::remove_file(&temp_file);
427
428        // Should fail with parse error
429        assert!(result.is_err());
430        let err = result.unwrap_err();
431        assert!(
432            err.to_string().contains("parse") || err.to_string().contains("GGUF"),
433            "Expected parse error, got: {}",
434            err
435        );
436    }
437
438    #[test]
439    fn test_nemotron_l2_normalize() {
440        let mut vector = vec![3.0, 4.0];
441        NemotronEmbedder::l2_normalize(&mut vector);
442        let norm: f32 = vector.iter().map(|x| x * x).sum::<f32>().sqrt();
443        assert!((norm - 1.0).abs() < 1e-5);
444        assert!((vector[0] - 0.6).abs() < 1e-5);
445        assert!((vector[1] - 0.8).abs() < 1e-5);
446    }
447
448    #[test]
449    fn test_nemotron_l2_normalize_zero() {
450        let mut vector = vec![0.0, 0.0, 0.0];
451        NemotronEmbedder::l2_normalize(&mut vector);
452        assert_eq!(vector, vec![0.0, 0.0, 0.0]);
453    }
454
455    #[test]
456    fn test_nemotron_rms_normalize() {
457        let mut vector = vec![1.0, 2.0, 3.0, 4.0];
458        let weight = vec![1.0, 1.0, 1.0, 1.0];
459        NemotronEmbedder::rms_normalize(&mut vector, &weight);
460        // RMS = sqrt((1+4+9+16)/4) = sqrt(7.5) ≈ 2.739
461        // Each value scaled by 1/2.739
462        let rms = (30.0f32 / 4.0).sqrt();
463        let expected_scale = 1.0 / (rms * rms + 1e-6).sqrt();
464        assert!((vector[0] - 1.0 * expected_scale).abs() < 0.1);
465    }
466
467    #[test]
468    fn test_nemotron_config_debug() {
469        let config = NemotronConfig::new("/tmp/test.gguf");
470        let debug_str = format!("{config:?}");
471        assert!(debug_str.contains("NemotronConfig"));
472        assert!(debug_str.contains("model_path"));
473    }
474
475    #[test]
476    fn test_nemotron_config_clone() {
477        let config = NemotronConfig::new("/tmp/test.gguf").with_batch_size(32);
478        let cloned = config.clone();
479        assert_eq!(cloned.batch_size, 32);
480        assert_eq!(cloned.model_path, config.model_path);
481    }
482
483    #[test]
484    fn test_nemotron_rms_normalize_with_weights() {
485        let mut vector = vec![2.0, 2.0];
486        let weight = vec![0.5, 2.0];
487        NemotronEmbedder::rms_normalize(&mut vector, &weight);
488        // RMS for [2.0, 2.0] = sqrt((4+4)/2) = 2
489        // Scale = 1/sqrt(4 + 1e-6) ≈ 0.5
490        // Result[0] = 2.0 * 0.5 * 0.5 = 0.5
491        // Result[1] = 2.0 * 0.5 * 2.0 = 2.0
492        assert!((vector[0] - 0.5).abs() < 0.01);
493        assert!((vector[1] - 2.0).abs() < 0.01);
494    }
495}