Skip to main content

tensorlogic_sklears_kernels/
graph_kernel.rs

1//! Graph kernels for measuring similarity between structured data.
2//!
3//! These kernels operate on graph representations of logical expressions
4//! and measure structural similarity through various graph-theoretic properties.
5
6use std::collections::HashMap;
7
8use serde::{Deserialize, Serialize};
9use tensorlogic_ir::TLExpr;
10
11use crate::error::{KernelError, Result};
12use crate::types::Kernel;
13
14/// Simple graph representation for kernel computation
15#[derive(Clone, Debug, PartialEq, Eq)]
16pub struct Graph {
17    /// Number of nodes
18    pub n_nodes: usize,
19    /// Edge list (from, to, edge_type)
20    pub edges: Vec<(usize, usize, String)>,
21    /// Node labels
22    pub node_labels: Vec<String>,
23}
24
25impl Graph {
26    /// Create a new graph
27    pub fn new(n_nodes: usize) -> Self {
28        Self {
29            n_nodes,
30            edges: Vec::new(),
31            node_labels: vec!["node".to_string(); n_nodes],
32        }
33    }
34
35    /// Add an edge to the graph
36    pub fn add_edge(&mut self, from: usize, to: usize, edge_type: String) {
37        if from < self.n_nodes && to < self.n_nodes {
38            self.edges.push((from, to, edge_type));
39        }
40    }
41
42    /// Set node label
43    pub fn set_node_label(&mut self, node: usize, label: String) {
44        if node < self.n_nodes {
45            self.node_labels[node] = label;
46        }
47    }
48
49    /// Get adjacency list representation
50    pub fn adjacency_list(&self) -> Vec<Vec<usize>> {
51        let mut adj = vec![Vec::new(); self.n_nodes];
52        for &(from, to, _) in &self.edges {
53            adj[from].push(to);
54        }
55        adj
56    }
57
58    /// Get neighbors of a node
59    pub fn neighbors(&self, node: usize) -> Vec<usize> {
60        self.edges
61            .iter()
62            .filter(|(from, _, _)| *from == node)
63            .map(|(_, to, _)| *to)
64            .collect()
65    }
66
67    /// Convert TLExpr to graph representation
68    pub fn from_tlexpr(expr: &TLExpr) -> Self {
69        let mut graph = Graph::new(0);
70        let mut node_id = 0;
71        Self::build_graph_recursive(expr, &mut graph, &mut node_id, None);
72        graph
73    }
74
75    fn build_graph_recursive(
76        expr: &TLExpr,
77        graph: &mut Graph,
78        node_id: &mut usize,
79        parent: Option<usize>,
80    ) -> usize {
81        let current_id = *node_id;
82        *node_id += 1;
83        graph.n_nodes += 1;
84
85        // Set node label based on expression type
86        let label = match expr {
87            TLExpr::Pred { name, .. } => format!("pred:{}", name),
88            TLExpr::And(_, _) => "and".to_string(),
89            TLExpr::Or(_, _) => "or".to_string(),
90            TLExpr::Not(_) => "not".to_string(),
91            TLExpr::Exists { domain, .. } => format!("exists:{}", domain),
92            TLExpr::ForAll { domain, .. } => format!("forall:{}", domain),
93            TLExpr::Imply(_, _) => "imply".to_string(),
94            _ => "unknown".to_string(),
95        };
96
97        graph.node_labels.push(label.clone());
98
99        // Add edge from parent if it exists
100        if let Some(parent_id) = parent {
101            graph.add_edge(parent_id, current_id, "child".to_string());
102        }
103
104        // Recursively process children
105        match expr {
106            TLExpr::And(left, right) | TLExpr::Or(left, right) | TLExpr::Imply(left, right) => {
107                Self::build_graph_recursive(left, graph, node_id, Some(current_id));
108                Self::build_graph_recursive(right, graph, node_id, Some(current_id));
109            }
110            TLExpr::Not(inner) => {
111                Self::build_graph_recursive(inner, graph, node_id, Some(current_id));
112            }
113            TLExpr::Exists { body, .. } | TLExpr::ForAll { body, .. } => {
114                Self::build_graph_recursive(body, graph, node_id, Some(current_id));
115            }
116            _ => {}
117        }
118
119        current_id
120    }
121}
122
123/// Subgraph matching kernel configuration
124#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
125pub struct SubgraphMatchingConfig {
126    /// Maximum subgraph size to consider
127    pub max_subgraph_size: usize,
128    /// Whether to normalize by graph sizes
129    pub normalize: bool,
130}
131
132impl SubgraphMatchingConfig {
133    /// Create default configuration
134    pub fn new() -> Self {
135        Self {
136            max_subgraph_size: 3,
137            normalize: true,
138        }
139    }
140
141    /// Set maximum subgraph size
142    pub fn with_max_size(mut self, size: usize) -> Self {
143        self.max_subgraph_size = size;
144        self
145    }
146}
147
148impl Default for SubgraphMatchingConfig {
149    fn default() -> Self {
150        Self::new()
151    }
152}
153
154/// Subgraph matching kernel
155///
156/// Measures similarity by counting common subgraphs between two graphs.
157pub struct SubgraphMatchingKernel {
158    config: SubgraphMatchingConfig,
159}
160
161impl SubgraphMatchingKernel {
162    /// Create a new subgraph matching kernel
163    pub fn new(config: SubgraphMatchingConfig) -> Self {
164        Self { config }
165    }
166
167    /// Count subgraphs of given size in a graph
168    fn count_subgraphs(&self, graph: &Graph, size: usize) -> HashMap<String, usize> {
169        let mut subgraph_counts = HashMap::new();
170
171        if size > graph.n_nodes {
172            return subgraph_counts;
173        }
174
175        // For simplicity, count node label patterns
176        // More sophisticated: enumerate all connected subgraphs
177        for node in 0..graph.n_nodes {
178            let pattern = self.extract_pattern(graph, node, size);
179            *subgraph_counts.entry(pattern).or_insert(0) += 1;
180        }
181
182        subgraph_counts
183    }
184
185    /// Extract local pattern around a node
186    fn extract_pattern(&self, graph: &Graph, start: usize, depth: usize) -> String {
187        let mut pattern_parts = vec![graph.node_labels[start].clone()];
188
189        if depth > 1 {
190            let neighbors = graph.neighbors(start);
191            let mut neighbor_labels: Vec<_> = neighbors
192                .iter()
193                .map(|&n| graph.node_labels[n].clone())
194                .collect();
195            neighbor_labels.sort();
196            pattern_parts.extend(neighbor_labels);
197        }
198
199        pattern_parts.join("|")
200    }
201
202    /// Compute similarity between two graphs
203    pub fn compute_graphs(&self, g1: &Graph, g2: &Graph) -> Result<f64> {
204        let mut total_similarity = 0.0;
205
206        for size in 1..=self.config.max_subgraph_size {
207            let counts1 = self.count_subgraphs(g1, size);
208            let counts2 = self.count_subgraphs(g2, size);
209
210            // Compute intersection
211            let mut intersection = 0.0;
212            for (pattern, count1) in &counts1 {
213                if let Some(count2) = counts2.get(pattern) {
214                    intersection += (*count1).min(*count2) as f64;
215                }
216            }
217
218            total_similarity += intersection;
219        }
220
221        if self.config.normalize {
222            let max_size = (g1.n_nodes.max(g2.n_nodes)) as f64;
223            if max_size > 0.0 {
224                total_similarity /= max_size;
225            }
226        }
227
228        Ok(total_similarity)
229    }
230}
231
232impl Kernel for SubgraphMatchingKernel {
233    fn compute(&self, x: &[f64], _y: &[f64]) -> Result<f64> {
234        // For basic kernel trait compatibility, return a placeholder
235        // Real usage should use compute_graphs
236        Ok(x.iter().sum::<f64>())
237    }
238
239    fn name(&self) -> &str {
240        "SubgraphMatching"
241    }
242}
243
244/// Walk-based kernel configuration
245#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
246pub struct WalkKernelConfig {
247    /// Maximum walk length
248    pub max_walk_length: usize,
249    /// Decay factor for longer walks
250    pub decay_factor: f64,
251    /// Whether to normalize
252    pub normalize: bool,
253}
254
255impl WalkKernelConfig {
256    /// Create default configuration
257    pub fn new() -> Self {
258        Self {
259            max_walk_length: 4,
260            decay_factor: 0.8,
261            normalize: true,
262        }
263    }
264
265    /// Set maximum walk length
266    pub fn with_max_length(mut self, length: usize) -> Self {
267        self.max_walk_length = length;
268        self
269    }
270
271    /// Set decay factor
272    pub fn with_decay(mut self, decay: f64) -> Self {
273        self.decay_factor = decay;
274        self
275    }
276}
277
278impl Default for WalkKernelConfig {
279    fn default() -> Self {
280        Self::new()
281    }
282}
283
284/// Random walk kernel
285///
286/// Measures similarity by counting common random walks between graphs.
287pub struct RandomWalkKernel {
288    config: WalkKernelConfig,
289}
290
291impl RandomWalkKernel {
292    /// Create a new random walk kernel
293    pub fn new(config: WalkKernelConfig) -> Result<Self> {
294        if config.decay_factor <= 0.0 || config.decay_factor > 1.0 {
295            return Err(KernelError::InvalidParameter {
296                parameter: "decay_factor".to_string(),
297                value: config.decay_factor.to_string(),
298                reason: "must be in (0, 1]".to_string(),
299            });
300        }
301
302        Ok(Self { config })
303    }
304
305    /// Extract walks from a graph
306    fn extract_walks(&self, graph: &Graph) -> HashMap<Vec<String>, usize> {
307        let mut walk_counts = HashMap::new();
308        let adj = graph.adjacency_list();
309
310        for start in 0..graph.n_nodes {
311            self.dfs_walks(
312                graph,
313                &adj,
314                start,
315                vec![graph.node_labels[start].clone()],
316                &mut walk_counts,
317            );
318        }
319
320        walk_counts
321    }
322
323    /// DFS to enumerate walks
324    fn dfs_walks(
325        &self,
326        graph: &Graph,
327        adj: &[Vec<usize>],
328        current: usize,
329        path: Vec<String>,
330        walk_counts: &mut HashMap<Vec<String>, usize>,
331    ) {
332        if path.len() >= self.config.max_walk_length {
333            *walk_counts.entry(path).or_insert(0) += 1;
334            return;
335        }
336
337        // Add current path
338        *walk_counts.entry(path.clone()).or_insert(0) += 1;
339
340        // Continue walk
341        for &neighbor in &adj[current] {
342            let mut new_path = path.clone();
343            new_path.push(graph.node_labels[neighbor].clone());
344            self.dfs_walks(graph, adj, neighbor, new_path, walk_counts);
345        }
346    }
347
348    /// Compute similarity between two graphs
349    pub fn compute_graphs(&self, g1: &Graph, g2: &Graph) -> Result<f64> {
350        let walks1 = self.extract_walks(g1);
351        let walks2 = self.extract_walks(g2);
352
353        let mut similarity = 0.0;
354
355        for (walk, count1) in &walks1 {
356            if let Some(count2) = walks2.get(walk) {
357                let walk_sim = (*count1).min(*count2) as f64;
358                let decay = self.config.decay_factor.powi(walk.len() as i32);
359                similarity += walk_sim * decay;
360            }
361        }
362
363        if self.config.normalize {
364            let total1: usize = walks1.values().sum();
365            let total2: usize = walks2.values().sum();
366            let normalizer = ((total1 * total2) as f64).sqrt();
367            if normalizer > 0.0 {
368                similarity /= normalizer;
369            }
370        }
371
372        Ok(similarity)
373    }
374}
375
376impl Kernel for RandomWalkKernel {
377    fn compute(&self, x: &[f64], _y: &[f64]) -> Result<f64> {
378        // Placeholder for trait compatibility
379        Ok(x.iter().sum::<f64>())
380    }
381
382    fn name(&self) -> &str {
383        "RandomWalk"
384    }
385}
386
387/// Weisfeiler-Lehman kernel configuration
388#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
389pub struct WeisfeilerLehmanConfig {
390    /// Number of WL iterations
391    pub n_iterations: usize,
392    /// Whether to normalize
393    pub normalize: bool,
394}
395
396impl WeisfeilerLehmanConfig {
397    /// Create default configuration
398    pub fn new() -> Self {
399        Self {
400            n_iterations: 3,
401            normalize: true,
402        }
403    }
404
405    /// Set number of iterations
406    pub fn with_iterations(mut self, iterations: usize) -> Self {
407        self.n_iterations = iterations;
408        self
409    }
410}
411
412impl Default for WeisfeilerLehmanConfig {
413    fn default() -> Self {
414        Self::new()
415    }
416}
417
418/// Weisfeiler-Lehman (WL) kernel
419///
420/// Iteratively refines node labels based on neighborhood structure,
421/// then compares label histograms.
422pub struct WeisfeilerLehmanKernel {
423    config: WeisfeilerLehmanConfig,
424}
425
426impl WeisfeilerLehmanKernel {
427    /// Create a new WL kernel
428    pub fn new(config: WeisfeilerLehmanConfig) -> Self {
429        Self { config }
430    }
431
432    /// Perform one WL iteration
433    fn wl_iteration(&self, graph: &Graph, labels: &[String]) -> Vec<String> {
434        let mut new_labels = Vec::with_capacity(graph.n_nodes);
435        let adj = graph.adjacency_list();
436
437        for node in 0..graph.n_nodes {
438            // Collect neighbor labels
439            let mut neighbor_labels: Vec<String> =
440                adj[node].iter().map(|&n| labels[n].clone()).collect();
441
442            neighbor_labels.sort();
443
444            // Create new label by concatenating
445            let mut new_label = labels[node].clone();
446            for neighbor_label in neighbor_labels {
447                new_label.push('_');
448                new_label.push_str(&neighbor_label);
449            }
450
451            new_labels.push(new_label);
452        }
453
454        new_labels
455    }
456
457    /// Extract label histograms across all iterations
458    fn extract_label_histograms(&self, graph: &Graph) -> Vec<HashMap<String, usize>> {
459        let mut histograms = Vec::new();
460        let mut labels = graph.node_labels.clone();
461
462        for _ in 0..self.config.n_iterations {
463            // Count labels
464            let mut histogram = HashMap::new();
465            for label in &labels {
466                *histogram.entry(label.clone()).or_insert(0) += 1;
467            }
468            histograms.push(histogram);
469
470            // Update labels
471            labels = self.wl_iteration(graph, &labels);
472        }
473
474        histograms
475    }
476
477    /// Compute similarity between two graphs
478    pub fn compute_graphs(&self, g1: &Graph, g2: &Graph) -> Result<f64> {
479        let hists1 = self.extract_label_histograms(g1);
480        let hists2 = self.extract_label_histograms(g2);
481
482        let mut total_similarity = 0.0;
483
484        for (hist1, hist2) in hists1.iter().zip(hists2.iter()) {
485            // Compute histogram intersection
486            let mut intersection = 0.0;
487            for (label, count1) in hist1 {
488                if let Some(count2) = hist2.get(label) {
489                    intersection += (*count1).min(*count2) as f64;
490                }
491            }
492            total_similarity += intersection;
493        }
494
495        if self.config.normalize {
496            let size1 = g1.n_nodes as f64;
497            let size2 = g2.n_nodes as f64;
498            let normalizer = (size1 * size2).sqrt();
499            if normalizer > 0.0 {
500                total_similarity /= normalizer;
501            }
502        }
503
504        Ok(total_similarity)
505    }
506}
507
508impl Kernel for WeisfeilerLehmanKernel {
509    fn compute(&self, x: &[f64], _y: &[f64]) -> Result<f64> {
510        // Placeholder for trait compatibility
511        Ok(x.iter().sum::<f64>())
512    }
513
514    fn name(&self) -> &str {
515        "WeisfeilerLehman"
516    }
517}
518
519#[cfg(test)]
520mod tests {
521    use super::*;
522
523    #[test]
524    fn test_graph_creation() {
525        let mut graph = Graph::new(3);
526        graph.add_edge(0, 1, "edge".to_string());
527        graph.add_edge(1, 2, "edge".to_string());
528        graph.set_node_label(0, "A".to_string());
529        graph.set_node_label(1, "B".to_string());
530        graph.set_node_label(2, "C".to_string());
531
532        assert_eq!(graph.n_nodes, 3);
533        assert_eq!(graph.edges.len(), 2);
534        assert_eq!(graph.node_labels[0], "A");
535    }
536
537    #[test]
538    fn test_graph_from_tlexpr() {
539        let expr = TLExpr::and(TLExpr::pred("p1", vec![]), TLExpr::pred("p2", vec![]));
540
541        let graph = Graph::from_tlexpr(&expr);
542        assert!(graph.n_nodes > 0);
543        assert!(!graph.node_labels.is_empty());
544    }
545
546    #[test]
547    fn test_subgraph_matching_kernel() {
548        let config = SubgraphMatchingConfig::new().with_max_size(2);
549        let kernel = SubgraphMatchingKernel::new(config);
550
551        let mut g1 = Graph::new(3);
552        g1.add_edge(0, 1, "edge".to_string());
553        g1.add_edge(1, 2, "edge".to_string());
554
555        let mut g2 = Graph::new(3);
556        g2.add_edge(0, 1, "edge".to_string());
557        g2.add_edge(0, 2, "edge".to_string());
558
559        let sim = kernel.compute_graphs(&g1, &g2).unwrap();
560        assert!(sim >= 0.0);
561    }
562
563    #[test]
564    fn test_random_walk_kernel() {
565        let config = WalkKernelConfig::new().with_max_length(3);
566        let kernel = RandomWalkKernel::new(config).unwrap();
567
568        let mut g1 = Graph::new(3);
569        g1.add_edge(0, 1, "edge".to_string());
570        g1.add_edge(1, 2, "edge".to_string());
571
572        let mut g2 = Graph::new(3);
573        g2.add_edge(0, 1, "edge".to_string());
574        g2.add_edge(1, 2, "edge".to_string());
575
576        let sim = kernel.compute_graphs(&g1, &g2).unwrap();
577        assert!(sim > 0.0);
578    }
579
580    #[test]
581    fn test_random_walk_kernel_invalid_decay() {
582        let config = WalkKernelConfig::new().with_decay(1.5);
583        let result = RandomWalkKernel::new(config);
584        assert!(result.is_err());
585    }
586
587    #[test]
588    fn test_weisfeiler_lehman_kernel() {
589        let config = WeisfeilerLehmanConfig::new().with_iterations(2);
590        let kernel = WeisfeilerLehmanKernel::new(config);
591
592        let mut g1 = Graph::new(4);
593        g1.set_node_label(0, "A".to_string());
594        g1.set_node_label(1, "B".to_string());
595        g1.set_node_label(2, "B".to_string());
596        g1.set_node_label(3, "A".to_string());
597        g1.add_edge(0, 1, "edge".to_string());
598        g1.add_edge(1, 2, "edge".to_string());
599        g1.add_edge(2, 3, "edge".to_string());
600
601        let mut g2 = Graph::new(4);
602        g2.set_node_label(0, "A".to_string());
603        g2.set_node_label(1, "B".to_string());
604        g2.set_node_label(2, "B".to_string());
605        g2.set_node_label(3, "A".to_string());
606        g2.add_edge(0, 1, "edge".to_string());
607        g2.add_edge(1, 2, "edge".to_string());
608        g2.add_edge(2, 3, "edge".to_string());
609
610        let sim = kernel.compute_graphs(&g1, &g2).unwrap();
611        assert!(sim > 0.0);
612    }
613
614    #[test]
615    fn test_wl_self_similarity() {
616        let config = WeisfeilerLehmanConfig::new();
617        let kernel = WeisfeilerLehmanKernel::new(config);
618
619        let mut graph = Graph::new(3);
620        graph.add_edge(0, 1, "edge".to_string());
621        graph.add_edge(1, 2, "edge".to_string());
622
623        let sim = kernel.compute_graphs(&graph, &graph).unwrap();
624        assert!(sim > 0.0);
625    }
626
627    #[test]
628    fn test_graph_neighbors() {
629        let mut graph = Graph::new(3);
630        graph.add_edge(0, 1, "edge".to_string());
631        graph.add_edge(0, 2, "edge".to_string());
632
633        let neighbors = graph.neighbors(0);
634        assert_eq!(neighbors.len(), 2);
635        assert!(neighbors.contains(&1));
636        assert!(neighbors.contains(&2));
637    }
638
639    #[test]
640    fn test_graph_adjacency_list() {
641        let mut graph = Graph::new(3);
642        graph.add_edge(0, 1, "edge".to_string());
643        graph.add_edge(1, 2, "edge".to_string());
644
645        let adj = graph.adjacency_list();
646        assert_eq!(adj.len(), 3);
647        assert_eq!(adj[0], vec![1]);
648        assert_eq!(adj[1], vec![2]);
649        assert_eq!(adj[2], Vec::<usize>::new());
650    }
651}