Skip to main content

tensorlogic_ir/expr/
strategy_selector.rs

1//! Automatic optimization strategy selection based on expression characteristics.
2//!
3//! This module provides heuristics for automatically selecting the best optimization
4//! strategy based on the structure and complexity of expressions.
5
6use super::{
7    advanced_analysis::{ComplexityMetrics, OperatorCounts},
8    optimization_pipeline::{OptimizationLevel, OptimizationPass, PipelineConfig},
9    TLExpr,
10};
11
12/// Characteristics of an expression used for strategy selection.
13#[derive(Clone, Debug, PartialEq)]
14pub struct ExpressionProfile {
15    /// Operator counts by category
16    pub operator_counts: OperatorCounts,
17    /// Complexity metrics
18    pub complexity: ComplexityMetrics,
19    /// Whether the expression has quantifiers
20    pub has_quantifiers: bool,
21    /// Whether the expression has modal operators
22    pub has_modal: bool,
23    /// Whether the expression has temporal operators
24    pub has_temporal: bool,
25    /// Whether the expression has fuzzy operators
26    pub has_fuzzy: bool,
27    /// Whether the expression has constants
28    pub has_constants: bool,
29    /// Expression size (node count)
30    pub size: usize,
31}
32
33impl ExpressionProfile {
34    /// Analyze an expression to create a profile.
35    pub fn analyze(expr: &TLExpr) -> Self {
36        let operator_counts = OperatorCounts::from_expr(expr);
37        let complexity = ComplexityMetrics::from_expr(expr);
38
39        Self {
40            has_quantifiers: operator_counts.quantifiers > 0,
41            has_modal: operator_counts.modal > 0,
42            has_temporal: operator_counts.temporal > 0,
43            has_fuzzy: operator_counts.fuzzy > 0,
44            has_constants: operator_counts.constants > 0,
45            size: operator_counts.total,
46            operator_counts,
47            complexity,
48        }
49    }
50
51    /// Check if the expression is simple (few operators, shallow depth).
52    pub fn is_simple(&self) -> bool {
53        self.size <= 10 && self.complexity.max_depth <= 3
54    }
55
56    /// Check if the expression is complex (many operators, deep nesting).
57    pub fn is_complex(&self) -> bool {
58        self.size > 50 || self.complexity.max_depth > 10
59    }
60
61    /// Check if the expression would benefit from distributive laws.
62    pub fn needs_distribution(&self) -> bool {
63        // Check for patterns like A ∧ (B ∨ C) or A ∨ (B ∧ C)
64        // Heuristic: many logical operators might benefit
65        self.operator_counts.logical > 5
66    }
67
68    /// Check if the expression has significant constant folding opportunities.
69    pub fn has_constant_opportunities(&self) -> bool {
70        self.has_constants && self.operator_counts.arithmetic > 0
71    }
72}
73
74/// Strategy selector that recommends optimization configurations.
75#[derive(Clone, Copy, Debug)]
76pub struct StrategySelector {
77    /// Default optimization level for fallback
78    _default_level: OptimizationLevel,
79}
80
81impl Default for StrategySelector {
82    fn default() -> Self {
83        Self {
84            _default_level: OptimizationLevel::Standard,
85        }
86    }
87}
88
89impl StrategySelector {
90    /// Create a new strategy selector with a default optimization level.
91    pub fn new(default_level: OptimizationLevel) -> Self {
92        Self {
93            _default_level: default_level,
94        }
95    }
96
97    /// Select an optimization level based on expression profile.
98    pub fn select_level(&self, profile: &ExpressionProfile) -> OptimizationLevel {
99        // Simple expressions: use basic optimizations
100        if profile.is_simple() {
101            return OptimizationLevel::Basic;
102        }
103
104        // Complex expressions with specific features: use aggressive
105        if profile.is_complex()
106            && (profile.has_modal || profile.has_temporal || profile.has_quantifiers)
107        {
108            return OptimizationLevel::Aggressive;
109        }
110
111        // Default to standard for most cases
112        OptimizationLevel::Standard
113    }
114
115    /// Select specific optimization passes based on expression profile.
116    pub fn select_passes(&self, profile: &ExpressionProfile) -> Vec<OptimizationPass> {
117        let mut passes = Vec::new();
118
119        // Always include constant folding if there are constants
120        if profile.has_constants {
121            passes.push(OptimizationPass::ConstantFolding);
122            passes.push(OptimizationPass::ConstantPropagation);
123        }
124
125        // Always include basic simplifications
126        passes.push(OptimizationPass::AlgebraicSimplification);
127
128        // Add NNF conversion for complex logical expressions
129        if profile.operator_counts.logical > 3 {
130            passes.push(OptimizationPass::NegationNormalForm);
131        }
132
133        // Add modal equivalences if modal operators present
134        if profile.has_modal {
135            passes.push(OptimizationPass::ModalEquivalences);
136            passes.push(OptimizationPass::DistributiveModal);
137        }
138
139        // Add temporal equivalences if temporal operators present
140        if profile.has_temporal {
141            passes.push(OptimizationPass::TemporalEquivalences);
142        }
143
144        // Add quantifier distribution if quantifiers present
145        if profile.has_quantifiers && profile.operator_counts.quantifiers > 2 {
146            passes.push(OptimizationPass::DistributiveQuantifiers);
147        }
148
149        // Add distributive laws for complex logical expressions
150        if profile.needs_distribution() {
151            passes.push(OptimizationPass::DistributiveAndOverOr);
152        }
153
154        passes
155    }
156
157    /// Create a recommended pipeline configuration for an expression.
158    pub fn recommend_config(&self, expr: &TLExpr) -> PipelineConfig {
159        let profile = ExpressionProfile::analyze(expr);
160        let level = self.select_level(&profile);
161        let custom_passes = self.select_passes(&profile);
162
163        let max_iterations = if profile.is_complex() { 15 } else { 10 };
164
165        PipelineConfig::with_level(level)
166            .with_custom_passes(custom_passes)
167            .with_max_iterations(max_iterations)
168    }
169
170    /// Quick optimization recommendation: returns whether aggressive optimization is recommended.
171    pub fn should_optimize_aggressively(&self, expr: &TLExpr) -> bool {
172        let profile = ExpressionProfile::analyze(expr);
173        matches!(self.select_level(&profile), OptimizationLevel::Aggressive)
174    }
175}
176
177/// Convenience function to automatically select and apply the best optimization strategy.
178pub fn auto_optimize(expr: TLExpr) -> (TLExpr, super::optimization_pipeline::OptimizationMetrics) {
179    let selector = StrategySelector::default();
180    let config = selector.recommend_config(&expr);
181
182    let pipeline = super::optimization_pipeline::OptimizationPipeline::new(config);
183    pipeline.optimize(expr)
184}
185
186#[cfg(test)]
187mod tests {
188    use super::*;
189    use crate::Term;
190
191    #[test]
192    fn test_profile_simple_expression() {
193        let expr = TLExpr::pred("P", vec![Term::var("x")]);
194        let profile = ExpressionProfile::analyze(&expr);
195
196        assert!(profile.is_simple());
197        assert!(!profile.is_complex());
198        assert!(!profile.has_constants);
199    }
200
201    #[test]
202    fn test_profile_complex_expression() {
203        // Build a deeply nested expression
204        let mut expr = TLExpr::pred("P", vec![Term::var("x")]);
205        for _ in 0..15 {
206            expr = TLExpr::and(expr.clone(), TLExpr::pred("Q", vec![Term::var("y")]));
207        }
208
209        let profile = ExpressionProfile::analyze(&expr);
210        assert!(profile.is_complex());
211        assert!(!profile.is_simple());
212    }
213
214    #[test]
215    fn test_profile_with_constants() {
216        let expr = TLExpr::add(TLExpr::constant(1.0), TLExpr::constant(2.0));
217        let profile = ExpressionProfile::analyze(&expr);
218
219        assert!(profile.has_constants);
220        assert!(profile.has_constant_opportunities());
221    }
222
223    #[test]
224    fn test_profile_with_quantifiers() {
225        let expr = TLExpr::forall("x", "D", TLExpr::pred("P", vec![Term::var("x")]));
226        let profile = ExpressionProfile::analyze(&expr);
227
228        assert!(profile.has_quantifiers);
229    }
230
231    #[test]
232    fn test_profile_with_modal() {
233        let expr = TLExpr::modal_box(TLExpr::pred("P", vec![Term::var("x")]));
234        let profile = ExpressionProfile::analyze(&expr);
235
236        assert!(profile.has_modal);
237    }
238
239    #[test]
240    fn test_profile_with_temporal() {
241        let expr = TLExpr::eventually(TLExpr::pred("P", vec![Term::var("x")]));
242        let profile = ExpressionProfile::analyze(&expr);
243
244        assert!(profile.has_temporal);
245    }
246
247    #[test]
248    fn test_selector_simple_expression() {
249        let expr = TLExpr::pred("P", vec![Term::var("x")]);
250        let selector = StrategySelector::default();
251        let profile = ExpressionProfile::analyze(&expr);
252
253        let level = selector.select_level(&profile);
254        assert_eq!(level, OptimizationLevel::Basic);
255    }
256
257    #[test]
258    fn test_selector_complex_modal_expression() {
259        // Build a complex modal expression
260        let mut expr = TLExpr::modal_box(TLExpr::pred("P", vec![Term::var("x")]));
261        for _ in 0..12 {
262            expr = TLExpr::and(expr.clone(), TLExpr::modal_box(TLExpr::pred("Q", vec![])));
263        }
264
265        let selector = StrategySelector::default();
266        let profile = ExpressionProfile::analyze(&expr);
267
268        let level = selector.select_level(&profile);
269        assert_eq!(level, OptimizationLevel::Aggressive);
270    }
271
272    #[test]
273    fn test_selector_pass_selection() {
274        let expr = TLExpr::and(
275            TLExpr::constant(1.0),
276            TLExpr::modal_box(TLExpr::pred("P", vec![Term::var("x")])),
277        );
278
279        let selector = StrategySelector::default();
280        let profile = ExpressionProfile::analyze(&expr);
281        let passes = selector.select_passes(&profile);
282
283        // Should include constant folding and modal equivalences
284        assert!(passes.contains(&OptimizationPass::ConstantFolding));
285        assert!(passes.contains(&OptimizationPass::ModalEquivalences));
286    }
287
288    #[test]
289    fn test_recommend_config() {
290        let expr = TLExpr::and(
291            TLExpr::constant(1.0),
292            TLExpr::pred("P", vec![Term::var("x")]),
293        );
294
295        let selector = StrategySelector::default();
296        let config = selector.recommend_config(&expr);
297
298        assert_eq!(config.level, OptimizationLevel::Basic);
299        assert!(config.custom_passes.is_some());
300    }
301
302    #[test]
303    fn test_auto_optimize() {
304        let expr = TLExpr::and(
305            TLExpr::constant(1.0),
306            TLExpr::pred("P", vec![Term::var("x")]),
307        );
308
309        let (optimized, metrics) = auto_optimize(expr);
310
311        // Should have applied optimizations
312        assert!(metrics.passes_applied > 0);
313        // Should have simplified the expression
314        assert_eq!(optimized, TLExpr::pred("P", vec![Term::var("x")]));
315    }
316
317    #[test]
318    fn test_should_optimize_aggressively() {
319        let simple_expr = TLExpr::pred("P", vec![Term::var("x")]);
320        let selector = StrategySelector::default();
321
322        assert!(!selector.should_optimize_aggressively(&simple_expr));
323
324        // Build complex expression
325        let mut complex_expr = TLExpr::modal_box(TLExpr::pred("P", vec![Term::var("x")]));
326        for _ in 0..12 {
327            complex_expr = TLExpr::and(
328                complex_expr.clone(),
329                TLExpr::modal_box(TLExpr::pred("Q", vec![])),
330            );
331        }
332
333        assert!(selector.should_optimize_aggressively(&complex_expr));
334    }
335
336    #[test]
337    fn test_needs_distribution() {
338        // Expression with many logical operators
339        let mut expr = TLExpr::pred("P", vec![Term::var("x")]);
340        for i in 0..7 {
341            expr = TLExpr::and(
342                expr,
343                TLExpr::or(
344                    TLExpr::pred("Q", vec![Term::var(format!("x{}", i))]),
345                    TLExpr::pred("R", vec![Term::var(format!("y{}", i))]),
346                ),
347            );
348        }
349
350        let profile = ExpressionProfile::analyze(&expr);
351        assert!(profile.needs_distribution());
352    }
353}