1use crate::types::{CsrGraph, SimilarityScore};
9use rustkernel_core::{domain::Domain, kernel::KernelMetadata, traits::GpuKernel};
10use std::collections::HashSet;
11
12#[derive(Debug, Clone)]
21pub struct JaccardSimilarity {
22 metadata: KernelMetadata,
23}
24
25impl JaccardSimilarity {
26 #[must_use]
28 pub fn new() -> Self {
29 Self {
30 metadata: KernelMetadata::batch("graph/jaccard-similarity", Domain::GraphAnalytics)
31 .with_description("Jaccard similarity (neighbor set overlap)")
32 .with_throughput(100_000)
33 .with_latency_us(10.0),
34 }
35 }
36
37 pub fn compute_pair(graph: &CsrGraph, node_a: u64, node_b: u64) -> f64 {
39 let neighbors_a: HashSet<u64> = graph.neighbors(node_a).iter().copied().collect();
40 let neighbors_b: HashSet<u64> = graph.neighbors(node_b).iter().copied().collect();
41
42 let intersection = neighbors_a.intersection(&neighbors_b).count();
43 let union = neighbors_a.union(&neighbors_b).count();
44
45 if union == 0 {
46 0.0
47 } else {
48 intersection as f64 / union as f64
49 }
50 }
51
52 pub fn compute_all_pairs(
59 graph: &CsrGraph,
60 min_similarity: f64,
61 max_pairs: usize,
62 ) -> Vec<SimilarityScore> {
63 let n = graph.num_nodes;
64 let mut results = Vec::new();
65
66 for i in 0..n {
67 for j in (i + 1)..n {
68 let similarity = Self::compute_pair(graph, i as u64, j as u64);
69
70 if similarity >= min_similarity {
71 results.push(SimilarityScore {
72 id_a: i as u64,
73 id_b: j as u64,
74 similarity,
75 });
76
77 if results.len() >= max_pairs {
78 return results;
79 }
80 }
81 }
82 }
83
84 results.sort_by(|a, b| {
86 b.similarity
87 .partial_cmp(&a.similarity)
88 .unwrap_or(std::cmp::Ordering::Equal)
89 });
90 results
91 }
92
93 pub fn top_k_pairs(graph: &CsrGraph, k: usize) -> Vec<SimilarityScore> {
95 Self::compute_all_pairs(graph, 0.0, k)
96 }
97}
98
99impl Default for JaccardSimilarity {
100 fn default() -> Self {
101 Self::new()
102 }
103}
104
105impl GpuKernel for JaccardSimilarity {
106 fn metadata(&self) -> &KernelMetadata {
107 &self.metadata
108 }
109}
110
111#[derive(Debug, Clone)]
120pub struct CosineSimilarity {
121 metadata: KernelMetadata,
122}
123
124impl CosineSimilarity {
125 #[must_use]
127 pub fn new() -> Self {
128 Self {
129 metadata: KernelMetadata::batch("graph/cosine-similarity", Domain::GraphAnalytics)
130 .with_description("Cosine similarity (normalized dot product)")
131 .with_throughput(100_000)
132 .with_latency_us(10.0),
133 }
134 }
135
136 pub fn compute_pair(graph: &CsrGraph, node_a: u64, node_b: u64) -> f64 {
138 let neighbors_a: HashSet<u64> = graph.neighbors(node_a).iter().copied().collect();
139 let neighbors_b: HashSet<u64> = graph.neighbors(node_b).iter().copied().collect();
140
141 let intersection = neighbors_a.intersection(&neighbors_b).count() as f64;
142 let norm = (neighbors_a.len() as f64 * neighbors_b.len() as f64).sqrt();
143
144 if norm == 0.0 {
145 0.0
146 } else {
147 intersection / norm
148 }
149 }
150
151 pub fn compute_all_pairs(
153 graph: &CsrGraph,
154 min_similarity: f64,
155 max_pairs: usize,
156 ) -> Vec<SimilarityScore> {
157 let n = graph.num_nodes;
158 let mut results = Vec::new();
159
160 for i in 0..n {
161 for j in (i + 1)..n {
162 let similarity = Self::compute_pair(graph, i as u64, j as u64);
163
164 if similarity >= min_similarity {
165 results.push(SimilarityScore {
166 id_a: i as u64,
167 id_b: j as u64,
168 similarity,
169 });
170
171 if results.len() >= max_pairs {
172 return results;
173 }
174 }
175 }
176 }
177
178 results.sort_by(|a, b| {
179 b.similarity
180 .partial_cmp(&a.similarity)
181 .unwrap_or(std::cmp::Ordering::Equal)
182 });
183 results
184 }
185}
186
187impl Default for CosineSimilarity {
188 fn default() -> Self {
189 Self::new()
190 }
191}
192
193impl GpuKernel for CosineSimilarity {
194 fn metadata(&self) -> &KernelMetadata {
195 &self.metadata
196 }
197}
198
199#[derive(Debug, Clone)]
208pub struct AdamicAdarIndex {
209 metadata: KernelMetadata,
210}
211
212impl AdamicAdarIndex {
213 #[must_use]
215 pub fn new() -> Self {
216 Self {
217 metadata: KernelMetadata::batch("graph/adamic-adar", Domain::GraphAnalytics)
218 .with_description("Adamic-Adar index (weighted common neighbors)")
219 .with_throughput(100_000)
220 .with_latency_us(10.0),
221 }
222 }
223
224 pub fn compute_pair(graph: &CsrGraph, node_a: u64, node_b: u64) -> f64 {
226 let neighbors_a: HashSet<u64> = graph.neighbors(node_a).iter().copied().collect();
227 let neighbors_b: HashSet<u64> = graph.neighbors(node_b).iter().copied().collect();
228
229 let common_neighbors = neighbors_a.intersection(&neighbors_b);
230
231 common_neighbors
232 .map(|&z| {
233 let degree = graph.out_degree(z) as f64;
234 if degree > 1.0 { 1.0 / degree.ln() } else { 0.0 }
235 })
236 .sum()
237 }
238
239 pub fn compute_all_pairs(
241 graph: &CsrGraph,
242 min_score: f64,
243 max_pairs: usize,
244 ) -> Vec<SimilarityScore> {
245 let n = graph.num_nodes;
246 let mut results = Vec::new();
247
248 for i in 0..n {
249 for j in (i + 1)..n {
250 let score = Self::compute_pair(graph, i as u64, j as u64);
251
252 if score >= min_score {
253 results.push(SimilarityScore {
254 id_a: i as u64,
255 id_b: j as u64,
256 similarity: score,
257 });
258
259 if results.len() >= max_pairs {
260 return results;
261 }
262 }
263 }
264 }
265
266 results.sort_by(|a, b| {
267 b.similarity
268 .partial_cmp(&a.similarity)
269 .unwrap_or(std::cmp::Ordering::Equal)
270 });
271 results
272 }
273
274 pub fn top_k_pairs(graph: &CsrGraph, k: usize) -> Vec<SimilarityScore> {
276 Self::compute_all_pairs(graph, 0.0, k)
277 }
278}
279
280impl Default for AdamicAdarIndex {
281 fn default() -> Self {
282 Self::new()
283 }
284}
285
286impl GpuKernel for AdamicAdarIndex {
287 fn metadata(&self) -> &KernelMetadata {
288 &self.metadata
289 }
290}
291
292#[derive(Debug, Clone)]
300pub struct CommonNeighbors {
301 metadata: KernelMetadata,
302}
303
304impl CommonNeighbors {
305 #[must_use]
307 pub fn new() -> Self {
308 Self {
309 metadata: KernelMetadata::batch("graph/common-neighbors", Domain::GraphAnalytics)
310 .with_description("Common neighbors count")
311 .with_throughput(200_000)
312 .with_latency_us(5.0),
313 }
314 }
315
316 pub fn compute_pair(graph: &CsrGraph, node_a: u64, node_b: u64) -> usize {
318 let neighbors_a: HashSet<u64> = graph.neighbors(node_a).iter().copied().collect();
319 let neighbors_b: HashSet<u64> = graph.neighbors(node_b).iter().copied().collect();
320
321 neighbors_a.intersection(&neighbors_b).count()
322 }
323
324 pub fn compute_all_pairs(
326 graph: &CsrGraph,
327 min_count: usize,
328 max_pairs: usize,
329 ) -> Vec<SimilarityScore> {
330 let n = graph.num_nodes;
331 let mut results = Vec::new();
332
333 for i in 0..n {
334 for j in (i + 1)..n {
335 let count = Self::compute_pair(graph, i as u64, j as u64);
336
337 if count >= min_count {
338 results.push(SimilarityScore {
339 id_a: i as u64,
340 id_b: j as u64,
341 similarity: count as f64,
342 });
343
344 if results.len() >= max_pairs {
345 return results;
346 }
347 }
348 }
349 }
350
351 results.sort_by(|a, b| {
352 b.similarity
353 .partial_cmp(&a.similarity)
354 .unwrap_or(std::cmp::Ordering::Equal)
355 });
356 results
357 }
358}
359
360impl Default for CommonNeighbors {
361 fn default() -> Self {
362 Self::new()
363 }
364}
365
366impl GpuKernel for CommonNeighbors {
367 fn metadata(&self) -> &KernelMetadata {
368 &self.metadata
369 }
370}
371
372#[cfg(test)]
373mod tests {
374 use super::*;
375
376 fn create_test_graph() -> CsrGraph {
377 CsrGraph::from_edges(
382 6,
383 &[
384 (0, 1),
385 (1, 0),
386 (1, 2),
387 (2, 1),
388 (0, 3),
389 (3, 0),
390 (1, 4),
391 (4, 1),
392 (2, 5),
393 (5, 2),
394 (3, 4),
395 (4, 3),
396 (4, 5),
397 (5, 4),
398 ],
399 )
400 }
401
402 #[test]
403 fn test_jaccard_similarity_metadata() {
404 let kernel = JaccardSimilarity::new();
405 assert_eq!(kernel.metadata().id, "graph/jaccard-similarity");
406 assert_eq!(kernel.metadata().domain, Domain::GraphAnalytics);
407 }
408
409 #[test]
410 fn test_jaccard_similarity_pair() {
411 let graph = create_test_graph();
412
413 let sim = JaccardSimilarity::compute_pair(&graph, 0, 2);
417 assert!(
418 (sim - 1.0 / 3.0).abs() < 0.01,
419 "Expected ~0.33, got {}",
420 sim
421 );
422
423 }
426
427 #[test]
428 fn test_cosine_similarity_pair() {
429 let graph = create_test_graph();
430
431 let sim = CosineSimilarity::compute_pair(&graph, 0, 2);
434 assert!((sim - 0.5).abs() < 0.01, "Expected 0.5, got {}", sim);
435 }
436
437 #[test]
438 fn test_adamic_adar_pair() {
439 let graph = create_test_graph();
440
441 let aa = AdamicAdarIndex::compute_pair(&graph, 0, 2);
444
445 assert!(aa > 0.0, "Expected positive Adamic-Adar score, got {}", aa);
447
448 let aa_no_common = AdamicAdarIndex::compute_pair(&graph, 0, 5);
450 assert_eq!(aa_no_common, 0.0);
451 }
452
453 #[test]
454 fn test_common_neighbors_pair() {
455 let graph = create_test_graph();
456
457 let count = CommonNeighbors::compute_pair(&graph, 0, 2);
459 assert_eq!(count, 1);
460
461 let count = CommonNeighbors::compute_pair(&graph, 0, 1);
463 assert_eq!(count, 0);
465 }
466
467 #[test]
468 fn test_jaccard_all_pairs() {
469 let graph = create_test_graph();
470 let pairs = JaccardSimilarity::compute_all_pairs(&graph, 0.0, 100);
471
472 assert!(!pairs.is_empty());
474
475 for i in 1..pairs.len() {
477 assert!(pairs[i - 1].similarity >= pairs[i].similarity);
478 }
479 }
480}
481
482#[derive(Debug, Clone)]
488pub struct ValueDistribution {
489 pub node_count: usize,
491 pub bin_count: usize,
493 pub distributions: Vec<f64>,
496 pub bin_edges: Vec<f64>,
498 pub strategy: BinningStrategy,
500}
501
502#[derive(Debug, Clone, Copy, PartialEq, Eq)]
504pub enum BinningStrategy {
505 EqualWidth,
507 Logarithmic,
509 Quantile,
511}
512
513impl ValueDistribution {
514 pub fn new(node_count: usize, bin_count: usize) -> Self {
516 Self {
517 node_count,
518 bin_count,
519 distributions: vec![0.0; node_count * bin_count],
520 bin_edges: vec![0.0; bin_count + 1],
521 strategy: BinningStrategy::EqualWidth,
522 }
523 }
524
525 pub fn from_values(values: &[Vec<f64>], bin_count: usize) -> Self {
527 let node_count = values.len();
528
529 let (min_val, max_val) = values
531 .iter()
532 .flat_map(|v| v.iter())
533 .fold((f64::INFINITY, f64::NEG_INFINITY), |(min, max), &v| {
534 (min.min(v), max.max(v))
535 });
536
537 let range = max_val - min_val;
538 let bin_width = if range > 0.0 {
539 range / bin_count as f64
540 } else {
541 1.0
542 };
543
544 let mut dist = Self::new(node_count, bin_count);
545
546 for i in 0..=bin_count {
548 dist.bin_edges[i] = min_val + i as f64 * bin_width;
549 }
550 dist.bin_edges[bin_count] = max_val + 0.001; for (node, node_values) in values.iter().enumerate() {
554 if node_values.is_empty() {
555 continue;
556 }
557
558 for &v in node_values {
559 let bin = ((v - min_val) / bin_width).floor() as usize;
560 let bin = bin.min(bin_count - 1);
561 dist.distributions[node * bin_count + bin] += 1.0;
562 }
563
564 let sum: f64 = dist.distributions[node * bin_count..(node + 1) * bin_count]
566 .iter()
567 .sum();
568 if sum > 0.0 {
569 for b in 0..bin_count {
570 dist.distributions[node * bin_count + b] /= sum;
571 }
572 }
573 }
574
575 dist
576 }
577
578 pub fn get_distribution(&self, node: usize) -> &[f64] {
580 let start = node * self.bin_count;
581 &self.distributions[start..start + self.bin_count]
582 }
583}
584
585#[derive(Debug, Clone)]
587pub struct ValueSimilarityResult {
588 pub node_a: usize,
590 pub node_b: usize,
592 pub similarity: f64,
594 pub distance: f64,
596}
597
598#[derive(Debug, Clone)]
604pub struct ValueSimilarity {
605 metadata: KernelMetadata,
606}
607
608impl Default for ValueSimilarity {
609 fn default() -> Self {
610 Self::new()
611 }
612}
613
614impl ValueSimilarity {
615 #[must_use]
617 pub fn new() -> Self {
618 Self {
619 metadata: KernelMetadata::batch("graph/value-similarity", Domain::GraphAnalytics)
620 .with_description("Value distribution similarity via JSD/Wasserstein")
621 .with_throughput(25_000)
622 .with_latency_us(800.0),
623 }
624 }
625
626 pub fn jensen_shannon_divergence(p: &[f64], q: &[f64]) -> f64 {
631 assert_eq!(p.len(), q.len(), "Distributions must have same length");
632
633 let epsilon = 1e-10;
634
635 let mut kl_pm = 0.0;
636 let mut kl_qm = 0.0;
637
638 for i in 0..p.len() {
639 let m = 0.5 * (p[i] + q[i]);
640
641 if p[i] > epsilon && m > epsilon {
642 kl_pm += p[i] * (p[i] / m).ln();
643 }
644 if q[i] > epsilon && m > epsilon {
645 kl_qm += q[i] * (q[i] / m).ln();
646 }
647 }
648
649 0.5 * kl_pm + 0.5 * kl_qm
650 }
651
652 pub fn jsd_similarity(p: &[f64], q: &[f64]) -> f64 {
654 let jsd = Self::jensen_shannon_divergence(p, q);
655 1.0 - (jsd / 2.0_f64.ln()).sqrt()
657 }
658
659 pub fn wasserstein_distance(p: &[f64], q: &[f64]) -> f64 {
663 assert_eq!(p.len(), q.len(), "Distributions must have same length");
664
665 let mut cdf_p = 0.0;
666 let mut cdf_q = 0.0;
667 let mut w1 = 0.0;
668
669 for i in 0..p.len() {
670 cdf_p += p[i];
671 cdf_q += q[i];
672 w1 += (cdf_p - cdf_q).abs();
673 }
674
675 w1
676 }
677
678 pub fn wasserstein_similarity(p: &[f64], q: &[f64]) -> f64 {
680 let w1 = Self::wasserstein_distance(p, q);
681 1.0 / (1.0 + w1)
683 }
684
685 pub fn compute_all_pairs_jsd(
687 distributions: &ValueDistribution,
688 min_similarity: f64,
689 max_pairs: usize,
690 ) -> Vec<ValueSimilarityResult> {
691 let n = distributions.node_count;
692 let mut results = Vec::new();
693
694 for i in 0..n {
695 for j in (i + 1)..n {
696 let p = distributions.get_distribution(i);
697 let q = distributions.get_distribution(j);
698
699 let jsd = Self::jensen_shannon_divergence(p, q);
700 let similarity = 1.0 - (jsd / 2.0_f64.ln()).sqrt();
701
702 if similarity >= min_similarity {
703 results.push(ValueSimilarityResult {
704 node_a: i,
705 node_b: j,
706 similarity,
707 distance: jsd,
708 });
709
710 if results.len() >= max_pairs {
711 return results;
712 }
713 }
714 }
715 }
716
717 results.sort_by(|a, b| {
719 b.similarity
720 .partial_cmp(&a.similarity)
721 .unwrap_or(std::cmp::Ordering::Equal)
722 });
723
724 results
725 }
726
727 pub fn compute_all_pairs_wasserstein(
729 distributions: &ValueDistribution,
730 min_similarity: f64,
731 max_pairs: usize,
732 ) -> Vec<ValueSimilarityResult> {
733 let n = distributions.node_count;
734 let mut results = Vec::new();
735
736 for i in 0..n {
737 for j in (i + 1)..n {
738 let p = distributions.get_distribution(i);
739 let q = distributions.get_distribution(j);
740
741 let w1 = Self::wasserstein_distance(p, q);
742 let similarity = 1.0 / (1.0 + w1);
743
744 if similarity >= min_similarity {
745 results.push(ValueSimilarityResult {
746 node_a: i,
747 node_b: j,
748 similarity,
749 distance: w1,
750 });
751
752 if results.len() >= max_pairs {
753 return results;
754 }
755 }
756 }
757 }
758
759 results.sort_by(|a, b| {
760 b.similarity
761 .partial_cmp(&a.similarity)
762 .unwrap_or(std::cmp::Ordering::Equal)
763 });
764
765 results
766 }
767
768 pub fn find_similar_nodes(
770 distributions: &ValueDistribution,
771 target_node: usize,
772 min_similarity: f64,
773 top_k: usize,
774 ) -> Vec<ValueSimilarityResult> {
775 let n = distributions.node_count;
776 let p = distributions.get_distribution(target_node);
777 let mut results = Vec::new();
778
779 for i in 0..n {
780 if i == target_node {
781 continue;
782 }
783
784 let q = distributions.get_distribution(i);
785 let similarity = Self::jsd_similarity(p, q);
786
787 if similarity >= min_similarity {
788 results.push(ValueSimilarityResult {
789 node_a: target_node,
790 node_b: i,
791 similarity,
792 distance: Self::jensen_shannon_divergence(p, q),
793 });
794 }
795 }
796
797 results.sort_by(|a, b| {
798 b.similarity
799 .partial_cmp(&a.similarity)
800 .unwrap_or(std::cmp::Ordering::Equal)
801 });
802
803 results.into_iter().take(top_k).collect()
804 }
805}
806
807impl GpuKernel for ValueSimilarity {
808 fn metadata(&self) -> &KernelMetadata {
809 &self.metadata
810 }
811}
812
813#[cfg(test)]
814mod value_similarity_tests {
815 use super::*;
816
817 #[test]
818 fn test_value_similarity_metadata() {
819 let kernel = ValueSimilarity::new();
820 assert_eq!(kernel.metadata().id, "graph/value-similarity");
821 assert_eq!(kernel.metadata().domain, Domain::GraphAnalytics);
822 }
823
824 #[test]
825 fn test_jsd_identical_distributions() {
826 let p = vec![0.25, 0.25, 0.25, 0.25];
827 let q = vec![0.25, 0.25, 0.25, 0.25];
828
829 let jsd = ValueSimilarity::jensen_shannon_divergence(&p, &q);
830 assert!(
831 jsd.abs() < 0.001,
832 "JSD of identical distributions should be 0"
833 );
834 }
835
836 #[test]
837 fn test_jsd_different_distributions() {
838 let p = vec![1.0, 0.0, 0.0, 0.0];
839 let q = vec![0.0, 0.0, 0.0, 1.0];
840
841 let jsd = ValueSimilarity::jensen_shannon_divergence(&p, &q);
842 assert!(
844 jsd > 0.6,
845 "JSD should be high for very different distributions"
846 );
847 }
848
849 #[test]
850 fn test_jsd_similarity() {
851 let p = vec![0.25, 0.25, 0.25, 0.25];
852 let q = vec![0.25, 0.25, 0.25, 0.25];
853
854 let sim = ValueSimilarity::jsd_similarity(&p, &q);
855 assert!(
856 (sim - 1.0).abs() < 0.01,
857 "Identical distributions should have similarity 1.0"
858 );
859 }
860
861 #[test]
862 fn test_wasserstein_identical() {
863 let p = vec![0.25, 0.25, 0.25, 0.25];
864 let q = vec![0.25, 0.25, 0.25, 0.25];
865
866 let w1 = ValueSimilarity::wasserstein_distance(&p, &q);
867 assert!(
868 w1.abs() < 0.001,
869 "Wasserstein of identical distributions should be 0"
870 );
871 }
872
873 #[test]
874 fn test_wasserstein_shifted() {
875 let p = vec![1.0, 0.0, 0.0, 0.0];
876 let q = vec![0.0, 1.0, 0.0, 0.0];
877
878 let w1 = ValueSimilarity::wasserstein_distance(&p, &q);
879 assert!((w1 - 1.0).abs() < 0.01);
881 }
882
883 #[test]
884 fn test_value_distribution_from_values() {
885 let values = vec![vec![1.0, 2.0, 3.0], vec![2.0, 3.0, 4.0]];
886
887 let dist = ValueDistribution::from_values(&values, 4);
888
889 assert_eq!(dist.node_count, 2);
890 assert_eq!(dist.bin_count, 4);
891
892 let sum0: f64 = dist.get_distribution(0).iter().sum();
894 let sum1: f64 = dist.get_distribution(1).iter().sum();
895 assert!((sum0 - 1.0).abs() < 0.01);
896 assert!((sum1 - 1.0).abs() < 0.01);
897 }
898
899 #[test]
900 fn test_find_similar_nodes() {
901 let values = vec![
902 vec![1.0, 2.0, 3.0],
903 vec![1.0, 2.0, 3.0], vec![10.0, 11.0, 12.0], ];
906
907 let dist = ValueDistribution::from_values(&values, 5);
908 let similar = ValueSimilarity::find_similar_nodes(&dist, 0, 0.5, 5);
909
910 assert!(!similar.is_empty());
912 assert_eq!(similar[0].node_b, 1);
913 }
914}