Skip to main content

tensorlogic_sklears_kernels/
logic_kernel.rs

1//! Logic-derived similarity kernels.
2//!
3//! These kernels measure similarity based on logical rule satisfaction patterns.
4//! Two inputs are similar if they satisfy the same set of logical rules.
5
6use tensorlogic_ir::TLExpr;
7
8use crate::error::{KernelError, Result};
9use crate::types::{Kernel, RuleSimilarityConfig};
10
11/// Logic-based similarity kernel.
12///
13/// Measures similarity based on which logical rules are satisfied by each input.
14///
15/// ## Formula
16///
17/// ```text
18/// K(x, y) = Σ_r w_r * agreement(x, y, r)
19/// ```
20///
21/// Where:
22/// - `r` ranges over logical rules
23/// - `w_r` is the weight for rule r
24/// - `agreement(x, y, r)` measures if x and y agree on rule r:
25///   - Both satisfy: satisfied_weight
26///   - Both violate: violated_weight
27///   - Disagree: mixed_weight
28pub struct RuleSimilarityKernel {
29    /// Logical rules for comparison
30    rules: Vec<TLExpr>,
31    /// Configuration
32    config: RuleSimilarityConfig,
33}
34
35impl RuleSimilarityKernel {
36    /// Create a new rule similarity kernel
37    pub fn new(rules: Vec<TLExpr>, config: RuleSimilarityConfig) -> Result<Self> {
38        if rules.is_empty() {
39            return Err(KernelError::InvalidParameter {
40                parameter: "rules".to_string(),
41                value: "empty".to_string(),
42                reason: "at least one rule required".to_string(),
43            });
44        }
45
46        Ok(Self { rules, config })
47    }
48
49    /// Evaluate if input satisfies a rule (simplified: uses feature index as rule ID)
50    fn evaluate_rule(&self, input: &[f64], rule_idx: usize) -> bool {
51        // Simplified: Check if the feature value at rule_idx > 0.5
52        // Real implementation would compile and execute the TLExpr
53        if rule_idx < input.len() {
54            input[rule_idx] > 0.5
55        } else {
56            false
57        }
58    }
59
60    /// Compute agreement between two inputs on a rule
61    fn compute_agreement(&self, x: &[f64], y: &[f64], rule_idx: usize) -> f64 {
62        let x_satisfies = self.evaluate_rule(x, rule_idx);
63        let y_satisfies = self.evaluate_rule(y, rule_idx);
64
65        match (x_satisfies, y_satisfies) {
66            (true, true) => self.config.satisfied_weight,
67            (false, false) => self.config.violated_weight,
68            _ => self.config.mixed_weight,
69        }
70    }
71}
72
73impl Kernel for RuleSimilarityKernel {
74    fn compute(&self, x: &[f64], y: &[f64]) -> Result<f64> {
75        if x.len() != y.len() {
76            return Err(KernelError::DimensionMismatch {
77                expected: vec![x.len()],
78                got: vec![y.len()],
79                context: "rule similarity kernel".to_string(),
80            });
81        }
82
83        let mut similarity = 0.0;
84        for rule_idx in 0..self.rules.len() {
85            similarity += self.compute_agreement(x, y, rule_idx);
86        }
87
88        if self.config.normalize {
89            similarity /= self.rules.len() as f64;
90        }
91
92        Ok(similarity)
93    }
94
95    fn name(&self) -> &str {
96        "RuleSimilarity"
97    }
98}
99
100/// Predicate overlap kernel.
101///
102/// Measures similarity based on how many predicates are true for both inputs.
103pub struct PredicateOverlapKernel {
104    /// Number of predicates to consider
105    n_predicates: usize,
106    /// Weight for each predicate
107    predicate_weights: Vec<f64>,
108}
109
110impl PredicateOverlapKernel {
111    /// Create a new predicate overlap kernel
112    pub fn new(n_predicates: usize) -> Self {
113        Self {
114            n_predicates,
115            predicate_weights: vec![1.0; n_predicates],
116        }
117    }
118
119    /// Create with custom predicate weights
120    pub fn with_weights(n_predicates: usize, weights: Vec<f64>) -> Result<Self> {
121        if weights.len() != n_predicates {
122            return Err(KernelError::DimensionMismatch {
123                expected: vec![n_predicates],
124                got: vec![weights.len()],
125                context: "predicate weights".to_string(),
126            });
127        }
128
129        Ok(Self {
130            n_predicates,
131            predicate_weights: weights,
132        })
133    }
134
135    /// Check if predicate is satisfied (threshold at 0.5)
136    fn is_predicate_true(&self, input: &[f64], pred_idx: usize) -> bool {
137        if pred_idx < input.len() {
138            input[pred_idx] > 0.5
139        } else {
140            false
141        }
142    }
143}
144
145impl Kernel for PredicateOverlapKernel {
146    fn compute(&self, x: &[f64], y: &[f64]) -> Result<f64> {
147        if x.len() < self.n_predicates || y.len() < self.n_predicates {
148            return Err(KernelError::DimensionMismatch {
149                expected: vec![self.n_predicates],
150                got: vec![x.len().min(y.len())],
151                context: "predicate overlap kernel".to_string(),
152            });
153        }
154
155        let mut overlap = 0.0;
156        for pred_idx in 0..self.n_predicates {
157            if self.is_predicate_true(x, pred_idx) && self.is_predicate_true(y, pred_idx) {
158                overlap += self.predicate_weights[pred_idx];
159            }
160        }
161
162        // Normalize by total weight
163        let total_weight: f64 = self.predicate_weights.iter().sum();
164        Ok(overlap / total_weight)
165    }
166
167    fn name(&self) -> &str {
168        "PredicateOverlap"
169    }
170}
171
172#[cfg(test)]
173mod tests {
174    use super::*;
175
176    fn create_dummy_rules(n: usize) -> Vec<TLExpr> {
177        (0..n)
178            .map(|i| TLExpr::pred(format!("rule_{}", i), vec![]))
179            .collect()
180    }
181
182    #[test]
183    fn test_rule_similarity_kernel_creation() {
184        let rules = create_dummy_rules(5);
185        let config = RuleSimilarityConfig::new();
186        let kernel = RuleSimilarityKernel::new(rules, config).unwrap();
187        assert_eq!(kernel.name(), "RuleSimilarity");
188    }
189
190    #[test]
191    fn test_rule_similarity_kernel_empty_rules() {
192        let rules = vec![];
193        let config = RuleSimilarityConfig::new();
194        let result = RuleSimilarityKernel::new(rules, config);
195        assert!(result.is_err());
196    }
197
198    #[test]
199    fn test_rule_similarity_compute() {
200        let rules = create_dummy_rules(3);
201        let config = RuleSimilarityConfig::new()
202            .with_satisfied_weight(1.0)
203            .with_violated_weight(0.5)
204            .with_mixed_weight(0.0);
205
206        let kernel = RuleSimilarityKernel::new(rules, config).unwrap();
207
208        // Both satisfy all rules
209        let x = vec![1.0, 1.0, 1.0];
210        let y = vec![1.0, 1.0, 1.0];
211        let sim = kernel.compute(&x, &y).unwrap();
212        assert!((sim - 1.0).abs() < 1e-10); // Normalized: 3.0 / 3 = 1.0
213
214        // Both violate all rules
215        let x = vec![0.0, 0.0, 0.0];
216        let y = vec![0.0, 0.0, 0.0];
217        let sim = kernel.compute(&x, &y).unwrap();
218        assert!((sim - 0.5).abs() < 1e-10); // Normalized: 1.5 / 3 = 0.5
219
220        // Completely disagree
221        let x = vec![1.0, 1.0, 1.0];
222        let y = vec![0.0, 0.0, 0.0];
223        let sim = kernel.compute(&x, &y).unwrap();
224        assert!(sim.abs() < 1e-10); // Normalized: 0.0 / 3 = 0.0
225    }
226
227    #[test]
228    fn test_rule_similarity_dimension_mismatch() {
229        let rules = create_dummy_rules(3);
230        let config = RuleSimilarityConfig::new();
231        let kernel = RuleSimilarityKernel::new(rules, config).unwrap();
232
233        let x = vec![1.0, 1.0];
234        let y = vec![1.0, 1.0, 1.0];
235        let result = kernel.compute(&x, &y);
236        assert!(result.is_err());
237    }
238
239    #[test]
240    fn test_predicate_overlap_kernel() {
241        let kernel = PredicateOverlapKernel::new(4);
242        assert_eq!(kernel.name(), "PredicateOverlap");
243
244        // All predicates match
245        let x = vec![1.0, 1.0, 1.0, 1.0];
246        let y = vec![1.0, 1.0, 1.0, 1.0];
247        let sim = kernel.compute(&x, &y).unwrap();
248        assert!((sim - 1.0).abs() < 1e-10);
249
250        // Half predicates match
251        let x = vec![1.0, 1.0, 0.0, 0.0];
252        let y = vec![1.0, 1.0, 1.0, 1.0];
253        let sim = kernel.compute(&x, &y).unwrap();
254        assert!((sim - 0.5).abs() < 1e-10);
255
256        // No predicates match
257        let x = vec![0.0, 0.0, 0.0, 0.0];
258        let y = vec![1.0, 1.0, 1.0, 1.0];
259        let sim = kernel.compute(&x, &y).unwrap();
260        assert!(sim.abs() < 1e-10);
261    }
262
263    #[test]
264    fn test_predicate_overlap_with_weights() {
265        let weights = vec![1.0, 2.0, 1.0, 2.0]; // Total = 6.0
266        let kernel = PredicateOverlapKernel::with_weights(4, weights).unwrap();
267
268        // Higher-weighted predicates match
269        let x = vec![0.0, 1.0, 0.0, 1.0];
270        let y = vec![0.0, 1.0, 0.0, 1.0];
271        let sim = kernel.compute(&x, &y).unwrap();
272        assert!((sim - 4.0 / 6.0).abs() < 1e-10); // (2.0 + 2.0) / 6.0
273    }
274
275    #[test]
276    fn test_predicate_overlap_dimension_mismatch() {
277        let kernel = PredicateOverlapKernel::new(5);
278        let x = vec![1.0, 1.0]; // Only 2 features, need 5
279        let y = vec![1.0, 1.0];
280        let result = kernel.compute(&x, &y);
281        assert!(result.is_err());
282    }
283
284    #[test]
285    fn test_kernel_matrix_computation() {
286        let kernel = PredicateOverlapKernel::new(3);
287        let inputs = vec![
288            vec![1.0, 1.0, 1.0],
289            vec![1.0, 1.0, 0.0],
290            vec![0.0, 0.0, 0.0],
291        ];
292
293        let matrix = kernel.compute_matrix(&inputs).unwrap();
294        assert_eq!(matrix.len(), 3);
295        assert_eq!(matrix[0].len(), 3);
296
297        // Check diagonal (self-similarity)
298        for (i, row) in matrix.iter().enumerate().take(3) {
299            assert!(row[i] >= 0.0);
300        }
301
302        // Check symmetry
303        for (i, row) in matrix.iter().enumerate().take(3) {
304            for j in 0..3 {
305                assert!((row[j] - matrix[j][i]).abs() < 1e-10);
306            }
307        }
308    }
309}