tensorlogic_sklears_kernels/
logic_kernel.rs1use tensorlogic_ir::TLExpr;
7
8use crate::error::{KernelError, Result};
9use crate::types::{Kernel, RuleSimilarityConfig};
10
11pub struct RuleSimilarityKernel {
29 rules: Vec<TLExpr>,
31 config: RuleSimilarityConfig,
33}
34
35impl RuleSimilarityKernel {
36 pub fn new(rules: Vec<TLExpr>, config: RuleSimilarityConfig) -> Result<Self> {
38 if rules.is_empty() {
39 return Err(KernelError::InvalidParameter {
40 parameter: "rules".to_string(),
41 value: "empty".to_string(),
42 reason: "at least one rule required".to_string(),
43 });
44 }
45
46 Ok(Self { rules, config })
47 }
48
49 fn evaluate_rule(&self, input: &[f64], rule_idx: usize) -> bool {
51 if rule_idx < input.len() {
54 input[rule_idx] > 0.5
55 } else {
56 false
57 }
58 }
59
60 fn compute_agreement(&self, x: &[f64], y: &[f64], rule_idx: usize) -> f64 {
62 let x_satisfies = self.evaluate_rule(x, rule_idx);
63 let y_satisfies = self.evaluate_rule(y, rule_idx);
64
65 match (x_satisfies, y_satisfies) {
66 (true, true) => self.config.satisfied_weight,
67 (false, false) => self.config.violated_weight,
68 _ => self.config.mixed_weight,
69 }
70 }
71}
72
73impl Kernel for RuleSimilarityKernel {
74 fn compute(&self, x: &[f64], y: &[f64]) -> Result<f64> {
75 if x.len() != y.len() {
76 return Err(KernelError::DimensionMismatch {
77 expected: vec![x.len()],
78 got: vec![y.len()],
79 context: "rule similarity kernel".to_string(),
80 });
81 }
82
83 let mut similarity = 0.0;
84 for rule_idx in 0..self.rules.len() {
85 similarity += self.compute_agreement(x, y, rule_idx);
86 }
87
88 if self.config.normalize {
89 similarity /= self.rules.len() as f64;
90 }
91
92 Ok(similarity)
93 }
94
95 fn name(&self) -> &str {
96 "RuleSimilarity"
97 }
98}
99
100pub struct PredicateOverlapKernel {
104 n_predicates: usize,
106 predicate_weights: Vec<f64>,
108}
109
110impl PredicateOverlapKernel {
111 pub fn new(n_predicates: usize) -> Self {
113 Self {
114 n_predicates,
115 predicate_weights: vec![1.0; n_predicates],
116 }
117 }
118
119 pub fn with_weights(n_predicates: usize, weights: Vec<f64>) -> Result<Self> {
121 if weights.len() != n_predicates {
122 return Err(KernelError::DimensionMismatch {
123 expected: vec![n_predicates],
124 got: vec![weights.len()],
125 context: "predicate weights".to_string(),
126 });
127 }
128
129 Ok(Self {
130 n_predicates,
131 predicate_weights: weights,
132 })
133 }
134
135 fn is_predicate_true(&self, input: &[f64], pred_idx: usize) -> bool {
137 if pred_idx < input.len() {
138 input[pred_idx] > 0.5
139 } else {
140 false
141 }
142 }
143}
144
145impl Kernel for PredicateOverlapKernel {
146 fn compute(&self, x: &[f64], y: &[f64]) -> Result<f64> {
147 if x.len() < self.n_predicates || y.len() < self.n_predicates {
148 return Err(KernelError::DimensionMismatch {
149 expected: vec![self.n_predicates],
150 got: vec![x.len().min(y.len())],
151 context: "predicate overlap kernel".to_string(),
152 });
153 }
154
155 let mut overlap = 0.0;
156 for pred_idx in 0..self.n_predicates {
157 if self.is_predicate_true(x, pred_idx) && self.is_predicate_true(y, pred_idx) {
158 overlap += self.predicate_weights[pred_idx];
159 }
160 }
161
162 let total_weight: f64 = self.predicate_weights.iter().sum();
164 Ok(overlap / total_weight)
165 }
166
167 fn name(&self) -> &str {
168 "PredicateOverlap"
169 }
170}
171
172#[cfg(test)]
173mod tests {
174 use super::*;
175
176 fn create_dummy_rules(n: usize) -> Vec<TLExpr> {
177 (0..n)
178 .map(|i| TLExpr::pred(format!("rule_{}", i), vec![]))
179 .collect()
180 }
181
182 #[test]
183 fn test_rule_similarity_kernel_creation() {
184 let rules = create_dummy_rules(5);
185 let config = RuleSimilarityConfig::new();
186 let kernel = RuleSimilarityKernel::new(rules, config).unwrap();
187 assert_eq!(kernel.name(), "RuleSimilarity");
188 }
189
190 #[test]
191 fn test_rule_similarity_kernel_empty_rules() {
192 let rules = vec![];
193 let config = RuleSimilarityConfig::new();
194 let result = RuleSimilarityKernel::new(rules, config);
195 assert!(result.is_err());
196 }
197
198 #[test]
199 fn test_rule_similarity_compute() {
200 let rules = create_dummy_rules(3);
201 let config = RuleSimilarityConfig::new()
202 .with_satisfied_weight(1.0)
203 .with_violated_weight(0.5)
204 .with_mixed_weight(0.0);
205
206 let kernel = RuleSimilarityKernel::new(rules, config).unwrap();
207
208 let x = vec![1.0, 1.0, 1.0];
210 let y = vec![1.0, 1.0, 1.0];
211 let sim = kernel.compute(&x, &y).unwrap();
212 assert!((sim - 1.0).abs() < 1e-10); let x = vec![0.0, 0.0, 0.0];
216 let y = vec![0.0, 0.0, 0.0];
217 let sim = kernel.compute(&x, &y).unwrap();
218 assert!((sim - 0.5).abs() < 1e-10); let x = vec![1.0, 1.0, 1.0];
222 let y = vec![0.0, 0.0, 0.0];
223 let sim = kernel.compute(&x, &y).unwrap();
224 assert!(sim.abs() < 1e-10); }
226
227 #[test]
228 fn test_rule_similarity_dimension_mismatch() {
229 let rules = create_dummy_rules(3);
230 let config = RuleSimilarityConfig::new();
231 let kernel = RuleSimilarityKernel::new(rules, config).unwrap();
232
233 let x = vec![1.0, 1.0];
234 let y = vec![1.0, 1.0, 1.0];
235 let result = kernel.compute(&x, &y);
236 assert!(result.is_err());
237 }
238
239 #[test]
240 fn test_predicate_overlap_kernel() {
241 let kernel = PredicateOverlapKernel::new(4);
242 assert_eq!(kernel.name(), "PredicateOverlap");
243
244 let x = vec![1.0, 1.0, 1.0, 1.0];
246 let y = vec![1.0, 1.0, 1.0, 1.0];
247 let sim = kernel.compute(&x, &y).unwrap();
248 assert!((sim - 1.0).abs() < 1e-10);
249
250 let x = vec![1.0, 1.0, 0.0, 0.0];
252 let y = vec![1.0, 1.0, 1.0, 1.0];
253 let sim = kernel.compute(&x, &y).unwrap();
254 assert!((sim - 0.5).abs() < 1e-10);
255
256 let x = vec![0.0, 0.0, 0.0, 0.0];
258 let y = vec![1.0, 1.0, 1.0, 1.0];
259 let sim = kernel.compute(&x, &y).unwrap();
260 assert!(sim.abs() < 1e-10);
261 }
262
263 #[test]
264 fn test_predicate_overlap_with_weights() {
265 let weights = vec![1.0, 2.0, 1.0, 2.0]; let kernel = PredicateOverlapKernel::with_weights(4, weights).unwrap();
267
268 let x = vec![0.0, 1.0, 0.0, 1.0];
270 let y = vec![0.0, 1.0, 0.0, 1.0];
271 let sim = kernel.compute(&x, &y).unwrap();
272 assert!((sim - 4.0 / 6.0).abs() < 1e-10); }
274
275 #[test]
276 fn test_predicate_overlap_dimension_mismatch() {
277 let kernel = PredicateOverlapKernel::new(5);
278 let x = vec![1.0, 1.0]; let y = vec![1.0, 1.0];
280 let result = kernel.compute(&x, &y);
281 assert!(result.is_err());
282 }
283
284 #[test]
285 fn test_kernel_matrix_computation() {
286 let kernel = PredicateOverlapKernel::new(3);
287 let inputs = vec![
288 vec![1.0, 1.0, 1.0],
289 vec![1.0, 1.0, 0.0],
290 vec![0.0, 0.0, 0.0],
291 ];
292
293 let matrix = kernel.compute_matrix(&inputs).unwrap();
294 assert_eq!(matrix.len(), 3);
295 assert_eq!(matrix[0].len(), 3);
296
297 for (i, row) in matrix.iter().enumerate().take(3) {
299 assert!(row[i] >= 0.0);
300 }
301
302 for (i, row) in matrix.iter().enumerate().take(3) {
304 for j in 0..3 {
305 assert!((row[j] - matrix[j][i]).abs() < 1e-10);
306 }
307 }
308 }
309}