1use crate::error::{KernelError, Result};
24use serde::{Deserialize, Serialize};
25use std::collections::HashMap;
26use tensorlogic_ir::TLExpr;
27
28#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
30pub struct TreeNode {
31 pub label: String,
33 pub children: Vec<TreeNode>,
35}
36
37impl TreeNode {
38 pub fn new(label: impl Into<String>) -> Self {
40 Self {
41 label: label.into(),
42 children: Vec::new(),
43 }
44 }
45
46 pub fn with_children(label: impl Into<String>, children: Vec<TreeNode>) -> Self {
48 Self {
49 label: label.into(),
50 children,
51 }
52 }
53
54 pub fn height(&self) -> usize {
56 if self.children.is_empty() {
57 1
58 } else {
59 1 + self.children.iter().map(|c| c.height()).max().unwrap_or(0)
60 }
61 }
62
63 pub fn num_nodes(&self) -> usize {
65 1 + self.children.iter().map(|c| c.num_nodes()).sum::<usize>()
66 }
67
68 pub fn is_leaf(&self) -> bool {
70 self.children.is_empty()
71 }
72
73 pub fn from_tlexpr(expr: &TLExpr) -> Self {
75 match expr {
76 TLExpr::Pred { name, .. } => TreeNode::new(format!("Pred({})", name)),
77 TLExpr::And(left, right) => TreeNode::with_children(
78 "And",
79 vec![TreeNode::from_tlexpr(left), TreeNode::from_tlexpr(right)],
80 ),
81 TLExpr::Or(left, right) => TreeNode::with_children(
82 "Or",
83 vec![TreeNode::from_tlexpr(left), TreeNode::from_tlexpr(right)],
84 ),
85 TLExpr::Not(expr) => TreeNode::with_children("Not", vec![TreeNode::from_tlexpr(expr)]),
86 TLExpr::Imply(left, right) => TreeNode::with_children(
87 "Imply",
88 vec![TreeNode::from_tlexpr(left), TreeNode::from_tlexpr(right)],
89 ),
90 TLExpr::Exists { var, domain, body } => TreeNode::with_children(
91 format!("Exists({}, {})", var, domain),
92 vec![TreeNode::from_tlexpr(body)],
93 ),
94 TLExpr::ForAll { var, domain, body } => TreeNode::with_children(
95 format!("ForAll({}, {})", var, domain),
96 vec![TreeNode::from_tlexpr(body)],
97 ),
98 _ => TreeNode::new("Expr"),
99 }
100 }
101
102 fn get_all_subtrees(&self) -> Vec<TreeNode> {
104 let mut subtrees = vec![self.clone()];
105 for child in &self.children {
106 subtrees.extend(child.get_all_subtrees());
107 }
108 subtrees
109 }
110}
111
112#[derive(Debug, Clone, Serialize, Deserialize)]
114pub struct SubtreeKernelConfig {
115 pub normalize: bool,
117}
118
119impl SubtreeKernelConfig {
120 pub fn new() -> Self {
122 Self { normalize: true }
123 }
124
125 pub fn with_normalize(mut self, normalize: bool) -> Self {
127 self.normalize = normalize;
128 self
129 }
130}
131
132impl Default for SubtreeKernelConfig {
133 fn default() -> Self {
134 Self::new()
135 }
136}
137
138pub struct SubtreeKernel {
151 config: SubtreeKernelConfig,
152}
153
154impl SubtreeKernel {
155 pub fn new(config: SubtreeKernelConfig) -> Self {
157 Self { config }
158 }
159
160 pub fn compute_trees(&self, tree1: &TreeNode, tree2: &TreeNode) -> Result<f64> {
162 let subtrees1 = tree1.get_all_subtrees();
163 let subtrees2 = tree2.get_all_subtrees();
164
165 let mut count = 0;
167 for st1 in &subtrees1 {
168 for st2 in &subtrees2 {
169 if st1 == st2 {
170 count += 1;
171 }
172 }
173 }
174
175 let similarity = count as f64;
176
177 if self.config.normalize {
178 let self_sim1 = subtrees1.len() as f64;
180 let self_sim2 = subtrees2.len() as f64;
181 let norm = (self_sim1 * self_sim2).sqrt();
182 if norm > 0.0 {
183 Ok(similarity / norm)
184 } else {
185 Ok(0.0)
186 }
187 } else {
188 Ok(similarity)
189 }
190 }
191}
192
193#[derive(Debug, Clone, Serialize, Deserialize)]
195pub struct SubsetTreeKernelConfig {
196 pub normalize: bool,
198 pub decay: f64,
200}
201
202impl SubsetTreeKernelConfig {
203 pub fn new() -> Result<Self> {
205 Ok(Self {
206 normalize: true,
207 decay: 1.0,
208 })
209 }
210
211 pub fn with_normalize(mut self, normalize: bool) -> Self {
213 self.normalize = normalize;
214 self
215 }
216
217 pub fn with_decay(mut self, decay: f64) -> Result<Self> {
219 if !(0.0..=1.0).contains(&decay) {
220 return Err(KernelError::InvalidParameter {
221 parameter: "decay".to_string(),
222 value: decay.to_string(),
223 reason: "must be between 0.0 and 1.0".to_string(),
224 });
225 }
226 self.decay = decay;
227 Ok(self)
228 }
229}
230
231impl Default for SubsetTreeKernelConfig {
232 fn default() -> Self {
233 Self::new().unwrap()
234 }
235}
236
237pub struct SubsetTreeKernel {
242 config: SubsetTreeKernelConfig,
243}
244
245impl SubsetTreeKernel {
246 pub fn new(config: SubsetTreeKernelConfig) -> Self {
248 Self { config }
249 }
250
251 pub fn compute_trees(&self, tree1: &TreeNode, tree2: &TreeNode) -> Result<f64> {
253 let similarity = self.compute_recursive(tree1, tree2, &mut HashMap::new());
254
255 if self.config.normalize {
256 let self_sim1 = self.compute_recursive(tree1, tree1, &mut HashMap::new());
257 let self_sim2 = self.compute_recursive(tree2, tree2, &mut HashMap::new());
258 let norm = (self_sim1 * self_sim2).sqrt();
259 if norm > 0.0 {
260 Ok(similarity / norm)
261 } else {
262 Ok(0.0)
263 }
264 } else {
265 Ok(similarity)
266 }
267 }
268
269 fn compute_recursive(
271 &self,
272 n1: &TreeNode,
273 n2: &TreeNode,
274 cache: &mut HashMap<(usize, usize), f64>,
275 ) -> f64 {
276 let key = (n1.num_nodes(), n2.num_nodes());
278
279 if let Some(&cached) = cache.get(&key) {
280 return cached;
281 }
282
283 let mut result = 0.0;
284
285 if n1.label == n2.label {
287 result += self.config.decay;
289
290 if !n1.children.is_empty() && !n2.children.is_empty() {
292 for c1 in &n1.children {
294 for c2 in &n2.children {
295 result += self.config.decay * self.compute_recursive(c1, c2, cache);
296 }
297 }
298 }
299 }
300
301 cache.insert(key, result);
302 result
303 }
304}
305
306#[derive(Debug, Clone, Serialize, Deserialize)]
308pub struct PartialTreeKernelConfig {
309 pub normalize: bool,
311 pub decay: f64,
313 pub threshold: f64,
315}
316
317impl PartialTreeKernelConfig {
318 pub fn new() -> Result<Self> {
320 Ok(Self {
321 normalize: true,
322 decay: 0.8,
323 threshold: 0.0,
324 })
325 }
326
327 pub fn with_normalize(mut self, normalize: bool) -> Self {
329 self.normalize = normalize;
330 self
331 }
332
333 pub fn with_decay(mut self, decay: f64) -> Result<Self> {
335 if !(0.0..=1.0).contains(&decay) {
336 return Err(KernelError::InvalidParameter {
337 parameter: "decay".to_string(),
338 value: decay.to_string(),
339 reason: "must be between 0.0 and 1.0".to_string(),
340 });
341 }
342 self.decay = decay;
343 Ok(self)
344 }
345
346 pub fn with_threshold(mut self, threshold: f64) -> Result<Self> {
348 if !(0.0..=1.0).contains(&threshold) {
349 return Err(KernelError::InvalidParameter {
350 parameter: "threshold".to_string(),
351 value: threshold.to_string(),
352 reason: "must be between 0.0 and 1.0".to_string(),
353 });
354 }
355 self.threshold = threshold;
356 Ok(self)
357 }
358}
359
360impl Default for PartialTreeKernelConfig {
361 fn default() -> Self {
362 Self::new().unwrap()
363 }
364}
365
366pub struct PartialTreeKernel {
371 config: PartialTreeKernelConfig,
372}
373
374impl PartialTreeKernel {
375 pub fn new(config: PartialTreeKernelConfig) -> Self {
377 Self { config }
378 }
379
380 pub fn compute_trees(&self, tree1: &TreeNode, tree2: &TreeNode) -> Result<f64> {
382 let similarity = self.compute_partial_match(tree1, tree2, 1.0);
383
384 if similarity < self.config.threshold {
385 return Ok(0.0);
386 }
387
388 if self.config.normalize {
389 let self_sim1 = self.compute_partial_match(tree1, tree1, 1.0);
390 let self_sim2 = self.compute_partial_match(tree2, tree2, 1.0);
391 let norm = (self_sim1 * self_sim2).sqrt();
392 if norm > 0.0 {
393 Ok(similarity / norm)
394 } else {
395 Ok(0.0)
396 }
397 } else {
398 Ok(similarity)
399 }
400 }
401
402 fn compute_partial_match(&self, n1: &TreeNode, n2: &TreeNode, weight: f64) -> f64 {
404 let mut score = 0.0;
405
406 if n1.label == n2.label {
408 score += weight;
409
410 let min_children = n1.children.len().min(n2.children.len());
412 for i in 0..min_children {
413 score += self.compute_partial_match(
414 &n1.children[i],
415 &n2.children[i],
416 weight * self.config.decay,
417 );
418 }
419 } else {
420 let label_sim = self.label_similarity(&n1.label, &n2.label);
422 score += weight * label_sim * 0.5; let min_children = n1.children.len().min(n2.children.len());
426 for i in 0..min_children {
427 score += self.compute_partial_match(
428 &n1.children[i],
429 &n2.children[i],
430 weight * self.config.decay * 0.5,
431 );
432 }
433 }
434
435 score
436 }
437
438 fn label_similarity(&self, label1: &str, label2: &str) -> f64 {
440 if label1 == label2 {
441 1.0
442 } else {
443 let chars1: std::collections::HashSet<char> = label1.chars().collect();
445 let chars2: std::collections::HashSet<char> = label2.chars().collect();
446 let intersection = chars1.intersection(&chars2).count();
447 let union = chars1.union(&chars2).count();
448 if union > 0 {
449 intersection as f64 / union as f64
450 } else {
451 0.0
452 }
453 }
454 }
455}
456
457#[cfg(test)]
458mod tests {
459 use super::*;
460
461 #[test]
462 fn test_tree_node_creation() {
463 let node = TreeNode::new("root");
464 assert_eq!(node.label, "root");
465 assert!(node.children.is_empty());
466 assert!(node.is_leaf());
467 }
468
469 #[test]
470 fn test_tree_node_with_children() {
471 let child1 = TreeNode::new("child1");
472 let child2 = TreeNode::new("child2");
473 let parent = TreeNode::with_children("parent", vec![child1, child2]);
474
475 assert_eq!(parent.label, "parent");
476 assert_eq!(parent.children.len(), 2);
477 assert!(!parent.is_leaf());
478 }
479
480 #[test]
481 fn test_tree_height() {
482 let leaf = TreeNode::new("leaf");
483 assert_eq!(leaf.height(), 1);
484
485 let tree = TreeNode::with_children(
486 "root",
487 vec![
488 TreeNode::new("child1"),
489 TreeNode::with_children("child2", vec![TreeNode::new("grandchild")]),
490 ],
491 );
492 assert_eq!(tree.height(), 3);
493 }
494
495 #[test]
496 fn test_tree_num_nodes() {
497 let tree = TreeNode::with_children(
498 "root",
499 vec![
500 TreeNode::new("child1"),
501 TreeNode::with_children("child2", vec![TreeNode::new("grandchild")]),
502 ],
503 );
504 assert_eq!(tree.num_nodes(), 4);
505 }
506
507 #[test]
508 fn test_tree_from_tlexpr() {
509 let expr = TLExpr::and(TLExpr::pred("p1", vec![]), TLExpr::pred("p2", vec![]));
510 let tree = TreeNode::from_tlexpr(&expr);
511
512 assert_eq!(tree.label, "And");
513 assert_eq!(tree.children.len(), 2);
514 }
515
516 #[test]
517 fn test_subtree_kernel_identical() {
518 let tree1 = TreeNode::with_children(
519 "root",
520 vec![TreeNode::new("child1"), TreeNode::new("child2")],
521 );
522 let tree2 = tree1.clone();
523
524 let config = SubtreeKernelConfig::new().with_normalize(false);
525 let kernel = SubtreeKernel::new(config);
526
527 let sim = kernel.compute_trees(&tree1, &tree2).unwrap();
528 assert!(sim > 0.0);
529 }
530
531 #[test]
532 fn test_subtree_kernel_different() {
533 let tree1 = TreeNode::with_children("root", vec![TreeNode::new("child1")]);
535 let tree2 = TreeNode::with_children("root", vec![TreeNode::new("child2")]);
536
537 let config = SubtreeKernelConfig::new().with_normalize(false);
538 let kernel = SubtreeKernel::new(config);
539
540 let sim = kernel.compute_trees(&tree1, &tree2).unwrap();
541 assert!(sim >= 0.0); }
545
546 #[test]
547 fn test_subtree_kernel_partial_match() {
548 let tree1 = TreeNode::with_children(
550 "root",
551 vec![TreeNode::new("child1"), TreeNode::new("child2")],
552 );
553 let tree2 = TreeNode::with_children(
554 "root",
555 vec![TreeNode::new("child1"), TreeNode::new("child3")],
556 );
557
558 let config = SubtreeKernelConfig::new().with_normalize(false);
559 let kernel = SubtreeKernel::new(config);
560
561 let sim = kernel.compute_trees(&tree1, &tree2).unwrap();
562 assert!(sim > 0.0);
564 }
565
566 #[test]
567 fn test_subtree_kernel_normalized() {
568 let tree1 = TreeNode::with_children(
569 "root",
570 vec![TreeNode::new("child1"), TreeNode::new("child2")],
571 );
572 let tree2 = tree1.clone();
573
574 let config = SubtreeKernelConfig::new().with_normalize(true);
575 let kernel = SubtreeKernel::new(config);
576
577 let sim = kernel.compute_trees(&tree1, &tree2).unwrap();
578 assert!((sim - 1.0).abs() < 1e-6); }
580
581 #[test]
582 fn test_subset_tree_kernel() {
583 let tree1 = TreeNode::with_children(
584 "root",
585 vec![TreeNode::new("child1"), TreeNode::new("child2")],
586 );
587 let tree2 = TreeNode::with_children("root", vec![TreeNode::new("child1")]);
588
589 let config = SubsetTreeKernelConfig::new().unwrap();
590 let kernel = SubsetTreeKernel::new(config);
591
592 let sim = kernel.compute_trees(&tree1, &tree2).unwrap();
593 assert!(sim > 0.0);
594 }
595
596 #[test]
597 fn test_subset_tree_kernel_decay() {
598 let tree1 = TreeNode::with_children("root", vec![TreeNode::new("child")]);
599 let tree2 = tree1.clone();
600
601 let config1 = SubsetTreeKernelConfig::new()
602 .unwrap()
603 .with_decay(1.0)
604 .unwrap()
605 .with_normalize(false);
606 let kernel1 = SubsetTreeKernel::new(config1);
607
608 let config2 = SubsetTreeKernelConfig::new()
609 .unwrap()
610 .with_decay(0.5)
611 .unwrap()
612 .with_normalize(false);
613 let kernel2 = SubsetTreeKernel::new(config2);
614
615 let sim1 = kernel1.compute_trees(&tree1, &tree2).unwrap();
616 let sim2 = kernel2.compute_trees(&tree1, &tree2).unwrap();
617
618 assert!(sim2 < sim1);
620 }
621
622 #[test]
623 fn test_partial_tree_kernel() {
624 let tree1 = TreeNode::with_children(
625 "root",
626 vec![TreeNode::new("child1"), TreeNode::new("child2")],
627 );
628 let tree2 = TreeNode::with_children(
629 "root",
630 vec![TreeNode::new("child1"), TreeNode::new("child3")],
631 );
632
633 let config = PartialTreeKernelConfig::new().unwrap();
634 let kernel = PartialTreeKernel::new(config);
635
636 let sim = kernel.compute_trees(&tree1, &tree2).unwrap();
637 assert!(sim > 0.0); }
639
640 #[test]
641 fn test_partial_tree_kernel_threshold() {
642 let tree1 = TreeNode::with_children("root1", vec![TreeNode::new("child")]);
643 let tree2 = TreeNode::with_children("root2", vec![TreeNode::new("child")]);
644
645 let config = PartialTreeKernelConfig::new()
646 .unwrap()
647 .with_threshold(0.9)
648 .unwrap();
649 let kernel = PartialTreeKernel::new(config);
650
651 let sim = kernel.compute_trees(&tree1, &tree2).unwrap();
652 assert!(sim < 0.5);
654 }
655
656 #[test]
657 fn test_partial_tree_kernel_config_invalid_decay() {
658 let result = PartialTreeKernelConfig::new().unwrap().with_decay(1.5);
659 assert!(result.is_err());
660 }
661
662 #[test]
663 fn test_partial_tree_kernel_config_invalid_threshold() {
664 let result = PartialTreeKernelConfig::new().unwrap().with_threshold(-0.1);
665 assert!(result.is_err());
666 }
667
668 #[test]
669 fn test_tree_kernel_with_tlexpr() {
670 let expr1 = TLExpr::and(TLExpr::pred("p1", vec![]), TLExpr::pred("p2", vec![]));
671 let expr2 = TLExpr::and(TLExpr::pred("p1", vec![]), TLExpr::pred("p3", vec![]));
672
673 let tree1 = TreeNode::from_tlexpr(&expr1);
674 let tree2 = TreeNode::from_tlexpr(&expr2);
675
676 let config = SubtreeKernelConfig::new();
677 let kernel = SubtreeKernel::new(config);
678
679 let sim = kernel.compute_trees(&tree1, &tree2).unwrap();
680 assert!(sim > 0.0); }
682}