rust_rule_engine/backward/
optimizer.rs

1//! Query optimization for backward chaining
2//!
3//! This module implements query optimization techniques to improve performance:
4//! - Goal reordering based on selectivity
5//! - Index selection for O(1) lookups
6//! - Join ordering optimization
7//! - Predicate pushdown
8//! - Memoization of intermediate results
9//!
10//! # Examples
11//!
12//! ```rust,ignore
13//! use rust_rule_engine::backward::optimizer::QueryOptimizer;
14//!
15//! let mut optimizer = QueryOptimizer::new();
16//!
17//! // Before: item(?x) AND expensive(?x) AND in_stock(?x)
18//! // After:  in_stock(?x) AND expensive(?x) AND item(?x)
19//! //         (Evaluates most selective first)
20//!
21//! let optimized = optimizer.optimize_query(query)?;
22//! ```
23
24use super::goal::Goal;
25use std::collections::HashMap;
26
27/// Query optimizer for backward chaining
28#[derive(Debug, Clone)]
29pub struct QueryOptimizer {
30    /// Selectivity estimates for predicates
31    selectivity_map: HashMap<String, f64>,
32
33    /// Whether to enable goal reordering
34    enable_reordering: bool,
35
36    /// Whether to enable index selection
37    enable_index_selection: bool,
38
39    /// Whether to enable memoization
40    enable_memoization: bool,
41
42    /// Statistics for optimization
43    stats: OptimizationStats,
44}
45
46impl QueryOptimizer {
47    /// Create a new query optimizer
48    pub fn new() -> Self {
49        Self {
50            selectivity_map: HashMap::new(),
51            enable_reordering: true,
52            enable_index_selection: true,
53            enable_memoization: true,
54            stats: OptimizationStats::new(),
55        }
56    }
57
58    /// Create optimizer with custom configuration
59    pub fn with_config(config: OptimizerConfig) -> Self {
60        Self {
61            selectivity_map: HashMap::new(),
62            enable_reordering: config.enable_reordering,
63            enable_index_selection: config.enable_index_selection,
64            enable_memoization: config.enable_memoization,
65            stats: OptimizationStats::new(),
66        }
67    }
68
69    /// Optimize a list of goals
70    ///
71    /// Returns a reordered list of goals optimized for evaluation
72    pub fn optimize_goals(&mut self, goals: Vec<Goal>) -> Vec<Goal> {
73        if !self.enable_reordering || goals.len() <= 1 {
74            return goals;
75        }
76
77        self.stats.total_optimizations += 1;
78
79        // Estimate selectivity for each goal
80        let mut goal_selectivity: Vec<(Goal, f64)> = goals
81            .into_iter()
82            .map(|g| {
83                let selectivity = self.estimate_selectivity(&g);
84                (g, selectivity)
85            })
86            .collect();
87
88        // Sort by selectivity (lower = more selective = evaluate first)
89        goal_selectivity.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
90
91        // Extract optimized goals
92        let optimized: Vec<Goal> = goal_selectivity.into_iter().map(|(g, _)| g).collect();
93
94        self.stats.goals_reordered += optimized.len();
95
96        optimized
97    }
98
99    /// Estimate selectivity of a goal (lower = more selective)
100    ///
101    /// Returns a value between 0.0 (very selective) and 1.0 (not selective)
102    pub fn estimate_selectivity(&self, goal: &Goal) -> f64 {
103        // Check if we have a known selectivity estimate
104        if let Some(&selectivity) = self.selectivity_map.get(&goal.pattern) {
105            return selectivity;
106        }
107
108        // Heuristic-based estimation
109        self.heuristic_selectivity(goal)
110    }
111
112    /// Heuristic-based selectivity estimation
113    fn heuristic_selectivity(&self, goal: &Goal) -> f64 {
114        let pattern = &goal.pattern;
115
116        // Count bound vs unbound variables
117        let (bound_count, var_count) = self.count_variables(pattern);
118
119        if var_count == 0 {
120            // No variables = most selective (exact match)
121            return 0.1;
122        }
123
124        // More bound variables = more selective
125        let bound_ratio = bound_count as f64 / var_count as f64;
126        let selectivity = 1.0 - (bound_ratio * 0.8);
127
128        // Check for specific patterns
129        if pattern.contains("in_stock") || pattern.contains("available") {
130            // Stock checks typically very selective
131            return selectivity * 0.3;
132        }
133
134        if pattern.contains("expensive") || pattern.contains("premium") {
135            // Price filters moderately selective
136            return selectivity * 0.5;
137        }
138
139        if pattern.contains("item") || pattern.contains("product") {
140            // Generic predicates less selective
141            return selectivity * 1.2;
142        }
143
144        selectivity
145    }
146
147    /// Count bound and total variables in a pattern
148    fn count_variables(&self, pattern: &str) -> (usize, usize) {
149        let mut bound = 0;
150        let mut total = 0;
151
152        let chars: Vec<char> = pattern.chars().collect();
153        let mut i = 0;
154
155        while i < chars.len() {
156            if chars[i] == '?' {
157                total += 1;
158
159                // Skip variable name
160                i += 1;
161                while i < chars.len() && (chars[i].is_alphanumeric() || chars[i] == '_') {
162                    i += 1;
163                }
164
165                // Check if followed by a comparison (indicates bound)
166                while i < chars.len() && chars[i].is_whitespace() {
167                    i += 1;
168                }
169
170                if i < chars.len() && (chars[i] == '=' || chars[i] == '>' || chars[i] == '<') {
171                    bound += 1;
172                }
173            } else {
174                i += 1;
175            }
176        }
177
178        (bound, total)
179    }
180
181    /// Set selectivity estimate for a predicate
182    pub fn set_selectivity(&mut self, predicate: String, selectivity: f64) {
183        self.selectivity_map
184            .insert(predicate, selectivity.clamp(0.0, 1.0));
185    }
186
187    /// Get optimization statistics
188    pub fn stats(&self) -> &OptimizationStats {
189        &self.stats
190    }
191
192    /// Reset statistics
193    pub fn reset_stats(&mut self) {
194        self.stats = OptimizationStats::new();
195    }
196
197    /// Enable or disable goal reordering
198    pub fn set_reordering(&mut self, enabled: bool) {
199        self.enable_reordering = enabled;
200    }
201
202    /// Enable or disable index selection
203    pub fn set_index_selection(&mut self, enabled: bool) {
204        self.enable_index_selection = enabled;
205    }
206
207    /// Enable or disable memoization
208    pub fn set_memoization(&mut self, enabled: bool) {
209        self.enable_memoization = enabled;
210    }
211}
212
213impl Default for QueryOptimizer {
214    fn default() -> Self {
215        Self::new()
216    }
217}
218
219/// Configuration for the query optimizer
220#[derive(Debug, Clone)]
221pub struct OptimizerConfig {
222    /// Enable goal reordering
223    pub enable_reordering: bool,
224
225    /// Enable index selection
226    pub enable_index_selection: bool,
227
228    /// Enable memoization
229    pub enable_memoization: bool,
230}
231
232impl Default for OptimizerConfig {
233    fn default() -> Self {
234        Self {
235            enable_reordering: true,
236            enable_index_selection: true,
237            enable_memoization: true,
238        }
239    }
240}
241
242/// Statistics for query optimization
243#[derive(Debug, Clone, Default)]
244pub struct OptimizationStats {
245    /// Total number of optimizations performed
246    pub total_optimizations: usize,
247
248    /// Total goals reordered
249    pub goals_reordered: usize,
250
251    /// Number of index selections made
252    pub index_selections: usize,
253
254    /// Number of memoization hits
255    pub memoization_hits: usize,
256
257    /// Number of memoization misses
258    pub memoization_misses: usize,
259}
260
261impl OptimizationStats {
262    /// Create new statistics
263    pub fn new() -> Self {
264        Self::default()
265    }
266
267    /// Get memoization hit rate
268    pub fn memoization_hit_rate(&self) -> f64 {
269        let total = self.memoization_hits + self.memoization_misses;
270        if total == 0 {
271            0.0
272        } else {
273            self.memoization_hits as f64 / total as f64
274        }
275    }
276
277    /// Get summary string
278    pub fn summary(&self) -> String {
279        format!(
280            "Optimizations: {} | Goals reordered: {} | Memo hits: {} ({:.1}%)",
281            self.total_optimizations,
282            self.goals_reordered,
283            self.memoization_hits,
284            self.memoization_hit_rate() * 100.0
285        )
286    }
287}
288
289/// Join ordering optimizer
290#[derive(Debug, Clone)]
291#[allow(dead_code)]
292pub struct JoinOptimizer {
293    /// Cost estimates for different join strategies
294    cost_model: HashMap<String, f64>,
295}
296
297impl JoinOptimizer {
298    /// Create a new join optimizer
299    pub fn new() -> Self {
300        Self {
301            cost_model: HashMap::new(),
302        }
303    }
304
305    /// Optimize join order for a set of goals
306    ///
307    /// Returns goals ordered for optimal join performance
308    pub fn optimize_joins(&self, goals: Vec<Goal>) -> Vec<Goal> {
309        if goals.len() <= 1 {
310            return goals;
311        }
312
313        // Simple heuristic: start with goals that have most bound variables
314        let mut sorted_goals = goals;
315        sorted_goals.sort_by_key(|g| {
316            // Count bound variables (negative for descending sort)
317            -(self.count_bound_vars(&g.pattern) as i32)
318        });
319
320        sorted_goals
321    }
322
323    /// Count bound variables in a pattern
324    fn count_bound_vars(&self, pattern: &str) -> usize {
325        // Simple heuristic: variables followed by comparison operators
326        let mut count = 0;
327        let chars: Vec<char> = pattern.chars().collect();
328        let mut i = 0;
329
330        while i < chars.len() {
331            if chars[i] == '?' {
332                // Skip variable name
333                i += 1;
334                while i < chars.len() && (chars[i].is_alphanumeric() || chars[i] == '_') {
335                    i += 1;
336                }
337
338                // Check if followed by comparison
339                while i < chars.len() && chars[i].is_whitespace() {
340                    i += 1;
341                }
342
343                if i < chars.len() && (chars[i] == '=' || chars[i] == '>' || chars[i] == '<') {
344                    count += 1;
345                }
346            } else {
347                i += 1;
348            }
349        }
350
351        count
352    }
353}
354
355impl Default for JoinOptimizer {
356    fn default() -> Self {
357        Self::new()
358    }
359}
360
361#[cfg(test)]
362mod tests {
363    use super::*;
364
365    #[test]
366    fn test_optimizer_creation() {
367        let optimizer = QueryOptimizer::new();
368        assert!(optimizer.enable_reordering);
369        assert!(optimizer.enable_index_selection);
370        assert!(optimizer.enable_memoization);
371    }
372
373    #[test]
374    fn test_optimizer_with_config() {
375        let config = OptimizerConfig {
376            enable_reordering: false,
377            enable_index_selection: true,
378            enable_memoization: false,
379        };
380
381        let optimizer = QueryOptimizer::with_config(config);
382        assert!(!optimizer.enable_reordering);
383        assert!(optimizer.enable_index_selection);
384        assert!(!optimizer.enable_memoization);
385    }
386
387    #[test]
388    fn test_goal_reordering() {
389        let mut optimizer = QueryOptimizer::new();
390
391        // Set selectivity estimates (lower = more selective)
392        optimizer.set_selectivity("in_stock(?x)".to_string(), 0.1);
393        optimizer.set_selectivity("expensive(?x)".to_string(), 0.3);
394        optimizer.set_selectivity("item(?x)".to_string(), 0.9);
395
396        let goals = vec![
397            Goal::new("item(?x)".to_string()),
398            Goal::new("expensive(?x)".to_string()),
399            Goal::new("in_stock(?x)".to_string()),
400        ];
401
402        let optimized = optimizer.optimize_goals(goals);
403
404        // Should be ordered: in_stock, expensive, item
405        assert_eq!(optimized[0].pattern, "in_stock(?x)");
406        assert_eq!(optimized[1].pattern, "expensive(?x)");
407        assert_eq!(optimized[2].pattern, "item(?x)");
408    }
409
410    #[test]
411    fn test_selectivity_estimation() {
412        let optimizer = QueryOptimizer::new();
413
414        // Exact match (no variables) - very selective
415        let goal1 = Goal::new("employee(alice)".to_string());
416        let sel1 = optimizer.estimate_selectivity(&goal1);
417        assert!(sel1 < 0.5);
418
419        // One unbound variable - less selective
420        let goal2 = Goal::new("employee(?x)".to_string());
421        let sel2 = optimizer.estimate_selectivity(&goal2);
422        assert!(sel2 > sel1);
423
424        // Bound variable (with comparison) - more selective
425        let goal3 = Goal::new("salary(?x) WHERE ?x > 100000".to_string());
426        let sel3 = optimizer.estimate_selectivity(&goal3);
427        // Should be more selective than fully unbound
428        assert!(sel3 < sel2);
429    }
430
431    #[test]
432    fn test_count_variables() {
433        let optimizer = QueryOptimizer::new();
434
435        // No variables
436        let (bound, total) = optimizer.count_variables("employee(alice)");
437        assert_eq!(total, 0);
438        assert_eq!(bound, 0);
439
440        // One unbound variable
441        let (bound, total) = optimizer.count_variables("employee(?x)");
442        assert_eq!(total, 1);
443        assert_eq!(bound, 0);
444
445        // One bound variable
446        let (bound, total) = optimizer.count_variables("salary(?x) WHERE ?x > 100");
447        assert_eq!(total, 2); // ?x appears twice
448        assert_eq!(bound, 1); // Second ?x is bound by >
449    }
450
451    #[test]
452    fn test_optimization_stats() {
453        let mut optimizer = QueryOptimizer::new();
454
455        let goals = vec![
456            Goal::new("a(?x)".to_string()),
457            Goal::new("b(?x)".to_string()),
458        ];
459
460        optimizer.optimize_goals(goals);
461
462        let stats = optimizer.stats();
463        assert_eq!(stats.total_optimizations, 1);
464        assert_eq!(stats.goals_reordered, 2);
465    }
466
467    #[test]
468    fn test_join_optimizer() {
469        let optimizer = JoinOptimizer::new();
470
471        let goals = vec![
472            Goal::new("item(?x)".to_string()),
473            Goal::new("price(?x, ?p) WHERE ?p > 100".to_string()),
474            Goal::new("in_stock(?x)".to_string()),
475        ];
476
477        let optimized = optimizer.optimize_joins(goals);
478
479        // Goal with bound variable should come first
480        assert!(optimized[0].pattern.contains("?p > 100"));
481    }
482
483    #[test]
484    fn test_disable_reordering() {
485        let mut optimizer = QueryOptimizer::new();
486        optimizer.set_reordering(false);
487
488        let goals = vec![
489            Goal::new("a(?x)".to_string()),
490            Goal::new("b(?x)".to_string()),
491            Goal::new("c(?x)".to_string()),
492        ];
493
494        let optimized = optimizer.optimize_goals(goals.clone());
495
496        // Order should be unchanged
497        assert_eq!(optimized[0].pattern, goals[0].pattern);
498        assert_eq!(optimized[1].pattern, goals[1].pattern);
499        assert_eq!(optimized[2].pattern, goals[2].pattern);
500    }
501
502    #[test]
503    fn test_stats_summary() {
504        let mut stats = OptimizationStats::new();
505        stats.total_optimizations = 10;
506        stats.goals_reordered = 25;
507        stats.memoization_hits = 8;
508        stats.memoization_misses = 2;
509
510        let summary = stats.summary();
511        assert!(summary.contains("10"));
512        assert!(summary.contains("25"));
513        assert!(summary.contains("8"));
514        assert!(summary.contains("80")); // 80% hit rate
515    }
516
517    #[test]
518    fn test_memoization_hit_rate() {
519        let mut stats = OptimizationStats::new();
520
521        // No data yet
522        assert_eq!(stats.memoization_hit_rate(), 0.0);
523
524        // 80% hit rate
525        stats.memoization_hits = 8;
526        stats.memoization_misses = 2;
527        assert!((stats.memoization_hit_rate() - 0.8).abs() < 0.01);
528
529        // 100% hit rate
530        stats.memoization_hits = 10;
531        stats.memoization_misses = 0;
532        assert_eq!(stats.memoization_hit_rate(), 1.0);
533    }
534}