Skip to main content

tensorlogic_sklears_kernels/
feature_extraction.rs

1//! Automatic feature extraction from logical expressions.
2//!
3//! Converts TLExpr into feature vectors for use with kernel methods.
4
5use std::collections::HashMap;
6
7use tensorlogic_ir::TLExpr;
8
9use crate::error::Result;
10
11/// Feature extractor for logical expressions
12///
13/// Automatically converts TLExpr into numerical feature vectors
14/// suitable for kernel computation.
15///
16/// # Example
17///
18/// ```rust
19/// use tensorlogic_sklears_kernels::{FeatureExtractor, FeatureExtractionConfig};
20/// use tensorlogic_ir::TLExpr;
21///
22/// let config = FeatureExtractionConfig::new()
23///     .with_max_depth(3)
24///     .with_encode_structure(true);
25///
26/// let extractor = FeatureExtractor::new(config);
27///
28/// let expr = TLExpr::and(
29///     TLExpr::pred("tall", vec![]),
30///     TLExpr::pred("smart", vec![]),
31/// );
32///
33/// let features = extractor.extract(&expr).unwrap();
34/// println!("Feature vector: {:?}", features);
35/// ```
36#[derive(Clone, Debug)]
37pub struct FeatureExtractor {
38    config: FeatureExtractionConfig,
39    /// Vocabulary for predicate names
40    vocabulary: HashMap<String, usize>,
41}
42
43/// Configuration for feature extraction
44#[derive(Clone, Debug)]
45pub struct FeatureExtractionConfig {
46    /// Maximum tree depth to encode
47    pub max_depth: usize,
48    /// Whether to encode structural information
49    pub encode_structure: bool,
50    /// Whether to encode quantifier information
51    pub encode_quantifiers: bool,
52    /// Feature vector dimension (if fixed)
53    pub fixed_dimension: Option<usize>,
54}
55
56impl FeatureExtractionConfig {
57    /// Create default configuration
58    pub fn new() -> Self {
59        Self {
60            max_depth: 5,
61            encode_structure: true,
62            encode_quantifiers: true,
63            fixed_dimension: None,
64        }
65    }
66
67    /// Set maximum depth
68    pub fn with_max_depth(mut self, depth: usize) -> Self {
69        self.max_depth = depth;
70        self
71    }
72
73    /// Set whether to encode structure
74    pub fn with_encode_structure(mut self, encode: bool) -> Self {
75        self.encode_structure = encode;
76        self
77    }
78
79    /// Set whether to encode quantifiers
80    pub fn with_encode_quantifiers(mut self, encode: bool) -> Self {
81        self.encode_quantifiers = encode;
82        self
83    }
84
85    /// Set fixed feature dimension
86    pub fn with_fixed_dimension(mut self, dim: usize) -> Self {
87        self.fixed_dimension = Some(dim);
88        self
89    }
90}
91
92impl Default for FeatureExtractionConfig {
93    fn default() -> Self {
94        Self::new()
95    }
96}
97
98impl FeatureExtractor {
99    /// Create a new feature extractor
100    pub fn new(config: FeatureExtractionConfig) -> Self {
101        Self {
102            config,
103            vocabulary: HashMap::new(),
104        }
105    }
106
107    /// Extract features from a logical expression
108    pub fn extract(&self, expr: &TLExpr) -> Result<Vec<f64>> {
109        let mut features = Vec::new();
110
111        // Extract predicate frequencies
112        let pred_counts = self.count_predicates(expr);
113
114        // Extract structural features
115        if self.config.encode_structure {
116            features.extend(self.extract_structural_features(expr));
117        }
118
119        // Extract predicate features
120        features.extend(self.extract_predicate_features(&pred_counts));
121
122        // Extract quantifier features
123        if self.config.encode_quantifiers {
124            features.extend(self.extract_quantifier_features(expr));
125        }
126
127        // Pad or truncate to fixed dimension if specified
128        if let Some(dim) = self.config.fixed_dimension {
129            features.resize(dim, 0.0);
130        }
131
132        Ok(features)
133    }
134
135    /// Extract features from multiple expressions
136    pub fn extract_batch(&self, exprs: &[TLExpr]) -> Result<Vec<Vec<f64>>> {
137        exprs.iter().map(|expr| self.extract(expr)).collect()
138    }
139
140    /// Build vocabulary from a set of expressions
141    pub fn build_vocabulary(&mut self, exprs: &[TLExpr]) {
142        let mut vocab_index = 0;
143
144        for expr in exprs {
145            self.collect_predicates(expr, &mut vocab_index);
146        }
147    }
148
149    /// Collect predicates from expression
150    fn collect_predicates(&mut self, expr: &TLExpr, vocab_index: &mut usize) {
151        match expr {
152            TLExpr::Pred { name, .. } => {
153                if !self.vocabulary.contains_key(name) {
154                    self.vocabulary.insert(name.clone(), *vocab_index);
155                    *vocab_index += 1;
156                }
157            }
158            TLExpr::And(left, right) | TLExpr::Or(left, right) | TLExpr::Imply(left, right) => {
159                self.collect_predicates(left, vocab_index);
160                self.collect_predicates(right, vocab_index);
161            }
162            TLExpr::Not(inner) => {
163                self.collect_predicates(inner, vocab_index);
164            }
165            TLExpr::Exists { body, .. } | TLExpr::ForAll { body, .. } => {
166                self.collect_predicates(body, vocab_index);
167            }
168            _ => {}
169        }
170    }
171
172    /// Count predicates in expression
173    fn count_predicates(&self, expr: &TLExpr) -> HashMap<String, usize> {
174        let mut counts = HashMap::new();
175        self.count_predicates_recursive(expr, &mut counts);
176        counts
177    }
178
179    #[allow(clippy::only_used_in_recursion)]
180    fn count_predicates_recursive(&self, expr: &TLExpr, counts: &mut HashMap<String, usize>) {
181        match expr {
182            TLExpr::Pred { name, .. } => {
183                *counts.entry(name.clone()).or_insert(0) += 1;
184            }
185            TLExpr::And(left, right) | TLExpr::Or(left, right) | TLExpr::Imply(left, right) => {
186                self.count_predicates_recursive(left, counts);
187                self.count_predicates_recursive(right, counts);
188            }
189            TLExpr::Not(inner) => {
190                self.count_predicates_recursive(inner, counts);
191            }
192            TLExpr::Exists { body, .. } | TLExpr::ForAll { body, .. } => {
193                self.count_predicates_recursive(body, counts);
194            }
195            _ => {}
196        }
197    }
198
199    /// Extract structural features
200    fn extract_structural_features(&self, expr: &TLExpr) -> Vec<f64> {
201        vec![
202            self.compute_depth(expr, 0) as f64,
203            self.count_nodes(expr) as f64,
204            self.count_operators(expr, "and") as f64,
205            self.count_operators(expr, "or") as f64,
206            self.count_operators(expr, "not") as f64,
207            self.count_operators(expr, "imply") as f64,
208        ]
209    }
210
211    /// Compute tree depth
212    fn compute_depth(&self, expr: &TLExpr, current_depth: usize) -> usize {
213        if current_depth >= self.config.max_depth {
214            return current_depth;
215        }
216
217        match expr {
218            TLExpr::Pred { .. } => current_depth,
219            TLExpr::And(left, right) | TLExpr::Or(left, right) | TLExpr::Imply(left, right) => {
220                let left_depth = self.compute_depth(left, current_depth + 1);
221                let right_depth = self.compute_depth(right, current_depth + 1);
222                left_depth.max(right_depth)
223            }
224            TLExpr::Not(inner)
225            | TLExpr::Exists { body: inner, .. }
226            | TLExpr::ForAll { body: inner, .. } => self.compute_depth(inner, current_depth + 1),
227            _ => current_depth,
228        }
229    }
230
231    /// Count total nodes
232    #[allow(clippy::only_used_in_recursion)]
233    fn count_nodes(&self, expr: &TLExpr) -> usize {
234        match expr {
235            TLExpr::Pred { .. } => 1,
236            TLExpr::And(left, right) | TLExpr::Or(left, right) | TLExpr::Imply(left, right) => {
237                1 + self.count_nodes(left) + self.count_nodes(right)
238            }
239            TLExpr::Not(inner)
240            | TLExpr::Exists { body: inner, .. }
241            | TLExpr::ForAll { body: inner, .. } => 1 + self.count_nodes(inner),
242            _ => 1,
243        }
244    }
245
246    /// Count specific operators
247    #[allow(clippy::only_used_in_recursion)]
248    fn count_operators(&self, expr: &TLExpr, op_type: &str) -> usize {
249        let this_count = match (op_type, expr) {
250            ("and", TLExpr::And(_, _)) => 1,
251            ("or", TLExpr::Or(_, _)) => 1,
252            ("not", TLExpr::Not(_)) => 1,
253            ("imply", TLExpr::Imply(_, _)) => 1,
254            _ => 0,
255        };
256
257        let child_count = match expr {
258            TLExpr::And(left, right) | TLExpr::Or(left, right) | TLExpr::Imply(left, right) => {
259                self.count_operators(left, op_type) + self.count_operators(right, op_type)
260            }
261            TLExpr::Not(inner)
262            | TLExpr::Exists { body: inner, .. }
263            | TLExpr::ForAll { body: inner, .. } => self.count_operators(inner, op_type),
264            _ => 0,
265        };
266
267        this_count + child_count
268    }
269
270    /// Extract predicate features
271    fn extract_predicate_features(&self, counts: &HashMap<String, usize>) -> Vec<f64> {
272        if self.vocabulary.is_empty() {
273            // If no vocabulary, return counts as-is
274            counts.values().map(|&c| c as f64).collect()
275        } else {
276            // Use vocabulary for consistent ordering
277            let mut features = vec![0.0; self.vocabulary.len()];
278            for (pred, &count) in counts {
279                if let Some(&idx) = self.vocabulary.get(pred) {
280                    features[idx] = count as f64;
281                }
282            }
283            features
284        }
285    }
286
287    /// Extract quantifier features
288    fn extract_quantifier_features(&self, expr: &TLExpr) -> Vec<f64> {
289        vec![
290            self.count_quantifiers(expr, "exists") as f64,
291            self.count_quantifiers(expr, "forall") as f64,
292        ]
293    }
294
295    /// Count quantifiers
296    #[allow(clippy::only_used_in_recursion)]
297    fn count_quantifiers(&self, expr: &TLExpr, quant_type: &str) -> usize {
298        let this_count = match (quant_type, expr) {
299            ("exists", TLExpr::Exists { .. }) => 1,
300            ("forall", TLExpr::ForAll { .. }) => 1,
301            _ => 0,
302        };
303
304        let child_count = match expr {
305            TLExpr::And(left, right) | TLExpr::Or(left, right) | TLExpr::Imply(left, right) => {
306                self.count_quantifiers(left, quant_type) + self.count_quantifiers(right, quant_type)
307            }
308            TLExpr::Not(inner)
309            | TLExpr::Exists { body: inner, .. }
310            | TLExpr::ForAll { body: inner, .. } => self.count_quantifiers(inner, quant_type),
311            _ => 0,
312        };
313
314        this_count + child_count
315    }
316
317    /// Get vocabulary size
318    pub fn vocab_size(&self) -> usize {
319        self.vocabulary.len()
320    }
321
322    /// Get vocabulary
323    pub fn vocabulary(&self) -> &HashMap<String, usize> {
324        &self.vocabulary
325    }
326}
327
328#[cfg(test)]
329mod tests {
330    use super::*;
331
332    #[test]
333    fn test_feature_extraction_basic() {
334        let config = FeatureExtractionConfig::new();
335        let extractor = FeatureExtractor::new(config);
336
337        let expr = TLExpr::pred("tall", vec![]);
338        let features = extractor.extract(&expr).unwrap();
339
340        assert!(!features.is_empty());
341    }
342
343    #[test]
344    fn test_feature_extraction_compound() {
345        let config = FeatureExtractionConfig::new();
346        let extractor = FeatureExtractor::new(config);
347
348        let expr = TLExpr::and(TLExpr::pred("tall", vec![]), TLExpr::pred("smart", vec![]));
349
350        let features = extractor.extract(&expr).unwrap();
351        assert!(!features.is_empty());
352    }
353
354    #[test]
355    fn test_structural_features() {
356        let config = FeatureExtractionConfig::new().with_encode_structure(true);
357        let extractor = FeatureExtractor::new(config);
358
359        let expr = TLExpr::and(
360            TLExpr::pred("a", vec![]),
361            TLExpr::or(TLExpr::pred("b", vec![]), TLExpr::pred("c", vec![])),
362        );
363
364        let features = extractor.extract(&expr).unwrap();
365
366        // Should have depth > 1
367        assert!(features[0] > 1.0);
368
369        // Should have multiple nodes
370        assert!(features[1] > 1.0);
371    }
372
373    #[test]
374    fn test_quantifier_features() {
375        let config = FeatureExtractionConfig::new().with_encode_quantifiers(true);
376        let extractor = FeatureExtractor::new(config);
377
378        let expr = TLExpr::exists("x", "Person", TLExpr::pred("likes", vec![]));
379
380        let features = extractor.extract(&expr).unwrap();
381        assert!(!features.is_empty());
382    }
383
384    #[test]
385    fn test_vocabulary_building() {
386        let config = FeatureExtractionConfig::new();
387        let mut extractor = FeatureExtractor::new(config);
388
389        let exprs = vec![
390            TLExpr::pred("tall", vec![]),
391            TLExpr::pred("smart", vec![]),
392            TLExpr::pred("tall", vec![]),
393        ];
394
395        extractor.build_vocabulary(&exprs);
396
397        assert_eq!(extractor.vocab_size(), 2); // tall and smart
398    }
399
400    #[test]
401    fn test_batch_extraction() {
402        let config = FeatureExtractionConfig::new();
403        let extractor = FeatureExtractor::new(config);
404
405        let exprs = vec![
406            TLExpr::pred("a", vec![]),
407            TLExpr::pred("b", vec![]),
408            TLExpr::and(TLExpr::pred("a", vec![]), TLExpr::pred("b", vec![])),
409        ];
410
411        let features = extractor.extract_batch(&exprs).unwrap();
412        assert_eq!(features.len(), 3);
413    }
414
415    #[test]
416    fn test_fixed_dimension() {
417        let config = FeatureExtractionConfig::new().with_fixed_dimension(10);
418        let extractor = FeatureExtractor::new(config);
419
420        let expr = TLExpr::pred("test", vec![]);
421        let features = extractor.extract(&expr).unwrap();
422
423        assert_eq!(features.len(), 10);
424    }
425
426    #[test]
427    fn test_depth_computation() {
428        let config = FeatureExtractionConfig::new();
429        let extractor = FeatureExtractor::new(config);
430
431        // Depth 0: single predicate
432        let expr1 = TLExpr::pred("a", vec![]);
433        assert_eq!(extractor.compute_depth(&expr1, 0), 0);
434
435        // Depth 2: nested structure
436        let expr2 = TLExpr::and(
437            TLExpr::pred("a", vec![]),
438            TLExpr::and(TLExpr::pred("b", vec![]), TLExpr::pred("c", vec![])),
439        );
440        assert_eq!(extractor.compute_depth(&expr2, 0), 2);
441    }
442
443    #[test]
444    fn test_operator_counting() {
445        let config = FeatureExtractionConfig::new();
446        let extractor = FeatureExtractor::new(config);
447
448        let expr = TLExpr::and(
449            TLExpr::and(TLExpr::pred("a", vec![]), TLExpr::pred("b", vec![])),
450            TLExpr::or(TLExpr::pred("c", vec![]), TLExpr::pred("d", vec![])),
451        );
452
453        assert_eq!(extractor.count_operators(&expr, "and"), 2);
454        assert_eq!(extractor.count_operators(&expr, "or"), 1);
455        assert_eq!(extractor.count_operators(&expr, "not"), 0);
456    }
457}