1use scirs2_core::ndarray::Array2;
29use sklears_core::error::Result;
30use sklears_core::traits::{Fit, Transform};
31use std::collections::{HashMap, HashSet, VecDeque};
32
33#[derive(Debug, Clone)]
35pub struct Graph {
37 pub adjacency: HashMap<usize, Vec<usize>>,
39 pub node_labels: Option<HashMap<usize, String>>,
41 pub edge_labels: Option<HashMap<(usize, usize), String>>,
43 pub num_nodes: usize,
45}
46
47impl Graph {
48 pub fn new(num_nodes: usize) -> Self {
50 Self {
51 adjacency: HashMap::new(),
52 node_labels: None,
53 edge_labels: None,
54 num_nodes,
55 }
56 }
57
58 pub fn add_edge(&mut self, from: usize, to: usize) {
60 self.adjacency.entry(from).or_default().push(to);
61 self.adjacency.entry(to).or_default().push(from);
62 }
63
64 pub fn add_directed_edge(&mut self, from: usize, to: usize) {
66 self.adjacency.entry(from).or_default().push(to);
67 }
68
69 pub fn set_node_labels(&mut self, labels: HashMap<usize, String>) {
71 self.node_labels = Some(labels);
72 }
73
74 pub fn set_edge_labels(&mut self, labels: HashMap<(usize, usize), String>) {
76 self.edge_labels = Some(labels);
77 }
78
79 pub fn neighbors(&self, node: usize) -> Vec<usize> {
81 self.adjacency.get(&node).cloned().unwrap_or_default()
82 }
83
84 pub fn nodes(&self) -> Vec<usize> {
86 (0..self.num_nodes).collect()
87 }
88
89 pub fn edges(&self) -> Vec<(usize, usize)> {
91 let mut edges = Vec::new();
92 for (&from, neighbors) in &self.adjacency {
93 for &to in neighbors {
94 if from <= to {
95 edges.push((from, to));
97 }
98 }
99 }
100 edges
101 }
102}
103
104#[derive(Debug, Clone)]
106pub struct RandomWalkKernel {
108 max_length: usize,
110 lambda: f64,
112 use_node_labels: bool,
114 use_edge_labels: bool,
116}
117
118impl RandomWalkKernel {
119 pub fn new(max_length: usize, lambda: f64) -> Self {
120 Self {
121 max_length,
122 lambda,
123 use_node_labels: false,
124 use_edge_labels: false,
125 }
126 }
127
128 pub fn use_node_labels(mut self, use_labels: bool) -> Self {
130 self.use_node_labels = use_labels;
131 self
132 }
133
134 pub fn use_edge_labels(mut self, use_labels: bool) -> Self {
136 self.use_edge_labels = use_labels;
137 self
138 }
139
140 fn product_graph(&self, g1: &Graph, g2: &Graph) -> Graph {
142 let mut product = Graph::new(g1.num_nodes * g2.num_nodes);
143
144 for i in 0..g1.num_nodes {
146 for j in 0..g2.num_nodes {
147 let node_ij = i * g2.num_nodes + j;
148
149 let nodes_match = if self.use_node_labels {
151 if let (Some(labels1), Some(labels2)) = (&g1.node_labels, &g2.node_labels) {
152 labels1.get(&i) == labels2.get(&j)
153 } else {
154 true
155 }
156 } else {
157 true
158 };
159
160 if !nodes_match {
161 continue;
162 }
163
164 for &neighbor_i in &g1.neighbors(i) {
166 for &neighbor_j in &g2.neighbors(j) {
167 let neighbor_ij = neighbor_i * g2.num_nodes + neighbor_j;
168
169 let edges_match = if self.use_edge_labels {
171 if let (Some(labels1), Some(labels2)) =
172 (&g1.edge_labels, &g2.edge_labels)
173 {
174 labels1.get(&(i, neighbor_i)) == labels2.get(&(j, neighbor_j))
175 } else {
176 true
177 }
178 } else {
179 true
180 };
181
182 if edges_match {
183 product.add_directed_edge(node_ij, neighbor_ij);
184 }
185 }
186 }
187 }
188 }
189
190 product
191 }
192
193 fn kernel_value(&self, g1: &Graph, g2: &Graph) -> f64 {
195 let product = self.product_graph(g1, g2);
196
197 let n = product.num_nodes;
199 if n == 0 {
200 return 0.0;
201 }
202
203 let mut adj = Array2::zeros((n, n));
205 for (&from, neighbors) in &product.adjacency {
206 for &to in neighbors {
207 adj[(from, to)] = 1.0;
208 }
209 }
210
211 let result = Array2::eye(n);
213 let mut current_power = Array2::eye(n);
214 let mut total = result.clone();
215
216 for k in 1..=self.max_length {
217 current_power = current_power.dot(&adj);
218 total = total + self.lambda.powi(k as i32) * ¤t_power;
219 }
220
221 total.sum()
223 }
224}
225
226#[derive(Debug, Clone)]
228pub struct FittedRandomWalkKernel {
230 training_graphs: Vec<Graph>,
232 max_length: usize,
234 lambda: f64,
235 use_node_labels: bool,
236 use_edge_labels: bool,
237}
238
239impl Fit<Vec<Graph>, ()> for RandomWalkKernel {
240 type Fitted = FittedRandomWalkKernel;
241 fn fit(self, graphs: &Vec<Graph>, _y: &()) -> Result<Self::Fitted> {
242 Ok(FittedRandomWalkKernel {
243 training_graphs: graphs.clone(),
244 max_length: self.max_length,
245 lambda: self.lambda,
246 use_node_labels: self.use_node_labels,
247 use_edge_labels: self.use_edge_labels,
248 })
249 }
250}
251
252impl Transform<Vec<Graph>, Array2<f64>> for FittedRandomWalkKernel {
253 fn transform(&self, graphs: &Vec<Graph>) -> Result<Array2<f64>> {
254 let n_test = graphs.len();
255 let n_train = self.training_graphs.len();
256 let mut kernel_matrix = Array2::zeros((n_test, n_train));
257
258 let kernel = RandomWalkKernel {
259 max_length: self.max_length,
260 lambda: self.lambda,
261 use_node_labels: self.use_node_labels,
262 use_edge_labels: self.use_edge_labels,
263 };
264
265 for i in 0..n_test {
266 for j in 0..n_train {
267 kernel_matrix[(i, j)] = kernel.kernel_value(&graphs[i], &self.training_graphs[j]);
268 }
269 }
270
271 Ok(kernel_matrix)
272 }
273}
274
275#[derive(Debug, Clone)]
277pub struct ShortestPathKernel {
279 use_node_labels: bool,
281 normalize: bool,
283}
284
285impl Default for ShortestPathKernel {
286 fn default() -> Self {
287 Self::new()
288 }
289}
290
291impl ShortestPathKernel {
292 pub fn new() -> Self {
293 Self {
294 use_node_labels: false,
295 normalize: true,
296 }
297 }
298
299 pub fn use_node_labels(mut self, use_labels: bool) -> Self {
301 self.use_node_labels = use_labels;
302 self
303 }
304
305 pub fn normalize(mut self, normalize: bool) -> Self {
307 self.normalize = normalize;
308 self
309 }
310
311 fn all_pairs_shortest_paths(&self, graph: &Graph) -> HashMap<(usize, usize), usize> {
313 let mut distances = HashMap::new();
314 let nodes = graph.nodes();
315
316 for &i in &nodes {
318 for &j in &nodes {
319 if i == j {
320 distances.insert((i, j), 0);
321 } else {
322 distances.insert((i, j), usize::MAX);
323 }
324 }
325 }
326
327 for (&from, neighbors) in &graph.adjacency {
329 for &to in neighbors {
330 distances.insert((from, to), 1);
331 }
332 }
333
334 for &k in &nodes {
336 for &i in &nodes {
337 for &j in &nodes {
338 if let (Some(&dist_ik), Some(&dist_kj)) =
339 (distances.get(&(i, k)), distances.get(&(k, j)))
340 {
341 if dist_ik != usize::MAX && dist_kj != usize::MAX {
342 let new_dist = dist_ik + dist_kj;
343 if let Some(current_dist) = distances.get_mut(&(i, j)) {
344 if new_dist < *current_dist {
345 *current_dist = new_dist;
346 }
347 }
348 }
349 }
350 }
351 }
352 }
353
354 distances
355 }
356
357 fn extract_features(&self, graph: &Graph) -> HashMap<String, usize> {
359 let distances = self.all_pairs_shortest_paths(graph);
360 let mut features = HashMap::new();
361
362 for ((i, j), &dist) in &distances {
363 if dist != usize::MAX {
364 let feature = if self.use_node_labels {
365 if let Some(ref labels) = graph.node_labels {
366 let label_i = labels.get(i).cloned().unwrap_or_default();
367 let label_j = labels.get(j).cloned().unwrap_or_default();
368 format!("{}:{}:{}", label_i, label_j, dist)
369 } else {
370 format!("path:{}", dist)
371 }
372 } else {
373 format!("path:{}", dist)
374 };
375
376 *features.entry(feature).or_insert(0) += 1;
377 }
378 }
379
380 features
381 }
382
383 fn kernel_value(&self, g1: &Graph, g2: &Graph) -> f64 {
385 let features1 = self.extract_features(g1);
386 let features2 = self.extract_features(g2);
387
388 let mut dot_product = 0.0;
389 for (feature, &count1) in &features1 {
390 if let Some(&count2) = features2.get(feature) {
391 dot_product += (count1 * count2) as f64;
392 }
393 }
394
395 if self.normalize {
396 let norm1 = features1
397 .values()
398 .map(|&x| (x * x) as f64)
399 .sum::<f64>()
400 .sqrt();
401 let norm2 = features2
402 .values()
403 .map(|&x| (x * x) as f64)
404 .sum::<f64>()
405 .sqrt();
406 if norm1 > 0.0 && norm2 > 0.0 {
407 dot_product / (norm1 * norm2)
408 } else {
409 0.0
410 }
411 } else {
412 dot_product
413 }
414 }
415}
416
417#[derive(Debug, Clone)]
419pub struct FittedShortestPathKernel {
421 training_graphs: Vec<Graph>,
423 use_node_labels: bool,
425 normalize: bool,
427}
428
429impl Fit<Vec<Graph>, ()> for ShortestPathKernel {
430 type Fitted = FittedShortestPathKernel;
431
432 fn fit(self, graphs: &Vec<Graph>, _y: &()) -> Result<Self::Fitted> {
433 Ok(FittedShortestPathKernel {
434 training_graphs: graphs.clone(),
435 use_node_labels: self.use_node_labels,
436 normalize: self.normalize,
437 })
438 }
439}
440
441impl Transform<Vec<Graph>, Array2<f64>> for FittedShortestPathKernel {
442 fn transform(&self, graphs: &Vec<Graph>) -> Result<Array2<f64>> {
443 let n_test = graphs.len();
444 let n_train = self.training_graphs.len();
445 let mut kernel_matrix = Array2::zeros((n_test, n_train));
446
447 let kernel = ShortestPathKernel {
448 use_node_labels: self.use_node_labels,
449 normalize: self.normalize,
450 };
451
452 for i in 0..n_test {
453 for j in 0..n_train {
454 kernel_matrix[(i, j)] = kernel.kernel_value(&graphs[i], &self.training_graphs[j]);
455 }
456 }
457
458 Ok(kernel_matrix)
459 }
460}
461
462#[derive(Debug, Clone)]
464pub struct WeisfeilerLehmanKernel {
466 iterations: usize,
468 use_node_labels: bool,
470 normalize: bool,
472}
473
474impl WeisfeilerLehmanKernel {
475 pub fn new(iterations: usize) -> Self {
476 Self {
477 iterations,
478 use_node_labels: false,
479 normalize: true,
480 }
481 }
482
483 pub fn use_node_labels(mut self, use_labels: bool) -> Self {
485 self.use_node_labels = use_labels;
486 self
487 }
488
489 pub fn normalize(mut self, normalize: bool) -> Self {
491 self.normalize = normalize;
492 self
493 }
494
495 fn wl_relabel(&self, graph: &Graph) -> Vec<HashMap<usize, String>> {
497 let mut labelings = Vec::new();
498 let nodes = graph.nodes();
499
500 let mut current_labels = HashMap::new();
502 for &node in &nodes {
503 let initial_label = if self.use_node_labels {
504 graph
505 .node_labels
506 .as_ref()
507 .and_then(|labels| labels.get(&node))
508 .cloned()
509 .unwrap_or_else(|| "default".to_string())
510 } else {
511 "1".to_string()
512 };
513 current_labels.insert(node, initial_label);
514 }
515 labelings.push(current_labels.clone());
516
517 for _iter in 0..self.iterations {
519 let mut new_labels = HashMap::new();
520
521 for &node in &nodes {
522 let mut neighbor_labels = Vec::new();
523 for &neighbor in &graph.neighbors(node) {
524 if let Some(label) = current_labels.get(&neighbor) {
525 neighbor_labels.push(label.clone());
526 }
527 }
528 neighbor_labels.sort();
529
530 let current_label = current_labels.get(&node).cloned().unwrap_or_default();
531 let new_label = format!("{}:{}", current_label, neighbor_labels.join(","));
532 new_labels.insert(node, new_label);
533 }
534
535 labelings.push(new_labels.clone());
536 current_labels = new_labels;
537 }
538
539 labelings
540 }
541
542 fn extract_features(&self, graph: &Graph) -> HashMap<String, usize> {
544 let labelings = self.wl_relabel(graph);
545 let mut features = HashMap::new();
546
547 for labeling in labelings {
548 for (_, label) in labeling {
549 *features.entry(label).or_insert(0) += 1;
550 }
551 }
552
553 features
554 }
555
556 fn kernel_value(&self, g1: &Graph, g2: &Graph) -> f64 {
558 let features1 = self.extract_features(g1);
559 let features2 = self.extract_features(g2);
560
561 let mut dot_product = 0.0;
562 for (feature, &count1) in &features1 {
563 if let Some(&count2) = features2.get(feature) {
564 dot_product += (count1 * count2) as f64;
565 }
566 }
567
568 if self.normalize {
569 let norm1 = features1
570 .values()
571 .map(|&x| (x * x) as f64)
572 .sum::<f64>()
573 .sqrt();
574 let norm2 = features2
575 .values()
576 .map(|&x| (x * x) as f64)
577 .sum::<f64>()
578 .sqrt();
579 if norm1 > 0.0 && norm2 > 0.0 {
580 dot_product / (norm1 * norm2)
581 } else {
582 0.0
583 }
584 } else {
585 dot_product
586 }
587 }
588}
589
590#[derive(Debug, Clone)]
592pub struct FittedWeisfeilerLehmanKernel {
594 training_graphs: Vec<Graph>,
596 iterations: usize,
598 use_node_labels: bool,
600 normalize: bool,
602}
603
604impl Fit<Vec<Graph>, ()> for WeisfeilerLehmanKernel {
605 type Fitted = FittedWeisfeilerLehmanKernel;
606 fn fit(self, graphs: &Vec<Graph>, _y: &()) -> Result<Self::Fitted> {
607 Ok(FittedWeisfeilerLehmanKernel {
608 training_graphs: graphs.clone(),
609 iterations: self.iterations,
610 use_node_labels: self.use_node_labels,
611 normalize: self.normalize,
612 })
613 }
614}
615
616impl Transform<Vec<Graph>, Array2<f64>> for FittedWeisfeilerLehmanKernel {
617 fn transform(&self, graphs: &Vec<Graph>) -> Result<Array2<f64>> {
618 let n_test = graphs.len();
619 let n_train = self.training_graphs.len();
620 let mut kernel_matrix = Array2::zeros((n_test, n_train));
621
622 let kernel = WeisfeilerLehmanKernel {
623 iterations: self.iterations,
624 use_node_labels: self.use_node_labels,
625 normalize: self.normalize,
626 };
627
628 for i in 0..n_test {
629 for j in 0..n_train {
630 kernel_matrix[(i, j)] = kernel.kernel_value(&graphs[i], &self.training_graphs[j]);
631 }
632 }
633
634 Ok(kernel_matrix)
635 }
636}
637
638#[derive(Debug, Clone)]
640pub struct SubgraphKernel {
642 max_size: usize,
644 connected_only: bool,
646 normalize: bool,
648}
649
650impl SubgraphKernel {
651 pub fn new(max_size: usize) -> Self {
652 Self {
653 max_size,
654 connected_only: true,
655 normalize: true,
656 }
657 }
658
659 pub fn connected_only(mut self, connected: bool) -> Self {
661 self.connected_only = connected;
662 self
663 }
664
665 pub fn normalize(mut self, normalize: bool) -> Self {
667 self.normalize = normalize;
668 self
669 }
670
671 fn find_connected_subgraphs(&self, graph: &Graph, size: usize) -> Vec<Vec<usize>> {
673 if size == 0 {
674 return vec![];
675 }
676
677 let mut subgraphs = Vec::new();
678 let nodes = graph.nodes();
679
680 let combinations = self.combinations(&nodes, size);
682
683 for combination in combinations {
684 if self.is_connected_subgraph(graph, &combination) {
685 subgraphs.push(combination);
686 }
687 }
688
689 subgraphs
690 }
691
692 fn is_connected_subgraph(&self, graph: &Graph, nodes: &[usize]) -> bool {
694 if nodes.len() <= 1 {
695 return true;
696 }
697
698 let node_set: HashSet<_> = nodes.iter().collect();
699 let mut visited = HashSet::new();
700 let mut queue = VecDeque::new();
701
702 queue.push_back(nodes[0]);
704 visited.insert(nodes[0]);
705
706 while let Some(current) = queue.pop_front() {
707 for &neighbor in &graph.neighbors(current) {
708 if node_set.contains(&neighbor) && !visited.contains(&neighbor) {
709 visited.insert(neighbor);
710 queue.push_back(neighbor);
711 }
712 }
713 }
714
715 visited.len() == nodes.len()
716 }
717
718 fn combinations(&self, items: &[usize], k: usize) -> Vec<Vec<usize>> {
720 if k == 0 {
721 return vec![vec![]];
722 }
723 if k > items.len() {
724 return vec![];
725 }
726 if k == items.len() {
727 return vec![items.to_vec()];
728 }
729
730 let mut result = Vec::new();
731
732 let with_first = self.combinations(&items[1..], k - 1);
734 for mut combo in with_first {
735 combo.insert(0, items[0]);
736 result.push(combo);
737 }
738
739 let without_first = self.combinations(&items[1..], k);
741 result.extend(without_first);
742
743 result
744 }
745
746 fn subgraph_to_string(&self, graph: &Graph, nodes: &[usize]) -> String {
748 let mut edges = Vec::new();
749 let node_set: HashSet<_> = nodes.iter().collect();
750
751 for &node in nodes {
752 for &neighbor in &graph.neighbors(node) {
753 if node_set.contains(&neighbor) && node < neighbor {
754 edges.push((node, neighbor));
755 }
756 }
757 }
758
759 edges.sort();
760 format!("nodes:{},edges:{:?}", nodes.len(), edges)
761 }
762
763 fn extract_features(&self, graph: &Graph) -> HashMap<String, usize> {
765 let mut features = HashMap::new();
766
767 for size in 1..=self.max_size {
768 let subgraphs = if self.connected_only {
769 self.find_connected_subgraphs(graph, size)
770 } else {
771 self.find_connected_subgraphs(graph, size)
773 };
774
775 for subgraph in subgraphs {
776 let feature = self.subgraph_to_string(graph, &subgraph);
777 *features.entry(feature).or_insert(0) += 1;
778 }
779 }
780
781 features
782 }
783
784 fn kernel_value(&self, g1: &Graph, g2: &Graph) -> f64 {
786 let features1 = self.extract_features(g1);
787 let features2 = self.extract_features(g2);
788
789 let mut dot_product = 0.0;
790 for (feature, &count1) in &features1 {
791 if let Some(&count2) = features2.get(feature) {
792 dot_product += (count1 * count2) as f64;
793 }
794 }
795
796 if self.normalize {
797 let norm1 = features1
798 .values()
799 .map(|&x| (x * x) as f64)
800 .sum::<f64>()
801 .sqrt();
802 let norm2 = features2
803 .values()
804 .map(|&x| (x * x) as f64)
805 .sum::<f64>()
806 .sqrt();
807 if norm1 > 0.0 && norm2 > 0.0 {
808 dot_product / (norm1 * norm2)
809 } else {
810 0.0
811 }
812 } else {
813 dot_product
814 }
815 }
816}
817
818#[derive(Debug, Clone)]
820pub struct FittedSubgraphKernel {
822 training_graphs: Vec<Graph>,
824 max_size: usize,
826 connected_only: bool,
828 normalize: bool,
830}
831
832impl Fit<Vec<Graph>, ()> for SubgraphKernel {
833 type Fitted = FittedSubgraphKernel;
834
835 fn fit(self, graphs: &Vec<Graph>, _y: &()) -> Result<Self::Fitted> {
836 Ok(FittedSubgraphKernel {
837 training_graphs: graphs.clone(),
838 max_size: self.max_size,
839 connected_only: self.connected_only,
840 normalize: self.normalize,
841 })
842 }
843}
844
845impl Transform<Vec<Graph>, Array2<f64>> for FittedSubgraphKernel {
846 fn transform(&self, graphs: &Vec<Graph>) -> Result<Array2<f64>> {
847 let n_test = graphs.len();
848 let n_train = self.training_graphs.len();
849 let mut kernel_matrix = Array2::zeros((n_test, n_train));
850
851 let kernel = SubgraphKernel {
852 max_size: self.max_size,
853 connected_only: self.connected_only,
854 normalize: self.normalize,
855 };
856
857 for i in 0..n_test {
858 for j in 0..n_train {
859 kernel_matrix[(i, j)] = kernel.kernel_value(&graphs[i], &self.training_graphs[j]);
860 }
861 }
862
863 Ok(kernel_matrix)
864 }
865}
866
867#[allow(non_snake_case)]
868#[cfg(test)]
869mod tests {
870 use super::*;
871 use approx::assert_abs_diff_eq;
872
873 fn create_test_graph(edges: Vec<(usize, usize)>, num_nodes: usize) -> Graph {
874 let mut graph = Graph::new(num_nodes);
875 for (from, to) in edges {
876 graph.add_edge(from, to);
877 }
878 graph
879 }
880
881 #[test]
882 fn test_graph_creation() {
883 let mut graph = Graph::new(3);
884 graph.add_edge(0, 1);
885 graph.add_edge(1, 2);
886
887 assert_eq!(graph.neighbors(0), vec![1]);
888 assert_eq!(graph.neighbors(1), vec![0, 2]);
889 assert_eq!(graph.neighbors(2), vec![1]);
890 assert_eq!(graph.nodes(), vec![0, 1, 2]);
891 }
892
893 #[test]
894 fn test_random_walk_kernel() {
895 let kernel = RandomWalkKernel::new(3, 0.1);
896
897 let graph1 = create_test_graph(vec![(0, 1), (1, 2)], 3);
898 let graph2 = create_test_graph(vec![(0, 1), (1, 2)], 3);
899 let graph3 = create_test_graph(vec![(0, 1), (0, 2), (1, 2)], 3);
900
901 let graphs = vec![graph1, graph2, graph3];
902 let fitted = kernel.fit(&graphs, &()).unwrap();
903 let kernel_matrix = fitted.transform(&graphs).unwrap();
904
905 assert_eq!(kernel_matrix.shape(), &[3, 3]);
906
907 assert_abs_diff_eq!(
909 kernel_matrix[(0, 0)],
910 kernel_matrix[(1, 1)],
911 epsilon = 1e-10
912 );
913 assert_abs_diff_eq!(
914 kernel_matrix[(0, 1)],
915 kernel_matrix[(1, 0)],
916 epsilon = 1e-10
917 );
918 }
919
920 #[test]
921 fn test_shortest_path_kernel() {
922 let kernel = ShortestPathKernel::new();
923
924 let graph1 = create_test_graph(vec![(0, 1), (1, 2)], 3);
925 let graph2 = create_test_graph(vec![(0, 1), (1, 2), (0, 2)], 3);
926
927 let graphs = vec![graph1, graph2];
928 let fitted = kernel.fit(&graphs, &()).unwrap();
929 let kernel_matrix = fitted.transform(&graphs).unwrap();
930
931 assert_eq!(kernel_matrix.shape(), &[2, 2]);
932 assert!(kernel_matrix.iter().all(|&x| x >= 0.0 && x.is_finite()));
934
935 assert_abs_diff_eq!(kernel_matrix[(0, 0)], 1.0, epsilon = 1e-10);
937 assert_abs_diff_eq!(kernel_matrix[(1, 1)], 1.0, epsilon = 1e-10);
938
939 assert_abs_diff_eq!(
941 kernel_matrix[(0, 1)],
942 kernel_matrix[(1, 0)],
943 epsilon = 1e-10
944 );
945
946 assert!(kernel_matrix[(0, 1)] >= 0.0 && kernel_matrix[(0, 1)] <= 1.0);
948 }
949
950 #[test]
951 fn test_weisfeiler_lehman_kernel() {
952 let kernel = WeisfeilerLehmanKernel::new(2);
953
954 let graph1 = create_test_graph(vec![(0, 1), (1, 2)], 3);
955 let graph2 = create_test_graph(vec![(0, 1), (1, 2)], 3);
956 let graph3 = create_test_graph(vec![(0, 1), (0, 2), (1, 2)], 3);
957
958 let graphs = vec![graph1, graph2, graph3];
959 let fitted = kernel.fit(&graphs, &()).unwrap();
960 let kernel_matrix = fitted.transform(&graphs).unwrap();
961
962 assert_eq!(kernel_matrix.shape(), &[3, 3]);
963 assert!(kernel_matrix
964 .iter()
965 .all(|&x| x >= 0.0 && x <= 1.0 && x.is_finite()));
966
967 assert_abs_diff_eq!(kernel_matrix[(0, 1)], 1.0, epsilon = 1e-10);
969 }
970
971 #[test]
972 fn test_subgraph_kernel() {
973 let kernel = SubgraphKernel::new(2);
974
975 let graph1 = create_test_graph(vec![(0, 1)], 2);
976 let graph2 = create_test_graph(vec![(0, 1), (1, 2)], 3);
977
978 let graphs = vec![graph1, graph2];
979 let fitted = kernel.fit(&graphs, &()).unwrap();
980 let kernel_matrix = fitted.transform(&graphs).unwrap();
981
982 assert_eq!(kernel_matrix.shape(), &[2, 2]);
983 assert!(kernel_matrix.iter().all(|&x| x >= 0.0 && x.is_finite()));
984 }
985
986 #[test]
987 fn test_graph_with_labels() {
988 let mut graph = Graph::new(3);
989 graph.add_edge(0, 1);
990 graph.add_edge(1, 2);
991
992 let mut labels = HashMap::new();
993 labels.insert(0, "A".to_string());
994 labels.insert(1, "B".to_string());
995 labels.insert(2, "A".to_string());
996 graph.set_node_labels(labels);
997
998 let kernel = WeisfeilerLehmanKernel::new(1).use_node_labels(true);
999 let graphs = vec![graph];
1000 let fitted = kernel.fit(&graphs, &()).unwrap();
1001 let kernel_matrix = fitted.transform(&graphs).unwrap();
1002
1003 assert_eq!(kernel_matrix.shape(), &[1, 1]);
1004 assert_abs_diff_eq!(kernel_matrix[(0, 0)], 1.0, epsilon = 1e-10);
1005 }
1006
1007 #[test]
1008 fn test_shortest_path_computation() {
1009 let kernel = ShortestPathKernel::new();
1010 let graph = create_test_graph(vec![(0, 1), (1, 2), (2, 3)], 4);
1011
1012 let distances = kernel.all_pairs_shortest_paths(&graph);
1013
1014 assert_eq!(distances[&(0, 3)], 3);
1015 assert_eq!(distances[&(0, 1)], 1);
1016 assert_eq!(distances[&(1, 3)], 2);
1017 }
1018
1019 #[test]
1020 fn test_subgraph_connectivity() {
1021 let kernel = SubgraphKernel::new(3);
1022 let graph = create_test_graph(vec![(0, 1), (2, 3)], 4); assert!(!kernel.is_connected_subgraph(&graph, &[0, 1, 2]));
1025 assert!(kernel.is_connected_subgraph(&graph, &[0, 1]));
1026 assert!(kernel.is_connected_subgraph(&graph, &[2, 3]));
1027 }
1028}