1use std::collections::HashMap;
7
8use serde::{Deserialize, Serialize};
9use tensorlogic_ir::TLExpr;
10
11use crate::error::{KernelError, Result};
12use crate::types::Kernel;
13
14#[derive(Clone, Debug, PartialEq, Eq)]
16pub struct Graph {
17 pub n_nodes: usize,
19 pub edges: Vec<(usize, usize, String)>,
21 pub node_labels: Vec<String>,
23}
24
25impl Graph {
26 pub fn new(n_nodes: usize) -> Self {
28 Self {
29 n_nodes,
30 edges: Vec::new(),
31 node_labels: vec!["node".to_string(); n_nodes],
32 }
33 }
34
35 pub fn add_edge(&mut self, from: usize, to: usize, edge_type: String) {
37 if from < self.n_nodes && to < self.n_nodes {
38 self.edges.push((from, to, edge_type));
39 }
40 }
41
42 pub fn set_node_label(&mut self, node: usize, label: String) {
44 if node < self.n_nodes {
45 self.node_labels[node] = label;
46 }
47 }
48
49 pub fn adjacency_list(&self) -> Vec<Vec<usize>> {
51 let mut adj = vec![Vec::new(); self.n_nodes];
52 for &(from, to, _) in &self.edges {
53 adj[from].push(to);
54 }
55 adj
56 }
57
58 pub fn neighbors(&self, node: usize) -> Vec<usize> {
60 self.edges
61 .iter()
62 .filter(|(from, _, _)| *from == node)
63 .map(|(_, to, _)| *to)
64 .collect()
65 }
66
67 pub fn from_tlexpr(expr: &TLExpr) -> Self {
69 let mut graph = Graph::new(0);
70 let mut node_id = 0;
71 Self::build_graph_recursive(expr, &mut graph, &mut node_id, None);
72 graph
73 }
74
75 fn build_graph_recursive(
76 expr: &TLExpr,
77 graph: &mut Graph,
78 node_id: &mut usize,
79 parent: Option<usize>,
80 ) -> usize {
81 let current_id = *node_id;
82 *node_id += 1;
83 graph.n_nodes += 1;
84
85 let label = match expr {
87 TLExpr::Pred { name, .. } => format!("pred:{}", name),
88 TLExpr::And(_, _) => "and".to_string(),
89 TLExpr::Or(_, _) => "or".to_string(),
90 TLExpr::Not(_) => "not".to_string(),
91 TLExpr::Exists { domain, .. } => format!("exists:{}", domain),
92 TLExpr::ForAll { domain, .. } => format!("forall:{}", domain),
93 TLExpr::Imply(_, _) => "imply".to_string(),
94 _ => "unknown".to_string(),
95 };
96
97 graph.node_labels.push(label.clone());
98
99 if let Some(parent_id) = parent {
101 graph.add_edge(parent_id, current_id, "child".to_string());
102 }
103
104 match expr {
106 TLExpr::And(left, right) | TLExpr::Or(left, right) | TLExpr::Imply(left, right) => {
107 Self::build_graph_recursive(left, graph, node_id, Some(current_id));
108 Self::build_graph_recursive(right, graph, node_id, Some(current_id));
109 }
110 TLExpr::Not(inner) => {
111 Self::build_graph_recursive(inner, graph, node_id, Some(current_id));
112 }
113 TLExpr::Exists { body, .. } | TLExpr::ForAll { body, .. } => {
114 Self::build_graph_recursive(body, graph, node_id, Some(current_id));
115 }
116 _ => {}
117 }
118
119 current_id
120 }
121}
122
123#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
125pub struct SubgraphMatchingConfig {
126 pub max_subgraph_size: usize,
128 pub normalize: bool,
130}
131
132impl SubgraphMatchingConfig {
133 pub fn new() -> Self {
135 Self {
136 max_subgraph_size: 3,
137 normalize: true,
138 }
139 }
140
141 pub fn with_max_size(mut self, size: usize) -> Self {
143 self.max_subgraph_size = size;
144 self
145 }
146}
147
148impl Default for SubgraphMatchingConfig {
149 fn default() -> Self {
150 Self::new()
151 }
152}
153
154pub struct SubgraphMatchingKernel {
158 config: SubgraphMatchingConfig,
159}
160
161impl SubgraphMatchingKernel {
162 pub fn new(config: SubgraphMatchingConfig) -> Self {
164 Self { config }
165 }
166
167 fn count_subgraphs(&self, graph: &Graph, size: usize) -> HashMap<String, usize> {
169 let mut subgraph_counts = HashMap::new();
170
171 if size > graph.n_nodes {
172 return subgraph_counts;
173 }
174
175 for node in 0..graph.n_nodes {
178 let pattern = self.extract_pattern(graph, node, size);
179 *subgraph_counts.entry(pattern).or_insert(0) += 1;
180 }
181
182 subgraph_counts
183 }
184
185 fn extract_pattern(&self, graph: &Graph, start: usize, depth: usize) -> String {
187 let mut pattern_parts = vec![graph.node_labels[start].clone()];
188
189 if depth > 1 {
190 let neighbors = graph.neighbors(start);
191 let mut neighbor_labels: Vec<_> = neighbors
192 .iter()
193 .map(|&n| graph.node_labels[n].clone())
194 .collect();
195 neighbor_labels.sort();
196 pattern_parts.extend(neighbor_labels);
197 }
198
199 pattern_parts.join("|")
200 }
201
202 pub fn compute_graphs(&self, g1: &Graph, g2: &Graph) -> Result<f64> {
204 let mut total_similarity = 0.0;
205
206 for size in 1..=self.config.max_subgraph_size {
207 let counts1 = self.count_subgraphs(g1, size);
208 let counts2 = self.count_subgraphs(g2, size);
209
210 let mut intersection = 0.0;
212 for (pattern, count1) in &counts1 {
213 if let Some(count2) = counts2.get(pattern) {
214 intersection += (*count1).min(*count2) as f64;
215 }
216 }
217
218 total_similarity += intersection;
219 }
220
221 if self.config.normalize {
222 let max_size = (g1.n_nodes.max(g2.n_nodes)) as f64;
223 if max_size > 0.0 {
224 total_similarity /= max_size;
225 }
226 }
227
228 Ok(total_similarity)
229 }
230}
231
232impl Kernel for SubgraphMatchingKernel {
233 fn compute(&self, x: &[f64], _y: &[f64]) -> Result<f64> {
234 Ok(x.iter().sum::<f64>())
237 }
238
239 fn name(&self) -> &str {
240 "SubgraphMatching"
241 }
242}
243
244#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
246pub struct WalkKernelConfig {
247 pub max_walk_length: usize,
249 pub decay_factor: f64,
251 pub normalize: bool,
253}
254
255impl WalkKernelConfig {
256 pub fn new() -> Self {
258 Self {
259 max_walk_length: 4,
260 decay_factor: 0.8,
261 normalize: true,
262 }
263 }
264
265 pub fn with_max_length(mut self, length: usize) -> Self {
267 self.max_walk_length = length;
268 self
269 }
270
271 pub fn with_decay(mut self, decay: f64) -> Self {
273 self.decay_factor = decay;
274 self
275 }
276}
277
278impl Default for WalkKernelConfig {
279 fn default() -> Self {
280 Self::new()
281 }
282}
283
284pub struct RandomWalkKernel {
288 config: WalkKernelConfig,
289}
290
291impl RandomWalkKernel {
292 pub fn new(config: WalkKernelConfig) -> Result<Self> {
294 if config.decay_factor <= 0.0 || config.decay_factor > 1.0 {
295 return Err(KernelError::InvalidParameter {
296 parameter: "decay_factor".to_string(),
297 value: config.decay_factor.to_string(),
298 reason: "must be in (0, 1]".to_string(),
299 });
300 }
301
302 Ok(Self { config })
303 }
304
305 fn extract_walks(&self, graph: &Graph) -> HashMap<Vec<String>, usize> {
307 let mut walk_counts = HashMap::new();
308 let adj = graph.adjacency_list();
309
310 for start in 0..graph.n_nodes {
311 self.dfs_walks(
312 graph,
313 &adj,
314 start,
315 vec![graph.node_labels[start].clone()],
316 &mut walk_counts,
317 );
318 }
319
320 walk_counts
321 }
322
323 fn dfs_walks(
325 &self,
326 graph: &Graph,
327 adj: &[Vec<usize>],
328 current: usize,
329 path: Vec<String>,
330 walk_counts: &mut HashMap<Vec<String>, usize>,
331 ) {
332 if path.len() >= self.config.max_walk_length {
333 *walk_counts.entry(path).or_insert(0) += 1;
334 return;
335 }
336
337 *walk_counts.entry(path.clone()).or_insert(0) += 1;
339
340 for &neighbor in &adj[current] {
342 let mut new_path = path.clone();
343 new_path.push(graph.node_labels[neighbor].clone());
344 self.dfs_walks(graph, adj, neighbor, new_path, walk_counts);
345 }
346 }
347
348 pub fn compute_graphs(&self, g1: &Graph, g2: &Graph) -> Result<f64> {
350 let walks1 = self.extract_walks(g1);
351 let walks2 = self.extract_walks(g2);
352
353 let mut similarity = 0.0;
354
355 for (walk, count1) in &walks1 {
356 if let Some(count2) = walks2.get(walk) {
357 let walk_sim = (*count1).min(*count2) as f64;
358 let decay = self.config.decay_factor.powi(walk.len() as i32);
359 similarity += walk_sim * decay;
360 }
361 }
362
363 if self.config.normalize {
364 let total1: usize = walks1.values().sum();
365 let total2: usize = walks2.values().sum();
366 let normalizer = ((total1 * total2) as f64).sqrt();
367 if normalizer > 0.0 {
368 similarity /= normalizer;
369 }
370 }
371
372 Ok(similarity)
373 }
374}
375
376impl Kernel for RandomWalkKernel {
377 fn compute(&self, x: &[f64], _y: &[f64]) -> Result<f64> {
378 Ok(x.iter().sum::<f64>())
380 }
381
382 fn name(&self) -> &str {
383 "RandomWalk"
384 }
385}
386
387#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
389pub struct WeisfeilerLehmanConfig {
390 pub n_iterations: usize,
392 pub normalize: bool,
394}
395
396impl WeisfeilerLehmanConfig {
397 pub fn new() -> Self {
399 Self {
400 n_iterations: 3,
401 normalize: true,
402 }
403 }
404
405 pub fn with_iterations(mut self, iterations: usize) -> Self {
407 self.n_iterations = iterations;
408 self
409 }
410}
411
412impl Default for WeisfeilerLehmanConfig {
413 fn default() -> Self {
414 Self::new()
415 }
416}
417
418pub struct WeisfeilerLehmanKernel {
423 config: WeisfeilerLehmanConfig,
424}
425
426impl WeisfeilerLehmanKernel {
427 pub fn new(config: WeisfeilerLehmanConfig) -> Self {
429 Self { config }
430 }
431
432 fn wl_iteration(&self, graph: &Graph, labels: &[String]) -> Vec<String> {
434 let mut new_labels = Vec::with_capacity(graph.n_nodes);
435 let adj = graph.adjacency_list();
436
437 for node in 0..graph.n_nodes {
438 let mut neighbor_labels: Vec<String> =
440 adj[node].iter().map(|&n| labels[n].clone()).collect();
441
442 neighbor_labels.sort();
443
444 let mut new_label = labels[node].clone();
446 for neighbor_label in neighbor_labels {
447 new_label.push('_');
448 new_label.push_str(&neighbor_label);
449 }
450
451 new_labels.push(new_label);
452 }
453
454 new_labels
455 }
456
457 fn extract_label_histograms(&self, graph: &Graph) -> Vec<HashMap<String, usize>> {
459 let mut histograms = Vec::new();
460 let mut labels = graph.node_labels.clone();
461
462 for _ in 0..self.config.n_iterations {
463 let mut histogram = HashMap::new();
465 for label in &labels {
466 *histogram.entry(label.clone()).or_insert(0) += 1;
467 }
468 histograms.push(histogram);
469
470 labels = self.wl_iteration(graph, &labels);
472 }
473
474 histograms
475 }
476
477 pub fn compute_graphs(&self, g1: &Graph, g2: &Graph) -> Result<f64> {
479 let hists1 = self.extract_label_histograms(g1);
480 let hists2 = self.extract_label_histograms(g2);
481
482 let mut total_similarity = 0.0;
483
484 for (hist1, hist2) in hists1.iter().zip(hists2.iter()) {
485 let mut intersection = 0.0;
487 for (label, count1) in hist1 {
488 if let Some(count2) = hist2.get(label) {
489 intersection += (*count1).min(*count2) as f64;
490 }
491 }
492 total_similarity += intersection;
493 }
494
495 if self.config.normalize {
496 let size1 = g1.n_nodes as f64;
497 let size2 = g2.n_nodes as f64;
498 let normalizer = (size1 * size2).sqrt();
499 if normalizer > 0.0 {
500 total_similarity /= normalizer;
501 }
502 }
503
504 Ok(total_similarity)
505 }
506}
507
508impl Kernel for WeisfeilerLehmanKernel {
509 fn compute(&self, x: &[f64], _y: &[f64]) -> Result<f64> {
510 Ok(x.iter().sum::<f64>())
512 }
513
514 fn name(&self) -> &str {
515 "WeisfeilerLehman"
516 }
517}
518
519#[cfg(test)]
520mod tests {
521 use super::*;
522
523 #[test]
524 fn test_graph_creation() {
525 let mut graph = Graph::new(3);
526 graph.add_edge(0, 1, "edge".to_string());
527 graph.add_edge(1, 2, "edge".to_string());
528 graph.set_node_label(0, "A".to_string());
529 graph.set_node_label(1, "B".to_string());
530 graph.set_node_label(2, "C".to_string());
531
532 assert_eq!(graph.n_nodes, 3);
533 assert_eq!(graph.edges.len(), 2);
534 assert_eq!(graph.node_labels[0], "A");
535 }
536
537 #[test]
538 fn test_graph_from_tlexpr() {
539 let expr = TLExpr::and(TLExpr::pred("p1", vec![]), TLExpr::pred("p2", vec![]));
540
541 let graph = Graph::from_tlexpr(&expr);
542 assert!(graph.n_nodes > 0);
543 assert!(!graph.node_labels.is_empty());
544 }
545
546 #[test]
547 fn test_subgraph_matching_kernel() {
548 let config = SubgraphMatchingConfig::new().with_max_size(2);
549 let kernel = SubgraphMatchingKernel::new(config);
550
551 let mut g1 = Graph::new(3);
552 g1.add_edge(0, 1, "edge".to_string());
553 g1.add_edge(1, 2, "edge".to_string());
554
555 let mut g2 = Graph::new(3);
556 g2.add_edge(0, 1, "edge".to_string());
557 g2.add_edge(0, 2, "edge".to_string());
558
559 let sim = kernel.compute_graphs(&g1, &g2).unwrap();
560 assert!(sim >= 0.0);
561 }
562
563 #[test]
564 fn test_random_walk_kernel() {
565 let config = WalkKernelConfig::new().with_max_length(3);
566 let kernel = RandomWalkKernel::new(config).unwrap();
567
568 let mut g1 = Graph::new(3);
569 g1.add_edge(0, 1, "edge".to_string());
570 g1.add_edge(1, 2, "edge".to_string());
571
572 let mut g2 = Graph::new(3);
573 g2.add_edge(0, 1, "edge".to_string());
574 g2.add_edge(1, 2, "edge".to_string());
575
576 let sim = kernel.compute_graphs(&g1, &g2).unwrap();
577 assert!(sim > 0.0);
578 }
579
580 #[test]
581 fn test_random_walk_kernel_invalid_decay() {
582 let config = WalkKernelConfig::new().with_decay(1.5);
583 let result = RandomWalkKernel::new(config);
584 assert!(result.is_err());
585 }
586
587 #[test]
588 fn test_weisfeiler_lehman_kernel() {
589 let config = WeisfeilerLehmanConfig::new().with_iterations(2);
590 let kernel = WeisfeilerLehmanKernel::new(config);
591
592 let mut g1 = Graph::new(4);
593 g1.set_node_label(0, "A".to_string());
594 g1.set_node_label(1, "B".to_string());
595 g1.set_node_label(2, "B".to_string());
596 g1.set_node_label(3, "A".to_string());
597 g1.add_edge(0, 1, "edge".to_string());
598 g1.add_edge(1, 2, "edge".to_string());
599 g1.add_edge(2, 3, "edge".to_string());
600
601 let mut g2 = Graph::new(4);
602 g2.set_node_label(0, "A".to_string());
603 g2.set_node_label(1, "B".to_string());
604 g2.set_node_label(2, "B".to_string());
605 g2.set_node_label(3, "A".to_string());
606 g2.add_edge(0, 1, "edge".to_string());
607 g2.add_edge(1, 2, "edge".to_string());
608 g2.add_edge(2, 3, "edge".to_string());
609
610 let sim = kernel.compute_graphs(&g1, &g2).unwrap();
611 assert!(sim > 0.0);
612 }
613
614 #[test]
615 fn test_wl_self_similarity() {
616 let config = WeisfeilerLehmanConfig::new();
617 let kernel = WeisfeilerLehmanKernel::new(config);
618
619 let mut graph = Graph::new(3);
620 graph.add_edge(0, 1, "edge".to_string());
621 graph.add_edge(1, 2, "edge".to_string());
622
623 let sim = kernel.compute_graphs(&graph, &graph).unwrap();
624 assert!(sim > 0.0);
625 }
626
627 #[test]
628 fn test_graph_neighbors() {
629 let mut graph = Graph::new(3);
630 graph.add_edge(0, 1, "edge".to_string());
631 graph.add_edge(0, 2, "edge".to_string());
632
633 let neighbors = graph.neighbors(0);
634 assert_eq!(neighbors.len(), 2);
635 assert!(neighbors.contains(&1));
636 assert!(neighbors.contains(&2));
637 }
638
639 #[test]
640 fn test_graph_adjacency_list() {
641 let mut graph = Graph::new(3);
642 graph.add_edge(0, 1, "edge".to_string());
643 graph.add_edge(1, 2, "edge".to_string());
644
645 let adj = graph.adjacency_list();
646 assert_eq!(adj.len(), 3);
647 assert_eq!(adj[0], vec![1]);
648 assert_eq!(adj[1], vec![2]);
649 assert_eq!(adj[2], Vec::<usize>::new());
650 }
651}