sklears_kernel_approximation/
graph_kernels.rs

1//! Graph Kernel Approximations
2//!
3//! This module implements various graph kernel approximation methods for
4//! analyzing graph-structured data such as molecular graphs, social networks,
5//! and other relational data structures.
6//!
7//! # Key Features
8//!
9//! - **Random Walk Kernels**: Count common random walks between graphs
10//! - **Shortest Path Kernels**: Compare shortest path distributions
11//! - **Weisfeiler-Lehman Kernels**: Graph isomorphism-based kernels
12//! - **Subgraph Kernels**: Count common subgraph patterns
13//! - **Graphlet Kernels**: Count small connected subgraphs
14//! - **Graph Laplacian Kernels**: Use graph Laplacian spectrum
15//!
16//! # Mathematical Background
17//!
18//! Graph kernel between graphs G₁ and G₂:
19//! K(G₁, G₂) = Σ φ(G₁)[f] * φ(G₂)[f]
20//!
21//! Where φ(G)[f] is the feature map counting occurrences of feature f in graph G.
22//!
23//! # References
24//!
25//! - Vishwanathan, S. V. N., et al. (2010). Graph kernels
26//! - Weisfeiler, B., & Lehman, A. A. (1968). The reduction of a graph to canonical form
27
28use scirs2_core::ndarray::Array2;
29use sklears_core::error::Result;
30use sklears_core::traits::{Fit, Transform};
31use std::collections::{HashMap, HashSet, VecDeque};
32
33/// Simple graph representation
34#[derive(Debug, Clone)]
35/// Graph
36pub struct Graph {
37    /// Adjacency list representation
38    pub adjacency: HashMap<usize, Vec<usize>>,
39    /// Node labels (optional)
40    pub node_labels: Option<HashMap<usize, String>>,
41    /// Edge labels (optional)
42    pub edge_labels: Option<HashMap<(usize, usize), String>>,
43    /// Number of nodes
44    pub num_nodes: usize,
45}
46
47impl Graph {
48    /// Create new graph
49    pub fn new(num_nodes: usize) -> Self {
50        Self {
51            adjacency: HashMap::new(),
52            node_labels: None,
53            edge_labels: None,
54            num_nodes,
55        }
56    }
57
58    /// Add edge
59    pub fn add_edge(&mut self, from: usize, to: usize) {
60        self.adjacency.entry(from).or_default().push(to);
61        self.adjacency.entry(to).or_default().push(from);
62    }
63
64    /// Add directed edge
65    pub fn add_directed_edge(&mut self, from: usize, to: usize) {
66        self.adjacency.entry(from).or_default().push(to);
67    }
68
69    /// Set node labels
70    pub fn set_node_labels(&mut self, labels: HashMap<usize, String>) {
71        self.node_labels = Some(labels);
72    }
73
74    /// Set edge labels
75    pub fn set_edge_labels(&mut self, labels: HashMap<(usize, usize), String>) {
76        self.edge_labels = Some(labels);
77    }
78
79    /// Get neighbors of a node
80    pub fn neighbors(&self, node: usize) -> Vec<usize> {
81        self.adjacency.get(&node).cloned().unwrap_or_default()
82    }
83
84    /// Get all nodes
85    pub fn nodes(&self) -> Vec<usize> {
86        (0..self.num_nodes).collect()
87    }
88
89    /// Get all edges
90    pub fn edges(&self) -> Vec<(usize, usize)> {
91        let mut edges = Vec::new();
92        for (&from, neighbors) in &self.adjacency {
93            for &to in neighbors {
94                if from <= to {
95                    // Avoid duplicates for undirected graphs
96                    edges.push((from, to));
97                }
98            }
99        }
100        edges
101    }
102}
103
104/// Random walk kernel for graphs
105#[derive(Debug, Clone)]
106/// RandomWalkKernel
107pub struct RandomWalkKernel {
108    /// Maximum walk length
109    max_length: usize,
110    /// Convergence parameter (lambda)
111    lambda: f64,
112    /// Whether to use node labels
113    use_node_labels: bool,
114    /// Whether to use edge labels
115    use_edge_labels: bool,
116}
117
118impl RandomWalkKernel {
119    pub fn new(max_length: usize, lambda: f64) -> Self {
120        Self {
121            max_length,
122            lambda,
123            use_node_labels: false,
124            use_edge_labels: false,
125        }
126    }
127
128    /// Enable node labels
129    pub fn use_node_labels(mut self, use_labels: bool) -> Self {
130        self.use_node_labels = use_labels;
131        self
132    }
133
134    /// Enable edge labels
135    pub fn use_edge_labels(mut self, use_labels: bool) -> Self {
136        self.use_edge_labels = use_labels;
137        self
138    }
139
140    /// Compute direct product graph for random walk kernel
141    fn product_graph(&self, g1: &Graph, g2: &Graph) -> Graph {
142        let mut product = Graph::new(g1.num_nodes * g2.num_nodes);
143
144        // Create nodes in product graph: (i, j) -> i * g2.num_nodes + j
145        for i in 0..g1.num_nodes {
146            for j in 0..g2.num_nodes {
147                let node_ij = i * g2.num_nodes + j;
148
149                // Check if nodes match (labels if available)
150                let nodes_match = if self.use_node_labels {
151                    if let (Some(labels1), Some(labels2)) = (&g1.node_labels, &g2.node_labels) {
152                        labels1.get(&i) == labels2.get(&j)
153                    } else {
154                        true
155                    }
156                } else {
157                    true
158                };
159
160                if !nodes_match {
161                    continue;
162                }
163
164                // Add edges in product graph
165                for &neighbor_i in &g1.neighbors(i) {
166                    for &neighbor_j in &g2.neighbors(j) {
167                        let neighbor_ij = neighbor_i * g2.num_nodes + neighbor_j;
168
169                        // Check if edges match (labels if available)
170                        let edges_match = if self.use_edge_labels {
171                            if let (Some(labels1), Some(labels2)) =
172                                (&g1.edge_labels, &g2.edge_labels)
173                            {
174                                labels1.get(&(i, neighbor_i)) == labels2.get(&(j, neighbor_j))
175                            } else {
176                                true
177                            }
178                        } else {
179                            true
180                        };
181
182                        if edges_match {
183                            product.add_directed_edge(node_ij, neighbor_ij);
184                        }
185                    }
186                }
187            }
188        }
189
190        product
191    }
192
193    /// Compute random walk kernel value between two graphs
194    fn kernel_value(&self, g1: &Graph, g2: &Graph) -> f64 {
195        let product = self.product_graph(g1, g2);
196
197        // Count walks in product graph using matrix powers
198        let n = product.num_nodes;
199        if n == 0 {
200            return 0.0;
201        }
202
203        // Build adjacency matrix
204        let mut adj = Array2::zeros((n, n));
205        for (&from, neighbors) in &product.adjacency {
206            for &to in neighbors {
207                adj[(from, to)] = 1.0;
208            }
209        }
210
211        // Compute sum of matrix powers: I + λA + λ²A² + ... + λᵏAᵏ
212        let result = Array2::eye(n);
213        let mut current_power = Array2::eye(n);
214        let mut total = result.clone();
215
216        for k in 1..=self.max_length {
217            current_power = current_power.dot(&adj);
218            total = total + self.lambda.powi(k as i32) * &current_power;
219        }
220
221        // Sum all entries (total number of walks)
222        total.sum()
223    }
224}
225
226/// Fitted random walk kernel
227#[derive(Debug, Clone)]
228/// FittedRandomWalkKernel
229pub struct FittedRandomWalkKernel {
230    /// Training graphs
231    training_graphs: Vec<Graph>,
232    /// Kernel parameters
233    max_length: usize,
234    lambda: f64,
235    use_node_labels: bool,
236    use_edge_labels: bool,
237}
238
239impl Fit<Vec<Graph>, ()> for RandomWalkKernel {
240    type Fitted = FittedRandomWalkKernel;
241    fn fit(self, graphs: &Vec<Graph>, _y: &()) -> Result<Self::Fitted> {
242        Ok(FittedRandomWalkKernel {
243            training_graphs: graphs.clone(),
244            max_length: self.max_length,
245            lambda: self.lambda,
246            use_node_labels: self.use_node_labels,
247            use_edge_labels: self.use_edge_labels,
248        })
249    }
250}
251
252impl Transform<Vec<Graph>, Array2<f64>> for FittedRandomWalkKernel {
253    fn transform(&self, graphs: &Vec<Graph>) -> Result<Array2<f64>> {
254        let n_test = graphs.len();
255        let n_train = self.training_graphs.len();
256        let mut kernel_matrix = Array2::zeros((n_test, n_train));
257
258        let kernel = RandomWalkKernel {
259            max_length: self.max_length,
260            lambda: self.lambda,
261            use_node_labels: self.use_node_labels,
262            use_edge_labels: self.use_edge_labels,
263        };
264
265        for i in 0..n_test {
266            for j in 0..n_train {
267                kernel_matrix[(i, j)] = kernel.kernel_value(&graphs[i], &self.training_graphs[j]);
268            }
269        }
270
271        Ok(kernel_matrix)
272    }
273}
274
275/// Shortest path kernel for graphs
276#[derive(Debug, Clone)]
277/// ShortestPathKernel
278pub struct ShortestPathKernel {
279    /// Whether to use node labels
280    use_node_labels: bool,
281    /// Whether to normalize by graph size
282    normalize: bool,
283}
284
285impl Default for ShortestPathKernel {
286    fn default() -> Self {
287        Self::new()
288    }
289}
290
291impl ShortestPathKernel {
292    pub fn new() -> Self {
293        Self {
294            use_node_labels: false,
295            normalize: true,
296        }
297    }
298
299    /// Enable node labels
300    pub fn use_node_labels(mut self, use_labels: bool) -> Self {
301        self.use_node_labels = use_labels;
302        self
303    }
304
305    /// Set normalization
306    pub fn normalize(mut self, normalize: bool) -> Self {
307        self.normalize = normalize;
308        self
309    }
310
311    /// Compute shortest paths between all pairs of nodes
312    fn all_pairs_shortest_paths(&self, graph: &Graph) -> HashMap<(usize, usize), usize> {
313        let mut distances = HashMap::new();
314        let nodes = graph.nodes();
315
316        // Initialize distances
317        for &i in &nodes {
318            for &j in &nodes {
319                if i == j {
320                    distances.insert((i, j), 0);
321                } else {
322                    distances.insert((i, j), usize::MAX);
323                }
324            }
325        }
326
327        // Set direct edge distances
328        for (&from, neighbors) in &graph.adjacency {
329            for &to in neighbors {
330                distances.insert((from, to), 1);
331            }
332        }
333
334        // Floyd-Warshall algorithm
335        for &k in &nodes {
336            for &i in &nodes {
337                for &j in &nodes {
338                    if let (Some(&dist_ik), Some(&dist_kj)) =
339                        (distances.get(&(i, k)), distances.get(&(k, j)))
340                    {
341                        if dist_ik != usize::MAX && dist_kj != usize::MAX {
342                            let new_dist = dist_ik + dist_kj;
343                            if let Some(current_dist) = distances.get_mut(&(i, j)) {
344                                if new_dist < *current_dist {
345                                    *current_dist = new_dist;
346                                }
347                            }
348                        }
349                    }
350                }
351            }
352        }
353
354        distances
355    }
356
357    /// Extract shortest path features
358    fn extract_features(&self, graph: &Graph) -> HashMap<String, usize> {
359        let distances = self.all_pairs_shortest_paths(graph);
360        let mut features = HashMap::new();
361
362        for ((i, j), &dist) in &distances {
363            if dist != usize::MAX {
364                let feature = if self.use_node_labels {
365                    if let Some(ref labels) = graph.node_labels {
366                        let label_i = labels.get(i).cloned().unwrap_or_default();
367                        let label_j = labels.get(j).cloned().unwrap_or_default();
368                        format!("{}:{}:{}", label_i, label_j, dist)
369                    } else {
370                        format!("path:{}", dist)
371                    }
372                } else {
373                    format!("path:{}", dist)
374                };
375
376                *features.entry(feature).or_insert(0) += 1;
377            }
378        }
379
380        features
381    }
382
383    /// Compute kernel value between two graphs
384    fn kernel_value(&self, g1: &Graph, g2: &Graph) -> f64 {
385        let features1 = self.extract_features(g1);
386        let features2 = self.extract_features(g2);
387
388        let mut dot_product = 0.0;
389        for (feature, &count1) in &features1 {
390            if let Some(&count2) = features2.get(feature) {
391                dot_product += (count1 * count2) as f64;
392            }
393        }
394
395        if self.normalize {
396            let norm1 = features1
397                .values()
398                .map(|&x| (x * x) as f64)
399                .sum::<f64>()
400                .sqrt();
401            let norm2 = features2
402                .values()
403                .map(|&x| (x * x) as f64)
404                .sum::<f64>()
405                .sqrt();
406            if norm1 > 0.0 && norm2 > 0.0 {
407                dot_product / (norm1 * norm2)
408            } else {
409                0.0
410            }
411        } else {
412            dot_product
413        }
414    }
415}
416
417/// Fitted shortest path kernel
418#[derive(Debug, Clone)]
419/// FittedShortestPathKernel
420pub struct FittedShortestPathKernel {
421    /// Training graphs
422    training_graphs: Vec<Graph>,
423    /// Use node labels
424    use_node_labels: bool,
425    /// Normalize flag
426    normalize: bool,
427}
428
429impl Fit<Vec<Graph>, ()> for ShortestPathKernel {
430    type Fitted = FittedShortestPathKernel;
431
432    fn fit(self, graphs: &Vec<Graph>, _y: &()) -> Result<Self::Fitted> {
433        Ok(FittedShortestPathKernel {
434            training_graphs: graphs.clone(),
435            use_node_labels: self.use_node_labels,
436            normalize: self.normalize,
437        })
438    }
439}
440
441impl Transform<Vec<Graph>, Array2<f64>> for FittedShortestPathKernel {
442    fn transform(&self, graphs: &Vec<Graph>) -> Result<Array2<f64>> {
443        let n_test = graphs.len();
444        let n_train = self.training_graphs.len();
445        let mut kernel_matrix = Array2::zeros((n_test, n_train));
446
447        let kernel = ShortestPathKernel {
448            use_node_labels: self.use_node_labels,
449            normalize: self.normalize,
450        };
451
452        for i in 0..n_test {
453            for j in 0..n_train {
454                kernel_matrix[(i, j)] = kernel.kernel_value(&graphs[i], &self.training_graphs[j]);
455            }
456        }
457
458        Ok(kernel_matrix)
459    }
460}
461
462/// Weisfeiler-Lehman kernel for graphs
463#[derive(Debug, Clone)]
464/// WeisfeilerLehmanKernel
465pub struct WeisfeilerLehmanKernel {
466    /// Number of iterations
467    iterations: usize,
468    /// Whether to use original node labels
469    use_node_labels: bool,
470    /// Whether to normalize features
471    normalize: bool,
472}
473
474impl WeisfeilerLehmanKernel {
475    pub fn new(iterations: usize) -> Self {
476        Self {
477            iterations,
478            use_node_labels: false,
479            normalize: true,
480        }
481    }
482
483    /// Enable node labels
484    pub fn use_node_labels(mut self, use_labels: bool) -> Self {
485        self.use_node_labels = use_labels;
486        self
487    }
488
489    /// Set normalization
490    pub fn normalize(mut self, normalize: bool) -> Self {
491        self.normalize = normalize;
492        self
493    }
494
495    /// Perform Weisfeiler-Lehman relabeling
496    fn wl_relabel(&self, graph: &Graph) -> Vec<HashMap<usize, String>> {
497        let mut labelings = Vec::new();
498        let nodes = graph.nodes();
499
500        // Initial labeling
501        let mut current_labels = HashMap::new();
502        for &node in &nodes {
503            let initial_label = if self.use_node_labels {
504                graph
505                    .node_labels
506                    .as_ref()
507                    .and_then(|labels| labels.get(&node))
508                    .cloned()
509                    .unwrap_or_else(|| "default".to_string())
510            } else {
511                "1".to_string()
512            };
513            current_labels.insert(node, initial_label);
514        }
515        labelings.push(current_labels.clone());
516
517        // Iterative relabeling
518        for _iter in 0..self.iterations {
519            let mut new_labels = HashMap::new();
520
521            for &node in &nodes {
522                let mut neighbor_labels = Vec::new();
523                for &neighbor in &graph.neighbors(node) {
524                    if let Some(label) = current_labels.get(&neighbor) {
525                        neighbor_labels.push(label.clone());
526                    }
527                }
528                neighbor_labels.sort();
529
530                let current_label = current_labels.get(&node).cloned().unwrap_or_default();
531                let new_label = format!("{}:{}", current_label, neighbor_labels.join(","));
532                new_labels.insert(node, new_label);
533            }
534
535            labelings.push(new_labels.clone());
536            current_labels = new_labels;
537        }
538
539        labelings
540    }
541
542    /// Extract features from WL labelings
543    fn extract_features(&self, graph: &Graph) -> HashMap<String, usize> {
544        let labelings = self.wl_relabel(graph);
545        let mut features = HashMap::new();
546
547        for labeling in labelings {
548            for (_, label) in labeling {
549                *features.entry(label).or_insert(0) += 1;
550            }
551        }
552
553        features
554    }
555
556    /// Compute kernel value between two graphs
557    fn kernel_value(&self, g1: &Graph, g2: &Graph) -> f64 {
558        let features1 = self.extract_features(g1);
559        let features2 = self.extract_features(g2);
560
561        let mut dot_product = 0.0;
562        for (feature, &count1) in &features1 {
563            if let Some(&count2) = features2.get(feature) {
564                dot_product += (count1 * count2) as f64;
565            }
566        }
567
568        if self.normalize {
569            let norm1 = features1
570                .values()
571                .map(|&x| (x * x) as f64)
572                .sum::<f64>()
573                .sqrt();
574            let norm2 = features2
575                .values()
576                .map(|&x| (x * x) as f64)
577                .sum::<f64>()
578                .sqrt();
579            if norm1 > 0.0 && norm2 > 0.0 {
580                dot_product / (norm1 * norm2)
581            } else {
582                0.0
583            }
584        } else {
585            dot_product
586        }
587    }
588}
589
590/// Fitted Weisfeiler-Lehman kernel
591#[derive(Debug, Clone)]
592/// FittedWeisfeilerLehmanKernel
593pub struct FittedWeisfeilerLehmanKernel {
594    /// Training graphs
595    training_graphs: Vec<Graph>,
596    /// Number of iterations
597    iterations: usize,
598    /// Use node labels
599    use_node_labels: bool,
600    /// Normalize flag
601    normalize: bool,
602}
603
604impl Fit<Vec<Graph>, ()> for WeisfeilerLehmanKernel {
605    type Fitted = FittedWeisfeilerLehmanKernel;
606    fn fit(self, graphs: &Vec<Graph>, _y: &()) -> Result<Self::Fitted> {
607        Ok(FittedWeisfeilerLehmanKernel {
608            training_graphs: graphs.clone(),
609            iterations: self.iterations,
610            use_node_labels: self.use_node_labels,
611            normalize: self.normalize,
612        })
613    }
614}
615
616impl Transform<Vec<Graph>, Array2<f64>> for FittedWeisfeilerLehmanKernel {
617    fn transform(&self, graphs: &Vec<Graph>) -> Result<Array2<f64>> {
618        let n_test = graphs.len();
619        let n_train = self.training_graphs.len();
620        let mut kernel_matrix = Array2::zeros((n_test, n_train));
621
622        let kernel = WeisfeilerLehmanKernel {
623            iterations: self.iterations,
624            use_node_labels: self.use_node_labels,
625            normalize: self.normalize,
626        };
627
628        for i in 0..n_test {
629            for j in 0..n_train {
630                kernel_matrix[(i, j)] = kernel.kernel_value(&graphs[i], &self.training_graphs[j]);
631            }
632        }
633
634        Ok(kernel_matrix)
635    }
636}
637
638/// Subgraph kernel that counts common subgraph patterns
639#[derive(Debug, Clone)]
640/// SubgraphKernel
641pub struct SubgraphKernel {
642    /// Maximum subgraph size
643    max_size: usize,
644    /// Whether to use connected subgraphs only
645    connected_only: bool,
646    /// Whether to normalize features
647    normalize: bool,
648}
649
650impl SubgraphKernel {
651    pub fn new(max_size: usize) -> Self {
652        Self {
653            max_size,
654            connected_only: true,
655            normalize: true,
656        }
657    }
658
659    /// Set connected subgraphs only
660    pub fn connected_only(mut self, connected: bool) -> Self {
661        self.connected_only = connected;
662        self
663    }
664
665    /// Set normalization
666    pub fn normalize(mut self, normalize: bool) -> Self {
667        self.normalize = normalize;
668        self
669    }
670
671    /// Find all connected subgraphs of given size
672    fn find_connected_subgraphs(&self, graph: &Graph, size: usize) -> Vec<Vec<usize>> {
673        if size == 0 {
674            return vec![];
675        }
676
677        let mut subgraphs = Vec::new();
678        let nodes = graph.nodes();
679
680        // Generate all combinations of nodes of given size
681        let combinations = self.combinations(&nodes, size);
682
683        for combination in combinations {
684            if self.is_connected_subgraph(graph, &combination) {
685                subgraphs.push(combination);
686            }
687        }
688
689        subgraphs
690    }
691
692    /// Check if a set of nodes forms a connected subgraph
693    fn is_connected_subgraph(&self, graph: &Graph, nodes: &[usize]) -> bool {
694        if nodes.len() <= 1 {
695            return true;
696        }
697
698        let node_set: HashSet<_> = nodes.iter().collect();
699        let mut visited = HashSet::new();
700        let mut queue = VecDeque::new();
701
702        // Start BFS from first node
703        queue.push_back(nodes[0]);
704        visited.insert(nodes[0]);
705
706        while let Some(current) = queue.pop_front() {
707            for &neighbor in &graph.neighbors(current) {
708                if node_set.contains(&neighbor) && !visited.contains(&neighbor) {
709                    visited.insert(neighbor);
710                    queue.push_back(neighbor);
711                }
712            }
713        }
714
715        visited.len() == nodes.len()
716    }
717
718    /// Generate all combinations of k elements from a vector
719    fn combinations(&self, items: &[usize], k: usize) -> Vec<Vec<usize>> {
720        if k == 0 {
721            return vec![vec![]];
722        }
723        if k > items.len() {
724            return vec![];
725        }
726        if k == items.len() {
727            return vec![items.to_vec()];
728        }
729
730        let mut result = Vec::new();
731
732        // Include first element
733        let with_first = self.combinations(&items[1..], k - 1);
734        for mut combo in with_first {
735            combo.insert(0, items[0]);
736            result.push(combo);
737        }
738
739        // Exclude first element
740        let without_first = self.combinations(&items[1..], k);
741        result.extend(without_first);
742
743        result
744    }
745
746    /// Convert subgraph to canonical string representation
747    fn subgraph_to_string(&self, graph: &Graph, nodes: &[usize]) -> String {
748        let mut edges = Vec::new();
749        let node_set: HashSet<_> = nodes.iter().collect();
750
751        for &node in nodes {
752            for &neighbor in &graph.neighbors(node) {
753                if node_set.contains(&neighbor) && node < neighbor {
754                    edges.push((node, neighbor));
755                }
756            }
757        }
758
759        edges.sort();
760        format!("nodes:{},edges:{:?}", nodes.len(), edges)
761    }
762
763    /// Extract subgraph features
764    fn extract_features(&self, graph: &Graph) -> HashMap<String, usize> {
765        let mut features = HashMap::new();
766
767        for size in 1..=self.max_size {
768            let subgraphs = if self.connected_only {
769                self.find_connected_subgraphs(graph, size)
770            } else {
771                // For simplicity, just use connected subgraphs
772                self.find_connected_subgraphs(graph, size)
773            };
774
775            for subgraph in subgraphs {
776                let feature = self.subgraph_to_string(graph, &subgraph);
777                *features.entry(feature).or_insert(0) += 1;
778            }
779        }
780
781        features
782    }
783
784    /// Compute kernel value between two graphs
785    fn kernel_value(&self, g1: &Graph, g2: &Graph) -> f64 {
786        let features1 = self.extract_features(g1);
787        let features2 = self.extract_features(g2);
788
789        let mut dot_product = 0.0;
790        for (feature, &count1) in &features1 {
791            if let Some(&count2) = features2.get(feature) {
792                dot_product += (count1 * count2) as f64;
793            }
794        }
795
796        if self.normalize {
797            let norm1 = features1
798                .values()
799                .map(|&x| (x * x) as f64)
800                .sum::<f64>()
801                .sqrt();
802            let norm2 = features2
803                .values()
804                .map(|&x| (x * x) as f64)
805                .sum::<f64>()
806                .sqrt();
807            if norm1 > 0.0 && norm2 > 0.0 {
808                dot_product / (norm1 * norm2)
809            } else {
810                0.0
811            }
812        } else {
813            dot_product
814        }
815    }
816}
817
818/// Fitted subgraph kernel
819#[derive(Debug, Clone)]
820/// FittedSubgraphKernel
821pub struct FittedSubgraphKernel {
822    /// Training graphs
823    training_graphs: Vec<Graph>,
824    /// Max subgraph size
825    max_size: usize,
826    /// Connected only flag
827    connected_only: bool,
828    /// Normalize flag
829    normalize: bool,
830}
831
832impl Fit<Vec<Graph>, ()> for SubgraphKernel {
833    type Fitted = FittedSubgraphKernel;
834
835    fn fit(self, graphs: &Vec<Graph>, _y: &()) -> Result<Self::Fitted> {
836        Ok(FittedSubgraphKernel {
837            training_graphs: graphs.clone(),
838            max_size: self.max_size,
839            connected_only: self.connected_only,
840            normalize: self.normalize,
841        })
842    }
843}
844
845impl Transform<Vec<Graph>, Array2<f64>> for FittedSubgraphKernel {
846    fn transform(&self, graphs: &Vec<Graph>) -> Result<Array2<f64>> {
847        let n_test = graphs.len();
848        let n_train = self.training_graphs.len();
849        let mut kernel_matrix = Array2::zeros((n_test, n_train));
850
851        let kernel = SubgraphKernel {
852            max_size: self.max_size,
853            connected_only: self.connected_only,
854            normalize: self.normalize,
855        };
856
857        for i in 0..n_test {
858            for j in 0..n_train {
859                kernel_matrix[(i, j)] = kernel.kernel_value(&graphs[i], &self.training_graphs[j]);
860            }
861        }
862
863        Ok(kernel_matrix)
864    }
865}
866
867#[allow(non_snake_case)]
868#[cfg(test)]
869mod tests {
870    use super::*;
871    use approx::assert_abs_diff_eq;
872
873    fn create_test_graph(edges: Vec<(usize, usize)>, num_nodes: usize) -> Graph {
874        let mut graph = Graph::new(num_nodes);
875        for (from, to) in edges {
876            graph.add_edge(from, to);
877        }
878        graph
879    }
880
881    #[test]
882    fn test_graph_creation() {
883        let mut graph = Graph::new(3);
884        graph.add_edge(0, 1);
885        graph.add_edge(1, 2);
886
887        assert_eq!(graph.neighbors(0), vec![1]);
888        assert_eq!(graph.neighbors(1), vec![0, 2]);
889        assert_eq!(graph.neighbors(2), vec![1]);
890        assert_eq!(graph.nodes(), vec![0, 1, 2]);
891    }
892
893    #[test]
894    fn test_random_walk_kernel() {
895        let kernel = RandomWalkKernel::new(3, 0.1);
896
897        let graph1 = create_test_graph(vec![(0, 1), (1, 2)], 3);
898        let graph2 = create_test_graph(vec![(0, 1), (1, 2)], 3);
899        let graph3 = create_test_graph(vec![(0, 1), (0, 2), (1, 2)], 3);
900
901        let graphs = vec![graph1, graph2, graph3];
902        let fitted = kernel.fit(&graphs, &()).unwrap();
903        let kernel_matrix = fitted.transform(&graphs).unwrap();
904
905        assert_eq!(kernel_matrix.shape(), &[3, 3]);
906
907        // Identical graphs should have same kernel value
908        assert_abs_diff_eq!(
909            kernel_matrix[(0, 0)],
910            kernel_matrix[(1, 1)],
911            epsilon = 1e-10
912        );
913        assert_abs_diff_eq!(
914            kernel_matrix[(0, 1)],
915            kernel_matrix[(1, 0)],
916            epsilon = 1e-10
917        );
918    }
919
920    #[test]
921    fn test_shortest_path_kernel() {
922        let kernel = ShortestPathKernel::new();
923
924        let graph1 = create_test_graph(vec![(0, 1), (1, 2)], 3);
925        let graph2 = create_test_graph(vec![(0, 1), (1, 2), (0, 2)], 3);
926
927        let graphs = vec![graph1, graph2];
928        let fitted = kernel.fit(&graphs, &()).unwrap();
929        let kernel_matrix = fitted.transform(&graphs).unwrap();
930
931        assert_eq!(kernel_matrix.shape(), &[2, 2]);
932        // Kernel values should be non-negative and finite
933        assert!(kernel_matrix.iter().all(|&x| x >= 0.0 && x.is_finite()));
934
935        // Self-similarity should be 1.0 for normalized kernels
936        assert_abs_diff_eq!(kernel_matrix[(0, 0)], 1.0, epsilon = 1e-10);
937        assert_abs_diff_eq!(kernel_matrix[(1, 1)], 1.0, epsilon = 1e-10);
938
939        // Matrix should be symmetric
940        assert_abs_diff_eq!(
941            kernel_matrix[(0, 1)],
942            kernel_matrix[(1, 0)],
943            epsilon = 1e-10
944        );
945
946        // All off-diagonal elements should be between 0 and 1 for normalized cosine similarity
947        assert!(kernel_matrix[(0, 1)] >= 0.0 && kernel_matrix[(0, 1)] <= 1.0);
948    }
949
950    #[test]
951    fn test_weisfeiler_lehman_kernel() {
952        let kernel = WeisfeilerLehmanKernel::new(2);
953
954        let graph1 = create_test_graph(vec![(0, 1), (1, 2)], 3);
955        let graph2 = create_test_graph(vec![(0, 1), (1, 2)], 3);
956        let graph3 = create_test_graph(vec![(0, 1), (0, 2), (1, 2)], 3);
957
958        let graphs = vec![graph1, graph2, graph3];
959        let fitted = kernel.fit(&graphs, &()).unwrap();
960        let kernel_matrix = fitted.transform(&graphs).unwrap();
961
962        assert_eq!(kernel_matrix.shape(), &[3, 3]);
963        assert!(kernel_matrix
964            .iter()
965            .all(|&x| x >= 0.0 && x <= 1.0 && x.is_finite()));
966
967        // Identical graphs should have similarity 1.0
968        assert_abs_diff_eq!(kernel_matrix[(0, 1)], 1.0, epsilon = 1e-10);
969    }
970
971    #[test]
972    fn test_subgraph_kernel() {
973        let kernel = SubgraphKernel::new(2);
974
975        let graph1 = create_test_graph(vec![(0, 1)], 2);
976        let graph2 = create_test_graph(vec![(0, 1), (1, 2)], 3);
977
978        let graphs = vec![graph1, graph2];
979        let fitted = kernel.fit(&graphs, &()).unwrap();
980        let kernel_matrix = fitted.transform(&graphs).unwrap();
981
982        assert_eq!(kernel_matrix.shape(), &[2, 2]);
983        assert!(kernel_matrix.iter().all(|&x| x >= 0.0 && x.is_finite()));
984    }
985
986    #[test]
987    fn test_graph_with_labels() {
988        let mut graph = Graph::new(3);
989        graph.add_edge(0, 1);
990        graph.add_edge(1, 2);
991
992        let mut labels = HashMap::new();
993        labels.insert(0, "A".to_string());
994        labels.insert(1, "B".to_string());
995        labels.insert(2, "A".to_string());
996        graph.set_node_labels(labels);
997
998        let kernel = WeisfeilerLehmanKernel::new(1).use_node_labels(true);
999        let graphs = vec![graph];
1000        let fitted = kernel.fit(&graphs, &()).unwrap();
1001        let kernel_matrix = fitted.transform(&graphs).unwrap();
1002
1003        assert_eq!(kernel_matrix.shape(), &[1, 1]);
1004        assert_abs_diff_eq!(kernel_matrix[(0, 0)], 1.0, epsilon = 1e-10);
1005    }
1006
1007    #[test]
1008    fn test_shortest_path_computation() {
1009        let kernel = ShortestPathKernel::new();
1010        let graph = create_test_graph(vec![(0, 1), (1, 2), (2, 3)], 4);
1011
1012        let distances = kernel.all_pairs_shortest_paths(&graph);
1013
1014        assert_eq!(distances[&(0, 3)], 3);
1015        assert_eq!(distances[&(0, 1)], 1);
1016        assert_eq!(distances[&(1, 3)], 2);
1017    }
1018
1019    #[test]
1020    fn test_subgraph_connectivity() {
1021        let kernel = SubgraphKernel::new(3);
1022        let graph = create_test_graph(vec![(0, 1), (2, 3)], 4); // Disconnected graph
1023
1024        assert!(!kernel.is_connected_subgraph(&graph, &[0, 1, 2]));
1025        assert!(kernel.is_connected_subgraph(&graph, &[0, 1]));
1026        assert!(kernel.is_connected_subgraph(&graph, &[2, 3]));
1027    }
1028}