1use std::collections::HashMap;
6
7use tensorlogic_ir::TLExpr;
8
9use crate::error::Result;
10
11#[derive(Clone, Debug)]
37pub struct FeatureExtractor {
38 config: FeatureExtractionConfig,
39 vocabulary: HashMap<String, usize>,
41}
42
43#[derive(Clone, Debug)]
45pub struct FeatureExtractionConfig {
46 pub max_depth: usize,
48 pub encode_structure: bool,
50 pub encode_quantifiers: bool,
52 pub fixed_dimension: Option<usize>,
54}
55
56impl FeatureExtractionConfig {
57 pub fn new() -> Self {
59 Self {
60 max_depth: 5,
61 encode_structure: true,
62 encode_quantifiers: true,
63 fixed_dimension: None,
64 }
65 }
66
67 pub fn with_max_depth(mut self, depth: usize) -> Self {
69 self.max_depth = depth;
70 self
71 }
72
73 pub fn with_encode_structure(mut self, encode: bool) -> Self {
75 self.encode_structure = encode;
76 self
77 }
78
79 pub fn with_encode_quantifiers(mut self, encode: bool) -> Self {
81 self.encode_quantifiers = encode;
82 self
83 }
84
85 pub fn with_fixed_dimension(mut self, dim: usize) -> Self {
87 self.fixed_dimension = Some(dim);
88 self
89 }
90}
91
92impl Default for FeatureExtractionConfig {
93 fn default() -> Self {
94 Self::new()
95 }
96}
97
98impl FeatureExtractor {
99 pub fn new(config: FeatureExtractionConfig) -> Self {
101 Self {
102 config,
103 vocabulary: HashMap::new(),
104 }
105 }
106
107 pub fn extract(&self, expr: &TLExpr) -> Result<Vec<f64>> {
109 let mut features = Vec::new();
110
111 let pred_counts = self.count_predicates(expr);
113
114 if self.config.encode_structure {
116 features.extend(self.extract_structural_features(expr));
117 }
118
119 features.extend(self.extract_predicate_features(&pred_counts));
121
122 if self.config.encode_quantifiers {
124 features.extend(self.extract_quantifier_features(expr));
125 }
126
127 if let Some(dim) = self.config.fixed_dimension {
129 features.resize(dim, 0.0);
130 }
131
132 Ok(features)
133 }
134
135 pub fn extract_batch(&self, exprs: &[TLExpr]) -> Result<Vec<Vec<f64>>> {
137 exprs.iter().map(|expr| self.extract(expr)).collect()
138 }
139
140 pub fn build_vocabulary(&mut self, exprs: &[TLExpr]) {
142 let mut vocab_index = 0;
143
144 for expr in exprs {
145 self.collect_predicates(expr, &mut vocab_index);
146 }
147 }
148
149 fn collect_predicates(&mut self, expr: &TLExpr, vocab_index: &mut usize) {
151 match expr {
152 TLExpr::Pred { name, .. } => {
153 if !self.vocabulary.contains_key(name) {
154 self.vocabulary.insert(name.clone(), *vocab_index);
155 *vocab_index += 1;
156 }
157 }
158 TLExpr::And(left, right) | TLExpr::Or(left, right) | TLExpr::Imply(left, right) => {
159 self.collect_predicates(left, vocab_index);
160 self.collect_predicates(right, vocab_index);
161 }
162 TLExpr::Not(inner) => {
163 self.collect_predicates(inner, vocab_index);
164 }
165 TLExpr::Exists { body, .. } | TLExpr::ForAll { body, .. } => {
166 self.collect_predicates(body, vocab_index);
167 }
168 _ => {}
169 }
170 }
171
172 fn count_predicates(&self, expr: &TLExpr) -> HashMap<String, usize> {
174 let mut counts = HashMap::new();
175 self.count_predicates_recursive(expr, &mut counts);
176 counts
177 }
178
179 #[allow(clippy::only_used_in_recursion)]
180 fn count_predicates_recursive(&self, expr: &TLExpr, counts: &mut HashMap<String, usize>) {
181 match expr {
182 TLExpr::Pred { name, .. } => {
183 *counts.entry(name.clone()).or_insert(0) += 1;
184 }
185 TLExpr::And(left, right) | TLExpr::Or(left, right) | TLExpr::Imply(left, right) => {
186 self.count_predicates_recursive(left, counts);
187 self.count_predicates_recursive(right, counts);
188 }
189 TLExpr::Not(inner) => {
190 self.count_predicates_recursive(inner, counts);
191 }
192 TLExpr::Exists { body, .. } | TLExpr::ForAll { body, .. } => {
193 self.count_predicates_recursive(body, counts);
194 }
195 _ => {}
196 }
197 }
198
199 fn extract_structural_features(&self, expr: &TLExpr) -> Vec<f64> {
201 vec![
202 self.compute_depth(expr, 0) as f64,
203 self.count_nodes(expr) as f64,
204 self.count_operators(expr, "and") as f64,
205 self.count_operators(expr, "or") as f64,
206 self.count_operators(expr, "not") as f64,
207 self.count_operators(expr, "imply") as f64,
208 ]
209 }
210
211 fn compute_depth(&self, expr: &TLExpr, current_depth: usize) -> usize {
213 if current_depth >= self.config.max_depth {
214 return current_depth;
215 }
216
217 match expr {
218 TLExpr::Pred { .. } => current_depth,
219 TLExpr::And(left, right) | TLExpr::Or(left, right) | TLExpr::Imply(left, right) => {
220 let left_depth = self.compute_depth(left, current_depth + 1);
221 let right_depth = self.compute_depth(right, current_depth + 1);
222 left_depth.max(right_depth)
223 }
224 TLExpr::Not(inner)
225 | TLExpr::Exists { body: inner, .. }
226 | TLExpr::ForAll { body: inner, .. } => self.compute_depth(inner, current_depth + 1),
227 _ => current_depth,
228 }
229 }
230
231 #[allow(clippy::only_used_in_recursion)]
233 fn count_nodes(&self, expr: &TLExpr) -> usize {
234 match expr {
235 TLExpr::Pred { .. } => 1,
236 TLExpr::And(left, right) | TLExpr::Or(left, right) | TLExpr::Imply(left, right) => {
237 1 + self.count_nodes(left) + self.count_nodes(right)
238 }
239 TLExpr::Not(inner)
240 | TLExpr::Exists { body: inner, .. }
241 | TLExpr::ForAll { body: inner, .. } => 1 + self.count_nodes(inner),
242 _ => 1,
243 }
244 }
245
246 #[allow(clippy::only_used_in_recursion)]
248 fn count_operators(&self, expr: &TLExpr, op_type: &str) -> usize {
249 let this_count = match (op_type, expr) {
250 ("and", TLExpr::And(_, _)) => 1,
251 ("or", TLExpr::Or(_, _)) => 1,
252 ("not", TLExpr::Not(_)) => 1,
253 ("imply", TLExpr::Imply(_, _)) => 1,
254 _ => 0,
255 };
256
257 let child_count = match expr {
258 TLExpr::And(left, right) | TLExpr::Or(left, right) | TLExpr::Imply(left, right) => {
259 self.count_operators(left, op_type) + self.count_operators(right, op_type)
260 }
261 TLExpr::Not(inner)
262 | TLExpr::Exists { body: inner, .. }
263 | TLExpr::ForAll { body: inner, .. } => self.count_operators(inner, op_type),
264 _ => 0,
265 };
266
267 this_count + child_count
268 }
269
270 fn extract_predicate_features(&self, counts: &HashMap<String, usize>) -> Vec<f64> {
272 if self.vocabulary.is_empty() {
273 counts.values().map(|&c| c as f64).collect()
275 } else {
276 let mut features = vec![0.0; self.vocabulary.len()];
278 for (pred, &count) in counts {
279 if let Some(&idx) = self.vocabulary.get(pred) {
280 features[idx] = count as f64;
281 }
282 }
283 features
284 }
285 }
286
287 fn extract_quantifier_features(&self, expr: &TLExpr) -> Vec<f64> {
289 vec![
290 self.count_quantifiers(expr, "exists") as f64,
291 self.count_quantifiers(expr, "forall") as f64,
292 ]
293 }
294
295 #[allow(clippy::only_used_in_recursion)]
297 fn count_quantifiers(&self, expr: &TLExpr, quant_type: &str) -> usize {
298 let this_count = match (quant_type, expr) {
299 ("exists", TLExpr::Exists { .. }) => 1,
300 ("forall", TLExpr::ForAll { .. }) => 1,
301 _ => 0,
302 };
303
304 let child_count = match expr {
305 TLExpr::And(left, right) | TLExpr::Or(left, right) | TLExpr::Imply(left, right) => {
306 self.count_quantifiers(left, quant_type) + self.count_quantifiers(right, quant_type)
307 }
308 TLExpr::Not(inner)
309 | TLExpr::Exists { body: inner, .. }
310 | TLExpr::ForAll { body: inner, .. } => self.count_quantifiers(inner, quant_type),
311 _ => 0,
312 };
313
314 this_count + child_count
315 }
316
317 pub fn vocab_size(&self) -> usize {
319 self.vocabulary.len()
320 }
321
322 pub fn vocabulary(&self) -> &HashMap<String, usize> {
324 &self.vocabulary
325 }
326}
327
328#[cfg(test)]
329mod tests {
330 use super::*;
331
332 #[test]
333 fn test_feature_extraction_basic() {
334 let config = FeatureExtractionConfig::new();
335 let extractor = FeatureExtractor::new(config);
336
337 let expr = TLExpr::pred("tall", vec![]);
338 let features = extractor.extract(&expr).unwrap();
339
340 assert!(!features.is_empty());
341 }
342
343 #[test]
344 fn test_feature_extraction_compound() {
345 let config = FeatureExtractionConfig::new();
346 let extractor = FeatureExtractor::new(config);
347
348 let expr = TLExpr::and(TLExpr::pred("tall", vec![]), TLExpr::pred("smart", vec![]));
349
350 let features = extractor.extract(&expr).unwrap();
351 assert!(!features.is_empty());
352 }
353
354 #[test]
355 fn test_structural_features() {
356 let config = FeatureExtractionConfig::new().with_encode_structure(true);
357 let extractor = FeatureExtractor::new(config);
358
359 let expr = TLExpr::and(
360 TLExpr::pred("a", vec![]),
361 TLExpr::or(TLExpr::pred("b", vec![]), TLExpr::pred("c", vec![])),
362 );
363
364 let features = extractor.extract(&expr).unwrap();
365
366 assert!(features[0] > 1.0);
368
369 assert!(features[1] > 1.0);
371 }
372
373 #[test]
374 fn test_quantifier_features() {
375 let config = FeatureExtractionConfig::new().with_encode_quantifiers(true);
376 let extractor = FeatureExtractor::new(config);
377
378 let expr = TLExpr::exists("x", "Person", TLExpr::pred("likes", vec![]));
379
380 let features = extractor.extract(&expr).unwrap();
381 assert!(!features.is_empty());
382 }
383
384 #[test]
385 fn test_vocabulary_building() {
386 let config = FeatureExtractionConfig::new();
387 let mut extractor = FeatureExtractor::new(config);
388
389 let exprs = vec![
390 TLExpr::pred("tall", vec![]),
391 TLExpr::pred("smart", vec![]),
392 TLExpr::pred("tall", vec![]),
393 ];
394
395 extractor.build_vocabulary(&exprs);
396
397 assert_eq!(extractor.vocab_size(), 2); }
399
400 #[test]
401 fn test_batch_extraction() {
402 let config = FeatureExtractionConfig::new();
403 let extractor = FeatureExtractor::new(config);
404
405 let exprs = vec![
406 TLExpr::pred("a", vec![]),
407 TLExpr::pred("b", vec![]),
408 TLExpr::and(TLExpr::pred("a", vec![]), TLExpr::pred("b", vec![])),
409 ];
410
411 let features = extractor.extract_batch(&exprs).unwrap();
412 assert_eq!(features.len(), 3);
413 }
414
415 #[test]
416 fn test_fixed_dimension() {
417 let config = FeatureExtractionConfig::new().with_fixed_dimension(10);
418 let extractor = FeatureExtractor::new(config);
419
420 let expr = TLExpr::pred("test", vec![]);
421 let features = extractor.extract(&expr).unwrap();
422
423 assert_eq!(features.len(), 10);
424 }
425
426 #[test]
427 fn test_depth_computation() {
428 let config = FeatureExtractionConfig::new();
429 let extractor = FeatureExtractor::new(config);
430
431 let expr1 = TLExpr::pred("a", vec![]);
433 assert_eq!(extractor.compute_depth(&expr1, 0), 0);
434
435 let expr2 = TLExpr::and(
437 TLExpr::pred("a", vec![]),
438 TLExpr::and(TLExpr::pred("b", vec![]), TLExpr::pred("c", vec![])),
439 );
440 assert_eq!(extractor.compute_depth(&expr2, 0), 2);
441 }
442
443 #[test]
444 fn test_operator_counting() {
445 let config = FeatureExtractionConfig::new();
446 let extractor = FeatureExtractor::new(config);
447
448 let expr = TLExpr::and(
449 TLExpr::and(TLExpr::pred("a", vec![]), TLExpr::pred("b", vec![])),
450 TLExpr::or(TLExpr::pred("c", vec![]), TLExpr::pred("d", vec![])),
451 );
452
453 assert_eq!(extractor.count_operators(&expr, "and"), 2);
454 assert_eq!(extractor.count_operators(&expr, "or"), 1);
455 assert_eq!(extractor.count_operators(&expr, "not"), 0);
456 }
457}