1use rustkernel_core::{domain::Domain, kernel::KernelMetadata, traits::GpuKernel};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct EmbeddingConfig {
18 pub dimension: usize,
20 pub max_seq_length: usize,
22 pub normalize: bool,
24 pub pooling: PoolingStrategy,
26 pub vocab_size: usize,
28}
29
30impl Default for EmbeddingConfig {
31 fn default() -> Self {
32 Self {
33 dimension: 384,
34 max_seq_length: 512,
35 normalize: true,
36 pooling: PoolingStrategy::Mean,
37 vocab_size: 50000,
38 }
39 }
40}
41
42#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
44pub enum PoolingStrategy {
45 Mean,
47 Max,
49 CLS,
51 AttentionWeighted,
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct EmbeddingResult {
58 pub embeddings: Vec<Vec<f64>>,
60 pub token_counts: Vec<usize>,
62 pub dimension: usize,
64}
65
66#[derive(Debug, Clone)]
72pub struct EmbeddingGeneration {
73 metadata: KernelMetadata,
74}
75
76impl Default for EmbeddingGeneration {
77 fn default() -> Self {
78 Self::new()
79 }
80}
81
82impl EmbeddingGeneration {
83 #[must_use]
85 pub fn new() -> Self {
86 Self {
87 metadata: KernelMetadata::batch("ml/embedding-generation", Domain::StatisticalML)
88 .with_description("GPU-accelerated text embedding generation")
89 .with_throughput(10_000)
90 .with_latency_us(50.0),
91 }
92 }
93
94 pub fn compute(texts: &[&str], config: &EmbeddingConfig) -> EmbeddingResult {
96 if texts.is_empty() {
97 return EmbeddingResult {
98 embeddings: Vec::new(),
99 token_counts: Vec::new(),
100 dimension: config.dimension,
101 };
102 }
103
104 let mut embeddings = Vec::with_capacity(texts.len());
105 let mut token_counts = Vec::with_capacity(texts.len());
106
107 for text in texts {
108 let tokens = Self::tokenize(text, config.max_seq_length);
109 token_counts.push(tokens.len());
110
111 let token_embeddings: Vec<Vec<f64>> = tokens
112 .iter()
113 .map(|token| Self::hash_embedding(token, config.dimension, config.vocab_size))
114 .collect();
115
116 let pooled = Self::pool_embeddings(&token_embeddings, config);
117
118 let final_embedding = if config.normalize {
119 Self::normalize_vector(&pooled)
120 } else {
121 pooled
122 };
123
124 embeddings.push(final_embedding);
125 }
126
127 EmbeddingResult {
128 embeddings,
129 token_counts,
130 dimension: config.dimension,
131 }
132 }
133
134 fn tokenize(text: &str, max_length: usize) -> Vec<String> {
136 text.to_lowercase()
137 .split_whitespace()
138 .take(max_length)
139 .map(|s| s.chars().filter(|c| c.is_alphanumeric()).collect())
140 .filter(|s: &String| !s.is_empty())
141 .collect()
142 }
143
144 #[allow(clippy::needless_range_loop)]
146 fn hash_embedding(token: &str, dimension: usize, vocab_size: usize) -> Vec<f64> {
147 let mut embedding = vec![0.0; dimension];
148
149 let hash1 = Self::hash_token(token, 0) as usize;
151 let hash2 = Self::hash_token(token, 1) as usize;
152 let hash3 = Self::hash_token(token, 2) as usize;
153
154 for i in 0..dimension {
156 let idx1 = (hash1 + i * 31) % vocab_size;
157 let idx2 = (hash2 + i * 37) % vocab_size;
158 let idx3 = (hash3 + i * 41) % vocab_size;
159
160 let sign1 = if (idx1 % 2) == 0 { 1.0 } else { -1.0 };
162 let sign2 = if (idx2 % 2) == 0 { 1.0 } else { -1.0 };
163
164 embedding[i] = sign1 * ((idx1 as f64 / vocab_size as f64) - 0.5)
165 + sign2 * ((idx2 as f64 / vocab_size as f64) - 0.5) * 0.5
166 + ((idx3 as f64 / vocab_size as f64) - 0.5) * 0.25;
167 }
168
169 embedding
170 }
171
172 fn hash_token(token: &str, seed: u64) -> u64 {
174 let mut hash: u64 = seed.wrapping_mul(0x517cc1b727220a95);
175 for byte in token.bytes() {
176 hash = hash.wrapping_mul(31).wrapping_add(byte as u64);
177 }
178 hash
179 }
180
181 fn pool_embeddings(embeddings: &[Vec<f64>], config: &EmbeddingConfig) -> Vec<f64> {
183 if embeddings.is_empty() {
184 return vec![0.0; config.dimension];
185 }
186
187 match config.pooling {
188 PoolingStrategy::Mean => {
189 let mut result = vec![0.0; config.dimension];
190 for emb in embeddings {
191 for (i, &v) in emb.iter().enumerate() {
192 result[i] += v;
193 }
194 }
195 let n = embeddings.len() as f64;
196 result.iter_mut().for_each(|v| *v /= n);
197 result
198 }
199 PoolingStrategy::Max => {
200 let mut result = vec![f64::NEG_INFINITY; config.dimension];
201 for emb in embeddings {
202 for (i, &v) in emb.iter().enumerate() {
203 result[i] = result[i].max(v);
204 }
205 }
206 result
207 }
208 PoolingStrategy::CLS => embeddings[0].clone(),
209 PoolingStrategy::AttentionWeighted => {
210 let mut result = vec![0.0; config.dimension];
212 let mut total_weight = 0.0;
213
214 for (pos, emb) in embeddings.iter().enumerate() {
215 let weight = 1.0 / (1.0 + pos as f64 * 0.1);
216 total_weight += weight;
217 for (i, &v) in emb.iter().enumerate() {
218 result[i] += v * weight;
219 }
220 }
221
222 result.iter_mut().for_each(|v| *v /= total_weight);
223 result
224 }
225 }
226 }
227
228 fn normalize_vector(v: &[f64]) -> Vec<f64> {
230 let norm: f64 = v.iter().map(|x| x * x).sum::<f64>().sqrt();
231 if norm < 1e-10 {
232 v.to_vec()
233 } else {
234 v.iter().map(|x| x / norm).collect()
235 }
236 }
237}
238
239impl GpuKernel for EmbeddingGeneration {
240 fn metadata(&self) -> &KernelMetadata {
241 &self.metadata
242 }
243}
244
245#[derive(Debug, Clone, Serialize, Deserialize)]
251pub struct SimilarityConfig {
252 pub metric: SimilarityMetric,
254 pub threshold: f64,
256 pub top_k: usize,
258 pub include_self: bool,
260}
261
262impl Default for SimilarityConfig {
263 fn default() -> Self {
264 Self {
265 metric: SimilarityMetric::Cosine,
266 threshold: 0.5,
267 top_k: 10,
268 include_self: false,
269 }
270 }
271}
272
273#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
275pub enum SimilarityMetric {
276 Cosine,
278 Euclidean,
280 DotProduct,
282 Manhattan,
284}
285
286#[derive(Debug, Clone, Serialize, Deserialize)]
288pub struct SimilarityMatch {
289 pub query_idx: usize,
291 pub match_idx: usize,
293 pub score: f64,
295}
296
297#[derive(Debug, Clone, Serialize, Deserialize)]
299pub struct SimilarityResult {
300 pub matches: Vec<SimilarityMatch>,
302 pub similarity_matrix: Option<Vec<Vec<f64>>>,
304 pub query_count: usize,
306 pub corpus_count: usize,
308}
309
310#[derive(Debug, Clone)]
315pub struct SemanticSimilarity {
316 metadata: KernelMetadata,
317}
318
319impl Default for SemanticSimilarity {
320 fn default() -> Self {
321 Self::new()
322 }
323}
324
325impl SemanticSimilarity {
326 #[must_use]
328 pub fn new() -> Self {
329 Self {
330 metadata: KernelMetadata::batch("ml/semantic-similarity", Domain::StatisticalML)
331 .with_description("Semantic similarity matching for documents and entities")
332 .with_throughput(50_000)
333 .with_latency_us(20.0),
334 }
335 }
336
337 pub fn compute(
339 queries: &[Vec<f64>],
340 corpus: &[Vec<f64>],
341 config: &SimilarityConfig,
342 ) -> SimilarityResult {
343 if queries.is_empty() || corpus.is_empty() {
344 return SimilarityResult {
345 matches: Vec::new(),
346 similarity_matrix: None,
347 query_count: queries.len(),
348 corpus_count: corpus.len(),
349 };
350 }
351
352 let mut all_matches: Vec<SimilarityMatch> = Vec::new();
353 let mut similarity_matrix: Vec<Vec<f64>> = Vec::with_capacity(queries.len());
354
355 for (q_idx, query) in queries.iter().enumerate() {
356 let mut row_scores: Vec<(usize, f64)> = Vec::with_capacity(corpus.len());
357
358 for (c_idx, doc) in corpus.iter().enumerate() {
359 if !config.include_self && q_idx == c_idx {
360 continue;
361 }
362
363 let score = Self::compute_similarity(query, doc, config.metric);
364 row_scores.push((c_idx, score));
365 }
366
367 row_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
369
370 for (c_idx, score) in row_scores.iter().take(config.top_k) {
372 if *score >= config.threshold {
373 all_matches.push(SimilarityMatch {
374 query_idx: q_idx,
375 match_idx: *c_idx,
376 score: *score,
377 });
378 }
379 }
380
381 let mut full_row = vec![0.0; corpus.len()];
383 for (c_idx, score) in row_scores {
384 full_row[c_idx] = score;
385 }
386 similarity_matrix.push(full_row);
387 }
388
389 SimilarityResult {
390 matches: all_matches,
391 similarity_matrix: Some(similarity_matrix),
392 query_count: queries.len(),
393 corpus_count: corpus.len(),
394 }
395 }
396
397 pub fn find_similar(
399 queries: &[Vec<f64>],
400 corpus: &[Vec<f64>],
401 labels: Option<&[String]>,
402 config: &SimilarityConfig,
403 ) -> Vec<Vec<(usize, f64, Option<String>)>> {
404 let result = Self::compute(queries, corpus, config);
405
406 let mut grouped: HashMap<usize, Vec<(usize, f64)>> = HashMap::new();
407 for m in result.matches {
408 grouped
409 .entry(m.query_idx)
410 .or_default()
411 .push((m.match_idx, m.score));
412 }
413
414 for matches in grouped.values_mut() {
416 matches.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
417 }
418
419 queries
420 .iter()
421 .enumerate()
422 .map(|(q_idx, _)| {
423 grouped
424 .get(&q_idx)
425 .map(|matches| {
426 matches
427 .iter()
428 .map(|(idx, score)| {
429 let label = labels.and_then(|l| l.get(*idx).cloned());
430 (*idx, *score, label)
431 })
432 .collect()
433 })
434 .unwrap_or_default()
435 })
436 .collect()
437 }
438
439 fn compute_similarity(a: &[f64], b: &[f64], metric: SimilarityMetric) -> f64 {
441 if a.len() != b.len() || a.is_empty() {
442 return 0.0;
443 }
444
445 match metric {
446 SimilarityMetric::Cosine => {
447 let dot: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
448 let norm_a: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
449 let norm_b: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt();
450 if norm_a < 1e-10 || norm_b < 1e-10 {
451 0.0
452 } else {
453 dot / (norm_a * norm_b)
454 }
455 }
456 SimilarityMetric::Euclidean => {
457 let dist: f64 = a
458 .iter()
459 .zip(b.iter())
460 .map(|(x, y)| (x - y).powi(2))
461 .sum::<f64>()
462 .sqrt();
463 1.0 / (1.0 + dist)
464 }
465 SimilarityMetric::DotProduct => a.iter().zip(b.iter()).map(|(x, y)| x * y).sum(),
466 SimilarityMetric::Manhattan => {
467 let dist: f64 = a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).sum();
468 1.0 / (1.0 + dist)
469 }
470 }
471 }
472
473 pub fn deduplicate(embeddings: &[Vec<f64>], threshold: f64) -> Vec<usize> {
475 if embeddings.is_empty() {
476 return Vec::new();
477 }
478
479 let mut keep: Vec<usize> = vec![0]; for i in 1..embeddings.len() {
482 let is_duplicate = keep.iter().any(|&j| {
483 let sim = Self::compute_similarity(
484 &embeddings[i],
485 &embeddings[j],
486 SimilarityMetric::Cosine,
487 );
488 sim >= threshold
489 });
490
491 if !is_duplicate {
492 keep.push(i);
493 }
494 }
495
496 keep
497 }
498}
499
500impl GpuKernel for SemanticSimilarity {
501 fn metadata(&self) -> &KernelMetadata {
502 &self.metadata
503 }
504}
505
506#[cfg(test)]
507mod tests {
508 use super::*;
509
510 #[test]
511 fn test_embedding_generation_metadata() {
512 let kernel = EmbeddingGeneration::new();
513 assert_eq!(kernel.metadata().id, "ml/embedding-generation");
514 }
515
516 #[test]
517 fn test_embedding_generation_basic() {
518 let config = EmbeddingConfig::default();
519 let texts = vec!["hello world", "machine learning"];
520
521 let result = EmbeddingGeneration::compute(&texts, &config);
522
523 assert_eq!(result.embeddings.len(), 2);
524 assert_eq!(result.embeddings[0].len(), config.dimension);
525 assert_eq!(result.token_counts, vec![2, 2]);
526 }
527
528 #[test]
529 fn test_embedding_normalization() {
530 let config = EmbeddingConfig {
531 normalize: true,
532 ..Default::default()
533 };
534
535 let result = EmbeddingGeneration::compute(&["test text"], &config);
536
537 let norm: f64 = result.embeddings[0]
538 .iter()
539 .map(|x| x * x)
540 .sum::<f64>()
541 .sqrt();
542 assert!((norm - 1.0).abs() < 0.001);
543 }
544
545 #[test]
546 fn test_embedding_empty() {
547 let config = EmbeddingConfig::default();
548 let result = EmbeddingGeneration::compute(&[], &config);
549 assert!(result.embeddings.is_empty());
550 }
551
552 #[test]
553 fn test_pooling_strategies() {
554 let texts = vec!["a b c d e"];
555
556 for pooling in [
557 PoolingStrategy::Mean,
558 PoolingStrategy::Max,
559 PoolingStrategy::CLS,
560 PoolingStrategy::AttentionWeighted,
561 ] {
562 let config = EmbeddingConfig {
563 pooling,
564 ..Default::default()
565 };
566 let result = EmbeddingGeneration::compute(&texts, &config);
567 assert_eq!(result.embeddings.len(), 1);
568 assert_eq!(result.embeddings[0].len(), config.dimension);
569 }
570 }
571
572 #[test]
573 fn test_semantic_similarity_metadata() {
574 let kernel = SemanticSimilarity::new();
575 assert_eq!(kernel.metadata().id, "ml/semantic-similarity");
576 }
577
578 #[test]
579 fn test_semantic_similarity_basic() {
580 let queries = vec![vec![1.0, 0.0, 0.0]];
581 let corpus = vec![
582 vec![1.0, 0.0, 0.0], vec![0.0, 1.0, 0.0], vec![0.7, 0.7, 0.0], ];
586
587 let config = SimilarityConfig {
588 threshold: 0.0,
589 include_self: true,
590 ..Default::default()
591 };
592
593 let result = SemanticSimilarity::compute(&queries, &corpus, &config);
594
595 assert!(!result.matches.is_empty());
596 assert_eq!(result.matches[0].match_idx, 0);
598 assert!((result.matches[0].score - 1.0).abs() < 0.001);
599 }
600
601 #[test]
602 fn test_similarity_metrics() {
603 let a = vec![1.0, 2.0, 3.0];
604 let b = vec![1.0, 2.0, 3.0];
605
606 for metric in [
607 SimilarityMetric::Cosine,
608 SimilarityMetric::Euclidean,
609 SimilarityMetric::DotProduct,
610 SimilarityMetric::Manhattan,
611 ] {
612 let sim = SemanticSimilarity::compute_similarity(&a, &b, metric);
613 assert!(
614 sim > 0.0,
615 "Identical vectors should have positive similarity for {:?}",
616 metric
617 );
618 }
619 }
620
621 #[test]
622 fn test_deduplicate() {
623 let embeddings = vec![
624 vec![1.0, 0.0],
625 vec![0.99, 0.01], vec![0.0, 1.0], vec![0.01, 0.99], ];
629
630 let kept = SemanticSimilarity::deduplicate(&embeddings, 0.95);
631
632 assert_eq!(kept.len(), 2);
633 assert!(kept.contains(&0));
634 assert!(kept.contains(&2));
635 }
636
637 #[test]
638 fn test_find_similar_with_labels() {
639 let queries = vec![vec![1.0, 0.0]];
640 let corpus = vec![vec![0.9, 0.1], vec![0.0, 1.0]];
641 let labels = vec!["doc_a".to_string(), "doc_b".to_string()];
642
643 let config = SimilarityConfig {
644 threshold: 0.0,
645 include_self: true, ..Default::default()
647 };
648
649 let results = SemanticSimilarity::find_similar(&queries, &corpus, Some(&labels), &config);
650
651 assert_eq!(results.len(), 1);
652 assert!(!results[0].is_empty());
653 assert_eq!(results[0][0].2, Some("doc_a".to_string()));
655 }
656}