rustkernel_ml/
nlp.rs

1//! Natural Language Processing and LLM integration kernels.
2//!
3//! This module provides GPU-accelerated NLP algorithms:
4//! - EmbeddingGeneration - Text to vector embeddings
5//! - SemanticSimilarity - Document/entity similarity matching
6
7use rustkernel_core::{domain::Domain, kernel::KernelMetadata, traits::GpuKernel};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10
11// ============================================================================
12// Embedding Generation Kernel
13// ============================================================================
14
15/// Configuration for embedding generation.
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct EmbeddingConfig {
18    /// Embedding dimension.
19    pub dimension: usize,
20    /// Maximum sequence length.
21    pub max_seq_length: usize,
22    /// Whether to normalize embeddings.
23    pub normalize: bool,
24    /// Pooling strategy for sequence embeddings.
25    pub pooling: PoolingStrategy,
26    /// Vocabulary size for hash-based embeddings.
27    pub vocab_size: usize,
28}
29
30impl Default for EmbeddingConfig {
31    fn default() -> Self {
32        Self {
33            dimension: 384,
34            max_seq_length: 512,
35            normalize: true,
36            pooling: PoolingStrategy::Mean,
37            vocab_size: 50000,
38        }
39    }
40}
41
42/// Pooling strategy for combining token embeddings.
43#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
44pub enum PoolingStrategy {
45    /// Average of all token embeddings.
46    Mean,
47    /// Max pooling across tokens.
48    Max,
49    /// Use CLS token embedding (first token).
50    CLS,
51    /// Weighted average by attention.
52    AttentionWeighted,
53}
54
55/// Result of embedding generation.
56#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct EmbeddingResult {
58    /// Generated embeddings (one per input text).
59    pub embeddings: Vec<Vec<f64>>,
60    /// Token counts per input.
61    pub token_counts: Vec<usize>,
62    /// Embedding dimension.
63    pub dimension: usize,
64}
65
66/// Embedding Generation kernel.
67///
68/// Generates dense vector embeddings from text using hash-based
69/// token embeddings with configurable pooling strategies.
70/// Suitable for semantic search, clustering, and similarity tasks.
71#[derive(Debug, Clone)]
72pub struct EmbeddingGeneration {
73    metadata: KernelMetadata,
74}
75
76impl Default for EmbeddingGeneration {
77    fn default() -> Self {
78        Self::new()
79    }
80}
81
82impl EmbeddingGeneration {
83    /// Create a new Embedding Generation kernel.
84    #[must_use]
85    pub fn new() -> Self {
86        Self {
87            metadata: KernelMetadata::batch("ml/embedding-generation", Domain::StatisticalML)
88                .with_description("GPU-accelerated text embedding generation")
89                .with_throughput(10_000)
90                .with_latency_us(50.0),
91        }
92    }
93
94    /// Generate embeddings for a batch of texts.
95    pub fn compute(texts: &[&str], config: &EmbeddingConfig) -> EmbeddingResult {
96        if texts.is_empty() {
97            return EmbeddingResult {
98                embeddings: Vec::new(),
99                token_counts: Vec::new(),
100                dimension: config.dimension,
101            };
102        }
103
104        let mut embeddings = Vec::with_capacity(texts.len());
105        let mut token_counts = Vec::with_capacity(texts.len());
106
107        for text in texts {
108            let tokens = Self::tokenize(text, config.max_seq_length);
109            token_counts.push(tokens.len());
110
111            let token_embeddings: Vec<Vec<f64>> = tokens
112                .iter()
113                .map(|token| Self::hash_embedding(token, config.dimension, config.vocab_size))
114                .collect();
115
116            let pooled = Self::pool_embeddings(&token_embeddings, config);
117
118            let final_embedding = if config.normalize {
119                Self::normalize_vector(&pooled)
120            } else {
121                pooled
122            };
123
124            embeddings.push(final_embedding);
125        }
126
127        EmbeddingResult {
128            embeddings,
129            token_counts,
130            dimension: config.dimension,
131        }
132    }
133
134    /// Simple whitespace tokenization with lowercasing.
135    fn tokenize(text: &str, max_length: usize) -> Vec<String> {
136        text.to_lowercase()
137            .split_whitespace()
138            .take(max_length)
139            .map(|s| s.chars().filter(|c| c.is_alphanumeric()).collect())
140            .filter(|s: &String| !s.is_empty())
141            .collect()
142    }
143
144    /// Generate embedding from token using hash-based approach.
145    #[allow(clippy::needless_range_loop)]
146    fn hash_embedding(token: &str, dimension: usize, vocab_size: usize) -> Vec<f64> {
147        let mut embedding = vec![0.0; dimension];
148
149        // Use multiple hash functions for better distribution
150        let hash1 = Self::hash_token(token, 0) as usize;
151        let hash2 = Self::hash_token(token, 1) as usize;
152        let hash3 = Self::hash_token(token, 2) as usize;
153
154        // Sparse embedding based on hashes
155        for i in 0..dimension {
156            let idx1 = (hash1 + i * 31) % vocab_size;
157            let idx2 = (hash2 + i * 37) % vocab_size;
158            let idx3 = (hash3 + i * 41) % vocab_size;
159
160            // Combine hashes to create embedding value
161            let sign1 = if (idx1 % 2) == 0 { 1.0 } else { -1.0 };
162            let sign2 = if (idx2 % 2) == 0 { 1.0 } else { -1.0 };
163
164            embedding[i] = sign1 * ((idx1 as f64 / vocab_size as f64) - 0.5)
165                + sign2 * ((idx2 as f64 / vocab_size as f64) - 0.5) * 0.5
166                + ((idx3 as f64 / vocab_size as f64) - 0.5) * 0.25;
167        }
168
169        embedding
170    }
171
172    /// Simple hash function for tokens.
173    fn hash_token(token: &str, seed: u64) -> u64 {
174        let mut hash: u64 = seed.wrapping_mul(0x517cc1b727220a95);
175        for byte in token.bytes() {
176            hash = hash.wrapping_mul(31).wrapping_add(byte as u64);
177        }
178        hash
179    }
180
181    /// Pool token embeddings according to strategy.
182    fn pool_embeddings(embeddings: &[Vec<f64>], config: &EmbeddingConfig) -> Vec<f64> {
183        if embeddings.is_empty() {
184            return vec![0.0; config.dimension];
185        }
186
187        match config.pooling {
188            PoolingStrategy::Mean => {
189                let mut result = vec![0.0; config.dimension];
190                for emb in embeddings {
191                    for (i, &v) in emb.iter().enumerate() {
192                        result[i] += v;
193                    }
194                }
195                let n = embeddings.len() as f64;
196                result.iter_mut().for_each(|v| *v /= n);
197                result
198            }
199            PoolingStrategy::Max => {
200                let mut result = vec![f64::NEG_INFINITY; config.dimension];
201                for emb in embeddings {
202                    for (i, &v) in emb.iter().enumerate() {
203                        result[i] = result[i].max(v);
204                    }
205                }
206                result
207            }
208            PoolingStrategy::CLS => embeddings[0].clone(),
209            PoolingStrategy::AttentionWeighted => {
210                // Simple attention: weight by position (earlier = higher weight)
211                let mut result = vec![0.0; config.dimension];
212                let mut total_weight = 0.0;
213
214                for (pos, emb) in embeddings.iter().enumerate() {
215                    let weight = 1.0 / (1.0 + pos as f64 * 0.1);
216                    total_weight += weight;
217                    for (i, &v) in emb.iter().enumerate() {
218                        result[i] += v * weight;
219                    }
220                }
221
222                result.iter_mut().for_each(|v| *v /= total_weight);
223                result
224            }
225        }
226    }
227
228    /// Normalize vector to unit length.
229    fn normalize_vector(v: &[f64]) -> Vec<f64> {
230        let norm: f64 = v.iter().map(|x| x * x).sum::<f64>().sqrt();
231        if norm < 1e-10 {
232            v.to_vec()
233        } else {
234            v.iter().map(|x| x / norm).collect()
235        }
236    }
237}
238
239impl GpuKernel for EmbeddingGeneration {
240    fn metadata(&self) -> &KernelMetadata {
241        &self.metadata
242    }
243}
244
245// ============================================================================
246// Semantic Similarity Kernel
247// ============================================================================
248
249/// Configuration for semantic similarity.
250#[derive(Debug, Clone, Serialize, Deserialize)]
251pub struct SimilarityConfig {
252    /// Similarity metric to use.
253    pub metric: SimilarityMetric,
254    /// Minimum similarity threshold for matches.
255    pub threshold: f64,
256    /// Maximum number of matches to return per query.
257    pub top_k: usize,
258    /// Whether to include self-matches.
259    pub include_self: bool,
260}
261
262impl Default for SimilarityConfig {
263    fn default() -> Self {
264        Self {
265            metric: SimilarityMetric::Cosine,
266            threshold: 0.5,
267            top_k: 10,
268            include_self: false,
269        }
270    }
271}
272
273/// Similarity metric.
274#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
275pub enum SimilarityMetric {
276    /// Cosine similarity (dot product of normalized vectors).
277    Cosine,
278    /// Euclidean distance (converted to similarity).
279    Euclidean,
280    /// Dot product (unnormalized).
281    DotProduct,
282    /// Manhattan distance (converted to similarity).
283    Manhattan,
284}
285
286/// A similarity match result.
287#[derive(Debug, Clone, Serialize, Deserialize)]
288pub struct SimilarityMatch {
289    /// Index of the query item.
290    pub query_idx: usize,
291    /// Index of the matched item.
292    pub match_idx: usize,
293    /// Similarity score.
294    pub score: f64,
295}
296
297/// Result of semantic similarity computation.
298#[derive(Debug, Clone, Serialize, Deserialize)]
299pub struct SimilarityResult {
300    /// All matches above threshold.
301    pub matches: Vec<SimilarityMatch>,
302    /// Full similarity matrix (if computed).
303    pub similarity_matrix: Option<Vec<Vec<f64>>>,
304    /// Query embeddings used.
305    pub query_count: usize,
306    /// Corpus embeddings used.
307    pub corpus_count: usize,
308}
309
310/// Semantic Similarity kernel.
311///
312/// Computes semantic similarity between text embeddings for
313/// document matching, entity resolution, and semantic search.
314#[derive(Debug, Clone)]
315pub struct SemanticSimilarity {
316    metadata: KernelMetadata,
317}
318
319impl Default for SemanticSimilarity {
320    fn default() -> Self {
321        Self::new()
322    }
323}
324
325impl SemanticSimilarity {
326    /// Create a new Semantic Similarity kernel.
327    #[must_use]
328    pub fn new() -> Self {
329        Self {
330            metadata: KernelMetadata::batch("ml/semantic-similarity", Domain::StatisticalML)
331                .with_description("Semantic similarity matching for documents and entities")
332                .with_throughput(50_000)
333                .with_latency_us(20.0),
334        }
335    }
336
337    /// Compute similarity between query embeddings and corpus embeddings.
338    pub fn compute(
339        queries: &[Vec<f64>],
340        corpus: &[Vec<f64>],
341        config: &SimilarityConfig,
342    ) -> SimilarityResult {
343        if queries.is_empty() || corpus.is_empty() {
344            return SimilarityResult {
345                matches: Vec::new(),
346                similarity_matrix: None,
347                query_count: queries.len(),
348                corpus_count: corpus.len(),
349            };
350        }
351
352        let mut all_matches: Vec<SimilarityMatch> = Vec::new();
353        let mut similarity_matrix: Vec<Vec<f64>> = Vec::with_capacity(queries.len());
354
355        for (q_idx, query) in queries.iter().enumerate() {
356            let mut row_scores: Vec<(usize, f64)> = Vec::with_capacity(corpus.len());
357
358            for (c_idx, doc) in corpus.iter().enumerate() {
359                if !config.include_self && q_idx == c_idx {
360                    continue;
361                }
362
363                let score = Self::compute_similarity(query, doc, config.metric);
364                row_scores.push((c_idx, score));
365            }
366
367            // Sort by score descending
368            row_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
369
370            // Take top-k above threshold
371            for (c_idx, score) in row_scores.iter().take(config.top_k) {
372                if *score >= config.threshold {
373                    all_matches.push(SimilarityMatch {
374                        query_idx: q_idx,
375                        match_idx: *c_idx,
376                        score: *score,
377                    });
378                }
379            }
380
381            // Build full row for matrix
382            let mut full_row = vec![0.0; corpus.len()];
383            for (c_idx, score) in row_scores {
384                full_row[c_idx] = score;
385            }
386            similarity_matrix.push(full_row);
387        }
388
389        SimilarityResult {
390            matches: all_matches,
391            similarity_matrix: Some(similarity_matrix),
392            query_count: queries.len(),
393            corpus_count: corpus.len(),
394        }
395    }
396
397    /// Find most similar documents for each query.
398    pub fn find_similar(
399        queries: &[Vec<f64>],
400        corpus: &[Vec<f64>],
401        labels: Option<&[String]>,
402        config: &SimilarityConfig,
403    ) -> Vec<Vec<(usize, f64, Option<String>)>> {
404        let result = Self::compute(queries, corpus, config);
405
406        let mut grouped: HashMap<usize, Vec<(usize, f64)>> = HashMap::new();
407        for m in result.matches {
408            grouped
409                .entry(m.query_idx)
410                .or_default()
411                .push((m.match_idx, m.score));
412        }
413
414        // Sort each group by score descending
415        for matches in grouped.values_mut() {
416            matches.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
417        }
418
419        queries
420            .iter()
421            .enumerate()
422            .map(|(q_idx, _)| {
423                grouped
424                    .get(&q_idx)
425                    .map(|matches| {
426                        matches
427                            .iter()
428                            .map(|(idx, score)| {
429                                let label = labels.and_then(|l| l.get(*idx).cloned());
430                                (*idx, *score, label)
431                            })
432                            .collect()
433                    })
434                    .unwrap_or_default()
435            })
436            .collect()
437    }
438
439    /// Compute pairwise similarity between two vectors.
440    fn compute_similarity(a: &[f64], b: &[f64], metric: SimilarityMetric) -> f64 {
441        if a.len() != b.len() || a.is_empty() {
442            return 0.0;
443        }
444
445        match metric {
446            SimilarityMetric::Cosine => {
447                let dot: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
448                let norm_a: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
449                let norm_b: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt();
450                if norm_a < 1e-10 || norm_b < 1e-10 {
451                    0.0
452                } else {
453                    dot / (norm_a * norm_b)
454                }
455            }
456            SimilarityMetric::Euclidean => {
457                let dist: f64 = a
458                    .iter()
459                    .zip(b.iter())
460                    .map(|(x, y)| (x - y).powi(2))
461                    .sum::<f64>()
462                    .sqrt();
463                1.0 / (1.0 + dist)
464            }
465            SimilarityMetric::DotProduct => a.iter().zip(b.iter()).map(|(x, y)| x * y).sum(),
466            SimilarityMetric::Manhattan => {
467                let dist: f64 = a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).sum();
468                1.0 / (1.0 + dist)
469            }
470        }
471    }
472
473    /// Deduplicate a corpus based on similarity threshold.
474    pub fn deduplicate(embeddings: &[Vec<f64>], threshold: f64) -> Vec<usize> {
475        if embeddings.is_empty() {
476            return Vec::new();
477        }
478
479        let mut keep: Vec<usize> = vec![0]; // Always keep first
480
481        for i in 1..embeddings.len() {
482            let is_duplicate = keep.iter().any(|&j| {
483                let sim = Self::compute_similarity(
484                    &embeddings[i],
485                    &embeddings[j],
486                    SimilarityMetric::Cosine,
487                );
488                sim >= threshold
489            });
490
491            if !is_duplicate {
492                keep.push(i);
493            }
494        }
495
496        keep
497    }
498}
499
500impl GpuKernel for SemanticSimilarity {
501    fn metadata(&self) -> &KernelMetadata {
502        &self.metadata
503    }
504}
505
506#[cfg(test)]
507mod tests {
508    use super::*;
509
510    #[test]
511    fn test_embedding_generation_metadata() {
512        let kernel = EmbeddingGeneration::new();
513        assert_eq!(kernel.metadata().id, "ml/embedding-generation");
514    }
515
516    #[test]
517    fn test_embedding_generation_basic() {
518        let config = EmbeddingConfig::default();
519        let texts = vec!["hello world", "machine learning"];
520
521        let result = EmbeddingGeneration::compute(&texts, &config);
522
523        assert_eq!(result.embeddings.len(), 2);
524        assert_eq!(result.embeddings[0].len(), config.dimension);
525        assert_eq!(result.token_counts, vec![2, 2]);
526    }
527
528    #[test]
529    fn test_embedding_normalization() {
530        let config = EmbeddingConfig {
531            normalize: true,
532            ..Default::default()
533        };
534
535        let result = EmbeddingGeneration::compute(&["test text"], &config);
536
537        let norm: f64 = result.embeddings[0]
538            .iter()
539            .map(|x| x * x)
540            .sum::<f64>()
541            .sqrt();
542        assert!((norm - 1.0).abs() < 0.001);
543    }
544
545    #[test]
546    fn test_embedding_empty() {
547        let config = EmbeddingConfig::default();
548        let result = EmbeddingGeneration::compute(&[], &config);
549        assert!(result.embeddings.is_empty());
550    }
551
552    #[test]
553    fn test_pooling_strategies() {
554        let texts = vec!["a b c d e"];
555
556        for pooling in [
557            PoolingStrategy::Mean,
558            PoolingStrategy::Max,
559            PoolingStrategy::CLS,
560            PoolingStrategy::AttentionWeighted,
561        ] {
562            let config = EmbeddingConfig {
563                pooling,
564                ..Default::default()
565            };
566            let result = EmbeddingGeneration::compute(&texts, &config);
567            assert_eq!(result.embeddings.len(), 1);
568            assert_eq!(result.embeddings[0].len(), config.dimension);
569        }
570    }
571
572    #[test]
573    fn test_semantic_similarity_metadata() {
574        let kernel = SemanticSimilarity::new();
575        assert_eq!(kernel.metadata().id, "ml/semantic-similarity");
576    }
577
578    #[test]
579    fn test_semantic_similarity_basic() {
580        let queries = vec![vec![1.0, 0.0, 0.0]];
581        let corpus = vec![
582            vec![1.0, 0.0, 0.0], // Same as query
583            vec![0.0, 1.0, 0.0], // Orthogonal
584            vec![0.7, 0.7, 0.0], // Partially similar
585        ];
586
587        let config = SimilarityConfig {
588            threshold: 0.0,
589            include_self: true,
590            ..Default::default()
591        };
592
593        let result = SemanticSimilarity::compute(&queries, &corpus, &config);
594
595        assert!(!result.matches.is_empty());
596        // First match should be the identical vector
597        assert_eq!(result.matches[0].match_idx, 0);
598        assert!((result.matches[0].score - 1.0).abs() < 0.001);
599    }
600
601    #[test]
602    fn test_similarity_metrics() {
603        let a = vec![1.0, 2.0, 3.0];
604        let b = vec![1.0, 2.0, 3.0];
605
606        for metric in [
607            SimilarityMetric::Cosine,
608            SimilarityMetric::Euclidean,
609            SimilarityMetric::DotProduct,
610            SimilarityMetric::Manhattan,
611        ] {
612            let sim = SemanticSimilarity::compute_similarity(&a, &b, metric);
613            assert!(
614                sim > 0.0,
615                "Identical vectors should have positive similarity for {:?}",
616                metric
617            );
618        }
619    }
620
621    #[test]
622    fn test_deduplicate() {
623        let embeddings = vec![
624            vec![1.0, 0.0],
625            vec![0.99, 0.01], // Very similar to first
626            vec![0.0, 1.0],   // Different
627            vec![0.01, 0.99], // Very similar to third
628        ];
629
630        let kept = SemanticSimilarity::deduplicate(&embeddings, 0.95);
631
632        assert_eq!(kept.len(), 2);
633        assert!(kept.contains(&0));
634        assert!(kept.contains(&2));
635    }
636
637    #[test]
638    fn test_find_similar_with_labels() {
639        let queries = vec![vec![1.0, 0.0]];
640        let corpus = vec![vec![0.9, 0.1], vec![0.0, 1.0]];
641        let labels = vec!["doc_a".to_string(), "doc_b".to_string()];
642
643        let config = SimilarityConfig {
644            threshold: 0.0,
645            include_self: true, // Include all comparisons since query != corpus
646            ..Default::default()
647        };
648
649        let results = SemanticSimilarity::find_similar(&queries, &corpus, Some(&labels), &config);
650
651        assert_eq!(results.len(), 1);
652        assert!(!results[0].is_empty());
653        // The highest similarity should come first (doc_a has higher cosine sim to query)
654        assert_eq!(results[0][0].2, Some("doc_a".to_string()));
655    }
656}