Skip to main content

tensorlogic_sklears_kernels/
tree_kernel.rs

1//! Tree kernels for structured data similarity
2//!
3//! This module provides kernel functions for tree-structured data, particularly
4//! useful for measuring similarity between hierarchical expressions like TLExpr.
5//!
6//! ## Tree Representations
7//!
8//! Trees are represented as labeled nodes with children, where each node has:
9//! - A label (string identifier)
10//! - A list of child nodes
11//!
12//! ## Kernel Types
13//!
14//! - **SubtreeKernel**: Counts common subtrees between two trees
15//! - **SubsetTreeKernel**: Counts common tree fragments (allows gaps)
16//! - **PartialTreeKernel**: Partial subtree matching with decay factors
17//!
18//! ## References
19//!
20//! - Collins & Duffy (2001): "Convolution Kernels for Natural Language"
21//! - Moschitti (2006): "Making Tree Kernels Practical for Natural Language Learning"
22
23use crate::error::{KernelError, Result};
24use serde::{Deserialize, Serialize};
25use std::collections::HashMap;
26use tensorlogic_ir::TLExpr;
27
28/// A tree node with label and children
29#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
30pub struct TreeNode {
31    /// Node label
32    pub label: String,
33    /// Child nodes
34    pub children: Vec<TreeNode>,
35}
36
37impl TreeNode {
38    /// Create a new tree node
39    pub fn new(label: impl Into<String>) -> Self {
40        Self {
41            label: label.into(),
42            children: Vec::new(),
43        }
44    }
45
46    /// Create a tree node with children
47    pub fn with_children(label: impl Into<String>, children: Vec<TreeNode>) -> Self {
48        Self {
49            label: label.into(),
50            children,
51        }
52    }
53
54    /// Get the height of the tree
55    pub fn height(&self) -> usize {
56        if self.children.is_empty() {
57            1
58        } else {
59            1 + self.children.iter().map(|c| c.height()).max().unwrap_or(0)
60        }
61    }
62
63    /// Get the number of nodes in the tree
64    pub fn num_nodes(&self) -> usize {
65        1 + self.children.iter().map(|c| c.num_nodes()).sum::<usize>()
66    }
67
68    /// Check if this is a leaf node
69    pub fn is_leaf(&self) -> bool {
70        self.children.is_empty()
71    }
72
73    /// Convert from TLExpr to TreeNode
74    pub fn from_tlexpr(expr: &TLExpr) -> Self {
75        match expr {
76            TLExpr::Pred { name, .. } => TreeNode::new(format!("Pred({})", name)),
77            TLExpr::And(left, right) => TreeNode::with_children(
78                "And",
79                vec![TreeNode::from_tlexpr(left), TreeNode::from_tlexpr(right)],
80            ),
81            TLExpr::Or(left, right) => TreeNode::with_children(
82                "Or",
83                vec![TreeNode::from_tlexpr(left), TreeNode::from_tlexpr(right)],
84            ),
85            TLExpr::Not(expr) => TreeNode::with_children("Not", vec![TreeNode::from_tlexpr(expr)]),
86            TLExpr::Imply(left, right) => TreeNode::with_children(
87                "Imply",
88                vec![TreeNode::from_tlexpr(left), TreeNode::from_tlexpr(right)],
89            ),
90            TLExpr::Exists { var, domain, body } => TreeNode::with_children(
91                format!("Exists({}, {})", var, domain),
92                vec![TreeNode::from_tlexpr(body)],
93            ),
94            TLExpr::ForAll { var, domain, body } => TreeNode::with_children(
95                format!("ForAll({}, {})", var, domain),
96                vec![TreeNode::from_tlexpr(body)],
97            ),
98            _ => TreeNode::new("Expr"),
99        }
100    }
101
102    /// Get all subtrees (including the tree itself)
103    fn get_all_subtrees(&self) -> Vec<TreeNode> {
104        let mut subtrees = vec![self.clone()];
105        for child in &self.children {
106            subtrees.extend(child.get_all_subtrees());
107        }
108        subtrees
109    }
110}
111
112/// Configuration for subtree kernel
113#[derive(Debug, Clone, Serialize, Deserialize)]
114pub struct SubtreeKernelConfig {
115    /// Whether to normalize the kernel value
116    pub normalize: bool,
117}
118
119impl SubtreeKernelConfig {
120    /// Create a new configuration
121    pub fn new() -> Self {
122        Self { normalize: true }
123    }
124
125    /// Set normalization flag
126    pub fn with_normalize(mut self, normalize: bool) -> Self {
127        self.normalize = normalize;
128        self
129    }
130}
131
132impl Default for SubtreeKernelConfig {
133    fn default() -> Self {
134        Self::new()
135    }
136}
137
138/// Subtree kernel - counts common subtrees
139///
140/// This kernel counts the number of subtrees that are common between two trees.
141/// It provides a measure of structural similarity.
142///
143/// ## Formula
144///
145/// ```text
146/// K(T1, T2) = Σ_i Σ_j I(subtree_i(T1) == subtree_j(T2))
147/// ```
148///
149/// where I is the indicator function.
150pub struct SubtreeKernel {
151    config: SubtreeKernelConfig,
152}
153
154impl SubtreeKernel {
155    /// Create a new subtree kernel
156    pub fn new(config: SubtreeKernelConfig) -> Self {
157        Self { config }
158    }
159
160    /// Compute kernel between two trees
161    pub fn compute_trees(&self, tree1: &TreeNode, tree2: &TreeNode) -> Result<f64> {
162        let subtrees1 = tree1.get_all_subtrees();
163        let subtrees2 = tree2.get_all_subtrees();
164
165        // Count matches
166        let mut count = 0;
167        for st1 in &subtrees1 {
168            for st2 in &subtrees2 {
169                if st1 == st2 {
170                    count += 1;
171                }
172            }
173        }
174
175        let similarity = count as f64;
176
177        if self.config.normalize {
178            // Normalize by geometric mean of self-similarities
179            let self_sim1 = subtrees1.len() as f64;
180            let self_sim2 = subtrees2.len() as f64;
181            let norm = (self_sim1 * self_sim2).sqrt();
182            if norm > 0.0 {
183                Ok(similarity / norm)
184            } else {
185                Ok(0.0)
186            }
187        } else {
188            Ok(similarity)
189        }
190    }
191}
192
193/// Configuration for subset tree kernel
194#[derive(Debug, Clone, Serialize, Deserialize)]
195pub struct SubsetTreeKernelConfig {
196    /// Whether to normalize the kernel value
197    pub normalize: bool,
198    /// Decay factor for tree fragments (0.0 to 1.0)
199    pub decay: f64,
200}
201
202impl SubsetTreeKernelConfig {
203    /// Create a new configuration
204    pub fn new() -> Result<Self> {
205        Ok(Self {
206            normalize: true,
207            decay: 1.0,
208        })
209    }
210
211    /// Set normalization flag
212    pub fn with_normalize(mut self, normalize: bool) -> Self {
213        self.normalize = normalize;
214        self
215    }
216
217    /// Set decay factor
218    pub fn with_decay(mut self, decay: f64) -> Result<Self> {
219        if !(0.0..=1.0).contains(&decay) {
220            return Err(KernelError::InvalidParameter {
221                parameter: "decay".to_string(),
222                value: decay.to_string(),
223                reason: "must be between 0.0 and 1.0".to_string(),
224            });
225        }
226        self.decay = decay;
227        Ok(self)
228    }
229}
230
231impl Default for SubsetTreeKernelConfig {
232    fn default() -> Self {
233        Self::new().unwrap()
234    }
235}
236
237/// Subset tree kernel - allows gaps in tree fragments
238///
239/// This kernel is more flexible than the subtree kernel, allowing matching
240/// of tree fragments even when intermediate nodes are skipped.
241pub struct SubsetTreeKernel {
242    config: SubsetTreeKernelConfig,
243}
244
245impl SubsetTreeKernel {
246    /// Create a new subset tree kernel
247    pub fn new(config: SubsetTreeKernelConfig) -> Self {
248        Self { config }
249    }
250
251    /// Compute kernel between two trees
252    pub fn compute_trees(&self, tree1: &TreeNode, tree2: &TreeNode) -> Result<f64> {
253        let similarity = self.compute_recursive(tree1, tree2, &mut HashMap::new());
254
255        if self.config.normalize {
256            let self_sim1 = self.compute_recursive(tree1, tree1, &mut HashMap::new());
257            let self_sim2 = self.compute_recursive(tree2, tree2, &mut HashMap::new());
258            let norm = (self_sim1 * self_sim2).sqrt();
259            if norm > 0.0 {
260                Ok(similarity / norm)
261            } else {
262                Ok(0.0)
263            }
264        } else {
265            Ok(similarity)
266        }
267    }
268
269    /// Recursive kernel computation with memoization
270    fn compute_recursive(
271        &self,
272        n1: &TreeNode,
273        n2: &TreeNode,
274        cache: &mut HashMap<(usize, usize), f64>,
275    ) -> f64 {
276        // Simple hash for caching (not perfect but good enough)
277        let key = (n1.num_nodes(), n2.num_nodes());
278
279        if let Some(&cached) = cache.get(&key) {
280            return cached;
281        }
282
283        let mut result = 0.0;
284
285        // If labels match
286        if n1.label == n2.label {
287            // Add contribution from this node
288            result += self.config.decay;
289
290            // If both have children, recursively compute
291            if !n1.children.is_empty() && !n2.children.is_empty() {
292                // Compute kernel for all pairs of children
293                for c1 in &n1.children {
294                    for c2 in &n2.children {
295                        result += self.config.decay * self.compute_recursive(c1, c2, cache);
296                    }
297                }
298            }
299        }
300
301        cache.insert(key, result);
302        result
303    }
304}
305
306/// Configuration for partial tree kernel
307#[derive(Debug, Clone, Serialize, Deserialize)]
308pub struct PartialTreeKernelConfig {
309    /// Whether to normalize the kernel value
310    pub normalize: bool,
311    /// Decay factor for partial matches
312    pub decay: f64,
313    /// Minimum similarity threshold for partial matches
314    pub threshold: f64,
315}
316
317impl PartialTreeKernelConfig {
318    /// Create a new configuration
319    pub fn new() -> Result<Self> {
320        Ok(Self {
321            normalize: true,
322            decay: 0.8,
323            threshold: 0.0,
324        })
325    }
326
327    /// Set normalization flag
328    pub fn with_normalize(mut self, normalize: bool) -> Self {
329        self.normalize = normalize;
330        self
331    }
332
333    /// Set decay factor
334    pub fn with_decay(mut self, decay: f64) -> Result<Self> {
335        if !(0.0..=1.0).contains(&decay) {
336            return Err(KernelError::InvalidParameter {
337                parameter: "decay".to_string(),
338                value: decay.to_string(),
339                reason: "must be between 0.0 and 1.0".to_string(),
340            });
341        }
342        self.decay = decay;
343        Ok(self)
344    }
345
346    /// Set threshold
347    pub fn with_threshold(mut self, threshold: f64) -> Result<Self> {
348        if !(0.0..=1.0).contains(&threshold) {
349            return Err(KernelError::InvalidParameter {
350                parameter: "threshold".to_string(),
351                value: threshold.to_string(),
352                reason: "must be between 0.0 and 1.0".to_string(),
353            });
354        }
355        self.threshold = threshold;
356        Ok(self)
357    }
358}
359
360impl Default for PartialTreeKernelConfig {
361    fn default() -> Self {
362        Self::new().unwrap()
363    }
364}
365
366/// Partial tree kernel - allows partial subtree matching
367///
368/// This kernel measures similarity by allowing partial matches with decay factors.
369/// It's useful when trees have similar structure but not exact matches.
370pub struct PartialTreeKernel {
371    config: PartialTreeKernelConfig,
372}
373
374impl PartialTreeKernel {
375    /// Create a new partial tree kernel
376    pub fn new(config: PartialTreeKernelConfig) -> Self {
377        Self { config }
378    }
379
380    /// Compute kernel between two trees
381    pub fn compute_trees(&self, tree1: &TreeNode, tree2: &TreeNode) -> Result<f64> {
382        let similarity = self.compute_partial_match(tree1, tree2, 1.0);
383
384        if similarity < self.config.threshold {
385            return Ok(0.0);
386        }
387
388        if self.config.normalize {
389            let self_sim1 = self.compute_partial_match(tree1, tree1, 1.0);
390            let self_sim2 = self.compute_partial_match(tree2, tree2, 1.0);
391            let norm = (self_sim1 * self_sim2).sqrt();
392            if norm > 0.0 {
393                Ok(similarity / norm)
394            } else {
395                Ok(0.0)
396            }
397        } else {
398            Ok(similarity)
399        }
400    }
401
402    /// Compute partial match score with decay
403    fn compute_partial_match(&self, n1: &TreeNode, n2: &TreeNode, weight: f64) -> f64 {
404        let mut score = 0.0;
405
406        // Exact label match
407        if n1.label == n2.label {
408            score += weight;
409
410            // Recursively match children with decay
411            let min_children = n1.children.len().min(n2.children.len());
412            for i in 0..min_children {
413                score += self.compute_partial_match(
414                    &n1.children[i],
415                    &n2.children[i],
416                    weight * self.config.decay,
417                );
418            }
419        } else {
420            // Partial match based on label similarity (simple heuristic)
421            let label_sim = self.label_similarity(&n1.label, &n2.label);
422            score += weight * label_sim * 0.5; // Partial credit
423
424            // Try matching children even if labels differ
425            let min_children = n1.children.len().min(n2.children.len());
426            for i in 0..min_children {
427                score += self.compute_partial_match(
428                    &n1.children[i],
429                    &n2.children[i],
430                    weight * self.config.decay * 0.5,
431                );
432            }
433        }
434
435        score
436    }
437
438    /// Simple label similarity (can be improved with more sophisticated methods)
439    fn label_similarity(&self, label1: &str, label2: &str) -> f64 {
440        if label1 == label2 {
441            1.0
442        } else {
443            // Simple Jaccard similarity on characters
444            let chars1: std::collections::HashSet<char> = label1.chars().collect();
445            let chars2: std::collections::HashSet<char> = label2.chars().collect();
446            let intersection = chars1.intersection(&chars2).count();
447            let union = chars1.union(&chars2).count();
448            if union > 0 {
449                intersection as f64 / union as f64
450            } else {
451                0.0
452            }
453        }
454    }
455}
456
457#[cfg(test)]
458mod tests {
459    use super::*;
460
461    #[test]
462    fn test_tree_node_creation() {
463        let node = TreeNode::new("root");
464        assert_eq!(node.label, "root");
465        assert!(node.children.is_empty());
466        assert!(node.is_leaf());
467    }
468
469    #[test]
470    fn test_tree_node_with_children() {
471        let child1 = TreeNode::new("child1");
472        let child2 = TreeNode::new("child2");
473        let parent = TreeNode::with_children("parent", vec![child1, child2]);
474
475        assert_eq!(parent.label, "parent");
476        assert_eq!(parent.children.len(), 2);
477        assert!(!parent.is_leaf());
478    }
479
480    #[test]
481    fn test_tree_height() {
482        let leaf = TreeNode::new("leaf");
483        assert_eq!(leaf.height(), 1);
484
485        let tree = TreeNode::with_children(
486            "root",
487            vec![
488                TreeNode::new("child1"),
489                TreeNode::with_children("child2", vec![TreeNode::new("grandchild")]),
490            ],
491        );
492        assert_eq!(tree.height(), 3);
493    }
494
495    #[test]
496    fn test_tree_num_nodes() {
497        let tree = TreeNode::with_children(
498            "root",
499            vec![
500                TreeNode::new("child1"),
501                TreeNode::with_children("child2", vec![TreeNode::new("grandchild")]),
502            ],
503        );
504        assert_eq!(tree.num_nodes(), 4);
505    }
506
507    #[test]
508    fn test_tree_from_tlexpr() {
509        let expr = TLExpr::and(TLExpr::pred("p1", vec![]), TLExpr::pred("p2", vec![]));
510        let tree = TreeNode::from_tlexpr(&expr);
511
512        assert_eq!(tree.label, "And");
513        assert_eq!(tree.children.len(), 2);
514    }
515
516    #[test]
517    fn test_subtree_kernel_identical() {
518        let tree1 = TreeNode::with_children(
519            "root",
520            vec![TreeNode::new("child1"), TreeNode::new("child2")],
521        );
522        let tree2 = tree1.clone();
523
524        let config = SubtreeKernelConfig::new().with_normalize(false);
525        let kernel = SubtreeKernel::new(config);
526
527        let sim = kernel.compute_trees(&tree1, &tree2).unwrap();
528        assert!(sim > 0.0);
529    }
530
531    #[test]
532    fn test_subtree_kernel_different() {
533        // Create trees with completely different children
534        let tree1 = TreeNode::with_children("root", vec![TreeNode::new("child1")]);
535        let tree2 = TreeNode::with_children("root", vec![TreeNode::new("child2")]);
536
537        let config = SubtreeKernelConfig::new().with_normalize(false);
538        let kernel = SubtreeKernel::new(config);
539
540        let sim = kernel.compute_trees(&tree1, &tree2).unwrap();
541        // Trees with different children have no matching subtrees
542        // (root node is different because it includes children)
543        assert!(sim >= 0.0); // No matches expected
544    }
545
546    #[test]
547    fn test_subtree_kernel_partial_match() {
548        // Create trees with same root label and one matching child
549        let tree1 = TreeNode::with_children(
550            "root",
551            vec![TreeNode::new("child1"), TreeNode::new("child2")],
552        );
553        let tree2 = TreeNode::with_children(
554            "root",
555            vec![TreeNode::new("child1"), TreeNode::new("child3")],
556        );
557
558        let config = SubtreeKernelConfig::new().with_normalize(false);
559        let kernel = SubtreeKernel::new(config);
560
561        let sim = kernel.compute_trees(&tree1, &tree2).unwrap();
562        // Should have similarity from the shared "child1" subtree
563        assert!(sim > 0.0);
564    }
565
566    #[test]
567    fn test_subtree_kernel_normalized() {
568        let tree1 = TreeNode::with_children(
569            "root",
570            vec![TreeNode::new("child1"), TreeNode::new("child2")],
571        );
572        let tree2 = tree1.clone();
573
574        let config = SubtreeKernelConfig::new().with_normalize(true);
575        let kernel = SubtreeKernel::new(config);
576
577        let sim = kernel.compute_trees(&tree1, &tree2).unwrap();
578        assert!((sim - 1.0).abs() < 1e-6); // Self-similarity should be 1.0 when normalized
579    }
580
581    #[test]
582    fn test_subset_tree_kernel() {
583        let tree1 = TreeNode::with_children(
584            "root",
585            vec![TreeNode::new("child1"), TreeNode::new("child2")],
586        );
587        let tree2 = TreeNode::with_children("root", vec![TreeNode::new("child1")]);
588
589        let config = SubsetTreeKernelConfig::new().unwrap();
590        let kernel = SubsetTreeKernel::new(config);
591
592        let sim = kernel.compute_trees(&tree1, &tree2).unwrap();
593        assert!(sim > 0.0);
594    }
595
596    #[test]
597    fn test_subset_tree_kernel_decay() {
598        let tree1 = TreeNode::with_children("root", vec![TreeNode::new("child")]);
599        let tree2 = tree1.clone();
600
601        let config1 = SubsetTreeKernelConfig::new()
602            .unwrap()
603            .with_decay(1.0)
604            .unwrap()
605            .with_normalize(false);
606        let kernel1 = SubsetTreeKernel::new(config1);
607
608        let config2 = SubsetTreeKernelConfig::new()
609            .unwrap()
610            .with_decay(0.5)
611            .unwrap()
612            .with_normalize(false);
613        let kernel2 = SubsetTreeKernel::new(config2);
614
615        let sim1 = kernel1.compute_trees(&tree1, &tree2).unwrap();
616        let sim2 = kernel2.compute_trees(&tree1, &tree2).unwrap();
617
618        // Lower decay should give lower similarity
619        assert!(sim2 < sim1);
620    }
621
622    #[test]
623    fn test_partial_tree_kernel() {
624        let tree1 = TreeNode::with_children(
625            "root",
626            vec![TreeNode::new("child1"), TreeNode::new("child2")],
627        );
628        let tree2 = TreeNode::with_children(
629            "root",
630            vec![TreeNode::new("child1"), TreeNode::new("child3")],
631        );
632
633        let config = PartialTreeKernelConfig::new().unwrap();
634        let kernel = PartialTreeKernel::new(config);
635
636        let sim = kernel.compute_trees(&tree1, &tree2).unwrap();
637        assert!(sim > 0.0); // Should have partial similarity
638    }
639
640    #[test]
641    fn test_partial_tree_kernel_threshold() {
642        let tree1 = TreeNode::with_children("root1", vec![TreeNode::new("child")]);
643        let tree2 = TreeNode::with_children("root2", vec![TreeNode::new("child")]);
644
645        let config = PartialTreeKernelConfig::new()
646            .unwrap()
647            .with_threshold(0.9)
648            .unwrap();
649        let kernel = PartialTreeKernel::new(config);
650
651        let sim = kernel.compute_trees(&tree1, &tree2).unwrap();
652        // Different roots with high threshold should give low/zero similarity
653        assert!(sim < 0.5);
654    }
655
656    #[test]
657    fn test_partial_tree_kernel_config_invalid_decay() {
658        let result = PartialTreeKernelConfig::new().unwrap().with_decay(1.5);
659        assert!(result.is_err());
660    }
661
662    #[test]
663    fn test_partial_tree_kernel_config_invalid_threshold() {
664        let result = PartialTreeKernelConfig::new().unwrap().with_threshold(-0.1);
665        assert!(result.is_err());
666    }
667
668    #[test]
669    fn test_tree_kernel_with_tlexpr() {
670        let expr1 = TLExpr::and(TLExpr::pred("p1", vec![]), TLExpr::pred("p2", vec![]));
671        let expr2 = TLExpr::and(TLExpr::pred("p1", vec![]), TLExpr::pred("p3", vec![]));
672
673        let tree1 = TreeNode::from_tlexpr(&expr1);
674        let tree2 = TreeNode::from_tlexpr(&expr2);
675
676        let config = SubtreeKernelConfig::new();
677        let kernel = SubtreeKernel::new(config);
678
679        let sim = kernel.compute_trees(&tree1, &tree2).unwrap();
680        assert!(sim > 0.0); // Should have some similarity (And node and p1 match)
681    }
682}