sklears_kernel_approximation/
graph_kernels.rs

1//! Graph Kernel Approximations
2//!
3//! This module implements various graph kernel approximation methods for
4//! analyzing graph-structured data such as molecular graphs, social networks,
5//! and other relational data structures.
6//!
7//! # Key Features
8//!
9//! - **Random Walk Kernels**: Count common random walks between graphs
10//! - **Shortest Path Kernels**: Compare shortest path distributions
11//! - **Weisfeiler-Lehman Kernels**: Graph isomorphism-based kernels
12//! - **Subgraph Kernels**: Count common subgraph patterns
13//! - **Graphlet Kernels**: Count small connected subgraphs
14//! - **Graph Laplacian Kernels**: Use graph Laplacian spectrum
15//!
16//! # Mathematical Background
17//!
18//! Graph kernel between graphs G₁ and G₂:
19//! K(G₁, G₂) = Σ φ(G₁)[f] * φ(G₂)[f]
20//!
21//! Where φ(G)[f] is the feature map counting occurrences of feature f in graph G.
22//!
23//! # References
24//!
25//! - Vishwanathan, S. V. N., et al. (2010). Graph kernels
26//! - Weisfeiler, B., & Lehman, A. A. (1968). The reduction of a graph to canonical form
27
28use scirs2_core::ndarray::Array2;
29use sklears_core::error::Result;
30use sklears_core::traits::{Fit, Transform};
31use std::collections::{HashMap, HashSet, VecDeque};
32
33/// Simple graph representation
34#[derive(Debug, Clone)]
35/// Graph
36pub struct Graph {
37    /// Adjacency list representation
38    pub adjacency: HashMap<usize, Vec<usize>>,
39    /// Node labels (optional)
40    pub node_labels: Option<HashMap<usize, String>>,
41    /// Edge labels (optional)
42    pub edge_labels: Option<HashMap<(usize, usize), String>>,
43    /// Number of nodes
44    pub num_nodes: usize,
45}
46
47impl Graph {
48    /// Create new graph
49    pub fn new(num_nodes: usize) -> Self {
50        Self {
51            adjacency: HashMap::new(),
52            node_labels: None,
53            edge_labels: None,
54            num_nodes,
55        }
56    }
57
58    /// Add edge
59    pub fn add_edge(&mut self, from: usize, to: usize) {
60        self.adjacency.entry(from).or_insert_with(Vec::new).push(to);
61        self.adjacency.entry(to).or_insert_with(Vec::new).push(from);
62    }
63
64    /// Add directed edge
65    pub fn add_directed_edge(&mut self, from: usize, to: usize) {
66        self.adjacency.entry(from).or_insert_with(Vec::new).push(to);
67    }
68
69    /// Set node labels
70    pub fn set_node_labels(&mut self, labels: HashMap<usize, String>) {
71        self.node_labels = Some(labels);
72    }
73
74    /// Set edge labels
75    pub fn set_edge_labels(&mut self, labels: HashMap<(usize, usize), String>) {
76        self.edge_labels = Some(labels);
77    }
78
79    /// Get neighbors of a node
80    pub fn neighbors(&self, node: usize) -> Vec<usize> {
81        self.adjacency.get(&node).cloned().unwrap_or_default()
82    }
83
84    /// Get all nodes
85    pub fn nodes(&self) -> Vec<usize> {
86        (0..self.num_nodes).collect()
87    }
88
89    /// Get all edges
90    pub fn edges(&self) -> Vec<(usize, usize)> {
91        let mut edges = Vec::new();
92        for (&from, neighbors) in &self.adjacency {
93            for &to in neighbors {
94                if from <= to {
95                    // Avoid duplicates for undirected graphs
96                    edges.push((from, to));
97                }
98            }
99        }
100        edges
101    }
102}
103
104/// Random walk kernel for graphs
105#[derive(Debug, Clone)]
106/// RandomWalkKernel
107pub struct RandomWalkKernel {
108    /// Maximum walk length
109    max_length: usize,
110    /// Convergence parameter (lambda)
111    lambda: f64,
112    /// Whether to use node labels
113    use_node_labels: bool,
114    /// Whether to use edge labels
115    use_edge_labels: bool,
116}
117
118impl RandomWalkKernel {
119    pub fn new(max_length: usize, lambda: f64) -> Self {
120        Self {
121            max_length,
122            lambda,
123            use_node_labels: false,
124            use_edge_labels: false,
125        }
126    }
127
128    /// Enable node labels
129    pub fn use_node_labels(mut self, use_labels: bool) -> Self {
130        self.use_node_labels = use_labels;
131        self
132    }
133
134    /// Enable edge labels
135    pub fn use_edge_labels(mut self, use_labels: bool) -> Self {
136        self.use_edge_labels = use_labels;
137        self
138    }
139
140    /// Compute direct product graph for random walk kernel
141    fn product_graph(&self, g1: &Graph, g2: &Graph) -> Graph {
142        let mut product = Graph::new(g1.num_nodes * g2.num_nodes);
143
144        // Create nodes in product graph: (i, j) -> i * g2.num_nodes + j
145        for i in 0..g1.num_nodes {
146            for j in 0..g2.num_nodes {
147                let node_ij = i * g2.num_nodes + j;
148
149                // Check if nodes match (labels if available)
150                let nodes_match = if self.use_node_labels {
151                    if let (Some(labels1), Some(labels2)) = (&g1.node_labels, &g2.node_labels) {
152                        labels1.get(&i) == labels2.get(&j)
153                    } else {
154                        true
155                    }
156                } else {
157                    true
158                };
159
160                if !nodes_match {
161                    continue;
162                }
163
164                // Add edges in product graph
165                for &neighbor_i in &g1.neighbors(i) {
166                    for &neighbor_j in &g2.neighbors(j) {
167                        let neighbor_ij = neighbor_i * g2.num_nodes + neighbor_j;
168
169                        // Check if edges match (labels if available)
170                        let edges_match = if self.use_edge_labels {
171                            if let (Some(labels1), Some(labels2)) =
172                                (&g1.edge_labels, &g2.edge_labels)
173                            {
174                                labels1.get(&(i, neighbor_i)) == labels2.get(&(j, neighbor_j))
175                            } else {
176                                true
177                            }
178                        } else {
179                            true
180                        };
181
182                        if edges_match {
183                            product.add_directed_edge(node_ij, neighbor_ij);
184                        }
185                    }
186                }
187            }
188        }
189
190        product
191    }
192
193    /// Compute random walk kernel value between two graphs
194    fn kernel_value(&self, g1: &Graph, g2: &Graph) -> f64 {
195        let product = self.product_graph(g1, g2);
196
197        // Count walks in product graph using matrix powers
198        let n = product.num_nodes;
199        if n == 0 {
200            return 0.0;
201        }
202
203        // Build adjacency matrix
204        let mut adj = Array2::zeros((n, n));
205        for (&from, neighbors) in &product.adjacency {
206            for &to in neighbors {
207                adj[(from, to)] = 1.0;
208            }
209        }
210
211        // Compute sum of matrix powers: I + λA + λ²A² + ... + λᵏAᵏ
212        let result = Array2::eye(n);
213        let mut current_power = Array2::eye(n);
214        let mut total = result.clone();
215
216        for k in 1..=self.max_length {
217            current_power = current_power.dot(&adj);
218            total = total + self.lambda.powi(k as i32) * &current_power;
219        }
220
221        // Sum all entries (total number of walks)
222        total.sum()
223    }
224}
225
226/// Fitted random walk kernel
227#[derive(Debug, Clone)]
228/// FittedRandomWalkKernel
229pub struct FittedRandomWalkKernel {
230    /// Training graphs
231    training_graphs: Vec<Graph>,
232    /// Kernel parameters
233    max_length: usize,
234    lambda: f64,
235    use_node_labels: bool,
236    use_edge_labels: bool,
237}
238
239impl Fit<Vec<Graph>, ()> for RandomWalkKernel {
240    type Fitted = FittedRandomWalkKernel;
241    fn fit(self, graphs: &Vec<Graph>, _y: &()) -> Result<Self::Fitted> {
242        Ok(FittedRandomWalkKernel {
243            training_graphs: graphs.clone(),
244            max_length: self.max_length,
245            lambda: self.lambda,
246            use_node_labels: self.use_node_labels,
247            use_edge_labels: self.use_edge_labels,
248        })
249    }
250}
251
252impl Transform<Vec<Graph>, Array2<f64>> for FittedRandomWalkKernel {
253    fn transform(&self, graphs: &Vec<Graph>) -> Result<Array2<f64>> {
254        let n_test = graphs.len();
255        let n_train = self.training_graphs.len();
256        let mut kernel_matrix = Array2::zeros((n_test, n_train));
257
258        let kernel = RandomWalkKernel {
259            max_length: self.max_length,
260            lambda: self.lambda,
261            use_node_labels: self.use_node_labels,
262            use_edge_labels: self.use_edge_labels,
263        };
264
265        for i in 0..n_test {
266            for j in 0..n_train {
267                kernel_matrix[(i, j)] = kernel.kernel_value(&graphs[i], &self.training_graphs[j]);
268            }
269        }
270
271        Ok(kernel_matrix)
272    }
273}
274
275/// Shortest path kernel for graphs
276#[derive(Debug, Clone)]
277/// ShortestPathKernel
278pub struct ShortestPathKernel {
279    /// Whether to use node labels
280    use_node_labels: bool,
281    /// Whether to normalize by graph size
282    normalize: bool,
283}
284
285impl ShortestPathKernel {
286    pub fn new() -> Self {
287        Self {
288            use_node_labels: false,
289            normalize: true,
290        }
291    }
292
293    /// Enable node labels
294    pub fn use_node_labels(mut self, use_labels: bool) -> Self {
295        self.use_node_labels = use_labels;
296        self
297    }
298
299    /// Set normalization
300    pub fn normalize(mut self, normalize: bool) -> Self {
301        self.normalize = normalize;
302        self
303    }
304
305    /// Compute shortest paths between all pairs of nodes
306    fn all_pairs_shortest_paths(&self, graph: &Graph) -> HashMap<(usize, usize), usize> {
307        let mut distances = HashMap::new();
308        let nodes = graph.nodes();
309
310        // Initialize distances
311        for &i in &nodes {
312            for &j in &nodes {
313                if i == j {
314                    distances.insert((i, j), 0);
315                } else {
316                    distances.insert((i, j), usize::MAX);
317                }
318            }
319        }
320
321        // Set direct edge distances
322        for (&from, neighbors) in &graph.adjacency {
323            for &to in neighbors {
324                distances.insert((from, to), 1);
325            }
326        }
327
328        // Floyd-Warshall algorithm
329        for &k in &nodes {
330            for &i in &nodes {
331                for &j in &nodes {
332                    if let (Some(&dist_ik), Some(&dist_kj)) =
333                        (distances.get(&(i, k)), distances.get(&(k, j)))
334                    {
335                        if dist_ik != usize::MAX && dist_kj != usize::MAX {
336                            let new_dist = dist_ik + dist_kj;
337                            if let Some(current_dist) = distances.get_mut(&(i, j)) {
338                                if new_dist < *current_dist {
339                                    *current_dist = new_dist;
340                                }
341                            }
342                        }
343                    }
344                }
345            }
346        }
347
348        distances
349    }
350
351    /// Extract shortest path features
352    fn extract_features(&self, graph: &Graph) -> HashMap<String, usize> {
353        let distances = self.all_pairs_shortest_paths(graph);
354        let mut features = HashMap::new();
355
356        for ((i, j), &dist) in &distances {
357            if dist != usize::MAX {
358                let feature = if self.use_node_labels {
359                    if let Some(ref labels) = graph.node_labels {
360                        let label_i = labels.get(i).cloned().unwrap_or_default();
361                        let label_j = labels.get(j).cloned().unwrap_or_default();
362                        format!("{}:{}:{}", label_i, label_j, dist)
363                    } else {
364                        format!("path:{}", dist)
365                    }
366                } else {
367                    format!("path:{}", dist)
368                };
369
370                *features.entry(feature).or_insert(0) += 1;
371            }
372        }
373
374        features
375    }
376
377    /// Compute kernel value between two graphs
378    fn kernel_value(&self, g1: &Graph, g2: &Graph) -> f64 {
379        let features1 = self.extract_features(g1);
380        let features2 = self.extract_features(g2);
381
382        let mut dot_product = 0.0;
383        for (feature, &count1) in &features1 {
384            if let Some(&count2) = features2.get(feature) {
385                dot_product += (count1 * count2) as f64;
386            }
387        }
388
389        if self.normalize {
390            let norm1 = features1
391                .values()
392                .map(|&x| (x * x) as f64)
393                .sum::<f64>()
394                .sqrt();
395            let norm2 = features2
396                .values()
397                .map(|&x| (x * x) as f64)
398                .sum::<f64>()
399                .sqrt();
400            if norm1 > 0.0 && norm2 > 0.0 {
401                dot_product / (norm1 * norm2)
402            } else {
403                0.0
404            }
405        } else {
406            dot_product
407        }
408    }
409}
410
411/// Fitted shortest path kernel
412#[derive(Debug, Clone)]
413/// FittedShortestPathKernel
414pub struct FittedShortestPathKernel {
415    /// Training graphs
416    training_graphs: Vec<Graph>,
417    /// Use node labels
418    use_node_labels: bool,
419    /// Normalize flag
420    normalize: bool,
421}
422
423impl Fit<Vec<Graph>, ()> for ShortestPathKernel {
424    type Fitted = FittedShortestPathKernel;
425
426    fn fit(self, graphs: &Vec<Graph>, _y: &()) -> Result<Self::Fitted> {
427        Ok(FittedShortestPathKernel {
428            training_graphs: graphs.clone(),
429            use_node_labels: self.use_node_labels,
430            normalize: self.normalize,
431        })
432    }
433}
434
435impl Transform<Vec<Graph>, Array2<f64>> for FittedShortestPathKernel {
436    fn transform(&self, graphs: &Vec<Graph>) -> Result<Array2<f64>> {
437        let n_test = graphs.len();
438        let n_train = self.training_graphs.len();
439        let mut kernel_matrix = Array2::zeros((n_test, n_train));
440
441        let kernel = ShortestPathKernel {
442            use_node_labels: self.use_node_labels,
443            normalize: self.normalize,
444        };
445
446        for i in 0..n_test {
447            for j in 0..n_train {
448                kernel_matrix[(i, j)] = kernel.kernel_value(&graphs[i], &self.training_graphs[j]);
449            }
450        }
451
452        Ok(kernel_matrix)
453    }
454}
455
456/// Weisfeiler-Lehman kernel for graphs
457#[derive(Debug, Clone)]
458/// WeisfeilerLehmanKernel
459pub struct WeisfeilerLehmanKernel {
460    /// Number of iterations
461    iterations: usize,
462    /// Whether to use original node labels
463    use_node_labels: bool,
464    /// Whether to normalize features
465    normalize: bool,
466}
467
468impl WeisfeilerLehmanKernel {
469    pub fn new(iterations: usize) -> Self {
470        Self {
471            iterations,
472            use_node_labels: false,
473            normalize: true,
474        }
475    }
476
477    /// Enable node labels
478    pub fn use_node_labels(mut self, use_labels: bool) -> Self {
479        self.use_node_labels = use_labels;
480        self
481    }
482
483    /// Set normalization
484    pub fn normalize(mut self, normalize: bool) -> Self {
485        self.normalize = normalize;
486        self
487    }
488
489    /// Perform Weisfeiler-Lehman relabeling
490    fn wl_relabel(&self, graph: &Graph) -> Vec<HashMap<usize, String>> {
491        let mut labelings = Vec::new();
492        let nodes = graph.nodes();
493
494        // Initial labeling
495        let mut current_labels = HashMap::new();
496        for &node in &nodes {
497            let initial_label = if self.use_node_labels {
498                graph
499                    .node_labels
500                    .as_ref()
501                    .and_then(|labels| labels.get(&node))
502                    .cloned()
503                    .unwrap_or_else(|| "default".to_string())
504            } else {
505                "1".to_string()
506            };
507            current_labels.insert(node, initial_label);
508        }
509        labelings.push(current_labels.clone());
510
511        // Iterative relabeling
512        for _iter in 0..self.iterations {
513            let mut new_labels = HashMap::new();
514
515            for &node in &nodes {
516                let mut neighbor_labels = Vec::new();
517                for &neighbor in &graph.neighbors(node) {
518                    if let Some(label) = current_labels.get(&neighbor) {
519                        neighbor_labels.push(label.clone());
520                    }
521                }
522                neighbor_labels.sort();
523
524                let current_label = current_labels.get(&node).cloned().unwrap_or_default();
525                let new_label = format!("{}:{}", current_label, neighbor_labels.join(","));
526                new_labels.insert(node, new_label);
527            }
528
529            labelings.push(new_labels.clone());
530            current_labels = new_labels;
531        }
532
533        labelings
534    }
535
536    /// Extract features from WL labelings
537    fn extract_features(&self, graph: &Graph) -> HashMap<String, usize> {
538        let labelings = self.wl_relabel(graph);
539        let mut features = HashMap::new();
540
541        for labeling in labelings {
542            for (_, label) in labeling {
543                *features.entry(label).or_insert(0) += 1;
544            }
545        }
546
547        features
548    }
549
550    /// Compute kernel value between two graphs
551    fn kernel_value(&self, g1: &Graph, g2: &Graph) -> f64 {
552        let features1 = self.extract_features(g1);
553        let features2 = self.extract_features(g2);
554
555        let mut dot_product = 0.0;
556        for (feature, &count1) in &features1 {
557            if let Some(&count2) = features2.get(feature) {
558                dot_product += (count1 * count2) as f64;
559            }
560        }
561
562        if self.normalize {
563            let norm1 = features1
564                .values()
565                .map(|&x| (x * x) as f64)
566                .sum::<f64>()
567                .sqrt();
568            let norm2 = features2
569                .values()
570                .map(|&x| (x * x) as f64)
571                .sum::<f64>()
572                .sqrt();
573            if norm1 > 0.0 && norm2 > 0.0 {
574                dot_product / (norm1 * norm2)
575            } else {
576                0.0
577            }
578        } else {
579            dot_product
580        }
581    }
582}
583
584/// Fitted Weisfeiler-Lehman kernel
585#[derive(Debug, Clone)]
586/// FittedWeisfeilerLehmanKernel
587pub struct FittedWeisfeilerLehmanKernel {
588    /// Training graphs
589    training_graphs: Vec<Graph>,
590    /// Number of iterations
591    iterations: usize,
592    /// Use node labels
593    use_node_labels: bool,
594    /// Normalize flag
595    normalize: bool,
596}
597
598impl Fit<Vec<Graph>, ()> for WeisfeilerLehmanKernel {
599    type Fitted = FittedWeisfeilerLehmanKernel;
600    fn fit(self, graphs: &Vec<Graph>, _y: &()) -> Result<Self::Fitted> {
601        Ok(FittedWeisfeilerLehmanKernel {
602            training_graphs: graphs.clone(),
603            iterations: self.iterations,
604            use_node_labels: self.use_node_labels,
605            normalize: self.normalize,
606        })
607    }
608}
609
610impl Transform<Vec<Graph>, Array2<f64>> for FittedWeisfeilerLehmanKernel {
611    fn transform(&self, graphs: &Vec<Graph>) -> Result<Array2<f64>> {
612        let n_test = graphs.len();
613        let n_train = self.training_graphs.len();
614        let mut kernel_matrix = Array2::zeros((n_test, n_train));
615
616        let kernel = WeisfeilerLehmanKernel {
617            iterations: self.iterations,
618            use_node_labels: self.use_node_labels,
619            normalize: self.normalize,
620        };
621
622        for i in 0..n_test {
623            for j in 0..n_train {
624                kernel_matrix[(i, j)] = kernel.kernel_value(&graphs[i], &self.training_graphs[j]);
625            }
626        }
627
628        Ok(kernel_matrix)
629    }
630}
631
632/// Subgraph kernel that counts common subgraph patterns
633#[derive(Debug, Clone)]
634/// SubgraphKernel
635pub struct SubgraphKernel {
636    /// Maximum subgraph size
637    max_size: usize,
638    /// Whether to use connected subgraphs only
639    connected_only: bool,
640    /// Whether to normalize features
641    normalize: bool,
642}
643
644impl SubgraphKernel {
645    pub fn new(max_size: usize) -> Self {
646        Self {
647            max_size,
648            connected_only: true,
649            normalize: true,
650        }
651    }
652
653    /// Set connected subgraphs only
654    pub fn connected_only(mut self, connected: bool) -> Self {
655        self.connected_only = connected;
656        self
657    }
658
659    /// Set normalization
660    pub fn normalize(mut self, normalize: bool) -> Self {
661        self.normalize = normalize;
662        self
663    }
664
665    /// Find all connected subgraphs of given size
666    fn find_connected_subgraphs(&self, graph: &Graph, size: usize) -> Vec<Vec<usize>> {
667        if size == 0 {
668            return vec![];
669        }
670
671        let mut subgraphs = Vec::new();
672        let nodes = graph.nodes();
673
674        // Generate all combinations of nodes of given size
675        let combinations = self.combinations(&nodes, size);
676
677        for combination in combinations {
678            if self.is_connected_subgraph(graph, &combination) {
679                subgraphs.push(combination);
680            }
681        }
682
683        subgraphs
684    }
685
686    /// Check if a set of nodes forms a connected subgraph
687    fn is_connected_subgraph(&self, graph: &Graph, nodes: &[usize]) -> bool {
688        if nodes.len() <= 1 {
689            return true;
690        }
691
692        let node_set: HashSet<_> = nodes.iter().collect();
693        let mut visited = HashSet::new();
694        let mut queue = VecDeque::new();
695
696        // Start BFS from first node
697        queue.push_back(nodes[0]);
698        visited.insert(nodes[0]);
699
700        while let Some(current) = queue.pop_front() {
701            for &neighbor in &graph.neighbors(current) {
702                if node_set.contains(&neighbor) && !visited.contains(&neighbor) {
703                    visited.insert(neighbor);
704                    queue.push_back(neighbor);
705                }
706            }
707        }
708
709        visited.len() == nodes.len()
710    }
711
712    /// Generate all combinations of k elements from a vector
713    fn combinations(&self, items: &[usize], k: usize) -> Vec<Vec<usize>> {
714        if k == 0 {
715            return vec![vec![]];
716        }
717        if k > items.len() {
718            return vec![];
719        }
720        if k == items.len() {
721            return vec![items.to_vec()];
722        }
723
724        let mut result = Vec::new();
725
726        // Include first element
727        let with_first = self.combinations(&items[1..], k - 1);
728        for mut combo in with_first {
729            combo.insert(0, items[0]);
730            result.push(combo);
731        }
732
733        // Exclude first element
734        let without_first = self.combinations(&items[1..], k);
735        result.extend(without_first);
736
737        result
738    }
739
740    /// Convert subgraph to canonical string representation
741    fn subgraph_to_string(&self, graph: &Graph, nodes: &[usize]) -> String {
742        let mut edges = Vec::new();
743        let node_set: HashSet<_> = nodes.iter().collect();
744
745        for &node in nodes {
746            for &neighbor in &graph.neighbors(node) {
747                if node_set.contains(&neighbor) && node < neighbor {
748                    edges.push((node, neighbor));
749                }
750            }
751        }
752
753        edges.sort();
754        format!("nodes:{},edges:{:?}", nodes.len(), edges)
755    }
756
757    /// Extract subgraph features
758    fn extract_features(&self, graph: &Graph) -> HashMap<String, usize> {
759        let mut features = HashMap::new();
760
761        for size in 1..=self.max_size {
762            let subgraphs = if self.connected_only {
763                self.find_connected_subgraphs(graph, size)
764            } else {
765                // For simplicity, just use connected subgraphs
766                self.find_connected_subgraphs(graph, size)
767            };
768
769            for subgraph in subgraphs {
770                let feature = self.subgraph_to_string(graph, &subgraph);
771                *features.entry(feature).or_insert(0) += 1;
772            }
773        }
774
775        features
776    }
777
778    /// Compute kernel value between two graphs
779    fn kernel_value(&self, g1: &Graph, g2: &Graph) -> f64 {
780        let features1 = self.extract_features(g1);
781        let features2 = self.extract_features(g2);
782
783        let mut dot_product = 0.0;
784        for (feature, &count1) in &features1 {
785            if let Some(&count2) = features2.get(feature) {
786                dot_product += (count1 * count2) as f64;
787            }
788        }
789
790        if self.normalize {
791            let norm1 = features1
792                .values()
793                .map(|&x| (x * x) as f64)
794                .sum::<f64>()
795                .sqrt();
796            let norm2 = features2
797                .values()
798                .map(|&x| (x * x) as f64)
799                .sum::<f64>()
800                .sqrt();
801            if norm1 > 0.0 && norm2 > 0.0 {
802                dot_product / (norm1 * norm2)
803            } else {
804                0.0
805            }
806        } else {
807            dot_product
808        }
809    }
810}
811
812/// Fitted subgraph kernel
813#[derive(Debug, Clone)]
814/// FittedSubgraphKernel
815pub struct FittedSubgraphKernel {
816    /// Training graphs
817    training_graphs: Vec<Graph>,
818    /// Max subgraph size
819    max_size: usize,
820    /// Connected only flag
821    connected_only: bool,
822    /// Normalize flag
823    normalize: bool,
824}
825
826impl Fit<Vec<Graph>, ()> for SubgraphKernel {
827    type Fitted = FittedSubgraphKernel;
828
829    fn fit(self, graphs: &Vec<Graph>, _y: &()) -> Result<Self::Fitted> {
830        Ok(FittedSubgraphKernel {
831            training_graphs: graphs.clone(),
832            max_size: self.max_size,
833            connected_only: self.connected_only,
834            normalize: self.normalize,
835        })
836    }
837}
838
839impl Transform<Vec<Graph>, Array2<f64>> for FittedSubgraphKernel {
840    fn transform(&self, graphs: &Vec<Graph>) -> Result<Array2<f64>> {
841        let n_test = graphs.len();
842        let n_train = self.training_graphs.len();
843        let mut kernel_matrix = Array2::zeros((n_test, n_train));
844
845        let kernel = SubgraphKernel {
846            max_size: self.max_size,
847            connected_only: self.connected_only,
848            normalize: self.normalize,
849        };
850
851        for i in 0..n_test {
852            for j in 0..n_train {
853                kernel_matrix[(i, j)] = kernel.kernel_value(&graphs[i], &self.training_graphs[j]);
854            }
855        }
856
857        Ok(kernel_matrix)
858    }
859}
860
861#[allow(non_snake_case)]
862#[cfg(test)]
863mod tests {
864    use super::*;
865    use approx::assert_abs_diff_eq;
866
867    fn create_test_graph(edges: Vec<(usize, usize)>, num_nodes: usize) -> Graph {
868        let mut graph = Graph::new(num_nodes);
869        for (from, to) in edges {
870            graph.add_edge(from, to);
871        }
872        graph
873    }
874
875    #[test]
876    fn test_graph_creation() {
877        let mut graph = Graph::new(3);
878        graph.add_edge(0, 1);
879        graph.add_edge(1, 2);
880
881        assert_eq!(graph.neighbors(0), vec![1]);
882        assert_eq!(graph.neighbors(1), vec![0, 2]);
883        assert_eq!(graph.neighbors(2), vec![1]);
884        assert_eq!(graph.nodes(), vec![0, 1, 2]);
885    }
886
887    #[test]
888    fn test_random_walk_kernel() {
889        let kernel = RandomWalkKernel::new(3, 0.1);
890
891        let graph1 = create_test_graph(vec![(0, 1), (1, 2)], 3);
892        let graph2 = create_test_graph(vec![(0, 1), (1, 2)], 3);
893        let graph3 = create_test_graph(vec![(0, 1), (0, 2), (1, 2)], 3);
894
895        let graphs = vec![graph1, graph2, graph3];
896        let fitted = kernel.fit(&graphs, &()).unwrap();
897        let kernel_matrix = fitted.transform(&graphs).unwrap();
898
899        assert_eq!(kernel_matrix.shape(), &[3, 3]);
900
901        // Identical graphs should have same kernel value
902        assert_abs_diff_eq!(
903            kernel_matrix[(0, 0)],
904            kernel_matrix[(1, 1)],
905            epsilon = 1e-10
906        );
907        assert_abs_diff_eq!(
908            kernel_matrix[(0, 1)],
909            kernel_matrix[(1, 0)],
910            epsilon = 1e-10
911        );
912    }
913
914    #[test]
915    fn test_shortest_path_kernel() {
916        let kernel = ShortestPathKernel::new();
917
918        let graph1 = create_test_graph(vec![(0, 1), (1, 2)], 3);
919        let graph2 = create_test_graph(vec![(0, 1), (1, 2), (0, 2)], 3);
920
921        let graphs = vec![graph1, graph2];
922        let fitted = kernel.fit(&graphs, &()).unwrap();
923        let kernel_matrix = fitted.transform(&graphs).unwrap();
924
925        assert_eq!(kernel_matrix.shape(), &[2, 2]);
926        // Kernel values should be non-negative and finite
927        assert!(kernel_matrix.iter().all(|&x| x >= 0.0 && x.is_finite()));
928
929        // Self-similarity should be 1.0 for normalized kernels
930        assert_abs_diff_eq!(kernel_matrix[(0, 0)], 1.0, epsilon = 1e-10);
931        assert_abs_diff_eq!(kernel_matrix[(1, 1)], 1.0, epsilon = 1e-10);
932
933        // Matrix should be symmetric
934        assert_abs_diff_eq!(
935            kernel_matrix[(0, 1)],
936            kernel_matrix[(1, 0)],
937            epsilon = 1e-10
938        );
939
940        // All off-diagonal elements should be between 0 and 1 for normalized cosine similarity
941        assert!(kernel_matrix[(0, 1)] >= 0.0 && kernel_matrix[(0, 1)] <= 1.0);
942    }
943
944    #[test]
945    fn test_weisfeiler_lehman_kernel() {
946        let kernel = WeisfeilerLehmanKernel::new(2);
947
948        let graph1 = create_test_graph(vec![(0, 1), (1, 2)], 3);
949        let graph2 = create_test_graph(vec![(0, 1), (1, 2)], 3);
950        let graph3 = create_test_graph(vec![(0, 1), (0, 2), (1, 2)], 3);
951
952        let graphs = vec![graph1, graph2, graph3];
953        let fitted = kernel.fit(&graphs, &()).unwrap();
954        let kernel_matrix = fitted.transform(&graphs).unwrap();
955
956        assert_eq!(kernel_matrix.shape(), &[3, 3]);
957        assert!(kernel_matrix
958            .iter()
959            .all(|&x| x >= 0.0 && x <= 1.0 && x.is_finite()));
960
961        // Identical graphs should have similarity 1.0
962        assert_abs_diff_eq!(kernel_matrix[(0, 1)], 1.0, epsilon = 1e-10);
963    }
964
965    #[test]
966    fn test_subgraph_kernel() {
967        let kernel = SubgraphKernel::new(2);
968
969        let graph1 = create_test_graph(vec![(0, 1)], 2);
970        let graph2 = create_test_graph(vec![(0, 1), (1, 2)], 3);
971
972        let graphs = vec![graph1, graph2];
973        let fitted = kernel.fit(&graphs, &()).unwrap();
974        let kernel_matrix = fitted.transform(&graphs).unwrap();
975
976        assert_eq!(kernel_matrix.shape(), &[2, 2]);
977        assert!(kernel_matrix.iter().all(|&x| x >= 0.0 && x.is_finite()));
978    }
979
980    #[test]
981    fn test_graph_with_labels() {
982        let mut graph = Graph::new(3);
983        graph.add_edge(0, 1);
984        graph.add_edge(1, 2);
985
986        let mut labels = HashMap::new();
987        labels.insert(0, "A".to_string());
988        labels.insert(1, "B".to_string());
989        labels.insert(2, "A".to_string());
990        graph.set_node_labels(labels);
991
992        let kernel = WeisfeilerLehmanKernel::new(1).use_node_labels(true);
993        let graphs = vec![graph];
994        let fitted = kernel.fit(&graphs, &()).unwrap();
995        let kernel_matrix = fitted.transform(&graphs).unwrap();
996
997        assert_eq!(kernel_matrix.shape(), &[1, 1]);
998        assert_abs_diff_eq!(kernel_matrix[(0, 0)], 1.0, epsilon = 1e-10);
999    }
1000
1001    #[test]
1002    fn test_shortest_path_computation() {
1003        let kernel = ShortestPathKernel::new();
1004        let graph = create_test_graph(vec![(0, 1), (1, 2), (2, 3)], 4);
1005
1006        let distances = kernel.all_pairs_shortest_paths(&graph);
1007
1008        assert_eq!(distances[&(0, 3)], 3);
1009        assert_eq!(distances[&(0, 1)], 1);
1010        assert_eq!(distances[&(1, 3)], 2);
1011    }
1012
1013    #[test]
1014    fn test_subgraph_connectivity() {
1015        let kernel = SubgraphKernel::new(3);
1016        let graph = create_test_graph(vec![(0, 1), (2, 3)], 4); // Disconnected graph
1017
1018        assert!(!kernel.is_connected_subgraph(&graph, &[0, 1, 2]));
1019        assert!(kernel.is_connected_subgraph(&graph, &[0, 1]));
1020        assert!(kernel.is_connected_subgraph(&graph, &[2, 3]));
1021    }
1022}