1use scirs2_core::ndarray::Array2;
29use sklears_core::error::Result;
30use sklears_core::traits::{Fit, Transform};
31use std::collections::{HashMap, HashSet, VecDeque};
32
33#[derive(Debug, Clone)]
35pub struct Graph {
37 pub adjacency: HashMap<usize, Vec<usize>>,
39 pub node_labels: Option<HashMap<usize, String>>,
41 pub edge_labels: Option<HashMap<(usize, usize), String>>,
43 pub num_nodes: usize,
45}
46
47impl Graph {
48 pub fn new(num_nodes: usize) -> Self {
50 Self {
51 adjacency: HashMap::new(),
52 node_labels: None,
53 edge_labels: None,
54 num_nodes,
55 }
56 }
57
58 pub fn add_edge(&mut self, from: usize, to: usize) {
60 self.adjacency.entry(from).or_insert_with(Vec::new).push(to);
61 self.adjacency.entry(to).or_insert_with(Vec::new).push(from);
62 }
63
64 pub fn add_directed_edge(&mut self, from: usize, to: usize) {
66 self.adjacency.entry(from).or_insert_with(Vec::new).push(to);
67 }
68
69 pub fn set_node_labels(&mut self, labels: HashMap<usize, String>) {
71 self.node_labels = Some(labels);
72 }
73
74 pub fn set_edge_labels(&mut self, labels: HashMap<(usize, usize), String>) {
76 self.edge_labels = Some(labels);
77 }
78
79 pub fn neighbors(&self, node: usize) -> Vec<usize> {
81 self.adjacency.get(&node).cloned().unwrap_or_default()
82 }
83
84 pub fn nodes(&self) -> Vec<usize> {
86 (0..self.num_nodes).collect()
87 }
88
89 pub fn edges(&self) -> Vec<(usize, usize)> {
91 let mut edges = Vec::new();
92 for (&from, neighbors) in &self.adjacency {
93 for &to in neighbors {
94 if from <= to {
95 edges.push((from, to));
97 }
98 }
99 }
100 edges
101 }
102}
103
104#[derive(Debug, Clone)]
106pub struct RandomWalkKernel {
108 max_length: usize,
110 lambda: f64,
112 use_node_labels: bool,
114 use_edge_labels: bool,
116}
117
118impl RandomWalkKernel {
119 pub fn new(max_length: usize, lambda: f64) -> Self {
120 Self {
121 max_length,
122 lambda,
123 use_node_labels: false,
124 use_edge_labels: false,
125 }
126 }
127
128 pub fn use_node_labels(mut self, use_labels: bool) -> Self {
130 self.use_node_labels = use_labels;
131 self
132 }
133
134 pub fn use_edge_labels(mut self, use_labels: bool) -> Self {
136 self.use_edge_labels = use_labels;
137 self
138 }
139
140 fn product_graph(&self, g1: &Graph, g2: &Graph) -> Graph {
142 let mut product = Graph::new(g1.num_nodes * g2.num_nodes);
143
144 for i in 0..g1.num_nodes {
146 for j in 0..g2.num_nodes {
147 let node_ij = i * g2.num_nodes + j;
148
149 let nodes_match = if self.use_node_labels {
151 if let (Some(labels1), Some(labels2)) = (&g1.node_labels, &g2.node_labels) {
152 labels1.get(&i) == labels2.get(&j)
153 } else {
154 true
155 }
156 } else {
157 true
158 };
159
160 if !nodes_match {
161 continue;
162 }
163
164 for &neighbor_i in &g1.neighbors(i) {
166 for &neighbor_j in &g2.neighbors(j) {
167 let neighbor_ij = neighbor_i * g2.num_nodes + neighbor_j;
168
169 let edges_match = if self.use_edge_labels {
171 if let (Some(labels1), Some(labels2)) =
172 (&g1.edge_labels, &g2.edge_labels)
173 {
174 labels1.get(&(i, neighbor_i)) == labels2.get(&(j, neighbor_j))
175 } else {
176 true
177 }
178 } else {
179 true
180 };
181
182 if edges_match {
183 product.add_directed_edge(node_ij, neighbor_ij);
184 }
185 }
186 }
187 }
188 }
189
190 product
191 }
192
193 fn kernel_value(&self, g1: &Graph, g2: &Graph) -> f64 {
195 let product = self.product_graph(g1, g2);
196
197 let n = product.num_nodes;
199 if n == 0 {
200 return 0.0;
201 }
202
203 let mut adj = Array2::zeros((n, n));
205 for (&from, neighbors) in &product.adjacency {
206 for &to in neighbors {
207 adj[(from, to)] = 1.0;
208 }
209 }
210
211 let result = Array2::eye(n);
213 let mut current_power = Array2::eye(n);
214 let mut total = result.clone();
215
216 for k in 1..=self.max_length {
217 current_power = current_power.dot(&adj);
218 total = total + self.lambda.powi(k as i32) * ¤t_power;
219 }
220
221 total.sum()
223 }
224}
225
226#[derive(Debug, Clone)]
228pub struct FittedRandomWalkKernel {
230 training_graphs: Vec<Graph>,
232 max_length: usize,
234 lambda: f64,
235 use_node_labels: bool,
236 use_edge_labels: bool,
237}
238
239impl Fit<Vec<Graph>, ()> for RandomWalkKernel {
240 type Fitted = FittedRandomWalkKernel;
241 fn fit(self, graphs: &Vec<Graph>, _y: &()) -> Result<Self::Fitted> {
242 Ok(FittedRandomWalkKernel {
243 training_graphs: graphs.clone(),
244 max_length: self.max_length,
245 lambda: self.lambda,
246 use_node_labels: self.use_node_labels,
247 use_edge_labels: self.use_edge_labels,
248 })
249 }
250}
251
252impl Transform<Vec<Graph>, Array2<f64>> for FittedRandomWalkKernel {
253 fn transform(&self, graphs: &Vec<Graph>) -> Result<Array2<f64>> {
254 let n_test = graphs.len();
255 let n_train = self.training_graphs.len();
256 let mut kernel_matrix = Array2::zeros((n_test, n_train));
257
258 let kernel = RandomWalkKernel {
259 max_length: self.max_length,
260 lambda: self.lambda,
261 use_node_labels: self.use_node_labels,
262 use_edge_labels: self.use_edge_labels,
263 };
264
265 for i in 0..n_test {
266 for j in 0..n_train {
267 kernel_matrix[(i, j)] = kernel.kernel_value(&graphs[i], &self.training_graphs[j]);
268 }
269 }
270
271 Ok(kernel_matrix)
272 }
273}
274
275#[derive(Debug, Clone)]
277pub struct ShortestPathKernel {
279 use_node_labels: bool,
281 normalize: bool,
283}
284
285impl ShortestPathKernel {
286 pub fn new() -> Self {
287 Self {
288 use_node_labels: false,
289 normalize: true,
290 }
291 }
292
293 pub fn use_node_labels(mut self, use_labels: bool) -> Self {
295 self.use_node_labels = use_labels;
296 self
297 }
298
299 pub fn normalize(mut self, normalize: bool) -> Self {
301 self.normalize = normalize;
302 self
303 }
304
305 fn all_pairs_shortest_paths(&self, graph: &Graph) -> HashMap<(usize, usize), usize> {
307 let mut distances = HashMap::new();
308 let nodes = graph.nodes();
309
310 for &i in &nodes {
312 for &j in &nodes {
313 if i == j {
314 distances.insert((i, j), 0);
315 } else {
316 distances.insert((i, j), usize::MAX);
317 }
318 }
319 }
320
321 for (&from, neighbors) in &graph.adjacency {
323 for &to in neighbors {
324 distances.insert((from, to), 1);
325 }
326 }
327
328 for &k in &nodes {
330 for &i in &nodes {
331 for &j in &nodes {
332 if let (Some(&dist_ik), Some(&dist_kj)) =
333 (distances.get(&(i, k)), distances.get(&(k, j)))
334 {
335 if dist_ik != usize::MAX && dist_kj != usize::MAX {
336 let new_dist = dist_ik + dist_kj;
337 if let Some(current_dist) = distances.get_mut(&(i, j)) {
338 if new_dist < *current_dist {
339 *current_dist = new_dist;
340 }
341 }
342 }
343 }
344 }
345 }
346 }
347
348 distances
349 }
350
351 fn extract_features(&self, graph: &Graph) -> HashMap<String, usize> {
353 let distances = self.all_pairs_shortest_paths(graph);
354 let mut features = HashMap::new();
355
356 for ((i, j), &dist) in &distances {
357 if dist != usize::MAX {
358 let feature = if self.use_node_labels {
359 if let Some(ref labels) = graph.node_labels {
360 let label_i = labels.get(i).cloned().unwrap_or_default();
361 let label_j = labels.get(j).cloned().unwrap_or_default();
362 format!("{}:{}:{}", label_i, label_j, dist)
363 } else {
364 format!("path:{}", dist)
365 }
366 } else {
367 format!("path:{}", dist)
368 };
369
370 *features.entry(feature).or_insert(0) += 1;
371 }
372 }
373
374 features
375 }
376
377 fn kernel_value(&self, g1: &Graph, g2: &Graph) -> f64 {
379 let features1 = self.extract_features(g1);
380 let features2 = self.extract_features(g2);
381
382 let mut dot_product = 0.0;
383 for (feature, &count1) in &features1 {
384 if let Some(&count2) = features2.get(feature) {
385 dot_product += (count1 * count2) as f64;
386 }
387 }
388
389 if self.normalize {
390 let norm1 = features1
391 .values()
392 .map(|&x| (x * x) as f64)
393 .sum::<f64>()
394 .sqrt();
395 let norm2 = features2
396 .values()
397 .map(|&x| (x * x) as f64)
398 .sum::<f64>()
399 .sqrt();
400 if norm1 > 0.0 && norm2 > 0.0 {
401 dot_product / (norm1 * norm2)
402 } else {
403 0.0
404 }
405 } else {
406 dot_product
407 }
408 }
409}
410
411#[derive(Debug, Clone)]
413pub struct FittedShortestPathKernel {
415 training_graphs: Vec<Graph>,
417 use_node_labels: bool,
419 normalize: bool,
421}
422
423impl Fit<Vec<Graph>, ()> for ShortestPathKernel {
424 type Fitted = FittedShortestPathKernel;
425
426 fn fit(self, graphs: &Vec<Graph>, _y: &()) -> Result<Self::Fitted> {
427 Ok(FittedShortestPathKernel {
428 training_graphs: graphs.clone(),
429 use_node_labels: self.use_node_labels,
430 normalize: self.normalize,
431 })
432 }
433}
434
435impl Transform<Vec<Graph>, Array2<f64>> for FittedShortestPathKernel {
436 fn transform(&self, graphs: &Vec<Graph>) -> Result<Array2<f64>> {
437 let n_test = graphs.len();
438 let n_train = self.training_graphs.len();
439 let mut kernel_matrix = Array2::zeros((n_test, n_train));
440
441 let kernel = ShortestPathKernel {
442 use_node_labels: self.use_node_labels,
443 normalize: self.normalize,
444 };
445
446 for i in 0..n_test {
447 for j in 0..n_train {
448 kernel_matrix[(i, j)] = kernel.kernel_value(&graphs[i], &self.training_graphs[j]);
449 }
450 }
451
452 Ok(kernel_matrix)
453 }
454}
455
456#[derive(Debug, Clone)]
458pub struct WeisfeilerLehmanKernel {
460 iterations: usize,
462 use_node_labels: bool,
464 normalize: bool,
466}
467
468impl WeisfeilerLehmanKernel {
469 pub fn new(iterations: usize) -> Self {
470 Self {
471 iterations,
472 use_node_labels: false,
473 normalize: true,
474 }
475 }
476
477 pub fn use_node_labels(mut self, use_labels: bool) -> Self {
479 self.use_node_labels = use_labels;
480 self
481 }
482
483 pub fn normalize(mut self, normalize: bool) -> Self {
485 self.normalize = normalize;
486 self
487 }
488
489 fn wl_relabel(&self, graph: &Graph) -> Vec<HashMap<usize, String>> {
491 let mut labelings = Vec::new();
492 let nodes = graph.nodes();
493
494 let mut current_labels = HashMap::new();
496 for &node in &nodes {
497 let initial_label = if self.use_node_labels {
498 graph
499 .node_labels
500 .as_ref()
501 .and_then(|labels| labels.get(&node))
502 .cloned()
503 .unwrap_or_else(|| "default".to_string())
504 } else {
505 "1".to_string()
506 };
507 current_labels.insert(node, initial_label);
508 }
509 labelings.push(current_labels.clone());
510
511 for _iter in 0..self.iterations {
513 let mut new_labels = HashMap::new();
514
515 for &node in &nodes {
516 let mut neighbor_labels = Vec::new();
517 for &neighbor in &graph.neighbors(node) {
518 if let Some(label) = current_labels.get(&neighbor) {
519 neighbor_labels.push(label.clone());
520 }
521 }
522 neighbor_labels.sort();
523
524 let current_label = current_labels.get(&node).cloned().unwrap_or_default();
525 let new_label = format!("{}:{}", current_label, neighbor_labels.join(","));
526 new_labels.insert(node, new_label);
527 }
528
529 labelings.push(new_labels.clone());
530 current_labels = new_labels;
531 }
532
533 labelings
534 }
535
536 fn extract_features(&self, graph: &Graph) -> HashMap<String, usize> {
538 let labelings = self.wl_relabel(graph);
539 let mut features = HashMap::new();
540
541 for labeling in labelings {
542 for (_, label) in labeling {
543 *features.entry(label).or_insert(0) += 1;
544 }
545 }
546
547 features
548 }
549
550 fn kernel_value(&self, g1: &Graph, g2: &Graph) -> f64 {
552 let features1 = self.extract_features(g1);
553 let features2 = self.extract_features(g2);
554
555 let mut dot_product = 0.0;
556 for (feature, &count1) in &features1 {
557 if let Some(&count2) = features2.get(feature) {
558 dot_product += (count1 * count2) as f64;
559 }
560 }
561
562 if self.normalize {
563 let norm1 = features1
564 .values()
565 .map(|&x| (x * x) as f64)
566 .sum::<f64>()
567 .sqrt();
568 let norm2 = features2
569 .values()
570 .map(|&x| (x * x) as f64)
571 .sum::<f64>()
572 .sqrt();
573 if norm1 > 0.0 && norm2 > 0.0 {
574 dot_product / (norm1 * norm2)
575 } else {
576 0.0
577 }
578 } else {
579 dot_product
580 }
581 }
582}
583
584#[derive(Debug, Clone)]
586pub struct FittedWeisfeilerLehmanKernel {
588 training_graphs: Vec<Graph>,
590 iterations: usize,
592 use_node_labels: bool,
594 normalize: bool,
596}
597
598impl Fit<Vec<Graph>, ()> for WeisfeilerLehmanKernel {
599 type Fitted = FittedWeisfeilerLehmanKernel;
600 fn fit(self, graphs: &Vec<Graph>, _y: &()) -> Result<Self::Fitted> {
601 Ok(FittedWeisfeilerLehmanKernel {
602 training_graphs: graphs.clone(),
603 iterations: self.iterations,
604 use_node_labels: self.use_node_labels,
605 normalize: self.normalize,
606 })
607 }
608}
609
610impl Transform<Vec<Graph>, Array2<f64>> for FittedWeisfeilerLehmanKernel {
611 fn transform(&self, graphs: &Vec<Graph>) -> Result<Array2<f64>> {
612 let n_test = graphs.len();
613 let n_train = self.training_graphs.len();
614 let mut kernel_matrix = Array2::zeros((n_test, n_train));
615
616 let kernel = WeisfeilerLehmanKernel {
617 iterations: self.iterations,
618 use_node_labels: self.use_node_labels,
619 normalize: self.normalize,
620 };
621
622 for i in 0..n_test {
623 for j in 0..n_train {
624 kernel_matrix[(i, j)] = kernel.kernel_value(&graphs[i], &self.training_graphs[j]);
625 }
626 }
627
628 Ok(kernel_matrix)
629 }
630}
631
632#[derive(Debug, Clone)]
634pub struct SubgraphKernel {
636 max_size: usize,
638 connected_only: bool,
640 normalize: bool,
642}
643
644impl SubgraphKernel {
645 pub fn new(max_size: usize) -> Self {
646 Self {
647 max_size,
648 connected_only: true,
649 normalize: true,
650 }
651 }
652
653 pub fn connected_only(mut self, connected: bool) -> Self {
655 self.connected_only = connected;
656 self
657 }
658
659 pub fn normalize(mut self, normalize: bool) -> Self {
661 self.normalize = normalize;
662 self
663 }
664
665 fn find_connected_subgraphs(&self, graph: &Graph, size: usize) -> Vec<Vec<usize>> {
667 if size == 0 {
668 return vec![];
669 }
670
671 let mut subgraphs = Vec::new();
672 let nodes = graph.nodes();
673
674 let combinations = self.combinations(&nodes, size);
676
677 for combination in combinations {
678 if self.is_connected_subgraph(graph, &combination) {
679 subgraphs.push(combination);
680 }
681 }
682
683 subgraphs
684 }
685
686 fn is_connected_subgraph(&self, graph: &Graph, nodes: &[usize]) -> bool {
688 if nodes.len() <= 1 {
689 return true;
690 }
691
692 let node_set: HashSet<_> = nodes.iter().collect();
693 let mut visited = HashSet::new();
694 let mut queue = VecDeque::new();
695
696 queue.push_back(nodes[0]);
698 visited.insert(nodes[0]);
699
700 while let Some(current) = queue.pop_front() {
701 for &neighbor in &graph.neighbors(current) {
702 if node_set.contains(&neighbor) && !visited.contains(&neighbor) {
703 visited.insert(neighbor);
704 queue.push_back(neighbor);
705 }
706 }
707 }
708
709 visited.len() == nodes.len()
710 }
711
712 fn combinations(&self, items: &[usize], k: usize) -> Vec<Vec<usize>> {
714 if k == 0 {
715 return vec![vec![]];
716 }
717 if k > items.len() {
718 return vec![];
719 }
720 if k == items.len() {
721 return vec![items.to_vec()];
722 }
723
724 let mut result = Vec::new();
725
726 let with_first = self.combinations(&items[1..], k - 1);
728 for mut combo in with_first {
729 combo.insert(0, items[0]);
730 result.push(combo);
731 }
732
733 let without_first = self.combinations(&items[1..], k);
735 result.extend(without_first);
736
737 result
738 }
739
740 fn subgraph_to_string(&self, graph: &Graph, nodes: &[usize]) -> String {
742 let mut edges = Vec::new();
743 let node_set: HashSet<_> = nodes.iter().collect();
744
745 for &node in nodes {
746 for &neighbor in &graph.neighbors(node) {
747 if node_set.contains(&neighbor) && node < neighbor {
748 edges.push((node, neighbor));
749 }
750 }
751 }
752
753 edges.sort();
754 format!("nodes:{},edges:{:?}", nodes.len(), edges)
755 }
756
757 fn extract_features(&self, graph: &Graph) -> HashMap<String, usize> {
759 let mut features = HashMap::new();
760
761 for size in 1..=self.max_size {
762 let subgraphs = if self.connected_only {
763 self.find_connected_subgraphs(graph, size)
764 } else {
765 self.find_connected_subgraphs(graph, size)
767 };
768
769 for subgraph in subgraphs {
770 let feature = self.subgraph_to_string(graph, &subgraph);
771 *features.entry(feature).or_insert(0) += 1;
772 }
773 }
774
775 features
776 }
777
778 fn kernel_value(&self, g1: &Graph, g2: &Graph) -> f64 {
780 let features1 = self.extract_features(g1);
781 let features2 = self.extract_features(g2);
782
783 let mut dot_product = 0.0;
784 for (feature, &count1) in &features1 {
785 if let Some(&count2) = features2.get(feature) {
786 dot_product += (count1 * count2) as f64;
787 }
788 }
789
790 if self.normalize {
791 let norm1 = features1
792 .values()
793 .map(|&x| (x * x) as f64)
794 .sum::<f64>()
795 .sqrt();
796 let norm2 = features2
797 .values()
798 .map(|&x| (x * x) as f64)
799 .sum::<f64>()
800 .sqrt();
801 if norm1 > 0.0 && norm2 > 0.0 {
802 dot_product / (norm1 * norm2)
803 } else {
804 0.0
805 }
806 } else {
807 dot_product
808 }
809 }
810}
811
812#[derive(Debug, Clone)]
814pub struct FittedSubgraphKernel {
816 training_graphs: Vec<Graph>,
818 max_size: usize,
820 connected_only: bool,
822 normalize: bool,
824}
825
826impl Fit<Vec<Graph>, ()> for SubgraphKernel {
827 type Fitted = FittedSubgraphKernel;
828
829 fn fit(self, graphs: &Vec<Graph>, _y: &()) -> Result<Self::Fitted> {
830 Ok(FittedSubgraphKernel {
831 training_graphs: graphs.clone(),
832 max_size: self.max_size,
833 connected_only: self.connected_only,
834 normalize: self.normalize,
835 })
836 }
837}
838
839impl Transform<Vec<Graph>, Array2<f64>> for FittedSubgraphKernel {
840 fn transform(&self, graphs: &Vec<Graph>) -> Result<Array2<f64>> {
841 let n_test = graphs.len();
842 let n_train = self.training_graphs.len();
843 let mut kernel_matrix = Array2::zeros((n_test, n_train));
844
845 let kernel = SubgraphKernel {
846 max_size: self.max_size,
847 connected_only: self.connected_only,
848 normalize: self.normalize,
849 };
850
851 for i in 0..n_test {
852 for j in 0..n_train {
853 kernel_matrix[(i, j)] = kernel.kernel_value(&graphs[i], &self.training_graphs[j]);
854 }
855 }
856
857 Ok(kernel_matrix)
858 }
859}
860
861#[allow(non_snake_case)]
862#[cfg(test)]
863mod tests {
864 use super::*;
865 use approx::assert_abs_diff_eq;
866
867 fn create_test_graph(edges: Vec<(usize, usize)>, num_nodes: usize) -> Graph {
868 let mut graph = Graph::new(num_nodes);
869 for (from, to) in edges {
870 graph.add_edge(from, to);
871 }
872 graph
873 }
874
875 #[test]
876 fn test_graph_creation() {
877 let mut graph = Graph::new(3);
878 graph.add_edge(0, 1);
879 graph.add_edge(1, 2);
880
881 assert_eq!(graph.neighbors(0), vec![1]);
882 assert_eq!(graph.neighbors(1), vec![0, 2]);
883 assert_eq!(graph.neighbors(2), vec![1]);
884 assert_eq!(graph.nodes(), vec![0, 1, 2]);
885 }
886
887 #[test]
888 fn test_random_walk_kernel() {
889 let kernel = RandomWalkKernel::new(3, 0.1);
890
891 let graph1 = create_test_graph(vec![(0, 1), (1, 2)], 3);
892 let graph2 = create_test_graph(vec![(0, 1), (1, 2)], 3);
893 let graph3 = create_test_graph(vec![(0, 1), (0, 2), (1, 2)], 3);
894
895 let graphs = vec![graph1, graph2, graph3];
896 let fitted = kernel.fit(&graphs, &()).unwrap();
897 let kernel_matrix = fitted.transform(&graphs).unwrap();
898
899 assert_eq!(kernel_matrix.shape(), &[3, 3]);
900
901 assert_abs_diff_eq!(
903 kernel_matrix[(0, 0)],
904 kernel_matrix[(1, 1)],
905 epsilon = 1e-10
906 );
907 assert_abs_diff_eq!(
908 kernel_matrix[(0, 1)],
909 kernel_matrix[(1, 0)],
910 epsilon = 1e-10
911 );
912 }
913
914 #[test]
915 fn test_shortest_path_kernel() {
916 let kernel = ShortestPathKernel::new();
917
918 let graph1 = create_test_graph(vec![(0, 1), (1, 2)], 3);
919 let graph2 = create_test_graph(vec![(0, 1), (1, 2), (0, 2)], 3);
920
921 let graphs = vec![graph1, graph2];
922 let fitted = kernel.fit(&graphs, &()).unwrap();
923 let kernel_matrix = fitted.transform(&graphs).unwrap();
924
925 assert_eq!(kernel_matrix.shape(), &[2, 2]);
926 assert!(kernel_matrix.iter().all(|&x| x >= 0.0 && x.is_finite()));
928
929 assert_abs_diff_eq!(kernel_matrix[(0, 0)], 1.0, epsilon = 1e-10);
931 assert_abs_diff_eq!(kernel_matrix[(1, 1)], 1.0, epsilon = 1e-10);
932
933 assert_abs_diff_eq!(
935 kernel_matrix[(0, 1)],
936 kernel_matrix[(1, 0)],
937 epsilon = 1e-10
938 );
939
940 assert!(kernel_matrix[(0, 1)] >= 0.0 && kernel_matrix[(0, 1)] <= 1.0);
942 }
943
944 #[test]
945 fn test_weisfeiler_lehman_kernel() {
946 let kernel = WeisfeilerLehmanKernel::new(2);
947
948 let graph1 = create_test_graph(vec![(0, 1), (1, 2)], 3);
949 let graph2 = create_test_graph(vec![(0, 1), (1, 2)], 3);
950 let graph3 = create_test_graph(vec![(0, 1), (0, 2), (1, 2)], 3);
951
952 let graphs = vec![graph1, graph2, graph3];
953 let fitted = kernel.fit(&graphs, &()).unwrap();
954 let kernel_matrix = fitted.transform(&graphs).unwrap();
955
956 assert_eq!(kernel_matrix.shape(), &[3, 3]);
957 assert!(kernel_matrix
958 .iter()
959 .all(|&x| x >= 0.0 && x <= 1.0 && x.is_finite()));
960
961 assert_abs_diff_eq!(kernel_matrix[(0, 1)], 1.0, epsilon = 1e-10);
963 }
964
965 #[test]
966 fn test_subgraph_kernel() {
967 let kernel = SubgraphKernel::new(2);
968
969 let graph1 = create_test_graph(vec![(0, 1)], 2);
970 let graph2 = create_test_graph(vec![(0, 1), (1, 2)], 3);
971
972 let graphs = vec![graph1, graph2];
973 let fitted = kernel.fit(&graphs, &()).unwrap();
974 let kernel_matrix = fitted.transform(&graphs).unwrap();
975
976 assert_eq!(kernel_matrix.shape(), &[2, 2]);
977 assert!(kernel_matrix.iter().all(|&x| x >= 0.0 && x.is_finite()));
978 }
979
980 #[test]
981 fn test_graph_with_labels() {
982 let mut graph = Graph::new(3);
983 graph.add_edge(0, 1);
984 graph.add_edge(1, 2);
985
986 let mut labels = HashMap::new();
987 labels.insert(0, "A".to_string());
988 labels.insert(1, "B".to_string());
989 labels.insert(2, "A".to_string());
990 graph.set_node_labels(labels);
991
992 let kernel = WeisfeilerLehmanKernel::new(1).use_node_labels(true);
993 let graphs = vec![graph];
994 let fitted = kernel.fit(&graphs, &()).unwrap();
995 let kernel_matrix = fitted.transform(&graphs).unwrap();
996
997 assert_eq!(kernel_matrix.shape(), &[1, 1]);
998 assert_abs_diff_eq!(kernel_matrix[(0, 0)], 1.0, epsilon = 1e-10);
999 }
1000
1001 #[test]
1002 fn test_shortest_path_computation() {
1003 let kernel = ShortestPathKernel::new();
1004 let graph = create_test_graph(vec![(0, 1), (1, 2), (2, 3)], 4);
1005
1006 let distances = kernel.all_pairs_shortest_paths(&graph);
1007
1008 assert_eq!(distances[&(0, 3)], 3);
1009 assert_eq!(distances[&(0, 1)], 1);
1010 assert_eq!(distances[&(1, 3)], 2);
1011 }
1012
1013 #[test]
1014 fn test_subgraph_connectivity() {
1015 let kernel = SubgraphKernel::new(3);
1016 let graph = create_test_graph(vec![(0, 1), (2, 3)], 4); assert!(!kernel.is_connected_subgraph(&graph, &[0, 1, 2]));
1019 assert!(kernel.is_connected_subgraph(&graph, &[0, 1]));
1020 assert!(kernel.is_connected_subgraph(&graph, &[2, 3]));
1021 }
1022}